"vscode:/vscode.git/clone" did not exist on "28f404349d69da1af7b52f18b022bc7971951a41"
Unverified Commit 0767c5fc authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

Fix dist example padding problem (#3687)

parent f282ee30
......@@ -184,7 +184,12 @@ def pad_data(nids, device):
def run(args, device, data):
# Unpack data
train_nid, val_nid, test_nid, in_feats, n_classes, g = data
shuffle = True
if args.pad_data:
train_nid = pad_data(train_nid, device)
# Current pipeline doesn't support duplicate node id within the same batch
# Therefore turn off shuffling to avoid potential duplicate node id within the same batch
shuffle = False
# Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')],
dgl.distributed.sample_neighbors, device)
......@@ -194,7 +199,7 @@ def run(args, device, data):
dataset=train_nid.numpy(),
batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
shuffle=True,
shuffle=shuffle,
drop_last=False)
# Define model and optimizer
......@@ -341,6 +346,8 @@ if __name__ == '__main__':
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--local_rank', type=int, help='get rank of the process')
parser.add_argument('--standalone', action='store_true', help='run in the standalone mode')
parser.add_argument('--pad-data', default=False, action='store_true',
help='Pad train nid to the same length across machine, to ensure num of batches to be the same.')
args = parser.parse_args()
print(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment