Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
0767c5fc
"vscode:/vscode.git/clone" did not exist on "28f404349d69da1af7b52f18b022bc7971951a41"
Unverified
Commit
0767c5fc
authored
Feb 07, 2022
by
Jinjing Zhou
Committed by
GitHub
Feb 07, 2022
Browse files
Fix dist example padding problem (#3687)
parent
f282ee30
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
2 deletions
+9
-2
examples/pytorch/graphsage/experimental/train_dist.py
examples/pytorch/graphsage/experimental/train_dist.py
+9
-2
No files found.
examples/pytorch/graphsage/experimental/train_dist.py
View file @
0767c5fc
...
...
@@ -184,7 +184,12 @@ def pad_data(nids, device):
def
run
(
args
,
device
,
data
):
# Unpack data
train_nid
,
val_nid
,
test_nid
,
in_feats
,
n_classes
,
g
=
data
shuffle
=
True
if
args
.
pad_data
:
train_nid
=
pad_data
(
train_nid
,
device
)
# Current pipeline doesn't support duplicate node id within the same batch
# Therefore turn off shuffling to avoid potential duplicate node id within the same batch
shuffle
=
False
# Create sampler
sampler
=
NeighborSampler
(
g
,
[
int
(
fanout
)
for
fanout
in
args
.
fan_out
.
split
(
','
)],
dgl
.
distributed
.
sample_neighbors
,
device
)
...
...
@@ -194,7 +199,7 @@ def run(args, device, data):
dataset
=
train_nid
.
numpy
(),
batch_size
=
args
.
batch_size
,
collate_fn
=
sampler
.
sample_blocks
,
shuffle
=
Tru
e
,
shuffle
=
shuffl
e
,
drop_last
=
False
)
# Define model and optimizer
...
...
@@ -341,6 +346,8 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--dropout'
,
type
=
float
,
default
=
0.5
)
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
help
=
'get rank of the process'
)
parser
.
add_argument
(
'--standalone'
,
action
=
'store_true'
,
help
=
'run in the standalone mode'
)
parser
.
add_argument
(
'--pad-data'
,
default
=
False
,
action
=
'store_true'
,
help
=
'Pad train nid to the same length across machine, to ensure num of batches to be the same.'
)
args
=
parser
.
parse_args
()
print
(
args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment