Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
f5cb51ae
Commit
f5cb51ae
authored
Jan 21, 2020
by
rusty1s
Browse files
stream to scatter kernels
parent
3cf59da2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
8 deletions
+8
-8
cuda/scatter_kernel.cu
cuda/scatter_kernel.cu
+8
-8
No files found.
cuda/scatter_kernel.cu
View file @
f5cb51ae
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
...
@@ -8,20 +9,23 @@
...
@@ -8,20 +9,23 @@
#define THREADS 1024
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define BLOCKS(N) (N + THREADS - 1) / THREADS
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
#define KERNEL_RUN(NAME, DIMS, N, ...) \
#define KERNEL_RUN(NAME, DIMS, N, ...) \
[&] { \
[&] { \
auto stream = at::cuda::getCurrentCUDAStream(); \
switch (DIMS) { \
switch (DIMS) { \
case 1: \
case 1: \
NAME<scalar_t, 1><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N);
\
NAME<scalar_t, 1><<<BLOCKS(N), THREADS
, 0, stream
>>>(__VA_ARGS__, N); \
break; \
break; \
case 2: \
case 2: \
NAME<scalar_t, 2><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N);
\
NAME<scalar_t, 2><<<BLOCKS(N), THREADS
, 0, stream
>>>(__VA_ARGS__, N); \
break; \
break; \
case 3: \
case 3: \
NAME<scalar_t, 3><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N);
\
NAME<scalar_t, 3><<<BLOCKS(N), THREADS
, 0, stream
>>>(__VA_ARGS__, N); \
break; \
break; \
default: \
default: \
NAME<scalar_t, -1><<<BLOCKS(N), THREADS>>>(__VA_ARGS__, N);
\
NAME<scalar_t, -1><<<BLOCKS(N), THREADS
, 0, stream
>>>(__VA_ARGS__, N); \
} \
} \
}()
}()
...
@@ -43,7 +47,6 @@ scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
...
@@ -43,7 +47,6 @@ scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void
scatter_mul_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
void
scatter_mul_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
int64_t
dim
)
{
int64_t
dim
)
{
cudaSetDevice
(
src
.
get_device
());
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"scatter_mul_kernel"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"scatter_mul_kernel"
,
[
&
]
{
KERNEL_RUN
(
scatter_mul_kernel
,
index
.
dim
(),
index
.
numel
(),
KERNEL_RUN
(
scatter_mul_kernel
,
index
.
dim
(),
index
.
numel
(),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
),
...
@@ -70,7 +73,6 @@ scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
...
@@ -70,7 +73,6 @@ scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void
scatter_div_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
void
scatter_div_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
int64_t
dim
)
{
int64_t
dim
)
{
cudaSetDevice
(
src
.
get_device
());
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"scatter_div_kernel"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"scatter_div_kernel"
,
[
&
]
{
KERNEL_RUN
(
scatter_div_kernel
,
index
.
dim
(),
index
.
numel
(),
KERNEL_RUN
(
scatter_div_kernel
,
index
.
dim
(),
index
.
numel
(),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
),
...
@@ -116,7 +118,6 @@ scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
...
@@ -116,7 +118,6 @@ scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void
scatter_max_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
void
scatter_max_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
at
::
Tensor
arg
,
int64_t
dim
)
{
at
::
Tensor
arg
,
int64_t
dim
)
{
cudaSetDevice
(
src
.
get_device
());
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"scatter_max_kernel"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"scatter_max_kernel"
,
[
&
]
{
auto
src_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
);
auto
src_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
);
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
index
);
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
index
);
...
@@ -147,7 +148,6 @@ scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
...
@@ -147,7 +148,6 @@ scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void
scatter_min_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
void
scatter_min_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
at
::
Tensor
arg
,
int64_t
dim
)
{
at
::
Tensor
arg
,
int64_t
dim
)
{
cudaSetDevice
(
src
.
get_device
());
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"scatter_min_kernel"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"scatter_min_kernel"
,
[
&
]
{
auto
src_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
);
auto
src_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
);
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
index
);
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
index
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment