"src/array/cuda/uvm/array_index_select_uvm.hip" did not exist on "8ae50c422b81de5734b971722d580c433d669fe8"
Commit 1fb5fa4f authored by rusty1s's avatar rusty1s
Browse files

torch-scatter=2.0 support

parent 2984f288
...@@ -28,11 +28,11 @@ Note that only `value` comes with autograd support, as `index` is discrete and t ...@@ -28,11 +28,11 @@ Note that only `value` comes with autograd support, as `index` is discrete and t
## Installation ## Installation
Ensure that at least PyTorch 1.1.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*: Ensure that at least PyTorch 1.4.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
``` ```
$ python -c "import torch; print(torch.__version__)" $ python -c "import torch; print(torch.__version__)"
>>> 1.1.0 >>> 1.4.0
$ echo $PATH $ echo $PATH
>>> /usr/local/cuda/bin:... >>> /usr/local/cuda/bin:...
...@@ -53,7 +53,7 @@ Be sure to import `torch` first before using this package to resolve symbols the ...@@ -53,7 +53,7 @@ Be sure to import `torch` first before using this package to resolve symbols the
## Coalesce ## Coalesce
``` ```
torch_sparse.coalesce(index, value, m, n, op="add", fill_value=0) -> (torch.LongTensor, torch.Tensor) torch_sparse.coalesce(index, value, m, n, op="add") -> (torch.LongTensor, torch.Tensor)
``` ```
Row-wise sorts `index` and removes duplicate entries. Row-wise sorts `index` and removes duplicate entries.
...@@ -67,7 +67,6 @@ For scattering, any operation of [`torch_scatter`](https://github.com/rusty1s/py ...@@ -67,7 +67,6 @@ For scattering, any operation of [`torch_scatter`](https://github.com/rusty1s/py
* **m** *(int)* - The first dimension of corresponding dense matrix. * **m** *(int)* - The first dimension of corresponding dense matrix.
* **n** *(int)* - The second dimension of corresponding dense matrix. * **n** *(int)* - The second dimension of corresponding dense matrix.
* **op** *(string, optional)* - The scatter operation to use. (default: `"add"`) * **op** *(string, optional)* - The scatter operation to use. (default: `"add"`)
* **fill_value** *(int, optional)* - The initial fill value of scatter operation. (default: `0`)
### Returns ### Returns
......
...@@ -39,7 +39,7 @@ if CUDA_HOME is not None and GPU: ...@@ -39,7 +39,7 @@ if CUDA_HOME is not None and GPU:
extra_compile_args=extra_compile_args), extra_compile_args=extra_compile_args),
] ]
__version__ = '0.4.3' __version__ = '0.4.4'
url = 'https://github.com/rusty1s/pytorch_sparse' url = 'https://github.com/rusty1s/pytorch_sparse'
install_requires = ['scipy'] install_requires = ['scipy']
......
...@@ -5,7 +5,7 @@ from .eye import eye ...@@ -5,7 +5,7 @@ from .eye import eye
from .spmm import spmm from .spmm import spmm
from .spspmm import spspmm from .spspmm import spspmm
__version__ = '0.4.3' __version__ = '0.4.4'
__all__ = [ __all__ = [
'__version__', '__version__',
......
...@@ -4,7 +4,7 @@ import torch_scatter ...@@ -4,7 +4,7 @@ import torch_scatter
from .utils.unique import unique from .utils.unique import unique
def coalesce(index, value, m, n, op='add', fill_value=0): def coalesce(index, value, m, n, op='add'):
"""Row-wise sorts :obj:`value` and removes duplicate entries. Duplicate """Row-wise sorts :obj:`value` and removes duplicate entries. Duplicate
entries are removed by scattering them together. For scattering, any entries are removed by scattering them together. For scattering, any
operation of `"torch_scatter"<https://github.com/rusty1s/pytorch_scatter>`_ operation of `"torch_scatter"<https://github.com/rusty1s/pytorch_scatter>`_
...@@ -17,8 +17,6 @@ def coalesce(index, value, m, n, op='add', fill_value=0): ...@@ -17,8 +17,6 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
n (int): The second dimension of corresponding dense matrix. n (int): The second dimension of corresponding dense matrix.
op (string, optional): The scatter operation to use. (default: op (string, optional): The scatter operation to use. (default:
:obj:`"add"`) :obj:`"add"`)
fill_value (int, optional): The initial fill value of scatter
operation. (default: :obj:`0`)
:rtype: (:class:`LongTensor`, :class:`Tensor`) :rtype: (:class:`LongTensor`, :class:`Tensor`)
""" """
...@@ -37,8 +35,7 @@ def coalesce(index, value, m, n, op='add', fill_value=0): ...@@ -37,8 +35,7 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
index = torch.stack([row[perm], col[perm]], dim=0) index = torch.stack([row[perm], col[perm]], dim=0)
op = getattr(torch_scatter, 'scatter_{}'.format(op)) op = getattr(torch_scatter, 'scatter_{}'.format(op))
value = op(value, inv, 0, None, perm.size(0), fill_value) value = op(value, inv, 0, None, perm.size(0))
if isinstance(value, tuple): value = value[0] if isinstance(value, tuple) else value
value = value[0]
return index, value return index, value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment