Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b6f774cd
"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "2595fa98423e70acaeee5dc54f954c0974bb6820"
Unverified
Commit
b6f774cd
authored
Jun 08, 2023
by
Rhett Ying
Committed by
GitHub
Jun 08, 2023
Browse files
[Examples] fix lint (#5838)
parent
5ada3fc9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
13 deletions
+9
-13
examples/pytorch/graphsage/dist/train_dist.py
examples/pytorch/graphsage/dist/train_dist.py
+9
-13
No files found.
examples/pytorch/graphsage/dist/train_dist.py
View file @
b6f774cd
...
...
@@ -3,14 +3,16 @@ import socket
import
time
from
contextlib
import
contextmanager
import
dgl
import
dgl.nn.pytorch
as
dglnn
import
numpy
as
np
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
tqdm
import
dgl
import
dgl.nn.pytorch
as
dglnn
def
load_subtensor
(
g
,
seeds
,
input_nodes
,
device
,
load_feat
=
True
):
"""
...
...
@@ -83,9 +85,7 @@ class DistSAGE(nn.Module):
"h_last"
,
persistent
=
True
,
)
print
(
f
"|V|=
{
g
.
num_nodes
()
}
, eval batch size:
{
batch_size
}
"
)
print
(
f
"|V|=
{
g
.
num_nodes
()
}
, eval batch size:
{
batch_size
}
"
)
sampler
=
dgl
.
dataloading
.
NeighborSampler
([
-
1
])
dataloader
=
dgl
.
dataloading
.
DistNodeDataLoader
(
...
...
@@ -249,7 +249,7 @@ def run(args, device, data):
acc
.
item
(),
np
.
mean
(
iter_tput
[
3
:]),
gpu_mem_alloc
,
np
.
sum
(
step_time
[
-
args
.
log_every
:]),
np
.
sum
(
step_time
[
-
args
.
log_every
:]),
)
)
start
=
time
.
time
()
...
...
@@ -284,8 +284,7 @@ def run(args, device, data):
device
,
)
print
(
"Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}"
.
format
(
"Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}"
.
format
(
g
.
rank
(),
val_acc
,
test_acc
,
time
.
time
()
-
start
)
)
...
...
@@ -298,10 +297,7 @@ def main(args):
print
(
socket
.
gethostname
(),
"Initializing DGL process group"
)
th
.
distributed
.
init_process_group
(
backend
=
args
.
backend
)
print
(
socket
.
gethostname
(),
"Initializing DistGraph"
)
g
=
dgl
.
distributed
.
DistGraph
(
args
.
graph_name
,
part_config
=
args
.
part_config
)
g
=
dgl
.
distributed
.
DistGraph
(
args
.
graph_name
,
part_config
=
args
.
part_config
)
print
(
socket
.
gethostname
(),
"rank:"
,
g
.
rank
())
pb
=
g
.
get_partition_book
()
...
...
@@ -413,7 +409,7 @@ if __name__ == "__main__":
default
=
False
,
action
=
"store_true"
,
help
=
"Pad train nid to the same length across machine, to ensure num "
"of batches to be the same."
,
"of batches to be the same."
,
)
parser
.
add_argument
(
"--net_type"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment