Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
f0fdfe20
Commit
f0fdfe20
authored
Feb 11, 2018
by
rusty1s
Browse files
correct gradient computations
parent
43432f49
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
180 additions
and
84 deletions
+180
-84
torch_scatter/__init__.py
torch_scatter/__init__.py
+1
-1
torch_scatter/functions/div.py
torch_scatter/functions/div.py
+15
-2
torch_scatter/functions/ffi.py
torch_scatter/functions/ffi.py
+46
-0
torch_scatter/functions/max.py
torch_scatter/functions/max.py
+16
-2
torch_scatter/functions/mean.py
torch_scatter/functions/mean.py
+18
-3
torch_scatter/functions/min.py
torch_scatter/functions/min.py
+16
-2
torch_scatter/functions/mul.py
torch_scatter/functions/mul.py
+15
-2
torch_scatter/functions/scatter.py
torch_scatter/functions/scatter.py
+53
-72
No files found.
torch_scatter/__init__.py
View file @
f0fdfe20
...
...
@@ -6,7 +6,7 @@ from .functions.mean import scatter_mean_, scatter_mean
from
.functions.max
import
scatter_max_
,
scatter_max
from
.functions.min
import
scatter_min_
,
scatter_min
__version__
=
'0.
2.3
'
__version__
=
'0.
3.0
'
__all__
=
[
'scatter_add_'
,
'scatter_add'
,
'scatter_sub_'
,
'scatter_sub'
,
...
...
torch_scatter/functions/div.py
View file @
f0fdfe20
from
.scatter
import
scatter
from
.scatter
import
Scatter
,
scatter
from
.utils
import
gen_output
class
ScatterDiv
(
Scatter
):
def
__init__
(
self
,
dim
):
super
(
ScatterDiv
,
self
).
__init__
(
'div'
,
dim
)
def
save_for_backward_step
(
self
,
*
data
):
output
,
index
,
input
=
data
self
.
save_for_backward
(
output
,
index
,
input
)
def
backward_step
(
self
,
*
data
):
grad
,
output
,
index
,
input
=
data
return
(
grad
/
output
.
data
).
gather
(
self
.
dim
,
index
.
data
)
*
input
.
data
def
scatter_div_
(
output
,
index
,
input
,
dim
=
0
):
r
"""
|
...
...
@@ -53,7 +66,7 @@ def scatter_div_(output, index, input, dim=0):
0.5000 0.2500 0.1667 1.0000 1.0000 1.0000
[torch.FloatTensor of size 2x6]
"""
return
scatter
(
'div'
,
dim
,
output
,
index
,
input
)
return
scatter
(
ScatterDiv
,
'div'
,
dim
,
output
,
index
,
input
)
def
scatter_div
(
index
,
input
,
dim
=
0
,
size
=
None
,
fill_value
=
1
):
...
...
torch_scatter/functions/ffi.py
0 → 100644
View file @
f0fdfe20
from
itertools
import
chain
from
.._ext
import
ffi
def
scatter
(
name
,
dim
,
*
data
):
# data = output, index, input, additional data
a
,
b
,
c
=
data
[:
3
]
# Assert index dimension is valid.
assert
dim
>=
0
and
dim
<
b
.
dim
(),
'Index dimension is out of bounds'
# Assert same dimensionality across all inputs.
assert
b
.
dim
()
==
c
.
dim
(),
(
'Index tensor must have same dimensions as '
'input tensor'
)
assert
a
.
dim
()
==
c
.
dim
(),
(
'Input tensor must have same dimensions as '
'output tensor'
)
# Assert same tensor length across index and input.
assert
b
.
numel
()
==
c
.
numel
(),
(
'Index tensor must have same size as '
'input tensor'
)
# Assert same tensor sizes across input and output apart from `dim`.
for
d
in
chain
(
range
(
dim
),
range
(
dim
+
1
,
a
.
dim
())):
assert
a
.
size
(
d
)
==
c
.
size
(
d
),
(
'Input tensor must have same size as output tensor apart from the '
'specified dimension'
)
typename
=
type
(
data
[
0
]).
__name__
.
replace
(
'Tensor'
,
''
)
cuda
=
'cuda_'
if
data
[
0
].
is_cuda
else
''
func
=
getattr
(
ffi
,
'scatter_{}_{}{}'
.
format
(
name
,
cuda
,
typename
))
func
(
dim
,
*
data
)
if
len
(
data
)
<=
3
:
return
data
[
0
]
return
(
data
[
0
],
)
+
tuple
(
data
[
3
:])
def
index_backward
(
dim
,
index
,
grad
,
arg
):
# pragma: no cover
typename
=
type
(
grad
).
__name__
.
replace
(
'Tensor'
,
''
)
cuda
=
'cuda_'
if
grad
.
is_cuda
else
''
func
=
getattr
(
ffi
,
'index_backward_{}{}'
.
format
(
cuda
,
typename
))
output
=
grad
.
new
(
index
.
size
()).
fill_
(
0
)
func
(
dim
,
output
,
index
,
grad
,
arg
)
return
output
torch_scatter/functions/max.py
View file @
f0fdfe20
from
.scatter
import
scatter
from
.scatter
import
Scatter
,
scatter
from
.ffi
import
index_backward
from
.utils
import
gen_filled_tensor
,
gen_output
class
ScatterMax
(
Scatter
):
def
__init__
(
self
,
dim
):
super
(
ScatterMax
,
self
).
__init__
(
'max'
,
dim
)
def
save_for_backward_step
(
self
,
*
data
):
output
,
index
,
input
,
arg
=
data
self
.
save_for_backward
(
index
,
arg
)
def
backward_step
(
self
,
*
data
):
grad
,
index
,
arg
=
data
return
index_backward
(
self
.
dim
,
index
.
data
,
grad
,
arg
.
data
)
def
scatter_max_
(
output
,
index
,
input
,
dim
=
0
):
r
"""
|
...
...
@@ -61,7 +75,7 @@ def scatter_max_(output, index, input, dim=0):
)
"""
arg
=
gen_filled_tensor
(
index
,
output
.
size
(),
fill_value
=-
1
)
return
scatter
(
'max'
,
dim
,
output
,
index
,
input
,
arg
)
return
scatter
(
ScatterMax
,
'max'
,
dim
,
output
,
index
,
input
,
arg
)
def
scatter_max
(
index
,
input
,
dim
=
0
,
size
=
None
,
fill_value
=
0
):
...
...
torch_scatter/functions/mean.py
View file @
f0fdfe20
from
__future__
import
division
from
.scatter
import
scatter
from
.scatter
import
Scatter
,
scatter
from
.utils
import
gen_filled_tensor
,
gen_output
class
ScatterMean
(
Scatter
):
def
__init__
(
self
,
dim
):
super
(
ScatterMean
,
self
).
__init__
(
'mean'
,
dim
)
def
save_for_backward_step
(
self
,
*
data
):
output
,
index
,
input
,
count
=
data
self
.
save_for_backward
(
index
)
def
backward_step
(
self
,
*
data
):
grad
,
index
=
data
return
grad
.
gather
(
self
.
dim
,
index
.
data
)
def
scatter_mean_
(
output
,
index
,
input
,
dim
=
0
):
r
"""
|
...
...
@@ -56,10 +69,12 @@ def scatter_mean_(output, index, input, dim=0):
1.0000 4.0000 2.0000 0.0000 0.0000 0.0000
[torch.FloatTensor of size 2x6]
"""
init
=
gen_filled_tensor
(
output
,
output
.
size
(),
fill_value
=
0
)
count
=
gen_filled_tensor
(
output
,
output
.
size
(),
fill_value
=
0
)
scatter
(
'mean'
,
dim
,
outpu
t
,
index
,
input
,
count
)
scatter
(
ScatterMean
,
'mean'
,
dim
,
ini
t
,
index
,
input
,
count
)
count
[
count
==
0
]
=
1
output
/=
count
init
/=
count
output
+=
init
return
output
...
...
torch_scatter/functions/min.py
View file @
f0fdfe20
from
.scatter
import
scatter
from
.scatter
import
Scatter
,
scatter
from
.ffi
import
index_backward
from
.utils
import
gen_filled_tensor
,
gen_output
class
ScatterMin
(
Scatter
):
def
__init__
(
self
,
dim
):
super
(
ScatterMin
,
self
).
__init__
(
'min'
,
dim
)
def
save_for_backward_step
(
self
,
*
data
):
output
,
index
,
input
,
arg
=
data
self
.
save_for_backward
(
index
,
arg
)
def
backward_step
(
self
,
*
data
):
grad
,
index
,
arg
=
data
return
index_backward
(
self
.
dim
,
index
.
data
,
grad
,
arg
.
data
)
def
scatter_min_
(
output
,
index
,
input
,
dim
=
0
):
r
"""
|
...
...
@@ -61,7 +75,7 @@ def scatter_min_(output, index, input, dim=0):
)
"""
arg
=
gen_filled_tensor
(
index
,
output
.
size
(),
fill_value
=-
1
)
return
scatter
(
'min'
,
dim
,
output
,
index
,
input
,
arg
)
return
scatter
(
ScatterMin
,
'min'
,
dim
,
output
,
index
,
input
,
arg
)
def
scatter_min
(
index
,
input
,
dim
=
0
,
size
=
None
,
fill_value
=
0
):
...
...
torch_scatter/functions/mul.py
View file @
f0fdfe20
from
.scatter
import
scatter
from
.scatter
import
Scatter
,
scatter
from
.utils
import
gen_output
class
ScatterMul
(
Scatter
):
def
__init__
(
self
,
dim
):
super
(
ScatterMul
,
self
).
__init__
(
'mul'
,
dim
)
def
save_for_backward_step
(
self
,
*
data
):
output
,
index
,
input
=
data
self
.
save_for_backward
(
output
,
index
,
input
)
def
backward_step
(
self
,
*
data
):
grad
,
output
,
index
,
input
=
data
return
(
grad
*
output
.
data
).
gather
(
self
.
dim
,
index
.
data
)
/
input
.
data
def
scatter_mul_
(
output
,
index
,
input
,
dim
=
0
):
r
"""
|
...
...
@@ -52,7 +65,7 @@ def scatter_mul_(output, index, input, dim=0):
6 4 8 1 1 1
[torch.FloatTensor of size 2x6]
"""
return
scatter
(
'mul'
,
dim
,
output
,
index
,
input
)
return
scatter
(
ScatterMul
,
'mul'
,
dim
,
output
,
index
,
input
)
def
scatter_mul
(
index
,
input
,
dim
=
0
,
size
=
None
,
fill_value
=
1
):
...
...
torch_scatter/functions/scatter.py
View file @
f0fdfe20
from
itertools
import
chain
import
torch
from
torch.autograd
import
Function
from
.._ext
import
ffi
def
has_arg
(
name
):
return
name
in
[
'max'
,
'min'
]
def
_scatter
(
name
,
dim
,
*
data
):
a
,
b
,
c
=
data
[:
3
]
# Assert index dimension is valid.
assert
dim
>=
0
and
dim
<
a
.
dim
(),
'Index dimension is out of bounds'
# Assert same dimensionality across all inputs.
assert
b
.
dim
()
==
c
.
dim
(),
(
'Index tensor must have same dimensions as '
'input tensor'
)
assert
a
.
dim
()
==
c
.
dim
(),
(
'Input tensor must have same dimensions as '
'output tensor'
)
# Assert same tensor length across index and input.
assert
b
.
numel
()
==
c
.
numel
(),
(
'Index tensor must have same size as '
'input tensor'
)
from
.ffi
import
scatter
as
ffi_scatter
# Assert same tensor sizes across input and output apart from `dim`.
for
d
in
chain
(
range
(
dim
),
range
(
dim
+
1
,
a
.
dim
())):
assert
a
.
size
(
d
)
==
c
.
size
(
d
),
(
'Input tensor must have same size as output tensor apart from the '
'specified dimension'
)
typename
=
type
(
data
[
0
]).
__name__
.
replace
(
'Tensor'
,
''
)
cuda
=
'cuda_'
if
data
[
0
].
is_cuda
else
''
func
=
getattr
(
ffi
,
'scatter_{}_{}{}'
.
format
(
name
,
cuda
,
typename
))
func
(
dim
,
*
data
)
return
(
data
[
0
],
data
[
3
])
if
has_arg
(
name
)
else
data
[
0
]
def
index_backward
(
dim
,
index
,
grad
,
arg
):
# pragma: no cover
typename
=
type
(
grad
).
__name__
.
replace
(
'Tensor'
,
''
)
cuda
=
'cuda_'
if
grad
.
is_cuda
else
''
func
=
getattr
(
ffi
,
'index_backward_{}{}'
.
format
(
cuda
,
typename
))
output
=
grad
.
new
(
index
.
size
()).
fill_
(
0
)
func
(
dim
,
output
,
index
,
grad
,
arg
)
return
output
class
_Scatter
(
Function
):
class
Scatter
(
Function
):
def
__init__
(
self
,
name
,
dim
):
super
(
_
Scatter
,
self
).
__init__
()
super
(
Scatter
,
self
).
__init__
()
self
.
name
=
name
self
.
dim
=
dim
def
save_for_backward_step
(
self
,
*
data
):
raise
NotImplementedError
def
forward
(
self
,
*
data
):
assert
not
self
.
needs_input_grad
[
1
],
'Can
\'
t differentiate the index'
self
.
mark_dirty
(
data
[
0
])
# Mark output as dirty.
self
.
len
=
len
(
data
)
# Save number of arguments for backward step.
_scatter
(
self
.
name
,
self
.
dim
,
*
data
)
output
=
ffi_scatter
(
self
.
name
,
self
.
dim
,
*
data
)
self
.
save_for_backward_step
(
*
data
)
# `scatter_min` and `scatter_max` additionally return the `argmax`
# respectively `argmin`. Therefore, we need to save the `arg` for the
# backward pass.
if
has_arg
(
self
.
name
):
self
.
save_for_backward
(
data
[
1
],
data
[
3
])
return
data
[
0
],
data
[
3
]
else
:
self
.
save_for_backward
(
data
[
1
])
return
data
[
0
]
return
output
def
backward
(
self
,
*
data
):
# pragma: no cover
grad_output
=
grad_input
=
None
...
...
@@ -78,24 +30,53 @@ class _Scatter(Function):
if
self
.
needs_input_grad
[
0
]:
grad_output
=
data
[
0
]
# Different grad computation of `input` if `scatter_max` or
# `scatter_min` was used.
if
self
.
needs_input_grad
[
2
]
and
not
has_arg
(
self
.
name
):
index
,
=
self
.
saved_variables
grad_input
=
data
[
0
].
gather
(
self
.
dim
,
index
.
data
)
if
self
.
needs_input_grad
[
2
]
and
has_arg
(
self
.
name
):
index
,
arg
=
self
.
saved_variables
data
=
(
index
.
data
,
data
[
0
],
arg
.
data
)
grad_input
=
index_backward
(
self
.
dim
,
*
data
)
# Call grad computation of `input` for the specific scatter operation.
if
self
.
needs_input_grad
[
2
]:
grad_input
=
self
.
backward_step
(
data
[
0
],
*
self
.
saved_variables
)
# Return and fill with empty grads for non-differentiable passed
# arguments in forward pass.
# Return and fill with empty grads for non-differentiable arguments.
return
(
grad_output
,
None
,
grad_input
)
+
(
None
,
)
*
(
self
.
len
-
3
)
def
backward_step
(
self
,
*
data
):
raise
NotImplementedError
def
scatter
(
name
,
dim
,
*
data
):
def
scatter
(
Clx
,
name
,
dim
,
*
data
):
if
torch
.
is_tensor
(
data
[
0
]):
return
_scatter
(
name
,
dim
,
*
data
)
return
ffi
_scatter
(
name
,
dim
,
*
data
)
else
:
return
_Scatter
(
name
,
dim
)(
*
data
)
return
Clx
(
dim
)(
*
data
)
# def index_backward(dim, index, grad, arg): # pragma: no cover
# typename = type(grad).__name__.replace('Tensor', '')
# cuda = 'cuda_' if grad.is_cuda else ''
# func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename))
# output = grad.new(index.size()).fill_(0)
# func(dim, output, index, grad, arg)
# return output
# def _scatter_backward(name, dim, saved, *data):
# # saved = (index, ), (index, arg) or (index, count)
# print(name)
# print(len(data))
# print(len(saved))
# print(saved[1].size())
# # data = (grad, )
# # index, = seved
# if has_arg(name):
# return index_backward(dim, saved[0].data, data[0], saved[1].data)
# if has_count(name):
# return (data[0] / saved[1]).gather(dim, saved[0].data)
# # Different grad computation of `input` if `scatter_max` or
# # `scatter_min` was used.
# # if self.needs_input_grad[2] and not has_arg(self.name):
# # index, = self.saved_variables
# # grad_input = data[0].gather(self.dim, index.data)
# # if self.needs_input_grad[2] and has_arg(self.name):
# # index, arg = self.saved_variables
# # data = (index.data, data[0], arg.data)
# grad_input = index_backward(self.dim, *data)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment