Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
8e464c16
Commit
8e464c16
authored
Feb 28, 2020
by
rusty1s
Browse files
complete spline cuda
parent
751593df
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
115 additions
and
99 deletions
+115
-99
csrc/cuda/weighting_cuda.cu
csrc/cuda/weighting_cuda.cu
+115
-3
csrc/cuda/weighting_kernel.cu
csrc/cuda/weighting_kernel.cu
+0
-96
No files found.
csrc/cuda/weighting_cuda.cu
View file @
8e464c16
...
@@ -19,7 +19,7 @@ spline_weighting_fw_kernel(const scalar_t *x, const scalar_t *weight,
...
@@ -19,7 +19,7 @@ spline_weighting_fw_kernel(const scalar_t *x, const scalar_t *weight,
const
int64_t
m_out
=
thread_idx
%
M_out
;
const
int64_t
m_out
=
thread_idx
%
M_out
;
if
(
thread_idx
<
numel
)
{
if
(
thread_idx
<
numel
)
{
scalar_t
v
=
0
;
scalar_t
v
=
(
scalar_t
)
0.
;
for
(
ptrdiff_t
s
=
0
;
s
<
S
;
s
++
)
{
for
(
ptrdiff_t
s
=
0
;
s
<
S
;
s
++
)
{
const
scalar_t
b
=
basis
[
e
*
S
+
s
];
const
scalar_t
b
=
basis
[
e
*
S
+
s
];
...
@@ -116,6 +116,7 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
...
@@ -116,6 +116,7 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
auto
S
=
basis
.
size
(
1
);
auto
S
=
basis
.
size
(
1
);
auto
grad_x
=
at
::
zeros
({
E
,
M_in
},
grad_out
.
options
());
auto
grad_x
=
at
::
zeros
({
E
,
M_in
},
grad_out
.
options
());
weight
=
weight
.
transpose
(
1
,
2
).
contiguous
();
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
...
@@ -135,17 +136,128 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
...
@@ -135,17 +136,128 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
return
grad_x
;
return
grad_x
;
}
}
template
<
typename
scalar_t
>
spline_weighting_bw_weight_kernel
(
const
scalar_t
*
grad_out
,
const
scalar_t
*
x
,
const
scalar_t
*
basis
,
const
int64_t
*
weight_index
,
scalar_t
*
grad_x
,
int64_t
E
,
int64_t
M_in
,
int64_t
M_out
,
int64_t
S
,
int64_t
numel
)
{
const
int64_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int64_t
e
=
thread_idx
/
M_out
;
const
int64_t
m_out
=
thread_idx
%
M_out
;
if
(
thread_idx
<
numel
)
{
auto
g
=
grad_out
[
e
*
M_out
+
m_out
];
for
(
int64_t
s
=
0
;
s
<
S
;
s
++
)
{
const
scalar_t
b
=
basis
[
e
*
S
+
s
];
const
int64_t
wi
=
weight_index
[
e
*
S
+
s
];
for
(
int64_t
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
auto
v
=
g
*
b
*
x
[
e
*
M_in
+
m_in
];
atomicAdd
(
&
grad_weight
[
wi
*
M_in
*
M_out
+
m_in
*
M_out
+
m_out
],
v
);
}
}
}
}
torch
::
Tensor
spline_weighting_bw_weight_cuda
(
torch
::
Tensor
grad_out
,
torch
::
Tensor
spline_weighting_bw_weight_cuda
(
torch
::
Tensor
grad_out
,
torch
::
Tensor
x
,
torch
::
Tensor
x
,
torch
::
Tensor
basis
,
torch
::
Tensor
basis
,
torch
::
Tensor
weight_index
,
torch
::
Tensor
weight_index
,
int64_t
kernel_size
)
{
int64_t
kernel_size
)
{
return
grad_out
;
CHECK_CUDA
(
grad_out
);
CHECK_CUDA
(
x
);
CHECK_CUDA
(
basis
);
CHECK_CUDA
(
weight_index
);
cudaSetDevice
(
grad_out
.
get_device
());
auto
E
=
grad_out
.
size
(
0
);
auto
M_in
=
x
.
size
(
1
);
auto
M_out
=
grad_out
.
size
(
1
);
auto
S
=
basis
.
size
(
1
);
auto
grad_weight
=
at
::
zeros
({
kernel_size
,
M_in
,
M_out
},
grad_out
.
options
());
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"weighting_bw_weight"
,
[
&
]
{
auto
grad_out_data
=
grad_out
.
data_ptr
<
scalar_t
>
();
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
basis_data
=
basis
.
data_ptr
<
scalar_t
>
();
auto
grad_weight_data
=
grad_weight
.
data_ptr
<
scalar_t
>
();
spline_weighting_bw_weight_kernel
<
scalar_t
>
<<<
BLOCKS
(
grad_out
.
numel
()),
THREADS
,
0
,
stream
>>>
(
grad_out_data
,
x_data
,
basis_data
,
weight_index_data
,
grad_weight_data
,
E
,
M_in
,
M_out
,
S
,
grad_out
.
numel
());
});
return
grad_weight
;
}
template
<
typename
scalar_t
>
spline_weighting_bw_basis_kernel
(
const
scalar_t
*
grad_out
,
const
scalar_t
*
x
,
const
scalar_t
*
weight
,
const
int64_t
*
weight_index
,
scalar_t
*
grad_basis
,
int64_t
E
,
int64_t
M_in
,
int64_t
M_out
,
int64_t
S
,
int64_t
numel
)
{
const
size_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int64_t
e
=
i
/
M_out
;
const
int64_t
m_out
=
i
%
M_out
;
if
(
thread_idx
<
numel
)
{
const
scalar_t
g
=
grad_out
[
e
*
M_out
+
m_out
];
for
(
int64_t
s
=
0
;
s
<
S
;
s
++
)
{
scalar_t
v
=
(
scalar_t
)
0.
;
const
int64_t
wi
=
weight_index
[
e
*
S
+
s
];
for
(
int64_t
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
const
scalar_t
w
=
weight
[
wi
*
M_in
*
M_out
+
m_in
*
M_out
+
m_out
];
v
+=
g
*
w
*
x
[
e
*
M_in
+
m_in
];
}
atomicAdd
(
&
grad_basis
[
e
*
S
+
s
],
v
);
}
}
}
}
torch
::
Tensor
spline_weighting_bw_basis_cuda
(
torch
::
Tensor
grad_out
,
torch
::
Tensor
spline_weighting_bw_basis_cuda
(
torch
::
Tensor
grad_out
,
torch
::
Tensor
x
,
torch
::
Tensor
x
,
torch
::
Tensor
weight
,
torch
::
Tensor
weight
,
torch
::
Tensor
weight_index
)
{
torch
::
Tensor
weight_index
)
{
return
grad_out
;
CHECK_CPU
(
grad_out
);
CHECK_CPU
(
x
);
CHECK_CPU
(
weight
);
CHECK_CPU
(
weight_index
);
cudaSetDevice
(
grad_out
.
get_device
());
CHECK_INPUT
(
x
.
size
(
1
)
==
weight
.
size
(
1
));
CHECK_INPUT
(
grad_out
.
size
(
1
)
==
weight
.
size
(
2
));
auto
E
=
grad_out
.
size
(
0
);
auto
M_in
=
x
.
size
(
1
);
auto
M_out
=
grad_out
.
size
(
1
);
auto
S
=
weight_index
.
size
(
1
);
auto
grad_basis
=
at
::
zeros
({
E
,
S
},
grad_out
.
options
());
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"weighting_bw_basis"
,
[
&
]
{
auto
grad_out_data
=
grad_out
.
data_ptr
<
scalar_t
>
();
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
auto
grad_basis_data
=
grad_basis
.
data_ptr
<
scalar_t
>
();
spline_weighting_bw_basis_kernel
<
scalar_t
>
<<<
BLOCKS
(
grad_out
.
numel
()),
THREADS
,
0
,
stream
>>>
(
grad_out_data
,
x_data
,
weight_data
,
weight_index_data
,
grad_basis_data
,
E
,
M_in
,
M_out
,
S
,
grad_out
.
numel
());
});
return
grad_basis
;
}
}
csrc/cuda/weighting_kernel.cu
View file @
8e464c16
...
@@ -5,102 +5,6 @@
...
@@ -5,102 +5,6 @@
#define THREADS 1024
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template
<
typename
scalar_t
>
__global__
void
weighting_fw_kernel
(
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
out
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
x
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
weight
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
basis
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
weight_index
,
size_t
numel
)
{
const
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
i
=
index
;
i
<
numel
;
i
+=
stride
)
{
int64_t
e
=
i
/
out
.
sizes
[
1
],
m_out
=
i
%
out
.
sizes
[
1
];
auto
S
=
basis
.
sizes
[
1
];
scalar_t
v
=
0
;
for
(
ptrdiff_t
s
=
0
;
s
<
S
;
s
++
)
{
auto
b
=
basis
.
data
[
e
*
S
+
s
];
auto
wi
=
weight_index
.
data
[
e
*
S
+
s
];
for
(
ptrdiff_t
m_in
=
0
;
m_in
<
x
.
sizes
[
1
];
m_in
++
)
{
auto
tmp
=
weight
.
data
[
wi
*
weight
.
strides
[
0
]
+
m_in
*
weight
.
strides
[
1
]
+
m_out
*
weight
.
strides
[
2
]];
tmp
*=
b
*
x
.
data
[
e
*
x
.
strides
[
0
]
+
m_in
*
x
.
strides
[
1
]];
v
+=
tmp
;
}
}
out
.
data
[
i
]
=
v
;
}
}
at
::
Tensor
weighting_fw_cuda
(
at
::
Tensor
x
,
at
::
Tensor
weight
,
at
::
Tensor
basis
,
at
::
Tensor
weight_index
)
{
cudaSetDevice
(
x
.
get_device
());
auto
E
=
x
.
size
(
0
),
M_out
=
weight
.
size
(
2
);
auto
out
=
at
::
empty
({
E
,
M_out
},
x
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
out
.
scalar_type
(),
"weighting_fw"
,
[
&
]
{
weighting_fw_kernel
<
scalar_t
><<<
BLOCKS
(
out
.
numel
()),
THREADS
>>>
(
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
out
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
x
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
weight
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
basis
),
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
weight_index
),
out
.
numel
());
});
return
out
;
}
template
<
typename
scalar_t
>
__global__
void
weighting_bw_x_kernel
(
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
grad_x
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
grad_out
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
weight
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
basis
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
weight_index
,
size_t
numel
)
{
const
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
i
=
index
;
i
<
numel
;
i
+=
stride
)
{
int64_t
e
=
i
/
grad_x
.
sizes
[
1
],
m_in
=
i
%
grad_x
.
sizes
[
1
];
auto
S
=
basis
.
sizes
[
1
];
scalar_t
v
=
0
;
for
(
ptrdiff_t
s
=
0
;
s
<
S
;
s
++
)
{
auto
b
=
basis
.
data
[
e
*
S
+
s
];
auto
wi
=
weight_index
.
data
[
e
*
S
+
s
];
for
(
ptrdiff_t
m_out
=
0
;
m_out
<
grad_out
.
sizes
[
1
];
m_out
++
)
{
auto
tmp
=
weight
.
data
[
wi
*
weight
.
strides
[
0
]
+
m_out
*
weight
.
strides
[
1
]
+
m_in
*
weight
.
strides
[
2
]];
tmp
*=
b
*
grad_out
.
data
[
e
*
grad_out
.
strides
[
0
]
+
m_out
*
grad_out
.
strides
[
1
]];
v
+=
tmp
;
}
}
grad_x
.
data
[
i
]
=
v
;
}
}
at
::
Tensor
weighting_bw_x_cuda
(
at
::
Tensor
grad_out
,
at
::
Tensor
weight
,
at
::
Tensor
basis
,
at
::
Tensor
weight_index
)
{
cudaSetDevice
(
grad_out
.
get_device
());
auto
E
=
grad_out
.
size
(
0
),
M_in
=
weight
.
size
(
1
);
auto
grad_x
=
at
::
empty
({
E
,
M_in
},
grad_out
.
options
());
weight
=
weight
.
transpose
(
1
,
2
).
contiguous
();
AT_DISPATCH_FLOATING_TYPES
(
grad_x
.
scalar_type
(),
"weighting_bw_x"
,
[
&
]
{
weighting_bw_x_kernel
<
scalar_t
><<<
BLOCKS
(
grad_x
.
numel
()),
THREADS
>>>
(
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
grad_x
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
grad_out
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
weight
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
basis
),
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
weight_index
),
grad_x
.
numel
());
});
return
grad_x
;
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
weighting_bw_w_kernel
(
__global__
void
weighting_bw_w_kernel
(
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
grad_weight
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
grad_weight
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment