Unverified Commit 9497a9be authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[Bugfix][Examples] Fix graphsage multigpu training example training set size (#3002)



* Make multigpu graphsage use whole datset

* Specify queeze dimension

* Remove squeeze dimension
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 5be937a7
......@@ -13,7 +13,7 @@ from torch.nn.parallel import DistributedDataParallel
import tqdm
from model import SAGE
from load_graph import load_reddit, inductive_split
from load_graph import load_reddit, inductive_split, load_ogb
def compute_acc(pred, labels):
"""
......@@ -85,7 +85,6 @@ def run(proc_id, n_gpus, args, devices, data):
train_nid = train_mask.nonzero().squeeze()
val_nid = val_mask.nonzero().squeeze()
test_nid = test_mask.nonzero().squeeze()
train_nid = train_nid[:n_gpus * args.batch_size + 1]
# Create PyTorch DataLoader for constructing blocks
sampler = dgl.dataloading.MultiLayerNeighborSampler(
......@@ -172,6 +171,7 @@ if __name__ == '__main__':
argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument('--gpu', type=str, default='0',
help="Comma separated list of GPU device IDs.")
argparser.add_argument('--dataset', type=str, default='reddit')
argparser.add_argument('--num-epochs', type=int, default=20)
argparser.add_argument('--num-hidden', type=int, default=16)
argparser.add_argument('--num-layers', type=int, default=2)
......@@ -195,7 +195,13 @@ if __name__ == '__main__':
devices = list(map(int, args.gpu.split(',')))
n_gpus = len(devices)
if args.dataset == 'reddit':
g, n_classes = load_reddit()
elif args.dataset == 'ogbn-products':
g, n_classes = load_ogb('ogbn-products')
else:
raise Exception('unknown dataset')
# Construct graph
g = dgl.as_heterograph(g)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment