Commit 61eb4d03 authored by rusty1s's avatar rusty1s
Browse files

no coverage

parent 8099c537
...@@ -2,7 +2,7 @@ from .scatter import Scatter, scatter ...@@ -2,7 +2,7 @@ from .scatter import Scatter, scatter
from .utils import gen_output from .utils import gen_output
class ScatterDiv(Scatter): class ScatterDiv(Scatter): # pragma: no cover
def __init__(self, dim): def __init__(self, dim):
super(ScatterDiv, self).__init__('div', dim) super(ScatterDiv, self).__init__('div', dim)
......
...@@ -11,7 +11,7 @@ class ScatterMax(Scatter): ...@@ -11,7 +11,7 @@ class ScatterMax(Scatter):
output, index, input, arg = data output, index, input, arg = data
self.save_for_backward(index, arg) self.save_for_backward(index, arg)
def backward_step(self, *data): def backward_step(self, *data): # pragma: no cover
grad, index, arg = data grad, index, arg = data
return index_backward(self.dim, index.data, grad, arg.data) return index_backward(self.dim, index.data, grad, arg.data)
......
...@@ -12,7 +12,7 @@ class ScatterMean(Scatter): ...@@ -12,7 +12,7 @@ class ScatterMean(Scatter):
output, index, input, count = data output, index, input, count = data
self.save_for_backward(index) self.save_for_backward(index)
def backward_step(self, *data): def backward_step(self, *data): # pragma: no cover
grad, index = data grad, index = data
return grad.gather(self.dim, index.data) return grad.gather(self.dim, index.data)
......
...@@ -11,7 +11,7 @@ class ScatterMin(Scatter): ...@@ -11,7 +11,7 @@ class ScatterMin(Scatter):
output, index, input, arg = data output, index, input, arg = data
self.save_for_backward(index, arg) self.save_for_backward(index, arg)
def backward_step(self, *data): def backward_step(self, *data): # pragma: no cover
grad, index, arg = data grad, index, arg = data
return index_backward(self.dim, index.data, grad, arg.data) return index_backward(self.dim, index.data, grad, arg.data)
......
...@@ -10,7 +10,7 @@ class ScatterMul(Scatter): ...@@ -10,7 +10,7 @@ class ScatterMul(Scatter):
output, index, input = data output, index, input = data
self.save_for_backward(output, index, input) self.save_for_backward(output, index, input)
def backward_step(self, *data): def backward_step(self, *data): # pragma: no cover
grad, output, index, input = data grad, output, index, input = data
return (grad * output.data).gather(self.dim, index.data) / input.data return (grad * output.data).gather(self.dim, index.data) / input.data
......
...@@ -10,7 +10,7 @@ class Scatter(Function): ...@@ -10,7 +10,7 @@ class Scatter(Function):
self.name = name self.name = name
self.dim = dim self.dim = dim
def save_for_backward_step(self, *data): def save_for_backward_step(self, *data): # pragma: no cover
raise NotImplementedError raise NotImplementedError
def forward(self, *data): def forward(self, *data):
...@@ -37,7 +37,7 @@ class Scatter(Function): ...@@ -37,7 +37,7 @@ class Scatter(Function):
# Return and fill with empty grads for non-differentiable arguments. # Return and fill with empty grads for non-differentiable arguments.
return (grad_output, None, grad_input) + (None, ) * (self.len - 3) return (grad_output, None, grad_input) + (None, ) * (self.len - 3)
def backward_step(self, *data): def backward_step(self, *data): # pragma: no cover
raise NotImplementedError raise NotImplementedError
...@@ -46,37 +46,3 @@ def scatter(Clx, name, dim, *data): ...@@ -46,37 +46,3 @@ def scatter(Clx, name, dim, *data):
return ffi_scatter(name, dim, *data) return ffi_scatter(name, dim, *data)
else: else:
return Clx(dim)(*data) return Clx(dim)(*data)
# def index_backward(dim, index, grad, arg): # pragma: no cover
# typename = type(grad).__name__.replace('Tensor', '')
# cuda = 'cuda_' if grad.is_cuda else ''
# func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename))
# output = grad.new(index.size()).fill_(0)
# func(dim, output, index, grad, arg)
# return output
# def _scatter_backward(name, dim, saved, *data):
# # saved = (index, ), (index, arg) or (index, count)
# print(name)
# print(len(data))
# print(len(saved))
# print(saved[1].size())
# # data = (grad, )
# # index, = seved
# if has_arg(name):
# return index_backward(dim, saved[0].data, data[0], saved[1].data)
# if has_count(name):
# return (data[0] / saved[1]).gather(dim, saved[0].data)
# # Different grad computation of `input` if `scatter_max` or
# # `scatter_min` was used.
# # if self.needs_input_grad[2] and not has_arg(self.name):
# # index, = self.saved_variables
# # grad_input = data[0].gather(self.dim, index.data)
# # if self.needs_input_grad[2] and has_arg(self.name):
# # index, arg = self.saved_variables
# # data = (index.data, data[0], arg.data)
# grad_input = index_backward(self.dim, *data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment