Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
ff3be8e3
Commit
ff3be8e3
authored
Jan 31, 2020
by
rusty1s
Browse files
pytorch 1.4 support: toIntVector -> to IntList
parent
7ef77d92
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
29 additions
and
14 deletions
+29
-14
csrc/scatter.cpp
csrc/scatter.cpp
+5
-4
csrc/segment_coo.cpp
csrc/segment_coo.cpp
+6
-5
csrc/segment_csr.cpp
csrc/segment_csr.cpp
+6
-5
csrc/utils.h
csrc/utils.h
+12
-0
No files found.
csrc/scatter.cpp
View file @
ff3be8e3
#include <torch/script.h>
#include "cpu/scatter_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA
#include "cuda/scatter_cuda.h"
...
...
@@ -58,7 +59,7 @@ public:
auto
saved
=
ctx
->
get_saved_variables
();
auto
index
=
saved
[
0
];
auto
dim
=
ctx
->
saved_data
[
"dim"
].
toInt
();
auto
src_shape
=
ctx
->
saved_data
[
"src_shape"
].
toInt
Vector
(
);
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toInt
List
()
);
auto
grad_in
=
torch
::
gather
(
grad_out
,
dim
,
index
,
false
);
return
{
grad_in
,
Variable
(),
Variable
(),
Variable
(),
Variable
()};
}
...
...
@@ -100,7 +101,7 @@ public:
auto
index
=
saved
[
0
];
auto
count
=
saved
[
1
];
auto
dim
=
ctx
->
saved_data
[
"dim"
].
toInt
();
auto
src_shape
=
ctx
->
saved_data
[
"src_shape"
].
toInt
Vector
(
);
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toInt
List
()
);
count
=
torch
::
gather
(
count
,
dim
,
index
,
false
);
auto
grad_in
=
torch
::
gather
(
grad_out
,
dim
,
index
,
false
);
grad_in
.
div_
(
count
);
...
...
@@ -134,7 +135,7 @@ public:
auto
index
=
saved
[
0
];
auto
arg_out
=
saved
[
1
];
auto
dim
=
ctx
->
saved_data
[
"dim"
].
toInt
();
auto
src_shape
=
ctx
->
saved_data
[
"src_shape"
].
toInt
Vector
(
);
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toInt
List
()
);
src_shape
[
dim
]
+=
1
;
auto
grad_in
=
torch
::
zeros
(
src_shape
,
grad_out
.
options
());
grad_in
.
scatter_
(
dim
,
arg_out
,
grad_out
);
...
...
@@ -169,7 +170,7 @@ public:
auto
index
=
saved
[
0
];
auto
arg_out
=
saved
[
1
];
auto
dim
=
ctx
->
saved_data
[
"dim"
].
toInt
();
auto
src_shape
=
ctx
->
saved_data
[
"src_shape"
].
toInt
Vector
(
);
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toInt
List
()
);
src_shape
[
dim
]
+=
1
;
auto
grad_in
=
torch
::
zeros
(
src_shape
,
grad_out
.
options
());
grad_in
.
scatter_
(
dim
,
arg_out
,
grad_out
);
...
...
csrc/segment_coo.cpp
View file @
ff3be8e3
#include <torch/script.h>
#include "cpu/segment_coo_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA
#include "cuda/segment_coo_cuda.h"
...
...
@@ -57,7 +58,7 @@ public:
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
index
=
saved
[
0
];
auto
src_shape
=
ctx
->
saved_data
[
"src_shape"
].
toInt
Vector
(
);
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toInt
List
()
);
auto
grad_in
=
torch
::
empty
(
src_shape
,
grad_out
.
options
());
gather_coo_fw
(
grad_out
,
index
,
grad_in
);
return
{
grad_in
,
Variable
(),
Variable
(),
Variable
()};
...
...
@@ -85,7 +86,7 @@ public:
auto
saved
=
ctx
->
get_saved_variables
();
auto
index
=
saved
[
0
];
auto
count
=
saved
[
1
];
auto
src_shape
=
ctx
->
saved_data
[
"src_shape"
].
toInt
Vector
(
);
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toInt
List
()
);
auto
grad_in
=
torch
::
empty
(
src_shape
,
grad_out
.
options
());
gather_coo_fw
(
grad_out
,
index
,
grad_in
);
count
=
gather_coo_fw
(
count
,
index
,
torch
::
nullopt
);
...
...
@@ -118,7 +119,7 @@ public:
auto
saved
=
ctx
->
get_saved_variables
();
auto
index
=
saved
[
0
];
auto
arg_out
=
saved
[
1
];
auto
src_shape
=
ctx
->
saved_data
[
"src_shape"
].
toInt
Vector
(
);
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toInt
List
()
);
src_shape
[
index
.
dim
()
-
1
]
+=
1
;
auto
grad_in
=
torch
::
zeros
(
src_shape
,
grad_out
.
options
());
grad_in
.
scatter_
(
index
.
dim
()
-
1
,
arg_out
,
grad_out
);
...
...
@@ -150,7 +151,7 @@ public:
auto
saved
=
ctx
->
get_saved_variables
();
auto
index
=
saved
[
0
];
auto
arg_out
=
saved
[
1
];
auto
src_shape
=
ctx
->
saved_data
[
"src_shape"
].
toInt
Vector
(
);
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toInt
List
()
);
src_shape
[
index
.
dim
()
-
1
]
+=
1
;
auto
grad_in
=
torch
::
zeros
(
src_shape
,
grad_out
.
options
());
grad_in
.
scatter_
(
index
.
dim
()
-
1
,
arg_out
,
grad_out
);
...
...
@@ -177,7 +178,7 @@ public:
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
index
=
saved
[
0
];
auto
src_shape
=
ctx
->
saved_data
[
"src_shape"
].
toInt
Vector
(
);
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toInt
List
()
);
auto
grad_in
=
torch
::
zeros
(
src_shape
,
grad_out
.
options
());
segment_coo_fw
(
grad_out
,
index
,
grad_in
,
torch
::
nullopt
,
"sum"
);
...
...
csrc/segment_csr.cpp
View file @
ff3be8e3
#include <torch/script.h>
#include "cpu/segment_csr_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA
#include "cuda/segment_csr_cuda.h"
...
...
@@ -55,7 +56,7 @@ public:
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
indptr
=
saved
[
0
];
auto
src_shape
=
ctx
->
saved_data
[
"src_shape"
].
toInt
Vector
(
);
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toInt
List
()
);
auto
grad_in
=
torch
::
empty
(
src_shape
,
grad_out
.
options
());
gather_csr_fw
(
grad_out
,
indptr
,
grad_in
);
return
{
grad_in
,
Variable
(),
Variable
()};
...
...
@@ -79,7 +80,7 @@ public:
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
indptr
=
saved
[
0
];
auto
src_shape
=
ctx
->
saved_data
[
"src_shape"
].
toInt
Vector
(
);
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toInt
List
()
);
auto
grad_in
=
torch
::
empty
(
src_shape
,
grad_out
.
options
());
gather_csr_fw
(
grad_out
,
indptr
,
grad_in
);
auto
indptr1
=
indptr
.
narrow
(
-
1
,
0
,
indptr
.
size
(
-
1
)
-
1
);
...
...
@@ -114,7 +115,7 @@ public:
auto
saved
=
ctx
->
get_saved_variables
();
auto
indptr
=
saved
[
0
];
auto
arg_out
=
saved
[
1
];
auto
src_shape
=
ctx
->
saved_data
[
"src_shape"
].
toInt
Vector
(
);
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toInt
List
()
);
src_shape
[
indptr
.
dim
()
-
1
]
+=
1
;
auto
grad_in
=
torch
::
zeros
(
src_shape
,
grad_out
.
options
());
grad_in
.
scatter_
(
indptr
.
dim
()
-
1
,
arg_out
,
grad_out
);
...
...
@@ -145,7 +146,7 @@ public:
auto
saved
=
ctx
->
get_saved_variables
();
auto
indptr
=
saved
[
0
];
auto
arg_out
=
saved
[
1
];
auto
src_shape
=
ctx
->
saved_data
[
"src_shape"
].
toInt
Vector
(
);
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toInt
List
()
);
src_shape
[
indptr
.
dim
()
-
1
]
+=
1
;
auto
grad_in
=
torch
::
zeros
(
src_shape
,
grad_out
.
options
());
grad_in
.
scatter_
(
indptr
.
dim
()
-
1
,
arg_out
,
grad_out
);
...
...
@@ -172,7 +173,7 @@ public:
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
indptr
=
saved
[
0
];
auto
src_shape
=
ctx
->
saved_data
[
"src_shape"
].
toInt
Vector
(
);
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toInt
List
()
);
auto
grad_in
=
torch
::
empty
(
src_shape
,
grad_out
.
options
());
segment_csr_fw
(
grad_out
,
indptr
,
grad_in
,
"sum"
);
...
...
csrc/utils.h
0 → 100644
View file @
ff3be8e3
#pragma once
#include <torch/script.h>
#include <vector>
inline
std
::
vector
<
int64_t
>
list2vec
(
const
c10
::
List
<
int64_t
>
list
)
{
std
::
vector
<
int64_t
>
result
;
result
.
reserve
(
list
.
size
());
for
(
size_t
i
=
0
;
i
<
list
.
size
();
i
++
)
result
.
push_back
(
list
[
i
]);
return
result
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment