Unverified Commit cebf3364 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

fix (#1950)

parent 967ecb80
...@@ -11,6 +11,7 @@ import time ...@@ -11,6 +11,7 @@ import time
import argparse import argparse
import tqdm import tqdm
import traceback import traceback
import math
from _thread import start_new_thread from _thread import start_new_thread
from functools import wraps from functools import wraps
from dgl.data import RedditDataset from dgl.data import RedditDataset
...@@ -267,7 +268,7 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -267,7 +268,7 @@ def run(proc_id, n_gpus, args, devices, data):
val_mask = th.BoolTensor(val_mask) val_mask = th.BoolTensor(val_mask)
# Split train_nid # Split train_nid
train_nid = th.split(train_nid, len(train_nid) // n_gpus)[proc_id] train_nid = th.split(train_nid, math.ceil(len(train_nid) // n_gpus))[proc_id]
# Create sampler # Create sampler
sampler = NeighborSampler(g, [int(_) for _ in args.fan_out.split(',')]) sampler = NeighborSampler(g, [int(_) for _ in args.fan_out.split(',')])
......
...@@ -9,6 +9,7 @@ from torch.utils.data import DataLoader ...@@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
import dgl.function as fn import dgl.function as fn
import dgl.nn.pytorch as dglnn import dgl.nn.pytorch as dglnn
import time import time
import math
import argparse import argparse
from dgl.data import RedditDataset from dgl.data import RedditDataset
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
...@@ -145,7 +146,7 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -145,7 +146,7 @@ def run(proc_id, n_gpus, args, devices, data):
test_nid = test_mask.nonzero()[:, 0] test_nid = test_mask.nonzero()[:, 0]
# Split train_nid # Split train_nid
train_nid = th.split(train_nid, len(train_nid) // n_gpus)[proc_id] train_nid = th.split(train_nid, math.ceil(len(train_nid) // n_gpus))[proc_id]
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
sampler = dgl.sampling.MultiLayerNeighborSampler( sampler = dgl.sampling.MultiLayerNeighborSampler(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment