Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
cebf3364
Unverified
Commit
cebf3364
authored
Aug 06, 2020
by
Jinjing Zhou
Committed by
GitHub
Aug 06, 2020
Browse files
fix (#1950)
parent
967ecb80
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
2 deletions
+4
-2
examples/pytorch/graphsage/train_cv_multi_gpu.py
examples/pytorch/graphsage/train_cv_multi_gpu.py
+2
-1
examples/pytorch/graphsage/train_sampling_multi_gpu.py
examples/pytorch/graphsage/train_sampling_multi_gpu.py
+2
-1
No files found.
examples/pytorch/graphsage/train_cv_multi_gpu.py
View file @
cebf3364
...
@@ -11,6 +11,7 @@ import time
...
@@ -11,6 +11,7 @@ import time
import
argparse
import
argparse
import
tqdm
import
tqdm
import
traceback
import
traceback
import
math
from
_thread
import
start_new_thread
from
_thread
import
start_new_thread
from
functools
import
wraps
from
functools
import
wraps
from
dgl.data
import
RedditDataset
from
dgl.data
import
RedditDataset
...
@@ -267,7 +268,7 @@ def run(proc_id, n_gpus, args, devices, data):
...
@@ -267,7 +268,7 @@ def run(proc_id, n_gpus, args, devices, data):
val_mask
=
th
.
BoolTensor
(
val_mask
)
val_mask
=
th
.
BoolTensor
(
val_mask
)
# Split train_nid
# Split train_nid
train_nid
=
th
.
split
(
train_nid
,
len
(
train_nid
)
//
n_gpus
)[
proc_id
]
train_nid
=
th
.
split
(
train_nid
,
math
.
ceil
(
len
(
train_nid
)
//
n_gpus
)
)
[
proc_id
]
# Create sampler
# Create sampler
sampler
=
NeighborSampler
(
g
,
[
int
(
_
)
for
_
in
args
.
fan_out
.
split
(
','
)])
sampler
=
NeighborSampler
(
g
,
[
int
(
_
)
for
_
in
args
.
fan_out
.
split
(
','
)])
...
...
examples/pytorch/graphsage/train_sampling_multi_gpu.py
View file @
cebf3364
...
@@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
...
@@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
import
dgl.function
as
fn
import
dgl.function
as
fn
import
dgl.nn.pytorch
as
dglnn
import
dgl.nn.pytorch
as
dglnn
import
time
import
time
import
math
import
argparse
import
argparse
from
dgl.data
import
RedditDataset
from
dgl.data
import
RedditDataset
from
torch.nn.parallel
import
DistributedDataParallel
from
torch.nn.parallel
import
DistributedDataParallel
...
@@ -145,7 +146,7 @@ def run(proc_id, n_gpus, args, devices, data):
...
@@ -145,7 +146,7 @@ def run(proc_id, n_gpus, args, devices, data):
test_nid
=
test_mask
.
nonzero
()[:,
0
]
test_nid
=
test_mask
.
nonzero
()[:,
0
]
# Split train_nid
# Split train_nid
train_nid
=
th
.
split
(
train_nid
,
len
(
train_nid
)
//
n_gpus
)[
proc_id
]
train_nid
=
th
.
split
(
train_nid
,
math
.
ceil
(
len
(
train_nid
)
//
n_gpus
)
)
[
proc_id
]
# Create PyTorch DataLoader for constructing blocks
# Create PyTorch DataLoader for constructing blocks
sampler
=
dgl
.
sampling
.
MultiLayerNeighborSampler
(
sampler
=
dgl
.
sampling
.
MultiLayerNeighborSampler
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment