Commit c3bacf3b authored by rusty1s's avatar rusty1s
Browse files

added assertions

parent 0466dd06
from itertools import chain
import torch
from torch.autograd import Function
......@@ -5,6 +7,25 @@ from .._ext import ffi
def _scatter(name, dim, *data):
a, b, c = data[:3]
# Assert same dimensionality across all inputs.
assert dim >= 0 and dim < a.dim(), 'Index dimension is out of bounds'
assert b.dim() == c.dim(), ('Index tensor must have same dimensions as '
'input tensor')
assert a.dim() == c.dim(), ('Input tensor must have same dimensions as '
'output tensor')
# Assert same tensor length across index and input.
assert b.numel() == c.numel(), ('Index tensor must have same size as '
'input tensor')
# Assert same tensor sizes across input and output apart from `dim`.
for d in chain(range(dim), range(dim + 1, a.dim())):
assert a.size(d) == c.size(d), (
'Input tensor must have same size as output tensor apart from the '
'specified dimension')
typename = type(data[0]).__name__.replace('Tensor', '')
func = getattr(ffi, 'scatter_{}_{}'.format(name, typename))
func(dim, *data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment