Unverified Commit 11d12f3c authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Refactor] Extract common code in gpu and cpu ToBLock (#5305)

parent 2238386a
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
* @file graph/transform/cuda/cuda_to_block.cu * @file graph/transform/cuda/cuda_to_block.cu
* @brief Functions to convert a set of edges into a graph block with local * @brief Functions to convert a set of edges into a graph block with local
* ids. * ids.
*
* Tested via python wrapper: python/dgl/path/to/to_block.py
*/ */
#include <cuda_runtime.h> #include <cuda_runtime.h>
...@@ -28,7 +30,7 @@ ...@@ -28,7 +30,7 @@
#include "../../../runtime/cuda/cuda_common.h" #include "../../../runtime/cuda/cuda_common.h"
#include "../../heterograph.h" #include "../../heterograph.h"
#include "../to_bipartite.h" #include "../to_block.h"
#include "cuda_map_edges.cuh" #include "cuda_map_edges.cuh"
using namespace dgl::aten; using namespace dgl::aten;
...@@ -137,209 +139,74 @@ class DeviceNodeMapMaker { ...@@ -137,209 +139,74 @@ class DeviceNodeMapMaker {
IdType max_num_nodes_; IdType max_num_nodes_;
}; };
// Since partial specialization is not allowed for functions, use this as an
// intermediate for ToBlock where XPU = kDGLCUDA.
template <typename IdType> template <typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockGPU( struct CUDAIdsMapper {
HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes, std::tuple<std::vector<IdArray>, std::vector<IdArray>> operator()(
const bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes_ptr) { const HeteroGraphPtr& graph, bool include_rhs_in_lhs, int64_t num_ntypes,
std::vector<IdArray>& lhs_nodes = *lhs_nodes_ptr; const DGLContext& ctx, const std::vector<int64_t>& maxNodesPerType,
const bool generate_lhs_nodes = lhs_nodes.empty(); const std::vector<EdgeArray>& edge_arrays,
const std::vector<IdArray>& src_nodes,
const auto& ctx = graph->Context(); const std::vector<IdArray>& rhs_nodes,
auto device = runtime::DeviceAPI::Get(ctx); std::vector<IdArray>* const lhs_nodes_ptr,
cudaStream_t stream = runtime::getCurrentCUDAStream(); std::vector<int64_t>* const num_nodes_per_type_ptr) {
std::vector<IdArray>& lhs_nodes = *lhs_nodes_ptr;
CHECK_EQ(ctx.device_type, kDGLCUDA); std::vector<int64_t>& num_nodes_per_type = *num_nodes_per_type_ptr;
for (const auto& nodes : rhs_nodes) { const bool generate_lhs_nodes = lhs_nodes.empty();
CHECK_EQ(ctx.device_type, nodes->ctx.device_type); auto device = runtime::DeviceAPI::Get(ctx);
} cudaStream_t stream = runtime::getCurrentCUDAStream();
// Since DST nodes are included in SRC nodes, a common requirement is to fetch // Allocate space for map creation process.
// the DST node features from the SRC nodes features. To avoid expensive DeviceNodeMapMaker<IdType> maker(maxNodesPerType);
// sparse lookup, the function assures that the DST nodes in both SRC and DST DeviceNodeMap<IdType> node_maps(maxNodesPerType, num_ntypes, ctx, stream);
// sets have the same ids. As a result, given the node feature tensor ``X`` of
// type ``utype``, the following code finds the corresponding DST node
// features of type ``vtype``:
const int64_t num_etypes = graph->NumEdgeTypes();
const int64_t num_ntypes = graph->NumVertexTypes();
CHECK(rhs_nodes.size() == static_cast<size_t>(num_ntypes))
<< "rhs_nodes not given for every node type";
std::vector<EdgeArray> edge_arrays(num_etypes);
for (int64_t etype = 0; etype < num_etypes; ++etype) {
const auto src_dst_types = graph->GetEndpointTypes(etype);
const dgl_type_t dsttype = src_dst_types.second;
if (!aten::IsNullArray(rhs_nodes[dsttype])) {
edge_arrays[etype] = graph->Edges(etype);
}
}
// count lhs and rhs nodes
std::vector<int64_t> maxNodesPerType(num_ntypes * 2, 0);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
maxNodesPerType[ntype + num_ntypes] += rhs_nodes[ntype]->shape[0];
if (generate_lhs_nodes) { if (generate_lhs_nodes) {
if (include_rhs_in_lhs) { lhs_nodes.reserve(num_ntypes);
maxNodesPerType[ntype] += rhs_nodes[ntype]->shape[0]; for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
} lhs_nodes.emplace_back(
} else { NewIdArray(maxNodesPerType[ntype], ctx, sizeof(IdType) * 8));
maxNodesPerType[ntype] += lhs_nodes[ntype]->shape[0];
}
}
if (generate_lhs_nodes) {
// we don't have lhs_nodes, see we need to count inbound edges to get an
// upper bound
for (int64_t etype = 0; etype < num_etypes; ++etype) {
const auto src_dst_types = graph->GetEndpointTypes(etype);
const dgl_type_t srctype = src_dst_types.first;
if (edge_arrays[etype].src.defined()) {
maxNodesPerType[srctype] += edge_arrays[etype].src->shape[0];
}
}
}
// gather lhs_nodes
std::vector<IdArray> src_nodes(num_ntypes);
if (generate_lhs_nodes) {
std::vector<int64_t> src_node_offsets(num_ntypes, 0);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
src_nodes[ntype] =
NewIdArray(maxNodesPerType[ntype], ctx, sizeof(IdType) * 8);
if (include_rhs_in_lhs) {
// place rhs nodes first
device->CopyDataFromTo(
rhs_nodes[ntype].Ptr<IdType>(), 0, src_nodes[ntype].Ptr<IdType>(),
src_node_offsets[ntype],
sizeof(IdType) * rhs_nodes[ntype]->shape[0], rhs_nodes[ntype]->ctx,
src_nodes[ntype]->ctx, rhs_nodes[ntype]->dtype);
src_node_offsets[ntype] += sizeof(IdType) * rhs_nodes[ntype]->shape[0];
}
}
for (int64_t etype = 0; etype < num_etypes; ++etype) {
const auto src_dst_types = graph->GetEndpointTypes(etype);
const dgl_type_t srctype = src_dst_types.first;
if (edge_arrays[etype].src.defined()) {
device->CopyDataFromTo(
edge_arrays[etype].src.Ptr<IdType>(), 0,
src_nodes[srctype].Ptr<IdType>(), src_node_offsets[srctype],
sizeof(IdType) * edge_arrays[etype].src->shape[0],
rhs_nodes[srctype]->ctx, src_nodes[srctype]->ctx,
rhs_nodes[srctype]->dtype);
src_node_offsets[srctype] +=
sizeof(IdType) * edge_arrays[etype].src->shape[0];
} }
} }
} else { // Populate the mappings.
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { if (generate_lhs_nodes) {
src_nodes[ntype] = lhs_nodes[ntype]; int64_t* count_lhs_device = static_cast<int64_t*>(
} device->AllocWorkspace(ctx, sizeof(int64_t) * num_ntypes * 2));
}
// allocate space for map creation process
DeviceNodeMapMaker<IdType> maker(maxNodesPerType);
DeviceNodeMap<IdType> node_maps(maxNodesPerType, num_ntypes, ctx, stream);
if (generate_lhs_nodes) {
lhs_nodes.reserve(num_ntypes);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
lhs_nodes.emplace_back(
NewIdArray(maxNodesPerType[ntype], ctx, sizeof(IdType) * 8));
}
}
std::vector<int64_t> num_nodes_per_type(num_ntypes * 2);
// populate RHS nodes from what we already know
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
num_nodes_per_type[num_ntypes + ntype] = rhs_nodes[ntype]->shape[0];
}
// populate the mappings
if (generate_lhs_nodes) {
int64_t* count_lhs_device = static_cast<int64_t*>(
device->AllocWorkspace(ctx, sizeof(int64_t) * num_ntypes * 2));
maker.Make(
src_nodes, rhs_nodes, &node_maps, count_lhs_device, &lhs_nodes, stream);
device->CopyDataFromTo(
count_lhs_device, 0, num_nodes_per_type.data(), 0,
sizeof(*num_nodes_per_type.data()) * num_ntypes, ctx,
DGLContext{kDGLCPU, 0}, DGLDataType{kDGLInt, 64, 1});
device->StreamSync(ctx, stream);
// wait for the node counts to finish transferring maker.Make(
device->FreeWorkspace(ctx, count_lhs_device); src_nodes, rhs_nodes, &node_maps, count_lhs_device, &lhs_nodes,
} else { stream);
maker.Make(lhs_nodes, rhs_nodes, &node_maps, stream);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) { device->CopyDataFromTo(
num_nodes_per_type[ntype] = lhs_nodes[ntype]->shape[0]; count_lhs_device, 0, num_nodes_per_type.data(), 0,
} sizeof(*num_nodes_per_type.data()) * num_ntypes, ctx,
} DGLContext{kDGLCPU, 0}, DGLDataType{kDGLInt, 64, 1});
device->StreamSync(ctx, stream);
std::vector<IdArray> induced_edges; // Wait for the node counts to finish transferring.
induced_edges.reserve(num_etypes); device->FreeWorkspace(ctx, count_lhs_device);
for (int64_t etype = 0; etype < num_etypes; ++etype) {
if (edge_arrays[etype].id.defined()) {
induced_edges.push_back(edge_arrays[etype].id);
} else { } else {
induced_edges.push_back( maker.Make(lhs_nodes, rhs_nodes, &node_maps, stream);
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx));
}
}
// build metagraph -- small enough to be done on CPU
const auto meta_graph = graph->meta_graph();
const EdgeArray etypes = meta_graph->Edges("eid");
const IdArray new_dst = Add(etypes.dst, num_ntypes);
const auto new_meta_graph =
ImmutableGraph::CreateFromCOO(num_ntypes * 2, etypes.src, new_dst);
// allocate vector for graph relations while GPU is busy for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
std::vector<HeteroGraphPtr> rel_graphs; num_nodes_per_type[ntype] = lhs_nodes[ntype]->shape[0];
rel_graphs.reserve(num_etypes); }
// map node numberings from global to local, and build pointer for CSR
std::vector<IdArray> new_lhs;
std::vector<IdArray> new_rhs;
std::tie(new_lhs, new_rhs) = MapEdges(graph, edge_arrays, node_maps, stream);
// resize lhs nodes
if (generate_lhs_nodes) {
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
lhs_nodes[ntype]->shape[0] = num_nodes_per_type[ntype];
} }
} // Resize lhs nodes.
if (generate_lhs_nodes) {
// build the heterograph for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
for (int64_t etype = 0; etype < num_etypes; ++etype) { lhs_nodes[ntype]->shape[0] = num_nodes_per_type[ntype];
const auto src_dst_types = graph->GetEndpointTypes(etype); }
const dgl_type_t srctype = src_dst_types.first;
const dgl_type_t dsttype = src_dst_types.second;
if (rhs_nodes[dsttype]->shape[0] == 0) {
// No rhs nodes are given for this edge type. Create an empty graph.
rel_graphs.push_back(CreateFromCOO(
2, lhs_nodes[srctype]->shape[0], rhs_nodes[dsttype]->shape[0],
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx),
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx)));
} else {
rel_graphs.push_back(CreateFromCOO(
2, lhs_nodes[srctype]->shape[0], rhs_nodes[dsttype]->shape[0],
new_lhs[etype], new_rhs[etype]));
} }
// Map node numberings from global to local, and build pointer for CSR.
return MapEdges(graph, edge_arrays, node_maps, stream);
} }
};
HeteroGraphPtr new_graph = template <typename IdType>
CreateHeteroGraph(new_meta_graph, rel_graphs, num_nodes_per_type); std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockGPU(
HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,
// return the new graph, the new src nodes, and new edges bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes_ptr) {
return std::make_tuple(new_graph, induced_edges); return dgl::transform::ProcessToBlock<IdType>(
graph, rhs_nodes, include_rhs_in_lhs, lhs_nodes_ptr,
CUDAIdsMapper<IdType>());
} }
} // namespace } // namespace
......
...@@ -13,17 +13,20 @@ ...@@ -13,17 +13,20 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
* *
* @file graph/transform/to_bipartite.cc * @file graph/transform/to_block.cc
* @brief Convert a graph to a bipartite-structured graph. * @brief Convert a graph to a bipartite-structured graph.
*
* Tested via python wrapper: python/dgl/path/to/to_block.py
*/ */
#include "to_bipartite.h" #include "to_block.h"
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/registry.h> #include <dgl/runtime/registry.h>
#include <dgl/transform.h> #include <dgl/transform.h>
...@@ -144,6 +147,174 @@ std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockCPU( ...@@ -144,6 +147,174 @@ std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlockCPU(
} // namespace } // namespace
template <typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> ProcessToBlock(
HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes_ptr,
IdsMapper &&ids_mapper) {
std::vector<IdArray> &lhs_nodes = *lhs_nodes_ptr;
const bool generate_lhs_nodes = lhs_nodes.empty();
const auto &ctx = graph->Context();
auto device = runtime::DeviceAPI::Get(ctx);
// Since DST nodes are included in SRC nodes, a common requirement is to fetch
// the DST node features from the SRC nodes features. To avoid expensive
// sparse lookup, the function assures that the DST nodes in both SRC and DST
// sets have the same ids. As a result, given the node feature tensor ``X`` of
// type ``utype``, the following code finds the corresponding DST node
// features of type ``vtype``:
const int64_t num_etypes = graph->NumEdgeTypes();
const int64_t num_ntypes = graph->NumVertexTypes();
CHECK(rhs_nodes.size() == static_cast<size_t>(num_ntypes))
<< "rhs_nodes not given for every node type";
std::vector<EdgeArray> edge_arrays(num_etypes);
for (int64_t etype = 0; etype < num_etypes; ++etype) {
const auto src_dst_types = graph->GetEndpointTypes(etype);
const dgl_type_t dsttype = src_dst_types.second;
if (!aten::IsNullArray(rhs_nodes[dsttype])) {
edge_arrays[etype] = graph->Edges(etype);
}
}
// Count lhs and rhs nodes.
std::vector<int64_t> maxNodesPerType(num_ntypes * 2, 0);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
maxNodesPerType[ntype + num_ntypes] += rhs_nodes[ntype]->shape[0];
if (generate_lhs_nodes) {
if (include_rhs_in_lhs) {
maxNodesPerType[ntype] += rhs_nodes[ntype]->shape[0];
}
} else {
maxNodesPerType[ntype] += lhs_nodes[ntype]->shape[0];
}
}
if (generate_lhs_nodes) {
// We don't have lhs_nodes, see we need to count inbound edges to get an
// upper bound.
for (int64_t etype = 0; etype < num_etypes; ++etype) {
const auto src_dst_types = graph->GetEndpointTypes(etype);
const dgl_type_t srctype = src_dst_types.first;
if (edge_arrays[etype].src.defined()) {
maxNodesPerType[srctype] += edge_arrays[etype].src->shape[0];
}
}
}
// Gather lhs_nodes.
std::vector<IdArray> src_nodes(num_ntypes);
if (generate_lhs_nodes) {
std::vector<int64_t> src_node_offsets(num_ntypes, 0);
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
src_nodes[ntype] =
NewIdArray(maxNodesPerType[ntype], ctx, sizeof(IdType) * 8);
if (include_rhs_in_lhs) {
// Place rhs nodes first.
device->CopyDataFromTo(
rhs_nodes[ntype].Ptr<IdType>(), 0, src_nodes[ntype].Ptr<IdType>(),
src_node_offsets[ntype],
sizeof(IdType) * rhs_nodes[ntype]->shape[0], rhs_nodes[ntype]->ctx,
src_nodes[ntype]->ctx, rhs_nodes[ntype]->dtype);
src_node_offsets[ntype] += sizeof(IdType) * rhs_nodes[ntype]->shape[0];
}
}
for (int64_t etype = 0; etype < num_etypes; ++etype) {
const auto src_dst_types = graph->GetEndpointTypes(etype);
const dgl_type_t srctype = src_dst_types.first;
if (edge_arrays[etype].src.defined()) {
device->CopyDataFromTo(
edge_arrays[etype].src.Ptr<IdType>(), 0,
src_nodes[srctype].Ptr<IdType>(), src_node_offsets[srctype],
sizeof(IdType) * edge_arrays[etype].src->shape[0],
rhs_nodes[srctype]->ctx, src_nodes[srctype]->ctx,
rhs_nodes[srctype]->dtype);
src_node_offsets[srctype] +=
sizeof(IdType) * edge_arrays[etype].src->shape[0];
}
}
} else {
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
src_nodes[ntype] = lhs_nodes[ntype];
}
}
std::vector<int64_t> num_nodes_per_type(num_ntypes * 2);
// Populate RHS nodes from what we already know.
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
num_nodes_per_type[num_ntypes + ntype] = rhs_nodes[ntype]->shape[0];
}
std::vector<IdArray> new_lhs;
std::vector<IdArray> new_rhs;
std::tie(new_lhs, new_rhs) = ids_mapper(
graph, include_rhs_in_lhs, num_ntypes, ctx, maxNodesPerType, edge_arrays,
src_nodes, rhs_nodes, lhs_nodes_ptr, &num_nodes_per_type);
std::vector<IdArray> induced_edges;
induced_edges.reserve(num_etypes);
for (int64_t etype = 0; etype < num_etypes; ++etype) {
if (edge_arrays[etype].id.defined()) {
induced_edges.push_back(edge_arrays[etype].id);
} else {
induced_edges.push_back(
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx));
}
}
// Build metagraph.
const auto meta_graph = graph->meta_graph();
const EdgeArray etypes = meta_graph->Edges("eid");
const IdArray new_dst = Add(etypes.dst, num_ntypes);
const auto new_meta_graph =
ImmutableGraph::CreateFromCOO(num_ntypes * 2, etypes.src, new_dst);
// Allocate vector for graph relations while GPU is busy.
std::vector<HeteroGraphPtr> rel_graphs;
rel_graphs.reserve(num_etypes);
// Build the heterograph.
for (int64_t etype = 0; etype < num_etypes; ++etype) {
const auto src_dst_types = graph->GetEndpointTypes(etype);
const dgl_type_t srctype = src_dst_types.first;
const dgl_type_t dsttype = src_dst_types.second;
if (rhs_nodes[dsttype]->shape[0] == 0) {
// No rhs nodes are given for this edge type. Create an empty graph.
rel_graphs.push_back(CreateFromCOO(
2, lhs_nodes[srctype]->shape[0], rhs_nodes[dsttype]->shape[0],
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx),
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx)));
} else {
rel_graphs.push_back(CreateFromCOO(
2, lhs_nodes[srctype]->shape[0], rhs_nodes[dsttype]->shape[0],
new_lhs[etype], new_rhs[etype]));
}
}
HeteroGraphPtr new_graph =
CreateHeteroGraph(new_meta_graph, rel_graphs, num_nodes_per_type);
// Return the new graph, the new src nodes, and new edges.
return std::make_tuple(new_graph, induced_edges);
}
template std::tuple<HeteroGraphPtr, std::vector<IdArray>>
ProcessToBlock<int32_t>(
HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes_ptr,
IdsMapper &&get_maping_ids);
template std::tuple<HeteroGraphPtr, std::vector<IdArray>>
ProcessToBlock<int64_t>(
HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
bool include_rhs_in_lhs, std::vector<IdArray> *const lhs_nodes_ptr,
IdsMapper &&get_maping_ids);
template <> template <>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock<kDGLCPU, int32_t>( std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock<kDGLCPU, int32_t>(
HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
......
...@@ -13,23 +13,34 @@ ...@@ -13,23 +13,34 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
* *
* @file graph/transform/to_bipartite.h * @file graph/transform/to_block.h
* @brief Functions to convert a set of edges into a graph block with local * @brief Functions to convert a set of edges into a graph block with local
* ids. * ids.
*/ */
#ifndef DGL_GRAPH_TRANSFORM_TO_BIPARTITE_H_ #ifndef DGL_GRAPH_TRANSFORM_TO_BLOCK_H_
#define DGL_GRAPH_TRANSFORM_TO_BIPARTITE_H_ #define DGL_GRAPH_TRANSFORM_TO_BLOCK_H_
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <functional>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
namespace dgl { namespace dgl {
namespace transform { namespace transform {
/** @brief Mapper used in block generation which maps left and right Id arrays
* in the original MFG to new arrays with continuous numbers.
*/
using IdsMapper =
std::function<std::tuple<std::vector<IdArray>, std::vector<IdArray>>(
const HeteroGraphPtr&, bool, int64_t, const DGLContext&,
const std::vector<int64_t>&, const std::vector<EdgeArray>&,
const std::vector<IdArray>&, const std::vector<IdArray>&,
std::vector<IdArray>* const, std::vector<int64_t>* const)>;
/** /**
* @brief Create a graph block from the set of * @brief Create a graph block from the set of
* src and dst nodes (lhs and rhs respectively). * src and dst nodes (lhs and rhs respectively).
...@@ -49,7 +60,27 @@ std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock( ...@@ -49,7 +60,27 @@ std::tuple<HeteroGraphPtr, std::vector<IdArray>> ToBlock(
HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes, HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,
bool include_rhs_in_lhs, std::vector<IdArray>* lhs_nodes); bool include_rhs_in_lhs, std::vector<IdArray>* lhs_nodes);
/**
* @brief A warpper function shared by CPU and GPU ```ToBlock```
* which deal with the common preprocess and postprocess work of them.
*
* @tparam IdType The type to use as an index.
* @param graph The graph from which to extract the block.
* @param rhs_nodes The destination nodes of the block.
* @param include_rhs_in_lhs Whether or not to include the
* destination nodes of the block in the sources nodes.
* @param [in/out] lhs_nodes The source nodes of the block.
* @param MappingIdsFunc The function to get mapped ids from original ids.
*
* @return The block and the induced edges.
*/
template <typename IdType>
std::tuple<HeteroGraphPtr, std::vector<IdArray>> ProcessToBlock(
HeteroGraphPtr graph, const std::vector<IdArray>& rhs_nodes,
bool include_rhs_in_lhs, std::vector<IdArray>* const lhs_nodes_ptr,
IdsMapper&& get_maping_ids);
} // namespace transform } // namespace transform
} // namespace dgl } // namespace dgl
#endif // DGL_GRAPH_TRANSFORM_TO_BIPARTITE_H_ #endif // DGL_GRAPH_TRANSFORM_TO_BLOCK_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment