"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "b83bdb0cbee98c5e73ff193042a832381565d313"
Unverified Commit 765d725e authored by Israt Nisa's avatar Israt Nisa Committed by GitHub
Browse files

[Sparse] GAT example using sparse operator (#4645)



* GAT example

* Complete GAT and Add README

* Update
Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
Co-authored-by: default avatarczkkkkkk <zekucai@gmail.com>
parent 9ceb63db
Graph Attention Networks (GAT)
============
- Paper link: [https://arxiv.org/abs/1710.10903](https://arxiv.org/abs/1710.10903)
- Author's code repo (tensorflow implementation):
[https://github.com/PetarV-/GAT](https://github.com/PetarV-/GAT).
- Popular pytorch implementation:
[https://github.com/Diego999/pyGAT](https://github.com/Diego999/pyGAT).
How to run
-------
### Sparse tensor GATConv module
Run with the following (available dataset: "cora", "citeseer", "pubmed")
```bash
python3 train.py --dataset cora
```
Summary
-------
* cora: ~0.810
* citeseer: ~0.697
* pubmed: ~0.774
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from dgl import AddSelfLoop
import argparse
from dgl.mock_sparse import create_from_coo, softmax, bspmm
from torch.nn import init
class GATConv(nn.Module):
def __init__(self, in_size, out_size, n_heads):
super(GATConv, self).__init__()
self.in_size = in_size
self.out_size = out_size
self.n_heads = n_heads
self.W = nn.Parameter(torch.Tensor(in_size, out_size * n_heads))
self.a_l = nn.Parameter(torch.Tensor(1, n_heads, out_size))
self.a_r = nn.Parameter(torch.Tensor(1, n_heads, out_size))
self.leaky_relu = nn.LeakyReLU(0.2)
init.xavier_uniform_(self.W)
init.xavier_uniform_(self.a_l)
init.xavier_uniform_(self.a_r)
def forward(self, A, h):
Wh = (h @ self.W).view(
-1, self.n_heads, self.out_size
) # |V| x N_h x D_o
Wh1 = (Wh * self.a_l).sum(2) # |V| x N_h
Wh2 = (Wh * self.a_r).sum(2) # |V| x N_h
Wh1 = Wh1[A.row, :] # |E| x N_h
Wh2 = Wh2[A.col, :] # |E| x N_h
e = Wh1 + Wh2 # |E| x N_h
e = self.leaky_relu(e) # |E| x N_h
A = create_from_coo(
A.row, A.col, e, A.shape
) # |V| x |V| x N_h SparseMatrix
A_hat = softmax(A) # |V| x |V| x N_h SparseMatrix
Wh = Wh.reshape(-1, self.out_size, self.n_heads) # |V| x D_o x N_h
h_prime = bspmm(A_hat, Wh) # |V| x D_o x N_h
return torch.relu(h_prime)
class GAT(nn.Module):
def __init__(self, in_size, hidden_size, out_size, n_heads):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(GATConv(in_size, hidden_size, n_heads))
self.layers.append(GATConv(hidden_size * n_heads, out_size, n_heads))
def forward(self, A, features):
h = features
for i, layer in enumerate(self.layers):
h = layer(A, h)
if i == 1: # last layer
h = h.mean(1)
else: # other layer(s)
h = h.flatten(1)
return h
def evaluate(A, features, labels, mask, model):
model.eval()
with torch.no_grad():
logits = model(A, features)
logits = logits[mask]
labels = labels[mask]
_, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
def train(A, features, labels, masks, model):
# define train/val samples, loss function and optimizer
train_mask = masks[0]
val_mask = masks[1]
loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
# training loop
for epoch in range(50):
model.train()
logits = model(A, features)
loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = evaluate(A, features, labels, val_mask, model)
print(
"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
epoch, loss.item(), acc
)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
default="cora",
help="Dataset name ('cora', 'citeseer', 'pubmed').",
)
args = parser.parse_args()
print(f"Training with DGL SparseMatrix GATConv module.")
# load and preprocess dataset
transform = (
AddSelfLoop()
) # by default, it will first remove self-loops to prevent duplication
if args.dataset == "cora":
data = CoraGraphDataset(transform=transform)
elif args.dataset == "citeseer":
data = CiteseerGraphDataset(transform=transform)
elif args.dataset == "pubmed":
data = PubmedGraphDataset(transform=transform)
else:
raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0]
g = g.int()
features = g.ndata["feat"]
labels = g.ndata["label"]
masks = g.ndata["train_mask"], g.ndata["val_mask"], g.ndata["test_mask"]
row, col = g.adj_sparse("coo")
A = create_from_coo(
row, col, shape=(g.number_of_nodes(), g.number_of_nodes())
)
# create GAT model
in_size = features.shape[1]
out_size = data.num_classes
model = GAT(in_size, 8, out_size, 8)
# model training
print("Training...")
train(A, features, labels, masks, model)
# test the model
print("Testing...")
acc = evaluate(A, features, labels, masks[2], model)
print("Test accuracy {:.4f}".format(acc))
......@@ -28,7 +28,14 @@ def _sparse_dense_mm(A: SparseMatrix, X: torch.Tensor) -> torch.Tensor:
torch.Tensor
The result of multiplication
"""
return torch.matmul(A.adj, X)
is_one_dim = False
if len(X.shape) == 1:
is_one_dim = True
X = X.view(-1, 1)
ret = torch.sparse.mm(A.adj, X)
if is_one_dim:
ret = ret.view(-1)
return ret
def _sparse_sparse_mm(A1: SparseMatrix, A2: SparseMatrix) -> SparseMatrix:
"""Internal function for multiplying a sparse matrix by a sparse matrix
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment