Commit bf1f1014 authored by rusty1s's avatar rusty1s
Browse files

use scatter add pytorch implementation

parent 1006514c
...@@ -217,7 +217,7 @@ if __name__ == '__main__': ...@@ -217,7 +217,7 @@ if __name__ == '__main__':
parser.add_argument('--with_backward', action='store_true') parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args() args = parser.parse_args()
iters = 1 if args.device == 'cpu' else 50 iters = 1 if args.device == 'cpu' else 20
sizes = [1, 16, 32, 64, 128, 256, 512] sizes = [1, 16, 32, 64, 128, 256, 512]
sizes = sizes[:3] if args.device == 'cpu' else sizes sizes = sizes[:3] if args.device == 'cpu' else sizes
......
import torch
from torch_scatter import scatter
def test_zero_elements():
x = torch.randn(0, 16)
index = torch.tensor([]).view(0, 16)
print(x)
print(index)
scatter(x, index, dim=0, dim_size=0, reduce="add")
...@@ -3,7 +3,7 @@ from typing import Optional ...@@ -3,7 +3,7 @@ from typing import Optional
import torch import torch
from torch_scatter import scatter_sum, scatter_max from torch_scatter import scatter_sum, scatter_max
from .utils import broadcast from torch_scatter.utils import broadcast
@torch.jit.script @torch.jit.script
......
import torch import torch
from torch_scatter import scatter_sum, scatter_max from torch_scatter import scatter_sum, scatter_max
from torch_scatter.utils import broadcast
from .utils import broadcast
@torch.jit.script @torch.jit.script
......
...@@ -2,8 +2,7 @@ from typing import Optional ...@@ -2,8 +2,7 @@ from typing import Optional
import torch import torch
from torch_scatter import scatter_sum from torch_scatter import scatter_sum
from torch_scatter.utils import broadcast
from .utils import broadcast
@torch.jit.script @torch.jit.script
......
...@@ -4,6 +4,8 @@ from typing import Optional, Tuple ...@@ -4,6 +4,8 @@ from typing import Optional, Tuple
import torch import torch
from .utils import broadcast
try: try:
torch.ops.load_library( torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_scatter.so')) osp.join(osp.dirname(osp.abspath(__file__)), '_scatter.so'))
...@@ -23,7 +25,6 @@ except OSError: ...@@ -23,7 +25,6 @@ except OSError:
raise ImportError raise ImportError
return src, index return src, index
torch.ops.torch_scatter.scatter_sum = scatter_placeholder
torch.ops.torch_scatter.scatter_mean = scatter_placeholder torch.ops.torch_scatter.scatter_mean = scatter_placeholder
torch.ops.torch_scatter.scatter_min = scatter_with_arg_placeholder torch.ops.torch_scatter.scatter_min = scatter_with_arg_placeholder
torch.ops.torch_scatter.scatter_max = scatter_with_arg_placeholder torch.ops.torch_scatter.scatter_max = scatter_with_arg_placeholder
...@@ -33,14 +34,24 @@ except OSError: ...@@ -33,14 +34,24 @@ except OSError:
def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1, def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor: dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.scatter_sum(src, index, dim, out, dim_size) index = broadcast(index, src, dim)
if out is None:
size = src.size()
if dim_size is None:
size[dim] = int(index.max()) + 1
else:
size[dim] = dim_size
out = src.new_zeros(size)
return out.scatter_add_(dim, index, src)
else:
return out.scatter_add_(dim, index, src)
@torch.jit.script @torch.jit.script
def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor: dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.scatter_sum(src, index, dim, out, dim_size) return scatter_sum(src, index, dim, out, dim_size)
@torch.jit.script @torch.jit.script
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment