import os.path as osp from typing import Optional, Tuple import torch torch.ops.load_library( osp.join(osp.dirname(osp.abspath(__file__)), '_scatter.so')) @torch.jit.script def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None) -> torch.Tensor: return torch.ops.torch_scatter.scatter_sum(src, index, dim, out, dim_size) @torch.jit.script def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None) -> torch.Tensor: return torch.ops.torch_scatter.scatter_sum(src, index, dim, out, dim_size) @torch.jit.script def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None) -> torch.Tensor: return torch.ops.torch_scatter.scatter_mean(src, index, dim, out, dim_size) @torch.jit.script def scatter_min(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None ) -> Tuple[torch.Tensor, torch.Tensor]: return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size) @torch.jit.script def scatter_max(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None ) -> Tuple[torch.Tensor, torch.Tensor]: return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size) @torch.jit.script def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, reduce: str = "sum") -> torch.Tensor: if reduce == 'sum' or reduce == 'add': return scatter_sum(src, index, dim, out, dim_size) elif reduce == 'mean': return scatter_mean(src, index, dim, out, dim_size) elif reduce == 'min': return scatter_min(src, index, dim, out, dim_size)[0] elif reduce == 'max': return scatter_max(src, index, dim, out, dim_size)[0] else: raise ValueError