"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "8708851eca17ad5d61307d7a08b702ad3e77bb4e"
Unverified Commit cebf3364 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

fix (#1950)

parent 967ecb80
......@@ -11,6 +11,7 @@ import time
import argparse
import tqdm
import traceback
import math
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset
......@@ -267,7 +268,7 @@ def run(proc_id, n_gpus, args, devices, data):
val_mask = th.BoolTensor(val_mask)
# Split train_nid
train_nid = th.split(train_nid, len(train_nid) // n_gpus)[proc_id]
train_nid = th.split(train_nid, math.ceil(len(train_nid) // n_gpus))[proc_id]
# Create sampler
sampler = NeighborSampler(g, [int(_) for _ in args.fan_out.split(',')])
......
......@@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import math
import argparse
from dgl.data import RedditDataset
from torch.nn.parallel import DistributedDataParallel
......@@ -145,7 +146,7 @@ def run(proc_id, n_gpus, args, devices, data):
test_nid = test_mask.nonzero()[:, 0]
# Split train_nid
train_nid = th.split(train_nid, len(train_nid) // n_gpus)[proc_id]
train_nid = th.split(train_nid, math.ceil(len(train_nid) // n_gpus))[proc_id]
# Create PyTorch DataLoader for constructing blocks
sampler = dgl.sampling.MultiLayerNeighborSampler(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment