Unverified Commit 55f3e926 authored by Leo's avatar Leo Committed by GitHub
Browse files

[Fix] use DistDataLoader instead of Pytorch’s DataLoader. (#2182)


Co-authored-by: default avatarhzliuchw <hzliuchw@corp.netease.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 0afacfb8
......@@ -20,7 +20,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from dgl.distributed import DistDataLoader
#from pyinstrument import Profiler
class SAGE(nn.Module):
......@@ -207,13 +207,12 @@ class DistSAGE(SAGE):
sampler = PosNeighborSampler(g, [-1], dgl.distributed.sample_neighbors)
print('|V|={}, eval batch size: {}'.format(g.number_of_nodes(), batch_size))
# Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader(
dataloader = DistDataLoader(
dataset=nodes,
batch_size=batch_size,
collate_fn=sampler.sample_blocks,
shuffle=False,
drop_last=False,
num_workers=args.num_workers)
drop_last=False)
for blocks in tqdm.tqdm(dataloader):
block = blocks[0].to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment