Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
f98ff7e8
Commit
f98ff7e8
authored
Jan 09, 2020
by
rusty1s
Browse files
possible assertion fix
parent
4f6fe911
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
8 deletions
+9
-8
cuda/gather_kernel.cu
cuda/gather_kernel.cu
+9
-8
No files found.
cuda/gather_kernel.cu
View file @
f98ff7e8
...
@@ -61,20 +61,21 @@ __global__ void gather_csr_broadcast_kernel(
...
@@ -61,20 +61,21 @@ __global__ void gather_csr_broadcast_kernel(
at
::
Tensor
gather_csr_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
Tensor
gather_csr_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
optional
<
at
::
Tensor
>
out_opt
)
{
at
::
optional
<
at
::
Tensor
>
out_opt
)
{
AT_ASSERTM
(
src
.
dim
()
>=
indptr
.
dim
());
AT_ASSERTM
(
src
.
dim
()
>=
indptr
.
dim
()
,
"Input mismatch"
);
for
(
int
i
=
0
;
i
<
indptr
.
dim
()
-
1
;
i
++
)
for
(
int
i
=
0
;
i
<
indptr
.
dim
()
-
1
;
i
++
)
AT_ASSERTM
(
src
.
size
(
i
)
==
indptr
.
size
(
i
));
AT_ASSERTM
(
src
.
size
(
i
)
==
indptr
.
size
(
i
)
,
"Input mismatch"
);
src
=
src
.
contiguous
();
src
=
src
.
contiguous
();
auto
gather_dim
=
indptr
.
dim
()
-
1
;
auto
gather_dim
=
indptr
.
dim
()
-
1
;
AT_ASSERTM
(
src
.
size
(
gather_dim
)
==
indptr
.
size
(
gather_dim
)
-
1
);
AT_ASSERTM
(
src
.
size
(
gather_dim
)
==
indptr
.
size
(
gather_dim
)
-
1
,
"Input mismatch"
);
at
::
Tensor
out
;
at
::
Tensor
out
;
if
(
out_opt
.
has_value
())
{
if
(
out_opt
.
has_value
())
{
out
=
out_opt
.
value
().
contiguous
();
out
=
out_opt
.
value
().
contiguous
();
for
(
int
i
=
0
;
i
<
out
.
dim
();
i
++
)
for
(
int
i
=
0
;
i
<
out
.
dim
();
i
++
)
if
(
i
!=
gather_dim
)
if
(
i
!=
gather_dim
)
AT_ASSERTM
(
src
.
size
(
i
)
==
out
.
size
(
i
));
AT_ASSERTM
(
src
.
size
(
i
)
==
out
.
size
(
i
)
,
"Input mismatch"
);
}
else
{
}
else
{
auto
d_gather_size
=
indptr
.
flatten
()[
-
1
].
DATA_PTR
<
int64_t
>
();
auto
d_gather_size
=
indptr
.
flatten
()[
-
1
].
DATA_PTR
<
int64_t
>
();
auto
h_gather_size
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
auto
h_gather_size
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
...
@@ -154,9 +155,9 @@ __global__ void gather_coo_broadcast_kernel(
...
@@ -154,9 +155,9 @@ __global__ void gather_coo_broadcast_kernel(
at
::
Tensor
gather_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
gather_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
optional
<
at
::
Tensor
>
out_opt
)
{
at
::
optional
<
at
::
Tensor
>
out_opt
)
{
AT_ASSERTM
(
src
.
dim
()
>=
index
.
dim
());
AT_ASSERTM
(
src
.
dim
()
>=
index
.
dim
()
,
"Input mismatch"
);
for
(
int
i
=
0
;
i
<
index
.
dim
()
-
1
;
i
++
)
for
(
int
i
=
0
;
i
<
index
.
dim
()
-
1
;
i
++
)
AT_ASSERTM
(
src
.
size
(
i
)
==
index
.
size
(
i
));
AT_ASSERTM
(
src
.
size
(
i
)
==
index
.
size
(
i
)
,
"Input mismatch"
);
src
=
src
.
contiguous
();
src
=
src
.
contiguous
();
auto
gather_dim
=
index
.
dim
()
-
1
;
auto
gather_dim
=
index
.
dim
()
-
1
;
...
@@ -165,9 +166,9 @@ at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index,
...
@@ -165,9 +166,9 @@ at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index,
if
(
out_opt
.
has_value
())
{
if
(
out_opt
.
has_value
())
{
out
=
out_opt
.
value
().
contiguous
();
out
=
out_opt
.
value
().
contiguous
();
for
(
int
i
=
0
;
i
<
index
.
dim
();
i
++
)
for
(
int
i
=
0
;
i
<
index
.
dim
();
i
++
)
AT_ASSERTM
(
out
.
size
(
i
)
==
index
.
size
(
i
));
AT_ASSERTM
(
out
.
size
(
i
)
==
index
.
size
(
i
)
,
"Input mismatch"
);
for
(
int
i
=
index
.
dim
()
+
1
;
i
<
src
.
dim
();
i
++
)
for
(
int
i
=
index
.
dim
()
+
1
;
i
<
src
.
dim
();
i
++
)
AT_ASSERTM
(
out
.
size
(
i
)
==
src
.
size
(
i
));
AT_ASSERTM
(
out
.
size
(
i
)
==
src
.
size
(
i
)
,
"Input mismatch"
);
}
else
{
}
else
{
auto
sizes
=
src
.
sizes
().
vec
();
auto
sizes
=
src
.
sizes
().
vec
();
sizes
[
gather_dim
]
=
index
.
size
(
gather_dim
);
sizes
[
gather_dim
]
=
index
.
size
(
gather_dim
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment