Commit 91efc915 authored by rusty1s's avatar rusty1s
Browse files

add gcn example

parent e2d2af18
import argparse
import torch
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric_autoscale.models import GCN
from torch_geometric_autoscale import (get_data, metis, permute,
SubgraphLoader, compute_acc)
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Root directory of dataset storage.')
parser.add_argument('--device', type=int, default=0)
args = parser.parse_args()
torch.manual_seed(12345)
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
data, in_channels, out_channels = get_data(args.root, name='cora')
# Pre-process adjacency matrix for GCN:
data.adj_t = gcn_norm(data.adj_t, add_self_loops=True)
# Pre-partition the graph using Metis:
perm, ptr = metis(data.adj_t, num_parts=40, log=True)
data = permute(data, perm, log=True)
loader = SubgraphLoader(data, ptr, batch_size=10, shuffle=True)
model = GCN(
num_nodes=data.num_nodes,
in_channels=in_channels,
hidden_channels=16,
out_channels=out_channels,
num_layers=2,
dropout=0.5,
drop_input=True,
pool_size=2,
buffer_size=1000,
).to(device)
optimizer = torch.optim.Adam([
dict(params=model.reg_modules.parameters(), weight_decay=5e-4),
dict(params=model.nonreg_modules.parameters(), weight_decay=0)
], lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
def train(data, model, loader, optimizer):
model.train()
for batch, batch_size, n_id, offset, count in loader:
batch = batch.to(device)
train_mask = batch.train_mask[:batch_size]
optimizer.zero_grad()
out = model(batch.x, batch.adj_t, batch_size, n_id, offset, count)
loss = criterion(out[train_mask], batch.y[:batch_size][train_mask])
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
@torch.no_grad()
def test(data, model):
model.eval()
out = model(data.x.to(model.device), data.adj_t.to(model.device)).cpu()
train_acc = compute_acc(out, data.y, data.train_mask)
val_acc = compute_acc(out, data.y, data.val_mask)
test_acc = compute_acc(out, data.y, data.test_mask)
return train_acc, val_acc, test_acc
test(data, model) # Fill history.
best_val_acc = test_acc = 0
for epoch in range(1, 201):
train(data, model, loader, optimizer)
train_acc, val_acc, tmp_test_acc = test(data, model)
if val_acc > best_val_acc:
best_val_acc = val_acc
test_acc = tmp_test_acc
print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}'
f'Test: {tmp_test_acc:.4f}, Final: {test_acc:.4f}')
......@@ -8,7 +8,7 @@ from torch.utils.data import DataLoader
from torch_sparse import SparseTensor
from torch_geometric.data import Data
relabel_fn = torch.ops.scaling_gnns.relabel_one_hop
relabel_fn = torch.ops.torch_geometric_autoscale.relabel_one_hop
class SubData(NamedTuple):
......
......@@ -7,6 +7,8 @@ from torch import Tensor
from torch_sparse import SparseTensor
from torch_geometric.data import Data
partition_fn = torch.ops.torch_sparse.partition
def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False,
log: bool = True) -> Tuple[Tensor, Tensor]:
......@@ -22,8 +24,7 @@ def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False,
perm, ptr = torch.arange(num_nodes), torch.tensor([0, num_nodes])
else:
rowptr, col, _ = adj_t.csr()
cluster = torch.ops.torch_sparse.partition(rowptr, col, None,
num_parts, recursive)
cluster = partition_fn(rowptr, col, None, num_parts, recursive)
cluster, perm = cluster.sort()
ptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)
......
from .base import ScalableGNN
from .gcn import GCN
from .gat import GAT
from .appnp import APPNP
from .gcn2 import GCN2
from .gin import GIN
from .pna import PNA
from .pna_jk import PNA_JK
# from .gat import GAT
# from .appnp import APPNP
# from .gcn2 import GCN2
# from .gin import GIN
# from .pna import PNA
# from .pna_jk import PNA_JK
__all__ = [
'ScalableGNN',
'GCN',
'GAT',
'APPNP',
'GCN2',
'GIN',
'PNA',
'PNA_JK',
# 'GAT',
# 'APPNP',
# 'GCN2',
# 'GIN',
# 'PNA',
# 'PNA_JK',
]
......@@ -164,5 +164,5 @@ class ScalableGNN(torch.nn.Module):
@torch.no_grad()
def forward_layer(self, layer: int, x: Tensor, adj_t: SparseTensor,
state: Dict[Any]) -> Tensor:
state: Dict[str, Any]) -> Tensor:
raise NotImplementedError
from typing import Optional, Dict, Any
from typing import Optional
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import ModuleList, BatchNorm1d
from torch.nn import ModuleList, Linear, BatchNorm1d
from torch_sparse import SparseTensor
from torch_geometric.nn import GCNConv
from scaling_gnns.models.base2 import ScalableGNN
from torch_geometric_autoscale.models import ScalableGNN
class GCN(ScalableGNN):
def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
out_channels: int, num_layers: int, dropout: float = 0.0,
drop_input: bool = True, batch_norm: bool = False,
residual: bool = False, pool_size: Optional[int] = None,
residual: bool = False, linear: bool = False,
pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None):
super(GCN, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size, buffer_size, device)
......@@ -25,29 +26,43 @@ class GCN(ScalableGNN):
self.drop_input = drop_input
self.batch_norm = batch_norm
self.residual = residual
self.linear = linear
if linear:
self.lins = ModuleList()
self.lins.append(Linear(in_channels, hidden_channels))
self.lins.append(Linear(hidden_channels, out_channels))
self.convs = ModuleList()
for i in range(num_layers):
in_dim = in_channels if i == 0 else hidden_channels
out_dim = out_channels if i == num_layers - 1 else hidden_channels
in_dim = out_dim = hidden_channels
if i == 0 and not linear:
in_dim = in_channels
if i == num_layers - 1 and not linear:
out_dim = out_channels
conv = GCNConv(in_dim, out_dim, normalize=False)
self.convs.append(conv)
self.bns = ModuleList()
for i in range(num_layers - 1):
for i in range(num_layers):
bn = BatchNorm1d(hidden_channels)
self.bns.append(bn)
@property
def reg_modules(self):
if self.linear:
return ModuleList(list(self.convs) + list(self.bns))
else:
return ModuleList(list(self.convs[:-1]) + list(self.bns))
@property
def nonreg_modules(self):
return self.convs[-1:]
return self.lins if self.linear else self.convs[-1:]
def reset_parameters(self):
super(GCN, self).reset_parameters()
for lin in self.lins:
lin.reset_parameters()
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
......@@ -61,6 +76,10 @@ class GCN(ScalableGNN):
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
if self.linear:
x = self.lins[0](x).relu_()
x = F.dropout(x, p=self.dropout, training=self.training)
for conv, bn, hist in zip(self.convs[:-1], self.bns, self.histories):
h = conv(x, adj_t)
if self.batch_norm:
......@@ -71,23 +90,41 @@ class GCN(ScalableGNN):
x = self.push_and_pull(hist, x, batch_size, n_id, offset, count)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, adj_t)
return x
h = self.convs[-1](x, adj_t)
if not self.linear:
return h
if self.batch_norm:
h = self.bns[-1](h)
if self.residual and h.size(-1) == x.size(-1):
h += x[:h.size(0)]
h = h.relu_()
h = F.dropout(h, p=self.dropout, training=self.training)
return self.lins[1](h)
@torch.no_grad()
def forward_layer(self, layer: int, x: Tensor, adj_t: SparseTensor,
state: Dict[Any]) -> Tensor:
if layer == 0 and self.drop_input:
def forward_layer(self, layer, x, adj_t, state):
if layer == 0:
if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training)
if self.linear:
x = self.lins[0](x).relu_()
x = F.dropout(x, p=self.dropout, training=self.training)
else:
x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[layer](x, adj_t)
if layer < self.num_layers - 1:
if layer < self.num_layers - 1 or self.linear:
if self.batch_norm:
h = self.bns[layer](h)
if self.residual and h.size(-1) == x.size(-1):
h += x[:h.size(0)]
h = h.relu_()
if self.linear:
h = F.dropout(h, p=self.dropout, training=self.training)
h = self.lins[1](h)
return h
......@@ -4,9 +4,9 @@ import torch
from torch import Tensor
from torch.cuda import Stream
synchronize = torch.ops.scaling_gnns.synchronize
read_async = torch.ops.scaling_gnns.read_async
write_async = torch.ops.scaling_gnns.write_async
synchronize = torch.ops.torch_geometric_autoscale.synchronize
read_async = torch.ops.torch_geometric_autoscale.read_async
write_async = torch.ops.torch_geometric_autoscale.write_async
class AsyncIOPool(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment