Commit d80cd24d authored by rusty1s's avatar rusty1s
Browse files

broadcasting capabilities

parent d8470875
...@@ -35,7 +35,7 @@ The package consists of the following operations: ...@@ -35,7 +35,7 @@ The package consists of the following operations:
* [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html) * [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html)
* [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html) * [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html)
All included operations work on varying data types, are implemented both for CPU and GPU and include a backwards implementation. All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
## Installation ## Installation
......
...@@ -6,7 +6,7 @@ PyTorch Scatter Documentation ...@@ -6,7 +6,7 @@ PyTorch Scatter Documentation
This package consists of a small extension library of highly optimized sparse update (scatter) operations for the use in `PyTorch <http://pytorch.org/>`_, which are missing in the main package. This package consists of a small extension library of highly optimized sparse update (scatter) operations for the use in `PyTorch <http://pytorch.org/>`_, which are missing in the main package.
Scatter operations can be roughly described as reduce operations based on a given "group-index" tensor. Scatter operations can be roughly described as reduce operations based on a given "group-index" tensor.
All included operations work on varying data types, are implemented both for CPU and GPU and include a backwards implementation. All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
.. toctree:: .. toctree::
:glob: :glob:
......
import pytest
import torch
from torch_scatter import scatter_add
from .utils import devices
@pytest.mark.parametrize('device', devices)
def test_broadcasting(device):
B, C, H, W = (4, 3, 8, 8)
src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H)
assert out.size() == (B, C, H, W)
src = torch.randn((B, 1, H, W), device=device)
index = torch.randint(0, H, (B, C, H, W)).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H)
assert out.size() == (B, C, H, W)
src = torch.randn((B, 1, H, W), device=device)
index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H)
assert out.size() == (B, 1, H, W)
src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (H, )).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H)
assert out.size() == (B, C, H, W)
...@@ -24,6 +24,8 @@ def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0): ...@@ -24,6 +24,8 @@ def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
:attr:`dim` = `i`, then :attr:`out` must be an n-dimensional tensor with :attr:`dim` = `i`, then :attr:`out` must be an n-dimensional tensor with
size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the
values of :attr:`index` must be between `0` and `out.size(dim) - 1`. values of :attr:`index` must be between `0` and `out.size(dim) - 1`.
Both :attr:`src` and :attr:`index` are broadcasted in case their dimensions
do not match.
For one-dimensional tensors, the operation computes For one-dimensional tensors, the operation computes
......
from __future__ import division
from itertools import repeat from itertools import repeat
...@@ -16,6 +18,18 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0): ...@@ -16,6 +18,18 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
index_size[dim] = src.size(dim) index_size[dim] = src.size(dim)
index = index.view(index_size).expand_as(src) index = index.view(index_size).expand_as(src)
# Broadcasting capabilties: Expand dimensions to match.
if src.dim() != index.dim():
raise ValueError(
('Number of dimensions of src and index tensor do not match, '
'got {} and {}').format(src.dim(), index.dim()))
expand_size = []
for s, i in zip(src.size(), index.size()):
expand_size += [-1 if s == i and s != 1 and i != 1 else max(i, s)]
src = src.expand(expand_size)
index = index.expand_as(src)
# Generate output tensor if not given. # Generate output tensor if not given.
if out is None: if out is None:
out_size = list(src.size()) out_size = list(src.size())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment