Unverified Commit 55f3e926 authored by Leo's avatar Leo Committed by GitHub
Browse files

[Fix] use DistDataLoader instead of Pytorch’s DataLoader. (#2182)


Co-authored-by: default avatarhzliuchw <hzliuchw@corp.netease.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 0afacfb8
...@@ -20,7 +20,7 @@ import torch.nn as nn ...@@ -20,7 +20,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.utils.data import DataLoader from dgl.distributed import DistDataLoader
#from pyinstrument import Profiler #from pyinstrument import Profiler
class SAGE(nn.Module): class SAGE(nn.Module):
...@@ -207,13 +207,12 @@ class DistSAGE(SAGE): ...@@ -207,13 +207,12 @@ class DistSAGE(SAGE):
sampler = PosNeighborSampler(g, [-1], dgl.distributed.sample_neighbors) sampler = PosNeighborSampler(g, [-1], dgl.distributed.sample_neighbors)
print('|V|={}, eval batch size: {}'.format(g.number_of_nodes(), batch_size)) print('|V|={}, eval batch size: {}'.format(g.number_of_nodes(), batch_size))
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader( dataloader = DistDataLoader(
dataset=nodes, dataset=nodes,
batch_size=batch_size, batch_size=batch_size,
collate_fn=sampler.sample_blocks, collate_fn=sampler.sample_blocks,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False)
num_workers=args.num_workers)
for blocks in tqdm.tqdm(dataloader): for blocks in tqdm.tqdm(dataloader):
block = blocks[0].to(device) block = blocks[0].to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment