Unverified Commit cbd55eb1 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[performance] Batch DGLGraph in C++ end. (#2155)



* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* fix

* upd
Co-authored-by: default avatarVoVAllen <jz1749@nyu.edu>
parent ac570c1d
...@@ -145,6 +145,7 @@ def to_backend_ctx(dglctx): ...@@ -145,6 +145,7 @@ def to_backend_ctx(dglctx):
def astype(input, ty): def astype(input, ty):
with tf.device(input.device):
return tf.cast(input, dtype=ty) return tf.cast(input, dtype=ty)
......
"""Utilities for batching/unbatching graphs.""" """Utilities for batching/unbatching graphs."""
from collections.abc import Mapping from collections.abc import Mapping
from collections import defaultdict
from . import backend as F from . import backend as F
from .base import ALL, is_all, DGLError, dgl_warning from .base import ALL, is_all, DGLError, dgl_warning
from .heterograph_index import disjoint_union
from .heterograph import DGLHeteroGraph
from . import convert from . import convert
from . import utils from . import utils
__all__ = ['batch', 'unbatch', 'batch_hetero', 'unbatch_hetero'] __all__ = ['batch', 'unbatch', 'batch_hetero', 'unbatch_hetero']
def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None): def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None):
...@@ -156,61 +158,44 @@ def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None): ...@@ -156,61 +158,44 @@ def batch(graphs, ndata=ALL, edata=ALL, *, node_attrs=None, edge_attrs=None):
dgl_warning('Arguments edge_attrs has been deprecated. Please use' dgl_warning('Arguments edge_attrs has been deprecated. Please use'
' edata instead.') ' edata instead.')
edata = edge_attrs edata = edge_attrs
if not (is_all(ndata) or isinstance(ndata, list)): if not (is_all(ndata) or isinstance(ndata, list) or ndata is None):
raise DGLError('Invalid argument ndata: must be a string list but got {}.'.format( raise DGLError('Invalid argument ndata: must be a string list but got {}.'.format(
type(ndata))) type(ndata)))
if not (is_all(edata) or isinstance(edata, list)): if not (is_all(edata) or isinstance(edata, list) or edata is None):
raise DGLError('Invalid argument edata: must be a string list but got {}.'.format( raise DGLError('Invalid argument edata: must be a string list but got {}.'.format(
type(edata))) type(edata)))
if any(g.is_block for g in graphs): if any(g.is_block for g in graphs):
raise DGLError("Batching a block is not supported.") raise DGLError("Batching a block is not supported.")
utils.check_all_same_device(graphs, 'graphs') relations = list(sorted(graphs[0].canonical_etypes))
utils.check_all_same_idtype(graphs, 'graphs') ntypes = list(sorted(graphs[0].ntypes))
relations = graphs[0].canonical_etypes etypes = [etype for _, etype, _ in relations]
ntypes = graphs[0].ntypes
idtype = graphs[0].idtype gidx = disjoint_union(graphs[0]._graph.metagraph, [g._graph for g in graphs])
device = graphs[0].device retg = DGLHeteroGraph(gidx, ntypes, etypes)
# Batch graph structure for each relation graph
edge_dict = defaultdict(list)
num_nodes_dict = defaultdict(int)
for g in graphs:
for rel in relations:
srctype, etype, dsttype = rel
u, v = g.edges(order='eid', etype=rel)
src = u + num_nodes_dict[srctype]
dst = v + num_nodes_dict[dsttype]
edge_dict[rel].append((src, dst))
for ntype in ntypes:
num_nodes_dict[ntype] += g.number_of_nodes(ntype)
for rel in relations:
src, dst = zip(*edge_dict[rel])
edge_dict[rel] = (F.cat(src, 0), F.cat(dst, 0))
retg = convert.heterograph(edge_dict, num_nodes_dict, idtype=idtype, device=device)
# Compute batch num nodes # Compute batch num nodes
bnn = {} bnn = {}
for ntype in graphs[0].ntypes: for ntype in ntypes:
bnn[ntype] = F.cat([g.batch_num_nodes(ntype) for g in graphs], 0) bnn[ntype] = F.cat([g.batch_num_nodes(ntype) for g in graphs], 0)
retg.set_batch_num_nodes(bnn) retg.set_batch_num_nodes(bnn)
# Compute batch num edges # Compute batch num edges
bne = {} bne = {}
for etype in graphs[0].canonical_etypes: for etype in relations:
bne[etype] = F.cat([g.batch_num_edges(etype) for g in graphs], 0) bne[etype] = F.cat([g.batch_num_edges(etype) for g in graphs], 0)
retg.set_batch_num_edges(bne) retg.set_batch_num_edges(bne)
# Batch node feature # Batch node feature
if ndata is not None: if ndata is not None:
for ntype in graphs[0].ntypes: for ntype in ntypes:
feat_dicts = [g.nodes[ntype].data for g in graphs if g.number_of_nodes(ntype) > 0] feat_dicts = [g.nodes[ntype].data for g in graphs if g.number_of_nodes(ntype) > 0]
ret_feat = _batch_feat_dicts(feat_dicts, ndata, 'nodes["{}"].data'.format(ntype)) ret_feat = _batch_feat_dicts(feat_dicts, ndata, 'nodes["{}"].data'.format(ntype))
retg.nodes[ntype].data.update(ret_feat) retg.nodes[ntype].data.update(ret_feat)
# Batch edge feature # Batch edge feature
if edata is not None: if edata is not None:
for etype in graphs[0].canonical_etypes: for etype in relations:
feat_dicts = [g.edges[etype].data for g in graphs if g.number_of_edges(etype) > 0] feat_dicts = [g.edges[etype].data for g in graphs if g.number_of_edges(etype) > 0]
ret_feat = _batch_feat_dicts(feat_dicts, edata, 'edges[{}].data'.format(etype)) ret_feat = _batch_feat_dicts(feat_dicts, edata, 'edges[{}].data'.format(etype))
retg.edges[etype].data.update(ret_feat) retg.edges[etype].data.update(ret_feat)
......
...@@ -202,7 +202,6 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) { ...@@ -202,7 +202,6 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
} }
NDArray Concat(const std::vector<IdArray>& arrays) { NDArray Concat(const std::vector<IdArray>& arrays) {
CHECK(arrays.size() > 1) << "Number of arrays should larger than 1";
IdArray ret; IdArray ret;
int64_t len = 0, offset = 0; int64_t len = 0, offset = 0;
......
...@@ -23,7 +23,9 @@ namespace cuda { ...@@ -23,7 +23,9 @@ namespace cuda {
* and is also power of two. * and is also power of two.
*/ */
inline int FindNumThreads(int dim, int max_nthrs = CUDA_MAX_NUM_THREADS) { inline int FindNumThreads(int dim, int max_nthrs = CUDA_MAX_NUM_THREADS) {
CHECK_NE(dim, 0); CHECK_GE(dim, 0);
if (dim == 0)
return 1;
int ret = max_nthrs; int ret = max_nthrs;
while (ret > dim) { while (ret > dim) {
ret = ret >> 1; ret = ret >> 1;
......
...@@ -10,8 +10,6 @@ namespace dgl { ...@@ -10,8 +10,6 @@ namespace dgl {
namespace aten { namespace aten {
///////////////////////// COO Based Operations///////////////////////// ///////////////////////// COO Based Operations/////////////////////////
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) { COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
CHECK(coos.size() > 1) <<
"The length of input COOMatrix vector should be larger than 1";
uint64_t src_offset = 0, dst_offset = 0; uint64_t src_offset = 0, dst_offset = 0;
int64_t edge_data_offset = 0; int64_t edge_data_offset = 0;
bool has_data = false; bool has_data = false;
...@@ -114,8 +112,6 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes( ...@@ -114,8 +112,6 @@ std::vector<COOMatrix> DisjointPartitionCooBySizes(
///////////////////////// CSR Based Operations///////////////////////// ///////////////////////// CSR Based Operations/////////////////////////
CSRMatrix DisjointUnionCsr(const std::vector<CSRMatrix>& csrs) { CSRMatrix DisjointUnionCsr(const std::vector<CSRMatrix>& csrs) {
CHECK(csrs.size() > 1) <<
"The length of input CSRMatrix vector should be larger than 1";
uint64_t src_offset = 0, dst_offset = 0; uint64_t src_offset = 0, dst_offset = 0;
int64_t indices_offset = 0; int64_t indices_offset = 0;
bool has_data = false; bool has_data = false;
......
...@@ -124,6 +124,10 @@ void HeteroGraphSanityCheck(GraphPtr meta_graph, const std::vector<HeteroGraphPt ...@@ -124,6 +124,10 @@ void HeteroGraphSanityCheck(GraphPtr meta_graph, const std::vector<HeteroGraphPt
for (const auto &rg : rel_graphs) { for (const auto &rg : rel_graphs) {
CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must have only one edge type."; CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must have only one edge type.";
} }
auto ctx = rel_graphs[0]->Context();
for (const auto &rg : rel_graphs) {
CHECK_EQ(rg->Context(), ctx) << "Each relation graph must have the same context.";
}
} }
std::vector<int64_t> std::vector<int64_t>
......
...@@ -559,28 +559,6 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes_v ...@@ -559,28 +559,6 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes_v
*rv = ret_list; *rv = ret_list;
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0];
List<HeteroGraphRef> component_graphs = args[1];
CHECK(component_graphs.size() > 0)
<< "Expect graph list has at least one graph";
std::vector<HeteroGraphPtr> component_ptrs;
component_ptrs.reserve(component_graphs.size());
const int64_t bits = component_graphs[0]->NumBits();
for (const auto& component : component_graphs) {
component_ptrs.push_back(component.sptr());
CHECK_EQ(component->NumBits(), bits)
<< "Expect graphs to batch have the same index dtype(int" << bits
<< "), but got int" << component->NumBits();
}
ATEN_ID_BITS_SWITCH(bits, IdType, {
auto hgptr =
DisjointUnionHeteroGraph<IdType>(meta_graph.sptr(), component_ptrs);
*rv = HeteroGraphRef(hgptr);
});
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
......
...@@ -91,33 +91,34 @@ HeteroGraphPtr DisjointUnionHeteroGraph2( ...@@ -91,33 +91,34 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges()); std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices(), 0); std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices(), 0);
// Loop over all ntypes
for (dgl_type_t vtype = 0; vtype < meta_graph->NumVertices(); ++vtype) {
uint64_t offset = 0;
for (const auto &cg : component_graphs)
offset += cg->NumVertices(vtype);
num_nodes_per_type[vtype] = offset;
}
// Loop over all canonical etypes // Loop over all canonical etypes
for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) { for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {
auto pair = meta_graph->FindEdge(etype); auto pair = meta_graph->FindEdge(etype);
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
uint64_t src_offset = 0, dst_offset = 0;
HeteroGraphPtr rgptr = nullptr; HeteroGraphPtr rgptr = nullptr;
const dgl_format_code_t code =\ const dgl_format_code_t code =\
component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats(); component_graphs[0]->GetRelationGraph(etype)->GetAllowedFormats();
// do some preprocess // do some preprocess
for (size_t i = 0; i < component_graphs.size(); ++i) { for (const auto &cg : component_graphs) {
const auto& cg = component_graphs[i];
const dgl_format_code_t cur_code = cg->GetRelationGraph(etype)->GetAllowedFormats(); const dgl_format_code_t cur_code = cg->GetRelationGraph(etype)->GetAllowedFormats();
if (cur_code != code) if (cur_code != code)
LOG(FATAL) << "All components should have the same formats"; LOG(FATAL) << "All components should have the same formats";
// Update offsets
src_offset += cg->NumVertices(src_vtype);
dst_offset += cg->NumVertices(dst_vtype);
} }
// prefer COO // prefer COO
if (FORMAT_HAS_COO(code)) { if (FORMAT_HAS_COO(code)) {
std::vector<aten::COOMatrix> coos; std::vector<aten::COOMatrix> coos;
for (size_t i = 0; i < component_graphs.size(); ++i) { for (const auto &cg : component_graphs) {
const auto& cg = component_graphs[i];
aten::COOMatrix coo = cg->GetCOOMatrix(etype); aten::COOMatrix coo = cg->GetCOOMatrix(etype);
coos.push_back(coo); coos.push_back(coo);
} }
...@@ -128,8 +129,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph2( ...@@ -128,8 +129,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
(src_vtype == dst_vtype) ? 1 : 2, res, code); (src_vtype == dst_vtype) ? 1 : 2, res, code);
} else if (FORMAT_HAS_CSR(code)) { } else if (FORMAT_HAS_CSR(code)) {
std::vector<aten::CSRMatrix> csrs; std::vector<aten::CSRMatrix> csrs;
for (size_t i = 0; i < component_graphs.size(); ++i) { for (const auto &cg : component_graphs) {
const auto& cg = component_graphs[i];
aten::CSRMatrix csr = cg->GetCSRMatrix(etype); aten::CSRMatrix csr = cg->GetCSRMatrix(etype);
csrs.push_back(csr); csrs.push_back(csr);
} }
...@@ -141,8 +141,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph2( ...@@ -141,8 +141,7 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
} else if (FORMAT_HAS_CSC(code)) { } else if (FORMAT_HAS_CSC(code)) {
// CSR and CSC have the same storage format, i.e. CSRMatrix // CSR and CSC have the same storage format, i.e. CSRMatrix
std::vector<aten::CSRMatrix> cscs; std::vector<aten::CSRMatrix> cscs;
for (size_t i = 0; i < component_graphs.size(); ++i) { for (const auto &cg : component_graphs) {
const auto& cg = component_graphs[i];
aten::CSRMatrix csc = cg->GetCSCMatrix(etype); aten::CSRMatrix csc = cg->GetCSCMatrix(etype);
cscs.push_back(csc); cscs.push_back(csc);
} }
...@@ -152,8 +151,6 @@ HeteroGraphPtr DisjointUnionHeteroGraph2( ...@@ -152,8 +151,6 @@ HeteroGraphPtr DisjointUnionHeteroGraph2(
(src_vtype == dst_vtype) ? 1 : 2, res, code); (src_vtype == dst_vtype) ? 1 : 2, res, code);
} }
rel_graphs[etype] = rgptr; rel_graphs[etype] = rgptr;
num_nodes_per_type[src_vtype] = src_offset;
num_nodes_per_type[dst_vtype] = dst_offset;
} }
return CreateHeteroGraph(meta_graph, rel_graphs, std::move(num_nodes_per_type)); return CreateHeteroGraph(meta_graph, rel_graphs, std::move(num_nodes_per_type));
...@@ -272,56 +269,6 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2( ...@@ -272,56 +269,6 @@ std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes2(
return rst; return rst;
} }
template <class IdType>
HeteroGraphPtr DisjointUnionHeteroGraph(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs) {
CHECK_GT(component_graphs.size(), 0) << "Input graph list is empty";
std::vector<HeteroGraphPtr> rel_graphs(meta_graph->NumEdges());
std::vector<int64_t> num_nodes_per_type(meta_graph->NumVertices(), 0);
// Loop over all canonical etypes
for (dgl_type_t etype = 0; etype < meta_graph->NumEdges(); ++etype) {
auto pair = meta_graph->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
IdType src_offset = 0, dst_offset = 0;
std::vector<IdType> result_src, result_dst;
// Loop over all graphs
for (size_t i = 0; i < component_graphs.size(); ++i) {
const auto& cg = component_graphs[i];
EdgeArray edges = cg->Edges(etype);
size_t num_edges = cg->NumEdges(etype);
const IdType* edges_src_data = static_cast<const IdType*>(edges.src->data);
const IdType* edges_dst_data = static_cast<const IdType*>(edges.dst->data);
// Loop over all edges
for (size_t j = 0; j < num_edges; ++j) {
// TODO(mufei): Should use array operations to implement this.
result_src.push_back(edges_src_data[j] + src_offset);
result_dst.push_back(edges_dst_data[j] + dst_offset);
}
// Update offsets
src_offset += cg->NumVertices(src_vtype);
dst_offset += cg->NumVertices(dst_vtype);
}
HeteroGraphPtr rgptr = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype) ? 1 : 2, src_offset, dst_offset,
aten::VecToIdArray(result_src, sizeof(IdType) * 8),
aten::VecToIdArray(result_dst, sizeof(IdType) * 8));
rel_graphs[etype] = rgptr;
num_nodes_per_type[src_vtype] = src_offset;
num_nodes_per_type[dst_vtype] = dst_offset;
}
return CreateHeteroGraph(meta_graph, rel_graphs, std::move(num_nodes_per_type));
}
template HeteroGraphPtr DisjointUnionHeteroGraph<int32_t>(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);
template HeteroGraphPtr DisjointUnionHeteroGraph<int64_t>(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& component_graphs);
template <class IdType> template <class IdType>
std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes( std::vector<HeteroGraphPtr> DisjointPartitionHeteroBySizes(
GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes) { GraphPtr meta_graph, HeteroGraphPtr batched_graph, IdArray vertex_sizes, IdArray edge_sizes) {
......
...@@ -204,6 +204,7 @@ NDArray NDArray::Empty(std::vector<int64_t> shape, ...@@ -204,6 +204,7 @@ NDArray NDArray::Empty(std::vector<int64_t> shape,
// setup memory content // setup memory content
size_t size = GetDataSize(ret.data_->dl_tensor); size_t size = GetDataSize(ret.data_->dl_tensor);
size_t alignment = GetDataAlignment(ret.data_->dl_tensor); size_t alignment = GetDataAlignment(ret.data_->dl_tensor);
if (size > 0)
ret.data_->dl_tensor.data = ret.data_->dl_tensor.data =
DeviceAPI::Get(ret->ctx)->AllocDataSpace( DeviceAPI::Get(ret->ctx)->AllocDataSpace(
ret->ctx, size, alignment, ret->dtype); ret->ctx, size, alignment, ret->dtype);
......
...@@ -325,6 +325,6 @@ if __name__ == '__main__': ...@@ -325,6 +325,6 @@ if __name__ == '__main__':
#test_topology('int32') #test_topology('int32')
#test_batching_batched('int32') #test_batching_batched('int32')
#test_batched_features('int32') #test_batched_features('int32')
#test_empty_relation('int32') # test_empty_relation('int64')
#test_to_device('int32') #test_to_device('int32')
pass pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment