Unverified Commit 0767c5fc authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

Fix dist example padding problem (#3687)

parent f282ee30
...@@ -184,7 +184,12 @@ def pad_data(nids, device): ...@@ -184,7 +184,12 @@ def pad_data(nids, device):
def run(args, device, data): def run(args, device, data):
# Unpack data # Unpack data
train_nid, val_nid, test_nid, in_feats, n_classes, g = data train_nid, val_nid, test_nid, in_feats, n_classes, g = data
train_nid = pad_data(train_nid, device) shuffle = True
if args.pad_data:
train_nid = pad_data(train_nid, device)
# Current pipeline doesn't support duplicate node id within the same batch
# Therefore turn off shuffling to avoid potential duplicate node id within the same batch
shuffle = False
# Create sampler # Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')], sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')],
dgl.distributed.sample_neighbors, device) dgl.distributed.sample_neighbors, device)
...@@ -194,7 +199,7 @@ def run(args, device, data): ...@@ -194,7 +199,7 @@ def run(args, device, data):
dataset=train_nid.numpy(), dataset=train_nid.numpy(),
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=sampler.sample_blocks, collate_fn=sampler.sample_blocks,
shuffle=True, shuffle=shuffle,
drop_last=False) drop_last=False)
# Define model and optimizer # Define model and optimizer
...@@ -341,6 +346,8 @@ if __name__ == '__main__': ...@@ -341,6 +346,8 @@ if __name__ == '__main__':
parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--local_rank', type=int, help='get rank of the process') parser.add_argument('--local_rank', type=int, help='get rank of the process')
parser.add_argument('--standalone', action='store_true', help='run in the standalone mode') parser.add_argument('--standalone', action='store_true', help='run in the standalone mode')
parser.add_argument('--pad-data', default=False, action='store_true',
help='Pad train nid to the same length across machine, to ensure num of batches to be the same.')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment