Commit 7c46799e authored by rusty1s's avatar rusty1s
Browse files

doc fixes

parent 85940068
......@@ -21,7 +21,7 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
dim_size = out.size(dim)
else:
if dim_size is None:
dim_size = int(index.max().item() + 1)
dim_size = int(index.max()) + 1
size = src.size()
size[dim] = dim_size
......
......@@ -10,20 +10,23 @@ try:
except OSError:
warnings.warn('Failed to load `scatter` binaries.')
def placeholder(src: torch.Tensor, index: torch.Tensor, dim: int,
out: Optional[torch.Tensor],
dim_size: Optional[int]) -> torch.Tensor:
def scatter_placeholder(src: torch.Tensor, index: torch.Tensor, dim: int,
out: Optional[torch.Tensor],
dim_size: Optional[int]) -> torch.Tensor:
raise ImportError
return src
def arg_placeholder(src: torch.Tensor, index: torch.Tensor, dim: int,
out: Optional[torch.Tensor], dim_size: Optional[int]
) -> Tuple[torch.Tensor, torch.Tensor]:
def scatter_with_arg_placeholder(src: torch.Tensor, index: torch.Tensor,
dim: int, out: Optional[torch.Tensor],
dim_size: Optional[int]
) -> Tuple[torch.Tensor, torch.Tensor]:
raise ImportError
return src, index
torch.ops.torch_scatter.scatter_sum = placeholder
torch.ops.torch_scatter.scatter_mean = placeholder
torch.ops.torch_scatter.scatter_min = arg_placeholder
torch.ops.torch_scatter.scatter_max = arg_placeholder
torch.ops.torch_scatter.scatter_sum = scatter_placeholder
torch.ops.torch_scatter.scatter_mean = scatter_placeholder
torch.ops.torch_scatter.scatter_min = scatter_with_arg_placeholder
torch.ops.torch_scatter.scatter_max = scatter_with_arg_placeholder
@torch.jit.script
......
......@@ -14,16 +14,19 @@ except OSError:
out: Optional[torch.Tensor],
dim_size: Optional[int]) -> torch.Tensor:
raise ImportError
return src
def segment_coo_with_arg_placeholder(
src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor],
dim_size: Optional[int]) -> Tuple[torch.Tensor, torch.Tensor]:
raise ImportError
return src, index
def gather_coo_placeholder(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor]) -> torch.Tensor:
raise ImportError
return src
torch.ops.torch_scatter.segment_sum_coo = segment_coo_placeholder
torch.ops.torch_scatter.segment_mean_coo = segment_coo_placeholder
......
......@@ -13,15 +13,18 @@ except OSError:
def segment_csr_placeholder(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor]) -> torch.Tensor:
raise ImportError
return src
def segment_csr_with_arg_placeholder(
src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
raise ImportError
return src, indptr
def gather_csr_placeholder(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor]) -> torch.Tensor:
raise ImportError
return src
torch.ops.torch_scatter.segment_sum_csr = segment_csr_placeholder
torch.ops.torch_scatter.segment_mean_csr = segment_csr_placeholder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment