Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
07b3b17c
Commit
07b3b17c
authored
Jan 15, 2020
by
Koch
Browse files
fix: use std::vector for dynamically sized array in gather.cpp
parent
b51c22a4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
2 deletions
+4
-2
cpu/gather.cpp
cpu/gather.cpp
+4
-2
No files found.
cpu/gather.cpp
View file @
07b3b17c
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
#include "compat.h"
#include "compat.h"
#include "index_info.h"
#include "index_info.h"
#include <vector>
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
at
::
Tensor
gather_csr
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
Tensor
gather_csr
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
...
@@ -43,7 +45,7 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
...
@@ -43,7 +45,7 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
scalar_t
vals
[
K
]
;
std
::
vector
<
scalar_t
>
vals
(
K
)
;
int64_t
row_start
,
row_end
;
int64_t
row_start
,
row_end
;
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
int
offset
=
IndexPtrToOffset
<
int64_t
>::
get
(
n
,
indptr_info
);
int
offset
=
IndexPtrToOffset
<
int64_t
>::
get
(
n
,
indptr_info
);
...
@@ -104,7 +106,7 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
...
@@ -104,7 +106,7 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
scalar_t
vals
[
K
]
;
std
::
vector
<
scalar_t
>
vals
(
K
)
;
int64_t
idx
,
next_idx
;
int64_t
idx
,
next_idx
;
for
(
int
e_1
=
0
;
e_1
<
E_1
;
e_1
++
)
{
for
(
int
e_1
=
0
;
e_1
<
E_1
;
e_1
++
)
{
int
offset
=
IndexToOffset
<
int64_t
>::
get
(
e_1
*
E_2
,
index_info
);
int
offset
=
IndexToOffset
<
int64_t
>::
get
(
e_1
*
E_2
,
index_info
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment