Commit 1fb5fa4f authored by rusty1s's avatar rusty1s
Browse files

torch-scatter=2.0 support

parent 2984f288
......@@ -28,11 +28,11 @@ Note that only `value` comes with autograd support, as `index` is discrete and t
## Installation
Ensure that at least PyTorch 1.1.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
Ensure that at least PyTorch 1.4.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
```
$ python -c "import torch; print(torch.__version__)"
>>> 1.1.0
>>> 1.4.0
$ echo $PATH
>>> /usr/local/cuda/bin:...
......@@ -53,7 +53,7 @@ Be sure to import `torch` first before using this package to resolve symbols the
## Coalesce
```
torch_sparse.coalesce(index, value, m, n, op="add", fill_value=0) -> (torch.LongTensor, torch.Tensor)
torch_sparse.coalesce(index, value, m, n, op="add") -> (torch.LongTensor, torch.Tensor)
```
Row-wise sorts `index` and removes duplicate entries.
......@@ -67,7 +67,6 @@ For scattering, any operation of [`torch_scatter`](https://github.com/rusty1s/py
* **m** *(int)* - The first dimension of corresponding dense matrix.
* **n** *(int)* - The second dimension of corresponding dense matrix.
* **op** *(string, optional)* - The scatter operation to use. (default: `"add"`)
* **fill_value** *(int, optional)* - The initial fill value of scatter operation. (default: `0`)
### Returns
......
......@@ -39,7 +39,7 @@ if CUDA_HOME is not None and GPU:
extra_compile_args=extra_compile_args),
]
__version__ = '0.4.3'
__version__ = '0.4.4'
url = 'https://github.com/rusty1s/pytorch_sparse'
install_requires = ['scipy']
......
......@@ -5,7 +5,7 @@ from .eye import eye
from .spmm import spmm
from .spspmm import spspmm
__version__ = '0.4.3'
__version__ = '0.4.4'
__all__ = [
'__version__',
......
......@@ -4,7 +4,7 @@ import torch_scatter
from .utils.unique import unique
def coalesce(index, value, m, n, op='add', fill_value=0):
def coalesce(index, value, m, n, op='add'):
"""Row-wise sorts :obj:`value` and removes duplicate entries. Duplicate
entries are removed by scattering them together. For scattering, any
operation of `"torch_scatter"<https://github.com/rusty1s/pytorch_scatter>`_
......@@ -17,8 +17,6 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
n (int): The second dimension of corresponding dense matrix.
op (string, optional): The scatter operation to use. (default:
:obj:`"add"`)
fill_value (int, optional): The initial fill value of scatter
operation. (default: :obj:`0`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
......@@ -37,8 +35,7 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
index = torch.stack([row[perm], col[perm]], dim=0)
op = getattr(torch_scatter, 'scatter_{}'.format(op))
value = op(value, inv, 0, None, perm.size(0), fill_value)
if isinstance(value, tuple):
value = value[0]
value = op(value, inv, 0, None, perm.size(0))
value = value[0] if isinstance(value, tuple) else value
return index, value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment