Unverified Commit a1f74982 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Example] Fix Various Examples Related to TorchMetrics (#5521)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-36-188.ap-northeast-1.compute.internal>
parent 97286f98
Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks
============ ============
- Paper link: [Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks](https://arxiv.org/abs/1905.07953) - Paper link: [Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks](https://arxiv.org/abs/1905.07953)
- Author's code repo: [https://github.com/google-research/google-research/blob/master/cluster_gcn/](https://github.com/google-research/google-research/blob/master/cluster_gcn/). - Author's code repo: [https://github.com/google-research/google-research/blob/master/cluster_gcn/](https://github.com/google-research/google-research/blob/master/cluster_gcn/).
This repo reproduce the reported speed and performance maximally on Reddit and PPI. However, the diag enhancement is not covered, as the GraphSage aggregator already achieves satisfying F1 score. This repo reproduce the reported speed and performance maximally on Reddit and PPI. However, the diag enhancement is not covered, as the GraphSage aggregator already achieves satisfying F1 score.
...@@ -10,7 +10,7 @@ Dependencies ...@@ -10,7 +10,7 @@ Dependencies
- Python 3.7+(for string formatting features) - Python 3.7+(for string formatting features)
- PyTorch 1.9.0+ - PyTorch 1.9.0+
- sklearn - sklearn
- TorchMetrics - TorchMetrics 0.11.4
## Run Experiments ## Run Experiments
......
...@@ -72,7 +72,12 @@ for _ in range(10): ...@@ -72,7 +72,12 @@ for _ in range(10):
loss.backward() loss.backward()
opt.step() opt.step()
if it % 20 == 0: if it % 20 == 0:
acc = MF.accuracy(y_hat[m], y[m]) acc = MF.accuracy(
y_hat[m],
y[m],
task="multiclass",
num_classes=dataset.num_classes,
)
mem = torch.cuda.max_memory_allocated() / 1000000 mem = torch.cuda.max_memory_allocated() / 1000000
print("Loss", loss.item(), "Acc", acc.item(), "GPU Mem", mem, "MB") print("Loss", loss.item(), "Acc", acc.item(), "GPU Mem", mem, "MB")
tt = time.time() tt = time.time()
...@@ -97,8 +102,18 @@ for _ in range(10): ...@@ -97,8 +102,18 @@ for _ in range(10):
val_labels = torch.cat(val_labels, 0) val_labels = torch.cat(val_labels, 0)
test_preds = torch.cat(test_preds, 0) test_preds = torch.cat(test_preds, 0)
test_labels = torch.cat(test_labels, 0) test_labels = torch.cat(test_labels, 0)
val_acc = MF.accuracy(val_preds, val_labels) val_acc = MF.accuracy(
test_acc = MF.accuracy(test_preds, test_labels) val_preds,
val_labels,
task="multiclass",
num_classes=dataset.num_classes,
)
test_acc = MF.accuracy(
test_preds,
test_labels,
task="multiclass",
num_classes=dataset.num_classes,
)
print("Validation acc:", val_acc.item(), "Test acc:", test_acc.item()) print("Validation acc:", val_acc.item(), "Test acc:", test_acc.item())
print(np.mean(durations[4:]), np.std(durations[4:])) print(np.mean(durations[4:]), np.std(durations[4:]))
...@@ -10,7 +10,7 @@ Requirements ...@@ -10,7 +10,7 @@ Requirements
------------ ------------
```bash ```bash
pip install requests torchmetrics pip install requests torchmetrics==0.11.4 ogb
``` ```
How to run How to run
...@@ -25,7 +25,7 @@ python3 train_full.py --dataset cora --gpu 0 # full graph ...@@ -25,7 +25,7 @@ python3 train_full.py --dataset cora --gpu 0 # full graph
Results: Results:
``` ```
* cora: ~0.8330 * cora: ~0.8330
* citeseer: ~0.7110 * citeseer: ~0.7110
* pubmed: ~0.7830 * pubmed: ~0.7830
``` ```
...@@ -45,8 +45,7 @@ Test Accuracy: 0.7632 ...@@ -45,8 +45,7 @@ Test Accuracy: 0.7632
### PyTorch Lightning for node classification ### PyTorch Lightning for node classification
Train w/ mini-batch sampling for node classification with PyTorch Lightning on OGB-products. Train w/ mini-batch sampling for node classification with PyTorch Lightning on OGB-products. It requires PyTorch Lightning 2.0.1. It works with both single GPU and multiple GPUs:
Works with both single GPU and multiple GPUs:
```bash ```bash
python3 lightning/node_classification.py python3 lightning/node_classification.py
......
...@@ -27,8 +27,8 @@ class SAGE(LightningModule): ...@@ -27,8 +27,8 @@ class SAGE(LightningModule):
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
self.n_hidden = n_hidden self.n_hidden = n_hidden
self.n_classes = n_classes self.n_classes = n_classes
self.train_acc = Accuracy() self.train_acc = Accuracy(task="multiclass", num_classes=n_classes)
self.val_acc = Accuracy() self.val_acc = Accuracy(task="multiclass", num_classes=n_classes)
def forward(self, blocks, x): def forward(self, blocks, x):
h = x h = x
...@@ -180,9 +180,11 @@ if __name__ == "__main__": ...@@ -180,9 +180,11 @@ if __name__ == "__main__":
# Train # Train
checkpoint_callback = ModelCheckpoint(monitor="val_acc", save_top_k=1) checkpoint_callback = ModelCheckpoint(monitor="val_acc", save_top_k=1)
# Use this for single GPU # Use this for single GPU
# trainer = Trainer(gpus=[0], max_epochs=10, callbacks=[checkpoint_callback]) # trainer = Trainer(accelerator="gpu", devices=[0], max_epochs=10,
# callbacks=[checkpoint_callback])
trainer = Trainer( trainer = Trainer(
gpus=[0, 1, 2, 3], accelerator="gpu",
devices=[0, 1, 2, 3],
max_epochs=10, max_epochs=10,
callbacks=[checkpoint_callback], callbacks=[checkpoint_callback],
strategy="ddp_spawn", strategy="ddp_spawn",
...@@ -203,5 +205,7 @@ if __name__ == "__main__": ...@@ -203,5 +205,7 @@ if __name__ == "__main__":
pred = model.inference(graph, "cuda", 4096, 12, graph.device) pred = model.inference(graph, "cuda", 4096, 12, graph.device)
pred = pred[test_idx] pred = pred[test_idx]
label = graph.ndata["label"][test_idx] label = graph.ndata["label"][test_idx]
acc = MF.accuracy(pred, label) acc = MF.accuracy(
pred, label, task="multiclass", num_classes=datamodule.n_classes
)
print("Test accuracy:", acc) print("Test accuracy:", acc)
...@@ -5,7 +5,6 @@ import dgl.nn as dglnn ...@@ -5,7 +5,6 @@ import dgl.nn as dglnn
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchmetrics.functional as MF
import tqdm import tqdm
from dgl.dataloading import ( from dgl.dataloading import (
as_edge_prediction_sampler, as_edge_prediction_sampler,
......
...@@ -74,7 +74,7 @@ class SAGE(nn.Module): ...@@ -74,7 +74,7 @@ class SAGE(nn.Module):
return y return y
def evaluate(model, graph, dataloader): def evaluate(model, graph, dataloader, num_classes):
model.eval() model.eval()
ys = [] ys = []
y_hats = [] y_hats = []
...@@ -83,10 +83,15 @@ def evaluate(model, graph, dataloader): ...@@ -83,10 +83,15 @@ def evaluate(model, graph, dataloader):
x = blocks[0].srcdata["feat"] x = blocks[0].srcdata["feat"]
ys.append(blocks[-1].dstdata["label"]) ys.append(blocks[-1].dstdata["label"])
y_hats.append(model(blocks, x)) y_hats.append(model(blocks, x))
return MF.accuracy(torch.cat(y_hats), torch.cat(ys)) return MF.accuracy(
torch.cat(y_hats),
torch.cat(ys),
task="multiclass",
num_classes=num_classes,
)
def layerwise_infer(device, graph, nid, model, batch_size): def layerwise_infer(device, graph, nid, model, num_classes, batch_size):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
pred = model.inference( pred = model.inference(
...@@ -94,10 +99,12 @@ def layerwise_infer(device, graph, nid, model, batch_size): ...@@ -94,10 +99,12 @@ def layerwise_infer(device, graph, nid, model, batch_size):
) # pred in buffer_device ) # pred in buffer_device
pred = pred[nid] pred = pred[nid]
label = graph.ndata["label"][nid].to(pred.device) label = graph.ndata["label"][nid].to(pred.device)
return MF.accuracy(pred, label) return MF.accuracy(
pred, label, task="multiclass", num_classes=num_classes
)
def train(args, device, g, dataset, model): def train(args, device, g, dataset, model, num_classes):
# create sampler & dataloader # create sampler & dataloader
train_idx = dataset.train_idx.to(device) train_idx = dataset.train_idx.to(device)
val_idx = dataset.val_idx.to(device) val_idx = dataset.val_idx.to(device)
...@@ -147,7 +154,7 @@ def train(args, device, g, dataset, model): ...@@ -147,7 +154,7 @@ def train(args, device, g, dataset, model):
loss.backward() loss.backward()
opt.step() opt.step()
total_loss += loss.item() total_loss += loss.item()
acc = evaluate(model, g, val_dataloader) acc = evaluate(model, g, val_dataloader, num_classes)
print( print(
"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format( "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
epoch, total_loss / (it + 1), acc.item() epoch, total_loss / (it + 1), acc.item()
...@@ -174,6 +181,7 @@ if __name__ == "__main__": ...@@ -174,6 +181,7 @@ if __name__ == "__main__":
dataset = AsNodePredDataset(DglNodePropPredDataset("ogbn-products")) dataset = AsNodePredDataset(DglNodePropPredDataset("ogbn-products"))
g = dataset[0] g = dataset[0]
g = g.to("cuda" if args.mode == "puregpu" else "cpu") g = g.to("cuda" if args.mode == "puregpu" else "cpu")
num_classes = dataset.num_classes
device = torch.device("cpu" if args.mode == "cpu" else "cuda") device = torch.device("cpu" if args.mode == "cpu" else "cuda")
# create GraphSAGE model # create GraphSAGE model
...@@ -183,9 +191,11 @@ if __name__ == "__main__": ...@@ -183,9 +191,11 @@ if __name__ == "__main__":
# model training # model training
print("Training...") print("Training...")
train(args, device, g, dataset, model) train(args, device, g, dataset, model, num_classes)
# test the model # test the model
print("Testing...") print("Testing...")
acc = layerwise_infer(device, g, dataset.test_idx, model, batch_size=4096) acc = layerwise_infer(
device, g, dataset.test_idx, model, num_classes, batch_size=4096
)
print("Test Accuracy {:.4f}".format(acc.item())) print("Test Accuracy {:.4f}".format(acc.item()))
...@@ -5,7 +5,7 @@ Requirements ...@@ -5,7 +5,7 @@ Requirements
------------ ------------
```bash ```bash
pip install torchmetrics pip install torchmetrics==0.11.4
``` ```
How to run How to run
...@@ -68,6 +68,6 @@ Eval F1-score: ~0.7999 Test F1-score: ~0.6383 ...@@ -68,6 +68,6 @@ Eval F1-score: ~0.7999 Test F1-score: ~0.6383
Notably, Notably,
* The loss function is defined by predicting whether an edge exists between two nodes or not. * The loss function is defined by predicting whether an edge exists between two nodes or not.
* When computing the score of `(u, v)`, the connections between node `u` and `v` are removed from neighbor sampling. * When computing the score of `(u, v)`, the connections between node `u` and `v` are removed from neighbor sampling.
* The performance of the learned embeddings are measured by training a softmax regression with scikit-learn. * The performance of the learned embeddings are measured by training a softmax regression with scikit-learn.
...@@ -86,7 +86,7 @@ class SAGE(nn.Module): ...@@ -86,7 +86,7 @@ class SAGE(nn.Module):
return y return y
def evaluate(model, g, dataloader): def evaluate(model, g, num_classes, dataloader):
model.eval() model.eval()
ys = [] ys = []
y_hats = [] y_hats = []
...@@ -95,11 +95,16 @@ def evaluate(model, g, dataloader): ...@@ -95,11 +95,16 @@ def evaluate(model, g, dataloader):
x = blocks[0].srcdata["feat"] x = blocks[0].srcdata["feat"]
ys.append(blocks[-1].dstdata["label"]) ys.append(blocks[-1].dstdata["label"])
y_hats.append(model(blocks, x)) y_hats.append(model(blocks, x))
return MF.accuracy(torch.cat(y_hats), torch.cat(ys)) return MF.accuracy(
torch.cat(y_hats),
torch.cat(ys),
task="multiclass",
num_classes=num_classes,
)
def layerwise_infer( def layerwise_infer(
proc_id, device, g, nid, model, use_uva, batch_size=2**16 proc_id, device, g, num_classes, nid, model, use_uva, batch_size=2**16
): ):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -107,11 +112,15 @@ def layerwise_infer( ...@@ -107,11 +112,15 @@ def layerwise_infer(
pred = pred[nid] pred = pred[nid]
labels = g.ndata["label"][nid].to(pred.device) labels = g.ndata["label"][nid].to(pred.device)
if proc_id == 0: if proc_id == 0:
acc = MF.accuracy(pred, labels) acc = MF.accuracy(
pred, labels, task="multiclass", num_classes=num_classes
)
print("Test Accuracy {:.4f}".format(acc.item())) print("Test Accuracy {:.4f}".format(acc.item()))
def train(proc_id, nprocs, device, g, train_idx, val_idx, model, use_uva): def train(
proc_id, nprocs, device, g, num_classes, train_idx, val_idx, model, use_uva
):
sampler = NeighborSampler( sampler = NeighborSampler(
[10, 10, 10], prefetch_node_feats=["feat"], prefetch_labels=["label"] [10, 10, 10], prefetch_node_feats=["feat"], prefetch_labels=["label"]
) )
...@@ -154,7 +163,9 @@ def train(proc_id, nprocs, device, g, train_idx, val_idx, model, use_uva): ...@@ -154,7 +163,9 @@ def train(proc_id, nprocs, device, g, train_idx, val_idx, model, use_uva):
loss.backward() loss.backward()
opt.step() opt.step()
total_loss += loss total_loss += loss
acc = evaluate(model, g, val_dataloader).to(device) / nprocs acc = (
evaluate(model, g, num_classes, val_dataloader).to(device) / nprocs
)
dist.reduce(acc, 0) dist.reduce(acc, 0)
if proc_id == 0: if proc_id == 0:
print( print(
...@@ -175,20 +186,30 @@ def run(proc_id, nprocs, devices, g, data, mode): ...@@ -175,20 +186,30 @@ def run(proc_id, nprocs, devices, g, data, mode):
world_size=nprocs, world_size=nprocs,
rank=proc_id, rank=proc_id,
) )
out_size, train_idx, val_idx, test_idx = data num_classes, train_idx, val_idx, test_idx = data
train_idx = train_idx.to(device) train_idx = train_idx.to(device)
val_idx = val_idx.to(device) val_idx = val_idx.to(device)
g = g.to(device if mode == "puregpu" else "cpu") g = g.to(device if mode == "puregpu" else "cpu")
# create GraphSAGE model (distributed) # create GraphSAGE model (distributed)
in_size = g.ndata["feat"].shape[1] in_size = g.ndata["feat"].shape[1]
model = SAGE(in_size, 256, out_size).to(device) model = SAGE(in_size, 256, num_classes).to(device)
model = DistributedDataParallel( model = DistributedDataParallel(
model, device_ids=[device], output_device=device model, device_ids=[device], output_device=device
) )
# training + testing # training + testing
use_uva = mode == "mixed" use_uva = mode == "mixed"
train(proc_id, nprocs, device, g, train_idx, val_idx, model, use_uva) train(
layerwise_infer(proc_id, device, g, test_idx, model, use_uva) proc_id,
nprocs,
device,
g,
num_classes,
train_idx,
val_idx,
model,
use_uva,
)
layerwise_infer(proc_id, device, g, num_classes, test_idx, model, use_uva)
# cleanup process group # cleanup process group
dist.destroy_process_group() dist.destroy_process_group()
......
...@@ -4,11 +4,11 @@ This is an adaptation of RGCN where graph convolution is replaced with graph att ...@@ -4,11 +4,11 @@ This is an adaptation of RGCN where graph convolution is replaced with graph att
Dependencies Dependencies
------------ ------------
- torchmetrics - torchmetrics 0.11.4
Install as follows: Install as follows:
```bash ```bash
pip install torchmetrics pip install torchmetrics==0.11.4
``` ```
How to Run How to Run
......
...@@ -57,7 +57,7 @@ class HeteroGAT(nn.Module): ...@@ -57,7 +57,7 @@ class HeteroGAT(nn.Module):
return self.linear(h["paper"]) return self.linear(h["paper"])
def evaluate(model, dataloader, desc): def evaluate(num_classes, model, dataloader, desc):
preds = [] preds = []
labels = [] labels = []
with torch.no_grad(): with torch.no_grad():
...@@ -71,11 +71,13 @@ def evaluate(model, dataloader, desc): ...@@ -71,11 +71,13 @@ def evaluate(model, dataloader, desc):
labels.append(y.cpu()) labels.append(y.cpu())
preds = torch.cat(preds, 0) preds = torch.cat(preds, 0)
labels = torch.cat(labels, 0) labels = torch.cat(labels, 0)
acc = MF.accuracy(preds, labels) acc = MF.accuracy(
preds, labels, task="multiclass", num_classes=num_classes
)
return acc return acc
def train(train_loader, val_loader, test_loader, model): def train(train_loader, val_loader, test_loader, num_classes, model):
# loss function and optimizer # loss function and optimizer
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
...@@ -96,8 +98,8 @@ def train(train_loader, val_loader, test_loader, model): ...@@ -96,8 +98,8 @@ def train(train_loader, val_loader, test_loader, model):
opt.step() opt.step()
total_loss += loss.item() total_loss += loss.item()
model.eval() model.eval()
val_acc = evaluate(model, val_dataloader, "Val. ") val_acc = evaluate(num_classes, model, val_dataloader, "Val. ")
test_acc = evaluate(model, test_dataloader, "Test ") test_acc = evaluate(num_classes, model, test_dataloader, "Test ")
print( print(
f"Epoch {epoch:05d} | Loss {total_loss/(it+1):.4f} | Validation Acc. {val_acc.item():.4f} | Test Acc. {test_acc.item():.4f}" f"Epoch {epoch:05d} | Loss {total_loss/(it+1):.4f} | Validation Acc. {val_acc.item():.4f} | Test Acc. {test_acc.item():.4f}"
) )
...@@ -138,8 +140,8 @@ if __name__ == "__main__": ...@@ -138,8 +140,8 @@ if __name__ == "__main__":
# create RGAT model # create RGAT model
in_size = graph.ndata["feat"]["paper"].shape[1] in_size = graph.ndata["feat"]["paper"].shape[1]
out_size = dataset.num_classes num_classes = dataset.num_classes
model = HeteroGAT(graph.etypes, in_size, 256, out_size).to(device) model = HeteroGAT(graph.etypes, in_size, 256, num_classes).to(device)
# dataloader + model training + testing # dataloader + model training + testing
train_sampler = NeighborSampler( train_sampler = NeighborSampler(
...@@ -186,4 +188,4 @@ if __name__ == "__main__": ...@@ -186,4 +188,4 @@ if __name__ == "__main__":
use_uva=torch.cuda.is_available(), use_uva=torch.cuda.is_available(),
) )
train(train_dataloader, val_dataloader, test_dataloader, model) train(train_dataloader, val_dataloader, test_dataloader, num_classes, model)
...@@ -6,12 +6,12 @@ ...@@ -6,12 +6,12 @@
### Dependencies ### Dependencies
- rdflib - rdflib
- torchmetrics - torchmetrics 0.11.4
Install as follows: Install as follows:
```bash ```bash
pip install rdflib pip install rdflib
pip install torchmetrics pip install torchmetrics==0.11.4
``` ```
How to run How to run
......
...@@ -38,16 +38,21 @@ class RGCN(nn.Module): ...@@ -38,16 +38,21 @@ class RGCN(nn.Module):
return h return h
def evaluate(g, target_idx, labels, test_mask, model): def evaluate(g, target_idx, labels, num_classes, test_mask, model):
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze() test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
logits = model(g) logits = model(g)
logits = logits[target_idx] logits = logits[target_idx]
return accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item() return accuracy(
logits[test_idx].argmax(dim=1),
labels[test_idx],
task="multiclass",
num_classes=num_classes,
).item()
def train(g, target_idx, labels, train_mask, model): def train(g, target_idx, labels, num_classes, train_mask, model):
# define train idx, loss function and optimizer # define train idx, loss function and optimizer
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze() train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
...@@ -62,7 +67,10 @@ def train(g, target_idx, labels, train_mask, model): ...@@ -62,7 +67,10 @@ def train(g, target_idx, labels, train_mask, model):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
acc = accuracy( acc = accuracy(
logits[train_idx].argmax(dim=1), labels[train_idx] logits[train_idx].argmax(dim=1),
labels[train_idx],
task="multiclass",
num_classes=num_classes,
).item() ).item()
print( print(
"Epoch {:05d} | Loss {:.4f} | Train Accuracy {:.4f} ".format( "Epoch {:05d} | Loss {:.4f} | Train Accuracy {:.4f} ".format(
...@@ -112,9 +120,9 @@ if __name__ == "__main__": ...@@ -112,9 +120,9 @@ if __name__ == "__main__":
target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id] target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
# create RGCN model # create RGCN model
in_size = g.num_nodes() # featureless with one-hot encoding in_size = g.num_nodes() # featureless with one-hot encoding
out_size = data.num_classes num_classes = data.num_classes
model = RGCN(in_size, 16, out_size, num_rels).to(device) model = RGCN(in_size, 16, num_classes, num_rels).to(device)
train(g, target_idx, labels, train_mask, model) train(g, target_idx, labels, num_classes, train_mask, model)
acc = evaluate(g, target_idx, labels, test_mask, model) acc = evaluate(g, target_idx, labels, num_classes, test_mask, model)
print("Test accuracy {:.4f}".format(acc)) print("Test accuracy {:.4f}".format(acc))
...@@ -41,7 +41,7 @@ class RGCN(nn.Module): ...@@ -41,7 +41,7 @@ class RGCN(nn.Module):
return h return h
def evaluate(model, label, dataloader, inv_target): def evaluate(model, labels, num_classes, dataloader, inv_target):
model.eval() model.eval()
eval_logits = [] eval_logits = []
eval_seeds = [] eval_seeds = []
...@@ -55,10 +55,15 @@ def evaluate(model, label, dataloader, inv_target): ...@@ -55,10 +55,15 @@ def evaluate(model, label, dataloader, inv_target):
eval_seeds.append(output_nodes.cpu().detach()) eval_seeds.append(output_nodes.cpu().detach())
eval_logits = torch.cat(eval_logits) eval_logits = torch.cat(eval_logits)
eval_seeds = torch.cat(eval_seeds) eval_seeds = torch.cat(eval_seeds)
return accuracy(eval_logits.argmax(dim=1), labels[eval_seeds].cpu()).item() return accuracy(
eval_logits.argmax(dim=1),
labels[eval_seeds].cpu(),
task="multiclass",
num_classes=num_classes,
).item()
def train(device, g, target_idx, labels, train_mask, model): def train(device, g, target_idx, labels, train_mask, num_classes, model):
# define train idx, loss function and optimizer # define train idx, loss function and optimizer
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze() train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
...@@ -95,7 +100,7 @@ def train(device, g, target_idx, labels, train_mask, model): ...@@ -95,7 +100,7 @@ def train(device, g, target_idx, labels, train_mask, model):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
total_loss += loss.item() total_loss += loss.item()
acc = evaluate(model, labels, val_loader, inv_target) acc = evaluate(model, labels, num_classes, val_loader, inv_target)
print( print(
"Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} ".format( "Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} ".format(
epoch, total_loss / (it + 1), acc epoch, total_loss / (it + 1), acc
...@@ -150,10 +155,10 @@ if __name__ == "__main__": ...@@ -150,10 +155,10 @@ if __name__ == "__main__":
# create RGCN model # create RGCN model
in_size = g.num_nodes() # featureless with one-hot encoding in_size = g.num_nodes() # featureless with one-hot encoding
out_size = data.num_classes num_classes = data.num_classes
model = RGCN(in_size, 16, out_size, num_rels).to(device) model = RGCN(in_size, 16, num_classes, num_rels).to(device)
train(device, g, target_idx, labels, train_mask, model) train(device, g, target_idx, labels, train_mask, num_classes, model)
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze() test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
test_sampler = MultiLayerNeighborSampler( test_sampler = MultiLayerNeighborSampler(
[-1, -1] [-1, -1]
...@@ -166,5 +171,5 @@ if __name__ == "__main__": ...@@ -166,5 +171,5 @@ if __name__ == "__main__":
batch_size=32, batch_size=32,
shuffle=False, shuffle=False,
) )
acc = evaluate(model, labels, test_loader, inv_target) acc = evaluate(model, labels, num_classes, test_loader, inv_target)
print("Test accuracy {:.4f}".format(acc)) print("Test accuracy {:.4f}".format(acc))
...@@ -45,7 +45,7 @@ class RGCN(nn.Module): ...@@ -45,7 +45,7 @@ class RGCN(nn.Module):
return h return h
def evaluate(model, labels, dataloader, inv_target): def evaluate(model, labels, num_classes, dataloader, inv_target):
model.eval() model.eval()
eval_logits = [] eval_logits = []
eval_seeds = [] eval_seeds = []
...@@ -61,12 +61,25 @@ def evaluate(model, labels, dataloader, inv_target): ...@@ -61,12 +61,25 @@ def evaluate(model, labels, dataloader, inv_target):
eval_seeds = torch.cat(eval_seeds) eval_seeds = torch.cat(eval_seeds)
num_seeds = len(eval_seeds) num_seeds = len(eval_seeds)
loc_sum = accuracy( loc_sum = accuracy(
eval_logits.argmax(dim=1), labels[eval_seeds].cpu() eval_logits.argmax(dim=1),
labels[eval_seeds].cpu(),
task="multiclass",
num_classes=num_classes,
) * float(num_seeds) ) * float(num_seeds)
return torch.tensor([loc_sum.item(), float(num_seeds)]) return torch.tensor([loc_sum.item(), float(num_seeds)])
def train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model): def train(
proc_id,
device,
g,
target_idx,
labels,
num_classes,
train_idx,
inv_target,
model,
):
# define loss function and optimizer # define loss function and optimizer
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4) optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
...@@ -106,9 +119,9 @@ def train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model): ...@@ -106,9 +119,9 @@ def train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model):
total_loss += loss.item() total_loss += loss.item()
# torchmetric accuracy defined as num_correct_labels / num_train_nodes # torchmetric accuracy defined as num_correct_labels / num_train_nodes
# loc_acc_split = [loc_accuracy * loc_num_train_nodes, loc_num_train_nodes] # loc_acc_split = [loc_accuracy * loc_num_train_nodes, loc_num_train_nodes]
loc_acc_split = evaluate(model, labels, val_loader, inv_target).to( loc_acc_split = evaluate(
device model, labels, num_classes, val_loader, inv_target
) ).to(device)
dist.reduce(loc_acc_split, 0) dist.reduce(loc_acc_split, 0)
if proc_id == 0: if proc_id == 0:
acc = loc_acc_split[0] / loc_acc_split[1] acc = loc_acc_split[0] / loc_acc_split[1]
...@@ -143,13 +156,22 @@ def run(proc_id, nprocs, devices, g, data): ...@@ -143,13 +156,22 @@ def run(proc_id, nprocs, devices, g, data):
inv_target = inv_target.to(device) inv_target = inv_target.to(device)
# create RGCN model (distributed) # create RGCN model (distributed)
in_size = g.num_nodes() in_size = g.num_nodes()
out_size = num_classes model = RGCN(in_size, 16, num_classes, num_rels).to(device)
model = RGCN(in_size, 16, out_size, num_rels).to(device)
model = DistributedDataParallel( model = DistributedDataParallel(
model, device_ids=[device], output_device=device model, device_ids=[device], output_device=device
) )
# training + testing # training + testing
train(proc_id, device, g, target_idx, labels, train_idx, inv_target, model) train(
proc_id,
device,
g,
target_idx,
labels,
num_classes,
train_idx,
inv_target,
model,
)
test_sampler = MultiLayerNeighborSampler( test_sampler = MultiLayerNeighborSampler(
[-1, -1] [-1, -1]
) # -1 for sampling all neighbors ) # -1 for sampling all neighbors
...@@ -162,7 +184,9 @@ def run(proc_id, nprocs, devices, g, data): ...@@ -162,7 +184,9 @@ def run(proc_id, nprocs, devices, g, data):
shuffle=False, shuffle=False,
use_ddp=True, use_ddp=True,
) )
loc_acc_split = evaluate(model, labels, test_loader, inv_target).to(device) loc_acc_split = evaluate(
model, labels, num_classes, test_loader, inv_target
).to(device)
dist.reduce(loc_acc_split, 0) dist.reduce(loc_acc_split, 0)
if proc_id == 0: if proc_id == 0:
acc = loc_acc_split[0] / loc_acc_split[1] acc = loc_acc_split[0] / loc_acc_split[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment