Commit 7c46799e authored by rusty1s's avatar rusty1s
Browse files

doc fixes

parent 85940068
...@@ -21,7 +21,7 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -21,7 +21,7 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
dim_size = out.size(dim) dim_size = out.size(dim)
else: else:
if dim_size is None: if dim_size is None:
dim_size = int(index.max().item() + 1) dim_size = int(index.max()) + 1
size = src.size() size = src.size()
size[dim] = dim_size size[dim] = dim_size
......
...@@ -10,20 +10,23 @@ try: ...@@ -10,20 +10,23 @@ try:
except OSError: except OSError:
warnings.warn('Failed to load `scatter` binaries.') warnings.warn('Failed to load `scatter` binaries.')
def placeholder(src: torch.Tensor, index: torch.Tensor, dim: int, def scatter_placeholder(src: torch.Tensor, index: torch.Tensor, dim: int,
out: Optional[torch.Tensor], out: Optional[torch.Tensor],
dim_size: Optional[int]) -> torch.Tensor: dim_size: Optional[int]) -> torch.Tensor:
raise ImportError raise ImportError
return src
def arg_placeholder(src: torch.Tensor, index: torch.Tensor, dim: int, def scatter_with_arg_placeholder(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor], dim_size: Optional[int] dim: int, out: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: dim_size: Optional[int]
) -> Tuple[torch.Tensor, torch.Tensor]:
raise ImportError raise ImportError
return src, index
torch.ops.torch_scatter.scatter_sum = placeholder torch.ops.torch_scatter.scatter_sum = scatter_placeholder
torch.ops.torch_scatter.scatter_mean = placeholder torch.ops.torch_scatter.scatter_mean = scatter_placeholder
torch.ops.torch_scatter.scatter_min = arg_placeholder torch.ops.torch_scatter.scatter_min = scatter_with_arg_placeholder
torch.ops.torch_scatter.scatter_max = arg_placeholder torch.ops.torch_scatter.scatter_max = scatter_with_arg_placeholder
@torch.jit.script @torch.jit.script
......
...@@ -14,16 +14,19 @@ except OSError: ...@@ -14,16 +14,19 @@ except OSError:
out: Optional[torch.Tensor], out: Optional[torch.Tensor],
dim_size: Optional[int]) -> torch.Tensor: dim_size: Optional[int]) -> torch.Tensor:
raise ImportError raise ImportError
return src
def segment_coo_with_arg_placeholder( def segment_coo_with_arg_placeholder(
src: torch.Tensor, index: torch.Tensor, src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor], out: Optional[torch.Tensor],
dim_size: Optional[int]) -> Tuple[torch.Tensor, torch.Tensor]: dim_size: Optional[int]) -> Tuple[torch.Tensor, torch.Tensor]:
raise ImportError raise ImportError
return src, index
def gather_coo_placeholder(src: torch.Tensor, index: torch.Tensor, def gather_coo_placeholder(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor]) -> torch.Tensor: out: Optional[torch.Tensor]) -> torch.Tensor:
raise ImportError raise ImportError
return src
torch.ops.torch_scatter.segment_sum_coo = segment_coo_placeholder torch.ops.torch_scatter.segment_sum_coo = segment_coo_placeholder
torch.ops.torch_scatter.segment_mean_coo = segment_coo_placeholder torch.ops.torch_scatter.segment_mean_coo = segment_coo_placeholder
......
...@@ -13,15 +13,18 @@ except OSError: ...@@ -13,15 +13,18 @@ except OSError:
def segment_csr_placeholder(src: torch.Tensor, indptr: torch.Tensor, def segment_csr_placeholder(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor]) -> torch.Tensor: out: Optional[torch.Tensor]) -> torch.Tensor:
raise ImportError raise ImportError
return src
def segment_csr_with_arg_placeholder( def segment_csr_with_arg_placeholder(
src: torch.Tensor, indptr: torch.Tensor, src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: out: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
raise ImportError raise ImportError
return src, indptr
def gather_csr_placeholder(src: torch.Tensor, indptr: torch.Tensor, def gather_csr_placeholder(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor]) -> torch.Tensor: out: Optional[torch.Tensor]) -> torch.Tensor:
raise ImportError raise ImportError
return src
torch.ops.torch_scatter.segment_sum_csr = segment_csr_placeholder torch.ops.torch_scatter.segment_sum_csr = segment_csr_placeholder
torch.ops.torch_scatter.segment_mean_csr = segment_csr_placeholder torch.ops.torch_scatter.segment_mean_csr = segment_csr_placeholder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment