Unverified Commit ba61a566 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[Bug] fix inference for labor example (#6148)

parent 3fb81fca
......@@ -47,13 +47,14 @@ class SAGE(nn.Module):
h = self.dropout(h)
return h
def inference(self, g, device, batch_size, num_workers, buffer_device=None):
def inference(self, g, device, batch_size, use_uva, num_workers):
# The difference between this inference function and the one in the official
# example is that the intermediate results can also benefit from prefetching.
g.ndata["h"] = g.ndata["features"]
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(
1, prefetch_node_feats=["h"]
)
pin_memory = g.device != device and use_uva
dataloader = dgl.dataloading.DataLoader(
g,
th.arange(g.num_nodes(), dtype=g.idtype, device=g.device),
......@@ -62,26 +63,30 @@ class SAGE(nn.Module):
batch_size=batch_size,
shuffle=False,
drop_last=False,
use_uva=use_uva,
num_workers=num_workers,
persistent_workers=(num_workers > 0),
)
if buffer_device is None:
buffer_device = device
self.train(False)
self.eval()
for l, layer in enumerate(self.layers):
y = th.zeros(
y = th.empty(
g.num_nodes(),
self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
device=buffer_device,
dtype=g.ndata["h"].dtype,
device=g.device,
pin_memory=pin_memory,
)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
x = blocks[0].srcdata["h"]
h = layer(blocks[0], x)
if l != len(self.layers) - 1:
if l < len(self.layers) - 1:
h = self.activation(h)
h = self.dropout(h)
y[output_nodes] = h.to(buffer_device)
# by design, our output nodes are contiguous
y[output_nodes[0].item() : output_nodes[-1].item() + 1] = h.to(
y.device
)
g.ndata["h"] = y
return y
......@@ -269,18 +269,9 @@ class DataModule(LightningDataModule):
dataloader_device = device
self.g = g
if cast_to_int:
self.train_nid, self.val_nid, self.test_nid = (
train_nid.int(),
val_nid.int(),
test_nid.int(),
)
else:
self.train_nid, self.val_nid, self.test_nid = (
train_nid,
val_nid,
test_nid,
)
self.train_nid = train_nid.to(g.idtype)
self.val_nid = val_nid.to(g.idtype)
self.test_nid = test_nid.to(g.idtype)
self.sampler = sampler
self.device = dataloader_device
self.use_uva = use_uva
......@@ -385,7 +376,7 @@ if __name__ == "__main__":
argparser.add_argument(
"--gpu",
type=int,
default=0,
default=0 if th.cuda.is_available() else -1,
help="GPU device ID. Use -1 for CPU training",
)
argparser.add_argument("--dataset", type=str, default="reddit")
......@@ -493,7 +484,7 @@ if __name__ == "__main__":
logger = TensorBoardLogger(args.logdir, name=subdir)
trainer = Trainer(
accelerator="gpu" if args.gpu != -1 else "cpu",
devices=[args.gpu],
devices=[args.gpu] if args.gpu != -1 else "auto",
max_epochs=args.num_epochs,
max_steps=args.num_steps,
min_steps=args.min_steps,
......@@ -521,15 +512,16 @@ if __name__ == "__main__":
graph,
f"cuda:{args.gpu}" if args.gpu != -1 else "cpu",
4096,
args.use_uva,
args.num_workers,
graph.device,
)
for nid, split_name in zip(
[datamodule.train_nid, datamodule.val_nid, datamodule.test_nid],
["Train", "Validation", "Test"],
):
nid = nid.to(pred.device).long()
pred_nid = pred[nid]
label = graph.ndata["labels"][nid]
f1score = model.f1score_class().to(pred.device)
acc = f1score(pred_nid, label)
print(f"{split_name} accuracy:", acc.item())
print(f"{split_name} accuracy: {acc.item()}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment