"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b2add10d132377de0a935faa5e7cc42b6320fa53"
Unverified Commit 0227ddfb authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[NN] Rework RelGraphConv and HGTConv (#3742)

* WIP: TypedLinear and new RelGraphConv

* wip

* further simplify RGCN

* a bunch of tweak for performance; add basic cpu support

* update on segmm

* wip: segment.cu

* new backward kernel works

* fix a bunch of bugs in kernel; leave idx_a for future

* add nn test for typed_linear

* rgcn nn test

* bugfix in corner case; update RGCN README

* doc

* fix cpp lint

* fix lint

* fix ut

* wip: hgtconv; presorted flag for rgcn

* hgt code and ut; WIP: some fix on reorder graph

* better typed linear init

* fix ut

* fix lint; add docstring
parent 4f00d5ac
...@@ -16,22 +16,19 @@ class RGCN(nn.Module): ...@@ -16,22 +16,19 @@ class RGCN(nn.Module):
num_rels, num_rels,
num_bases, num_bases,
num_hidden_layers, num_hidden_layers,
dropout, dropout):
lowmem):
super(RGCN, self).__init__() super(RGCN, self).__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# i2h # i2h
self.layers.append(RelGraphConv(num_nodes, n_hidden, num_rels, "basis", self.layers.append(RelGraphConv(num_nodes, n_hidden, num_rels, "basis",
num_bases, activation=F.relu, dropout=dropout, num_bases, activation=F.relu, dropout=dropout))
low_mem=lowmem))
# h2h # h2h
for i in range(num_hidden_layers): for i in range(num_hidden_layers):
self.layers.append(RelGraphConv(n_hidden, n_hidden, num_rels, "basis", self.layers.append(RelGraphConv(n_hidden, n_hidden, num_rels, "basis",
num_bases, activation=F.relu, dropout=dropout, num_bases, activation=F.relu, dropout=dropout))
low_mem=lowmem))
# o2h # o2h
self.layers.append(RelGraphConv(n_hidden, num_classes, num_rels, "basis", self.layers.append(RelGraphConv(n_hidden, num_classes, num_rels, "basis",
num_bases, activation=None, low_mem=lowmem)) num_bases, activation=None))
def forward(self, g, h, r, norm): def forward(self, g, h, r, norm):
for layer in self.layers: for layer in self.layers:
...@@ -40,9 +37,8 @@ class RGCN(nn.Module): ...@@ -40,9 +37,8 @@ class RGCN(nn.Module):
@utils.benchmark('time', 300) @utils.benchmark('time', 300)
@utils.parametrize('data', ['aifb']) @utils.parametrize('data', ['aifb'])
@utils.parametrize('lowmem', [True, False])
@utils.parametrize('use_type_count', [True, False]) @utils.parametrize('use_type_count', [True, False])
def track_time(data, lowmem, use_type_count): def track_time(data, use_type_count):
# args # args
if data == 'aifb': if data == 'aifb':
num_bases = -1 num_bases = -1
...@@ -108,8 +104,7 @@ def track_time(data, lowmem, use_type_count): ...@@ -108,8 +104,7 @@ def track_time(data, lowmem, use_type_count):
num_rels, num_rels,
num_bases, num_bases,
0, 0,
0, 0).to(device)
lowmem).to(device)
optimizer = torch.optim.Adam(model.parameters(), optimizer = torch.optim.Adam(model.parameters(),
lr=1e-2, lr=1e-2,
......
...@@ -295,7 +295,7 @@ TransR ...@@ -295,7 +295,7 @@ TransR
:members: rel_emb, rel_project, forward, reset_parameters :members: rel_emb, rel_project, forward, reset_parameters
:show-inheritance: :show-inheritance:
Heterogeneous Graph Convolution Module Heterogeneous Learning Module
---------------------------------------- ----------------------------------------
HeteroGraphConv HeteroGraphConv
...@@ -319,9 +319,17 @@ HeteroEmbedding ...@@ -319,9 +319,17 @@ HeteroEmbedding
.. _apinn-pytorch-util: .. _apinn-pytorch-util:
Utility Modules Utility Modules
---------------------------------------- ----------------------------------------
TypedLinear
----------------------------------------
.. autoclass:: dgl.nn.pytorch.TypedLinear
:members: forward
:show-inheritance:
Sequential Sequential
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -18,57 +18,36 @@ pip install rdflib pandas ...@@ -18,57 +18,36 @@ pip install rdflib pandas
Example code was tested with rdflib 4.2.2 and pandas 0.23.4 Example code was tested with rdflib 4.2.2 and pandas 0.23.4
### Entity Classification ### Entity Classification
AIFB: accuracy 96.29% (3 runs, DGL), 95.83% (paper)
```
python entity.py -d aifb --l2norm 0 --gpu 0
```
MUTAG: accuracy 72.55% (3 runs, DGL), 73.23% (paper) For AIFB, MUTAG, BGS and AM,
``` ```
python entity.py -d aifb --wd 0 --gpu 0
python entity.py -d mutag --n-bases 30 --gpu 0 python entity.py -d mutag --n-bases 30 --gpu 0
```
BGS: accuracy 89.70% (3 runs, DGL), 83.10% (paper)
```
python entity.py -d bgs --n-bases 40 --gpu 0 python entity.py -d bgs --n-bases 40 --gpu 0
``` python entity.py -d am --n-bases 40 --n-hidden 10 --gpu 0
AM: accuracy 89.56% (3 runs, DGL), 89.29% (paper)
```
python entity.py -d am --n-bases 40 --n-hidden 10
``` ```
### Entity Classification with minibatch ### Entity Classification with minibatch
AIFB: accuracy avg(5 runs) 91.10%, best 97.22% (DGL) For AIFB, MUTAG, BGS and AM,
``` ```
python entity_sample.py -d aifb --l2norm 0 --gpu 0 --fanout='20,20' --batch-size 128 python entity_sample.py -d aifb --wd 0 --gpu 0 --fanout='20,20' --batch-size 128
python entity_sample.py -d mutag --n-bases 30 --gpu 0 --batch-size 64 --fanout='-1,-1' --use-self-loop --n-epochs 20 --dropout 0.5
python entity_sample.py -d bgs --n-bases 40 --gpu 0 --fanout='-1,-1' --n-epochs=16 --batch-size=16 --dropout 0.3
python entity_sample.py -d am --n-bases 40 --gpu 0 --fanout='35,35' --batch-size 64 --n-hidden 16 --use-self-loop --n-epochs=20 --dropout 0.7
``` ```
MUTAG: accuracy avg(10 runs) 66.47%, best 72.06% (DGL) ### Entity Classification on multiple GPUs
```
python entity_sample.py -d mutag --n-bases 30 --gpu 0 --batch-size 64 --fanout "-1, -1" --use-self-loop --n-epochs 20 --sparse-lr 0.01 --dropout 0.5
```
BGS: accuracy avg(5 runs) 84.83%, best 89.66% (DGL)
```
python entity_sample.py -d bgs --n-bases 40 --gpu 0 --fanout "-1, -1" --n-epochs=16 --batch-size=16 --sparse-lr 0.05 --dropout 0.3
```
AM: accuracy avg(5 runs) 88.58%, best 89.90% (DGL)
```
python entity_sample.py -d am --n-bases 40 --gpu 0 --fanout '35,35' --batch-size 64 --n-hidden 16 --use-self-loop --n-epochs=20 --sparse-lr 0.02 --dropout 0.7
```
To use multiple GPUs, replace `entity_sample.py` with `entity_sample_multi_gpu.py` and specify To use multiple GPUs, replace `entity_sample.py` with `entity_sample_multi_gpu.py` and specify
multiple GPU IDs separated by comma, e.g., `--gpu 0,1`. multiple GPU IDs separated by comma, e.g., `--gpu 0,1`.
### Link Prediction ### Link Prediction
FB15k-237: MRR 0.163 (DGL), 0.158 (paper) FB15k-237 in RAW-MRR
``` ```
python link.py --gpu 0 --eval-protocol raw python link.py --gpu 0 --eval-protocol raw
``` ```
FB15k-237: Filtered-MRR 0.247 FB15k-237 in Filtered-MRR
``` ```
python link.py --gpu 0 --eval-protocol filtered python link.py --gpu 0 --eval-protocol filtered
``` ```
""" """
Differences compared to tkipf/relation-gcn Differences compared to tkipf/relation-gcn
* l2norm applied to all weights * weight decay applied to all weights
* remove nodes that won't be touched
""" """
import argparse import argparse
import torch as th import torch as th
import torch.nn.functional as F import torch.nn.functional as F
...@@ -17,13 +15,7 @@ def main(args): ...@@ -17,13 +15,7 @@ def main(args):
g, num_rels, num_classes, labels, train_idx, test_idx, target_idx = load_data( g, num_rels, num_classes, labels, train_idx, test_idx, target_idx = load_data(
args.dataset, get_norm=True) args.dataset, get_norm=True)
num_nodes = g.num_nodes() model = RGCN(g.num_nodes(),
# Since the nodes are featureless, learn node embeddings from scratch
# This requires passing the node IDs to the model.
feats = th.arange(num_nodes)
model = RGCN(num_nodes,
args.n_hidden, args.n_hidden,
num_classes, num_classes,
num_rels, num_rels,
...@@ -33,16 +25,15 @@ def main(args): ...@@ -33,16 +25,15 @@ def main(args):
device = th.device(args.gpu) device = th.device(args.gpu)
else: else:
device = th.device('cpu') device = th.device('cpu')
feats = feats.to(device)
labels = labels.to(device) labels = labels.to(device)
model = model.to(device) model = model.to(device)
g = g.to(device) g = g.int().to(device)
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.l2norm) optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.wd)
model.train() model.train()
for epoch in range(50): for epoch in range(100):
logits = model(g, feats) logits = model(g)
logits = logits[target_idx] logits = logits[target_idx]
loss = F.cross_entropy(logits[train_idx], labels[train_idx]) loss = F.cross_entropy(logits[train_idx], labels[train_idx])
optimizer.zero_grad() optimizer.zero_grad()
...@@ -56,7 +47,7 @@ def main(args): ...@@ -56,7 +47,7 @@ def main(args):
model.eval() model.eval()
with th.no_grad(): with th.no_grad():
logits = model(g, feats) logits = model(g)
logits = logits[target_idx] logits = logits[target_idx]
test_acc = accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item() test_acc = accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item()
print("Test Accuracy: {:.4f}".format(test_acc)) print("Test Accuracy: {:.4f}".format(test_acc))
...@@ -72,8 +63,8 @@ if __name__ == '__main__': ...@@ -72,8 +63,8 @@ if __name__ == '__main__':
parser.add_argument("-d", "--dataset", type=str, required=True, parser.add_argument("-d", "--dataset", type=str, required=True,
choices=['aifb', 'mutag', 'bgs', 'am'], choices=['aifb', 'mutag', 'bgs', 'am'],
help="dataset to use") help="dataset to use")
parser.add_argument("--l2norm", type=float, default=5e-4, parser.add_argument("--wd", type=float, default=5e-4,
help="l2 norm coef") help="weight decay")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
""" """
Differences compared to tkipf/relation-gcn Differences compared to tkipf/relation-gcn
* l2norm applied to all weights * weight decay applied to all weights
* remove nodes that won't be touched * remove nodes that won't be touched
""" """
import argparse import argparse
...@@ -13,7 +13,7 @@ from torchmetrics.functional import accuracy ...@@ -13,7 +13,7 @@ from torchmetrics.functional import accuracy
from tqdm import tqdm from tqdm import tqdm
from entity_utils import load_data from entity_utils import load_data
from model import RelGraphEmbedLayer, RGCN from model import RGCN
def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=False): def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=False):
fanouts = [int(fanout) for fanout in args.fanout.split(',')] fanouts = [int(fanout) for fanout in args.fanout.split(',')]
...@@ -54,21 +54,6 @@ def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=F ...@@ -54,21 +54,6 @@ def init_dataloaders(args, g, train_idx, test_idx, target_idx, device, use_ddp=F
return train_loader, val_loader, test_loader return train_loader, val_loader, test_loader
def init_models(args, device, num_nodes, num_classes, num_rels):
embed_layer = RelGraphEmbedLayer(device,
num_nodes,
args.n_hidden)
model = RGCN(args.n_hidden,
args.n_hidden,
num_classes,
num_rels,
num_bases=args.n_bases,
dropout=args.dropout,
self_loop=args.use_self_loop)
return embed_layer, model
def process_batch(inv_target, batch): def process_batch(inv_target, batch):
_, seeds, blocks = batch _, seeds, blocks = batch
# map the seed nodes back to their type-specific ids, # map the seed nodes back to their type-specific ids,
...@@ -80,38 +65,32 @@ def process_batch(inv_target, batch): ...@@ -80,38 +65,32 @@ def process_batch(inv_target, batch):
return seeds, blocks return seeds, blocks
def train(model, embed_layer, train_loader, inv_target, def train(model, train_loader, inv_target,
labels, emb_optimizer, optimizer): labels, optimizer):
model.train() model.train()
embed_layer.train()
for sample_data in train_loader: for sample_data in train_loader:
seeds, blocks = process_batch(inv_target, sample_data) seeds, blocks = process_batch(inv_target, sample_data)
feats = embed_layer(blocks[0].srcdata[dgl.NID].cpu()) logits = model.forward(blocks)
logits = model(blocks, feats)
loss = F.cross_entropy(logits, labels[seeds]) loss = F.cross_entropy(logits, labels[seeds])
emb_optimizer.zero_grad()
optimizer.zero_grad()
optimizer.zero_grad()
loss.backward() loss.backward()
emb_optimizer.step()
optimizer.step() optimizer.step()
train_acc = accuracy(logits.argmax(dim=1), labels[seeds]).item() train_acc = accuracy(logits.argmax(dim=1), labels[seeds]).item()
return train_acc, loss.item() return train_acc, loss.item()
def evaluate(model, embed_layer, eval_loader, inv_target): def evaluate(model, eval_loader, inv_target):
model.eval() model.eval()
embed_layer.eval()
eval_logits = [] eval_logits = []
eval_seeds = [] eval_seeds = []
with th.no_grad(): with th.no_grad():
for sample_data in tqdm(eval_loader): for sample_data in tqdm(eval_loader):
seeds, blocks = process_batch(inv_target, sample_data) seeds, blocks = process_batch(inv_target, sample_data)
feats = embed_layer(blocks[0].srcdata[dgl.NID].cpu()) logits = model.forward(blocks)
logits = model(blocks, feats)
eval_logits.append(logits.cpu().detach()) eval_logits.append(logits.cpu().detach())
eval_seeds.append(seeds.cpu().detach()) eval_seeds.append(seeds.cpu().detach())
...@@ -131,26 +110,30 @@ def main(args): ...@@ -131,26 +110,30 @@ def main(args):
train_loader, val_loader, test_loader = init_dataloaders( train_loader, val_loader, test_loader = init_dataloaders(
args, g, train_idx, test_idx, target_idx, args.gpu) args, g, train_idx, test_idx, target_idx, args.gpu)
embed_layer, model = init_models(args, device, g.num_nodes(), num_classes, num_rels)
model = RGCN(g.num_nodes(),
args.n_hidden,
num_classes,
num_rels,
num_bases=args.n_bases,
dropout=args.dropout,
self_loop=args.use_self_loop,
ns_mode=True)
labels = labels.to(device) labels = labels.to(device)
model = model.to(device) model = model.to(device)
emb_optimizer = th.optim.SparseAdam(embed_layer.parameters(), lr=args.sparse_lr) optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.wd)
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.l2norm)
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
train_acc, loss = train(model, embed_layer, train_loader, inv_target, train_acc, loss = train(model, train_loader, inv_target, labels, optimizer)
labels, emb_optimizer, optimizer)
print("Epoch {:05d}/{:05d} | Train Accuracy: {:.4f} | Train Loss: {:.4f}".format( print("Epoch {:05d}/{:05d} | Train Accuracy: {:.4f} | Train Loss: {:.4f}".format(
epoch, args.n_epochs, train_acc, loss)) epoch, args.n_epochs, train_acc, loss))
val_logits, val_seeds = evaluate(model, embed_layer, val_loader, inv_target) val_logits, val_seeds = evaluate(model, val_loader, inv_target)
val_acc = accuracy(val_logits.argmax(dim=1), labels[val_seeds].cpu()).item() val_acc = accuracy(val_logits.argmax(dim=1), labels[val_seeds].cpu()).item()
print("Validation Accuracy: {:.4f}".format(val_acc)) print("Validation Accuracy: {:.4f}".format(val_acc))
test_logits, test_seeds = evaluate(model, embed_layer, test_logits, test_seeds = evaluate(model, test_loader, inv_target)
test_loader, inv_target)
test_acc = accuracy(test_logits.argmax(dim=1), labels[test_seeds].cpu()).item() test_acc = accuracy(test_logits.argmax(dim=1), labels[test_seeds].cpu()).item()
print("Final Test Accuracy: {:.4f}".format(test_acc)) print("Final Test Accuracy: {:.4f}".format(test_acc))
...@@ -162,8 +145,6 @@ if __name__ == '__main__': ...@@ -162,8 +145,6 @@ if __name__ == '__main__':
help="number of hidden units") help="number of hidden units")
parser.add_argument("--gpu", type=int, default=0, parser.add_argument("--gpu", type=int, default=0,
help="gpu") help="gpu")
parser.add_argument("--sparse-lr", type=float, default=2e-2,
help="sparse embedding learning rate")
parser.add_argument("--n-bases", type=int, default=-1, parser.add_argument("--n-bases", type=int, default=-1,
help="number of filter weight matrices, default: -1 [use all]") help="number of filter weight matrices, default: -1 [use all]")
parser.add_argument("--n-epochs", type=int, default=50, parser.add_argument("--n-epochs", type=int, default=50,
...@@ -171,8 +152,8 @@ if __name__ == '__main__': ...@@ -171,8 +152,8 @@ if __name__ == '__main__':
parser.add_argument("-d", "--dataset", type=str, required=True, parser.add_argument("-d", "--dataset", type=str, required=True,
choices=['aifb', 'mutag', 'bgs', 'am'], choices=['aifb', 'mutag', 'bgs', 'am'],
help="dataset to use") help="dataset to use")
parser.add_argument("--l2norm", type=float, default=5e-4, parser.add_argument("--wd", type=float, default=5e-4,
help="l2 norm coef") help="weight decay")
parser.add_argument("--fanout", type=str, default="4, 4", parser.add_argument("--fanout", type=str, default="4, 4",
help="Fan-out of neighbor sampling") help="Fan-out of neighbor sampling")
parser.add_argument("--use-self-loop", default=False, action='store_true', parser.add_argument("--use-self-loop", default=False, action='store_true',
......
""" """
Differences compared to tkipf/relation-gcn Differences compared to tkipf/relation-gcn
* l2norm applied to all weights * weight decay applied to all weights
* remove nodes that won't be touched
""" """
import argparse import argparse
import gc import gc
...@@ -14,7 +13,8 @@ from torchmetrics.functional import accuracy ...@@ -14,7 +13,8 @@ from torchmetrics.functional import accuracy
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from entity_utils import load_data from entity_utils import load_data
from entity_sample import init_dataloaders, init_models, train, evaluate from entity_sample import init_dataloaders, train, evaluate
from model import RGCN
def collect_eval(n_gpus, queue, labels): def collect_eval(n_gpus, queue, labels):
eval_logits = [] eval_logits = []
...@@ -48,21 +48,25 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, queue=None): ...@@ -48,21 +48,25 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, queue=None):
use_ddp = True if n_gpus > 1 else False use_ddp = True if n_gpus > 1 else False
train_loader, val_loader, test_loader = init_dataloaders( train_loader, val_loader, test_loader = init_dataloaders(
args, g, train_idx, test_idx, target_idx, dev_id, use_ddp=use_ddp) args, g, train_idx, test_idx, target_idx, dev_id, use_ddp=use_ddp)
embed_layer, model = init_models(args, device, g.num_nodes(), num_classes, num_rels)
model = RGCN(g.num_nodes(),
args.n_hidden,
num_classes,
num_rels,
num_bases=args.n_bases,
dropout=args.dropout,
self_loop=args.use_self_loop,
ns_mode=True)
labels = labels.to(device) labels = labels.to(device)
model = model.to(device) model = model.to(device)
model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id) model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
embed_layer = DistributedDataParallel(embed_layer, device_ids=None, output_device=None)
emb_optimizer = th.optim.SparseAdam(embed_layer.module.parameters(), lr=args.sparse_lr) optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.wd)
optimizer = th.optim.Adam(model.parameters(), lr=1e-2, weight_decay=args.l2norm)
th.set_num_threads(n_cpus) th.set_num_threads(n_cpus)
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
train_loader.set_epoch(epoch) train_acc, loss = train(model, train_loader, inv_target,
train_acc, loss = train(model, embed_layer, train_loader, inv_target, labels, optimizer)
labels, emb_optimizer, optimizer)
if proc_id == 0: if proc_id == 0:
print("Epoch {:05d}/{:05d} | Train Accuracy: {:.4f} | Train Loss: {:.4f}".format( print("Epoch {:05d}/{:05d} | Train Accuracy: {:.4f} | Train Loss: {:.4f}".format(
...@@ -71,7 +75,7 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, queue=None): ...@@ -71,7 +75,7 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, queue=None):
# garbage collection that empties the queue # garbage collection that empties the queue
gc.collect() gc.collect()
val_logits, val_seeds = evaluate(model, embed_layer, val_loader, inv_target) val_logits, val_seeds = evaluate(model, val_loader, inv_target)
queue.put((val_logits, val_seeds)) queue.put((val_logits, val_seeds))
# gather evaluation result from multiple processes # gather evaluation result from multiple processes
...@@ -81,7 +85,7 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, queue=None): ...@@ -81,7 +85,7 @@ def run(proc_id, n_gpus, n_cpus, args, devices, dataset, queue=None):
# garbage collection that empties the queue # garbage collection that empties the queue
gc.collect() gc.collect()
test_logits, test_seeds = evaluate(model, embed_layer, test_loader, inv_target) test_logits, test_seeds = evaluate(model, test_loader, inv_target)
queue.put((test_logits, test_seeds)) queue.put((test_logits, test_seeds))
if proc_id == 0: if proc_id == 0:
test_acc = collect_eval(n_gpus, queue, labels) test_acc = collect_eval(n_gpus, queue, labels)
...@@ -119,8 +123,6 @@ if __name__ == '__main__': ...@@ -119,8 +123,6 @@ if __name__ == '__main__':
help="number of hidden units") help="number of hidden units")
parser.add_argument("--gpu", type=str, default='0', parser.add_argument("--gpu", type=str, default='0',
help="gpu") help="gpu")
parser.add_argument("--sparse-lr", type=float, default=2e-2,
help="sparse embedding learning rate")
parser.add_argument("--n-bases", type=int, default=-1, parser.add_argument("--n-bases", type=int, default=-1,
help="number of filter weight matrices, default: -1 [use all]") help="number of filter weight matrices, default: -1 [use all]")
parser.add_argument("--n-epochs", type=int, default=50, parser.add_argument("--n-epochs", type=int, default=50,
...@@ -128,8 +130,8 @@ if __name__ == '__main__': ...@@ -128,8 +130,8 @@ if __name__ == '__main__':
parser.add_argument("-d", "--dataset", type=str, required=True, parser.add_argument("-d", "--dataset", type=str, required=True,
choices=['aifb', 'mutag', 'bgs', 'am'], choices=['aifb', 'mutag', 'bgs', 'am'],
help="dataset to use") help="dataset to use")
parser.add_argument("--l2norm", type=float, default=5e-4, parser.add_argument("--wd", type=float, default=5e-4,
help="l2 norm coef") help="weight decay")
parser.add_argument("--fanout", type=str, default="4, 4", parser.add_argument("--fanout", type=str, default="4, 4",
help="Fan-out of neighbor sampling") help="Fan-out of neighbor sampling")
parser.add_argument("--use-self-loop", default=False, action='store_true', parser.add_argument("--use-self-loop", default=False, action='store_true',
......
...@@ -20,7 +20,8 @@ class LinkPredict(nn.Module): ...@@ -20,7 +20,8 @@ class LinkPredict(nn.Module):
def __init__(self, in_dim, num_rels, h_dim=500, num_bases=100, dropout=0.2, reg_param=0.01): def __init__(self, in_dim, num_rels, h_dim=500, num_bases=100, dropout=0.2, reg_param=0.01):
super(LinkPredict, self).__init__() super(LinkPredict, self).__init__()
self.rgcn = RGCN(in_dim, h_dim, h_dim, num_rels * 2, regularizer="bdd", self.rgcn = RGCN(in_dim, h_dim, h_dim, num_rels * 2, regularizer="bdd",
num_bases=num_bases, dropout=dropout, self_loop=True, link_pred=True) num_bases=num_bases, dropout=dropout, self_loop=True)
self.dropout = nn.Dropout(dropout)
self.reg_param = reg_param self.reg_param = reg_param
self.w_relation = nn.Parameter(th.Tensor(num_rels, h_dim)) self.w_relation = nn.Parameter(th.Tensor(num_rels, h_dim))
nn.init.xavier_uniform_(self.w_relation, nn.init.xavier_uniform_(self.w_relation,
...@@ -34,8 +35,8 @@ class LinkPredict(nn.Module): ...@@ -34,8 +35,8 @@ class LinkPredict(nn.Module):
score = th.sum(s * r * o, dim=1) score = th.sum(s * r * o, dim=1)
return score return score
def forward(self, g, h): def forward(self, g, nids):
return self.rgcn(g, h) return self.dropout(self.rgcn(g, nids=nids))
def regularization_loss(self, embedding): def regularization_loss(self, embedding):
return th.mean(embedding.pow(2)) + th.mean(self.w_relation.pow(2)) return th.mean(embedding.pow(2)) + th.mean(self.w_relation.pow(2))
...@@ -54,7 +55,7 @@ def main(args): ...@@ -54,7 +55,7 @@ def main(args):
num_rels = data.num_rels num_rels = data.num_rels
train_g, test_g = preprocess(graph, num_rels) train_g, test_g = preprocess(graph, num_rels)
test_node_id = th.arange(0, num_nodes).view(-1, 1) test_nids = th.arange(0, num_nodes)
test_mask = graph.edata['test_mask'] test_mask = graph.edata['test_mask']
subg_iter = SubgraphIterator(train_g, num_rels, args.edge_sampler) subg_iter = SubgraphIterator(train_g, num_rels, args.edge_sampler)
dataloader = GraphDataLoader(subg_iter, batch_size=1, collate_fn=lambda x: x[0]) dataloader = GraphDataLoader(subg_iter, batch_size=1, collate_fn=lambda x: x[0])
...@@ -77,14 +78,14 @@ def main(args): ...@@ -77,14 +78,14 @@ def main(args):
for epoch, batch_data in enumerate(dataloader): for epoch, batch_data in enumerate(dataloader):
model.train() model.train()
g, node_id, data, labels = batch_data g, train_nids, edges, labels = batch_data
g = g.to(device) g = g.to(device)
node_id = node_id.to(device) train_nids = train_nids.to(device)
data = data.to(device) edges = edges.to(device)
labels = labels.to(device) labels = labels.to(device)
embed = model(g, node_id) embed = model(g, train_nids)
loss = model.get_loss(embed, data, labels) loss = model.get_loss(embed, edges, labels)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # clip gradients nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # clip gradients
...@@ -97,7 +98,7 @@ def main(args): ...@@ -97,7 +98,7 @@ def main(args):
model = model.cpu() model = model.cpu()
model.eval() model.eval()
print("start eval") print("start eval")
embed = model(test_g, test_node_id) embed = model(test_g, test_nids)
mrr = calc_mrr(embed, model.w_relation, test_mask, triplets, mrr = calc_mrr(embed, model.w_relation, test_mask, triplets,
batch_size=500, eval_p=args.eval_protocol) batch_size=500, eval_p=args.eval_protocol)
# save best model # save best model
...@@ -114,7 +115,7 @@ def main(args): ...@@ -114,7 +115,7 @@ def main(args):
model.eval() model.eval()
model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'])
print("Using best epoch: {}".format(checkpoint['epoch'])) print("Using best epoch: {}".format(checkpoint['epoch']))
embed = model(test_g, test_node_id) embed = model(test_g, test_nids)
calc_mrr(embed, model.w_relation, test_mask, triplets, calc_mrr(embed, model.w_relation, test_mask, triplets,
batch_size=500, eval_p=args.eval_protocol) batch_size=500, eval_p=args.eval_protocol)
......
...@@ -158,7 +158,7 @@ class SubgraphIterator: ...@@ -158,7 +158,7 @@ class SubgraphIterator:
sub_g = dgl.graph((src, dst), num_nodes=num_nodes) sub_g = dgl.graph((src, dst), num_nodes=num_nodes)
sub_g.edata[dgl.ETYPE] = th.from_numpy(rel) sub_g.edata[dgl.ETYPE] = th.from_numpy(rel)
sub_g.edata['norm'] = dgl.norm_by_dst(sub_g).unsqueeze(-1) sub_g.edata['norm'] = dgl.norm_by_dst(sub_g).unsqueeze(-1)
uniq_v = th.from_numpy(uniq_v).view(-1, 1).long() uniq_v = th.from_numpy(uniq_v).view(-1).long()
return sub_g, uniq_v, samples, labels return sub_g, uniq_v, samples, labels
......
...@@ -7,81 +7,32 @@ import dgl ...@@ -7,81 +7,32 @@ import dgl
from dgl.nn.pytorch import RelGraphConv from dgl.nn.pytorch import RelGraphConv
class RGCN(nn.Module): class RGCN(nn.Module):
def __init__(self, in_dim, h_dim, out_dim, num_rels, def __init__(self, num_nodes, h_dim, out_dim, num_rels,
regularizer="basis", num_bases=-1, dropout=0., regularizer="basis", num_bases=-1, dropout=0.,
self_loop=False, link_pred=False): self_loop=False,
ns_mode=False):
super(RGCN, self).__init__() super(RGCN, self).__init__()
self.layers = nn.ModuleList() if num_bases == -1:
if link_pred: num_bases = num_rels
self.emb = nn.Embedding(in_dim, h_dim) self.emb = nn.Embedding(num_nodes, h_dim)
in_dim = h_dim self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer,
num_bases, self_loop=self_loop)
self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer, num_bases, self_loop=self_loop)
self.dropout = nn.Dropout(dropout)
self.ns_mode = ns_mode
def forward(self, g, nids=None):
if self.ns_mode:
# forward for neighbor sampling
x = self.emb(g[0].srcdata[dgl.NID])
h = self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata['norm'])
h = self.dropout(F.relu(h))
h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata['norm'])
return h
else: else:
self.emb = None x = self.emb.weight if nids is None else self.emb(nids)
self.layers.append(RelGraphConv(in_dim, h_dim, num_rels, regularizer, h = self.conv1(g, x, g.edata[dgl.ETYPE], g.edata['norm'])
num_bases, activation=F.relu, self_loop=self_loop, h = self.dropout(F.relu(h))
dropout=dropout)) h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata['norm'])
return h
# For entity classification, dropout should not be applied to the output layer
if not link_pred:
dropout = 0.
self.layers.append(RelGraphConv(h_dim, out_dim, num_rels, regularizer,
num_bases, self_loop=self_loop, dropout=dropout))
def forward(self, g, h):
if isinstance(g, DGLGraph):
blocks = [g] * len(self.layers)
else:
blocks = g
if self.emb is not None:
h = self.emb(h.squeeze())
for layer, block in zip(self.layers, blocks):
h = layer(block, h, block.edata[dgl.ETYPE], block.edata['norm'])
return h
def initializer(emb):
emb.uniform_(-1.0, 1.0)
return emb
class RelGraphEmbedLayer(nn.Module):
"""Embedding layer for featureless heterograph.
Parameters
----------
out_dev
Device to store the output embeddings
num_nodes : int
Number of nodes in the graph.
embed_size : int
Output embed size
"""
def __init__(self,
out_dev,
num_nodes,
embed_size):
super(RelGraphEmbedLayer, self).__init__()
self.out_dev = out_dev
self.embed_size = embed_size
# create embeddings for all nodes
self.node_embed = nn.Embedding(num_nodes, embed_size, sparse=True)
nn.init.uniform_(self.node_embed.weight, -1.0, 1.0)
def forward(self, node_ids):
"""Forward computation
Parameters
----------
node_ids : tensor
Raw node IDs.
Returns
-------
tensor
embeddings as the input of the next layer
"""
embeds = self.node_embed(node_ids).to(self.out_dev)
return embeds
...@@ -2,7 +2,8 @@ import torch as th ...@@ -2,7 +2,8 @@ import torch as th
from distutils.version import LooseVersion from distutils.version import LooseVersion
from ...base import is_all, ALL from ...base import is_all, ALL
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp
from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero, _gather_mm, _gather_mm_scatter, _segment_mm from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero
from ...sparse import _gather_mm, _gather_mm_scatter, _segment_mm, _segment_mm_backward_B
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _edge_softmax_forward, _edge_softmax_backward from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _edge_softmax_forward, _edge_softmax_backward
from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero
from ...heterograph_index import create_unitgraph_from_csr from ...heterograph_index import create_unitgraph_from_csr
...@@ -697,22 +698,16 @@ class SEGMENTMM(th.autograd.Function): ...@@ -697,22 +698,16 @@ class SEGMENTMM(th.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=th.float16) @custom_fwd(cast_inputs=th.float16)
def forward(ctx, A, B, seglen_A): def forward(ctx, A, B, seglen_A):
if A.shape[0] != th.sum(seglen_A):
raise Exception("The summation of the elements of seglen_A must be equal to " +
"dimension 0 of A. Expected "+ str(A.shape[0]) + "got" + str(th.sum(seglen_A)))
if B.dim() != 3: if B.dim() != 3:
raise Exception("Expected dimension of B is 3. Got " + str(B.dim())) raise ValueError("segment_mm expects B to be a 3D tensor.")
# Reshaping B form 3D to 2D C = th.zeros((A.shape[0], B.shape[2]), device=A.device, dtype=A.dtype)
B_3D_shape = B.shape
B = B.reshape(B.shape[0] * B.shape[1], B.shape[2])
C = th.zeros((A.shape[0], B.shape[1]), device=A.device, dtype=A.dtype)
C = _segment_mm(A, B, C, seglen_A) C = _segment_mm(A, B, C, seglen_A)
ctx.backward_cache = A, B, seglen_A, B_3D_shape ctx.backward_cache = A, B, seglen_A
return C return C
@staticmethod @staticmethod
def backward(ctx, dZ): def backward(ctx, dZ):
A, B, seglen_A, B_3D_shape = ctx.backward_cache A, B, seglen_A = ctx.backward_cache
A_grad = B_grad = None A_grad = B_grad = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
# Compute A_grad = Out_grad * B^T # Compute A_grad = Out_grad * B^T
...@@ -721,9 +716,8 @@ class SEGMENTMM(th.autograd.Function): ...@@ -721,9 +716,8 @@ class SEGMENTMM(th.autograd.Function):
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
# Compute B_grad = A^T * Out_grad # Compute B_grad = A^T * Out_grad
B_grad = th.zeros(B.shape, device=B.device, dtype=B.dtype) B_grad = th.zeros(B.shape, device=B.device, dtype=B.dtype)
B_grad = _segment_mm(A, dZ, B_grad, seglen_A, a_trans=True) B_grad = _segment_mm_backward_B(A, dZ, B_grad, seglen_A)
B_grad = B_grad.reshape(B_3D_shape[0], B_3D_shape[1], B_3D_shape[2]) return A_grad, B_grad, None
return A_grad, B_grad, None, None, None, None, None, None
class GATHERMM(th.autograd.Function): class GATHERMM(th.autograd.Function):
...@@ -731,31 +725,27 @@ class GATHERMM(th.autograd.Function): ...@@ -731,31 +725,27 @@ class GATHERMM(th.autograd.Function):
@custom_fwd(cast_inputs=th.float16) @custom_fwd(cast_inputs=th.float16)
def forward(ctx, A, B, idx_a, idx_b): def forward(ctx, A, B, idx_a, idx_b):
if B.dim() != 3: if B.dim() != 3:
raise Exception("Expected dimension of B is 3. Got " + str(B.dim())) raise ValueError("Expected dimension of B is 3. Got " + str(B.dim()))
# Reshaping B form 3D to 2D N = len(idx_b) if idx_a is None else len(idx_a)
B_3D_shape = B.shape C = th.zeros((N, B.shape[2]), device=A.device, dtype=A.dtype)
B = B.reshape(B.shape[0] * B.shape[1], B.shape[2]) C = _gather_mm(A, B, C, idx_a, idx_b)
C = th.zeros((A.shape[0], B.shape[1]), device=A.device, dtype=A.dtype) ctx.backward_cache = A, B, idx_a, idx_b
C = _gather_mm(A, B, C, B_3D_shape[0], idx_a, idx_b)
ctx.backward_cache = A, B, idx_a, idx_b, B_3D_shape
return C return C
@staticmethod @staticmethod
def backward(ctx, dZ): def backward(ctx, dZ):
A, B, idx_a, idx_b, B_3D_shape = ctx.backward_cache A, B, idx_a, idx_b = ctx.backward_cache
A_grad = B_grad = None A_grad = B_grad = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
# Compute A_grad = Out_grad * B^T # Compute A_grad = Out_grad * B^T
A_grad = th.zeros(A.shape, device=A.device, dtype=A.dtype) A_grad = th.zeros(A.shape, device=A.device, dtype=A.dtype)
A_grad = _gather_mm_scatter(dZ, B, A_grad, B_3D_shape[0], A_grad = _gather_mm_scatter(dZ, B.transpose(1, 2), A_grad,
idx_b=idx_b, idx_c=idx_a, b_trans=True) idx_b=idx_b, idx_c=idx_a)
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
# Compute B_grad = A^T * Out_grad # Compute B_grad = A^T * Out_grad
B_grad = th.zeros(B.shape, device=B.device, dtype=B.dtype) B_grad = th.zeros(B.shape, device=B.device, dtype=B.dtype)
B_grad = _gather_mm_scatter(A, dZ, B_grad, B_3D_shape[0], B_grad = _gather_mm_scatter(A, dZ, B_grad, idx_a=idx_a, idx_c=idx_b)
idx_a=idx_a, idx_c=idx_b) return A_grad, B_grad, None, None
B_grad = B_grad.reshape(B_3D_shape[0], B_3D_shape[1], B_3D_shape[2])
return A_grad, B_grad, None, None, None, None, None, None
def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
if op == 'sub': if op == 'sub':
...@@ -834,7 +824,20 @@ def csrmask(gidxA, A_weights, gidxB): ...@@ -834,7 +824,20 @@ def csrmask(gidxA, A_weights, gidxB):
return CSRMask.apply(gidxA, A_weights, gidxB) return CSRMask.apply(gidxA, A_weights, gidxB)
def segment_mm(A, B, seglen_A): def segment_mm(A, B, seglen_A):
return SEGMENTMM.apply(A, B, seglen_A) if A.device.type == 'cpu':
C = []
def gather_mm(A, B, idx_a = None, idx_b = None): off = 0
return GATHERMM.apply(A, B, idx_a, idx_b) for i in range(B.shape[0]):
C.append(A[off:off+seglen_A[i]] @ B[i])
off += seglen_A[i]
return th.cat(C)
else:
return SEGMENTMM.apply(A, B, seglen_A)
def gather_mm(A, B, idx_A=None, idx_B=None):
if A.device.type == 'cpu':
A = A[idx_A] if idx_A is not None else A
B = B[idx_B] if idx_B is not None else B
return th.bmm(A.unsqueeze(1), B).squeeze(1)
else:
return GATHERMM.apply(A, B, idx_A, idx_B)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from .conv import * from .conv import *
from .explain import * from .explain import *
from .link import * from .link import *
from .linear import *
from .glob import * from .glob import *
from .softmax import * from .softmax import *
from .factory import * from .factory import *
......
...@@ -25,9 +25,10 @@ from .cfconv import CFConv ...@@ -25,9 +25,10 @@ from .cfconv import CFConv
from .dotgatconv import DotGatConv from .dotgatconv import DotGatConv
from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention
from .gcn2conv import GCN2Conv from .gcn2conv import GCN2Conv
from .hgtconv import HGTConv
__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'GATv2Conv', 'EGATConv', 'TAGConv', __all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'GATv2Conv', 'EGATConv', 'TAGConv',
'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv',
'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv', 'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'TWIRLSConv', 'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'TWIRLSConv',
'TWIRLSUnfoldingAndAttention', 'GCN2Conv'] 'TWIRLSUnfoldingAndAttention', 'GCN2Conv', 'HGTConv']
"""Heterogeneous Graph Transformer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import math
import torch
import torch.nn as nn
from .... import function as fn
from ..linear import TypedLinear
from ..softmax import edge_softmax
class HGTConv(nn.Module):
r"""Heterogeneous graph transformer convolution.
Introduced in "`Heterogeneous Graph Transformer <https://arxiv.org/abs/2003.01332>`__".
Given a graph :math:`G(V, E)` and input node features :math:`H^{(l-1)}`,
it computes the new node features as follows:
Compute a multi-head attention score for each edge :math:`(s, e, t)` in the graph:
.. math::
Attention(s, e, t) = \text{Softmax}\left(||_{i\in[1,h]}ATT-head^i(s, e, t)\right) \\
ATT-head^i(s, e, t) = \left(K^i(s)W^{ATT}_{\phi(e)}Q^i(t)^{\top}\right)\cdot
\frac{\mu_{(\tau(s),\phi(e),\tau(t)}}{\sqrt{d}} \\
K^i(s) = \text{K-Linear}^i_{\tau(s)}(H^{(l-1)}[s]) \\
Q^i(t) = \text{Q-Linear}^i_{\tau(t)}(H^{(l-1)}[t]) \\
Compute the message to send on each edge :math:`(s, e, t)`:
.. math::
Message(s, e, t) = ||_{i\in[1, h]} MSG-head^i(s, e, t) \\
MSG-head^i(s, e, t) = \text{M-Linear}^i_{\tau(s)}(H^{(l-1)}[s])W^{MSG}_{\phi(e)} \\
Send messages to target nodes :math:`t` and aggregate:
.. math::
\tilde{H}^{(l)}[t] = \sum_{\forall s\in \mathcal{N}(t)}\left( Attention(s,e,t)
\cdot Message(s,e,t)\right)
Compute new node features:
.. math::
H^{(l)}[t]=\text{A-Linear}_{\tau(t)}(\sigma(\tilde(H)^{(l)}[t])) + H^{(l-1)}[t]
Parameters
----------
in_size : int
Input node feature size.
head_size : int
Output head size. The output node feature size is ``head_size * num_heads``.
num_heads : int
Number of heads. The output node feature size is ``head_size * num_heads``.
num_ntypes : int
Number of node types.
num_etypes : int
Number of edge types.
dropout : optional, float
Dropout rate.
use_norm : optiona, bool
If true, apply a layer norm on the output node feature.
Examples
--------
"""
def __init__(self,
in_size,
head_size,
num_heads,
num_ntypes,
num_etypes,
dropout=0.2,
use_norm=False):
super().__init__()
self.in_size = in_size
self.head_size = head_size
self.num_heads = num_heads
self.sqrt_d = math.sqrt(head_size)
self.use_norm = use_norm
self.linear_k = TypedLinear(in_size, head_size * num_heads, num_ntypes)
self.linear_q = TypedLinear(in_size, head_size * num_heads, num_ntypes)
self.linear_v = TypedLinear(in_size, head_size * num_heads, num_ntypes)
self.linear_a = TypedLinear(head_size * num_heads, head_size * num_heads, num_ntypes)
self.relation_pri = nn.ParameterList([nn.Parameter(torch.ones(num_etypes))
for i in range(num_heads)])
self.relation_att = nn.ModuleList([TypedLinear(head_size, head_size, num_etypes)
for i in range(num_heads)])
self.relation_msg = nn.ModuleList([TypedLinear(head_size, head_size, num_etypes)
for i in range(num_heads)])
self.skip = nn.Parameter(torch.ones(num_ntypes))
self.drop = nn.Dropout(dropout)
if use_norm:
self.norm = nn.LayerNorm(head_size * num_heads)
if in_size != head_size * num_heads:
self.residual_w = nn.Parameter(torch.Tensor(in_size, head_size * num_heads))
nn.init.xavier_uniform_(self.residual_w)
def forward(self, g, x, ntype, etype, *, presorted=False):
"""Forward computation.
Parameters
----------
g : DGLGraph
The input graph.
x : torch.Tensor
A 2D tensor of node features. Shape: :math:`(|V|, D_{in})`.
ntype : torch.Tensor
An 1D integer tensor of node types. Shape: :math:`(|V|,)`.
etype : torch.Tensor
An 1D integer tensor of edge types. Shape: :math:`(|E|,)`.
presorted : bool, optional
Whether *both* the nodes and the edges of the input graph have been sorted by
their types. Forward on pre-sorted graph may be faster. Graphs created by
:func:`~dgl.to_homogeneous` automatically satisfy the condition.
Also see :func:`~dgl.reorder_graph` for manually reordering the nodes and edges.
Returns
-------
torch.Tensor
New node features. Shape: :math:`(|V|, D_{head} * N_{head})`.
"""
self.presorted = presorted
with g.local_scope():
k = self.linear_k(x, ntype, presorted).view(-1, self.num_heads, self.head_size)
q = self.linear_q(x, ntype, presorted).view(-1, self.num_heads, self.head_size)
v = self.linear_v(x, ntype, presorted).view(-1, self.num_heads, self.head_size)
g.srcdata['k'] = k
g.dstdata['q'] = q
g.srcdata['v'] = v
g.edata['etype'] = etype
g.apply_edges(self.message)
g.edata['m'] = g.edata['m'] * edge_softmax(g, g.edata['a']).unsqueeze(-1)
g.update_all(fn.copy_e('m', 'm'), fn.sum('m', 'h'))
h = g.dstdata['h'].view(-1, self.num_heads * self.head_size)
# target-specific aggregation
h = self.drop(self.linear_a(h, ntype, presorted))
alpha = torch.sigmoid(self.skip[ntype]).unsqueeze(-1)
if x.shape != h.shape:
h = h * alpha + (x @ self.residual_w) * (1 - alpha)
else:
h = h * alpha + x * (1 - alpha)
if self.use_norm:
h = self.norm(h)
return h
def message(self, edges):
"""Message function."""
a, m = [], []
etype = edges.data['etype']
k = torch.unbind(edges.src['k'], dim=1)
q = torch.unbind(edges.dst['q'], dim=1)
v = torch.unbind(edges.src['v'], dim=1)
for i in range(self.num_heads):
kw = self.relation_att[i](k[i], etype, self.presorted) # (E, O)
a.append((kw * q[i]).sum(-1) * self.relation_pri[i][etype] / self.sqrt_d) # (E,)
m.append(self.relation_msg[i](v[i], etype, self.presorted)) # (E, O)
return {'a' : torch.stack(a, dim=1), 'm' : torch.stack(m, dim=1)}
"""Torch Module for Relational graph convolution layer""" """Torch Module for Relational graph convolution layer"""
# pylint: disable= no-member, arguments-differ, invalid-name # pylint: disable= no-member, arguments-differ, invalid-name
import functools
import numpy as np
import torch as th import torch as th
from torch import nn from torch import nn
from .... import function as fn from .... import function as fn
from .. import utils from ..linear import TypedLinear
from ....base import DGLError
from .... import edge_subgraph
class RelGraphConv(nn.Module): class RelGraphConv(nn.Module):
r"""Relational graph convolution layer. r"""Relational graph convolution layer.
...@@ -55,22 +51,21 @@ class RelGraphConv(nn.Module): ...@@ -55,22 +51,21 @@ class RelGraphConv(nn.Module):
Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`. Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.
num_rels : int num_rels : int
Number of relations. . Number of relations. .
regularizer : str regularizer : str, optional
Which weight regularizer to use "basis" or "bdd". Which weight regularizer to use "basis" or "bdd":
"basis" is short for basis-diagonal-decomposition.
"bdd" is short for block-diagonal-decomposition. - "basis" is short for basis-decomposition.
- "bdd" is short for block-diagonal-decomposition.
Default applies no regularization.
num_bases : int, optional num_bases : int, optional
Number of bases. If is none, use number of relations. Default: ``None``. Number of bases. Needed when ``regularizer`` is specified. Default: ``None``.
bias : bool, optional bias : bool, optional
True if bias is added. Default: ``True``. True if bias is added. Default: ``True``.
activation : callable, optional activation : callable, optional
Activation function. Default: ``None``. Activation function. Default: ``None``.
self_loop : bool, optional self_loop : bool, optional
True to include self loop message. Default: ``True``. True to include self loop message. Default: ``True``.
low_mem : bool, optional
True to use low memory implementation of relation message passing function. Default: False.
This option trades speed with memory consumption, and will slowdown the forward/backward.
Turn it on when you encounter OOM problem during training or evaluation. Default: ``False``.
dropout : float, optional dropout : float, optional
Dropout rate. Default: ``0.0`` Dropout rate. Default: ``0.0``
layer_norm: float, optional layer_norm: float, optional
...@@ -86,9 +81,7 @@ class RelGraphConv(nn.Module): ...@@ -86,9 +81,7 @@ class RelGraphConv(nn.Module):
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = th.ones(6, 10) >>> feat = th.ones(6, 10)
>>> conv = RelGraphConv(10, 2, 3, regularizer='basis', num_bases=2) >>> conv = RelGraphConv(10, 2, 3, regularizer='basis', num_bases=2)
>>> conv.weight.shape >>> etype = th.tensor([0,1,2,0,1,2])
torch.Size([2, 10, 2])
>>> etype = th.tensor(np.array([0,1,2,0,1,2]).astype(np.int64))
>>> res = conv(g, feat, etype) >>> res = conv(g, feat, etype)
>>> res >>> res
tensor([[ 0.3996, -2.3303], tensor([[ 0.3996, -2.3303],
...@@ -97,80 +90,32 @@ class RelGraphConv(nn.Module): ...@@ -97,80 +90,32 @@ class RelGraphConv(nn.Module):
[ 2.1046, -2.8654], [ 2.1046, -2.8654],
[-0.4323, -0.1440], [-0.4323, -0.1440],
[-0.1309, -1.0000]], grad_fn=<AddBackward0>) [-0.1309, -1.0000]], grad_fn=<AddBackward0>)
>>> # One-hot input
>>> one_hot_feat = th.tensor(np.array([0,1,2,3,4,5]).astype(np.int64))
>>> res = conv(g, one_hot_feat, etype)
>>> res
tensor([[ 0.5925, 0.0985],
[-0.3953, 0.8408],
[-0.9819, 0.5284],
[-1.0085, -0.1721],
[ 0.5962, 1.2002],
[ 0.0365, -0.3532]], grad_fn=<AddBackward0>)
""" """
def __init__(self, def __init__(self,
in_feat, in_feat,
out_feat, out_feat,
num_rels, num_rels,
regularizer="basis", regularizer=None,
num_bases=None, num_bases=None,
bias=True, bias=True,
activation=None, activation=None,
self_loop=True, self_loop=True,
low_mem=False,
dropout=0.0, dropout=0.0,
layer_norm=False): layer_norm=False):
super(RelGraphConv, self).__init__() super().__init__()
self.in_feat = in_feat self.linear_r = TypedLinear(in_feat, out_feat, num_rels, regularizer, num_bases)
self.out_feat = out_feat
self.num_rels = num_rels
self.regularizer = regularizer
self.num_bases = num_bases
if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases <= 0:
self.num_bases = self.num_rels
self.bias = bias self.bias = bias
self.activation = activation self.activation = activation
self.self_loop = self_loop self.self_loop = self_loop
self.low_mem = low_mem
self.layer_norm = layer_norm self.layer_norm = layer_norm
if regularizer == "basis":
# add basis weights
self.weight = nn.Parameter(th.Tensor(self.num_bases, self.in_feat, self.out_feat))
if self.num_bases < self.num_rels:
# linear combination coefficients
self.w_comp = nn.Parameter(th.Tensor(self.num_rels, self.num_bases))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
if self.num_bases < self.num_rels:
nn.init.xavier_uniform_(self.w_comp,
gain=nn.init.calculate_gain('relu'))
# message func
self.message_func = self.basis_message_func
elif regularizer == "bdd":
if in_feat % self.num_bases != 0 or out_feat % self.num_bases != 0:
raise ValueError(
'Feature size must be a multiplier of num_bases (%d).'
% self.num_bases
)
# add block diagonal weights
self.submat_in = in_feat // self.num_bases
self.submat_out = out_feat // self.num_bases
# assuming in_feat and out_feat are both divisible by num_bases
self.weight = nn.Parameter(th.Tensor(
self.num_rels, self.num_bases * self.submat_in * self.submat_out))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
# message func
self.message_func = self.bdd_message_func
else:
raise ValueError("Regularizer must be either 'basis' or 'bdd'")
# bias # bias
if self.bias: if self.bias:
self.h_bias = nn.Parameter(th.Tensor(out_feat)) self.h_bias = nn.Parameter(th.Tensor(out_feat))
nn.init.zeros_(self.h_bias) nn.init.zeros_(self.h_bias)
# TODO(minjie): consider remove those options in the future to make
# the module only about graph convolution.
# layer norm # layer norm
if self.layer_norm: if self.layer_norm:
self.layer_norm_weight = nn.LayerNorm(out_feat, elementwise_affine=True) self.layer_norm_weight = nn.LayerNorm(out_feat, elementwise_affine=True)
...@@ -178,121 +123,18 @@ class RelGraphConv(nn.Module): ...@@ -178,121 +123,18 @@ class RelGraphConv(nn.Module):
# weight for self loop # weight for self loop
if self.self_loop: if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat)) self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight, nn.init.xavier_uniform_(self.loop_weight, gain=nn.init.calculate_gain('relu'))
gain=nn.init.calculate_gain('relu'))
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def basis_message_func(self, edges, etypes): def message(self, edges):
"""Message function for basis regularizer. """Message function."""
m = self.linear_r(edges.src['h'], edges.data['etype'], self.presorted)
Parameters
----------
edges : dgl.EdgeBatch
Input to DGL message UDF.
etypes : torch.Tensor or list[int]
Edge type data. Could be either:
* An :math:`(|E|,)` dense tensor. Each element corresponds to the edge's type ID.
Preferred format if ``lowmem == False``.
* An integer list. The i^th element is the number of edges of the i^th type.
This requires the input graph to store edges sorted by their type IDs.
Preferred format if ``lowmem == True``.
"""
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = self.weight.view(self.num_bases,
self.in_feat * self.out_feat)
weight = th.matmul(self.w_comp, weight).view(
self.num_rels, self.in_feat, self.out_feat)
else:
weight = self.weight
h = edges.src['h']
device = h.device
if h.dtype == th.int64 and h.ndim == 1:
# Each element is the node's ID. Use index select: weight[etypes, h, :]
# The following is a faster version of it.
if isinstance(etypes, list):
etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
th.tensor(etypes, device=device))
idim = weight.shape[1]
weight = weight.view(-1, weight.shape[2])
flatidx = etypes * idim + h
msg = weight.index_select(0, flatidx)
elif self.low_mem:
# A more memory-friendly implementation.
# Calculate msg @ W_r before put msg into edge.
assert isinstance(etypes, list)
h_t = th.split(h, etypes)
msg = []
for etype in range(self.num_rels):
if h_t[etype].shape[0] == 0:
continue
msg.append(th.matmul(h_t[etype], weight[etype]))
msg = th.cat(msg)
else:
# Use batched matmult
if isinstance(etypes, list):
etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
th.tensor(etypes, device=device))
weight = weight.index_select(0, etypes)
msg = th.bmm(h.unsqueeze(1), weight).squeeze(1)
if 'norm' in edges.data: if 'norm' in edges.data:
msg = msg * edges.data['norm'] m = m * edges.data['norm']
return {'msg': msg} return {'m' : m}
def bdd_message_func(self, edges, etypes): def forward(self, g, feat, etypes, norm=None, *, presorted=False):
"""Message function for block-diagonal-decomposition regularizer.
Parameters
----------
edges : dgl.EdgeBatch
Input to DGL message UDF.
etypes : torch.Tensor or list[int]
Edge type data. Could be either:
* An :math:`(|E|,)` dense tensor. Each element corresponds to the edge's type ID.
Preferred format if ``lowmem == False``.
* An integer list. The i^th element is the number of edges of the i^th type.
This requires the input graph to store edges sorted by their type IDs.
Preferred format if ``lowmem == True``.
"""
h = edges.src['h']
device = h.device
if h.dtype == th.int64 and h.ndim == 1:
raise TypeError('Block decomposition does not allow integer ID feature.')
if self.low_mem:
# A more memory-friendly implementation.
# Calculate msg @ W_r before put msg into edge.
assert isinstance(etypes, list)
h_t = th.split(h, etypes)
msg = []
for etype in range(self.num_rels):
if h_t[etype].shape[0] == 0:
continue
tmp_w = self.weight[etype].view(self.num_bases, self.submat_in, self.submat_out)
tmp_h = h_t[etype].view(-1, self.num_bases, self.submat_in)
msg.append(th.einsum('abc,bcd->abd', tmp_h, tmp_w).reshape(-1, self.out_feat))
msg = th.cat(msg)
else:
# Use batched matmult
if isinstance(etypes, list):
etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
th.tensor(etypes, device=device))
weight = self.weight.index_select(0, etypes).view(
-1, self.submat_in, self.submat_out)
node = h.view(-1, 1, self.submat_in)
msg = th.bmm(node, weight).view(-1, self.out_feat)
if 'norm' in edges.data:
msg = msg * edges.data['norm']
return {'msg': msg}
def forward(self, g, feat, etypes, norm=None):
"""Forward computation. """Forward computation.
Parameters Parameters
...@@ -300,88 +142,39 @@ class RelGraphConv(nn.Module): ...@@ -300,88 +142,39 @@ class RelGraphConv(nn.Module):
g : DGLGraph g : DGLGraph
The graph. The graph.
feat : torch.Tensor feat : torch.Tensor
Input node features. Could be either A 2D tensor of node features. Shape: :math:`(|V|, D_{in})`.
* :math:`(|V|, D)` dense tensor
* :math:`(|V|,)` int64 vector, representing the categorical values of each
node. It then treat the input feature as an one-hot encoding feature.
etypes : torch.Tensor or list[int] etypes : torch.Tensor or list[int]
Edge type data. Could be either An 1D integer tensor of edge types. Shape: :math:`(|E|,)`.
* An :math:`(|E|,)` dense tensor. Each element corresponds to the edge's type ID.
Preferred format if ``lowmem == False``.
* An integer list. The i^th element is the number of edges of the i^th type.
This requires the input graph to store edges sorted by their type IDs.
Preferred format if ``lowmem == True``.
norm : torch.Tensor, optional norm : torch.Tensor, optional
Edge normalizer. Could be either An 1D tensor of edge norm value. Shape: :math:`(|E|,)`.
presorted : bool, optional
* An :math:`(|E|, 1)` tensor storing the normalizer on each edge. Whether the edges of the input graph have been sorted by their types.
Forward on pre-sorted graph may be faster. Graphs created
by :func:`~dgl.to_homogeneous` automatically satisfy the condition.
Also see :func:`~dgl.reorder_graph` for sorting edges manually.
Returns Returns
------- -------
torch.Tensor torch.Tensor
New node features. New node features. Shape: :math:`(|V|, D_{out})`.
Notes
-----
Under the ``low_mem`` mode, DGL will sort the graph based on the edge types
and compute message passing one type at a time. DGL recommends sorts the
graph beforehand (and cache it if possible) and provides the integer list
format to the ``etypes`` argument. Use DGL's :func:`~dgl.to_homogeneous` API
to get a sorted homogeneous graph from a heterogeneous graph. Pass ``return_count=True``
to it to get the ``etypes`` in integer list.
""" """
if isinstance(etypes, th.Tensor): self.presorted = presorted
if len(etypes) != g.num_edges():
raise DGLError('"etypes" tensor must have length equal to the number of edges'
' in the graph. But got {} and {}.'.format(
len(etypes), g.num_edges()))
if self.low_mem and not (feat.dtype == th.int64 and feat.ndim == 1):
# Low-mem optimization is not enabled for node ID input. When enabled,
# it first sorts the graph based on the edge types (the sorting will not
# change the node IDs). It then converts the etypes tensor to an integer
# list, where each element is the number of edges of the type.
# Sort the graph based on the etypes
sorted_etypes, index = th.sort(etypes)
g = edge_subgraph(g, index, relabel_nodes=False)
# Create a new etypes to be an integer list of number of edges.
pos = _searchsorted(sorted_etypes, th.arange(self.num_rels, device=g.device))
num = th.tensor([len(etypes)], device=g.device)
etypes = (th.cat([pos[1:], num]) - pos).tolist()
if norm is not None:
norm = norm[index]
with g.local_scope(): with g.local_scope():
g.srcdata['h'] = feat g.srcdata['h'] = feat
if norm is not None: if norm is not None:
g.edata['norm'] = norm g.edata['norm'] = norm
if self.self_loop: g.edata['etype'] = etypes
loop_message = utils.matmul_maybe_select(feat[:g.number_of_dst_nodes()],
self.loop_weight)
# message passing # message passing
g.update_all(functools.partial(self.message_func, etypes=etypes), g.update_all(self.message, fn.sum('m', 'h'))
fn.sum(msg='msg', out='h'))
# apply bias and activation # apply bias and activation
node_repr = g.dstdata['h'] h = g.dstdata['h']
if self.layer_norm: if self.layer_norm:
node_repr = self.layer_norm_weight(node_repr) h = self.layer_norm_weight(h)
if self.bias: if self.bias:
node_repr = node_repr + self.h_bias h = h + self.h_bias
if self.self_loop: if self.self_loop:
node_repr = node_repr + loop_message h = h + feat[:g.num_dst_nodes()] @ self.loop_weight
if self.activation: if self.activation:
node_repr = self.activation(node_repr) h = self.activation(h)
node_repr = self.dropout(node_repr) h = self.dropout(h)
return node_repr return h
_TORCH_HAS_SEARCHSORTED = getattr(th, 'searchsorted', None)
def _searchsorted(sorted_sequence, values):
# searchsorted is introduced to PyTorch in 1.6.0
if _TORCH_HAS_SEARCHSORTED:
return th.searchsorted(sorted_sequence, values)
else:
device = values.device
return th.from_numpy(np.searchsorted(sorted_sequence.cpu().numpy(),
values.cpu().numpy())).to(device)
"""Various commonly used linear modules"""
# pylint: disable= no-member, arguments-differ, invalid-name, W0235
import math
import torch
import torch.nn as nn
from ...ops import segment_mm, gather_mm
__all__ = ['TypedLinear']
class TypedLinear(nn.Module):
r"""Linear transformation according to types.
For each sample of the input batch :math:`x \in X`, apply linear transformation
:math:`xW_t`, where :math:`t` is the type of :math:`x`.
The module supports two regularization methods (basis-decomposition and
block-diagonal-decomposition) proposed by "`Modeling Relational Data
with Graph Convolutional Networks <https://arxiv.org/abs/1703.06103>`__"
The basis regularization decomposes :math:`W_t` by:
.. math::
W_t^{(l)} = \sum_{b=1}^B a_{tb}^{(l)}V_b^{(l)}
where :math:`B` is the number of bases, :math:`V_b^{(l)}` are linearly combined
with coefficients :math:`a_{tb}^{(l)}`.
The block-diagonal-decomposition regularization decomposes :math:`W_t` into :math:`B`
block-diagonal matrices. We refer to :math:`B` as the number of bases:
.. math::
W_t^{(l)} = \oplus_{b=1}^B Q_{tb}^{(l)}
where :math:`B` is the number of bases, :math:`Q_{tb}^{(l)}` are block
bases with shape :math:`R^{(d^{(l+1)}/B)\times(d^{l}/B)}`.
Parameters
----------
in_size : int
Input feature size.
out_size : int
Output feature size.
num_types : int
Total number of types.
regularizer : str, optional
Which weight regularizer to use "basis" or "bdd":
- "basis" is short for basis-decomposition.
- "bdd" is short for block-diagonal-decomposition.
Default applies no regularization.
num_bases : int, optional
Number of bases. Needed when ``regularizer`` is specified. Typically smaller
than ``num_types``.
Default: ``None``.
Examples
--------
No regularization.
>>> from dgl.nn import TypedLinear
>>> import torch
>>>
>>> x = torch.randn(100, 32)
>>> x_type = torch.randint(0, 5, (100,))
>>> m = TypedLinear(32, 64, 5)
>>> y = m(x, x_type)
>>> print(y.shape)
torch.Size([100, 64])
With basis regularization
>>> x = torch.randn(100, 32)
>>> x_type = torch.randint(0, 5, (100,))
>>> m = TypedLinear(32, 64, 5, regularizer='basis', num_bases=4)
>>> y = m(x, x_type)
>>> print(y.shape)
torch.Size([100, 64])
"""
def __init__(self, in_size, out_size, num_types,
regularizer=None, num_bases=None):
super().__init__()
self.in_size = in_size
self.out_size = out_size
self.num_types = num_types
if regularizer is None:
self.W = nn.Parameter(torch.Tensor(num_types, in_size, out_size))
elif regularizer == 'basis':
if num_bases is None:
raise ValueError('Missing "num_bases" for basis regularization.')
self.W = nn.Parameter(torch.Tensor(num_bases, in_size, out_size))
self.coeff = nn.Parameter(torch.Tensor(num_types, num_bases))
self.num_bases = num_bases
elif regularizer == 'bdd':
if num_bases is None:
raise ValueError('Missing "num_bases" for bdd regularization.')
if in_size % num_bases != 0 or out_size % num_bases != 0:
raise ValueError(
'Input and output sizes must be divisible by num_bases.'
)
self.submat_in = in_size // num_bases
self.submat_out = out_size // num_bases
self.W = nn.Parameter(torch.Tensor(
num_types, num_bases * self.submat_in * self.submat_out))
self.num_bases = num_bases
else:
raise ValueError(
f'Supported regularizer options: "basis", "bdd", but got {regularizer}')
self.regularizer = regularizer
self.reset_parameters()
def reset_parameters(self):
"""Reset parameters"""
with torch.no_grad():
# Follow torch.nn.Linear 's initialization to use kaiming_uniform_ on in_size
if self.regularizer is None:
nn.init.uniform_(self.W, -1/math.sqrt(self.in_size), 1/math.sqrt(self.in_size))
elif self.regularizer == 'basis':
nn.init.uniform_(self.W, -1/math.sqrt(self.in_size), 1/math.sqrt(self.in_size))
nn.init.xavier_uniform_(self.coeff, gain=nn.init.calculate_gain('relu'))
elif self.regularizer == 'bdd':
nn.init.uniform_(self.W, -1/math.sqrt(self.submat_in), 1/math.sqrt(self.submat_in))
else:
raise ValueError(
f'Supported regularizer options: "basis", "bdd", but got {regularizer}')
def get_weight(self):
"""Get type-wise weight"""
if self.regularizer is None:
return self.W
elif self.regularizer == 'basis':
W = self.W.view(self.num_bases, self.in_size * self.out_size)
return (self.coeff @ W).view(self.num_types, self.in_size, self.out_size)
elif self.regularizer == 'bdd':
return self.W
else:
raise ValueError(
f'Supported regularizer options: "basis", "bdd", but got {regularizer}')
def forward(self, x, x_type, sorted_by_type=False):
"""Forward computation.
Parameters
----------
x : torch.Tensor
A 2D input tensor. Shape: (N, D1)
x_type : torch.Tensor
A 1D integer tensor storing the type of the elements in ``x`` with one-to-one
correspondenc. Shape: (N,)
sorted_by_type : bool, optional
Whether the inputs have been sorted by the types. Forward on pre-sorted inputs may
be faster.
Returns
-------
y : torch.Tensor
The transformed output tensor. Shape: (N, D2)
"""
w = self.get_weight()
if self.regularizer == 'bdd':
w = w.index_select(0, x_type).view(-1, self.submat_in, self.submat_out)
x = x.view(-1, 1, self.submat_in)
return torch.bmm(x, w).view(-1, self.out_size)
elif sorted_by_type:
pos_l = torch.searchsorted(x_type, torch.arange(self.num_types, device=x.device))
pos_r = torch.cat([pos_l[1:], torch.tensor([len(x_type)], device=x.device)])
seglen = (pos_r - pos_l).cpu() # XXX(minjie): cause device synchronize
return segment_mm(x, w, seglen_a=seglen)
else:
return gather_mm(x, w, idx_b=x_type)
def __repr__(self):
if self.regularizer is None:
return (f'TypedLinear(in_size={self.in_size}, out_size={self.out_size}, '
f'num_types={self.num_types})')
else:
return (f'TypedLinear(in_size={self.in_size}, out_size={self.out_size}, '
f'num_types={self.num_types}, regularizer={self.regularizer}, '
f'num_bases={self.num_bases})')
"""dgl gather_mm operator module.""" """dgl gather_mm operator module."""
from ..backend import gather_mm as gather_mm_internal from .. import backend as F
from ..backend import segment_mm as segment_mm_internal
__all__ = ['gather_mm', 'segment_mm'] __all__ = ['gather_mm']
def segment_mm(lhs_data, rhs_data, seglen_lhs): def gather_mm(a, b, *, idx_b):
r""" Performs matrix multiplication according to segments.
Suppose ``seglen_lhs == [10, 5, 0, 3]``, the operator will perform
four matrix multiplications:
lhs_data[0:10] @ rhs_data[0], lhs_data[10:15] @ rhs_data[1],
lhs_data[15:15] @ rhs_data[2], lhs_data[15:18] @ rhs_data[3]
Parameters
----------
lhs_data : tensor
The left operand, 2-D tensor of shape (N, D1)
rhs_data : tensor
The right operand, 2-D tensor of shape (R * D1, D2)
seglen_lhs : tensor
An integer tensor of shape (R,). Each element is the length of segments
of input ``lhs_data``. The summation of all elements must be equal to N.
Returns
-------
tensor
The output dense matrix of shape (N, D2)
"""
return segment_mm_internal(lhs_data, rhs_data, seglen_lhs)
def gather_mm(lhs_data, rhs_data, idx_lhs = None, idx_rhs = None):
r"""Gather data according to the given indices and perform matrix multiplication. r"""Gather data according to the given indices and perform matrix multiplication.
Let the result tensor be C, the operator conducts the following computation: Let the result tensor be ``c``, the operator conducts the following computation:
If both idx_lhs and idx_rhs are not none:
c[i] = lhs_data[idx_lhs[i]] @ rhs_data[idx_rhs[i]]
, where len(C) == len(idx_lhs) == len(idx_rhs)
If idx_lhs is given but not idx_rhs:
c[i] = rhs_data[idx_lhs[i]] @ rhs_data[i]
, where len(C) == len(idx_lhs)
If idx_rhs is given but not idx_lhs:
c[i] = lhs_data[i] @ rhs_data[idx_rhs[i]] c[i] = a[i] @ b[idx_b[i]]
, where len(C) == len(idx_rhs) , where len(c) == len(idx_b)
Parameters Parameters
---------- ----------
lhs_data : tensor a : Tensor
2-D tensor of shape (N, D1) A 2-D tensor of shape ``(N, D1)``
rhs_data : tensor b : Tensor
3-D tensor of shape (R, D1, D2) A 3-D tensor of shape ``(R, D1, D2)``
idx_lhs : Tensor, optional idx_b : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,). An 1-D integer tensor of shape ``(N,)``.
idx_rhs : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,).
Returns Returns
------- -------
Tensor Tensor
The output dense matrix of shape (N, D2) The output dense matrix of shape ``(N, D2)``
""" """
return gather_mm_internal(lhs_data, rhs_data, idx_lhs, idx_rhs) N, D1 = F.shape(a)
R, _, D2 = F.shape(b)
if N > 1000000 or D1 > 8 or D2 > 8:
# Use segment_mm for large workload
import torch
sorted_idx_b, perm = torch.sort(idx_b)
_, rev_perm = torch.sort(perm)
sorted_a = torch.index_select(a, 0, perm)
pos_l = torch.searchsorted(sorted_idx_b, torch.arange(R, device=a.device))
pos_r = torch.cat([pos_l[1:], torch.tensor([len(idx_b)], device=a.device)])
seglen = (pos_r - pos_l).cpu() # XXX(minjie): cause device synchronize
return torch.index_select(F.segment_mm(sorted_a, b, seglen), 0, rev_perm)
else:
return F.gather_mm(a, b, None, idx_b)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from ..base import DGLError from ..base import DGLError
from .. import backend as F from .. import backend as F
__all__ = ['segment_reduce', 'segment_softmax', 'segment_mm']
def segment_reduce(seglen, value, reducer='sum'): def segment_reduce(seglen, value, reducer='sum'):
"""Segment reduction operator. """Segment reduction operator.
...@@ -98,3 +99,29 @@ def segment_softmax(seglen, value): ...@@ -98,3 +99,29 @@ def segment_softmax(seglen, value):
value = F.exp(value - F.repeat(value_max, seglen, dim=0)) value = F.exp(value - F.repeat(value_max, seglen, dim=0))
value_sum = segment_reduce(seglen, value, reducer='sum') value_sum = segment_reduce(seglen, value, reducer='sum')
return value / F.repeat(value_sum, seglen, dim=0) return value / F.repeat(value_sum, seglen, dim=0)
def segment_mm(a, b, seglen_a):
r""" Performs matrix multiplication according to segments.
Suppose ``seglen_a == [10, 5, 0, 3]``, the operator will perform
four matrix multiplications::
a[0:10] @ b[0], a[10:15] @ b[1],
a[15:15] @ b[2], a[15:18] @ b[3]
Parameters
----------
a : Tensor
The left operand, 2-D tensor of shape ``(N, D1)``
b : Tensor
The right operand, 3-D tensor of shape ``(R, D1, D2)``
seglen_a : Tensor
An integer tensor of shape ``(R,)``. Each element is the length of segments
of input ``a``. The summation of all elements must be equal to ``N``.
Returns
-------
Tensor
The output dense matrix of shape ``(N, D2)``
"""
return F.segment_mm(a, b, seglen_a)
...@@ -389,108 +389,43 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple): ...@@ -389,108 +389,43 @@ def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
return out, (list_arg_u, list_arg_e, list_arg_u_ntype, list_arg_e_etype) return out, (list_arg_u, list_arg_e, list_arg_u_ntype, list_arg_e_etype)
def _segment_mm(A, B, out, seglen_A, a_trans=False, b_trans=False): def _segment_mm(A, B, out, seglen_A, b_trans=False):
r""" Dense Matrix Multiplication interface. It multiplies dense tensor A """Invoke the C API of segment_mm."""
and dense tensor B according to relation types. A is sorted and concatenated
according to relation types.
Parameters
----------
A : tensor
2-D tensor of shape (N, D1)
B : tensor
2-D tensor of shape (R * D1, D2)
seglen_A : Tensor
An integer tensor of shape (R,). Each element is the length of segments
of input ``A``. The summation of all elements must be equal to N.
a_trans : bool
Indicates whether matrix A needs to be tranposed
b_trans : bool
Indicates whether matrix B needs to be tranposed
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
# TODO(Israt): Add CPU support. Currently, only handles GPU code
_CAPI_DGLKernelSEGMENTMM(to_dgl_nd(A), _CAPI_DGLKernelSEGMENTMM(to_dgl_nd(A),
to_dgl_nd(B), to_dgl_nd(B),
to_dgl_nd_for_write(out), to_dgl_nd_for_write(out),
to_dgl_nd(seglen_A), to_dgl_nd(seglen_A),
a_trans, b_trans) False, b_trans)
return out return out
def _segment_mm_backward_B(A, dC, dB, seglen):
def _gather_mm(A, B, out, num_rel, idx_a=None, idx_b=None): """Invoke the C API of the backward of segment_mm on B."""
r""" Generalized Dense Matrix Multiplication interface. It multiplies _CAPI_DGLKernelSEGMENTMMBackwardB(
tensor A and B according to relation types and outputs in out. B is a to_dgl_nd(A),
concatenated tensor across relation types. A is unsorted and the to_dgl_nd(dC),
relation type is fetched from param etypes. to_dgl_nd_for_write(dB),
to_dgl_nd(seglen))
Parameters return dB
----------
A : tensor def _gather_mm(A, B, out, idx_a=None, idx_b=None):
2-D tensor of shape (N, D1) r"""Invoke the C API of the gather_mm operator."""
B : tensor
2-D tensor of shape (R * D1, D2)
idx_a : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,)
idx_b : Tensor, optional
If specified, must be a 1-D integer tensor of shape (N,)
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
# TODO(Israt): Add CPU support. Currently, only handles GPU code
_CAPI_DGLKernelGATHERMM(to_dgl_nd(A), _CAPI_DGLKernelGATHERMM(to_dgl_nd(A),
to_dgl_nd(B), to_dgl_nd(B),
to_dgl_nd_for_write(out), to_dgl_nd_for_write(out),
to_dgl_nd(idx_a), to_dgl_nd(idx_a),
to_dgl_nd(idx_b), to_dgl_nd(idx_b))
num_rel)
return out return out
def _gather_mm_scatter(A, B, out, num_rel, idx_a=None, idx_b=None, idx_c=None, def _gather_mm_scatter(A, B, out, idx_a=None, idx_b=None, idx_c=None):
a_trans=False, b_trans=False): r"""Invoke the C API of the gather_mm_scatter operator."""
r""" Generalized Dense Matrix Multiplication interface. It multiplies _CAPI_DGLKernelGATHERMMSCATTER(
tensor A and B according to relation types and outputs in out. B is a to_dgl_nd(A),
concatenated tensor across relation types. A is unsorted and the to_dgl_nd(B),
relation type is fetched from param etypes. to_dgl_nd_for_write(out),
to_dgl_nd(idx_a),
Parameters to_dgl_nd(idx_b),
---------- to_dgl_nd(idx_c))
A : tensor
2-D tensor of shape (N, D1)
B : tensor
2-D tensor of shape (R * D1, D2)
idx_a : Tensor, optional
If specified, must be a 1-D integer tensor of shape (K,)
idx_b : Tensor, optional
If specified, must be a 1-D integer tensor of shape (N,)
idx_c : Tensor, optional
If specified, must be a 1-D integer tensor of shape (N,)
A_trans : bool
Indicates whether matrix A needs to be tranposed
B_trans : bool
Indicates whether matrix B needs to be tranposed
Returns
-------
Tensor
The output dense matrix of shape (N, D2)
"""
# TODO(Israt): Add CPU support. Currently, only handles GPU code
_CAPI_DGLKernelGATHERMMSCATTER(to_dgl_nd(A),
to_dgl_nd(B),
to_dgl_nd_for_write(out),
to_dgl_nd(idx_a),
to_dgl_nd(idx_b),
to_dgl_nd(idx_c),
num_rel, a_trans, b_trans)
return out return out
......
...@@ -22,7 +22,7 @@ import scipy.sparse as sparse ...@@ -22,7 +22,7 @@ import scipy.sparse as sparse
import scipy.sparse.linalg import scipy.sparse.linalg
from .._ffi.function import _init_api from .._ffi.function import _init_api
from ..base import dgl_warning, DGLError from ..base import dgl_warning, DGLError, NID, EID
from .. import convert from .. import convert
from ..heterograph import DGLHeteroGraph, DGLBlock from ..heterograph import DGLHeteroGraph, DGLBlock
from ..heterograph_index import create_metagraph_index, create_heterograph_from_relations from ..heterograph_index import create_metagraph_index, create_heterograph_from_relations
...@@ -2973,7 +2973,7 @@ def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'): ...@@ -2973,7 +2973,7 @@ def sort_csc_by_tag(g, tag, tag_offset_name='_TAG_OFFSET'):
return new_g return new_g
def reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='src', def reorder_graph(g, node_permute_algo=None, edge_permute_algo='src',
store_ids=True, permute_config=None): store_ids=True, permute_config=None):
r"""Return a new graph with nodes and edges re-ordered/re-labeled r"""Return a new graph with nodes and edges re-ordered/re-labeled
according to the specified permute algorithm. according to the specified permute algorithm.
...@@ -2994,7 +2994,7 @@ def reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='src', ...@@ -2994,7 +2994,7 @@ def reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='src',
g : DGLGraph g : DGLGraph
The homogeneous graph. The homogeneous graph.
node_permute_algo: str, optional node_permute_algo: str, optional
The permutation algorithm to re-order nodes. Options are ``rcmk`` or The permutation algorithm to re-order nodes. If given, the options are ``rcmk`` or
``metis`` or ``custom``. ``rcmk`` is the default value. ``metis`` or ``custom``. ``rcmk`` is the default value.
* ``rcmk``: Use the `Reverse Cuthill–McKee <https://docs.scipy.org/doc/scipy/reference/ * ``rcmk``: Use the `Reverse Cuthill–McKee <https://docs.scipy.org/doc/scipy/reference/
...@@ -3014,6 +3014,8 @@ def reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='src', ...@@ -3014,6 +3014,8 @@ def reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='src',
* ``src``: Edges are arranged according to their source nodes. * ``src``: Edges are arranged according to their source nodes.
* ``dst``: Edges are arranged according to their destination nodes. * ``dst``: Edges are arranged according to their destination nodes.
* ``custom``: Edges are arranged according to the user-provided edge permutation
array (provided in :attr:`permute_config`).
store_ids: bool, optional store_ids: bool, optional
If True, DGL will store the original node and edge IDs in the ndata and edata If True, DGL will store the original node and edge IDs in the ndata and edata
of the resulting graph under name ``dgl.NID`` and ``dgl.EID``, respectively. of the resulting graph under name ``dgl.NID`` and ``dgl.EID``, respectively.
...@@ -3023,9 +3025,12 @@ def reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='src', ...@@ -3023,9 +3025,12 @@ def reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='src',
* For ``rcmk``, this argument is not required. * For ``rcmk``, this argument is not required.
* For ``metis``, users should specify the number of partitions ``k`` (e.g., * For ``metis``, users should specify the number of partitions ``k`` (e.g.,
``permute_config={'k':10}`` to partition the graph to 10 clusters). ``permute_config={'k':10}`` to partition the graph to 10 clusters).
* For ``custom``, users should provide a node permutation array ``nodes_perm``. * For ``custom`` node reordering, users should provide a node permutation
The array must be an integer list or a tensor with the same device of the array ``nodes_perm``. The array must be an integer list or a tensor with
input graph. the same device of the input graph.
* For ``custom`` edge reordering, users should provide an edge permutation
array ``edges_perm``. The array must be an integer list or a tensor with
the same device of the input graph.
Returns Returns
------- -------
...@@ -3118,49 +3123,83 @@ def reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='src', ...@@ -3118,49 +3123,83 @@ def reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='src',
[2], [2],
[1]]), '_ID': tensor([0, 2, 4, 1, 3])} [1]]), '_ID': tensor([0, 2, 4, 1, 3])}
Reorder according to node and edge types:
>>> ntype = ... # some node type array
>>> etype = ... # some edge type array
>>> sorted_ntype, idx_nt = torch.sort(ntype)
>>> sorted_etype, idx_et = torch.sort(etype)
>>> rg = dgl.reorder_graph(g, node_permute_algo='custom', edge_permute_algo='custom',
... permute_config={'nodes_perm' : idx_nt.to(g.idtype),
... 'edges_perm' : idx_et.to(g.idtype)})
""" """
# sanity checks # sanity checks
if not g.is_homogeneous: if not g.is_homogeneous:
raise DGLError("Homograph is supported only.") raise DGLError("Only homogeneous graphs are supported.")
expected_node_algo = ['rcmk', 'metis', 'custom'] expected_node_algo = ['rcmk', 'metis', 'custom']
if node_permute_algo not in expected_node_algo: if node_permute_algo is not None and node_permute_algo not in expected_node_algo:
raise DGLError("Unexpected node_permute_algo is specified: {}. Expected algos: {}".format( raise DGLError("Unexpected node_permute_algo is specified: {}. Expected algos: {}".format(
node_permute_algo, expected_node_algo)) node_permute_algo, expected_node_algo))
expected_edge_algo = ['src', 'dst'] expected_edge_algo = ['src', 'dst', 'custom']
if edge_permute_algo not in expected_edge_algo: if edge_permute_algo not in expected_edge_algo:
raise DGLError("Unexpected edge_permute_algo is specified: {}. Expected algos: {}".format( raise DGLError("Unexpected edge_permute_algo is specified: {}. Expected algos: {}".format(
edge_permute_algo, expected_edge_algo)) edge_permute_algo, expected_edge_algo))
# generate nodes permutation g.edata['__orig__'] = F.arange(0, g.num_edges(), g.idtype, g.device)
# reorder nodes
if node_permute_algo == 'rcmk': if node_permute_algo == 'rcmk':
nodes_perm = rcmk_perm(g) nodes_perm = rcmk_perm(g)
rg = subgraph.node_subgraph(g, nodes_perm, store_ids=False)
elif node_permute_algo == 'metis': elif node_permute_algo == 'metis':
if permute_config is None or 'k' not in permute_config: if permute_config is None or 'k' not in permute_config:
raise DGLError( raise DGLError(
"Partition parts 'k' is required for metis. Please specify in permute_config.") "Partition parts 'k' is required for metis. Please specify in permute_config.")
nodes_perm = metis_perm(g, permute_config['k']) nodes_perm = metis_perm(g, permute_config['k'])
else: rg = subgraph.node_subgraph(g, nodes_perm, store_ids=False)
elif node_permute_algo == 'custom':
if permute_config is None or 'nodes_perm' not in permute_config: if permute_config is None or 'nodes_perm' not in permute_config:
raise DGLError( raise DGLError(
"permute_algo is specified as custom, but no 'nodes_perm' is specified in \ "node_permute_algo is specified as custom, but no 'nodes_perm' is specified in \
permute_config.") permute_config.")
nodes_perm = permute_config['nodes_perm'] nodes_perm = permute_config['nodes_perm']
if len(nodes_perm) != g.num_nodes(): if len(nodes_perm) != g.num_nodes():
raise DGLError("Length of passed in nodes_perm[{}] does not \ raise DGLError("Length of 'nodes_perm' ({}) does not \
match graph num_nodes[{}].".format(len(nodes_perm), g.num_nodes())) match graph num_nodes ({}).".format(len(nodes_perm), g.num_nodes()))
rg = subgraph.node_subgraph(g, nodes_perm, store_ids=False)
else:
nodes_perm = F.arange(0, g.num_nodes(), g.idtype, g.device)
rg = g.clone()
# reorder nodes if store_ids:
rg = subgraph.node_subgraph(g, nodes_perm, store_ids=store_ids) rg.ndata[NID] = F.copy_to(F.tensor(nodes_perm, g.idtype), g.device)
g.edata.pop('__orig__')
# reorder edges # reorder edges
if edge_permute_algo == 'src': if edge_permute_algo == 'src':
# the output graph of dgl.node_subgraph() is ordered/labeled edges_perm = np.argsort(F.asnumpy(rg.edges()[0]))
# according to src already. Nothing needs to do. rg = subgraph.edge_subgraph(
pass rg, edges_perm, relabel_nodes=False, store_ids=False)
elif edge_permute_algo == 'dst': elif edge_permute_algo == 'dst':
edges_perm = np.argsort(F.asnumpy(rg.edges()[1])) edges_perm = np.argsort(F.asnumpy(rg.edges()[1]))
rg = subgraph.edge_subgraph( rg = subgraph.edge_subgraph(
rg, edges_perm, relabel_nodes=False, store_ids=store_ids) rg, edges_perm, relabel_nodes=False, store_ids=False)
elif edge_permute_algo == 'custom':
if permute_config is None or 'edges_perm' not in permute_config:
raise DGLError(
"edge_permute_algo is specified as custom, but no 'edges_perm' is specified in \
permute_config.")
edges_perm = permute_config['edges_perm']
# First revert the edge reorder caused by node reorder and then
# apply user-provided edge permutation
rev_id = F.argsort(rg.edata['__orig__'], 0, False)
edges_perm = F.astype(F.gather_row(rev_id, edges_perm), rg.idtype)
rg = subgraph.edge_subgraph(
rg, edges_perm, relabel_nodes=False, store_ids=False)
if store_ids:
rg.edata[EID] = rg.edata.pop('__orig__')
return rg return rg
......
...@@ -23,108 +23,114 @@ namespace aten { ...@@ -23,108 +23,114 @@ namespace aten {
} while (0) } while (0)
/*! \brief Generalized segmentMM. */ /*! \brief Generalized SegmentMM. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, int bits>
void segmentMM(const NDArray A, void SegmentMM(const NDArray A,
const NDArray B, const NDArray B,
NDArray C, NDArray C,
const NDArray seglen_A, const NDArray seglen_A,
bool a_trans, bool b_trans) { bool a_trans, bool b_trans) {
SWITCH_BITS(bits, DType, { LOG(FATAL) << "Unsupported CPU kernel for SegmentMM.";
LOG(FATAL) << "Unsupported CPU kernel for SegmentMM."; }
});
template <int XPU, typename IdType, int bits>
void SegmentMMBackwardB(const NDArray A,
const NDArray dC,
NDArray dB,
const NDArray seglen) {
LOG(FATAL) << "Unsupported CPU kernel for SegmentMMBackwardB.";
} }
/*! \brief Generalized GatherMM. */ /*! \brief Generalized GatherMM. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, int bits>
void gatherMM(const NDArray A, void GatherMM(const NDArray A,
const NDArray B, const NDArray B,
NDArray C, NDArray C,
const NDArray idx_a, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_b) {
const int num_rel) { LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
SWITCH_BITS(bits, DType, {
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
});
} }
/*! \brief Generalized GatherMM_scatter. */ /*! \brief Generalized GatherMM_scatter. */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, int bits>
void gatherMM_scatter(const NDArray A, void GatherMMScatter(const NDArray A,
const NDArray B, const NDArray B,
NDArray C, NDArray C,
const NDArray idx_a, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_b,
const NDArray idx_c, const NDArray idx_c) {
const int num_rel, LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
bool a_trans, bool b_trans) {
SWITCH_BITS(bits, DType, {
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
});
} }
template void gatherMM<kDLCPU, int32_t, 16>( template void GatherMM<kDLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel); const NDArray idx_a, const NDArray idx_b);
template void gatherMM<kDLCPU, int64_t, 16>( template void GatherMM<kDLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel); const NDArray idx_a, const NDArray idx_b);
template void gatherMM<kDLCPU, int32_t, 32>( template void GatherMM<kDLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel); const NDArray idx_a, const NDArray idx_b);
template void gatherMM<kDLCPU, int64_t, 32>( template void GatherMM<kDLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel); const NDArray idx_a, const NDArray idx_b);
template void gatherMM<kDLCPU, int32_t, 64>( template void GatherMM<kDLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel); const NDArray idx_a, const NDArray idx_b);
template void gatherMM<kDLCPU, int64_t, 64>( template void GatherMM<kDLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const int num_rel); const NDArray idx_a, const NDArray idx_b);
template void gatherMM_scatter<kDLCPU, int32_t, 16>( template void GatherMMScatter<kDLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const int num_rel, bool a_trans, bool b_trans); template void GatherMMScatter<kDLCPU, int64_t, 16>(
template void gatherMM_scatter<kDLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const int num_rel, bool a_trans, bool b_trans); template void GatherMMScatter<kDLCPU, int32_t, 32>(
template void gatherMM_scatter<kDLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const int num_rel, bool a_trans, bool b_trans); template void GatherMMScatter<kDLCPU, int64_t, 32>(
template void gatherMM_scatter<kDLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const int num_rel, bool a_trans, bool b_trans); template void GatherMMScatter<kDLCPU, int32_t, 64>(
template void gatherMM_scatter<kDLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const int num_rel, bool a_trans, bool b_trans); template void GatherMMScatter<kDLCPU, int64_t, 64>(
template void gatherMM_scatter<kDLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c, const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const int num_rel, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int32_t, 16>( template void SegmentMM<kDLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int64_t, 16>( template void SegmentMM<kDLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int32_t, 32>( template void SegmentMM<kDLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int64_t, 32>( template void SegmentMM<kDLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int32_t, 64>( template void SegmentMM<kDLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void segmentMM<kDLCPU, int64_t, 64>( template void SegmentMM<kDLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDLCPU, int32_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int64_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int32_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int64_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int32_t, 64>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int64_t, 64>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment