Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
1b316a63
Commit
1b316a63
authored
Dec 13, 2019
by
rusty1s
Browse files
basic segment_add functionality
parent
d9565693
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
11 deletions
+21
-11
cuda/segment.cpp
cuda/segment.cpp
+5
-5
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+11
-2
test/test_segment.py
test/test_segment.py
+3
-2
torch_scatter/segment.py
torch_scatter/segment.py
+2
-2
No files found.
cuda/segment.cpp
View file @
1b316a63
...
...
@@ -2,15 +2,15 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
void
segment_add_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
int64_t
dim
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
segment_add_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
);
void
segment_add
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
int64_t
dim
)
{
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
segment_add
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
CHECK_CUDA
(
out
);
segment_add_cuda
(
src
,
index
,
out
,
dim
);
return
segment_add_cuda
(
src
,
index
,
out
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
cuda/segment_kernel.cu
View file @
1b316a63
...
...
@@ -8,16 +8,25 @@
#include "compat.cuh"
void
segment_add_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
int64_t
dim
)
{
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
segment_add_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
cudaSetDevice
(
src
.
get_device
());
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
allocator
=
THCThrustAllocator
(
at
::
globalContext
().
lazyInitCUDA
());
auto
policy
=
thrust
::
cuda
::
par
(
allocator
).
on
(
stream
);
auto
key
=
at
::
full_like
(
out
,
-
1
,
out
.
options
().
dtype
(
at
::
kLong
));
auto
index_data
=
thrust
::
device_ptr
<
int64_t
>
(
index
.
DATA_PTR
<
int64_t
>
());
auto
key_data
=
thrust
::
device_ptr
<
int64_t
>
(
key
.
DATA_PTR
<
int64_t
>
());
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_add_kernel"
,
[
&
]
{
auto
src_data
=
thrust
::
device_ptr
<
scalar_t
>
(
src
.
DATA_PTR
<
scalar_t
>
());
auto
out_data
=
thrust
::
device_ptr
<
scalar_t
>
(
out
.
DATA_PTR
<
scalar_t
>
());
thrust
::
reduce_by_key
(
policy
,
index_data
,
index_data
+
index
.
size
(
0
),
src_data
,
key_data
,
out_data
);
});
return
std
::
make_tuple
(
out
,
key
);
}
test/test_segment.py
View file @
1b316a63
...
...
@@ -13,7 +13,8 @@ devices = [torch.device('cuda')]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_forward
(
dtype
,
device
):
src
=
tensor
([
1
,
2
,
3
,
4
,
5
,
6
],
dtype
,
device
)
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
2
],
torch
.
long
,
device
)
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
out
=
segment_add
(
src
,
index
,
dim
=
0
)
out
,
key
=
segment_add
(
src
,
index
,
dim
=
0
)
print
(
out
)
print
(
key
)
torch_scatter/segment.py
View file @
1b316a63
...
...
@@ -11,5 +11,5 @@ def segment_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
if
src
.
size
(
dim
)
==
0
:
# pragma: no cover
return
out
assert
src
.
is_cuda
torch_scatter
.
segment_cuda
.
segment_add
(
src
,
index
,
out
,
dim
)
return
out
out
,
key
=
torch_scatter
.
segment_cuda
.
segment_add
(
src
,
index
,
out
)
return
out
,
key
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment