Commit efac08dd authored by rusty1s's avatar rusty1s
Browse files

Merge branch 'master' of github.com:rusty1s/pytorch_scatter

parents 481e81fe 18bb5b1d
......@@ -5,3 +5,4 @@ exclude_lines =
pragma: no cover
cuda
backward
raise
......@@ -17,13 +17,14 @@ before_install:
- export CXX="g++-4.9"
install:
- pip install numpy
- pip install -q torch
- pip install -q torch -f https://download.pytorch.org/whl/nightly/cpu/torch.html
- pip install pycodestyle
- pip install flake8
- pip install codecov
- pip install sphinx
- pip install sphinx_rtd_theme
script:
- python -c "import torch; print(torch.__version__)"
- pycodestyle .
- flake8 .
- python setup.py install
......
......@@ -35,7 +35,7 @@ The package consists of the following operations:
* [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html)
* [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html)
All included operations work on varying data types, are implemented both for CPU and GPU and include a backwards implementation.
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
## Installation
......
......@@ -2,5 +2,6 @@ Scatter Add
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_add
......@@ -2,5 +2,6 @@ Scatter Div
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_div
......@@ -2,5 +2,6 @@ Scatter Max
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_max
......@@ -2,5 +2,6 @@ Scatter Mean
============
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_mean
......@@ -2,5 +2,6 @@ Scatter Mul
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_mul
......@@ -2,5 +2,6 @@ Scatter Std
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_std
......@@ -2,5 +2,6 @@ Scatter Sub
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: scatter_sub
......@@ -6,7 +6,7 @@ PyTorch Scatter Documentation
This package consists of a small extension library of highly optimized sparse update (scatter) operations for the use in `PyTorch <http://pytorch.org/>`_, which are missing in the main package.
Scatter operations can be roughly described as reduce operations based on a given "group-index" tensor.
All included operations work on varying data types, are implemented both for CPU and GPU and include a backwards implementation.
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
.. toctree::
:glob:
......
import pytest
import torch
from torch_scatter import scatter_add
from .utils import devices
@pytest.mark.parametrize('device', devices)
def test_broadcasting(device):
B, C, H, W = (4, 3, 8, 8)
src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H)
assert out.size() == (B, C, H, W)
src = torch.randn((B, 1, H, W), device=device)
index = torch.randint(0, H, (B, C, H, W)).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H)
assert out.size() == (B, C, H, W)
src = torch.randn((B, 1, H, W), device=device)
index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H)
assert out.size() == (B, 1, H, W)
src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (H, )).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H)
assert out.size() == (B, C, H, W)
......@@ -24,6 +24,8 @@ def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
:attr:`dim` = `i`, then :attr:`out` must be an n-dimensional tensor with
size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. Moreover, the
values of :attr:`index` must be between `0` and `out.size(dim) - 1`.
Both :attr:`src` and :attr:`index` are broadcasted in case their dimensions
do not match.
For one-dimensional tensors, the operation computes
......
from __future__ import division
from itertools import repeat
import torch
......@@ -21,6 +23,18 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
else: # PyTorch has a bug when view is used on zero-element tensors.
index = src.new_empty(index_size, dtype=torch.long)
# Broadcasting capabilties: Expand dimensions to match.
if src.dim() != index.dim():
raise ValueError(
('Number of dimensions of src and index tensor do not match, '
'got {} and {}').format(src.dim(), index.dim()))
expand_size = []
for s, i in zip(src.size(), index.size()):
expand_size += [-1 if s == i and s != 1 and i != 1 else max(i, s)]
src = src.expand(expand_size)
index = index.expand_as(src)
# Generate output tensor if not given.
if out is None:
out_size = list(src.size())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment