Unverified Commit a88e7f7e authored by YJ-Zhao's avatar YJ-Zhao Committed by GitHub
Browse files

[Example]rgcn-ogbn-mag (#4331)



* rgcn-ogbn-mag

* Add link in README.md

* correct code-format,add the reset_parameters function to the HeteroEmbedding module

* add the annotation in hetero.py

* add a unit test

* modify format

* Update
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-50-143.us-west-2.compute.internal>
parent 919b7838
......@@ -249,7 +249,7 @@ To quickly locate the examples of your interest, search for the tagged keywords
- Tags: matrix completion, recommender system, link prediction, bipartite graphs
- <a name="graphsage"></a> Hamilton et al. Inductive Representation Learning on Large Graphs. [Paper link](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf).
- Example code: [PyTorch](../examples/pytorch/graphsage), [PyTorch on ogbn-products](../examples/pytorch/ogb/ogbn-products), [PyTorch on ogbl-ppa](https://github.com/awslabs/dgl-lifesci/tree/master/examples/link_prediction/ogbl-ppa), [MXNet](../examples/mxnet/graphsage)
- Example code: [PyTorch](../examples/pytorch/graphsage), [PyTorch on ogbn-products](../examples/pytorch/ogb/ogbn-products), [PyTorch on ogbn-mag](../examples/pytorch/ogb/ogbn-mag), [PyTorch on ogbl-ppa](https://github.com/awslabs/dgl-lifesci/tree/master/examples/link_prediction/ogbl-ppa), [MXNet](../examples/mxnet/graphsage)
- Tags: node classification, sampling, unsupervised learning, link prediction, OGB
- <a name="metapath2vec"></a> Dong et al. metapath2vec: Scalable Representation Learning for Heterogeneous Networks. [Paper link](https://dl.acm.org/doi/10.1145/3097983.3098036).
......
......@@ -5,11 +5,6 @@ The following options can be specified via command line arguments:
```
optional arguments:
-h, --help show this help message and exit
--dropout DROPOUT dropout probability
--n-hidden N_HIDDEN number of hidden units
--lr LR learning rate
-e N_EPOCHS, --n-epochs N_EPOCHS
number of training epochs
--runs RUNS
```
......@@ -58,68 +53,3 @@ ParameterDict(
The input features are passed to a modified version of the R-GCN architecture. As in the R-GCN paper, each _edge-type_ has its own linear projection matrix (the "weight" ModuleDict below). Different from the original paper, however, each _node-type_ has its own "self" linear projection matrix (the "loop_weights" ModuleDict below). There are 7 edge-types: 4 natural edge-types ("cites", "affiliated_with", "has_topic" and "writes") and 3 manufactured reverse edge-types ("rev-affiliated_with", "rev-has_topic", "rev-writes"). As mentioned above, note that there is _not_ a reverse edge type like "rev-cites", and instead the reverse edges are given the same type of "cites". This exception was presumably made because the source and destinate nodes are of type "paper". Whereas the 7 "relation" linear layers do not have a bias, the 4 "self" linear layers do.
With two of these layers, a hidden dimension size of 64 and 349 output classes, we end up with 337,460 R-GCN model parameters.
```
EntityClassify(
(layers): ModuleList(
(0): RelGraphConvLayer(
(conv): HeteroGraphConv(
(mods): ModuleDict(
(affiliated_with): GraphConv(in=128, out=64, normalization=right, activation=None)
(cites): GraphConv(in=128, out=64, normalization=right, activation=None)
(has_topic): GraphConv(in=128, out=64, normalization=right, activation=None)
(rev-affiliated_with): GraphConv(in=128, out=64, normalization=right, activation=None)
(rev-has_topic): GraphConv(in=128, out=64, normalization=right, activation=None)
(rev-writes): GraphConv(in=128, out=64, normalization=right, activation=None)
(writes): GraphConv(in=128, out=64, normalization=right, activation=None)
)
)
(weight): ModuleDict(
(affiliated_with): Linear(in_features=128, out_features=64, bias=False)
(cites): Linear(in_features=128, out_features=64, bias=False)
(has_topic): Linear(in_features=128, out_features=64, bias=False)
(rev-affiliated_with): Linear(in_features=128, out_features=64, bias=False)
(rev-has_topic): Linear(in_features=128, out_features=64, bias=False)
(rev-writes): Linear(in_features=128, out_features=64, bias=False)
(writes): Linear(in_features=128, out_features=64, bias=False)
)
(loop_weights): ModuleDict(
(author): Linear(in_features=128, out_features=64, bias=True)
(field_of_study): Linear(in_features=128, out_features=64, bias=True)
(institution): Linear(in_features=128, out_features=64, bias=True)
(paper): Linear(in_features=128, out_features=64, bias=True)
)
(dropout): Dropout(p=0.5, inplace=False)
)
(1): RelGraphConvLayer(
(conv): HeteroGraphConv(
(mods): ModuleDict(
(affiliated_with): GraphConv(in=64, out=349, normalization=right, activation=None)
(cites): GraphConv(in=64, out=349, normalization=right, activation=None)
(has_topic): GraphConv(in=64, out=349, normalization=right, activation=None)
(rev-affiliated_with): GraphConv(in=64, out=349, normalization=right, activation=None)
(rev-has_topic): GraphConv(in=64, out=349, normalization=right, activation=None)
(rev-writes): GraphConv(in=64, out=349, normalization=right, activation=None)
(writes): GraphConv(in=64, out=349, normalization=right, activation=None)
)
)
(weight): ModuleDict(
(affiliated_with): Linear(in_features=64, out_features=349, bias=False)
(cites): Linear(in_features=64, out_features=349, bias=False)
(has_topic): Linear(in_features=64, out_features=349, bias=False)
(rev-affiliated_with): Linear(in_features=64, out_features=349, bias=False)
(rev-has_topic): Linear(in_features=64, out_features=349, bias=False)
(rev-writes): Linear(in_features=64, out_features=349, bias=False)
(writes): Linear(in_features=64, out_features=349, bias=False)
)
(loop_weights): ModuleDict(
(author): Linear(in_features=64, out_features=349, bias=True)
(field_of_study): Linear(in_features=64, out_features=349, bias=True)
(institution): Linear(in_features=64, out_features=349, bias=True)
(paper): Linear(in_features=64, out_features=349, bias=True)
)
(dropout): Dropout(p=0.0, inplace=False)
)
)
)
```
......@@ -4,136 +4,93 @@ from tqdm import tqdm
import dgl
import dgl.nn as dglnn
from dgl.nn import HeteroEmbedding
from dgl import Compose, AddReverse, ToSimple
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
def prepare_data(args):
dataset = DglNodePropPredDataset(name="ogbn-mag")
split_idx = dataset.get_idx_split()
# graph: dgl graph object, label: torch tensor of shape (num_nodes, num_tasks)
g, labels = dataset[0]
labels = labels['paper'].flatten()
def extract_embed(node_embed, input_nodes):
emb = {}
for ntype, nid in input_nodes.items():
nid = input_nodes[ntype]
if ntype in node_embed:
emb[ntype] = node_embed[ntype][nid]
return emb
class RelGraphEmbed(nn.Module):
r"""Embedding layer for featureless heterograph.
Parameters
----------
g : DGLGraph
Input graph.
embed_size : int
The length of each embedding vector
exclude : list[str]
The list of node-types to exclude (e.g., because they have natural features)
"""
def __init__(self, g, embed_size, exclude=list()):
transform = Compose([ToSimple(), AddReverse()])
g = transform(g)
super(RelGraphEmbed, self).__init__()
self.g = g
self.embed_size = embed_size
print("Loaded graph: {}".format(g))
# create learnable embeddings for all nodes, except those with a node-type in the "exclude" list
self.embeds = nn.ParameterDict()
for ntype in g.ntypes:
if ntype in exclude:
continue
embed = nn.Parameter(th.Tensor(g.number_of_nodes(ntype), self.embed_size))
self.embeds[ntype] = embed
logger = Logger(args.runs)
self.reset_parameters()
# train sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler([25, 20])
train_loader = dgl.dataloading.DataLoader(
g, split_idx['train'], sampler,
batch_size=1024, shuffle=True, num_workers=0)
def reset_parameters(self):
for emb in self.embeds.values():
nn.init.xavier_uniform_(emb)
return g, labels, dataset.num_classes, split_idx, logger, train_loader
def forward(self, block=None):
return self.embeds
def extract_embed(node_embed, input_nodes):
emb = node_embed({
ntype: input_nodes[ntype] for ntype in input_nodes if ntype != 'paper'
})
return emb
def rel_graph_embed(graph, embed_size):
node_num = {}
for ntype in graph.ntypes:
if ntype == 'paper':
continue
node_num[ntype] = graph.num_nodes(ntype)
embeds = HeteroEmbedding(node_num, embed_size)
return embeds
class RelGraphConvLayer(nn.Module):
r"""Relational graph convolution layer.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
ntypes : list[str]
Node type names
rel_names : list[str]
Relation names.
weight : bool, optional
True if a linear layer is applied after message passing. Default: True
bias : bool, optional
True if bias is added. Default: True
activation : callable, optional
Activation function. Default: None
self_loop : bool, optional
True to include self loop message. Default: False
dropout : float, optional
Dropout rate. Default: 0.0
"""
def __init__(self,
in_feat,
out_feat,
ntypes,
rel_names,
*,
weight=True,
bias=True,
activation=None,
self_loop=False,
dropout=0.0):
super(RelGraphConvLayer, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.ntypes = ntypes
self.rel_names = rel_names
self.bias = bias
self.activation = activation
self.self_loop = self_loop
self.conv = dglnn.HeteroGraphConv({
rel : dglnn.GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False)
for rel in rel_names
})
self.use_weight = weight
if self.use_weight:
self.weight = nn.ModuleDict({
rel_name: nn.Linear(in_feat, out_feat, bias=False)
for rel_name in self.rel_names
})
# weight for self loop
if self.self_loop:
self.loop_weights = nn.ModuleDict({
ntype: nn.Linear(in_feat, out_feat, bias=bias)
ntype: nn.Linear(in_feat, out_feat, bias=True)
for ntype in self.ntypes
})
self.dropout = nn.Dropout(dropout)
self.reset_parameters()
def reset_parameters(self):
if self.use_weight:
for layer in self.weight.values():
layer.reset_parameters()
if self.self_loop:
for layer in self.loop_weights.values():
layer.reset_parameters()
def forward(self, g, inputs):
"""Forward computation
"""
Parameters
----------
g : DGLHeteroGraph
......@@ -147,83 +104,41 @@ class RelGraphConvLayer(nn.Module):
New node features for each node type.
"""
g = g.local_var()
if self.use_weight:
wdict = {rel_name: {'weight': self.weight[rel_name].weight.T}
for rel_name in self.rel_names}
else:
wdict = {}
if g.is_block:
inputs_dst = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
else:
inputs_dst = inputs
hs = self.conv(g, inputs, mod_kwargs=wdict)
def _apply(ntype, h):
if self.self_loop:
h = h + self.loop_weights[ntype](inputs_dst[ntype])
if self.activation:
h = self.activation(h)
return self.dropout(h)
return {ntype : _apply(ntype, h) for ntype, h in hs.items()}
return {ntype : _apply(ntype, h) for ntype, h in hs.items()}
class EntityClassify(nn.Module):
r"""
R-GCN node classification model
Parameters
----------
g : DGLGraph
The heterogenous graph used for message passing
in_dim : int
Input feature size.
h_dim : int
Hidden dimension size.
out_dim : int
Output dimension size.
num_hidden_layers : int, optional
Number of RelGraphConvLayers. Default: 1
dropout : float, optional
Dropout rate. Default: 0.0
use_self_loop : bool, optional
True to include self loop message in RelGraphConvLayers. Default: True
"""
def __init__(self,
g, in_dim,
h_dim, out_dim,
num_hidden_layers=1,
dropout=0,
use_self_loop=True):
def __init__(self, g, in_dim, out_dim):
super(EntityClassify, self).__init__()
self.g = g
self.in_dim = in_dim
self.h_dim = h_dim
self.h_dim = 64
self.out_dim = out_dim
self.rel_names = list(set(g.etypes))
self.rel_names.sort()
self.num_hidden_layers = num_hidden_layers
self.dropout = dropout
self.use_self_loop = use_self_loop
self.dropout = 0.5
self.layers = nn.ModuleList()
# i2h
self.layers.append(RelGraphConvLayer(
self.in_dim, self.h_dim, g.ntypes, self.rel_names,
activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout))
# h2h
for _ in range(self.num_hidden_layers):
self.layers.append(RelGraphConvLayer(
self.h_dim, self.h_dim, g.ntypes, self.rel_names,
activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout))
activation=F.relu, dropout=self.dropout))
# h2o
self.layers.append(RelGraphConvLayer(
self.h_dim, self.out_dim, g.ntypes, self.rel_names,
activation=None,
self_loop=self.use_self_loop))
activation=None))
def reset_parameters(self):
for layer in self.layers:
......@@ -234,7 +149,6 @@ class EntityClassify(nn.Module):
h = layer(block, h)
return h
class Logger(object):
r"""
This class was taken directly from the PyG implementation and can be found
......@@ -242,8 +156,7 @@ class Logger(object):
This was done to ensure that performance was measured in precisely the same way
"""
def __init__(self, runs, info=None):
self.info = info
def __init__(self, runs):
self.results = [[] for _ in range(runs)]
def add_result(self, run, result):
......@@ -283,110 +196,14 @@ class Logger(object):
r = best_result[:, 3]
print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}')
def parse_args():
# DGL
parser = argparse.ArgumentParser(description='RGCN')
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--n-hidden", type=int, default=64,
help="number of hidden units")
parser.add_argument("--lr", type=float, default=0.01,
help="learning rate")
parser.add_argument("-e", "--n-epochs", type=int, default=3,
help="number of training epochs")
# OGB
parser.add_argument('--runs', type=int, default=10)
args = parser.parse_args()
return args
def prepare_data(args):
dataset = DglNodePropPredDataset(name="ogbn-mag")
split_idx = dataset.get_idx_split()
g, labels = dataset[0] # graph: dgl graph object, label: torch tensor of shape (num_nodes, num_tasks)
labels = labels['paper'].flatten()
def add_reverse_hetero(g, combine_like=True):
r"""
Parameters
----------
g : DGLGraph
The heterogenous graph where reverse edges should be added
combine_like : bool, optional
Whether reverse-edges that have identical source/destination
node types should be combined with the existing edge-type,
rather than creating a new edge type. Default: True.
"""
relations = {}
num_nodes_dict = {ntype: g.num_nodes(ntype) for ntype in g.ntypes}
for metapath in g.canonical_etypes:
src_ntype, rel_type, dst_ntype = metapath
src, dst = g.all_edges(etype=rel_type)
if src_ntype==dst_ntype and combine_like:
# Make edges un-directed instead of making a reverse edge type
relations[metapath] = (th.cat([src, dst], dim=0), th.cat([dst, src], dim=0))
else:
# Original edges
relations[metapath] = (src, dst)
reverse_metapath = (dst_ntype, 'rev-' + rel_type, src_ntype)
relations[reverse_metapath] = (dst, src) # Reverse edges
new_g = dgl.heterograph(relations, num_nodes_dict=num_nodes_dict)
# Remove duplicate edges
new_g = dgl.to_simple(new_g, return_counts=None, writeback_mapping=False, copy_ndata=True)
# copy_ndata:
for ntype in g.ntypes:
for k, v in g.nodes[ntype].data.items():
new_g.nodes[ntype].data[k] = v.detach().clone()
return new_g
g = add_reverse_hetero(g)
print("Loaded graph: {}".format(g))
logger = Logger(args['runs'], args)
# train sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler(args['fanout'])
train_loader = dgl.dataloading.DataLoader(
g, split_idx['train'], sampler,
batch_size=args['batch_size'], shuffle=True, num_workers=0)
return (g, labels, dataset.num_classes, split_idx,
logger, train_loader)
def get_model(g, num_classes, args):
embed_layer = RelGraphEmbed(g, 128, exclude=['paper'])
model = EntityClassify(
g, 128, args['n_hidden'], num_classes,
num_hidden_layers=args['num_layers'] - 2,
dropout=args['dropout'],
use_self_loop=True,
)
print(embed_layer)
print(f"Number of embedding parameters: {sum(p.numel() for p in embed_layer.parameters())}")
print(model)
print(f"Number of model parameters: {sum(p.numel() for p in model.parameters())}")
return embed_layer, model
def train(g, model, node_embed, optimizer, train_loader, split_idx,
labels, logger, device, run, args):
# training loop
labels, logger, device, run):
print("start training...")
category = 'paper'
for epoch in range(args['n_epochs']):
N_train= split_idx['train'][category].shape[0]
pbar = tqdm(total=N_train)
for epoch in range(3):
num_train = split_idx['train'][category].shape[0]
pbar = tqdm(total=num_train)
pbar.set_description(f'Epoch {epoch:02d}')
model.train()
......@@ -400,11 +217,9 @@ def train(g, model, node_embed, optimizer, train_loader, split_idx,
emb = extract_embed(node_embed, input_nodes)
# Add the batch's raw "paper" features
emb.update({'paper': g.ndata['feat']['paper'][input_nodes['paper']]})
lbl = labels[seeds]
if th.cuda.is_available():
emb = {k : e.cuda() for k, e in emb.items()}
lbl = lbl.cuda()
emb = {k : e.to(device) for k, e in emb.items()}
lbl = labels[seeds].to(device)
optimizer.zero_grad()
logits = model(emb, blocks)[category]
......@@ -418,9 +233,9 @@ def train(g, model, node_embed, optimizer, train_loader, split_idx,
pbar.update(batch_size)
pbar.close()
loss = total_loss / N_train
loss = total_loss / num_train
result = test(g, model, node_embed, labels, device, split_idx, args)
result = test(g, model, node_embed, labels, device, split_idx)
logger.add_result(run, result)
train_acc, valid_acc, test_acc = result
print(f'Run: {run + 1:02d}, '
......@@ -433,19 +248,19 @@ def train(g, model, node_embed, optimizer, train_loader, split_idx,
return logger
@th.no_grad()
def test(g, model, node_embed, y_true, device, split_idx, args):
def test(g, model, node_embed, y_true, device, split_idx):
model.eval()
category = 'paper'
evaluator = Evaluator(name='ogbn-mag')
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(args['num_layers'])
# 2 GNN layers
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
loader = dgl.dataloading.DataLoader(
g, {'paper': th.arange(g.num_nodes('paper'))}, sampler,
batch_size=16384, shuffle=False, num_workers=0)
N = y_true.size(0)
pbar = tqdm(total=N)
pbar.set_description(f'Full Inference')
pbar = tqdm(total=y_true.size(0))
pbar.set_description(f'Inference')
y_hats = list()
......@@ -457,9 +272,7 @@ def test(g, model, node_embed, y_true, device, split_idx, args):
emb = extract_embed(node_embed, input_nodes)
# Get the batch's raw "paper" features
emb.update({'paper': g.ndata['feat']['paper'][input_nodes['paper']]})
if th.cuda.is_available():
emb = {k : e.cuda() for k, e in emb.items()}
emb = {k : e.to(device) for k, e in emb.items()}
logits = model(emb, blocks)[category]
y_hat = logits.log_softmax(dim=-1).argmax(dim=1, keepdims=True)
......@@ -488,40 +301,36 @@ def test(g, model, node_embed, y_true, device, split_idx, args):
return train_acc, valid_acc, test_acc
def main(args):
# Static parameters
hyperparameters = dict(
num_layers=2,
fanout=[25, 20],
batch_size=1024,
)
hyperparameters.update(vars(args))
print(hyperparameters)
device = f'cuda:0' if th.cuda.is_available() else 'cpu'
(g, labels, num_classes, split_idx,
logger, train_loader) = prepare_data(hyperparameters)
g, labels, num_classes, split_idx, logger, train_loader = prepare_data(args)
embed_layer = rel_graph_embed(g, 128)
model = EntityClassify(g, 128, num_classes).to(device)
embed_layer, model = get_model(g, num_classes, hyperparameters)
model = model.to(device)
print(f"Number of embedding parameters: {sum(p.numel() for p in embed_layer.parameters())}")
print(f"Number of model parameters: {sum(p.numel() for p in model.parameters())}")
for run in range(hyperparameters['runs']):
for run in range(args.runs):
embed_layer.reset_parameters()
model.reset_parameters()
# optimizer
all_params = itertools.chain(model.parameters(), embed_layer.parameters())
optimizer = th.optim.Adam(all_params, lr=hyperparameters['lr'])
logger = train(g, model, embed_layer(), optimizer, train_loader, split_idx,
labels, logger, device, run, hyperparameters)
optimizer = th.optim.Adam(all_params, lr=0.01)
logger = train(g, model, embed_layer, optimizer, train_loader, split_idx,
labels, logger, device, run)
logger.print_statistics(run)
print("Final performance: ")
logger.print_statistics()
if __name__ == '__main__':
args = parse_args()
parser = argparse.ArgumentParser(description='RGCN')
parser.add_argument('--runs', type=int, default=10)
args = parser.parse_args()
main(args)
import argparse
from itertools import chain
from timeit import default_timer
from typing import Callable, Tuple, Union
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils
from model import EntityClassify, RelGraphEmbedding
def train(
embedding_layer: nn.Module,
model: nn.Module,
device: Union[str, torch.device],
embedding_optimizer: torch.optim.Optimizer,
model_optimizer: torch.optim.Optimizer,
loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
labels: torch.Tensor,
predict_category: str,
dataloader: dgl.dataloading.DataLoader,
) -> Tuple[float]:
model.train()
total_loss = 0
total_accuracy = 0
start = default_timer()
embedding_layer = embedding_layer.to(device)
model = model.to(device)
loss_function = loss_function.to(device)
for step, (in_nodes, out_nodes, blocks) in enumerate(dataloader):
embedding_optimizer.zero_grad()
model_optimizer.zero_grad()
in_nodes = {rel: nid.to(device) for rel, nid in in_nodes.items()}
out_nodes = out_nodes[predict_category].to(labels.device)
blocks = [block.to(device) for block in blocks]
batch_labels = labels[out_nodes].to(device)
embedding = embedding_layer(in_nodes=in_nodes, device=device)
logits = model(blocks, embedding)[predict_category]
loss = loss_function(logits, batch_labels)
indices = logits.argmax(dim=-1)
correct = torch.sum(indices == batch_labels)
accuracy = correct.item() / len(batch_labels)
loss.backward()
model_optimizer.step()
embedding_optimizer.step()
total_loss += loss.item()
total_accuracy += accuracy
stop = default_timer()
time = stop - start
total_loss /= step + 1
total_accuracy /= step + 1
return time, total_loss, total_accuracy
def validate(
embedding_layer: nn.Module,
model: nn.Module,
device: Union[str, torch.device],
inference_mode: str,
loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
hg: dgl.DGLHeteroGraph,
labels: torch.Tensor,
predict_category: str,
dataloader: dgl.dataloading.DataLoader = None,
eval_batch_size: int = None,
eval_num_workers: int = None,
mask: torch.Tensor = None,
) -> Tuple[float]:
embedding_layer.eval()
model.eval()
start = default_timer()
embedding_layer = embedding_layer.to(device)
model = model.to(device)
loss_function = loss_function.to(device)
valid_labels = labels[mask].to(device)
with torch.no_grad():
if inference_mode == 'neighbor_sampler':
total_loss = 0
total_accuracy = 0
for step, (in_nodes, out_nodes, blocks) in enumerate(dataloader):
in_nodes = {rel: nid.to(device)
for rel, nid in in_nodes.items()}
out_nodes = out_nodes[predict_category].to(labels.device)
blocks = [block.to(device) for block in blocks]
batch_labels = labels[out_nodes].to(device)
embedding = embedding_layer(in_nodes=in_nodes, device=device)
logits = model(blocks, embedding)[predict_category]
loss = loss_function(logits, batch_labels)
indices = logits.argmax(dim=-1)
correct = torch.sum(indices == batch_labels)
accuracy = correct.item() / len(batch_labels)
total_loss += loss.item()
total_accuracy += accuracy
total_loss /= step + 1
total_accuracy /= step + 1
elif inference_mode == 'full_neighbor_sampler':
logits = model.inference(
hg,
eval_batch_size,
eval_num_workers,
embedding_layer,
device,
)[predict_category][mask]
total_loss = loss_function(logits, valid_labels)
indices = logits.argmax(dim=-1)
correct = torch.sum(indices == valid_labels)
total_accuracy = correct.item() / len(valid_labels)
total_loss = total_loss.item()
else:
embedding = embedding_layer(device=device)
logits = model(hg, embedding)[predict_category][mask]
total_loss = loss_function(logits, valid_labels)
indices = logits.argmax(dim=-1)
correct = torch.sum(indices == valid_labels)
total_accuracy = correct.item() / len(valid_labels)
total_loss = total_loss.item()
stop = default_timer()
time = stop - start
return time, total_loss, total_accuracy
def run(args: argparse.ArgumentParser) -> None:
torch.manual_seed(args.seed)
dataset, hg, train_idx, valid_idx, test_idx = utils.process_dataset(
args.dataset,
root=args.dataset_root,
)
predict_category = dataset.predict_category
labels = hg.nodes[predict_category].data['labels']
training_device = torch.device('cuda' if args.gpu_training else 'cpu')
inference_device = torch.device('cuda' if args.gpu_inference else 'cpu')
inferfence_mode = args.inference_mode
fanouts = [int(fanout) for fanout in args.fanouts.split(',')]
train_sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
train_dataloader = dgl.dataloading.DataLoader(
hg,
{predict_category: train_idx},
train_sampler,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers,
)
if inferfence_mode == 'neighbor_sampler':
valid_sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
valid_dataloader = dgl.dataloading.DataLoader(
hg,
{predict_category: valid_idx},
valid_sampler,
batch_size=args.eval_batch_size,
shuffle=False,
drop_last=False,
num_workers=args.eval_num_workers,
)
if args.test_validation:
test_sampler = dgl.dataloading.MultiLayerNeighborSampler(fanouts)
test_dataloader = dgl.dataloading.DataLoader(
hg,
{predict_category: test_idx},
test_sampler,
batch_size=args.eval_batch_size,
shuffle=False,
drop_last=False,
num_workers=args.eval_num_workers,
)
else:
valid_dataloader = None
if args.test_validation:
test_dataloader = None
in_feats = hg.nodes[predict_category].data['feat'].shape[-1]
out_feats = dataset.num_classes
num_nodes = {}
node_feats = {}
for ntype in hg.ntypes:
num_nodes[ntype] = hg.num_nodes(ntype)
node_feats[ntype] = hg.nodes[ntype].data.get('feat')
activations = {'leaky_relu': F.leaky_relu, 'relu': F.relu}
embedding_layer = RelGraphEmbedding(hg, in_feats, num_nodes, node_feats)
model = EntityClassify(
hg,
in_feats,
args.hidden_feats,
out_feats,
args.num_bases,
args.num_layers,
norm=args.norm,
layer_norm=args.layer_norm,
input_dropout=args.input_dropout,
dropout=args.dropout,
activation=activations[args.activation],
self_loop=args.self_loop,
)
loss_function = nn.CrossEntropyLoss()
embedding_optimizer = torch.optim.SparseAdam(
embedding_layer.node_embeddings.parameters(), lr=args.embedding_lr)
if args.node_feats_projection:
all_parameters = chain(
model.parameters(), embedding_layer.embeddings.parameters())
model_optimizer = torch.optim.Adam(all_parameters, lr=args.model_lr)
else:
model_optimizer = torch.optim.Adam(
model.parameters(), lr=args.model_lr)
checkpoint = utils.Callback(args.early_stopping_patience,
args.early_stopping_monitor)
print('## Training started ##')
for epoch in range(args.num_epochs):
train_time, train_loss, train_accuracy = train(
embedding_layer,
model,
training_device,
embedding_optimizer,
model_optimizer,
loss_function,
labels,
predict_category,
train_dataloader,
)
valid_time, valid_loss, valid_accuracy = validate(
embedding_layer,
model,
inference_device,
inferfence_mode,
loss_function,
hg,
labels,
predict_category=predict_category,
dataloader=valid_dataloader,
eval_batch_size=args.eval_batch_size,
eval_num_workers=args.eval_num_workers,
mask=valid_idx,
)
checkpoint.create(
epoch,
train_time,
valid_time,
train_loss,
valid_loss,
train_accuracy,
valid_accuracy,
{'embedding_layer': embedding_layer, 'model': model},
)
print(
f'Epoch: {epoch + 1:03} '
f'Train Loss: {train_loss:.2f} '
f'Valid Loss: {valid_loss:.2f} '
f'Train Accuracy: {train_accuracy:.4f} '
f'Valid Accuracy: {valid_accuracy:.4f} '
f'Train Epoch Time: {train_time:.2f} '
f'Valid Epoch Time: {valid_time:.2f}'
)
if checkpoint.should_stop:
print('## Training finished: early stopping ##')
break
elif epoch >= args.num_epochs - 1:
print('## Training finished ##')
print(
f'Best Epoch: {checkpoint.best_epoch} '
f'Train Loss: {checkpoint.best_epoch_train_loss:.2f} '
f'Valid Loss: {checkpoint.best_epoch_valid_loss:.2f} '
f'Train Accuracy: {checkpoint.best_epoch_train_accuracy:.4f} '
f'Valid Accuracy: {checkpoint.best_epoch_valid_accuracy:.4f}'
)
if args.test_validation:
print('## Test data validation ##')
embedding_layer.load_state_dict(
checkpoint.best_epoch_model_parameters['embedding_layer'])
model.load_state_dict(checkpoint.best_epoch_model_parameters['model'])
test_time, test_loss, test_accuracy = validate(
embedding_layer,
model,
inference_device,
inferfence_mode,
loss_function,
hg,
labels,
predict_category=predict_category,
dataloader=test_dataloader,
eval_batch_size=args.eval_batch_size,
eval_num_workers=args.eval_num_workers,
mask=test_idx,
)
print(
f'Test Loss: {test_loss:.2f} '
f'Test Accuracy: {test_accuracy:.4f} '
f'Test Epoch Time: {test_time:.2f}'
)
if __name__ == '__main__':
argparser = argparse.ArgumentParser('RGCN')
argparser.add_argument('--gpu-training', dest='gpu_training',
action='store_true')
argparser.add_argument('--no-gpu-training', dest='gpu_training',
action='store_false')
argparser.set_defaults(gpu_training=True)
argparser.add_argument('--gpu-inference', dest='gpu_inference',
action='store_true')
argparser.add_argument('--no-gpu-inference', dest='gpu_inference',
action='store_false')
argparser.set_defaults(gpu_inference=True)
argparser.add_argument('--inference-mode', default='neighbor_sampler', type=str,
choices=['neighbor_sampler', 'full_neighbor_sampler', 'full_graph'])
argparser.add_argument('--dataset', default='ogbn-mag', type=str,
choices=['ogbn-mag'])
argparser.add_argument('--dataset-root', default='dataset', type=str)
argparser.add_argument('--num-epochs', default=500, type=int)
argparser.add_argument('--embedding-lr', default=0.01, type=float)
argparser.add_argument('--model-lr', default=0.01, type=float)
argparser.add_argument('--node-feats-projection',
dest='node_feats_projection', action='store_true')
argparser.add_argument('--no-node-feats-projection',
dest='node_feats_projection', action='store_false')
argparser.set_defaults(node_feats_projection=False)
argparser.add_argument('--hidden-feats', default=64, type=int)
argparser.add_argument('--num-bases', default=2, type=int)
argparser.add_argument('--num-layers', default=2, type=int)
argparser.add_argument('--norm', default='right',
type=str, choices=['both', 'none', 'right'])
argparser.add_argument('--layer-norm', dest='layer_norm',
action='store_true')
argparser.add_argument('--no-layer-norm', dest='layer_norm',
action='store_false')
argparser.set_defaults(layer_norm=False)
argparser.add_argument('--input-dropout', default=0.1, type=float)
argparser.add_argument('--dropout', default=0.5, type=float)
argparser.add_argument('--activation', default='relu', type=str,
choices=['leaky_relu', 'relu'])
argparser.add_argument('--self-loop', dest='self_loop',
action='store_true')
argparser.add_argument('--no-self-loop', dest='self_loop',
action='store_false')
argparser.set_defaults(self_loop=True)
argparser.add_argument('--fanouts', default='25,20', type=str)
argparser.add_argument('--batch-size', default=1024, type=int)
argparser.add_argument('--eval-batch-size', default=1024, type=int)
argparser.add_argument('--num-workers', default=4, type=int)
argparser.add_argument('--eval-num-workers', default=4, type=int)
argparser.add_argument('--early-stopping-patience', default=10, type=int)
argparser.add_argument('--early-stopping-monitor', default='loss',
type=str, choices=['accuracy', 'loss'])
argparser.add_argument('--test-validation', dest='test_validation',
action='store_true')
argparser.add_argument('--no-test-validation', dest='test_validation',
action='store_false')
argparser.set_defaults(test_validation=True)
argparser.add_argument('--seed', default=13, type=int)
args = argparser.parse_args()
run(args)
from typing import Callable, Dict, List, Union
import dgl
import dgl.nn.pytorch as dglnn
import torch
import torch.nn as nn
class RelGraphEmbedding(nn.Module):
def __init__(
self,
hg: dgl.DGLHeteroGraph,
embedding_size: int,
num_nodes: Dict[str, int],
node_feats: Dict[str, torch.Tensor],
node_feats_projection: bool = False,
):
super().__init__()
self._hg = hg
self._node_feats = node_feats
self._node_feats_projection = node_feats_projection
self.node_embeddings = nn.ModuleDict()
if node_feats_projection:
self.embeddings = nn.ParameterDict()
for ntype in hg.ntypes:
if node_feats[ntype] is None:
node_embedding = nn.Embedding(
num_nodes[ntype], embedding_size, sparse=True)
nn.init.uniform_(node_embedding.weight, -1, 1)
self.node_embeddings[ntype] = node_embedding
elif node_feats[ntype] is not None and node_feats_projection:
input_embedding_size = node_feats[ntype].shape[-1]
embedding = nn.Parameter(torch.Tensor(
input_embedding_size, embedding_size))
nn.init.xavier_uniform_(embedding)
self.embeddings[ntype] = embedding
def forward(
self,
in_nodes: Dict[str, torch.Tensor] = None,
device: torch.device = None,
) -> Dict[str, torch.Tensor]:
if in_nodes is not None:
ntypes = [ntype for ntype in in_nodes.keys()]
nids = [nid for nid in in_nodes.values()]
else:
ntypes = self._hg.ntypes
nids = [self._hg.nodes(ntype) for ntype in ntypes]
x = {}
for ntype, nid in zip(ntypes, nids):
if self._node_feats[ntype] is None:
x[ntype] = self.node_embeddings[ntype](nid)
else:
if device is not None:
self._node_feats[ntype] = self._node_feats[ntype].to(
device)
if self._node_feats_projection:
x[ntype] = self._node_feats[ntype][nid] @ self.embeddings[ntype]
else:
x[ntype] = self._node_feats[ntype][nid]
return x
class RelGraphConvLayer(nn.Module):
def __init__(
self,
in_feats: int,
out_feats: int,
rel_names: List[str],
num_bases: int,
norm: str = 'right',
weight: bool = True,
bias: bool = True,
activation: Callable[[torch.Tensor], torch.Tensor] = None,
dropout: float = None,
self_loop: bool = False,
):
super().__init__()
self._rel_names = rel_names
self._num_rels = len(rel_names)
self._conv = dglnn.HeteroGraphConv({rel: dglnn.GraphConv(
in_feats, out_feats, norm=norm, weight=False, bias=False) for rel in rel_names})
self._use_weight = weight
self._use_basis = num_bases < self._num_rels and weight
self._use_bias = bias
self._activation = activation
self._dropout = nn.Dropout(dropout) if dropout is not None else None
self._use_self_loop = self_loop
if weight:
if self._use_basis:
self.basis = dglnn.WeightBasis(
(in_feats, out_feats), num_bases, self._num_rels)
else:
self.weight = nn.Parameter(torch.Tensor(
self._num_rels, in_feats, out_feats))
nn.init.xavier_uniform_(
self.weight, gain=nn.init.calculate_gain('relu'))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_feats))
nn.init.zeros_(self.bias)
if self_loop:
self.self_loop_weight = nn.Parameter(
torch.Tensor(in_feats, out_feats))
nn.init.xavier_uniform_(
self.self_loop_weight, gain=nn.init.calculate_gain('relu'))
def _apply_layers(
self,
ntype: str,
inputs: torch.Tensor,
inputs_dst: torch.Tensor = None,
) -> torch.Tensor:
x = inputs
if inputs_dst is not None:
x += torch.matmul(inputs_dst[ntype], self.self_loop_weight)
if self._use_bias:
x += self.bias
if self._activation is not None:
x = self._activation(x)
if self._dropout is not None:
x = self._dropout(x)
return x
def forward(
self,
hg: dgl.DGLHeteroGraph,
inputs: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
hg = hg.local_var()
if self._use_weight:
weight = self.basis() if self._use_basis else self.weight
weight_dict = {self._rel_names[i]: {'weight': w.squeeze(
dim=0)} for i, w in enumerate(torch.split(weight, 1, dim=0))}
else:
weight_dict = {}
if self._use_self_loop:
if hg.is_block:
inputs_dst = {ntype: h[:hg.num_dst_nodes(
ntype)] for ntype, h in inputs.items()}
else:
inputs_dst = inputs
else:
inputs_dst = None
x = self._conv(hg, inputs, mod_kwargs=weight_dict)
x = {ntype: self._apply_layers(ntype, h, inputs_dst)
for ntype, h in x.items()}
return x
class EntityClassify(nn.Module):
def __init__(
self,
hg: dgl.DGLHeteroGraph,
in_feats: int,
hidden_feats: int,
out_feats: int,
num_bases: int,
num_layers: int,
norm: str = 'right',
layer_norm: bool = False,
input_dropout: float = 0,
dropout: float = 0,
activation: Callable[[torch.Tensor], torch.Tensor] = None,
self_loop: bool = False,
):
super().__init__()
self._hidden_feats = hidden_feats
self._out_feats = out_feats
self._num_layers = num_layers
self._input_dropout = nn.Dropout(input_dropout)
self._dropout = nn.Dropout(dropout)
self._activation = activation
self._rel_names = sorted(list(set(hg.etypes)))
self._num_rels = len(self._rel_names)
if num_bases < 0 or num_bases > self._num_rels:
self._num_bases = self._num_rels
else:
self._num_bases = num_bases
self._layers = nn.ModuleList()
self._layers.append(RelGraphConvLayer(
in_feats,
hidden_feats,
self._rel_names,
self._num_bases,
norm=norm,
self_loop=self_loop,
))
for _ in range(1, num_layers - 1):
self._layers.append(RelGraphConvLayer(
hidden_feats,
hidden_feats,
self._rel_names,
self._num_bases,
norm=norm,
self_loop=self_loop,
))
self._layers.append(RelGraphConvLayer(
hidden_feats,
out_feats,
self._rel_names,
self._num_bases,
norm=norm,
self_loop=self_loop,
))
if layer_norm:
self._layer_norms = nn.ModuleList()
for _ in range(num_layers - 1):
self._layer_norms.append(nn.LayerNorm(hidden_feats))
else:
self._layer_norms = None
def _apply_layers(
self,
layer_idx: int,
inputs: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
x = inputs
for ntype, h in x.items():
if self._layer_norms is not None:
h = self._layer_norms[layer_idx](h)
if self._activation is not None:
h = self._activation(h)
x[ntype] = self._dropout(h)
return x
def forward(
self,
hg: Union[dgl.DGLHeteroGraph, List[dgl.DGLHeteroGraph]],
inputs: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
x = {ntype: self._input_dropout(h) for ntype, h in inputs.items()}
if isinstance(hg, list):
for i, (layer, block) in enumerate(zip(self._layers, hg)):
x = layer(block, x)
if i < self._num_layers - 1:
x = self._apply_layers(i, x)
else:
for i, layer in enumerate(self._layers):
x = layer(hg, x)
if i < self._num_layers - 1:
x = self._apply_layers(i, x)
return x
def inference(
self,
hg: dgl.DGLHeteroGraph,
batch_size: int,
num_workers: int,
embedding_layer: nn.Module,
device: torch.device,
) -> Dict[str, torch.Tensor]:
for i, layer in enumerate(self._layers):
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.DataLoader(
hg,
{ntype: hg.nodes(ntype) for ntype in hg.ntypes},
sampler,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=num_workers,
)
if i < self._num_layers - 1:
y = {ntype: torch.zeros(hg.num_nodes(
ntype), self._hidden_feats, device=device) for ntype in hg.ntypes}
else:
y = {ntype: torch.zeros(hg.num_nodes(
ntype), self._out_feats, device=device) for ntype in hg.ntypes}
for in_nodes, out_nodes, blocks in dataloader:
in_nodes = {rel: nid.to(device)
for rel, nid in in_nodes.items()}
out_nodes = {rel: nid.to(device)
for rel, nid in out_nodes.items()}
block = blocks[0].to(device)
if i == 0:
h = embedding_layer(in_nodes=in_nodes, device=device)
else:
h = {ntype: x[ntype][in_nodes[ntype]]
for ntype in hg.ntypes}
h = layer(block, h)
if i < self._num_layers - 1:
h = self._apply_layers(i, h)
for ntype in out_nodes:
y[ntype][out_nodes[ntype]] = h[ntype]
x = y
return x
from copy import deepcopy
from typing import Dict, List, Tuple, Union
import dgl
import torch
import torch.nn as nn
from ogb.nodeproppred import DglNodePropPredDataset
class Callback:
def __init__(
self,
patience: int,
monitor: str,
) -> None:
self._patience = patience
self._monitor = monitor
self._lookback = 0
self._best_epoch = None
self._train_times = []
self._valid_times = []
self._train_losses = []
self._valid_losses = []
self._train_accuracies = []
self._valid_accuracies = []
self._model_parameters = {}
@property
def best_epoch(self) -> int:
return self._best_epoch + 1
@property
def train_times(self) -> List[float]:
return self._train_times
@property
def valid_times(self) -> List[float]:
return self._valid_times
@property
def train_losses(self) -> List[float]:
return self._train_losses
@property
def valid_losses(self) -> List[float]:
return self._valid_losses
@property
def train_accuracies(self) -> List[float]:
return self._train_accuracies
@property
def valid_accuracies(self) -> List[float]:
return self._valid_accuracies
@property
def best_epoch_training_time(self) -> float:
return sum(self._train_times[:self._best_epoch])
@property
def best_epoch_train_loss(self) -> float:
return self._train_losses[self._best_epoch]
@property
def best_epoch_valid_loss(self) -> float:
return self._valid_losses[self._best_epoch]
@property
def best_epoch_train_accuracy(self) -> float:
return self._train_accuracies[self._best_epoch]
@property
def best_epoch_valid_accuracy(self) -> float:
return self._valid_accuracies[self._best_epoch]
@property
def best_epoch_model_parameters(
self) -> Union[Dict[str, torch.Tensor], Dict[str, Dict[str, torch.Tensor]]]:
return self._model_parameters
@property
def should_stop(self) -> bool:
return self._lookback >= self._patience
def create(
self,
epoch: int,
train_time: float,
valid_time: float,
train_loss: float,
valid_loss: float,
train_accuracy: float,
valid_accuracy: float,
model: Union[nn.Module, Dict[str, nn.Module]],
) -> None:
self._train_times.append(train_time)
self._valid_times.append(valid_time)
self._train_losses.append(train_loss)
self._valid_losses.append(valid_loss)
self._train_accuracies.append(train_accuracy)
self._valid_accuracies.append(valid_accuracy)
best_epoch = False
if self._best_epoch is None:
best_epoch = True
elif self._monitor == 'loss':
if valid_loss < self._valid_losses[self._best_epoch]:
best_epoch = True
elif self._monitor == 'accuracy':
if valid_accuracy > self._valid_accuracies[self._best_epoch]:
best_epoch = True
if best_epoch:
self._best_epoch = epoch
if isinstance(model, dict):
for name, current_model in model.items():
self._model_parameters[name] = deepcopy(
current_model.to('cpu').state_dict())
else:
self._model_parameters = deepcopy(model.to('cpu').state_dict())
self._lookback = 0
else:
self._lookback += 1
class OGBDataset:
def __init__(
self,
g: Union[dgl.DGLGraph, dgl.DGLHeteroGraph],
num_labels: int,
predict_category: str = None,
) -> None:
self._g = g
self._num_labels = num_labels
self._predict_category = predict_category
@property
def num_labels(self) -> int:
return self._num_labels
@property
def num_classes(self) -> int:
return self._num_labels
@property
def predict_category(self) -> str:
return self._predict_category
def __getitem__(self, idx: int) -> Union[dgl.DGLGraph, dgl.DGLHeteroGraph]:
return self._g
def load_ogbn_mag(root: str = None) -> OGBDataset:
dataset = DglNodePropPredDataset(name='ogbn-mag', root=root)
split_idx = dataset.get_idx_split()
train_idx = split_idx['train']['paper']
valid_idx = split_idx['valid']['paper']
test_idx = split_idx['test']['paper']
hg_original, labels = dataset[0]
labels = labels['paper'].squeeze()
num_labels = dataset.num_classes
subgraphs = {}
for etype in hg_original.canonical_etypes:
src, dst = hg_original.all_edges(etype=etype)
subgraphs[etype] = (src, dst)
subgraphs[(etype[2], f'rev-{etype[1]}', etype[0])] = (dst, src)
hg = dgl.heterograph(subgraphs)
hg.nodes['paper'].data['feat'] = hg_original.nodes['paper'].data['feat']
hg.nodes['paper'].data['labels'] = labels
train_mask = torch.zeros((hg.num_nodes('paper'),), dtype=torch.bool)
train_mask[train_idx] = True
valid_mask = torch.zeros((hg.num_nodes('paper'),), dtype=torch.bool)
valid_mask[valid_idx] = True
test_mask = torch.zeros((hg.num_nodes('paper'),), dtype=torch.bool)
test_mask[test_idx] = True
hg.nodes['paper'].data['train_mask'] = train_mask
hg.nodes['paper'].data['valid_mask'] = valid_mask
hg.nodes['paper'].data['test_mask'] = test_mask
ogb_dataset = OGBDataset(hg, num_labels, 'paper')
return ogb_dataset
def process_dataset(
name: str,
root: str = None,
) -> Tuple[OGBDataset, dgl.DGLHeteroGraph, torch.Tensor]:
if root is None:
root = 'datasets'
if name == 'ogbn-mag':
dataset = load_ogbn_mag(root=root)
g = dataset[0]
predict_category = dataset.predict_category
train_idx = torch.nonzero(
g.nodes[predict_category].data['train_mask'], as_tuple=True)[0]
valid_idx = torch.nonzero(
g.nodes[predict_category].data['valid_mask'], as_tuple=True)[0]
test_idx = torch.nonzero(
g.nodes[predict_category].data['test_mask'], as_tuple=True)[0]
return dataset, g, train_idx, valid_idx, test_idx
......@@ -360,6 +360,13 @@ class HeteroEmbedding(nn.Module):
"""
return {self.raw_keys[typ]: emb.weight for typ, emb in self.embeds.items()}
def reset_parameters(self):
"""
Use the xavier method in nn.init module to make the parameters uniformly distributed
"""
for typ in self.embeds.keys():
nn.init.xavier_uniform_(self.embeds[typ].weight)
def forward(self, input_ids):
"""Forward function
......
......@@ -1262,6 +1262,11 @@ def test_hetero_embedding(out_dim):
assert embeds['user'].shape == (2, out_dim)
assert embeds[('user', 'follows', 'user')].shape == (3, out_dim)
layer.reset_parameters()
embeds = layer.weight
assert embeds['user'].shape == (2, out_dim)
assert embeds[('user', 'follows', 'user')].shape == (3, out_dim)
embeds = layer({
'user': F.tensor([0], dtype=F.int64),
('user', 'follows', 'user'): F.tensor([0, 2], dtype=F.int64)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment