Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f5c8c0b5
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "84e5cc596c67afd2863131f5fa309ebaa994624d"
Unverified
Commit
f5c8c0b5
authored
Nov 06, 2023
by
Ramon Zhou
Committed by
GitHub
Nov 06, 2023
Browse files
[GraphBolt] Update multi-gpu example for regression benchmark. (#6519)
parent
84a01a16
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
3 deletions
+17
-3
examples/multigpu/graphbolt/node_classification.py
examples/multigpu/graphbolt/node_classification.py
+17
-3
No files found.
examples/multigpu/graphbolt/node_classification.py
View file @
f5c8c0b5
...
@@ -38,6 +38,7 @@ main
...
@@ -38,6 +38,7 @@ main
"""
"""
import
argparse
import
argparse
import
os
import
os
import
time
import
dgl.graphbolt
as
gb
import
dgl.graphbolt
as
gb
import
dgl.nn
as
dglnn
import
dgl.nn
as
dglnn
...
@@ -138,7 +139,9 @@ def create_dataloader(
...
@@ -138,7 +139,9 @@ def create_dataloader(
# A CopyTo object copying data in the datapipe to a specified device.\
# A CopyTo object copying data in the datapipe to a specified device.\
############################################################################
############################################################################
datapipe
=
datapipe
.
copy_to
(
device
)
datapipe
=
datapipe
.
copy_to
(
device
)
dataloader
=
gb
.
SingleProcessDataLoader
(
datapipe
)
dataloader
=
gb
.
MultiProcessDataLoader
(
datapipe
,
num_workers
=
args
.
num_workers
)
# Return the fully-initialized DataLoader object.
# Return the fully-initialized DataLoader object.
return
dataloader
return
dataloader
...
@@ -204,6 +207,8 @@ def train(
...
@@ -204,6 +207,8 @@ def train(
)
)
for
epoch
in
range
(
args
.
epochs
):
for
epoch
in
range
(
args
.
epochs
):
epoch_start
=
time
.
time
()
model
.
train
()
model
.
train
()
total_loss
=
torch
.
tensor
(
0
,
dtype
=
torch
.
float
).
to
(
device
)
total_loss
=
torch
.
tensor
(
0
,
dtype
=
torch
.
float
).
to
(
device
)
########################################################################
########################################################################
...
@@ -273,11 +278,14 @@ def train(
...
@@ -273,11 +278,14 @@ def train(
dist
.
reduce
(
tensor
=
acc
,
dst
=
0
)
dist
.
reduce
(
tensor
=
acc
,
dst
=
0
)
total_loss
/=
step
+
1
total_loss
/=
step
+
1
dist
.
reduce
(
tensor
=
total_loss
,
dst
=
0
)
dist
.
reduce
(
tensor
=
total_loss
,
dst
=
0
)
epoch_end
=
time
.
time
()
if
rank
==
0
:
if
rank
==
0
:
print
(
print
(
f
"Epoch
{
epoch
:
05
d
}
| "
f
"Epoch
{
epoch
:
05
d
}
| "
f
"Average Loss
{
total_loss
.
item
()
/
world_size
:.
4
f
}
| "
f
"Average Loss
{
total_loss
.
item
()
/
world_size
:.
4
f
}
| "
f
"Accuracy
{
acc
.
item
():.
4
f
}
"
f
"Accuracy
{
acc
.
item
():.
4
f
}
| "
f
"Time
{
epoch_end
-
epoch_start
:.
4
f
}
"
)
)
...
@@ -342,7 +350,7 @@ def run(rank, world_size, args, devices, dataset):
...
@@ -342,7 +350,7 @@ def run(rank, world_size, args, devices, dataset):
)
)
dist
.
reduce
(
tensor
=
test_acc
,
dst
=
0
)
dist
.
reduce
(
tensor
=
test_acc
,
dst
=
0
)
if
rank
==
0
:
if
rank
==
0
:
print
(
f
"Test Accuracy
is
{
test_acc
.
item
():.
4
f
}
"
)
print
(
f
"Test Accuracy
{
test_acc
.
item
():.
4
f
}
"
)
def
parse_args
():
def
parse_args
():
...
@@ -376,6 +384,9 @@ def parse_args():
...
@@ -376,6 +384,9 @@ def parse_args():
help
=
"Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
help
=
"Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
" identical with the number of layers in your model. Default: 15,10,5"
,
" identical with the number of layers in your model. Default: 15,10,5"
,
)
)
parser
.
add_argument
(
"--num-workers"
,
type
=
int
,
default
=
0
,
help
=
"The number of processes."
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -393,6 +404,9 @@ if __name__ == "__main__":
...
@@ -393,6 +404,9 @@ if __name__ == "__main__":
# Load and preprocess dataset.
# Load and preprocess dataset.
dataset
=
gb
.
BuiltinDataset
(
"ogbn-products"
).
load
()
dataset
=
gb
.
BuiltinDataset
(
"ogbn-products"
).
load
()
# Thread limiting to avoid resource competition.
os
.
environ
[
"OMP_NUM_THREADS"
]
=
str
(
mp
.
cpu_count
()
//
2
//
world_size
)
mp
.
set_sharing_strategy
(
"file_system"
)
mp
.
set_sharing_strategy
(
"file_system"
)
mp
.
spawn
(
mp
.
spawn
(
run
,
run
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment