Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
4a119480
"tests/vscode:/vscode.git/clone" did not exist on "d2e7a19fd52e13ad2d9036eefd4c7bafbb8c4303"
Commit
4a119480
authored
May 07, 2019
by
rusty1s
Browse files
fixed fill value for min and max
parent
e4d17fe8
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
47 additions
and
11 deletions
+47
-11
test/test_max_min.py
test/test_max_min.py
+22
-0
torch_scatter/add.py
torch_scatter/add.py
+1
-1
torch_scatter/div.py
torch_scatter/div.py
+1
-1
torch_scatter/max.py
torch_scatter/max.py
+9
-3
torch_scatter/mean.py
torch_scatter/mean.py
+1
-1
torch_scatter/min.py
torch_scatter/min.py
+11
-3
torch_scatter/mul.py
torch_scatter/mul.py
+1
-1
torch_scatter/sub.py
torch_scatter/sub.py
+1
-1
No files found.
test/test_max_min.py
0 → 100644
View file @
4a119480
import
torch
from
torch_scatter
import
scatter_max
,
scatter_min
def
test_max_fill_value
():
src
=
torch
.
Tensor
([[
2
,
0
,
1
,
4
,
3
],
[
0
,
2
,
1
,
3
,
4
]])
index
=
torch
.
tensor
([[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]])
out
,
_
=
scatter_max
(
src
,
index
)
v
=
torch
.
finfo
(
torch
.
float
).
min
assert
out
.
tolist
()
==
[[
v
,
v
,
4
,
3
,
2
,
0
],
[
2
,
4
,
3
,
v
,
v
,
v
]]
def
test_min_fill_value
():
src
=
torch
.
Tensor
([[
-
2
,
0
,
-
1
,
-
4
,
-
3
],
[
0
,
-
2
,
-
1
,
-
3
,
-
4
]])
index
=
torch
.
tensor
([[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]])
out
,
_
=
scatter_min
(
src
,
index
)
v
=
torch
.
finfo
(
torch
.
float
).
max
assert
out
.
tolist
()
==
[[
v
,
v
,
-
4
,
-
3
,
-
2
,
0
],
[
-
2
,
-
4
,
-
3
,
v
,
v
,
v
]]
torch_scatter/add.py
View file @
4a119480
...
...
@@ -56,7 +56,7 @@ def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
from torch_scatter import scatter_add
src = torch.
t
ensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
src = torch.
T
ensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
...
...
torch_scatter/div.py
View file @
4a119480
...
...
@@ -75,7 +75,7 @@ def scatter_div(src, index, dim=-1, out=None, dim_size=None, fill_value=1):
from torch_scatter import scatter_div
src = torch.
t
ensor([[2, 1, 1, 4, 2], [1, 2, 1, 2, 4]]).float()
src = torch.
T
ensor([[2, 1, 1, 4, 2], [1, 2, 1, 2, 4]]).float()
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_ones((2, 6))
...
...
torch_scatter/max.py
View file @
4a119480
import
torch
from
torch.autograd
import
Function
from
torch_scatter.utils.ext
import
get_func
...
...
@@ -30,7 +31,7 @@ class ScatterMax(Function):
return
None
,
grad_src
,
None
,
None
def
scatter_max
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
0
):
def
scatter_max
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
None
):
r
"""
|
...
...
@@ -67,7 +68,9 @@ def scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`0`)
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the smallest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
:rtype: (:class:`Tensor`, :class:`LongTensor`)
...
...
@@ -79,7 +82,7 @@ def scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
from torch_scatter import scatter_max
src = torch.
t
ensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
src = torch.
T
ensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
...
...
@@ -95,6 +98,9 @@ def scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
tensor([[-1, -1, 3, 4, 0, 1],
[ 1, 4, 3, -1, -1, -1]])
"""
if
fill_value
is
None
:
op
=
torch
.
finfo
if
torch
.
is_floating_point
(
src
)
else
torch
.
iinfo
fill_value
=
op
(
src
.
dtype
).
min
src
,
out
,
index
,
dim
=
gen
(
src
,
index
,
dim
,
out
,
dim_size
,
fill_value
)
if
src
.
size
(
dim
)
==
0
:
# pragma: no cover
return
out
,
index
.
new_full
(
out
.
size
(),
-
1
)
...
...
torch_scatter/mean.py
View file @
4a119480
...
...
@@ -52,7 +52,7 @@ def scatter_mean(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
from torch_scatter import scatter_mean
src = torch.
t
ensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
.float()
src = torch.
T
ensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
...
...
torch_scatter/min.py
View file @
4a119480
import
torch
from
torch.autograd
import
Function
from
torch_scatter.utils.ext
import
get_func
...
...
@@ -30,7 +31,7 @@ class ScatterMin(Function):
return
None
,
grad_src
,
None
,
None
def
scatter_min
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
0
):
def
scatter_min
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
None
):
r
"""
|
...
...
@@ -67,7 +68,11 @@ def scatter_min(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`0`)
fill output tensor with :attr:`fill_value`. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. If set to :obj:`None`,
the output tensor is filled with the greatest possible value of
:obj:`src.dtype`. (default: :obj:`None`)
:rtype: (:class:`Tensor`, :class:`LongTensor`)
...
...
@@ -79,7 +84,7 @@ def scatter_min(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
from torch_scatter import scatter_min
src = torch.
t
ensor([[-2, 0, -1, -4, -3], [0, -2, -1, -3, -4]])
src = torch.
T
ensor([[-2, 0, -1, -4, -3], [0, -2, -1, -3, -4]])
index = torch.tensor([[ 4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
...
...
@@ -95,6 +100,9 @@ def scatter_min(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
tensor([[-1, -1, 3, 4, 0, 1],
[ 1, 4, 3, -1, -1, -1]])
"""
if
fill_value
is
None
:
op
=
torch
.
finfo
if
torch
.
is_floating_point
(
src
)
else
torch
.
iinfo
fill_value
=
op
(
src
.
dtype
).
max
src
,
out
,
index
,
dim
=
gen
(
src
,
index
,
dim
,
out
,
dim_size
,
fill_value
)
if
src
.
size
(
dim
)
==
0
:
# pragma: no cover
return
out
,
index
.
new_full
(
out
.
size
(),
-
1
)
...
...
torch_scatter/mul.py
View file @
4a119480
...
...
@@ -74,7 +74,7 @@ def scatter_mul(src, index, dim=-1, out=None, dim_size=None, fill_value=1):
from torch_scatter import scatter_mul
src = torch.
t
ensor([[2, 0, 3, 4, 3], [2, 3, 4, 2, 4]])
src = torch.
T
ensor([[2, 0, 3, 4, 3], [2, 3, 4, 2, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_ones((2, 6))
...
...
torch_scatter/sub.py
View file @
4a119480
...
...
@@ -48,7 +48,7 @@ def scatter_sub(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
from torch_scatter import scatter_sub
src = torch.
t
ensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
src = torch.
T
ensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment