Commit 91efc915 authored by rusty1s's avatar rusty1s
Browse files

add gcn example

parent e2d2af18
import argparse
import torch
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric_autoscale.models import GCN
from torch_geometric_autoscale import (get_data, metis, permute,
SubgraphLoader, compute_acc)
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True,
help='Root directory of dataset storage.')
parser.add_argument('--device', type=int, default=0)
args = parser.parse_args()
torch.manual_seed(12345)
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
data, in_channels, out_channels = get_data(args.root, name='cora')
# Pre-process adjacency matrix for GCN:
data.adj_t = gcn_norm(data.adj_t, add_self_loops=True)
# Pre-partition the graph using Metis:
perm, ptr = metis(data.adj_t, num_parts=40, log=True)
data = permute(data, perm, log=True)
loader = SubgraphLoader(data, ptr, batch_size=10, shuffle=True)
model = GCN(
num_nodes=data.num_nodes,
in_channels=in_channels,
hidden_channels=16,
out_channels=out_channels,
num_layers=2,
dropout=0.5,
drop_input=True,
pool_size=2,
buffer_size=1000,
).to(device)
optimizer = torch.optim.Adam([
dict(params=model.reg_modules.parameters(), weight_decay=5e-4),
dict(params=model.nonreg_modules.parameters(), weight_decay=0)
], lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
def train(data, model, loader, optimizer):
model.train()
for batch, batch_size, n_id, offset, count in loader:
batch = batch.to(device)
train_mask = batch.train_mask[:batch_size]
optimizer.zero_grad()
out = model(batch.x, batch.adj_t, batch_size, n_id, offset, count)
loss = criterion(out[train_mask], batch.y[:batch_size][train_mask])
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
@torch.no_grad()
def test(data, model):
model.eval()
out = model(data.x.to(model.device), data.adj_t.to(model.device)).cpu()
train_acc = compute_acc(out, data.y, data.train_mask)
val_acc = compute_acc(out, data.y, data.val_mask)
test_acc = compute_acc(out, data.y, data.test_mask)
return train_acc, val_acc, test_acc
test(data, model) # Fill history.
best_val_acc = test_acc = 0
for epoch in range(1, 201):
train(data, model, loader, optimizer)
train_acc, val_acc, tmp_test_acc = test(data, model)
if val_acc > best_val_acc:
best_val_acc = val_acc
test_acc = tmp_test_acc
print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}'
f'Test: {tmp_test_acc:.4f}, Final: {test_acc:.4f}')
...@@ -8,7 +8,7 @@ from torch.utils.data import DataLoader ...@@ -8,7 +8,7 @@ from torch.utils.data import DataLoader
from torch_sparse import SparseTensor from torch_sparse import SparseTensor
from torch_geometric.data import Data from torch_geometric.data import Data
relabel_fn = torch.ops.scaling_gnns.relabel_one_hop relabel_fn = torch.ops.torch_geometric_autoscale.relabel_one_hop
class SubData(NamedTuple): class SubData(NamedTuple):
......
...@@ -7,6 +7,8 @@ from torch import Tensor ...@@ -7,6 +7,8 @@ from torch import Tensor
from torch_sparse import SparseTensor from torch_sparse import SparseTensor
from torch_geometric.data import Data from torch_geometric.data import Data
partition_fn = torch.ops.torch_sparse.partition
def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False, def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False,
log: bool = True) -> Tuple[Tensor, Tensor]: log: bool = True) -> Tuple[Tensor, Tensor]:
...@@ -22,8 +24,7 @@ def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False, ...@@ -22,8 +24,7 @@ def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False,
perm, ptr = torch.arange(num_nodes), torch.tensor([0, num_nodes]) perm, ptr = torch.arange(num_nodes), torch.tensor([0, num_nodes])
else: else:
rowptr, col, _ = adj_t.csr() rowptr, col, _ = adj_t.csr()
cluster = torch.ops.torch_sparse.partition(rowptr, col, None, cluster = partition_fn(rowptr, col, None, num_parts, recursive)
num_parts, recursive)
cluster, perm = cluster.sort() cluster, perm = cluster.sort()
ptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts) ptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts)
......
from .base import ScalableGNN from .base import ScalableGNN
from .gcn import GCN from .gcn import GCN
from .gat import GAT # from .gat import GAT
from .appnp import APPNP # from .appnp import APPNP
from .gcn2 import GCN2 # from .gcn2 import GCN2
from .gin import GIN # from .gin import GIN
from .pna import PNA # from .pna import PNA
from .pna_jk import PNA_JK # from .pna_jk import PNA_JK
__all__ = [ __all__ = [
'ScalableGNN', 'ScalableGNN',
'GCN', 'GCN',
'GAT', # 'GAT',
'APPNP', # 'APPNP',
'GCN2', # 'GCN2',
'GIN', # 'GIN',
'PNA', # 'PNA',
'PNA_JK', # 'PNA_JK',
] ]
...@@ -164,5 +164,5 @@ class ScalableGNN(torch.nn.Module): ...@@ -164,5 +164,5 @@ class ScalableGNN(torch.nn.Module):
@torch.no_grad() @torch.no_grad()
def forward_layer(self, layer: int, x: Tensor, adj_t: SparseTensor, def forward_layer(self, layer: int, x: Tensor, adj_t: SparseTensor,
state: Dict[Any]) -> Tensor: state: Dict[str, Any]) -> Tensor:
raise NotImplementedError raise NotImplementedError
from typing import Optional, Dict, Any from typing import Optional
import torch import torch
from torch import Tensor from torch import Tensor
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import ModuleList, BatchNorm1d from torch.nn import ModuleList, Linear, BatchNorm1d
from torch_sparse import SparseTensor from torch_sparse import SparseTensor
from torch_geometric.nn import GCNConv from torch_geometric.nn import GCNConv
from scaling_gnns.models.base2 import ScalableGNN from torch_geometric_autoscale.models import ScalableGNN
class GCN(ScalableGNN): class GCN(ScalableGNN):
def __init__(self, num_nodes: int, in_channels, hidden_channels: int, def __init__(self, num_nodes: int, in_channels, hidden_channels: int,
out_channels: int, num_layers: int, dropout: float = 0.0, out_channels: int, num_layers: int, dropout: float = 0.0,
drop_input: bool = True, batch_norm: bool = False, drop_input: bool = True, batch_norm: bool = False,
residual: bool = False, pool_size: Optional[int] = None, residual: bool = False, linear: bool = False,
pool_size: Optional[int] = None,
buffer_size: Optional[int] = None, device=None): buffer_size: Optional[int] = None, device=None):
super(GCN, self).__init__(num_nodes, hidden_channels, num_layers, super(GCN, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size, buffer_size, device) pool_size, buffer_size, device)
...@@ -25,29 +26,43 @@ class GCN(ScalableGNN): ...@@ -25,29 +26,43 @@ class GCN(ScalableGNN):
self.drop_input = drop_input self.drop_input = drop_input
self.batch_norm = batch_norm self.batch_norm = batch_norm
self.residual = residual self.residual = residual
self.linear = linear
if linear:
self.lins = ModuleList()
self.lins.append(Linear(in_channels, hidden_channels))
self.lins.append(Linear(hidden_channels, out_channels))
self.convs = ModuleList() self.convs = ModuleList()
for i in range(num_layers): for i in range(num_layers):
in_dim = in_channels if i == 0 else hidden_channels in_dim = out_dim = hidden_channels
out_dim = out_channels if i == num_layers - 1 else hidden_channels if i == 0 and not linear:
in_dim = in_channels
if i == num_layers - 1 and not linear:
out_dim = out_channels
conv = GCNConv(in_dim, out_dim, normalize=False) conv = GCNConv(in_dim, out_dim, normalize=False)
self.convs.append(conv) self.convs.append(conv)
self.bns = ModuleList() self.bns = ModuleList()
for i in range(num_layers - 1): for i in range(num_layers):
bn = BatchNorm1d(hidden_channels) bn = BatchNorm1d(hidden_channels)
self.bns.append(bn) self.bns.append(bn)
@property @property
def reg_modules(self): def reg_modules(self):
return ModuleList(list(self.convs[:-1]) + list(self.bns)) if self.linear:
return ModuleList(list(self.convs) + list(self.bns))
else:
return ModuleList(list(self.convs[:-1]) + list(self.bns))
@property @property
def nonreg_modules(self): def nonreg_modules(self):
return self.convs[-1:] return self.lins if self.linear else self.convs[-1:]
def reset_parameters(self): def reset_parameters(self):
super(GCN, self).reset_parameters() super(GCN, self).reset_parameters()
for lin in self.lins:
lin.reset_parameters()
for conv in self.convs: for conv in self.convs:
conv.reset_parameters() conv.reset_parameters()
for bn in self.bns: for bn in self.bns:
...@@ -61,6 +76,10 @@ class GCN(ScalableGNN): ...@@ -61,6 +76,10 @@ class GCN(ScalableGNN):
if self.drop_input: if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
if self.linear:
x = self.lins[0](x).relu_()
x = F.dropout(x, p=self.dropout, training=self.training)
for conv, bn, hist in zip(self.convs[:-1], self.bns, self.histories): for conv, bn, hist in zip(self.convs[:-1], self.bns, self.histories):
h = conv(x, adj_t) h = conv(x, adj_t)
if self.batch_norm: if self.batch_norm:
...@@ -71,23 +90,41 @@ class GCN(ScalableGNN): ...@@ -71,23 +90,41 @@ class GCN(ScalableGNN):
x = self.push_and_pull(hist, x, batch_size, n_id, offset, count) x = self.push_and_pull(hist, x, batch_size, n_id, offset, count)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = self.convs[-1](x, adj_t) h = self.convs[-1](x, adj_t)
return x
if not self.linear:
return h
if self.batch_norm:
h = self.bns[-1](h)
if self.residual and h.size(-1) == x.size(-1):
h += x[:h.size(0)]
h = h.relu_()
h = F.dropout(h, p=self.dropout, training=self.training)
return self.lins[1](h)
@torch.no_grad() @torch.no_grad()
def forward_layer(self, layer: int, x: Tensor, adj_t: SparseTensor, def forward_layer(self, layer, x, adj_t, state):
state: Dict[Any]) -> Tensor: if layer == 0:
if layer == 0 and self.drop_input: if self.drop_input:
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
if self.linear:
x = self.lins[0](x).relu_()
x = F.dropout(x, p=self.dropout, training=self.training)
else: else:
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
h = self.convs[layer](x, adj_t) h = self.convs[layer](x, adj_t)
if layer < self.num_layers - 1: if layer < self.num_layers - 1 or self.linear:
if self.batch_norm: if self.batch_norm:
h = self.bns[layer](h) h = self.bns[layer](h)
if self.residual and h.size(-1) == x.size(-1): if self.residual and h.size(-1) == x.size(-1):
h += x[:h.size(0)] h += x[:h.size(0)]
h = h.relu_() h = h.relu_()
if self.linear:
h = F.dropout(h, p=self.dropout, training=self.training)
h = self.lins[1](h)
return h return h
...@@ -4,9 +4,9 @@ import torch ...@@ -4,9 +4,9 @@ import torch
from torch import Tensor from torch import Tensor
from torch.cuda import Stream from torch.cuda import Stream
synchronize = torch.ops.scaling_gnns.synchronize synchronize = torch.ops.torch_geometric_autoscale.synchronize
read_async = torch.ops.scaling_gnns.read_async read_async = torch.ops.torch_geometric_autoscale.read_async
write_async = torch.ops.scaling_gnns.write_async write_async = torch.ops.torch_geometric_autoscale.write_async
class AsyncIOPool(torch.nn.Module): class AsyncIOPool(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment