Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
61eb4d03
Commit
61eb4d03
authored
Feb 11, 2018
by
rusty1s
Browse files
no coverage
parent
8099c537
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
7 additions
and
41 deletions
+7
-41
torch_scatter/functions/div.py
torch_scatter/functions/div.py
+1
-1
torch_scatter/functions/max.py
torch_scatter/functions/max.py
+1
-1
torch_scatter/functions/mean.py
torch_scatter/functions/mean.py
+1
-1
torch_scatter/functions/min.py
torch_scatter/functions/min.py
+1
-1
torch_scatter/functions/mul.py
torch_scatter/functions/mul.py
+1
-1
torch_scatter/functions/scatter.py
torch_scatter/functions/scatter.py
+2
-36
No files found.
torch_scatter/functions/div.py
View file @
61eb4d03
...
@@ -2,7 +2,7 @@ from .scatter import Scatter, scatter
...
@@ -2,7 +2,7 @@ from .scatter import Scatter, scatter
from
.utils
import
gen_output
from
.utils
import
gen_output
class
ScatterDiv
(
Scatter
):
class
ScatterDiv
(
Scatter
):
# pragma: no cover
def
__init__
(
self
,
dim
):
def
__init__
(
self
,
dim
):
super
(
ScatterDiv
,
self
).
__init__
(
'div'
,
dim
)
super
(
ScatterDiv
,
self
).
__init__
(
'div'
,
dim
)
...
...
torch_scatter/functions/max.py
View file @
61eb4d03
...
@@ -11,7 +11,7 @@ class ScatterMax(Scatter):
...
@@ -11,7 +11,7 @@ class ScatterMax(Scatter):
output
,
index
,
input
,
arg
=
data
output
,
index
,
input
,
arg
=
data
self
.
save_for_backward
(
index
,
arg
)
self
.
save_for_backward
(
index
,
arg
)
def
backward_step
(
self
,
*
data
):
def
backward_step
(
self
,
*
data
):
# pragma: no cover
grad
,
index
,
arg
=
data
grad
,
index
,
arg
=
data
return
index_backward
(
self
.
dim
,
index
.
data
,
grad
,
arg
.
data
)
return
index_backward
(
self
.
dim
,
index
.
data
,
grad
,
arg
.
data
)
...
...
torch_scatter/functions/mean.py
View file @
61eb4d03
...
@@ -12,7 +12,7 @@ class ScatterMean(Scatter):
...
@@ -12,7 +12,7 @@ class ScatterMean(Scatter):
output
,
index
,
input
,
count
=
data
output
,
index
,
input
,
count
=
data
self
.
save_for_backward
(
index
)
self
.
save_for_backward
(
index
)
def
backward_step
(
self
,
*
data
):
def
backward_step
(
self
,
*
data
):
# pragma: no cover
grad
,
index
=
data
grad
,
index
=
data
return
grad
.
gather
(
self
.
dim
,
index
.
data
)
return
grad
.
gather
(
self
.
dim
,
index
.
data
)
...
...
torch_scatter/functions/min.py
View file @
61eb4d03
...
@@ -11,7 +11,7 @@ class ScatterMin(Scatter):
...
@@ -11,7 +11,7 @@ class ScatterMin(Scatter):
output
,
index
,
input
,
arg
=
data
output
,
index
,
input
,
arg
=
data
self
.
save_for_backward
(
index
,
arg
)
self
.
save_for_backward
(
index
,
arg
)
def
backward_step
(
self
,
*
data
):
def
backward_step
(
self
,
*
data
):
# pragma: no cover
grad
,
index
,
arg
=
data
grad
,
index
,
arg
=
data
return
index_backward
(
self
.
dim
,
index
.
data
,
grad
,
arg
.
data
)
return
index_backward
(
self
.
dim
,
index
.
data
,
grad
,
arg
.
data
)
...
...
torch_scatter/functions/mul.py
View file @
61eb4d03
...
@@ -10,7 +10,7 @@ class ScatterMul(Scatter):
...
@@ -10,7 +10,7 @@ class ScatterMul(Scatter):
output
,
index
,
input
=
data
output
,
index
,
input
=
data
self
.
save_for_backward
(
output
,
index
,
input
)
self
.
save_for_backward
(
output
,
index
,
input
)
def
backward_step
(
self
,
*
data
):
def
backward_step
(
self
,
*
data
):
# pragma: no cover
grad
,
output
,
index
,
input
=
data
grad
,
output
,
index
,
input
=
data
return
(
grad
*
output
.
data
).
gather
(
self
.
dim
,
index
.
data
)
/
input
.
data
return
(
grad
*
output
.
data
).
gather
(
self
.
dim
,
index
.
data
)
/
input
.
data
...
...
torch_scatter/functions/scatter.py
View file @
61eb4d03
...
@@ -10,7 +10,7 @@ class Scatter(Function):
...
@@ -10,7 +10,7 @@ class Scatter(Function):
self
.
name
=
name
self
.
name
=
name
self
.
dim
=
dim
self
.
dim
=
dim
def
save_for_backward_step
(
self
,
*
data
):
def
save_for_backward_step
(
self
,
*
data
):
# pragma: no cover
raise
NotImplementedError
raise
NotImplementedError
def
forward
(
self
,
*
data
):
def
forward
(
self
,
*
data
):
...
@@ -37,7 +37,7 @@ class Scatter(Function):
...
@@ -37,7 +37,7 @@ class Scatter(Function):
# Return and fill with empty grads for non-differentiable arguments.
# Return and fill with empty grads for non-differentiable arguments.
return
(
grad_output
,
None
,
grad_input
)
+
(
None
,
)
*
(
self
.
len
-
3
)
return
(
grad_output
,
None
,
grad_input
)
+
(
None
,
)
*
(
self
.
len
-
3
)
def
backward_step
(
self
,
*
data
):
def
backward_step
(
self
,
*
data
):
# pragma: no cover
raise
NotImplementedError
raise
NotImplementedError
...
@@ -46,37 +46,3 @@ def scatter(Clx, name, dim, *data):
...
@@ -46,37 +46,3 @@ def scatter(Clx, name, dim, *data):
return
ffi_scatter
(
name
,
dim
,
*
data
)
return
ffi_scatter
(
name
,
dim
,
*
data
)
else
:
else
:
return
Clx
(
dim
)(
*
data
)
return
Clx
(
dim
)(
*
data
)
# def index_backward(dim, index, grad, arg): # pragma: no cover
# typename = type(grad).__name__.replace('Tensor', '')
# cuda = 'cuda_' if grad.is_cuda else ''
# func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename))
# output = grad.new(index.size()).fill_(0)
# func(dim, output, index, grad, arg)
# return output
# def _scatter_backward(name, dim, saved, *data):
# # saved = (index, ), (index, arg) or (index, count)
# print(name)
# print(len(data))
# print(len(saved))
# print(saved[1].size())
# # data = (grad, )
# # index, = seved
# if has_arg(name):
# return index_backward(dim, saved[0].data, data[0], saved[1].data)
# if has_count(name):
# return (data[0] / saved[1]).gather(dim, saved[0].data)
# # Different grad computation of `input` if `scatter_max` or
# # `scatter_min` was used.
# # if self.needs_input_grad[2] and not has_arg(self.name):
# # index, = self.saved_variables
# # grad_input = data[0].gather(self.dim, index.data)
# # if self.needs_input_grad[2] and has_arg(self.name):
# # index, arg = self.saved_variables
# # data = (index.data, data[0], arg.data)
# grad_input = index_backward(self.dim, *data)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment