Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
aae9e125
Commit
aae9e125
authored
Jan 09, 2020
by
rusty1s
Browse files
update benchmark
parent
d7f9176e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
67 deletions
+49
-67
benchmark/gather.py
benchmark/gather.py
+40
-60
benchmark/scatter_segment.py
benchmark/scatter_segment.py
+9
-7
No files found.
benchmark/gather.py
View file @
aae9e125
# flake8: noqa
import
time
import
time
import
itertools
import
itertools
import
argparse
import
torch
import
torch
from
scipy.io
import
loadmat
from
scipy.io
import
loadmat
from
torch_scatter
import
gather_coo
,
gather_csr
from
torch_scatter
import
gather_coo
,
gather_csr
from
scatter_segment
import
iters
,
device
,
sizes
from
scatter_segment
import
iters
,
sizes
from
scatter_segment
import
short_rows
,
long_rows
,
download
,
bold
from
scatter_segment
import
short_rows
,
long_rows
,
download
,
bold
...
@@ -14,13 +17,13 @@ from scatter_segment import short_rows, long_rows, download, bold
...
@@ -14,13 +17,13 @@ from scatter_segment import short_rows, long_rows, download, bold
def
correctness
(
dataset
):
def
correctness
(
dataset
):
group
,
name
=
dataset
group
,
name
=
dataset
mat
=
loadmat
(
f
'
{
name
}
.mat'
)[
'Problem'
][
0
][
0
][
2
].
tocsr
()
mat
=
loadmat
(
f
'
{
name
}
.mat'
)[
'Problem'
][
0
][
0
][
2
].
tocsr
()
rowptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
device
,
torch
.
long
)
rowptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
args
.
device
,
torch
.
long
)
row
=
torch
.
from_numpy
(
mat
.
tocoo
().
row
).
to
(
device
,
torch
.
long
)
row
=
torch
.
from_numpy
(
mat
.
tocoo
().
row
).
to
(
args
.
device
,
torch
.
long
)
dim_size
=
rowptr
.
size
(
0
)
-
1
dim_size
=
rowptr
.
size
(
0
)
-
1
for
size
in
sizes
[
1
:]:
for
size
in
sizes
[
1
:]:
try
:
try
:
x
=
torch
.
randn
((
dim_size
,
size
),
device
=
device
)
x
=
torch
.
randn
((
dim_size
,
size
),
device
=
args
.
device
)
x
=
x
.
squeeze
(
-
1
)
if
size
==
1
else
x
x
=
x
.
squeeze
(
-
1
)
if
size
==
1
else
x
out1
=
x
.
index_select
(
0
,
row
)
out1
=
x
.
index_select
(
0
,
row
)
...
@@ -34,75 +37,48 @@ def correctness(dataset):
...
@@ -34,75 +37,48 @@ def correctness(dataset):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
timing
(
dataset
):
def
time_func
(
func
,
x
):
group
,
name
=
dataset
mat
=
loadmat
(
f
'
{
name
}
.mat'
)[
'Problem'
][
0
][
0
][
2
].
tocsr
()
rowptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
device
,
torch
.
long
)
row
=
torch
.
from_numpy
(
mat
.
tocoo
().
row
).
to
(
device
,
torch
.
long
)
dim_size
=
rowptr
.
size
(
0
)
-
1
avg_row_len
=
row
.
size
(
0
)
/
dim_size
t1
,
t2
,
t3
,
t4
=
[],
[],
[],
[]
for
size
in
sizes
:
try
:
x
=
torch
.
randn
((
dim_size
,
size
),
device
=
device
)
row_expand
=
row
.
view
(
-
1
,
1
).
expand
(
-
1
,
x
.
size
(
-
1
))
x
=
x
.
squeeze
(
-
1
)
if
size
==
1
else
x
row_expand
=
row_expand
.
squeeze
(
-
1
)
if
size
==
1
else
row_expand
try
:
try
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
for
_
in
range
(
iters
):
out
=
x
.
index_select
(
0
,
row
)
func
(
x
)
del
out
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t1
.
append
(
time
.
perf_counter
()
-
t
)
return
time
.
perf_counter
()
-
t
except
RuntimeError
:
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
t1
.
append
(
float
(
'inf'
)
)
return
float
(
'inf'
)
try
:
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
out
=
x
.
gather
(
0
,
row_expand
)
del
out
torch
.
cuda
.
synchronize
()
t2
.
append
(
time
.
perf_counter
()
-
t
)
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
t2
.
append
(
float
(
'inf'
))
try
:
@
torch
.
no_grad
()
torch
.
cuda
.
synchronize
()
def
timing
(
dataset
):
t
=
time
.
perf_counter
()
group
,
name
=
dataset
for
_
in
range
(
iters
):
mat
=
loadmat
(
f
'
{
name
}
.mat'
)[
'Problem'
][
0
][
0
][
2
].
tocsr
()
out
=
gather_coo
(
x
,
row
)
rowptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
args
.
device
,
torch
.
long
)
del
out
row
=
torch
.
from_numpy
(
mat
.
tocoo
().
row
).
to
(
args
.
device
,
torch
.
long
)
torch
.
cuda
.
synchronize
()
dim_size
=
rowptr
.
size
(
0
)
-
1
t3
.
append
(
time
.
perf_counter
()
-
t
)
avg_row_len
=
row
.
size
(
0
)
/
dim_size
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
select
=
lambda
x
:
x
.
index_select
(
0
,
row
)
t3
.
append
(
float
(
'inf'
))
gather
=
lambda
x
:
x
.
gather
(
0
,
row
.
view
(
-
1
,
1
).
expand
(
-
1
,
x
.
size
(
1
)))
gat_coo
=
lambda
x
:
gather_coo
(
x
,
row
)
gat_csr
=
lambda
x
:
gather_csr
(
x
,
rowptr
)
t1
,
t2
,
t3
,
t4
=
[],
[],
[],
[]
for
size
in
sizes
:
try
:
try
:
torch
.
cuda
.
synchronize
()
x
=
torch
.
randn
((
dim_size
,
size
),
device
=
args
.
device
)
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
t1
+=
[
time_func
(
select
,
x
)]
out
=
gather_csr
(
x
,
rowptr
)
t2
+=
[
time_func
(
gather
,
x
)]
del
out
t3
+=
[
time_func
(
gat_coo
,
x
)]
torch
.
cuda
.
synchronize
()
t4
+=
[
time_func
(
gat_csr
,
x
)]
t4
.
append
(
time
.
perf_counter
()
-
t
)
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
t4
.
append
(
float
(
'inf'
))
del
x
del
x
except
RuntimeError
:
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
for
t
in
(
t1
,
t2
,
t3
):
for
t
in
(
t1
,
t2
,
t3
,
t4
):
t
.
append
(
float
(
'inf'
))
t
.
append
(
float
(
'inf'
))
ts
=
torch
.
tensor
([
t1
,
t2
,
t3
,
t4
])
ts
=
torch
.
tensor
([
t1
,
t2
,
t3
,
t4
])
...
@@ -125,8 +101,12 @@ def timing(dataset):
...
@@ -125,8 +101,12 @@ def timing(dataset):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda'
)
args
=
parser
.
parse_args
()
for
_
in
range
(
10
):
# Warmup.
for
_
in
range
(
10
):
# Warmup.
torch
.
randn
(
100
,
100
,
device
=
device
).
sum
()
torch
.
randn
(
100
,
100
,
device
=
args
.
device
).
sum
()
for
dataset
in
itertools
.
chain
(
short_rows
,
long_rows
):
for
dataset
in
itertools
.
chain
(
short_rows
,
long_rows
):
download
(
dataset
)
download
(
dataset
)
correctness
(
dataset
)
correctness
(
dataset
)
...
...
benchmark/scatter_segment.py
View file @
aae9e125
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
import
time
import
time
import
os.path
as
osp
import
os.path
as
osp
import
itertools
import
itertools
import
argparse
import
argparse
import
wget
import
wget
import
torch
import
torch
from
scipy.io
import
loadmat
from
scipy.io
import
loadmat
...
@@ -13,12 +13,6 @@ import torch_scatter
...
@@ -13,12 +13,6 @@ import torch_scatter
from
torch_scatter
import
scatter_add
,
scatter_mean
,
scatter_min
,
scatter_max
from
torch_scatter
import
scatter_add
,
scatter_mean
,
scatter_min
,
scatter_max
from
torch_scatter
import
segment_coo
,
segment_csr
from
torch_scatter
import
segment_coo
,
segment_csr
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--reduce'
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda'
)
args
=
parser
.
parse_args
()
args
.
dense_reduce
=
'sum'
if
args
.
reduce
==
'add'
else
args
.
reduce
iters
=
20
iters
=
20
sizes
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
sizes
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
...
@@ -94,6 +88,7 @@ def correctness(dataset):
...
@@ -94,6 +88,7 @@ def correctness(dataset):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
@
torch
.
no_grad
()
def
time_func
(
func
,
x
):
def
time_func
(
func
,
x
):
try
:
try
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -184,6 +179,13 @@ def timing(dataset):
...
@@ -184,6 +179,13 @@ def timing(dataset):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--reduce'
,
type
=
str
,
required
=
True
,
choices
=
[
'add'
,
'mean'
,
'min'
,
'max'
])
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda'
)
args
=
parser
.
parse_args
()
args
.
dense_reduce
=
'sum'
if
args
.
reduce
==
'add'
else
args
.
reduce
for
_
in
range
(
10
):
# Warmup.
for
_
in
range
(
10
):
# Warmup.
torch
.
randn
(
100
,
100
,
device
=
args
.
device
).
sum
()
torch
.
randn
(
100
,
100
,
device
=
args
.
device
).
sum
()
for
dataset
in
itertools
.
chain
(
short_rows
,
long_rows
):
for
dataset
in
itertools
.
chain
(
short_rows
,
long_rows
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment