Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
24de7391
You need to sign in or sign up before continuing.
Unverified
Commit
24de7391
authored
Nov 27, 2023
by
Ramon Zhou
Committed by
GitHub
Nov 27, 2023
Browse files
[GraphBolt] Modify multi-gpu example to make use of the persistent_workers (#6603)
parent
c63a926d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
43 deletions
+44
-43
examples/multigpu/graphbolt/node_classification.py
examples/multigpu/graphbolt/node_classification.py
+44
-43
No files found.
examples/multigpu/graphbolt/node_classification.py
View file @
24de7391
...
@@ -148,20 +148,10 @@ def create_dataloader(
...
@@ -148,20 +148,10 @@ def create_dataloader(
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
evaluate
(
rank
,
args
,
model
,
graph
,
features
,
itemset
,
num_classes
,
device
):
def
evaluate
(
rank
,
model
,
dataloader
,
num_classes
,
device
):
model
.
eval
()
model
.
eval
()
y
=
[]
y
=
[]
y_hats
=
[]
y_hats
=
[]
dataloader
=
create_dataloader
(
args
,
graph
,
features
,
itemset
,
drop_last
=
False
,
shuffle
=
False
,
drop_uneven_inputs
=
False
,
device
=
device
,
)
for
step
,
data
in
(
for
step
,
data
in
(
tqdm
.
tqdm
(
enumerate
(
dataloader
))
if
rank
==
0
else
enumerate
(
dataloader
)
tqdm
.
tqdm
(
enumerate
(
dataloader
))
if
rank
==
0
else
enumerate
(
dataloader
)
...
@@ -185,26 +175,13 @@ def train(
...
@@ -185,26 +175,13 @@ def train(
world_size
,
world_size
,
rank
,
rank
,
args
,
args
,
graph
,
train_dataloader
,
features
,
valid_dataloader
,
train_set
,
valid_set
,
num_classes
,
num_classes
,
model
,
model
,
device
,
device
,
):
):
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
# Create training data loader.
dataloader
=
create_dataloader
(
args
,
graph
,
features
,
train_set
,
device
,
drop_last
=
False
,
shuffle
=
True
,
drop_uneven_inputs
=
False
,
)
for
epoch
in
range
(
args
.
epochs
):
for
epoch
in
range
(
args
.
epochs
):
epoch_start
=
time
.
time
()
epoch_start
=
time
.
time
()
...
@@ -227,9 +204,9 @@ def train(
...
@@ -227,9 +204,9 @@ def train(
########################################################################
########################################################################
with
Join
([
model
]):
with
Join
([
model
]):
for
step
,
data
in
(
for
step
,
data
in
(
tqdm
.
tqdm
(
enumerate
(
dataloader
))
tqdm
.
tqdm
(
enumerate
(
train_
dataloader
))
if
rank
==
0
if
rank
==
0
else
enumerate
(
dataloader
)
else
enumerate
(
train_
dataloader
)
):
):
# The input features are from the source nodes in the first
# The input features are from the source nodes in the first
# layer's computation graph.
# layer's computation graph.
...
@@ -258,11 +235,8 @@ def train(
...
@@ -258,11 +235,8 @@ def train(
acc
=
(
acc
=
(
evaluate
(
evaluate
(
rank
,
rank
,
args
,
model
,
model
,
graph
,
valid_dataloader
,
features
,
valid_set
,
num_classes
,
num_classes
,
device
,
device
,
)
)
...
@@ -305,6 +279,7 @@ def run(rank, world_size, args, devices, dataset):
...
@@ -305,6 +279,7 @@ def run(rank, world_size, args, devices, dataset):
features
=
dataset
.
feature
features
=
dataset
.
feature
train_set
=
dataset
.
tasks
[
0
].
train_set
train_set
=
dataset
.
tasks
[
0
].
train_set
valid_set
=
dataset
.
tasks
[
0
].
validation_set
valid_set
=
dataset
.
tasks
[
0
].
validation_set
test_set
=
dataset
.
tasks
[
0
].
test_set
args
.
fanout
=
list
(
map
(
int
,
args
.
fanout
.
split
(
","
)))
args
.
fanout
=
list
(
map
(
int
,
args
.
fanout
.
split
(
","
)))
num_classes
=
dataset
.
tasks
[
0
].
metadata
[
"num_classes"
]
num_classes
=
dataset
.
tasks
[
0
].
metadata
[
"num_classes"
]
...
@@ -316,6 +291,38 @@ def run(rank, world_size, args, devices, dataset):
...
@@ -316,6 +291,38 @@ def run(rank, world_size, args, devices, dataset):
model
=
SAGE
(
in_size
,
hidden_size
,
out_size
).
to
(
device
)
model
=
SAGE
(
in_size
,
hidden_size
,
out_size
).
to
(
device
)
model
=
DDP
(
model
)
model
=
DDP
(
model
)
# Create data loaders.
train_dataloader
=
create_dataloader
(
args
,
graph
,
features
,
train_set
,
device
,
drop_last
=
False
,
shuffle
=
True
,
drop_uneven_inputs
=
False
,
)
valid_dataloader
=
create_dataloader
(
args
,
graph
,
features
,
valid_set
,
device
,
drop_last
=
False
,
shuffle
=
False
,
drop_uneven_inputs
=
False
,
)
test_dataloader
=
create_dataloader
(
args
,
graph
,
features
,
test_set
,
device
,
drop_last
=
False
,
shuffle
=
False
,
drop_uneven_inputs
=
False
,
)
# Model training.
# Model training.
if
rank
==
0
:
if
rank
==
0
:
print
(
"Training..."
)
print
(
"Training..."
)
...
@@ -323,10 +330,8 @@ def run(rank, world_size, args, devices, dataset):
...
@@ -323,10 +330,8 @@ def run(rank, world_size, args, devices, dataset):
world_size
,
world_size
,
rank
,
rank
,
args
,
args
,
graph
,
train_dataloader
,
features
,
valid_dataloader
,
train_set
,
valid_set
,
num_classes
,
num_classes
,
model
,
model
,
device
,
device
,
...
@@ -335,17 +340,13 @@ def run(rank, world_size, args, devices, dataset):
...
@@ -335,17 +340,13 @@ def run(rank, world_size, args, devices, dataset):
# Test the model.
# Test the model.
if
rank
==
0
:
if
rank
==
0
:
print
(
"Testing..."
)
print
(
"Testing..."
)
test_set
=
dataset
.
tasks
[
0
].
test_set
test_acc
=
(
test_acc
=
(
evaluate
(
evaluate
(
rank
,
rank
,
args
,
model
,
model
,
graph
,
test_dataloader
,
features
,
num_classes
,
itemset
=
test_set
,
device
,
num_classes
=
num_classes
,
device
=
device
,
)
)
/
world_size
/
world_size
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment