Unverified Commit 0227ddfb authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[NN] Rework RelGraphConv and HGTConv (#3742)

* WIP: TypedLinear and new RelGraphConv

* wip

* further simplify RGCN

* a bunch of tweak for performance; add basic cpu support

* update on segmm

* wip: segment.cu

* new backward kernel works

* fix a bunch of bugs in kernel; leave idx_a for future

* add nn test for typed_linear

* rgcn nn test

* bugfix in corner case; update RGCN README

* doc

* fix cpp lint

* fix lint

* fix ut

* wip: hgtconv; presorted flag for rgcn

* hgt code and ut; WIP: some fix on reorder graph

* better typed linear init

* fix ut

* fix lint; add docstring
parent 4f00d5ac
......@@ -16,22 +16,19 @@ class RGCN(nn.Module):
num_rels,
num_bases,
num_hidden_layers,
dropout,
lowmem):
dropout):
super(RGCN, self).__init__()
self.layers = nn.ModuleList()
# i2h
self.layers.append(RelGraphConv(num_nodes, n_hidden, num_rels, "basis",
num_bases, activation=F.relu, dropout=dropout,
low_mem=lowmem))
num_bases, activation=F.relu, dropout=dropout))
# h2h
for i in range(num_hidden_layers):
self.layers.append(RelGraphConv(n_hidden, n_hidden, num_rels, "basis",
num_bases, activation=F.relu, dropout=dropout,
low_mem=lowmem))
num_bases, activation=F.relu, dropout=dropout))
# o2h
self.layers.append(RelGraphConv(n_hidden, num_classes, num_rels, "basis",
num_bases, activation=None, low_mem=lowmem))
num_bases, activation=None))
def forward(self, g, h, r, norm):
for layer in self.layers:
......@@ -40,9 +37,8 @@ class RGCN(nn.Module):
@utils.benchmark('time', 300)
@utils.parametrize('data', ['aifb'])
@utils.parametrize('lowmem', [True, False])
@utils.parametrize('use_type_count', [True, False])
def track_time(data, lowmem, use_type_count):
def track_time(data, use_type_count):
# args
if data == 'aifb':
num_bases = -1
......@@ -108,8 +104,7 @@ def track_time(data, lowmem, use_type_count):
num_rels,
num_bases,
0,
0,
lowmem).to(device)
0).to(device)
optimizer = torch.optim.Adam(model.parameters(),
lr=1e-2,
......
......@@ -295,7 +295,7 @@ TransR
:members: rel_emb, rel_project, forward, reset_parameters
:show-inheritance:
Heterogeneous Graph Convolution Module
Heterogeneous Learning Module
----------------------------------------
HeteroGraphConv
......@@ -319,9 +319,17 @@ HeteroEmbedding
.. _apinn-pytorch-util:
Utility Modules
----------------------------------------
TypedLinear
----------------------------------------
.. autoclass:: dgl.nn.pytorch.TypedLinear
:members: forward
:show-inheritance:
Sequential
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -18,57 +18,36 @@ pip install rdflib pandas
Example code was tested with rdflib 4.2.2 and pandas 0.23.4
### Entity Classification
AIFB: accuracy 96.29% (3 runs, DGL), 95.83% (paper)
```
python entity.py -d aifb --l2norm 0 --gpu 0
```
MUTAG: accuracy 72.55% (3 runs, DGL), 73.23% (paper)
For AIFB, MUTAG, BGS and AM,
```
python entity.py -d aifb --wd 0 --gpu 0
python entity.py -d mutag --n-bases 30 --gpu 0
```
BGS: accuracy 89.70% (3 runs, DGL), 83.10% (paper)
```
python entity.py -d bgs --n-bases 40 --gpu 0
```
AM: accuracy 89.56% (3 runs, DGL), 89.29% (paper)
```
python entity.py -d am --n-bases 40 --n-hidden 10
python entity.py -d am --n-bases 40 --n-hidden 10 --gpu 0
```
### Entity Classification with minibatch
AIFB: accuracy avg(5 runs) 91.10%, best 97.22% (DGL)
For AIFB, MUTAG, BGS and AM,
```
python entity_sample.py -d aifb --l2norm 0 --gpu 0 --fanout='20,20' --batch-size 128
python entity_sample.py -d aifb --wd 0 --gpu 0 --fanout='20,20' --batch-size 128
python entity_sample.py -d mutag --n-bases 30 --gpu 0 --batch-size 64 --fanout='-1,-1' --use-self-loop --n-epochs 20 --dropout 0.5
python entity_sample.py -d bgs --n-bases 40 --gpu 0 --fanout='-1,-1' --n-epochs=16 --batch-size=16 --dropout 0.3
python entity_sample.py -d am --n-bases 40 --gpu 0 --fanout='35,35' --batch-size 64 --n-hidden 16 --use-self-loop --n-epochs=20 --dropout 0.7
```
MUTAG: accuracy avg(10 runs) 66.47%, best 72.06% (DGL)
```
python entity_sample.py -d mutag --n-bases 30 --gpu 0 --batch-size 64 --fanout "-1, -1" --use-self-loop --n-epochs 20 --sparse-lr 0.01 --dropout 0.5
```
BGS: accuracy avg(5 runs) 84.83%, best 89.66% (DGL)
```
python entity_sample.py -d bgs --n-bases 40 --gpu 0 --fanout "-1, -1" --n-epochs=16 --batch-size=16 --sparse-lr 0.05 --dropout 0.3
```
AM: accuracy avg(5 runs) 88.58%, best 89.90% (DGL)
```
python entity_sample.py -d am --n-bases 40 --gpu 0 --fanout '35,35' --batch-size 64 --n-hidden 16 --use-self-loop --n-epochs=20 --sparse-lr 0.02 --dropout 0.7
```
### Entity Classification on multiple GPUs
To use multiple GPUs, replace `entity_sample.py` with `entity_sample_multi_gpu.py` and specify
multiple GPU IDs separated by comma, e.g., `--gpu 0,1`.
### Link Prediction
FB15k-237: MRR 0.163 (DGL), 0.158 (paper)
FB15k-237 in RAW-MRR
```
python link.py --gpu 0 --eval-protocol raw
```
FB15k-237: Filtered-MRR 0.247
FB15k-237 in Filtered-MRR
```
python link.py --gpu 0 --eval-protocol filtered
```
"""
Differences compared to tkipf/relation-gcn
* l2norm applied to all weights
* remove nodes that won't be touched
* weight decay applied to all weights
"""
import argparse
import torch as th
import torch.nn.functional as F
......@@ -17,13 +15,7 @@ def main(args):
g, num_rels, num_classes, labels, train_idx, test_idx, target_idx = load_data(
args.dataset, get_norm=True)
num_nodes = g.num_nodes()
# Since the nodes are featureless, learn node embeddings from scratch
# This requires passing the node IDs to the model.
feats = th.arange(num_nodes)
model = RGCN(num_nodes,
model = RGCN(g.num_nodes(),
args.n_hidden,
num_classes,
num_rels,
......@@ -33,16 +25,15 @@ def main(args):
device = th.device(args.gpu)
else:
device = th.device('cpu')
feats = feats.to(device)
labels = labels.to(device)
model = model.to(device)
g = g.to(device)
g = g.int().to(device)
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.l2norm)
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.wd)
model.train()
for epoch in range(50):
logits = model(g, feats)
for epoch in range(100):
logits = model(g)
logits = logits[target_idx]
loss = F.cross_entropy(logits[train_idx], labels[train_idx])
optimizer.zero_grad()
......@@ -56,7 +47,7 @@ def main(args):
model.eval()
with th.no_grad():
logits = model(g, feats)
logits = model(g)
logits = logits[target_idx]
test_acc = accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item()
print("Test Accuracy: {:.4f}".format(test_acc))
......@@ -72,8 +63,8 @@ if __name__ == '__main__':
parser.add_argument("-d", "--dataset", type=str, required=True,
choices=['aifb', 'mutag', 'bgs', 'am'],
help="dataset to use")
parser.add_argument("--l2norm", type=float, default=5e-4,
help="l2 norm coef")
parser.add_argument("--wd", type=float, default=5e-4,
help="weight decay")
args = parser.parse_args()
print(args)
......
"""
Differences compared to tkipf/relation-gcn
* l2norm applied to all weights
* weight decay applied to all weights
* remove nodes that won't be touched
"""
import argparse
......@@ -13,7 +13,7 @@ from torchmetrics.functional import accuracy
from tqdm import tqdm
from entity_utils import load_data
from model import RelGraphEmbedLayer, RGCN
from model import RGCN
def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=False):
fanouts = [int(fanout) for fanout in args.fanout.split(',')]
......@@ -54,21 +54,6 @@ def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=F
return train_loader, val_loader, test_loader
def init_models(args, device, num_nodes, num_classes, num_rels):
embed_layer = RelGraphEmbedLayer(device,
num_nodes,
args.n_hidden)
model = RGCN(args.n_hidden,
args.n_hidden,
num_classes,
num_rels,
num_bases=args.n_bases,
dropout=args.dropout,
self_loop=args.use_self_loop)
return embed_layer, model
def process_batch(inv_target, batch):
_, seeds, blocks = batch
# map the seed nodes back to their type-specific ids,
......@@ -80,38 +65,32 @@ def process_batch(inv_target, batch):
return seeds, blocks
def train(model, embed_layer, train_loader, inv_target,
labels, emb_optimizer, optimizer):
def train(model, train_loader, inv_target,
labels, optimizer):
model.train()
embed_layer.train()
for sample_data in train_loader:
seeds, blocks = process_batch(inv_target, sample_data)
feats = embed_layer(blocks[0].srcdata[dgl.NID].cpu())
logits = model(blocks, feats)
logits = model.forward(blocks)
loss = F.cross_entropy(logits, labels[seeds])
emb_optimizer.zero_grad()
optimizer.zero_grad()
optimizer.zero_grad()
loss.backward()
emb_optimizer.step()
optimizer.step()
train_acc = accuracy(logits.argmax(dim=1), labels[seeds]).item()
return train_acc, loss.item()
def evaluate(model, embed_layer, eval_loader, inv_target):
def evaluate(model, eval_loader, inv_target):
model.eval()
embed_layer.eval()
eval_logits = []
eval_seeds = []
with th.no_grad():
for sample_data in tqdm(eval_loader):
seeds, blocks = process_batch(inv_target, sample_data)
feats = embed_layer(blocks[0].srcdata[dgl.NID].cpu())
logits = model(blocks, feats)
logits = model.forward(blocks)
eval_logits.append(logits.cpu().detach())
eval_seeds.append(seeds.cpu().detach())
......@@ -131,26 +110,30 @@ def main(args):
train_loader, val_loader, test_loader = init_dataloaders(
args, g, train_idx, test_idx, target_idx, args.gpu)
embed_layer, model = init_models(args, device, g.num_nodes(), num_classes, num_rels)
model = RGCN(g.num_nodes(),
args.n_hidden,
num_classes,
num_rels,
num_bases=args.n_bases,
dropout=args.dropout,
self_loop=args.use_self_loop,
ns_mode=True)
labels = labels.to(device)
model = model.to(device)
emb_optimizer = th.optim.SparseAdam(embed_layer.parameters(), lr=args.sparse_lr)
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.l2norm)
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.wd)
for epoch in range(args.n_epochs):
train_acc, loss = train(model, embed_layer, train_loader, inv_target,
labels, emb_optimizer, optimizer)
train_acc, loss = train(model, train_loader, inv_target, labels, optimizer)
print("Epoch {:05d}/{:05d} | Train Accuracy: {:.4f} | Train Loss: {:.4f}".format(
epoch, args.n_epochs, train_acc, loss))
val_logits, val_seeds = evaluate(model, embed_layer, val_loader, inv_target)
val_logits, val_seeds = evaluate(model, val_loader, inv_target)
val_acc = accuracy(val_logits.argmax(dim=1), labels[val_seeds].cpu()).item()
print("Validation Accuracy: {:.4f}".format(val_acc))
test_logits, test_seeds = evaluate(model, embed_layer,
test_loader, inv_target)
test_logits, test_seeds = evaluate(model, test_loader, inv_target)
test_acc = accuracy(test_logits.argmax(dim=1), labels[test_seeds].cpu()).item()
print("Final Test Accuracy: {:.4f}".format(test_acc))
......@@ -162,8 +145,6 @@ if __name__ == '__main__':
help="number of hidden units")
parser.add_argument("--gpu", type=int, default=0,
help="gpu")
parser.add_argument("--sparse-lr", type=float, default=2e-2,
help="sparse embedding learning rate")
parser.add_argument("--n-bases", type=int, default=-1,
help="number of filter weight matrices, default: -1 [use all]")
parser.add_argument("--n-epochs", type=int, default=50,
......@@ -171,8 +152,8 @@ if __name__ == '__main__':
parser.add_argument("-d", "--dataset", type=str, required=True,
choices=['aifb', 'mutag', 'bgs', 'am'],
help="dataset to use")
parser.add_argument("--l2norm", type=float, default=5e-4,
help="l2 norm coef")
parser.add_argument("--wd", type=float, default=5e-4,
help="weight decay")
parser.add_argument("--fanout", type=str, default="4, 4",
help="Fan-out of neighbor sampling")
parser.add_argument("--use-self-loop", default=False, action='store_true',
......
"""
Differences compared to tkipf/relation-gcn
* l2norm applied to all weights
* remove nodes that won't be touched
* weight decay applied to all weights
"""
import argparse
import gc
......@@ -14,7 +13,8 @@ from torchmetrics.functional import accuracy
from torch.nn.parallel import DistributedDataParallel
from entity_utils import load_data
from entity_sample import init_dataloaders, init_models, train, evaluate
from entity_sample import init_dataloaders, train, evaluate
from model import RGCN
def collect_eval(n_gpus, queue, labels):
eval_logits = []
......@@ -48,21 +48,25 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, queue=None):
use_ddp = True if n_gpus > 1 else False
train_loader, val_loader, test_loader = init_dataloaders(
args, g, train_idx, test_idx, target_idx, dev_id, use_ddp=use_ddp)
embed_layer, model = init_models(args, device, g.num_nodes(), num_classes, num_rels)
model = RGCN(g.num_nodes(),
args.n_hidden,
num_classes,
num_rels,
num_bases=args.n_bases,
dropout=args.dropout,
self_loop=args.use_self_loop,
ns_mode=True)
labels = labels.to(device)
model = model.to(device)
model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
embed_layer = DistributedDataParallel(embed_layer, device_ids=None, output_device=None)
emb_optimizer = th.optim.SparseAdam(embed_layer.module.parameters(), lr=args.sparse_lr)
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.l2norm)
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.wd)
th.set_num_threads(n_cpus)
for epoch in range(args.n_epochs):
train_loader.set_epoch(epoch)
train_acc, loss = train(model, embed_layer, train_loader, inv_target,
labels, emb_optimizer, optimizer)
train_acc, loss = train(model, train_loader, inv_target,
labels, optimizer)
if proc_id == 0:
print("Epoch {:05d}/{:05d} | Train Accuracy: {:.4f} | Train Loss: {:.4f}".format(
......@@ -71,7 +75,7 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, queue=None):
# garbage collection that empties the queue
gc.collect()
val_logits, val_seeds = evaluate(model, embed_layer, val_loader, inv_target)
val_logits, val_seeds = evaluate(model, val_loader, inv_target)
queue.put((val_logits, val_seeds))
# gather evaluation result from multiple processes
......@@ -81,7 +85,7 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, queue=None):
# garbage collection that empties the queue
gc.collect()
test_logits, test_seeds = evaluate(model, embed_layer, test_loader, inv_target)
test_logits, test_seeds = evaluate(model, test_loader, inv_target)
queue.put((test_logits, test_seeds))
if proc_id == 0:
test_acc = collect_eval(n_gpus, queue, labels)
......@@ -119,8 +123,6 @@ if __name__ == '__main__':
help="number of hidden units")
parser.add_argument("--gpu", type=str, default='0',
help="gpu")
parser.add_argument("--sparse-lr", type=float, default=2e-2,
help="sparse embedding learning rate")
parser.add_argument("--n-bases", type=int, default=-1,
help="number of filter weight matrices, default: -1 [use all]")
parser.add_argument("--n-epochs", type=int, default=50,
......@@ -128,8 +130,8 @@ if __name__ == '__main__':
parser.add_argument("-d", "--dataset", type=str, required=True,
choices=['aifb', 'mutag', 'bgs', 'am'],
help="dataset to use")
parser.add_argument("--l2norm", type=float, default=5e-4,
help="l2 norm coef")
parser.add_argument("--wd", type=float, default=5e-4,
help="weight decay")
parser.add_argument("--fanout", type=str, default="4, 4",
help="Fan-out of neighbor sampling")
parser.add_argument("--use-self-loop", default=False, action='store_true',
......
......@@ -20,7 +20,8 @@ class LinkPredict(nn.Module):
def __init__(self, in_dim, num_rels, h_dim=500, num_bases=100, dropout=0.2, reg_param=0.01):
super(LinkPredict, self).__init__()
self.rgcn = RGCN(in_dim, h_dim, h_dim, num_rels * 2, regularizer="bdd",
num_bases=num_bases, dropout=dropout, self_loop=True, link_pred=True)
num_bases=num_bases, dropout=dropout, self_loop=True)
self.dropout = nn.Dropout(dropout)
self.reg_param = reg_param
self.w_relation = nn.Parameter(th.Tensor(num_rels, h_dim))
nn.init.xavier_uniform_(self.w_relation,
......@@ -34,8 +35,8 @@ class LinkPredict(nn.Module):
score = th.sum(s * r * o, dim=1)
return score
def forward(self, g, h):
return self.rgcn(g, h)
def forward(self, g, nids):
return self.dropout(self.rgcn(g, nids=nids))
def regularization_loss(self, embedding):
return th.mean(embedding.pow(2)) + th.mean(self.w_relation.pow(2))
......@@ -54,7 +55,7 @@ def main(args):
num_rels = data.num_rels
train_g, test_g = preprocess(graph, num_rels)
test_node_id = th.arange(0, num_nodes).view(-1, 1)
test_nids = th.arange(0, num_nodes)
test_mask = graph.edata['test_mask']
subg_iter = SubgraphIterator(train_g, num_rels, args.edge_sampler)
dataloader = GraphDataLoader(subg_iter, batch_size=1, collate_fn=lambda x: x[0])
......@@ -77,14 +78,14 @@ def main(args):
for epoch, batch_data in enumerate(dataloader):
model.train()
g, node_id, data, labels = batch_data
g, train_nids, edges, labels = batch_data
g = g.to(device)
node_id = node_id.to(device)
data = data.to(device)
train_nids = train_nids.to(device)
edges = edges.to(device)
labels = labels.to(device)
embed = model(g, node_id)
loss = model.get_loss(embed, data, labels)
embed = model(g, train_nids)
loss = model.get_loss(embed, edges, labels)
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # clip gradients
......@@ -97,7 +98,7 @@ def main(args):
model = model.cpu()
model.eval()
print("start eval")
embed = model(test_g, test_node_id)
embed = model(test_g, test_nids)
mrr = calc_mrr(embed, model.w_relation, test_mask, triplets,
batch_size=500, eval_p=args.eval_protocol)
# save best model
......@@ -114,7 +115,7 @@ def main(args):
model.eval()
model.load_state_dict(checkpoint['state_dict'])
print("Using best epoch: {}".format(checkpoint['epoch']))
embed = model(test_g, test_node_id)
embed = model(test_g, test_nids)
calc_mrr(embed, model.w_relation, test_mask, triplets,
batch_size=500, eval_p=args.eval_protocol)
......
......@@ -158,7 +158,7 @@ class SubgraphIterator:
sub_g = dgl.graph((src, dst), num_nodes=num_nodes)
sub_g.edata[dgl.ETYPE] = th.from_numpy(rel)
sub_g.edata['norm'] = dgl.norm_by_dst(sub_g).unsqueeze(-1)
uniq_v = th.from_numpy(uniq_v).view(-1, 1).long()
uniq_v = th.from_numpy(uniq_v).view(-1).long()
return sub_g, uniq_v, samples, labels
......
......@@ -7,81 +7,32 @@ import dgl
from dgl.nn.pytorch import RelGraphConv
class RGCN(nn.Module):
def __init__(self, in_dim, h_dim, out_dim, num_rels,
def __init__(self, num_nodes, h_dim, out_dim, num_rels,
regularizer="basis", num_bases=-1, dropout=0.,
self_loop=False, link_pred=False):
self_loop=False,
ns_mode=False):
super(RGCN, self).__init__()
self.layers = nn.ModuleList()
if link_pred:
self.emb = nn.Embedding(in_dim, h_dim)
in_dim = h_dim
else:
self.emb = None
self.layers.append(RelGraphConv(in_dim, h_dim, num_rels, regularizer,
num_bases, activation=F.relu, self_loop=self_loop,
dropout=dropout))
# For entity classification, dropout should not be applied to the output layer
if not link_pred:
dropout = 0.
self.layers.append(RelGraphConv(h_dim, out_dim, num_rels, regularizer,
num_bases, self_loop=self_loop, dropout=dropout))
def forward(self, g, h):
if isinstance(g, DGLGraph):
blocks = [g] * len(self.layers)
if num_bases == -1:
num_bases = num_rels
self.emb = nn.Embedding(num_nodes, h_dim)
self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer,
num_bases, self_loop=self_loop)
self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer, num_bases, self_loop=self_loop)
self.dropout = nn.Dropout(dropout)
self.ns_mode = ns_mode
def forward(self, g, nids=None):
if self.ns_mode:
# forward for neighbor sampling
x = self.emb(g[0].srcdata[dgl.NID])
h = self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata['norm'])
h = self.dropout(F.relu(h))
h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata['norm'])
return h
else:
blocks = g
if self.emb is not None:
h = self.emb(h.squeeze())
for layer, block in zip(self.layers, blocks):
h = layer(block, h, block.edata[dgl.ETYPE], block.edata['norm'])
x = self.emb.weight if nids is None else self.emb(nids)
h = self.conv1(g, x, g.edata[dgl.ETYPE], g.edata['norm'])
h = self.dropout(F.relu(h))
h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata['norm'])
return h
def initializer(emb):
emb.uniform_(-1.0, 1.0)
return emb
class RelGraphEmbedLayer(nn.Module):
"""Embedding layer for featureless heterograph.
Parameters
----------
out_dev
Device to store the output embeddings
num_nodes : int
Number of nodes in the graph.
embed_size : int
Output embed size
"""
def __init__(self,
out_dev,
num_nodes,
embed_size):
super(RelGraphEmbedLayer, self).__init__()
self.out_dev = out_dev
self.embed_size = embed_size
# create embeddings for all nodes
self.node_embed = nn.Embedding(num_nodes, embed_size, sparse=True)
nn.init.uniform_(self.node_embed.weight, -1.0, 1.0)
def forward(self, node_ids):
"""Forward computation
Parameters
----------
node_ids : tensor
Raw node IDs.
Returns
-------
tensor
embeddings as the input of the next layer
"""
embeds = self.node_embed(node_ids).to(self.out_dev)
return embeds
......@@ -2,7 +2,8 @@ import torch as th
from distutils.version import LooseVersion
from ...base import is_all, ALL
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp
from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero, _gather_mm, _gather_mm_scatter, _segment_mm
from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero
from ...sparse import _gather_mm, _gather_mm_scatter, _segment_mm, _segment_mm_backward_B
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _edge_softmax_forward, _edge_softmax_backward
from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero
from ...heterograph_index import create_unitgraph_from_csr
......@@ -697,22 +698,16 @@ class SEGMENTMM(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, A, B, seglen_A):
if A.shape[0] != th.sum(seglen_A):
raise Exception("The summation of the elements of seglen_A must be equal to " +
"dimension 0 of A. Expected "+ str(A.shape[0]) + "got" + str(th.sum(seglen_A)))
if B.dim() != 3:
raise Exception("Expected dimension of B is 3. Got " + str(B.dim()))
# Reshaping B form 3D to 2D
B_3D_shape = B.shape
B = B.reshape(B.shape[0] * B.shape[1], B.shape[2])
C = th.zeros((A.shape[0], B.shape[1]), device=A.device, dtype=A.dtype)
raise ValueError("segment_mm expects B to be a 3D tensor.")
C = th.zeros((A.shape[0], B.shape[2]), device=A.device, dtype=A.dtype)
C = _segment_mm(A, B, C, seglen_A)
ctx.backward_cache = A, B, seglen_A, B_3D_shape
ctx.backward_cache = A, B, seglen_A
return C
@staticmethod
def backward(ctx, dZ):
A, B, seglen_A, B_3D_shape = ctx.backward_cache
A, B, seglen_A = ctx.backward_cache
A_grad = B_grad = None
if ctx.needs_input_grad[0]:
# Compute A_grad = Out_grad * B^T
......@@ -721,9 +716,8 @@ class SEGMENTMM(th.autograd.Function):
if ctx.needs_input_grad[1]:
# Compute B_grad = A^T * Out_grad
B_grad = th.zeros(B.shape, device=B.device, dtype=B.dtype)
B_grad = _segment_mm(A, dZ, B_grad, seglen_A, a_trans=True)
B_grad = B_grad.reshape(B_3D_shape[0], B_3D_shape[1], B_3D_shape[2])
return A_grad, B_grad, None, None, None, None, None, None
B_grad = _segment_mm_backward_B(A, dZ, B_grad, seglen_A)
return A_grad, B_grad, None
class GATHERMM(th.autograd.Function):
......@@ -731,31 +725,27 @@ class GATHERMM(th.autograd.Function):
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, A, B, idx_a, idx_b):
if B.dim() != 3:
raise Exception("Expected dimension of B is 3. Got " + str(B.dim()))
# Reshaping B form 3D to 2D
B_3D_shape = B.shape
B = B.reshape(B.shape[0] * B.shape[1], B.shape[2])
C = th.zeros((A.shape[0], B.shape[1]), device=A.device, dtype=A.dtype)
C = _gather_mm(A, B, C, B_3D_shape[0], idx_a, idx_b)
ctx.backward_cache = A, B, idx_a, idx_b, B_3D_shape
raise ValueError("Expected dimension of B is 3. Got " + str(B.dim()))
N = len(idx_b) if idx_a is None else len(idx_a)
C = th.zeros((N, B.shape[2]), device=A.device, dtype=A.dtype)
C = _gather_mm(A, B, C, idx_a, idx_b)
ctx.backward_cache = A, B, idx_a, idx_b
return C
@staticmethod
def backward(ctx, dZ):
A, B, idx_a, idx_b, B_3D_shape = ctx.backward_cache
A, B, idx_a, idx_b = ctx.backward_cache
A_grad = B_grad = None
if ctx.needs_input_grad[0]:
# Compute A_grad = Out_grad * B^T
A_grad = th.zeros(A.shape, device=A.device, dtype=A.dtype)
A_grad = _gather_mm_scatter(dZ, B, A_grad, B_3D_shape[0],
idx_b=idx_b, idx_c=idx_a, b_trans=True)
A_grad = _gather_mm_scatter(dZ, B.transpose(1, 2), A_grad,
idx_b=idx_b, idx_c=idx_a)
if ctx.needs_input_grad[1]:
# Compute B_grad = A^T * Out_grad
B_grad = th.zeros(B.shape, device=B.device, dtype=B.dtype)
B_grad = _gather_mm_scatter(A, dZ, B_grad, B_3D_shape[0],
idx_a=idx_a, idx_c=idx_b)
B_grad = B_grad.reshape(B_3D_shape[0], B_3D_shape[1], B_3D_shape[2])
return A_grad, B_grad, None, None, None, None, None, None
B_grad = _gather_mm_scatter(A, dZ, B_grad, idx_a=idx_a, idx_c=idx_b)
return A_grad, B_grad, None, None
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
if op == 'sub':
......@@ -834,7 +824,20 @@ def csrmask(gidxA, A_weights, gidxB):
return CSRMask.apply(gidxA, A_weights, gidxB)
def segment_mm(A, B, seglen_A):
if A.device.type == 'cpu':
C = []
off = 0
for i in range(B.shape[0]):
C.append(A[off:off+seglen_A[i]] @ B[i])
off += seglen_A[i]
return th.cat(C)
else:
return SEGMENTMM.apply(A, B, seglen_A)
def gather_mm(A, B, idx_a = None, idx_b = None):
return GATHERMM.apply(A, B, idx_a, idx_b)
def gather_mm(A, B, idx_A=None, idx_B=None):
if A.device.type == 'cpu':
A = A[idx_A] if idx_A is not None else A
B = B[idx_B] if idx_B is not None else B
return th.bmm(A.unsqueeze(1), B).squeeze(1)
else:
return GATHERMM.apply(A, B, idx_A, idx_B)
......@@ -2,6 +2,7 @@
from .conv import *
from .explain import *
from .link import *
from .linear import *
from .glob import *
from .softmax import *
from .factory import *
......
......@@ -25,9 +25,10 @@ from .cfconv import CFConv
from .dotgatconv import DotGatConv
from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention
from .gcn2conv import GCN2Conv
from .hgtconv import HGTConv
__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'GATv2Conv', 'EGATConv', 'TAGConv',
'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv',
'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'TWIRLSConv',
'TWIRLSUnfoldingAndAttention', 'GCN2Conv']
'TWIRLSUnfoldingAndAttention', 'GCN2Conv', 'HGTConv']
"""Heterogeneous Graph Transformer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import math
import torch
import torch.nn as nn
from .... import function as fn
from ..linear import TypedLinear
from ..softmax import edge_softmax
class HGTConv(nn.Module):
r"""Heterogeneous graph transformer convolution.
Introduced in "`Heterogeneous Graph Transformer <https://arxiv.org/abs/2003.01332>`__".
Given a graph :math:`G(V, E)` and input node features :math:`H^{(l-1)}`,
it computes the new node features as follows:
Compute a multi-head attention score for each edge :math:`(s, e, t)` in the graph:
.. math::
Attention(s, e, t) = \text{Softmax}\left(||_{i\in[1,h]}ATT-head^i(s, e, t)\right) \\
ATT-head^i(s, e, t) = \left(K^i(s)W^{ATT}_{\phi(e)}Q^i(t)^{\top}\right)\cdot
\frac{\mu_{(\tau(s),\phi(e),\tau(t)}}{\sqrt{d}} \\
K^i(s) = \text{K-Linear}^i_{\tau(s)}(H^{(l-1)}[s]) \\
Q^i(t) = \text{Q-Linear}^i_{\tau(t)}(H^{(l-1)}[t]) \\
Compute the message to send on each edge :math:`(s, e, t)`:
.. math::
Message(s, e, t) = ||_{i\in[1, h]} MSG-head^i(s, e, t) \\
MSG-head^i(s, e, t) = \text{M-Linear}^i_{\tau(s)}(H^{(l-1)}[s])W^{MSG}_{\phi(e)} \\
Send messages to target nodes :math:`t` and aggregate:
.. math::
\tilde{H}^{(l)}[t] = \sum_{\forall s\in \mathcal{N}(t)}\left( Attention(s,e,t)
\cdot Message(s,e,t)\right)
Compute new node features:
.. math::
H^{(l)}[t]=\text{A-Linear}_{\tau(t)}(\sigma(\tilde(H)^{(l)}[t])) + H^{(l-1)}[t]
Parameters
----------
in_size : int
Input node feature size.
head_size : int
Output head size. The output node feature size is ``head_size * num_heads``.
num_heads : int
Number of heads. The output node feature size is ``head_size * num_heads``.
num_ntypes : int
Number of node types.
num_etypes : int
Number of edge types.
dropout : optional, float
Dropout rate.
use_norm : optiona, bool
If true, apply a layer norm on the output node feature.
Examples
--------
"""
def __init__(self,
in_size,
head_size,
num_heads,
num_ntypes,
num_etypes,
dropout=0.2,
use_norm=False):
super().__init__()
self.in_size = in_size
self.head_size = head_size
self.num_heads = num_heads
self.sqrt_d = math.sqrt(head_size)
self.use_norm = use_norm
self.linear_k = TypedLinear(in_size, head_size * num_heads, num_ntypes)
self.linear_q = TypedLinear(in_size, head_size * num_heads, num_ntypes)
self.linear_v = TypedLinear(in_size, head_size * num_heads, num_ntypes)
self.linear_a = TypedLinear(head_size * num_heads, head_size * num_heads, num_ntypes)
self.relation_pri = nn.ParameterList([nn.Parameter(torch.ones(num_etypes))
for i in range(num_heads)])
self.relation_att = nn.ModuleList([TypedLinear(head_size, head_size, num_etypes)
for i in range(num_heads)])
self.relation_msg = nn.ModuleList([TypedLinear(head_size, head_size, num_etypes)
for i in range(num_heads)])
self.skip = nn.Parameter(torch.ones(num_ntypes))
self.drop = nn.Dropout(dropout)
if use_norm:
self.norm = nn.LayerNorm(head_size * num_heads)
if in_size != head_size * num_heads:
self.residual_w = nn.Parameter(torch.Tensor(in_size, head_size * num_heads))
nn.init.xavier_uniform_(self.residual_w)
def forward(self, g, x, ntype, etype, *, presorted=False):
"""Forward computation.
Parameters
----------
g : DGLGraph
The input graph.
x : torch.Tensor
A 2D tensor of node features. Shape: :math:`(|V|, D_{in})`.
ntype : torch.Tensor
An 1D integer tensor of node types. Shape: :math:`(|V|,)`.
etype : torch.Tensor
An 1D integer tensor of edge types. Shape: :math:`(|E|,)`.
presorted : bool, optional
Whether *both* the nodes and the edges of the input graph have been sorted by
their types. Forward on pre-sorted graph may be faster. Graphs created by
:func:`~dgl.to_homogeneous` automatically satisfy the condition.
Also see :func:`~dgl.reorder_graph` for manually reordering the nodes and edges.
Returns
-------
torch.Tensor
New node features. Shape: :math:`(|V|, D_{head} * N_{head})`.
"""
self.presorted = presorted
with g.local_scope():
k = self.linear_k(x, ntype, presorted).view(-1, self.num_heads, self.head_size)
q = self.linear_q(x, ntype, presorted).view(-1, self.num_heads, self.head_size)
v = self.linear_v(x, ntype, presorted).view(-1, self.num_heads, self.head_size)
g.srcdata['k'] = k
g.dstdata['q'] = q
g.srcdata['v'] = v
g.edata['etype'] = etype
g.apply_edges(self.message)
g.edata['m'] = g.edata['m'] * edge_softmax(g, g.edata['a']).unsqueeze(-1)
g.update_all(fn.copy_e('m', 'm'), fn.sum('m', 'h'))
h = g.dstdata['h'].view(-1, self.num_heads * self.head_size)
# target-specific aggregation
h = self.drop(self.linear_a(h, ntype, presorted))
alpha = torch.sigmoid(self.skip[ntype]).unsqueeze(-1)
if x.shape != h.shape:
h = h * alpha + (x @ self.residual_w) * (1 - alpha)
else:
h = h * alpha + x * (1 - alpha)
if self.use_norm:
h = self.norm(h)
return h
def message(self, edges):
"""Message function."""
a, m = [], []
etype = edges.data['etype']
k = torch.unbind(edges.src['k'], dim=1)
q = torch.unbind(edges.dst['q'], dim=1)
v = torch.unbind(edges.src['v'], dim=1)
for i in range(self.num_heads):
kw = self.relation_att[i](k[i], etype, self.presorted) # (E, O)
a.append((kw * q[i]).sum(-1) * self.relation_pri[i][etype] / self.sqrt_d) # (E,)
m.append(self.relation_msg[i](v[i], etype, self.presorted)) # (E, O)
return {'a' : torch.stack(a, dim=1), 'm' : torch.stack(m, dim=1)}
"""Torch Module for Relational graph convolution layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import functools
import numpy as np
import torch as th
from torch import nn
from .... import function as fn
from .. import utils
from ....base import DGLError
from .... import edge_subgraph
from ..linear import TypedLinear
class RelGraphConv(nn.Module):
r"""Relational graph convolution layer.
......@@ -55,22 +51,21 @@ class RelGraphConv(nn.Module):
Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.
num_rels : int
Number of relations. .
regularizer : str
Which weight regularizer to use "basis" or "bdd".
"basis" is short for basis-diagonal-decomposition.
"bdd" is short for block-diagonal-decomposition.
regularizer : str, optional
Which weight regularizer to use "basis" or "bdd":
- "basis" is short for basis-decomposition.
- "bdd" is short for block-diagonal-decomposition.
Default applies no regularization.
num_bases : int, optional
Number of bases. If is none, use number of relations. Default: ``None``.
Number of bases. Needed when ``regularizer`` is specified. Default: ``None``.
bias : bool, optional
True if bias is added. Default: ``True``.
activation : callable, optional
Activation function. Default: ``None``.
self_loop : bool, optional
True to include self loop message. Default: ``True``.
low_mem : bool, optional
True to use low memory implementation of relation message passing function. Default: False.
This option trades speed with memory consumption, and will slowdown the forward/backward.
Turn it on when you encounter OOM problem during training or evaluation. Default: ``False``.
dropout : float, optional
Dropout rate. Default: ``0.0``
layer_norm: float, optional
......@@ -86,9 +81,7 @@ class RelGraphConv(nn.Module):
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = th.ones(6, 10)
>>> conv = RelGraphConv(10, 2, 3, regularizer='basis', num_bases=2)
>>> conv.weight.shape
torch.Size([2, 10, 2])
>>> etype = th.tensor(np.array([0,1,2,0,1,2]).astype(np.int64))
>>> etype = th.tensor([0,1,2,0,1,2])
>>> res = conv(g, feat, etype)
>>> res
tensor([[ 0.3996, -2.3303],
......@@ -97,80 +90,32 @@ class RelGraphConv(nn.Module):
[ 2.1046, -2.8654],
[-0.4323, -0.1440],
[-0.1309, -1.0000]], grad_fn=<AddBackward0>)
>>> # One-hot input
>>> one_hot_feat = th.tensor(np.array([0,1,2,3,4,5]).astype(np.int64))
>>> res = conv(g, one_hot_feat, etype)
>>> res
tensor([[ 0.5925, 0.0985],
[-0.3953, 0.8408],
[-0.9819, 0.5284],
[-1.0085, -0.1721],
[ 0.5962, 1.2002],
[ 0.0365, -0.3532]], grad_fn=<AddBackward0>)
"""
def __init__(self,
in_feat,
out_feat,
num_rels,
regularizer="basis",
regularizer=None,
num_bases=None,
bias=True,
activation=None,
self_loop=True,
low_mem=False,
dropout=0.0,
layer_norm=False):
super(RelGraphConv, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.num_rels = num_rels
self.regularizer = regularizer
self.num_bases = num_bases
if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases <= 0:
self.num_bases = self.num_rels
super().__init__()
self.linear_r = TypedLinear(in_feat, out_feat, num_rels, regularizer, num_bases)
self.bias = bias
self.activation = activation
self.self_loop = self_loop
self.low_mem = low_mem
self.layer_norm = layer_norm
if regularizer == "basis":
# add basis weights
self.weight = nn.Parameter(th.Tensor(self.num_bases, self.in_feat, self.out_feat))
if self.num_bases < self.num_rels:
# linear combination coefficients
self.w_comp = nn.Parameter(th.Tensor(self.num_rels, self.num_bases))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
if self.num_bases < self.num_rels:
nn.init.xavier_uniform_(self.w_comp,
gain=nn.init.calculate_gain('relu'))
# message func
self.message_func = self.basis_message_func
elif regularizer == "bdd":
if in_feat % self.num_bases != 0 or out_feat % self.num_bases != 0:
raise ValueError(
'Feature size must be a multiplier of num_bases (%d).'
% self.num_bases
)
# add block diagonal weights
self.submat_in = in_feat // self.num_bases
self.submat_out = out_feat // self.num_bases
# assuming in_feat and out_feat are both divisible by num_bases
self.weight = nn.Parameter(th.Tensor(
self.num_rels, self.num_bases * self.submat_in * self.submat_out))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
# message func
self.message_func = self.bdd_message_func
else:
raise ValueError("Regularizer must be either 'basis' or 'bdd'")
# bias
if self.bias:
self.h_bias = nn.Parameter(th.Tensor(out_feat))
nn.init.zeros_(self.h_bias)
# TODO(minjie): consider remove those options in the future to make
# the module only about graph convolution.
# layer norm
if self.layer_norm:
self.layer_norm_weight = nn.LayerNorm(out_feat, elementwise_affine=True)
......@@ -178,121 +123,18 @@ class RelGraphConv(nn.Module):
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain('relu'))
self.dropout = nn.Dropout(dropout)
def basis_message_func(self, edges, etypes):
"""Message function for basis regularizer.
Parameters
----------
edges : dgl.EdgeBatch
Input to DGL message UDF.
etypes : torch.Tensor or list[int]
Edge type data. Could be either:
* An :math:`(|E|,)` dense tensor. Each element corresponds to the edge's type ID.
Preferred format if ``lowmem == False``.
* An integer list. The i^th element is the number of edges of the i^th type.
This requires the input graph to store edges sorted by their type IDs.
Preferred format if ``lowmem == True``.
"""
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = self.weight.view(self.num_bases,
self.in_feat * self.out_feat)
weight = th.matmul(self.w_comp, weight).view(
self.num_rels, self.in_feat, self.out_feat)
else:
weight = self.weight
h = edges.src['h']
device = h.device
if h.dtype == th.int64 and h.ndim == 1:
# Each element is the node's ID. Use index select: weight[etypes, h, :]
# The following is a faster version of it.
if isinstance(etypes, list):
etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
th.tensor(etypes, device=device))
idim = weight.shape[1]
weight = weight.view(-1, weight.shape[2])
flatidx = etypes * idim + h
msg = weight.index_select(0, flatidx)
elif self.low_mem:
# A more memory-friendly implementation.
# Calculate msg @ W_r before put msg into edge.
assert isinstance(etypes, list)
h_t = th.split(h, etypes)
msg = []
for etype in range(self.num_rels):
if h_t[etype].shape[0] == 0:
continue
msg.append(th.matmul(h_t[etype], weight[etype]))
msg = th.cat(msg)
else:
# Use batched matmult
if isinstance(etypes, list):
etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
th.tensor(etypes, device=device))
weight = weight.index_select(0, etypes)
msg = th.bmm(h.unsqueeze(1), weight).squeeze(1)
if 'norm' in edges.data:
msg = msg * edges.data['norm']
return {'msg': msg}
def bdd_message_func(self, edges, etypes):
"""Message function for block-diagonal-decomposition regularizer.
Parameters
----------
edges : dgl.EdgeBatch
Input to DGL message UDF.
etypes : torch.Tensor or list[int]
Edge type data. Could be either:
* An :math:`(|E|,)` dense tensor. Each element corresponds to the edge's type ID.
Preferred format if ``lowmem == False``.
* An integer list. The i^th element is the number of edges of the i^th type.
This requires the input graph to store edges sorted by their type IDs.
Preferred format if ``lowmem == True``.
"""
h = edges.src['h']
device = h.device
if h.dtype == th.int64 and h.ndim == 1:
raise TypeError('Block decomposition does not allow integer ID feature.')
if self.low_mem:
# A more memory-friendly implementation.
# Calculate msg @ W_r before put msg into edge.
assert isinstance(etypes, list)
h_t = th.split(h, etypes)
msg = []
for etype in range(self.num_rels):
if h_t[etype].shape[0] == 0:
continue
tmp_w = self.weight[etype].view(self.num_bases, self.submat_in, self.submat_out)
tmp_h = h_t[etype].view(-1, self.num_bases, self.submat_in)
msg.append(th.einsum('abc,bcd->abd', tmp_h, tmp_w).reshape(-1, self.out_feat))
msg = th.cat(msg)
else:
# Use batched matmult
if isinstance(etypes, list):
etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
th.tensor(etypes, device=device))
weight = self.weight.index_select(0, etypes).view(
-1, self.submat_in, self.submat_out)
node = h.view(-1, 1, self.submat_in)
msg = th.bmm(node, weight).view(-1, self.out_feat)
def message(self, edges):
"""Message function."""
m = self.linear_r(edges.src['h'], edges.data['etype'], self.presorted)
if 'norm' in edges.data:
msg = msg * edges.data['norm']
return {'msg': msg}
m = m * edges.data['norm']
return {'m' : m}
def forward(self, g, feat, etypes, norm=None):
def forward(self, g, feat, etypes, norm=None, *, presorted=False):
"""Forward computation.
Parameters
......@@ -300,88 +142,39 @@ class RelGraphConv(nn.Module):
g : DGLGraph
The graph.
feat : torch.Tensor
Input node features. Could be either
* :math:`(|V|, D)` dense tensor
* :math:`(|V|,)` int64 vector, representing the categorical values of each
node. It then treat the input feature as an one-hot encoding feature.
A 2D tensor of node features. Shape: :math:`(|V|, D_{in})`.
etypes : torch.Tensor or list[int]
Edge type data. Could be either
* An :math:`(|E|,)` dense tensor. Each element corresponds to the edge's type ID.
Preferred format if ``lowmem == False``.
* An integer list. The i^th element is the number of edges of the i^th type.
This requires the input graph to store edges sorted by their type IDs.
Preferred format if ``lowmem == True``.
An 1D integer tensor of edge types. Shape: :math:`(|E|,)`.
norm : torch.Tensor, optional
Edge normalizer. Could be either
* An :math:`(|E|, 1)` tensor storing the normalizer on each edge.
An 1D tensor of edge norm value. Shape: :math:`(|E|,)`.
presorted : bool, optional
Whether the edges of the input graph have been sorted by their types.
Forward on pre-sorted graph may be faster. Graphs created
by :func:`~dgl.to_homogeneous` automatically satisfy the condition.
Also see :func:`~dgl.reorder_graph` for sorting edges manually.
Returns
-------
torch.Tensor
New node features.
Notes
-----
Under the ``low_mem`` mode, DGL will sort the graph based on the edge types
and compute message passing one type at a time. DGL recommends sorts the
graph beforehand (and cache it if possible) and provides the integer list
format to the ``etypes`` argument. Use DGL's :func:`~dgl.to_homogeneous` API
to get a sorted homogeneous graph from a heterogeneous graph. Pass ``return_count=True``
to it to get the ``etypes`` in integer list.
New node features. Shape: :math:`(|V|, D_{out})`.
"""
if isinstance(etypes, th.Tensor):
if len(etypes) != g.num_edges():
raise DGLError('"etypes" tensor must have length equal to the number of edges'
' in the graph. But got {} and {}.'.format(
len(etypes), g.num_edges()))
if self.low_mem and not (feat.dtype == th.int64 and feat.ndim == 1):
# Low-mem optimization is not enabled for node ID input. When enabled,
# it first sorts the graph based on the edge types (the sorting will not
# change the node IDs). It then converts the etypes tensor to an integer
# list, where each element is the number of edges of the type.
# Sort the graph based on the etypes
sorted_etypes, index = th.sort(etypes)
g = edge_subgraph(g, index, relabel_nodes=False)
# Create a new etypes to be an integer list of number of edges.
pos = _searchsorted(sorted_etypes, th.arange(self.num_rels, device=g.device))
num = th.tensor([len(etypes)], device=g.device)
etypes = (th.cat([pos[1:], num]) - pos).tolist()
if norm is not None:
norm = norm[index]
self.presorted = presorted
with g.local_scope():
g.srcdata['h'] = feat
if norm is not None:
g.edata['norm'] = norm
if self.self_loop:
loop_message = utils.matmul_maybe_select(feat[:g.number_of_dst_nodes()],
self.loop_weight)
g.edata['etype'] = etypes
# message passing
g.update_all(functools.partial(self.message_func, etypes=etypes),
fn.sum(msg='msg', out='h'))
g.update_all(self.message, fn.sum('m', 'h'))
# apply bias and activation
node_repr = g.dstdata['h']
h = g.dstdata['h']
if self.layer_norm:
node_repr = self.layer_norm_weight(node_repr)
h = self.layer_norm_weight(h)
if self.bias:
node_repr = node_repr + self.h_bias
h = h + self.h_bias
if self.self_loop:
node_repr = node_repr + loop_message
h = h + feat[:g.num_dst_nodes()] @ self.loop_weight
if self.activation:
node_repr = self.activation(node_repr)
node_repr = self.dropout(node_repr)
return node_repr
_TORCH_HAS_SEARCHSORTED = getattr(th, 'searchsorted', None)
def _searchsorted(sorted_sequence, values):
# searchsorted is introduced to PyTorch in 1.6.0
if _TORCH_HAS_SEARCHSORTED:
return th.searchsorted(sorted_sequence, values)
else:
device = values.device
return th.from_numpy(np.searchsorted(sorted_sequence.cpu().numpy(),
values.cpu().numpy())).to(device)
h = self.activation(h)
h = self.dropout(h)
return h
"""Various commonly used linear modules"""
# pylint: disable= no-member, arguments-differ, invalid-name, W0235
import math
import torch
import torch.nn as nn
from ...ops import segment_mm, gather_mm
__all__ = ['TypedLinear']
class TypedLinear(nn.Module):
r"""Linear transformation according to types.
For each sample of the input batch :math:`x \in X`, apply linear transformation
:math:`xW_t`, where :math:`t` is the type of :math:`x`.
The module supports two regularization methods (basis-decomposition and
block-diagonal-decomposition) proposed by "`Modeling Relational Data
with Graph Convolutional Networks <https://arxiv.org/abs/1703.06103>`__"
The basis regularization decomposes :math:`W_t` by:
.. math::
W_t^{(l)} = \sum_{b=1}^B a_{tb}^{(l)}V_b^{(l)}
where :math:`B` is the number of bases, :math:`V_b^{(l)}` are linearly combined
with coefficients :math:`a_{tb}^{(l)}`.
The block-diagonal-decomposition regularization decomposes :math:`W_t` into :math:`B`
block-diagonal matrices. We refer to :math:`B` as the number of bases:
.. math::
W_t^{(l)} = \oplus_{b=1}^B Q_{tb}^{(l)}
where :math:`B` is the number of bases, :math:`Q_{tb}^{(l)}` are block
bases with shape :math:`R^{(d^{(l+1)}/B)\times(d^{l}/B)}`.
Parameters
----------
in_size : int
Input feature size.
out_size : int
Output feature size.
num_types : int
Total number of types.
regularizer : str, optional
Which weight regularizer to use "basis" or "bdd":
- "basis" is short for basis-decomposition.
- "bdd" is short for block-diagonal-decomposition.
Default applies no regularization.
num_bases : int, optional
Number of bases. Needed when ``regularizer`` is specified. Typically smaller
than ``num_types``.
Default: ``None``.
Examples
--------
No regularization.
>>> from dgl.nn import TypedLinear
>>> import torch
>>>
>>> x = torch.randn(100, 32)
>>> x_type = torch.randint(0, 5, (100,))
>>> m = TypedLinear(32, 64, 5)
>>> y = m(x, x_type)
>>> print(y.shape)
torch.Size([100, 64])
With basis regularization
>>> x = torch.randn(100, 32)
>>> x_type = torch.randint(0, 5, (100,))
>>> m = TypedLinear(32, 64, 5, regularizer='basis', num_bases=4)
>>> y = m(x, x_type)
>>> print(y.shape)
torch.Size([100, 64])
"""
def __init__(self, in_size, out_size, num_types,
regularizer=None, num_bases=None):
super().__init__()
self.in_size = in_size
self.out_size = out_size
self.num_types = num_types
if regularizer is None:
self.W = nn.Parameter(torch.Tensor(num_types, in_size, out_size))
elif regularizer == 'basis':
if num_bases is None:
raise ValueError('Missing "num_bases" for basis regularization.')
self.W = nn.Parameter(torch.Tensor(num_bases, in_size, out_size))
self.coeff = nn.Parameter(torch.Tensor(num_types, num_bases))
self.num_bases = num_bases
elif regularizer == 'bdd':
if num_bases is None:
raise ValueError('Missing "num_bases" for bdd regularization.')
if in_size % num_bases != 0 or out_size % num_bases != 0:
raise ValueError(
'Input and output sizes must be divisible by num_bases.'
)
self.submat_in = in_size // num_bases
self.submat_out = out_size // num_bases
self.W = nn.Parameter(torch.Tensor(
num_types, num_bases * self.submat_in * self.submat_out))
self.num_bases = num_bases
else:
raise ValueError(
f'Supported regularizer options: "basis", "bdd", but got {regularizer}')
self.regularizer = regularizer
self.reset_parameters()
def reset_parameters(self):
"""Reset parameters"""
with torch.no_grad():
# Follow torch.nn.Linear 's initialization to use kaiming_uniform_ on in_size
if self.regularizer is None:
nn.init.uniform_(self.W, -1/math.sqrt(self.in_size), 1/math.sqrt(self.in_size))
elif self.regularizer == 'basis':
nn.init.uniform_(self.W, -1/math.sqrt(self.in_size), 1/math.sqrt(self.in_size))
nn.init.xavier_uniform_(self.coeff, gain=nn.init.calculate_gain('relu'))
elif self.regularizer == 'bdd':
nn.init.uniform_(self.W, -1/math.sqrt(self.submat_in), 1/math.sqrt(self.submat_in))
else:
raise ValueError(
f'Supported regularizer options: "basis", "bdd", but got {regularizer}')
def get_weight(self):
"""Get type-wise weight"""
if self.regularizer is None:
return self.W
elif self.regularizer == 'basis':
W = self.W.view(self.num_bases, self.in_size * self.out_size)
return (self.coeff @ W).view(self.num_types, self.in_size, self.out_size)
elif self.regularizer == 'bdd':
return self.W
else:
raise ValueError(
f'Supported regularizer options: "basis", "bdd", but got {regularizer}')
def forward(self, x, x_type, sorted_by_type=False):
"""Forward computation.
Parameters
----------
x : torch.Tensor
A 2D input tensor. Shape: (N, D1)
x_type : torch.Tensor
A 1D integer tensor storing the type of the elements in ``x`` with one-to-one
correspondenc. Shape: (N,)
sorted_by_type : bool, optional
Whether the inputs have been sorted by the types. Forward on pre-sorted inputs may
be faster.
Returns
-------
y : torch.Tensor
The transformed output tensor. Shape: (N, D2)
"""
w = self.get_weight()
if self.regularizer == 'bdd':
w = w.index_select(0, x_type).view(-1, self.submat_in, self.submat_out)
x = x.view(-1, 1, self.submat_in)
return torch.bmm(x, w).view(-1, self.out_size)
elif sorted_by_type:
pos_l = torch.searchsorted(x_type, torch.arange(self.num_types, device=x.device))
pos_r = torch.cat([pos_l[1:], torch.tensor([len(x_type)], device=x.device)])
seglen = (pos_r - pos_l).cpu() # XXX(minjie): cause device synchronize
return segment_mm(x, w, seglen_a=seglen)
else:
return gather_mm(x, w, idx_b=x_type)
def __repr__(self):
if self.regularizer is None:
return (f'TypedLinear(in_size={self.in_size}, out_size={self.out_size}, '
f'num_types={self.num_types})')
else:
return (f'TypedLinear(in_size={self.in_size}, out_size={self.out_size}, '
f'num_types={self.num_types}, regularizer={self.regularizer}, '
f'num_bases={self.num_bases})')
"""dgl gather_mm operator module."""
from ..backend import gather_mm as gather_mm_internal
from ..backend import segment_mm as segment_mm_internal
from .. import backend as F
__all__ = ['gather_mm', 'segment_mm']
__all__ = ['gather_mm']
def segment_mm(lhs_data, rhs_data, seglen_lhs):
r""" Performs matrix multiplication according to segments.
Suppose ``seglen_lhs == [10, 5, 0, 3]``, the operator will perform
four matrix multiplications:
lhs_data[0:10] @ rhs_data[0], lhs_data[10:15] @ rhs_data[1],
lhs_data[15:15] @ rhs_data[2], lhs_data[15:18] @ rhs_data[3]
Parameters
----------
lhs_data : tensor
The left operand, 2-D tensor of shape (N, D1)
rhs_data : tensor
The right operand, 2-D tensor of shape (R * D1, D2)
seglen_lhs : tensor
An integer tensor of shape (R,). Each element is the length of segments
of input ``lhs_data``. The summation of all elements must be equal to N.
Returns
-------
tensor
The output dense matrix of shape (N, D2)
"""
return segment_mm_internal(lhs_data, rhs_data, seglen_lhs)
def gather_mm(lhs_data, rhs_data, idx_lhs = None, idx_rhs = None):
def gather_mm(a, b, *, idx_b):
r"""Gather data according to the given indices and perform matrix multiplication.
Let the result tensor be C, the operator conducts the following computation:
If both idx_lhs and idx_rhs are not none:
c[i] = lhs_data[idx_lhs[i]] @ rhs_data[idx_rhs[i]]
, where len(C) == len(idx_lhs) == len(idx_rhs)
If idx_lhs is given but not idx_rhs:
c[i] = rhs_data[idx_lhs[i]] @ rhs_data[i]
, where len(C) == len(idx_lhs)
If idx_rhs is given but not idx_lhs:
Let the result tensor be ``c``, the operator conducts the following computation:
c[i] = lhs_data[i] @ rhs_data[idx_rhs[i]]
, where len(C) == len(idx_rhs)
c[i] = a[i] @ b[idx_b[i]]
, where len(c) == len(idx_b)
Parameters
----------
lhs_data : tensor
2-D tensor of shape (N, D1)
rhs_data : tensor
3-D tensor of shape (R, D1, D2)
idx_lhs : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,).
idx_rhs : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,).
a : Tensor
A 2-D tensor of shape ``(N, D1)``
b : Tensor
A 3-D tensor of shape ``(R, D1, D2)``
idx_b : Tensor, optional
An 1-D integer tensor of shape ``(N,)``.
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
The output dense matrix of shape ``(N, D2)``
"""
return gather_mm_internal(lhs_data, rhs_data, idx_lhs, idx_rhs)
N, D1 = F.shape(a)
R, _, D2 = F.shape(b)
if N > 1000000 or D1 > 8 or D2 > 8:
# Use segment_mm for large workload
import torch
sorted_idx_b, perm = torch.sort(idx_b)
_, rev_perm = torch.sort(perm)
sorted_a = torch.index_select(a, 0, perm)
pos_l = torch.searchsorted(sorted_idx_b, torch.arange(R, device=a.device))
pos_r = torch.cat([pos_l[1:], torch.tensor([len(idx_b)], device=a.device)])
seglen = (pos_r - pos_l).cpu() # XXX(minjie): cause device synchronize
return torch.index_select(F.segment_mm(sorted_a, b, seglen), 0, rev_perm)
else:
return F.gather_mm(a, b, None, idx_b)
......@@ -3,6 +3,7 @@
from ..base import DGLError
from .. import backend as F
__all__ = ['segment_reduce', 'segment_softmax', 'segment_mm']
def segment_reduce(seglen, value, reducer='sum'):
"""Segment reduction operator.
......@@ -98,3 +99,29 @@ def segment_softmax(seglen, value):
value = F.exp(value - F.repeat(value_max, seglen, dim=0))
value_sum = segment_reduce(seglen, value, reducer='sum')
return value / F.repeat(value_sum, seglen, dim=0)
def segment_mm(a, b, seglen_a):
r""" Performs matrix multiplication according to segments.
Suppose ``seglen_a == [10, 5, 0, 3]``, the operator will perform
four matrix multiplications::
a[0:10] @ b[0], a[10:15] @ b[1],
a[15:15] @ b[2], a[15:18] @ b[3]
Parameters
----------
a : Tensor
The left operand, 2-D tensor of shape ``(N, D1)``
b : Tensor
The right operand, 3-D tensor of shape ``(R, D1, D2)``
seglen_a : Tensor
An integer tensor of shape ``(R,)``. Each element is the length of segments
of input ``a``. The summation of all elements must be equal to ``N``.
Returns
-------
Tensor
The output dense matrix of shape ``(N, D2)``
"""
return F.segment_mm(a, b, seglen_a)
......@@ -389,108 +389,43 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
return out, (list_arg_u, list_arg_e, list_arg_u_ntype, list_arg_e_etype)
def _segment_mm(A, B, out, seglen_A, a_trans=False, b_trans=False):
r""" Dense Matrix Multiplication interface. It multiplies dense tensor A
and dense tensor B according to relation types. A is sorted and concatenated
according to relation types.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
2-D tensor of shape (R * D1, D2)
seglen_A : Tensor
An integer tensor of shape (R,). Each element is the length of segments
of input ``A``. The summation of all elements must be equal to N.
a_trans : bool
Indicates whether matrix A needs to be tranposed
b_trans : bool
Indicates whether matrix B needs to be tranposed
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
# TODO(Israt): Add CPU support. Currently, only handles GPU code
def _segment_mm(A, B, out, seglen_A, b_trans=False):
"""Invoke the C API of segment_mm."""
_CAPI_DGLKernelSEGMENTMM(to_dgl_nd(A),
to_dgl_nd(B),
to_dgl_nd_for_write(out),
to_dgl_nd(seglen_A),
a_trans, b_trans)
False, b_trans)
return out
def _gather_mm(A, B, out, num_rel, idx_a=None, idx_b=None):
r""" Generalized Dense Matrix Multiplication interface. It multiplies
tensor A and B according to relation types and outputs in out. B is a
concatenated tensor across relation types. A is unsorted and the
relation type is fetched from param etypes.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
2-D tensor of shape (R * D1, D2)
idx_a : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,)
idx_b : Tensor, optional
If specified, must be a 1-D integer tensor of shape (N,)
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
# TODO(Israt): Add CPU support. Currently, only handles GPU code
def _segment_mm_backward_B(A, dC, dB, seglen):
"""Invoke the C API of the backward of segment_mm on B."""
_CAPI_DGLKernelSEGMENTMMBackwardB(
to_dgl_nd(A),
to_dgl_nd(dC),
to_dgl_nd_for_write(dB),
to_dgl_nd(seglen))
return dB
def _gather_mm(A, B, out, idx_a=None, idx_b=None):
r"""Invoke the C API of the gather_mm operator."""
_CAPI_DGLKernelGATHERMM(to_dgl_nd(A),
to_dgl_nd(B),
to_dgl_nd_for_write(out),
to_dgl_nd(idx_a),
to_dgl_nd(idx_b),
num_rel)
to_dgl_nd(idx_b))
return out
def _gather_mm_scatter(A, B, out, num_rel, idx_a=None, idx_b=None, idx_c=None,
a_trans=False, b_trans=False):
r""" Generalized Dense Matrix Multiplication interface. It multiplies
tensor A and B according to relation types and outputs in out. B is a
concatenated tensor across relation types. A is unsorted and the
relation type is fetched from param etypes.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
2-D tensor of shape (R * D1, D2)
idx_a : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,)
idx_b : Tensor, optional
If specified, must be a 1-D integer tensor of shape (N,)
idx_c : Tensor, optional
If specified, must be a 1-D integer tensor of shape (N,)
A_trans : bool
Indicates whether matrix A needs to be tranposed
B_trans : bool
Indicates whether matrix B needs to be tranposed
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
# TODO(Israt): Add CPU support. Currently, only handles GPU code
_CAPI_DGLKernelGATHERMMSCATTER(to_dgl_nd(A),
def _gather_mm_scatter(A, B, out, idx_a=None, idx_b=None, idx_c=None):
r"""Invoke the C API of the gather_mm_scatter operator."""
_CAPI_DGLKernelGATHERMMSCATTER(
to_dgl_nd(A),
to_dgl_nd(B),
to_dgl_nd_for_write(out),
to_dgl_nd(idx_a),
to_dgl_nd(idx_b),
to_dgl_nd(idx_c),
num_rel, a_trans, b_trans)
to_dgl_nd(idx_c))
return out
......
......@@ -22,7 +22,7 @@ import scipy.sparse as sparse
import scipy.sparse.linalg
from .._ffi.function import _init_api
from ..base import dgl_warning, DGLError
from ..base import dgl_warning, DGLError, NID, EID
from .. import convert
from ..heterograph import DGLHeteroGraph, DGLBlock
from ..heterograph_index import create_metagraph_index, create_heterograph_from_relations
......@@ -2973,7 +2973,7 @@ def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'):
return new_g
def reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='src',
def reorder_graph(g, node_permute_algo=None, edge_permute_algo='src',
store_ids=True, permute_config=None):
r"""Return a new graph with nodes and edges re-ordered/re-labeled
according to the specified permute algorithm.
......@@ -2994,7 +2994,7 @@ def reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='src',
g : DGLGraph
The homogeneous graph.
node_permute_algo: str, optional
The permutation algorithm to re-order nodes. Options are ``rcmk`` or
The permutation algorithm to re-order nodes. If given, the options are ``rcmk`` or
``metis`` or ``custom``. ``rcmk`` is the default value.
* ``rcmk``: Use the `Reverse Cuthill–McKee <https://docs.scipy.org/doc/scipy/reference/
......@@ -3014,6 +3014,8 @@ def reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='src',
* ``src``: Edges are arranged according to their source nodes.
* ``dst``: Edges are arranged according to their destination nodes.
* ``custom``: Edges are arranged according to the user-provided edge permutation
array (provided in :attr:`permute_config`).
store_ids: bool, optional
If True, DGL will store the original node and edge IDs in the ndata and edata
of the resulting graph under name ``dgl.NID`` and ``dgl.EID``, respectively.
......@@ -3023,9 +3025,12 @@ def reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='src',
* For ``rcmk``, this argument is not required.
* For ``metis``, users should specify the number of partitions ``k`` (e.g.,
``permute_config={'k':10}`` to partition the graph to 10 clusters).
* For ``custom``, users should provide a node permutation array ``nodes_perm``.
The array must be an integer list or a tensor with the same device of the
input graph.
* For ``custom`` node reordering, users should provide a node permutation
array ``nodes_perm``. The array must be an integer list or a tensor with
the same device of the input graph.
* For ``custom`` edge reordering, users should provide an edge permutation
array ``edges_perm``. The array must be an integer list or a tensor with
the same device of the input graph.
Returns
-------
......@@ -3118,49 +3123,83 @@ def reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='src',
[2],
[1]]), '_ID': tensor([0, 2, 4, 1, 3])}
Reorder according to node and edge types:
>>> ntype = ... # some node type array
>>> etype = ... # some edge type array
>>> sorted_ntype, idx_nt = torch.sort(ntype)
>>> sorted_etype, idx_et = torch.sort(etype)
>>> rg = dgl.reorder_graph(g, node_permute_algo='custom', edge_permute_algo='custom',
... permute_config={'nodes_perm' : idx_nt.to(g.idtype),
... 'edges_perm' : idx_et.to(g.idtype)})
"""
# sanity checks
if not g.is_homogeneous:
raise DGLError("Homograph is supported only.")
raise DGLError("Only homogeneous graphs are supported.")
expected_node_algo = ['rcmk', 'metis', 'custom']
if node_permute_algo not in expected_node_algo:
if node_permute_algo is not None and node_permute_algo not in expected_node_algo:
raise DGLError("Unexpected node_permute_algo is specified: {}. Expected algos: {}".format(
node_permute_algo, expected_node_algo))
expected_edge_algo = ['src', 'dst']
expected_edge_algo = ['src', 'dst', 'custom']
if edge_permute_algo not in expected_edge_algo:
raise DGLError("Unexpected edge_permute_algo is specified: {}. Expected algos: {}".format(
edge_permute_algo, expected_edge_algo))
# generate nodes permutation
g.edata['__orig__'] = F.arange(0, g.num_edges(), g.idtype, g.device)
# reorder nodes
if node_permute_algo == 'rcmk':
nodes_perm = rcmk_perm(g)
rg = subgraph.node_subgraph(g, nodes_perm, store_ids=False)
elif node_permute_algo == 'metis':
if permute_config is None or 'k' not in permute_config:
raise DGLError(
"Partition parts 'k' is required for metis. Please specify in permute_config.")
nodes_perm = metis_perm(g, permute_config['k'])
else:
rg = subgraph.node_subgraph(g, nodes_perm, store_ids=False)
elif node_permute_algo == 'custom':
if permute_config is None or 'nodes_perm' not in permute_config:
raise DGLError(
"permute_algo is specified as custom, but no 'nodes_perm' is specified in \
"node_permute_algo is specified as custom, but no 'nodes_perm' is specified in \
permute_config.")
nodes_perm = permute_config['nodes_perm']
if len(nodes_perm) != g.num_nodes():
raise DGLError("Length of passed in nodes_perm[{}] does not \
match graph num_nodes[{}].".format(len(nodes_perm), g.num_nodes()))
raise DGLError("Length of 'nodes_perm' ({}) does not \
match graph num_nodes ({}).".format(len(nodes_perm), g.num_nodes()))
rg = subgraph.node_subgraph(g, nodes_perm, store_ids=False)
else:
nodes_perm = F.arange(0, g.num_nodes(), g.idtype, g.device)
rg = g.clone()
# reorder nodes
rg = subgraph.node_subgraph(g, nodes_perm, store_ids=store_ids)
if store_ids:
rg.ndata[NID] = F.copy_to(F.tensor(nodes_perm, g.idtype), g.device)
g.edata.pop('__orig__')
# reorder edges
if edge_permute_algo == 'src':
# the output graph of dgl.node_subgraph() is ordered/labeled
# according to src already. Nothing needs to do.
pass
edges_perm = np.argsort(F.asnumpy(rg.edges()[0]))
rg = subgraph.edge_subgraph(
rg, edges_perm, relabel_nodes=False, store_ids=False)
elif edge_permute_algo == 'dst':
edges_perm = np.argsort(F.asnumpy(rg.edges()[1]))
rg = subgraph.edge_subgraph(
rg, edges_perm, relabel_nodes=False, store_ids=store_ids)
rg, edges_perm, relabel_nodes=False, store_ids=False)
elif edge_permute_algo == 'custom':
if permute_config is None or 'edges_perm' not in permute_config:
raise DGLError(
"edge_permute_algo is specified as custom, but no 'edges_perm' is specified in \
permute_config.")
edges_perm = permute_config['edges_perm']
# First revert the edge reorder caused by node reorder and then
# apply user-provided edge permutation
rev_id = F.argsort(rg.edata['__orig__'], 0, False)
edges_perm = F.astype(F.gather_row(rev_id, edges_perm), rg.idtype)
rg = subgraph.edge_subgraph(
rg, edges_perm, relabel_nodes=False, store_ids=False)
if store_ids:
rg.edata[EID] = rg.edata.pop('__orig__')
return rg
......
......@@ -23,108 +23,114 @@ namespace aten {
} while (0)
/*! \brief Generalized segmentMM. */
/*! \brief Generalized SegmentMM. */
template <int XPU, typename IdType, int bits>
void segmentMM(const NDArray A,
void SegmentMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray seglen_A,
bool a_trans, bool b_trans) {
SWITCH_BITS(bits, DType, {
LOG(FATAL) << "Unsupported CPU kernel for SegmentMM.";
});
}
template <int XPU, typename IdType, int bits>
void SegmentMMBackwardB(const NDArray A,
const NDArray dC,
NDArray dB,
const NDArray seglen) {
LOG(FATAL) << "Unsupported CPU kernel for SegmentMMBackwardB.";
}
/*! \brief Generalized GatherMM. */
template <int XPU, typename IdType, int bits>
void gatherMM(const NDArray A,
void GatherMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b,
const int num_rel) {
SWITCH_BITS(bits, DType, {
const NDArray idx_b) {
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
});
}
/*! \brief Generalized GatherMM_scatter. */
template <int XPU, typename IdType, int bits>
void gatherMM_scatter(const NDArray A,
void GatherMMScatter(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b,
const NDArray idx_c,
const int num_rel,
bool a_trans, bool b_trans) {
SWITCH_BITS(bits, DType, {
const NDArray idx_c) {
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
});
}
template void gatherMM<kDLCPU, int32_t, 16>(
template void GatherMM<kDLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLCPU, int64_t, 16>(
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLCPU, int32_t, 32>(
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLCPU, int64_t, 32>(
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLCPU, int32_t, 64>(
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
template void gatherMM<kDLCPU, int64_t, 64>(
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel);
const NDArray idx_a, const NDArray idx_b);
template void gatherMM_scatter<kDLCPU, int32_t, 16>(
template void GatherMMScatter<kDLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLCPU, int64_t, 16>(
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLCPU, int32_t, 32>(
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLCPU, int64_t, 32>(
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLCPU, int32_t, 64>(
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
template void gatherMM_scatter<kDLCPU, int64_t, 64>(
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c,
const int num_rel, bool a_trans, bool b_trans);
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void segmentMM<kDLCPU, int32_t, 16>(
template void SegmentMM<kDLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int64_t, 16>(
template void SegmentMM<kDLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int32_t, 32>(
template void SegmentMM<kDLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int64_t, 32>(
template void SegmentMM<kDLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int32_t, 64>(
template void SegmentMM<kDLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int64_t, 64>(
template void SegmentMM<kDLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDLCPU, int32_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int64_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int32_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int64_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int32_t, 64>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int64_t, 64>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
} // namespace aten
} // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment