Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
bf1f1014
Commit
bf1f1014
authored
Feb 03, 2020
by
rusty1s
Browse files
use scatter add pytorch implementation
parent
1006514c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
29 additions
and
9 deletions
+29
-9
benchmark/scatter_segment.py
benchmark/scatter_segment.py
+1
-1
test/test_zero_tensors.py
test/test_zero_tensors.py
+11
-0
torch_scatter/composite/logsumexp.py
torch_scatter/composite/logsumexp.py
+1
-1
torch_scatter/composite/softmax.py
torch_scatter/composite/softmax.py
+1
-2
torch_scatter/composite/std.py
torch_scatter/composite/std.py
+1
-2
torch_scatter/scatter.py
torch_scatter/scatter.py
+14
-3
torch_scatter/utils.py
torch_scatter/utils.py
+0
-0
No files found.
benchmark/scatter_segment.py
View file @
bf1f1014
...
...
@@ -217,7 +217,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--with_backward'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda'
)
args
=
parser
.
parse_args
()
iters
=
1
if
args
.
device
==
'cpu'
else
5
0
iters
=
1
if
args
.
device
==
'cpu'
else
2
0
sizes
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
sizes
=
sizes
[:
3
]
if
args
.
device
==
'cpu'
else
sizes
...
...
test/test_zero_tensors.py
0 → 100644
View file @
bf1f1014
import
torch
from
torch_scatter
import
scatter
def
test_zero_elements
():
x
=
torch
.
randn
(
0
,
16
)
index
=
torch
.
tensor
([]).
view
(
0
,
16
)
print
(
x
)
print
(
index
)
scatter
(
x
,
index
,
dim
=
0
,
dim_size
=
0
,
reduce
=
"add"
)
torch_scatter/composite/logsumexp.py
View file @
bf1f1014
...
...
@@ -3,7 +3,7 @@ from typing import Optional
import
torch
from
torch_scatter
import
scatter_sum
,
scatter_max
from
.utils
import
broadcast
from
torch_scatter
.utils
import
broadcast
@
torch
.
jit
.
script
...
...
torch_scatter/composite/softmax.py
View file @
bf1f1014
import
torch
from
torch_scatter
import
scatter_sum
,
scatter_max
from
.utils
import
broadcast
from
torch_scatter.utils
import
broadcast
@
torch
.
jit
.
script
...
...
torch_scatter/composite/std.py
View file @
bf1f1014
...
...
@@ -2,8 +2,7 @@ from typing import Optional
import
torch
from
torch_scatter
import
scatter_sum
from
.utils
import
broadcast
from
torch_scatter.utils
import
broadcast
@
torch
.
jit
.
script
...
...
torch_scatter/scatter.py
View file @
bf1f1014
...
...
@@ -4,6 +4,8 @@ from typing import Optional, Tuple
import
torch
from
.utils
import
broadcast
try
:
torch
.
ops
.
load_library
(
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'_scatter.so'
))
...
...
@@ -23,7 +25,6 @@ except OSError:
raise
ImportError
return
src
,
index
torch
.
ops
.
torch_scatter
.
scatter_sum
=
scatter_placeholder
torch
.
ops
.
torch_scatter
.
scatter_mean
=
scatter_placeholder
torch
.
ops
.
torch_scatter
.
scatter_min
=
scatter_with_arg_placeholder
torch
.
ops
.
torch_scatter
.
scatter_max
=
scatter_with_arg_placeholder
...
...
@@ -33,14 +34,24 @@ except OSError:
def
scatter_sum
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dim_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
ops
.
torch_scatter
.
scatter_sum
(
src
,
index
,
dim
,
out
,
dim_size
)
index
=
broadcast
(
index
,
src
,
dim
)
if
out
is
None
:
size
=
src
.
size
()
if
dim_size
is
None
:
size
[
dim
]
=
int
(
index
.
max
())
+
1
else
:
size
[
dim
]
=
dim_size
out
=
src
.
new_zeros
(
size
)
return
out
.
scatter_add_
(
dim
,
index
,
src
)
else
:
return
out
.
scatter_add_
(
dim
,
index
,
src
)
@
torch
.
jit
.
script
def
scatter_add
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dim_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
ops
.
torch_scatter
.
scatter_sum
(
src
,
index
,
dim
,
out
,
dim_size
)
return
scatter_sum
(
src
,
index
,
dim
,
out
,
dim_size
)
@
torch
.
jit
.
script
...
...
torch_scatter/
composite/
utils.py
→
torch_scatter/utils.py
View file @
bf1f1014
File moved
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment