Unverified Commit c56e27a8 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Bugfix] Accessing data from the indexes stored in same device (#4242)

* First update to fix two examples

* Update to fix RGCN/graphsage example and dataloader

* Update
parent 14e4e1b0
......@@ -133,10 +133,9 @@ def run(proc_id, n_gpus, args, devices, data):
# blocks.
tic_step = time.time()
for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(dataloader):
input_nodes = input_nodes.to(nfeat.device)
batch_inputs = nfeat[input_nodes].to(device)
pos_graph = pos_graph.to(device)
neg_graph = neg_graph.to(device)
blocks = [block.int().to(device) for block in blocks]
blocks = [block.int() for block in blocks]
d_step = time.time()
# Compute loss and prediction
......
......@@ -104,7 +104,6 @@ def train(rank, world_size, graph, num_classes, split_idx):
# move ids to GPU
train_idx = train_idx.to('cuda')
valid_idx = valid_idx.to('cuda')
test_idx = test_idx.to('cuda')
# For training, each process/GPU will get a subset of the
# train_idx/valid_idx, and generate mini-batches indepednetly. This allows
......
......@@ -72,7 +72,7 @@ def evaluate(model, graph, dataloader):
def layerwise_infer(device, graph, nid, model, batch_size):
model.eval()
with torch.no_grad():
pred = model.inference(graph, device, batch_size)
pred = model.inference(graph, device, batch_size).to(device)
pred = pred[nid]
label = graph.ndata['label'][nid]
return MF.accuracy(pred, label)
......
......@@ -121,6 +121,7 @@ def main(args):
ns_mode=True)
labels = labels.to(device)
model = model.to(device)
inv_target = inv_target.to(device)
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.wd)
......
......@@ -60,6 +60,7 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, queue=None):
ns_mode=True)
labels = labels.to(device)
model = model.to(device)
inv_target = inv_target.to(device)
model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.wd)
......
......@@ -228,7 +228,7 @@ class DDPTensorizedDataset(torch.utils.data.IterableDataset):
np.random.shuffle(self._indices[:self.num_indices].numpy())
else:
self._indices[:self.num_indices] = self._indices[
torch.randperm(self.num_indices, device=self._device)]
torch.randperm(self.num_indices, device=self._indices.device)]
if not self.drop_last:
# pad extra
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment