Commit 6a41c3e7 authored by zhuww's avatar zhuww
Browse files

add scatter param

parent 16401cf5
...@@ -136,11 +136,11 @@ class Copy(torch.autograd.Function): ...@@ -136,11 +136,11 @@ class Copy(torch.autograd.Function):
return _reduce(grad_output) return _reduce(grad_output)
def scatter(input: Tensor, dim: int = -1) -> Tensor: def scatter(input: Tensor, dim: int = -1, drop_unused=False) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad: if torch.is_grad_enabled() and input.requires_grad:
input = Scatter.apply(input, dim) input = Scatter.apply(input, dim)
else: else:
input = _split(input, dim=dim) input = _split(input, dim=dim, drop_unused=drop_unused)
return input return input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment