Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
34045b9a
Commit
34045b9a
authored
Jan 10, 2020
by
rusty1s
Browse files
linting
parent
4e4b69bd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
16 deletions
+30
-16
benchmark/gather.py
benchmark/gather.py
+11
-6
benchmark/scatter_segment.py
benchmark/scatter_segment.py
+19
-10
No files found.
benchmark/gather.py
View file @
34045b9a
# flake8: noqa
import
time
import
itertools
...
...
@@ -66,10 +64,17 @@ def timing(dataset):
dim_size
=
rowptr
.
size
(
0
)
-
1
avg_row_len
=
row
.
size
(
0
)
/
dim_size
select
=
lambda
x
:
x
.
index_select
(
0
,
row
)
gather
=
lambda
x
:
x
.
gather
(
0
,
row
.
view
(
-
1
,
1
).
expand
(
-
1
,
x
.
size
(
1
)))
gat_coo
=
lambda
x
:
gather_coo
(
x
,
row
)
gat_csr
=
lambda
x
:
gather_csr
(
x
,
rowptr
)
def
select
(
x
):
return
x
.
index_select
(
0
,
row
)
def
gather
(
x
):
return
x
.
gather
(
0
,
row
.
view
(
-
1
,
1
).
expand
(
-
1
,
x
.
size
(
1
)))
def
gat_coo
(
x
):
return
gather_coo
(
x
,
row
)
def
gat_csr
(
x
):
return
gather_csr
(
x
,
rowptr
)
t1
,
t2
,
t3
,
t4
=
[],
[],
[],
[]
for
size
in
sizes
:
...
...
benchmark/scatter_segment.py
View file @
34045b9a
# flake8: noqa
import
time
import
os.path
as
osp
import
itertools
...
...
@@ -120,14 +118,25 @@ def timing(dataset):
dim_size
=
rowptr
.
size
(
0
)
-
1
avg_row_len
=
row
.
size
(
0
)
/
dim_size
sca_row
=
lambda
x
:
getattr
(
torch_scatter
,
f
'scatter_
{
args
.
reduce
}
'
)(
x
,
row
,
dim
=
0
,
dim_size
=
dim_size
)
sca_col
=
lambda
x
:
getattr
(
torch_scatter
,
f
'scatter_
{
args
.
reduce
}
'
)(
x
,
row_perm
,
dim
=
0
,
dim_size
=
dim_size
)
seg_coo
=
lambda
x
:
segment_coo
(
x
,
row
,
reduce
=
args
.
reduce
)
seg_csr
=
lambda
x
:
segment_csr
(
x
,
rowptr
,
reduce
=
args
.
reduce
)
dense1
=
lambda
x
:
getattr
(
torch
,
args
.
dense_reduce
)(
x
,
dim
=-
2
)
dense2
=
lambda
x
:
getattr
(
torch
,
args
.
dense_reduce
)(
x
,
dim
=-
1
)
def
sca_row
(
x
):
op
=
getattr
(
torch_scatter
,
f
'scatter_
{
args
.
reduce
}
'
)
return
op
(
x
,
row
,
dim
=
0
,
dim_size
=
dim_size
)
def
sca_col
(
x
):
op
=
getattr
(
torch_scatter
,
f
'scatter_
{
args
.
reduce
}
'
)
return
op
(
x
,
row_perm
,
dim
=
0
,
dim_size
=
dim_size
)
def
seg_coo
(
x
):
return
segment_coo
(
x
,
row
,
reduce
=
args
.
reduce
)
def
seg_csr
(
x
):
return
segment_csr
(
x
,
rowptr
,
reduce
=
args
.
reduce
)
def
dense1
(
x
):
return
getattr
(
torch
,
args
.
dense_reduce
)(
x
,
dim
=-
2
)
def
dense2
(
x
):
return
getattr
(
torch
,
args
.
dense_reduce
)(
x
,
dim
=-
1
)
t1
,
t2
,
t3
,
t4
,
t5
,
t6
=
[],
[],
[],
[],
[],
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment