Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
112345ce
Commit
112345ce
authored
Feb 29, 2020
by
rusty1s
Browse files
cuda related fixes
parent
8e464c16
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
18 deletions
+16
-18
csrc/cuda/weighting_cuda.cu
csrc/cuda/weighting_cuda.cu
+16
-18
No files found.
csrc/cuda/weighting_cuda.cu
View file @
112345ce
...
@@ -89,7 +89,7 @@ spline_weighting_bw_x_kernel(const scalar_t *grad_out, const scalar_t *weight,
...
@@ -89,7 +89,7 @@ spline_weighting_bw_x_kernel(const scalar_t *grad_out, const scalar_t *weight,
const
int64_t
wi
=
weight_index
[
e
*
S
+
s
];
const
int64_t
wi
=
weight_index
[
e
*
S
+
s
];
for
(
int64_t
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
for
(
int64_t
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
scalar_t
tmp
=
weight
[
wi
*
M_
in
*
M_
out
+
m_out
*
M_
out
+
m_in
];
scalar_t
tmp
=
weight
[
wi
*
M_
out
*
M_
in
+
m_out
*
M_
in
+
m_in
];
tmp
*=
b
*
grad_out
[
e
*
M_out
+
m_out
];
tmp
*=
b
*
grad_out
[
e
*
M_out
+
m_out
];
v
+=
tmp
;
v
+=
tmp
;
}
}
...
@@ -116,7 +116,7 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
...
@@ -116,7 +116,7 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
auto
S
=
basis
.
size
(
1
);
auto
S
=
basis
.
size
(
1
);
auto
grad_x
=
at
::
zeros
({
E
,
M_in
},
grad_out
.
options
());
auto
grad_x
=
at
::
zeros
({
E
,
M_in
},
grad_out
.
options
());
weight
=
weight
.
transpose
(
1
,
2
).
contiguous
();
weight
=
weight
.
transpose
(
1
,
2
).
contiguous
();
// Contiguous memory-access.
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
...
@@ -137,11 +137,10 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
...
@@ -137,11 +137,10 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
spline_weighting_bw_weight_kernel
(
const
scalar_t
*
grad_out
,
const
scalar_t
*
x
,
__global__
void
spline_weighting_bw_weight_kernel
(
const
scalar_t
*
basis
,
const
scalar_t
*
grad_out
,
const
scalar_t
*
x
,
const
scalar_t
*
basis
,
const
int64_t
*
weight_index
,
scalar_t
*
grad_x
,
const
int64_t
*
weight_index
,
scalar_t
*
grad_weight
,
int64_t
E
,
int64_t
M_in
,
int64_t
E
,
int64_t
M_in
,
int64_t
M_out
,
int64_t
M_out
,
int64_t
S
,
int64_t
numel
)
{
int64_t
S
,
int64_t
numel
)
{
const
int64_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int64_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int64_t
e
=
thread_idx
/
M_out
;
const
int64_t
e
=
thread_idx
/
M_out
;
...
@@ -198,15 +197,14 @@ torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
...
@@ -198,15 +197,14 @@ torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
spline_weighting_bw_basis_kernel
(
const
scalar_t
*
grad_out
,
const
scalar_t
*
x
,
__global__
void
spline_weighting_bw_basis_kernel
(
const
scalar_t
*
weight
,
const
scalar_t
*
grad_out
,
const
scalar_t
*
x
,
const
scalar_t
*
weight
,
const
int64_t
*
weight_index
,
const
int64_t
*
weight_index
,
scalar_t
*
grad_basis
,
int64_t
E
,
int64_t
M_in
,
scalar_t
*
grad_basis
,
int64_t
E
,
int64_t
M_in
,
int64_t
M_out
,
int64_t
S
,
int64_t
numel
)
{
int64_t
M_out
,
int64_t
S
,
int64_t
numel
)
{
const
size_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int64_t
e
=
i
/
M_out
;
const
int64_t
e
=
thread_idx
/
M_out
;
const
int64_t
m_out
=
i
%
M_out
;
const
int64_t
m_out
=
thread_idx
%
M_out
;
if
(
thread_idx
<
numel
)
{
if
(
thread_idx
<
numel
)
{
const
scalar_t
g
=
grad_out
[
e
*
M_out
+
m_out
];
const
scalar_t
g
=
grad_out
[
e
*
M_out
+
m_out
];
...
@@ -228,10 +226,10 @@ torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
...
@@ -228,10 +226,10 @@ torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
torch
::
Tensor
x
,
torch
::
Tensor
x
,
torch
::
Tensor
weight
,
torch
::
Tensor
weight
,
torch
::
Tensor
weight_index
)
{
torch
::
Tensor
weight_index
)
{
CHECK_C
P
U
(
grad_out
);
CHECK_CU
DA
(
grad_out
);
CHECK_C
P
U
(
x
);
CHECK_CU
DA
(
x
);
CHECK_C
P
U
(
weight
);
CHECK_CU
DA
(
weight
);
CHECK_C
P
U
(
weight_index
);
CHECK_CU
DA
(
weight_index
);
cudaSetDevice
(
grad_out
.
get_device
());
cudaSetDevice
(
grad_out
.
get_device
());
CHECK_INPUT
(
x
.
size
(
1
)
==
weight
.
size
(
1
));
CHECK_INPUT
(
x
.
size
(
1
)
==
weight
.
size
(
1
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment