Unverified Commit a1f74982 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Example] Fix Various Examples Related to TorchMetrics (#5521)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-36-188.ap-northeast-1.compute.internal>
parent 97286f98
...@@ -10,7 +10,7 @@ Dependencies ...@@ -10,7 +10,7 @@ Dependencies
- Python 3.7+(for string formatting features) - Python 3.7+(for string formatting features)
- PyTorch 1.9.0+ - PyTorch 1.9.0+
- sklearn - sklearn
- TorchMetrics - TorchMetrics 0.11.4
## Run Experiments ## Run Experiments
......
...@@ -72,7 +72,12 @@ for _ in range(10): ...@@ -72,7 +72,12 @@ for _ in range(10):
loss.backward() loss.backward()
opt.step() opt.step()
if it % 20 == 0: if it % 20 == 0:
acc = MF.accuracy(y_hat[m], y[m]) acc = MF.accuracy(
y_hat[m],
y[m],
task="multiclass",
num_classes=dataset.num_classes,
)
mem = torch.cuda.max_memory_allocated() / 1000000 mem = torch.cuda.max_memory_allocated() / 1000000
print("Loss", loss.item(), "Acc", acc.item(), "GPU Mem", mem, "MB") print("Loss", loss.item(), "Acc", acc.item(), "GPU Mem", mem, "MB")
tt = time.time() tt = time.time()
...@@ -97,8 +102,18 @@ for _ in range(10): ...@@ -97,8 +102,18 @@ for _ in range(10):
val_labels = torch.cat(val_labels, 0) val_labels = torch.cat(val_labels, 0)
test_preds = torch.cat(test_preds, 0) test_preds = torch.cat(test_preds, 0)
test_labels = torch.cat(test_labels, 0) test_labels = torch.cat(test_labels, 0)
val_acc = MF.accuracy(val_preds, val_labels) val_acc = MF.accuracy(
test_acc = MF.accuracy(test_preds, test_labels) val_preds,
val_labels,
task="multiclass",
num_classes=dataset.num_classes,
)
test_acc = MF.accuracy(
test_preds,
test_labels,
task="multiclass",
num_classes=dataset.num_classes,
)
print("Validation acc:", val_acc.item(), "Test acc:", test_acc.item()) print("Validation acc:", val_acc.item(), "Test acc:", test_acc.item())
print(np.mean(durations[4:]), np.std(durations[4:])) print(np.mean(durations[4:]), np.std(durations[4:]))
...@@ -10,7 +10,7 @@ Requirements ...@@ -10,7 +10,7 @@ Requirements
------------ ------------
```bash ```bash
pip install requests torchmetrics pip install requests torchmetrics==0.11.4 ogb
``` ```
How to run How to run
...@@ -45,8 +45,7 @@ Test Accuracy: 0.7632 ...@@ -45,8 +45,7 @@ Test Accuracy: 0.7632
### PyTorch Lightning for node classification ### PyTorch Lightning for node classification
Train w/ mini-batch sampling for node classification with PyTorch Lightning on OGB-products. Train w/ mini-batch sampling for node classification with PyTorch Lightning on OGB-products. It requires PyTorch Lightning 2.0.1. It works with both single GPU and multiple GPUs:
Works with both single GPU and multiple GPUs:
```bash ```bash
python3 lightning/node_classification.py python3 lightning/node_classification.py
......
...@@ -27,8 +27,8 @@ class SAGE(LightningModule): ...@@ -27,8 +27,8 @@ class SAGE(LightningModule):
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
self.n_hidden = n_hidden self.n_hidden = n_hidden
self.n_classes = n_classes self.n_classes = n_classes
self.train_acc = Accuracy() self.train_acc = Accuracy(task="multiclass", num_classes=n_classes)
self.val_acc = Accuracy() self.val_acc = Accuracy(task="multiclass", num_classes=n_classes)
def forward(self, blocks, x): def forward(self, blocks, x):
h = x h = x
...@@ -180,9 +180,11 @@ if __name__ == "__main__": ...@@ -180,9 +180,11 @@ if __name__ == "__main__":
# Train # Train
checkpoint_callback = ModelCheckpoint(monitor="val_acc", save_top_k=1) checkpoint_callback = ModelCheckpoint(monitor="val_acc", save_top_k=1)
# Use this for single GPU # Use this for single GPU
# trainer = Trainer(gpus=[0], max_epochs=10, callbacks=[checkpoint_callback]) # trainer = Trainer(accelerator="gpu", devices=[0], max_epochs=10,
# callbacks=[checkpoint_callback])
trainer = Trainer( trainer = Trainer(
gpus=[0, 1, 2, 3], accelerator="gpu",
devices=[0, 1, 2, 3],
max_epochs=10, max_epochs=10,
callbacks=[checkpoint_callback], callbacks=[checkpoint_callback],
strategy="ddp_spawn", strategy="ddp_spawn",
...@@ -203,5 +205,7 @@ if __name__ == "__main__": ...@@ -203,5 +205,7 @@ if __name__ == "__main__":
pred = model.inference(graph, "cuda", 4096, 12, graph.device) pred = model.inference(graph, "cuda", 4096, 12, graph.device)
pred = pred[test_idx] pred = pred[test_idx]
label = graph.ndata["label"][test_idx] label = graph.ndata["label"][test_idx]
acc = MF.accuracy(pred, label) acc = MF.accuracy(
pred, label, task="multiclass", num_classes=datamodule.n_classes
)
print("Test accuracy:", acc) print("Test accuracy:", acc)
...@@ -5,7 +5,6 @@ import dgl.nn as dglnn ...@@ -5,7 +5,6 @@ import dgl.nn as dglnn
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchmetrics.functional as MF
import tqdm import tqdm
from dgl.dataloading import ( from dgl.dataloading import (
as_edge_prediction_sampler, as_edge_prediction_sampler,
......
...@@ -74,7 +74,7 @@ class SAGE(nn.Module): ...@@ -74,7 +74,7 @@ class SAGE(nn.Module):
return y return y
def evaluate(model, graph, dataloader): def evaluate(model, graph, dataloader, num_classes):
model.eval() model.eval()
ys = [] ys = []
y_hats = [] y_hats = []
...@@ -83,10 +83,15 @@ def evaluate(model, graph, dataloader): ...@@ -83,10 +83,15 @@ def evaluate(model, graph, dataloader):
x = blocks[0].srcdata["feat"] x = blocks[0].srcdata["feat"]
ys.append(blocks[-1].dstdata["label"]) ys.append(blocks[-1].dstdata["label"])
y_hats.append(model(blocks, x)) y_hats.append(model(blocks, x))
return MF.accuracy(torch.cat(y_hats), torch.cat(ys)) return MF.accuracy(
torch.cat(y_hats),
torch.cat(ys),
task="multiclass",
num_classes=num_classes,
)
def layerwise_infer(device, graph, nid, model, batch_size): def layerwise_infer(device, graph, nid, model, num_classes, batch_size):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
pred = model.inference( pred = model.inference(
...@@ -94,10 +99,12 @@ def layerwise_infer(device, graph, nid, model, batch_size): ...@@ -94,10 +99,12 @@ def layerwise_infer(device, graph, nid, model, batch_size):
) # pred in buffer_device ) # pred in buffer_device
pred = pred[nid] pred = pred[nid]
label = graph.ndata["label"][nid].to(pred.device) label = graph.ndata["label"][nid].to(pred.device)
return MF.accuracy(pred, label) return MF.accuracy(
pred, label, task="multiclass", num_classes=num_classes
)
def train(args, device, g, dataset, model): def train(args, device, g, dataset, model, num_classes):
# create sampler & dataloader # create sampler & dataloader
train_idx = dataset.train_idx.to(device) train_idx = dataset.train_idx.to(device)
val_idx = dataset.val_idx.to(device) val_idx = dataset.val_idx.to(device)
...@@ -147,7 +154,7 @@ def train(args, device, g, dataset, model): ...@@ -147,7 +154,7 @@ def train(args, device, g, dataset, model):
loss.backward() loss.backward()
opt.step() opt.step()
total_loss += loss.item() total_loss += loss.item()
acc = evaluate(model, g, val_dataloader) acc = evaluate(model, g, val_dataloader, num_classes)
print( print(
"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format( "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
epoch, total_loss / (it + 1), acc.item() epoch, total_loss / (it + 1), acc.item()
...@@ -174,6 +181,7 @@ if __name__ == "__main__": ...@@ -174,6 +181,7 @@ if __name__ == "__main__":
dataset = AsNodePredDataset(DglNodePropPredDataset("ogbn-products")) dataset = AsNodePredDataset(DglNodePropPredDataset("ogbn-products"))
g = dataset[0] g = dataset[0]
g = g.to("cuda" if args.mode == "puregpu" else "cpu") g = g.to("cuda" if args.mode == "puregpu" else "cpu")
num_classes = dataset.num_classes
device = torch.device("cpu" if args.mode == "cpu" else "cuda") device = torch.device("cpu" if args.mode == "cpu" else "cuda")
# create GraphSAGE model # create GraphSAGE model
...@@ -183,9 +191,11 @@ if __name__ == "__main__": ...@@ -183,9 +191,11 @@ if __name__ == "__main__":
# model training # model training
print("Training...") print("Training...")
train(args, device, g, dataset, model) train(args, device, g, dataset, model, num_classes)
# test the model # test the model
print("Testing...") print("Testing...")
acc = layerwise_infer(device, g, dataset.test_idx, model, batch_size=4096) acc = layerwise_infer(
device, g, dataset.test_idx, model, num_classes, batch_size=4096
)
print("Test Accuracy {:.4f}".format(acc.item())) print("Test Accuracy {:.4f}".format(acc.item()))
...@@ -5,7 +5,7 @@ Requirements ...@@ -5,7 +5,7 @@ Requirements
------------ ------------
```bash ```bash
pip install torchmetrics pip install torchmetrics==0.11.4
``` ```
How to run How to run
......
...@@ -86,7 +86,7 @@ class SAGE(nn.Module): ...@@ -86,7 +86,7 @@ class SAGE(nn.Module):
return y return y
def evaluate(model, g, dataloader): def evaluate(model, g, num_classes, dataloader):
model.eval() model.eval()
ys = [] ys = []
y_hats = [] y_hats = []
...@@ -95,11 +95,16 @@ def evaluate(model, g, dataloader): ...@@ -95,11 +95,16 @@ def evaluate(model, g, dataloader):
x = blocks[0].srcdata["feat"] x = blocks[0].srcdata["feat"]
ys.append(blocks[-1].dstdata["label"]) ys.append(blocks[-1].dstdata["label"])
y_hats.append(model(blocks, x)) y_hats.append(model(blocks, x))
return MF.accuracy(torch.cat(y_hats), torch.cat(ys)) return MF.accuracy(
torch.cat(y_hats),
torch.cat(ys),
task="multiclass",
num_classes=num_classes,
)
def layerwise_infer( def layerwise_infer(
proc_id, device, g, nid, model, use_uva, batch_size=2**16 proc_id, device, g, num_classes, nid, model, use_uva, batch_size=2**16
): ):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -107,11 +112,15 @@ def layerwise_infer( ...@@ -107,11 +112,15 @@ def layerwise_infer(
pred = pred[nid] pred = pred[nid]
labels = g.ndata["label"][nid].to(pred.device) labels = g.ndata["label"][nid].to(pred.device)
if proc_id == 0: if proc_id == 0:
acc = MF.accuracy(pred, labels) acc = MF.accuracy(
pred, labels, task="multiclass", num_classes=num_classes
)
print("Test Accuracy {:.4f}".format(acc.item())) print("Test Accuracy {:.4f}".format(acc.item()))
def train(proc_id, nprocs, device, g, train_idx, val_idx, model, use_uva): def train(
proc_id, nprocs, device, g, num_classes, train_idx, val_idx, model, use_uva
):
sampler = NeighborSampler( sampler = NeighborSampler(
[10, 10, 10], prefetch_node_feats=["feat"], prefetch_labels=["label"] [10, 10, 10], prefetch_node_feats=["feat"], prefetch_labels=["label"]
) )
...@@ -154,7 +163,9 @@ def train(proc_id, nprocs, device, g, train_idx, val_idx, model, use_uva): ...@@ -154,7 +163,9 @@ def train(proc_id, nprocs, device, g, train_idx, val_idx, model, use_uva):
loss.backward() loss.backward()
opt.step() opt.step()
total_loss += loss total_loss += loss
acc = evaluate(model, g, val_dataloader).to(device) / nprocs acc = (
evaluate(model, g, num_classes, val_dataloader).to(device) / nprocs
)
dist.reduce(acc, 0) dist.reduce(acc, 0)
if proc_id == 0: if proc_id == 0:
print( print(
...@@ -175,20 +186,30 @@ def run(proc_id, nprocs, devices, g, data, mode): ...@@ -175,20 +186,30 @@ def run(proc_id, nprocs, devices, g, data, mode):
world_size=nprocs, world_size=nprocs,
rank=proc_id, rank=proc_id,
) )
out_size, train_idx, val_idx, test_idx = data num_classes, train_idx, val_idx, test_idx = data
train_idx = train_idx.to(device) train_idx = train_idx.to(device)
val_idx = val_idx.to(device) val_idx = val_idx.to(device)
g = g.to(device if mode == "puregpu" else "cpu") g = g.to(device if mode == "puregpu" else "cpu")
# create GraphSAGE model (distributed) # create GraphSAGE model (distributed)
in_size = g.ndata["feat"].shape[1] in_size = g.ndata["feat"].shape[1]
model = SAGE(in_size, 256, out_size).to(device) model = SAGE(in_size, 256, num_classes).to(device)
model = DistributedDataParallel( model = DistributedDataParallel(
model, device_ids=[device], output_device=device model, device_ids=[device], output_device=device
) )
# training + testing # training + testing
use_uva = mode == "mixed" use_uva = mode == "mixed"
train(proc_id, nprocs, device, g, train_idx, val_idx, model, use_uva) train(
layerwise_infer(proc_id, device, g, test_idx, model, use_uva) proc_id,
nprocs,
device,
g,
num_classes,
train_idx,
val_idx,
model,
use_uva,
)
layerwise_infer(proc_id, device, g, num_classes, test_idx, model, use_uva)
# cleanup process group # cleanup process group
dist.destroy_process_group() dist.destroy_process_group()
......
...@@ -4,11 +4,11 @@ This is an adaptation of RGCN where graph convolution is replaced with graph att ...@@ -4,11 +4,11 @@ This is an adaptation of RGCN where graph convolution is replaced with graph att
Dependencies Dependencies
------------ ------------
- torchmetrics - torchmetrics 0.11.4
Install as follows: Install as follows:
```bash ```bash
pip install torchmetrics pip install torchmetrics==0.11.4
``` ```
How to Run How to Run
......
...@@ -57,7 +57,7 @@ class HeteroGAT(nn.Module): ...@@ -57,7 +57,7 @@ class HeteroGAT(nn.Module):
return self.linear(h["paper"]) return self.linear(h["paper"])
def evaluate(model, dataloader, desc): def evaluate(num_classes, model, dataloader, desc):
preds = [] preds = []
labels = [] labels = []
with torch.no_grad(): with torch.no_grad():
...@@ -71,11 +71,13 @@ def evaluate(model, dataloader, desc): ...@@ -71,11 +71,13 @@ def evaluate(model, dataloader, desc):
labels.append(y.cpu()) labels.append(y.cpu())
preds = torch.cat(preds, 0) preds = torch.cat(preds, 0)
labels = torch.cat(labels, 0) labels = torch.cat(labels, 0)
acc = MF.accuracy(preds, labels) acc = MF.accuracy(
preds, labels, task="multiclass", num_classes=num_classes
)
return acc return acc
def train(train_loader, val_loader, test_loader, model): def train(train_loader, val_loader, test_loader, num_classes, model):
# loss function and optimizer # loss function and optimizer
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
...@@ -96,8 +98,8 @@ def train(train_loader, val_loader, test_loader, model): ...@@ -96,8 +98,8 @@ def train(train_loader, val_loader, test_loader, model):
opt.step() opt.step()
total_loss += loss.item() total_loss += loss.item()
model.eval() model.eval()
val_acc = evaluate(model, val_dataloader, "Val. ") val_acc = evaluate(num_classes, model, val_dataloader, "Val. ")
test_acc = evaluate(model, test_dataloader, "Test ") test_acc = evaluate(num_classes, model, test_dataloader, "Test ")
print( print(
f"Epoch {epoch:05d} | Loss {total_loss/(it+1):.4f} | Validation Acc. {val_acc.item():.4f} | Test Acc. {test_acc.item():.4f}" f"Epoch {epoch:05d} | Loss {total_loss/(it+1):.4f} | Validation Acc. {val_acc.item():.4f} | Test Acc. {test_acc.item():.4f}"
) )
...@@ -138,8 +140,8 @@ if __name__ == "__main__": ...@@ -138,8 +140,8 @@ if __name__ == "__main__":
# create RGAT model # create RGAT model
in_size = graph.ndata["feat"]["paper"].shape[1] in_size = graph.ndata["feat"]["paper"].shape[1]
out_size = dataset.num_classes num_classes = dataset.num_classes
model = HeteroGAT(graph.etypes, in_size, 256, out_size).to(device) model = HeteroGAT(graph.etypes, in_size, 256, num_classes).to(device)
# dataloader + model training + testing # dataloader + model training + testing
train_sampler = NeighborSampler( train_sampler = NeighborSampler(
...@@ -186,4 +188,4 @@ if __name__ == "__main__": ...@@ -186,4 +188,4 @@ if __name__ == "__main__":
use_uva=torch.cuda.is_available(), use_uva=torch.cuda.is_available(),
) )
train(train_dataloader, val_dataloader, test_dataloader, model) train(train_dataloader, val_dataloader, test_dataloader, num_classes, model)
...@@ -6,12 +6,12 @@ ...@@ -6,12 +6,12 @@
### Dependencies ### Dependencies
- rdflib - rdflib
- torchmetrics - torchmetrics 0.11.4
Install as follows: Install as follows:
```bash ```bash
pip install rdflib pip install rdflib
pip install torchmetrics pip install torchmetrics==0.11.4
``` ```
How to run How to run
......
...@@ -38,16 +38,21 @@ class RGCN(nn.Module): ...@@ -38,16 +38,21 @@ class RGCN(nn.Module):
return h return h
def evaluate(g, target_idx, labels, test_mask, model): def evaluate(g, target_idx, labels, num_classes, test_mask, model):
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze() test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
logits = model(g) logits = model(g)
logits = logits[target_idx] logits = logits[target_idx]
return accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item() return accuracy(
logits[test_idx].argmax(dim=1),
labels[test_idx],
task="multiclass",
num_classes=num_classes,
).item()
def train(g, target_idx, labels, train_mask, model): def train(g, target_idx, labels, num_classes, train_mask, model):
# define train idx, loss function and optimizer # define train idx, loss function and optimizer
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze() train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
...@@ -62,7 +67,10 @@ def train(g, target_idx, labels, train_mask, model): ...@@ -62,7 +67,10 @@ def train(g, target_idx, labels, train_mask, model):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
acc = accuracy( acc = accuracy(
logits[train_idx].argmax(dim=1), labels[train_idx] logits[train_idx].argmax(dim=1),
labels[train_idx],
task="multiclass",
num_classes=num_classes,
).item() ).item()
print( print(
"Epoch {:05d} | Loss {:.4f} | Train Accuracy {:.4f} ".format( "Epoch {:05d} | Loss {:.4f} | Train Accuracy {:.4f} ".format(
...@@ -112,9 +120,9 @@ if __name__ == "__main__": ...@@ -112,9 +120,9 @@ if __name__ == "__main__":
target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id] target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
# create RGCN model # create RGCN model
in_size = g.num_nodes() # featureless with one-hot encoding in_size = g.num_nodes() # featureless with one-hot encoding
out_size = data.num_classes num_classes = data.num_classes
model = RGCN(in_size, 16, out_size, num_rels).to(device) model = RGCN(in_size, 16, num_classes, num_rels).to(device)
train(g, target_idx, labels, train_mask, model) train(g, target_idx, labels, num_classes, train_mask, model)
acc = evaluate(g, target_idx, labels, test_mask, model) acc = evaluate(g, target_idx, labels, num_classes, test_mask, model)
print("Test accuracy {:.4f}".format(acc)) print("Test accuracy {:.4f}".format(acc))
...@@ -41,7 +41,7 @@ class RGCN(nn.Module): ...@@ -41,7 +41,7 @@ class RGCN(nn.Module):
return h return h
def evaluate(model, label, dataloader, inv_target): def evaluate(model, labels, num_classes, dataloader, inv_target):
model.eval() model.eval()
eval_logits = [] eval_logits = []
eval_seeds = [] eval_seeds = []
...@@ -55,10 +55,15 @@ def evaluate(model, label, dataloader, inv_target): ...@@ -55,10 +55,15 @@ def evaluate(model, label, dataloader, inv_target):
eval_seeds.append(output_nodes.cpu().detach()) eval_seeds.append(output_nodes.cpu().detach())
eval_logits = torch.cat(eval_logits) eval_logits = torch.cat(eval_logits)
eval_seeds = torch.cat(eval_seeds) eval_seeds = torch.cat(eval_seeds)
return accuracy(eval_logits.argmax(dim=1), labels[eval_seeds].cpu()).item() return accuracy(
eval_logits.argmax(dim=1),
labels[eval_seeds].cpu(),
task="multiclass",
num_classes=num_classes,
).item()
def train(device, g, target_idx, labels, train_mask, model): def train(device, g, target_idx, labels, train_mask, num_classes, model):
# define train idx, loss function and optimizer # define train idx, loss function and optimizer
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze() train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
...@@ -95,7 +100,7 @@ def train(device, g, target_idx, labels, train_mask, model): ...@@ -95,7 +100,7 @@ def train(device, g, target_idx, labels, train_mask, model):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
total_loss += loss.item() total_loss += loss.item()
acc = evaluate(model, labels, val_loader, inv_target) acc = evaluate(model, labels, num_classes, val_loader, inv_target)
print( print(
"Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} ".format( "Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} ".format(
epoch, total_loss / (it + 1), acc epoch, total_loss / (it + 1), acc
...@@ -150,10 +155,10 @@ if __name__ == "__main__": ...@@ -150,10 +155,10 @@ if __name__ == "__main__":
# create RGCN model # create RGCN model
in_size = g.num_nodes() # featureless with one-hot encoding in_size = g.num_nodes() # featureless with one-hot encoding
out_size = data.num_classes num_classes = data.num_classes
model = RGCN(in_size, 16, out_size, num_rels).to(device) model = RGCN(in_size, 16, num_classes, num_rels).to(device)
train(device, g, target_idx, labels, train_mask, model) train(device, g, target_idx, labels, train_mask, num_classes, model)
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze() test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
test_sampler = MultiLayerNeighborSampler( test_sampler = MultiLayerNeighborSampler(
[-1, -1] [-1, -1]
...@@ -166,5 +171,5 @@ if __name__ == "__main__": ...@@ -166,5 +171,5 @@ if __name__ == "__main__":
batch_size=32, batch_size=32,
shuffle=False, shuffle=False,
) )
acc = evaluate(model, labels, test_loader, inv_target) acc = evaluate(model, labels, num_classes, test_loader, inv_target)
print("Test accuracy {:.4f}".format(acc)) print("Test accuracy {:.4f}".format(acc))
...@@ -45,7 +45,7 @@ class RGCN(nn.Module): ...@@ -45,7 +45,7 @@ class RGCN(nn.Module):
return h return h
def evaluate(model, labels, dataloader, inv_target): def evaluate(model, labels, num_classes, dataloader, inv_target):
model.eval() model.eval()
eval_logits = [] eval_logits = []
eval_seeds = [] eval_seeds = []
...@@ -61,12 +61,25 @@ def evaluate(model, labels, dataloader, inv_target): ...@@ -61,12 +61,25 @@ def evaluate(model, labels, dataloader, inv_target):
eval_seeds = torch.cat(eval_seeds) eval_seeds = torch.cat(eval_seeds)
num_seeds = len(eval_seeds) num_seeds = len(eval_seeds)
loc_sum = accuracy( loc_sum = accuracy(
eval_logits.argmax(dim=1), labels[eval_seeds].cpu() eval_logits.argmax(dim=1),
labels[eval_seeds].cpu(),
task="multiclass",
num_classes=num_classes,
) * float(num_seeds) ) * float(num_seeds)
return torch.tensor([loc_sum.item(), float(num_seeds)]) return torch.tensor([loc_sum.item(), float(num_seeds)])
def train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model): def train(
proc_id,
device,
g,
target_idx,
labels,
num_classes,
train_idx,
inv_target,
model,
):
# define loss function and optimizer # define loss function and optimizer
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4) optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
...@@ -106,9 +119,9 @@ def train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model): ...@@ -106,9 +119,9 @@ def train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model):
total_loss += loss.item() total_loss += loss.item()
# torchmetric accuracy defined as num_correct_labels / num_train_nodes # torchmetric accuracy defined as num_correct_labels / num_train_nodes
# loc_acc_split = [loc_accuracy * loc_num_train_nodes, loc_num_train_nodes] # loc_acc_split = [loc_accuracy * loc_num_train_nodes, loc_num_train_nodes]
loc_acc_split = evaluate(model, labels, val_loader, inv_target).to( loc_acc_split = evaluate(
device model, labels, num_classes, val_loader, inv_target
) ).to(device)
dist.reduce(loc_acc_split, 0) dist.reduce(loc_acc_split, 0)
if proc_id == 0: if proc_id == 0:
acc = loc_acc_split[0] / loc_acc_split[1] acc = loc_acc_split[0] / loc_acc_split[1]
...@@ -143,13 +156,22 @@ def run(proc_id, nprocs, devices, g, data): ...@@ -143,13 +156,22 @@ def run(proc_id, nprocs, devices, g, data):
inv_target = inv_target.to(device) inv_target = inv_target.to(device)
# create RGCN model (distributed) # create RGCN model (distributed)
in_size = g.num_nodes() in_size = g.num_nodes()
out_size = num_classes model = RGCN(in_size, 16, num_classes, num_rels).to(device)
model = RGCN(in_size, 16, out_size, num_rels).to(device)
model = DistributedDataParallel( model = DistributedDataParallel(
model, device_ids=[device], output_device=device model, device_ids=[device], output_device=device
) )
# training + testing # training + testing
train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model) train(
proc_id,
device,
g,
target_idx,
labels,
num_classes,
train_idx,
inv_target,
model,
)
test_sampler = MultiLayerNeighborSampler( test_sampler = MultiLayerNeighborSampler(
[-1, -1] [-1, -1]
) # -1 for sampling all neighbors ) # -1 for sampling all neighbors
...@@ -162,7 +184,9 @@ def run(proc_id, nprocs, devices, g, data): ...@@ -162,7 +184,9 @@ def run(proc_id, nprocs, devices, g, data):
shuffle=False, shuffle=False,
use_ddp=True, use_ddp=True,
) )
loc_acc_split = evaluate(model, labels, test_loader, inv_target).to(device) loc_acc_split = evaluate(
model, labels, num_classes, test_loader, inv_target
).to(device)
dist.reduce(loc_acc_split, 0) dist.reduce(loc_acc_split, 0)
if proc_id == 0: if proc_id == 0:
acc = loc_acc_split[0] / loc_acc_split[1] acc = loc_acc_split[0] / loc_acc_split[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment