Commit 6a41c3e7 authored by zhuww's avatar zhuww
Browse files

add scatter param

parent 16401cf5
......@@ -136,11 +136,11 @@ class Copy(torch.autograd.Function):
return _reduce(grad_output)
def scatter(input: Tensor, dim: int = -1) -> Tensor:
def scatter(input: Tensor, dim: int = -1, drop_unused=False) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Scatter.apply(input, dim)
else:
input = _split(input, dim=dim)
input = _split(input, dim=dim, drop_unused=drop_unused)
return input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment