Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
cebec48f
Commit
cebec48f
authored
Jan 09, 2020
by
rusty1s
Browse files
possible assertion fix
parent
f82bfbac
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
7 deletions
+8
-7
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+8
-7
No files found.
cuda/segment_kernel.cu
View file @
cebec48f
...
...
@@ -182,9 +182,9 @@ std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_csr_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
optional
<
at
::
Tensor
>
out_opt
,
std
::
string
reduce
)
{
AT_ASSERTM
(
src
.
dim
()
>=
indptr
.
dim
());
AT_ASSERTM
(
src
.
dim
()
>=
indptr
.
dim
()
,
"Input mismatch"
);
for
(
int
i
=
0
;
i
<
indptr
.
dim
()
-
1
;
i
++
)
AT_ASSERTM
(
src
.
size
(
i
)
==
indptr
.
size
(
i
));
AT_ASSERTM
(
src
.
size
(
i
)
==
indptr
.
size
(
i
)
,
"Input mismatch"
);
src
=
src
.
contiguous
();
auto
reduce_dim
=
indptr
.
dim
()
-
1
;
...
...
@@ -194,8 +194,9 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
out
=
out_opt
.
value
().
contiguous
();
for
(
int
i
=
0
;
i
<
out
.
dim
();
i
++
)
if
(
i
!=
reduce_dim
)
AT_ASSERTM
(
src
.
size
(
i
)
==
out
.
size
(
i
));
AT_ASSERTM
(
out
.
size
(
reduce_dim
)
==
indptr
.
size
(
reduce_dim
)
-
1
);
AT_ASSERTM
(
src
.
size
(
i
)
==
out
.
size
(
i
),
"Input mismatch"
);
AT_ASSERTM
(
out
.
size
(
reduce_dim
)
==
indptr
.
size
(
reduce_dim
)
-
1
,
"Input
mismatch"
);
}
else
{
auto
sizes
=
src
.
sizes
().
vec
();
sizes
[
reduce_dim
]
=
indptr
.
size
(
reduce_dim
)
-
1
;
...
...
@@ -341,9 +342,9 @@ __global__ void segment_coo_broadcast_kernel(
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
segment_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
std
::
string
reduce
)
{
AT_ASSERTM
(
src
.
dim
()
>=
index
.
dim
());
AT_ASSERTM
(
src
.
dim
()
>=
index
.
dim
()
,
"Input mismatch"
);
for
(
int
i
=
0
;
i
<
index
.
dim
();
i
++
)
AT_ASSERTM
(
src
.
size
(
i
)
==
index
.
size
(
i
));
AT_ASSERTM
(
src
.
size
(
i
)
==
index
.
size
(
i
)
,
"Input mismatch"
);
src
=
src
.
contiguous
();
out
=
out
.
contiguous
();
...
...
@@ -351,7 +352,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
for
(
int
i
=
0
;
i
<
out
.
dim
();
i
++
)
if
(
i
!=
reduce_dim
)
AT_ASSERTM
(
src
.
size
(
i
)
==
out
.
size
(
i
));
AT_ASSERTM
(
src
.
size
(
i
)
==
out
.
size
(
i
)
,
,
"Input mismatch"
);
at
::
optional
<
at
::
Tensor
>
arg_out
=
at
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment