Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
ea94e546
Commit
ea94e546
authored
Jan 10, 2020
by
rusty1s
Browse files
cpu boilerplate
parent
d824c8be
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
121 additions
and
32 deletions
+121
-32
benchmark/gather.py
benchmark/gather.py
+13
-5
benchmark/scatter_segment.py
benchmark/scatter_segment.py
+21
-8
cpu/gather.cpp
cpu/gather.cpp
+28
-0
cpu/segment.cpp
cpu/segment.cpp
+29
-0
setup.py
setup.py
+5
-3
torch_scatter/gather.py
torch_scatter/gather.py
+9
-4
torch_scatter/segment.py
torch_scatter/segment.py
+16
-12
No files found.
benchmark/gather.py
View file @
ea94e546
...
@@ -30,13 +30,16 @@ def correctness(dataset):
...
@@ -30,13 +30,16 @@ def correctness(dataset):
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
except
RuntimeError
:
except
RuntimeError
as
e
:
if
'out of memory'
not
in
str
(
e
):
raise
RuntimeError
(
e
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
def
time_func
(
func
,
x
):
def
time_func
(
func
,
x
):
try
:
try
:
torch
.
cuda
.
synchronize
()
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
t
=
time
.
perf_counter
()
if
not
args
.
with_backward
:
if
not
args
.
with_backward
:
...
@@ -49,9 +52,12 @@ def time_func(func, x):
...
@@ -49,9 +52,12 @@ def time_func(func, x):
out
=
func
(
x
)
out
=
func
(
x
)
torch
.
autograd
.
grad
(
out
,
x
,
out
,
only_inputs
=
True
)
torch
.
autograd
.
grad
(
out
,
x
,
out
,
only_inputs
=
True
)
torch
.
cuda
.
synchronize
()
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
synchronize
()
return
time
.
perf_counter
()
-
t
return
time
.
perf_counter
()
-
t
except
RuntimeError
:
except
RuntimeError
as
e
:
if
'out of memory'
not
in
str
(
e
):
raise
RuntimeError
(
e
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
float
(
'inf'
)
return
float
(
'inf'
)
...
@@ -88,7 +94,9 @@ def timing(dataset):
...
@@ -88,7 +94,9 @@ def timing(dataset):
del
x
del
x
except
RuntimeError
:
except
RuntimeError
as
e
:
if
'out of memory'
not
in
str
(
e
):
raise
RuntimeError
(
e
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
for
t
in
(
t1
,
t2
,
t3
,
t4
):
for
t
in
(
t1
,
t2
,
t3
,
t4
):
t
.
append
(
float
(
'inf'
))
t
.
append
(
float
(
'inf'
))
...
...
benchmark/scatter_segment.py
View file @
ea94e546
...
@@ -82,13 +82,16 @@ def correctness(dataset):
...
@@ -82,13 +82,16 @@ def correctness(dataset):
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
except
RuntimeError
:
except
RuntimeError
as
e
:
if
'out of memory'
not
in
str
(
e
):
raise
RuntimeError
(
e
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
def
time_func
(
func
,
x
):
def
time_func
(
func
,
x
):
try
:
try
:
torch
.
cuda
.
synchronize
()
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
t
=
time
.
perf_counter
()
if
not
args
.
with_backward
:
if
not
args
.
with_backward
:
...
@@ -102,9 +105,12 @@ def time_func(func, x):
...
@@ -102,9 +105,12 @@ def time_func(func, x):
out
=
out
[
0
]
if
isinstance
(
out
,
tuple
)
else
out
out
=
out
[
0
]
if
isinstance
(
out
,
tuple
)
else
out
torch
.
autograd
.
grad
(
out
,
x
,
out
,
only_inputs
=
True
)
torch
.
autograd
.
grad
(
out
,
x
,
out
,
only_inputs
=
True
)
torch
.
cuda
.
synchronize
()
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
synchronize
()
return
time
.
perf_counter
()
-
t
return
time
.
perf_counter
()
-
t
except
RuntimeError
:
except
RuntimeError
as
e
:
if
'out of memory'
not
in
str
(
e
):
raise
RuntimeError
(
e
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
float
(
'inf'
)
return
float
(
'inf'
)
...
@@ -152,7 +158,9 @@ def timing(dataset):
...
@@ -152,7 +158,9 @@ def timing(dataset):
del
x
del
x
except
RuntimeError
:
except
RuntimeError
as
e
:
if
'out of memory'
not
in
str
(
e
):
raise
RuntimeError
(
e
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
for
t
in
(
t1
,
t2
,
t3
,
t4
):
for
t
in
(
t1
,
t2
,
t3
,
t4
):
t
.
append
(
float
(
'inf'
))
t
.
append
(
float
(
'inf'
))
...
@@ -167,7 +175,9 @@ def timing(dataset):
...
@@ -167,7 +175,9 @@ def timing(dataset):
del
x
del
x
except
RuntimeError
:
except
RuntimeError
as
e
:
if
'out of memory'
not
in
str
(
e
):
raise
RuntimeError
(
e
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
for
t
in
(
t5
,
t6
):
for
t
in
(
t5
,
t6
):
t
.
append
(
float
(
'inf'
))
t
.
append
(
float
(
'inf'
))
...
@@ -197,8 +207,11 @@ def timing(dataset):
...
@@ -197,8 +207,11 @@ def timing(dataset):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--reduce'
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
choices
=
[
'add'
,
'mean'
,
'min'
,
'max'
])
'--reduce'
,
type
=
str
,
required
=
True
,
choices
=
[
'add'
,
'mean'
,
'min'
,
'max'
])
parser
.
add_argument
(
'--with_backward'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--with_backward'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
cpu/gather.cpp
0 → 100644
View file @
ea94e546
#include <torch/extension.h>
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
at
::
Tensor
gather_csr
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
optional
<
at
::
Tensor
>
out_opt
)
{
CHECK_CPU
(
src
);
CHECK_CPU
(
indptr
);
if
(
out_opt
.
has_value
())
CHECK_CPU
(
out_opt
.
value
());
AT_ASSERTM
(
false
,
"Not yet implemented"
);
return
src
;
}
at
::
Tensor
gather_coo
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
optional
<
at
::
Tensor
>
out_opt
)
{
CHECK_CPU
(
src
);
CHECK_CPU
(
index
);
if
(
out_opt
.
has_value
())
CHECK_CPU
(
out_opt
.
value
());
AT_ASSERTM
(
false
,
"Not yet implemented"
);
return
src
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"gather_csr"
,
&
gather_csr
,
"Gather CSR (CPU)"
);
m
.
def
(
"gather_coo"
,
&
gather_coo
,
"Gather COO (CPU)"
);
}
cpu/segment.cpp
0 → 100644
View file @
ea94e546
#include <torch/extension.h>
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
segment_csr
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
optional
<
at
::
Tensor
>
out_opt
,
std
::
string
reduce
)
{
CHECK_CPU
(
src
);
CHECK_CPU
(
indptr
);
if
(
out_opt
.
has_value
())
CHECK_CPU
(
out_opt
.
value
());
AT_ASSERTM
(
false
,
"Not yet implemented"
);
return
std
::
make_tuple
(
src
,
at
::
nullopt
);
}
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
segment_coo
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
std
::
string
reduce
)
{
CHECK_CPU
(
src
);
CHECK_CPU
(
index
);
CHECK_CPU
(
out
);
AT_ASSERTM
(
false
,
"Not yet implemented"
);
return
std
::
make_tuple
(
src
,
at
::
nullopt
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"segment_csr"
,
&
segment_csr
,
"Segment CSR (CPU)"
);
m
.
def
(
"segment_coo"
,
&
segment_coo
,
"Segment COO (CPU)"
);
}
setup.py
View file @
ea94e546
...
@@ -25,8 +25,9 @@ cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
...
@@ -25,8 +25,9 @@ cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
ext_modules
=
[]
ext_modules
=
[]
exts
=
[
e
.
split
(
osp
.
sep
)[
-
1
][:
-
4
]
for
e
in
glob
(
osp
.
join
(
'cpu'
,
'*.cpp'
))]
exts
=
[
e
.
split
(
osp
.
sep
)[
-
1
][:
-
4
]
for
e
in
glob
(
osp
.
join
(
'cpu'
,
'*.cpp'
))]
ext_modules
+=
[
ext_modules
+=
[
CppExtension
(
f
'torch_scatter.
{
ext
}
_cpu'
,
[
f
'cpu/
{
ext
}
.cpp'
],
CppExtension
(
extra_compile_args
=
cxx_extra_compile_args
)
for
ext
in
exts
f
'torch_scatter.
{
ext
}
_cpu'
,
[
f
'cpu/
{
ext
}
.cpp'
],
extra_compile_args
=
cxx_extra_compile_args
)
for
ext
in
exts
]
]
if
CUDA_HOME
is
not
None
and
USE_GPU
:
if
CUDA_HOME
is
not
None
and
USE_GPU
:
...
@@ -34,7 +35,8 @@ if CUDA_HOME is not None and USE_GPU:
...
@@ -34,7 +35,8 @@ if CUDA_HOME is not None and USE_GPU:
ext_modules
+=
[
ext_modules
+=
[
CUDAExtension
(
CUDAExtension
(
f
'torch_scatter.
{
ext
}
_cuda'
,
f
'torch_scatter.
{
ext
}
_cuda'
,
[
f
'cuda/
{
ext
}
.cpp'
,
f
'cuda/
{
ext
}
_kernel.cu'
],
extra_compile_args
=
{
[
f
'cuda/
{
ext
}
.cpp'
,
f
'cuda/
{
ext
}
_kernel.cu'
],
extra_compile_args
=
{
'cxx'
:
cxx_extra_compile_args
,
'cxx'
:
cxx_extra_compile_args
,
'nvcc'
:
nvcc_extra_compile_args
,
'nvcc'
:
nvcc_extra_compile_args
,
})
for
ext
in
exts
})
for
ext
in
exts
...
...
torch_scatter/gather.py
View file @
ea94e546
import
torch
import
torch
from
torch_scatter
import
segment_cpu
,
gather_cpu
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
from
torch_scatter
import
gather_cuda
,
segment_cuda
from
torch_scatter
import
gather_cuda
,
segment_cuda
gat
=
lambda
is_cuda
:
gather_cuda
if
is_cuda
else
gather_cpu
# noqa
seg
=
lambda
is_cuda
:
segment_cuda
if
is_cuda
else
segment_cpu
# noqa
class
GatherCOO
(
torch
.
autograd
.
Function
):
class
GatherCOO
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
...
@@ -12,7 +17,7 @@ class GatherCOO(torch.autograd.Function):
...
@@ -12,7 +17,7 @@ class GatherCOO(torch.autograd.Function):
ctx
.
src_size
=
list
(
src
.
size
())
ctx
.
src_size
=
list
(
src
.
size
())
ctx
.
save_for_backward
(
index
)
ctx
.
save_for_backward
(
index
)
return
gat
her
_cuda
.
gather_coo
(
src
,
index
,
out
)
return
gat
(
src
.
is
_cuda
)
.
gather_coo
(
src
,
index
,
out
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
...
@@ -20,7 +25,7 @@ class GatherCOO(torch.autograd.Function):
...
@@ -20,7 +25,7 @@ class GatherCOO(torch.autograd.Function):
grad_src
=
None
grad_src
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
0
]:
grad_src
,
_
=
seg
ment
_cuda
.
segment_coo
(
grad_src
,
_
=
seg
(
grad_out
.
is
_cuda
)
.
segment_coo
(
grad_out
,
index
,
grad_out
.
new_zeros
(
src_size
),
'add'
)
grad_out
,
index
,
grad_out
.
new_zeros
(
src_size
),
'add'
)
return
grad_src
,
None
,
None
return
grad_src
,
None
,
None
...
@@ -34,7 +39,7 @@ class GatherCSR(torch.autograd.Function):
...
@@ -34,7 +39,7 @@ class GatherCSR(torch.autograd.Function):
ctx
.
src_size
=
list
(
src
.
size
())
ctx
.
src_size
=
list
(
src
.
size
())
ctx
.
save_for_backward
(
indptr
)
ctx
.
save_for_backward
(
indptr
)
return
gat
her
_cuda
.
gather_csr
(
src
,
indptr
,
out
)
return
gat
(
src
.
is
_cuda
)
.
gather_csr
(
src
,
indptr
,
out
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_out
):
def
backward
(
ctx
,
grad_out
):
...
@@ -42,7 +47,7 @@ class GatherCSR(torch.autograd.Function):
...
@@ -42,7 +47,7 @@ class GatherCSR(torch.autograd.Function):
grad_src
=
None
grad_src
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
0
]:
grad_src
,
_
=
seg
ment
_cuda
.
segment_csr
(
grad_src
,
_
=
seg
(
grad_out
.
is
_cuda
)
.
segment_csr
(
grad_out
,
indptr
,
grad_out
.
new_empty
(
src_size
),
'add'
)
grad_out
,
indptr
,
grad_out
.
new_empty
(
src_size
),
'add'
)
return
grad_src
,
None
,
None
return
grad_src
,
None
,
None
...
...
torch_scatter/segment.py
View file @
ea94e546
import
torch
import
torch
from
torch_scatter
import
segment_cpu
,
gather_cpu
from
torch_scatter.helpers
import
min_value
,
max_value
from
torch_scatter.helpers
import
min_value
,
max_value
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
from
torch_scatter
import
segment_cuda
,
gather_cuda
from
torch_scatter
import
segment_cuda
,
gather_cuda
seg
=
lambda
is_cuda
:
segment_cuda
if
is_cuda
else
segment_cpu
# noqa
gat
=
lambda
is_cuda
:
gather_cuda
if
is_cuda
else
gather_cpu
# noqa
class
SegmentCOO
(
torch
.
autograd
.
Function
):
class
SegmentCOO
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
...
@@ -28,7 +32,7 @@ class SegmentCOO(torch.autograd.Function):
...
@@ -28,7 +32,7 @@ class SegmentCOO(torch.autograd.Function):
out
=
src
.
new_full
(
size
,
fill_value
)
out
=
src
.
new_full
(
size
,
fill_value
)
out
,
arg_out
=
seg
ment
_cuda
.
segment_coo
(
src
,
index
,
out
,
reduce
)
out
,
arg_out
=
seg
(
src
.
is
_cuda
)
.
segment_coo
(
src
,
index
,
out
,
reduce
)
if
fill_value
!=
0
:
if
fill_value
!=
0
:
out
.
masked_fill_
(
out
==
fill_value
,
0
)
out
.
masked_fill_
(
out
==
fill_value
,
0
)
...
@@ -47,13 +51,13 @@ class SegmentCOO(torch.autograd.Function):
...
@@ -47,13 +51,13 @@ class SegmentCOO(torch.autograd.Function):
grad_src
=
None
grad_src
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
reduce
==
'add'
:
if
ctx
.
reduce
==
'add'
:
grad_src
=
gat
her_cuda
.
gather_coo
(
grad_out
,
index
,
grad_src
=
gat
(
grad_out
)
.
gather_coo
(
grad_out
.
new_empty
(
src_size
))
grad_out
,
index
,
grad_out
.
new_empty
(
src_size
))
elif
ctx
.
reduce
==
'mean'
:
elif
ctx
.
reduce
==
'mean'
:
grad_src
=
gat
her_cuda
.
gather_coo
(
grad_out
,
index
,
grad_src
=
gat
(
grad_out
)
.
gather_coo
(
grad_out
.
new_empty
(
src_size
))
grad_out
,
index
,
grad_out
.
new_empty
(
src_size
))
count
=
arg_out
count
=
arg_out
count
=
gat
her
_cuda
.
gather_coo
(
count
=
gat
(
grad_out
.
is
_cuda
)
.
gather_coo
(
count
,
index
,
count
.
new_empty
(
src_size
[:
index
.
dim
()]))
count
,
index
,
count
.
new_empty
(
src_size
[:
index
.
dim
()]))
for
_
in
range
(
grad_out
.
dim
()
-
index
.
dim
()):
for
_
in
range
(
grad_out
.
dim
()
-
index
.
dim
()):
count
=
count
.
unsqueeze
(
-
1
)
count
=
count
.
unsqueeze
(
-
1
)
...
@@ -78,7 +82,7 @@ class SegmentCSR(torch.autograd.Function):
...
@@ -78,7 +82,7 @@ class SegmentCSR(torch.autograd.Function):
ctx
.
reduce
=
reduce
ctx
.
reduce
=
reduce
ctx
.
src_size
=
list
(
src
.
size
())
ctx
.
src_size
=
list
(
src
.
size
())
out
,
arg_out
=
seg
ment
_cuda
.
segment_csr
(
src
,
indptr
,
out
,
reduce
)
out
,
arg_out
=
seg
(
src
.
is
_cuda
)
.
segment_csr
(
src
,
indptr
,
out
,
reduce
)
ctx
.
save_for_backward
(
indptr
,
arg_out
)
ctx
.
save_for_backward
(
indptr
,
arg_out
)
return
out
if
arg_out
is
None
else
(
out
,
arg_out
)
return
out
if
arg_out
is
None
else
(
out
,
arg_out
)
...
@@ -89,15 +93,15 @@ class SegmentCSR(torch.autograd.Function):
...
@@ -89,15 +93,15 @@ class SegmentCSR(torch.autograd.Function):
grad_src
=
None
grad_src
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
reduce
==
'add'
:
if
ctx
.
reduce
==
'add'
:
grad_src
=
gat
her
_cuda
.
gather_csr
(
grad_out
,
indptr
,
grad_src
=
gat
(
grad_out
.
is
_cuda
)
.
gather_csr
(
grad_out
.
new_empty
(
src_size
))
grad_out
,
indptr
,
grad_out
.
new_empty
(
src_size
))
elif
ctx
.
reduce
==
'mean'
:
elif
ctx
.
reduce
==
'mean'
:
grad_src
=
gat
her
_cuda
.
gather_csr
(
grad_out
,
indptr
,
grad_src
=
gat
(
grad_out
.
is
_cuda
)
.
gather_csr
(
grad_out
.
new_empty
(
src_size
))
grad_out
,
indptr
,
grad_out
.
new_empty
(
src_size
))
indptr1
=
indptr
.
narrow
(
-
1
,
0
,
indptr
.
size
(
-
1
)
-
1
)
indptr1
=
indptr
.
narrow
(
-
1
,
0
,
indptr
.
size
(
-
1
)
-
1
)
indptr2
=
indptr
.
narrow
(
-
1
,
1
,
indptr
.
size
(
-
1
)
-
1
)
indptr2
=
indptr
.
narrow
(
-
1
,
1
,
indptr
.
size
(
-
1
)
-
1
)
count
=
(
indptr2
-
indptr1
).
to
(
grad_src
.
dtype
)
count
=
(
indptr2
-
indptr1
).
to
(
grad_src
.
dtype
)
count
=
gat
her
_cuda
.
gather_csr
(
count
=
gat
(
grad_out
.
is
_cuda
)
.
gather_csr
(
count
,
indptr
,
count
.
new_empty
(
src_size
[:
indptr
.
dim
()]))
count
,
indptr
,
count
.
new_empty
(
src_size
[:
indptr
.
dim
()]))
for
_
in
range
(
grad_out
.
dim
()
-
indptr
.
dim
()):
for
_
in
range
(
grad_out
.
dim
()
-
indptr
.
dim
()):
count
=
count
.
unsqueeze
(
-
1
)
count
=
count
.
unsqueeze
(
-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment