Unverified Commit 00add9f2 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Merge pull request #90 from jermainewang/cpp

[GraphIndex] Graph index and many related changes
parents ec4216dd dce1f44d
...@@ -4,43 +4,9 @@ Graph Convolutional Networks (GCN) ...@@ -4,43 +4,9 @@ Graph Convolutional Networks (GCN)
Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907) Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907)
Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn) Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn)
The folder contains three different implementations using DGL. The folder contains two different implementations using DGL.
Naive GCN (gcn.py) Batched GCN (gcn.py)
-------
The model is defined in the finest granularity (aka on *one* edge and *one* node).
* The message function `gcn_msg` computes the message for one edge. It simply returns the `h` representation of the source node.
```python
def gcn_msg(src, edge):
# src['h'] is a tensor of shape (D,). D is the feature length.
return src['h']
```
* The reduce function `gcn_reduce` accumulates the incoming messages for one node. The `msgs` argument is a list of all the messages. In GCN, the incoming messages are summed up.
```python
def gcn_reduce(node, msgs):
# msgs is a list of in-coming messages.
return sum(msgs)
```
* The update function `NodeUpdateModule` computes the new new node representation `h` using non-linear transformation on the reduced messages.
```python
class NodeUpdateModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node, accum):
# accum is a tensor of shape (D,).
h = self.linear(accum)
if self.activation:
h = self.activation(h)
return {'h' : h}
```
After defining the functions on each node/edge, the message passing is triggered by calling `update_all` on the DGLGraph object (in GCN module).
Batched GCN (gcn_batch.py)
----------- -----------
Defining the model on only one node and edge makes it hard to fully utilize GPUs. As a result, we allow users to define model on a *batch of* nodes and edges. Defining the model on only one node and edge makes it hard to fully utilize GPUs. As a result, we allow users to define model on a *batch of* nodes and edges.
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
Semi-Supervised Classification with Graph Convolutional Networks Semi-Supervised Classification with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1609.02907 Paper: https://arxiv.org/abs/1609.02907
Code: https://github.com/tkipf/gcn Code: https://github.com/tkipf/gcn
GCN with batch processing
""" """
import argparse import argparse
import numpy as np import numpy as np
...@@ -9,14 +11,15 @@ import time ...@@ -9,14 +11,15 @@ import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
def gcn_msg(src, edge): def gcn_msg(src, edge):
return src['h'] return {'m' : src['h']}
def gcn_reduce(node, msgs): def gcn_reduce(node, msgs):
return {'h' : sum(msgs)} return {'h' : torch.sum(msgs['m'], 1)}
class NodeApplyModule(nn.Module): class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None): def __init__(self, in_feats, out_feats, activation=None):
...@@ -32,7 +35,7 @@ class NodeApplyModule(nn.Module): ...@@ -32,7 +35,7 @@ class NodeApplyModule(nn.Module):
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, def __init__(self,
nx_graph, g,
in_feats, in_feats,
n_hidden, n_hidden,
n_classes, n_classes,
...@@ -40,7 +43,7 @@ class GCN(nn.Module): ...@@ -40,7 +43,7 @@ class GCN(nn.Module):
activation, activation,
dropout): dropout):
super(GCN, self).__init__() super(GCN, self).__init__()
self.g = DGLGraph(nx_graph) self.g = g
self.dropout = dropout self.dropout = dropout
# input layer # input layer
self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)]) self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
...@@ -50,31 +53,24 @@ class GCN(nn.Module): ...@@ -50,31 +53,24 @@ class GCN(nn.Module):
# output layer # output layer
self.layers.append(NodeApplyModule(n_hidden, n_classes)) self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features, train_nodes): def forward(self, features):
for n, feat in features.items(): self.g.set_n_repr({'h' : features})
self.g.nodes[n]['h'] = feat
for layer in self.layers: for layer in self.layers:
# apply dropout # apply dropout
if self.dropout: if self.dropout:
self.g.nodes[n]['h'] = F.dropout(g.nodes[n]['h'], p=self.dropout) g.apply_nodes(apply_node_func=
lambda node: F.dropout(node['h'], p=self.dropout))
self.g.update_all(gcn_msg, gcn_reduce, layer) self.g.update_all(gcn_msg, gcn_reduce, layer)
return torch.cat([torch.unsqueeze(self.g.nodes[n]['h'], 0) for n in train_nodes]) return self.g.pop_n_repr('h')
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
data = load_data(args) data = load_data(args)
# features of each samples features = torch.FloatTensor(data.features)
features = {} labels = torch.LongTensor(data.labels)
labels = [] mask = torch.ByteTensor(data.train_mask)
train_nodes = [] in_feats = features.shape[1]
for n in data.graph.nodes():
features[n] = torch.FloatTensor(data.features[n, :])
if data.train_mask[n] == 1:
train_nodes.append(n)
labels.append(data.labels[n])
labels = torch.LongTensor(labels)
in_feats = data.features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
...@@ -83,11 +79,13 @@ def main(args): ...@@ -83,11 +79,13 @@ def main(args):
else: else:
cuda = True cuda = True
torch.cuda.set_device(args.gpu) torch.cuda.set_device(args.gpu)
features = {k : v.cuda() for k, v in features.items()} features = features.cuda()
labels = labels.cuda() labels = labels.cuda()
mask = mask.cuda()
# create GCN model # create GCN model
model = GCN(data.graph, g = DGLGraph(data.graph)
model = GCN(g,
in_feats, in_feats,
args.n_hidden, args.n_hidden,
n_classes, n_classes,
...@@ -107,9 +105,9 @@ def main(args): ...@@ -107,9 +105,9 @@ def main(args):
if epoch >= 3: if epoch >= 3:
t0 = time.time() t0 = time.time()
# forward # forward
logits = model(features, train_nodes) logits = model(features)
logp = F.log_softmax(logits, 1) logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, labels) loss = F.nll_loss(logp[mask], labels[mask])
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
...@@ -130,7 +128,7 @@ if __name__ == '__main__': ...@@ -130,7 +128,7 @@ if __name__ == '__main__':
help="gpu") help="gpu")
parser.add_argument("--lr", type=float, default=1e-3, parser.add_argument("--lr", type=float, default=1e-3,
help="learning rate") help="learning rate")
parser.add_argument("--n-epochs", type=int, default=10, parser.add_argument("--n-epochs", type=int, default=20,
help="number of training epochs") help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16, parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units") help="number of hidden gcn units")
......
...@@ -23,10 +23,10 @@ class NodeApplyModule(nn.Module): ...@@ -23,10 +23,10 @@ class NodeApplyModule(nn.Module):
self.activation = activation self.activation = activation
def forward(self, node): def forward(self, node):
h = self.linear(node) h = self.linear(node['h'])
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return h return {'h' : h}
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, def __init__(self,
...@@ -49,14 +49,16 @@ class GCN(nn.Module): ...@@ -49,14 +49,16 @@ class GCN(nn.Module):
self.layers.append(NodeApplyModule(n_hidden, n_classes)) self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features): def forward(self, features):
self.g.set_n_repr(features) self.g.set_n_repr({'h' : features})
for layer in self.layers: for layer in self.layers:
# apply dropout # apply dropout
if self.dropout: if self.dropout:
val = F.dropout(self.g.get_n_repr(), p=self.dropout) g.apply_nodes(apply_node_func=
self.g.set_n_repr(val) lambda node: F.dropout(node['h'], p=self.dropout))
self.g.update_all(fn.copy_src(), fn.sum(), layer, batchable=True) self.g.update_all(fn.copy_src(src='h', out='m'),
return self.g.pop_n_repr() fn.sum(msg='m', out='h'),
layer)
return self.g.pop_n_repr('h')
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
......
Community Detection with Graph Neural Networks (CDGNN)
============
Paper link: [https://arxiv.org/abs/1705.08415](https://arxiv.org/abs/1705.08415)
Author's code repo: [https://github.com/joanbruna/GNN_community](https://github.com/joanbruna/GNN_community)
This folder contains a DGL implementation of the CDGNN model.
An experiment on the Stochastic Block Model in default settings can be run with
```bash
python train.py
```
An experiment on the Stochastic Block Model in customized settings can be run with
```bash
python train.py --batch-size BATCH_SIZE --gpu GPU --n-communities N_COMMUNITIES --n-features N_FEATURES --n-graphs N_GRAPH --n-iterations N_ITERATIONS --n-layers N_LAYER --n-nodes N_NODE --model-path MODEL_PATH --radius RADIUS
```
...@@ -3,237 +3,93 @@ Supervised Community Detection with Hierarchical Graph Neural Networks ...@@ -3,237 +3,93 @@ Supervised Community Detection with Hierarchical Graph Neural Networks
https://arxiv.org/abs/1705.08415 https://arxiv.org/abs/1705.08415
Deviations from paper: Deviations from paper:
- Addition of global aggregation operator. - Pm Pd
- Message passing is equivalent to `A^j \cdot X`, instead of `\min(1, A^j) \cdot X`.
""" """
# TODO self-loop?
# TODO in-place edit of node_reprs/edge_reprs in message_func/update_func?
# TODO batch-norm
import copy import copy
import itertools import itertools
import dgl.graph as G import dgl
import dgl.function as fn
import networkx as nx import networkx as nx
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
class GLGModule(nn.Module): class GNNModule(nn.Module):
__SHADOW__ = 'shadow'
def __init__(self, in_feats, out_feats, radius): def __init__(self, in_feats, out_feats, radius):
super().__init__() super().__init__()
self.out_feats = out_feats
self.radius = radius self.radius = radius
new_linear = lambda: nn.Linear(in_feats, out_feats) new_linear = lambda: nn.Linear(in_feats, out_feats * 2)
new_module_list = lambda: nn.ModuleList([new_linear() for i in range(radius)]) new_module_list = lambda: nn.ModuleList([new_linear() for i in range(radius)])
self.theta_x, self.theta_y, self.theta_deg, self.theta_global = \ self.theta_x, self.theta_deg, self.theta_y = \
new_linear(), new_linear(), new_linear(), new_linear() new_linear(), new_linear(), new_linear()
self.theta_list = new_module_list() self.theta_list = new_module_list()
self.gamma_x, self.gamma_y, self.gamma_deg, self.gamma_global = \ self.gamma_y, self.gamma_deg, self.gamma_x = \
new_linear(), new_linear(), new_linear(), new_linear() new_linear(), new_linear(), new_linear()
self.gamma_list = new_module_list() self.gamma_list = new_module_list()
@staticmethod self.bn_x = nn.BatchNorm1d(out_feats)
def copy(which): self.bn_y = nn.BatchNorm1d(out_feats)
if which == 'src':
return lambda src, trg, _: src.copy()
elif which == 'trg':
return lambda src, trg, _: trg.copy()
@staticmethod
def aggregate(msg_fld, trg_fld, normalize=False):
def a(node_reprs, edge_reprs):
node_reprs = node_reprs.copy()
node_reprs[trg_fld] = sum(msg[msg_fld] for msg in edge_reprs)
if normalize:
node_reprs[trg_fld] /= len(edge_reprs)
return node_reprs
return a
@staticmethod
def pull(msg_fld, trg_fld):
def p(node_reprs, edge_reprs):
node_reprs = node_reprs.copy()
node_reprs[trg_fld] = edge_reprs[0][msg_fld]
return node_reprs
return p
def local_aggregate(self, g):
def step():
g.register_message_func(self.copy('src'), g.edges)
g.register_update_func(self.aggregate('x', 'x'), g.nodes)
g.update_all()
step()
for reprs in g.nodes.values():
reprs[0] = reprs['x']
for i in range(1, self.radius):
for j in range(2 ** (i - 1)):
step()
for reprs in g.nodes.values():
reprs[i] = reprs['x']
@staticmethod
def global_aggregate(g):
shadow = GLGModule.__SHADOW__
copy, aggregate, pull = GLGModule.copy, GLGModule.aggregate, GLGModule.pull
node_list = list(g.nodes)
uv_list = [(node, shadow) for node in g.nodes]
vu_list = [(shadow, node) for node in g.nodes]
g.add_node(shadow) # TODO context manager
tuple(itertools.starmap(g.add_edge, uv_list))
g.register_message_func(copy('src'), uv_list)
g.register_update_func(aggregate('x', 'global', normalize=True), (shadow,))
g.update_to(shadow)
tuple(itertools.starmap(g.add_edge, vu_list))
g.register_message_func(copy('src'), vu_list)
g.register_update_func(pull('global', 'global'), node_list)
g.update_from(shadow)
g.remove_node(shadow)
@staticmethod
def multiply_by_degree(g):
g.register_message_func(lambda *args: None, g.edges)
def update_func(node_reprs, _):
node_reprs = node_reprs.copy()
node_reprs['deg'] = node_reprs['x'] * node_reprs['degree']
return node_reprs
g.register_update_func(update_func, g.nodes)
g.update_all()
@staticmethod
def message_func(src, trg, _):
return {'y' : src['x']}
def update_func(self, which):
if which == 'node':
linear_x, linear_y, linear_deg, linear_global = \
self.theta_x, self.theta_y, self.theta_deg, self.theta_global
linear_list = self.theta_list
elif which == 'edge':
linear_x, linear_y, linear_deg, linear_global = \
self.gamma_x, self.gamma_y, self.gamma_deg, self.gamma_global
linear_list = self.gamma_list
def u(node_reprs, edge_reprs):
edge_reprs = filter(lambda x: x is not None, edge_reprs)
y = sum(x['y'] for x in edge_reprs)
node_reprs = node_reprs.copy()
node_reprs['x'] = linear_x(node_reprs['x']) \
+ linear_y(y) \
+ linear_deg(node_reprs['deg']) \
+ linear_global(node_reprs['global']) \
+ sum(linear(node_reprs[i]) \
for i, linear in enumerate(linear_list))
return node_reprs
return u
def forward(self, g, lg, glg):
self.local_aggregate(g)
self.local_aggregate(lg)
self.global_aggregate(g)
self.global_aggregate(lg)
self.multiply_by_degree(g)
self.multiply_by_degree(lg)
# TODO efficiency
for node, reprs in g.nodes.items():
glg.nodes[node].update(reprs)
for node, reprs in lg.nodes.items():
glg.nodes[node].update(reprs)
glg.register_message_func(self.message_func, glg.edges)
glg.register_update_func(self.update_func('node'), g.nodes)
glg.register_update_func(self.update_func('edge'), lg.nodes)
glg.update_all()
# TODO efficiency
for node, reprs in g.nodes.items():
reprs.update(glg.nodes[node])
for node, reprs in lg.nodes.items():
reprs.update(glg.nodes[node])
def aggregate(self, g, z):
z_list = []
g.set_n_repr(z)
g.update_all(fn.copy_src(), fn.sum())
z_list.append(g.get_n_repr())
for i in range(self.radius - 1):
for j in range(2 ** i):
g.update_all(fn.copy_src(), fn.sum())
z_list.append(g.get_n_repr())
return z_list
class GNNModule(nn.Module): def forward(self, g, lg, x, y, deg_g, deg_lg, eid2nid):
def __init__(self, in_feats, out_feats, order, radius): xy = F.embedding(eid2nid, x)
super().__init__()
self.module_list = nn.ModuleList([GLGModule(in_feats, out_feats, radius) x_list = [theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x))]
for i in range(order)])
g.set_e_repr(y)
g.update_all(fn.copy_edge(), fn.sum())
yx = g.get_n_repr()
def forward(self, pairs, fusions): x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum(x_list) + self.theta_y(yx)
for module, (g, lg), glg in zip(self.module_list, pairs, fusions): x = self.bn_x(x[:, :self.out_feats] + F.relu(x[:, self.out_feats:]))
module(g, lg, glg)
for lhs, rhs in zip(pairs[:-1], pairs[1:]): y_list = [gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y))]
for node, reprs in lhs[1].nodes.items(): lg.set_n_repr(xy)
x_rhs = reprs['x'] lg.update_all(fn.copy_src(), fn.sum())
reprs['x'] = x_rhs + rhs[0].nodes[node]['x'] xy = lg.get_n_repr()
rhs[0].nodes[node]['x'] += x_rhs y = self.gamma_y(y) + self.gamma_deg(deg_lg * y) + sum(y_list) + self.gamma_x(xy)
y = self.bn_y(y[:, :self.out_feats] + F.relu(y[:, self.out_feats:]))
return x, y
class GNN(nn.Module): class GNN(nn.Module):
def __init__(self, feats, order, radius, n_classes): def __init__(self, feats, radius, n_classes):
super().__init__()
self.order = order
self.linear = nn.Linear(feats[-1], n_classes)
self.module_list = nn.ModuleList([GNNModule(in_feats, out_feats, order, radius)
for in_feats, out_feats in zip(feats[:-1], feats[1:])])
@staticmethod
def line_graph(g):
lg = nx.line_graph(g)
glg = nx.DiGraph()
glg.add_nodes_from(g.nodes)
glg.add_nodes_from(lg.nodes)
for u, v in g.edges:
glg.add_edge(u, (u, v))
glg.add_edge((u, v), u)
glg.add_edge(v, (u, v))
glg.add_edge((u, v), v)
return lg, glg
@staticmethod
def nx2dgl(g):
deg_dict = dict(nx.degree(g))
z = sum(deg_dict.values())
dgl_g = G.DGLGraph(g)
for node, reprs in dgl_g.nodes.items():
reprs['degree'] = deg_dict[node]
reprs['x'] = th.full((1, 1), reprs['degree'] / z)
reprs.update(g.nodes[node])
return dgl_g
def forward(self, g):
""" """
Parameters Parameters
---------- ----------
g : networkx.DiGraph g : networkx.DiGraph
""" """
pair_list, glg_list = [], [] super(GNN, self).__init__()
dgl_g = self.nx2dgl(g) self.linear = nn.Linear(feats[-1], n_classes)
origin = dgl_g self.module_list = nn.ModuleList([GNNModule(m, n, radius)
for i in range(self.order): for m, n in zip(feats[:-1], feats[1:])])
lg, glg = self.line_graph(g)
dgl_lg = self.nx2dgl(lg)
pair_list.append((dgl_g, copy.deepcopy(dgl_lg)))
glg_list.append(G.DGLGraph(glg))
g = lg
dgl_g = dgl_lg
for module in self.module_list: def forward(self, g, lg, deg_g, deg_lg, eid2nid):
module(pair_list, glg_list) def normalize(x):
x = x - th.mean(x, 0)
x = x / th.sqrt(th.mean(x * x, 0))
return x
return self.linear(th.cat([reprs['x'] for reprs in origin.nodes.values()], 0)) x = normalize(deg_g)
y = normalize(deg_lg)
for module in self.module_list:
x, y = module(g, lg, x, y, deg_g, deg_lg, eid2nid)
return self.linear(x)
"""
By Minjie
"""
from __future__ import division
import math
import numpy as np
import scipy.sparse as sp
import networkx as nx
import matplotlib.pyplot as plt
class SSBM:
def __init__(self, n, k, a=10.0, b=2.0, regime='constant', rng=None):
"""Symmetric Stochastic Block Model.
n - number of nodes
k - number of communities
a - probability scale for intra-community edge
b - probability scale for inter-community edge
regime - If "logaritm", this generates SSBM(n, k, a*log(n)/n, b*log(n)/n)
If "constant", this generates SSBM(n, k, a/n, b/n)
If "mixed", this generates SSBM(n, k, a*log(n)/n, b/n)
"""
self.n = n
self.k = k
if regime == 'logarithm':
if math.sqrt(a) - math.sqrt(b) >= math.sqrt(k):
print('SSBM model with possible exact recovery.')
else:
print('SSBM model with impossible exact recovery.')
self.a = a * math.log(n) / n
self.b = b * math.log(n) / n
elif regime == 'constant':
snr = (a - b) ** 2 / (k * (a + (k - 1) * b))
if snr > 1:
print('SSBM model with possible detection.')
else:
print('SSBM model that may not have detection (snr=%.5f).' % snr)
self.a = a / n
self.b = b / n
elif regime == 'mixed':
self.a = a * math.log(n) / n
self.b = b / n
else:
raise ValueError('Unknown regime: %s' % regime)
if rng is None:
self.rng = np.random.RandomState()
else:
self.rng = rng
self._graph = None
def generate(self):
self.generate_communities()
print('Finished generating communities.')
self.generate_edges()
print('Finished generating edges.')
def generate_communities(self):
nodes = list(range(self.n))
size = self.n // self.k
self.block_size = size
self.comm2node = [nodes[i*size:(i+1)*size] for i in range(self.k)]
self.node2comm = [nid // size for nid in range(self.n)]
def generate_edges(self):
# TODO: dedup edges
us = []
vs = []
# generate intra-comm edges
for i in range(self.k):
sp_mat = sp.random(self.block_size, self.block_size,
density=self.a,
random_state=self.rng,
data_rvs=lambda l: np.ones(l))
u = sp_mat.row + i * self.block_size
v = sp_mat.col + i * self.block_size
us.append(u)
vs.append(v)
# generate inter-comm edges
for i in range(self.k):
for j in range(self.k):
if i == j:
continue
sp_mat = sp.random(self.block_size, self.block_size,
density=self.b,
random_state=self.rng,
data_rvs=lambda l: np.ones(l))
u = sp_mat.row + i * self.block_size
v = sp_mat.col + j * self.block_size
us.append(u)
vs.append(v)
us = np.hstack(us)
vs = np.hstack(vs)
self.sp_mat = sp.coo_matrix((np.ones(us.shape[0]), (us, vs)), shape=(self.n, self.n))
@property
def graph(self):
if self._graph is None:
self._graph = nx.from_scipy_sparse_matrix(self.sp_mat, create_using=nx.DiGraph())
return self._graph
def plot(self):
x = self.sp_mat.row
y = self.sp_mat.col
plt.scatter(x, y, s=0.5, marker='.', c='k')
plt.savefig('ssbm-%d-%d.pdf' % (self.n, self.k))
plt.clf()
# plot out degree distribution
out_degree = [d for _, d in self.graph.out_degree().items()]
plt.hist(out_degree, 100, normed=True)
plt.savefig('ssbm-%d-%d_out_degree.pdf' % (self.n, self.k))
plt.clf()
if __name__ == '__main__':
n = 1000
k = 10
ssbm = SSBM(n, k, regime='mixed', a=4, b=1)
ssbm.generate()
g = ssbm.graph
print('#nodes:', g.number_of_nodes())
print('#edges:', g.number_of_edges())
#ssbm.plot()
#lg = nx.line_graph(g)
# plot degree distribution
#degree = [d for _, d in lg.degree().items()]
#plt.hist(degree, 100, normed=True)
#plt.savefig('lg<ssbm-%d-%d>_degree.pdf' % (n, k))
#plt.clf()
"""
ipython3 test.py -- --features 1 16 16 --gpu -1 --n-classes 5 --n-iterations 10 --n-nodes 10 --order 3 --radius 3
"""
import argparse
import networkx as nx
import torch as th
import torch.nn as nn
import torch.optim as optim
import gnn
parser = argparse.ArgumentParser()
parser.add_argument('--features', nargs='+', type=int)
parser.add_argument('--gpu', type=int)
parser.add_argument('--n-classes', type=int)
parser.add_argument('--n-iterations', type=int)
parser.add_argument('--n-nodes', type=int)
parser.add_argument('--order', type=int)
parser.add_argument('--radius', type=int)
args = parser.parse_args()
if args.gpu < 0:
cuda = False
else:
cuda = True
th.cuda.set_device(args.gpu)
g = nx.barabasi_albert_graph(args.n_nodes, 1).to_directed() # TODO SBM
y = th.multinomial(th.ones(args.n_classes), args.n_nodes, replacement=True)
network = gnn.GNN(args.features, args.order, args.radius, args.n_classes)
if cuda:
network.cuda()
ce = nn.CrossEntropyLoss()
adam = optim.Adam(network.parameters())
for i in range(args.n_iterations):
y_bar = network(g)
loss = ce(y_bar, y)
adam.zero_grad()
loss.backward()
adam.step()
print('[iteration %d]loss %f' % (i, loss))
from __future__ import division
import argparse
from itertools import permutations
import networkx as nx
import torch as th
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import dgl
from dgl.data import SBMMixture
import gnn
import utils
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int,
help='Batch size', default=1)
parser.add_argument('--gpu', type=int,
help='GPU', default=-1)
parser.add_argument('--n-communities', type=int,
help='Number of communities', default=2)
parser.add_argument('--n-features', type=int,
help='Number of features per layer', default=2)
parser.add_argument('--n-graphs', type=int,
help='Number of graphs', default=6000)
parser.add_argument('--n-iterations', type=int,
help='Number of iterations', default=10000)
parser.add_argument('--n-layers', type=int,
help='Number of layers', default=30)
parser.add_argument('--n-nodes', type=int,
help='Number of nodes', default=1000)
parser.add_argument('--model-path', type=str,
help='Path to the checkpoint of model', default='model')
parser.add_argument('--radius', type=int,
help='Radius', default=3)
args = parser.parse_args()
dev = th.device('cpu') if args.gpu < 0 else th.device('cuda:%d' % args.gpu)
dataset = SBMMixture(args.n_graphs, args.n_nodes, args.n_communities)
loader = utils.cycle(DataLoader(dataset, args.batch_size,
shuffle=True, collate_fn=dataset.collate_fn, drop_last=True))
ones = th.ones(args.n_nodes // args.n_communities)
y_list = [th.cat([th.cat([x * ones for x in p])] * args.batch_size).long().to(dev)
for p in permutations(range(args.n_communities))]
feats = [1] + [args.n_features] * args.n_layers + [args.n_communities]
model = gnn.GNN(feats, args.radius, args.n_communities).to(dev)
opt = optim.Adamax(model.parameters(), lr=0.04)
for i in range(args.n_iterations):
g, lg, deg_g, deg_lg, eid2nid = next(loader)
deg_g = deg_g.to(dev)
deg_lg = deg_lg.to(dev)
eid2nid = eid2nid.to(dev)
y_bar = model(g, lg, deg_g, deg_lg, eid2nid)
loss = min(F.cross_entropy(y_bar, y) for y in y_list)
opt.zero_grad()
loss.backward()
opt.step()
placeholder = '0' * (len(str(args.n_iterations)) - len(str(i)))
print('[iteration %s%d]loss %f' % (placeholder, i, loss))
th.save(model.state_dict(), args.model_path)
def cycle(loader):
while True:
for x in loader:
yield x
...@@ -8,23 +8,17 @@ from torch.utils.data import DataLoader ...@@ -8,23 +8,17 @@ from torch.utils.data import DataLoader
import dgl import dgl
import dgl.data as data import dgl.data as data
import dgl.ndarray as nd
from tree_lstm import TreeLSTM from tree_lstm import TreeLSTM
def _batch_to_cuda(batch):
return data.SSTBatch(graph=batch.graph,
nid_with_word = batch.nid_with_word.cuda(),
wordid = batch.wordid.cuda(),
label = batch.label.cuda())
import dgl.context as ctx
def tensor_topo_traverse(g, cuda, args): def tensor_topo_traverse(g, cuda, args):
n = g.number_of_nodes() n = g.number_of_nodes()
if cuda: if cuda:
adjmat = g.cached_graph.adjmat().get(ctx.gpu(args.gpu)) adjmat = g._graph.adjacency_matrix().get(nd.gpu(args.gpu))
mask = th.ones((n, 1)).cuda() mask = th.ones((n, 1)).cuda()
else: else:
adjmat = g.cached_graph.adjmat().get(ctx.cpu()) adjmat = g._graph.adjacency_matrix().get(nd.cpu())
mask = th.ones((n, 1)) mask = th.ones((n, 1))
degree = th.spmm(adjmat, mask) degree = th.spmm(adjmat, mask)
while th.sum(mask) != 0.: while th.sum(mask) != 0.:
...@@ -39,10 +33,17 @@ def main(args): ...@@ -39,10 +33,17 @@ def main(args):
cuda = args.gpu >= 0 cuda = args.gpu >= 0
if cuda: if cuda:
th.cuda.set_device(args.gpu) th.cuda.set_device(args.gpu)
def _batcher(trees):
bg = dgl.batch(trees)
if cuda:
reprs = bg.get_n_repr()
reprs = {key : val.cuda() for key, val in reprs.items()}
bg.set_n_repr(reprs)
return bg
trainset = data.SST() trainset = data.SST()
train_loader = DataLoader(dataset=trainset, train_loader = DataLoader(dataset=trainset,
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=data.SST.batcher, collate_fn=_batcher,
shuffle=False, shuffle=False,
num_workers=0) num_workers=0)
#testset = data.SST(mode='test') #testset = data.SST(mode='test')
...@@ -69,18 +70,15 @@ def main(args): ...@@ -69,18 +70,15 @@ def main(args):
dur = [] dur = []
for epoch in range(args.epochs): for epoch in range(args.epochs):
t_epoch = time.time() t_epoch = time.time()
for step, batch in enumerate(train_loader): for step, graph in enumerate(train_loader):
g = batch.graph
if cuda:
batch = _batch_to_cuda(batch)
if step >= 3: if step >= 3:
t0 = time.time() t0 = time.time()
label = graph.pop_n_repr('y')
# traverse graph # traverse graph
giter = list(tensor_topo_traverse(g, False, args)) giter = list(tensor_topo_traverse(graph, False, args))
logits = model(batch, zero_initializer, iterator=giter, train=True) logits = model(graph, zero_initializer, iterator=giter, train=True)
logp = F.log_softmax(logits, 1) logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label) loss = F.nll_loss(logp, label)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -89,11 +87,11 @@ def main(args): ...@@ -89,11 +87,11 @@ def main(args):
if step > 0 and step % args.log_every == 0: if step > 0 and step % args.log_every == 0:
pred = th.argmax(logits, 1) pred = th.argmax(logits, 1)
acc = th.sum(th.eq(batch.label, pred)) acc = th.sum(th.eq(label, pred))
mean_dur = np.mean(dur) mean_dur = np.mean(dur)
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | " print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
"Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}".format( "Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}".format(
epoch, step, loss.item(), acc.item()/len(batch.label), epoch, step, loss.item(), acc.item() / len(label),
mean_dur, args.batch_size / mean_dur)) mean_dur, args.batch_size / mean_dur))
print("Epoch time(s):", time.time() - t_epoch) print("Epoch time(s):", time.time() - t_epoch)
......
...@@ -10,23 +10,7 @@ import torch as th ...@@ -10,23 +10,7 @@ import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
def topological_traverse(G): import dgl
indegree_map = {v: d for v, d in G.in_degree() if d > 0}
# These nodes have zero indegree and ready to be returned.
zero_indegree = [v for v, d in G.in_degree() if d == 0]
while True:
yield zero_indegree
next_zero_indegree = []
while zero_indegree:
node = zero_indegree.pop()
for _, child in G.edges(node):
indegree_map[child] -= 1
if indegree_map[child] == 0:
next_zero_indegree.append(child)
del indegree_map[child]
if len(next_zero_indegree) == 0:
break
zero_indegree = next_zero_indegree
class ChildSumTreeLSTMCell(nn.Module): class ChildSumTreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size): def __init__(self, x_size, h_size):
...@@ -39,7 +23,7 @@ class ChildSumTreeLSTMCell(nn.Module): ...@@ -39,7 +23,7 @@ class ChildSumTreeLSTMCell(nn.Module):
self.ut = 0. self.ut = 0.
def message_func(self, src, edge): def message_func(self, src, edge):
return src return {'h' : src['h'], 'c' : src['c']}
def reduce_func(self, node, msgs): def reduce_func(self, node, msgs):
# equation (2) # equation (2)
...@@ -83,13 +67,13 @@ class TreeLSTM(nn.Module): ...@@ -83,13 +67,13 @@ class TreeLSTM(nn.Module):
else: else:
raise RuntimeError('Unknown cell type:', cell_type) raise RuntimeError('Unknown cell type:', cell_type)
def forward(self, batch, zero_initializer, h=None, c=None, iterator=None, train=True): def forward(self, graph, zero_initializer, h=None, c=None, iterator=None, train=True):
"""Compute tree-lstm prediction given a batch. """Compute tree-lstm prediction given a batch.
Parameters Parameters
---------- ----------
batch : dgl.data.SSTBatch graph : dgl.DGLGraph
The data batch. The batched trees.
zero_initializer : callable zero_initializer : callable
Function to return zero value tensor. Function to return zero value tensor.
h : Tensor, optional h : Tensor, optional
...@@ -104,15 +88,17 @@ class TreeLSTM(nn.Module): ...@@ -104,15 +88,17 @@ class TreeLSTM(nn.Module):
logits : Tensor logits : Tensor
The prediction of each node. The prediction of each node.
""" """
g = batch.graph g = graph
n = g.number_of_nodes() n = g.number_of_nodes()
g.register_message_func(self.cell.message_func, batchable=True) g.register_message_func(self.cell.message_func)
g.register_reduce_func(self.cell.reduce_func, batchable=True) g.register_reduce_func(self.cell.reduce_func)
g.register_apply_node_func(self.cell.apply_func, batchable=True) g.register_apply_node_func(self.cell.apply_func)
# feed embedding # feed embedding
embeds = self.embedding(batch.wordid) wordid = g.pop_n_repr('x')
x = zero_initializer((n, self.x_size)) mask = (wordid != dgl.data.SST.PAD_WORD)
x = x.index_copy(0, batch.nid_with_word, embeds) wordid = wordid * mask.long()
embeds = self.embedding(wordid)
x = embeds * th.unsqueeze(mask, 1).float()
if h is None: if h is None:
h = zero_initializer((n, self.h_size)) h = zero_initializer((n, self.h_size))
h_tild = zero_initializer((n, self.h_size)) h_tild = zero_initializer((n, self.h_size))
......
/*!
* Copyright (c) 2018 by Contributors
* \file dgl/graph.h
* \brief DGL graph index class.
*/
#ifndef DGL_GRAPH_H_
#define DGL_GRAPH_H_
#include <vector>
#include <cstdint>
#include "runtime/ndarray.h"
namespace dgl {
typedef uint64_t dgl_id_t;
typedef tvm::runtime::NDArray IdArray;
typedef tvm::runtime::NDArray DegreeArray;
typedef tvm::runtime::NDArray BoolArray;
class Graph;
class GraphOp;
struct Subgraph;
/*!
* \brief Base dgl graph index class.
*
* DGL's graph is directed. Vertices are integers enumerated from zero. Edges
* are uniquely identified by the two endpoints. Multi-edge is currently not
* supported.
*
* Removal of vertices/edges is not allowed. Instead, the graph can only be "cleared"
* by removing all the vertices and edges.
*
* When calling functions supporing multiple edges (e.g. AddEdges, HasEdges),
* the input edges are represented by two id arrays for source and destination
* vertex ids. In the general case, the two arrays should have the same length.
* If the length of src id array is one, it represents one-many connections.
* If the length of dst id array is one, it represents many-one connections.
*/
class Graph {
public:
/* \brief structure used to represent a list of edges */
typedef struct {
/* \brief the two endpoints and the id of the edge */
IdArray src, dst, id;
} EdgeArray;
/*! \brief default constructor */
explicit Graph(bool multigraph = false) : is_multigraph_(multigraph) {}
/*! \brief default copy constructor */
Graph(const Graph& other) = default;
#ifndef _MSC_VER
/*! \brief default move constructor */
Graph(Graph&& other) = default;
#else
Graph(Graph&& other) {
adjlist_ = other.adjlist_;
reverse_adjlist_ = other.reverse_adjlist_;
all_edges_src_ = other.all_edges_src_;
all_edges_dst_ = other.all_edges_dst_;
read_only_ = other.read_only_;
is_multigraph_ = other.is_multigraph_;
num_edges_ = other.num_edges_;
other.Clear();
}
#endif // _MSC_VER
/*! \brief default assign constructor */
Graph& operator=(const Graph& other) = default;
/*! \brief default destructor */
~Graph() = default;
/*!
* \brief Add vertices to the graph.
* \note Since vertices are integers enumerated from zero, only the number of
* vertices to be added needs to be specified.
* \param num_vertices The number of vertices to be added.
*/
void AddVertices(uint64_t num_vertices);
/*!
* \brief Add one edge to the graph.
* \param src The source vertex.
* \param dst The destination vertex.
*/
void AddEdge(dgl_id_t src, dgl_id_t dst);
/*!
* \brief Add edges to the graph.
* \param src_ids The source vertex id array.
* \param dst_ids The destination vertex id array.
*/
void AddEdges(IdArray src_ids, IdArray dst_ids);
/*!
* \brief Clear the graph. Remove all vertices/edges.
*/
void Clear() {
adjlist_.clear();
reverse_adjlist_.clear();
all_edges_src_.clear();
all_edges_dst_.clear();
read_only_ = false;
num_edges_ = 0;
}
/*!
* \note not const since we have caches
* \return whether the graph is a multigraph
*/
bool IsMultigraph() const {
return is_multigraph_;
}
/*! \return the number of vertices in the graph.*/
uint64_t NumVertices() const {
return adjlist_.size();
}
/*! \return the number of edges in the graph.*/
uint64_t NumEdges() const {
return num_edges_;
}
/*! \return true if the given vertex is in the graph.*/
bool HasVertex(dgl_id_t vid) const {
return vid < NumVertices();
}
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/
BoolArray HasVertices(IdArray vids) const;
/*! \return true if the given edge is in the graph.*/
bool HasEdgeBetween(dgl_id_t src, dgl_id_t dst) const;
/*! \return a 0-1 array indicating whether the given edges are in the graph.*/
BoolArray HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const;
/*!
* \brief Find the predecessors of a vertex.
* \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
* \return the predecessor id array.
*/
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const;
/*!
* \brief Find the successors of a vertex.
* \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
* \return the successor id array.
*/
IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const;
/*!
* \brief Get all edge ids between the two given endpoints
* \note Edges are associated with an integer id start from zero.
* The id is assigned when the edge is being added to the graph.
* \param src The source vertex.
* \param dst The destination vertex.
* \return the edge id array.
*/
IdArray EdgeId(dgl_id_t src, dgl_id_t dst) const;
/*!
* \brief Get all edge ids between the given endpoint pairs.
* \note Edges are associated with an integer id start from zero.
* The id is assigned when the edge is being added to the graph.
* If duplicate pairs exist, the returned edge IDs will also duplicate.
* The order of returned edge IDs will follow the order of src-dst pairs
* first, and ties are broken by the order of edge ID.
* \return EdgeArray containing all edges between all pairs.
*/
EdgeArray EdgeIds(IdArray src, IdArray dst) const;
/*!
* \brief Find the edge IDs and return their source and target node IDs.
* \param eids The edge ID array.
* \return EdgeArray containing all edges with id in eid. The order is preserved.
*/
EdgeArray FindEdges(IdArray eids) const;
/*!
* \brief Get the in edges of the vertex.
* \note The returned dst id array is filled with vid.
* \param vid The vertex id.
* \return the edges
*/
EdgeArray InEdges(dgl_id_t vid) const;
/*!
* \brief Get the in edges of the vertices.
* \param vids The vertex id array.
* \return the id arrays of the two endpoints of the edges.
*/
EdgeArray InEdges(IdArray vids) const;
/*!
* \brief Get the out edges of the vertex.
* \note The returned src id array is filled with vid.
* \param vid The vertex id.
* \return the id arrays of the two endpoints of the edges.
*/
EdgeArray OutEdges(dgl_id_t vid) const;
/*!
* \brief Get the out edges of the vertices.
* \param vids The vertex id array.
* \return the id arrays of the two endpoints of the edges.
*/
EdgeArray OutEdges(IdArray vids) const;
/*!
* \brief Get all the edges in the graph.
* \note If sorted is true, the returned edges list is sorted by their src and
* dst ids. Otherwise, they are in their edge id order.
* \param sorted Whether the returned edge list is sorted by their src and dst ids
* \return the id arrays of the two endpoints of the edges.
*/
EdgeArray Edges(bool sorted = false) const;
/*!
* \brief Get the in degree of the given vertex.
* \param vid The vertex id.
* \return the in degree
*/
uint64_t InDegree(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
return reverse_adjlist_[vid].succ.size();
}
/*!
* \brief Get the in degrees of the given vertices.
* \param vid The vertex id array.
* \return the in degree array
*/
DegreeArray InDegrees(IdArray vids) const;
/*!
* \brief Get the out degree of the given vertex.
* \param vid The vertex id.
* \return the out degree
*/
uint64_t OutDegree(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
return adjlist_[vid].succ.size();
}
/*!
* \brief Get the out degrees of the given vertices.
* \param vid The vertex id array.
* \return the out degree array
*/
DegreeArray OutDegrees(IdArray vids) const;
/*!
* \brief Construct the induced subgraph of the given vertices.
*
* The induced subgraph is a subgraph formed by specifying a set of vertices V' and then
* selecting all of the edges from the original graph that connect two vertices in V'.
*
* Vertices and edges in the original graph will be "reindexed" to local index. The local
* index of the vertices preserve the order of the given id array, while the local index
* of the edges preserve the index order in the original graph. Vertices not in the
* original graph are ignored.
*
* The result subgraph is read-only.
*
* \param vids The vertices in the subgraph.
* \return the induced subgraph
*/
Subgraph VertexSubgraph(IdArray vids) const;
/*!
* \brief Construct the induced edge subgraph of the given edges.
*
* The induced edges subgraph is a subgraph formed by specifying a set of edges E' and then
* selecting all of the nodes from the original graph that are endpoints in E'.
*
* Vertices and edges in the original graph will be "reindexed" to local index. The local
* index of the edges preserve the order of the given id array, while the local index
* of the vertices preserve the index order in the original graph. Edges not in the
* original graph are ignored.
*
* The result subgraph is read-only.
*
* \param eids The edges in the subgraph.
* \return the induced edge subgraph
*/
Subgraph EdgeSubgraph(IdArray eids) const;
/*!
* \brief Return a new graph with all the edges reversed.
*
* The returned graph preserves the vertex and edge index in the original graph.
*
* \return the reversed graph
*/
Graph Reverse() const;
protected:
friend class GraphOp;
/*! \brief Internal edge list type */
struct EdgeList {
/*! \brief successor vertex list */
std::vector<dgl_id_t> succ;
/*! \brief predecessor vertex list */
std::vector<dgl_id_t> edge_id;
};
typedef std::vector<EdgeList> AdjacencyList;
/*! \brief adjacency list using vector storage */
AdjacencyList adjlist_;
/*! \brief reverse adjacency list using vector storage */
AdjacencyList reverse_adjlist_;
/*! \brief all edges' src endpoints in their edge id order */
std::vector<dgl_id_t> all_edges_src_;
/*! \brief all edges' dst endpoints in their edge id order */
std::vector<dgl_id_t> all_edges_dst_;
/*! \brief read only flag */
bool read_only_ = false;
/*!
* \brief Whether if this is a multigraph.
*
* When a multiedge is added, this flag switches to true.
*/
bool is_multigraph_ = false;
/*! \brief number of edges */
uint64_t num_edges_ = 0;
};
/*! \brief Subgraph data structure */
struct Subgraph {
/*! \brief The graph. */
Graph graph;
/*!
* \brief The induced vertex ids.
* \note This is also a map from the new vertex id to the vertex id in the parent graph.
*/
IdArray induced_vertices;
/*!
* \brief The induced edge ids.
* \note This is also a map from the new edge id to the edge id in the parent graph.
*/
IdArray induced_edges;
};
} // namespace dgl
#endif // DGL_GRAPH_H_
/*!
* Copyright (c) 2018 by Contributors
* \file dgl/graph_op.h
* \brief Operations on graph index.
*/
#ifndef DGL_GRAPH_OP_H_
#define DGL_GRAPH_OP_H_
#include <vector>
#include "graph.h"
namespace dgl {
class GraphOp {
public:
/*!
* \brief Return the line graph.
*
* If i~j and j~i are two edges in original graph G, then
* (i,j)~(j,i) and (j,i)~(i,j) are the "backtracking" edges on
* the line graph.
*
* \param graph The input graph.
* \param backtracking Whether the backtracking edges are included or not
* \return the line graph
*/
static Graph LineGraph(const Graph* graph, bool backtracking);
/*!
* \brief Return a disjoint union of the input graphs.
*
* The new graph will include all the nodes/edges in the given graphs.
* Nodes/Edges will be relabled by adding the cumsum of the previous graph sizes
* in the given sequence order. For example, giving input [g1, g2, g3], where
* they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become node#7
* in the result graph. Edge ids are re-assigned similarly.
*
* \param graphs A list of input graphs to be unioned.
* \return the disjoint union of the graphs
*/
static Graph DisjointUnion(std::vector<const Graph*> graphs);
/*!
* \brief Partition the graph into several subgraphs.
*
* This is a reverse operation of DisjointUnion. The graph will be partitioned
* into num graphs. This requires the given number of partitions to evenly
* divides the number of nodes in the graph.
*
* \param graph The graph to be partitioned.
* \param num The number of partitions.
* \return a list of partitioned graphs
*/
static std::vector<Graph> DisjointPartitionByNum(const Graph* graph, int64_t num);
/*!
* \brief Partition the graph into several subgraphs.
*
* This is a reverse operation of DisjointUnion. The graph will be partitioned
* based on the given sizes. This requires the sum of the given sizes is equal
* to the number of nodes in the graph.
*
* \param graph The graph to be partitioned.
* \param sizes The number of partitions.
* \return a list of partitioned graphs
*/
static std::vector<Graph> DisjointPartitionBySizes(const Graph* graph, IdArray sizes);
};
} // namespace dgl
#endif // DGL_GRAPH_OP_H_
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/c_backend_api.h
* \brief TVM runtime backend API.
*
* The functions defined in this header are intended to be
* used by compiled tvm operators, usually user do not need to use these
* function directly.
*/
#ifndef DGL_RUNTIME_C_BACKEND_API_H_
#define DGL_RUNTIME_C_BACKEND_API_H_
#include "c_runtime_api.h"
#ifdef __cplusplus
extern "C" {
#endif
// Backend related functions.
/*!
* \brief Backend function for modules to get function
* from its environment mod_node (its imports and global function).
* The user do should not call TVMFuncFree on func.
*
* \param mod_node The module handle.
* \param func_name The name of the function.
* \param out The result function.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
TVMFunctionHandle *out);
/*!
* \brief Backend function to register system-wide library symbol.
*
* \param name The name of the symbol
* \param ptr The symbol address.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr);
/*!
* \brief Backend function to allocate temporal workspace.
*
* \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
*
* \param nbytes The size of the space requested.
* \param device_type The device type which the space will be allocated.
* \param device_id The device id which the space will be allocated.
* \param dtype_code_hint The type code of the array elements. Only used in
* certain backends such as OpenGL.
* \param dtype_bits_hint The type bits of the array elements. Only used in
* certain backends such as OpenGL.
* \return nullptr when error is thrown, a valid ptr if success
*/
TVM_DLL void* TVMBackendAllocWorkspace(int device_type,
int device_id,
uint64_t nbytes,
int dtype_code_hint,
int dtype_bits_hint);
/*!
* \brief Backend function to free temporal workspace.
*
* \param ptr The result allocated space pointer.
* \param device_type The device type which the space will be allocated.
* \param device_id The device id which the space will be allocated.
* \return 0 when no error is thrown, -1 when failure happens
*
* \sa TVMBackendAllocWorkspace
*/
TVM_DLL int TVMBackendFreeWorkspace(int device_type,
int device_id,
void* ptr);
/*!
* \brief Environment for TVM parallel task.
*/
typedef struct {
/*!
* \brief Auxiliary used for synchronization
*/
void* sync_handle;
/*! \brief total amount of task */
int32_t num_task;
} TVMParallelGroupEnv;
/*!
* \brief The callback function to execute a parallel lambda
* \param task_id the task id of the function.
* \param penv The parallel environment backs the execution.
* \param cdata The supporting closure data.
*/
typedef int (*FTVMParallelLambda)(
int task_id, TVMParallelGroupEnv* penv, void* cdata);
/*!
* \brief Backend function for running parallel jobs.
*
* \param flambda The parallel function to be launched.
* \param cdata The closure data.
* \param num_task Number of tasks to launch, can be 0, means launch
* with all available threads.
*
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda,
void* cdata,
int num_task);
/*!
* \brief BSP barrrier between parallel threads
* \param task_id the task id of the function.
* \param penv The parallel environment backs the execution.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv);
/*!
* \brief Simple static initialization fucntion.
* Run f once and set handle to be not null.
* This function is mainly used for test purpose.
*
* \param handle An global address to indicate f
* \param f The function to be ran
* \param cdata The closure data to pass to the function.
* \param nbytes Number of bytes in the closure data.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendRunOnce(void** handle,
int (*f)(void*),
void *cdata,
int nbytes);
#ifdef __cplusplus
} // TVM_EXTERN_C
#endif
#endif // DGL_RUNTIME_C_BACKEND_API_H_
/*!
* Copyright (c) 2016 by Contributors
* \file dgl/runtime/c_runtime_api.h
* \brief TVM runtime library.
*
* The philosophy of TVM project is to customize the compilation
* stage to generate code that can used by other projects transparently.
* So this is a minimum runtime code gluing, and some limited
* memory management code to enable quick testing.
*
* The runtime API is independent from TVM compilation stack and can
* be linked via libtvm_runtime.
*
* The common flow is:
* - Use TVMFuncListGlobalNames to get global function name
* - Use TVMFuncCall to call these functions.
*/
#ifndef DGL_RUNTIME_C_RUNTIME_API_H_
#define DGL_RUNTIME_C_RUNTIME_API_H_
// Macros to do weak linking
#ifdef _MSC_VER
#define TVM_WEAK __declspec(selectany)
#else
#define TVM_WEAK __attribute__((weak))
#endif
#ifdef __EMSCRIPTEN__
#include <emscripten/emscripten.h>
#define TVM_DLL EMSCRIPTEN_KEEPALIVE
#endif
#ifndef TVM_DLL
#ifdef _WIN32
#ifdef TVM_EXPORTS
#define TVM_DLL __declspec(dllexport)
#else
#define TVM_DLL __declspec(dllimport)
#endif
#else
#define TVM_DLL
#endif
#endif
// TVM version
#define TVM_VERSION "0.5.dev"
// TVM Runtime is DLPack compatible.
#include <dlpack/dlpack.h>
#ifdef __cplusplus
extern "C" {
#endif
#include <stdint.h>
#include <stddef.h>
/*! \brief type of array index. */
typedef int64_t tvm_index_t;
/*! \brief Extension device types in TVM */
typedef enum {
kDLAOCL = 5,
kDLSDAccel = 6,
kOpenGL = 11,
// Extension DRAM type, used for quickly test extension device
// The device api can differ depending on the xpu driver registered.
kExtDev = 12,
// AddExtraTVMType which is not in DLPack here
} TVMDeviceExtType;
/*!
* \brief The type code in TVMType
* \note TVMType is used in two places.
*/
typedef enum {
// The type code of other types are compatible with DLPack.
// The next few fields are extension types
// that is used by TVM API calls.
kHandle = 3U,
kNull = 4U,
kTVMType = 5U,
kTVMContext = 6U,
kArrayHandle = 7U,
kNodeHandle = 8U,
kModuleHandle = 9U,
kFuncHandle = 10U,
kStr = 11U,
kBytes = 12U,
kNDArrayContainer = 13U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
// Open an issue at the repo if you need a section of code.
kExtBegin = 15U,
kNNVMFirst = 16U,
kNNVMLast = 20U,
// The following section of code is used for non-reserved types.
kExtReserveEnd = 64U,
kExtEnd = 128U
} TVMTypeCode;
/*!
* \brief The data type used in TVM Runtime.
*
* Examples
* - float: type_code = 2, bits = 32, lanes=1
* - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
* - int8: type_code = 0, bits = 8, lanes=1
*
* \note Arguments TVM API function always takes bits=64 and lanes=1
*/
typedef DLDataType TVMType;
/*!
* \brief The Device information, abstract away common device types.
*/
typedef DLContext TVMContext;
/*!
* \brief The tensor array stucture to TVM API.
*/
typedef DLTensor TVMArray;
/*! \brief the array handle */
typedef TVMArray* TVMArrayHandle;
/*!
* \brief Union type of values
* being passed through API and function calls.
*/
typedef union {
int64_t v_int64;
double v_float64;
void* v_handle;
const char* v_str;
TVMType v_type;
TVMContext v_ctx;
} TVMValue;
/*!
* \brief Byte array type used to pass in byte array
* When kBytes is used as data type.
*/
typedef struct {
const char* data;
size_t size;
} TVMByteArray;
/*! \brief Handle to TVM runtime modules. */
typedef void* TVMModuleHandle;
/*! \brief Handle to packed function handle. */
typedef void* TVMFunctionHandle;
/*! \brief Handle to hold return value. */
typedef void* TVMRetValueHandle;
/*!
* \brief The stream that is specific to device
* can be NULL, which indicates the default one.
*/
typedef void* TVMStreamHandle;
/*!
* \brief Used for implementing C API function.
* Set last error message before return.
* \param msg The error message to be set.
*/
TVM_DLL void TVMAPISetLastError(const char* msg);
/*!
* \brief return str message of the last error
* all function in this file will return 0 when success
* and -1 when an error occured,
* TVMGetLastError can be called to retrieve the error
*
* this function is threadsafe and can be called by different thread
* \return error info
*/
TVM_DLL const char *TVMGetLastError(void);
/*!
* \brief Load module from file.
* \param file_name The file name to load the module from.
* \param format The format of the module.
* \param out The result module
*
* \return 0 when success, -1 when failure happens
* \note The resulting module do not contain import relation.
* It can be reconstructed by TVMModImport.
*/
TVM_DLL int TVMModLoadFromFile(const char* file_name,
const char* format,
TVMModuleHandle* out);
/*!
* \brief Add dep to mod's dependency.
* This allows functions in this module to use modules.
*
* \param mod The module handle.
* \param dep The dependent module to be imported.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMModImport(TVMModuleHandle mod,
TVMModuleHandle dep);
/*!
* \brief Get function from the module.
* \param mod The module handle.
* \param func_name The name of the function.
* \param query_imports Whether to query imported modules
* \param out The result function, can be NULL if it is not available.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
const char* func_name,
int query_imports,
TVMFunctionHandle *out);
/*!
* \brief Free front-end extension type resource.
* \param handle The extension handle.
* \param type_code The type of of the extension type.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMExtTypeFree(void* handle, int type_code);
/*!
* \brief Free the Module
* \param mod The module to be freed.
*
* \note This may not free up the module's resources.
* If there is active TVMFunctionHandle uses the module
* Or if this module is imported by another active module.
*
* The all functions remains valid until TVMFuncFree is called.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMModFree(TVMModuleHandle mod);
/*!
* \brief Free the function when it is no longer needed.
* \param func The function handle
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMFuncFree(TVMFunctionHandle func);
/*!
* \brief Call a Packed TVM Function.
*
* \param func node handle of the function.
* \param arg_values The arguments
* \param type_codes The type codes of the arguments
* \param num_args Number of arguments.
*
* \param ret_val The return value.
* \param ret_type_code the type code of return value.
*
* \return 0 when success, -1 when failure happens
* \note TVM calls always exchanges with type bits=64, lanes=1
*
* \note API calls always exchanges with type bits=64, lanes=1
* If API call returns container handles (e.g. FunctionHandle)
* these handles should be managed by the front-end.
* The front-end need to call free function (e.g. TVMFuncFree)
* to free these handles.
*/
TVM_DLL int TVMFuncCall(TVMFunctionHandle func,
TVMValue* arg_values,
int* type_codes,
int num_args,
TVMValue* ret_val,
int* ret_type_code);
/*!
* \brief Set the return value of TVMPackedCFunc.
*
* This function is called by TVMPackedCFunc to set the return value.
* When this function is not called, the function returns null by default.
*
* \param ret The return value handle, pass by ret in TVMPackedCFunc
* \param value The value to be returned.
* \param type_code The type of the value to be returned.
* \param num_ret Number of return values, for now only 1 is supported.
*/
TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret,
TVMValue* value,
int* type_code,
int num_ret);
/*!
* \brief Inplace translate callback argument value to return value.
* This is only needed for non-POD arguments.
*
* \param value The value to be translated.
* \param code The type code to be translated.
* \note This function will do a shallow copy when necessary.
*
* \return 0 when success, -1 when failure happens.
*/
TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code);
/*!
* \brief C type of packed function.
*
* \param args The arguments
* \param type_codes The type codes of the arguments
* \param num_args Number of arguments.
* \param ret The return value handle.
* \param resource_handle The handle additional resouce handle from fron-end.
* \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
* \sa TVMCFuncSetReturn
*/
typedef int (*TVMPackedCFunc)(
TVMValue* args,
int* type_codes,
int num_args,
TVMRetValueHandle ret,
void* resource_handle);
/*!
* \brief C callback to free the resource handle in C packed function.
* \param resource_handle The handle additional resouce handle from fron-end.
*/
typedef void (*TVMPackedCFuncFinalizer)(void* resource_handle);
/*!
* \brief Signature for extension function declarer.
*
* TVM call this function to get the extension functions
* The declarer will call register_func to register function and their name.
*
* \param register_func_handle The register function
* \return 0 if success, -1 if failure happens
*/
typedef int (*TVMExtensionFuncDeclarer)(TVMFunctionHandle register_func_handle);
/*!
* \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
*
* The resource_handle will be managed by TVM API, until the function is no longer used.
*
* \param func The packed C function.
* \param resource_handle The resource handle from front-end, can be NULL.
* \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL
* \param out the result function handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out);
/*!
* \brief Register the function to runtime's global table.
*
* The registered function then can be pulled by the backend by the name.
*
* \param name The name of the function.
* \param f The function to be registered.
* \param override Whether allow override already registered function.
*/
TVM_DLL int TVMFuncRegisterGlobal(
const char* name, TVMFunctionHandle f, int override);
/*!
* \brief Get a global function.
*
* \param name The name of the function.
* \param out the result function pointer, NULL if it does not exist.
*
* \note The function handle of global function is managed by TVM runtime,
* So TVMFuncFree is should not be called when it get deleted.
*/
TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
/*!
* \brief List all the globally registered function name
* \param out_size The number of functions
* \param out_array The array of function names.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMFuncListGlobalNames(int* out_size,
const char*** out_array);
// Array related apis for quick proptyping
/*!
* \brief Allocate a nd-array's memory,
* including space of shape, of given spec.
*
* \param shape The shape of the array, the data content will be copied to out
* \param ndim The number of dimension of the array.
* \param dtype_code The type code of the dtype
* \param dtype_bits The number of bits of dtype
* \param dtype_lanes The number of lanes in the dtype.
* \param device_type The device type of context
* \param device_id The device id of context.
* \param out The output handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
int ndim,
int dtype_code,
int dtype_bits,
int dtype_lanes,
int device_type,
int device_id,
TVMArrayHandle* out);
/*!
* \brief Free the TVM Array.
* \param handle The array handle to be freed.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayFree(TVMArrayHandle handle);
/*!
* \brief Copy array data from CPU byte array.
* \param handle The array handle.
* \param data the data pointer
* \param nbytes The number of bytes to copy.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle,
void* data,
size_t nbytes);
/*!
* \brief Copy array data to CPU byte array.
* \param handle The array handle.
* \param data the data pointer
* \param nbytes The number of bytes to copy.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle,
void* data,
size_t nbytes);
/*!
* \brief Copy the array, both from and to must be valid during the copy.
* \param from The array to be copied from.
* \param to The target space.
* \param stream The stream where the copy happens, can be NULL.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVMArrayHandle to,
TVMStreamHandle stream);
/*!
* \brief Produce an array from the DLManagedTensor that shares data memory
* with the DLManagedTensor.
* \param from The source DLManagedTensor.
* \param out The output array handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from,
TVMArrayHandle* out);
/*!
* \brief Produce a DLMangedTensor from the array that shares data memory with
* the array.
* \param from The source array.
* \param out The DLManagedTensor handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayToDLPack(TVMArrayHandle from,
DLManagedTensor** out);
/*!
* \brief Delete (free) a DLManagedTensor's data.
* \param dltensor Pointer to the DLManagedTensor.
*/
TVM_DLL void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor);
/*!
* \brief Create a new runtime stream.
*
* \param device_type The device type of context
* \param device_id The device id of context
* \param out The new stream handle
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out);
/*!
* \brief Free a created stream handle.
*
* \param device_type The device type of context
* \param device_id The device id of context
* \param stream The stream to be freed
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream);
/*!
* \brief Set the runtime stream of current thread to be stream.
* The subsequent calls to the same device_type
* will use the setted stream handle.
* The specific type of stream is runtime device dependent.
*
* \param device_type The device type of context
* \param device_id The device id of context.
* \param handle The stream handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMSetStream(int device_type, int device_id, TVMStreamHandle handle);
/*!
* \brief Wait until all computations on stream completes.
*
* \param device_type The device type of context
* \param device_id The device id of context.
* \param stream The stream to be synchronized.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream);
/*!
* \brief Synchronize two streams of execution.
*
* \param device_type The device type of context
* \param device_id The device id of context
* \param src The source stream to synchronize.
* \param dst The destination stream to synchronize.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMStreamStreamSynchronize(int device_type,
int device_id,
TVMStreamHandle src,
TVMStreamHandle dst);
#ifdef __cplusplus
} // TVM_EXTERN_C
#endif
#endif // DGL_RUNTIME_C_RUNTIME_API_H_
/*!
* Copyright (c) 2016 by Contributors
* \file dgl/runtime/device_api.h
* \brief Abstract device memory management API
*/
#ifndef DGL_RUNTIME_DEVICE_API_H_
#define DGL_RUNTIME_DEVICE_API_H_
#include <string>
#include "packed_func.h"
#include "c_runtime_api.h"
namespace tvm {
namespace runtime {
/*!
* \brief the query type into GetAttr
*/
enum DeviceAttrKind : int {
kExist = 0,
kMaxThreadsPerBlock = 1,
kWarpSize = 2,
kMaxSharedMemoryPerBlock = 3,
kComputeVersion = 4,
kDeviceName = 5,
kMaxClockRate = 6,
kMultiProcessorCount = 7,
kMaxThreadDimensions = 8
};
/*! \brief Number of bytes each allocation must align to */
constexpr int kAllocAlignment = 64;
/*! \brief Number of bytes each allocation must align to in temporary allocation */
constexpr int kTempAllocaAlignment = 64;
/*! \brief Maximum size that can be allocated on stack */
constexpr int kMaxStackAlloca = 1024;
/*!
* \brief TVM Runtime Device API, abstracts the device
* specific interface for memory management.
*/
class DeviceAPI {
public:
/*! \brief virtual destructor */
virtual ~DeviceAPI() {}
/*!
* \brief Set the environment device id to ctx
* \param ctx The context to be set.
*/
virtual void SetDevice(TVMContext ctx) = 0;
/*!
* \brief Get attribute of specified device.
* \param ctx The device context
* \param kind The result kind
* \param rv The return value.
* \sa DeviceAttrKind
*/
virtual void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) = 0;
/*!
* \brief Allocate a data space on device.
* \param ctx The device context to perform operation.
* \param nbytes The number of bytes in memory.
* \param alignment The alignment of the memory.
* \param type_hint The type of elements. Only needed by certain backends such
* as OpenGL, as nbytes & alignment are sufficient for most backends.
* \return The allocated device pointer.
*/
virtual void* AllocDataSpace(TVMContext ctx,
size_t nbytes,
size_t alignment,
TVMType type_hint) = 0;
/*!
* \brief Free a data space on device.
* \param ctx The device context to perform operation.
* \param ptr The data space.
*/
virtual void FreeDataSpace(TVMContext ctx, void* ptr) = 0;
/*!
* \brief copy data from one place to another
* \param from The source array.
* \param from_offset The byte offeset in the from.
* \param to The target array.
* \param to_offset The byte offset in the to.
* \param num_bytes The size of the memory in bytes
* \param ctx_from The source context
* \param ctx_to The target context
* \param type_hint The type of elements, only neded by certain backends.
* can be useful for cross device endian converison.
* \param stream Optional stream object.
*/
virtual void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t num_bytes,
TVMContext ctx_from,
TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) = 0;
/*!
* \brief Create a new stream of execution.
*
* \param ctx The context of allocation.
*/
TVM_DLL virtual TVMStreamHandle CreateStream(TVMContext ctx);
/*!
* \brief Free a stream of execution
*
* \param ctx The context of the stream
* \param stream The pointer to be freed.
*/
TVM_DLL virtual void FreeStream(TVMContext ctx, TVMStreamHandle stream);
/*!
* \brief Synchronize the stream
* \param ctx The context to perform operation.
* \param stream The stream to be sync.
*/
virtual void StreamSync(TVMContext ctx, TVMStreamHandle stream) = 0;
/*!
* \brief Set the stream
* \param ctx The context to set stream.
* \param stream The stream to be set.
*/
virtual void SetStream(TVMContext ctx, TVMStreamHandle stream) {}
/*!
* \brief Synchronize 2 streams of execution.
*
* An event is created in event_src stream that the second then
* stream waits on. Neither event_src or event_dst need to be of
* the same device ID as the context, but they must be of the same
* device type.
*
* \param ctx The context of the streams.
* \param event_src The source stream to synchronize.
* \param event_dst The destination stream to synchronize.
*/
TVM_DLL virtual void SyncStreamFromTo(TVMContext ctx,
TVMStreamHandle event_src,
TVMStreamHandle event_dst);
/*!
* \brief Allocate temporal workspace for backend execution.
*
* \note We have the following assumption about backend temporal
* workspace allocation, and backend will optimize for such assumption:
*
* - Only a few allocation will happen, and space will be released after use.
* - The release order is usually in reverse order of allocate (stack style).
* - Repeative pattern of same allocations over different runs.
* - Workspace should not overlap between different threads(i.e. be threadlocal)
*
* \param ctx The context of allocation.
* \param nbytes The size to be allocated.
* \param type_hint The type of elements. Only needed by certain backends such
* as OpenGL, as nbytes is sufficient for most backends.
*/
TVM_DLL virtual void* AllocWorkspace(TVMContext ctx,
size_t nbytes,
TVMType type_hint = {});
/*!
* \brief Free temporal workspace in backend execution.
*
* \param ctx The context of allocation.
* \param ptr The pointer to be freed.
*/
TVM_DLL virtual void FreeWorkspace(TVMContext ctx, void* ptr);
/*!
* \brief Get device API base don context.
* \param ctx The context
* \param allow_missing Whether allow missing
* \return The corresponding device API.
*/
TVM_DLL static DeviceAPI* Get(TVMContext ctx, bool allow_missing = false);
};
/*! \brief The device type bigger than this is RPC device */
constexpr int kRPCSessMask = 128;
} // namespace runtime
} // namespace tvm
#endif // DGL_RUNTIME_DEVICE_API_H_
/*!
* Copyright (c) 2017 by Contributors
* \file dgl/runtime/module.h
* \brief Runtime container of the functions generated by TVM,
* This is used to support dynamically link, load and save
* functions from different convention under unified API.
*/
#ifndef DGL_RUNTIME_MODULE_H_
#define DGL_RUNTIME_MODULE_H_
#include <dmlc/io.h>
#include <memory>
#include <vector>
#include <string>
#include <unordered_map>
#include "c_runtime_api.h"
namespace tvm {
namespace runtime {
// The internal container of module.
class ModuleNode;
class PackedFunc;
/*!
* \brief Module container of TVM.
*/
class Module {
public:
Module() {}
// constructor from container.
explicit Module(std::shared_ptr<ModuleNode> n)
: node_(n) {}
/*!
* \brief Get packed function from current module by name.
*
* \param name The name of the function.
* \param query_imports Whether also query dependency modules.
* \return The result function.
* This function will return PackedFunc(nullptr) if function do not exist.
* \note Implemented in packed_func.cc
*/
inline PackedFunc GetFunction(const std::string& name, bool query_imports = false);
/*! \return internal container */
inline ModuleNode* operator->();
/*! \return internal container */
inline const ModuleNode* operator->() const;
// The following functions requires link with runtime.
/*!
* \brief Import another module into this module.
* \param other The module to be imported.
*
* \note Cyclic dependency is not allowed among modules,
* An error will be thrown when cyclic dependency is detected.
*/
TVM_DLL void Import(Module other);
/*!
* \brief Load a module from file.
* \param file_name The name of the host function module.
* \param format The format of the file.
* \note This function won't load the import relationship.
* Re-create import relationship by calling Import.
*/
TVM_DLL static Module LoadFromFile(const std::string& file_name,
const std::string& format = "");
private:
std::shared_ptr<ModuleNode> node_;
};
/*!
* \brief Base node container of module.
* Do not create this directly, instead use Module.
*/
class ModuleNode {
public:
/*! \brief virtual destructor */
virtual ~ModuleNode() {}
/*! \return The module type key */
virtual const char* type_key() const = 0;
/*!
* \brief Get a PackedFunc from module.
*
* The PackedFunc may not be fully initialized,
* there might still be first time running overhead when
* executing the function on certain devices.
* For benchmarking, use prepare to eliminate
*
* \param name the name of the function.
* \param sptr_to_self The shared_ptr that points to this module node.
*
* \return PackedFunc(nullptr) when it is not available.
*
* \note The function will always remain valid.
* If the function need resource from the module(e.g. late linking),
* it should capture sptr_to_self.
*/
virtual PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) = 0;
/*!
* \brief Save the module to file.
* \param file_name The file to be saved to.
* \param format The format of the file.
*/
virtual void SaveToFile(const std::string& file_name,
const std::string& format);
/*!
* \brief Save the module to binary stream.
* \param stream The binary stream to save to.
* \note It is recommended to implement this for device modules,
* but not necessarily host modules.
* We can use this to do AOT loading of bundled device functions.
*/
TVM_DLL virtual void SaveToBinary(dmlc::Stream* stream);
/*!
* \brief Get the source code of module, when available.
* \param format Format of the source code, can be empty by default.
* \return Possible source code when available.
*/
TVM_DLL virtual std::string GetSource(const std::string& format = "");
/*!
* \brief Get a function from current environment
* The environment includes all the imports as well as Global functions.
*
* \param name name of the function.
* \return The corresponding function.
*/
TVM_DLL const PackedFunc* GetFuncFromEnv(const std::string& name);
/*! \return The module it imports from */
const std::vector<Module>& imports() const {
return imports_;
}
protected:
friend class Module;
/*! \brief The modules this module depend on */
std::vector<Module> imports_;
private:
/*! \brief Cache used by GetImport */
std::unordered_map<std::string,
std::unique_ptr<PackedFunc> > import_cache_;
};
/*! \brief namespace for constant symbols */
namespace symbol {
/*! \brief Global variable to store module context. */
constexpr const char* tvm_module_ctx = "__tvm_module_ctx";
/*! \brief Global variable to store device module blob */
constexpr const char* tvm_dev_mblob = "__tvm_dev_mblob";
/*! \brief Number of bytes of device module blob. */
constexpr const char* tvm_dev_mblob_nbytes = "__tvm_dev_mblob_nbytes";
/*! \brief global function to set device */
constexpr const char* tvm_set_device = "__tvm_set_device";
/*! \brief Auxiliary counter to global barrier. */
constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
/*! \brief Prepare the global barrier before kernels that uses global barrier. */
constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier";
/*! \brief Placeholder for the module's entry function. */
constexpr const char* tvm_module_main = "__tvm_main__";
} // namespace symbol
// implementations of inline functions.
inline ModuleNode* Module::operator->() {
return node_.get();
}
inline const ModuleNode* Module::operator->() const {
return node_.get();
}
} // namespace runtime
} // namespace tvm
#include "packed_func.h"
#endif // DGL_RUNTIME_MODULE_H_
/*!
* Copyright (c) 2017 by Contributors
* \file dgl/runtime/ndarray.h
* \brief Abstract device memory management API
*/
#ifndef DGL_RUNTIME_NDARRAY_H_
#define DGL_RUNTIME_NDARRAY_H_
#include <atomic>
#include <vector>
#include <utility>
#include "c_runtime_api.h"
#include "serializer.h"
namespace tvm {
namespace runtime {
/*!
* \brief Managed NDArray.
* The array is backed by reference counted blocks.
*/
class NDArray {
public:
// internal container type
struct Container;
/*! \brief default constructor */
NDArray() {}
/*!
* \brief cosntruct a NDArray that refers to data
* \param data The data this NDArray refers to
*/
explicit inline NDArray(Container* data);
/*!
* \brief copy constructor
* \param other The value to be copied
*/
inline NDArray(const NDArray& other); // NOLINT(*)
/*!
* \brief move constructor
* \param other The value to be moved
*/
NDArray(NDArray&& other) // NOLINT(*)
: data_(other.data_) {
other.data_ = nullptr;
}
/*! \brief destructor */
~NDArray() {
this->reset();
}
/*!
* \brief Swap this array with another NDArray
* \param other The other NDArray
*/
void swap(NDArray& other) { // NOLINT(*)
std::swap(data_, other.data_);
}
/*!
* \brief copy assignmemt
* \param other The value to be assigned.
* \return reference to self.
*/
NDArray& operator=(const NDArray& other) { // NOLINT(*)
// copy-and-swap idiom
NDArray(other).swap(*this); // NOLINT(*)
return *this;
}
/*!
* \brief move assignmemt
* \param other The value to be assigned.
* \return reference to self.
*/
NDArray& operator=(NDArray&& other) { // NOLINT(*)
// copy-and-swap idiom
NDArray(std::move(other)).swap(*this); // NOLINT(*)
return *this;
}
/*! \return If NDArray is defined */
bool defined() const {
return data_ != nullptr;
}
/*! \return If both NDArray reference the same container */
bool same_as(const NDArray& other) const {
return data_ == other.data_;
}
/*! \brief reset the content of NDArray to be nullptr */
inline void reset();
/*!
* \return the reference counter
* \note this number is approximate in multi-threaded setting.
*/
inline int use_count() const;
/*! \return Pointer to content of DLTensor */
inline const DLTensor* operator->() const;
/*!
* \brief Copy data content from another array.
* \param other The source array to be copied from.
* \note The copy may happen asynchrously if it involves a GPU context.
* TVMSynchronize is necessary.
*/
inline void CopyFrom(DLTensor* other);
inline void CopyFrom(const NDArray& other);
/*!
* \brief Copy data content into another array.
* \param other The source array to be copied from.
* \note The copy may happen asynchrously if it involves a GPU context.
* TVMSynchronize is necessary.
*/
inline void CopyTo(DLTensor* other) const;
inline void CopyTo(const NDArray& other) const;
/*!
* \brief Copy the data to another context.
* \param ctx The target context.
* \return The array under another context.
*/
inline NDArray CopyTo(const DLContext& ctx) const;
/*!
* \brief Load NDArray from stream
* \param stream The input data stream
* \return Whether load is successful
*/
inline bool Load(dmlc::Stream* stream);
/*!
* \brief Save NDArray to stream
* \param stream The output data stream
*/
inline void Save(dmlc::Stream* stream) const;
/*!
* \brief Create a NDArray that shares the data memory with the current one.
* \param shape The shape of the new array.
* \param dtype The data type of the new array.
* \note The memory size of new array must be smaller than the current one.
*/
TVM_DLL NDArray CreateView(
std::vector<int64_t> shape, DLDataType dtype);
/*!
* \brief Create a reference view of NDArray that
* represents as DLManagedTensor.
* \return A DLManagedTensor
*/
TVM_DLL DLManagedTensor* ToDLPack() const;
/*!
* \brief Create an empty NDArray.
* \param shape The shape of the new array.
* \param dtype The data type of the new array.
* \param ctx The context of the Array.
* \return The created Array
*/
TVM_DLL static NDArray Empty(std::vector<int64_t> shape,
DLDataType dtype,
DLContext ctx);
/*!
* \brief Create a NDArray backed by a dlpack tensor.
*
* This allows us to create a NDArray using the memory
* allocated by an external deep learning framework
* that is DLPack compatible.
*
* The memory is retained until the NDArray went out of scope.
* \param tensor The DLPack tensor to copy from.
* \return The created NDArray view.
*/
TVM_DLL static NDArray FromDLPack(DLManagedTensor* tensor);
/*!
* \brief Function to copy data from one array to another.
* \param from The source array.
* \param to The target array.
* \param stream The stream used in copy.
*/
TVM_DLL static void CopyFromTo(
DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr);
// internal namespace
struct Internal;
private:
/*! \brief Internal Data content */
Container* data_{nullptr};
// enable internal functions
friend struct Internal;
friend class TVMRetValue;
friend class TVMArgsSetter;
};
/*!
* \brief Save a DLTensor to stream
* \param strm The outpu stream
* \param tensor The tensor to be saved.
*/
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
/*!
* \brief Reference counted Container object used to back NDArray.
*
* This object is DLTensor compatible:
* the pointer to the NDArrayContainer can be directly
* interpreted as a DLTensor*
*
* \note: do not use this function directly, use NDArray.
*/
struct NDArray::Container {
public:
// NOTE: the first part of this structure is the same as
// DLManagedTensor, note that, however, the deleter
// is only called when the reference counter goes to 0
/*!
* \brief The corresponding dl_tensor field.
* \note it is important that the first field is DLTensor
* So that this data structure is DLTensor compatible.
* The head ptr of this struct can be viewed as DLTensor*.
*/
DLTensor dl_tensor;
/*!
* \brief addtional context, reserved for recycling
* \note We can attach additional content here
* which the current container depend on
* (e.g. reference to original memory when creating views).
*/
void* manager_ctx{nullptr};
/*!
* \brief Customized deleter
*
* \note The customized deleter is helpful to enable
* different ways of memory allocator that are not
* currently defined by the system.
*/
void (*deleter)(Container* self) = nullptr;
/*! \brief default constructor */
Container() {
dl_tensor.data = nullptr;
dl_tensor.ndim = 0;
dl_tensor.shape = nullptr;
dl_tensor.strides = nullptr;
dl_tensor.byte_offset = 0;
}
/*! \brief developer function, increases reference counter */
void IncRef() {
ref_counter_.fetch_add(1, std::memory_order_relaxed);
}
/*! \brief developer function, decrease reference counter */
void DecRef() {
if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
std::atomic_thread_fence(std::memory_order_acquire);
if (this->deleter != nullptr) {
(*this->deleter)(this);
}
}
}
private:
friend class NDArray;
friend class RPCWrappedFunc;
/*!
* \brief The shape container,
* can be used for shape data.
*/
std::vector<int64_t> shape_;
/*!
* \brief The stride container,
* can be used for stride data.
*/
std::vector<int64_t> stride_;
/*! \brief The internal array object */
std::atomic<int> ref_counter_{0};
};
// implementations of inline functions
// the usages of functions are documented in place.
inline NDArray::NDArray(Container* data)
: data_(data) {
data_->IncRef();
}
inline NDArray::NDArray(const NDArray& other)
: data_(other.data_) {
data_->IncRef();
}
inline void NDArray::reset() {
if (data_ != nullptr) {
data_->DecRef();
data_ = nullptr;
}
}
inline void NDArray::CopyFrom(DLTensor* other) {
CHECK(data_ != nullptr);
CopyFromTo(other, &(data_->dl_tensor));
}
inline void NDArray::CopyFrom(const NDArray& other) {
CHECK(data_ != nullptr);
CHECK(other.data_ != nullptr);
CopyFromTo(&(other.data_->dl_tensor), &(data_->dl_tensor));
}
inline void NDArray::CopyTo(DLTensor* other) const {
CHECK(data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), other);
}
inline void NDArray::CopyTo(const NDArray& other) const {
CHECK(data_ != nullptr);
CHECK(other.data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor));
}
inline NDArray NDArray::CopyTo(const DLContext& ctx) const {
CHECK(data_ != nullptr);
const DLTensor* dptr = operator->();
NDArray ret = Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim),
dptr->dtype, ctx);
this->CopyTo(ret);
return ret;
}
inline int NDArray::use_count() const {
if (data_ == nullptr) return 0;
return data_->ref_counter_.load(std::memory_order_relaxed);
}
inline const DLTensor* NDArray::operator->() const {
return &(data_->dl_tensor);
}
/*! \brief Magic number for NDArray file */
constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
inline bool SaveDLTensor(dmlc::Stream* strm,
DLTensor* tensor) {
uint64_t header = kTVMNDArrayMagic, reserved = 0;
strm->Write(header);
strm->Write(reserved);
// Always save data as CPU context
//
// Parameters that get serialized should be in CPU by default.
// So even the array's context is GPU, it will be stored as CPU array.
// This is used to prevent case when another user loads the parameters
// back on machine that do not have GPU or related context.
//
// We can always do array.CopyTo(target_ctx) to get a corresponding
// array in the target context.
DLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
strm->Write(cpu_ctx);
strm->Write(tensor->ndim);
strm->Write(tensor->dtype);
int ndim = tensor->ndim;
strm->WriteArray(tensor->shape, ndim);
int type_bytes = tensor->dtype.bits / 8;
int64_t num_elems = 1;
for (int i = 0; i < ndim; ++i) {
num_elems *= tensor->shape[i];
}
int64_t data_byte_size = type_bytes * num_elems;
strm->Write(data_byte_size);
if (DMLC_IO_NO_ENDIAN_SWAP &&
tensor->ctx.device_type == kDLCPU &&
tensor->strides == nullptr &&
tensor->byte_offset == 0) {
// quick path
strm->Write(tensor->data, data_byte_size);
} else {
std::vector<uint8_t> bytes(data_byte_size);
CHECK_EQ(TVMArrayCopyToBytes(
tensor, dmlc::BeginPtr(bytes), data_byte_size), 0)
<< TVMGetLastError();
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems);
}
strm->Write(dmlc::BeginPtr(bytes), data_byte_size);
}
return true;
}
inline void NDArray::Save(dmlc::Stream* strm) const {
SaveDLTensor(strm, const_cast<DLTensor*>(operator->()));
}
inline bool NDArray::Load(dmlc::Stream* strm) {
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format";
DLContext ctx;
int ndim;
DLDataType dtype;
CHECK(strm->Read(&ctx))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&ndim))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&dtype))
<< "Invalid DLTensor file format";
CHECK_EQ(ctx.device_type, kDLCPU)
<< "Invalid DLTensor context: can only save as CPU tensor";
std::vector<int64_t> shape(ndim);
if (ndim != 0) {
CHECK(strm->ReadArray(&shape[0], ndim))
<< "Invalid DLTensor file format";
}
NDArray ret = NDArray::Empty(shape, dtype, ctx);
int64_t num_elems = 1;
int elem_bytes = (ret->dtype.bits + 7) / 8;
for (int i = 0; i < ret->ndim; ++i) {
num_elems *= ret->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == num_elems * elem_bytes)
<< "Invalid DLTensor file format";
CHECK(strm->Read(ret->data, data_byte_size))
<< "Invalid DLTensor file format";
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
}
*this = ret;
return true;
}
} // namespace runtime
} // namespace tvm
#endif // DGL_RUNTIME_NDARRAY_H_
/*!
* Copyright (c) 2017 by Contributors
* \file dgl/runtime/packed_func.h
* \brief Type-erased function used across TVM API.
*/
#ifndef DGL_RUNTIME_PACKED_FUNC_H_
#define DGL_RUNTIME_PACKED_FUNC_H_
#include <dmlc/logging.h>
#include <functional>
#include <tuple>
#include <vector>
#include <string>
#include <limits>
#include <memory>
#include <type_traits>
#include "c_runtime_api.h"
#include "module.h"
#include "ndarray.h"
// Whether use TVM runtime in header only mode.
#ifndef TVM_RUNTIME_HEADER_ONLY
#define TVM_RUNTIME_HEADER_ONLY 0
#endif
namespace tvm {
// Forward declare NodeRef and Node for extensions.
// This header works fine without depend on NodeRef
// as long as it is not used.
class Node;
class NodeRef;
namespace runtime {
// forward declarations
class TVMArgs;
class TVMArgValue;
class TVMRetValue;
class TVMArgsSetter;
/*!
* \brief Packed function is a type-erased function.
* The arguments are passed by packed format.
*
* This is an useful unified interface to call generated functions,
* It is the unified function function type of TVM.
* It corresponds to TVMFunctionHandle in C runtime API.
*/
class PackedFunc {
public:
/*!
* \brief The internal std::function
* \param args The arguments to the function.
* \param rv The return value.
*
* \code
* // Example code on how to implemented FType
* void MyPackedFunc(TVMArgs args, TVMRetValue* rv) {
* // automatically convert arguments to desired type.
* int a0 = args[0];
* float a1 = args[1];
* ...
* // automatically assign values to rv
* std::string my_return_value = "x";
* *rv = my_return_value;
* }
* \endcode
*/
using FType = std::function<void (TVMArgs args, TVMRetValue* rv)>;
/*! \brief default constructor */
PackedFunc() {}
/*!
* \brief constructing a packed function from a std::function.
* \param body the internal container of packed function.
*/
explicit PackedFunc(FType body) : body_(body) {}
/*!
* \brief Call packed function by directly passing in unpacked format.
* \param args Arguments to be passed.
* \tparam Args arguments to be passed.
*
* \code
* // Example code on how to call packed function
* void CallPacked(PackedFunc f) {
* // call like normal functions by pass in arguments
* // return value is automatically converted back
* int rvalue = f(1, 2.0);
* }
* \endcode
*/
template<typename... Args>
inline TVMRetValue operator()(Args&& ...args) const;
/*!
* \brief Call the function in packed format.
* \param args The arguments
* \param rv The return value.
*/
inline void CallPacked(TVMArgs args, TVMRetValue* rv) const;
/*! \return the internal body function */
inline FType body() const;
/*! \return Whether the packed function is nullptr */
bool operator==(std::nullptr_t null) const {
return body_ == nullptr;
}
/*! \return Whether the packed function is not nullptr */
bool operator!=(std::nullptr_t null) const {
return body_ != nullptr;
}
private:
/*! \brief internal container of packed function */
FType body_;
};
/*!
* \brief Please refer to \ref TypedPackedFuncAnchor "TypedPackedFunc<R(Args..)>"
*/
template<typename FType>
class TypedPackedFunc;
/*!
* \anchor TypedPackedFuncAnchor
* \brief A PackedFunc wrapper to provide typed function signature.
* It is backed by a PackedFunc internally.
*
* TypedPackedFunc enables compile time type checking.
* TypedPackedFunc works with the runtime system:
* - It can be passed as an argument of PackedFunc.
* - It can be assigned to TVMRetValue.
* - It can be directly converted to a type-erased PackedFunc.
*
* Developers should prefer TypedPackedFunc over PackedFunc in C++ code
* as it enables compile time checking.
* We can construct a TypedPackedFunc from a lambda function
* with the same signature.
*
* \code
* // user defined lambda function.
* auto addone = [](int x)->int {
* return x + 1;
* };
* // We can directly convert
* // lambda function to TypedPackedFunc
* TypedPackedFunc<int(int)> ftyped(addone);
* // invoke the function.
* int y = ftyped(1);
* // Can be directly converted to PackedFunc
* PackedFunc packed = ftype;
* \endcode
* \tparam R The return value of the function.
* \tparam Args The argument signature of the function.
*/
template<typename R, typename ...Args>
class TypedPackedFunc<R(Args...)> {
public:
/*! \brief short hand for this function type */
using TSelf = TypedPackedFunc<R(Args...)>;
/*! \brief default constructor */
TypedPackedFunc() {}
/*!
* \brief construct by wrap a PackedFunc
*
* Example usage:
* \code
* PackedFunc packed([](TVMArgs args, TVMRetValue *rv) {
* int x = args[0];
* *rv = x + 1;
* });
* // construct from packed function
* TypedPackedFunc<int(int)> ftyped(packed);
* // call the typed version.
* CHECK_EQ(ftyped(1), 2);
* \endcode
*
* \param packed The packed function
*/
inline explicit TypedPackedFunc(PackedFunc packed);
/*!
* \brief construct from a lambda function with the same signature.
*
* Example usage:
* \code
* auto typed_lambda = [](int x)->int { return x + 1; }
* // construct from packed function
* TypedPackedFunc<int(int)> ftyped(typed_lambda);
* // call the typed version.
* CHECK_EQ(ftyped(1), 2);
* \endcode
*
* \param typed_lambda typed lambda function.
* \tparam FLambda the type of the lambda function.
*/
template<typename FLambda,
typename = typename std::enable_if<
std::is_convertible<FLambda,
std::function<R(Args...)>
>::value>::type>
explicit TypedPackedFunc(const FLambda& typed_lambda) {
this->AssignTypedLambda(typed_lambda);
}
/*!
* \brief copy assignment operator from typed lambda
*
* Example usage:
* \code
* // construct from packed function
* TypedPackedFunc<int(int)> ftyped;
* ftyped = [](int x) { return x + 1; }
* // call the typed version.
* CHECK_EQ(ftyped(1), 2);
* \endcode
*
* \param typed_lambda typed lambda function.
* \tparam FLambda the type of the lambda function.
* \returns reference to self.
*/
template<typename FLambda,
typename = typename std::enable_if<
std::is_convertible<FLambda,
std::function<R(Args...)>
>::value>::type>
TSelf& operator=(FLambda typed_lambda) { // NOLINT(*)
this->AssignTypedLambda(typed_lambda);
return *this;
}
/*!
* \brief copy assignment operator from PackedFunc.
* \param packed The packed function.
* \returns reference to self.
*/
TSelf& operator=(PackedFunc packed) {
packed_ = packed;
return *this;
}
/*!
* \brief Invoke the operator.
* \param args The arguments
* \returns The return value.
*/
inline R operator()(Args ...args) const;
/*!
* \brief convert to PackedFunc
* \return the internal PackedFunc
*/
operator PackedFunc() const {
return packed();
}
/*!
* \return reference the internal PackedFunc
*/
const PackedFunc& packed() const {
return packed_;
}
private:
friend class TVMRetValue;
/*! \brief The internal packed function */
PackedFunc packed_;
/*!
* \brief Assign the packed field using a typed lambda function.
*
* \param flambda The lambda function.
* \tparam FLambda The lambda function type.
* \note We capture the lambda when possible for maximum efficiency.
*/
template<typename FLambda>
inline void AssignTypedLambda(FLambda flambda);
};
/*! \brief Arguments into TVM functions. */
class TVMArgs {
public:
const TVMValue* values;
const int* type_codes;
int num_args;
/*!
* \brief constructor
* \param values The argument values
* \param type_codes The argument type codes
* \param num_args number of arguments.
*/
TVMArgs(const TVMValue* values,
const int* type_codes,
int num_args)
: values(values),
type_codes(type_codes),
num_args(num_args) { }
/*! \return size of the arguments */
inline int size() const;
/*!
* \brief Get i-th argument
* \param i the index.
* \return the ith argument.
*/
inline TVMArgValue operator[](int i) const;
};
/*!
* \brief Convert type code to its name
* \param type_code The type code .
* \return The name of type code.
*/
inline const char* TypeCode2Str(int type_code);
/*!
* \brief convert a string to TVM type.
* \param s The string to be converted.
* \return The corresponding tvm type.
*/
inline TVMType String2TVMType(std::string s);
/*!
* \brief convert a TVM type to string.
* \param t The type to be converted.
* \return The corresponding tvm type in string.
*/
inline std::string TVMType2String(TVMType t);
// macro to check type code.
#define TVM_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " \
<< TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \
/*!
* \brief Type traits to mark if a class is tvm extension type.
*
* To enable extension type in C++ must be register () ed via marco.
* TVM_REGISTER_EXT_TYPE(TypeName) after defining this with this traits.
*
* Extension class can be passed and returned via PackedFunc in all tvm runtime.
* Internally extension class is stored as T*.
*
* \tparam T the typename
*/
template<typename T>
struct extension_class_info {
static const int code = 0;
};
/*!
* \brief Runtime function table about extension type.
*/
class ExtTypeVTable {
public:
/*! \brief function to be called to delete a handle */
void (*destroy)(void* handle);
/*! \brief function to be called when clone a handle */
void* (*clone)(void* handle);
/*!
* \brief Register type
* \tparam T The type to be register.
* \return The registered vtable.
*/
template <typename T>
static inline ExtTypeVTable* Register_();
/*!
* \brief Get a vtable based on type code.
* \param type_code The type code
* \return The registered vtable.
*/
TVM_DLL static ExtTypeVTable* Get(int type_code);
private:
// Internal registration function.
TVM_DLL static ExtTypeVTable* RegisterInternal(int type_code, const ExtTypeVTable& vt);
};
/*!
* \brief Internal base class to
* handle conversion to POD values.
*/
class TVMPODValue_ {
public:
operator double() const {
// Allow automatic conversion from int to float
// This avoids errors when user pass in int from
// the frontend while the API expects a float.
if (type_code_ == kDLInt) {
return static_cast<double>(value_.v_int64);
}
TVM_CHECK_TYPE_CODE(type_code_, kDLFloat);
return value_.v_float64;
}
operator int64_t() const {
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64;
}
operator uint64_t() const {
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64;
}
operator int() const {
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
CHECK_LE(value_.v_int64,
std::numeric_limits<int>::max());
return static_cast<int>(value_.v_int64);
}
operator bool() const {
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64 != 0;
}
operator void*() const {
if (type_code_ == kNull) return nullptr;
if (type_code_ == kArrayHandle) return value_.v_handle;
TVM_CHECK_TYPE_CODE(type_code_, kHandle);
return value_.v_handle;
}
operator DLTensor*() const {
if (type_code_ == kArrayHandle ||
type_code_ == kNDArrayContainer) {
return static_cast<DLTensor*>(value_.v_handle);
} else {
if (type_code_ == kNull) return nullptr;
LOG(FATAL) << "Expected "
<< "DLTensor* or NDArray but get "
<< TypeCode2Str(type_code_);
return nullptr;
}
}
operator NDArray() const {
if (type_code_ == kNull) return NDArray();
TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer);
return NDArray(static_cast<NDArray::Container*>(value_.v_handle));
}
operator TVMContext() const {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx;
}
template<typename TExtension>
const TExtension& AsExtension() const {
CHECK_LT(type_code_, kExtEnd);
return static_cast<TExtension*>(value_.v_handle)[0];
}
int type_code() const {
return type_code_;
}
/*!
* \brief return handle as specific pointer type.
* \tparam T the data type.
* \return The pointer type.
*/
template<typename T>
T* ptr() const {
return static_cast<T*>(value_.v_handle);
}
protected:
friend class TVMArgsSetter;
friend class TVMRetValue;
TVMPODValue_() : type_code_(kNull) {}
TVMPODValue_(TVMValue value, int type_code)
: value_(value), type_code_(type_code) {}
/*! \brief The value */
TVMValue value_;
/*! \brief the type code */
int type_code_;
};
/*!
* \brief A single argument value to PackedFunc.
* Containing both type_code and TVMValue
*
* Provides utilities to do type cast into other types.
*/
class TVMArgValue : public TVMPODValue_ {
public:
/*! \brief default constructor */
TVMArgValue() {}
/*!
* \brief constructor
* \param value of the function
* \param type_code The type code.
*/
TVMArgValue(TVMValue value, int type_code)
: TVMPODValue_(value, type_code) {
}
// reuse converter from parent
using TVMPODValue_::operator double;
using TVMPODValue_::operator int64_t;
using TVMPODValue_::operator uint64_t;
using TVMPODValue_::operator int;
using TVMPODValue_::operator bool;
using TVMPODValue_::operator void*;
using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator TVMContext;
// conversion operator.
operator std::string() const {
if (type_code_ == kTVMType) {
return TVMType2String(operator TVMType());
} else if (type_code_ == kBytes) {
TVMByteArray* arr = static_cast<TVMByteArray*>(value_.v_handle);
return std::string(arr->data, arr->size);
} else {
TVM_CHECK_TYPE_CODE(type_code_, kStr);
return std::string(value_.v_str);
}
}
operator TVMType() const {
if (type_code_ == kStr) {
return String2TVMType(operator std::string());
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMType);
return value_.v_type;
}
operator PackedFunc() const {
if (type_code_ == kNull) return PackedFunc();
TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle);
return *ptr<PackedFunc>();
}
template<typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
}
operator Module() const {
TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
return *ptr<Module>();
}
const TVMValue& value() const {
return value_;
}
// Deferred extension handler.
template<typename TNodeRef>
inline TNodeRef AsNodeRef() const;
template<typename T,
typename = typename std::enable_if<
std::is_class<T>::value>::type>
inline operator T() const;
template<typename TNodeRef,
typename = typename std::enable_if<
std::is_class<TNodeRef>::value>::type>
inline bool IsNodeType() const;
// get internal node ptr, if it is node
inline std::shared_ptr<Node>& node_sptr();
};
/*!
* \brief Return Value container,
* Unlike TVMArgValue, which only holds reference and do not delete
* the underlying container during destruction.
*
* TVMRetValue holds value and will manage the underlying containers
* when it stores a complicated data type.
*/
class TVMRetValue : public TVMPODValue_ {
public:
/*! \brief default constructor */
TVMRetValue() {}
/*!
* \brief move constructor from anoter return value.
* \param other The other return value.
*/
TVMRetValue(TVMRetValue&& other)
: TVMPODValue_(other.value_, other.type_code_) {
other.value_.v_handle = nullptr;
other.type_code_ = kNull;
}
/*! \brief destructor */
~TVMRetValue() {
this->Clear();
}
// reuse converter from parent
using TVMPODValue_::operator double;
using TVMPODValue_::operator int64_t;
using TVMPODValue_::operator uint64_t;
using TVMPODValue_::operator int;
using TVMPODValue_::operator bool;
using TVMPODValue_::operator void*;
using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator NDArray;
// Disable copy and assign from another value, but allow move.
TVMRetValue(const TVMRetValue& other) {
this->Assign(other);
}
// conversion operators
operator std::string() const {
if (type_code_ == kTVMType) {
return TVMType2String(operator TVMType());
} else if (type_code_ == kBytes) {
return *ptr<std::string>();
}
TVM_CHECK_TYPE_CODE(type_code_, kStr);
return *ptr<std::string>();
}
operator TVMType() const {
if (type_code_ == kStr) {
return String2TVMType(operator std::string());
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMType);
return value_.v_type;
}
operator PackedFunc() const {
if (type_code_ == kNull) return PackedFunc();
TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle);
return *ptr<PackedFunc>();
}
template<typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
}
operator Module() const {
TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
return *ptr<Module>();
}
// Assign operators
TVMRetValue& operator=(TVMRetValue&& other) {
this->Clear();
value_ = other.value_;
type_code_ = other.type_code_;
other.type_code_ = kNull;
return *this;
}
TVMRetValue& operator=(double value) {
this->SwitchToPOD(kDLFloat);
value_.v_float64 = value;
return *this;
}
TVMRetValue& operator=(std::nullptr_t value) {
this->SwitchToPOD(kNull);
value_.v_handle = value;
return *this;
}
TVMRetValue& operator=(void* value) {
this->SwitchToPOD(kHandle);
value_.v_handle = value;
return *this;
}
TVMRetValue& operator=(int64_t value) {
this->SwitchToPOD(kDLInt);
value_.v_int64 = value;
return *this;
}
TVMRetValue& operator=(int value) {
this->SwitchToPOD(kDLInt);
value_.v_int64 = value;
return *this;
}
TVMRetValue& operator=(TVMType t) {
this->SwitchToPOD(kTVMType);
value_.v_type = t;
return *this;
}
TVMRetValue& operator=(bool value) {
this->SwitchToPOD(kDLInt);
value_.v_int64 = value;
return *this;
}
TVMRetValue& operator=(std::string value) {
this->SwitchToClass(kStr, value);
return *this;
}
TVMRetValue& operator=(TVMByteArray value) {
this->SwitchToClass(kBytes, std::string(value.data, value.size));
return *this;
}
TVMRetValue& operator=(NDArray other) {
this->Clear();
type_code_ = kNDArrayContainer;
value_.v_handle = other.data_;
other.data_ = nullptr;
return *this;
}
TVMRetValue& operator=(PackedFunc f) {
this->SwitchToClass(kFuncHandle, f);
return *this;
}
template<typename FType>
TVMRetValue& operator=(const TypedPackedFunc<FType>& f) {
return operator=(f.packed());
}
TVMRetValue& operator=(Module m) {
this->SwitchToClass(kModuleHandle, m);
return *this;
}
TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0
this->Assign(other);
return *this;
}
TVMRetValue& operator=(const TVMArgValue& other) {
this->Assign(other);
return *this;
}
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
TVMRetValue& operator=(const T& other) {
this->SwitchToClass<T>(
extension_class_info<T>::code, other);
return *this;
}
/*!
* \brief Move the value back to front-end via C API.
* This marks the current container as null.
* The managed resources is moved to front-end and
* the front end should take charge in managing them.
*
* \param ret_value The return value.
* \param ret_type_code The return type code.
*/
void MoveToCHost(TVMValue* ret_value,
int* ret_type_code) {
// cannot move str; need specially handle.
CHECK(type_code_ != kStr && type_code_ != kBytes);
*ret_value = value_;
*ret_type_code = type_code_;
type_code_ = kNull;
}
/*! \return The value field, if the data is POD */
const TVMValue& value() const {
CHECK(type_code_ != kNodeHandle &&
type_code_ != kFuncHandle &&
type_code_ != kModuleHandle &&
type_code_ != kStr) << "TVMRetValue.value can only be used for POD data";
return value_;
}
// NodeRef related extenstions: in tvm/packed_func_ext.h
template<typename T,
typename = typename std::enable_if<
std::is_class<T>::value>::type>
inline operator T() const;
template<typename TNodeRef>
inline TNodeRef AsNodeRef() const;
inline TVMRetValue& operator=(const NodeRef& other);
inline TVMRetValue& operator=(const std::shared_ptr<Node>& other);
private:
template<typename T>
void Assign(const T& other) {
switch (other.type_code()) {
case kStr: {
SwitchToClass<std::string>(kStr, other);
break;
}
case kBytes: {
SwitchToClass<std::string>(kBytes, other);
break;
}
case kFuncHandle: {
SwitchToClass<PackedFunc>(kFuncHandle, other);
break;
}
case kModuleHandle: {
SwitchToClass<Module>(kModuleHandle, other);
break;
}
case kNDArrayContainer: {
*this = other.operator NDArray();
break;
}
case kNodeHandle: {
SwitchToClass<std::shared_ptr<Node> >(
kNodeHandle, *other.template ptr<std::shared_ptr<Node> >());
break;
}
default: {
if (other.type_code() < kExtBegin) {
SwitchToPOD(other.type_code());
value_ = other.value_;
} else {
#if TVM_RUNTIME_HEADER_ONLY
LOG(FATAL) << "Header only mode do not support ext type";
#else
this->Clear();
type_code_ = other.type_code();
value_.v_handle =
(*(ExtTypeVTable::Get(other.type_code())->clone))(
other.value().v_handle);
#endif
}
break;
}
}
}
// get the internal container.
void SwitchToPOD(int type_code) {
if (type_code_ != type_code) {
this->Clear();
type_code_ = type_code;
}
}
template<typename T>
void SwitchToClass(int type_code, T v) {
if (type_code_ != type_code) {
this->Clear();
type_code_ = type_code;
value_.v_handle = new T(v);
} else {
*static_cast<T*>(value_.v_handle) = v;
}
}
void Clear() {
if (type_code_ == kNull) return;
switch (type_code_) {
case kStr: delete ptr<std::string>(); break;
case kFuncHandle: delete ptr<PackedFunc>(); break;
case kModuleHandle: delete ptr<Module>(); break;
case kNodeHandle: delete ptr<std::shared_ptr<Node> >(); break;
case kNDArrayContainer: {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break;
}
}
if (type_code_ > kExtBegin) {
#if TVM_RUNTIME_HEADER_ONLY
LOG(FATAL) << "Header only mode do not support ext type";
#else
(*(ExtTypeVTable::Get(type_code_)->destroy))(value_.v_handle);
#endif
}
type_code_ = kNull;
}
};
// implementation details
inline const char* TypeCode2Str(int type_code) {
switch (type_code) {
case kDLInt: return "int";
case kDLUInt: return "uint";
case kDLFloat: return "float";
case kStr: return "str";
case kBytes: return "bytes";
case kHandle: return "handle";
case kNull: return "NULL";
case kNodeHandle: return "NodeHandle";
case kArrayHandle: return "ArrayHandle";
case kTVMType: return "TVMType";
case kTVMContext: return "TVMContext";
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
}
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*)
os << TypeCode2Str(t.code);
if (t.code == kHandle) return os;
os << static_cast<int>(t.bits);
if (t.lanes != 1) {
os << 'x' << static_cast<int>(t.lanes);
}
return os;
}
#endif
inline std::string TVMType2String(TVMType t) {
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
std::ostringstream os;
os << t;
return os.str();
#else
std::string repr = "";
repr += TypeCode2Str(t.code);
if (t.code == kHandle) return repr;
repr += std::to_string(static_cast<int>(t.bits));
if (t.lanes != 1) {
repr += "x" + std::to_string(static_cast<int>(t.lanes));
}
return repr;
#endif
}
inline TVMType String2TVMType(std::string s) {
TVMType t;
t.bits = 32; t.lanes = 1;
const char* scan;
if (s.substr(0, 3) == "int") {
t.code = kDLInt; scan = s.c_str() + 3;
} else if (s.substr(0, 4) == "uint") {
t.code = kDLUInt; scan = s.c_str() + 4;
} else if (s.substr(0, 5) == "float") {
t.code = kDLFloat; scan = s.c_str() + 5;
} else if (s.substr(0, 6) == "handle") {
t.code = kHandle;
t.bits = 64; // handle uses 64 bit by default.
scan = s.c_str() + 6;
} else {
scan = s.c_str();
LOG(FATAL) << "unknown type " << s;
}
char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
if (bits != 0) t.bits = bits;
if (*xdelim == 'x') {
t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, nullptr, 10));
}
return t;
}
inline TVMArgValue TVMArgs::operator[](int i) const {
CHECK_LT(i, num_args)
<< "not enough argument passed, "
<< num_args << " passed"
<< " but request arg[" << i << "].";
return TVMArgValue(values[i], type_codes[i]);
}
inline int TVMArgs::size() const {
return num_args;
}
inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const {
body_(args, rv);
}
inline PackedFunc::FType PackedFunc::body() const {
return body_;
}
// internal namespace
namespace detail {
template<bool stop, std::size_t I, typename F>
struct for_each_dispatcher {
template<typename T, typename ...Args>
static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*)
f(I, std::forward<T>(value));
for_each_dispatcher<sizeof...(Args) == 0, (I+1), F>
::run(f, std::forward<Args>(args)...);
}
};
template<std::size_t I, typename F>
struct for_each_dispatcher<true, I, F> {
static void run(const F& f) {} // NOLINT(*)
};
template<typename F, typename ...Args>
inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
for_each_dispatcher<sizeof...(Args) == 0, 0, F>
::run(f, std::forward<Args>(args)...);
}
} // namespace detail
/* \brief argument settter to PackedFunc */
class TVMArgsSetter {
public:
TVMArgsSetter(TVMValue* values, int* type_codes)
: values_(values), type_codes_(type_codes) {}
// setters for POD types
template<typename T,
typename = typename std::enable_if<
std::is_integral<T>::value>::type>
void operator()(size_t i, T value) const {
values_[i].v_int64 = static_cast<int64_t>(value);
type_codes_[i] = kDLInt;
}
void operator()(size_t i, uint64_t value) const {
values_[i].v_int64 = static_cast<int64_t>(value);
CHECK_LE(value,
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
type_codes_[i] = kDLInt;
}
void operator()(size_t i, double value) const {
values_[i].v_float64 = value;
type_codes_[i] = kDLFloat;
}
void operator()(size_t i, std::nullptr_t value) const {
values_[i].v_handle = value;
type_codes_[i] = kNull;
}
void operator()(size_t i, const TVMArgValue& value) const {
values_[i] = value.value_;
type_codes_[i] = value.type_code_;
}
void operator()(size_t i, void* value) const {
values_[i].v_handle = value;
type_codes_[i] = kHandle;
}
void operator()(size_t i, DLTensor* value) const {
values_[i].v_handle = value;
type_codes_[i] = kArrayHandle;
}
void operator()(size_t i, TVMContext value) const {
values_[i].v_ctx = value;
type_codes_[i] = kTVMContext;
}
void operator()(size_t i, TVMType value) const {
values_[i].v_type = value;
type_codes_[i] = kTVMType;
}
void operator()(size_t i, const char* value) const {
values_[i].v_str = value;
type_codes_[i] = kStr;
}
// setters for container type
// They must be reference(instead of const ref)
// to make sure they are alive in the tuple(instead of getting converted)
void operator()(size_t i, const std::string& value) const { // NOLINT(*)
values_[i].v_str = value.c_str();
type_codes_[i] = kStr;
}
void operator()(size_t i, const TVMByteArray& value) const { // NOLINT(*)
values_[i].v_handle = const_cast<TVMByteArray*>(&value);
type_codes_[i] = kBytes;
}
void operator()(size_t i, const PackedFunc& value) const { // NOLINT(*)
values_[i].v_handle = const_cast<PackedFunc*>(&value);
type_codes_[i] = kFuncHandle;
}
template<typename FType>
void operator()(size_t i, const TypedPackedFunc<FType>& value) const { // NOLINT(*)
operator()(i, value.packed());
}
void operator()(size_t i, const Module& value) const { // NOLINT(*)
values_[i].v_handle = const_cast<Module*>(&value);
type_codes_[i] = kModuleHandle;
}
void operator()(size_t i, const NDArray& value) const { // NOLINT(*)
values_[i].v_handle = value.data_;
type_codes_[i] = kNDArrayContainer;
}
void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*)
if (value.type_code() == kStr) {
values_[i].v_str = value.ptr<std::string>()->c_str();
type_codes_[i] = kStr;
} else {
CHECK_NE(value.type_code(), kBytes) << "not handled.";
values_[i] = value.value_;
type_codes_[i] = value.type_code();
}
}
// extension
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
inline void operator()(size_t i, const T& value) const;
// NodeRef related extenstions: in tvm/packed_func_ext.h
inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*)
private:
/*! \brief The values fields */
TVMValue* values_;
/*! \brief The type code fields */
int* type_codes_;
};
template<typename... Args>
inline TVMRetValue PackedFunc::operator()(Args&& ...args) const {
const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
TVMValue values[kArraySize];
int type_codes[kArraySize];
detail::for_each(TVMArgsSetter(values, type_codes),
std::forward<Args>(args)...);
TVMRetValue rv;
body_(TVMArgs(values, type_codes, kNumArgs), &rv);
return rv;
}
namespace detail {
template<typename R, int nleft, int index, typename F>
struct unpack_call_dispatcher {
template<typename ...Args>
static void run(const F& f,
const TVMArgs& args_pack,
TVMRetValue* rv,
Args&&... unpacked_args) {
unpack_call_dispatcher<R, nleft - 1, index + 1, F>
::run(f, args_pack, rv,
std::forward<Args>(unpacked_args)...,
args_pack[index]);
}
};
template<typename R, int index, typename F>
struct unpack_call_dispatcher<R, 0, index, F> {
template<typename ...Args>
static void run(const F& f,
const TVMArgs& args_pack,
TVMRetValue* rv,
Args&&... unpacked_args) {
*rv = R(f(std::forward<Args>(unpacked_args)...));
}
};
template<int index, typename F>
struct unpack_call_dispatcher<void, 0, index, F> {
template<typename ...Args>
static void run(const F& f,
const TVMArgs& args_pack,
TVMRetValue* rv,
Args&&... unpacked_args) {
f(std::forward<Args>(unpacked_args)...);
}
};
template<typename R, int nargs, typename F>
inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) {
unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv);
}
template<typename R, typename ...Args>
inline R call_packed(const PackedFunc& pf, Args&& ...args) {
return R(pf(std::forward<Args>(args)...));
}
template<typename R>
struct typed_packed_call_dispatcher {
template<typename ...Args>
static inline R run(const PackedFunc& pf, Args&& ...args) {
return pf(std::forward<Args>(args)...);
}
};
template<>
struct typed_packed_call_dispatcher<void> {
template<typename ...Args>
static inline void run(const PackedFunc& pf, Args&& ...args) {
pf(std::forward<Args>(args)...);
}
};
} // namespace detail
template<typename R, typename ...Args>
TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed)
: packed_(packed) {}
template<typename R, typename ...Args>
template<typename FType>
inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) {
detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv);
});
}
template<typename R, typename ...Args>
inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
return detail::typed_packed_call_dispatcher<R>
::run(packed_, std::forward<Args>(args)...);
}
// extension and node type handling
namespace detail {
template<typename T, typename TSrc, bool is_ext>
struct TVMValueCast {
static T Apply(const TSrc* self) {
return self->template AsNodeRef<T>();
}
};
template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, true> {
static T Apply(const TSrc* self) {
return self->template AsExtension<T>();
}
};
} // namespace detail
template<typename T, typename>
inline TVMArgValue::operator T() const {
return detail::
TVMValueCast<T, TVMArgValue, extension_class_info<T>::code != 0>
::Apply(this);
}
template<typename T, typename>
inline TVMRetValue::operator T() const {
return detail::
TVMValueCast<T, TVMRetValue, extension_class_info<T>::code != 0>
::Apply(this);
}
template<typename T, typename>
inline void TVMArgsSetter::operator()(size_t i, const T& value) const {
static_assert(extension_class_info<T>::code != 0,
"Need to have extesion code");
type_codes_[i] = extension_class_info<T>::code;
values_[i].v_handle = const_cast<T*>(&value);
}
// extension type handling
template<typename T>
struct ExtTypeInfo {
static void destroy(void* handle) {
delete static_cast<T*>(handle);
}
static void* clone(void* handle) {
return new T(*static_cast<T*>(handle));
}
};
template<typename T>
inline ExtTypeVTable* ExtTypeVTable::Register_() {
const int code = extension_class_info<T>::code;
static_assert(code != 0,
"require extension_class_info traits to be declared with non-zero code");
ExtTypeVTable vt;
vt.clone = ExtTypeInfo<T>::clone;
vt.destroy = ExtTypeInfo<T>::destroy;
return ExtTypeVTable::RegisterInternal(code, vt);
}
// Implement Module::GetFunction
// Put implementation in this file so we have seen the PackedFunc
inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
PackedFunc pf = node_->GetFunction(name, node_);
if (pf != nullptr) return pf;
if (query_imports) {
for (const Module& m : node_->imports_) {
pf = m.node_->GetFunction(name, m.node_);
if (pf != nullptr) return pf;
}
}
return pf;
}
} // namespace runtime
} // namespace tvm
#endif // DGL_RUNTIME_PACKED_FUNC_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment