Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
a9f9266b
Commit
a9f9266b
authored
Jan 10, 2020
by
rusty1s
Browse files
index expand
parent
2743b291
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
5 deletions
+16
-5
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+16
-5
No files found.
cuda/segment_kernel.cu
View file @
a9f9266b
...
@@ -178,8 +178,13 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
...
@@ -178,8 +178,13 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
at
::
optional
<
at
::
Tensor
>
out_opt
,
std
::
string
reduce
)
{
at
::
optional
<
at
::
Tensor
>
out_opt
,
std
::
string
reduce
)
{
AT_ASSERTM
(
src
.
dim
()
>=
indptr
.
dim
(),
"Input mismatch"
);
AT_ASSERTM
(
src
.
dim
()
>=
indptr
.
dim
(),
"Input mismatch"
);
for
(
int
i
=
0
;
i
<
indptr
.
dim
()
-
1
;
i
++
)
AT_ASSERTM
(
src
.
size
(
i
)
==
indptr
.
size
(
i
),
"Input mismatch"
);
// Broadcasting across `index` via `expand`.
auto
sizes
=
indptr
.
sizes
().
vec
();
for
(
int
i
=
0
;
i
<
indptr
.
dim
()
-
1
;
i
++
)
{
sizes
[
i
]
=
src
.
size
(
i
);
}
indptr
=
indptr
.
expand
(
sizes
);
src
=
src
.
contiguous
();
src
=
src
.
contiguous
();
auto
reduce_dim
=
indptr
.
dim
()
-
1
;
auto
reduce_dim
=
indptr
.
dim
()
-
1
;
...
@@ -193,7 +198,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
...
@@ -193,7 +198,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
AT_ASSERTM
(
out
.
size
(
reduce_dim
)
==
indptr
.
size
(
reduce_dim
)
-
1
,
AT_ASSERTM
(
out
.
size
(
reduce_dim
)
==
indptr
.
size
(
reduce_dim
)
-
1
,
"Input mismatch"
);
"Input mismatch"
);
}
else
{
}
else
{
auto
sizes
=
src
.
sizes
().
vec
();
sizes
=
src
.
sizes
().
vec
();
sizes
[
reduce_dim
]
=
indptr
.
size
(
reduce_dim
)
-
1
;
sizes
[
reduce_dim
]
=
indptr
.
size
(
reduce_dim
)
-
1
;
out
=
at
::
empty
(
sizes
,
src
.
options
());
out
=
at
::
empty
(
sizes
,
src
.
options
());
}
}
...
@@ -370,9 +375,15 @@ __global__ void segment_coo_arg_broadcast_kernel(
...
@@ -370,9 +375,15 @@ __global__ void segment_coo_arg_broadcast_kernel(
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
segment_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
segment_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
std
::
string
reduce
)
{
std
::
string
reduce
)
{
AT_ASSERTM
(
src
.
dim
()
>=
index
.
dim
(),
"Input mismatch"
);
AT_ASSERTM
(
src
.
dim
()
>=
index
.
dim
(),
"Input mismatch"
);
for
(
int
i
=
0
;
i
<
index
.
dim
();
i
++
)
AT_ASSERTM
(
src
.
size
(
i
)
==
index
.
size
(
i
),
"Input mismatch"
);
// Broadcasting across `index` via `expand`.
auto
sizes
=
index
.
sizes
().
vec
();
for
(
int
i
=
0
;
i
<
index
.
dim
();
i
++
)
{
sizes
[
i
]
=
src
.
size
(
i
);
}
index
=
index
.
expand
(
sizes
);
src
=
src
.
contiguous
();
src
=
src
.
contiguous
();
out
=
out
.
contiguous
();
out
=
out
.
contiguous
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment