Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e3752754
Unverified
Commit
e3752754
authored
Nov 29, 2023
by
Ramon Zhou
Committed by
GitHub
Nov 29, 2023
Browse files
[Misc] Align dgl multi-gpu example to graphbolt (#6605)
parent
096bbf96
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
16 deletions
+31
-16
examples/multigpu/node_classification_sage.py
examples/multigpu/node_classification_sage.py
+31
-16
No files found.
examples/multigpu/node_classification_sage.py
View file @
e3752754
...
@@ -123,12 +123,13 @@ class SAGE(nn.Module):
...
@@ -123,12 +123,13 @@ class SAGE(nn.Module):
return
y
return
y
def
evaluate
(
model
,
g
,
num_classes
,
dataloader
):
def
evaluate
(
device
,
model
,
g
,
num_classes
,
dataloader
):
model
.
eval
()
model
.
eval
()
ys
=
[]
ys
=
[]
y_hats
=
[]
y_hats
=
[]
for
it
,
(
input_nodes
,
output_nodes
,
blocks
)
in
enumerate
(
dataloader
):
for
it
,
(
input_nodes
,
output_nodes
,
blocks
)
in
enumerate
(
dataloader
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
blocks
=
[
block
.
to
(
device
)
for
block
in
blocks
]
x
=
blocks
[
0
].
srcdata
[
"feat"
]
x
=
blocks
[
0
].
srcdata
[
"feat"
]
ys
.
append
(
blocks
[
-
1
].
dstdata
[
"label"
])
ys
.
append
(
blocks
[
-
1
].
dstdata
[
"label"
])
y_hats
.
append
(
model
(
blocks
,
x
))
y_hats
.
append
(
model
(
blocks
,
x
))
...
@@ -145,6 +146,8 @@ def layerwise_infer(
...
@@ -145,6 +146,8 @@ def layerwise_infer(
):
):
model
.
eval
()
model
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
if
not
use_uva
:
g
=
g
.
to
(
device
)
pred
=
model
.
module
.
inference
(
g
,
device
,
batch_size
,
use_uva
)
pred
=
model
.
module
.
inference
(
g
,
device
,
batch_size
,
use_uva
)
pred
=
pred
[
nid
]
pred
=
pred
[
nid
]
labels
=
g
.
ndata
[
"label"
][
nid
].
to
(
pred
.
device
)
labels
=
g
.
ndata
[
"label"
][
nid
].
to
(
pred
.
device
)
...
@@ -159,17 +162,20 @@ def train(
...
@@ -159,17 +162,20 @@ def train(
proc_id
,
proc_id
,
nprocs
,
nprocs
,
device
,
device
,
args
,
g
,
g
,
num_classes
,
num_classes
,
train_idx
,
train_idx
,
val_idx
,
val_idx
,
model
,
model
,
use_uva
,
use_uva
,
num_epochs
,
):
):
# Instantiate a neighbor sampler
# Instantiate a neighbor sampler
sampler
=
NeighborSampler
(
sampler
=
NeighborSampler
(
[
10
,
10
,
10
],
prefetch_node_feats
=
[
"feat"
],
prefetch_labels
=
[
"label"
]
[
10
,
10
,
10
],
prefetch_node_feats
=
[
"feat"
],
prefetch_labels
=
[
"label"
],
fused
=
(
args
.
mode
!=
"benchmark"
),
)
)
train_dataloader
=
DataLoader
(
train_dataloader
=
DataLoader
(
g
,
g
,
...
@@ -179,7 +185,7 @@ def train(
...
@@ -179,7 +185,7 @@ def train(
batch_size
=
1024
,
batch_size
=
1024
,
shuffle
=
True
,
shuffle
=
True
,
drop_last
=
False
,
drop_last
=
False
,
num_workers
=
0
,
num_workers
=
args
.
num_workers
,
use_ddp
=
True
,
# To split the set for each process
use_ddp
=
True
,
# To split the set for each process
use_uva
=
use_uva
,
use_uva
=
use_uva
,
)
)
...
@@ -191,12 +197,12 @@ def train(
...
@@ -191,12 +197,12 @@ def train(
batch_size
=
1024
,
batch_size
=
1024
,
shuffle
=
True
,
shuffle
=
True
,
drop_last
=
False
,
drop_last
=
False
,
num_workers
=
0
,
num_workers
=
args
.
num_workers
,
use_ddp
=
True
,
use_ddp
=
True
,
use_uva
=
use_uva
,
use_uva
=
use_uva
,
)
)
opt
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
,
weight_decay
=
5e-4
)
opt
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
,
weight_decay
=
5e-4
)
for
epoch
in
range
(
num_epochs
):
for
epoch
in
range
(
args
.
num_epochs
):
t0
=
time
.
time
()
t0
=
time
.
time
()
model
.
train
()
model
.
train
()
total_loss
=
0
total_loss
=
0
...
@@ -224,7 +230,8 @@ def train(
...
@@ -224,7 +230,8 @@ def train(
# for more information.
# for more information.
#####################################################################
#####################################################################
acc
=
(
acc
=
(
evaluate
(
model
,
g
,
num_classes
,
val_dataloader
).
to
(
device
)
/
nprocs
evaluate
(
device
,
model
,
g
,
num_classes
,
val_dataloader
).
to
(
device
)
/
nprocs
)
)
t1
=
time
.
time
()
t1
=
time
.
time
()
# Reduce `acc` tensors to process 0.
# Reduce `acc` tensors to process 0.
...
@@ -236,7 +243,7 @@ def train(
...
@@ -236,7 +243,7 @@ def train(
)
)
def
run
(
proc_id
,
nprocs
,
devices
,
g
,
data
,
mode
,
num_epoch
s
):
def
run
(
proc_id
,
nprocs
,
devices
,
g
,
data
,
arg
s
):
# Find corresponding device for current process.
# Find corresponding device for current process.
device
=
devices
[
proc_id
]
device
=
devices
[
proc_id
]
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
...
@@ -263,9 +270,10 @@ def run(proc_id, nprocs, devices, g, data, mode, num_epochs):
...
@@ -263,9 +270,10 @@ def run(proc_id, nprocs, devices, g, data, mode, num_epochs):
rank
=
proc_id
,
rank
=
proc_id
,
)
)
num_classes
,
train_idx
,
val_idx
,
test_idx
=
data
num_classes
,
train_idx
,
val_idx
,
test_idx
=
data
if
args
.
mode
!=
"benchmark"
:
train_idx
=
train_idx
.
to
(
device
)
train_idx
=
train_idx
.
to
(
device
)
val_idx
=
val_idx
.
to
(
device
)
val_idx
=
val_idx
.
to
(
device
)
g
=
g
.
to
(
device
if
mode
==
"puregpu"
else
"cpu"
)
g
=
g
.
to
(
device
if
args
.
mode
==
"puregpu"
else
"cpu"
)
in_size
=
g
.
ndata
[
"feat"
].
shape
[
1
]
in_size
=
g
.
ndata
[
"feat"
].
shape
[
1
]
model
=
SAGE
(
in_size
,
256
,
num_classes
).
to
(
device
)
model
=
SAGE
(
in_size
,
256
,
num_classes
).
to
(
device
)
model
=
DistributedDataParallel
(
model
=
DistributedDataParallel
(
...
@@ -273,20 +281,21 @@ def run(proc_id, nprocs, devices, g, data, mode, num_epochs):
...
@@ -273,20 +281,21 @@ def run(proc_id, nprocs, devices, g, data, mode, num_epochs):
)
)
# Training.
# Training.
use_uva
=
mode
==
"mixed"
use_uva
=
args
.
mode
==
"mixed"
if
proc_id
==
0
:
if
proc_id
==
0
:
print
(
"Training..."
)
print
(
"Training..."
)
train
(
train
(
proc_id
,
proc_id
,
nprocs
,
nprocs
,
device
,
device
,
args
,
g
,
g
,
num_classes
,
num_classes
,
train_idx
,
train_idx
,
val_idx
,
val_idx
,
model
,
model
,
use_uva
,
use_uva
,
num_epochs
,
)
)
# Testing.
# Testing.
...
@@ -303,7 +312,7 @@ if __name__ == "__main__":
...
@@ -303,7 +312,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
parser
.
add_argument
(
"--mode"
,
"--mode"
,
default
=
"mixed"
,
default
=
"mixed"
,
choices
=
[
"mixed"
,
"puregpu"
],
choices
=
[
"mixed"
,
"puregpu"
,
"benchmark"
],
help
=
"Training mode. 'mixed' for CPU-GPU mixed training, "
help
=
"Training mode. 'mixed' for CPU-GPU mixed training, "
"'puregpu' for pure-GPU training."
,
"'puregpu' for pure-GPU training."
,
)
)
...
@@ -317,7 +326,7 @@ if __name__ == "__main__":
...
@@ -317,7 +326,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
parser
.
add_argument
(
"--num_epochs"
,
"--num_epochs"
,
type
=
int
,
type
=
int
,
default
=
2
0
,
default
=
1
0
,
help
=
"Number of epochs for train."
,
help
=
"Number of epochs for train."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -332,6 +341,12 @@ if __name__ == "__main__":
...
@@ -332,6 +341,12 @@ if __name__ == "__main__":
default
=
"dataset"
,
default
=
"dataset"
,
help
=
"Root directory of dataset."
,
help
=
"Root directory of dataset."
,
)
)
parser
.
add_argument
(
"--num_workers"
,
type
=
int
,
default
=
0
,
help
=
"Number of workers"
,
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
devices
=
list
(
map
(
int
,
args
.
gpu
.
split
(
","
)))
devices
=
list
(
map
(
int
,
args
.
gpu
.
split
(
","
)))
nprocs
=
len
(
devices
)
nprocs
=
len
(
devices
)
...
@@ -364,6 +379,6 @@ if __name__ == "__main__":
...
@@ -364,6 +379,6 @@ if __name__ == "__main__":
# To use DDP with n GPUs, spawn up n processes.
# To use DDP with n GPUs, spawn up n processes.
mp
.
spawn
(
mp
.
spawn
(
run
,
run
,
args
=
(
nprocs
,
devices
,
g
,
data
,
args
.
mode
,
args
.
num_epochs
),
args
=
(
nprocs
,
devices
,
g
,
data
,
args
),
nprocs
=
nprocs
,
nprocs
=
nprocs
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment