"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "fa1ae3b7c50c35e345ff22da02e84bc8b1351ac4"
Unverified Commit 2647afc9 authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[Performance][Feature] Add `src_nodes` paramter to `to_block()` to avoid cost...


[Performance][Feature] Add `src_nodes` paramter to `to_block()` to avoid cost running unique() when available. (#2973)

* Add lhs_nodes are paremeter to to_block

* Update unit test

* Switch to simplified node conversion

* Switch lhs_nodes to be in/out parameter

* Update docs
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent ad61a9a5
##
# Copyright 2019-2021 Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Module for graph transformation utilities.""" """Module for graph transformation utilities."""
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
...@@ -2053,7 +2068,7 @@ def compact_graphs(graphs, always_preserve=None, copy_ndata=True, copy_edata=Tru ...@@ -2053,7 +2068,7 @@ def compact_graphs(graphs, always_preserve=None, copy_ndata=True, copy_edata=Tru
return new_graphs return new_graphs
def to_block(g, dst_nodes=None, include_dst_in_src=True): def to_block(g, dst_nodes=None, include_dst_in_src=True, src_nodes=None):
"""Convert a graph into a bipartite-structured *block* for message passing. """Convert a graph into a bipartite-structured *block* for message passing.
A block is a graph consisting of two sets of nodes: the A block is a graph consisting of two sets of nodes: the
...@@ -2089,6 +2104,12 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True): ...@@ -2089,6 +2104,12 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
(Default: True) (Default: True)
src_nodes : Tensor or disct[str, Tensor], optional
The list of source nodes (and prefixed by destination nodes if
`include_dst_in_src` is True).
If a tensor is given, the graph must have only one node type.
Returns Returns
------- -------
DGLBlock DGLBlock
...@@ -2215,15 +2236,36 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True): ...@@ -2215,15 +2236,36 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
if g._graph.ctx != d.ctx: if g._graph.ctx != d.ctx:
raise ValueError('g and dst_nodes need to have the same context.') raise ValueError('g and dst_nodes need to have the same context.')
new_graph_index, src_nodes_nd, induced_edges_nd = _CAPI_DGLToBlock( src_node_ids = None
g._graph, dst_node_ids_nd, include_dst_in_src) src_node_ids_nd = None
if src_nodes is not None and not isinstance(src_nodes, Mapping):
# src_nodes is a Tensor, check if the g has only one type.
if len(g.ntypes) > 1:
raise DGLError(
'Graph has more than one node type; please specify a dict for src_nodes.')
src_nodes = {g.ntypes[0]: src_nodes}
src_node_ids = [
F.copy_to(F.tensor(src_nodes.get(ntype, []), dtype=g._idtype_str), \
F.to_backend_ctx(g._graph.ctx)) \
for ntype in g.ntypes]
src_node_ids_nd = [F.to_dgl_nd(nodes) for nodes in src_node_ids]
for d in src_node_ids_nd:
if g._graph.ctx != d.ctx:
raise ValueError('g and src_nodes need to have the same context.')
else:
# use an empty list to signal we need to generate it
src_node_ids_nd = []
new_graph_index, src_nodes_ids_nd, induced_edges_nd = _CAPI_DGLToBlock(
g._graph, dst_node_ids_nd, include_dst_in_src, src_node_ids_nd)
# The new graph duplicates the original node types to SRC and DST sets. # The new graph duplicates the original node types to SRC and DST sets.
new_ntypes = (g.ntypes, g.ntypes) new_ntypes = (g.ntypes, g.ntypes)
new_graph = DGLBlock(new_graph_index, new_ntypes, g.etypes) new_graph = DGLBlock(new_graph_index, new_ntypes, g.etypes)
assert new_graph.is_unibipartite # sanity check assert new_graph.is_unibipartite # sanity check
src_node_ids = [F.from_dgl_nd(src) for src in src_nodes_nd] src_node_ids = [F.from_dgl_nd(src) for src in src_nodes_ids_nd]
edge_ids = [F.from_dgl_nd(eid) for eid in induced_edges_nd] edge_ids = [F.from_dgl_nd(eid) for eid in induced_edges_nd]
node_frames = utils.extract_node_subframes_for_block(g, src_node_ids, dst_node_ids) node_frames = utils.extract_node_subframes_for_block(g, src_node_ids, dst_node_ids)
......
/*! /*!
* Copyright (c) 2020 by Contributors * Copyright 2020-2021 Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file graph/transform/cuda_to_block.cu * \file graph/transform/cuda_to_block.cu
* \brief Functions to convert a set of edges into a graph block with local * \brief Functions to convert a set of edges into a graph block with local
* ids. * ids.
...@@ -202,7 +215,8 @@ class DeviceNodeMapMaker { ...@@ -202,7 +215,8 @@ class DeviceNodeMapMaker {
/** /**
* \brief This function builds node maps for each node type, preserving the * \brief This function builds node maps for each node type, preserving the
* order of the input nodes. * order of the input nodes. Here it is assumed the lhs_nodes are not unique,
* and thus a unique list is generated.
* *
* \param lhs_nodes The set of source input nodes. * \param lhs_nodes The set of source input nodes.
* \param rhs_nodes The set of destination input nodes. * \param rhs_nodes The set of destination input nodes.
...@@ -254,6 +268,49 @@ class DeviceNodeMapMaker { ...@@ -254,6 +268,49 @@ class DeviceNodeMapMaker {
} }
} }
/**
* \brief This function builds node maps for each node type, preserving the
* order of the input nodes. Here it is assumed both lhs_nodes and rhs_nodes
* are unique.
*
* \param lhs_nodes The set of source input nodes.
* \param rhs_nodes The set of destination input nodes.
* \param node_maps The node maps to be constructed.
* \param stream The stream to operate on.
*/
void Make(
const std::vector<IdArray>& lhs_nodes,
const std::vector<IdArray>& rhs_nodes,
DeviceNodeMap<IdType> * const node_maps,
cudaStream_t stream) {
const int64_t num_ntypes = lhs_nodes.size() + rhs_nodes.size();
// unique lhs nodes
const int64_t lhs_num_ntypes = static_cast<int64_t>(lhs_nodes.size());
for (int64_t ntype = 0; ntype < lhs_num_ntypes; ++ntype) {
const IdArray& nodes = lhs_nodes[ntype];
if (nodes->shape[0] > 0) {
CHECK_EQ(nodes->ctx.device_type, kDLGPU);
node_maps->LhsHashTable(ntype).FillWithUnique(
nodes.Ptr<IdType>(),
nodes->shape[0],
stream);
}
}
// unique rhs nodes
const int64_t rhs_num_ntypes = static_cast<int64_t>(rhs_nodes.size());
for (int64_t ntype = 0; ntype < rhs_num_ntypes; ++ntype) {
const IdArray& nodes = rhs_nodes[ntype];
if (nodes->shape[0] > 0) {
node_maps->RhsHashTable(ntype).FillWithUnique(
nodes.Ptr<IdType>(),
nodes->shape[0],
stream);
}
}
}
private: private:
IdType max_num_nodes_; IdType max_num_nodes_;
}; };
...@@ -323,11 +380,15 @@ MapEdges( ...@@ -323,11 +380,15 @@ MapEdges(
// Since partial specialization is not allowed for functions, use this as an // Since partial specialization is not allowed for functions, use this as an
// intermediate for ToBlock where XPU = kDLGPU. // intermediate for ToBlock where XPU = kDLGPU.
template<typename IdType> template<typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>>
ToBlockGPU( ToBlockGPU(
HeteroGraphPtr graph, HeteroGraphPtr graph,
const std::vector<IdArray> &rhs_nodes, const std::vector<IdArray> &rhs_nodes,
const bool include_rhs_in_lhs) { const bool include_rhs_in_lhs,
std::vector<IdArray>* const lhs_nodes_ptr) {
std::vector<IdArray>& lhs_nodes = *lhs_nodes_ptr;
const bool generate_lhs_nodes = lhs_nodes.empty();
cudaStream_t stream = 0; cudaStream_t stream = 0;
const auto& ctx = graph->Context(); const auto& ctx = graph->Context();
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
...@@ -363,80 +424,121 @@ ToBlockGPU( ...@@ -363,80 +424,121 @@ ToBlockGPU(
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
maxNodesPerType[ntype+num_ntypes] += rhs_nodes[ntype]->shape[0]; maxNodesPerType[ntype+num_ntypes] += rhs_nodes[ntype]->shape[0];
if (include_rhs_in_lhs) { if (generate_lhs_nodes) {
maxNodesPerType[ntype] += rhs_nodes[ntype]->shape[0]; if (include_rhs_in_lhs) {
maxNodesPerType[ntype] += rhs_nodes[ntype]->shape[0];
}
} else {
maxNodesPerType[ntype] += lhs_nodes[ntype]->shape[0];
} }
} }
for (int64_t etype = 0; etype < num_etypes; ++etype) { if (generate_lhs_nodes) {
const auto src_dst_types = graph->GetEndpointTypes(etype); // we don't have lhs_nodes, see we need to count inbound edges to get an
const dgl_type_t srctype = src_dst_types.first; // upper bound
if (edge_arrays[etype].src.defined()) { for (int64_t etype = 0; etype < num_etypes; ++etype) {
maxNodesPerType[srctype] += edge_arrays[etype].src->shape[0]; const auto src_dst_types = graph->GetEndpointTypes(etype);
const dgl_type_t srctype = src_dst_types.first;
if (edge_arrays[etype].src.defined()) {
maxNodesPerType[srctype] += edge_arrays[etype].src->shape[0];
}
} }
} }
// gather lhs_nodes // gather lhs_nodes
std::vector<int64_t> src_node_offsets(num_ntypes, 0);
std::vector<IdArray> src_nodes(num_ntypes); std::vector<IdArray> src_nodes(num_ntypes);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { if (generate_lhs_nodes) {
src_nodes[ntype] = NewIdArray(maxNodesPerType[ntype], ctx, std::vector<int64_t> src_node_offsets(num_ntypes, 0);
sizeof(IdType)*8); for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
if (include_rhs_in_lhs) { src_nodes[ntype] = NewIdArray(maxNodesPerType[ntype], ctx,
// place rhs nodes first sizeof(IdType)*8);
device->CopyDataFromTo(rhs_nodes[ntype].Ptr<IdType>(), 0, if (include_rhs_in_lhs) {
src_nodes[ntype].Ptr<IdType>(), src_node_offsets[ntype], // place rhs nodes first
sizeof(IdType)*rhs_nodes[ntype]->shape[0], device->CopyDataFromTo(rhs_nodes[ntype].Ptr<IdType>(), 0,
rhs_nodes[ntype]->ctx, src_nodes[ntype]->ctx, src_nodes[ntype].Ptr<IdType>(), src_node_offsets[ntype],
rhs_nodes[ntype]->dtype, sizeof(IdType)*rhs_nodes[ntype]->shape[0],
stream); rhs_nodes[ntype]->ctx, src_nodes[ntype]->ctx,
src_node_offsets[ntype] += sizeof(IdType)*rhs_nodes[ntype]->shape[0]; rhs_nodes[ntype]->dtype,
stream);
src_node_offsets[ntype] += sizeof(IdType)*rhs_nodes[ntype]->shape[0];
}
} }
} for (int64_t etype = 0; etype < num_etypes; ++etype) {
for (int64_t etype = 0; etype < num_etypes; ++etype) { const auto src_dst_types = graph->GetEndpointTypes(etype);
const auto src_dst_types = graph->GetEndpointTypes(etype); const dgl_type_t srctype = src_dst_types.first;
const dgl_type_t srctype = src_dst_types.first; if (edge_arrays[etype].src.defined()) {
if (edge_arrays[etype].src.defined()) { device->CopyDataFromTo(
device->CopyDataFromTo( edge_arrays[etype].src.Ptr<IdType>(), 0,
edge_arrays[etype].src.Ptr<IdType>(), 0, src_nodes[srctype].Ptr<IdType>(),
src_nodes[srctype].Ptr<IdType>(), src_node_offsets[srctype],
src_node_offsets[srctype], sizeof(IdType)*edge_arrays[etype].src->shape[0],
sizeof(IdType)*edge_arrays[etype].src->shape[0], rhs_nodes[srctype]->ctx,
rhs_nodes[srctype]->ctx, src_nodes[srctype]->ctx,
src_nodes[srctype]->ctx, rhs_nodes[srctype]->dtype,
rhs_nodes[srctype]->dtype, stream);
stream);
src_node_offsets[srctype] += sizeof(IdType)*edge_arrays[etype].src->shape[0];
src_node_offsets[srctype] += sizeof(IdType)*edge_arrays[etype].src->shape[0]; }
}
} else {
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
src_nodes[ntype] = lhs_nodes[ntype];
} }
} }
// allocate space for map creation process // allocate space for map creation process
DeviceNodeMapMaker<IdType> maker(maxNodesPerType); DeviceNodeMapMaker<IdType> maker(maxNodesPerType);
DeviceNodeMap<IdType> node_maps(maxNodesPerType, ctx, stream); DeviceNodeMap<IdType> node_maps(maxNodesPerType, ctx, stream);
int64_t total_lhs = 0; if (generate_lhs_nodes) {
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { lhs_nodes.reserve(num_ntypes);
total_lhs += maxNodesPerType[ntype]; for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
lhs_nodes.emplace_back(NewIdArray(
maxNodesPerType[ntype], ctx, sizeof(IdType)*8));
}
} }
std::vector<IdArray> lhs_nodes; std::vector<int64_t> num_nodes_per_type(num_ntypes*2);
lhs_nodes.reserve(num_ntypes); // populate RHS nodes from what we already know
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
lhs_nodes.emplace_back(NewIdArray( num_nodes_per_type[num_ntypes+ntype] = rhs_nodes[ntype]->shape[0];
maxNodesPerType[ntype], ctx, sizeof(IdType)*8));
} }
// populate the mappings // populate the mappings
int64_t * count_lhs_device = static_cast<int64_t*>( if (generate_lhs_nodes) {
device->AllocWorkspace(ctx, sizeof(int64_t)*num_ntypes*2)); int64_t * count_lhs_device = static_cast<int64_t*>(
maker.Make( device->AllocWorkspace(ctx, sizeof(int64_t)*num_ntypes*2));
src_nodes,
rhs_nodes, maker.Make(
&node_maps, src_nodes,
count_lhs_device, rhs_nodes,
&lhs_nodes, &node_maps,
stream); count_lhs_device,
&lhs_nodes,
stream);
device->CopyDataFromTo(
count_lhs_device, 0,
num_nodes_per_type.data(), 0,
sizeof(*num_nodes_per_type.data())*num_ntypes,
ctx,
DGLContext{kDLCPU, 0},
DGLType{kDLInt, 64, 1},
stream);
device->StreamSync(ctx, stream);
// wait for the node counts to finish transferring
device->FreeWorkspace(ctx, count_lhs_device);
} else {
maker.Make(
lhs_nodes,
rhs_nodes,
&node_maps,
stream);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
num_nodes_per_type[ntype] = lhs_nodes[ntype]->shape[0];
}
}
std::vector<IdArray> induced_edges; std::vector<IdArray> induced_edges;
induced_edges.reserve(num_etypes); induced_edges.reserve(num_etypes);
...@@ -460,32 +562,16 @@ ToBlockGPU( ...@@ -460,32 +562,16 @@ ToBlockGPU(
std::vector<HeteroGraphPtr> rel_graphs; std::vector<HeteroGraphPtr> rel_graphs;
rel_graphs.reserve(num_etypes); rel_graphs.reserve(num_etypes);
std::vector<int64_t> num_nodes_per_type(num_ntypes*2);
// populate RHS nodes from what we already know
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
num_nodes_per_type[num_ntypes+ntype] = rhs_nodes[ntype]->shape[0];
}
device->CopyDataFromTo(
count_lhs_device, 0,
num_nodes_per_type.data(), 0,
sizeof(*num_nodes_per_type.data())*num_ntypes,
ctx,
DGLContext{kDLCPU, 0},
DGLType{kDLInt, 64, 1},
stream);
device->StreamSync(ctx, stream);
// wait for the node counts to finish transferring
device->FreeWorkspace(ctx, count_lhs_device);
// map node numberings from global to local, and build pointer for CSR // map node numberings from global to local, and build pointer for CSR
std::vector<IdArray> new_lhs; std::vector<IdArray> new_lhs;
std::vector<IdArray> new_rhs; std::vector<IdArray> new_rhs;
std::tie(new_lhs, new_rhs) = MapEdges(graph, edge_arrays, node_maps, stream); std::tie(new_lhs, new_rhs) = MapEdges(graph, edge_arrays, node_maps, stream);
// resize lhs nodes // resize lhs nodes
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { if (generate_lhs_nodes) {
lhs_nodes[ntype]->shape[0] = num_nodes_per_type[ntype]; for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
lhs_nodes[ntype]->shape[0] = num_nodes_per_type[ntype];
}
} }
// build the heterograph // build the heterograph
...@@ -514,27 +600,29 @@ ToBlockGPU( ...@@ -514,27 +600,29 @@ ToBlockGPU(
new_meta_graph, rel_graphs, num_nodes_per_type); new_meta_graph, rel_graphs, num_nodes_per_type);
// return the new graph, the new src nodes, and new edges // return the new graph, the new src nodes, and new edges
return std::make_tuple(new_graph, lhs_nodes, induced_edges); return std::make_tuple(new_graph, induced_edges);
} }
} // namespace } // namespace
template<> template<>
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>>
ToBlock<kDLGPU, int32_t>( ToBlock<kDLGPU, int32_t>(
HeteroGraphPtr graph, HeteroGraphPtr graph,
const std::vector<IdArray> &rhs_nodes, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs) { bool include_rhs_in_lhs,
return ToBlockGPU<int32_t>(graph, rhs_nodes, include_rhs_in_lhs); std::vector<IdArray>* const lhs_nodes) {
return ToBlockGPU<int32_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
} }
template<> template<>
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>>
ToBlock<kDLGPU, int64_t>( ToBlock<kDLGPU, int64_t>(
HeteroGraphPtr graph, HeteroGraphPtr graph,
const std::vector<IdArray> &rhs_nodes, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs) { bool include_rhs_in_lhs,
return ToBlockGPU<int64_t>(graph, rhs_nodes, include_rhs_in_lhs); std::vector<IdArray>* const lhs_nodes) {
return ToBlockGPU<int64_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
} }
} // namespace transform } // namespace transform
......
/*! /*!
* Copyright (c) 2019 by Contributors * Copyright 2019-2021 Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file graph/transform/to_bipartite.cc * \file graph/transform/to_bipartite.cc
* \brief Convert a graph to a bipartite-structured graph. * \brief Convert a graph to a bipartite-structured graph.
*/ */
...@@ -15,8 +28,7 @@ ...@@ -15,8 +28,7 @@
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <vector> #include <vector>
#include <tuple> #include <tuple>
// TODO(BarclayII): currently ToBlock depend on IdHashMap<IdType> implementation which #include <utility>
// only works on CPU. Should fix later to make it device agnostic.
#include "../../array/cpu/array_utils.h" #include "../../array/cpu/array_utils.h"
namespace dgl { namespace dgl {
...@@ -31,8 +43,12 @@ namespace { ...@@ -31,8 +43,12 @@ namespace {
// Since partial specialization is not allowed for functions, use this as an // Since partial specialization is not allowed for functions, use this as an
// intermediate for ToBlock where XPU = kDLCPU. // intermediate for ToBlock where XPU = kDLCPU.
template<typename IdType> template<typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>>
ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs) { ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes_ptr) {
std::vector<IdArray>& lhs_nodes = *lhs_nodes_ptr;
const bool generate_lhs_nodes = lhs_nodes.empty();
const int64_t num_etypes = graph->NumEdgeTypes(); const int64_t num_etypes = graph->NumEdgeTypes();
const int64_t num_ntypes = graph->NumVertexTypes(); const int64_t num_ntypes = graph->NumVertexTypes();
std::vector<EdgeArray> edge_arrays(num_etypes); std::vector<EdgeArray> edge_arrays(num_etypes);
...@@ -43,13 +59,16 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool inc ...@@ -43,13 +59,16 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool inc
const std::vector<IdHashMap<IdType>> rhs_node_mappings(rhs_nodes.begin(), rhs_nodes.end()); const std::vector<IdHashMap<IdType>> rhs_node_mappings(rhs_nodes.begin(), rhs_nodes.end());
std::vector<IdHashMap<IdType>> lhs_node_mappings; std::vector<IdHashMap<IdType>> lhs_node_mappings;
if (include_rhs_in_lhs) if (generate_lhs_nodes) {
lhs_node_mappings = rhs_node_mappings; // copy // build lhs_node_mappings -- if we don't have them already
else if (include_rhs_in_lhs)
lhs_node_mappings.resize(num_ntypes); lhs_node_mappings = rhs_node_mappings; // copy
else
lhs_node_mappings.resize(num_ntypes);
} else {
lhs_node_mappings = std::vector<IdHashMap<IdType>>(lhs_nodes.begin(), lhs_nodes.end());
}
std::vector<int64_t> num_nodes_per_type;
num_nodes_per_type.reserve(2 * num_ntypes);
for (int64_t etype = 0; etype < num_etypes; ++etype) { for (int64_t etype = 0; etype < num_etypes; ++etype) {
const auto src_dst_types = graph->GetEndpointTypes(etype); const auto src_dst_types = graph->GetEndpointTypes(etype);
...@@ -57,11 +76,16 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool inc ...@@ -57,11 +76,16 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool inc
const dgl_type_t dsttype = src_dst_types.second; const dgl_type_t dsttype = src_dst_types.second;
if (!aten::IsNullArray(rhs_nodes[dsttype])) { if (!aten::IsNullArray(rhs_nodes[dsttype])) {
const EdgeArray& edges = graph->Edges(etype); const EdgeArray& edges = graph->Edges(etype);
lhs_node_mappings[srctype].Update(edges.src); if (generate_lhs_nodes) {
lhs_node_mappings[srctype].Update(edges.src);
}
edge_arrays[etype] = edges; edge_arrays[etype] = edges;
} }
} }
std::vector<int64_t> num_nodes_per_type;
num_nodes_per_type.reserve(2 * num_ntypes);
const auto meta_graph = graph->meta_graph(); const auto meta_graph = graph->meta_graph();
const EdgeArray etypes = meta_graph->Edges("eid"); const EdgeArray etypes = meta_graph->Edges("eid");
const IdArray new_dst = Add(etypes.dst, num_ntypes); const IdArray new_dst = Add(etypes.dst, num_ntypes);
...@@ -105,28 +129,34 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool inc ...@@ -105,28 +129,34 @@ ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool inc
const HeteroGraphPtr new_graph = CreateHeteroGraph( const HeteroGraphPtr new_graph = CreateHeteroGraph(
new_meta_graph, rel_graphs, num_nodes_per_type); new_meta_graph, rel_graphs, num_nodes_per_type);
std::vector<IdArray> lhs_nodes;
for (const IdHashMap<IdType> &lhs_map : lhs_node_mappings) if (generate_lhs_nodes) {
lhs_nodes.push_back(lhs_map.Values()); CHECK_EQ(lhs_nodes.size(), 0) << "InteralError: lhs_nodes should be empty "
return std::make_tuple(new_graph, lhs_nodes, induced_edges); "when generating it.";
for (const IdHashMap<IdType> &lhs_map : lhs_node_mappings)
lhs_nodes.push_back(lhs_map.Values());
}
return std::make_tuple(new_graph, induced_edges);
} }
} // namespace } // namespace
template<> template<>
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>>
ToBlock<kDLCPU, int32_t>(HeteroGraphPtr graph, ToBlock<kDLCPU, int32_t>(HeteroGraphPtr graph,
const std::vector<IdArray> &rhs_nodes, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs) { bool include_rhs_in_lhs,
return ToBlockCPU<int32_t>(graph, rhs_nodes, include_rhs_in_lhs); std::vector<IdArray>* const lhs_nodes) {
return ToBlockCPU<int32_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
} }
template<> template<>
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>>
ToBlock<kDLCPU, int64_t>(HeteroGraphPtr graph, ToBlock<kDLCPU, int64_t>(HeteroGraphPtr graph,
const std::vector<IdArray> &rhs_nodes, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs) { bool include_rhs_in_lhs,
return ToBlockCPU<int64_t>(graph, rhs_nodes, include_rhs_in_lhs); std::vector<IdArray>* const lhs_nodes) {
return ToBlockCPU<int64_t>(graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes);
} }
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBlock") DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBlock")
...@@ -134,15 +164,16 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBlock") ...@@ -134,15 +164,16 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBlock")
const HeteroGraphRef graph_ref = args[0]; const HeteroGraphRef graph_ref = args[0];
const std::vector<IdArray> &rhs_nodes = ListValueToVector<IdArray>(args[1]); const std::vector<IdArray> &rhs_nodes = ListValueToVector<IdArray>(args[1]);
const bool include_rhs_in_lhs = args[2]; const bool include_rhs_in_lhs = args[2];
std::vector<IdArray> lhs_nodes = ListValueToVector<IdArray>(args[3]);
HeteroGraphPtr new_graph; HeteroGraphPtr new_graph;
std::vector<IdArray> lhs_nodes;
std::vector<IdArray> induced_edges; std::vector<IdArray> induced_edges;
ATEN_XPU_SWITCH_CUDA(graph_ref->Context().device_type, XPU, "ToBlock", { ATEN_XPU_SWITCH_CUDA(graph_ref->Context().device_type, XPU, "ToBlock", {
ATEN_ID_TYPE_SWITCH(graph_ref->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph_ref->DataType(), IdType, {
std::tie(new_graph, lhs_nodes, induced_edges) = ToBlock<XPU, IdType>( std::tie(new_graph, induced_edges) = ToBlock<XPU, IdType>(
graph_ref.sptr(), rhs_nodes, include_rhs_in_lhs); graph_ref.sptr(), rhs_nodes, include_rhs_in_lhs,
&lhs_nodes);
}); });
}); });
......
/*! /*!
* Copyright (c) 2021 by Contributors * Copyright 2021 Contributors
* \file graph/transform/to_bipartite.h *
* \brief Array operator templates * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file graph/transform/cuda_to_block.cu
* \brief Functions to convert a set of edges into a graph block with local
* ids.
*/ */
#ifndef DGL_GRAPH_TRANSFORM_TO_BIPARTITE_H_ #ifndef DGL_GRAPH_TRANSFORM_TO_BIPARTITE_H_
...@@ -16,10 +30,24 @@ ...@@ -16,10 +30,24 @@
namespace dgl { namespace dgl {
namespace transform { namespace transform {
/**
* @brief Create a graph block from the set of
* src and dst nodes (lhs and rhs respectively).
*
* @tparam XPU The type of device to operate on.
* @tparam IdType The type to use as an index.
* @param graph The graph from which to extract the block.
* @param rhs_nodes The destination nodes of the block.
* @param include_rhs_in_lhs Whether or not to include the
* destination nodes of the block in the sources nodes.
* @param [in/out] lhs_nodes The source nodes of the block.
*
* @return The block and the induced edges.
*/
template<DLDeviceType XPU, typename IdType> template<DLDeviceType XPU, typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> std::tuple<HeteroGraphPtr, std::vector<IdArray>>
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs); bool include_rhs_in_lhs, std::vector<IdArray>* lhs_nodes);
} // namespace transform } // namespace transform
} // namespace dgl } // namespace dgl
......
##
# Copyright 2019-2021 Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from scipy import sparse as spsp from scipy import sparse as spsp
import networkx as nx import networkx as nx
import numpy as np import numpy as np
...@@ -935,6 +951,32 @@ def test_to_block(idtype): ...@@ -935,6 +951,32 @@ def test_to_block(idtype):
checkall(g, bg, dst_nodes) checkall(g, bg, dst_nodes)
check_features(g, bg) check_features(g, bg)
# test specifying lhs_nodes with include_dst_in_src
src_nodes = {}
for ntype in dst_nodes.keys():
# use the previous run to get the list of source nodes
src_nodes[ntype] = bg.srcnodes[ntype].data[dgl.NID]
bg = dgl.to_block(g, dst_nodes=dst_nodes, src_nodes=src_nodes)
checkall(g, bg, dst_nodes)
check_features(g, bg)
# test without include_dst_in_src
dst_nodes = {'A': F.tensor([4, 3, 2, 1], dtype=idtype), 'B': F.tensor([3, 5, 6, 1], dtype=idtype)}
bg = dgl.to_block(g, dst_nodes=dst_nodes, include_dst_in_src=False)
checkall(g, bg, dst_nodes, False)
check_features(g, bg)
# test specifying lhs_nodes without include_dst_in_src
src_nodes = {}
for ntype in dst_nodes.keys():
# use the previous run to get the list of source nodes
src_nodes[ntype] = bg.srcnodes[ntype].data[dgl.NID]
bg = dgl.to_block(g, dst_nodes=dst_nodes, include_dst_in_src=False,
src_nodes=src_nodes)
checkall(g, bg, dst_nodes, False)
check_features(g, bg)
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented") @unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented")
@parametrize_dtype @parametrize_dtype
def test_remove_edges(idtype): def test_remove_edges(idtype):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment