Unverified Commit 07db09f7 authored by SinuoXu's avatar SinuoXu Committed by GitHub
Browse files

[Example] add ogc method (#6437)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 32749512
...@@ -5,6 +5,14 @@ The folder contains example implementations of selected research papers related ...@@ -5,6 +5,14 @@ The folder contains example implementations of selected research papers related
* For examples working with a certain release, check out `https://github.com/dmlc/dgl/tree/<release_version>/examples` (E.g., https://github.com/dmlc/dgl/tree/0.5.x/examples) * For examples working with a certain release, check out `https://github.com/dmlc/dgl/tree/<release_version>/examples` (E.g., https://github.com/dmlc/dgl/tree/0.5.x/examples)
To quickly locate the examples of your interest, search for the tagged keywords or use the search tool on [dgl.ai](https://www.dgl.ai/). To quickly locate the examples of your interest, search for the tagged keywords or use the search tool on [dgl.ai](https://www.dgl.ai/).
## 2023
- <a name="labor"></a> Zheng Wang et al. From Cluster Assumption to Graph Convolution: Graph-based Semi-Supervised Learning Revisited. [Paper link](https://arxiv.org/abs/2210.13339)
- Example code: [PyTorch](../examples/pytorch/ogc)
- Tags: semi-supervised node classification
## 2022 ## 2022
- <a name="labor"></a> Balin et al. Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs. [Paper link](https://arxiv.org/abs/2210.13339) - <a name="labor"></a> Balin et al. Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs. [Paper link](https://arxiv.org/abs/2210.13339)
- Example code: [PyTorch](../examples/labor/train_lightning.py) - Example code: [PyTorch](../examples/labor/train_lightning.py)
......
# Optimized Graph Convolution (OGC)
This DGL example implements the OGC method from the paper: [From Cluster Assumption to Graph Convolution: Graph-based Semi-Supervised Learning Revisited](https://arxiv.org/abs/2309.13599).
With only one trainable layer, OGC is a very simple but powerful graph convolution method.
## Example Implementor
This example was implemented by [Sinuo Xu](https://github.com/SinuoXu) when she was an undergraduate at SJTU.
## Dependencies
Python 3.11.5
PyTorch 2.0.1
DGL 1.1.2
scikit-learn 1.3.1
## Dataset
The DGL's built-in Cora, Pubmed and Citeseer datasets, as follows:
| Dataset | #Nodes | #Edges | #Feats | #Classes | #Train Nodes | #Val Nodes | #Test Nodes |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| Citeseer | 3,327 | 9,228 | 3,703 | 6 | 120 | 500 | 1000 |
| Cora | 2,708 | 10,556 | 1,433 | 7 | 140 | 500 | 1000 |
| Pubmed | 19,717 | 88,651 | 500 | 3 | 60 | 500 | 1000 |
## Usage
```bash
python main.py --dataset cora
python main.py --dataset citeseer
python main.py --dataset pubmed
```
## Performance
| Dataset | Cora | Citeseer | Pubmed |
| :-: | :-: | :-: | :-: |
| OGC (DGL) | **86.9(±0.2)** | **77.4(±0.1)** | **83.6(±0.1)** |
| OGC (Reported) | **86.9(±0.0)** | **77.4(±0.0)** | 83.4(±0.0) |
import dgl.sparse as dglsp
import torch.nn as nn
import torch.nn.functional as F
from utils import LinearNeuralNetwork
class OGC(nn.Module):
def __init__(self, graph):
super(OGC, self).__init__()
self.linear_clf = LinearNeuralNetwork(
nfeat=graph.ndata["feat"].shape[1],
nclass=graph.ndata["label"].max().item() + 1,
bias=False,
)
self.label = graph.ndata["label"]
self.label_one_hot = F.one_hot(graph.ndata["label"]).float()
# LIM trick, else use both train and val set to construct this matrix.
self.label_idx_mat = dglsp.diag(graph.ndata["train_mask"]).float()
self.test_mask = graph.ndata["test_mask"]
self.tv_mask = graph.ndata["train_mask"] + graph.ndata["val_mask"]
def forward(self, x):
return self.linear_clf(x)
def update_embeds(self, embeds, lazy_adj, args):
"""Update classifier's weight by training a linear supervised model."""
pred_label = self(embeds).data
clf_weight = self.linear_clf.W.weight.data
# Update the smoothness loss via LGC.
embeds = dglsp.spmm(lazy_adj, embeds)
# Update the supervised loss via SEB.
deriv_sup = 2 * dglsp.matmul(
dglsp.spmm(self.label_idx_mat, -self.label_one_hot + pred_label),
clf_weight,
)
embeds = embeds - args.lr_sup * deriv_sup
args.lr_sup = args.lr_sup * args.decline
return embeds
import argparse
import time
import dgl.sparse as dglsp
import torch.nn.functional as F
import torch.optim as optim
from dgl import AddSelfLoop
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from ogc import OGC
from utils import model_test, symmetric_normalize_adjacency
def train(model, embeds, lazy_adj, args):
patience = 0
_, _, last_acc, last_output = model_test(model, embeds)
tv_mask = model.tv_mask
optimizer = optim.SGD(model.parameters(), lr=args.lr_clf)
for i in range(64):
model.train()
output = model(embeds)
loss_tv = F.mse_loss(
output[tv_mask], model.label_one_hot[tv_mask], reduction="sum"
)
optimizer.zero_grad()
loss_tv.backward()
optimizer.step()
# Updating node embeds by LGC and SEB jointly.
embeds = model.update_embeds(embeds, lazy_adj, args)
loss_tv, acc_tv, acc_test, pred = model_test(model, embeds)
print(
"epoch {} loss_tv {:.4f} acc_tv {:.4f} acc_test {:.4f}".format(
i + 1, loss_tv, acc_tv, acc_test
)
)
sim_rate = float(int((pred == last_output).sum()) / int(pred.shape[0]))
if sim_rate > args.max_sim_rate:
patience += 1
if patience > args.max_patience:
break
last_acc = acc_test
last_output = pred
return last_acc
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
default="citeseer",
choices=["cora", "citeseer", "pubmed"],
help="dataset to use",
)
parser.add_argument(
"--decline", type=float, default=0.9, help="decline rate"
)
parser.add_argument(
"--lr_sup",
type=float,
default=0.001,
help="learning rate for supervised loss",
)
parser.add_argument(
"--lr_clf",
type=float,
default=0.5,
help="learning rate for the used linear classifier",
)
parser.add_argument(
"--beta",
type=float,
default=0.1,
help="moving probability that a node moves to its neighbors",
)
parser.add_argument(
"--max_sim_rate",
type=float,
default=0.995,
help="max label prediction similarity between iterations",
)
parser.add_argument(
"--max_patience",
type=int,
default=2,
help="tolerance for consecutively similar test predictions",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="device to use",
)
args, _ = parser.parse_known_args()
# Load and preprocess dataset.
transform = AddSelfLoop()
if args.dataset == "cora":
data = CoraGraphDataset(transform=transform)
elif args.dataset == "citeseer":
data = CiteseerGraphDataset(transform=transform)
elif args.dataset == "pubmed":
data = PubmedGraphDataset(transform=transform)
else:
raise ValueError("Unknown dataset: {}".format(args.dataset))
graph = data[0].to(args.device)
features = graph.ndata["feat"]
adj = symmetric_normalize_adjacency(graph)
I_N = dglsp.identity((features.shape[0], features.shape[0]))
# Lazy random walk (also known as lazy graph convolution).
lazy_adj = dglsp.add((1 - args.beta) * I_N, args.beta * adj).to(args.device)
model = OGC(graph).to(args.device)
start_time = time.time()
res = train(model, features, lazy_adj, args)
time_tot = time.time() - start_time
print(f"Test Acc:{res:.4f}")
print(f"Total Time:{time_tot:.4f}")
import dgl.sparse as dglsp
import torch
import torch.nn as nn
import torch.nn.functional as F
class LinearNeuralNetwork(nn.Module):
def __init__(self, nfeat, nclass, bias=True):
super(LinearNeuralNetwork, self).__init__()
self.W = nn.Linear(nfeat, nclass, bias=bias)
def forward(self, x):
return self.W(x)
def symmetric_normalize_adjacency(graph):
"""Symmetric normalize graph adjacency matrix."""
indices = torch.stack(graph.edges())
n = graph.num_nodes()
adj = dglsp.spmatrix(indices, shape=(n, n))
deg_invsqrt = dglsp.diag(adj.sum(0)) ** -0.5
return deg_invsqrt @ adj @ deg_invsqrt
def model_test(model, embeds):
model.eval()
with torch.no_grad():
output = model(embeds)
pred = output.argmax(dim=-1)
test_mask, tv_mask = model.test_mask, model.tv_mask
loss_tv = F.mse_loss(output[tv_mask], model.label_one_hot[tv_mask])
accs = []
for mask in [tv_mask, test_mask]:
accs.append(float((pred[mask] == model.label[mask]).sum() / mask.sum()))
return loss_tv.item(), accs[0], accs[1], pred
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment