Unverified Commit 62b4bbb4 authored by mszarma's avatar mszarma Committed by GitHub
Browse files

[Fix] Enable mini-batch rgcn for CPU (#2345)

parent 77968e30
......@@ -185,7 +185,7 @@ def evaluate(model, embed_layer, eval_loader, node_feats):
@thread_wrapped_func
def run(proc_id, n_gpus, args, devices, dataset, split, queue=None):
dev_id = devices[proc_id]
dev_id = devices[proc_id] if devices[proc_id] != 'cpu' else -1
g, node_feats, num_of_ntype, num_classes, num_rels, target_idx, \
train_idx, val_idx, test_idx, labels = dataset
if split is not None:
......
......@@ -78,7 +78,7 @@ class RelGraphEmbedLayer(nn.Module):
sparse_emb=False,
embed_name='embed'):
super(RelGraphEmbedLayer, self).__init__()
self.dev_id = dev_id
self.dev_id = th.device(dev_id if dev_id >= 0 else 'cpu')
self.embed_size = embed_size
self.embed_name = embed_name
self.num_nodes = num_nodes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment