Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
1006514c
"docs/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "72b8e9f7971e397b38999efb6955bbe64d156de2"
Commit
1006514c
authored
Feb 03, 2020
by
rusty1s
Browse files
test with pytorch scatter
parent
f056396b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
20 deletions
+36
-20
benchmark/scatter_segment.py
benchmark/scatter_segment.py
+36
-20
No files found.
benchmark/scatter_segment.py
View file @
1006514c
...
@@ -115,10 +115,20 @@ def timing(dataset):
...
@@ -115,10 +115,20 @@ def timing(dataset):
dim_size
=
rowptr
.
size
(
0
)
-
1
dim_size
=
rowptr
.
size
(
0
)
-
1
avg_row_len
=
row
.
size
(
0
)
/
dim_size
avg_row_len
=
row
.
size
(
0
)
/
dim_size
def
sca_row
(
x
):
def
sca1_row
(
x
):
out
=
x
.
new_zeros
(
dim_size
,
*
x
.
size
()[
1
:])
row_tmp
=
row
.
view
(
-
1
,
1
).
expand_as
(
x
)
if
x
.
dim
()
>
1
else
row
return
out
.
scatter_add_
(
0
,
row_tmp
,
x
)
def
sca1_col
(
x
):
out
=
x
.
new_zeros
(
dim_size
,
*
x
.
size
()[
1
:])
row2_tmp
=
row2
.
view
(
-
1
,
1
).
expand_as
(
x
)
if
x
.
dim
()
>
1
else
row2
return
out
.
scatter_add_
(
0
,
row2_tmp
,
x
)
def
sca2_row
(
x
):
return
scatter
(
x
,
row
,
dim
=
0
,
dim_size
=
dim_size
,
reduce
=
args
.
reduce
)
return
scatter
(
x
,
row
,
dim
=
0
,
dim_size
=
dim_size
,
reduce
=
args
.
reduce
)
def
sca_col
(
x
):
def
sca
2
_col
(
x
):
return
scatter
(
x
,
row2
,
dim
=
0
,
dim_size
=
dim_size
,
reduce
=
args
.
reduce
)
return
scatter
(
x
,
row2
,
dim
=
0
,
dim_size
=
dim_size
,
reduce
=
args
.
reduce
)
def
seg_coo
(
x
):
def
seg_coo
(
x
):
...
@@ -133,17 +143,19 @@ def timing(dataset):
...
@@ -133,17 +143,19 @@ def timing(dataset):
def
dense2
(
x
):
def
dense2
(
x
):
return
getattr
(
torch
,
args
.
reduce
)(
x
,
dim
=-
1
)
return
getattr
(
torch
,
args
.
reduce
)(
x
,
dim
=-
1
)
t1
,
t2
,
t3
,
t4
,
t5
,
t6
=
[],
[],
[],
[],
[],
[]
t1
,
t2
,
t3
,
t4
,
t5
,
t6
,
t7
,
t8
=
[],
[],
[],
[],
[],
[],
[],
[]
for
size
in
sizes
:
for
size
in
sizes
:
try
:
try
:
x
=
torch
.
randn
((
row
.
size
(
0
),
size
),
device
=
args
.
device
)
x
=
torch
.
randn
((
row
.
size
(
0
),
size
),
device
=
args
.
device
)
x
=
x
.
squeeze
(
-
1
)
if
size
==
1
else
x
x
=
x
.
squeeze
(
-
1
)
if
size
==
1
else
x
t1
+=
[
time_func
(
sca_row
,
x
)]
t1
+=
[
time_func
(
sca1_row
,
x
)]
t2
+=
[
time_func
(
sca_col
,
x
)]
t2
+=
[
time_func
(
sca1_col
,
x
)]
t3
+=
[
time_func
(
seg_coo
,
x
)]
t3
+=
[
time_func
(
sca2_row
,
x
)]
t4
+=
[
time_func
(
seg_csr
,
x
)]
t4
+=
[
time_func
(
sca2_col
,
x
)]
t5
+=
[
time_func
(
seg_coo
,
x
)]
t6
+=
[
time_func
(
seg_csr
,
x
)]
del
x
del
x
...
@@ -151,16 +163,16 @@ def timing(dataset):
...
@@ -151,16 +163,16 @@ def timing(dataset):
if
'out of memory'
not
in
str
(
e
):
if
'out of memory'
not
in
str
(
e
):
raise
RuntimeError
(
e
)
raise
RuntimeError
(
e
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
for
t
in
(
t1
,
t2
,
t3
,
t4
):
for
t
in
(
t1
,
t2
,
t3
,
t4
,
t5
,
t6
):
t
.
append
(
float
(
'inf'
))
t
.
append
(
float
(
'inf'
))
try
:
try
:
x
=
torch
.
randn
((
dim_size
,
int
(
avg_row_len
+
1
),
size
),
x
=
torch
.
randn
((
dim_size
,
int
(
avg_row_len
+
1
),
size
),
device
=
args
.
device
)
device
=
args
.
device
)
t
5
+=
[
time_func
(
dense1
,
x
)]
t
7
+=
[
time_func
(
dense1
,
x
)]
x
=
x
.
view
(
dim_size
,
size
,
int
(
avg_row_len
+
1
))
x
=
x
.
view
(
dim_size
,
size
,
int
(
avg_row_len
+
1
))
t
6
+=
[
time_func
(
dense2
,
x
)]
t
8
+=
[
time_func
(
dense2
,
x
)]
del
x
del
x
...
@@ -168,10 +180,10 @@ def timing(dataset):
...
@@ -168,10 +180,10 @@ def timing(dataset):
if
'out of memory'
not
in
str
(
e
):
if
'out of memory'
not
in
str
(
e
):
raise
RuntimeError
(
e
)
raise
RuntimeError
(
e
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
for
t
in
(
t
5
,
t
6
):
for
t
in
(
t
7
,
t
8
):
t
.
append
(
float
(
'inf'
))
t
.
append
(
float
(
'inf'
))
ts
=
torch
.
tensor
([
t1
,
t2
,
t3
,
t4
,
t5
,
t6
])
ts
=
torch
.
tensor
([
t1
,
t2
,
t3
,
t4
,
t5
,
t6
,
t7
,
t8
])
winner
=
torch
.
zeros_like
(
ts
,
dtype
=
torch
.
bool
)
winner
=
torch
.
zeros_like
(
ts
,
dtype
=
torch
.
bool
)
winner
[
ts
.
argmin
(
dim
=
0
),
torch
.
arange
(
len
(
sizes
))]
=
1
winner
[
ts
.
argmin
(
dim
=
0
),
torch
.
arange
(
len
(
sizes
))]
=
1
winner
=
winner
.
tolist
()
winner
=
winner
.
tolist
()
...
@@ -179,29 +191,33 @@ def timing(dataset):
...
@@ -179,29 +191,33 @@ def timing(dataset):
name
=
f
'
{
group
}
/
{
name
}
'
name
=
f
'
{
group
}
/
{
name
}
'
print
(
f
'
{
bold
(
name
)
}
(avg row length:
{
avg_row_len
:.
2
f
}
):'
)
print
(
f
'
{
bold
(
name
)
}
(avg row length:
{
avg_row_len
:.
2
f
}
):'
)
print
(
'
\t
'
.
join
([
' '
]
+
[
f
'
{
size
:
>
5
}
'
for
size
in
sizes
]))
print
(
'
\t
'
.
join
([
' '
]
+
[
f
'
{
size
:
>
5
}
'
for
size
in
sizes
]))
print
(
'
\t
'
.
join
([
bold
(
'SCA_R
OW
'
)]
+
print
(
'
\t
'
.
join
([
bold
(
'SCA
1
_R
'
)]
+
[
bold
(
f
'
{
t
:.
5
f
}
'
,
f
)
for
t
,
f
in
zip
(
t1
,
winner
[
0
])]))
[
bold
(
f
'
{
t
:.
5
f
}
'
,
f
)
for
t
,
f
in
zip
(
t1
,
winner
[
0
])]))
print
(
'
\t
'
.
join
([
bold
(
'SCA_C
OL
'
)]
+
print
(
'
\t
'
.
join
([
bold
(
'SCA
1
_C
'
)]
+
[
bold
(
f
'
{
t
:.
5
f
}
'
,
f
)
for
t
,
f
in
zip
(
t2
,
winner
[
1
])]))
[
bold
(
f
'
{
t
:.
5
f
}
'
,
f
)
for
t
,
f
in
zip
(
t2
,
winner
[
1
])]))
print
(
'
\t
'
.
join
([
bold
(
'S
EG_COO
'
)]
+
print
(
'
\t
'
.
join
([
bold
(
'S
CA2_R
'
)]
+
[
bold
(
f
'
{
t
:.
5
f
}
'
,
f
)
for
t
,
f
in
zip
(
t3
,
winner
[
2
])]))
[
bold
(
f
'
{
t
:.
5
f
}
'
,
f
)
for
t
,
f
in
zip
(
t3
,
winner
[
2
])]))
print
(
'
\t
'
.
join
([
bold
(
'S
EG_CSR
'
)]
+
print
(
'
\t
'
.
join
([
bold
(
'S
CA2_C
'
)]
+
[
bold
(
f
'
{
t
:.
5
f
}
'
,
f
)
for
t
,
f
in
zip
(
t4
,
winner
[
3
])]))
[
bold
(
f
'
{
t
:.
5
f
}
'
,
f
)
for
t
,
f
in
zip
(
t4
,
winner
[
3
])]))
print
(
'
\t
'
.
join
([
bold
(
'
DENSE1
'
)]
+
print
(
'
\t
'
.
join
([
bold
(
'
SEG_COO
'
)]
+
[
bold
(
f
'
{
t
:.
5
f
}
'
,
f
)
for
t
,
f
in
zip
(
t5
,
winner
[
4
])]))
[
bold
(
f
'
{
t
:.
5
f
}
'
,
f
)
for
t
,
f
in
zip
(
t5
,
winner
[
4
])]))
print
(
'
\t
'
.
join
([
bold
(
'
DENSE2
'
)]
+
print
(
'
\t
'
.
join
([
bold
(
'
SEG_CSR
'
)]
+
[
bold
(
f
'
{
t
:.
5
f
}
'
,
f
)
for
t
,
f
in
zip
(
t6
,
winner
[
5
])]))
[
bold
(
f
'
{
t
:.
5
f
}
'
,
f
)
for
t
,
f
in
zip
(
t6
,
winner
[
5
])]))
print
(
'
\t
'
.
join
([
bold
(
'DENSE1 '
)]
+
[
bold
(
f
'
{
t
:.
5
f
}
'
,
f
)
for
t
,
f
in
zip
(
t7
,
winner
[
6
])]))
print
(
'
\t
'
.
join
([
bold
(
'DENSE2 '
)]
+
[
bold
(
f
'
{
t
:.
5
f
}
'
,
f
)
for
t
,
f
in
zip
(
t8
,
winner
[
7
])]))
print
()
print
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--reduce'
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
'--reduce'
,
type
=
str
,
required
=
True
,
choices
=
[
'sum'
,
'add'
,
'mean'
,
'min'
,
'max'
])
choices
=
[
'sum'
,
'mean'
,
'min'
,
'max'
])
parser
.
add_argument
(
'--with_backward'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--with_backward'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
iters
=
1
if
args
.
device
==
'cpu'
else
2
0
iters
=
1
if
args
.
device
==
'cpu'
else
5
0
sizes
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
sizes
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
sizes
=
sizes
[:
3
]
if
args
.
device
==
'cpu'
else
sizes
sizes
=
sizes
[:
3
]
if
args
.
device
==
'cpu'
else
sizes
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment