"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9c1d4e3be1580b3174cb0eb099a135aeb55a807c"
Commit c3bacf3b authored by rusty1s's avatar rusty1s
Browse files

added assertions

parent 0466dd06
from itertools import chain
import torch import torch
from torch.autograd import Function from torch.autograd import Function
...@@ -5,6 +7,25 @@ from .._ext import ffi ...@@ -5,6 +7,25 @@ from .._ext import ffi
def _scatter(name, dim, *data): def _scatter(name, dim, *data):
a, b, c = data[:3]
# Assert same dimensionality across all inputs.
assert dim >= 0 and dim < a.dim(), 'Index dimension is out of bounds'
assert b.dim() == c.dim(), ('Index tensor must have same dimensions as '
'input tensor')
assert a.dim() == c.dim(), ('Input tensor must have same dimensions as '
'output tensor')
# Assert same tensor length across index and input.
assert b.numel() == c.numel(), ('Index tensor must have same size as '
'input tensor')
# Assert same tensor sizes across input and output apart from `dim`.
for d in chain(range(dim), range(dim + 1, a.dim())):
assert a.size(d) == c.size(d), (
'Input tensor must have same size as output tensor apart from the '
'specified dimension')
typename = type(data[0]).__name__.replace('Tensor', '') typename = type(data[0]).__name__.replace('Tensor', '')
func = getattr(ffi, 'scatter_{}_{}'.format(name, typename)) func = getattr(ffi, 'scatter_{}_{}'.format(name, typename))
func(dim, *data) func(dim, *data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment