Commit 61eb4d03 authored by rusty1s's avatar rusty1s
Browse files

no coverage

parent 8099c537
......@@ -2,7 +2,7 @@ from .scatter import Scatter, scatter
from .utils import gen_output
class ScatterDiv(Scatter):
class ScatterDiv(Scatter): # pragma: no cover
def __init__(self, dim):
super(ScatterDiv, self).__init__('div', dim)
......
......@@ -11,7 +11,7 @@ class ScatterMax(Scatter):
output, index, input, arg = data
self.save_for_backward(index, arg)
def backward_step(self, *data):
def backward_step(self, *data): # pragma: no cover
grad, index, arg = data
return index_backward(self.dim, index.data, grad, arg.data)
......
......@@ -12,7 +12,7 @@ class ScatterMean(Scatter):
output, index, input, count = data
self.save_for_backward(index)
def backward_step(self, *data):
def backward_step(self, *data): # pragma: no cover
grad, index = data
return grad.gather(self.dim, index.data)
......
......@@ -11,7 +11,7 @@ class ScatterMin(Scatter):
output, index, input, arg = data
self.save_for_backward(index, arg)
def backward_step(self, *data):
def backward_step(self, *data): # pragma: no cover
grad, index, arg = data
return index_backward(self.dim, index.data, grad, arg.data)
......
......@@ -10,7 +10,7 @@ class ScatterMul(Scatter):
output, index, input = data
self.save_for_backward(output, index, input)
def backward_step(self, *data):
def backward_step(self, *data): # pragma: no cover
grad, output, index, input = data
return (grad * output.data).gather(self.dim, index.data) / input.data
......
......@@ -10,7 +10,7 @@ class Scatter(Function):
self.name = name
self.dim = dim
def save_for_backward_step(self, *data):
def save_for_backward_step(self, *data): # pragma: no cover
raise NotImplementedError
def forward(self, *data):
......@@ -37,7 +37,7 @@ class Scatter(Function):
# Return and fill with empty grads for non-differentiable arguments.
return (grad_output, None, grad_input) + (None, ) * (self.len - 3)
def backward_step(self, *data):
def backward_step(self, *data): # pragma: no cover
raise NotImplementedError
......@@ -46,37 +46,3 @@ def scatter(Clx, name, dim, *data):
return ffi_scatter(name, dim, *data)
else:
return Clx(dim)(*data)
# def index_backward(dim, index, grad, arg): # pragma: no cover
# typename = type(grad).__name__.replace('Tensor', '')
# cuda = 'cuda_' if grad.is_cuda else ''
# func = getattr(ffi, 'index_backward_{}{}'.format(cuda, typename))
# output = grad.new(index.size()).fill_(0)
# func(dim, output, index, grad, arg)
# return output
# def _scatter_backward(name, dim, saved, *data):
# # saved = (index, ), (index, arg) or (index, count)
# print(name)
# print(len(data))
# print(len(saved))
# print(saved[1].size())
# # data = (grad, )
# # index, = seved
# if has_arg(name):
# return index_backward(dim, saved[0].data, data[0], saved[1].data)
# if has_count(name):
# return (data[0] / saved[1]).gather(dim, saved[0].data)
# # Different grad computation of `input` if `scatter_max` or
# # `scatter_min` was used.
# # if self.needs_input_grad[2] and not has_arg(self.name):
# # index, = self.saved_variables
# # grad_input = data[0].gather(self.dim, index.data)
# # if self.needs_input_grad[2] and has_arg(self.name):
# # index, arg = self.saved_variables
# # data = (index.data, data[0], arg.data)
# grad_input = index_backward(self.dim, *data)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment