Commit 7d5034ba authored by rusty1s's avatar rusty1s
Browse files
parents b5c89536 efac08dd
...@@ -5,3 +5,4 @@ exclude_lines = ...@@ -5,3 +5,4 @@ exclude_lines =
pragma: no cover pragma: no cover
cuda cuda
backward backward
raise
...@@ -35,7 +35,7 @@ The package consists of the following operations: ...@@ -35,7 +35,7 @@ The package consists of the following operations:
* [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html) * [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html)
* [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html) * [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html)
All included operations work on varying data types, are implemented both for CPU and GPU and include a backwards implementation. All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
## Installation ## Installation
......
...@@ -6,7 +6,7 @@ PyTorch Scatter Documentation ...@@ -6,7 +6,7 @@ PyTorch Scatter Documentation
This package consists of a small extension library of highly optimized sparse update (scatter) operations for the use in `PyTorch <http://pytorch.org/>`_, which are missing in the main package. This package consists of a small extension library of highly optimized sparse update (scatter) operations for the use in `PyTorch <http://pytorch.org/>`_, which are missing in the main package.
Scatter operations can be roughly described as reduce operations based on a given "group-index" tensor. Scatter operations can be roughly described as reduce operations based on a given "group-index" tensor.
All included operations work on varying data types, are implemented both for CPU and GPU and include a backwards implementation. All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
.. toctree:: .. toctree::
:glob: :glob:
......
import pytest
import torch
from torch_scatter import scatter_add
from .utils import devices
@pytest.mark.parametrize('device', devices)
def test_broadcasting(device):
B, C, H, W = (4, 3, 8, 8)
src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H)
assert out.size() == (B, C, H, W)
src = torch.randn((B, 1, H, W), device=device)
index = torch.randint(0, H, (B, C, H, W)).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H)
assert out.size() == (B, C, H, W)
src = torch.randn((B, 1, H, W), device=device)
index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H)
assert out.size() == (B, 1, H, W)
src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (H, )).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H)
assert out.size() == (B, C, H, W)
...@@ -24,6 +24,8 @@ def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0): ...@@ -24,6 +24,8 @@ def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
:attr:`dim` = `i`, then :attr:`out` must be an n-dimensional tensor with :attr:`dim` = `i`, then :attr:`out` must be an n-dimensional tensor with
size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the
values of :attr:`index` must be between `0` and `out.size(dim) - 1`. values of :attr:`index` must be between `0` and `out.size(dim) - 1`.
Both :attr:`src` and :attr:`index` are broadcasted in case their dimensions
do not match.
For one-dimensional tensors, the operation computes For one-dimensional tensors, the operation computes
......
from __future__ import division
from itertools import repeat from itertools import repeat
import torch
def maybe_dim_size(index, dim_size=None): def maybe_dim_size(index, dim_size=None):
if dim_size is not None: if dim_size is not None:
...@@ -14,7 +18,22 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0): ...@@ -14,7 +18,22 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
if index.dim() == 1: if index.dim() == 1:
index_size = list(repeat(1, src.dim())) index_size = list(repeat(1, src.dim()))
index_size[dim] = src.size(dim) index_size[dim] = src.size(dim)
index = index.view(index_size).expand_as(src) if index.numel() > 0:
index = index.view(index_size).expand_as(src)
else: # PyTorch has a bug when view is used on zero-element tensors.
index = src.new_empty(index_size, dtype=torch.long)
# Broadcasting capabilties: Expand dimensions to match.
if src.dim() != index.dim():
raise ValueError(
('Number of dimensions of src and index tensor do not match, '
'got {} and {}').format(src.dim(), index.dim()))
expand_size = []
for s, i in zip(src.size(), index.size()):
expand_size += [-1 if s == i and s != 1 and i != 1 else max(i, s)]
src = src.expand(expand_size)
index = index.expand_as(src)
# Generate output tensor if not given. # Generate output tensor if not given.
if out is None: if out is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment