Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
1bc77c5e
Unverified
Commit
1bc77c5e
authored
Sep 15, 2022
by
Rhett Ying
Committed by
GitHub
Sep 15, 2022
Browse files
[examples]educe memory consumption (#4558)
* [examples]educe memory consumption * reffine help message * refine
parent
9a00cf19
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
6 deletions
+12
-6
examples/pytorch/graphsage/dist/train_dist.py
examples/pytorch/graphsage/dist/train_dist.py
+12
-6
No files found.
examples/pytorch/graphsage/dist/train_dist.py
View file @
1bc77c5e
...
@@ -185,8 +185,6 @@ def run(args, device, data):
...
@@ -185,8 +185,6 @@ def run(args, device, data):
loss_fcn
=
loss_fcn
.
to
(
device
)
loss_fcn
=
loss_fcn
.
to
(
device
)
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
)
train_size
=
th
.
sum
(
g
.
ndata
[
'train_mask'
][
0
:
g
.
number_of_nodes
()])
# Training loop
# Training loop
iter_tput
=
[]
iter_tput
=
[]
epoch
=
0
epoch
=
0
...
@@ -284,13 +282,17 @@ def main(args):
...
@@ -284,13 +282,17 @@ def main(args):
g
.
rank
(),
len
(
train_nid
),
len
(
np
.
intersect1d
(
train_nid
.
numpy
(),
local_nid
)),
g
.
rank
(),
len
(
train_nid
),
len
(
np
.
intersect1d
(
train_nid
.
numpy
(),
local_nid
)),
len
(
val_nid
),
len
(
np
.
intersect1d
(
val_nid
.
numpy
(),
local_nid
)),
len
(
val_nid
),
len
(
np
.
intersect1d
(
val_nid
.
numpy
(),
local_nid
)),
len
(
test_nid
),
len
(
np
.
intersect1d
(
test_nid
.
numpy
(),
local_nid
))))
len
(
test_nid
),
len
(
np
.
intersect1d
(
test_nid
.
numpy
(),
local_nid
))))
del
local_nid
if
args
.
num_gpus
==
-
1
:
if
args
.
num_gpus
==
-
1
:
device
=
th
.
device
(
'cpu'
)
device
=
th
.
device
(
'cpu'
)
else
:
else
:
dev_id
=
g
.
rank
()
%
args
.
num_gpus
dev_id
=
g
.
rank
()
%
args
.
num_gpus
device
=
th
.
device
(
'cuda:'
+
str
(
dev_id
))
device
=
th
.
device
(
'cuda:'
+
str
(
dev_id
))
n_classes
=
args
.
n_classes
if
n_classes
==
-
1
:
labels
=
g
.
ndata
[
'labels'
][
np
.
arange
(
g
.
number_of_nodes
())]
labels
=
g
.
ndata
[
'labels'
][
np
.
arange
(
g
.
number_of_nodes
())]
n_classes
=
len
(
th
.
unique
(
labels
[
th
.
logical_not
(
th
.
isnan
(
labels
))]))
n_classes
=
len
(
th
.
unique
(
labels
[
th
.
logical_not
(
th
.
isnan
(
labels
))]))
del
labels
print
(
'#labels:'
,
n_classes
)
print
(
'#labels:'
,
n_classes
)
# Pack data
# Pack data
...
@@ -307,8 +309,12 @@ if __name__ == '__main__':
...
@@ -307,8 +309,12 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--ip_config'
,
type
=
str
,
help
=
'The file for IP configuration'
)
parser
.
add_argument
(
'--ip_config'
,
type
=
str
,
help
=
'The file for IP configuration'
)
parser
.
add_argument
(
'--part_config'
,
type
=
str
,
help
=
'The path to the partition config file'
)
parser
.
add_argument
(
'--part_config'
,
type
=
str
,
help
=
'The path to the partition config file'
)
parser
.
add_argument
(
'--num_clients'
,
type
=
int
,
help
=
'The number of clients'
)
parser
.
add_argument
(
'--num_clients'
,
type
=
int
,
help
=
'The number of clients'
)
parser
.
add_argument
(
'--n_classes'
,
type
=
int
,
help
=
'the number of classes'
)
parser
.
add_argument
(
'--n_classes'
,
type
=
int
,
default
=-
1
,
parser
.
add_argument
(
'--backend'
,
type
=
str
,
default
=
'gloo'
,
help
=
'pytorch distributed backend'
)
help
=
'The number of classes. If not specified, this'
' value will be calculated via scaning all the labels'
' in the dataset which probably causes memory burst.'
)
parser
.
add_argument
(
'--backend'
,
type
=
str
,
default
=
'gloo'
,
help
=
'pytorch distributed backend'
)
parser
.
add_argument
(
'--num_gpus'
,
type
=
int
,
default
=-
1
,
parser
.
add_argument
(
'--num_gpus'
,
type
=
int
,
default
=-
1
,
help
=
"the number of GPU device. Use -1 for CPU training"
)
help
=
"the number of GPU device. Use -1 for CPU training"
)
parser
.
add_argument
(
'--num_epochs'
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
'--num_epochs'
,
type
=
int
,
default
=
20
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment