Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
d42e2629
Commit
d42e2629
authored
Apr 09, 2018
by
rusty1s
Browse files
clean up
parent
c99021ca
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
120 additions
and
225 deletions
+120
-225
aten/THC/THCBasis.cu
aten/THC/THCBasis.cu
+1
-105
aten/THC/THCBasisForward.cuh
aten/THC/THCBasisForward.cuh
+110
-0
test/basis.json
test/basis.json
+0
-65
test/test_basis.py
test/test_basis.py
+9
-47
test/utils.py
test/utils.py
+0
-8
No files found.
aten/THC/THCBasis.cu
View file @
d42e2629
#include "THCBasis.h"
#include "common.cuh"
#include "THCNumerics.cuh"
#define THC_TENSOR_BASIS_FORWARD(NAME, state, basis, weightIndex, pseudo, kernelSize, \
isOpenSpline) { \
THCAssertSameGPU( \
THCTensor_(checkGPU)(state, 5, basis, weightIndex, pseudo, kernelSize, isOpenSpline)); \
\
TensorInfo<real> basisInfo = THCTensor_(getTensorInfo)(state, basis); \
TensorInfo<int64_t> weightIndexInfo = THCudaLongTensor_getTensorInfo(state, weightIndex); \
TensorInfo<real> pseudoInfo = THCTensor_(getTensorInfo)(state, pseudo); \
int64_t *kernelSizeData = THCudaLongTensor_data(state, kernelSize); \
uint8_t *isOpenSplineData = THCudaByteTensor_data(state, isOpenSpline); \
\
KERNEL_REAL_RUN(NAME, THCTensor_(nElement)(state, basis), basisInfo, \
weightIndexInfo, pseudoInfo, kernelSizeData, isOpenSplineData); \
}
#define THC_TENSOR_BASIS_FORWARD_KERNEL(M, basis, weightIndex, pseudo, kernelSize, isOpenSpline, \
N, CODE) { \
KERNEL_LOOP(i, N) { \
ptrdiff_t e = i / basis.size[1], s = i % basis.size[1], d; \
int64_t k = s, kMod, wi = 0, wiOffset = 1; \
T b = ScalarConvert<int, T>::to(1), v; \
\
for (d = 0; d < pseudo.size[1]; d++) { \
kMod = k % (M + 1); \
k /= M + 1; \
\
v = pseudo.data[e * pseudo.stride[0] + d * pseudo.stride[1]]; \
v = THCNumerics<T>::mul(v, ScalarConvert<int64_t, T>::to(kernelSize[d] - M * isOpenSpline[d])); \
\
wi += ((ScalarConvert<T, int64_t>::to(v) + kMod) % kernelSize[d]) * wiOffset; \
wiOffset *= kernelSize[d]; \
\
v = THCNumerics<T>::sub(v, ScalarConvert<int64_t, T>::to(ScalarConvert<T, int64_t>::to(v))); \
CODE \
b = THCNumerics<T>::mul(b, v); \
} \
\
basis.data[e * basis.stride[0] + s * basis.stride[1]] = b; \
weightIndex.data[e * weightIndex.stride[0] + s * weightIndex.stride[1]] = wi; \
} \
}
template
<
typename
T
>
struct
BasisForward
{
static
inline
__device__
T
linear
(
T
v
,
int64_t
kMod
)
{
// 1 - v - kMod + 2 * v * kMod
T
tmp1
=
THCNumerics
<
T
>::
sub
(
ScalarConvert
<
int
,
T
>::
to
(
1
),
v
);
tmp1
=
THCNumerics
<
T
>::
sub
(
tmp1
,
ScalarConvert
<
int64_t
,
T
>::
to
(
kMod
));
T
tmp2
=
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
int
,
T
>::
to
(
2
),
v
);
tmp2
=
THCNumerics
<
T
>::
mul
(
tmp2
,
ScalarConvert
<
int64_t
,
T
>::
to
(
kMod
));
return
THCNumerics
<
T
>::
add
(
tmp1
,
tmp2
);
}
static
inline
__device__
T
quadratic
(
T
v
,
int64_t
kMod
)
{
if
(
kMod
==
0
)
{
// 0.5 * v * v - v + 0.5
T
tmp
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
float
,
T
>::
to
(
0.5
),
v
),
v
);
return
THCNumerics
<
T
>::
sub
(
tmp
,
THCNumerics
<
T
>::
add
(
v
,
ScalarConvert
<
float
,
T
>::
to
(
0.5
)));
}
else
if
(
kMod
==
1
)
{
// -v * v + v + 0.5
T
tmp
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
neg
(
v
),
v
);
return
THCNumerics
<
T
>::
add
(
THCNumerics
<
T
>::
add
(
tmp
,
v
),
ScalarConvert
<
float
,
T
>::
to
(
0.5
));
}
else
{
// 0.5 * v * v
return
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
float
,
T
>::
to
(
0.5
),
THCNumerics
<
T
>::
mul
(
v
,
v
));
}
}
static
inline
__device__
T
cubic
(
T
v
,
int64_t
kMod
)
{
if
(
kMod
==
0
)
{
// (1 - v) * (1 -v) * (1 - v) / 6
T
tmp
=
THCNumerics
<
T
>::
sub
(
ScalarConvert
<
int
,
T
>::
to
(
1
),
v
);
tmp
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
tmp
,
tmp
),
tmp
);
return
THCNumerics
<
T
>::
div
(
tmp
,
ScalarConvert
<
int
,
T
>::
to
(
6
));
}
else
if
(
kMod
==
1
)
{
// (3 * v * v * v - 6 * v * v + 4) / 6
T
tmp1
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
v
,
v
),
v
);
tmp1
=
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
int
,
T
>::
to
(
3
),
tmp1
);
T
tmp2
=
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
int
,
T
>::
to
(
6
),
THCNumerics
<
T
>::
mul
(
v
,
v
));
tmp1
=
THCNumerics
<
T
>::
add
(
THCNumerics
<
T
>::
sub
(
tmp1
,
tmp2
),
ScalarConvert
<
int
,
T
>::
to
(
4
));
return
THCNumerics
<
T
>::
div
(
tmp1
,
ScalarConvert
<
int
,
T
>::
to
(
6
));
}
else
if
(
kMod
==
2
)
{
// (-3 * v * v * v + 3 * v * v + 3 * v + 1) / 6
T
tmp1
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
v
,
v
),
v
);
tmp1
=
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
int
,
T
>::
to
(
-
3
),
tmp1
);
T
tmp2
=
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
int
,
T
>::
to
(
3
),
THCNumerics
<
T
>::
mul
(
v
,
v
));
T
tmp3
=
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
int
,
T
>::
to
(
3
),
v
);
tmp1
=
THCNumerics
<
T
>::
add
(
THCNumerics
<
T
>::
add
(
tmp1
,
tmp2
),
tmp3
);
tmp1
=
THCNumerics
<
T
>::
add
(
tmp1
,
ScalarConvert
<
int
,
T
>::
to
(
1
));
return
THCNumerics
<
T
>::
div
(
tmp1
,
ScalarConvert
<
int
,
T
>::
to
(
6
));
}
else
{
// v * v * v / 6
T
tmp
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
v
,
v
),
v
);
return
THCNumerics
<
T
>::
div
(
tmp
,
ScalarConvert
<
int
,
T
>::
to
(
6
));
}
}
};
#include "THCBasisForward.cuh"
template
<
typename
T
>
__global__
void
linearBasisForwardKernel
(
TensorInfo
<
T
>
basis
,
TensorInfo
<
int64_t
>
weightIndex
,
...
...
aten/THC/THCBasisForward.cuh
0 → 100644
View file @
d42e2629
#ifndef THC_BASIS_FORWARD_INC
#define THC_BASIS_FORWARD_INC
#include "common.cuh"
#include "THCNumerics.cuh"
#define THC_TENSOR_BASIS_FORWARD(NAME, state, basis, weightIndex, pseudo, kernelSize, \
isOpenSpline) { \
THCAssertSameGPU( \
THCTensor_(checkGPU)(state, 5, basis, weightIndex, pseudo, kernelSize, isOpenSpline)); \
\
TensorInfo<real> basisInfo = THCTensor_(getTensorInfo)(state, basis); \
TensorInfo<int64_t> weightIndexInfo = THCudaLongTensor_getTensorInfo(state, weightIndex); \
TensorInfo<real> pseudoInfo = THCTensor_(getTensorInfo)(state, pseudo); \
int64_t *kernelSizeData = THCudaLongTensor_data(state, kernelSize); \
uint8_t *isOpenSplineData = THCudaByteTensor_data(state, isOpenSpline); \
\
KERNEL_REAL_RUN(NAME, THCTensor_(nElement)(state, basis), basisInfo, \
weightIndexInfo, pseudoInfo, kernelSizeData, isOpenSplineData); \
}
#define THC_TENSOR_BASIS_FORWARD_KERNEL(M, basis, weightIndex, pseudo, kernelSize, isOpenSpline, \
N, CODE) { \
KERNEL_LOOP(i, N) { \
ptrdiff_t e = i / basis.size[1], s = i % basis.size[1], d; \
int64_t k = s, kMod, wi = 0, wiOffset = 1; \
T b = ScalarConvert<int, T>::to(1), v; \
\
for (d = 0; d < pseudo.size[1]; d++) { \
kMod = k % (M + 1); \
k /= M + 1; \
\
v = pseudo.data[e * pseudo.stride[0] + d * pseudo.stride[1]]; \
v = THCNumerics<T>::mul(v, ScalarConvert<int64_t, T>::to(kernelSize[d] - M * isOpenSpline[d])); \
\
wi += ((ScalarConvert<T, int64_t>::to(v) + kMod) % kernelSize[d]) * wiOffset; \
wiOffset *= kernelSize[d]; \
\
v = THCNumerics<T>::sub(v, ScalarConvert<int64_t, T>::to(ScalarConvert<T, int64_t>::to(v))); \
CODE \
b = THCNumerics<T>::mul(b, v); \
} \
\
basis.data[e * basis.stride[0] + s * basis.stride[1]] = b; \
weightIndex.data[e * weightIndex.stride[0] + s * weightIndex.stride[1]] = wi; \
} \
}
template
<
typename
T
>
struct
BasisForward
{
static
inline
__device__
T
linear
(
T
v
,
int64_t
kMod
)
{
// 1 - v - kMod + 2 * v * kMod
T
tmp1
=
THCNumerics
<
T
>::
sub
(
ScalarConvert
<
int
,
T
>::
to
(
1
),
v
);
tmp1
=
THCNumerics
<
T
>::
sub
(
tmp1
,
ScalarConvert
<
int64_t
,
T
>::
to
(
kMod
));
T
tmp2
=
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
int
,
T
>::
to
(
2
),
v
);
tmp2
=
THCNumerics
<
T
>::
mul
(
tmp2
,
ScalarConvert
<
int64_t
,
T
>::
to
(
kMod
));
return
THCNumerics
<
T
>::
add
(
tmp1
,
tmp2
);
}
static
inline
__device__
T
quadratic
(
T
v
,
int64_t
kMod
)
{
if
(
kMod
==
0
)
{
// 0.5 * v * v - v + 0.5
T
tmp
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
float
,
T
>::
to
(
0.5
),
v
),
v
);
return
THCNumerics
<
T
>::
add
(
THCNumerics
<
T
>::
sub
(
tmp
,
v
),
ScalarConvert
<
float
,
T
>::
to
(
0.5
));
}
else
if
(
kMod
==
1
)
{
// -v * v + v + 0.5
T
tmp
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
neg
(
v
),
v
);
return
THCNumerics
<
T
>::
add
(
THCNumerics
<
T
>::
add
(
tmp
,
v
),
ScalarConvert
<
float
,
T
>::
to
(
0.5
));
}
else
{
// 0.5 * v * v
return
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
float
,
T
>::
to
(
0.5
),
THCNumerics
<
T
>::
mul
(
v
,
v
));
}
}
static
inline
__device__
T
cubic
(
T
v
,
int64_t
kMod
)
{
if
(
kMod
==
0
)
{
// (1 - v) * (1 -v) * (1 - v) / 6
T
tmp
=
THCNumerics
<
T
>::
sub
(
ScalarConvert
<
int
,
T
>::
to
(
1
),
v
);
tmp
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
tmp
,
tmp
),
tmp
);
return
THCNumerics
<
T
>::
div
(
tmp
,
ScalarConvert
<
int
,
T
>::
to
(
6
));
}
else
if
(
kMod
==
1
)
{
// (3 * v * v * v - 6 * v * v + 4) / 6
T
tmp1
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
v
,
v
),
v
);
tmp1
=
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
int
,
T
>::
to
(
3
),
tmp1
);
T
tmp2
=
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
int
,
T
>::
to
(
6
),
THCNumerics
<
T
>::
mul
(
v
,
v
));
tmp1
=
THCNumerics
<
T
>::
add
(
THCNumerics
<
T
>::
sub
(
tmp1
,
tmp2
),
ScalarConvert
<
int
,
T
>::
to
(
4
));
return
THCNumerics
<
T
>::
div
(
tmp1
,
ScalarConvert
<
int
,
T
>::
to
(
6
));
}
else
if
(
kMod
==
2
)
{
// (-3 * v * v * v + 3 * v * v + 3 * v + 1) / 6
T
tmp1
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
v
,
v
),
v
);
tmp1
=
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
int
,
T
>::
to
(
-
3
),
tmp1
);
T
tmp2
=
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
int
,
T
>::
to
(
3
),
THCNumerics
<
T
>::
mul
(
v
,
v
));
T
tmp3
=
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
int
,
T
>::
to
(
3
),
v
);
tmp1
=
THCNumerics
<
T
>::
add
(
THCNumerics
<
T
>::
add
(
tmp1
,
tmp2
),
tmp3
);
tmp1
=
THCNumerics
<
T
>::
add
(
tmp1
,
ScalarConvert
<
int
,
T
>::
to
(
1
));
return
THCNumerics
<
T
>::
div
(
tmp1
,
ScalarConvert
<
int
,
T
>::
to
(
6
));
}
else
{
// v * v * v / 6
T
tmp
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
v
,
v
),
v
);
return
THCNumerics
<
T
>::
div
(
tmp
,
ScalarConvert
<
int
,
T
>::
to
(
6
));
}
}
};
#endif // THC_BASIS_FORWARD_INC
test/basis.json
deleted
100644 → 0
View file @
c99021ca
[
{
"name"
:
"Linear and open B-splines"
,
"degree"
:
1
,
"pseudo"
:
[
0
,
0.05
,
0.25
,
0.5
,
0.75
,
0.95
,
1
],
"kernel_size"
:
[
5
],
"is_open_spline"
:
[
1
],
"expected_basis"
:
[[
1
,
0
],
[
0.8
,
0.2
],
[
1
,
0
],
[
1
,
0
],
[
1
,
0
],
[
0.2
,
0.8
],
[
1
,
0
]],
"expected_index"
:
[[
0
,
1
],
[
0
,
1
],
[
1
,
2
],
[
2
,
3
],
[
3
,
4
],
[
3
,
4
],
[
4
,
0
]]
},
{
"name"
:
"Linear and closed B-splines"
,
"degree"
:
1
,
"pseudo"
:
[
0
,
0.05
,
0.25
,
0.5
,
0.75
,
0.95
,
1
],
"kernel_size"
:
[
4
],
"is_open_spline"
:
[
0
],
"expected_basis"
:
[[
1
,
0
],
[
0.8
,
0.2
],
[
1
,
0
],
[
1
,
0
],
[
1
,
0
],
[
0.2
,
0.8
],
[
1
,
0
]],
"expected_index"
:
[[
0
,
1
],
[
0
,
1
],
[
1
,
2
],
[
2
,
3
],
[
3
,
0
],
[
3
,
0
],
[
0
,
1
]]
},
{
"name"
:
"Quadratic and open B-splines"
,
"degree"
:
2
,
"pseudo"
:
[
0
,
0.05
,
0.25
,
0.5
,
0.75
,
0.95
,
1
],
"kernel_size"
:
[
6
],
"is_open_spline"
:
[
1
],
"expected_basis"
:
[[
0.5
,
0.5
,
0
],
[
0.32
,
0.66
,
0.02
],
[
0.5
,
0.5
,
0
],
[
0.5
,
0.5
,
0
],
[
0.5
,
0.5
,
0
],
[
0.02
,
0.66
,
0.32
],
[
0.5
,
0.5
,
0
]],
"expected_index"
:
[[
0
,
1
,
2
],
[
0
,
1
,
2
],
[
1
,
2
,
3
],
[
2
,
3
,
4
],
[
3
,
4
,
5
],
[
3
,
4
,
5
],
[
4
,
5
,
0
]]
},
{
"name"
:
"Quadratic and closed B-splines"
,
"degree"
:
2
,
"pseudo"
:
[
0
,
0.05
,
0.25
,
0.5
,
0.75
,
0.95
,
1
],
"kernel_size"
:
[
4
],
"is_open_spline"
:
[
0
],
"expected_basis"
:
[[
0.5
,
0.5
,
0
],
[
0.32
,
0.66
,
0.02
],
[
0.5
,
0.5
,
0
],
[
0.5
,
0.5
,
0
],
[
0.5
,
0.5
,
0
],
[
0.02
,
0.66
,
0.32
],
[
0.5
,
0.5
,
0
]],
"expected_index"
:
[[
0
,
1
,
2
],
[
0
,
1
,
2
],
[
1
,
2
,
3
],
[
2
,
3
,
0
],
[
3
,
0
,
1
],
[
3
,
0
,
1
],
[
0
,
1
,
2
]]
},
{
"name"
:
"Cubic and open B-splines"
,
"degree"
:
3
,
"pseudo"
:
[
0
,
0.05
,
0.25
,
0.5
,
0.75
,
0.95
,
1
],
"kernel_size"
:
[
7
],
"is_open_spline"
:
[
1
],
"expected_basis"
:
[[
0.16667
,
0.6667
,
0.1667
,
0
],
[
0.0853
,
0.6307
,
0.2827
,
0.00133
],
[
0.1667
,
0.6667
,
0.1667
,
0
],
[
0.1667
,
0.6667
,
0.1667
,
0
],
[
0.1667
,
0.6667
,
0.1667
,
0
],
[
0.00133
,
0.2827
,
0.6307
,
0.0853
],
[
0.1667
,
0.6667
,
0.1667
,
0
]],
"expected_index"
:
[[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
1
,
2
,
3
,
4
],
[
2
,
3
,
4
,
5
],
[
3
,
4
,
5
,
6
],
[
3
,
4
,
5
,
6
],
[
4
,
5
,
6
,
0
]]
},
{
"name"
:
"Cubic and closed B-splines"
,
"degree"
:
3
,
"pseudo"
:
[
0
,
0.05
,
0.25
,
0.5
,
0.75
,
0.95
,
1
],
"kernel_size"
:
[
4
],
"is_open_spline"
:
[
0
],
"expected_basis"
:
[[
0.16667
,
0.6667
,
0.1667
,
0
],
[
0.0853
,
0.6307
,
0.2827
,
0.00133
],
[
0.1667
,
0.6667
,
0.1667
,
0
],
[
0.1667
,
0.6667
,
0.1667
,
0
],
[
0.1667
,
0.6667
,
0.1667
,
0
],
[
0.00133
,
0.2827
,
0.6307
,
0.0853
],
[
0.1667
,
0.6667
,
0.1667
,
0
]],
"expected_index"
:
[[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
1
,
2
,
3
,
0
],
[
2
,
3
,
0
,
1
],
[
3
,
0
,
1
,
2
],
[
3
,
0
,
1
,
2
],
[
0
,
1
,
2
,
3
]]
},
{
"name"
:
"Two-dimensional pseudo-coordinates"
,
"degree"
:
1
,
"pseudo"
:
[[
0.125
,
0.5
],
[
0.5
,
0.5
],
[
0.75
,
0.125
]],
"kernel_size"
:
[
5
,
5
],
"is_open_spline"
:
[
1
,
1
],
"expected_basis"
:
[[
0.5
,
0.5
,
0
,
0
],
[
1
,
0
,
0
,
0
],
[
0.5
,
0
,
0.5
,
0
]],
"expected_index"
:
[[
2
,
7
,
3
,
8
],
[
12
,
17
,
13
,
18
],
[
15
,
20
,
16
,
21
]]
}
]
test/test_basis.py
View file @
d42e2629
...
...
@@ -41,52 +41,14 @@ def test_basis_forward_cpu(tensor, i):
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
def
test_basis_forward_gpu
():
# pragma: no cover
pseudo
=
torch
.
cuda
.
FloatTensor
([
0
,
0.0625
,
0.25
,
0.75
,
0.9375
,
1
])
kernel_size
=
torch
.
cuda
.
LongTensor
([
5
])
is_open_spline
=
torch
.
cuda
.
ByteTensor
([
1
])
basis
,
weight_index
=
basis_forward
(
1
,
pseudo
,
kernel_size
,
is_open_spline
)
print
(
basis
.
cpu
().
tolist
())
print
(
weight_index
.
cpu
().
tolist
())
# 'basis': [[1, 0], [0.75, 0.25], [1, 0], [1, 0], [0.25, 0.75], [1, 0]],
# 'weight_index': [[0, 1], [0, 1], [1, 2], [3, 4], [3, 4], [4, 0]],
# @pytest.mark.parametrize('tensor,i', product(tensors, range(len(data))))
# def test_spline_basis_cpu(tensor, i):
# degree = data[i].get('degree')
# pseudo = Tensor(tensor, data[i]['pseudo'])
# pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
# kernel_size = torch.LongTensor(data[i]['kernel_size'])
# is_open_spline = torch.ByteTensor(data[i]['is_open_spline'])
# K = kernel_size.prod()
# expected_basis = Tensor(tensor, data[i]['expected_basis'])
# expected_index = torch.LongTensor(data[i]['expected_index'])
# basis, index = spline_basis_forward(degree, pseudo, kernel_size,
# is_open_spline, K)
# basis = [pytest.approx(b, 0.01) for b in basis.view(-1).tolist()]
# assert basis == expected_basis.view(-1).tolist()
# assert index.tolist() == expected_index.tolist()
# @pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
# @pytest.mark.parametrize('tensor,i', product(tensors, range(len(data))))
# def test_spline_basis_gpu(tensor, i): # pragma: no cover
# degree = data[i].get('degree')
# pseudo = Tensor(tensor, data[i]['pseudo']).cuda()
# pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
# kernel_size = torch.cuda.LongTensor(data[i]['kernel_size'])
# is_open_spline = torch.cuda.ByteTensor(data[i]['is_open_spline'])
# K = kernel_size.prod()
# expected_basis = Tensor(tensor, data[i]['expected_basis'])
# expected_index = torch.LongTensor(data[i]['expected_index'])
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
tests
))))
def
test_basis_forward_gpu
(
tensor
,
i
):
# pragma: no cover
data
=
tests
[
i
]
# basis, index = spline_basis_forward(degree, pseudo, kernel_size,
# is_open_spline, K)
# basis, index = basis.cpu(), index.cpu()
# basis = [pytest.approx(b, 0.01) for b in basis.view(-1).tolist()]
pseudo
=
getattr
(
torch
.
cuda
,
tensor
)(
data
[
'pseudo'
])
kernel_size
=
torch
.
cuda
.
LongTensor
(
data
[
'kernel_size'
])
is_open_spline
=
torch
.
cuda
.
ByteTensor
(
data
[
'is_open_spline'
])
# assert basis == expected_basis.view(-1).tolist()
# assert index.tolist() == expected_index.tolist()
basis
,
weight_index
=
basis_forward
(
1
,
pseudo
,
kernel_size
,
is_open_spline
)
assert
basis
.
cpu
().
tolist
()
==
data
[
'basis'
]
assert
weight_index
.
cpu
().
tolist
()
==
data
[
'weight_index'
]
test/utils.py
deleted
100644 → 0
View file @
c99021ca
import
torch
tensors
=
[
'FloatTensor'
,
'DoubleTensor'
]
def
Tensor
(
str
,
x
):
tensor
=
getattr
(
torch
,
str
)
return
tensor
(
x
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment