Unverified Commit 81915f55 authored by Krzysztof Sadowski's avatar Krzysztof Sadowski Committed by GitHub
Browse files

[Examples] RGCN Heterogeneous on ogbn-mag (#3371)



* upload

* cleanup of unused code

* default gpu training/inference

* layer norm instead of batch norm

* fix for default inference mode

* simplified embedding forward method
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 2150fcaf
import argparse
from itertools import chain
from timeit import default_timer
from typing import Callable, Tuple, Union
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
from model import EntityClassify, RelGraphEmbedding
def train(
embedding_layer: nn.Module,
model: nn.Module,
device: Union[str, torch.device],
embedding_optimizer: torch.optim.Optimizer,
model_optimizer: torch.optim.Optimizer,
loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
labels: torch.Tensor,
predict_category: str,
dataloader: dgl.dataloading.NodeDataLoader,
) -> Tuple[float]:
model.train()
total_loss = 0
total_accuracy = 0
start = default_timer()
embedding_layer = embedding_layer.to(device)
model = model.to(device)
loss_function = loss_function.to(device)
for step, (in_nodes, out_nodes, blocks) in enumerate(dataloader):
embedding_optimizer.zero_grad()
model_optimizer.zero_grad()
in_nodes = {rel: nid.to(device) for rel, nid in in_nodes.items()}
out_nodes = out_nodes[predict_category].to(device)
blocks = [block.to(device) for block in blocks]
batch_labels = labels[out_nodes].to(device)
embedding = embedding_layer(in_nodes=in_nodes, device=device)
logits = model(blocks, embedding)[predict_category]
loss = loss_function(logits, batch_labels)
indices = logits.argmax(dim=-1)
correct = torch.sum(indices == batch_labels)
accuracy = correct.item() / len(batch_labels)
loss.backward()
model_optimizer.step()
embedding_optimizer.step()
total_loss += loss.item()
total_accuracy += accuracy
stop = default_timer()
time = stop - start
total_loss /= step + 1
total_accuracy /= step + 1
return time, total_loss, total_accuracy
def validate(
embedding_layer: nn.Module,
model: nn.Module,
device: Union[str, torch.device],
inference_mode: str,
loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
hg: dgl.DGLHeteroGraph,
labels: torch.Tensor,
predict_category: str,
dataloader: dgl.dataloading.NodeDataLoader = None,
eval_batch_size: int = None,
eval_num_workers: int = None,
mask: torch.Tensor = None,
) -> Tuple[float]:
embedding_layer.eval()
model.eval()
start = default_timer()
embedding_layer = embedding_layer.to(device)
model = model.to(device)
loss_function = loss_function.to(device)
valid_labels = labels[mask].to(device)
with torch.no_grad():
if inference_mode == 'neighbor_sampler':
total_loss = 0
total_accuracy = 0
for step, (in_nodes, out_nodes, blocks) in enumerate(dataloader):
in_nodes = {rel: nid.to(device)
for rel, nid in in_nodes.items()}
out_nodes = out_nodes[predict_category].to(device)
blocks = [block.to(device) for block in blocks]
batch_labels = labels[out_nodes].to(device)
embedding = embedding_layer(in_nodes=in_nodes, device=device)
logits = model(blocks, embedding)[predict_category]
loss = loss_function(logits, batch_labels)
indices = logits.argmax(dim=-1)
correct = torch.sum(indices == batch_labels)
accuracy = correct.item() / len(batch_labels)
total_loss += loss.item()
total_accuracy += accuracy
total_loss /= step + 1
total_accuracy /= step + 1
elif inference_mode == 'full_neighbor_sampler':
logits = model.inference(
hg,
eval_batch_size,
eval_num_workers,
embedding_layer,
device,
)[predict_category][mask]
total_loss = loss_function(logits, valid_labels)
indices = logits.argmax(dim=-1)
correct = torch.sum(indices == valid_labels)
total_accuracy = correct.item() / len(valid_labels)
total_loss = total_loss.item()
else:
embedding = embedding_layer(device=device)
logits = model(hg, embedding)[predict_category][mask]
total_loss = loss_function(logits, valid_labels)
indices = logits.argmax(dim=-1)
correct = torch.sum(indices == valid_labels)
total_accuracy = correct.item() / len(valid_labels)
total_loss = total_loss.item()
stop = default_timer()
time = stop - start
return time, total_loss, total_accuracy
def run(args: argparse.ArgumentParser) -> None:
torch.manual_seed(args.seed)
dataset, hg, train_idx, valid_idx, test_idx = utils.process_dataset(
args.dataset,
root=args.dataset_root,
)
predict_category = dataset.predict_category
labels = hg.nodes[predict_category].data['labels']
training_device = torch.device('cuda' if args.gpu_training else 'cpu')
inference_device = torch.device('cuda' if args.gpu_inference else 'cpu')
inferfence_mode = args.inference_mode
fanouts = [int(fanout) for fanout in args.fanouts.split(',')]
train_sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
train_dataloader = dgl.dataloading.NodeDataLoader(
hg,
{predict_category: train_idx},
train_sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers,
)
if inferfence_mode == 'neighbor_sampler':
valid_sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
valid_dataloader = dgl.dataloading.NodeDataLoader(
hg,
{predict_category: valid_idx},
valid_sampler,
batch_size=args.eval_batch_size,
shuffle=False,
drop_last=False,
num_workers=args.eval_num_workers,
)
if args.test_validation:
test_sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
test_dataloader = dgl.dataloading.NodeDataLoader(
hg,
{predict_category: test_idx},
test_sampler,
batch_size=args.eval_batch_size,
shuffle=False,
drop_last=False,
num_workers=args.eval_num_workers,
)
else:
valid_dataloader = None
if args.test_validation:
test_dataloader = None
in_feats = hg.nodes[predict_category].data['feat'].shape[-1]
out_feats = dataset.num_classes
num_nodes = {}
node_feats = {}
for ntype in hg.ntypes:
num_nodes[ntype] = hg.num_nodes(ntype)
node_feats[ntype] = hg.nodes[ntype].data.get('feat')
activations = {'leaky_relu': F.leaky_relu, 'relu': F.relu}
embedding_layer = RelGraphEmbedding(hg, in_feats, num_nodes, node_feats)
model = EntityClassify(
hg,
in_feats,
args.hidden_feats,
out_feats,
args.num_bases,
args.num_layers,
norm=args.norm,
layer_norm=args.layer_norm,
input_dropout=args.input_dropout,
dropout=args.dropout,
activation=activations[args.activation],
self_loop=args.self_loop,
)
loss_function = nn.CrossEntropyLoss()
embedding_optimizer = torch.optim.SparseAdam(
embedding_layer.node_embeddings.parameters(), lr=args.embedding_lr)
if args.node_feats_projection:
all_parameters = chain(
model.parameters(), embedding_layer.embeddings.parameters())
model_optimizer = torch.optim.Adam(all_parameters, lr=args.model_lr)
else:
model_optimizer = torch.optim.Adam(
model.parameters(), lr=args.model_lr)
checkpoint = utils.Callback(args.early_stopping_patience,
args.early_stopping_monitor)
print('## Training started ##')
for epoch in range(args.num_epochs):
train_time, train_loss, train_accuracy = train(
embedding_layer,
model,
training_device,
embedding_optimizer,
model_optimizer,
loss_function,
labels,
predict_category,
train_dataloader,
)
valid_time, valid_loss, valid_accuracy = validate(
embedding_layer,
model,
inference_device,
inferfence_mode,
loss_function,
hg,
labels,
predict_category=predict_category,
dataloader=valid_dataloader,
eval_batch_size=args.eval_batch_size,
eval_num_workers=args.eval_num_workers,
mask=valid_idx,
)
checkpoint.create(
epoch,
train_time,
valid_time,
train_loss,
valid_loss,
train_accuracy,
valid_accuracy,
{'embedding_layer': embedding_layer, 'model': model},
)
print(
f'Epoch: {epoch + 1:03} '
f'Train Loss: {train_loss:.2f} '
f'Valid Loss: {valid_loss:.2f} '
f'Train Accuracy: {train_accuracy:.4f} '
f'Valid Accuracy: {valid_accuracy:.4f} '
f'Train Epoch Time: {train_time:.2f} '
f'Valid Epoch Time: {valid_time:.2f}'
)
if checkpoint.should_stop:
print('## Training finished: early stopping ##')
break
elif epoch >= args.num_epochs - 1:
print('## Training finished ##')
print(
f'Best Epoch: {checkpoint.best_epoch} '
f'Train Loss: {checkpoint.best_epoch_train_loss:.2f} '
f'Valid Loss: {checkpoint.best_epoch_valid_loss:.2f} '
f'Train Accuracy: {checkpoint.best_epoch_train_accuracy:.4f} '
f'Valid Accuracy: {checkpoint.best_epoch_valid_accuracy:.4f}'
)
if args.test_validation:
print('## Test data validation ##')
embedding_layer.load_state_dict(
checkpoint.best_epoch_model_parameters['embedding_layer'])
model.load_state_dict(checkpoint.best_epoch_model_parameters['model'])
test_time, test_loss, test_accuracy = validate(
embedding_layer,
model,
inference_device,
inferfence_mode,
loss_function,
hg,
labels,
predict_category=predict_category,
dataloader=test_dataloader,
eval_batch_size=args.eval_batch_size,
eval_num_workers=args.eval_num_workers,
mask=test_idx,
)
print(
f'Test Loss: {test_loss:.2f} '
f'Test Accuracy: {test_accuracy:.4f} '
f'Test Epoch Time: {test_time:.2f}'
)
if __name__ == '__main__':
argparser = argparse.ArgumentParser('RGCN')
argparser.add_argument('--gpu-training', dest='gpu_training',
action='store_true')
argparser.add_argument('--no-gpu-training', dest='gpu_training',
action='store_false')
argparser.set_defaults(gpu_training=True)
argparser.add_argument('--gpu-inference', dest='gpu_inference',
action='store_true')
argparser.add_argument('--no-gpu-inference', dest='gpu_inference',
action='store_false')
argparser.set_defaults(gpu_inference=True)
argparser.add_argument('--inference-mode', default='neighbor_sampler', type=str,
choices=['neighbor_sampler', 'full_neighbor_sampler', 'full_graph'])
argparser.add_argument('--dataset', default='ogbn-mag', type=str,
choices=['ogbn-mag'])
argparser.add_argument('--dataset-root', default='dataset', type=str)
argparser.add_argument('--num-epochs', default=500, type=int)
argparser.add_argument('--embedding-lr', default=0.01, type=float)
argparser.add_argument('--model-lr', default=0.01, type=float)
argparser.add_argument('--node-feats-projection',
dest='node_feats_projection', action='store_true')
argparser.add_argument('--no-node-feats-projection',
dest='node_feats_projection', action='store_false')
argparser.set_defaults(node_feats_projection=False)
argparser.add_argument('--hidden-feats', default=64, type=int)
argparser.add_argument('--num-bases', default=2, type=int)
argparser.add_argument('--num-layers', default=2, type=int)
argparser.add_argument('--norm', default='right',
type=str, choices=['both', 'none', 'right'])
argparser.add_argument('--layer-norm', dest='layer_norm',
action='store_true')
argparser.add_argument('--no-layer-norm', dest='layer_norm',
action='store_false')
argparser.set_defaults(layer_norm=False)
argparser.add_argument('--input-dropout', default=0.1, type=float)
argparser.add_argument('--dropout', default=0.5, type=float)
argparser.add_argument('--activation', default='relu', type=str,
choices=['leaky_relu', 'relu'])
argparser.add_argument('--self-loop', dest='self_loop',
action='store_true')
argparser.add_argument('--no-self-loop', dest='self_loop',
action='store_false')
argparser.set_defaults(self_loop=True)
argparser.add_argument('--fanouts', default='25,20', type=str)
argparser.add_argument('--batch-size', default=1024, type=int)
argparser.add_argument('--eval-batch-size', default=1024, type=int)
argparser.add_argument('--num-workers', default=4, type=int)
argparser.add_argument('--eval-num-workers', default=4, type=int)
argparser.add_argument('--early-stopping-patience', default=10, type=int)
argparser.add_argument('--early-stopping-monitor', default='loss',
type=str, choices=['accuracy', 'loss'])
argparser.add_argument('--test-validation', dest='test_validation',
action='store_true')
argparser.add_argument('--no-test-validation', dest='test_validation',
action='store_false')
argparser.set_defaults(test_validation=True)
argparser.add_argument('--seed', default=13, type=int)
args = argparser.parse_args()
run(args)
from typing import Callable, Dict, List, Union
import dgl
import dgl.nn.pytorch as dglnn
import torch
import torch.nn as nn
class RelGraphEmbedding(nn.Module):
def __init__(
self,
hg: dgl.DGLHeteroGraph,
embedding_size: int,
num_nodes: Dict[str, int],
node_feats: Dict[str, torch.Tensor],
node_feats_projection: bool = False,
):
super().__init__()
self._hg = hg
self._node_feats = node_feats
self._node_feats_projection = node_feats_projection
self.node_embeddings = nn.ModuleDict()
if node_feats_projection:
self.embeddings = nn.ParameterDict()
for ntype in hg.ntypes:
if node_feats[ntype] is None:
node_embedding = nn.Embedding(
num_nodes[ntype], embedding_size, sparse=True)
nn.init.uniform_(node_embedding.weight, -1, 1)
self.node_embeddings[ntype] = node_embedding
elif node_feats[ntype] is not None and node_feats_projection:
input_embedding_size = node_feats[ntype].shape[-1]
embedding = nn.Parameter(torch.Tensor(
input_embedding_size, embedding_size))
nn.init.xavier_uniform_(embedding)
self.embeddings[ntype] = embedding
def forward(
self,
in_nodes: Dict[str, torch.Tensor] = None,
device: torch.device = None,
) -> Dict[str, torch.Tensor]:
if in_nodes is not None:
ntypes = [ntype for ntype in in_nodes.keys()]
nids = [nid for nid in in_nodes.values()]
else:
ntypes = self._hg.ntypes
nids = [self._hg.nodes(ntype) for ntype in ntypes]
x = {}
for ntype, nid in zip(ntypes, nids):
if self._node_feats[ntype] is None:
x[ntype] = self.node_embeddings[ntype](nid)
else:
if device is not None:
self._node_feats[ntype] = self._node_feats[ntype].to(
device)
if self._node_feats_projection:
x[ntype] = self._node_feats[ntype][nid] @ self.embeddings[ntype]
else:
x[ntype] = self._node_feats[ntype][nid]
return x
class RelGraphConvLayer(nn.Module):
def __init__(
self,
in_feats: int,
out_feats: int,
rel_names: List[str],
num_bases: int,
norm: str = 'right',
weight: bool = True,
bias: bool = True,
activation: Callable[[torch.Tensor], torch.Tensor] = None,
dropout: float = None,
self_loop: bool = False,
):
super().__init__()
self._rel_names = rel_names
self._num_rels = len(rel_names)
self._conv = dglnn.HeteroGraphConv({rel: dglnn.GraphConv(
in_feats, out_feats, norm=norm, weight=False, bias=False) for rel in rel_names})
self._use_weight = weight
self._use_basis = num_bases < self._num_rels and weight
self._use_bias = bias
self._activation = activation
self._dropout = nn.Dropout(dropout) if dropout is not None else None
self._use_self_loop = self_loop
if weight:
if self._use_basis:
self.basis = dglnn.WeightBasis(
(in_feats, out_feats), num_bases, self._num_rels)
else:
self.weight = nn.Parameter(torch.Tensor(
self._num_rels, in_feats, out_feats))
nn.init.xavier_uniform_(
self.weight, gain=nn.init.calculate_gain('relu'))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_feats))
nn.init.zeros_(self.bias)
if self_loop:
self.self_loop_weight = nn.Parameter(
torch.Tensor(in_feats, out_feats))
nn.init.xavier_uniform_(
self.self_loop_weight, gain=nn.init.calculate_gain('relu'))
def _apply_layers(
self,
ntype: str,
inputs: torch.Tensor,
inputs_dst: torch.Tensor = None,
) -> torch.Tensor:
x = inputs
if inputs_dst is not None:
x += torch.matmul(inputs_dst[ntype], self.self_loop_weight)
if self._use_bias:
x += self.bias
if self._activation is not None:
x = self._activation(x)
if self._dropout is not None:
x = self._dropout(x)
return x
def forward(
self,
hg: dgl.DGLHeteroGraph,
inputs: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
hg = hg.local_var()
if self._use_weight:
weight = self.basis() if self._use_basis else self.weight
weight_dict = {self._rel_names[i]: {'weight': w.squeeze(
dim=0)} for i, w in enumerate(torch.split(weight, 1, dim=0))}
else:
weight_dict = {}
if self._use_self_loop:
if hg.is_block:
inputs_dst = {ntype: h[:hg.num_dst_nodes(
ntype)] for ntype, h in inputs.items()}
else:
inputs_dst = inputs
else:
inputs_dst = None
x = self._conv(hg, inputs, mod_kwargs=weight_dict)
x = {ntype: self._apply_layers(ntype, h, inputs_dst)
for ntype, h in x.items()}
return x
class EntityClassify(nn.Module):
def __init__(
self,
hg: dgl.DGLHeteroGraph,
in_feats: int,
hidden_feats: int,
out_feats: int,
num_bases: int,
num_layers: int,
norm: str = 'right',
layer_norm: bool = False,
input_dropout: float = 0,
dropout: float = 0,
activation: Callable[[torch.Tensor], torch.Tensor] = None,
self_loop: bool = False,
):
super().__init__()
self._hidden_feats = hidden_feats
self._out_feats = out_feats
self._num_layers = num_layers
self._input_dropout = nn.Dropout(input_dropout)
self._dropout = nn.Dropout(dropout)
self._activation = activation
self._rel_names = sorted(list(set(hg.etypes)))
self._num_rels = len(self._rel_names)
if num_bases < 0 or num_bases > self._num_rels:
self._num_bases = self._num_rels
else:
self._num_bases = num_bases
self._layers = nn.ModuleList()
self._layers.append(RelGraphConvLayer(
in_feats,
hidden_feats,
self._rel_names,
self._num_bases,
norm=norm,
self_loop=self_loop,
))
for _ in range(1, num_layers - 1):
self._layers.append(RelGraphConvLayer(
hidden_feats,
hidden_feats,
self._rel_names,
self._num_bases,
norm=norm,
self_loop=self_loop,
))
self._layers.append(RelGraphConvLayer(
hidden_feats,
out_feats,
self._rel_names,
self._num_bases,
norm=norm,
self_loop=self_loop,
))
if layer_norm:
self._layer_norms = nn.ModuleList()
for _ in range(num_layers - 1):
self._layer_norms.append(nn.LayerNorm(hidden_feats))
else:
self._layer_norms = None
def _apply_layers(
self,
layer_idx: int,
inputs: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
x = inputs
for ntype, h in x.items():
if self._layer_norms is not None:
h = self._layer_norms[layer_idx](h)
if self._activation is not None:
h = self._activation(h)
x[ntype] = self._dropout(h)
return x
def forward(
self,
hg: Union[dgl.DGLHeteroGraph, List[dgl.DGLHeteroGraph]],
inputs: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
x = {ntype: self._input_dropout(h) for ntype, h in inputs.items()}
if isinstance(hg, list):
for i, (layer, block) in enumerate(zip(self._layers, hg)):
x = layer(block, x)
if i < self._num_layers - 1:
x = self._apply_layers(i, x)
else:
for i, layer in enumerate(self._layers):
x = layer(hg, x)
if i < self._num_layers - 1:
x = self._apply_layers(i, x)
return x
def inference(
self,
hg: dgl.DGLHeteroGraph,
batch_size: int,
num_workers: int,
embedding_layer: nn.Module,
device: torch.device,
) -> Dict[str, torch.Tensor]:
for i, layer in enumerate(self._layers):
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
hg,
{ntype: hg.nodes(ntype) for ntype in hg.ntypes},
sampler,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=num_workers,
)
if i < self._num_layers - 1:
y = {ntype: torch.zeros(hg.num_nodes(
ntype), self._hidden_feats, device=device) for ntype in hg.ntypes}
else:
y = {ntype: torch.zeros(hg.num_nodes(
ntype), self._out_feats, device=device) for ntype in hg.ntypes}
for in_nodes, out_nodes, blocks in dataloader:
in_nodes = {rel: nid.to(device)
for rel, nid in in_nodes.items()}
out_nodes = {rel: nid.to(device)
for rel, nid in out_nodes.items()}
block = blocks[0].to(device)
if i == 0:
h = embedding_layer(in_nodes=in_nodes, device=device)
else:
h = {ntype: x[ntype][in_nodes[ntype]]
for ntype in hg.ntypes}
h = layer(block, h)
if i < self._num_layers - 1:
h = self._apply_layers(i, h)
for ntype in h:
y[ntype][out_nodes[ntype]] = h[ntype]
x = y
return x
from copy import deepcopy
from typing import Dict, List, Tuple, Union
import dgl
import torch
import torch.nn as nn
from ogb.nodeproppred import DglNodePropPredDataset
class Callback:
def __init__(
self,
patience: int,
monitor: str,
) -> None:
self._patience = patience
self._monitor = monitor
self._lookback = 0
self._best_epoch = None
self._train_times = []
self._valid_times = []
self._train_losses = []
self._valid_losses = []
self._train_accuracies = []
self._valid_accuracies = []
self._model_parameters = {}
@property
def best_epoch(self) -> int:
return self._best_epoch + 1
@property
def train_times(self) -> List[float]:
return self._train_times
@property
def valid_times(self) -> List[float]:
return self._valid_times
@property
def train_losses(self) -> List[float]:
return self._train_losses
@property
def valid_losses(self) -> List[float]:
return self._valid_losses
@property
def train_accuracies(self) -> List[float]:
return self._train_accuracies
@property
def valid_accuracies(self) -> List[float]:
return self._valid_accuracies
@property
def best_epoch_training_time(self) -> float:
return sum(self._train_times[:self._best_epoch])
@property
def best_epoch_train_loss(self) -> float:
return self._train_losses[self._best_epoch]
@property
def best_epoch_valid_loss(self) -> float:
return self._valid_losses[self._best_epoch]
@property
def best_epoch_train_accuracy(self) -> float:
return self._train_accuracies[self._best_epoch]
@property
def best_epoch_valid_accuracy(self) -> float:
return self._valid_accuracies[self._best_epoch]
@property
def best_epoch_model_parameters(
self) -> Union[Dict[str, torch.Tensor], Dict[str, Dict[str, torch.Tensor]]]:
return self._model_parameters
@property
def should_stop(self) -> bool:
return self._lookback >= self._patience
def create(
self,
epoch: int,
train_time: float,
valid_time: float,
train_loss: float,
valid_loss: float,
train_accuracy: float,
valid_accuracy: float,
model: Union[nn.Module, Dict[str, nn.Module]],
) -> None:
self._train_times.append(train_time)
self._valid_times.append(valid_time)
self._train_losses.append(train_loss)
self._valid_losses.append(valid_loss)
self._train_accuracies.append(train_accuracy)
self._valid_accuracies.append(valid_accuracy)
best_epoch = False
if self._best_epoch is None:
best_epoch = True
elif self._monitor == 'loss':
if valid_loss < self._valid_losses[self._best_epoch]:
best_epoch = True
elif self._monitor == 'accuracy':
if valid_accuracy > self._valid_accuracies[self._best_epoch]:
best_epoch = True
if best_epoch:
self._best_epoch = epoch
if isinstance(model, dict):
for name, current_model in model.items():
self._model_parameters[name] = deepcopy(
current_model.to('cpu').state_dict())
else:
self._model_parameters = deepcopy(model.to('cpu').state_dict())
self._lookback = 0
else:
self._lookback += 1
class OGBDataset:
def __init__(
self,
g: Union[dgl.DGLGraph, dgl.DGLHeteroGraph],
num_labels: int,
predict_category: str = None,
) -> None:
self._g = g
self._num_labels = num_labels
self._predict_category = predict_category
@property
def num_labels(self) -> int:
return self._num_labels
@property
def num_classes(self) -> int:
return self._num_labels
@property
def predict_category(self) -> str:
return self._predict_category
def __getitem__(self, idx: int) -> Union[dgl.DGLGraph, dgl.DGLHeteroGraph]:
return self._g
def load_ogbn_mag(root: str = None) -> OGBDataset:
dataset = DglNodePropPredDataset(name='ogbn-mag', root=root)
split_idx = dataset.get_idx_split()
train_idx = split_idx['train']['paper']
valid_idx = split_idx['valid']['paper']
test_idx = split_idx['test']['paper']
hg_original, labels = dataset[0]
labels = labels['paper'].squeeze()
num_labels = dataset.num_classes
subgraphs = {}
for etype in hg_original.canonical_etypes:
src, dst = hg_original.all_edges(etype=etype)
subgraphs[etype] = (src, dst)
subgraphs[(etype[2], f'rev-{etype[1]}', etype[0])] = (dst, src)
hg = dgl.heterograph(subgraphs)
hg.nodes['paper'].data['feat'] = hg_original.nodes['paper'].data['feat']
hg.nodes['paper'].data['labels'] = labels
train_mask = torch.zeros((hg.num_nodes('paper'),), dtype=torch.bool)
train_mask[train_idx] = True
valid_mask = torch.zeros((hg.num_nodes('paper'),), dtype=torch.bool)
valid_mask[valid_idx] = True
test_mask = torch.zeros((hg.num_nodes('paper'),), dtype=torch.bool)
test_mask[test_idx] = True
hg.nodes['paper'].data['train_mask'] = train_mask
hg.nodes['paper'].data['valid_mask'] = valid_mask
hg.nodes['paper'].data['test_mask'] = test_mask
ogb_dataset = OGBDataset(hg, num_labels, 'paper')
return ogb_dataset
def process_dataset(
name: str,
root: str = None,
) -> Tuple[OGBDataset, dgl.DGLHeteroGraph, torch.Tensor]:
if root is None:
root = 'datasets'
if name == 'ogbn-mag':
dataset = load_ogbn_mag(root=root)
g = dataset[0]
predict_category = dataset.predict_category
train_idx = torch.nonzero(
g.nodes[predict_category].data['train_mask'], as_tuple=True)[0]
valid_idx = torch.nonzero(
g.nodes[predict_category].data['valid_mask'], as_tuple=True)[0]
test_idx = torch.nonzero(
g.nodes[predict_category].data['test_mask'], as_tuple=True)[0]
return dataset, g, train_idx, valid_idx, test_idx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment