Unverified Commit ba110e50 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Bug] Fix VRGCN with new dataset (#1955)

parent 13ea9ddb
...@@ -220,10 +220,8 @@ def run(args, dev_id, data): ...@@ -220,10 +220,8 @@ def run(args, dev_id, data):
# Unpack data # Unpack data
train_mask, val_mask, in_feats, labels, n_classes, g = data train_mask, val_mask, in_feats, labels, n_classes, g = data
train_nid = th.LongTensor(np.nonzero(train_mask)[0]) train_nid = train_mask.nonzero()[:, 0]
val_nid = th.LongTensor(np.nonzero(val_mask)[0]) val_nid = val_mask.nonzero()[:, 0]
train_mask = th.BoolTensor(train_mask)
val_mask = th.BoolTensor(val_mask)
# Create sampler # Create sampler
sampler = NeighborSampler(g, [int(_) for _ in args.fan_out.split(',')]) sampler = NeighborSampler(g, [int(_) for _ in args.fan_out.split(',')])
......
...@@ -262,10 +262,8 @@ def run(proc_id, n_gpus, args, devices, data): ...@@ -262,10 +262,8 @@ def run(proc_id, n_gpus, args, devices, data):
# Unpack data # Unpack data
train_mask, val_mask, in_feats, labels, n_classes, g = data train_mask, val_mask, in_feats, labels, n_classes, g = data
train_nid = th.LongTensor(np.nonzero(train_mask)[0]) train_nid = train_mask.nonzero()[:, 0]
val_nid = th.LongTensor(np.nonzero(val_mask)[0]) val_nid = val_mask.nonzero()[:, 0]
train_mask = th.BoolTensor(train_mask)
val_mask = th.BoolTensor(val_mask)
# Split train_nid # Split train_nid
train_nid = th.split(train_nid, math.ceil(len(train_nid) // n_gpus))[proc_id] train_nid = th.split(train_nid, math.ceil(len(train_nid) // n_gpus))[proc_id]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment