Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
c99021ca
Commit
c99021ca
authored
Apr 09, 2018
by
rusty1s
Browse files
added quadratic and cubic impl
parent
36ed7951
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
71 additions
and
7 deletions
+71
-7
aten/THC/THCBasis.cu
aten/THC/THCBasis.cu
+63
-7
aten/THC/THCNumerics.cuh
aten/THC/THCNumerics.cuh
+4
-0
aten/THC/generic/THCBasis.h
aten/THC/generic/THCBasis.h
+4
-0
No files found.
aten/THC/THCBasis.cu
View file @
c99021ca
...
...
@@ -46,16 +46,72 @@
}
template
<
typename
T
>
__global__
void
linearBasisForwardKernel
(
TensorInfo
<
T
>
basis
,
TensorInfo
<
int64_t
>
weightIndex
,
TensorInfo
<
T
>
pseudo
,
int64_t
*
kernelSize
,
uint8_t
*
isOpenSpline
,
ptrdiff_t
n
)
{
THC_TENSOR_BASIS_FORWARD_KERNEL
(
1
,
basis
,
weightIndex
,
pseudo
,
kernelSize
,
isOpenSpline
,
n
,
struct
BasisForward
{
static
inline
__device__
T
linear
(
T
v
,
int64_t
kMod
)
{
// 1 - v - kMod + 2 * v * kMod
T
tmp1
=
THCNumerics
<
T
>::
sub
(
ScalarConvert
<
int
,
T
>::
to
(
1
),
v
);
tmp1
=
THCNumerics
<
T
>::
sub
(
tmp1
,
ScalarConvert
<
int64_t
,
T
>::
to
(
kMod
));
T
tmp2
=
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
int
,
T
>::
to
(
2
),
v
);
tmp2
=
THCNumerics
<
T
>::
mul
(
tmp2
,
ScalarConvert
<
int64_t
,
T
>::
to
(
kMod
));
v
=
THCNumerics
<
T
>::
add
(
tmp1
,
tmp2
);
return
THCNumerics
<
T
>::
add
(
tmp1
,
tmp2
);
}
static
inline
__device__
T
quadratic
(
T
v
,
int64_t
kMod
)
{
if
(
kMod
==
0
)
{
// 0.5 * v * v - v + 0.5
T
tmp
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
float
,
T
>::
to
(
0.5
),
v
),
v
);
return
THCNumerics
<
T
>::
sub
(
tmp
,
THCNumerics
<
T
>::
add
(
v
,
ScalarConvert
<
float
,
T
>::
to
(
0.5
)));
}
else
if
(
kMod
==
1
)
{
// -v * v + v + 0.5
T
tmp
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
neg
(
v
),
v
);
return
THCNumerics
<
T
>::
add
(
THCNumerics
<
T
>::
add
(
tmp
,
v
),
ScalarConvert
<
float
,
T
>::
to
(
0.5
));
}
else
{
// 0.5 * v * v
return
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
float
,
T
>::
to
(
0.5
),
THCNumerics
<
T
>::
mul
(
v
,
v
));
}
}
static
inline
__device__
T
cubic
(
T
v
,
int64_t
kMod
)
{
if
(
kMod
==
0
)
{
// (1 - v) * (1 -v) * (1 - v) / 6
T
tmp
=
THCNumerics
<
T
>::
sub
(
ScalarConvert
<
int
,
T
>::
to
(
1
),
v
);
tmp
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
tmp
,
tmp
),
tmp
);
return
THCNumerics
<
T
>::
div
(
tmp
,
ScalarConvert
<
int
,
T
>::
to
(
6
));
}
else
if
(
kMod
==
1
)
{
// (3 * v * v * v - 6 * v * v + 4) / 6
T
tmp1
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
v
,
v
),
v
);
tmp1
=
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
int
,
T
>::
to
(
3
),
tmp1
);
T
tmp2
=
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
int
,
T
>::
to
(
6
),
THCNumerics
<
T
>::
mul
(
v
,
v
));
tmp1
=
THCNumerics
<
T
>::
add
(
THCNumerics
<
T
>::
sub
(
tmp1
,
tmp2
),
ScalarConvert
<
int
,
T
>::
to
(
4
));
return
THCNumerics
<
T
>::
div
(
tmp1
,
ScalarConvert
<
int
,
T
>::
to
(
6
));
}
else
if
(
kMod
==
2
)
{
// (-3 * v * v * v + 3 * v * v + 3 * v + 1) / 6
T
tmp1
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
v
,
v
),
v
);
tmp1
=
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
int
,
T
>::
to
(
-
3
),
tmp1
);
T
tmp2
=
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
int
,
T
>::
to
(
3
),
THCNumerics
<
T
>::
mul
(
v
,
v
));
T
tmp3
=
THCNumerics
<
T
>::
mul
(
ScalarConvert
<
int
,
T
>::
to
(
3
),
v
);
tmp1
=
THCNumerics
<
T
>::
add
(
THCNumerics
<
T
>::
add
(
tmp1
,
tmp2
),
tmp3
);
tmp1
=
THCNumerics
<
T
>::
add
(
tmp1
,
ScalarConvert
<
int
,
T
>::
to
(
1
));
return
THCNumerics
<
T
>::
div
(
tmp1
,
ScalarConvert
<
int
,
T
>::
to
(
6
));
}
else
{
// v * v * v / 6
T
tmp
=
THCNumerics
<
T
>::
mul
(
THCNumerics
<
T
>::
mul
(
v
,
v
),
v
);
return
THCNumerics
<
T
>::
div
(
tmp
,
ScalarConvert
<
int
,
T
>::
to
(
6
));
}
}
};
template
<
typename
T
>
__global__
void
linearBasisForwardKernel
(
TensorInfo
<
T
>
basis
,
TensorInfo
<
int64_t
>
weightIndex
,
TensorInfo
<
T
>
pseudo
,
int64_t
*
kernelSize
,
uint8_t
*
isOpenSpline
,
ptrdiff_t
n
)
{
THC_TENSOR_BASIS_FORWARD_KERNEL
(
1
,
basis
,
weightIndex
,
pseudo
,
kernelSize
,
isOpenSpline
,
n
,
v
=
BasisForward
<
T
>::
linear
(
v
,
kMod
);
)
}
...
...
@@ -64,7 +120,7 @@ __global__ void quadraticBasisForwardKernel(TensorInfo<T> basis, TensorInfo<int6
TensorInfo
<
T
>
pseudo
,
int64_t
*
kernelSize
,
uint8_t
*
isOpenSpline
,
ptrdiff_t
n
)
{
THC_TENSOR_BASIS_FORWARD_KERNEL
(
2
,
basis
,
weightIndex
,
pseudo
,
kernelSize
,
isOpenSpline
,
n
,
/* printf("DRIN"); */
v
=
BasisForward
<
T
>::
quadratic
(
v
,
kMod
);
)
}
...
...
@@ -73,7 +129,7 @@ __global__ void cubicBasisForwardKernel(TensorInfo<T> basis, TensorInfo<int64_t>
TensorInfo
<
T
>
pseudo
,
int64_t
*
kernelSize
,
uint8_t
*
isOpenSpline
,
ptrdiff_t
n
)
{
THC_TENSOR_BASIS_FORWARD_KERNEL
(
3
,
basis
,
weightIndex
,
pseudo
,
kernelSize
,
isOpenSpline
,
n
,
/* printf("DRIN"); */
v
=
BasisForward
<
T
>::
cubic
(
v
,
kMod
);
)
}
...
...
aten/THC/THCNumerics.cuh
View file @
c99021ca
...
...
@@ -18,6 +18,8 @@ struct THCNumerics {
static
inline
__host__
__device__
T
add
(
T
a
,
T
b
)
{
return
a
+
b
;
}
static
inline
__host__
__device__
T
sub
(
T
a
,
T
b
)
{
return
a
-
b
;
}
static
inline
__host__
__device__
T
mul
(
T
a
,
T
b
)
{
return
a
*
b
;
}
static
inline
__host__
__device__
T
div
(
T
a
,
T
b
)
{
return
a
/
b
;
}
static
inline
__host__
__device__
T
neg
(
T
a
)
{
return
-
a
;
}
};
#ifdef CUDA_HALF_TENSOR
...
...
@@ -26,6 +28,8 @@ struct THCNumerics<half> {
static
inline
__host__
__device__
half
add
(
half
a
,
half
b
)
{
return
f2h
(
h2f
(
a
)
+
h2f
(
b
));
}
static
inline
__host__
__device__
half
sub
(
half
a
,
half
b
)
{
return
f2h
(
h2f
(
a
)
-
h2f
(
b
));
}
static
inline
__host__
__device__
half
mul
(
half
a
,
half
b
)
{
return
f2h
(
h2f
(
a
)
*
h2f
(
b
));
}
static
inline
__host__
__device__
half
div
(
half
a
,
half
b
)
{
return
f2h
(
h2f
(
a
)
/
h2f
(
b
));
}
static
inline
__host__
__device__
half
neg
(
half
a
)
{
return
f2h
(
-
h2f
(
a
));
}
};
#endif // CUDA_HALF_TENSOR
...
...
aten/THC/generic/THCBasis.h
View file @
c99021ca
...
...
@@ -15,4 +15,8 @@ void THCTensor_(cubicBasisForward)(THCState *state, THCTensor *basis,
THCudaLongTensor
*
weightIndex
,
THCTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
);
void
THCTensor_
(
linearBasisBackward
)(
THCState
*
state
,
THCTensor
*
basis
,
THCudaLongTensor
*
weightIndex
,
THCTensor
*
pseudo
,
THCudaLongTensor
*
kernelSize
,
THCudaByteTensor
*
isOpenSpline
);
#endif // THC_GENERIC_FILE
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment