Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
6569c84e
Commit
6569c84e
authored
Aug 05, 2018
by
rusty1s
Browse files
cpu port to new extension api
parent
aa80bb88
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
250 additions
and
29 deletions
+250
-29
.coveragerc
.coveragerc
+1
-1
.gitignore
.gitignore
+1
-1
.travis.yml
.travis.yml
+5
-4
README.md
README.md
+1
-1
cpu/dim_apply.h
cpu/dim_apply.h
+114
-0
cpu/scatter.cpp
cpu/scatter.cpp
+87
-0
setup.py
setup.py
+18
-7
test/test_backward.py
test/test_backward.py
+2
-1
test/utils.py
test/utils.py
+0
-3
torch_scatter/__init__.py
torch_scatter/__init__.py
+1
-1
torch_scatter/div.py
torch_scatter/div.py
+2
-2
torch_scatter/max.py
torch_scatter/max.py
+3
-3
torch_scatter/min.py
torch_scatter/min.py
+3
-3
torch_scatter/mul.py
torch_scatter/mul.py
+2
-2
torch_scatter/utils/ext.py
torch_scatter/utils/ext.py
+10
-0
No files found.
.coveragerc
View file @
6569c84e
[report]
exclude_lines =
pragma: no cover
def backward
cuda
.gitignore
View file @
6569c84e
__pycache__/
_ext/
build/
dist/
.cache/
.eggs/
*.egg-info/
.coverage
*.so
*.aux
*.log
*.pdf
...
...
.travis.yml
View file @
6569c84e
...
...
@@ -6,14 +6,15 @@ matrix:
-
python
:
2.7
-
python
:
3.5
-
python
:
3.6
-
python
:
3.7
install
:
-
if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl; fi
-
if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.0-cp35-cp35m-linux_x86_64.whl; fi
-
if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.0-cp36-cp36m-linux_x86_64.whl; fi
-
if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl; fi
-
if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp35-cp35m-linux_x86_64.whl; fi
-
if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp36-cp36m-linux_x86_64.whl; fi
-
if [[ $TRAVIS_PYTHON_VERSION == 3.7 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.1.post2-cp37-cp37m-linux_x86_64.whl; fi
-
pip install pycodestyle
-
pip install flake8
-
pip install codecov
-
pip install cffi
script
:
-
pycodestyle .
-
flake8 .
...
...
README.md
View file @
6569c84e
...
...
@@ -40,7 +40,7 @@ If not, add cuda (`/usr/local/cuda/bin`) to your `$PATH`.
Then run:
```
pip install
cffi
torch-scatter
pip install torch-scatter
```
## Example
...
...
cpu/dim_apply.h
0 → 100644
View file @
6569c84e
#pragma once
#include <torch/torch.h>
#define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \
TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
\
TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>(); \
auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \
\
TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>(); \
auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\
auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(torch::CPU(at::kLong), {dims}); \
auto counter = zeros.data<int64_t>(); \
bool has_finished = false; \
\
while (!has_finished) { \
CODE; \
if (dims == 1) \
break; \
\
for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) { \
if (cur_dim == DIM) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} \
continue; \
} \
\
counter[cur_dim]++; \
TENSOR1##_data += TENSOR1.stride(cur_dim); \
TENSOR2##_data += TENSOR2.stride(cur_dim); \
TENSOR3##_data += TENSOR3.stride(cur_dim); \
\
if (counter[cur_dim] == TENSOR1.size(cur_dim)) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} else { \
TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim); \
TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim); \
TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim); \
counter[cur_dim] = 0; \
} \
} else \
break; \
} \
}
#define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, \
TENSOR4, DIM, CODE) \
TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
\
TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>(); \
auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \
\
TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>(); \
auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\
TYPE4 *TENSOR4##_data = TENSOR4.data<TYPE4>(); \
auto TENSOR4##_size = TENSOR4.size(DIM); \
auto TENSOR4##_stride = TENSOR4.stride(DIM); \
\
auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(torch::CPU(at::kLong), {dims}); \
auto counter = zeros.data<int64_t>(); \
bool has_finished = false; \
\
while (!has_finished) { \
CODE; \
if (dims == 1) \
break; \
\
for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) { \
if (cur_dim == DIM) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} \
continue; \
} \
\
counter[cur_dim]++; \
TENSOR1##_data += TENSOR1.stride(cur_dim); \
TENSOR2##_data += TENSOR2.stride(cur_dim); \
TENSOR3##_data += TENSOR3.stride(cur_dim); \
TENSOR4##_data += TENSOR4.stride(cur_dim); \
\
if (counter[cur_dim] == TENSOR1.size(cur_dim)) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} else { \
TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim); \
TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim); \
TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim); \
TENSOR4##_data -= counter[cur_dim] * TENSOR4.stride(cur_dim); \
counter[cur_dim] = 0; \
} \
} else \
break; \
} \
}
cpu/scatter.cpp
0 → 100644
View file @
6569c84e
#include <torch/torch.h>
#include "dim_apply.h"
void
scatter_mul
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
int64_t
dim
)
{
int64_t
elems_per_row
=
index
.
size
(
dim
),
i
,
idx
;
AT_DISPATCH_ALL_TYPES
(
src
.
type
(),
"scatter_mul"
,
[
&
]
{
DIM_APPLY3
(
scalar_t
,
src
,
int64_t
,
index
,
scalar_t
,
out
,
dim
,
{
for
(
i
=
0
;
i
<
elems_per_row
;
i
++
)
{
idx
=
index_data
[
i
*
index_stride
];
out_data
[
idx
*
out_stride
]
*=
src_data
[
i
*
src_stride
];
}
})
});
}
void
scatter_div
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
int64_t
dim
)
{
int64_t
elems_per_row
=
index
.
size
(
dim
),
i
,
idx
;
AT_DISPATCH_ALL_TYPES
(
src
.
type
(),
"scatter_div"
,
[
&
]
{
DIM_APPLY3
(
scalar_t
,
src
,
int64_t
,
index
,
scalar_t
,
out
,
dim
,
{
for
(
i
=
0
;
i
<
elems_per_row
;
i
++
)
{
idx
=
index_data
[
i
*
index_stride
];
out_data
[
idx
*
out_stride
]
/=
src_data
[
i
*
src_stride
];
}
})
});
}
void
scatter_max
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
at
::
Tensor
arg
,
int64_t
dim
)
{
int64_t
elems_per_row
=
index
.
size
(
dim
),
i
,
idx
;
AT_DISPATCH_ALL_TYPES
(
src
.
type
(),
"scatter_max"
,
[
&
]
{
DIM_APPLY4
(
scalar_t
,
src
,
int64_t
,
index
,
scalar_t
,
out
,
int64_t
,
arg
,
dim
,
{
for
(
i
=
0
;
i
<
elems_per_row
;
i
++
)
{
idx
=
index_data
[
i
*
index_stride
];
if
(
src_data
[
i
*
src_stride
]
>=
out_data
[
idx
*
out_stride
])
{
out_data
[
idx
*
out_stride
]
=
src_data
[
i
*
src_stride
];
arg_data
[
idx
*
arg_stride
]
=
i
;
}
}
})
});
}
void
scatter_min
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
at
::
Tensor
arg
,
int64_t
dim
)
{
int64_t
elems_per_row
=
index
.
size
(
dim
),
i
,
idx
;
AT_DISPATCH_ALL_TYPES
(
src
.
type
(),
"scatter_min"
,
[
&
]
{
DIM_APPLY4
(
scalar_t
,
src
,
int64_t
,
index
,
scalar_t
,
out
,
int64_t
,
arg
,
dim
,
{
for
(
i
=
0
;
i
<
elems_per_row
;
i
++
)
{
idx
=
index_data
[
i
*
index_stride
];
if
(
src_data
[
i
*
src_stride
]
<=
out_data
[
idx
*
out_stride
])
{
out_data
[
idx
*
out_stride
]
=
src_data
[
i
*
src_stride
];
arg_data
[
idx
*
arg_stride
]
=
i
;
}
}
})
});
}
void
index_backward
(
at
::
Tensor
grad
,
at
::
Tensor
index
,
at
::
Tensor
arg
,
at
::
Tensor
out
,
int64_t
dim
)
{
int64_t
elems_per_row
=
index
.
size
(
dim
),
i
,
idx
;
AT_DISPATCH_ALL_TYPES
(
grad
.
type
(),
"index_backward"
,
[
&
]
{
DIM_APPLY4
(
scalar_t
,
grad
,
int64_t
,
index
,
int64_t
,
arg
,
scalar_t
,
out
,
dim
,
{
for
(
i
=
0
;
i
<
elems_per_row
;
i
++
)
{
idx
=
index_data
[
i
*
index_stride
];
if
(
arg_data
[
idx
*
arg_stride
]
==
i
)
{
out_data
[
i
*
out_stride
]
=
grad_data
[
idx
*
grad_stride
];
}
}
})
});
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"scatter_mul"
,
&
scatter_mul
,
"Scatter Mul (CPU)"
);
m
.
def
(
"scatter_div"
,
&
scatter_div
,
"Scatter Div (CPU)"
);
m
.
def
(
"scatter_max"
,
&
scatter_max
,
"Scatter Max (CPU)"
);
m
.
def
(
"scatter_min"
,
&
scatter_min
,
"Scatter Min (CPU)"
);
m
.
def
(
"index_backward"
,
&
index_backward
,
"Index Backward (CPU)"
);
}
setup.py
View file @
6569c84e
from
os
import
path
as
osp
import
glob
from
setuptools
import
setup
,
find_packages
__version__
=
'1.0.3'
import
torch.cuda
from
torch.utils.cpp_extension
import
CppExtension
,
CUDAExtension
ext_modules
=
[
CppExtension
(
'scatter_cpu'
,
[
'cpu/scatter.cpp'
])]
cmdclass
=
{
'build_ext'
:
torch
.
utils
.
cpp_extension
.
BuildExtension
}
if
torch
.
cuda
.
is_available
():
ext_modules
+=
[
CUDAExtension
(
'scatter_cuda'
,
[
'cuda/scatter.cpp'
]
+
glob
.
glob
(
'cuda/*.cu'
))
]
__version__
=
'1.0.4'
url
=
'https://github.com/rusty1s/pytorch_scatter'
install_requires
=
[
'cffi'
]
install_requires
=
[]
setup_requires
=
[
'pytest-runner'
,
'cffi'
]
tests_require
=
[
'pytest'
,
'pytest-cov'
]
...
...
@@ -21,7 +32,7 @@ setup(
install_requires
=
install_requires
,
setup_requires
=
setup_requires
,
tests_require
=
tests_require
,
packages
=
find_packag
es
(
ex
clude
=
[
'build'
])
,
ext_package
=
''
,
cffi_modules
=
[
osp
.
join
(
osp
.
dirname
(
__file__
),
'build.py:ffi'
)]
,
ext_modul
es
=
ex
t_modules
,
cmdclass
=
cmdclass
,
packages
=
find_packages
()
,
)
test/test_backward.py
View file @
6569c84e
...
...
@@ -5,8 +5,9 @@ import torch
from
torch.autograd
import
gradcheck
import
torch_scatter
from
.utils
import
dtypes
,
devices
,
tensor
from
.utils
import
devices
,
tensor
dtypes
=
[
torch
.
float
,
torch
.
double
]
funcs
=
[
'add'
,
'sub'
,
'mul'
,
'div'
,
'mean'
]
indices
=
[
2
,
0
,
1
,
1
,
0
]
...
...
test/utils.py
View file @
6569c84e
...
...
@@ -3,9 +3,6 @@ from torch.testing import get_all_dtypes
dtypes
=
get_all_dtypes
()
dtypes
.
remove
(
torch
.
half
)
dtypes
.
remove
(
torch
.
short
)
# TODO: PyTorch `atomicAdd` bug with short type.
dtypes
.
remove
(
torch
.
uint8
)
# We cannot properly test unsigned values.
dtypes
.
remove
(
torch
.
int8
)
# Overflow on gradient computations :(
devices
=
[
torch
.
device
(
'cpu'
)]
if
torch
.
cuda
.
is_available
():
# pragma: no cover
...
...
torch_scatter/__init__.py
View file @
6569c84e
...
...
@@ -6,7 +6,7 @@ from .mean import scatter_mean
from
.max
import
scatter_max
from
.min
import
scatter_min
__version__
=
'1.0.
3
'
__version__
=
'1.0.
4
'
__all__
=
[
'scatter_add'
,
'scatter_sub'
,
'scatter_mul'
,
'scatter_div'
,
'scatter_mean'
,
...
...
torch_scatter/div.py
View file @
6569c84e
from
torch.autograd
import
Function
from
.utils.
ffi
import
get_func
from
.utils.
ext
import
get_func
from
.utils.gen
import
gen
...
...
@@ -8,7 +8,7 @@ class ScatterDiv(Function):
@
staticmethod
def
forward
(
ctx
,
out
,
src
,
index
,
dim
):
func
=
get_func
(
'scatter_div'
,
src
)
func
(
dim
,
out
,
index
,
src
)
func
(
src
,
index
,
out
,
dim
)
ctx
.
mark_dirty
(
out
)
ctx
.
save_for_backward
(
out
,
src
,
index
)
...
...
torch_scatter/max.py
View file @
6569c84e
from
torch.autograd
import
Function
from
.utils.
ffi
import
get_func
from
.utils.
ext
import
get_func
from
.utils.gen
import
gen
...
...
@@ -9,7 +9,7 @@ class ScatterMax(Function):
def
forward
(
ctx
,
out
,
src
,
index
,
dim
):
arg
=
index
.
new_full
(
out
.
size
(),
-
1
)
func
=
get_func
(
'scatter_max'
,
src
)
func
(
dim
,
out
,
index
,
src
,
arg
)
func
(
src
,
index
,
out
,
arg
,
dim
)
ctx
.
mark_dirty
(
out
)
ctx
.
dim
=
dim
...
...
@@ -25,7 +25,7 @@ class ScatterMax(Function):
if
ctx
.
needs_input_grad
[
1
]:
grad_src
=
grad_out
.
new_zeros
(
index
.
size
())
func
=
get_func
(
'index_backward'
,
grad_out
)
func
(
ctx
.
dim
,
grad_
src
,
index
,
grad_out
,
arg
)
func
(
grad_
out
,
index
,
arg
,
grad_src
,
ctx
.
dim
)
return
None
,
grad_src
,
None
,
None
...
...
torch_scatter/min.py
View file @
6569c84e
from
torch.autograd
import
Function
from
.utils.
ffi
import
get_func
from
.utils.
ext
import
get_func
from
.utils.gen
import
gen
...
...
@@ -9,7 +9,7 @@ class ScatterMin(Function):
def
forward
(
ctx
,
out
,
src
,
index
,
dim
):
arg
=
index
.
new_full
(
out
.
size
(),
-
1
)
func
=
get_func
(
'scatter_min'
,
src
)
func
(
dim
,
out
,
index
,
src
,
arg
)
func
(
src
,
index
,
out
,
arg
,
dim
)
ctx
.
mark_dirty
(
out
)
ctx
.
dim
=
dim
...
...
@@ -25,7 +25,7 @@ class ScatterMin(Function):
if
ctx
.
needs_input_grad
[
1
]:
grad_src
=
grad_out
.
new_zeros
(
index
.
size
())
func
=
get_func
(
'index_backward'
,
grad_out
)
func
(
ctx
.
dim
,
grad_
src
,
index
,
grad_out
,
arg
)
func
(
grad_
out
,
index
,
arg
,
grad_src
,
ctx
.
dim
)
return
None
,
grad_src
,
None
,
None
...
...
torch_scatter/mul.py
View file @
6569c84e
from
torch.autograd
import
Function
from
.utils.
ffi
import
get_func
from
.utils.
ext
import
get_func
from
.utils.gen
import
gen
...
...
@@ -8,7 +8,7 @@ class ScatterMul(Function):
@
staticmethod
def
forward
(
ctx
,
out
,
src
,
index
,
dim
):
func
=
get_func
(
'scatter_mul'
,
src
)
func
(
dim
,
out
,
index
,
src
)
func
(
src
,
index
,
out
,
dim
)
ctx
.
mark_dirty
(
out
)
ctx
.
save_for_backward
(
out
,
src
,
index
)
...
...
torch_scatter/utils/ext.py
0 → 100644
View file @
6569c84e
import
torch
import
scatter_cpu
if
torch
.
cuda
.
is_available
():
import
scatter_cuda
def
get_func
(
name
,
tensor
):
scatter
=
scatter_cuda
if
tensor
.
is_cuda
else
scatter_cpu
return
getattr
(
scatter
,
name
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment