Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
ff285368
Commit
ff285368
authored
Mar 25, 2021
by
rusty1s
Browse files
remove @torch.jit.script annotations (move jit compatibility to the test suite)
parent
3341dbeb
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
54 additions
and
47 deletions
+54
-47
test/composite/test_logsumexp.py
test/composite/test_logsumexp.py
+3
-0
test/composite/test_softmax.py
test/composite/test_softmax.py
+6
-0
test/composite/test_std.py
test/composite/test_std.py
+3
-0
test/test_scatter.py
test/test_scatter.py
+11
-5
test/test_segment.py
test/test_segment.py
+23
-11
torch_scatter/composite/logsumexp.py
torch_scatter/composite/logsumexp.py
+0
-1
torch_scatter/composite/softmax.py
torch_scatter/composite/softmax.py
+0
-2
torch_scatter/composite/std.py
torch_scatter/composite/std.py
+0
-1
torch_scatter/scatter.py
torch_scatter/scatter.py
+0
-6
torch_scatter/segment_coo.py
torch_scatter/segment_coo.py
+8
-14
torch_scatter/segment_csr.py
torch_scatter/segment_csr.py
+0
-6
torch_scatter/utils.py
torch_scatter/utils.py
+0
-1
No files found.
test/composite/test_logsumexp.py
View file @
ff285368
...
...
@@ -18,3 +18,6 @@ def test_logsumexp():
assert
out
.
tolist
()
==
torch
.
logsumexp
(
src
,
dim
=
0
).
tolist
()
outputs
.
backward
(
torch
.
randn_like
(
outputs
))
jit
=
torch
.
jit
.
script
(
scatter_logsumexp
)
assert
jit
(
inputs
,
index
).
tolist
()
==
outputs
.
tolist
()
test/composite/test_softmax.py
View file @
ff285368
...
...
@@ -22,6 +22,9 @@ def test_softmax():
out
.
backward
(
torch
.
randn_like
(
out
))
jit
=
torch
.
jit
.
script
(
scatter_softmax
)
assert
jit
(
src
,
index
).
tolist
()
==
out
.
tolist
()
def
test_log_softmax
():
src
=
torch
.
tensor
([
0.2
,
0
,
0.2
,
-
2.1
,
3.2
,
7
,
-
1
,
float
(
'-inf'
)])
...
...
@@ -42,3 +45,6 @@ def test_log_softmax():
assert
torch
.
allclose
(
out
,
expected
)
out
.
backward
(
torch
.
randn_like
(
out
))
jit
=
torch
.
jit
.
script
(
scatter_log_softmax
)
assert
jit
(
src
,
index
).
tolist
()
==
out
.
tolist
()
test/composite/test_std.py
View file @
ff285368
...
...
@@ -13,3 +13,6 @@ def test_std():
assert
torch
.
allclose
(
out
,
expected
)
out
.
backward
(
torch
.
randn_like
(
out
))
jit
=
torch
.
jit
.
script
(
scatter_std
)
assert
jit
(
src
,
index
,
dim
=-
1
,
unbiased
=
True
).
tolist
()
==
out
.
tolist
()
test/test_scatter.py
View file @
ff285368
...
...
@@ -99,12 +99,18 @@ def test_forward(test, reduce, dtype, device):
dim
=
test
[
'dim'
]
expected
=
tensor
(
test
[
reduce
],
dtype
,
device
)
out
=
getattr
(
torch_scatter
,
'scatter_'
+
reduce
)(
src
,
index
,
dim
)
if
isinstance
(
out
,
tuple
):
out
,
arg_out
=
out
fn
=
getattr
(
torch_scatter
,
'scatter_'
+
reduce
)
jit
=
torch
.
jit
.
script
(
fn
)
out1
=
fn
(
src
,
index
,
dim
)
out2
=
jit
(
src
,
index
,
dim
)
if
isinstance
(
out1
,
tuple
):
out1
,
arg_out1
=
out1
out2
,
arg_out2
=
out2
arg_expected
=
tensor
(
test
[
'arg_'
+
reduce
],
torch
.
long
,
device
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
out
==
expected
)
assert
torch
.
all
(
arg_out1
==
arg_expected
)
assert
arg_out1
.
tolist
()
==
arg_out1
.
tolist
()
assert
torch
.
all
(
out1
==
expected
)
assert
out1
.
tolist
()
==
out2
.
tolist
()
@
pytest
.
mark
.
parametrize
(
'test,reduce,device'
,
...
...
test/test_segment.py
View file @
ff285368
...
...
@@ -91,19 +91,31 @@ def test_forward(test, reduce, dtype, device):
indptr
=
tensor
(
test
[
'indptr'
],
torch
.
long
,
device
)
expected
=
tensor
(
test
[
reduce
],
dtype
,
device
)
out
=
getattr
(
torch_scatter
,
'segment_'
+
reduce
+
'_csr'
)(
src
,
indptr
)
if
isinstance
(
out
,
tuple
):
out
,
arg_out
=
out
fn
=
getattr
(
torch_scatter
,
'segment_'
+
reduce
+
'_csr'
)
jit
=
torch
.
jit
.
script
(
fn
)
out1
=
fn
(
src
,
indptr
)
out2
=
jit
(
src
,
indptr
)
if
isinstance
(
out1
,
tuple
):
out1
,
arg_out1
=
out1
out2
,
arg_out2
=
out2
arg_expected
=
tensor
(
test
[
'arg_'
+
reduce
],
torch
.
long
,
device
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
out
==
expected
)
out
=
getattr
(
torch_scatter
,
'segment_'
+
reduce
+
'_coo'
)(
src
,
index
)
if
isinstance
(
out
,
tuple
):
out
,
arg_out
=
out
assert
torch
.
all
(
arg_out1
==
arg_expected
)
assert
arg_out1
.
tolist
()
==
arg_out2
.
tolist
()
assert
torch
.
all
(
out1
==
expected
)
assert
out1
.
tolist
()
==
out2
.
tolist
()
fn
=
getattr
(
torch_scatter
,
'segment_'
+
reduce
+
'_coo'
)
jit
=
torch
.
jit
.
script
(
fn
)
out1
=
fn
(
src
,
index
)
out2
=
jit
(
src
,
index
)
if
isinstance
(
out1
,
tuple
):
out1
,
arg_out1
=
out1
out2
,
arg_out2
=
out2
arg_expected
=
tensor
(
test
[
'arg_'
+
reduce
],
torch
.
long
,
device
)
assert
torch
.
all
(
arg_out
==
arg_expected
)
assert
torch
.
all
(
out
==
expected
)
assert
torch
.
all
(
arg_out1
==
arg_expected
)
assert
arg_out1
.
tolist
()
==
arg_out2
.
tolist
()
assert
torch
.
all
(
out1
==
expected
)
assert
out1
.
tolist
()
==
out2
.
tolist
()
@
pytest
.
mark
.
parametrize
(
'test,reduce,device'
,
...
...
torch_scatter/composite/logsumexp.py
View file @
ff285368
...
...
@@ -6,7 +6,6 @@ from torch_scatter import scatter_sum, scatter_max
from
torch_scatter.utils
import
broadcast
@
torch
.
jit
.
script
def
scatter_logsumexp
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dim_size
:
Optional
[
int
]
=
None
,
...
...
torch_scatter/composite/softmax.py
View file @
ff285368
...
...
@@ -4,7 +4,6 @@ from torch_scatter import scatter_sum, scatter_max
from
torch_scatter.utils
import
broadcast
@
torch
.
jit
.
script
def
scatter_softmax
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
eps
:
float
=
1e-12
)
->
torch
.
Tensor
:
if
not
torch
.
is_floating_point
(
src
):
...
...
@@ -25,7 +24,6 @@ def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
return
recentered_scores_exp
.
div
(
normalizing_constants
)
@
torch
.
jit
.
script
def
scatter_log_softmax
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
eps
:
float
=
1e-12
)
->
torch
.
Tensor
:
if
not
torch
.
is_floating_point
(
src
):
...
...
torch_scatter/composite/std.py
View file @
ff285368
...
...
@@ -5,7 +5,6 @@ from torch_scatter import scatter_sum
from
torch_scatter.utils
import
broadcast
@
torch
.
jit
.
script
def
scatter_std
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dim_size
:
Optional
[
int
]
=
None
,
...
...
torch_scatter/scatter.py
View file @
ff285368
...
...
@@ -5,7 +5,6 @@ import torch
from
.utils
import
broadcast
@
torch
.
jit
.
script
def
scatter_sum
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dim_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
...
...
@@ -24,21 +23,18 @@ def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
return
out
.
scatter_add_
(
dim
,
index
,
src
)
@
torch
.
jit
.
script
def
scatter_add
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dim_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
scatter_sum
(
src
,
index
,
dim
,
out
,
dim_size
)
@
torch
.
jit
.
script
def
scatter_mul
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dim_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
ops
.
torch_scatter
.
scatter_mul
(
src
,
index
,
dim
,
out
,
dim_size
)
@
torch
.
jit
.
script
def
scatter_mean
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dim_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
...
...
@@ -63,7 +59,6 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
return
out
@
torch
.
jit
.
script
def
scatter_min
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -71,7 +66,6 @@ def scatter_min(
return
torch
.
ops
.
torch_scatter
.
scatter_min
(
src
,
index
,
dim
,
out
,
dim_size
)
@
torch
.
jit
.
script
def
scatter_max
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
torch_scatter/segment_coo.py
View file @
ff285368
...
...
@@ -3,40 +3,35 @@ from typing import Optional, Tuple
import
torch
@
torch
.
jit
.
script
def
segment_sum_coo
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dim_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
ops
.
torch_scatter
.
segment_sum_coo
(
src
,
index
,
out
,
dim_size
)
@
torch
.
jit
.
script
def
segment_add_coo
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dim_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
ops
.
torch_scatter
.
segment_sum_coo
(
src
,
index
,
out
,
dim_size
)
@
torch
.
jit
.
script
def
segment_mean_coo
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dim_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
ops
.
torch_scatter
.
segment_mean_coo
(
src
,
index
,
out
,
dim_size
)
@
torch
.
jit
.
script
def
segment_min_coo
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dim_size
:
Optional
[
int
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
segment_min_coo
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dim_size
:
Optional
[
int
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
ops
.
torch_scatter
.
segment_min_coo
(
src
,
index
,
out
,
dim_size
)
@
torch
.
jit
.
script
def
segment_max_coo
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dim_size
:
Optional
[
int
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
segment_max_coo
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dim_size
:
Optional
[
int
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
ops
.
torch_scatter
.
segment_max_coo
(
src
,
index
,
out
,
dim_size
)
...
...
@@ -137,7 +132,6 @@ def segment_coo(src: torch.Tensor, index: torch.Tensor,
raise
ValueError
@
torch
.
jit
.
script
def
gather_coo
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
ops
.
torch_scatter
.
gather_coo
(
src
,
index
,
out
)
torch_scatter/segment_csr.py
View file @
ff285368
...
...
@@ -3,25 +3,21 @@ from typing import Optional, Tuple
import
torch
@
torch
.
jit
.
script
def
segment_sum_csr
(
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
ops
.
torch_scatter
.
segment_sum_csr
(
src
,
indptr
,
out
)
@
torch
.
jit
.
script
def
segment_add_csr
(
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
ops
.
torch_scatter
.
segment_sum_csr
(
src
,
indptr
,
out
)
@
torch
.
jit
.
script
def
segment_mean_csr
(
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
ops
.
torch_scatter
.
segment_mean_csr
(
src
,
indptr
,
out
)
@
torch
.
jit
.
script
def
segment_min_csr
(
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
]
=
None
...
...
@@ -29,7 +25,6 @@ def segment_min_csr(
return
torch
.
ops
.
torch_scatter
.
segment_min_csr
(
src
,
indptr
,
out
)
@
torch
.
jit
.
script
def
segment_max_csr
(
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
]
=
None
...
...
@@ -114,7 +109,6 @@ def segment_csr(src: torch.Tensor, indptr: torch.Tensor,
raise
ValueError
@
torch
.
jit
.
script
def
gather_csr
(
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
ops
.
torch_scatter
.
gather_csr
(
src
,
indptr
,
out
)
torch_scatter/utils.py
View file @
ff285368
import
torch
@
torch
.
jit
.
script
def
broadcast
(
src
:
torch
.
Tensor
,
other
:
torch
.
Tensor
,
dim
:
int
):
if
dim
<
0
:
dim
=
other
.
dim
()
+
dim
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment