Unverified Commit 8801154b authored by VoVAllen's avatar VoVAllen Committed by GitHub
Browse files

Merge pull request #1 from jermainewang/cpp

Cpp
parents b46abb09 b2c1c4fa
......@@ -2,6 +2,8 @@
Semi-Supervised Classification with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1609.02907
Code: https://github.com/tkipf/gcn
GCN with batch processing
"""
import argparse
import numpy as np
......@@ -9,14 +11,15 @@ import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
def gcn_msg(src, edge):
return src['h']
return src
def gcn_reduce(node, msgs):
return {'h' : sum(msgs)}
return torch.sum(msgs, 1)
class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
......@@ -25,14 +28,14 @@ class NodeApplyModule(nn.Module):
self.activation = activation
def forward(self, node):
h = self.linear(node['h'])
h = self.linear(node)
if self.activation:
h = self.activation(h)
return {'h' : h}
return h
class GCN(nn.Module):
def __init__(self,
nx_graph,
g,
in_feats,
n_hidden,
n_classes,
......@@ -40,7 +43,7 @@ class GCN(nn.Module):
activation,
dropout):
super(GCN, self).__init__()
self.g = DGLGraph(nx_graph)
self.g = g
self.dropout = dropout
# input layer
self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
......@@ -50,31 +53,24 @@ class GCN(nn.Module):
# output layer
self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features, train_nodes):
for n, feat in features.items():
self.g.nodes[n]['h'] = feat
def forward(self, features):
self.g.set_n_repr(features)
for layer in self.layers:
# apply dropout
if self.dropout:
self.g.nodes[n]['h'] = F.dropout(g.nodes[n]['h'], p=self.dropout)
val = F.dropout(self.g.get_n_repr(), p=self.dropout)
self.g.set_n_repr(val)
self.g.update_all(gcn_msg, gcn_reduce, layer)
return torch.cat([torch.unsqueeze(self.g.nodes[n]['h'], 0) for n in train_nodes])
return self.g.pop_n_repr()
def main(args):
# load and preprocess dataset
data = load_data(args)
# features of each samples
features = {}
labels = []
train_nodes = []
for n in data.graph.nodes():
features[n] = torch.FloatTensor(data.features[n, :])
if data.train_mask[n] == 1:
train_nodes.append(n)
labels.append(data.labels[n])
labels = torch.LongTensor(labels)
in_feats = data.features.shape[1]
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
......@@ -83,11 +79,13 @@ def main(args):
else:
cuda = True
torch.cuda.set_device(args.gpu)
features = {k : v.cuda() for k, v in features.items()}
features = features.cuda()
labels = labels.cuda()
mask = mask.cuda()
# create GCN model
model = GCN(data.graph,
g = DGLGraph(data.graph)
model = GCN(g,
in_feats,
args.n_hidden,
n_classes,
......@@ -107,9 +105,9 @@ def main(args):
if epoch >= 3:
t0 = time.time()
# forward
logits = model(features, train_nodes)
logits = model(features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, labels)
loss = F.nll_loss(logp[mask], labels[mask])
optimizer.zero_grad()
loss.backward()
......@@ -130,7 +128,7 @@ if __name__ == '__main__':
help="gpu")
parser.add_argument("--lr", type=float, default=1e-3,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=10,
parser.add_argument("--n-epochs", type=int, default=20,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden gcn units")
......
......@@ -55,7 +55,7 @@ class GCN(nn.Module):
if self.dropout:
val = F.dropout(self.g.get_n_repr(), p=self.dropout)
self.g.set_n_repr(val)
self.g.update_all(fn.copy_src(), fn.sum(), layer, batchable=True)
self.g.update_all(fn.copy_src(), fn.sum(), layer)
return self.g.pop_n_repr()
def main(args):
......
Community Detection with Graph Neural Networks (CDGNN)
============
Paper link: [https://arxiv.org/abs/1705.08415](https://arxiv.org/abs/1705.08415)
Author's code repo: [https://github.com/joanbruna/GNN_community](https://github.com/joanbruna/GNN_community)
This folder contains a DGL implementation of the CDGNN model.
An experiment on the Stochastic Block Model in default settings can be run with
```bash
python train.py
```
An experiment on the Stochastic Block Model in customized settings can be run with
```bash
python train.py --batch-size BATCH_SIZE --gpu GPU --n-communities N_COMMUNITIES --n-features N_FEATURES --n-graphs N_GRAPH --n-iterations N_ITERATIONS --n-layers N_LAYER --n-nodes N_NODE --model-path MODEL_PATH --radius RADIUS
```
......@@ -3,237 +3,93 @@ Supervised Community Detection with Hierarchical Graph Neural Networks
https://arxiv.org/abs/1705.08415
Deviations from paper:
- Addition of global aggregation operator.
- Message passing is equivalent to `A^j \cdot X`, instead of `\min(1, A^j) \cdot X`.
- Pm Pd
"""
# TODO self-loop?
# TODO in-place edit of node_reprs/edge_reprs in message_func/update_func?
# TODO batch-norm
import copy
import itertools
import dgl.graph as G
import dgl
import dgl.function as fn
import networkx as nx
import torch as th
import torch.nn as nn
import torch.nn.functional as F
class GLGModule(nn.Module):
__SHADOW__ = 'shadow'
class GNNModule(nn.Module):
def __init__(self, in_feats, out_feats, radius):
super().__init__()
self.out_feats = out_feats
self.radius = radius
new_linear = lambda: nn.Linear(in_feats, out_feats)
new_linear = lambda: nn.Linear(in_feats, out_feats * 2)
new_module_list = lambda: nn.ModuleList([new_linear() for i in range(radius)])
self.theta_x, self.theta_y, self.theta_deg, self.theta_global = \
new_linear(), new_linear(), new_linear(), new_linear()
self.theta_x, self.theta_deg, self.theta_y = \
new_linear(), new_linear(), new_linear()
self.theta_list = new_module_list()
self.gamma_x, self.gamma_y, self.gamma_deg, self.gamma_global = \
new_linear(), new_linear(), new_linear(), new_linear()
self.gamma_y, self.gamma_deg, self.gamma_x = \
new_linear(), new_linear(), new_linear()
self.gamma_list = new_module_list()
@staticmethod
def copy(which):
if which == 'src':
return lambda src, trg, _: src.copy()
elif which == 'trg':
return lambda src, trg, _: trg.copy()
@staticmethod
def aggregate(msg_fld, trg_fld, normalize=False):
def a(node_reprs, edge_reprs):
node_reprs = node_reprs.copy()
node_reprs[trg_fld] = sum(msg[msg_fld] for msg in edge_reprs)
if normalize:
node_reprs[trg_fld] /= len(edge_reprs)
return node_reprs
return a
@staticmethod
def pull(msg_fld, trg_fld):
def p(node_reprs, edge_reprs):
node_reprs = node_reprs.copy()
node_reprs[trg_fld] = edge_reprs[0][msg_fld]
return node_reprs
return p
def local_aggregate(self, g):
def step():
g.register_message_func(self.copy('src'), g.edges)
g.register_update_func(self.aggregate('x', 'x'), g.nodes)
g.update_all()
step()
for reprs in g.nodes.values():
reprs[0] = reprs['x']
for i in range(1, self.radius):
for j in range(2 ** (i - 1)):
step()
for reprs in g.nodes.values():
reprs[i] = reprs['x']
@staticmethod
def global_aggregate(g):
shadow = GLGModule.__SHADOW__
copy, aggregate, pull = GLGModule.copy, GLGModule.aggregate, GLGModule.pull
node_list = list(g.nodes)
uv_list = [(node, shadow) for node in g.nodes]
vu_list = [(shadow, node) for node in g.nodes]
g.add_node(shadow) # TODO context manager
tuple(itertools.starmap(g.add_edge, uv_list))
g.register_message_func(copy('src'), uv_list)
g.register_update_func(aggregate('x', 'global', normalize=True), (shadow,))
g.update_to(shadow)
tuple(itertools.starmap(g.add_edge, vu_list))
g.register_message_func(copy('src'), vu_list)
g.register_update_func(pull('global', 'global'), node_list)
g.update_from(shadow)
g.remove_node(shadow)
@staticmethod
def multiply_by_degree(g):
g.register_message_func(lambda *args: None, g.edges)
def update_func(node_reprs, _):
node_reprs = node_reprs.copy()
node_reprs['deg'] = node_reprs['x'] * node_reprs['degree']
return node_reprs
g.register_update_func(update_func, g.nodes)
g.update_all()
@staticmethod
def message_func(src, trg, _):
return {'y' : src['x']}
def update_func(self, which):
if which == 'node':
linear_x, linear_y, linear_deg, linear_global = \
self.theta_x, self.theta_y, self.theta_deg, self.theta_global
linear_list = self.theta_list
elif which == 'edge':
linear_x, linear_y, linear_deg, linear_global = \
self.gamma_x, self.gamma_y, self.gamma_deg, self.gamma_global
linear_list = self.gamma_list
def u(node_reprs, edge_reprs):
edge_reprs = filter(lambda x: x is not None, edge_reprs)
y = sum(x['y'] for x in edge_reprs)
node_reprs = node_reprs.copy()
node_reprs['x'] = linear_x(node_reprs['x']) \
+ linear_y(y) \
+ linear_deg(node_reprs['deg']) \
+ linear_global(node_reprs['global']) \
+ sum(linear(node_reprs[i]) \
for i, linear in enumerate(linear_list))
return node_reprs
return u
def forward(self, g, lg, glg):
self.local_aggregate(g)
self.local_aggregate(lg)
self.global_aggregate(g)
self.global_aggregate(lg)
self.multiply_by_degree(g)
self.multiply_by_degree(lg)
# TODO efficiency
for node, reprs in g.nodes.items():
glg.nodes[node].update(reprs)
for node, reprs in lg.nodes.items():
glg.nodes[node].update(reprs)
glg.register_message_func(self.message_func, glg.edges)
glg.register_update_func(self.update_func('node'), g.nodes)
glg.register_update_func(self.update_func('edge'), lg.nodes)
glg.update_all()
# TODO efficiency
for node, reprs in g.nodes.items():
reprs.update(glg.nodes[node])
for node, reprs in lg.nodes.items():
reprs.update(glg.nodes[node])
self.bn_x = nn.BatchNorm1d(out_feats)
self.bn_y = nn.BatchNorm1d(out_feats)
def aggregate(self, g, z):
z_list = []
g.set_n_repr(z)
g.update_all(fn.copy_src(), fn.sum())
z_list.append(g.get_n_repr())
for i in range(self.radius - 1):
for j in range(2 ** i):
g.update_all(fn.copy_src(), fn.sum())
z_list.append(g.get_n_repr())
return z_list
class GNNModule(nn.Module):
def __init__(self, in_feats, out_feats, order, radius):
super().__init__()
self.module_list = nn.ModuleList([GLGModule(in_feats, out_feats, radius)
for i in range(order)])
def forward(self, g, lg, x, y, deg_g, deg_lg, eid2nid):
xy = F.embedding(eid2nid, x)
x_list = [theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x))]
g.set_e_repr(y)
g.update_all(fn.copy_edge(), fn.sum())
yx = g.get_n_repr()
def forward(self, pairs, fusions):
for module, (g, lg), glg in zip(self.module_list, pairs, fusions):
module(g, lg, glg)
for lhs, rhs in zip(pairs[:-1], pairs[1:]):
for node, reprs in lhs[1].nodes.items():
x_rhs = reprs['x']
reprs['x'] = x_rhs + rhs[0].nodes[node]['x']
rhs[0].nodes[node]['x'] += x_rhs
x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum(x_list) + self.theta_y(yx)
x = self.bn_x(x[:, :self.out_feats] + F.relu(x[:, self.out_feats:]))
y_list = [gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y))]
lg.set_n_repr(xy)
lg.update_all(fn.copy_src(), fn.sum())
xy = lg.get_n_repr()
y = self.gamma_y(y) + self.gamma_deg(deg_lg * y) + sum(y_list) + self.gamma_x(xy)
y = self.bn_y(y[:, :self.out_feats] + F.relu(y[:, self.out_feats:]))
return x, y
class GNN(nn.Module):
def __init__(self, feats, order, radius, n_classes):
super().__init__()
self.order = order
self.linear = nn.Linear(feats[-1], n_classes)
self.module_list = nn.ModuleList([GNNModule(in_feats, out_feats, order, radius)
for in_feats, out_feats in zip(feats[:-1], feats[1:])])
@staticmethod
def line_graph(g):
lg = nx.line_graph(g)
glg = nx.DiGraph()
glg.add_nodes_from(g.nodes)
glg.add_nodes_from(lg.nodes)
for u, v in g.edges:
glg.add_edge(u, (u, v))
glg.add_edge((u, v), u)
glg.add_edge(v, (u, v))
glg.add_edge((u, v), v)
return lg, glg
@staticmethod
def nx2dgl(g):
deg_dict = dict(nx.degree(g))
z = sum(deg_dict.values())
dgl_g = G.DGLGraph(g)
for node, reprs in dgl_g.nodes.items():
reprs['degree'] = deg_dict[node]
reprs['x'] = th.full((1, 1), reprs['degree'] / z)
reprs.update(g.nodes[node])
return dgl_g
def forward(self, g):
def __init__(self, feats, radius, n_classes):
"""
Parameters
----------
g : networkx.DiGraph
"""
pair_list, glg_list = [], []
dgl_g = self.nx2dgl(g)
origin = dgl_g
for i in range(self.order):
lg, glg = self.line_graph(g)
dgl_lg = self.nx2dgl(lg)
pair_list.append((dgl_g, copy.deepcopy(dgl_lg)))
glg_list.append(G.DGLGraph(glg))
g = lg
dgl_g = dgl_lg
super(GNN, self).__init__()
self.linear = nn.Linear(feats[-1], n_classes)
self.module_list = nn.ModuleList([GNNModule(m, n, radius)
for m, n in zip(feats[:-1], feats[1:])])
for module in self.module_list:
module(pair_list, glg_list)
def forward(self, g, lg, deg_g, deg_lg, eid2nid):
def normalize(x):
x = x - th.mean(x, 0)
x = x / th.sqrt(th.mean(x * x, 0))
return x
return self.linear(th.cat([reprs['x'] for reprs in origin.nodes.values()], 0))
x = normalize(deg_g)
y = normalize(deg_lg)
for module in self.module_list:
x, y = module(g, lg, x, y, deg_g, deg_lg, eid2nid)
return self.linear(x)
"""
By Minjie
"""
from __future__ import division
import math
import numpy as np
import scipy.sparse as sp
import networkx as nx
import matplotlib.pyplot as plt
class SSBM:
def __init__(self, n, k, a=10.0, b=2.0, regime='constant', rng=None):
"""Symmetric Stochastic Block Model.
n - number of nodes
k - number of communities
a - probability scale for intra-community edge
b - probability scale for inter-community edge
regime - If "logaritm", this generates SSBM(n, k, a*log(n)/n, b*log(n)/n)
If "constant", this generates SSBM(n, k, a/n, b/n)
If "mixed", this generates SSBM(n, k, a*log(n)/n, b/n)
"""
self.n = n
self.k = k
if regime == 'logarithm':
if math.sqrt(a) - math.sqrt(b) >= math.sqrt(k):
print('SSBM model with possible exact recovery.')
else:
print('SSBM model with impossible exact recovery.')
self.a = a * math.log(n) / n
self.b = b * math.log(n) / n
elif regime == 'constant':
snr = (a - b) ** 2 / (k * (a + (k - 1) * b))
if snr > 1:
print('SSBM model with possible detection.')
else:
print('SSBM model that may not have detection (snr=%.5f).' % snr)
self.a = a / n
self.b = b / n
elif regime == 'mixed':
self.a = a * math.log(n) / n
self.b = b / n
else:
raise ValueError('Unknown regime: %s' % regime)
if rng is None:
self.rng = np.random.RandomState()
else:
self.rng = rng
self._graph = None
def generate(self):
self.generate_communities()
print('Finished generating communities.')
self.generate_edges()
print('Finished generating edges.')
def generate_communities(self):
nodes = list(range(self.n))
size = self.n // self.k
self.block_size = size
self.comm2node = [nodes[i*size:(i+1)*size] for i in range(self.k)]
self.node2comm = [nid // size for nid in range(self.n)]
def generate_edges(self):
# TODO: dedup edges
us = []
vs = []
# generate intra-comm edges
for i in range(self.k):
sp_mat = sp.random(self.block_size, self.block_size,
density=self.a,
random_state=self.rng,
data_rvs=lambda l: np.ones(l))
u = sp_mat.row + i * self.block_size
v = sp_mat.col + i * self.block_size
us.append(u)
vs.append(v)
# generate inter-comm edges
for i in range(self.k):
for j in range(self.k):
if i == j:
continue
sp_mat = sp.random(self.block_size, self.block_size,
density=self.b,
random_state=self.rng,
data_rvs=lambda l: np.ones(l))
u = sp_mat.row + i * self.block_size
v = sp_mat.col + j * self.block_size
us.append(u)
vs.append(v)
us = np.hstack(us)
vs = np.hstack(vs)
self.sp_mat = sp.coo_matrix((np.ones(us.shape[0]), (us, vs)), shape=(self.n, self.n))
@property
def graph(self):
if self._graph is None:
self._graph = nx.from_scipy_sparse_matrix(self.sp_mat, create_using=nx.DiGraph())
return self._graph
def plot(self):
x = self.sp_mat.row
y = self.sp_mat.col
plt.scatter(x, y, s=0.5, marker='.', c='k')
plt.savefig('ssbm-%d-%d.pdf' % (self.n, self.k))
plt.clf()
# plot out degree distribution
out_degree = [d for _, d in self.graph.out_degree().items()]
plt.hist(out_degree, 100, normed=True)
plt.savefig('ssbm-%d-%d_out_degree.pdf' % (self.n, self.k))
plt.clf()
if __name__ == '__main__':
n = 1000
k = 10
ssbm = SSBM(n, k, regime='mixed', a=4, b=1)
ssbm.generate()
g = ssbm.graph
print('#nodes:', g.number_of_nodes())
print('#edges:', g.number_of_edges())
#ssbm.plot()
#lg = nx.line_graph(g)
# plot degree distribution
#degree = [d for _, d in lg.degree().items()]
#plt.hist(degree, 100, normed=True)
#plt.savefig('lg<ssbm-%d-%d>_degree.pdf' % (n, k))
#plt.clf()
"""
ipython3 test.py -- --features 1 16 16 --gpu -1 --n-classes 5 --n-iterations 10 --n-nodes 10 --order 3 --radius 3
"""
import argparse
import networkx as nx
import torch as th
import torch.nn as nn
import torch.optim as optim
import gnn
parser = argparse.ArgumentParser()
parser.add_argument('--features', nargs='+', type=int)
parser.add_argument('--gpu', type=int)
parser.add_argument('--n-classes', type=int)
parser.add_argument('--n-iterations', type=int)
parser.add_argument('--n-nodes', type=int)
parser.add_argument('--order', type=int)
parser.add_argument('--radius', type=int)
args = parser.parse_args()
if args.gpu < 0:
cuda = False
else:
cuda = True
th.cuda.set_device(args.gpu)
g = nx.barabasi_albert_graph(args.n_nodes, 1).to_directed() # TODO SBM
y = th.multinomial(th.ones(args.n_classes), args.n_nodes, replacement=True)
network = gnn.GNN(args.features, args.order, args.radius, args.n_classes)
if cuda:
network.cuda()
ce = nn.CrossEntropyLoss()
adam = optim.Adam(network.parameters())
for i in range(args.n_iterations):
y_bar = network(g)
loss = ce(y_bar, y)
adam.zero_grad()
loss.backward()
adam.step()
print('[iteration %d]loss %f' % (i, loss))
from __future__ import division
import argparse
from itertools import permutations
import networkx as nx
import torch as th
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import dgl
from dgl.data import SBMMixture
import gnn
import utils
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int,
help='Batch size', default=1)
parser.add_argument('--gpu', type=int,
help='GPU', default=-1)
parser.add_argument('--n-communities', type=int,
help='Number of communities', default=2)
parser.add_argument('--n-features', type=int,
help='Number of features per layer', default=2)
parser.add_argument('--n-graphs', type=int,
help='Number of graphs', default=6000)
parser.add_argument('--n-iterations', type=int,
help='Number of iterations', default=10000)
parser.add_argument('--n-layers', type=int,
help='Number of layers', default=30)
parser.add_argument('--n-nodes', type=int,
help='Number of nodes', default=1000)
parser.add_argument('--model-path', type=str,
help='Path to the checkpoint of model', default='model')
parser.add_argument('--radius', type=int,
help='Radius', default=3)
args = parser.parse_args()
dev = th.device('cpu') if args.gpu < 0 else th.device('cuda:%d' % args.gpu)
dataset = SBMMixture(args.n_graphs, args.n_nodes, args.n_communities)
loader = utils.cycle(DataLoader(dataset, args.batch_size,
shuffle=True, collate_fn=dataset.collate_fn, drop_last=True))
ones = th.ones(args.n_nodes // args.n_communities)
y_list = [th.cat([th.cat([x * ones for x in p])] * args.batch_size).long().to(dev)
for p in permutations(range(args.n_communities))]
feats = [1] + [args.n_features] * args.n_layers + [args.n_communities]
model = gnn.GNN(feats, args.radius, args.n_communities).to(dev)
opt = optim.Adamax(model.parameters(), lr=0.04)
for i in range(args.n_iterations):
g, lg, deg_g, deg_lg, eid2nid = next(loader)
deg_g = deg_g.to(dev)
deg_lg = deg_lg.to(dev)
eid2nid = eid2nid.to(dev)
y_bar = model(g, lg, deg_g, deg_lg, eid2nid)
loss = min(F.cross_entropy(y_bar, y) for y in y_list)
opt.zero_grad()
loss.backward()
opt.step()
placeholder = '0' * (len(str(args.n_iterations)) - len(str(i)))
print('[iteration %s%d]loss %f' % (placeholder, i, loss))
th.save(model.state_dict(), args.model_path)
def cycle(loader):
while True:
for x in loader:
yield x
......@@ -8,23 +8,17 @@ from torch.utils.data import DataLoader
import dgl
import dgl.data as data
import dgl.ndarray as nd
from tree_lstm import TreeLSTM
def _batch_to_cuda(batch):
return data.SSTBatch(graph=batch.graph,
nid_with_word = batch.nid_with_word.cuda(),
wordid = batch.wordid.cuda(),
label = batch.label.cuda())
import dgl.context as ctx
def tensor_topo_traverse(g, cuda, args):
n = g.number_of_nodes()
if cuda:
adjmat = g.cached_graph.adjmat().get(ctx.gpu(args.gpu))
adjmat = g._graph.adjacency_matrix().get(nd.gpu(args.gpu))
mask = th.ones((n, 1)).cuda()
else:
adjmat = g.cached_graph.adjmat().get(ctx.cpu())
adjmat = g._graph.adjacency_matrix().get(nd.cpu())
mask = th.ones((n, 1))
degree = th.spmm(adjmat, mask)
while th.sum(mask) != 0.:
......@@ -39,10 +33,17 @@ def main(args):
cuda = args.gpu >= 0
if cuda:
th.cuda.set_device(args.gpu)
def _batcher(trees):
bg = dgl.batch(trees)
if cuda:
reprs = bg.get_n_repr()
reprs = {key : val.cuda() for key, val in reprs.items()}
bg.set_n_repr(reprs)
return bg
trainset = data.SST()
train_loader = DataLoader(dataset=trainset,
batch_size=args.batch_size,
collate_fn=data.SST.batcher,
collate_fn=_batcher,
shuffle=False,
num_workers=0)
#testset = data.SST(mode='test')
......@@ -69,18 +70,15 @@ def main(args):
dur = []
for epoch in range(args.epochs):
t_epoch = time.time()
for step, batch in enumerate(train_loader):
g = batch.graph
if cuda:
batch = _batch_to_cuda(batch)
for step, graph in enumerate(train_loader):
if step >= 3:
t0 = time.time()
label = graph.pop_n_repr('y')
# traverse graph
giter = list(tensor_topo_traverse(g, False, args))
logits = model(batch, zero_initializer, iterator=giter, train=True)
giter = list(tensor_topo_traverse(graph, False, args))
logits = model(graph, zero_initializer, iterator=giter, train=True)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label)
loss = F.nll_loss(logp, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
......@@ -89,11 +87,11 @@ def main(args):
if step > 0 and step % args.log_every == 0:
pred = th.argmax(logits, 1)
acc = th.sum(th.eq(batch.label, pred))
acc = th.sum(th.eq(label, pred))
mean_dur = np.mean(dur)
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
"Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}".format(
epoch, step, loss.item(), acc.item()/len(batch.label),
epoch, step, loss.item(), acc.item() / len(label),
mean_dur, args.batch_size / mean_dur))
print("Epoch time(s):", time.time() - t_epoch)
......
......@@ -10,23 +10,7 @@ import torch as th
import torch.nn as nn
import torch.nn.functional as F
def topological_traverse(G):
indegree_map = {v: d for v, d in G.in_degree() if d > 0}
# These nodes have zero indegree and ready to be returned.
zero_indegree = [v for v, d in G.in_degree() if d == 0]
while True:
yield zero_indegree
next_zero_indegree = []
while zero_indegree:
node = zero_indegree.pop()
for _, child in G.edges(node):
indegree_map[child] -= 1
if indegree_map[child] == 0:
next_zero_indegree.append(child)
del indegree_map[child]
if len(next_zero_indegree) == 0:
break
zero_indegree = next_zero_indegree
import dgl
class ChildSumTreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
......@@ -39,7 +23,7 @@ class ChildSumTreeLSTMCell(nn.Module):
self.ut = 0.
def message_func(self, src, edge):
return src
return {'h' : src['h'], 'c' : src['c']}
def reduce_func(self, node, msgs):
# equation (2)
......@@ -83,13 +67,13 @@ class TreeLSTM(nn.Module):
else:
raise RuntimeError('Unknown cell type:', cell_type)
def forward(self, batch, zero_initializer, h=None, c=None, iterator=None, train=True):
def forward(self, graph, zero_initializer, h=None, c=None, iterator=None, train=True):
"""Compute tree-lstm prediction given a batch.
Parameters
----------
batch : dgl.data.SSTBatch
The data batch.
graph : dgl.DGLGraph
The batched trees.
zero_initializer : callable
Function to return zero value tensor.
h : Tensor, optional
......@@ -104,15 +88,17 @@ class TreeLSTM(nn.Module):
logits : Tensor
The prediction of each node.
"""
g = batch.graph
g = graph
n = g.number_of_nodes()
g.register_message_func(self.cell.message_func, batchable=True)
g.register_reduce_func(self.cell.reduce_func, batchable=True)
g.register_apply_node_func(self.cell.apply_func, batchable=True)
g.register_message_func(self.cell.message_func)
g.register_reduce_func(self.cell.reduce_func)
g.register_apply_node_func(self.cell.apply_func)
# feed embedding
embeds = self.embedding(batch.wordid)
x = zero_initializer((n, self.x_size))
x = x.index_copy(0, batch.nid_with_word, embeds)
wordid = g.pop_n_repr('x')
mask = (wordid != dgl.data.SST.PAD_WORD)
wordid = wordid * mask.long()
embeds = self.embedding(wordid)
x = embeds * th.unsqueeze(mask, 1).float()
if h is None:
h = zero_initializer((n, self.h_size))
h_tild = zero_initializer((n, self.h_size))
......
// DGL Graph interface
#ifndef DGL_DGLGRAPH_H_
#define DGL_DGLGRAPH_H_
#include <stdint.h>
#include "runtime/ndarray.h"
namespace dgl {
typedef uint64_t dgl_id_t;
typedef tvm::runtime::NDArray IdArray;
typedef tvm::runtime::NDArray DegreeArray;
typedef tvm::runtime::NDArray BoolArray;
class Graph;
class GraphOp;
struct Subgraph;
/*!
* \brief Base dgl graph class.
*
* DGL's graph is directed. Vertices are integers enumerated from zero. Edges
* are uniquely identified by the two endpoints. Multi-edge is currently not
* supported.
*
* Removal of vertices/edges is not allowed. Instead, the graph can only be "cleared"
* by removing all the vertices and edges.
*
* When calling functions supporing multiple edges (e.g. AddEdges, HasEdges),
* the input edges are represented by two id arrays for source and destination
* vertex ids. In the general case, the two arrays should have the same length.
* If the length of src id array is one, it represents one-many connections.
* If the length of dst id array is one, it represents many-one connections.
*/
class Graph {
public:
/* \brief structure used to represent a list of edges */
typedef struct {
/* \brief the two endpoints and the id of the edge */
IdArray src, dst, id;
} EdgeArray;
/*! \brief default constructor */
Graph() {}
/*! \brief default copy constructor */
Graph(const Graph& other) = default;
#ifndef _MSC_VER
/*! \brief default move constructor */
Graph(Graph&& other) = default;
#else
Graph(Graph&& other) {
adjlist_ = other.adjlist_;
reverse_adjlist_ = other.reverse_adjlist_;
all_edges_src_ = other.all_edges_src_;
all_edges_dst_ = other.all_edges_dst_;
read_only_ = other.read_only_;
num_edges_ = other.num_edges_;
other.clear();
}
#endif // _MSC_VER
/*! \brief default assign constructor */
Graph& operator=(const Graph& other) = default;
/*! \brief default destructor */
~Graph() = default;
/*!
* \brief Add vertices to the graph.
* \note Since vertices are integers enumerated from zero, only the number of
* vertices to be added needs to be specified.
* \param num_vertices The number of vertices to be added.
*/
void AddVertices(uint64_t num_vertices);
/*!
* \brief Add one edge to the graph.
* \param src The source vertex.
* \param dst The destination vertex.
*/
void AddEdge(dgl_id_t src, dgl_id_t dst);
/*!
* \brief Add edges to the graph.
* \param src_ids The source vertex id array.
* \param dst_ids The destination vertex id array.
*/
void AddEdges(IdArray src_ids, IdArray dst_ids);
/*!
* \brief Clear the graph. Remove all vertices/edges.
*/
void Clear() {
adjlist_.clear();
reverse_adjlist_.clear();
all_edges_src_.clear();
all_edges_dst_.clear();
read_only_ = false;
num_edges_ = 0;
}
/*! \return the number of vertices in the graph.*/
uint64_t NumVertices() const {
return adjlist_.size();
}
/*! \return the number of edges in the graph.*/
uint64_t NumEdges() const {
return num_edges_;
}
/*! \return true if the given vertex is in the graph.*/
bool HasVertex(dgl_id_t vid) const {
return vid < NumVertices();
}
/*! \return a 0-1 array indicating whether the given vertices are in the graph.*/
BoolArray HasVertices(IdArray vids) const;
/*! \return true if the given edge is in the graph.*/
bool HasEdge(dgl_id_t src, dgl_id_t dst) const;
/*! \return a 0-1 array indicating whether the given edges are in the graph.*/
BoolArray HasEdges(IdArray src_ids, IdArray dst_ids) const;
/*!
* \brief Find the predecessors of a vertex.
* \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
* \return the predecessor id array.
*/
IdArray Predecessors(dgl_id_t vid, uint64_t radius = 1) const;
/*!
* \brief Find the successors of a vertex.
* \param vid The vertex id.
* \param radius The radius of the neighborhood. Default is immediate neighbor (radius=1).
* \return the successor id array.
*/
IdArray Successors(dgl_id_t vid, uint64_t radius = 1) const;
/*!
* \brief Get the edge id using the two endpoints
* \note Edges are associated with an integer id start from zero.
* The id is assigned when the edge is being added to the graph.
* \param src The source vertex.
* \param dst The destination vertex.
* \return the edge id.
*/
dgl_id_t EdgeId(dgl_id_t src, dgl_id_t dst) const;
/*!
* \brief Get the edge id using the two endpoints
* \note Edges are associated with an integer id start from zero.
* The id is assigned when the edge is being added to the graph.
* \return the edge id array.
*/
IdArray EdgeIds(IdArray src, IdArray dst) const;
/*!
* \brief Get the in edges of the vertex.
* \note The returned dst id array is filled with vid.
* \param vid The vertex id.
* \return the edges
*/
EdgeArray InEdges(dgl_id_t vid) const;
/*!
* \brief Get the in edges of the vertices.
* \param vids The vertex id array.
* \return the id arrays of the two endpoints of the edges.
*/
EdgeArray InEdges(IdArray vids) const;
/*!
* \brief Get the out edges of the vertex.
* \note The returned src id array is filled with vid.
* \param vid The vertex id.
* \return the id arrays of the two endpoints of the edges.
*/
EdgeArray OutEdges(dgl_id_t vid) const;
/*!
* \brief Get the out edges of the vertices.
* \param vids The vertex id array.
* \return the id arrays of the two endpoints of the edges.
*/
EdgeArray OutEdges(IdArray vids) const;
/*!
* \brief Get all the edges in the graph.
* \note If sorted is true, the returned edges list is sorted by their src and
* dst ids. Otherwise, they are in their edge id order.
* \param sorted Whether the returned edge list is sorted by their src and dst ids
* \return the id arrays of the two endpoints of the edges.
*/
EdgeArray Edges(bool sorted = false) const;
/*!
* \brief Get the in degree of the given vertex.
* \param vid The vertex id.
* \return the in degree
*/
uint64_t InDegree(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
return reverse_adjlist_[vid].succ.size();
}
/*!
* \brief Get the in degrees of the given vertices.
* \param vid The vertex id array.
* \return the in degree array
*/
DegreeArray InDegrees(IdArray vids) const;
/*!
* \brief Get the out degree of the given vertex.
* \param vid The vertex id.
* \return the out degree
*/
uint64_t OutDegree(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
return adjlist_[vid].succ.size();
}
/*!
* \brief Get the out degrees of the given vertices.
* \param vid The vertex id array.
* \return the out degree array
*/
DegreeArray OutDegrees(IdArray vids) const;
/*!
* \brief Construct the induced subgraph of the given vertices.
*
* The induced subgraph is a subgraph formed by specifying a set of vertices V' and then
* selecting all of the edges from the original graph that connect two vertices in V'.
*
* Vertices and edges in the original graph will be "reindexed" to local index. The local
* index of the vertices preserve the order of the given id array, while the local index
* of the edges preserve the index order in the original graph. Vertices not in the
* original graph are ignored.
*
* The result subgraph is read-only.
*
* \param vids The vertices in the subgraph.
* \return the induced subgraph
*/
Subgraph VertexSubgraph(IdArray vids) const;
/*!
* \brief Construct the induced edge subgraph of the given edges.
*
* The induced edges subgraph is a subgraph formed by specifying a set of edges E' and then
* selecting all of the nodes from the original graph that are endpoints in E'.
*
* Vertices and edges in the original graph will be "reindexed" to local index. The local
* index of the edges preserve the order of the given id array, while the local index
* of the vertices preserve the index order in the original graph. Edges not in the
* original graph are ignored.
*
* The result subgraph is read-only.
*
* \param vids The edges in the subgraph.
* \return the induced edge subgraph
*/
Subgraph EdgeSubgraph(IdArray src, IdArray dst) const;
/*!
* \brief Return a new graph with all the edges reversed.
*
* The returned graph preserves the vertex and edge index in the original graph.
*
* \return the reversed graph
*/
Graph Reverse() const;
protected:
friend class GraphOp;
/*! \brief Internal edge list type */
struct EdgeList {
/*! \brief successor vertex list */
std::vector<dgl_id_t> succ;
/*! \brief predecessor vertex list */
std::vector<dgl_id_t> edge_id;
};
typedef std::vector<EdgeList> AdjacencyList;
/*! \brief adjacency list using vector storage */
AdjacencyList adjlist_;
/*! \brief reverse adjacency list using vector storage */
AdjacencyList reverse_adjlist_;
/*! \brief all edges' src endpoints in their edge id order */
std::vector<dgl_id_t> all_edges_src_;
/*! \brief all edges' dst endpoints in their edge id order */
std::vector<dgl_id_t> all_edges_dst_;
/*! \brief read only flag */
bool read_only_ = false;
/*! \brief number of edges */
uint64_t num_edges_ = 0;
};
/*! \brief Subgraph data structure */
struct Subgraph {
/*! \brief The graph. */
Graph graph;
/*!
* \brief The induced vertex ids.
* \note This is also a map from the new vertex id to the vertex id in the parent graph.
*/
IdArray induced_vertices;
/*!
* \brief The induced edge ids.
* \note This is also a map from the new edge id to the edge id in the parent graph.
*/
IdArray induced_edges;
};
} // namespace dgl
#endif // DGL_DGLGRAPH_H_
// Graph operations
#ifndef DGL_GRAPH_OP_H_
#define DGL_GRAPH_OP_H_
#include "graph.h"
namespace dgl {
class GraphOp {
public:
/*!
* \brief Return the line graph.
*
* If i~j and j~i are two edges in original graph G, then
* (i,j)~(j,i) and (j,i)~(i,j) are the "backtracking" edges on
* the line graph.
*
* \param graph The input graph.
* \param backtracking Whether the backtracking edges are included or not
* \return the line graph
*/
static Graph LineGraph(const Graph* graph, bool backtracking);
/*!
* \brief Return a disjoint union of the input graphs.
*
* The new graph will include all the nodes/edges in the given graphs.
* Nodes/Edges will be relabled by adding the cumsum of the previous graph sizes
* in the given sequence order. For example, giving input [g1, g2, g3], where
* they have 5, 6, 7 nodes respectively. Then node#2 of g2 will become node#7
* in the result graph. Edge ids are re-assigned similarly.
*
* \param graphs A list of input graphs to be unioned.
* \return the disjoint union of the graphs
*/
static Graph DisjointUnion(std::vector<const Graph*> graphs);
/*!
* \brief Partition the graph into several subgraphs.
*
* This is a reverse operation of DisjointUnion. The graph will be partitioned
* into num graphs. This requires the given number of partitions to evenly
* divides the number of nodes in the graph.
*
* \param graph The graph to be partitioned.
* \param num The number of partitions.
* \return a list of partitioned graphs
*/
static std::vector<Graph> DisjointPartitionByNum(const Graph* graph, int64_t num);
/*!
* \brief Partition the graph into several subgraphs.
*
* This is a reverse operation of DisjointUnion. The graph will be partitioned
* based on the given sizes. This requires the sum of the given sizes is equal
* to the number of nodes in the graph.
*
* \param graph The graph to be partitioned.
* \param sizes The number of partitions.
* \return a list of partitioned graphs
*/
static std::vector<Graph> DisjointPartitionBySizes(const Graph* graph, IdArray sizes);
};
} // namespace dgl
#endif // DGL_GRAPH_OP_H_
# C API and runtime
Borrowed and adapted from TVM project.
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/c_backend_api.h
* \brief TVM runtime backend API.
*
* The functions defined in this header are intended to be
* used by compiled tvm operators, usually user do not need to use these
* function directly.
*/
#ifndef TVM_RUNTIME_C_BACKEND_API_H_
#define TVM_RUNTIME_C_BACKEND_API_H_
#include "c_runtime_api.h"
#ifdef __cplusplus
extern "C" {
#endif
// Backend related functions.
/*!
* \brief Backend function for modules to get function
* from its environment mod_node (its imports and global function).
* The user do should not call TVMFuncFree on func.
*
* \param mod_node The module handle.
* \param func_name The name of the function.
* \param out The result function.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
TVMFunctionHandle *out);
/*!
* \brief Backend function to register system-wide library symbol.
*
* \param name The name of the symbol
* \param ptr The symbol address.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr);
/*!
* \brief Backend function to allocate temporal workspace.
*
* \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment.
*
* \param nbytes The size of the space requested.
* \param device_type The device type which the space will be allocated.
* \param device_id The device id which the space will be allocated.
* \param dtype_code_hint The type code of the array elements. Only used in
* certain backends such as OpenGL.
* \param dtype_bits_hint The type bits of the array elements. Only used in
* certain backends such as OpenGL.
* \return nullptr when error is thrown, a valid ptr if success
*/
TVM_DLL void* TVMBackendAllocWorkspace(int device_type,
int device_id,
uint64_t nbytes,
int dtype_code_hint,
int dtype_bits_hint);
/*!
* \brief Backend function to free temporal workspace.
*
* \param ptr The result allocated space pointer.
* \param device_type The device type which the space will be allocated.
* \param device_id The device id which the space will be allocated.
* \return 0 when no error is thrown, -1 when failure happens
*
* \sa TVMBackendAllocWorkspace
*/
TVM_DLL int TVMBackendFreeWorkspace(int device_type,
int device_id,
void* ptr);
/*!
* \brief Environment for TVM parallel task.
*/
typedef struct {
/*!
* \brief Auxiliary used for synchronization
*/
void* sync_handle;
/*! \brief total amount of task */
int32_t num_task;
} TVMParallelGroupEnv;
/*!
* \brief The callback function to execute a parallel lambda
* \param task_id the task id of the function.
* \param penv The parallel environment backs the execution.
* \param cdata The supporting closure data.
*/
typedef int (*FTVMParallelLambda)(
int task_id, TVMParallelGroupEnv* penv, void* cdata);
/*!
* \brief Backend function for running parallel jobs.
*
* \param flambda The parallel function to be launched.
* \param cdata The closure data.
* \param num_task Number of tasks to launch, can be 0, means launch
* with all available threads.
*
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda,
void* cdata,
int num_task);
/*!
* \brief BSP barrrier between parallel threads
* \param task_id the task id of the function.
* \param penv The parallel environment backs the execution.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv);
/*!
* \brief Simple static initialization fucntion.
* Run f once and set handle to be not null.
* This function is mainly used for test purpose.
*
* \param handle An global address to indicate f
* \param f The function to be ran
* \param cdata The closure data to pass to the function.
* \param nbytes Number of bytes in the closure data.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMBackendRunOnce(void** handle,
int (*f)(void*),
void *cdata,
int nbytes);
#ifdef __cplusplus
} // TVM_EXTERN_C
#endif
#endif // TVM_RUNTIME_C_BACKEND_API_H_
/*!
* Copyright (c) 2016 by Contributors
* \file tvm/runtime/c_runtime_api.h
* \brief TVM runtime library.
*
* The philosophy of TVM project is to customize the compilation
* stage to generate code that can used by other projects transparently.
* So this is a minimum runtime code gluing, and some limited
* memory management code to enable quick testing.
*
* The runtime API is independent from TVM compilation stack and can
* be linked via libtvm_runtime.
*
* The common flow is:
* - Use TVMFuncListGlobalNames to get global function name
* - Use TVMFuncCall to call these functions.
*/
#ifndef TVM_RUNTIME_C_RUNTIME_API_H_
#define TVM_RUNTIME_C_RUNTIME_API_H_
// Macros to do weak linking
#ifdef _MSC_VER
#define TVM_WEAK __declspec(selectany)
#else
#define TVM_WEAK __attribute__((weak))
#endif
#ifdef __EMSCRIPTEN__
#include <emscripten/emscripten.h>
#define TVM_DLL EMSCRIPTEN_KEEPALIVE
#endif
#ifndef TVM_DLL
#ifdef _WIN32
#ifdef TVM_EXPORTS
#define TVM_DLL __declspec(dllexport)
#else
#define TVM_DLL __declspec(dllimport)
#endif
#else
#define TVM_DLL
#endif
#endif
// TVM version
#define TVM_VERSION "0.5.dev"
// TVM Runtime is DLPack compatible.
#include <dlpack/dlpack.h>
#ifdef __cplusplus
extern "C" {
#endif
#include <stdint.h>
#include <stddef.h>
/*! \brief type of array index. */
typedef int64_t tvm_index_t;
/*! \brief Extension device types in TVM */
typedef enum {
kDLAOCL = 5,
kDLSDAccel = 6,
kOpenGL = 11,
// Extension DRAM type, used for quickly test extension device
// The device api can differ depending on the xpu driver registered.
kExtDev = 12,
// AddExtraTVMType which is not in DLPack here
} TVMDeviceExtType;
/*!
* \brief The type code in TVMType
* \note TVMType is used in two places.
*/
typedef enum {
// The type code of other types are compatible with DLPack.
// The next few fields are extension types
// that is used by TVM API calls.
kHandle = 3U,
kNull = 4U,
kTVMType = 5U,
kTVMContext = 6U,
kArrayHandle = 7U,
kNodeHandle = 8U,
kModuleHandle = 9U,
kFuncHandle = 10U,
kStr = 11U,
kBytes = 12U,
kNDArrayContainer = 13U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
// Open an issue at the repo if you need a section of code.
kExtBegin = 15U,
kNNVMFirst = 16U,
kNNVMLast = 20U,
// The following section of code is used for non-reserved types.
kExtReserveEnd = 64U,
kExtEnd = 128U
} TVMTypeCode;
/*!
* \brief The data type used in TVM Runtime.
*
* Examples
* - float: type_code = 2, bits = 32, lanes=1
* - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
* - int8: type_code = 0, bits = 8, lanes=1
*
* \note Arguments TVM API function always takes bits=64 and lanes=1
*/
typedef DLDataType TVMType;
/*!
* \brief The Device information, abstract away common device types.
*/
typedef DLContext TVMContext;
/*!
* \brief The tensor array stucture to TVM API.
*/
typedef DLTensor TVMArray;
/*! \brief the array handle */
typedef TVMArray* TVMArrayHandle;
/*!
* \brief Union type of values
* being passed through API and function calls.
*/
typedef union {
int64_t v_int64;
double v_float64;
void* v_handle;
const char* v_str;
TVMType v_type;
TVMContext v_ctx;
} TVMValue;
/*!
* \brief Byte array type used to pass in byte array
* When kBytes is used as data type.
*/
typedef struct {
const char* data;
size_t size;
} TVMByteArray;
/*! \brief Handle to TVM runtime modules. */
typedef void* TVMModuleHandle;
/*! \brief Handle to packed function handle. */
typedef void* TVMFunctionHandle;
/*! \brief Handle to hold return value. */
typedef void* TVMRetValueHandle;
/*!
* \brief The stream that is specific to device
* can be NULL, which indicates the default one.
*/
typedef void* TVMStreamHandle;
/*!
* \brief Used for implementing C API function.
* Set last error message before return.
* \param msg The error message to be set.
*/
TVM_DLL void TVMAPISetLastError(const char* msg);
/*!
* \brief return str message of the last error
* all function in this file will return 0 when success
* and -1 when an error occured,
* TVMGetLastError can be called to retrieve the error
*
* this function is threadsafe and can be called by different thread
* \return error info
*/
TVM_DLL const char *TVMGetLastError(void);
/*!
* \brief Load module from file.
* \param file_name The file name to load the module from.
* \param format The format of the module.
* \param out The result module
*
* \return 0 when success, -1 when failure happens
* \note The resulting module do not contain import relation.
* It can be reconstructed by TVMModImport.
*/
TVM_DLL int TVMModLoadFromFile(const char* file_name,
const char* format,
TVMModuleHandle* out);
/*!
* \brief Add dep to mod's dependency.
* This allows functions in this module to use modules.
*
* \param mod The module handle.
* \param dep The dependent module to be imported.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMModImport(TVMModuleHandle mod,
TVMModuleHandle dep);
/*!
* \brief Get function from the module.
* \param mod The module handle.
* \param func_name The name of the function.
* \param query_imports Whether to query imported modules
* \param out The result function, can be NULL if it is not available.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
const char* func_name,
int query_imports,
TVMFunctionHandle *out);
/*!
* \brief Free front-end extension type resource.
* \param handle The extension handle.
* \param type_code The type of of the extension type.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMExtTypeFree(void* handle, int type_code);
/*!
* \brief Free the Module
* \param mod The module to be freed.
*
* \note This may not free up the module's resources.
* If there is active TVMFunctionHandle uses the module
* Or if this module is imported by another active module.
*
* The all functions remains valid until TVMFuncFree is called.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMModFree(TVMModuleHandle mod);
/*!
* \brief Free the function when it is no longer needed.
* \param func The function handle
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMFuncFree(TVMFunctionHandle func);
/*!
* \brief Call a Packed TVM Function.
*
* \param func node handle of the function.
* \param arg_values The arguments
* \param type_codes The type codes of the arguments
* \param num_args Number of arguments.
*
* \param ret_val The return value.
* \param ret_type_code the type code of return value.
*
* \return 0 when success, -1 when failure happens
* \note TVM calls always exchanges with type bits=64, lanes=1
*
* \note API calls always exchanges with type bits=64, lanes=1
* If API call returns container handles (e.g. FunctionHandle)
* these handles should be managed by the front-end.
* The front-end need to call free function (e.g. TVMFuncFree)
* to free these handles.
*/
TVM_DLL int TVMFuncCall(TVMFunctionHandle func,
TVMValue* arg_values,
int* type_codes,
int num_args,
TVMValue* ret_val,
int* ret_type_code);
/*!
* \brief Set the return value of TVMPackedCFunc.
*
* This function is called by TVMPackedCFunc to set the return value.
* When this function is not called, the function returns null by default.
*
* \param ret The return value handle, pass by ret in TVMPackedCFunc
* \param value The value to be returned.
* \param type_code The type of the value to be returned.
* \param num_ret Number of return values, for now only 1 is supported.
*/
TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret,
TVMValue* value,
int* type_code,
int num_ret);
/*!
* \brief Inplace translate callback argument value to return value.
* This is only needed for non-POD arguments.
*
* \param value The value to be translated.
* \param code The type code to be translated.
* \note This function will do a shallow copy when necessary.
*
* \return 0 when success, -1 when failure happens.
*/
TVM_DLL int TVMCbArgToReturn(TVMValue* value, int code);
/*!
* \brief C type of packed function.
*
* \param args The arguments
* \param type_codes The type codes of the arguments
* \param num_args Number of arguments.
* \param ret The return value handle.
* \param resource_handle The handle additional resouce handle from fron-end.
* \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
* \sa TVMCFuncSetReturn
*/
typedef int (*TVMPackedCFunc)(
TVMValue* args,
int* type_codes,
int num_args,
TVMRetValueHandle ret,
void* resource_handle);
/*!
* \brief C callback to free the resource handle in C packed function.
* \param resource_handle The handle additional resouce handle from fron-end.
*/
typedef void (*TVMPackedCFuncFinalizer)(void* resource_handle);
/*!
* \brief Signature for extension function declarer.
*
* TVM call this function to get the extension functions
* The declarer will call register_func to register function and their name.
*
* \param register_func_handle The register function
* \return 0 if success, -1 if failure happens
*/
typedef int (*TVMExtensionFuncDeclarer)(TVMFunctionHandle register_func_handle);
/*!
* \brief Wrap a TVMPackedCFunc to become a FunctionHandle.
*
* The resource_handle will be managed by TVM API, until the function is no longer used.
*
* \param func The packed C function.
* \param resource_handle The resource handle from front-end, can be NULL.
* \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL
* \param out the result function handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out);
/*!
* \brief Register the function to runtime's global table.
*
* The registered function then can be pulled by the backend by the name.
*
* \param name The name of the function.
* \param f The function to be registered.
* \param override Whether allow override already registered function.
*/
TVM_DLL int TVMFuncRegisterGlobal(
const char* name, TVMFunctionHandle f, int override);
/*!
* \brief Get a global function.
*
* \param name The name of the function.
* \param out the result function pointer, NULL if it does not exist.
*
* \note The function handle of global function is managed by TVM runtime,
* So TVMFuncFree is should not be called when it get deleted.
*/
TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
/*!
* \brief List all the globally registered function name
* \param out_size The number of functions
* \param out_array The array of function names.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMFuncListGlobalNames(int* out_size,
const char*** out_array);
// Array related apis for quick proptyping
/*!
* \brief Allocate a nd-array's memory,
* including space of shape, of given spec.
*
* \param shape The shape of the array, the data content will be copied to out
* \param ndim The number of dimension of the array.
* \param dtype_code The type code of the dtype
* \param dtype_bits The number of bits of dtype
* \param dtype_lanes The number of lanes in the dtype.
* \param device_type The device type of context
* \param device_id The device id of context.
* \param out The output handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
int ndim,
int dtype_code,
int dtype_bits,
int dtype_lanes,
int device_type,
int device_id,
TVMArrayHandle* out);
/*!
* \brief Free the TVM Array.
* \param handle The array handle to be freed.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayFree(TVMArrayHandle handle);
/*!
* \brief Copy array data from CPU byte array.
* \param handle The array handle.
* \param data the data pointer
* \param nbytes The number of bytes to copy.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle,
void* data,
size_t nbytes);
/*!
* \brief Copy array data to CPU byte array.
* \param handle The array handle.
* \param data the data pointer
* \param nbytes The number of bytes to copy.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle,
void* data,
size_t nbytes);
/*!
* \brief Copy the array, both from and to must be valid during the copy.
* \param from The array to be copied from.
* \param to The target space.
* \param stream The stream where the copy happens, can be NULL.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVMArrayHandle to,
TVMStreamHandle stream);
/*!
* \brief Produce an array from the DLManagedTensor that shares data memory
* with the DLManagedTensor.
* \param from The source DLManagedTensor.
* \param out The output array handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from,
TVMArrayHandle* out);
/*!
* \brief Produce a DLMangedTensor from the array that shares data memory with
* the array.
* \param from The source array.
* \param out The DLManagedTensor handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayToDLPack(TVMArrayHandle from,
DLManagedTensor** out);
/*!
* \brief Delete (free) a DLManagedTensor's data.
* \param dltensor Pointer to the DLManagedTensor.
*/
TVM_DLL void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor);
/*!
* \brief Create a new runtime stream.
*
* \param device_type The device type of context
* \param device_id The device id of context
* \param out The new stream handle
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out);
/*!
* \brief Free a created stream handle.
*
* \param device_type The device type of context
* \param device_id The device id of context
* \param stream The stream to be freed
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream);
/*!
* \brief Set the runtime stream of current thread to be stream.
* The subsequent calls to the same device_type
* will use the setted stream handle.
* The specific type of stream is runtime device dependent.
*
* \param device_type The device type of context
* \param device_id The device id of context.
* \param handle The stream handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMSetStream(int device_type, int device_id, TVMStreamHandle handle);
/*!
* \brief Wait until all computations on stream completes.
*
* \param device_type The device type of context
* \param device_id The device id of context.
* \param stream The stream to be synchronized.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream);
/*!
* \brief Synchronize two streams of execution.
*
* \param device_type The device type of context
* \param device_id The device id of context
* \param src The source stream to synchronize.
* \param dst The destination stream to synchronize.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMStreamStreamSynchronize(int device_type,
int device_id,
TVMStreamHandle src,
TVMStreamHandle dst);
#ifdef __cplusplus
} // TVM_EXTERN_C
#endif
#endif // TVM_RUNTIME_C_RUNTIME_API_H_
/*!
* Copyright (c) 2016 by Contributors
* \file tvm/runtime/device_api.h
* \brief Abstract device memory management API
*/
#ifndef TVM_RUNTIME_DEVICE_API_H_
#define TVM_RUNTIME_DEVICE_API_H_
#include <string>
#include "packed_func.h"
#include "c_runtime_api.h"
namespace tvm {
namespace runtime {
/*!
* \brief the query type into GetAttr
*/
enum DeviceAttrKind : int {
kExist = 0,
kMaxThreadsPerBlock = 1,
kWarpSize = 2,
kMaxSharedMemoryPerBlock = 3,
kComputeVersion = 4,
kDeviceName = 5,
kMaxClockRate = 6,
kMultiProcessorCount = 7,
kMaxThreadDimensions = 8
};
/*! \brief Number of bytes each allocation must align to */
constexpr int kAllocAlignment = 64;
/*! \brief Number of bytes each allocation must align to in temporary allocation */
constexpr int kTempAllocaAlignment = 64;
/*! \brief Maximum size that can be allocated on stack */
constexpr int kMaxStackAlloca = 1024;
/*!
* \brief TVM Runtime Device API, abstracts the device
* specific interface for memory management.
*/
class DeviceAPI {
public:
/*! \brief virtual destructor */
virtual ~DeviceAPI() {}
/*!
* \brief Set the environment device id to ctx
* \param ctx The context to be set.
*/
virtual void SetDevice(TVMContext ctx) = 0;
/*!
* \brief Get attribute of specified device.
* \param ctx The device context
* \param kind The result kind
* \param rv The return value.
* \sa DeviceAttrKind
*/
virtual void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) = 0;
/*!
* \brief Allocate a data space on device.
* \param ctx The device context to perform operation.
* \param nbytes The number of bytes in memory.
* \param alignment The alignment of the memory.
* \param type_hint The type of elements. Only needed by certain backends such
* as OpenGL, as nbytes & alignment are sufficient for most backends.
* \return The allocated device pointer.
*/
virtual void* AllocDataSpace(TVMContext ctx,
size_t nbytes,
size_t alignment,
TVMType type_hint) = 0;
/*!
* \brief Free a data space on device.
* \param ctx The device context to perform operation.
* \param ptr The data space.
*/
virtual void FreeDataSpace(TVMContext ctx, void* ptr) = 0;
/*!
* \brief copy data from one place to another
* \param from The source array.
* \param from_offset The byte offeset in the from.
* \param to The target array.
* \param to_offset The byte offset in the to.
* \param num_bytes The size of the memory in bytes
* \param ctx_from The source context
* \param ctx_to The target context
* \param type_hint The type of elements, only neded by certain backends.
* can be useful for cross device endian converison.
* \param stream Optional stream object.
*/
virtual void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t num_bytes,
TVMContext ctx_from,
TVMContext ctx_to,
TVMType type_hint,
TVMStreamHandle stream) = 0;
/*!
* \brief Create a new stream of execution.
*
* \param ctx The context of allocation.
*/
TVM_DLL virtual TVMStreamHandle CreateStream(TVMContext ctx);
/*!
* \brief Free a stream of execution
*
* \param ctx The context of the stream
* \param stream The pointer to be freed.
*/
TVM_DLL virtual void FreeStream(TVMContext ctx, TVMStreamHandle stream);
/*!
* \brief Synchronize the stream
* \param ctx The context to perform operation.
* \param stream The stream to be sync.
*/
virtual void StreamSync(TVMContext ctx, TVMStreamHandle stream) = 0;
/*!
* \brief Set the stream
* \param ctx The context to set stream.
* \param stream The stream to be set.
*/
virtual void SetStream(TVMContext ctx, TVMStreamHandle stream) {}
/*!
* \brief Synchronize 2 streams of execution.
*
* An event is created in event_src stream that the second then
* stream waits on. Neither event_src or event_dst need to be of
* the same device ID as the context, but they must be of the same
* device type.
*
* \param ctx The context of the streams.
* \param event_src The source stream to synchronize.
* \param event_dst The destination stream to synchronize.
*/
TVM_DLL virtual void SyncStreamFromTo(TVMContext ctx,
TVMStreamHandle event_src,
TVMStreamHandle event_dst);
/*!
* \brief Allocate temporal workspace for backend execution.
*
* \note We have the following assumption about backend temporal
* workspace allocation, and backend will optimize for such assumption:
*
* - Only a few allocation will happen, and space will be released after use.
* - The release order is usually in reverse order of allocate (stack style).
* - Repeative pattern of same allocations over different runs.
* - Workspace should not overlap between different threads(i.e. be threadlocal)
*
* \param ctx The context of allocation.
* \param nbytes The size to be allocated.
* \param type_hint The type of elements. Only needed by certain backends such
* as OpenGL, as nbytes is sufficient for most backends.
*/
TVM_DLL virtual void* AllocWorkspace(TVMContext ctx,
size_t nbytes,
TVMType type_hint = {});
/*!
* \brief Free temporal workspace in backend execution.
*
* \param ctx The context of allocation.
* \param ptr The pointer to be freed.
*/
TVM_DLL virtual void FreeWorkspace(TVMContext ctx, void* ptr);
/*!
* \brief Get device API base don context.
* \param ctx The context
* \param allow_missing Whether allow missing
* \return The corresponding device API.
*/
TVM_DLL static DeviceAPI* Get(TVMContext ctx, bool allow_missing = false);
};
/*! \brief The device type bigger than this is RPC device */
constexpr int kRPCSessMask = 128;
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_DEVICE_API_H_
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/module.h
* \brief Runtime container of the functions generated by TVM,
* This is used to support dynamically link, load and save
* functions from different convention under unified API.
*/
#ifndef TVM_RUNTIME_MODULE_H_
#define TVM_RUNTIME_MODULE_H_
#include <dmlc/io.h>
#include <memory>
#include <vector>
#include <string>
#include <unordered_map>
#include "c_runtime_api.h"
namespace tvm {
namespace runtime {
// The internal container of module.
class ModuleNode;
class PackedFunc;
/*!
* \brief Module container of TVM.
*/
class Module {
public:
Module() {}
// constructor from container.
explicit Module(std::shared_ptr<ModuleNode> n)
: node_(n) {}
/*!
* \brief Get packed function from current module by name.
*
* \param name The name of the function.
* \param query_imports Whether also query dependency modules.
* \return The result function.
* This function will return PackedFunc(nullptr) if function do not exist.
* \note Implemented in packed_func.cc
*/
inline PackedFunc GetFunction(const std::string& name, bool query_imports = false);
/*! \return internal container */
inline ModuleNode* operator->();
/*! \return internal container */
inline const ModuleNode* operator->() const;
// The following functions requires link with runtime.
/*!
* \brief Import another module into this module.
* \param other The module to be imported.
*
* \note Cyclic dependency is not allowed among modules,
* An error will be thrown when cyclic dependency is detected.
*/
TVM_DLL void Import(Module other);
/*!
* \brief Load a module from file.
* \param file_name The name of the host function module.
* \param format The format of the file.
* \note This function won't load the import relationship.
* Re-create import relationship by calling Import.
*/
TVM_DLL static Module LoadFromFile(const std::string& file_name,
const std::string& format = "");
private:
std::shared_ptr<ModuleNode> node_;
};
/*!
* \brief Base node container of module.
* Do not create this directly, instead use Module.
*/
class ModuleNode {
public:
/*! \brief virtual destructor */
virtual ~ModuleNode() {}
/*! \return The module type key */
virtual const char* type_key() const = 0;
/*!
* \brief Get a PackedFunc from module.
*
* The PackedFunc may not be fully initialized,
* there might still be first time running overhead when
* executing the function on certain devices.
* For benchmarking, use prepare to eliminate
*
* \param name the name of the function.
* \param sptr_to_self The shared_ptr that points to this module node.
*
* \return PackedFunc(nullptr) when it is not available.
*
* \note The function will always remain valid.
* If the function need resource from the module(e.g. late linking),
* it should capture sptr_to_self.
*/
virtual PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) = 0;
/*!
* \brief Save the module to file.
* \param file_name The file to be saved to.
* \param format The format of the file.
*/
virtual void SaveToFile(const std::string& file_name,
const std::string& format);
/*!
* \brief Save the module to binary stream.
* \param stream The binary stream to save to.
* \note It is recommended to implement this for device modules,
* but not necessarily host modules.
* We can use this to do AOT loading of bundled device functions.
*/
TVM_DLL virtual void SaveToBinary(dmlc::Stream* stream);
/*!
* \brief Get the source code of module, when available.
* \param format Format of the source code, can be empty by default.
* \return Possible source code when available.
*/
TVM_DLL virtual std::string GetSource(const std::string& format = "");
/*!
* \brief Get a function from current environment
* The environment includes all the imports as well as Global functions.
*
* \param name name of the function.
* \return The corresponding function.
*/
TVM_DLL const PackedFunc* GetFuncFromEnv(const std::string& name);
/*! \return The module it imports from */
const std::vector<Module>& imports() const {
return imports_;
}
protected:
friend class Module;
/*! \brief The modules this module depend on */
std::vector<Module> imports_;
private:
/*! \brief Cache used by GetImport */
std::unordered_map<std::string,
std::unique_ptr<PackedFunc> > import_cache_;
};
/*! \brief namespace for constant symbols */
namespace symbol {
/*! \brief Global variable to store module context. */
constexpr const char* tvm_module_ctx = "__tvm_module_ctx";
/*! \brief Global variable to store device module blob */
constexpr const char* tvm_dev_mblob = "__tvm_dev_mblob";
/*! \brief Number of bytes of device module blob. */
constexpr const char* tvm_dev_mblob_nbytes = "__tvm_dev_mblob_nbytes";
/*! \brief global function to set device */
constexpr const char* tvm_set_device = "__tvm_set_device";
/*! \brief Auxiliary counter to global barrier. */
constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
/*! \brief Prepare the global barrier before kernels that uses global barrier. */
constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier";
/*! \brief Placeholder for the module's entry function. */
constexpr const char* tvm_module_main = "__tvm_main__";
} // namespace symbol
// implementations of inline functions.
inline ModuleNode* Module::operator->() {
return node_.get();
}
inline const ModuleNode* Module::operator->() const {
return node_.get();
}
} // namespace runtime
} // namespace tvm
#include "packed_func.h"
#endif // TVM_RUNTIME_MODULE_H_
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/ndarray.h
* \brief Abstract device memory management API
*/
#ifndef TVM_RUNTIME_NDARRAY_H_
#define TVM_RUNTIME_NDARRAY_H_
#include <atomic>
#include <vector>
#include <utility>
#include "c_runtime_api.h"
#include "serializer.h"
namespace tvm {
namespace runtime {
/*!
* \brief Managed NDArray.
* The array is backed by reference counted blocks.
*/
class NDArray {
public:
// internal container type
struct Container;
/*! \brief default constructor */
NDArray() {}
/*!
* \brief cosntruct a NDArray that refers to data
* \param data The data this NDArray refers to
*/
explicit inline NDArray(Container* data);
/*!
* \brief copy constructor
* \param other The value to be copied
*/
inline NDArray(const NDArray& other); // NOLINT(*)
/*!
* \brief move constructor
* \param other The value to be moved
*/
NDArray(NDArray&& other) // NOLINT(*)
: data_(other.data_) {
other.data_ = nullptr;
}
/*! \brief destructor */
~NDArray() {
this->reset();
}
/*!
* \brief Swap this array with another NDArray
* \param other The other NDArray
*/
void swap(NDArray& other) { // NOLINT(*)
std::swap(data_, other.data_);
}
/*!
* \brief copy assignmemt
* \param other The value to be assigned.
* \return reference to self.
*/
NDArray& operator=(const NDArray& other) { // NOLINT(*)
// copy-and-swap idiom
NDArray(other).swap(*this); // NOLINT(*)
return *this;
}
/*!
* \brief move assignmemt
* \param other The value to be assigned.
* \return reference to self.
*/
NDArray& operator=(NDArray&& other) { // NOLINT(*)
// copy-and-swap idiom
NDArray(std::move(other)).swap(*this); // NOLINT(*)
return *this;
}
/*! \return If NDArray is defined */
bool defined() const {
return data_ != nullptr;
}
/*! \return If both NDArray reference the same container */
bool same_as(const NDArray& other) const {
return data_ == other.data_;
}
/*! \brief reset the content of NDArray to be nullptr */
inline void reset();
/*!
* \return the reference counter
* \note this number is approximate in multi-threaded setting.
*/
inline int use_count() const;
/*! \return Pointer to content of DLTensor */
inline const DLTensor* operator->() const;
/*!
* \brief Copy data content from another array.
* \param other The source array to be copied from.
* \note The copy may happen asynchrously if it involves a GPU context.
* TVMSynchronize is necessary.
*/
inline void CopyFrom(DLTensor* other);
inline void CopyFrom(const NDArray& other);
/*!
* \brief Copy data content into another array.
* \param other The source array to be copied from.
* \note The copy may happen asynchrously if it involves a GPU context.
* TVMSynchronize is necessary.
*/
inline void CopyTo(DLTensor* other) const;
inline void CopyTo(const NDArray& other) const;
/*!
* \brief Copy the data to another context.
* \param ctx The target context.
* \return The array under another context.
*/
inline NDArray CopyTo(const DLContext& ctx) const;
/*!
* \brief Load NDArray from stream
* \param stream The input data stream
* \return Whether load is successful
*/
inline bool Load(dmlc::Stream* stream);
/*!
* \brief Save NDArray to stream
* \param stream The output data stream
*/
inline void Save(dmlc::Stream* stream) const;
/*!
* \brief Create a NDArray that shares the data memory with the current one.
* \param shape The shape of the new array.
* \param dtype The data type of the new array.
* \note The memory size of new array must be smaller than the current one.
*/
TVM_DLL NDArray CreateView(
std::vector<int64_t> shape, DLDataType dtype);
/*!
* \brief Create a reference view of NDArray that
* represents as DLManagedTensor.
* \return A DLManagedTensor
*/
TVM_DLL DLManagedTensor* ToDLPack() const;
/*!
* \brief Create an empty NDArray.
* \param shape The shape of the new array.
* \param dtype The data type of the new array.
* \param ctx The context of the Array.
* \return The created Array
*/
TVM_DLL static NDArray Empty(std::vector<int64_t> shape,
DLDataType dtype,
DLContext ctx);
/*!
* \brief Create a NDArray backed by a dlpack tensor.
*
* This allows us to create a NDArray using the memory
* allocated by an external deep learning framework
* that is DLPack compatible.
*
* The memory is retained until the NDArray went out of scope.
* \param tensor The DLPack tensor to copy from.
* \return The created NDArray view.
*/
TVM_DLL static NDArray FromDLPack(DLManagedTensor* tensor);
/*!
* \brief Function to copy data from one array to another.
* \param from The source array.
* \param to The target array.
* \param stream The stream used in copy.
*/
TVM_DLL static void CopyFromTo(
DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr);
// internal namespace
struct Internal;
private:
/*! \brief Internal Data content */
Container* data_{nullptr};
// enable internal functions
friend struct Internal;
friend class TVMRetValue;
friend class TVMArgsSetter;
};
/*!
* \brief Save a DLTensor to stream
* \param strm The outpu stream
* \param tensor The tensor to be saved.
*/
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
/*!
* \brief Reference counted Container object used to back NDArray.
*
* This object is DLTensor compatible:
* the pointer to the NDArrayContainer can be directly
* interpreted as a DLTensor*
*
* \note: do not use this function directly, use NDArray.
*/
struct NDArray::Container {
public:
// NOTE: the first part of this structure is the same as
// DLManagedTensor, note that, however, the deleter
// is only called when the reference counter goes to 0
/*!
* \brief The corresponding dl_tensor field.
* \note it is important that the first field is DLTensor
* So that this data structure is DLTensor compatible.
* The head ptr of this struct can be viewed as DLTensor*.
*/
DLTensor dl_tensor;
/*!
* \brief addtional context, reserved for recycling
* \note We can attach additional content here
* which the current container depend on
* (e.g. reference to original memory when creating views).
*/
void* manager_ctx{nullptr};
/*!
* \brief Customized deleter
*
* \note The customized deleter is helpful to enable
* different ways of memory allocator that are not
* currently defined by the system.
*/
void (*deleter)(Container* self) = nullptr;
/*! \brief default constructor */
Container() {
dl_tensor.data = nullptr;
dl_tensor.ndim = 0;
dl_tensor.shape = nullptr;
dl_tensor.strides = nullptr;
dl_tensor.byte_offset = 0;
}
/*! \brief developer function, increases reference counter */
void IncRef() {
ref_counter_.fetch_add(1, std::memory_order_relaxed);
}
/*! \brief developer function, decrease reference counter */
void DecRef() {
if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) {
std::atomic_thread_fence(std::memory_order_acquire);
if (this->deleter != nullptr) {
(*this->deleter)(this);
}
}
}
private:
friend class NDArray;
friend class RPCWrappedFunc;
/*!
* \brief The shape container,
* can be used for shape data.
*/
std::vector<int64_t> shape_;
/*!
* \brief The stride container,
* can be used for stride data.
*/
std::vector<int64_t> stride_;
/*! \brief The internal array object */
std::atomic<int> ref_counter_{0};
};
// implementations of inline functions
// the usages of functions are documented in place.
inline NDArray::NDArray(Container* data)
: data_(data) {
data_->IncRef();
}
inline NDArray::NDArray(const NDArray& other)
: data_(other.data_) {
data_->IncRef();
}
inline void NDArray::reset() {
if (data_ != nullptr) {
data_->DecRef();
data_ = nullptr;
}
}
inline void NDArray::CopyFrom(DLTensor* other) {
CHECK(data_ != nullptr);
CopyFromTo(other, &(data_->dl_tensor));
}
inline void NDArray::CopyFrom(const NDArray& other) {
CHECK(data_ != nullptr);
CHECK(other.data_ != nullptr);
CopyFromTo(&(other.data_->dl_tensor), &(data_->dl_tensor));
}
inline void NDArray::CopyTo(DLTensor* other) const {
CHECK(data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), other);
}
inline void NDArray::CopyTo(const NDArray& other) const {
CHECK(data_ != nullptr);
CHECK(other.data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), &(other.data_->dl_tensor));
}
inline NDArray NDArray::CopyTo(const DLContext& ctx) const {
CHECK(data_ != nullptr);
const DLTensor* dptr = operator->();
NDArray ret = Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim),
dptr->dtype, ctx);
this->CopyTo(ret);
return ret;
}
inline int NDArray::use_count() const {
if (data_ == nullptr) return 0;
return data_->ref_counter_.load(std::memory_order_relaxed);
}
inline const DLTensor* NDArray::operator->() const {
return &(data_->dl_tensor);
}
/*! \brief Magic number for NDArray file */
constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
inline bool SaveDLTensor(dmlc::Stream* strm,
DLTensor* tensor) {
uint64_t header = kTVMNDArrayMagic, reserved = 0;
strm->Write(header);
strm->Write(reserved);
// Always save data as CPU context
//
// Parameters that get serialized should be in CPU by default.
// So even the array's context is GPU, it will be stored as CPU array.
// This is used to prevent case when another user loads the parameters
// back on machine that do not have GPU or related context.
//
// We can always do array.CopyTo(target_ctx) to get a corresponding
// array in the target context.
DLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
strm->Write(cpu_ctx);
strm->Write(tensor->ndim);
strm->Write(tensor->dtype);
int ndim = tensor->ndim;
strm->WriteArray(tensor->shape, ndim);
int type_bytes = tensor->dtype.bits / 8;
int64_t num_elems = 1;
for (int i = 0; i < ndim; ++i) {
num_elems *= tensor->shape[i];
}
int64_t data_byte_size = type_bytes * num_elems;
strm->Write(data_byte_size);
if (DMLC_IO_NO_ENDIAN_SWAP &&
tensor->ctx.device_type == kDLCPU &&
tensor->strides == nullptr &&
tensor->byte_offset == 0) {
// quick path
strm->Write(tensor->data, data_byte_size);
} else {
std::vector<uint8_t> bytes(data_byte_size);
CHECK_EQ(TVMArrayCopyToBytes(
tensor, dmlc::BeginPtr(bytes), data_byte_size), 0)
<< TVMGetLastError();
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems);
}
strm->Write(dmlc::BeginPtr(bytes), data_byte_size);
}
return true;
}
inline void NDArray::Save(dmlc::Stream* strm) const {
SaveDLTensor(strm, const_cast<DLTensor*>(operator->()));
}
inline bool NDArray::Load(dmlc::Stream* strm) {
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format";
DLContext ctx;
int ndim;
DLDataType dtype;
CHECK(strm->Read(&ctx))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&ndim))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&dtype))
<< "Invalid DLTensor file format";
CHECK_EQ(ctx.device_type, kDLCPU)
<< "Invalid DLTensor context: can only save as CPU tensor";
std::vector<int64_t> shape(ndim);
if (ndim != 0) {
CHECK(strm->ReadArray(&shape[0], ndim))
<< "Invalid DLTensor file format";
}
NDArray ret = NDArray::Empty(shape, dtype, ctx);
int64_t num_elems = 1;
int elem_bytes = (ret->dtype.bits + 7) / 8;
for (int i = 0; i < ret->ndim; ++i) {
num_elems *= ret->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == num_elems * elem_bytes)
<< "Invalid DLTensor file format";
CHECK(strm->Read(ret->data, data_byte_size))
<< "Invalid DLTensor file format";
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
}
*this = ret;
return true;
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_NDARRAY_H_
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/runtime/packed_func.h
* \brief Type-erased function used across TVM API.
*/
#ifndef TVM_RUNTIME_PACKED_FUNC_H_
#define TVM_RUNTIME_PACKED_FUNC_H_
#include <dmlc/logging.h>
#include <functional>
#include <tuple>
#include <vector>
#include <string>
#include <limits>
#include <memory>
#include <type_traits>
#include "c_runtime_api.h"
#include "module.h"
#include "ndarray.h"
// Whether use TVM runtime in header only mode.
#ifndef TVM_RUNTIME_HEADER_ONLY
#define TVM_RUNTIME_HEADER_ONLY 0
#endif
namespace tvm {
// Forward declare NodeRef and Node for extensions.
// This header works fine without depend on NodeRef
// as long as it is not used.
class Node;
class NodeRef;
namespace runtime {
// forward declarations
class TVMArgs;
class TVMArgValue;
class TVMRetValue;
class TVMArgsSetter;
/*!
* \brief Packed function is a type-erased function.
* The arguments are passed by packed format.
*
* This is an useful unified interface to call generated functions,
* It is the unified function function type of TVM.
* It corresponds to TVMFunctionHandle in C runtime API.
*/
class PackedFunc {
public:
/*!
* \brief The internal std::function
* \param args The arguments to the function.
* \param rv The return value.
*
* \code
* // Example code on how to implemented FType
* void MyPackedFunc(TVMArgs args, TVMRetValue* rv) {
* // automatically convert arguments to desired type.
* int a0 = args[0];
* float a1 = args[1];
* ...
* // automatically assign values to rv
* std::string my_return_value = "x";
* *rv = my_return_value;
* }
* \endcode
*/
using FType = std::function<void (TVMArgs args, TVMRetValue* rv)>;
/*! \brief default constructor */
PackedFunc() {}
/*!
* \brief constructing a packed function from a std::function.
* \param body the internal container of packed function.
*/
explicit PackedFunc(FType body) : body_(body) {}
/*!
* \brief Call packed function by directly passing in unpacked format.
* \param args Arguments to be passed.
* \tparam Args arguments to be passed.
*
* \code
* // Example code on how to call packed function
* void CallPacked(PackedFunc f) {
* // call like normal functions by pass in arguments
* // return value is automatically converted back
* int rvalue = f(1, 2.0);
* }
* \endcode
*/
template<typename... Args>
inline TVMRetValue operator()(Args&& ...args) const;
/*!
* \brief Call the function in packed format.
* \param args The arguments
* \param rv The return value.
*/
inline void CallPacked(TVMArgs args, TVMRetValue* rv) const;
/*! \return the internal body function */
inline FType body() const;
/*! \return Whether the packed function is nullptr */
bool operator==(std::nullptr_t null) const {
return body_ == nullptr;
}
/*! \return Whether the packed function is not nullptr */
bool operator!=(std::nullptr_t null) const {
return body_ != nullptr;
}
private:
/*! \brief internal container of packed function */
FType body_;
};
/*!
* \brief Please refer to \ref TypedPackedFuncAnchor "TypedPackedFunc<R(Args..)>"
*/
template<typename FType>
class TypedPackedFunc;
/*!
* \anchor TypedPackedFuncAnchor
* \brief A PackedFunc wrapper to provide typed function signature.
* It is backed by a PackedFunc internally.
*
* TypedPackedFunc enables compile time type checking.
* TypedPackedFunc works with the runtime system:
* - It can be passed as an argument of PackedFunc.
* - It can be assigned to TVMRetValue.
* - It can be directly converted to a type-erased PackedFunc.
*
* Developers should prefer TypedPackedFunc over PackedFunc in C++ code
* as it enables compile time checking.
* We can construct a TypedPackedFunc from a lambda function
* with the same signature.
*
* \code
* // user defined lambda function.
* auto addone = [](int x)->int {
* return x + 1;
* };
* // We can directly convert
* // lambda function to TypedPackedFunc
* TypedPackedFunc<int(int)> ftyped(addone);
* // invoke the function.
* int y = ftyped(1);
* // Can be directly converted to PackedFunc
* PackedFunc packed = ftype;
* \endcode
* \tparam R The return value of the function.
* \tparam Args The argument signature of the function.
*/
template<typename R, typename ...Args>
class TypedPackedFunc<R(Args...)> {
public:
/*! \brief short hand for this function type */
using TSelf = TypedPackedFunc<R(Args...)>;
/*! \brief default constructor */
TypedPackedFunc() {}
/*!
* \brief construct by wrap a PackedFunc
*
* Example usage:
* \code
* PackedFunc packed([](TVMArgs args, TVMRetValue *rv) {
* int x = args[0];
* *rv = x + 1;
* });
* // construct from packed function
* TypedPackedFunc<int(int)> ftyped(packed);
* // call the typed version.
* CHECK_EQ(ftyped(1), 2);
* \endcode
*
* \param packed The packed function
*/
inline explicit TypedPackedFunc(PackedFunc packed);
/*!
* \brief construct from a lambda function with the same signature.
*
* Example usage:
* \code
* auto typed_lambda = [](int x)->int { return x + 1; }
* // construct from packed function
* TypedPackedFunc<int(int)> ftyped(typed_lambda);
* // call the typed version.
* CHECK_EQ(ftyped(1), 2);
* \endcode
*
* \param typed_lambda typed lambda function.
* \tparam FLambda the type of the lambda function.
*/
template<typename FLambda,
typename = typename std::enable_if<
std::is_convertible<FLambda,
std::function<R(Args...)>
>::value>::type>
explicit TypedPackedFunc(const FLambda& typed_lambda) {
this->AssignTypedLambda(typed_lambda);
}
/*!
* \brief copy assignment operator from typed lambda
*
* Example usage:
* \code
* // construct from packed function
* TypedPackedFunc<int(int)> ftyped;
* ftyped = [](int x) { return x + 1; }
* // call the typed version.
* CHECK_EQ(ftyped(1), 2);
* \endcode
*
* \param typed_lambda typed lambda function.
* \tparam FLambda the type of the lambda function.
* \returns reference to self.
*/
template<typename FLambda,
typename = typename std::enable_if<
std::is_convertible<FLambda,
std::function<R(Args...)>
>::value>::type>
TSelf& operator=(FLambda typed_lambda) { // NOLINT(*)
this->AssignTypedLambda(typed_lambda);
return *this;
}
/*!
* \brief copy assignment operator from PackedFunc.
* \param packed The packed function.
* \returns reference to self.
*/
TSelf& operator=(PackedFunc packed) {
packed_ = packed;
return *this;
}
/*!
* \brief Invoke the operator.
* \param args The arguments
* \returns The return value.
*/
inline R operator()(Args ...args) const;
/*!
* \brief convert to PackedFunc
* \return the internal PackedFunc
*/
operator PackedFunc() const {
return packed();
}
/*!
* \return reference the internal PackedFunc
*/
const PackedFunc& packed() const {
return packed_;
}
private:
friend class TVMRetValue;
/*! \brief The internal packed function */
PackedFunc packed_;
/*!
* \brief Assign the packed field using a typed lambda function.
*
* \param flambda The lambda function.
* \tparam FLambda The lambda function type.
* \note We capture the lambda when possible for maximum efficiency.
*/
template<typename FLambda>
inline void AssignTypedLambda(FLambda flambda);
};
/*! \brief Arguments into TVM functions. */
class TVMArgs {
public:
const TVMValue* values;
const int* type_codes;
int num_args;
/*!
* \brief constructor
* \param values The argument values
* \param type_codes The argument type codes
* \param num_args number of arguments.
*/
TVMArgs(const TVMValue* values,
const int* type_codes,
int num_args)
: values(values),
type_codes(type_codes),
num_args(num_args) { }
/*! \return size of the arguments */
inline int size() const;
/*!
* \brief Get i-th argument
* \param i the index.
* \return the ith argument.
*/
inline TVMArgValue operator[](int i) const;
};
/*!
* \brief Convert type code to its name
* \param type_code The type code .
* \return The name of type code.
*/
inline const char* TypeCode2Str(int type_code);
/*!
* \brief convert a string to TVM type.
* \param s The string to be converted.
* \return The corresponding tvm type.
*/
inline TVMType String2TVMType(std::string s);
/*!
* \brief convert a TVM type to string.
* \param t The type to be converted.
* \return The corresponding tvm type in string.
*/
inline std::string TVMType2String(TVMType t);
// macro to check type code.
#define TVM_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " \
<< TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \
/*!
* \brief Type traits to mark if a class is tvm extension type.
*
* To enable extension type in C++ must be register () ed via marco.
* TVM_REGISTER_EXT_TYPE(TypeName) after defining this with this traits.
*
* Extension class can be passed and returned via PackedFunc in all tvm runtime.
* Internally extension class is stored as T*.
*
* \tparam T the typename
*/
template<typename T>
struct extension_class_info {
static const int code = 0;
};
/*!
* \brief Runtime function table about extension type.
*/
class ExtTypeVTable {
public:
/*! \brief function to be called to delete a handle */
void (*destroy)(void* handle);
/*! \brief function to be called when clone a handle */
void* (*clone)(void* handle);
/*!
* \brief Register type
* \tparam T The type to be register.
* \return The registered vtable.
*/
template <typename T>
static inline ExtTypeVTable* Register_();
/*!
* \brief Get a vtable based on type code.
* \param type_code The type code
* \return The registered vtable.
*/
TVM_DLL static ExtTypeVTable* Get(int type_code);
private:
// Internal registration function.
TVM_DLL static ExtTypeVTable* RegisterInternal(int type_code, const ExtTypeVTable& vt);
};
/*!
* \brief Internal base class to
* handle conversion to POD values.
*/
class TVMPODValue_ {
public:
operator double() const {
// Allow automatic conversion from int to float
// This avoids errors when user pass in int from
// the frontend while the API expects a float.
if (type_code_ == kDLInt) {
return static_cast<double>(value_.v_int64);
}
TVM_CHECK_TYPE_CODE(type_code_, kDLFloat);
return value_.v_float64;
}
operator int64_t() const {
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64;
}
operator uint64_t() const {
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64;
}
operator int() const {
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
CHECK_LE(value_.v_int64,
std::numeric_limits<int>::max());
return static_cast<int>(value_.v_int64);
}
operator bool() const {
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64 != 0;
}
operator void*() const {
if (type_code_ == kNull) return nullptr;
if (type_code_ == kArrayHandle) return value_.v_handle;
TVM_CHECK_TYPE_CODE(type_code_, kHandle);
return value_.v_handle;
}
operator DLTensor*() const {
if (type_code_ == kArrayHandle ||
type_code_ == kNDArrayContainer) {
return static_cast<DLTensor*>(value_.v_handle);
} else {
if (type_code_ == kNull) return nullptr;
LOG(FATAL) << "Expected "
<< "DLTensor* or NDArray but get "
<< TypeCode2Str(type_code_);
return nullptr;
}
}
operator NDArray() const {
if (type_code_ == kNull) return NDArray();
TVM_CHECK_TYPE_CODE(type_code_, kNDArrayContainer);
return NDArray(static_cast<NDArray::Container*>(value_.v_handle));
}
operator TVMContext() const {
TVM_CHECK_TYPE_CODE(type_code_, kTVMContext);
return value_.v_ctx;
}
template<typename TExtension>
const TExtension& AsExtension() const {
CHECK_LT(type_code_, kExtEnd);
return static_cast<TExtension*>(value_.v_handle)[0];
}
int type_code() const {
return type_code_;
}
/*!
* \brief return handle as specific pointer type.
* \tparam T the data type.
* \return The pointer type.
*/
template<typename T>
T* ptr() const {
return static_cast<T*>(value_.v_handle);
}
protected:
friend class TVMArgsSetter;
friend class TVMRetValue;
TVMPODValue_() : type_code_(kNull) {}
TVMPODValue_(TVMValue value, int type_code)
: value_(value), type_code_(type_code) {}
/*! \brief The value */
TVMValue value_;
/*! \brief the type code */
int type_code_;
};
/*!
* \brief A single argument value to PackedFunc.
* Containing both type_code and TVMValue
*
* Provides utilities to do type cast into other types.
*/
class TVMArgValue : public TVMPODValue_ {
public:
/*! \brief default constructor */
TVMArgValue() {}
/*!
* \brief constructor
* \param value of the function
* \param type_code The type code.
*/
TVMArgValue(TVMValue value, int type_code)
: TVMPODValue_(value, type_code) {
}
// reuse converter from parent
using TVMPODValue_::operator double;
using TVMPODValue_::operator int64_t;
using TVMPODValue_::operator uint64_t;
using TVMPODValue_::operator int;
using TVMPODValue_::operator bool;
using TVMPODValue_::operator void*;
using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator NDArray;
using TVMPODValue_::operator TVMContext;
// conversion operator.
operator std::string() const {
if (type_code_ == kTVMType) {
return TVMType2String(operator TVMType());
} else if (type_code_ == kBytes) {
TVMByteArray* arr = static_cast<TVMByteArray*>(value_.v_handle);
return std::string(arr->data, arr->size);
} else {
TVM_CHECK_TYPE_CODE(type_code_, kStr);
return std::string(value_.v_str);
}
}
operator TVMType() const {
if (type_code_ == kStr) {
return String2TVMType(operator std::string());
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMType);
return value_.v_type;
}
operator PackedFunc() const {
if (type_code_ == kNull) return PackedFunc();
TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle);
return *ptr<PackedFunc>();
}
template<typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
}
operator Module() const {
TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
return *ptr<Module>();
}
const TVMValue& value() const {
return value_;
}
// Deferred extension handler.
template<typename TNodeRef>
inline TNodeRef AsNodeRef() const;
template<typename T,
typename = typename std::enable_if<
std::is_class<T>::value>::type>
inline operator T() const;
template<typename TNodeRef,
typename = typename std::enable_if<
std::is_class<TNodeRef>::value>::type>
inline bool IsNodeType() const;
// get internal node ptr, if it is node
inline std::shared_ptr<Node>& node_sptr();
};
/*!
* \brief Return Value container,
* Unlike TVMArgValue, which only holds reference and do not delete
* the underlying container during destruction.
*
* TVMRetValue holds value and will manage the underlying containers
* when it stores a complicated data type.
*/
class TVMRetValue : public TVMPODValue_ {
public:
/*! \brief default constructor */
TVMRetValue() {}
/*!
* \brief move constructor from anoter return value.
* \param other The other return value.
*/
TVMRetValue(TVMRetValue&& other)
: TVMPODValue_(other.value_, other.type_code_) {
other.value_.v_handle = nullptr;
other.type_code_ = kNull;
}
/*! \brief destructor */
~TVMRetValue() {
this->Clear();
}
// reuse converter from parent
using TVMPODValue_::operator double;
using TVMPODValue_::operator int64_t;
using TVMPODValue_::operator uint64_t;
using TVMPODValue_::operator int;
using TVMPODValue_::operator bool;
using TVMPODValue_::operator void*;
using TVMPODValue_::operator DLTensor*;
using TVMPODValue_::operator TVMContext;
using TVMPODValue_::operator NDArray;
// Disable copy and assign from another value, but allow move.
TVMRetValue(const TVMRetValue& other) {
this->Assign(other);
}
// conversion operators
operator std::string() const {
if (type_code_ == kTVMType) {
return TVMType2String(operator TVMType());
} else if (type_code_ == kBytes) {
return *ptr<std::string>();
}
TVM_CHECK_TYPE_CODE(type_code_, kStr);
return *ptr<std::string>();
}
operator TVMType() const {
if (type_code_ == kStr) {
return String2TVMType(operator std::string());
}
TVM_CHECK_TYPE_CODE(type_code_, kTVMType);
return value_.v_type;
}
operator PackedFunc() const {
if (type_code_ == kNull) return PackedFunc();
TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle);
return *ptr<PackedFunc>();
}
template<typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
}
operator Module() const {
TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
return *ptr<Module>();
}
// Assign operators
TVMRetValue& operator=(TVMRetValue&& other) {
this->Clear();
value_ = other.value_;
type_code_ = other.type_code_;
other.type_code_ = kNull;
return *this;
}
TVMRetValue& operator=(double value) {
this->SwitchToPOD(kDLFloat);
value_.v_float64 = value;
return *this;
}
TVMRetValue& operator=(std::nullptr_t value) {
this->SwitchToPOD(kNull);
value_.v_handle = value;
return *this;
}
TVMRetValue& operator=(void* value) {
this->SwitchToPOD(kHandle);
value_.v_handle = value;
return *this;
}
TVMRetValue& operator=(int64_t value) {
this->SwitchToPOD(kDLInt);
value_.v_int64 = value;
return *this;
}
TVMRetValue& operator=(int value) {
this->SwitchToPOD(kDLInt);
value_.v_int64 = value;
return *this;
}
TVMRetValue& operator=(TVMType t) {
this->SwitchToPOD(kTVMType);
value_.v_type = t;
return *this;
}
TVMRetValue& operator=(bool value) {
this->SwitchToPOD(kDLInt);
value_.v_int64 = value;
return *this;
}
TVMRetValue& operator=(std::string value) {
this->SwitchToClass(kStr, value);
return *this;
}
TVMRetValue& operator=(TVMByteArray value) {
this->SwitchToClass(kBytes, std::string(value.data, value.size));
return *this;
}
TVMRetValue& operator=(NDArray other) {
this->Clear();
type_code_ = kNDArrayContainer;
value_.v_handle = other.data_;
other.data_ = nullptr;
return *this;
}
TVMRetValue& operator=(PackedFunc f) {
this->SwitchToClass(kFuncHandle, f);
return *this;
}
template<typename FType>
TVMRetValue& operator=(const TypedPackedFunc<FType>& f) {
return operator=(f.packed());
}
TVMRetValue& operator=(Module m) {
this->SwitchToClass(kModuleHandle, m);
return *this;
}
TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0
this->Assign(other);
return *this;
}
TVMRetValue& operator=(const TVMArgValue& other) {
this->Assign(other);
return *this;
}
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
TVMRetValue& operator=(const T& other) {
this->SwitchToClass<T>(
extension_class_info<T>::code, other);
return *this;
}
/*!
* \brief Move the value back to front-end via C API.
* This marks the current container as null.
* The managed resources is moved to front-end and
* the front end should take charge in managing them.
*
* \param ret_value The return value.
* \param ret_type_code The return type code.
*/
void MoveToCHost(TVMValue* ret_value,
int* ret_type_code) {
// cannot move str; need specially handle.
CHECK(type_code_ != kStr && type_code_ != kBytes);
*ret_value = value_;
*ret_type_code = type_code_;
type_code_ = kNull;
}
/*! \return The value field, if the data is POD */
const TVMValue& value() const {
CHECK(type_code_ != kNodeHandle &&
type_code_ != kFuncHandle &&
type_code_ != kModuleHandle &&
type_code_ != kStr) << "TVMRetValue.value can only be used for POD data";
return value_;
}
// NodeRef related extenstions: in tvm/packed_func_ext.h
template<typename T,
typename = typename std::enable_if<
std::is_class<T>::value>::type>
inline operator T() const;
template<typename TNodeRef>
inline TNodeRef AsNodeRef() const;
inline TVMRetValue& operator=(const NodeRef& other);
inline TVMRetValue& operator=(const std::shared_ptr<Node>& other);
private:
template<typename T>
void Assign(const T& other) {
switch (other.type_code()) {
case kStr: {
SwitchToClass<std::string>(kStr, other);
break;
}
case kBytes: {
SwitchToClass<std::string>(kBytes, other);
break;
}
case kFuncHandle: {
SwitchToClass<PackedFunc>(kFuncHandle, other);
break;
}
case kModuleHandle: {
SwitchToClass<Module>(kModuleHandle, other);
break;
}
case kNDArrayContainer: {
*this = other.operator NDArray();
break;
}
case kNodeHandle: {
SwitchToClass<std::shared_ptr<Node> >(
kNodeHandle, *other.template ptr<std::shared_ptr<Node> >());
break;
}
default: {
if (other.type_code() < kExtBegin) {
SwitchToPOD(other.type_code());
value_ = other.value_;
} else {
#if TVM_RUNTIME_HEADER_ONLY
LOG(FATAL) << "Header only mode do not support ext type";
#else
this->Clear();
type_code_ = other.type_code();
value_.v_handle =
(*(ExtTypeVTable::Get(other.type_code())->clone))(
other.value().v_handle);
#endif
}
break;
}
}
}
// get the internal container.
void SwitchToPOD(int type_code) {
if (type_code_ != type_code) {
this->Clear();
type_code_ = type_code;
}
}
template<typename T>
void SwitchToClass(int type_code, T v) {
if (type_code_ != type_code) {
this->Clear();
type_code_ = type_code;
value_.v_handle = new T(v);
} else {
*static_cast<T*>(value_.v_handle) = v;
}
}
void Clear() {
if (type_code_ == kNull) return;
switch (type_code_) {
case kStr: delete ptr<std::string>(); break;
case kFuncHandle: delete ptr<PackedFunc>(); break;
case kModuleHandle: delete ptr<Module>(); break;
case kNodeHandle: delete ptr<std::shared_ptr<Node> >(); break;
case kNDArrayContainer: {
static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
break;
}
}
if (type_code_ > kExtBegin) {
#if TVM_RUNTIME_HEADER_ONLY
LOG(FATAL) << "Header only mode do not support ext type";
#else
(*(ExtTypeVTable::Get(type_code_)->destroy))(value_.v_handle);
#endif
}
type_code_ = kNull;
}
};
// implementation details
inline const char* TypeCode2Str(int type_code) {
switch (type_code) {
case kDLInt: return "int";
case kDLUInt: return "uint";
case kDLFloat: return "float";
case kStr: return "str";
case kBytes: return "bytes";
case kHandle: return "handle";
case kNull: return "NULL";
case kNodeHandle: return "NodeHandle";
case kArrayHandle: return "ArrayHandle";
case kTVMType: return "TVMType";
case kTVMContext: return "TVMContext";
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
case kNDArrayContainer: return "NDArrayContainer";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
}
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*)
os << TypeCode2Str(t.code);
if (t.code == kHandle) return os;
os << static_cast<int>(t.bits);
if (t.lanes != 1) {
os << 'x' << static_cast<int>(t.lanes);
}
return os;
}
#endif
inline std::string TVMType2String(TVMType t) {
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
std::ostringstream os;
os << t;
return os.str();
#else
std::string repr = "";
repr += TypeCode2Str(t.code);
if (t.code == kHandle) return repr;
repr += std::to_string(static_cast<int>(t.bits));
if (t.lanes != 1) {
repr += "x" + std::to_string(static_cast<int>(t.lanes));
}
return repr;
#endif
}
inline TVMType String2TVMType(std::string s) {
TVMType t;
t.bits = 32; t.lanes = 1;
const char* scan;
if (s.substr(0, 3) == "int") {
t.code = kDLInt; scan = s.c_str() + 3;
} else if (s.substr(0, 4) == "uint") {
t.code = kDLUInt; scan = s.c_str() + 4;
} else if (s.substr(0, 5) == "float") {
t.code = kDLFloat; scan = s.c_str() + 5;
} else if (s.substr(0, 6) == "handle") {
t.code = kHandle;
t.bits = 64; // handle uses 64 bit by default.
scan = s.c_str() + 6;
} else {
scan = s.c_str();
LOG(FATAL) << "unknown type " << s;
}
char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
if (bits != 0) t.bits = bits;
if (*xdelim == 'x') {
t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, nullptr, 10));
}
return t;
}
inline TVMArgValue TVMArgs::operator[](int i) const {
CHECK_LT(i, num_args)
<< "not enough argument passed, "
<< num_args << " passed"
<< " but request arg[" << i << "].";
return TVMArgValue(values[i], type_codes[i]);
}
inline int TVMArgs::size() const {
return num_args;
}
inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const {
body_(args, rv);
}
inline PackedFunc::FType PackedFunc::body() const {
return body_;
}
// internal namespace
namespace detail {
template<bool stop, std::size_t I, typename F>
struct for_each_dispatcher {
template<typename T, typename ...Args>
static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*)
f(I, std::forward<T>(value));
for_each_dispatcher<sizeof...(Args) == 0, (I+1), F>
::run(f, std::forward<Args>(args)...);
}
};
template<std::size_t I, typename F>
struct for_each_dispatcher<true, I, F> {
static void run(const F& f) {} // NOLINT(*)
};
template<typename F, typename ...Args>
inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
for_each_dispatcher<sizeof...(Args) == 0, 0, F>
::run(f, std::forward<Args>(args)...);
}
} // namespace detail
/* \brief argument settter to PackedFunc */
class TVMArgsSetter {
public:
TVMArgsSetter(TVMValue* values, int* type_codes)
: values_(values), type_codes_(type_codes) {}
// setters for POD types
template<typename T,
typename = typename std::enable_if<
std::is_integral<T>::value>::type>
void operator()(size_t i, T value) const {
values_[i].v_int64 = static_cast<int64_t>(value);
type_codes_[i] = kDLInt;
}
void operator()(size_t i, uint64_t value) const {
values_[i].v_int64 = static_cast<int64_t>(value);
CHECK_LE(value,
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
type_codes_[i] = kDLInt;
}
void operator()(size_t i, double value) const {
values_[i].v_float64 = value;
type_codes_[i] = kDLFloat;
}
void operator()(size_t i, std::nullptr_t value) const {
values_[i].v_handle = value;
type_codes_[i] = kNull;
}
void operator()(size_t i, const TVMArgValue& value) const {
values_[i] = value.value_;
type_codes_[i] = value.type_code_;
}
void operator()(size_t i, void* value) const {
values_[i].v_handle = value;
type_codes_[i] = kHandle;
}
void operator()(size_t i, DLTensor* value) const {
values_[i].v_handle = value;
type_codes_[i] = kArrayHandle;
}
void operator()(size_t i, TVMContext value) const {
values_[i].v_ctx = value;
type_codes_[i] = kTVMContext;
}
void operator()(size_t i, TVMType value) const {
values_[i].v_type = value;
type_codes_[i] = kTVMType;
}
void operator()(size_t i, const char* value) const {
values_[i].v_str = value;
type_codes_[i] = kStr;
}
// setters for container type
// They must be reference(instead of const ref)
// to make sure they are alive in the tuple(instead of getting converted)
void operator()(size_t i, const std::string& value) const { // NOLINT(*)
values_[i].v_str = value.c_str();
type_codes_[i] = kStr;
}
void operator()(size_t i, const TVMByteArray& value) const { // NOLINT(*)
values_[i].v_handle = const_cast<TVMByteArray*>(&value);
type_codes_[i] = kBytes;
}
void operator()(size_t i, const PackedFunc& value) const { // NOLINT(*)
values_[i].v_handle = const_cast<PackedFunc*>(&value);
type_codes_[i] = kFuncHandle;
}
template<typename FType>
void operator()(size_t i, const TypedPackedFunc<FType>& value) const { // NOLINT(*)
operator()(i, value.packed());
}
void operator()(size_t i, const Module& value) const { // NOLINT(*)
values_[i].v_handle = const_cast<Module*>(&value);
type_codes_[i] = kModuleHandle;
}
void operator()(size_t i, const NDArray& value) const { // NOLINT(*)
values_[i].v_handle = value.data_;
type_codes_[i] = kNDArrayContainer;
}
void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*)
if (value.type_code() == kStr) {
values_[i].v_str = value.ptr<std::string>()->c_str();
type_codes_[i] = kStr;
} else {
CHECK_NE(value.type_code(), kBytes) << "not handled.";
values_[i] = value.value_;
type_codes_[i] = value.type_code();
}
}
// extension
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
inline void operator()(size_t i, const T& value) const;
// NodeRef related extenstions: in tvm/packed_func_ext.h
inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*)
private:
/*! \brief The values fields */
TVMValue* values_;
/*! \brief The type code fields */
int* type_codes_;
};
template<typename... Args>
inline TVMRetValue PackedFunc::operator()(Args&& ...args) const {
const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
TVMValue values[kArraySize];
int type_codes[kArraySize];
detail::for_each(TVMArgsSetter(values, type_codes),
std::forward<Args>(args)...);
TVMRetValue rv;
body_(TVMArgs(values, type_codes, kNumArgs), &rv);
return rv;
}
namespace detail {
template<typename R, int nleft, int index, typename F>
struct unpack_call_dispatcher {
template<typename ...Args>
static void run(const F& f,
const TVMArgs& args_pack,
TVMRetValue* rv,
Args&&... unpacked_args) {
unpack_call_dispatcher<R, nleft - 1, index + 1, F>
::run(f, args_pack, rv,
std::forward<Args>(unpacked_args)...,
args_pack[index]);
}
};
template<typename R, int index, typename F>
struct unpack_call_dispatcher<R, 0, index, F> {
template<typename ...Args>
static void run(const F& f,
const TVMArgs& args_pack,
TVMRetValue* rv,
Args&&... unpacked_args) {
*rv = R(f(std::forward<Args>(unpacked_args)...));
}
};
template<int index, typename F>
struct unpack_call_dispatcher<void, 0, index, F> {
template<typename ...Args>
static void run(const F& f,
const TVMArgs& args_pack,
TVMRetValue* rv,
Args&&... unpacked_args) {
f(std::forward<Args>(unpacked_args)...);
}
};
template<typename R, int nargs, typename F>
inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) {
unpack_call_dispatcher<R, nargs, 0, F>::run(f, args, rv);
}
template<typename R, typename ...Args>
inline R call_packed(const PackedFunc& pf, Args&& ...args) {
return R(pf(std::forward<Args>(args)...));
}
template<typename R>
struct typed_packed_call_dispatcher {
template<typename ...Args>
static inline R run(const PackedFunc& pf, Args&& ...args) {
return pf(std::forward<Args>(args)...);
}
};
template<>
struct typed_packed_call_dispatcher<void> {
template<typename ...Args>
static inline void run(const PackedFunc& pf, Args&& ...args) {
pf(std::forward<Args>(args)...);
}
};
} // namespace detail
template<typename R, typename ...Args>
TypedPackedFunc<R(Args...)>::TypedPackedFunc(PackedFunc packed)
: packed_(packed) {}
template<typename R, typename ...Args>
template<typename FType>
inline void TypedPackedFunc<R(Args...)>::AssignTypedLambda(FType flambda) {
packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) {
detail::unpack_call<R, sizeof...(Args)>(flambda, args, rv);
});
}
template<typename R, typename ...Args>
inline R TypedPackedFunc<R(Args...)>::operator()(Args... args) const {
return detail::typed_packed_call_dispatcher<R>
::run(packed_, std::forward<Args>(args)...);
}
// extension and node type handling
namespace detail {
template<typename T, typename TSrc, bool is_ext>
struct TVMValueCast {
static T Apply(const TSrc* self) {
return self->template AsNodeRef<T>();
}
};
template<typename T, typename TSrc>
struct TVMValueCast<T, TSrc, true> {
static T Apply(const TSrc* self) {
return self->template AsExtension<T>();
}
};
} // namespace detail
template<typename T, typename>
inline TVMArgValue::operator T() const {
return detail::
TVMValueCast<T, TVMArgValue, extension_class_info<T>::code != 0>
::Apply(this);
}
template<typename T, typename>
inline TVMRetValue::operator T() const {
return detail::
TVMValueCast<T, TVMRetValue, extension_class_info<T>::code != 0>
::Apply(this);
}
template<typename T, typename>
inline void TVMArgsSetter::operator()(size_t i, const T& value) const {
static_assert(extension_class_info<T>::code != 0,
"Need to have extesion code");
type_codes_[i] = extension_class_info<T>::code;
values_[i].v_handle = const_cast<T*>(&value);
}
// extension type handling
template<typename T>
struct ExtTypeInfo {
static void destroy(void* handle) {
delete static_cast<T*>(handle);
}
static void* clone(void* handle) {
return new T(*static_cast<T*>(handle));
}
};
template<typename T>
inline ExtTypeVTable* ExtTypeVTable::Register_() {
const int code = extension_class_info<T>::code;
static_assert(code != 0,
"require extension_class_info traits to be declared with non-zero code");
ExtTypeVTable vt;
vt.clone = ExtTypeInfo<T>::clone;
vt.destroy = ExtTypeInfo<T>::destroy;
return ExtTypeVTable::RegisterInternal(code, vt);
}
// Implement Module::GetFunction
// Put implementation in this file so we have seen the PackedFunc
inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
PackedFunc pf = node_->GetFunction(name, node_);
if (pf != nullptr) return pf;
if (query_imports) {
for (const Module& m : node_->imports_) {
pf = m.node_->GetFunction(name, m.node_);
if (pf != nullptr) return pf;
}
}
return pf;
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_PACKED_FUNC_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment