Unverified Commit ba110e50 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Fix VRGCN with new dataset (#1955)

parent 13ea9ddb
......@@ -220,10 +220,8 @@ def run(args, dev_id, data):
# Unpack data
train_mask, val_mask, in_feats, labels, n_classes, g = data
train_nid = th.LongTensor(np.nonzero(train_mask)[0])
val_nid = th.LongTensor(np.nonzero(val_mask)[0])
train_mask = th.BoolTensor(train_mask)
val_mask = th.BoolTensor(val_mask)
train_nid = train_mask.nonzero()[:, 0]
val_nid = val_mask.nonzero()[:, 0]
# Create sampler
sampler = NeighborSampler(g, [int(_) for _ in args.fan_out.split(',')])
......
......@@ -262,10 +262,8 @@ def run(proc_id, n_gpus, args, devices, data):
# Unpack data
train_mask, val_mask, in_feats, labels, n_classes, g = data
train_nid = th.LongTensor(np.nonzero(train_mask)[0])
val_nid = th.LongTensor(np.nonzero(val_mask)[0])
train_mask = th.BoolTensor(train_mask)
val_mask = th.BoolTensor(val_mask)
train_nid = train_mask.nonzero()[:, 0]
val_nid = val_mask.nonzero()[:, 0]
# Split train_nid
train_nid = th.split(train_nid, math.ceil(len(train_nid) // n_gpus))[proc_id]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment