Commit 8ec6d0c6 authored by rusty1s's avatar rusty1s
Browse files

convert size to a list

parent 6a2cc2e7
......@@ -11,7 +11,7 @@ def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
dim_size: Optional[int] = None) -> torch.Tensor:
index = broadcast(index, src, dim)
if out is None:
size = src.size()
size = list(src.size())
if dim_size is not None:
size[dim] = dim_size
elif index.numel() == 0:
......@@ -57,18 +57,18 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
@torch.jit.script
def scatter_min(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
def scatter_min(
src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)
@torch.jit.script
def scatter_max(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
def scatter_max(
src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment