Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f65bd2d6
Unverified
Commit
f65bd2d6
authored
Jan 13, 2023
by
Rhett Ying
Committed by
GitHub
Jan 13, 2023
Browse files
[CI] fix device configure when run on GPU (#5154)
parent
cdfd1e38
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
9 deletions
+14
-9
benchmarks/benchmarks/api/bench_edge_subgraph.py
benchmarks/benchmarks/api/bench_edge_subgraph.py
+4
-4
benchmarks/benchmarks/api/bench_in_subgraph.py
benchmarks/benchmarks/api/bench_in_subgraph.py
+4
-2
benchmarks/benchmarks/api/bench_node_subgraph.py
benchmarks/benchmarks/api/bench_node_subgraph.py
+4
-2
benchmarks/benchmarks/api/bench_sample_neighbors.py
benchmarks/benchmarks/api/bench_sample_neighbors.py
+2
-1
No files found.
benchmarks/benchmarks/api/bench_edge_subgraph.py
View file @
f65bd2d6
...
@@ -9,7 +9,6 @@ import dgl.function as fn
...
@@ -9,7 +9,6 @@ import dgl.function as fn
from
..
import
utils
from
..
import
utils
@
utils
.
skip_if_gpu
()
@
utils
.
benchmark
(
"time"
)
@
utils
.
benchmark
(
"time"
)
@
utils
.
parametrize
(
"graph_name"
,
[
"livejournal"
,
"reddit"
])
@
utils
.
parametrize
(
"graph_name"
,
[
"livejournal"
,
"reddit"
])
@
utils
.
parametrize
(
"format"
,
[
"coo"
])
@
utils
.
parametrize
(
"format"
,
[
"coo"
])
...
@@ -20,15 +19,16 @@ def track_time(graph_name, format, seed_egdes_num):
...
@@ -20,15 +19,16 @@ def track_time(graph_name, format, seed_egdes_num):
graph
=
graph
.
to
(
device
)
graph
=
graph
.
to
(
device
)
seed_edges
=
np
.
random
.
randint
(
0
,
graph
.
num_edges
(),
seed_egdes_num
)
seed_edges
=
np
.
random
.
randint
(
0
,
graph
.
num_edges
(),
seed_egdes_num
)
seed_edges
=
torch
.
from_numpy
(
seed_edges
).
to
(
device
)
# dry run
# dry run
for
i
in
range
(
3
):
for
i
in
range
(
3
):
dgl
.
edge_subgraph
(
graph
,
seed_edges
)
dgl
.
edge_subgraph
(
graph
,
seed_edges
)
# timing
# timing
num_iters
=
50
with
utils
.
Timer
()
as
t
:
with
utils
.
Timer
()
as
t
:
for
i
in
range
(
3
):
for
i
in
range
(
num_iters
):
dgl
.
edge_subgraph
(
graph
,
seed_edges
)
dgl
.
edge_subgraph
(
graph
,
seed_edges
)
return
t
.
elapsed_secs
/
3
return
t
.
elapsed_secs
/
num_iters
benchmarks/benchmarks/api/bench_in_subgraph.py
View file @
f65bd2d6
...
@@ -19,14 +19,16 @@ def track_time(graph_name, format, seed_nodes_num):
...
@@ -19,14 +19,16 @@ def track_time(graph_name, format, seed_nodes_num):
graph
=
graph
.
to
(
device
)
graph
=
graph
.
to
(
device
)
seed_nodes
=
np
.
random
.
randint
(
0
,
graph
.
num_nodes
(),
seed_nodes_num
)
seed_nodes
=
np
.
random
.
randint
(
0
,
graph
.
num_nodes
(),
seed_nodes_num
)
seed_nodes
=
torch
.
from_numpy
(
seed_nodes
).
to
(
device
)
# dry run
# dry run
for
i
in
range
(
3
):
for
i
in
range
(
3
):
dgl
.
in_subgraph
(
graph
,
seed_nodes
)
dgl
.
in_subgraph
(
graph
,
seed_nodes
)
# timing
# timing
num_iters
=
50
with
utils
.
Timer
()
as
t
:
with
utils
.
Timer
()
as
t
:
for
i
in
range
(
3
):
for
i
in
range
(
num_iters
):
dgl
.
in_subgraph
(
graph
,
seed_nodes
)
dgl
.
in_subgraph
(
graph
,
seed_nodes
)
return
t
.
elapsed_secs
/
3
return
t
.
elapsed_secs
/
num_iters
benchmarks/benchmarks/api/bench_node_subgraph.py
View file @
f65bd2d6
...
@@ -19,14 +19,16 @@ def track_time(graph_name, format, seed_nodes_num):
...
@@ -19,14 +19,16 @@ def track_time(graph_name, format, seed_nodes_num):
graph
=
graph
.
to
(
device
)
graph
=
graph
.
to
(
device
)
seed_nodes
=
np
.
random
.
randint
(
0
,
graph
.
num_nodes
(),
seed_nodes_num
)
seed_nodes
=
np
.
random
.
randint
(
0
,
graph
.
num_nodes
(),
seed_nodes_num
)
seed_nodes
=
torch
.
from_numpy
(
seed_nodes
).
to
(
device
)
# dry run
# dry run
for
i
in
range
(
3
):
for
i
in
range
(
3
):
dgl
.
node_subgraph
(
graph
,
seed_nodes
)
dgl
.
node_subgraph
(
graph
,
seed_nodes
)
# timing
# timing
num_iters
=
50
with
utils
.
Timer
()
as
t
:
with
utils
.
Timer
()
as
t
:
for
i
in
range
(
3
):
for
i
in
range
(
num_iters
):
dgl
.
node_subgraph
(
graph
,
seed_nodes
)
dgl
.
node_subgraph
(
graph
,
seed_nodes
)
return
t
.
elapsed_secs
/
3
return
t
.
elapsed_secs
/
num_iters
benchmarks/benchmarks/api/bench_sample_neighbors.py
View file @
f65bd2d6
...
@@ -17,10 +17,11 @@ from .. import utils
...
@@ -17,10 +17,11 @@ from .. import utils
@
utils
.
parametrize
(
"fanout"
,
[
5
,
20
,
40
])
@
utils
.
parametrize
(
"fanout"
,
[
5
,
20
,
40
])
def
track_time
(
graph_name
,
format
,
seed_nodes_num
,
fanout
):
def
track_time
(
graph_name
,
format
,
seed_nodes_num
,
fanout
):
device
=
utils
.
get_bench_device
()
device
=
utils
.
get_bench_device
()
graph
=
utils
.
get_graph
(
graph_name
,
format
)
graph
=
utils
.
get_graph
(
graph_name
,
format
)
.
to
(
device
)
edge_dir
=
"in"
edge_dir
=
"in"
seed_nodes
=
np
.
random
.
randint
(
0
,
graph
.
num_nodes
(),
seed_nodes_num
)
seed_nodes
=
np
.
random
.
randint
(
0
,
graph
.
num_nodes
(),
seed_nodes_num
)
seed_nodes
=
torch
.
from_numpy
(
seed_nodes
).
to
(
device
)
# dry run
# dry run
for
i
in
range
(
3
):
for
i
in
range
(
3
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment