Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
124c8115
Commit
124c8115
authored
Dec 28, 2019
by
rusty1s
Browse files
no need to enforce stride=1
parent
d762df6a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
5 deletions
+5
-5
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+4
-3
test/test_segment.py
test/test_segment.py
+1
-2
No files found.
cuda/segment_kernel.cu
View file @
124c8115
...
...
@@ -41,6 +41,7 @@ template <typename T, typename I> struct IndexPtrToOffset {
static
__host__
__device__
I
get
(
I
idx
,
const
at
::
cuda
::
detail
::
TensorInfo
<
T
,
I
>
&
info
)
{
I
offset
=
idx
%
(
info
.
sizes
[
info
.
dims
-
1
]
-
1
);
offset
*=
info
.
strides
[
info
.
dims
-
1
];
idx
/=
info
.
sizes
[
info
.
dims
-
1
]
-
1
;
for
(
int
i
=
info
.
dims
-
2
;
i
>=
0
;
--
i
)
{
offset
+=
(
idx
%
info
.
sizes
[
i
])
*
info
.
strides
[
i
];
...
...
@@ -63,7 +64,8 @@ __global__ void segment_add_csr_kernel(
if
(
warp_idx
<
N
)
{
auto
offset
=
IndexPtrToOffset
<
int64_t
,
int
>::
get
(
warp_idx
,
indptr_info
);
int
row_start
=
__ldg
(
indptr_info
.
data
+
offset
);
int
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
1
);
int
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
scalar_t
val
=
(
scalar_t
)
0
;
offset
=
(
warp_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
;
...
...
@@ -82,9 +84,8 @@ __global__ void segment_add_csr_kernel(
}
at
::
Tensor
segment_add_csr_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
)
{
src
=
src
.
contiguous
();
AT_ASSERTM
(
indptr
.
stride
(
-
1
)
==
1
);
AT_ASSERTM
(
src
.
dim
()
>=
indptr
.
dim
());
src
=
src
.
contiguous
();
auto
reduce_dim
=
indptr
.
dim
()
-
1
;
auto
sizes
=
src
.
sizes
().
vec
();
...
...
test/test_segment.py
View file @
124c8115
...
...
@@ -25,8 +25,7 @@ def test_forward2(dtype, device):
src
=
tensor
([[
1
,
2
,
3
,
4
,
5
,
6
],
[
1
,
3
,
5
,
7
,
9
,
11
]],
dtype
,
device
)
indptr
=
tensor
([[
0
,
2
,
5
,
5
,
6
]],
torch
.
long
,
device
)
indptr
=
indptr
.
view
(
1
,
-
1
).
expand
(
2
,
-
1
)
assert
indptr
.
stride
(
-
1
)
==
1
indptr
=
indptr
.
view
(
1
,
-
1
).
expand
(
2
,
-
1
).
t
().
contiguous
().
t
()
out
=
segment_add_csr
(
src
,
indptr
)
print
(
'CSR'
,
out
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment