Unverified Commit 81831111 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4805)



* [Misc] clang-format auto fix.

* fix
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 16e771c0
...@@ -4,9 +4,11 @@ ...@@ -4,9 +4,11 @@
* \brief Geometry operator CPU implementation * \brief Geometry operator CPU implementation
*/ */
#include <dgl/random.h> #include <dgl/random.h>
#include <numeric> #include <numeric>
#include <vector>
#include <utility> #include <utility>
#include <vector>
#include "../geometry_op.h" #include "../geometry_op.h"
namespace dgl { namespace dgl {
...@@ -25,8 +27,9 @@ void IndexShuffle(IdType *idxs, int64_t num_elems) { ...@@ -25,8 +27,9 @@ void IndexShuffle(IdType *idxs, int64_t num_elems) {
template void IndexShuffle<int32_t>(int32_t *idxs, int64_t num_elems); template void IndexShuffle<int32_t>(int32_t *idxs, int64_t num_elems);
template void IndexShuffle<int64_t>(int64_t *idxs, int64_t num_elems); template void IndexShuffle<int64_t>(int64_t *idxs, int64_t num_elems);
/*! \brief Groupwise index shuffle algorithm. This function will perform shuffle in subarrays /*! \brief Groupwise index shuffle algorithm. This function will perform shuffle
* indicated by group index. The group index is similar to indptr in CSRMatrix. * in subarrays indicated by group index. The group index is similar to indptr
* in CSRMatrix.
* *
* \param group_idxs group index array. * \param group_idxs group index array.
* \param idxs index array for shuffle. * \param idxs index array for shuffle.
...@@ -34,64 +37,73 @@ template void IndexShuffle<int64_t>(int64_t *idxs, int64_t num_elems); ...@@ -34,64 +37,73 @@ template void IndexShuffle<int64_t>(int64_t *idxs, int64_t num_elems);
* \param num_elems length of idxs * \param num_elems length of idxs
*/ */
template <typename IdType> template <typename IdType>
void GroupIndexShuffle(const IdType *group_idxs, IdType *idxs, void GroupIndexShuffle(
int64_t num_groups_idxs, int64_t num_elems) { const IdType *group_idxs, IdType *idxs, int64_t num_groups_idxs,
int64_t num_elems) {
if (num_groups_idxs < 2) return; // empty idxs array if (num_groups_idxs < 2) return; // empty idxs array
CHECK_LE(group_idxs[num_groups_idxs - 1], num_elems) << "group_idxs out of range"; CHECK_LE(group_idxs[num_groups_idxs - 1], num_elems)
<< "group_idxs out of range";
for (int64_t i = 0; i < num_groups_idxs - 1; ++i) { for (int64_t i = 0; i < num_groups_idxs - 1; ++i) {
auto subarray_len = group_idxs[i + 1] - group_idxs[i]; auto subarray_len = group_idxs[i + 1] - group_idxs[i];
IndexShuffle(idxs + group_idxs[i], subarray_len); IndexShuffle(idxs + group_idxs[i], subarray_len);
} }
} }
template void GroupIndexShuffle<int32_t>( template void GroupIndexShuffle<int32_t>(
const int32_t *group_idxs, int32_t *idxs, int64_t num_groups_idxs, int64_t num_elems); const int32_t *group_idxs, int32_t *idxs, int64_t num_groups_idxs,
int64_t num_elems);
template void GroupIndexShuffle<int64_t>( template void GroupIndexShuffle<int64_t>(
const int64_t *group_idxs, int64_t *idxs, int64_t num_groups_idxs, int64_t num_elems); const int64_t *group_idxs, int64_t *idxs, int64_t num_groups_idxs,
int64_t num_elems);
template <typename IdType> template <typename IdType>
IdArray RandomPerm(int64_t num_nodes) { IdArray RandomPerm(int64_t num_nodes) {
IdArray perm = aten::NewIdArray(num_nodes, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8); IdArray perm =
IdType* perm_data = static_cast<IdType*>(perm->data); aten::NewIdArray(num_nodes, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
IdType *perm_data = static_cast<IdType *>(perm->data);
std::iota(perm_data, perm_data + num_nodes, 0); std::iota(perm_data, perm_data + num_nodes, 0);
IndexShuffle(perm_data, num_nodes); IndexShuffle(perm_data, num_nodes);
return perm; return perm;
} }
template <typename IdType> template <typename IdType>
IdArray GroupRandomPerm(const IdType *group_idxs, int64_t num_group_idxs, int64_t num_nodes) { IdArray GroupRandomPerm(
IdArray perm = aten::NewIdArray(num_nodes, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8); const IdType *group_idxs, int64_t num_group_idxs, int64_t num_nodes) {
IdType* perm_data = static_cast<IdType*>(perm->data); IdArray perm =
aten::NewIdArray(num_nodes, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
IdType *perm_data = static_cast<IdType *>(perm->data);
std::iota(perm_data, perm_data + num_nodes, 0); std::iota(perm_data, perm_data + num_nodes, 0);
GroupIndexShuffle(group_idxs, perm_data, num_group_idxs, num_nodes); GroupIndexShuffle(group_idxs, perm_data, num_group_idxs, num_nodes);
return perm; return perm;
} }
/*! /*!
* \brief Farthest Point Sampler without the need to compute all pairs of distance. * \brief Farthest Point Sampler without the need to compute all pairs of
* distance.
* *
* The input array has shape (N, d), where N is the number of points, and d is the dimension. * The input array has shape (N, d), where N is the number of points, and d is
* It consists of a (flatten) batch of point clouds. * the dimension. It consists of a (flatten) batch of point clouds.
* *
* In each batch, the algorithm starts with the sample index specified by ``start_idx``. * In each batch, the algorithm starts with the sample index specified by
* Then for each point, we maintain the minimum to-sample distance. * ``start_idx``. Then for each point, we maintain the minimum to-sample
* Finally, we pick the point with the maximum such distance. * distance. Finally, we pick the point with the maximum such distance. This
* This process will be repeated for ``sample_points`` - 1 times. * process will be repeated for ``sample_points`` - 1 times.
*/ */
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points, void FarthestPointSampler(
NDArray dist, IdArray start_idx, IdArray result) { NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
const FloatType* array_data = static_cast<FloatType*>(array->data); IdArray start_idx, IdArray result) {
const FloatType *array_data = static_cast<FloatType *>(array->data);
const int64_t point_in_batch = array->shape[0] / batch_size; const int64_t point_in_batch = array->shape[0] / batch_size;
const int64_t dim = array->shape[1]; const int64_t dim = array->shape[1];
// distance // distance
FloatType* dist_data = static_cast<FloatType*>(dist->data); FloatType *dist_data = static_cast<FloatType *>(dist->data);
// sample for each cloud in the batch // sample for each cloud in the batch
IdType* start_idx_data = static_cast<IdType*>(start_idx->data); IdType *start_idx_data = static_cast<IdType *>(start_idx->data);
// return value // return value
IdType* ret_data = static_cast<IdType*>(result->data); IdType *ret_data = static_cast<IdType *>(result->data);
int64_t array_start = 0, ret_start = 0; int64_t array_start = 0, ret_start = 0;
// loop for each point cloud sample in this batch // loop for each point cloud sample in this batch
...@@ -136,29 +148,30 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin ...@@ -136,29 +148,30 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin
} }
} }
template void FarthestPointSampler<kDGLCPU, float, int32_t>( template void FarthestPointSampler<kDGLCPU, float, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
NDArray dist, IdArray start_idx, IdArray result); IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDGLCPU, float, int64_t>( template void FarthestPointSampler<kDGLCPU, float, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
NDArray dist, IdArray start_idx, IdArray result); IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDGLCPU, double, int32_t>( template void FarthestPointSampler<kDGLCPU, double, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
NDArray dist, IdArray start_idx, IdArray result); IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDGLCPU, double, int64_t>( template void FarthestPointSampler<kDGLCPU, double, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
NDArray dist, IdArray start_idx, IdArray result); IdArray start_idx, IdArray result);
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result) { void WeightedNeighborMatching(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result) {
const int64_t num_nodes = result->shape[0]; const int64_t num_nodes = result->shape[0];
const IdType *indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType*>(csr.indices->data); const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
IdType *result_data = static_cast<IdType*>(result->data); IdType *result_data = static_cast<IdType *>(result->data);
FloatType *weight_data = static_cast<FloatType*>(weight->data); FloatType *weight_data = static_cast<FloatType *>(weight->data);
// build node visiting order // build node visiting order
IdArray vis_order = RandomPerm<IdType>(num_nodes); IdArray vis_order = RandomPerm<IdType>(num_nodes);
IdType *vis_order_data = static_cast<IdType*>(vis_order->data); IdType *vis_order_data = static_cast<IdType *>(vis_order->data);
for (int64_t n = 0; n < num_nodes; ++n) { for (int64_t n = 0; n < num_nodes; ++n) {
auto u = vis_order_data[n]; auto u = vis_order_data[n];
...@@ -193,16 +206,16 @@ template void WeightedNeighborMatching<kDGLCPU, double, int64_t>( ...@@ -193,16 +206,16 @@ template void WeightedNeighborMatching<kDGLCPU, double, int64_t>(
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) { void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
const int64_t num_nodes = result->shape[0]; const int64_t num_nodes = result->shape[0];
const IdType *indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType*>(csr.indices->data); const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
IdType *result_data = static_cast<IdType*>(result->data); IdType *result_data = static_cast<IdType *>(result->data);
// build vis order // build vis order
IdArray u_vis_order = RandomPerm<IdType>(num_nodes); IdArray u_vis_order = RandomPerm<IdType>(num_nodes);
IdType *u_vis_order_data = static_cast<IdType*>(u_vis_order->data); IdType *u_vis_order_data = static_cast<IdType *>(u_vis_order->data);
IdArray v_vis_order = GroupRandomPerm<IdType>( IdArray v_vis_order = GroupRandomPerm<IdType>(
indptr_data, csr.indptr->shape[0], csr.indices->shape[0]); indptr_data, csr.indptr->shape[0], csr.indices->shape[0]);
IdType *v_vis_order_data = static_cast<IdType*>(v_vis_order->data); IdType *v_vis_order_data = static_cast<IdType *>(v_vis_order->data);
for (int64_t n = 0; n < num_nodes; ++n) { for (int64_t n = 0; n < num_nodes; ++n) {
auto u = u_vis_order_data[n]; auto u = u_vis_order_data[n];
...@@ -221,8 +234,10 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) { ...@@ -221,8 +234,10 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
} }
} }
} }
template void NeighborMatching<kDGLCPU, int32_t>(const aten::CSRMatrix &csr, IdArray result); template void NeighborMatching<kDGLCPU, int32_t>(
template void NeighborMatching<kDGLCPU, int64_t>(const aten::CSRMatrix &csr, IdArray result); const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDGLCPU, int64_t>(
const aten::CSRMatrix &csr, IdArray result);
} // namespace impl } // namespace impl
} // namespace geometry } // namespace geometry
......
...@@ -3,14 +3,16 @@ ...@@ -3,14 +3,16 @@
* \file geometry/cuda/edge_coarsening_impl.cu * \file geometry/cuda/edge_coarsening_impl.cu
* \brief Edge coarsening CUDA implementation * \brief Edge coarsening CUDA implementation
*/ */
#include <curand_kernel.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/random.h> #include <dgl/random.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <curand_kernel.h>
#include <cstdint> #include <cstdint>
#include "../geometry_op.h"
#include "../../runtime/cuda/cuda_common.h"
#include "../../array/cuda/utils.h" #include "../../array/cuda/utils.h"
#include "../../runtime/cuda/cuda_common.h"
#include "../geometry_op.h"
#define BLOCKS(N, T) (N + T - 1) / T #define BLOCKS(N, T) (N + T - 1) / T
...@@ -26,7 +28,8 @@ constexpr int EMPTY_IDX = -1; ...@@ -26,7 +28,8 @@ constexpr int EMPTY_IDX = -1;
__device__ bool done_d; __device__ bool done_d;
__global__ void init_done_kernel() { done_d = true; } __global__ void init_done_kernel() { done_d = true; }
__global__ void generate_uniform_kernel(float *ret_values, size_t num, uint64_t seed) { __global__ void generate_uniform_kernel(
float *ret_values, size_t num, uint64_t seed) {
size_t id = blockIdx.x * blockDim.x + threadIdx.x; size_t id = blockIdx.x * blockDim.x + threadIdx.x;
if (id < num) { if (id < num) {
curandState state; curandState state;
...@@ -36,7 +39,8 @@ __global__ void generate_uniform_kernel(float *ret_values, size_t num, uint64_t ...@@ -36,7 +39,8 @@ __global__ void generate_uniform_kernel(float *ret_values, size_t num, uint64_t
} }
template <typename IdType> template <typename IdType>
__global__ void colorize_kernel(const float *prop, int64_t num_elem, IdType *result) { __global__ void colorize_kernel(
const float *prop, int64_t num_elem, IdType *result) {
const IdType idx = blockIdx.x * blockDim.x + threadIdx.x; const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_elem) { if (idx < num_elem) {
if (result[idx] < 0) { // if unmatched if (result[idx] < 0) { // if unmatched
...@@ -47,9 +51,9 @@ __global__ void colorize_kernel(const float *prop, int64_t num_elem, IdType *res ...@@ -47,9 +51,9 @@ __global__ void colorize_kernel(const float *prop, int64_t num_elem, IdType *res
} }
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
__global__ void weighted_propose_kernel(const IdType *indptr, const IdType *indices, __global__ void weighted_propose_kernel(
const FloatType *weights, int64_t num_elem, const IdType *indptr, const IdType *indices, const FloatType *weights,
IdType *proposal, IdType *result) { int64_t num_elem, IdType *proposal, IdType *result) {
const IdType idx = blockIdx.x * blockDim.x + threadIdx.x; const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_elem) { if (idx < num_elem) {
if (result[idx] != BLUE) return; if (result[idx] != BLUE) return;
...@@ -61,8 +65,7 @@ __global__ void weighted_propose_kernel(const IdType *indptr, const IdType *indi ...@@ -61,8 +65,7 @@ __global__ void weighted_propose_kernel(const IdType *indptr, const IdType *indi
for (IdType i = indptr[idx]; i < indptr[idx + 1]; ++i) { for (IdType i = indptr[idx]; i < indptr[idx + 1]; ++i) {
auto v = indices[i]; auto v = indices[i];
if (result[v] < 0) if (result[v] < 0) has_unmatched_neighbor = true;
has_unmatched_neighbor = true;
if (result[v] == RED && weights[i] >= weight_max) { if (result[v] == RED && weights[i] >= weight_max) {
v_max = v; v_max = v;
weight_max = weights[i]; weight_max = weights[i];
...@@ -70,15 +73,14 @@ __global__ void weighted_propose_kernel(const IdType *indptr, const IdType *indi ...@@ -70,15 +73,14 @@ __global__ void weighted_propose_kernel(const IdType *indptr, const IdType *indi
} }
proposal[idx] = v_max; proposal[idx] = v_max;
if (!has_unmatched_neighbor) if (!has_unmatched_neighbor) result[idx] = idx;
result[idx] = idx;
} }
} }
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
__global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indices, __global__ void weighted_respond_kernel(
const FloatType *weights, int64_t num_elem, const IdType *indptr, const IdType *indices, const FloatType *weights,
IdType *proposal, IdType *result) { int64_t num_elem, IdType *proposal, IdType *result) {
const IdType idx = blockIdx.x * blockDim.x + threadIdx.x; const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_elem) { if (idx < num_elem) {
if (result[idx] != RED) return; if (result[idx] != RED) return;
...@@ -93,9 +95,7 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi ...@@ -93,9 +95,7 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi
if (result[v] < 0) { if (result[v] < 0) {
has_unmatched_neighbors = true; has_unmatched_neighbors = true;
} }
if (result[v] == BLUE if (result[v] == BLUE && proposal[v] == idx && weights[i] >= weight_max) {
&& proposal[v] == idx
&& weights[i] >= weight_max) {
v_max = v; v_max = v;
weight_max = weights[i]; weight_max = weights[i];
} }
...@@ -105,8 +105,7 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi ...@@ -105,8 +105,7 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi
result[idx] = min(idx, v_max); result[idx] = min(idx, v_max);
} }
if (!has_unmatched_neighbors) if (!has_unmatched_neighbors) result[idx] = idx;
result[idx] = idx;
} }
} }
...@@ -114,8 +113,8 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi ...@@ -114,8 +113,8 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi
* nodes with BLUE(-1) and RED(-2) and checks whether the node matching * nodes with BLUE(-1) and RED(-2) and checks whether the node matching
* process has finished. * process has finished.
*/ */
template<typename IdType> template <typename IdType>
bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) { bool Colorize(IdType *result_data, int64_t num_nodes, float *const prop) {
// initial done signal // initial done signal
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
CUDA_KERNEL_CALL(init_done_kernel, 1, 1, 0, stream); CUDA_KERNEL_CALL(init_done_kernel, 1, 1, 0, stream);
...@@ -124,14 +123,17 @@ bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) { ...@@ -124,14 +123,17 @@ bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) {
uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX); uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);
auto num_threads = cuda::FindNumThreads(num_nodes); auto num_threads = cuda::FindNumThreads(num_nodes);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads)); auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));
CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, stream, CUDA_KERNEL_CALL(
prop, num_nodes, seed); generate_uniform_kernel, num_blocks, num_threads, 0, stream, prop,
num_nodes, seed);
// call kernel // call kernel
CUDA_KERNEL_CALL(colorize_kernel, num_blocks, num_threads, 0, stream, CUDA_KERNEL_CALL(
prop, num_nodes, result_data); colorize_kernel, num_blocks, num_threads, 0, stream, prop, num_nodes,
result_data);
bool done_h = false; bool done_h = false;
CUDA_CALL(cudaMemcpyFromSymbol(&done_h, done_d, sizeof(done_h), 0, cudaMemcpyDeviceToHost)); CUDA_CALL(cudaMemcpyFromSymbol(
&done_h, done_d, sizeof(done_h), 0, cudaMemcpyDeviceToHost));
return done_h; return done_h;
} }
...@@ -151,9 +153,10 @@ bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) { ...@@ -151,9 +153,10 @@ bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) {
* pair and mark them with the smaller id between them. * pair and mark them with the smaller id between them.
*/ */
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result) { void WeightedNeighborMatching(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = result->ctx; const auto &ctx = result->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
device->SetDevice(ctx); device->SetDevice(ctx);
...@@ -162,23 +165,27 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, ...@@ -162,23 +165,27 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight,
IdArray proposal = aten::Full(-1, num_nodes, sizeof(IdType) * 8, ctx); IdArray proposal = aten::Full(-1, num_nodes, sizeof(IdType) * 8, ctx);
// get data ptrs // get data ptrs
IdType *indptr_data = static_cast<IdType*>(csr.indptr->data); IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
IdType *indices_data = static_cast<IdType*>(csr.indices->data); IdType *indices_data = static_cast<IdType *>(csr.indices->data);
IdType *result_data = static_cast<IdType*>(result->data); IdType *result_data = static_cast<IdType *>(result->data);
IdType *proposal_data = static_cast<IdType*>(proposal->data); IdType *proposal_data = static_cast<IdType *>(proposal->data);
FloatType *weight_data = static_cast<FloatType*>(weight->data); FloatType *weight_data = static_cast<FloatType *>(weight->data);
// allocate workspace for prop used in Colorize() // allocate workspace for prop used in Colorize()
float *prop = static_cast<float*>( float *prop = static_cast<float *>(
device->AllocWorkspace(ctx, num_nodes * sizeof(float))); device->AllocWorkspace(ctx, num_nodes * sizeof(float)));
auto num_threads = cuda::FindNumThreads(num_nodes); auto num_threads = cuda::FindNumThreads(num_nodes);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads)); auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));
while (!Colorize<IdType>(result_data, num_nodes, prop)) { while (!Colorize<IdType>(result_data, num_nodes, prop)) {
CUDA_KERNEL_CALL(weighted_propose_kernel, num_blocks, num_threads, 0, stream, CUDA_KERNEL_CALL(
indptr_data, indices_data, weight_data, num_nodes, proposal_data, result_data); weighted_propose_kernel, num_blocks, num_threads, 0, stream,
CUDA_KERNEL_CALL(weighted_respond_kernel, num_blocks, num_threads, 0, stream, indptr_data, indices_data, weight_data, num_nodes, proposal_data,
indptr_data, indices_data, weight_data, num_nodes, proposal_data, result_data); result_data);
CUDA_KERNEL_CALL(
weighted_respond_kernel, num_blocks, num_threads, 0, stream,
indptr_data, indices_data, weight_data, num_nodes, proposal_data,
result_data);
} }
device->FreeWorkspace(ctx, prop); device->FreeWorkspace(ctx, prop);
} }
...@@ -204,7 +211,7 @@ template void WeightedNeighborMatching<kDGLCUDA, double, int64_t>( ...@@ -204,7 +211,7 @@ template void WeightedNeighborMatching<kDGLCUDA, double, int64_t>(
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) { void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
const int64_t num_edges = csr.indices->shape[0]; const int64_t num_edges = csr.indices->shape[0];
const auto& ctx = result->ctx; const auto &ctx = result->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
device->SetDevice(ctx); device->SetDevice(ctx);
...@@ -212,17 +219,20 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) { ...@@ -212,17 +219,20 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
NDArray weight = NDArray::Empty( NDArray weight = NDArray::Empty(
{num_edges}, DGLDataType{kDGLFloat, sizeof(float) * 8, 1}, ctx); {num_edges}, DGLDataType{kDGLFloat, sizeof(float) * 8, 1}, ctx);
float *weight_data = static_cast<float*>(weight->data); float *weight_data = static_cast<float *>(weight->data);
uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX); uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);
auto num_threads = cuda::FindNumThreads(num_edges); auto num_threads = cuda::FindNumThreads(num_edges);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_edges, num_threads)); auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_edges, num_threads));
CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, stream, CUDA_KERNEL_CALL(
weight_data, num_edges, seed); generate_uniform_kernel, num_blocks, num_threads, 0, stream, weight_data,
num_edges, seed);
WeightedNeighborMatching<XPU, float, IdType>(csr, weight, result); WeightedNeighborMatching<XPU, float, IdType>(csr, weight, result);
} }
template void NeighborMatching<kDGLCUDA, int32_t>(const aten::CSRMatrix &csr, IdArray result); template void NeighborMatching<kDGLCUDA, int32_t>(
template void NeighborMatching<kDGLCUDA, int64_t>(const aten::CSRMatrix &csr, IdArray result); const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDGLCUDA, int64_t>(
const aten::CSRMatrix &csr, IdArray result);
} // namespace impl } // namespace impl
} // namespace geometry } // namespace geometry
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "../../c_api_common.h" #include "../../c_api_common.h"
#include "../../runtime/cuda/cuda_common.h"
#include "../geometry_op.h" #include "../geometry_op.h"
#define THREADS 1024 #define THREADS 1024
...@@ -16,21 +16,23 @@ namespace geometry { ...@@ -16,21 +16,23 @@ namespace geometry {
namespace impl { namespace impl {
/*! /*!
* \brief Farthest Point Sampler without the need to compute all pairs of distance. * \brief Farthest Point Sampler without the need to compute all pairs of
* distance.
* *
* The input array has shape (N, d), where N is the number of points, and d is the dimension. * The input array has shape (N, d), where N is the number of points, and d is
* It consists of a (flatten) batch of point clouds. * the dimension. It consists of a (flatten) batch of point clouds.
* *
* In each batch, the algorithm starts with the sample index specified by ``start_idx``. * In each batch, the algorithm starts with the sample index specified by
* Then for each point, we maintain the minimum to-sample distance. * ``start_idx``. Then for each point, we maintain the minimum to-sample
* Finally, we pick the point with the maximum such distance. * distance. Finally, we pick the point with the maximum such distance. This
* This process will be repeated for ``sample_points`` - 1 times. * process will be repeated for ``sample_points`` - 1 times.
*/ */
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
__global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size, __global__ void fps_kernel(
const FloatType* array_data, const int64_t batch_size,
const int64_t sample_points, const int64_t point_in_batch, const int64_t sample_points, const int64_t point_in_batch,
const int64_t dim, const IdType *start_idx, const int64_t dim, const IdType* start_idx, FloatType* dist_data,
FloatType *dist_data, IdType *ret_data) { IdType* ret_data) {
const int64_t thread_idx = threadIdx.x; const int64_t thread_idx = threadIdx.x;
const int64_t batch_idx = blockIdx.x; const int64_t batch_idx = blockIdx.x;
...@@ -90,8 +92,9 @@ __global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size ...@@ -90,8 +92,9 @@ __global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size
} }
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points, void FarthestPointSampler(
NDArray dist, IdArray start_idx, IdArray result) { NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
IdArray start_idx, IdArray result) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const FloatType* array_data = static_cast<FloatType*>(array->data); const FloatType* array_data = static_cast<FloatType*>(array->data);
...@@ -109,24 +112,23 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin ...@@ -109,24 +112,23 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin
IdType* start_idx_data = static_cast<IdType*>(start_idx->data); IdType* start_idx_data = static_cast<IdType*>(start_idx->data);
CUDA_CALL(cudaSetDevice(array->ctx.device_id)); CUDA_CALL(cudaSetDevice(array->ctx.device_id));
CUDA_KERNEL_CALL(fps_kernel, CUDA_KERNEL_CALL(
batch_size, THREADS, 0, stream, fps_kernel, batch_size, THREADS, 0, stream, array_data, batch_size,
array_data, batch_size, sample_points, sample_points, point_in_batch, dim, start_idx_data, dist_data, ret_data);
point_in_batch, dim, start_idx_data, dist_data, ret_data);
} }
template void FarthestPointSampler<kDGLCUDA, float, int32_t>( template void FarthestPointSampler<kDGLCUDA, float, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
NDArray dist, IdArray start_idx, IdArray result); IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDGLCUDA, float, int64_t>( template void FarthestPointSampler<kDGLCUDA, float, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
NDArray dist, IdArray start_idx, IdArray result); IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDGLCUDA, double, int32_t>( template void FarthestPointSampler<kDGLCUDA, double, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
NDArray dist, IdArray start_idx, IdArray result); IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDGLCUDA, double, int64_t>( template void FarthestPointSampler<kDGLCUDA, double, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
NDArray dist, IdArray start_idx, IdArray result); IdArray start_idx, IdArray result);
} // namespace impl } // namespace impl
} // namespace geometry } // namespace geometry
......
...@@ -4,28 +4,34 @@ ...@@ -4,28 +4,34 @@
* \brief DGL geometry utilities implementation * \brief DGL geometry utilities implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <dgl/runtime/ndarray.h>
#include "../array/check.h"
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./geometry_op.h" #include "./geometry_op.h"
#include "../array/check.h"
using namespace dgl::runtime; using namespace dgl::runtime;
namespace dgl { namespace dgl {
namespace geometry { namespace geometry {
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points, void FarthestPointSampler(
NDArray dist, IdArray start_idx, IdArray result) { NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
IdArray start_idx, IdArray result) {
CHECK_EQ(array->ctx, result->ctx) << "Array and the result should be on the same device."; CHECK_EQ(array->ctx, result->ctx)
CHECK_EQ(array->shape[0], dist->shape[0]) << "Shape of array and dist mismatch"; << "Array and the result should be on the same device.";
CHECK_EQ(start_idx->shape[0], batch_size) << "Shape of start_idx and batch_size mismatch"; CHECK_EQ(array->shape[0], dist->shape[0])
CHECK_EQ(result->shape[0], batch_size * sample_points) << "Invalid shape of result"; << "Shape of array and dist mismatch";
CHECK_EQ(start_idx->shape[0], batch_size)
<< "Shape of start_idx and batch_size mismatch";
CHECK_EQ(result->shape[0], batch_size * sample_points)
<< "Invalid shape of result";
ATEN_FLOAT_TYPE_SWITCH(array->dtype, FloatType, "values", { ATEN_FLOAT_TYPE_SWITCH(array->dtype, FloatType, "values", {
ATEN_ID_TYPE_SWITCH(result->dtype, IdType, { ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "FarthestPointSampler", { ATEN_XPU_SWITCH_CUDA(
array->ctx.device_type, XPU, "FarthestPointSampler", {
impl::FarthestPointSampler<XPU, FloatType, IdType>( impl::FarthestPointSampler<XPU, FloatType, IdType>(
array, batch_size, sample_points, dist, start_idx, result); array, batch_size, sample_points, dist, start_idx, result);
}); });
...@@ -33,9 +39,11 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin ...@@ -33,9 +39,11 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin
}); });
} }
void NeighborMatching(HeteroGraphPtr graph, const NDArray weight, IdArray result) { void NeighborMatching(
HeteroGraphPtr graph, const NDArray weight, IdArray result) {
if (!aten::IsNullArray(weight)) { if (!aten::IsNullArray(weight)) {
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "NeighborMatching", { ATEN_XPU_SWITCH_CUDA(
graph->Context().device_type, XPU, "NeighborMatching", {
ATEN_FLOAT_TYPE_SWITCH(weight->dtype, FloatType, "weight", { ATEN_FLOAT_TYPE_SWITCH(weight->dtype, FloatType, "weight", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
impl::WeightedNeighborMatching<XPU, FloatType, IdType>( impl::WeightedNeighborMatching<XPU, FloatType, IdType>(
...@@ -44,10 +52,10 @@ void NeighborMatching(HeteroGraphPtr graph, const NDArray weight, IdArray result ...@@ -44,10 +52,10 @@ void NeighborMatching(HeteroGraphPtr graph, const NDArray weight, IdArray result
}); });
}); });
} else { } else {
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "NeighborMatching", { ATEN_XPU_SWITCH_CUDA(
graph->Context().device_type, XPU, "NeighborMatching", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
impl::NeighborMatching<XPU, IdType>( impl::NeighborMatching<XPU, IdType>(graph->GetCSRMatrix(0), result);
graph->GetCSRMatrix(0), result);
}); });
}); });
} }
...@@ -56,7 +64,7 @@ void NeighborMatching(HeteroGraphPtr graph, const NDArray weight, IdArray result ...@@ -56,7 +64,7 @@ void NeighborMatching(HeteroGraphPtr graph, const NDArray weight, IdArray result
///////////////////////// C APIs ///////////////////////// ///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("geometry._CAPI_FarthestPointSampler") DGL_REGISTER_GLOBAL("geometry._CAPI_FarthestPointSampler")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const NDArray data = args[0]; const NDArray data = args[0];
const int64_t batch_size = args[1]; const int64_t batch_size = args[1];
const int64_t sample_points = args[2]; const int64_t sample_points = args[2];
...@@ -64,24 +72,28 @@ DGL_REGISTER_GLOBAL("geometry._CAPI_FarthestPointSampler") ...@@ -64,24 +72,28 @@ DGL_REGISTER_GLOBAL("geometry._CAPI_FarthestPointSampler")
IdArray start_idx = args[4]; IdArray start_idx = args[4];
IdArray result = args[5]; IdArray result = args[5];
FarthestPointSampler(data, batch_size, sample_points, dist, start_idx, result); FarthestPointSampler(
data, batch_size, sample_points, dist, start_idx, result);
}); });
DGL_REGISTER_GLOBAL("geometry._CAPI_NeighborMatching") DGL_REGISTER_GLOBAL("geometry._CAPI_NeighborMatching")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0]; HeteroGraphRef graph = args[0];
const NDArray weight = args[1]; const NDArray weight = args[1];
IdArray result = args[2]; IdArray result = args[2];
// sanity check // sanity check
aten::CheckCtx(graph->Context(), {weight, result}, {"edge_weight, result"}); aten::CheckCtx(
graph->Context(), {weight, result}, {"edge_weight, result"});
aten::CheckContiguous({weight, result}, {"edge_weight", "result"}); aten::CheckContiguous({weight, result}, {"edge_weight", "result"});
CHECK_EQ(graph->NumEdgeTypes(), 1) << "homogeneous graph has only one edge type"; CHECK_EQ(graph->NumEdgeTypes(), 1)
<< "homogeneous graph has only one edge type";
CHECK_EQ(result->ndim, 1) << "result should be an 1D tensor."; CHECK_EQ(result->ndim, 1) << "result should be an 1D tensor.";
auto pair = graph->meta_graph()->FindEdge(0); auto pair = graph->meta_graph()->FindEdge(0);
const dgl_type_t node_type = pair.first; const dgl_type_t node_type = pair.first;
CHECK_EQ(graph->NumVertices(node_type), result->shape[0]) CHECK_EQ(graph->NumVertices(node_type), result->shape[0])
<< "The number of nodes should be the same as the length of result tensor."; << "The number of nodes should be the same as the length of result "
"tensor.";
if (!aten::IsNullArray(weight)) { if (!aten::IsNullArray(weight)) {
CHECK_EQ(weight->ndim, 1) << "weight should be an 1D tensor."; CHECK_EQ(weight->ndim, 1) << "weight should be an 1D tensor.";
CHECK_EQ(graph->NumEdges(0), weight->shape[0]) CHECK_EQ(graph->NumEdges(0), weight->shape[0])
......
...@@ -13,16 +13,19 @@ namespace geometry { ...@@ -13,16 +13,19 @@ namespace geometry {
namespace impl { namespace impl {
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points, void FarthestPointSampler(
NDArray dist, IdArray start_idx, IdArray result); NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
IdArray start_idx, IdArray result);
/*! \brief Implementation of weighted neighbor matching process of edge coarsening used /*! \brief Implementation of weighted neighbor matching process of edge
* in Metis and Graclus for homogeneous graph coarsening. This procedure keeps * coarsening used in Metis and Graclus for homogeneous graph coarsening. This
* picking an unmarked vertex and matching it with one its unmarked neighbors * procedure keeps picking an unmarked vertex and matching it with one its
* (that maximizes its edge weight) until no match can be done. * unmarked neighbors (that maximizes its edge weight) until no match can be
* done.
*/ */
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result); void WeightedNeighborMatching(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
/*! \brief Implementation of neighbor matching process of edge coarsening used /*! \brief Implementation of neighbor matching process of edge coarsening used
* in Metis and Graclus for homogeneous graph coarsening. This procedure keeps * in Metis and Graclus for homogeneous graph coarsening. This procedure keeps
......
...@@ -10,60 +10,51 @@ namespace dgl { ...@@ -10,60 +10,51 @@ namespace dgl {
// creator implementation // creator implementation
HeteroGraphPtr CreateHeteroGraph( HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph, GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type) { const std::vector<int64_t>& num_nodes_per_type) {
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs, num_nodes_per_type)); return HeteroGraphPtr(
new HeteroGraph(meta_graph, rel_graphs, num_nodes_per_type));
} }
HeteroGraphPtr CreateFromCOO( HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row,
IdArray row, IdArray col, IdArray col, bool row_sorted, bool col_sorted, dgl_format_code_t formats) {
bool row_sorted, bool col_sorted, dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCOO( auto unit_g = UnitGraph::CreateFromCOO(
num_vtypes, num_src, num_dst, row, col, row_sorted, col_sorted, formats); num_vtypes, num_src, num_dst, row, col, row_sorted, col_sorted, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g})); return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
} }
HeteroGraphPtr CreateFromCOO( HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat, int64_t num_vtypes, const aten::COOMatrix& mat, dgl_format_code_t formats) {
dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCOO(num_vtypes, mat, formats); auto unit_g = UnitGraph::CreateFromCOO(num_vtypes, mat, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g})); return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
} }
HeteroGraphPtr CreateFromCSR( HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCSR( auto unit_g = UnitGraph::CreateFromCSR(
num_vtypes, num_src, num_dst, indptr, indices, edge_ids, formats); num_vtypes, num_src, num_dst, indptr, indices, edge_ids, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g})); return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
} }
HeteroGraphPtr CreateFromCSR( HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {
dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCSR(num_vtypes, mat, formats); auto unit_g = UnitGraph::CreateFromCSR(num_vtypes, mat, formats);
auto ret = HeteroGraphPtr(new HeteroGraph( auto ret = HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
unit_g->meta_graph(), return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
{unit_g}));
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(),
{unit_g}));
} }
HeteroGraphPtr CreateFromCSC( HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCSC( auto unit_g = UnitGraph::CreateFromCSC(
num_vtypes, num_src, num_dst, indptr, indices, edge_ids, formats); num_vtypes, num_src, num_dst, indptr, indices, edge_ids, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g})); return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
} }
HeteroGraphPtr CreateFromCSC( HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {
dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCSC(num_vtypes, mat, formats); auto unit_g = UnitGraph::CreateFromCSC(num_vtypes, mat, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g})); return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
} }
......
...@@ -37,16 +37,18 @@ gk_csr_t *Convert2GKCsr(const aten::CSRMatrix mat, bool is_row) { ...@@ -37,16 +37,18 @@ gk_csr_t *Convert2GKCsr(const aten::CSRMatrix mat, bool is_row) {
size_t num_ptrs; size_t num_ptrs;
if (is_row) { if (is_row) {
num_ptrs = gk_csr->nrows + 1; num_ptrs = gk_csr->nrows + 1;
gk_indptr = gk_csr->rowptr = gk_zmalloc(gk_csr->nrows+1, gk_indptr = gk_csr->rowptr = gk_zmalloc(
const_cast<char*>("gk_csr_ExtractPartition: rowptr")); gk_csr->nrows + 1,
gk_indices = gk_csr->rowind = gk_imalloc(nnz, const_cast<char *>("gk_csr_ExtractPartition: rowptr"));
const_cast<char*>("gk_csr_ExtractPartition: rowind")); gk_indices = gk_csr->rowind =
gk_imalloc(nnz, const_cast<char *>("gk_csr_ExtractPartition: rowind"));
} else { } else {
num_ptrs = gk_csr->ncols + 1; num_ptrs = gk_csr->ncols + 1;
gk_indptr = gk_csr->colptr = gk_zmalloc(gk_csr->ncols+1, gk_indptr = gk_csr->colptr = gk_zmalloc(
const_cast<char*>("gk_csr_ExtractPartition: colptr")); gk_csr->ncols + 1,
gk_indices = gk_csr->colind = gk_imalloc(nnz, const_cast<char *>("gk_csr_ExtractPartition: colptr"));
const_cast<char*>("gk_csr_ExtractPartition: colind")); gk_indices = gk_csr->colind =
gk_imalloc(nnz, const_cast<char *>("gk_csr_ExtractPartition: colind"));
} }
for (size_t i = 0; i < num_ptrs; i++) { for (size_t i = 0; i < num_ptrs; i++) {
...@@ -98,7 +100,8 @@ aten::CSRMatrix Convert2DGLCsr(gk_csr_t *gk_csr, bool is_row) { ...@@ -98,7 +100,8 @@ aten::CSRMatrix Convert2DGLCsr(gk_csr_t *gk_csr, bool is_row) {
eids[i] = i; eids[i] = i;
} }
return aten::CSRMatrix(gk_csr->nrows, gk_csr->ncols, indptr_arr, indices_arr, eids_arr); return aten::CSRMatrix(
gk_csr->nrows, gk_csr->ncols, indptr_arr, indices_arr, eids_arr);
} }
#endif // !defined(_WIN32) #endif // !defined(_WIN32)
......
...@@ -5,11 +5,13 @@ ...@@ -5,11 +5,13 @@
*/ */
#include <dgl/graph.h> #include <dgl/graph.h>
#include <dgl/sampler.h> #include <dgl/sampler.h>
#include <algorithm> #include <algorithm>
#include <unordered_map>
#include <set>
#include <functional> #include <functional>
#include <set>
#include <tuple> #include <tuple>
#include <unordered_map>
#include "../c_api_common.h" #include "../c_api_common.h"
namespace dgl { namespace dgl {
...@@ -21,8 +23,8 @@ Graph::Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes) { ...@@ -21,8 +23,8 @@ Graph::Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes) {
num_edges_ = src_ids->shape[0]; num_edges_ = src_ids->shape[0];
CHECK(static_cast<int64_t>(num_edges_) == dst_ids->shape[0]) CHECK(static_cast<int64_t>(num_edges_) == dst_ids->shape[0])
<< "vectors in COO must have the same length"; << "vectors in COO must have the same length";
const dgl_id_t *src_data = static_cast<dgl_id_t*>(src_ids->data); const dgl_id_t* src_data = static_cast<dgl_id_t*>(src_ids->data);
const dgl_id_t *dst_data = static_cast<dgl_id_t*>(dst_ids->data); const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst_ids->data);
all_edges_src_.reserve(num_edges_); all_edges_src_.reserve(num_edges_);
all_edges_dst_.reserve(num_edges_); all_edges_dst_.reserve(num_edges_);
for (uint64_t i = 0; i < num_edges_; i++) { for (uint64_t i = 0; i < num_edges_; i++) {
...@@ -53,15 +55,15 @@ bool Graph::IsMultigraph() const { ...@@ -53,15 +55,15 @@ bool Graph::IsMultigraph() const {
pairs.emplace_back(all_edges_src_[eid], all_edges_dst_[eid]); pairs.emplace_back(all_edges_src_[eid], all_edges_dst_[eid]);
} }
// sort according to src and dst ids // sort according to src and dst ids
std::sort(pairs.begin(), pairs.end(), std::sort(pairs.begin(), pairs.end(), [](const Pair& t1, const Pair& t2) {
[] (const Pair& t1, const Pair& t2) { return std::get<0>(t1) < std::get<0>(t2) ||
return std::get<0>(t1) < std::get<0>(t2) (std::get<0>(t1) == std::get<0>(t2) &&
|| (std::get<0>(t1) == std::get<0>(t2) && std::get<1>(t1) < std::get<1>(t2)); std::get<1>(t1) < std::get<1>(t2));
}); });
for (uint64_t eid = 0; eid < num_edges_-1; ++eid) { for (uint64_t eid = 0; eid < num_edges_ - 1; ++eid) {
// As src and dst are all sorted, we only need to compare i and i+1 // As src and dst are all sorted, we only need to compare i and i+1
if (std::get<0>(pairs[eid]) == std::get<0>(pairs[eid+1]) && if (std::get<0>(pairs[eid]) == std::get<0>(pairs[eid + 1]) &&
std::get<1>(pairs[eid]) == std::get<1>(pairs[eid+1])) std::get<1>(pairs[eid]) == std::get<1>(pairs[eid + 1]))
return true; return true;
} }
...@@ -125,7 +127,7 @@ BoolArray Graph::HasVertices(IdArray vids) const { ...@@ -125,7 +127,7 @@ BoolArray Graph::HasVertices(IdArray vids) const {
int64_t* rst_data = static_cast<int64_t*>(rst->data); int64_t* rst_data = static_cast<int64_t*>(rst->data);
const int64_t nverts = NumVertices(); const int64_t nverts = NumVertices();
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
rst_data[i] = (vid_data[i] < nverts && vid_data[i] >= 0)? 1 : 0; rst_data[i] = (vid_data[i] < nverts && vid_data[i] >= 0) ? 1 : 0;
} }
return rst; return rst;
} }
...@@ -151,18 +153,18 @@ BoolArray Graph::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const { ...@@ -151,18 +153,18 @@ BoolArray Graph::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
if (srclen == 1) { if (srclen == 1) {
// one-many // one-many
for (int64_t i = 0; i < dstlen; ++i) { for (int64_t i = 0; i < dstlen; ++i) {
rst_data[i] = HasEdgeBetween(src_data[0], dst_data[i])? 1 : 0; rst_data[i] = HasEdgeBetween(src_data[0], dst_data[i]) ? 1 : 0;
} }
} else if (dstlen == 1) { } else if (dstlen == 1) {
// many-one // many-one
for (int64_t i = 0; i < srclen; ++i) { for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdgeBetween(src_data[i], dst_data[0])? 1 : 0; rst_data[i] = HasEdgeBetween(src_data[i], dst_data[0]) ? 1 : 0;
} }
} else { } else {
// many-many // many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array."; CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) { for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdgeBetween(src_data[i], dst_data[i])? 1 : 0; rst_data[i] = HasEdgeBetween(src_data[i], dst_data[i]) ? 1 : 0;
} }
} }
return rst; return rst;
...@@ -174,11 +176,11 @@ IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const { ...@@ -174,11 +176,11 @@ IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const {
CHECK(radius >= 1) << "invalid radius: " << radius; CHECK(radius >= 1) << "invalid radius: " << radius;
std::set<dgl_id_t> vset; std::set<dgl_id_t> vset;
for (auto& it : reverse_adjlist_[vid].succ) for (auto& it : reverse_adjlist_[vid].succ) vset.insert(it);
vset.insert(it);
const int64_t len = vset.size(); const int64_t len = vset.size();
IdArray rst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray rst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data); int64_t* rst_data = static_cast<int64_t*>(rst->data);
std::copy(vset.begin(), vset.end(), rst_data); std::copy(vset.begin(), vset.end(), rst_data);
...@@ -191,11 +193,11 @@ IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const { ...@@ -191,11 +193,11 @@ IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const {
CHECK(radius >= 1) << "invalid radius: " << radius; CHECK(radius >= 1) << "invalid radius: " << radius;
std::set<dgl_id_t> vset; std::set<dgl_id_t> vset;
for (auto& it : adjlist_[vid].succ) for (auto& it : adjlist_[vid].succ) vset.insert(it);
vset.insert(it);
const int64_t len = vset.size(); const int64_t len = vset.size();
IdArray rst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray rst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data); int64_t* rst_data = static_cast<int64_t*>(rst->data);
std::copy(vset.begin(), vset.end(), rst_data); std::copy(vset.begin(), vset.end(), rst_data);
...@@ -204,19 +206,20 @@ IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const { ...@@ -204,19 +206,20 @@ IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const {
// O(E) // O(E)
IdArray Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const { IdArray Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
CHECK(HasVertex(src) && HasVertex(dst)) << "invalid edge: " << src << " -> " << dst; CHECK(HasVertex(src) && HasVertex(dst))
<< "invalid edge: " << src << " -> " << dst;
const auto& succ = adjlist_[src].succ; const auto& succ = adjlist_[src].succ;
std::vector<dgl_id_t> edgelist; std::vector<dgl_id_t> edgelist;
for (size_t i = 0; i < succ.size(); ++i) { for (size_t i = 0; i < succ.size(); ++i) {
if (succ[i] == dst) if (succ[i] == dst) edgelist.push_back(adjlist_[src].edge_id[i]);
edgelist.push_back(adjlist_[src].edge_id[i]);
} }
// FIXME: signed? Also it seems that we are using int64_t everywhere... // FIXME: signed? Also it seems that we are using int64_t everywhere...
const int64_t len = edgelist.size(); const int64_t len = edgelist.size();
IdArray rst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray rst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
// FIXME: signed? // FIXME: signed?
int64_t* rst_data = static_cast<int64_t*>(rst->data); int64_t* rst_data = static_cast<int64_t*>(rst->data);
...@@ -243,10 +246,11 @@ EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const { ...@@ -243,10 +246,11 @@ EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
std::vector<dgl_id_t> src, dst, eid; std::vector<dgl_id_t> src, dst, eid;
for (i = 0, j = 0; i < srclen && j < dstlen; i += src_stride, j += dst_stride) { for (i = 0, j = 0; i < srclen && j < dstlen;
i += src_stride, j += dst_stride) {
const dgl_id_t src_id = src_data[i], dst_id = dst_data[j]; const dgl_id_t src_id = src_data[i], dst_id = dst_data[j];
CHECK(HasVertex(src_id) && HasVertex(dst_id)) << CHECK(HasVertex(src_id) && HasVertex(dst_id))
"invalid edge: " << src_id << " -> " << dst_id; << "invalid edge: " << src_id << " -> " << dst_id;
const auto& succ = adjlist_[src_id].succ; const auto& succ = adjlist_[src_id].succ;
for (size_t k = 0; k < succ.size(); ++k) { for (size_t k = 0; k < succ.size(); ++k) {
if (succ[k] == dst_id) { if (succ[k] == dst_id) {
...@@ -286,8 +290,7 @@ EdgeArray Graph::FindEdges(IdArray eids) const { ...@@ -286,8 +290,7 @@ EdgeArray Graph::FindEdges(IdArray eids) const {
for (uint64_t i = 0; i < (uint64_t)len; ++i) { for (uint64_t i = 0; i < (uint64_t)len; ++i) {
dgl_id_t eid = eid_data[i]; dgl_id_t eid = eid_data[i];
if (eid >= num_edges_) if (eid >= num_edges_) LOG(FATAL) << "invalid edge id:" << eid;
LOG(FATAL) << "invalid edge id:" << eid;
rst_src_data[i] = all_edges_src_[eid]; rst_src_data[i] = all_edges_src_[eid];
rst_dst_data[i] = all_edges_dst_[eid]; rst_dst_data[i] = all_edges_dst_[eid];
...@@ -301,9 +304,12 @@ EdgeArray Graph::FindEdges(IdArray eids) const { ...@@ -301,9 +304,12 @@ EdgeArray Graph::FindEdges(IdArray eids) const {
EdgeArray Graph::InEdges(dgl_id_t vid) const { EdgeArray Graph::InEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = reverse_adjlist_[vid].succ.size(); const int64_t len = reverse_adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray src = IdArray::Empty(
IdArray dst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray dst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* src_data = static_cast<int64_t*>(src->data); int64_t* src_data = static_cast<int64_t*>(src->data);
int64_t* dst_data = static_cast<int64_t*>(dst->data); int64_t* dst_data = static_cast<int64_t*>(dst->data);
int64_t* eid_data = static_cast<int64_t*>(eid->data); int64_t* eid_data = static_cast<int64_t*>(eid->data);
...@@ -347,9 +353,12 @@ EdgeArray Graph::InEdges(IdArray vids) const { ...@@ -347,9 +353,12 @@ EdgeArray Graph::InEdges(IdArray vids) const {
EdgeArray Graph::OutEdges(dgl_id_t vid) const { EdgeArray Graph::OutEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = adjlist_[vid].succ.size(); const int64_t len = adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray src = IdArray::Empty(
IdArray dst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray dst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* src_data = static_cast<int64_t*>(src->data); int64_t* src_data = static_cast<int64_t*>(src->data);
int64_t* dst_data = static_cast<int64_t*>(dst->data); int64_t* dst_data = static_cast<int64_t*>(dst->data);
int64_t* eid_data = static_cast<int64_t*>(eid->data); int64_t* eid_data = static_cast<int64_t*>(eid->data);
...@@ -390,11 +399,14 @@ EdgeArray Graph::OutEdges(IdArray vids) const { ...@@ -390,11 +399,14 @@ EdgeArray Graph::OutEdges(IdArray vids) const {
} }
// O(E*log(E)) if sort is required; otherwise, O(E) // O(E*log(E)) if sort is required; otherwise, O(E)
EdgeArray Graph::Edges(const std::string &order) const { EdgeArray Graph::Edges(const std::string& order) const {
const int64_t len = num_edges_; const int64_t len = num_edges_;
IdArray src = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray src = IdArray::Empty(
IdArray dst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray dst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
if (order == "srcdst") { if (order == "srcdst") {
typedef std::tuple<int64_t, int64_t, int64_t> Tuple; typedef std::tuple<int64_t, int64_t, int64_t> Tuple;
...@@ -404,10 +416,11 @@ EdgeArray Graph::Edges(const std::string &order) const { ...@@ -404,10 +416,11 @@ EdgeArray Graph::Edges(const std::string &order) const {
tuples.emplace_back(all_edges_src_[eid], all_edges_dst_[eid], eid); tuples.emplace_back(all_edges_src_[eid], all_edges_dst_[eid], eid);
} }
// sort according to src and dst ids // sort according to src and dst ids
std::sort(tuples.begin(), tuples.end(), std::sort(
[] (const Tuple& t1, const Tuple& t2) { tuples.begin(), tuples.end(), [](const Tuple& t1, const Tuple& t2) {
return std::get<0>(t1) < std::get<0>(t2) return std::get<0>(t1) < std::get<0>(t2) ||
|| (std::get<0>(t1) == std::get<0>(t2) && std::get<1>(t1) < std::get<1>(t2)); (std::get<0>(t1) == std::get<0>(t2) &&
std::get<1>(t1) < std::get<1>(t2));
}); });
// make return arrays // make return arrays
...@@ -488,8 +501,11 @@ Subgraph Graph::VertexSubgraph(IdArray vids) const { ...@@ -488,8 +501,11 @@ Subgraph Graph::VertexSubgraph(IdArray vids) const {
} }
} }
} }
rst.induced_edges = IdArray::Empty({static_cast<int64_t>(edges.size())}, vids->dtype, vids->ctx); rst.induced_edges = IdArray::Empty(
std::copy(edges.begin(), edges.end(), static_cast<int64_t*>(rst.induced_edges->data)); {static_cast<int64_t>(edges.size())}, vids->dtype, vids->ctx);
std::copy(
edges.begin(), edges.end(),
static_cast<int64_t*>(rst.induced_edges->data));
return rst; return rst;
} }
...@@ -524,7 +540,9 @@ Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const { ...@@ -524,7 +540,9 @@ Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
rst.induced_vertices = IdArray::Empty( rst.induced_vertices = IdArray::Empty(
{static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx); {static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx);
std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data)); std::copy(
nodes.begin(), nodes.end(),
static_cast<int64_t*>(rst.induced_vertices->data));
} else { } else {
rst.graph = std::make_shared<Graph>(); rst.graph = std::make_shared<Graph>();
rst.induced_edges = eids; rst.induced_edges = eids;
...@@ -536,59 +554,58 @@ Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const { ...@@ -536,59 +554,58 @@ Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
rst.graph->AddEdge(src_id, dst_id); rst.graph->AddEdge(src_id, dst_id);
} }
for (uint64_t i = 0; i < NumVertices(); ++i) for (uint64_t i = 0; i < NumVertices(); ++i) nodes.push_back(i);
nodes.push_back(i);
rst.induced_vertices = IdArray::Empty( rst.induced_vertices = IdArray::Empty(
{static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx); {static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx);
std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data)); std::copy(
nodes.begin(), nodes.end(),
static_cast<int64_t*>(rst.induced_vertices->data));
} }
return rst; return rst;
} }
std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const { std::vector<IdArray> Graph::GetAdj(
bool transpose, const std::string& fmt) const {
uint64_t num_edges = NumEdges(); uint64_t num_edges = NumEdges();
uint64_t num_nodes = NumVertices(); uint64_t num_nodes = NumVertices();
if (fmt == "coo") { if (fmt == "coo") {
IdArray idx = IdArray::Empty( IdArray idx = IdArray::Empty(
{2 * static_cast<int64_t>(num_edges)}, {2 * static_cast<int64_t>(num_edges)}, DGLDataType{kDGLInt, 64, 1},
DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0}); DGLContext{kDGLCPU, 0});
int64_t *idx_data = static_cast<int64_t*>(idx->data); int64_t* idx_data = static_cast<int64_t*>(idx->data);
if (transpose) { if (transpose) {
std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data); std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data);
std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), idx_data + num_edges); std::copy(
all_edges_dst_.begin(), all_edges_dst_.end(), idx_data + num_edges);
} else { } else {
std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), idx_data); std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), idx_data);
std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data + num_edges); std::copy(
all_edges_src_.begin(), all_edges_src_.end(), idx_data + num_edges);
} }
IdArray eid = IdArray::Empty( IdArray eid = IdArray::Empty(
{static_cast<int64_t>(num_edges)}, {static_cast<int64_t>(num_edges)}, DGLDataType{kDGLInt, 64, 1},
DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0}); DGLContext{kDGLCPU, 0});
int64_t *eid_data = static_cast<int64_t*>(eid->data); int64_t* eid_data = static_cast<int64_t*>(eid->data);
for (uint64_t eid = 0; eid < num_edges; ++eid) { for (uint64_t eid = 0; eid < num_edges; ++eid) {
eid_data[eid] = eid; eid_data[eid] = eid;
} }
return std::vector<IdArray>{idx, eid}; return std::vector<IdArray>{idx, eid};
} else if (fmt == "csr") { } else if (fmt == "csr") {
IdArray indptr = IdArray::Empty( IdArray indptr = IdArray::Empty(
{static_cast<int64_t>(num_nodes) + 1}, {static_cast<int64_t>(num_nodes) + 1}, DGLDataType{kDGLInt, 64, 1},
DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0}); DGLContext{kDGLCPU, 0});
IdArray indices = IdArray::Empty( IdArray indices = IdArray::Empty(
{static_cast<int64_t>(num_edges)}, {static_cast<int64_t>(num_edges)}, DGLDataType{kDGLInt, 64, 1},
DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0}); DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty( IdArray eid = IdArray::Empty(
{static_cast<int64_t>(num_edges)}, {static_cast<int64_t>(num_edges)}, DGLDataType{kDGLInt, 64, 1},
DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0}); DGLContext{kDGLCPU, 0});
int64_t *indptr_data = static_cast<int64_t*>(indptr->data); int64_t* indptr_data = static_cast<int64_t*>(indptr->data);
int64_t *indices_data = static_cast<int64_t*>(indices->data); int64_t* indices_data = static_cast<int64_t*>(indices->data);
int64_t *eid_data = static_cast<int64_t*>(eid->data); int64_t* eid_data = static_cast<int64_t*>(eid->data);
const AdjacencyList *adjlist; const AdjacencyList* adjlist;
if (transpose) { if (transpose) {
// Out-edges. // Out-edges.
adjlist = &adjlist_; adjlist = &adjlist_;
...@@ -599,9 +616,11 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const ...@@ -599,9 +616,11 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const
indptr_data[0] = 0; indptr_data[0] = 0;
for (size_t i = 0; i < adjlist->size(); i++) { for (size_t i = 0; i < adjlist->size(); i++) {
indptr_data[i + 1] = indptr_data[i] + adjlist->at(i).succ.size(); indptr_data[i + 1] = indptr_data[i] + adjlist->at(i).succ.size();
std::copy(adjlist->at(i).succ.begin(), adjlist->at(i).succ.end(), std::copy(
adjlist->at(i).succ.begin(), adjlist->at(i).succ.end(),
indices_data + indptr_data[i]); indices_data + indptr_data[i]);
std::copy(adjlist->at(i).edge_id.begin(), adjlist->at(i).edge_id.end(), std::copy(
adjlist->at(i).edge_id.begin(), adjlist->at(i).edge_id.end(),
eid_data + indptr_data[i]); eid_data + indptr_data[i]);
} }
return std::vector<IdArray>{indptr, indices, eid}; return std::vector<IdArray>{indptr, indices, eid};
......
...@@ -3,72 +3,74 @@ ...@@ -3,72 +3,74 @@
* \file graph/graph.cc * \file graph/graph.cc
* \brief DGL graph index APIs * \brief DGL graph index APIs
*/ */
#include <dgl/packed_func_ext.h>
#include <dgl/graph.h> #include <dgl/graph.h>
#include <dgl/immutable_graph.h>
#include <dgl/graph_op.h> #include <dgl/graph_op.h>
#include <dgl/sampler.h> #include <dgl/immutable_graph.h>
#include <dgl/nodeflow.h> #include <dgl/nodeflow.h>
#include <dgl/packed_func_ext.h>
#include <dgl/sampler.h>
#include "../c_api_common.h" #include "../c_api_common.h"
using dgl::runtime::DGLArgs; using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue; using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue; using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray; using dgl::runtime::NDArray;
using dgl::runtime::PackedFunc;
namespace dgl { namespace dgl {
///////////////////////////// Graph API /////////////////////////////////// ///////////////////////////// Graph API ///////////////////////////////////
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreateMutable") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreateMutable")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = GraphRef(Graph::Create()); *rv = GraphRef(Graph::Create());
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreate") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const IdArray src_ids = args[0]; const IdArray src_ids = args[0];
const IdArray dst_ids = args[1]; const IdArray dst_ids = args[1];
const int64_t num_nodes = args[2]; const int64_t num_nodes = args[2];
const bool readonly = args[3]; const bool readonly = args[3];
if (readonly) { if (readonly) {
*rv = GraphRef(ImmutableGraph::CreateFromCOO(num_nodes, src_ids, dst_ids)); *rv = GraphRef(
ImmutableGraph::CreateFromCOO(num_nodes, src_ids, dst_ids));
} else { } else {
*rv = GraphRef(Graph::CreateFromCOO(num_nodes, src_ids, dst_ids)); *rv = GraphRef(Graph::CreateFromCOO(num_nodes, src_ids, dst_ids));
} }
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const IdArray indptr = args[0]; const IdArray indptr = args[0];
const IdArray indices = args[1]; const IdArray indices = args[1];
const std::string edge_dir = args[2]; const std::string edge_dir = args[2];
IdArray edge_ids = IdArray::Empty({indices->shape[0]}, IdArray edge_ids = IdArray::Empty(
DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); {indices->shape[0]}, DGLDataType{kDGLInt, 64, 1},
int64_t *edge_data = static_cast<int64_t *>(edge_ids->data); DGLContext{kDGLCPU, 0});
for (int64_t i = 0; i < edge_ids->shape[0]; i++) int64_t* edge_data = static_cast<int64_t*>(edge_ids->data);
edge_data[i] = i; for (int64_t i = 0; i < edge_ids->shape[0]; i++) edge_data[i] = i;
*rv = GraphRef(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, edge_dir)); *rv = GraphRef(
ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, edge_dir));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const std::string shared_mem_name = args[0]; const std::string shared_mem_name = args[0];
*rv = GraphRef(ImmutableGraph::CreateFromCSR(shared_mem_name)); *rv = GraphRef(ImmutableGraph::CreateFromCSR(shared_mem_name));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddVertices") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
uint64_t num_vertices = args[1]; uint64_t num_vertices = args[1];
g->AddVertices(num_vertices); g->AddVertices(num_vertices);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdge") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdge")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t src = args[1]; const dgl_id_t src = args[1];
const dgl_id_t dst = args[2]; const dgl_id_t dst = args[2];
...@@ -76,7 +78,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdge") ...@@ -76,7 +78,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdge")
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdges") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray src = args[1]; const IdArray src = args[1];
const IdArray dst = args[2]; const IdArray dst = args[2];
...@@ -84,51 +86,51 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdges") ...@@ -84,51 +86,51 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdges")
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphClear") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphClear")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
g->Clear(); g->Clear();
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphIsMultigraph") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphIsMultigraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
*rv = g->IsMultigraph(); *rv = g->IsMultigraph();
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphIsReadonly") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphIsReadonly")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
*rv = g->IsReadonly(); *rv = g->IsReadonly();
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumVertices") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
*rv = static_cast<int64_t>(g->NumVertices()); *rv = static_cast<int64_t>(g->NumVertices());
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumEdges") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
*rv = static_cast<int64_t>(g->NumEdges()); *rv = static_cast<int64_t>(g->NumEdges());
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertex") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertex")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
*rv = g->HasVertex(vid); *rv = g->HasVertex(vid);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertices") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray vids = args[1]; const IdArray vids = args[1];
*rv = g->HasVertices(vids); *rv = g->HasVertices(vids);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgeBetween") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgeBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t src = args[1]; const dgl_id_t src = args[1];
const dgl_id_t dst = args[2]; const dgl_id_t dst = args[2];
...@@ -136,7 +138,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgeBetween") ...@@ -136,7 +138,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgeBetween")
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgesBetween") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgesBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray src = args[1]; const IdArray src = args[1];
const IdArray dst = args[2]; const IdArray dst = args[2];
...@@ -144,7 +146,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgesBetween") ...@@ -144,7 +146,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgesBetween")
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
const uint64_t radius = args[2]; const uint64_t radius = args[2];
...@@ -152,7 +154,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors") ...@@ -152,7 +154,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors")
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphSuccessors") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphSuccessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
const uint64_t radius = args[2]; const uint64_t radius = args[2];
...@@ -160,7 +162,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphSuccessors") ...@@ -160,7 +162,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphSuccessors")
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t src = args[1]; const dgl_id_t src = args[1];
const dgl_id_t dst = args[2]; const dgl_id_t dst = args[2];
...@@ -168,7 +170,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId") ...@@ -168,7 +170,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId")
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray src = args[1]; const IdArray src = args[1];
const IdArray dst = args[2]; const IdArray dst = args[2];
...@@ -176,89 +178,89 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds") ...@@ -176,89 +178,89 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds")
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFindEdge") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFindEdge")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t eid = args[1]; const dgl_id_t eid = args[1];
const auto& pair = g->FindEdge(eid); const auto& pair = g->FindEdge(eid);
*rv = PackedFunc([pair] (DGLArgs args, DGLRetValue* rv) { *rv = PackedFunc([pair](DGLArgs args, DGLRetValue* rv) {
const int choice = args[0]; const int choice = args[0];
const int64_t ret = (choice == 0? pair.first : pair.second); const int64_t ret = (choice == 0 ? pair.first : pair.second);
*rv = ret; *rv = ret;
}); });
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFindEdges") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFindEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray eids = args[1]; const IdArray eids = args[1];
*rv = ConvertEdgeArrayToPackedFunc(g->FindEdges(eids)); *rv = ConvertEdgeArrayToPackedFunc(g->FindEdges(eids));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_1") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
*rv = ConvertEdgeArrayToPackedFunc(g->InEdges(vid)); *rv = ConvertEdgeArrayToPackedFunc(g->InEdges(vid));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_2") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray vids = args[1]; const IdArray vids = args[1];
*rv = ConvertEdgeArrayToPackedFunc(g->InEdges(vids)); *rv = ConvertEdgeArrayToPackedFunc(g->InEdges(vids));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_1") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
*rv = ConvertEdgeArrayToPackedFunc(g->OutEdges(vid)); *rv = ConvertEdgeArrayToPackedFunc(g->OutEdges(vid));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_2") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray vids = args[1]; const IdArray vids = args[1];
*rv = ConvertEdgeArrayToPackedFunc(g->OutEdges(vids)); *rv = ConvertEdgeArrayToPackedFunc(g->OutEdges(vids));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdges") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
std::string order = args[1]; std::string order = args[1];
*rv = ConvertEdgeArrayToPackedFunc(g->Edges(order)); *rv = ConvertEdgeArrayToPackedFunc(g->Edges(order));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegree") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(g->InDegree(vid)); *rv = static_cast<int64_t>(g->InDegree(vid));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegrees") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray vids = args[1]; const IdArray vids = args[1];
*rv = g->InDegrees(vids); *rv = g->InDegrees(vids);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegree") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(g->OutDegree(vid)); *rv = static_cast<int64_t>(g->OutDegree(vid));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegrees") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray vids = args[1]; const IdArray vids = args[1];
*rv = g->OutDegrees(vids); *rv = g->OutDegrees(vids);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray vids = args[1]; const IdArray vids = args[1];
std::shared_ptr<Subgraph> subg(new Subgraph(g->VertexSubgraph(vids))); std::shared_ptr<Subgraph> subg(new Subgraph(g->VertexSubgraph(vids)));
...@@ -266,7 +268,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph") ...@@ -266,7 +268,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray eids = args[1]; const IdArray eids = args[1];
bool preserve_nodes = args[2]; bool preserve_nodes = args[2];
...@@ -276,7 +278,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph") ...@@ -276,7 +278,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphGetAdj") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphGetAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
bool transpose = args[1]; bool transpose = args[1];
std::string format = args[2]; std::string format = args[2];
...@@ -285,13 +287,13 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphGetAdj") ...@@ -285,13 +287,13 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphGetAdj")
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphContext") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphContext")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
*rv = g->Context(); *rv = g->Context();
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
*rv = g->NumBits(); *rv = g->NumBits();
}); });
...@@ -299,25 +301,25 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits") ...@@ -299,25 +301,25 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits")
// Subgraph C APIs // Subgraph C APIs
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetGraph") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
SubgraphRef subg = args[0]; SubgraphRef subg = args[0];
*rv = GraphRef(subg->graph); *rv = GraphRef(subg->graph);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetInducedVertices") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetInducedVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
SubgraphRef subg = args[0]; SubgraphRef subg = args[0];
*rv = subg->induced_vertices; *rv = subg->induced_vertices;
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetInducedEdges") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetInducedEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
SubgraphRef subg = args[0]; SubgraphRef subg = args[0];
*rv = subg->induced_edges; *rv = subg->induced_edges;
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSortAdj") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSortAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
g->SortCSR(); g->SortCSR();
}); });
......
...@@ -3,13 +3,15 @@ ...@@ -3,13 +3,15 @@
* \file graph/graph.cc * \file graph/graph.cc
* \brief Graph operation implementation * \brief Graph operation implementation
*/ */
#include <dgl/graph_op.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/graph_op.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <algorithm> #include <algorithm>
#include "../c_api_common.h" #include "../c_api_common.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -19,7 +21,7 @@ namespace { ...@@ -19,7 +21,7 @@ namespace {
// generate consecutive dgl ids // generate consecutive dgl ids
class RangeIter : public std::iterator<std::input_iterator_tag, dgl_id_t> { class RangeIter : public std::iterator<std::input_iterator_tag, dgl_id_t> {
public: public:
explicit RangeIter(dgl_id_t from): cur_(from) {} explicit RangeIter(dgl_id_t from) : cur_(from) {}
RangeIter& operator++() { RangeIter& operator++() {
++cur_; ++cur_;
...@@ -31,15 +33,9 @@ class RangeIter : public std::iterator<std::input_iterator_tag, dgl_id_t> { ...@@ -31,15 +33,9 @@ class RangeIter : public std::iterator<std::input_iterator_tag, dgl_id_t> {
++cur_; ++cur_;
return retval; return retval;
} }
bool operator==(RangeIter other) const { bool operator==(RangeIter other) const { return cur_ == other.cur_; }
return cur_ == other.cur_; bool operator!=(RangeIter other) const { return cur_ != other.cur_; }
} dgl_id_t operator*() const { return cur_; }
bool operator!=(RangeIter other) const {
return cur_ != other.cur_;
}
dgl_id_t operator*() const {
return cur_;
}
private: private:
dgl_id_t cur_; dgl_id_t cur_;
...@@ -88,7 +84,8 @@ GraphPtr GraphOp::DisjointUnion(std::vector<GraphPtr> graphs) { ...@@ -88,7 +84,8 @@ GraphPtr GraphOp::DisjointUnion(std::vector<GraphPtr> graphs) {
rst->AddVertices(gr->NumVertices()); rst->AddVertices(gr->NumVertices());
for (uint64_t i = 0; i < gr->NumEdges(); ++i) { for (uint64_t i = 0; i < gr->NumEdges(); ++i) {
// TODO(minjie): quite ugly to expose internal members // TODO(minjie): quite ugly to expose internal members
rst->AddEdge(mg->all_edges_src_[i] + cumsum, mg->all_edges_dst_[i] + cumsum); rst->AddEdge(
mg->all_edges_src_[i] + cumsum, mg->all_edges_dst_[i] + cumsum);
} }
cumsum += gr->NumVertices(); cumsum += gr->NumVertices();
} }
...@@ -136,14 +133,17 @@ GraphPtr GraphOp::DisjointUnion(std::vector<GraphPtr> graphs) { ...@@ -136,14 +133,17 @@ GraphPtr GraphOp::DisjointUnion(std::vector<GraphPtr> graphs) {
cum_num_edges += g_num_edges; cum_num_edges += g_num_edges;
} }
return ImmutableGraph::CreateFromCSR(indptr_arr, indices_arr, edge_ids_arr, "in"); return ImmutableGraph::CreateFromCSR(
indptr_arr, indices_arr, edge_ids_arr, "in");
} }
} }
std::vector<GraphPtr> GraphOp::DisjointPartitionByNum(GraphPtr graph, int64_t num) { std::vector<GraphPtr> GraphOp::DisjointPartitionByNum(
GraphPtr graph, int64_t num) {
CHECK(num != 0 && graph->NumVertices() % num == 0) CHECK(num != 0 && graph->NumVertices() % num == 0)
<< "Number of partitions must evenly divide the number of nodes."; << "Number of partitions must evenly divide the number of nodes.";
IdArray sizes = IdArray::Empty({num}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray sizes = IdArray::Empty(
{num}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* sizes_data = static_cast<int64_t*>(sizes->data); int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num); std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num);
return DisjointPartitionBySizes(graph, sizes); return DisjointPartitionBySizes(graph, sizes);
...@@ -170,10 +170,11 @@ std::vector<GraphPtr> GraphOp::DisjointPartitionBySizes( ...@@ -170,10 +170,11 @@ std::vector<GraphPtr> GraphOp::DisjointPartitionBySizes(
MutableGraphPtr mg = Graph::Create(); MutableGraphPtr mg = Graph::Create();
// TODO(minjie): quite ugly to expose internal members // TODO(minjie): quite ugly to expose internal members
// copy adj // copy adj
mg->adjlist_.insert(mg->adjlist_.end(), mg->adjlist_.insert(
graph->adjlist_.begin() + node_offset, mg->adjlist_.end(), graph->adjlist_.begin() + node_offset,
graph->adjlist_.begin() + node_offset + sizes_data[i]); graph->adjlist_.begin() + node_offset + sizes_data[i]);
mg->reverse_adjlist_.insert(mg->reverse_adjlist_.end(), mg->reverse_adjlist_.insert(
mg->reverse_adjlist_.end(),
graph->reverse_adjlist_.begin() + node_offset, graph->reverse_adjlist_.begin() + node_offset,
graph->reverse_adjlist_.begin() + node_offset + sizes_data[i]); graph->reverse_adjlist_.begin() + node_offset + sizes_data[i]);
// relabel adjs // relabel adjs
...@@ -209,12 +210,15 @@ std::vector<GraphPtr> GraphOp::DisjointPartitionBySizes( ...@@ -209,12 +210,15 @@ std::vector<GraphPtr> GraphOp::DisjointPartitionBySizes(
} }
} else { } else {
// Input is an immutable graph. Partition it into several multiple graphs. // Input is an immutable graph. Partition it into several multiple graphs.
ImmutableGraphPtr graph = std::dynamic_pointer_cast<ImmutableGraph>(batched_graph); ImmutableGraphPtr graph =
std::dynamic_pointer_cast<ImmutableGraph>(batched_graph);
// TODO(minjie): why in csr? // TODO(minjie): why in csr?
CSRPtr in_csr_ptr = graph->GetInCSR(); CSRPtr in_csr_ptr = graph->GetInCSR();
const dgl_id_t* indptr = static_cast<dgl_id_t*>(in_csr_ptr->indptr()->data); const dgl_id_t* indptr = static_cast<dgl_id_t*>(in_csr_ptr->indptr()->data);
const dgl_id_t* indices = static_cast<dgl_id_t*>(in_csr_ptr->indices()->data); const dgl_id_t* indices =
const dgl_id_t* edge_ids = static_cast<dgl_id_t*>(in_csr_ptr->edge_ids()->data); static_cast<dgl_id_t*>(in_csr_ptr->indices()->data);
const dgl_id_t* edge_ids =
static_cast<dgl_id_t*>(in_csr_ptr->edge_ids()->data);
dgl_id_t cum_sum_edges = 0; dgl_id_t cum_sum_edges = 0;
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
const int64_t start_pos = cumsum[i]; const int64_t start_pos = cumsum[i];
...@@ -257,7 +261,8 @@ IdArray GraphOp::MapParentIdToSubgraphId(IdArray parent_vids, IdArray query) { ...@@ -257,7 +261,8 @@ IdArray GraphOp::MapParentIdToSubgraphId(IdArray parent_vids, IdArray query) {
const auto query_len = query->shape[0]; const auto query_len = query->shape[0];
const dgl_id_t* parent_data = static_cast<dgl_id_t*>(parent_vids->data); const dgl_id_t* parent_data = static_cast<dgl_id_t*>(parent_vids->data);
const dgl_id_t* query_data = static_cast<dgl_id_t*>(query->data); const dgl_id_t* query_data = static_cast<dgl_id_t*>(query->data);
IdArray rst = IdArray::Empty({query_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray rst = IdArray::Empty(
{query_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data); dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data);
const bool is_sorted = std::is_sorted(parent_data, parent_data + parent_len); const bool is_sorted = std::is_sorted(parent_data, parent_data + parent_len);
...@@ -300,11 +305,12 @@ IdArray GraphOp::ExpandIds(IdArray ids, IdArray offset) { ...@@ -300,11 +305,12 @@ IdArray GraphOp::ExpandIds(IdArray ids, IdArray offset) {
const auto id_len = ids->shape[0]; const auto id_len = ids->shape[0];
const auto off_len = offset->shape[0]; const auto off_len = offset->shape[0];
CHECK_EQ(id_len + 1, off_len); CHECK_EQ(id_len + 1, off_len);
const dgl_id_t *id_data = static_cast<dgl_id_t*>(ids->data); const dgl_id_t* id_data = static_cast<dgl_id_t*>(ids->data);
const dgl_id_t *off_data = static_cast<dgl_id_t*>(offset->data); const dgl_id_t* off_data = static_cast<dgl_id_t*>(offset->data);
const int64_t len = off_data[off_len - 1]; const int64_t len = off_data[off_len - 1];
IdArray rst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray rst = IdArray::Empty(
dgl_id_t *rst_data = static_cast<dgl_id_t*>(rst->data); {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data);
for (int64_t i = 0; i < id_len; i++) { for (int64_t i = 0; i < id_len; i++) {
const int64_t local_len = off_data[i + 1] - off_data[i]; const int64_t local_len = off_data[i + 1] - off_data[i];
for (int64_t j = 0; j < local_len; j++) { for (int64_t j = 0; j < local_len; j++) {
...@@ -325,10 +331,11 @@ GraphPtr GraphOp::ToSimpleGraph(GraphPtr graph) { ...@@ -325,10 +331,11 @@ GraphPtr GraphOp::ToSimpleGraph(GraphPtr graph) {
hashmap.insert(dst); hashmap.insert(dst);
} }
} }
indptr[src+1] = indices.size(); indptr[src + 1] = indices.size();
} }
CSRPtr csr(new CSR(graph->NumVertices(), indices.size(), CSRPtr csr(new CSR(
indptr.begin(), indices.begin(), RangeIter(0))); graph->NumVertices(), indices.size(), indptr.begin(), indices.begin(),
RangeIter(0)));
return std::make_shared<ImmutableGraph>(csr); return std::make_shared<ImmutableGraph>(csr);
} }
...@@ -403,8 +410,9 @@ GraphPtr GraphOp::ToBidirectedImmutableGraph(GraphPtr g) { ...@@ -403,8 +410,9 @@ GraphPtr GraphOp::ToBidirectedImmutableGraph(GraphPtr g) {
g->NumVertices(), srcs_array, dsts_array); g->NumVertices(), srcs_array, dsts_array);
} }
HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hops) { HaloSubgraph GraphOp::GetSubgraphWithHalo(
const dgl_id_t *nid = static_cast<dgl_id_t *>(nodes->data); GraphPtr g, IdArray nodes, int num_hops) {
const dgl_id_t* nid = static_cast<dgl_id_t*>(nodes->data);
const auto id_len = nodes->shape[0]; const auto id_len = nodes->shape[0];
// A map contains all nodes in the subgraph. // A map contains all nodes in the subgraph.
// The key is the old node Ids, the value indicates whether a node is a inner // The key is the old node Ids, the value indicates whether a node is a inner
...@@ -414,8 +422,7 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop ...@@ -414,8 +422,7 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
// vector. The first few nodes are the inner nodes in the subgraph. // vector. The first few nodes are the inner nodes in the subgraph.
std::vector<dgl_id_t> old_node_ids(nid, nid + id_len); std::vector<dgl_id_t> old_node_ids(nid, nid + id_len);
std::vector<std::vector<dgl_id_t>> outer_nodes(num_hops); std::vector<std::vector<dgl_id_t>> outer_nodes(num_hops);
for (int64_t i = 0; i < id_len; i++) for (int64_t i = 0; i < id_len; i++) all_nodes[nid[i]] = true;
all_nodes[nid[i]] = true;
auto orig_nodes = all_nodes; auto orig_nodes = all_nodes;
std::vector<dgl_id_t> edge_src, edge_dst, edge_eid; std::vector<dgl_id_t> edge_src, edge_dst, edge_eid;
...@@ -428,9 +435,9 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop ...@@ -428,9 +435,9 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
auto dst = in_edges.dst; auto dst = in_edges.dst;
auto eid = in_edges.id; auto eid = in_edges.id;
auto num_edges = eid->shape[0]; auto num_edges = eid->shape[0];
const dgl_id_t *src_data = static_cast<dgl_id_t *>(src->data); const dgl_id_t* src_data = static_cast<dgl_id_t*>(src->data);
const dgl_id_t *dst_data = static_cast<dgl_id_t *>(dst->data); const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst->data);
const dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data); const dgl_id_t* eid_data = static_cast<dgl_id_t*>(eid->data);
for (int64_t i = 0; i < num_edges; i++) { for (int64_t i = 0; i < num_edges; i++) {
// We check if the source node is in the original node. // We check if the source node is in the original node.
auto it1 = orig_nodes.find(src_data[i]); auto it1 = orig_nodes.find(src_data[i]);
...@@ -451,15 +458,15 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop ...@@ -451,15 +458,15 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
// Now we need to traverse the graph with the in-edges to access nodes // Now we need to traverse the graph with the in-edges to access nodes
// and edges more hops away. // and edges more hops away.
for (int k = 1; k < num_hops; k++) { for (int k = 1; k < num_hops; k++) {
const std::vector<dgl_id_t> &nodes = outer_nodes[k-1]; const std::vector<dgl_id_t>& nodes = outer_nodes[k - 1];
EdgeArray in_edges = g->InEdges(aten::VecToIdArray(nodes)); EdgeArray in_edges = g->InEdges(aten::VecToIdArray(nodes));
auto src = in_edges.src; auto src = in_edges.src;
auto dst = in_edges.dst; auto dst = in_edges.dst;
auto eid = in_edges.id; auto eid = in_edges.id;
auto num_edges = eid->shape[0]; auto num_edges = eid->shape[0];
const dgl_id_t *src_data = static_cast<dgl_id_t *>(src->data); const dgl_id_t* src_data = static_cast<dgl_id_t*>(src->data);
const dgl_id_t *dst_data = static_cast<dgl_id_t *>(dst->data); const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst->data);
const dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data); const dgl_id_t* eid_data = static_cast<dgl_id_t*>(eid->data);
for (int64_t i = 0; i < num_edges; i++) { for (int64_t i = 0; i < num_edges; i++) {
edge_src.push_back(src_data[i]); edge_src.push_back(src_data[i]);
edge_dst.push_back(dst_data[i]); edge_dst.push_back(dst_data[i]);
...@@ -482,12 +489,12 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop ...@@ -482,12 +489,12 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
} }
num_edges = edge_src.size(); num_edges = edge_src.size();
IdArray new_src = IdArray::Empty({num_edges}, DGLDataType{kDGLInt, 64, 1}, IdArray new_src = IdArray::Empty(
DGLContext{kDGLCPU, 0}); {num_edges}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray new_dst = IdArray::Empty({num_edges}, DGLDataType{kDGLInt, 64, 1}, IdArray new_dst = IdArray::Empty(
DGLContext{kDGLCPU, 0}); {num_edges}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
dgl_id_t *new_src_data = static_cast<dgl_id_t *>(new_src->data); dgl_id_t* new_src_data = static_cast<dgl_id_t*>(new_src->data);
dgl_id_t *new_dst_data = static_cast<dgl_id_t *>(new_dst->data); dgl_id_t* new_dst_data = static_cast<dgl_id_t*>(new_dst->data);
for (size_t i = 0; i < edge_src.size(); i++) { for (size_t i = 0; i < edge_src.size(); i++) {
new_src_data[i] = old2new[edge_src[i]]; new_src_data[i] = old2new[edge_src[i]];
new_dst_data[i] = old2new[edge_dst[i]]; new_dst_data[i] = old2new[edge_dst[i]];
...@@ -499,7 +506,8 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop ...@@ -499,7 +506,8 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
inner_nodes[i] = all_nodes[old_nid]; inner_nodes[i] = all_nodes[old_nid];
} }
GraphPtr subg = ImmutableGraph::CreateFromCOO(old_node_ids.size(), new_src, new_dst); GraphPtr subg =
ImmutableGraph::CreateFromCOO(old_node_ids.size(), new_src, new_dst);
HaloSubgraph halo_subg; HaloSubgraph halo_subg;
halo_subg.graph = subg; halo_subg.graph = subg;
halo_subg.induced_vertices = aten::VecToIdArray(old_node_ids); halo_subg.induced_vertices = aten::VecToIdArray(old_node_ids);
...@@ -509,7 +517,8 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop ...@@ -509,7 +517,8 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
return halo_subg; return halo_subg;
} }
GraphPtr GraphOp::ReorderImmutableGraph(ImmutableGraphPtr ig, IdArray new_order) { GraphPtr GraphOp::ReorderImmutableGraph(
ImmutableGraphPtr ig, IdArray new_order) {
CSRPtr in_csr, out_csr; CSRPtr in_csr, out_csr;
COOPtr coo; COOPtr coo;
// We only need to reorder one of the graph structure. // We only need to reorder one of the graph structure.
...@@ -517,12 +526,14 @@ GraphPtr GraphOp::ReorderImmutableGraph(ImmutableGraphPtr ig, IdArray new_order) ...@@ -517,12 +526,14 @@ GraphPtr GraphOp::ReorderImmutableGraph(ImmutableGraphPtr ig, IdArray new_order)
in_csr = ig->GetInCSR(); in_csr = ig->GetInCSR();
auto csrmat = in_csr->ToCSRMatrix(); auto csrmat = in_csr->ToCSRMatrix();
auto new_csrmat = aten::CSRReorder(csrmat, new_order, new_order); auto new_csrmat = aten::CSRReorder(csrmat, new_order, new_order);
in_csr = CSRPtr(new CSR(new_csrmat.indptr, new_csrmat.indices, new_csrmat.data)); in_csr =
CSRPtr(new CSR(new_csrmat.indptr, new_csrmat.indices, new_csrmat.data));
} else if (ig->HasOutCSR()) { } else if (ig->HasOutCSR()) {
out_csr = ig->GetOutCSR(); out_csr = ig->GetOutCSR();
auto csrmat = out_csr->ToCSRMatrix(); auto csrmat = out_csr->ToCSRMatrix();
auto new_csrmat = aten::CSRReorder(csrmat, new_order, new_order); auto new_csrmat = aten::CSRReorder(csrmat, new_order, new_order);
out_csr = CSRPtr(new CSR(new_csrmat.indptr, new_csrmat.indices, new_csrmat.data)); out_csr =
CSRPtr(new CSR(new_csrmat.indptr, new_csrmat.indices, new_csrmat.data));
} else { } else {
coo = ig->GetCOO(); coo = ig->GetCOO();
auto coomat = coo->ToCOOMatrix(); auto coomat = coo->ToCOOMatrix();
...@@ -536,14 +547,14 @@ GraphPtr GraphOp::ReorderImmutableGraph(ImmutableGraphPtr ig, IdArray new_order) ...@@ -536,14 +547,14 @@ GraphPtr GraphOp::ReorderImmutableGraph(ImmutableGraphPtr ig, IdArray new_order)
} }
DGL_REGISTER_GLOBAL("transform._CAPI_DGLPartitionWithHalo") DGL_REGISTER_GLOBAL("transform._CAPI_DGLPartitionWithHalo")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef graph = args[0]; GraphRef graph = args[0];
IdArray node_parts = args[1]; IdArray node_parts = args[1];
int num_hops = args[2]; int num_hops = args[2];
const dgl_id_t *part_data = static_cast<dgl_id_t *>(node_parts->data); const dgl_id_t* part_data = static_cast<dgl_id_t*>(node_parts->data);
int64_t num_nodes = node_parts->shape[0]; int64_t num_nodes = node_parts->shape[0];
std::unordered_map<int, std::vector<dgl_id_t> > part_map; std::unordered_map<int, std::vector<dgl_id_t>> part_map;
for (int64_t i = 0; i < num_nodes; i++) { for (int64_t i = 0; i < num_nodes; i++) {
dgl_id_t part_id = part_data[i]; dgl_id_t part_id = part_data[i];
auto it = part_map.find(part_id); auto it = part_map.find(part_id);
...@@ -556,7 +567,7 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLPartitionWithHalo") ...@@ -556,7 +567,7 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLPartitionWithHalo")
} }
} }
std::vector<int> part_ids; std::vector<int> part_ids;
std::vector<std::vector<dgl_id_t> > part_nodes; std::vector<std::vector<dgl_id_t>> part_nodes;
int max_part_id = 0; int max_part_id = 0;
for (auto it = part_map.begin(); it != part_map.end(); it++) { for (auto it = part_map.begin(); it != part_map.end(); it++) {
max_part_id = std::max(it->first, max_part_id); max_part_id = std::max(it->first, max_part_id);
...@@ -570,12 +581,13 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLPartitionWithHalo") ...@@ -570,12 +581,13 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLPartitionWithHalo")
// try to construct in-CSR in openmp for loop, which will lead // try to construct in-CSR in openmp for loop, which will lead
// to some unexpected results. // to some unexpected results.
graph_ptr->GetInCSR(); graph_ptr->GetInCSR();
std::vector<std::shared_ptr<HaloSubgraph> > subgs(max_part_id + 1); std::vector<std::shared_ptr<HaloSubgraph>> subgs(max_part_id + 1);
int num_partitions = part_nodes.size(); int num_partitions = part_nodes.size();
runtime::parallel_for(0, num_partitions, [&](size_t b, size_t e) { runtime::parallel_for(0, num_partitions, [&](size_t b, size_t e) {
for (auto i = b; i < e; ++i) { for (auto i = b; i < e; ++i) {
auto nodes = aten::VecToIdArray(part_nodes[i]); auto nodes = aten::VecToIdArray(part_nodes[i]);
HaloSubgraph subg = GraphOp::GetSubgraphWithHalo(graph_ptr, nodes, num_hops); HaloSubgraph subg =
GraphOp::GetSubgraphWithHalo(graph_ptr, nodes, num_hops);
std::shared_ptr<HaloSubgraph> subg_ptr(new HaloSubgraph(subg)); std::shared_ptr<HaloSubgraph> subg_ptr(new HaloSubgraph(subg));
int part_id = part_ids[i]; int part_id = part_ids[i];
subgs[part_id] = subg_ptr; subgs[part_id] = subg_ptr;
...@@ -589,25 +601,26 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLPartitionWithHalo") ...@@ -589,25 +601,26 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLPartitionWithHalo")
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGetSubgraphWithHalo") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGetSubgraphWithHalo")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef graph = args[0]; GraphRef graph = args[0];
IdArray nodes = args[1]; IdArray nodes = args[1];
int num_hops = args[2]; int num_hops = args[2];
HaloSubgraph subg = GraphOp::GetSubgraphWithHalo(graph.sptr(), nodes, num_hops); HaloSubgraph subg =
GraphOp::GetSubgraphWithHalo(graph.sptr(), nodes, num_hops);
std::shared_ptr<HaloSubgraph> subg_ptr(new HaloSubgraph(subg)); std::shared_ptr<HaloSubgraph> subg_ptr(new HaloSubgraph(subg));
*rv = SubgraphRef(subg_ptr); *rv = SubgraphRef(subg_ptr);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_GetHaloSubgraphInnerNodes") DGL_REGISTER_GLOBAL("graph_index._CAPI_GetHaloSubgraphInnerNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
SubgraphRef g = args[0]; SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<HaloSubgraph>(g.sptr()); auto gptr = std::dynamic_pointer_cast<HaloSubgraph>(g.sptr());
CHECK(gptr) << "The input graph has to be immutable graph"; CHECK(gptr) << "The input graph has to be immutable graph";
*rv = gptr->inner_nodes; *rv = gptr->inner_nodes;
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
List<GraphRef> graphs = args[0]; List<GraphRef> graphs = args[0];
std::vector<GraphPtr> ptrs(graphs.size()); std::vector<GraphPtr> ptrs(graphs.size());
for (size_t i = 0; i < graphs.size(); ++i) { for (size_t i = 0; i < graphs.size(); ++i) {
...@@ -617,7 +630,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion") ...@@ -617,7 +630,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
int64_t num = args[1]; int64_t num = args[1];
const auto& ret = GraphOp::DisjointPartitionByNum(g.sptr(), num); const auto& ret = GraphOp::DisjointPartitionByNum(g.sptr(), num);
...@@ -629,7 +642,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum") ...@@ -629,7 +642,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum")
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray sizes = args[1]; const IdArray sizes = args[1];
const auto& ret = GraphOp::DisjointPartitionBySizes(g.sptr(), sizes); const auto& ret = GraphOp::DisjointPartitionBySizes(g.sptr(), sizes);
...@@ -638,35 +651,35 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes") ...@@ -638,35 +651,35 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes")
ret_list.push_back(GraphRef(gp)); ret_list.push_back(GraphRef(gp));
} }
*rv = ret_list; *rv = ret_list;
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLineGraph") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLineGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
bool backtracking = args[1]; bool backtracking = args[1];
*rv = GraphOp::LineGraph(g.sptr(), backtracking); *rv = GraphOp::LineGraph(g.sptr(), backtracking);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLToImmutable") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLToImmutable")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
*rv = ImmutableGraph::ToImmutable(g.sptr()); *rv = ImmutableGraph::ToImmutable(g.sptr());
}); });
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToSimpleGraph") DGL_REGISTER_GLOBAL("transform._CAPI_DGLToSimpleGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
*rv = GraphOp::ToSimpleGraph(g.sptr()); *rv = GraphOp::ToSimpleGraph(g.sptr());
}); });
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedMutableGraph") DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedMutableGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
*rv = GraphOp::ToBidirectedMutableGraph(g.sptr()); *rv = GraphOp::ToBidirectedMutableGraph(g.sptr());
}); });
DGL_REGISTER_GLOBAL("transform._CAPI_DGLReorderGraph") DGL_REGISTER_GLOBAL("transform._CAPI_DGLReorderGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray new_order = args[1]; const IdArray new_order = args[1];
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
...@@ -675,7 +688,7 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLReorderGraph") ...@@ -675,7 +688,7 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLReorderGraph")
}); });
DGL_REGISTER_GLOBAL("transform._CAPI_DGLReassignEdges") DGL_REGISTER_GLOBAL("transform._CAPI_DGLReassignEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef graph = args[0]; GraphRef graph = args[0];
bool is_incsr = args[1]; bool is_incsr = args[1];
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(graph.sptr()); auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(graph.sptr());
...@@ -683,7 +696,8 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLReassignEdges") ...@@ -683,7 +696,8 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLReassignEdges")
CSRPtr csr = is_incsr ? gptr->GetInCSR() : gptr->GetOutCSR(); CSRPtr csr = is_incsr ? gptr->GetInCSR() : gptr->GetOutCSR();
auto csrmat = csr->ToCSRMatrix(); auto csrmat = csr->ToCSRMatrix();
int64_t num_edges = csrmat.data->shape[0]; int64_t num_edges = csrmat.data->shape[0];
IdArray new_data = IdArray::Empty({num_edges}, csrmat.data->dtype, csrmat.data->ctx); IdArray new_data =
IdArray::Empty({num_edges}, csrmat.data->dtype, csrmat.data->ctx);
// Return the original edge Ids. // Return the original edge Ids.
*rv = new_data; *rv = new_data;
// TODO(zhengda) I need to invalidate out-CSR and COO. // TODO(zhengda) I need to invalidate out-CSR and COO.
...@@ -692,8 +706,8 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLReassignEdges") ...@@ -692,8 +706,8 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLReassignEdges")
// TODO(zhengda) after assignment, we actually don't need to store them // TODO(zhengda) after assignment, we actually don't need to store them
// physically. // physically.
ATEN_ID_TYPE_SWITCH(new_data->dtype, IdType, { ATEN_ID_TYPE_SWITCH(new_data->dtype, IdType, {
IdType *typed_new_data = static_cast<IdType*>(new_data->data); IdType* typed_new_data = static_cast<IdType*>(new_data->data);
IdType *typed_data = static_cast<IdType*>(csrmat.data->data); IdType* typed_data = static_cast<IdType*>(csrmat.data->data);
for (int64_t i = 0; i < num_edges; i++) { for (int64_t i = 0; i < num_edges; i++) {
typed_new_data[i] = typed_data[i]; typed_new_data[i] = typed_data[i];
typed_data[i] = i; typed_data[i] = i;
...@@ -702,7 +716,7 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLReassignEdges") ...@@ -702,7 +716,7 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLReassignEdges")
}); });
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedImmutableGraph") DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedImmutableGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
auto gptr = g.sptr(); auto gptr = g.sptr();
auto immutable_g = std::dynamic_pointer_cast<ImmutableGraph>(gptr); auto immutable_g = std::dynamic_pointer_cast<ImmutableGraph>(gptr);
...@@ -719,29 +733,31 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedImmutableGraph") ...@@ -719,29 +733,31 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedImmutableGraph")
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLMapSubgraphNID") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLMapSubgraphNID")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const IdArray parent_vids = args[0]; const IdArray parent_vids = args[0];
const IdArray query = args[1]; const IdArray query = args[1];
*rv = GraphOp::MapParentIdToSubgraphId(parent_vids, query); *rv = GraphOp::MapParentIdToSubgraphId(parent_vids, query);
}); });
template<class IdType> template <class IdType>
IdArray MapIds(IdArray ids, IdArray range_starts, IdArray range_ends, IdArray typed_map, IdArray MapIds(
IdArray ids, IdArray range_starts, IdArray range_ends, IdArray typed_map,
int num_parts, int num_types) { int num_parts, int num_types) {
int64_t num_ids = ids->shape[0]; int64_t num_ids = ids->shape[0];
int64_t num_ranges = range_starts->shape[0]; int64_t num_ranges = range_starts->shape[0];
IdArray ret = IdArray::Empty({num_ids * 2}, ids->dtype, ids->ctx); IdArray ret = IdArray::Empty({num_ids * 2}, ids->dtype, ids->ctx);
const IdType *range_start_data = static_cast<IdType *>(range_starts->data); const IdType* range_start_data = static_cast<IdType*>(range_starts->data);
const IdType *range_end_data = static_cast<IdType *>(range_ends->data); const IdType* range_end_data = static_cast<IdType*>(range_ends->data);
const IdType *ids_data = static_cast<IdType *>(ids->data); const IdType* ids_data = static_cast<IdType*>(ids->data);
const IdType *typed_map_data = static_cast<IdType *>(typed_map->data); const IdType* typed_map_data = static_cast<IdType*>(typed_map->data);
IdType *types_data = static_cast<IdType *>(ret->data); IdType* types_data = static_cast<IdType*>(ret->data);
IdType *per_type_ids_data = static_cast<IdType *>(ret->data) + num_ids; IdType* per_type_ids_data = static_cast<IdType*>(ret->data) + num_ids;
runtime::parallel_for(0, ids->shape[0], [&](size_t b, size_t e) { runtime::parallel_for(0, ids->shape[0], [&](size_t b, size_t e) {
for (auto i = b; i < e; ++i) { for (auto i = b; i < e; ++i) {
IdType id = ids_data[i]; IdType id = ids_data[i];
auto it = std::lower_bound(range_end_data, range_end_data + num_ranges, id); auto it =
std::lower_bound(range_end_data, range_end_data + num_ranges, id);
// The range must exist. // The range must exist.
BUG_IF_FAIL(it != range_end_data + num_ranges); BUG_IF_FAIL(it != range_end_data + num_ranges);
size_t range_id = it - range_end_data; size_t range_id = it - range_end_data;
...@@ -752,8 +768,9 @@ IdArray MapIds(IdArray ids, IdArray range_starts, IdArray range_ends, IdArray ty ...@@ -752,8 +768,9 @@ IdArray MapIds(IdArray ids, IdArray range_starts, IdArray range_ends, IdArray ty
if (part_id == 0) { if (part_id == 0) {
per_type_ids_data[i] = id - range_start_data[range_id]; per_type_ids_data[i] = id - range_start_data[range_id];
} else { } else {
per_type_ids_data[i] = id - range_start_data[range_id] per_type_ids_data[i] =
+ typed_map_data[num_parts * type_id + part_id - 1]; id - range_start_data[range_id] +
typed_map_data[num_parts * type_id + part_id - 1];
} }
} }
}); });
...@@ -761,7 +778,7 @@ IdArray MapIds(IdArray ids, IdArray range_starts, IdArray range_ends, IdArray ty ...@@ -761,7 +778,7 @@ IdArray MapIds(IdArray ids, IdArray range_starts, IdArray range_ends, IdArray ty
} }
DGL_REGISTER_GLOBAL("distributed.id_map._CAPI_DGLHeteroMapIds") DGL_REGISTER_GLOBAL("distributed.id_map._CAPI_DGLHeteroMapIds")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const IdArray ids = args[0]; const IdArray ids = args[0];
const IdArray range_starts = args[1]; const IdArray range_starts = args[1];
const IdArray range_ends = args[2]; const IdArray range_ends = args[2];
...@@ -778,7 +795,8 @@ DGL_REGISTER_GLOBAL("distributed.id_map._CAPI_DGLHeteroMapIds") ...@@ -778,7 +795,8 @@ DGL_REGISTER_GLOBAL("distributed.id_map._CAPI_DGLHeteroMapIds")
IdArray ret; IdArray ret;
ATEN_ID_TYPE_SWITCH(ids->dtype, IdType, { ATEN_ID_TYPE_SWITCH(ids->dtype, IdType, {
ret = MapIds<IdType>(ids, range_starts, range_ends, typed_map, num_parts, num_types); ret = MapIds<IdType>(
ids, range_starts, range_ends, typed_map, num_parts, num_types);
}); });
*rv = ret; *rv = ret;
}); });
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <dgl/graph_traversal.h> #include <dgl/graph_traversal.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include "../c_api_common.h" #include "../c_api_common.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -13,7 +14,7 @@ namespace dgl { ...@@ -13,7 +14,7 @@ namespace dgl {
namespace traverse { namespace traverse {
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes_v2") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
const IdArray src = args[1]; const IdArray src = args[1];
bool reversed = args[2]; bool reversed = args[2];
...@@ -28,7 +29,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes_v2") ...@@ -28,7 +29,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes_v2")
}); });
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges_v2") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
const IdArray src = args[1]; const IdArray src = args[1];
bool reversed = args[2]; bool reversed = args[2];
...@@ -44,7 +45,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges_v2") ...@@ -44,7 +45,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges_v2")
}); });
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes_v2") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
bool reversed = args[1]; bool reversed = args[1];
aten::CSRMatrix csr; aten::CSRMatrix csr;
...@@ -58,9 +59,8 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes_v2") ...@@ -58,9 +59,8 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes_v2")
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections}); *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
}); });
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges_v2") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
const IdArray source = args[1]; const IdArray source = args[1];
const bool reversed = args[2]; const bool reversed = args[2];
...@@ -76,7 +76,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges_v2") ...@@ -76,7 +76,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges_v2")
}); });
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges_v2") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
const IdArray source = args[1]; const IdArray source = args[1];
const bool reversed = args[2]; const bool reversed = args[2];
...@@ -90,14 +90,12 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges_v2") ...@@ -90,14 +90,12 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges_v2")
csr = g.sptr()->GetCSRMatrix(0); csr = g.sptr()->GetCSRMatrix(0);
} }
const auto& front = aten::DGLDFSLabeledEdges(csr, const auto& front = aten::DGLDFSLabeledEdges(
source, csr, source, has_reverse_edge, has_nontree_edge, return_labels);
has_reverse_edge,
has_nontree_edge,
return_labels);
if (return_labels) { if (return_labels) {
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.tags, front.sections}); *rv = ConvertNDArrayVectorToPackedFunc(
{front.ids, front.tags, front.sections});
} else { } else {
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections}); *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
} }
......
...@@ -4,14 +4,16 @@ ...@@ -4,14 +4,16 @@
* \brief Heterograph implementation * \brief Heterograph implementation
*/ */
#include "./heterograph.h" #include "./heterograph.h"
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/immutable_graph.h>
#include <dgl/graph_serializer.h> #include <dgl/graph_serializer.h>
#include <dgl/immutable_graph.h>
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
#include <memory> #include <memory>
#include <vector>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include <vector>
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -23,7 +25,8 @@ using dgl::ImmutableGraph; ...@@ -23,7 +25,8 @@ using dgl::ImmutableGraph;
HeteroSubgraph EdgeSubgraphPreserveNodes( HeteroSubgraph EdgeSubgraphPreserveNodes(
const HeteroGraph* hg, const std::vector<IdArray>& eids) { const HeteroGraph* hg, const std::vector<IdArray>& eids) {
CHECK_EQ(eids.size(), hg->NumEdgeTypes()) CHECK_EQ(eids.size(), hg->NumEdgeTypes())
<< "Invalid input: the input list size must be the same as the number of edge type."; << "Invalid input: the input list size must be the same as the number of "
"edge type.";
HeteroSubgraph ret; HeteroSubgraph ret;
ret.induced_vertices.resize(hg->NumVertexTypes()); ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = eids; ret.induced_edges = eids;
...@@ -33,14 +36,14 @@ HeteroSubgraph EdgeSubgraphPreserveNodes( ...@@ -33,14 +36,14 @@ HeteroSubgraph EdgeSubgraphPreserveNodes(
auto pair = hg->meta_graph()->FindEdge(etype); auto pair = hg->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
const auto& rel_vsg = hg->GetRelationGraph(etype)->EdgeSubgraph( const auto& rel_vsg =
{eids[etype]}, true); hg->GetRelationGraph(etype)->EdgeSubgraph({eids[etype]}, true);
subrels[etype] = rel_vsg.graph; subrels[etype] = rel_vsg.graph;
ret.induced_vertices[src_vtype] = rel_vsg.induced_vertices[0]; ret.induced_vertices[src_vtype] = rel_vsg.induced_vertices[0];
ret.induced_vertices[dst_vtype] = rel_vsg.induced_vertices[1]; ret.induced_vertices[dst_vtype] = rel_vsg.induced_vertices[1];
} }
ret.graph = HeteroGraphPtr(new HeteroGraph( ret.graph = HeteroGraphPtr(
hg->meta_graph(), subrels, hg->NumVerticesPerType())); new HeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType()));
return ret; return ret;
} }
...@@ -49,34 +52,36 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes( ...@@ -49,34 +52,36 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
// TODO(minjie): In general, all relabeling should be separated with subgraph // TODO(minjie): In general, all relabeling should be separated with subgraph
// operations. // operations.
CHECK_EQ(eids.size(), hg->NumEdgeTypes()) CHECK_EQ(eids.size(), hg->NumEdgeTypes())
<< "Invalid input: the input list size must be the same as the number of edge type."; << "Invalid input: the input list size must be the same as the number of "
"edge type.";
HeteroSubgraph ret; HeteroSubgraph ret;
ret.induced_vertices.resize(hg->NumVertexTypes()); ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = eids; ret.induced_edges = eids;
// NOTE(minjie): EdgeSubgraph when preserve_nodes is false is quite complicated in // NOTE(minjie): EdgeSubgraph when preserve_nodes is false is quite
// heterograph. This is because we need to make sure bipartite graphs that incident // complicated in heterograph. This is because we need to make sure bipartite
// on the same vertex type must have the same ID space. For example, suppose we have // graphs that incident on the same vertex type must have the same ID space.
// following heterograph: // For example, suppose we have following heterograph:
// //
// Meta graph: A -> B -> C // Meta graph: A -> B -> C
// UnitGraph graphs: // UnitGraph graphs:
// * A -> B: (0, 0), (0, 1) // * A -> B: (0, 0), (0, 1)
// * B -> C: (1, 0), (1, 1) // * B -> C: (1, 0), (1, 1)
// //
// Suppose for A->B, we only keep edge (0, 0), while for B->C we only keep (1, 0). We need // Suppose for A->B, we only keep edge (0, 0), while for B->C we only keep (1,
// to make sure that in the result subgraph, node type B still has two nodes. This means // 0). We need to make sure that in the result subgraph, node type B still has
// we cannot simply compute EdgeSubgraph for B->C which will relabel node#1 of type B to be // two nodes. This means we cannot simply compute EdgeSubgraph for B->C which
// node #0. // will relabel node#1 of type B to be node #0.
// //
// One implementation is as follows: // One implementation is as follows:
// (1) For each bipartite graph, slice out the edges using the given eids. // (1) For each bipartite graph, slice out the edges using the given eids.
// (2) Make a dictionary map<vtype, vector<IdArray>>, where the key is the vertex type // (2) Make a dictionary map<vtype, vector<IdArray>>, where the key is the
// and the value is the incident nodes from the bipartite graphs that has the vertex // vertex type
// type as either srctype or dsttype. // and the value is the incident nodes from the bipartite graphs that has
// the vertex type as either srctype or dsttype.
// (3) Then for each vertex type, use aten::Relabel_ on its vector<IdArray>. // (3) Then for each vertex type, use aten::Relabel_ on its vector<IdArray>.
// aten::Relabel_ computes the union of the vertex sets and relabel // aten::Relabel_ computes the union of the vertex sets and relabel
// the unique elements from zero. The returned mapping array is the final induced // the unique elements from zero. The returned mapping array is the final
// vertex set for that vertex type. // induced vertex set for that vertex type.
// (4) Use the relabeled edges to construct the bipartite graph. // (4) Use the relabeled edges to construct the bipartite graph.
// step (1) & (2) // step (1) & (2)
std::vector<EdgeArray> subedges(hg->NumEdgeTypes()); std::vector<EdgeArray> subedges(hg->NumEdgeTypes());
...@@ -103,10 +108,9 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes( ...@@ -103,10 +108,9 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
subrels[etype] = UnitGraph::CreateFromCOO( subrels[etype] = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype)? 1 : 2, (src_vtype == dst_vtype) ? 1 : 2,
ret.induced_vertices[src_vtype]->shape[0], ret.induced_vertices[src_vtype]->shape[0],
ret.induced_vertices[dst_vtype]->shape[0], ret.induced_vertices[dst_vtype]->shape[0], subedges[etype].src,
subedges[etype].src,
subedges[etype].dst); subedges[etype].dst);
} }
ret.graph = HeteroGraphPtr(new HeteroGraph( ret.graph = HeteroGraphPtr(new HeteroGraph(
...@@ -114,36 +118,39 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes( ...@@ -114,36 +118,39 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
return ret; return ret;
} }
void HeteroGraphSanityCheck(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) { void HeteroGraphSanityCheck(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
// Sanity check // Sanity check
CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size()); CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size());
CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed."; CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed.";
// all relation graphs must have only one edge type // all relation graphs must have only one edge type
for (const auto &rg : rel_graphs) { for (const auto& rg : rel_graphs) {
CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must have only one edge type."; CHECK_EQ(rg->NumEdgeTypes(), 1)
<< "Each relation graph must have only one edge type.";
} }
auto ctx = rel_graphs[0]->Context(); auto ctx = rel_graphs[0]->Context();
for (const auto &rg : rel_graphs) { for (const auto& rg : rel_graphs) {
CHECK_EQ(rg->Context(), ctx) << "Each relation graph must have the same context."; CHECK_EQ(rg->Context(), ctx)
<< "Each relation graph must have the same context.";
} }
} }
std::vector<int64_t> std::vector<int64_t> InferNumVerticesPerType(
InferNumVerticesPerType(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) { GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
// create num verts per type // create num verts per type
std::vector<int64_t> num_verts_per_type(meta_graph->NumVertices(), -1); std::vector<int64_t> num_verts_per_type(meta_graph->NumVertices(), -1);
EdgeArray etype_array = meta_graph->Edges(); EdgeArray etype_array = meta_graph->Edges();
dgl_type_t *srctypes = static_cast<dgl_type_t *>(etype_array.src->data); dgl_type_t* srctypes = static_cast<dgl_type_t*>(etype_array.src->data);
dgl_type_t *dsttypes = static_cast<dgl_type_t *>(etype_array.dst->data); dgl_type_t* dsttypes = static_cast<dgl_type_t*>(etype_array.dst->data);
dgl_type_t *etypes = static_cast<dgl_type_t *>(etype_array.id->data); dgl_type_t* etypes = static_cast<dgl_type_t*>(etype_array.id->data);
for (size_t i = 0; i < meta_graph->NumEdges(); ++i) { for (size_t i = 0; i < meta_graph->NumEdges(); ++i) {
dgl_type_t srctype = srctypes[i]; dgl_type_t srctype = srctypes[i];
dgl_type_t dsttype = dsttypes[i]; dgl_type_t dsttype = dsttypes[i];
dgl_type_t etype = etypes[i]; dgl_type_t etype = etypes[i];
const auto& rg = rel_graphs[etype]; const auto& rg = rel_graphs[etype];
const auto sty = 0; const auto sty = 0;
const auto dty = rg->NumVertexTypes() == 1? 0 : 1; const auto dty = rg->NumVertexTypes() == 1 ? 0 : 1;
size_t nv; size_t nv;
// # nodes of source type // # nodes of source type
...@@ -164,7 +171,8 @@ InferNumVerticesPerType(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& ...@@ -164,7 +171,8 @@ InferNumVerticesPerType(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
return num_verts_per_type; return num_verts_per_type;
} }
std::vector<UnitGraphPtr> CastToUnitGraphs(const std::vector<HeteroGraphPtr>& rel_graphs) { std::vector<UnitGraphPtr> CastToUnitGraphs(
const std::vector<HeteroGraphPtr>& rel_graphs) {
std::vector<UnitGraphPtr> relation_graphs(rel_graphs.size()); std::vector<UnitGraphPtr> relation_graphs(rel_graphs.size());
for (size_t i = 0; i < rel_graphs.size(); ++i) { for (size_t i = 0; i < rel_graphs.size(); ++i) {
HeteroGraphPtr relg = rel_graphs[i]; HeteroGraphPtr relg = rel_graphs[i];
...@@ -181,9 +189,9 @@ std::vector<UnitGraphPtr> CastToUnitGraphs(const std::vector<HeteroGraphPtr>& re ...@@ -181,9 +189,9 @@ std::vector<UnitGraphPtr> CastToUnitGraphs(const std::vector<HeteroGraphPtr>& re
} // namespace } // namespace
HeteroGraph::HeteroGraph( HeteroGraph::HeteroGraph(
GraphPtr meta_graph, GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<HeteroGraphPtr>& rel_graphs, const std::vector<int64_t>& num_nodes_per_type)
const std::vector<int64_t>& num_nodes_per_type) : BaseHeteroGraph(meta_graph) { : BaseHeteroGraph(meta_graph) {
if (num_nodes_per_type.size() == 0) if (num_nodes_per_type.size() == 0)
num_verts_per_type_ = InferNumVerticesPerType(meta_graph, rel_graphs); num_verts_per_type_ = InferNumVerticesPerType(meta_graph, rel_graphs);
else else
...@@ -193,7 +201,7 @@ HeteroGraph::HeteroGraph( ...@@ -193,7 +201,7 @@ HeteroGraph::HeteroGraph(
} }
bool HeteroGraph::IsMultigraph() const { bool HeteroGraph::IsMultigraph() const {
for (const auto &hg : relation_graphs_) { for (const auto& hg : relation_graphs_) {
if (hg->IsMultigraph()) { if (hg->IsMultigraph()) {
return true; return true;
} }
...@@ -206,9 +214,11 @@ BoolArray HeteroGraph::HasVertices(dgl_type_t vtype, IdArray vids) const { ...@@ -206,9 +214,11 @@ BoolArray HeteroGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
return aten::LT(vids, NumVertices(vtype)); return aten::LT(vids, NumVertices(vtype));
} }
HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) const { HeteroSubgraph HeteroGraph::VertexSubgraph(
const std::vector<IdArray>& vids) const {
CHECK_EQ(vids.size(), NumVertexTypes()) CHECK_EQ(vids.size(), NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types."; << "Invalid input: the input list size must be the same as the number of "
"vertex types.";
HeteroSubgraph ret; HeteroSubgraph ret;
ret.induced_vertices = vids; ret.induced_vertices = vids;
std::vector<int64_t> num_vertices_per_type(NumVertexTypes()); std::vector<int64_t> num_vertices_per_type(NumVertexTypes());
...@@ -220,15 +230,16 @@ HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) con ...@@ -220,15 +230,16 @@ HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) con
auto pair = meta_graph_->FindEdge(etype); auto pair = meta_graph_->FindEdge(etype);
const dgl_type_t src_vtype = pair.first; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second; const dgl_type_t dst_vtype = pair.second;
const std::vector<IdArray> rel_vids = (src_vtype == dst_vtype) ? const std::vector<IdArray> rel_vids =
std::vector<IdArray>({vids[src_vtype]}) : (src_vtype == dst_vtype)
std::vector<IdArray>({vids[src_vtype], vids[dst_vtype]}); ? std::vector<IdArray>({vids[src_vtype]})
: std::vector<IdArray>({vids[src_vtype], vids[dst_vtype]});
const auto& rel_vsg = GetRelationGraph(etype)->VertexSubgraph(rel_vids); const auto& rel_vsg = GetRelationGraph(etype)->VertexSubgraph(rel_vids);
subrels[etype] = rel_vsg.graph; subrels[etype] = rel_vsg.graph;
ret.induced_edges[etype] = rel_vsg.induced_edges[0]; ret.induced_edges[etype] = rel_vsg.induced_edges[0];
} }
ret.graph = HeteroGraphPtr(new HeteroGraph( ret.graph = HeteroGraphPtr(
meta_graph_, subrels, std::move(num_vertices_per_type))); new HeteroGraph(meta_graph_, subrels, std::move(num_vertices_per_type)));
return ret; return ret;
} }
...@@ -248,11 +259,11 @@ HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { ...@@ -248,11 +259,11 @@ HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
for (auto g : hgindex->relation_graphs_) { for (auto g : hgindex->relation_graphs_) {
rel_graphs.push_back(UnitGraph::AsNumBits(g, bits)); rel_graphs.push_back(UnitGraph::AsNumBits(g, bits));
} }
return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs, return HeteroGraphPtr(new HeteroGraph(
hgindex->num_verts_per_type_)); hgindex->meta_graph_, rel_graphs, hgindex->num_verts_per_type_));
} }
HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DGLContext &ctx) { HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DGLContext& ctx) {
if (ctx == g->Context()) { if (ctx == g->Context()) {
return g; return g;
} }
...@@ -262,23 +273,20 @@ HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DGLContext &ctx) { ...@@ -262,23 +273,20 @@ HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DGLContext &ctx) {
for (auto g : hgindex->relation_graphs_) { for (auto g : hgindex->relation_graphs_) {
rel_graphs.push_back(UnitGraph::CopyTo(g, ctx)); rel_graphs.push_back(UnitGraph::CopyTo(g, ctx));
} }
return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs, return HeteroGraphPtr(new HeteroGraph(
hgindex->num_verts_per_type_)); hgindex->meta_graph_, rel_graphs, hgindex->num_verts_per_type_));
} }
void HeteroGraph::PinMemory_() { void HeteroGraph::PinMemory_() {
for (auto g : relation_graphs_) for (auto g : relation_graphs_) g->PinMemory_();
g->PinMemory_();
} }
void HeteroGraph::UnpinMemory_() { void HeteroGraph::UnpinMemory_() {
for (auto g : relation_graphs_) for (auto g : relation_graphs_) g->UnpinMemory_();
g->UnpinMemory_();
} }
void HeteroGraph::RecordStream(DGLStreamHandle stream) { void HeteroGraph::RecordStream(DGLStreamHandle stream) {
for (auto g : relation_graphs_) for (auto g : relation_graphs_) g->RecordStream(stream);
g->RecordStream(stream);
} }
std::string HeteroGraph::SharedMemName() const { std::string HeteroGraph::SharedMemName() const {
...@@ -286,13 +294,13 @@ std::string HeteroGraph::SharedMemName() const { ...@@ -286,13 +294,13 @@ std::string HeteroGraph::SharedMemName() const {
} }
HeteroGraphPtr HeteroGraph::CopyToSharedMem( HeteroGraphPtr HeteroGraph::CopyToSharedMem(
HeteroGraphPtr g, const std::string& name, const std::vector<std::string>& ntypes, HeteroGraphPtr g, const std::string& name,
const std::vector<std::string>& ntypes,
const std::vector<std::string>& etypes, const std::set<std::string>& fmts) { const std::vector<std::string>& etypes, const std::set<std::string>& fmts) {
// TODO(JJ): Raise error when calling shared_memory if graph index is on gpu // TODO(JJ): Raise error when calling shared_memory if graph index is on gpu
auto hg = std::dynamic_pointer_cast<HeteroGraph>(g); auto hg = std::dynamic_pointer_cast<HeteroGraph>(g);
CHECK_NOTNULL(hg); CHECK_NOTNULL(hg);
if (hg->SharedMemName() == name) if (hg->SharedMemName() == name) return g;
return g;
// Copy buffer to share memory // Copy buffer to share memory
auto mem = std::make_shared<SharedMemory>(name); auto mem = std::make_shared<SharedMemory>(name);
...@@ -312,7 +320,7 @@ HeteroGraphPtr HeteroGraph::CopyToSharedMem( ...@@ -312,7 +320,7 @@ HeteroGraphPtr HeteroGraph::CopyToSharedMem(
std::vector<HeteroGraphPtr> relgraphs(g->NumEdgeTypes()); std::vector<HeteroGraphPtr> relgraphs(g->NumEdgeTypes());
for (dgl_type_t etype = 0 ; etype < g->NumEdgeTypes() ; ++etype) { for (dgl_type_t etype = 0; etype < g->NumEdgeTypes(); ++etype) {
auto src_dst_type = g->GetEndpointTypes(etype); auto src_dst_type = g->GetEndpointTypes(etype);
int num_vtypes = (src_dst_type.first == src_dst_type.second ? 1 : 2); int num_vtypes = (src_dst_type.first == src_dst_type.second ? 1 : 2);
aten::COOMatrix coo; aten::COOMatrix coo;
...@@ -341,10 +349,11 @@ HeteroGraphPtr HeteroGraph::CopyToSharedMem( ...@@ -341,10 +349,11 @@ HeteroGraphPtr HeteroGraph::CopyToSharedMem(
} }
std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>> std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
HeteroGraph::CreateFromSharedMem(const std::string &name) { HeteroGraph::CreateFromSharedMem(const std::string& name) {
bool exist = SharedMemory::Exist(name); bool exist = SharedMemory::Exist(name);
if (!exist) { if (!exist) {
return std::make_tuple(nullptr, std::vector<std::string>(), std::vector<std::string>()); return std::make_tuple(
nullptr, std::vector<std::string>(), std::vector<std::string>());
} }
auto mem = std::make_shared<SharedMemory>(name); auto mem = std::make_shared<SharedMemory>(name);
auto mem_buf = mem->Open(SHARED_MEM_METAINFO_SIZE_MAX); auto mem_buf = mem->Open(SHARED_MEM_METAINFO_SIZE_MAX);
...@@ -367,7 +376,7 @@ std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>> ...@@ -367,7 +376,7 @@ std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
CHECK(shm.Read(&num_verts_per_type)) << "Invalid number of vertices per type"; CHECK(shm.Read(&num_verts_per_type)) << "Invalid number of vertices per type";
std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges()); std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
for (dgl_type_t etype = 0 ; etype < metagraph->NumEdges() ; ++etype) { for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
auto src_dst = metagraph->FindEdge(etype); auto src_dst = metagraph->FindEdge(etype);
int num_vtypes = (src_dst.first == src_dst.second) ? 1 : 2; int num_vtypes = (src_dst.first == src_dst.second) ? 1 : 2;
aten::COOMatrix coo; aten::COOMatrix coo;
...@@ -387,7 +396,8 @@ std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>> ...@@ -387,7 +396,8 @@ std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo); num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo);
} }
auto ret = std::make_shared<HeteroGraph>(metagraph, relgraphs, num_verts_per_type); auto ret =
std::make_shared<HeteroGraph>(metagraph, relgraphs, num_verts_per_type);
ret->shared_mem_ = mem; ret->shared_mem_ = mem;
std::vector<std::string> ntypes; std::vector<std::string> ntypes;
...@@ -400,11 +410,12 @@ std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>> ...@@ -400,11 +410,12 @@ std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
HeteroGraphPtr HeteroGraph::GetGraphInFormat(dgl_format_code_t formats) const { HeteroGraphPtr HeteroGraph::GetGraphInFormat(dgl_format_code_t formats) const {
std::vector<HeteroGraphPtr> format_rels(NumEdgeTypes()); std::vector<HeteroGraphPtr> format_rels(NumEdgeTypes());
for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
auto relgraph = std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(etype)); auto relgraph =
std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(etype));
format_rels[etype] = relgraph->GetGraphInFormat(formats); format_rels[etype] = relgraph->GetGraphInFormat(formats);
} }
return HeteroGraphPtr(new HeteroGraph( return HeteroGraphPtr(
meta_graph_, format_rels, NumVerticesPerType())); new HeteroGraph(meta_graph_, format_rels, NumVerticesPerType()));
} }
FlattenedHeteroGraphPtr HeteroGraph::Flatten( FlattenedHeteroGraphPtr HeteroGraph::Flatten(
...@@ -418,15 +429,16 @@ FlattenedHeteroGraphPtr HeteroGraph::Flatten( ...@@ -418,15 +429,16 @@ FlattenedHeteroGraphPtr HeteroGraph::Flatten(
} }
template <class IdType> template <class IdType>
FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>& etypes) const { FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(
const std::vector<dgl_type_t>& etypes) const {
std::unordered_map<dgl_type_t, size_t> srctype_offsets, dsttype_offsets; std::unordered_map<dgl_type_t, size_t> srctype_offsets, dsttype_offsets;
size_t src_nodes = 0, dst_nodes = 0; size_t src_nodes = 0, dst_nodes = 0;
std::vector<dgl_type_t> induced_srctype, induced_dsttype; std::vector<dgl_type_t> induced_srctype, induced_dsttype;
std::vector<IdType> induced_srcid, induced_dstid; std::vector<IdType> induced_srcid, induced_dstid;
std::vector<dgl_type_t> srctype_set, dsttype_set; std::vector<dgl_type_t> srctype_set, dsttype_set;
// XXXtype_offsets contain the mapping from node type and number of nodes after this // XXXtype_offsets contain the mapping from node type and number of nodes
// loop. // after this loop.
for (dgl_type_t etype : etypes) { for (dgl_type_t etype : etypes) {
auto src_dsttype = meta_graph_->FindEdge(etype); auto src_dsttype = meta_graph_->FindEdge(etype);
dgl_type_t srctype = src_dsttype.first; dgl_type_t srctype = src_dsttype.first;
...@@ -443,15 +455,16 @@ FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>& ...@@ -443,15 +455,16 @@ FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>&
dsttype_set.push_back(dsttype); dsttype_set.push_back(dsttype);
} }
} }
// Sort the node types so that we can compare the sets and decide whether a homogeneous graph // Sort the node types so that we can compare the sets and decide whether a
// should be returned. // homogeneous graph should be returned.
std::sort(srctype_set.begin(), srctype_set.end()); std::sort(srctype_set.begin(), srctype_set.end());
std::sort(dsttype_set.begin(), dsttype_set.end()); std::sort(dsttype_set.begin(), dsttype_set.end());
bool homograph = (srctype_set.size() == dsttype_set.size()) && bool homograph =
(srctype_set.size() == dsttype_set.size()) &&
std::equal(srctype_set.begin(), srctype_set.end(), dsttype_set.begin()); std::equal(srctype_set.begin(), srctype_set.end(), dsttype_set.begin());
// XXXtype_offsets contain the mapping from node type to node ID offsets after these // XXXtype_offsets contain the mapping from node type to node ID offsets after
// two loops. // these two loops.
for (size_t i = 0; i < srctype_set.size(); ++i) { for (size_t i = 0; i < srctype_set.size(); ++i) {
dgl_type_t ntype = srctype_set[i]; dgl_type_t ntype = srctype_set[i];
size_t num_nodes = srctype_offsets[ntype]; size_t num_nodes = srctype_offsets[ntype];
...@@ -492,14 +505,12 @@ FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>& ...@@ -492,14 +505,12 @@ FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>&
src_arrs.push_back(edges.src + srctype_offset); src_arrs.push_back(edges.src + srctype_offset);
dst_arrs.push_back(edges.dst + dsttype_offset); dst_arrs.push_back(edges.dst + dsttype_offset);
eid_arrs.push_back(edges.id); eid_arrs.push_back(edges.id);
induced_etypes.push_back(aten::Full(etype, num_edges, NumBits(), Context())); induced_etypes.push_back(
aten::Full(etype, num_edges, NumBits(), Context()));
} }
HeteroGraphPtr gptr = UnitGraph::CreateFromCOO( HeteroGraphPtr gptr = UnitGraph::CreateFromCOO(
homograph ? 1 : 2, homograph ? 1 : 2, src_nodes, dst_nodes, aten::Concat(src_arrs),
src_nodes,
dst_nodes,
aten::Concat(src_arrs),
aten::Concat(dst_arrs)); aten::Concat(dst_arrs));
// Sanity check // Sanity check
...@@ -507,15 +518,20 @@ FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>& ...@@ -507,15 +518,20 @@ FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>&
CHECK_EQ(gptr->NumBits(), NumBits()); CHECK_EQ(gptr->NumBits(), NumBits());
FlattenedHeteroGraph* result = new FlattenedHeteroGraph; FlattenedHeteroGraph* result = new FlattenedHeteroGraph;
result->graph = HeteroGraphRef(HeteroGraphPtr(new HeteroGraph(gptr->meta_graph(), {gptr}))); result->graph = HeteroGraphRef(
result->induced_srctype = aten::VecToIdArray(induced_srctype).CopyTo(Context()); HeteroGraphPtr(new HeteroGraph(gptr->meta_graph(), {gptr})));
result->induced_srctype_set = aten::VecToIdArray(srctype_set).CopyTo(Context()); result->induced_srctype =
aten::VecToIdArray(induced_srctype).CopyTo(Context());
result->induced_srctype_set =
aten::VecToIdArray(srctype_set).CopyTo(Context());
result->induced_srcid = aten::VecToIdArray(induced_srcid).CopyTo(Context()); result->induced_srcid = aten::VecToIdArray(induced_srcid).CopyTo(Context());
result->induced_etype = aten::Concat(induced_etypes); result->induced_etype = aten::Concat(induced_etypes);
result->induced_etype_set = aten::VecToIdArray(etypes).CopyTo(Context()); result->induced_etype_set = aten::VecToIdArray(etypes).CopyTo(Context());
result->induced_eid = aten::Concat(eid_arrs); result->induced_eid = aten::Concat(eid_arrs);
result->induced_dsttype = aten::VecToIdArray(induced_dsttype).CopyTo(Context()); result->induced_dsttype =
result->induced_dsttype_set = aten::VecToIdArray(dsttype_set).CopyTo(Context()); aten::VecToIdArray(induced_dsttype).CopyTo(Context());
result->induced_dsttype_set =
aten::VecToIdArray(dsttype_set).CopyTo(Context());
result->induced_dstid = aten::VecToIdArray(induced_dstid).CopyTo(Context()); result->induced_dstid = aten::VecToIdArray(induced_dstid).CopyTo(Context());
return FlattenedHeteroGraphPtr(result); return FlattenedHeteroGraphPtr(result);
} }
...@@ -545,21 +561,25 @@ void HeteroGraph::Save(dmlc::Stream* fs) const { ...@@ -545,21 +561,25 @@ void HeteroGraph::Save(dmlc::Stream* fs) const {
GraphPtr HeteroGraph::AsImmutableGraph() const { GraphPtr HeteroGraph::AsImmutableGraph() const {
CHECK(NumVertexTypes() == 1) << "graph has more than one node types"; CHECK(NumVertexTypes() == 1) << "graph has more than one node types";
CHECK(NumEdgeTypes() == 1) << "graph has more than one edge types"; CHECK(NumEdgeTypes() == 1) << "graph has more than one edge types";
auto unit_graph = CHECK_NOTNULL( auto unit_graph =
std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(0))); CHECK_NOTNULL(std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(0)));
return unit_graph->AsImmutableGraph(); return unit_graph->AsImmutableGraph();
} }
HeteroGraphPtr HeteroGraph::LineGraph(bool backtracking) const { HeteroGraphPtr HeteroGraph::LineGraph(bool backtracking) const {
CHECK_EQ(1, meta_graph_->NumEdges()) << "Only support Homogeneous graph now (one edge type)"; CHECK_EQ(1, meta_graph_->NumEdges())
CHECK_EQ(1, meta_graph_->NumVertices()) << "Only support Homogeneous graph now (one node type)"; << "Only support Homogeneous graph now (one edge type)";
CHECK_EQ(1, meta_graph_->NumVertices())
<< "Only support Homogeneous graph now (one node type)";
CHECK_EQ(1, relation_graphs_.size()) << "Only support Homogeneous graph now"; CHECK_EQ(1, relation_graphs_.size()) << "Only support Homogeneous graph now";
UnitGraphPtr ug = relation_graphs_[0]; UnitGraphPtr ug = relation_graphs_[0];
const auto &ulg = ug->LineGraph(backtracking); const auto& ulg = ug->LineGraph(backtracking);
std::vector<HeteroGraphPtr> rel_graph = {ulg}; std::vector<HeteroGraphPtr> rel_graph = {ulg};
std::vector<int64_t> num_nodes_per_type = {static_cast<int64_t>(ulg->NumVertices(0))}; std::vector<int64_t> num_nodes_per_type = {
return HeteroGraphPtr(new HeteroGraph(meta_graph_, rel_graph, std::move(num_nodes_per_type))); static_cast<int64_t>(ulg->NumVertices(0))};
return HeteroGraphPtr(
new HeteroGraph(meta_graph_, rel_graph, std::move(num_nodes_per_type)));
} }
} // namespace dgl } // namespace dgl
...@@ -5,11 +5,12 @@ ...@@ -5,11 +5,12 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/aten/coo.h> #include <dgl/aten/coo.h>
#include <dgl/packed_func_ext.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <dgl/runtime/c_runtime_api.h>
#include <set> #include <set>
#include "../c_api_common.h" #include "../c_api_common.h"
...@@ -25,7 +26,7 @@ namespace dgl { ...@@ -25,7 +26,7 @@ namespace dgl {
// XXX(minjie): Ideally, Unitgraph should be invisible to python side // XXX(minjie): Ideally, Unitgraph should be invisible to python side
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
int64_t nvtypes = args[0]; int64_t nvtypes = args[0];
int64_t num_src = args[1]; int64_t num_src = args[1];
int64_t num_dst = args[2]; int64_t num_dst = args[2];
...@@ -40,13 +41,13 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO") ...@@ -40,13 +41,13 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
formats_vec.push_back(ParseSparseFormat(fmt)); formats_vec.push_back(ParseSparseFormat(fmt));
} }
const auto code = SparseFormatsToCode(formats_vec); const auto code = SparseFormatsToCode(formats_vec);
auto hgptr = CreateFromCOO(nvtypes, num_src, num_dst, row, col, auto hgptr = CreateFromCOO(
row_sorted, col_sorted, code); nvtypes, num_src, num_dst, row, col, row_sorted, col_sorted, code);
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCSR") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCSR")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
int64_t nvtypes = args[0]; int64_t nvtypes = args[0];
int64_t num_src = args[1]; int64_t num_src = args[1];
int64_t num_dst = args[2]; int64_t num_dst = args[2];
...@@ -62,16 +63,18 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCSR") ...@@ -62,16 +63,18 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCSR")
} }
const auto code = SparseFormatsToCode(formats_vec); const auto code = SparseFormatsToCode(formats_vec);
if (!transpose) { if (!transpose) {
auto hgptr = CreateFromCSR(nvtypes, num_src, num_dst, indptr, indices, edge_ids, code); auto hgptr = CreateFromCSR(
nvtypes, num_src, num_dst, indptr, indices, edge_ids, code);
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
} else { } else {
auto hgptr = CreateFromCSC(nvtypes, num_src, num_dst, indptr, indices, edge_ids, code); auto hgptr = CreateFromCSC(
nvtypes, num_src, num_dst, indptr, indices, edge_ids, code);
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
} }
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0]; GraphRef meta_graph = args[0];
List<HeteroGraphRef> rel_graphs = args[1]; List<HeteroGraphRef> rel_graphs = args[1];
std::vector<HeteroGraphPtr> rel_ptrs; std::vector<HeteroGraphPtr> rel_ptrs;
...@@ -83,8 +86,9 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraph") ...@@ -83,8 +86,9 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraph")
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraphWithNumNodes") DGL_REGISTER_GLOBAL(
.set_body([] (DGLArgs args, DGLRetValue* rv) { "heterograph_index._CAPI_DGLHeteroCreateHeteroGraphWithNumNodes")
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0]; GraphRef meta_graph = args[0];
List<HeteroGraphRef> rel_graphs = args[1]; List<HeteroGraphRef> rel_graphs = args[1];
IdArray num_nodes_per_type = args[2]; IdArray num_nodes_per_type = args[2];
...@@ -101,20 +105,20 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraphWithNumNo ...@@ -101,20 +105,20 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraphWithNumNo
///////////////////////// HeteroGraph member functions ///////////////////////// ///////////////////////// HeteroGraph member functions /////////////////////////
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetMetaGraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetMetaGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
*rv = hg->meta_graph(); *rv = hg->meta_graph();
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsMetaGraphUniBipartite") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsMetaGraphUniBipartite")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
GraphPtr mg = hg->meta_graph(); GraphPtr mg = hg->meta_graph();
*rv = mg->IsUniBipartite(); *rv = mg->IsUniBipartite();
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype; CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
...@@ -126,12 +130,13 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph") ...@@ -126,12 +130,13 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
List<Value> etypes = args[1]; List<Value> etypes = args[1];
std::vector<dgl_id_t> etypes_vec; std::vector<dgl_id_t> etypes_vec;
for (Value val : etypes) { for (Value val : etypes) {
// (gq) have to decompose it into two statements because of a weird MSVC internal error // (gq) have to decompose it into two statements because of a weird MSVC
// internal error
dgl_id_t id = val->data; dgl_id_t id = val->data;
etypes_vec.push_back(id); etypes_vec.push_back(id);
} }
...@@ -140,7 +145,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph") ...@@ -140,7 +145,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddVertices") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1]; dgl_type_t vtype = args[1];
int64_t num = args[2]; int64_t num = args[2];
...@@ -148,7 +153,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddVertices") ...@@ -148,7 +153,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddVertices")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdge") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdge")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
dgl_id_t src = args[2]; dgl_id_t src = args[2];
...@@ -157,7 +162,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdge") ...@@ -157,7 +162,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdge")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdges") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
IdArray src = args[2]; IdArray src = args[2];
...@@ -166,63 +171,63 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdges") ...@@ -166,63 +171,63 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdges")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroClear") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroClear")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
hg->Clear(); hg->Clear();
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDataType") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDataType")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
*rv = hg->DataType(); *rv = hg->DataType();
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroContext") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroContext")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
*rv = hg->Context(); *rv = hg->Context();
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsPinned") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsPinned")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
*rv = hg->IsPinned(); *rv = hg->IsPinned();
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumBits") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
*rv = hg->NumBits(); *rv = hg->NumBits();
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsMultigraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsMultigraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
*rv = hg->IsMultigraph(); *rv = hg->IsMultigraph();
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsReadonly") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsReadonly")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
*rv = hg->IsReadonly(); *rv = hg->IsReadonly();
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumVertices") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1]; dgl_type_t vtype = args[1];
*rv = static_cast<int64_t>(hg->NumVertices(vtype)); *rv = static_cast<int64_t>(hg->NumVertices(vtype));
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumEdges") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
*rv = static_cast<int64_t>(hg->NumEdges(etype)); *rv = static_cast<int64_t>(hg->NumEdges(etype));
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertex") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertex")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1]; dgl_type_t vtype = args[1];
dgl_id_t vid = args[2]; dgl_id_t vid = args[2];
...@@ -230,7 +235,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertex") ...@@ -230,7 +235,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertex")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertices") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1]; dgl_type_t vtype = args[1];
IdArray vids = args[2]; IdArray vids = args[2];
...@@ -238,7 +243,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertices") ...@@ -238,7 +243,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertices")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgeBetween") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgeBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
dgl_id_t src = args[2]; dgl_id_t src = args[2];
...@@ -247,7 +252,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgeBetween") ...@@ -247,7 +252,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgeBetween")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgesBetween") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgesBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
IdArray src = args[2]; IdArray src = args[2];
...@@ -256,7 +261,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgesBetween") ...@@ -256,7 +261,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgesBetween")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPredecessors") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPredecessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
dgl_id_t dst = args[2]; dgl_id_t dst = args[2];
...@@ -264,7 +269,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPredecessors") ...@@ -264,7 +269,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPredecessors")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSuccessors") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSuccessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
dgl_id_t src = args[2]; dgl_id_t src = args[2];
...@@ -272,7 +277,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSuccessors") ...@@ -272,7 +277,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSuccessors")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeId") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeId")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
dgl_id_t src = args[2]; dgl_id_t src = args[2];
...@@ -281,7 +286,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeId") ...@@ -281,7 +286,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeId")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIdsAll") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIdsAll")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
IdArray src = args[2]; IdArray src = args[2];
...@@ -290,9 +295,8 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIdsAll") ...@@ -290,9 +295,8 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIdsAll")
*rv = ConvertEdgeArrayToPackedFunc(ret); *rv = ConvertEdgeArrayToPackedFunc(ret);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIdsOne") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIdsOne")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
IdArray src = args[2]; IdArray src = args[2];
...@@ -301,7 +305,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIdsOne") ...@@ -301,7 +305,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIdsOne")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroFindEdges") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroFindEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
IdArray eids = args[2]; IdArray eids = args[2];
...@@ -310,7 +314,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroFindEdges") ...@@ -310,7 +314,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroFindEdges")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_1") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
dgl_id_t vid = args[2]; dgl_id_t vid = args[2];
...@@ -319,7 +323,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_1") ...@@ -319,7 +323,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_1")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_2") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
IdArray vids = args[2]; IdArray vids = args[2];
...@@ -328,7 +332,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_2") ...@@ -328,7 +332,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_2")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_1") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
dgl_id_t vid = args[2]; dgl_id_t vid = args[2];
...@@ -337,7 +341,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_1") ...@@ -337,7 +341,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_1")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_2") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
IdArray vids = args[2]; IdArray vids = args[2];
...@@ -346,7 +350,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_2") ...@@ -346,7 +350,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_2")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdges") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
std::string order = args[2]; std::string order = args[2];
...@@ -355,7 +359,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdges") ...@@ -355,7 +359,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdges")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegree") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
dgl_id_t vid = args[2]; dgl_id_t vid = args[2];
...@@ -363,7 +367,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegree") ...@@ -363,7 +367,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegree")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegrees") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
IdArray vids = args[2]; IdArray vids = args[2];
...@@ -371,7 +375,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegrees") ...@@ -371,7 +375,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegrees")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegree") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
dgl_id_t vid = args[2]; dgl_id_t vid = args[2];
...@@ -379,7 +383,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegree") ...@@ -379,7 +383,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegree")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegrees") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
IdArray vids = args[2]; IdArray vids = args[2];
...@@ -387,17 +391,16 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegrees") ...@@ -387,17 +391,16 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegrees")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetAdj") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
bool transpose = args[2]; bool transpose = args[2];
std::string fmt = args[3]; std::string fmt = args[3];
*rv = ConvertNDArrayVectorToPackedFunc( *rv = ConvertNDArrayVectorToPackedFunc(hg->GetAdj(etype, transpose, fmt));
hg->GetAdj(etype, transpose, fmt));
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroVertexSubgraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroVertexSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
List<Value> vids = args[1]; List<Value> vids = args[1];
bool relabel_nodes = args[2]; bool relabel_nodes = args[2];
...@@ -413,7 +416,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroVertexSubgraph") ...@@ -413,7 +416,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroVertexSubgraph")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeSubgraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
List<Value> eids = args[1]; List<Value> eids = args[1];
bool preserve_nodes = args[2]; bool preserve_nodes = args[2];
...@@ -430,13 +433,14 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeSubgraph") ...@@ -430,13 +433,14 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeSubgraph")
///////////////////////// HeteroSubgraph members ///////////////////////// ///////////////////////// HeteroSubgraph members /////////////////////////
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetGraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroSubgraphRef subg = args[0]; HeteroSubgraphRef subg = args[0];
*rv = HeteroGraphRef(subg->graph); *rv = HeteroGraphRef(subg->graph);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedVertices") DGL_REGISTER_GLOBAL(
.set_body([] (DGLArgs args, DGLRetValue* rv) { "heterograph_index._CAPI_DGLHeteroSubgraphGetInducedVertices")
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroSubgraphRef subg = args[0]; HeteroSubgraphRef subg = args[0];
List<Value> induced_verts; List<Value> induced_verts;
for (IdArray arr : subg->induced_vertices) { for (IdArray arr : subg->induced_vertices) {
...@@ -446,7 +450,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedVertices ...@@ -446,7 +450,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedVertices
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedEdges") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroSubgraphRef subg = args[0]; HeteroSubgraphRef subg = args[0];
List<Value> induced_edges; List<Value> induced_edges;
for (IdArray arr : subg->induced_edges) { for (IdArray arr : subg->induced_edges) {
...@@ -455,10 +459,11 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedEdges") ...@@ -455,10 +459,11 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedEdges")
*rv = induced_edges; *rv = induced_edges;
}); });
///////////////////////// Global functions and algorithms ///////////////////////// ///////////////////////// Global functions and algorithms
////////////////////////////
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
int bits = args[1]; int bits = args[1];
HeteroGraphPtr bhg_ptr = hg.sptr(); HeteroGraphPtr bhg_ptr = hg.sptr();
...@@ -473,7 +478,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits") ...@@ -473,7 +478,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
int device_type = args[1]; int device_type = args[1];
int device_id = args[2]; int device_id = args[2];
...@@ -485,7 +490,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo") ...@@ -485,7 +490,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPinMemory_") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPinMemory_")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr()); auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr());
hgindex->PinMemory_(); hgindex->PinMemory_();
...@@ -493,7 +498,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPinMemory_") ...@@ -493,7 +498,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPinMemory_")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpinMemory_") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpinMemory_")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr()); auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr());
hgindex->UnpinMemory_(); hgindex->UnpinMemory_();
...@@ -501,7 +506,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpinMemory_") ...@@ -501,7 +506,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpinMemory_")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroRecordStream") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroRecordStream")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
DGLStreamHandle stream = args[1]; DGLStreamHandle stream = args[1];
auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr()); auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr());
...@@ -510,7 +515,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroRecordStream") ...@@ -510,7 +515,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroRecordStream")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyToSharedMem") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
std::string name = args[1]; std::string name = args[1];
List<Value> ntypes = args[2]; List<Value> ntypes = args[2];
...@@ -519,7 +524,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyToSharedMem") ...@@ -519,7 +524,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyToSharedMem")
auto ntypes_vec = ListValueToVector<std::string>(ntypes); auto ntypes_vec = ListValueToVector<std::string>(ntypes);
auto etypes_vec = ListValueToVector<std::string>(etypes); auto etypes_vec = ListValueToVector<std::string>(etypes);
std::set<std::string> fmts_set; std::set<std::string> fmts_set;
for (const auto &fmt : fmts) { for (const auto& fmt : fmts) {
std::string fmt_data = fmt->data; std::string fmt_data = fmt->data;
fmts_set.insert(fmt_data); fmts_set.insert(fmt_data);
} }
...@@ -529,7 +534,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyToSharedMem") ...@@ -529,7 +534,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyToSharedMem")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateFromSharedMem") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateFromSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
std::string name = args[0]; std::string name = args[0];
HeteroGraphPtr hg; HeteroGraphPtr hg;
std::vector<std::string> ntypes; std::vector<std::string> ntypes;
...@@ -537,9 +542,9 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateFromSharedMem") ...@@ -537,9 +542,9 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateFromSharedMem")
std::tie(hg, ntypes, etypes) = HeteroGraph::CreateFromSharedMem(name); std::tie(hg, ntypes, etypes) = HeteroGraph::CreateFromSharedMem(name);
List<Value> ntypes_list; List<Value> ntypes_list;
List<Value> etypes_list; List<Value> etypes_list;
for (const auto &ntype : ntypes) for (const auto& ntype : ntypes)
ntypes_list.push_back(Value(MakeValue(ntype))); ntypes_list.push_back(Value(MakeValue(ntype)));
for (const auto &etype : etypes) for (const auto& etype : etypes)
etypes_list.push_back(Value(MakeValue(etype))); etypes_list.push_back(Value(MakeValue(etype)));
List<ObjectRef> ret; List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(hg)); ret.push_back(HeteroGraphRef(hg));
...@@ -549,7 +554,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateFromSharedMem") ...@@ -549,7 +554,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateFromSharedMem")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroJointUnion") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroJointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0]; GraphRef meta_graph = args[0];
List<HeteroGraphRef> component_graphs = args[1]; List<HeteroGraphRef> component_graphs = args[1];
CHECK(component_graphs.size() > 1) CHECK(component_graphs.size() > 1)
...@@ -561,8 +566,8 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroJointUnion") ...@@ -561,8 +566,8 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroJointUnion")
for (const auto& component : component_graphs) { for (const auto& component : component_graphs) {
component_ptrs.push_back(component.sptr()); component_ptrs.push_back(component.sptr());
CHECK_EQ(component->NumBits(), bits) CHECK_EQ(component->NumBits(), bits)
<< "Expect graphs to joint union have the same index dtype(int" << bits << "Expect graphs to joint union have the same index dtype(int"
<< "), but got int" << component->NumBits(); << bits << "), but got int" << component->NumBits();
CHECK_EQ(component->Context(), ctx) CHECK_EQ(component->Context(), ctx)
<< "Expect graphs to joint union have the same context" << ctx << "Expect graphs to joint union have the same context" << ctx
<< "), but got " << component->Context(); << "), but got " << component->Context();
...@@ -570,10 +575,10 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroJointUnion") ...@@ -570,10 +575,10 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroJointUnion")
auto hgptr = JointUnionHeteroGraph(meta_graph.sptr(), component_ptrs); auto hgptr = JointUnionHeteroGraph(meta_graph.sptr(), component_ptrs);
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion_v2") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0]; GraphRef meta_graph = args[0];
List<HeteroGraphRef> component_graphs = args[1]; List<HeteroGraphRef> component_graphs = args[1];
CHECK(component_graphs.size() > 0) CHECK(component_graphs.size() > 0)
...@@ -594,86 +599,86 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion_v2") ...@@ -594,86 +599,86 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion_v2")
auto hgptr = DisjointUnionHeteroGraph2(meta_graph.sptr(), component_ptrs); auto hgptr = DisjointUnionHeteroGraph2(meta_graph.sptr(), component_ptrs);
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes_v2") DGL_REGISTER_GLOBAL(
.set_body([] (DGLArgs args, DGLRetValue* rv) { "heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes_v2")
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
const IdArray vertex_sizes = args[1]; const IdArray vertex_sizes = args[1];
const IdArray edge_sizes = args[2]; const IdArray edge_sizes = args[2];
std::vector<HeteroGraphPtr> ret; std::vector<HeteroGraphPtr> ret;
ret = DisjointPartitionHeteroBySizes2(hg->meta_graph(), hg.sptr(), ret = DisjointPartitionHeteroBySizes2(
vertex_sizes, edge_sizes); hg->meta_graph(), hg.sptr(), vertex_sizes, edge_sizes);
List<HeteroGraphRef> ret_list; List<HeteroGraphRef> ret_list;
for (HeteroGraphPtr hgptr : ret) { for (HeteroGraphPtr hgptr : ret) {
ret_list.push_back(HeteroGraphRef(hgptr)); ret_list.push_back(HeteroGraphRef(hgptr));
} }
*rv = ret_list; *rv = ret_list;
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
const IdArray vertex_sizes = args[1]; const IdArray vertex_sizes = args[1];
const IdArray edge_sizes = args[2]; const IdArray edge_sizes = args[2];
const int64_t bits = hg->NumBits(); const int64_t bits = hg->NumBits();
std::vector<HeteroGraphPtr> ret; std::vector<HeteroGraphPtr> ret;
ATEN_ID_BITS_SWITCH(bits, IdType, { ATEN_ID_BITS_SWITCH(bits, IdType, {
ret = DisjointPartitionHeteroBySizes<IdType>(hg->meta_graph(), hg.sptr(), ret = DisjointPartitionHeteroBySizes<IdType>(
vertex_sizes, edge_sizes); hg->meta_graph(), hg.sptr(), vertex_sizes, edge_sizes);
}); });
List<HeteroGraphRef> ret_list; List<HeteroGraphRef> ret_list;
for (HeteroGraphPtr hgptr : ret) { for (HeteroGraphPtr hgptr : ret) {
ret_list.push_back(HeteroGraphRef(hgptr)); ret_list.push_back(HeteroGraphRef(hgptr));
} }
*rv = ret_list; *rv = ret_list;
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSlice") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSlice")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
const IdArray num_nodes_per_type = args[1]; const IdArray num_nodes_per_type = args[1];
const IdArray start_nid_per_type = args[2]; const IdArray start_nid_per_type = args[2];
const IdArray num_edges_per_type = args[3]; const IdArray num_edges_per_type = args[3];
const IdArray start_eid_per_type = args[4]; const IdArray start_eid_per_type = args[4];
auto hgptr = SliceHeteroGraph(hg->meta_graph(), hg.sptr(), num_nodes_per_type, auto hgptr = SliceHeteroGraph(
start_nid_per_type, num_edges_per_type, start_eid_per_type); hg->meta_graph(), hg.sptr(), num_nodes_per_type, start_nid_per_type,
num_edges_per_type, start_eid_per_type);
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetCreatedFormats") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetCreatedFormats")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
List<Value> format_list; List<Value> format_list;
dgl_format_code_t code = hg->GetRelationGraph(0)->GetCreatedFormats(); dgl_format_code_t code = hg->GetRelationGraph(0)->GetCreatedFormats();
for (auto format : CodeToSparseFormats(code)) { for (auto format : CodeToSparseFormats(code)) {
format_list.push_back( format_list.push_back(Value(MakeValue(ToStringSparseFormat(format))));
Value(MakeValue(ToStringSparseFormat(format))));
} }
*rv = format_list; *rv = format_list;
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetAllowedFormats") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetAllowedFormats")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
List<Value> format_list; List<Value> format_list;
dgl_format_code_t code = hg->GetRelationGraph(0)->GetAllowedFormats(); dgl_format_code_t code = hg->GetRelationGraph(0)->GetAllowedFormats();
for (auto format : CodeToSparseFormats(code)) { for (auto format : CodeToSparseFormats(code)) {
format_list.push_back( format_list.push_back(Value(MakeValue(ToStringSparseFormat(format))));
Value(MakeValue(ToStringSparseFormat(format))));
} }
*rv = format_list; *rv = format_list;
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateFormat") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateFormat")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_format_code_t code = hg->GetRelationGraph(0)->GetAllowedFormats(); dgl_format_code_t code = hg->GetRelationGraph(0)->GetAllowedFormats();
auto get_format_f = [&](size_t etype_b, size_t etype_e) { auto get_format_f = [&](size_t etype_b, size_t etype_e) {
for (auto etype = etype_b; etype < etype_e; ++etype) { for (auto etype = etype_b; etype < etype_e; ++etype) {
auto bg = std::dynamic_pointer_cast<UnitGraph>(hg->GetRelationGraph(etype)); auto bg =
for (auto format : CodeToSparseFormats(code)) std::dynamic_pointer_cast<UnitGraph>(hg->GetRelationGraph(etype));
bg->GetFormat(format); for (auto format : CodeToSparseFormats(code)) bg->GetFormat(format);
} }
}; };
...@@ -682,10 +687,10 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateFormat") ...@@ -682,10 +687,10 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateFormat")
#else #else
get_format_f(0, hg->NumEdgeTypes()); get_format_f(0, hg->NumEdgeTypes());
#endif #endif
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatGraph") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
List<Value> formats = args[1]; List<Value> formats = args[1];
std::vector<SparseFormat> formats_vec; std::vector<SparseFormat> formats_vec;
...@@ -693,13 +698,12 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatGraph") ...@@ -693,13 +698,12 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatGraph")
std::string fmt = val->data; std::string fmt = val->data;
formats_vec.push_back(ParseSparseFormat(fmt)); formats_vec.push_back(ParseSparseFormat(fmt));
} }
auto hgptr = hg->GetGraphInFormat( auto hgptr = hg->GetGraphInFormat(SparseFormatsToCode(formats_vec));
SparseFormatsToCode(formats_vec));
*rv = HeteroGraphRef(hgptr); *rv = HeteroGraphRef(hgptr);
}); });
DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLInSubgraph") DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLInSubgraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]); const auto& nodes = ListValueToVector<IdArray>(args[1]);
bool relabel_nodes = args[2]; bool relabel_nodes = args[2];
...@@ -709,7 +713,7 @@ DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLInSubgraph") ...@@ -709,7 +713,7 @@ DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLInSubgraph")
}); });
DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLOutSubgraph") DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLOutSubgraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]); const auto& nodes = ListValueToVector<IdArray>(args[1]);
bool relabel_nodes = args[2]; bool relabel_nodes = args[2];
...@@ -719,27 +723,30 @@ DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLOutSubgraph") ...@@ -719,27 +723,30 @@ DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLOutSubgraph")
}); });
DGL_REGISTER_GLOBAL("transform._CAPI_DGLAsImmutableGraph") DGL_REGISTER_GLOBAL("transform._CAPI_DGLAsImmutableGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
*rv = GraphRef(hg->AsImmutableGraph()); *rv = GraphRef(hg->AsImmutableGraph());
}); });
DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortOutEdges") DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortOutEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
NDArray tag = args[1]; NDArray tag = args[1];
int64_t num_tag = args[2]; int64_t num_tag = args[2];
CHECK_EQ(hg->Context().device_type, kDGLCPU) << "Only support sorting by tag on cpu"; CHECK_EQ(hg->Context().device_type, kDGLCPU)
<< "Only support sorting by tag on cpu";
CHECK(aten::IsValidIdArray(tag)); CHECK(aten::IsValidIdArray(tag));
CHECK_EQ(tag->ctx.device_type, kDGLCPU) << "Only support sorting by tag on cpu"; CHECK_EQ(tag->ctx.device_type, kDGLCPU)
<< "Only support sorting by tag on cpu";
const auto csr = hg->GetCSRMatrix(0); const auto csr = hg->GetCSRMatrix(0);
NDArray tag_pos = aten::NullArray(); NDArray tag_pos = aten::NullArray();
aten::CSRMatrix output; aten::CSRMatrix output;
std::tie(output, tag_pos) = aten::CSRSortByTag(csr, tag, num_tag); std::tie(output, tag_pos) = aten::CSRSortByTag(csr, tag, num_tag);
HeteroGraphPtr output_hg = CreateFromCSR(hg->NumVertexTypes(), output, ALL_CODE); HeteroGraphPtr output_hg =
CreateFromCSR(hg->NumVertexTypes(), output, ALL_CODE);
List<ObjectRef> ret; List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(output_hg)); ret.push_back(HeteroGraphRef(output_hg));
ret.push_back(Value(MakeValue(tag_pos))); ret.push_back(Value(MakeValue(tag_pos)));
...@@ -747,14 +754,16 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortOutEdges") ...@@ -747,14 +754,16 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortOutEdges")
}); });
DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortInEdges") DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortInEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
NDArray tag = args[1]; NDArray tag = args[1];
int64_t num_tag = args[2]; int64_t num_tag = args[2];
CHECK_EQ(hg->Context().device_type, kDGLCPU) << "Only support sorting by tag on cpu"; CHECK_EQ(hg->Context().device_type, kDGLCPU)
<< "Only support sorting by tag on cpu";
CHECK(aten::IsValidIdArray(tag)); CHECK(aten::IsValidIdArray(tag));
CHECK_EQ(tag->ctx.device_type, kDGLCPU) << "Only support sorting by tag on cpu"; CHECK_EQ(tag->ctx.device_type, kDGLCPU)
<< "Only support sorting by tag on cpu";
const auto csc = hg->GetCSCMatrix(0); const auto csc = hg->GetCSCMatrix(0);
...@@ -762,7 +771,8 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortInEdges") ...@@ -762,7 +771,8 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortInEdges")
aten::CSRMatrix output; aten::CSRMatrix output;
std::tie(output, tag_pos) = aten::CSRSortByTag(csc, tag, num_tag); std::tie(output, tag_pos) = aten::CSRSortByTag(csc, tag, num_tag);
HeteroGraphPtr output_hg = CreateFromCSC(hg->NumVertexTypes(), output, ALL_CODE); HeteroGraphPtr output_hg =
CreateFromCSC(hg->NumVertexTypes(), output, ALL_CODE);
List<ObjectRef> ret; List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(output_hg)); ret.push_back(HeteroGraphRef(output_hg));
ret.push_back(Value(MakeValue(tag_pos))); ret.push_back(Value(MakeValue(tag_pos)));
...@@ -770,7 +780,7 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortInEdges") ...@@ -770,7 +780,7 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortInEdges")
}); });
DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes") DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef metagraph = args[0]; GraphRef metagraph = args[0];
std::unordered_set<uint64_t> dst_set; std::unordered_set<uint64_t> dst_set;
std::unordered_set<uint64_t> src_set; std::unordered_set<uint64_t> src_set;
...@@ -793,7 +803,8 @@ DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes") ...@@ -793,7 +803,8 @@ DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes")
else if (is_dst) else if (is_dst)
dstlist.push_back(Value(MakeValue(static_cast<int64_t>(nid)))); dstlist.push_back(Value(MakeValue(static_cast<int64_t>(nid))));
else else
// If a node type is isolated, put it in srctype as defined in the Python docstring. // If a node type is isolated, put it in srctype as defined in the
// Python docstring.
srclist.push_back(Value(MakeValue(static_cast<int64_t>(nid)))); srclist.push_back(Value(MakeValue(static_cast<int64_t>(nid))));
} }
ret_list.push_back(srclist); ret_list.push_back(srclist);
...@@ -802,24 +813,24 @@ DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes") ...@@ -802,24 +813,24 @@ DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroReverse") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroReverse")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
CHECK_GT(hg->NumEdgeTypes(), 0); CHECK_GT(hg->NumEdgeTypes(), 0);
auto g = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr()); auto g = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr());
std::vector<HeteroGraphPtr> rev_ugs; std::vector<HeteroGraphPtr> rev_ugs;
const auto &ugs = g->relation_graphs(); const auto& ugs = g->relation_graphs();
rev_ugs.resize(ugs.size()); rev_ugs.resize(ugs.size());
for (size_t i = 0; i < ugs.size(); ++i) { for (size_t i = 0; i < ugs.size(); ++i) {
const auto &rev_ug = ugs[i]->Reverse(); const auto& rev_ug = ugs[i]->Reverse();
rev_ugs[i] = rev_ug; rev_ugs[i] = rev_ug;
} }
// node types are not changed // node types are not changed
const auto& num_nodes = g->NumVerticesPerType(); const auto& num_nodes = g->NumVerticesPerType();
const auto& meta_edges = hg->meta_graph()->Edges("eid"); const auto& meta_edges = hg->meta_graph()->Edges("eid");
// reverse the metagraph // reverse the metagraph
const auto& rev_meta = ImmutableGraph::CreateFromCOO(hg->meta_graph()->NumVertices(), const auto& rev_meta = ImmutableGraph::CreateFromCOO(
meta_edges.dst, meta_edges.src); hg->meta_graph()->NumVertices(), meta_edges.dst, meta_edges.src);
*rv = CreateHeteroGraph(rev_meta, rev_ugs, num_nodes); *rv = CreateHeteroGraph(rev_meta, rev_ugs, num_nodes);
}); });
} // namespace dgl } // namespace dgl
...@@ -4,13 +4,14 @@ ...@@ -4,13 +4,14 @@
* \brief DGL immutable graph index implementation * \brief DGL immutable graph index implementation
*/ */
#include <dgl/base_heterograph.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/smart_ptr_serializer.h> #include <dgl/runtime/smart_ptr_serializer.h>
#include <dgl/base_heterograph.h>
#include <dmlc/io.h> #include <dmlc/io.h>
#include <dmlc/type_traits.h> #include <dmlc/type_traits.h>
#include <string.h> #include <string.h>
#include <bitset> #include <bitset>
#include <numeric> #include <numeric>
#include <tuple> #include <tuple>
...@@ -23,7 +24,8 @@ using namespace dgl::runtime; ...@@ -23,7 +24,8 @@ using namespace dgl::runtime;
namespace dgl { namespace dgl {
namespace { namespace {
inline std::string GetSharedMemName(const std::string &name, const std::string &edge_dir) { inline std::string GetSharedMemName(
const std::string &name, const std::string &edge_dir) {
return name + "_" + edge_dir; return name + "_" + edge_dir;
} }
...@@ -39,8 +41,9 @@ struct GraphIndexMetadata { ...@@ -39,8 +41,9 @@ struct GraphIndexMetadata {
}; };
/* /*
* Serialize the metadata of a graph index and place it in a shared-memory tensor. * Serialize the metadata of a graph index and place it in a shared-memory
* In this way, another process can reconstruct a GraphIndex from a shared-memory tensor. * tensor. In this way, another process can reconstruct a GraphIndex from a
* shared-memory tensor.
*/ */
NDArray SerializeMetadata(ImmutableGraphPtr gidx, const std::string &name) { NDArray SerializeMetadata(ImmutableGraphPtr gidx, const std::string &name) {
#ifndef _WIN32 #ifndef _WIN32
...@@ -51,8 +54,9 @@ NDArray SerializeMetadata(ImmutableGraphPtr gidx, const std::string &name) { ...@@ -51,8 +54,9 @@ NDArray SerializeMetadata(ImmutableGraphPtr gidx, const std::string &name) {
meta.has_out_csr = gidx->HasOutCSR(); meta.has_out_csr = gidx->HasOutCSR();
meta.has_coo = false; meta.has_coo = false;
NDArray meta_arr = NDArray::EmptyShared(name, {sizeof(meta)}, DGLDataType{kDGLInt, 8, 1}, NDArray meta_arr = NDArray::EmptyShared(
DGLContext{kDGLCPU, 0}, true); name, {sizeof(meta)}, DGLDataType{kDGLInt, 8, 1}, DGLContext{kDGLCPU, 0},
true);
memcpy(meta_arr->data, &meta, sizeof(meta)); memcpy(meta_arr->data, &meta, sizeof(meta));
return meta_arr; return meta_arr;
#else #else
...@@ -67,8 +71,9 @@ NDArray SerializeMetadata(ImmutableGraphPtr gidx, const std::string &name) { ...@@ -67,8 +71,9 @@ NDArray SerializeMetadata(ImmutableGraphPtr gidx, const std::string &name) {
GraphIndexMetadata DeserializeMetadata(const std::string &name) { GraphIndexMetadata DeserializeMetadata(const std::string &name) {
GraphIndexMetadata meta; GraphIndexMetadata meta;
#ifndef _WIN32 #ifndef _WIN32
NDArray meta_arr = NDArray::EmptyShared(name, {sizeof(meta)}, DGLDataType{kDGLInt, 8, 1}, NDArray meta_arr = NDArray::EmptyShared(
DGLContext{kDGLCPU, 0}, false); name, {sizeof(meta)}, DGLDataType{kDGLInt, 8, 1}, DGLContext{kDGLCPU, 0},
false);
memcpy(&meta, meta_arr->data, sizeof(meta)); memcpy(&meta, meta_arr->data, sizeof(meta));
#else #else
LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet"; LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet";
...@@ -77,18 +82,23 @@ GraphIndexMetadata DeserializeMetadata(const std::string &name) { ...@@ -77,18 +82,23 @@ GraphIndexMetadata DeserializeMetadata(const std::string &name) {
} }
std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory( std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges, bool is_create) { const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges,
bool is_create) {
#ifndef _WIN32 #ifndef _WIN32
const int64_t file_size = (num_verts + 1 + num_edges * 2) * sizeof(dgl_id_t); const int64_t file_size = (num_verts + 1 + num_edges * 2) * sizeof(dgl_id_t);
IdArray sm_array = IdArray::EmptyShared( IdArray sm_array = IdArray::EmptyShared(
shared_mem_name, {file_size}, DGLDataType{kDGLInt, 8, 1}, DGLContext{kDGLCPU, 0}, is_create); shared_mem_name, {file_size}, DGLDataType{kDGLInt, 8, 1},
DGLContext{kDGLCPU, 0}, is_create);
// Create views from the shared memory array. Note that we don't need to save // Create views from the shared memory array. Note that we don't need to save
// the sm_array because the refcount is maintained by the view arrays. // the sm_array because the refcount is maintained by the view arrays.
IdArray indptr = sm_array.CreateView({num_verts + 1}, DGLDataType{kDGLInt, 64, 1}); IdArray indptr =
IdArray indices = sm_array.CreateView({num_edges}, DGLDataType{kDGLInt, 64, 1}, sm_array.CreateView({num_verts + 1}, DGLDataType{kDGLInt, 64, 1});
IdArray indices = sm_array.CreateView(
{num_edges}, DGLDataType{kDGLInt, 64, 1},
(num_verts + 1) * sizeof(dgl_id_t)); (num_verts + 1) * sizeof(dgl_id_t));
IdArray edge_ids = sm_array.CreateView({num_edges}, DGLDataType{kDGLInt, 64, 1}, IdArray edge_ids = sm_array.CreateView(
{num_edges}, DGLDataType{kDGLInt, 64, 1},
(num_verts + 1 + num_edges) * sizeof(dgl_id_t)); (num_verts + 1 + num_edges) * sizeof(dgl_id_t));
return std::make_tuple(indptr, indices, edge_ids); return std::make_tuple(indptr, indices, edge_ids);
#else #else
...@@ -106,10 +116,9 @@ std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory( ...@@ -106,10 +116,9 @@ std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
CSR::CSR(int64_t num_vertices, int64_t num_edges) { CSR::CSR(int64_t num_vertices, int64_t num_edges) {
CHECK(!(num_vertices == 0 && num_edges != 0)); CHECK(!(num_vertices == 0 && num_edges != 0));
adj_ = aten::CSRMatrix{num_vertices, num_vertices, adj_ = aten::CSRMatrix{
aten::NewIdArray(num_vertices + 1), num_vertices, num_vertices, aten::NewIdArray(num_vertices + 1),
aten::NewIdArray(num_edges), aten::NewIdArray(num_edges), aten::NewIdArray(num_edges)};
aten::NewIdArray(num_edges)};
adj_.sorted = false; adj_.sorted = false;
} }
...@@ -123,8 +132,10 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) { ...@@ -123,8 +132,10 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
adj_.sorted = false; adj_.sorted = false;
} }
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, CSR::CSR(
const std::string &shared_mem_name): shared_mem_name_(shared_mem_name) { IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &shared_mem_name)
: shared_mem_name_(shared_mem_name) {
CHECK(aten::IsValidIdArray(indptr)); CHECK(aten::IsValidIdArray(indptr));
CHECK(aten::IsValidIdArray(indices)); CHECK(aten::IsValidIdArray(indices));
CHECK(aten::IsValidIdArray(edge_ids)); CHECK(aten::IsValidIdArray(edge_ids));
...@@ -133,8 +144,8 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, ...@@ -133,8 +144,8 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
const int64_t num_edges = indices->shape[0]; const int64_t num_edges = indices->shape[0];
adj_.num_rows = num_verts; adj_.num_rows = num_verts;
adj_.num_cols = num_verts; adj_.num_cols = num_verts;
std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory( std::tie(adj_.indptr, adj_.indices, adj_.data) =
shared_mem_name, num_verts, num_edges, true); MapFromSharedMemory(shared_mem_name, num_verts, num_edges, true);
// copy the given data into the shared memory arrays // copy the given data into the shared memory arrays
adj_.indptr.CopyFrom(indptr); adj_.indptr.CopyFrom(indptr);
adj_.indices.CopyFrom(indices); adj_.indices.CopyFrom(indices);
...@@ -142,19 +153,18 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, ...@@ -142,19 +153,18 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
adj_.sorted = false; adj_.sorted = false;
} }
CSR::CSR(const std::string &shared_mem_name, CSR::CSR(
int64_t num_verts, int64_t num_edges): shared_mem_name_(shared_mem_name) { const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges)
: shared_mem_name_(shared_mem_name) {
CHECK(!(num_verts == 0 && num_edges != 0)); CHECK(!(num_verts == 0 && num_edges != 0));
adj_.num_rows = num_verts; adj_.num_rows = num_verts;
adj_.num_cols = num_verts; adj_.num_cols = num_verts;
std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory( std::tie(adj_.indptr, adj_.indices, adj_.data) =
shared_mem_name, num_verts, num_edges, false); MapFromSharedMemory(shared_mem_name, num_verts, num_edges, false);
adj_.sorted = false; adj_.sorted = false;
} }
bool CSR::IsMultigraph() const { bool CSR::IsMultigraph() const { return aten::CSRHasDuplicate(adj_); }
return aten::CSRHasDuplicate(adj_);
}
EdgeArray CSR::OutEdges(dgl_id_t vid) const { EdgeArray CSR::OutEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
...@@ -204,7 +214,7 @@ IdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const { ...@@ -204,7 +214,7 @@ IdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const {
} }
EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const { EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
const auto& arrs = aten::CSRGetDataAndIndices(adj_, src_ids, dst_ids); const auto &arrs = aten::CSRGetDataAndIndices(adj_, src_ids, dst_ids);
return EdgeArray{arrs[0], arrs[1], arrs[2]}; return EdgeArray{arrs[0], arrs[1], arrs[2]};
} }
...@@ -212,14 +222,15 @@ EdgeArray CSR::Edges(const std::string &order) const { ...@@ -212,14 +222,15 @@ EdgeArray CSR::Edges(const std::string &order) const {
CHECK(order.empty() || order == std::string("srcdst")) CHECK(order.empty() || order == std::string("srcdst"))
<< "CSR only support Edges of order \"srcdst\"," << "CSR only support Edges of order \"srcdst\","
<< " but got \"" << order << "\"."; << " but got \"" << order << "\".";
const auto& coo = aten::CSRToCOO(adj_, false); const auto &coo = aten::CSRToCOO(adj_, false);
return EdgeArray{coo.row, coo.col, coo.data}; return EdgeArray{coo.row, coo.col, coo.data};
} }
Subgraph CSR::VertexSubgraph(IdArray vids) const { Subgraph CSR::VertexSubgraph(IdArray vids) const {
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array."; CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto& submat = aten::CSRSliceMatrix(adj_, vids, vids); const auto &submat = aten::CSRSliceMatrix(adj_, vids, vids);
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context()); IdArray sub_eids =
aten::Range(0, submat.data->shape[0], NumBits(), Context());
CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids)); CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids));
subcsr->adj_.sorted = this->adj_.sorted; subcsr->adj_.sorted = this->adj_.sorted;
Subgraph subg; Subgraph subg;
...@@ -230,21 +241,21 @@ Subgraph CSR::VertexSubgraph(IdArray vids) const { ...@@ -230,21 +241,21 @@ Subgraph CSR::VertexSubgraph(IdArray vids) const {
} }
CSRPtr CSR::Transpose() const { CSRPtr CSR::Transpose() const {
const auto& trans = aten::CSRTranspose(adj_); const auto &trans = aten::CSRTranspose(adj_);
return CSRPtr(new CSR(trans.indptr, trans.indices, trans.data)); return CSRPtr(new CSR(trans.indptr, trans.indices, trans.data));
} }
COOPtr CSR::ToCOO() const { COOPtr CSR::ToCOO() const {
const auto& coo = aten::CSRToCOO(adj_, true); const auto &coo = aten::CSRToCOO(adj_, true);
return COOPtr(new COO(NumVertices(), coo.row, coo.col)); return COOPtr(new COO(NumVertices(), coo.row, coo.col));
} }
CSR CSR::CopyTo(const DGLContext& ctx) const { CSR CSR::CopyTo(const DGLContext &ctx) const {
if (Context() == ctx) { if (Context() == ctx) {
return *this; return *this;
} else { } else {
CSR ret(adj_.indptr.CopyTo(ctx), CSR ret(
adj_.indices.CopyTo(ctx), adj_.indptr.CopyTo(ctx), adj_.indices.CopyTo(ctx),
adj_.data.CopyTo(ctx)); adj_.data.CopyTo(ctx));
return ret; return ret;
} }
...@@ -264,8 +275,8 @@ CSR CSR::AsNumBits(uint8_t bits) const { ...@@ -264,8 +275,8 @@ CSR CSR::AsNumBits(uint8_t bits) const {
if (NumBits() == bits) { if (NumBits() == bits) {
return *this; return *this;
} else { } else {
CSR ret(aten::AsNumBits(adj_.indptr, bits), CSR ret(
aten::AsNumBits(adj_.indices, bits), aten::AsNumBits(adj_.indptr, bits), aten::AsNumBits(adj_.indices, bits),
aten::AsNumBits(adj_.data, bits)); aten::AsNumBits(adj_.data, bits));
return ret; return ret;
} }
...@@ -274,8 +285,8 @@ CSR CSR::AsNumBits(uint8_t bits) const { ...@@ -274,8 +285,8 @@ CSR CSR::AsNumBits(uint8_t bits) const {
DGLIdIters CSR::SuccVec(dgl_id_t vid) const { DGLIdIters CSR::SuccVec(dgl_id_t vid) const {
// TODO(minjie): This still assumes the data type and device context // TODO(minjie): This still assumes the data type and device context
// of this graph. Should fix later. // of this graph. Should fix later.
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data); const dgl_id_t *indptr_data = static_cast<dgl_id_t *>(adj_.indptr->data);
const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data); const dgl_id_t *indices_data = static_cast<dgl_id_t *>(adj_.indices->data);
const dgl_id_t start = indptr_data[vid]; const dgl_id_t start = indptr_data[vid];
const dgl_id_t end = indptr_data[vid + 1]; const dgl_id_t end = indptr_data[vid + 1];
return DGLIdIters(indices_data + start, indices_data + end); return DGLIdIters(indices_data + start, indices_data + end);
...@@ -284,29 +295,28 @@ DGLIdIters CSR::SuccVec(dgl_id_t vid) const { ...@@ -284,29 +295,28 @@ DGLIdIters CSR::SuccVec(dgl_id_t vid) const {
DGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const { DGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const {
// TODO(minjie): This still assumes the data type and device context // TODO(minjie): This still assumes the data type and device context
// of this graph. Should fix later. // of this graph. Should fix later.
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data); const dgl_id_t *indptr_data = static_cast<dgl_id_t *>(adj_.indptr->data);
const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data); const dgl_id_t *eid_data = static_cast<dgl_id_t *>(adj_.data->data);
const dgl_id_t start = indptr_data[vid]; const dgl_id_t start = indptr_data[vid];
const dgl_id_t end = indptr_data[vid + 1]; const dgl_id_t end = indptr_data[vid + 1];
return DGLIdIters(eid_data + start, eid_data + end); return DGLIdIters(eid_data + start, eid_data + end);
} }
bool CSR::Load(dmlc::Stream *fs) { bool CSR::Load(dmlc::Stream *fs) {
fs->Read(const_cast<dgl::aten::CSRMatrix*>(&adj_)); fs->Read(const_cast<dgl::aten::CSRMatrix *>(&adj_));
return true; return true;
} }
void CSR::Save(dmlc::Stream *fs) const { void CSR::Save(dmlc::Stream *fs) const { fs->Write(adj_); }
fs->Write(adj_);
}
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
// //
// COO graph implementation // COO graph implementation
// //
////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////
COO::COO(int64_t num_vertices, IdArray src, IdArray dst, COO::COO(
bool row_sorted, bool col_sorted) { int64_t num_vertices, IdArray src, IdArray dst, bool row_sorted,
bool col_sorted) {
CHECK(aten::IsValidIdArray(src)); CHECK(aten::IsValidIdArray(src));
CHECK(aten::IsValidIdArray(dst)); CHECK(aten::IsValidIdArray(dst));
CHECK_EQ(src->shape[0], dst->shape[0]); CHECK_EQ(src->shape[0], dst->shape[0]);
...@@ -314,9 +324,7 @@ COO::COO(int64_t num_vertices, IdArray src, IdArray dst, ...@@ -314,9 +324,7 @@ COO::COO(int64_t num_vertices, IdArray src, IdArray dst,
aten::NullArray(), row_sorted, col_sorted}; aten::NullArray(), row_sorted, col_sorted};
} }
bool COO::IsMultigraph() const { bool COO::IsMultigraph() const { return aten::COOHasDuplicate(adj_); }
return aten::COOHasDuplicate(adj_);
}
std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const { std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
CHECK(eid < NumEdges()) << "Invalid edge id: " << eid; CHECK(eid < NumEdges()) << "Invalid edge id: " << eid;
...@@ -327,17 +335,17 @@ std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const { ...@@ -327,17 +335,17 @@ std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
EdgeArray COO::FindEdges(IdArray eids) const { EdgeArray COO::FindEdges(IdArray eids) const {
CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array"; CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
BUG_IF_FAIL(aten::IsNullArray(adj_.data)) << BUG_IF_FAIL(aten::IsNullArray(adj_.data))
"FindEdges requires the internal COO matrix not having EIDs."; << "FindEdges requires the internal COO matrix not having EIDs.";
return EdgeArray{aten::IndexSelect(adj_.row, eids), return EdgeArray{
aten::IndexSelect(adj_.col, eids), aten::IndexSelect(adj_.row, eids), aten::IndexSelect(adj_.col, eids),
eids}; eids};
} }
EdgeArray COO::Edges(const std::string &order) const { EdgeArray COO::Edges(const std::string &order) const {
CHECK(order.empty() || order == std::string("eid")) CHECK(order.empty() || order == std::string("eid"))
<< "COO only support Edges of order \"eid\", but got \"" << "COO only support Edges of order \"eid\", but got \"" << order
<< order << "\"."; << "\".";
IdArray rst_eid = aten::Range(0, NumEdges(), NumBits(), Context()); IdArray rst_eid = aten::Range(0, NumEdges(), NumBits(), Context());
return EdgeArray{adj_.row, adj_.col, rst_eid}; return EdgeArray{adj_.row, adj_.col, rst_eid};
} }
...@@ -366,17 +374,15 @@ Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const { ...@@ -366,17 +374,15 @@ Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
} }
CSRPtr COO::ToCSR() const { CSRPtr COO::ToCSR() const {
const auto& csr = aten::COOToCSR(adj_); const auto &csr = aten::COOToCSR(adj_);
return CSRPtr(new CSR(csr.indptr, csr.indices, csr.data)); return CSRPtr(new CSR(csr.indptr, csr.indices, csr.data));
} }
COO COO::CopyTo(const DGLContext& ctx) const { COO COO::CopyTo(const DGLContext &ctx) const {
if (Context() == ctx) { if (Context() == ctx) {
return *this; return *this;
} else { } else {
COO ret(NumVertices(), COO ret(NumVertices(), adj_.row.CopyTo(ctx), adj_.col.CopyTo(ctx));
adj_.row.CopyTo(ctx),
adj_.col.CopyTo(ctx));
return ret; return ret;
} }
} }
...@@ -390,8 +396,8 @@ COO COO::AsNumBits(uint8_t bits) const { ...@@ -390,8 +396,8 @@ COO COO::AsNumBits(uint8_t bits) const {
if (NumBits() == bits) { if (NumBits() == bits) {
return *this; return *this;
} else { } else {
COO ret(NumVertices(), COO ret(
aten::AsNumBits(adj_.row, bits), NumVertices(), aten::AsNumBits(adj_.row, bits),
aten::AsNumBits(adj_.col, bits)); aten::AsNumBits(adj_.col, bits));
return ret; return ret;
} }
...@@ -411,13 +417,14 @@ BoolArray ImmutableGraph::HasVertices(IdArray vids) const { ...@@ -411,13 +417,14 @@ BoolArray ImmutableGraph::HasVertices(IdArray vids) const {
CSRPtr ImmutableGraph::GetInCSR() const { CSRPtr ImmutableGraph::GetInCSR() const {
if (!in_csr_) { if (!in_csr_) {
if (out_csr_) { if (out_csr_) {
const_cast<ImmutableGraph*>(this)->in_csr_ = out_csr_->Transpose(); const_cast<ImmutableGraph *>(this)->in_csr_ = out_csr_->Transpose();
if (out_csr_->IsSharedMem()) if (out_csr_->IsSharedMem())
LOG(WARNING) << "We just construct an in-CSR from a shared-memory out CSR. " LOG(WARNING)
<< "We just construct an in-CSR from a shared-memory out CSR. "
<< "It may dramatically increase memory consumption."; << "It may dramatically increase memory consumption.";
} else { } else {
CHECK(coo_) << "None of CSR, COO exist"; CHECK(coo_) << "None of CSR, COO exist";
const_cast<ImmutableGraph*>(this)->in_csr_ = coo_->Transpose()->ToCSR(); const_cast<ImmutableGraph *>(this)->in_csr_ = coo_->Transpose()->ToCSR();
} }
} }
return in_csr_; return in_csr_;
...@@ -427,13 +434,14 @@ CSRPtr ImmutableGraph::GetInCSR() const { ...@@ -427,13 +434,14 @@ CSRPtr ImmutableGraph::GetInCSR() const {
CSRPtr ImmutableGraph::GetOutCSR() const { CSRPtr ImmutableGraph::GetOutCSR() const {
if (!out_csr_) { if (!out_csr_) {
if (in_csr_) { if (in_csr_) {
const_cast<ImmutableGraph*>(this)->out_csr_ = in_csr_->Transpose(); const_cast<ImmutableGraph *>(this)->out_csr_ = in_csr_->Transpose();
if (in_csr_->IsSharedMem()) if (in_csr_->IsSharedMem())
LOG(WARNING) << "We just construct an out-CSR from a shared-memory in CSR. " LOG(WARNING)
<< "We just construct an out-CSR from a shared-memory in CSR. "
<< "It may dramatically increase memory consumption."; << "It may dramatically increase memory consumption.";
} else { } else {
CHECK(coo_) << "None of CSR, COO exist"; CHECK(coo_) << "None of CSR, COO exist";
const_cast<ImmutableGraph*>(this)->out_csr_ = coo_->ToCSR(); const_cast<ImmutableGraph *>(this)->out_csr_ = coo_->ToCSR();
} }
} }
return out_csr_; return out_csr_;
...@@ -443,10 +451,10 @@ CSRPtr ImmutableGraph::GetOutCSR() const { ...@@ -443,10 +451,10 @@ CSRPtr ImmutableGraph::GetOutCSR() const {
COOPtr ImmutableGraph::GetCOO() const { COOPtr ImmutableGraph::GetCOO() const {
if (!coo_) { if (!coo_) {
if (in_csr_) { if (in_csr_) {
const_cast<ImmutableGraph*>(this)->coo_ = in_csr_->ToCOO()->Transpose(); const_cast<ImmutableGraph *>(this)->coo_ = in_csr_->ToCOO()->Transpose();
} else { } else {
CHECK(out_csr_) << "Both CSR are missing."; CHECK(out_csr_) << "Both CSR are missing.";
const_cast<ImmutableGraph*>(this)->coo_ = out_csr_->ToCOO(); const_cast<ImmutableGraph *>(this)->coo_ = out_csr_->ToCOO();
} }
} }
return coo_; return coo_;
...@@ -457,7 +465,7 @@ EdgeArray ImmutableGraph::Edges(const std::string &order) const { ...@@ -457,7 +465,7 @@ EdgeArray ImmutableGraph::Edges(const std::string &order) const {
// arbitrary order // arbitrary order
if (in_csr_) { if (in_csr_) {
// transpose // transpose
const auto& edges = in_csr_->Edges(order); const auto &edges = in_csr_->Edges(order);
return EdgeArray{edges.dst, edges.src, edges.id}; return EdgeArray{edges.dst, edges.src, edges.id};
} else { } else {
return AnyGraph()->Edges(order); return AnyGraph()->Edges(order);
...@@ -489,16 +497,19 @@ Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const { ...@@ -489,16 +497,19 @@ Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
return sg; return sg;
} }
std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &fmt) const { std::vector<IdArray> ImmutableGraph::GetAdj(
// TODO(minjie): Our current semantics of adjacency matrix is row for dst nodes and col for bool transpose, const std::string &fmt) const {
// src nodes. Therefore, we need to flip the transpose flag. For example, transpose=False // TODO(minjie): Our current semantics of adjacency matrix is row for dst
// is equal to in edge CSR. // nodes and col for
// We have this behavior because previously we use framework's SPMM and we don't cache // src nodes. Therefore, we need to flip the transpose flag. For example,
// reverse adj. This is not intuitive and also not consistent with networkx's // transpose=False is equal to in edge CSR. We have this behavior because
// to_scipy_sparse_matrix. With the upcoming custom kernel change, we should change the // previously we use framework's SPMM and we don't cache reverse adj. This
// behavior and make row for src and col for dst. // is not intuitive and also not consistent with networkx's
// to_scipy_sparse_matrix. With the upcoming custom kernel change, we should
// change the behavior and make row for src and col for dst.
if (fmt == std::string("csr")) { if (fmt == std::string("csr")) {
return transpose? GetOutCSR()->GetAdj(false, "csr") : GetInCSR()->GetAdj(false, "csr"); return transpose ? GetOutCSR()->GetAdj(false, "csr")
: GetInCSR()->GetAdj(false, "csr");
} else if (fmt == std::string("coo")) { } else if (fmt == std::string("coo")) {
return GetCOO()->GetAdj(!transpose, fmt); return GetCOO()->GetAdj(!transpose, fmt);
} else { } else {
...@@ -508,7 +519,8 @@ std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &f ...@@ -508,7 +519,8 @@ std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &f
} }
ImmutableGraphPtr ImmutableGraph::CreateFromCSR( ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir) { IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir) {
CSRPtr csr(new CSR(indptr, indices, edge_ids)); CSRPtr csr(new CSR(indptr, indices, edge_ids));
if (edge_dir == "in") { if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr)); return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
...@@ -530,17 +542,19 @@ ImmutableGraphPtr ImmutableGraph::CreateFromCSR(const std::string &name) { ...@@ -530,17 +542,19 @@ ImmutableGraphPtr ImmutableGraph::CreateFromCSR(const std::string &name) {
GraphIndexMetadata meta = DeserializeMetadata(GetSharedMemName(name, "meta")); GraphIndexMetadata meta = DeserializeMetadata(GetSharedMemName(name, "meta"));
CSRPtr in_csr, out_csr; CSRPtr in_csr, out_csr;
if (meta.has_in_csr) { if (meta.has_in_csr) {
in_csr = CSRPtr(new CSR(GetSharedMemName(name, "in"), meta.num_nodes, meta.num_edges)); in_csr = CSRPtr(
new CSR(GetSharedMemName(name, "in"), meta.num_nodes, meta.num_edges));
} }
if (meta.has_out_csr) { if (meta.has_out_csr) {
out_csr = CSRPtr(new CSR(GetSharedMemName(name, "out"), meta.num_nodes, meta.num_edges)); out_csr = CSRPtr(
new CSR(GetSharedMemName(name, "out"), meta.num_nodes, meta.num_edges));
} }
return ImmutableGraphPtr(new ImmutableGraph(in_csr, out_csr, name)); return ImmutableGraphPtr(new ImmutableGraph(in_csr, out_csr, name));
} }
ImmutableGraphPtr ImmutableGraph::CreateFromCOO( ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst, int64_t num_vertices, IdArray src, IdArray dst, bool row_sorted,
bool row_sorted, bool col_sorted) { bool col_sorted) {
COOPtr coo(new COO(num_vertices, src, dst, row_sorted, col_sorted)); COOPtr coo(new COO(num_vertices, src, dst, row_sorted, col_sorted));
return std::make_shared<ImmutableGraph>(coo); return std::make_shared<ImmutableGraph>(coo);
} }
...@@ -550,13 +564,14 @@ ImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) { ...@@ -550,13 +564,14 @@ ImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) {
if (ig) { if (ig) {
return ig; return ig;
} else { } else {
const auto& adj = graph->GetAdj(true, "csr"); const auto &adj = graph->GetAdj(true, "csr");
CSRPtr csr(new CSR(adj[0], adj[1], adj[2])); CSRPtr csr(new CSR(adj[0], adj[1], adj[2]));
return ImmutableGraph::CreateFromCSR(adj[0], adj[1], adj[2], "out"); return ImmutableGraph::CreateFromCSR(adj[0], adj[1], adj[2], "out");
} }
} }
ImmutableGraphPtr ImmutableGraph::CopyTo(ImmutableGraphPtr g, const DGLContext& ctx) { ImmutableGraphPtr ImmutableGraph::CopyTo(
ImmutableGraphPtr g, const DGLContext &ctx) {
if (ctx == g->Context()) { if (ctx == g->Context()) {
return g; return g;
} }
...@@ -569,16 +584,20 @@ ImmutableGraphPtr ImmutableGraph::CopyTo(ImmutableGraphPtr g, const DGLContext& ...@@ -569,16 +584,20 @@ ImmutableGraphPtr ImmutableGraph::CopyTo(ImmutableGraphPtr g, const DGLContext&
return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr)); return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
} }
ImmutableGraphPtr ImmutableGraph::CopyToSharedMem(ImmutableGraphPtr g, const std::string &name) { ImmutableGraphPtr ImmutableGraph::CopyToSharedMem(
ImmutableGraphPtr g, const std::string &name) {
CSRPtr new_incsr, new_outcsr; CSRPtr new_incsr, new_outcsr;
std::string shared_mem_name = GetSharedMemName(name, "in"); std::string shared_mem_name = GetSharedMemName(name, "in");
new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyToSharedMem(shared_mem_name))); new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyToSharedMem(shared_mem_name)));
shared_mem_name = GetSharedMemName(name, "out"); shared_mem_name = GetSharedMemName(name, "out");
new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyToSharedMem(shared_mem_name))); new_outcsr =
CSRPtr(new CSR(g->GetOutCSR()->CopyToSharedMem(shared_mem_name)));
auto new_g = ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr, name)); auto new_g =
new_g->serialized_shared_meta_ = SerializeMetadata(new_g, GetSharedMemName(name, "meta")); ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr, name));
new_g->serialized_shared_meta_ =
SerializeMetadata(new_g, GetSharedMemName(name, "meta"));
return new_g; return new_g;
} }
...@@ -598,8 +617,8 @@ ImmutableGraphPtr ImmutableGraph::AsNumBits(ImmutableGraphPtr g, uint8_t bits) { ...@@ -598,8 +617,8 @@ ImmutableGraphPtr ImmutableGraph::AsNumBits(ImmutableGraphPtr g, uint8_t bits) {
ImmutableGraphPtr ImmutableGraph::Reverse() const { ImmutableGraphPtr ImmutableGraph::Reverse() const {
if (coo_) { if (coo_) {
return ImmutableGraphPtr(new ImmutableGraph( return ImmutableGraphPtr(
out_csr_, in_csr_, coo_->Transpose())); new ImmutableGraph(out_csr_, in_csr_, coo_->Transpose()));
} else { } else {
return ImmutableGraphPtr(new ImmutableGraph(out_csr_, in_csr_)); return ImmutableGraphPtr(new ImmutableGraph(out_csr_, in_csr_));
} }
...@@ -628,54 +647,53 @@ HeteroGraphPtr ImmutableGraph::AsHeteroGraph() const { ...@@ -628,54 +647,53 @@ HeteroGraphPtr ImmutableGraph::AsHeteroGraph() const {
aten::CSRMatrix in_csr, out_csr; aten::CSRMatrix in_csr, out_csr;
aten::COOMatrix coo; aten::COOMatrix coo;
if (in_csr_) if (in_csr_) in_csr = GetInCSR()->ToCSRMatrix();
in_csr = GetInCSR()->ToCSRMatrix(); if (out_csr_) out_csr = GetOutCSR()->ToCSRMatrix();
if (out_csr_) if (coo_) coo = GetCOO()->ToCOOMatrix();
out_csr = GetOutCSR()->ToCSRMatrix();
if (coo_)
coo = GetCOO()->ToCOOMatrix();
auto g = UnitGraph::CreateUnitGraphFrom( auto g = UnitGraph::CreateUnitGraphFrom(
1, in_csr, out_csr, coo, 1, in_csr, out_csr, coo, in_csr_ != nullptr, out_csr_ != nullptr,
in_csr_ != nullptr,
out_csr_ != nullptr,
coo_ != nullptr); coo_ != nullptr);
return HeteroGraphPtr(new HeteroGraph(g->meta_graph(), {g})); return HeteroGraphPtr(new HeteroGraph(g->meta_graph(), {g}));
} }
DGL_REGISTER_GLOBAL("transform._CAPI_DGLAsHeteroGraph") DGL_REGISTER_GLOBAL("transform._CAPI_DGLAsHeteroGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef g = args[0]; GraphRef g = args[0];
ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); ImmutableGraphPtr ig =
std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(ig) << "graph is not readonly"; CHECK(ig) << "graph is not readonly";
*rv = HeteroGraphRef(ig->AsHeteroGraph()); *rv = HeteroGraphRef(ig->AsHeteroGraph());
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const int device_type = args[1]; const int device_type = args[1];
const int device_id = args[2]; const int device_id = args[2];
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(device_type); ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr())); ImmutableGraphPtr ig =
CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
*rv = ImmutableGraph::CopyTo(ig, ctx); *rv = ImmutableGraph::CopyTo(ig, ctx);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyToSharedMem") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef g = args[0]; GraphRef g = args[0];
std::string name = args[1]; std::string name = args[1];
ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr())); ImmutableGraphPtr ig =
CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
*rv = ImmutableGraph::CopyToSharedMem(ig, name); *rv = ImmutableGraph::CopyToSharedMem(ig, name);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphAsNumBits") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef g = args[0]; GraphRef g = args[0];
int bits = args[1]; int bits = args[1];
ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr())); ImmutableGraphPtr ig =
CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
*rv = ImmutableGraph::AsNumBits(ig, bits); *rv = ImmutableGraph::AsNumBits(ig, bits);
}); });
......
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
* \brief Call Metis partitioning * \brief Call Metis partitioning
*/ */
#include <metis.h>
#include <dgl/graph_op.h> #include <dgl/graph_op.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <metis.h>
#include "../c_api_common.h" #include "../c_api_common.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -25,12 +26,12 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) { ...@@ -25,12 +26,12 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) {
idx_t nvtxs = g->NumVertices(); idx_t nvtxs = g->NumVertices();
idx_t ncon = 1; // # balacing constraints. idx_t ncon = 1; // # balacing constraints.
idx_t *xadj = static_cast<idx_t*>(mat.indptr->data); idx_t *xadj = static_cast<idx_t *>(mat.indptr->data);
idx_t *adjncy = static_cast<idx_t*>(mat.indices->data); idx_t *adjncy = static_cast<idx_t *>(mat.indices->data);
idx_t nparts = k; idx_t nparts = k;
IdArray part_arr = aten::NewIdArray(nvtxs); IdArray part_arr = aten::NewIdArray(nvtxs);
idx_t objval = 0; idx_t objval = 0;
idx_t *part = static_cast<idx_t*>(part_arr->data); idx_t *part = static_cast<idx_t *>(part_arr->data);
int64_t vwgt_len = vwgt_arr->shape[0]; int64_t vwgt_len = vwgt_arr->shape[0];
CHECK_EQ(sizeof(idx_t), vwgt_arr->dtype.bits / 8) CHECK_EQ(sizeof(idx_t), vwgt_arr->dtype.bits / 8)
...@@ -40,7 +41,7 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) { ...@@ -40,7 +41,7 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) {
idx_t *vwgt = NULL; idx_t *vwgt = NULL;
if (vwgt_len > 0) { if (vwgt_len > 0) {
ncon = vwgt_len / g->NumVertices(); ncon = vwgt_len / g->NumVertices();
vwgt = static_cast<idx_t*>(vwgt_arr->data); vwgt = static_cast<idx_t *>(vwgt_arr->data);
} }
idx_t options[METIS_NOPTIONS]; idx_t options[METIS_NOPTIONS];
...@@ -56,7 +57,8 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) { ...@@ -56,7 +57,8 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) {
options[METIS_OPTION_OBJTYPE] = METIS_OBJTYPE_VOL; options[METIS_OPTION_OBJTYPE] = METIS_OBJTYPE_VOL;
} }
int ret = METIS_PartGraphKway(&nvtxs, // The number of vertices int ret = METIS_PartGraphKway(
&nvtxs, // The number of vertices
&ncon, // The number of balancing constraints. &ncon, // The number of balancing constraints.
xadj, // indptr xadj, // indptr
adjncy, // indices adjncy, // indices
...@@ -99,7 +101,7 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) { ...@@ -99,7 +101,7 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) {
#endif // !defined(_WIN32) #endif // !defined(_WIN32)
DGL_REGISTER_GLOBAL("transform._CAPI_DGLMetisPartition") DGL_REGISTER_GLOBAL("transform._CAPI_DGLMetisPartition")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef g = args[0]; GraphRef g = args[0];
int k = args[1]; int k = args[1];
NDArray vwgt = args[2]; NDArray vwgt = args[2];
......
...@@ -5,21 +5,20 @@ ...@@ -5,21 +5,20 @@
*/ */
#include "./network.h" #include "./network.h"
#include <stdlib.h> #include <dgl/immutable_graph.h>
#include <dgl/nodeflow.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/runtime/ndarray.h> #include <dgl/runtime/ndarray.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <dgl/packed_func_ext.h> #include <stdlib.h>
#include <dgl/immutable_graph.h>
#include <dgl/nodeflow.h>
#include <unordered_map> #include <unordered_map>
#include "../rpc/network/common.h"
#include "../rpc/network/communicator.h" #include "../rpc/network/communicator.h"
#include "../rpc/network/socket_communicator.h"
#include "../rpc/network/msg_queue.h" #include "../rpc/network/msg_queue.h"
#include "../rpc/network/common.h" #include "../rpc/network/socket_communicator.h"
using dgl::network::StringPrintf; using dgl::network::StringPrintf;
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -29,11 +28,8 @@ const bool AUTO_FREE = true; ...@@ -29,11 +28,8 @@ const bool AUTO_FREE = true;
namespace dgl { namespace dgl {
namespace network { namespace network {
NDArray CreateNDArrayFromRaw(
NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape, std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx, void* raw,
DGLDataType dtype,
DGLContext ctx,
void* raw,
bool auto_free) { bool auto_free) {
return NDArray::CreateFromRaw(shape, dtype, ctx, raw, auto_free); return NDArray::CreateFromRaw(shape, dtype, ctx, raw, auto_free);
} }
...@@ -74,16 +70,16 @@ char* ArrayMeta::Serialize(int64_t* size) { ...@@ -74,16 +70,16 @@ char* ArrayMeta::Serialize(int64_t* size) {
*(reinterpret_cast<int*>(pointer)) = ndarray_count_; *(reinterpret_cast<int*>(pointer)) = ndarray_count_;
pointer += sizeof(ndarray_count_); pointer += sizeof(ndarray_count_);
// Write data type // Write data type
memcpy(pointer, memcpy(
reinterpret_cast<DGLDataType*>(data_type_.data()), pointer, reinterpret_cast<DGLDataType*>(data_type_.data()),
sizeof(DGLDataType) * data_type_.size()); sizeof(DGLDataType) * data_type_.size());
pointer += (sizeof(DGLDataType) * data_type_.size()); pointer += (sizeof(DGLDataType) * data_type_.size());
// Write size of data_shape_ // Write size of data_shape_
*(reinterpret_cast<size_t*>(pointer)) = data_shape_.size(); *(reinterpret_cast<size_t*>(pointer)) = data_shape_.size();
pointer += sizeof(data_shape_.size()); pointer += sizeof(data_shape_.size());
// Write data of data_shape_ // Write data of data_shape_
memcpy(pointer, memcpy(
reinterpret_cast<char*>(data_shape_.data()), pointer, reinterpret_cast<char*>(data_shape_.data()),
sizeof(int64_t) * data_shape_.size()); sizeof(int64_t) * data_shape_.size());
} }
*size = buffer_size; *size = buffer_size;
...@@ -103,8 +99,7 @@ void ArrayMeta::Deserialize(char* buffer, int64_t size) { ...@@ -103,8 +99,7 @@ void ArrayMeta::Deserialize(char* buffer, int64_t size) {
data_size += sizeof(int); data_size += sizeof(int);
// Read data type // Read data type
data_type_.resize(ndarray_count_); data_type_.resize(ndarray_count_);
memcpy(data_type_.data(), buffer, memcpy(data_type_.data(), buffer, ndarray_count_ * sizeof(DGLDataType));
ndarray_count_ * sizeof(DGLDataType));
buffer += ndarray_count_ * sizeof(DGLDataType); buffer += ndarray_count_ * sizeof(DGLDataType);
data_size += ndarray_count_ * sizeof(DGLDataType); data_size += ndarray_count_ * sizeof(DGLDataType);
// Read size of data_shape_ // Read size of data_shape_
...@@ -113,8 +108,7 @@ void ArrayMeta::Deserialize(char* buffer, int64_t size) { ...@@ -113,8 +108,7 @@ void ArrayMeta::Deserialize(char* buffer, int64_t size) {
data_size += sizeof(size_t); data_size += sizeof(size_t);
data_shape_.resize(count); data_shape_.resize(count);
// Read data of data_shape_ // Read data of data_shape_
memcpy(data_shape_.data(), buffer, memcpy(data_shape_.data(), buffer, count * sizeof(int64_t));
count * sizeof(int64_t));
data_size += count * sizeof(int64_t); data_size += count * sizeof(int64_t);
} }
CHECK_EQ(data_size, size); CHECK_EQ(data_size, size);
...@@ -170,11 +164,11 @@ void KVStoreMsg::Deserialize(char* buffer, int64_t size) { ...@@ -170,11 +164,11 @@ void KVStoreMsg::Deserialize(char* buffer, int64_t size) {
CHECK_EQ(data_size, size); CHECK_EQ(data_size, size);
} }
////////////////////////////////// Basic Networking Components //////////////////////////////// ////////////////////////////////// Basic Networking Components
///////////////////////////////////
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate") DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
std::string type = args[0]; std::string type = args[0];
int64_t msg_queue_size = args[1]; int64_t msg_queue_size = args[1];
network::Sender* sender = nullptr; network::Sender* sender = nullptr;
...@@ -188,7 +182,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate") ...@@ -188,7 +182,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
}); });
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate") DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
std::string type = args[0]; std::string type = args[0];
int64_t msg_queue_size = args[1]; int64_t msg_queue_size = args[1];
network::Receiver* receiver = nullptr; network::Receiver* receiver = nullptr;
...@@ -202,21 +196,22 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate") ...@@ -202,21 +196,22 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
}); });
DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeSender") DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeSender")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
network::Sender* sender = static_cast<network::Sender*>(chandle); network::Sender* sender = static_cast<network::Sender*>(chandle);
sender->Finalize(); sender->Finalize();
}); });
DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeReceiver") DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeReceiver")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle); network::Receiver* receiver =
static_cast<network::SocketReceiver*>(chandle);
receiver->Finalize(); receiver->Finalize();
}); });
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver") DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
std::string ip = args[1]; std::string ip = args[1];
int port = args[2]; int port = args[2];
...@@ -232,7 +227,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver") ...@@ -232,7 +227,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver")
}); });
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect") DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
network::Sender* sender = static_cast<network::Sender*>(chandle); network::Sender* sender = static_cast<network::Sender*>(chandle);
const int max_try_times = 1024; const int max_try_times = 1024;
...@@ -242,12 +237,13 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect") ...@@ -242,12 +237,13 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect")
}); });
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverWait") DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverWait")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
std::string ip = args[1]; std::string ip = args[1];
int port = args[2]; int port = args[2];
int num_sender = args[3]; int num_sender = args[3];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle); network::Receiver* receiver =
static_cast<network::SocketReceiver*>(chandle);
std::string addr; std::string addr;
if (receiver->NetType() == "socket") { if (receiver->NetType() == "socket") {
addr = StringPrintf("socket://%s:%d", ip.c_str(), port); addr = StringPrintf("socket://%s:%d", ip.c_str(), port);
...@@ -259,12 +255,11 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverWait") ...@@ -259,12 +255,11 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverWait")
} }
}); });
////////////////////////// Distributed Sampler Components
////////////////////////// Distributed Sampler Components //////////////////////////////// ///////////////////////////////////
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendNodeFlow") DGL_REGISTER_GLOBAL("network._CAPI_SenderSendNodeFlow")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
int recv_id = args[1]; int recv_id = args[1];
GraphRef g = args[2]; GraphRef g = args[2];
...@@ -348,7 +343,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendNodeFlow") ...@@ -348,7 +343,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendNodeFlow")
}); });
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSamplerEndSignal") DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSamplerEndSignal")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
int recv_id = args[1]; int recv_id = args[1];
ArrayMeta meta(kFinalMsg); ArrayMeta meta(kFinalMsg);
...@@ -361,9 +356,10 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSamplerEndSignal") ...@@ -361,9 +356,10 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSamplerEndSignal")
}); });
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle); network::Receiver* receiver =
static_cast<network::SocketReceiver*>(chandle);
int send_id = 0; int send_id = 0;
Message recv_msg; Message recv_msg;
CHECK_EQ(receiver->Recv(&recv_msg, &send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->Recv(&recv_msg, &send_id), REMOVE_SUCCESS);
...@@ -377,71 +373,50 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") ...@@ -377,71 +373,50 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
CHECK_EQ(receiver->RecvFrom(&array_0, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&array_0, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[0], 1); CHECK_EQ(meta.data_shape_[0], 1);
nf->node_mapping = CreateNDArrayFromRaw( nf->node_mapping = CreateNDArrayFromRaw(
{meta.data_shape_[1]}, {meta.data_shape_[1]}, DGLDataType{kDGLInt, 64, 1},
DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}, array_0.data, AUTO_FREE);
DGLContext{kDGLCPU, 0},
array_0.data,
AUTO_FREE);
// edge_mapping // edge_mapping
Message array_1; Message array_1;
CHECK_EQ(receiver->RecvFrom(&array_1, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&array_1, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[2], 1); CHECK_EQ(meta.data_shape_[2], 1);
nf->edge_mapping = CreateNDArrayFromRaw( nf->edge_mapping = CreateNDArrayFromRaw(
{meta.data_shape_[3]}, {meta.data_shape_[3]}, DGLDataType{kDGLInt, 64, 1},
DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}, array_1.data, AUTO_FREE);
DGLContext{kDGLCPU, 0},
array_1.data,
AUTO_FREE);
// layer_offset // layer_offset
Message array_2; Message array_2;
CHECK_EQ(receiver->RecvFrom(&array_2, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&array_2, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[4], 1); CHECK_EQ(meta.data_shape_[4], 1);
nf->layer_offsets = CreateNDArrayFromRaw( nf->layer_offsets = CreateNDArrayFromRaw(
{meta.data_shape_[5]}, {meta.data_shape_[5]}, DGLDataType{kDGLInt, 64, 1},
DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}, array_2.data, AUTO_FREE);
DGLContext{kDGLCPU, 0},
array_2.data,
AUTO_FREE);
// flow_offset // flow_offset
Message array_3; Message array_3;
CHECK_EQ(receiver->RecvFrom(&array_3, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&array_3, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[6], 1); CHECK_EQ(meta.data_shape_[6], 1);
nf->flow_offsets = CreateNDArrayFromRaw( nf->flow_offsets = CreateNDArrayFromRaw(
{meta.data_shape_[7]}, {meta.data_shape_[7]}, DGLDataType{kDGLInt, 64, 1},
DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}, array_3.data, AUTO_FREE);
DGLContext{kDGLCPU, 0},
array_3.data,
AUTO_FREE);
// CSR indptr // CSR indptr
Message array_4; Message array_4;
CHECK_EQ(receiver->RecvFrom(&array_4, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&array_4, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[8], 1); CHECK_EQ(meta.data_shape_[8], 1);
NDArray indptr = CreateNDArrayFromRaw( NDArray indptr = CreateNDArrayFromRaw(
{meta.data_shape_[9]}, {meta.data_shape_[9]}, DGLDataType{kDGLInt, 64, 1},
DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}, array_4.data, AUTO_FREE);
DGLContext{kDGLCPU, 0},
array_4.data,
AUTO_FREE);
// CSR indice // CSR indice
Message array_5; Message array_5;
CHECK_EQ(receiver->RecvFrom(&array_5, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&array_5, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[10], 1); CHECK_EQ(meta.data_shape_[10], 1);
NDArray indice = CreateNDArrayFromRaw( NDArray indice = CreateNDArrayFromRaw(
{meta.data_shape_[11]}, {meta.data_shape_[11]}, DGLDataType{kDGLInt, 64, 1},
DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}, array_5.data, AUTO_FREE);
DGLContext{kDGLCPU, 0},
array_5.data,
AUTO_FREE);
// CSR edge_ids // CSR edge_ids
Message array_6; Message array_6;
CHECK_EQ(receiver->RecvFrom(&array_6, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&array_6, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[12], 1); CHECK_EQ(meta.data_shape_[12], 1);
NDArray edge_ids = CreateNDArrayFromRaw( NDArray edge_ids = CreateNDArrayFromRaw(
{meta.data_shape_[13]}, {meta.data_shape_[13]}, DGLDataType{kDGLInt, 64, 1},
DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}, array_6.data, AUTO_FREE);
DGLContext{kDGLCPU, 0},
array_6.data,
AUTO_FREE);
// Create CSR // Create CSR
CSRPtr csr(new CSR(indptr, indice, edge_ids)); CSRPtr csr(new CSR(indptr, indice, edge_ids));
nf->graph = GraphPtr(new ImmutableGraph(csr, nullptr)); nf->graph = GraphPtr(new ImmutableGraph(csr, nullptr));
...@@ -453,14 +428,11 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") ...@@ -453,14 +428,11 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
} }
}); });
////////////////////////// Distributed KVStore Components
///////////////////////////////////
////////////////////////// Distributed KVStore Components //////////////////////////////// static void send_kv_message(
network::Sender* sender, KVStoreMsg* kv_msg, int recv_id, bool auto_free) {
static void send_kv_message(network::Sender* sender,
KVStoreMsg* kv_msg,
int recv_id,
bool auto_free) {
int64_t kv_size = 0; int64_t kv_size = 0;
char* kv_data = kv_msg->Serialize(&kv_size); char* kv_data = kv_msg->Serialize(&kv_size);
// Send kv_data // Send kv_data
...@@ -471,23 +443,18 @@ static void send_kv_message(network::Sender* sender, ...@@ -471,23 +443,18 @@ static void send_kv_message(network::Sender* sender,
send_kv_msg.deallocator = DefaultMessageDeleter; send_kv_msg.deallocator = DefaultMessageDeleter;
} }
CHECK_EQ(sender->Send(send_kv_msg, recv_id), ADD_SUCCESS); CHECK_EQ(sender->Send(send_kv_msg, recv_id), ADD_SUCCESS);
if (kv_msg->msg_type != kFinalMsg && if (kv_msg->msg_type != kFinalMsg && kv_msg->msg_type != kBarrierMsg &&
kv_msg->msg_type != kBarrierMsg && kv_msg->msg_type != kIPIDMsg && kv_msg->msg_type != kGetShapeMsg) {
kv_msg->msg_type != kIPIDMsg &&
kv_msg->msg_type != kGetShapeMsg) {
// Send ArrayMeta // Send ArrayMeta
ArrayMeta meta(kv_msg->msg_type); ArrayMeta meta(kv_msg->msg_type);
if (kv_msg->msg_type != kInitMsg && if (kv_msg->msg_type != kInitMsg && kv_msg->msg_type != kGetShapeBackMsg) {
kv_msg->msg_type != kGetShapeBackMsg) {
meta.AddArray(kv_msg->id); meta.AddArray(kv_msg->id);
} }
if (kv_msg->msg_type != kPullMsg && if (kv_msg->msg_type != kPullMsg && kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeBackMsg) { kv_msg->msg_type != kGetShapeBackMsg) {
meta.AddArray(kv_msg->data); meta.AddArray(kv_msg->data);
} }
if (kv_msg->msg_type != kPullMsg && if (kv_msg->msg_type != kPullMsg && kv_msg->msg_type != kPushMsg &&
kv_msg->msg_type != kPushMsg &&
kv_msg->msg_type != kPullBackMsg) { kv_msg->msg_type != kPullBackMsg) {
meta.AddArray(kv_msg->shape); meta.AddArray(kv_msg->shape);
} }
...@@ -501,8 +468,7 @@ static void send_kv_message(network::Sender* sender, ...@@ -501,8 +468,7 @@ static void send_kv_message(network::Sender* sender,
} }
CHECK_EQ(sender->Send(send_meta_msg, recv_id), ADD_SUCCESS); CHECK_EQ(sender->Send(send_meta_msg, recv_id), ADD_SUCCESS);
// Send ID NDArray // Send ID NDArray
if (kv_msg->msg_type != kInitMsg && if (kv_msg->msg_type != kInitMsg && kv_msg->msg_type != kGetShapeMsg &&
kv_msg->msg_type != kGetShapeMsg &&
kv_msg->msg_type != kGetShapeBackMsg) { kv_msg->msg_type != kGetShapeBackMsg) {
Message send_id_msg; Message send_id_msg;
send_id_msg.data = static_cast<char*>(kv_msg->id->data); send_id_msg.data = static_cast<char*>(kv_msg->id->data);
...@@ -514,8 +480,7 @@ static void send_kv_message(network::Sender* sender, ...@@ -514,8 +480,7 @@ static void send_kv_message(network::Sender* sender,
CHECK_EQ(sender->Send(send_id_msg, recv_id), ADD_SUCCESS); CHECK_EQ(sender->Send(send_id_msg, recv_id), ADD_SUCCESS);
} }
// Send data NDArray // Send data NDArray
if (kv_msg->msg_type != kPullMsg && if (kv_msg->msg_type != kPullMsg && kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeMsg && kv_msg->msg_type != kGetShapeMsg &&
kv_msg->msg_type != kGetShapeBackMsg) { kv_msg->msg_type != kGetShapeBackMsg) {
Message send_data_msg; Message send_data_msg;
...@@ -528,8 +493,7 @@ static void send_kv_message(network::Sender* sender, ...@@ -528,8 +493,7 @@ static void send_kv_message(network::Sender* sender,
CHECK_EQ(sender->Send(send_data_msg, recv_id), ADD_SUCCESS); CHECK_EQ(sender->Send(send_data_msg, recv_id), ADD_SUCCESS);
} }
// Send shape NDArray // Send shape NDArray
if (kv_msg->msg_type != kPullMsg && if (kv_msg->msg_type != kPullMsg && kv_msg->msg_type != kPushMsg &&
kv_msg->msg_type != kPushMsg &&
kv_msg->msg_type != kPullBackMsg) { kv_msg->msg_type != kPullBackMsg) {
Message send_shape_msg; Message send_shape_msg;
send_shape_msg.data = static_cast<char*>(kv_msg->shape->data); send_shape_msg.data = static_cast<char*>(kv_msg->shape->data);
...@@ -544,17 +508,15 @@ static void send_kv_message(network::Sender* sender, ...@@ -544,17 +508,15 @@ static void send_kv_message(network::Sender* sender,
} }
static KVStoreMsg* recv_kv_message(network::Receiver* receiver) { static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
KVStoreMsg *kv_msg = new KVStoreMsg(); KVStoreMsg* kv_msg = new KVStoreMsg();
// Recv kv_Msg // Recv kv_Msg
Message recv_kv_msg; Message recv_kv_msg;
int send_id; int send_id;
CHECK_EQ(receiver->Recv(&recv_kv_msg, &send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->Recv(&recv_kv_msg, &send_id), REMOVE_SUCCESS);
kv_msg->Deserialize(recv_kv_msg.data, recv_kv_msg.size); kv_msg->Deserialize(recv_kv_msg.data, recv_kv_msg.size);
recv_kv_msg.deallocator(&recv_kv_msg); recv_kv_msg.deallocator(&recv_kv_msg);
if (kv_msg->msg_type == kFinalMsg || if (kv_msg->msg_type == kFinalMsg || kv_msg->msg_type == kBarrierMsg ||
kv_msg->msg_type == kBarrierMsg || kv_msg->msg_type == kIPIDMsg || kv_msg->msg_type == kGetShapeMsg) {
kv_msg->msg_type == kIPIDMsg ||
kv_msg->msg_type == kGetShapeMsg) {
return kv_msg; return kv_msg;
} }
// Recv ArrayMeta // Recv ArrayMeta
...@@ -563,21 +525,16 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) { ...@@ -563,21 +525,16 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
ArrayMeta meta(recv_meta_msg.data, recv_meta_msg.size); ArrayMeta meta(recv_meta_msg.data, recv_meta_msg.size);
recv_meta_msg.deallocator(&recv_meta_msg); recv_meta_msg.deallocator(&recv_meta_msg);
// Recv ID NDArray // Recv ID NDArray
if (kv_msg->msg_type != kInitMsg && if (kv_msg->msg_type != kInitMsg && kv_msg->msg_type != kGetShapeBackMsg) {
kv_msg->msg_type != kGetShapeBackMsg) {
Message recv_id_msg; Message recv_id_msg;
CHECK_EQ(receiver->RecvFrom(&recv_id_msg, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&recv_id_msg, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[0], 1); CHECK_EQ(meta.data_shape_[0], 1);
kv_msg->id = CreateNDArrayFromRaw( kv_msg->id = CreateNDArrayFromRaw(
{meta.data_shape_[1]}, {meta.data_shape_[1]}, meta.data_type_[0], DGLContext{kDGLCPU, 0},
meta.data_type_[0], recv_id_msg.data, AUTO_FREE);
DGLContext{kDGLCPU, 0},
recv_id_msg.data,
AUTO_FREE);
} }
// Recv Data NDArray // Recv Data NDArray
if (kv_msg->msg_type != kPullMsg && if (kv_msg->msg_type != kPullMsg && kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeBackMsg) { kv_msg->msg_type != kGetShapeBackMsg) {
Message recv_data_msg; Message recv_data_msg;
CHECK_EQ(receiver->RecvFrom(&recv_data_msg, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&recv_data_msg, send_id), REMOVE_SUCCESS);
...@@ -585,18 +542,14 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) { ...@@ -585,18 +542,14 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
CHECK_GE(ndim, 1); CHECK_GE(ndim, 1);
std::vector<int64_t> vec_shape; std::vector<int64_t> vec_shape;
for (int i = 0; i < ndim; ++i) { for (int i = 0; i < ndim; ++i) {
vec_shape.push_back(meta.data_shape_[3+i]); vec_shape.push_back(meta.data_shape_[3 + i]);
} }
kv_msg->data = CreateNDArrayFromRaw( kv_msg->data = CreateNDArrayFromRaw(
vec_shape, vec_shape, meta.data_type_[1], DGLContext{kDGLCPU, 0},
meta.data_type_[1], recv_data_msg.data, AUTO_FREE);
DGLContext{kDGLCPU, 0},
recv_data_msg.data,
AUTO_FREE);
} }
// Recv Shape // Recv Shape
if (kv_msg->msg_type != kPullMsg && if (kv_msg->msg_type != kPullMsg && kv_msg->msg_type != kPushMsg &&
kv_msg->msg_type != kPushMsg &&
kv_msg->msg_type != kPullBackMsg) { kv_msg->msg_type != kPullBackMsg) {
Message recv_shape_msg; Message recv_shape_msg;
CHECK_EQ(receiver->RecvFrom(&recv_shape_msg, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&recv_shape_msg, send_id), REMOVE_SUCCESS);
...@@ -604,20 +557,17 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) { ...@@ -604,20 +557,17 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
CHECK_GE(ndim, 1); CHECK_GE(ndim, 1);
std::vector<int64_t> vec_shape; std::vector<int64_t> vec_shape;
for (int i = 0; i < ndim; ++i) { for (int i = 0; i < ndim; ++i) {
vec_shape.push_back(meta.data_shape_[1+i]); vec_shape.push_back(meta.data_shape_[1 + i]);
} }
kv_msg->shape = CreateNDArrayFromRaw( kv_msg->shape = CreateNDArrayFromRaw(
vec_shape, vec_shape, meta.data_type_[0], DGLContext{kDGLCPU, 0},
meta.data_type_[0], recv_shape_msg.data, AUTO_FREE);
DGLContext{kDGLCPU, 0},
recv_shape_msg.data,
AUTO_FREE);
} }
return kv_msg; return kv_msg;
} }
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg") DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
int args_count = 0; int args_count = 0;
CommunicatorHandle chandle = args[args_count++]; CommunicatorHandle chandle = args[args_count++];
int recv_id = args[args_count++]; int recv_id = args[args_count++];
...@@ -628,23 +578,18 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg") ...@@ -628,23 +578,18 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg")
if (kv_msg.msg_type != kFinalMsg && kv_msg.msg_type != kBarrierMsg) { if (kv_msg.msg_type != kFinalMsg && kv_msg.msg_type != kBarrierMsg) {
std::string name = args[args_count++]; std::string name = args[args_count++];
kv_msg.name = name; kv_msg.name = name;
if (kv_msg.msg_type != kIPIDMsg && if (kv_msg.msg_type != kIPIDMsg && kv_msg.msg_type != kInitMsg &&
kv_msg.msg_type != kInitMsg &&
kv_msg.msg_type != kGetShapeMsg && kv_msg.msg_type != kGetShapeMsg &&
kv_msg.msg_type != kGetShapeBackMsg) { kv_msg.msg_type != kGetShapeBackMsg) {
kv_msg.id = args[args_count++]; kv_msg.id = args[args_count++];
} }
if (kv_msg.msg_type != kPullMsg && if (kv_msg.msg_type != kPullMsg && kv_msg.msg_type != kIPIDMsg &&
kv_msg.msg_type != kIPIDMsg && kv_msg.msg_type != kInitMsg && kv_msg.msg_type != kGetShapeMsg &&
kv_msg.msg_type != kInitMsg &&
kv_msg.msg_type != kGetShapeMsg &&
kv_msg.msg_type != kGetShapeBackMsg) { kv_msg.msg_type != kGetShapeBackMsg) {
kv_msg.data = args[args_count++]; kv_msg.data = args[args_count++];
} }
if (kv_msg.msg_type != kIPIDMsg && if (kv_msg.msg_type != kIPIDMsg && kv_msg.msg_type != kPullMsg &&
kv_msg.msg_type != kPullMsg && kv_msg.msg_type != kPushMsg && kv_msg.msg_type != kPullBackMsg &&
kv_msg.msg_type != kPushMsg &&
kv_msg.msg_type != kPullBackMsg &&
kv_msg.msg_type != kGetShapeMsg) { kv_msg.msg_type != kGetShapeMsg) {
kv_msg.shape = args[args_count++]; kv_msg.shape = args[args_count++];
} }
...@@ -653,63 +598,64 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg") ...@@ -653,63 +598,64 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg")
}); });
DGL_REGISTER_GLOBAL("network.CAPI_ReceiverRecvKVMsg") DGL_REGISTER_GLOBAL("network.CAPI_ReceiverRecvKVMsg")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle); network::Receiver* receiver =
static_cast<network::SocketReceiver*>(chandle);
*rv = recv_kv_message(receiver); *rv = recv_kv_message(receiver);
}); });
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgType") DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgType")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
KVMsgHandle chandle = args[0]; KVMsgHandle chandle = args[0];
network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle); network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle);
*rv = msg->msg_type; *rv = msg->msg_type;
}); });
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgRank") DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgRank")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
KVMsgHandle chandle = args[0]; KVMsgHandle chandle = args[0];
network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle); network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle);
*rv = msg->rank; *rv = msg->rank;
}); });
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgName") DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgName")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
KVMsgHandle chandle = args[0]; KVMsgHandle chandle = args[0];
network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle); network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle);
*rv = msg->name; *rv = msg->name;
}); });
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgID") DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgID")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
KVMsgHandle chandle = args[0]; KVMsgHandle chandle = args[0];
network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle); network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle);
*rv = msg->id; *rv = msg->id;
}); });
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgData") DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgData")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
KVMsgHandle chandle = args[0]; KVMsgHandle chandle = args[0];
network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle); network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle);
*rv = msg->data; *rv = msg->data;
}); });
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgShape") DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgShape")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
KVMsgHandle chandle = args[0]; KVMsgHandle chandle = args[0];
network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle); network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle);
*rv = msg->shape; *rv = msg->shape;
}); });
DGL_REGISTER_GLOBAL("network._CAPI_DeleteKVMsg") DGL_REGISTER_GLOBAL("network._CAPI_DeleteKVMsg")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
KVMsgHandle chandle = args[0]; KVMsgHandle chandle = args[0];
network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle); network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle);
delete msg; delete msg;
}); });
DGL_REGISTER_GLOBAL("network._CAPI_FastPull") DGL_REGISTER_GLOBAL("network._CAPI_FastPull")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
std::string name = args[0]; std::string name = args[0];
int local_machine_id = args[1]; int local_machine_id = args[1];
int machine_count = args[2]; int machine_count = args[2];
...@@ -722,7 +668,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_FastPull") ...@@ -722,7 +668,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_FastPull")
CommunicatorHandle chandle_receiver = args[9]; CommunicatorHandle chandle_receiver = args[9];
std::string str_flag = args[10]; std::string str_flag = args[10];
network::Sender* sender = static_cast<network::Sender*>(chandle_sender); network::Sender* sender = static_cast<network::Sender*>(chandle_sender);
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle_receiver); network::Receiver* receiver =
static_cast<network::SocketReceiver*>(chandle_receiver);
int64_t ID_size = ID.GetSize() / sizeof(int64_t); int64_t ID_size = ID.GetSize() / sizeof(int64_t);
int64_t* ID_data = static_cast<int64_t*>(ID->data); int64_t* ID_data = static_cast<int64_t*>(ID->data);
int64_t* pb_data = static_cast<int64_t*>(pb->data); int64_t* pb_data = static_cast<int64_t*>(pb->data);
...@@ -785,15 +732,13 @@ DGL_REGISTER_GLOBAL("network._CAPI_FastPull") ...@@ -785,15 +732,13 @@ DGL_REGISTER_GLOBAL("network._CAPI_FastPull")
kv_msg.msg_type = MessageType::kPullMsg; kv_msg.msg_type = MessageType::kPullMsg;
kv_msg.rank = client_id; kv_msg.rank = client_id;
kv_msg.name = name; kv_msg.name = name;
kv_msg.id = CreateNDArrayFromRaw({static_cast<int64_t>(remote_ids[i].size())}, kv_msg.id = CreateNDArrayFromRaw(
ID->dtype, {static_cast<int64_t>(remote_ids[i].size())}, ID->dtype,
DGLContext{kDGLCPU, 0}, DGLContext{kDGLCPU, 0}, remote_ids[i].data(), !AUTO_FREE);
remote_ids[i].data(), int lower = i * group_count;
!AUTO_FREE); int higher = (i + 1) * group_count - 1;
int lower = i*group_count;
int higher = (i+1)*group_count-1;
#ifndef _WIN32 // windows does not support rand_r() #ifndef _WIN32 // windows does not support rand_r()
int s_id = (rand_r(&seed) % (higher-lower+1))+lower; int s_id = (rand_r(&seed) % (higher - lower + 1)) + lower;
send_kv_message(sender, &kv_msg, s_id, !AUTO_FREE); send_kv_message(sender, &kv_msg, s_id, !AUTO_FREE);
#else #else
LOG(FATAL) << "KVStore does not support Windows yet."; LOG(FATAL) << "KVStore does not support Windows yet.";
...@@ -801,40 +746,38 @@ DGL_REGISTER_GLOBAL("network._CAPI_FastPull") ...@@ -801,40 +746,38 @@ DGL_REGISTER_GLOBAL("network._CAPI_FastPull")
msg_count++; msg_count++;
} }
} }
char *return_data = new char[ID_size*row_size]; char* return_data = new char[ID_size * row_size];
const int64_t local_ids_size = local_ids.size(); const int64_t local_ids_size = local_ids.size();
// Copy local data // Copy local data
runtime::parallel_for(0, local_ids_size, [&](size_t b, size_t e) { runtime::parallel_for(0, local_ids_size, [&](size_t b, size_t e) {
for (auto i = b; i < e; ++i) { for (auto i = b; i < e; ++i) {
CHECK_GE(ID_size*row_size, local_ids_orginal[i] * row_size + row_size); CHECK_GE(
ID_size * row_size, local_ids_orginal[i] * row_size + row_size);
CHECK_GE(data_size, local_ids[i] * row_size + row_size); CHECK_GE(data_size, local_ids[i] * row_size + row_size);
CHECK_GE(local_ids[i], 0); CHECK_GE(local_ids[i], 0);
memcpy(return_data + local_ids_orginal[i] * row_size, memcpy(
local_data_char + local_ids[i] * row_size, return_data + local_ids_orginal[i] * row_size,
row_size); local_data_char + local_ids[i] * row_size, row_size);
} }
}); });
// Recv remote message // Recv remote message
for (int i = 0; i < msg_count; ++i) { for (int i = 0; i < msg_count; ++i) {
KVStoreMsg *kv_msg = recv_kv_message(receiver); KVStoreMsg* kv_msg = recv_kv_message(receiver);
int64_t id_size = kv_msg->id.GetSize() / sizeof(int64_t); int64_t id_size = kv_msg->id.GetSize() / sizeof(int64_t);
int part_id = kv_msg->rank / group_count; int part_id = kv_msg->rank / group_count;
char* data_char = static_cast<char*>(kv_msg->data->data); char* data_char = static_cast<char*>(kv_msg->data->data);
for (int64_t n = 0; n < id_size; ++n) { for (int64_t n = 0; n < id_size; ++n) {
memcpy(return_data + remote_ids_original[part_id][n] * row_size, memcpy(
data_char + n * row_size, return_data + remote_ids_original[part_id][n] * row_size,
row_size); data_char + n * row_size, row_size);
} }
delete kv_msg; delete kv_msg;
} }
// Get final tensor // Get final tensor
local_data_shape[0] = ID_size; local_data_shape[0] = ID_size;
NDArray res_tensor = CreateNDArrayFromRaw( NDArray res_tensor = CreateNDArrayFromRaw(
local_data_shape, local_data_shape, local_data->dtype, DGLContext{kDGLCPU, 0},
local_data->dtype, return_data, AUTO_FREE);
DGLContext{kDGLCPU, 0},
return_data,
AUTO_FREE);
*rv = res_tensor; *rv = res_tensor;
}); });
......
...@@ -6,12 +6,12 @@ ...@@ -6,12 +6,12 @@
#ifndef DGL_GRAPH_NETWORK_H_ #ifndef DGL_GRAPH_NETWORK_H_
#define DGL_GRAPH_NETWORK_H_ #define DGL_GRAPH_NETWORK_H_
#include <dmlc/logging.h>
#include <dgl/runtime/ndarray.h> #include <dgl/runtime/ndarray.h>
#include <dmlc/logging.h>
#include <string.h> #include <string.h>
#include <vector>
#include <string> #include <string>
#include <vector>
#include "../c_api_common.h" #include "../c_api_common.h"
#include "../rpc/network/msg_queue.h" #include "../rpc/network/msg_queue.h"
...@@ -24,10 +24,8 @@ namespace network { ...@@ -24,10 +24,8 @@ namespace network {
/*! /*!
* \brief Create NDArray from raw data * \brief Create NDArray from raw data
*/ */
NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape, NDArray CreateNDArrayFromRaw(
DGLDataType dtype, std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx, void* raw);
DGLContext ctx,
void* raw);
/*! /*!
* \brief Message type for DGL distributed training * \brief Message type for DGL distributed training
...@@ -75,7 +73,6 @@ enum MessageType { ...@@ -75,7 +73,6 @@ enum MessageType {
kGetShapeBackMsg = 9 kGetShapeBackMsg = 9
}; };
/*! /*!
* \brief Meta data for NDArray message * \brief Meta data for NDArray message
*/ */
...@@ -85,8 +82,7 @@ class ArrayMeta { ...@@ -85,8 +82,7 @@ class ArrayMeta {
* \brief ArrayMeta constructor. * \brief ArrayMeta constructor.
* \param msg_type type of message * \param msg_type type of message
*/ */
explicit ArrayMeta(int msg_type) explicit ArrayMeta(int msg_type) : msg_type_(msg_type), ndarray_count_(0) {}
: msg_type_(msg_type), ndarray_count_(0) {}
/*! /*!
* \brief Construct ArrayMeta from binary data buffer. * \brief Construct ArrayMeta from binary data buffer.
...@@ -101,16 +97,12 @@ class ArrayMeta { ...@@ -101,16 +97,12 @@ class ArrayMeta {
/*! /*!
* \return message type * \return message type
*/ */
inline int msg_type() const { inline int msg_type() const { return msg_type_; }
return msg_type_;
}
/*! /*!
* \return count of ndarray * \return count of ndarray
*/ */
inline int ndarray_count() const { inline int ndarray_count() const { return ndarray_count_; }
return ndarray_count_;
}
/*! /*!
* \brief Add NDArray meta data to ArrayMeta * \brief Add NDArray meta data to ArrayMeta
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
*/ */
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/nodeflow.h> #include <dgl/nodeflow.h>
#include <dgl/packed_func_ext.h>
#include <string> #include <string>
...@@ -19,15 +19,16 @@ using dgl::runtime::PackedFunc; ...@@ -19,15 +19,16 @@ using dgl::runtime::PackedFunc;
namespace dgl { namespace dgl {
std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::string &fmt, std::vector<IdArray> GetNodeFlowSlice(
size_t layer0_size, size_t layer1_start, const ImmutableGraph &graph, const std::string &fmt, size_t layer0_size,
size_t layer1_end, bool remap) { size_t layer1_start, size_t layer1_end, bool remap) {
CHECK_GE(layer1_start, layer0_size); CHECK_GE(layer1_start, layer0_size);
if (fmt == std::string("csr")) { if (fmt == std::string("csr")) {
dgl_id_t first_vid = layer1_start - layer0_size; dgl_id_t first_vid = layer1_start - layer0_size;
auto csr = aten::CSRSliceRows(graph.GetInCSR()->ToCSRMatrix(), layer1_start, layer1_end); auto csr = aten::CSRSliceRows(
graph.GetInCSR()->ToCSRMatrix(), layer1_start, layer1_end);
if (remap) { if (remap) {
dgl_id_t *eid_data = static_cast<dgl_id_t*>(csr.data->data); dgl_id_t *eid_data = static_cast<dgl_id_t *>(csr.data->data);
const dgl_id_t first_eid = eid_data[0]; const dgl_id_t first_eid = eid_data[0];
IdArray new_indices = aten::Sub(csr.indices, first_vid); IdArray new_indices = aten::Sub(csr.indices, first_vid);
IdArray new_data = aten::Sub(csr.data, first_eid); IdArray new_data = aten::Sub(csr.data, first_eid);
...@@ -37,14 +38,14 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st ...@@ -37,14 +38,14 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st
} }
} else if (fmt == std::string("coo")) { } else if (fmt == std::string("coo")) {
auto csr = graph.GetInCSR()->ToCSRMatrix(); auto csr = graph.GetInCSR()->ToCSRMatrix();
const dgl_id_t* indptr = static_cast<dgl_id_t*>(csr.indptr->data); const dgl_id_t *indptr = static_cast<dgl_id_t *>(csr.indptr->data);
const dgl_id_t* indices = static_cast<dgl_id_t*>(csr.indices->data); const dgl_id_t *indices = static_cast<dgl_id_t *>(csr.indices->data);
const dgl_id_t* edge_ids = static_cast<dgl_id_t*>(csr.data->data); const dgl_id_t *edge_ids = static_cast<dgl_id_t *>(csr.data->data);
int64_t nnz = indptr[layer1_end] - indptr[layer1_start]; int64_t nnz = indptr[layer1_end] - indptr[layer1_start];
IdArray idx = aten::NewIdArray(2 * nnz); IdArray idx = aten::NewIdArray(2 * nnz);
IdArray eid = aten::NewIdArray(nnz); IdArray eid = aten::NewIdArray(nnz);
int64_t *idx_data = static_cast<int64_t*>(idx->data); int64_t *idx_data = static_cast<int64_t *>(idx->data);
dgl_id_t *eid_data = static_cast<dgl_id_t*>(eid->data); dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data);
size_t num_edges = 0; size_t num_edges = 0;
for (size_t i = layer1_start; i < layer1_end; i++) { for (size_t i = layer1_start; i < layer1_end; i++) {
for (dgl_id_t j = indptr[i]; j < indptr[i + 1]; j++) { for (dgl_id_t j = indptr[i]; j < indptr[i + 1]; j++) {
...@@ -65,10 +66,12 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st ...@@ -65,10 +66,12 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st
eid_data[i] = edge_ids[edge_start + i] - first_eid; eid_data[i] = edge_ids[edge_start + i] - first_eid;
} }
} else { } else {
std::copy(indices + indptr[layer1_start], std::copy(
indices + indptr[layer1_end], idx_data + nnz); indices + indptr[layer1_start], indices + indptr[layer1_end],
std::copy(edge_ids + indptr[layer1_start], idx_data + nnz);
edge_ids + indptr[layer1_end], eid_data); std::copy(
edge_ids + indptr[layer1_start], edge_ids + indptr[layer1_end],
eid_data);
} }
return std::vector<IdArray>{idx, eid}; return std::vector<IdArray>{idx, eid};
} else { } else {
...@@ -78,14 +81,15 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st ...@@ -78,14 +81,15 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st
} }
DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetBlockAdj") DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetBlockAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef g = args[0]; GraphRef g = args[0];
std::string format = args[1]; std::string format = args[1];
int64_t layer0_size = args[2]; int64_t layer0_size = args[2];
int64_t start = args[3]; int64_t start = args[3];
int64_t end = args[4]; int64_t end = args[4];
const bool remap = args[5]; const bool remap = args[5];
auto ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr())); auto ig =
CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
auto res = GetNodeFlowSlice(*ig, format, layer0_size, start, end, remap); auto res = GetNodeFlowSlice(*ig, format, layer0_size, start, end, remap);
*rv = ConvertNDArrayVectorToPackedFunc(res); *rv = ConvertNDArrayVectorToPackedFunc(res);
}); });
......
...@@ -3,13 +3,14 @@ ...@@ -3,13 +3,14 @@
* \file graph/pickle.cc * \file graph/pickle.cc
* \brief Functions for pickle and unpickle a graph * \brief Functions for pickle and unpickle a graph
*/ */
#include <dgl/graph_serializer.h>
#include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/immutable_graph.h>
#include <dgl/graph_serializer.h>
#include <dmlc/memory_io.h> #include <dmlc/memory_io.h>
#include "./heterograph.h"
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./heterograph.h"
#include "unit_graph.h" #include "unit_graph.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -91,7 +92,7 @@ HeteroPickleStates HeteroForkingPickle(HeteroGraphPtr graph) { ...@@ -91,7 +92,7 @@ HeteroPickleStates HeteroForkingPickle(HeteroGraphPtr graph) {
return states; return states;
} }
HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) { HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates &states) {
char *buf = const_cast<char *>(states.meta.c_str()); // a readonly stream? char *buf = const_cast<char *>(states.meta.c_str()); // a readonly stream?
dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size()); dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size());
dmlc::Stream *strm = &ifs; dmlc::Stream *strm = &ifs;
...@@ -108,10 +109,10 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) { ...@@ -108,10 +109,10 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
auto array_itr = states.arrays.begin(); auto array_itr = states.arrays.begin();
for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) { for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
const auto& pair = metagraph->FindEdge(etype); const auto &pair = metagraph->FindEdge(etype);
const dgl_type_t srctype = pair.first; const dgl_type_t srctype = pair.first;
const dgl_type_t dsttype = pair.second; const dgl_type_t dsttype = pair.second;
const int64_t num_vtypes = (srctype == dsttype)? 1 : 2; const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2;
int64_t num_src = num_nodes_per_type[srctype]; int64_t num_src = num_nodes_per_type[srctype];
int64_t num_dst = num_nodes_per_type[dsttype]; int64_t num_dst = num_nodes_per_type[dsttype];
SparseFormat fmt; SparseFormat fmt;
...@@ -126,7 +127,8 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) { ...@@ -126,7 +127,8 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
bool csorted; bool csorted;
CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'"; CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'";
CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'"; CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'";
auto coo = aten::COOMatrix(num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted); auto coo = aten::COOMatrix(
num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
// TODO(zihao) fix // TODO(zihao) fix
relgraph = CreateFromCOO(num_vtypes, coo, ALL_CODE); relgraph = CreateFromCOO(num_vtypes, coo, ALL_CODE);
break; break;
...@@ -138,7 +140,8 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) { ...@@ -138,7 +140,8 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
const auto &edge_id = *(array_itr++); const auto &edge_id = *(array_itr++);
bool sorted; bool sorted;
CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'"; CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
auto csr = aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted); auto csr =
aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);
// TODO(zihao) fix // TODO(zihao) fix
relgraph = CreateFromCSR(num_vtypes, csr, ALL_CODE); relgraph = CreateFromCSR(num_vtypes, csr, ALL_CODE);
break; break;
...@@ -157,17 +160,18 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) { ...@@ -157,17 +160,18 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
} }
// For backward compatibility // For backward compatibility
HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states) { HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates &states) {
const auto metagraph = states.metagraph; const auto metagraph = states.metagraph;
const auto &num_nodes_per_type = states.num_nodes_per_type; const auto &num_nodes_per_type = states.num_nodes_per_type;
CHECK_EQ(states.adjs.size(), metagraph->NumEdges()); CHECK_EQ(states.adjs.size(), metagraph->NumEdges());
std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges()); std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) { for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
const auto& pair = metagraph->FindEdge(etype); const auto &pair = metagraph->FindEdge(etype);
const dgl_type_t srctype = pair.first; const dgl_type_t srctype = pair.first;
const dgl_type_t dsttype = pair.second; const dgl_type_t dsttype = pair.second;
const int64_t num_vtypes = (srctype == dsttype)? 1 : 2; const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2;
const SparseFormat fmt = static_cast<SparseFormat>(states.adjs[etype]->format); const SparseFormat fmt =
static_cast<SparseFormat>(states.adjs[etype]->format);
switch (fmt) { switch (fmt) {
case SparseFormat::kCOO: case SparseFormat::kCOO:
relgraphs[etype] = UnitGraph::CreateFromCOO( relgraphs[etype] = UnitGraph::CreateFromCOO(
...@@ -202,7 +206,7 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) { ...@@ -202,7 +206,7 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
auto array_itr = states.arrays.begin(); auto array_itr = states.arrays.begin();
for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) { for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
const auto& pair = metagraph->FindEdge(etype); const auto &pair = metagraph->FindEdge(etype);
const dgl_type_t srctype = pair.first; const dgl_type_t srctype = pair.first;
const dgl_type_t dsttype = pair.second; const dgl_type_t dsttype = pair.second;
const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2; const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2;
...@@ -227,7 +231,8 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) { ...@@ -227,7 +231,8 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
bool csorted; bool csorted;
CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'"; CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'";
CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'"; CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'";
coo = aten::COOMatrix(num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted); coo = aten::COOMatrix(
num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
} }
if (created_formats & CSR_CODE) { if (created_formats & CSR_CODE) {
CHECK_GE(states.arrays.end() - array_itr, 3); CHECK_GE(states.arrays.end() - array_itr, 3);
...@@ -258,13 +263,13 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) { ...@@ -258,13 +263,13 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
} }
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetVersion") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetVersion")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroPickleStatesRef st = args[0]; HeteroPickleStatesRef st = args[0];
*rv = st->version; *rv = st->version;
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMeta") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMeta")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroPickleStatesRef st = args[0]; HeteroPickleStatesRef st = args[0];
DGLByteArray buf; DGLByteArray buf;
buf.data = st->meta.c_str(); buf.data = st->meta.c_str();
...@@ -273,50 +278,50 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMeta") ...@@ -273,50 +278,50 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMeta")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetArrays") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetArrays")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroPickleStatesRef st = args[0]; HeteroPickleStatesRef st = args[0];
*rv = ConvertNDArrayVectorToPackedFunc(st->arrays); *rv = ConvertNDArrayVectorToPackedFunc(st->arrays);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetArraysNum") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetArraysNum")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroPickleStatesRef st = args[0]; HeteroPickleStatesRef st = args[0];
*rv = static_cast<int64_t>(st->arrays.size()); *rv = static_cast<int64_t>(st->arrays.size());
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStates") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStates")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
const int version = args[0]; const int version = args[0];
std::string meta = args[1]; std::string meta = args[1];
const List<Value> arrays = args[2]; const List<Value> arrays = args[2];
std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates ); std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);
st->version = version == 0 ? 1 : version; st->version = version == 0 ? 1 : version;
st->meta = meta; st->meta = meta;
st->arrays.reserve(arrays.size()); st->arrays.reserve(arrays.size());
for (const auto& ref : arrays) { for (const auto &ref : arrays) {
st->arrays.push_back(ref->data); st->arrays.push_back(ref->data);
} }
*rv = HeteroPickleStatesRef(st); *rv = HeteroPickleStatesRef(st);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickle") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef ref = args[0]; HeteroGraphRef ref = args[0];
std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates ); std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);
*st = HeteroPickle(ref.sptr()); *st = HeteroPickle(ref.sptr());
*rv = HeteroPickleStatesRef(st); *rv = HeteroPickleStatesRef(st);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroForkingPickle") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroForkingPickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef ref = args[0]; HeteroGraphRef ref = args[0];
std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates ); std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);
*st = HeteroForkingPickle(ref.sptr()); *st = HeteroForkingPickle(ref.sptr());
*rv = HeteroPickleStatesRef(st); *rv = HeteroPickleStatesRef(st);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroPickleStatesRef ref = args[0]; HeteroPickleStatesRef ref = args[0];
HeteroGraphPtr graph; HeteroGraphPtr graph;
switch (ref->version) { switch (ref->version) {
...@@ -334,24 +339,23 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle") ...@@ -334,24 +339,23 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle")
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroForkingUnpickle") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroForkingUnpickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroPickleStatesRef ref = args[0]; HeteroPickleStatesRef ref = args[0];
HeteroGraphPtr graph = HeteroForkingUnpickle(*ref.sptr()); HeteroGraphPtr graph = HeteroForkingUnpickle(*ref.sptr());
*rv = HeteroGraphRef(graph); *rv = HeteroGraphRef(graph);
}); });
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStatesOld") DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStatesOld")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef metagraph = args[0]; GraphRef metagraph = args[0];
IdArray num_nodes_per_type = args[1]; IdArray num_nodes_per_type = args[1];
List<SparseMatrixRef> adjs = args[2]; List<SparseMatrixRef> adjs = args[2];
std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates ); std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);
st->version = 0; st->version = 0;
st->metagraph = metagraph.sptr(); st->metagraph = metagraph.sptr();
st->num_nodes_per_type = num_nodes_per_type.ToVector<int64_t>(); st->num_nodes_per_type = num_nodes_per_type.ToVector<int64_t>();
st->adjs.reserve(adjs.size()); st->adjs.reserve(adjs.size());
for (const auto& ref : adjs) for (const auto &ref : adjs) st->adjs.push_back(ref.sptr());
st->adjs.push_back(ref.sptr());
*rv = HeteroPickleStatesRef(st); *rv = HeteroPickleStatesRef(st);
}); });
} // namespace dgl } // namespace dgl
...@@ -3,18 +3,20 @@ ...@@ -3,18 +3,20 @@
* \file graph/sampler.cc * \file graph/sampler.cc
* \brief DGL sampler implementation * \brief DGL sampler implementation
*/ */
#include <dgl/sampler.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/random.h> #include <dgl/random.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <dgl/sampler.h>
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <algorithm> #include <algorithm>
#include <cstdlib>
#include <cmath> #include <cmath>
#include <cstdlib>
#include <numeric> #include <numeric>
#include "../c_api_common.h" #include "../c_api_common.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -25,21 +27,21 @@ namespace { ...@@ -25,21 +27,21 @@ namespace {
/* /*
* ArrayHeap is used to sample elements from vector * ArrayHeap is used to sample elements from vector
*/ */
template<typename ValueType> template <typename ValueType>
class ArrayHeap { class ArrayHeap {
public: public:
explicit ArrayHeap(const std::vector<ValueType>& prob) { explicit ArrayHeap(const std::vector<ValueType> &prob) {
vec_size_ = prob.size(); vec_size_ = prob.size();
bit_len_ = ceil(log2(vec_size_)); bit_len_ = ceil(log2(vec_size_));
limit_ = 1UL << bit_len_; limit_ = 1UL << bit_len_;
// allocate twice the size // allocate twice the size
heap_.resize(limit_ << 1, 0); heap_.resize(limit_ << 1, 0);
// allocate the leaves // allocate the leaves
for (size_t i = limit_; i < vec_size_+limit_; ++i) { for (size_t i = limit_; i < vec_size_ + limit_; ++i) {
heap_[i] = prob[i-limit_]; heap_[i] = prob[i - limit_];
} }
// iterate up the tree (this is O(m)) // iterate up the tree (this is O(m))
for (int i = bit_len_-1; i >= 0; --i) { for (int i = bit_len_ - 1; i >= 0; --i) {
for (size_t j = (1UL << i); j < (1UL << (i + 1)); ++j) { for (size_t j = (1UL << i); j < (1UL << (i + 1)); ++j) {
heap_[j] = heap_[j << 1] + heap_[(j << 1) + 1]; heap_[j] = heap_[j << 1] + heap_[(j << 1) + 1];
} }
...@@ -54,7 +56,7 @@ class ArrayHeap { ...@@ -54,7 +56,7 @@ class ArrayHeap {
size_t i = index + limit_; size_t i = index + limit_;
heap_[i] = 0; heap_[i] = 0;
i /= 2; i /= 2;
for (int j = bit_len_-1; j >= 0; --j) { for (int j = bit_len_ - 1; j >= 0; --j) {
// Using heap_[i] = heap_[i] - w will loss some precision in float. // Using heap_[i] = heap_[i] - w will loss some precision in float.
// Using addition to re-calculate the weight layer by layer. // Using addition to re-calculate the weight layer by layer.
heap_[i] = heap_[i << 1] + heap_[(i << 1) + 1]; heap_[i] = heap_[i << 1] + heap_[(i << 1) + 1];
...@@ -93,7 +95,7 @@ class ArrayHeap { ...@@ -93,7 +95,7 @@ class ArrayHeap {
/* /*
* Sample a vector by given the size n * Sample a vector by given the size n
*/ */
size_t SampleWithoutReplacement(size_t n, std::vector<size_t>* samples) { size_t SampleWithoutReplacement(size_t n, std::vector<size_t> *samples) {
// sample n elements // sample n elements
size_t i = 0; size_t i = 0;
for (; i < n; ++i) { for (; i < n; ++i) {
...@@ -116,20 +118,14 @@ class ArrayHeap { ...@@ -116,20 +118,14 @@ class ArrayHeap {
}; };
///////////////////////// Samplers ////////////////////////// ///////////////////////// Samplers //////////////////////////
class EdgeSamplerObject: public Object { class EdgeSamplerObject : public Object {
public: public:
EdgeSamplerObject(const GraphPtr gptr, EdgeSamplerObject(
IdArray seed_edges, const GraphPtr gptr, IdArray seed_edges, const int64_t batch_size,
const int64_t batch_size, const int64_t num_workers, const bool replacement, const bool reset,
const int64_t num_workers, const std::string neg_mode, const int64_t neg_sample_size,
const bool replacement, const int64_t chunk_size, const bool exclude_positive,
const bool reset, const bool check_false_neg, IdArray relations) {
const std::string neg_mode,
const int64_t neg_sample_size,
const int64_t chunk_size,
const bool exclude_positive,
const bool check_false_neg,
IdArray relations) {
gptr_ = gptr; gptr_ = gptr;
seed_edges_ = seed_edges; seed_edges_ = seed_edges;
relations_ = relations; relations_ = relations;
...@@ -147,24 +143,22 @@ class EdgeSamplerObject: public Object { ...@@ -147,24 +143,22 @@ class EdgeSamplerObject: public Object {
~EdgeSamplerObject() {} ~EdgeSamplerObject() {}
virtual void Fetch(DGLRetValue* rv) = 0; virtual void Fetch(DGLRetValue *rv) = 0;
virtual void Reset() = 0; virtual void Reset() = 0;
protected: protected:
virtual void randomSample(size_t set_size, size_t num, std::vector<size_t>* out) = 0; virtual void randomSample(
virtual void randomSample(size_t set_size, size_t num, const std::vector<size_t> &exclude, size_t set_size, size_t num, std::vector<size_t> *out) = 0;
std::vector<size_t>* out) = 0; virtual void randomSample(
size_t set_size, size_t num, const std::vector<size_t> &exclude,
NegSubgraph genNegEdgeSubgraph(const Subgraph &pos_subg, std::vector<size_t> *out) = 0;
const std::string &neg_mode,
int64_t neg_sample_size, NegSubgraph genNegEdgeSubgraph(
bool exclude_positive, const Subgraph &pos_subg, const std::string &neg_mode,
bool check_false_neg); int64_t neg_sample_size, bool exclude_positive, bool check_false_neg);
NegSubgraph genChunkedNegEdgeSubgraph(const Subgraph &pos_subg, NegSubgraph genChunkedNegEdgeSubgraph(
const std::string &neg_mode, const Subgraph &pos_subg, const std::string &neg_mode,
int64_t neg_sample_size, int64_t neg_sample_size, bool exclude_positive, bool check_false_neg);
bool exclude_positive,
bool check_false_neg);
GraphPtr gptr_; GraphPtr gptr_;
IdArray seed_edges_; IdArray seed_edges_;
...@@ -184,7 +178,7 @@ class EdgeSamplerObject: public Object { ...@@ -184,7 +178,7 @@ class EdgeSamplerObject: public Object {
/* /*
* Uniformly sample integers from [0, set_size) without replacement. * Uniformly sample integers from [0, set_size) without replacement.
*/ */
void RandomSample(size_t set_size, size_t num, std::vector<size_t>* out) { void RandomSample(size_t set_size, size_t num, std::vector<size_t> *out) {
if (num < set_size) { if (num < set_size) {
std::unordered_set<size_t> sampled_idxs; std::unordered_set<size_t> sampled_idxs;
while (sampled_idxs.size() < num) { while (sampled_idxs.size() < num) {
...@@ -194,13 +188,13 @@ void RandomSample(size_t set_size, size_t num, std::vector<size_t>* out) { ...@@ -194,13 +188,13 @@ void RandomSample(size_t set_size, size_t num, std::vector<size_t>* out) {
} else { } else {
// If we need to sample all elements in the set, we don't need to // If we need to sample all elements in the set, we don't need to
// generate random numbers. // generate random numbers.
for (size_t i = 0; i < set_size; i++) for (size_t i = 0; i < set_size; i++) out->push_back(i);
out->push_back(i);
} }
} }
void RandomSample(size_t set_size, size_t num, const std::vector<size_t> &exclude, void RandomSample(
std::vector<size_t>* out) { size_t set_size, size_t num, const std::vector<size_t> &exclude,
std::vector<size_t> *out) {
std::unordered_map<size_t, int> sampled_idxs; std::unordered_map<size_t, int> sampled_idxs;
for (auto v : exclude) { for (auto v : exclude) {
sampled_idxs.insert(std::pair<size_t, int>(v, 0)); sampled_idxs.insert(std::pair<size_t, int>(v, 0));
...@@ -231,9 +225,9 @@ void RandomSample(size_t set_size, size_t num, const std::vector<size_t> &exclud ...@@ -231,9 +225,9 @@ void RandomSample(size_t set_size, size_t num, const std::vector<size_t> &exclud
* For a sparse array whose non-zeros are represented by nz_idxs, * For a sparse array whose non-zeros are represented by nz_idxs,
* negate the sparse array and outputs the non-zeros in the negated array. * negate the sparse array and outputs the non-zeros in the negated array.
*/ */
void NegateArray(const std::vector<size_t> &nz_idxs, void NegateArray(
size_t arr_size, const std::vector<size_t> &nz_idxs, size_t arr_size,
std::vector<size_t>* out) { std::vector<size_t> *out) {
// nz_idxs must have been sorted. // nz_idxs must have been sorted.
auto it = nz_idxs.begin(); auto it = nz_idxs.begin();
size_t i = 0; size_t i = 0;
...@@ -253,12 +247,10 @@ void NegateArray(const std::vector<size_t> &nz_idxs, ...@@ -253,12 +247,10 @@ void NegateArray(const std::vector<size_t> &nz_idxs,
/* /*
* Uniform sample vertices from a list of vertices. * Uniform sample vertices from a list of vertices.
*/ */
void GetUniformSample(const dgl_id_t* edge_id_list, void GetUniformSample(
const dgl_id_t* vid_list, const dgl_id_t *edge_id_list, const dgl_id_t *vid_list,
const size_t ver_len, const size_t ver_len, const size_t max_num_neighbor,
const size_t max_num_neighbor, std::vector<dgl_id_t> *out_ver, std::vector<dgl_id_t> *out_edge) {
std::vector<dgl_id_t>* out_ver,
std::vector<dgl_id_t>* out_edge) {
// Copy vid_list to output // Copy vid_list to output
if (ver_len <= max_num_neighbor) { if (ver_len <= max_num_neighbor) {
out_ver->insert(out_ver->end(), vid_list, vid_list + ver_len); out_ver->insert(out_ver->end(), vid_list, vid_list + ver_len);
...@@ -292,16 +284,15 @@ void GetUniformSample(const dgl_id_t* edge_id_list, ...@@ -292,16 +284,15 @@ void GetUniformSample(const dgl_id_t* edge_id_list,
/* /*
* Non-uniform sample via ArrayHeap * Non-uniform sample via ArrayHeap
* *
* \param probability Transition probability on the entire graph, indexed by edge ID * \param probability Transition probability on the entire graph, indexed by
* edge ID
*/ */
template<typename ValueType> template <typename ValueType>
void GetNonUniformSample(const ValueType* probability, void GetNonUniformSample(
const dgl_id_t* edge_id_list, const ValueType *probability, const dgl_id_t *edge_id_list,
const dgl_id_t* vid_list, const dgl_id_t *vid_list, const size_t ver_len,
const size_t ver_len, const size_t max_num_neighbor, std::vector<dgl_id_t> *out_ver,
const size_t max_num_neighbor, std::vector<dgl_id_t> *out_edge) {
std::vector<dgl_id_t>* out_ver,
std::vector<dgl_id_t>* out_edge) {
// Copy vid_list to output // Copy vid_list to output
if (ver_len <= max_num_neighbor) { if (ver_len <= max_num_neighbor) {
out_ver->insert(out_ver->end(), vid_list, vid_list + ver_len); out_ver->insert(out_ver->end(), vid_list, vid_list + ver_len);
...@@ -333,8 +324,8 @@ void GetNonUniformSample(const ValueType* probability, ...@@ -333,8 +324,8 @@ void GetNonUniformSample(const ValueType* probability,
struct neigh_list { struct neigh_list {
std::vector<dgl_id_t> neighs; std::vector<dgl_id_t> neighs;
std::vector<dgl_id_t> edges; std::vector<dgl_id_t> edges;
neigh_list(const std::vector<dgl_id_t> &_neighs, neigh_list(
const std::vector<dgl_id_t> &_edges) const std::vector<dgl_id_t> &_neighs, const std::vector<dgl_id_t> &_edges)
: neighs(_neighs), edges(_edges) {} : neighs(_neighs), edges(_edges) {}
}; };
...@@ -350,12 +341,11 @@ struct neighbor_info { ...@@ -350,12 +341,11 @@ struct neighbor_info {
} }
}; };
NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list, NodeFlow ConstructNodeFlow(
std::vector<dgl_id_t> edge_list, std::vector<dgl_id_t> neighbor_list, std::vector<dgl_id_t> edge_list,
std::vector<size_t> layer_offsets, std::vector<size_t> layer_offsets,
std::vector<std::pair<dgl_id_t, int> > *sub_vers, std::vector<std::pair<dgl_id_t, int>> *sub_vers,
std::vector<neighbor_info> *neigh_pos, std::vector<neighbor_info> *neigh_pos, const std::string &edge_type,
const std::string &edge_type,
int64_t num_edges, int num_hops) { int64_t num_edges, int num_hops) {
NodeFlow nf = NodeFlow::Create(); NodeFlow nf = NodeFlow::Create();
uint64_t num_vertices = sub_vers->size(); uint64_t num_vertices = sub_vers->size();
...@@ -371,9 +361,9 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list, ...@@ -371,9 +361,9 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
// Construct sub_csr_graph, we treat nodeflow as multigraph by default // Construct sub_csr_graph, we treat nodeflow as multigraph by default
auto subg_csr = CSRPtr(new CSR(num_vertices, num_edges)); auto subg_csr = CSRPtr(new CSR(num_vertices, num_edges));
dgl_id_t* indptr_out = static_cast<dgl_id_t*>(subg_csr->indptr()->data); dgl_id_t *indptr_out = static_cast<dgl_id_t *>(subg_csr->indptr()->data);
dgl_id_t* col_list_out = static_cast<dgl_id_t*>(subg_csr->indices()->data); dgl_id_t *col_list_out = static_cast<dgl_id_t *>(subg_csr->indices()->data);
dgl_id_t* eid_out = static_cast<dgl_id_t*>(subg_csr->edge_ids()->data); dgl_id_t *eid_out = static_cast<dgl_id_t *>(subg_csr->edge_ids()->data);
size_t collected_nedges = 0; size_t collected_nedges = 0;
// The data from the previous steps: // The data from the previous steps:
...@@ -385,12 +375,13 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list, ...@@ -385,12 +375,13 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
layer_ver_maps.resize(num_hops); layer_ver_maps.resize(num_hops);
size_t out_node_idx = 0; size_t out_node_idx = 0;
for (int layer_id = num_hops - 1; layer_id >= 0; layer_id--) { for (int layer_id = num_hops - 1; layer_id >= 0; layer_id--) {
// We sort the vertices in a layer so that we don't need to sort the neighbor Ids // We sort the vertices in a layer so that we don't need to sort the
// after remap to a subgraph. However, we don't need to sort the first layer // neighbor Ids after remap to a subgraph. However, we don't need to sort
// because we want the order of the nodes in the first layer is the same as // the first layer because we want the order of the nodes in the first layer
// the input seed nodes. // is the same as the input seed nodes.
if (layer_id > 0) { if (layer_id > 0) {
std::sort(sub_vers->begin() + layer_offsets[layer_id], std::sort(
sub_vers->begin() + layer_offsets[layer_id],
sub_vers->begin() + layer_offsets[layer_id + 1], sub_vers->begin() + layer_offsets[layer_id + 1],
[](const std::pair<dgl_id_t, dgl_id_t> &a1, [](const std::pair<dgl_id_t, dgl_id_t> &a1,
const std::pair<dgl_id_t, dgl_id_t> &a2) { const std::pair<dgl_id_t, dgl_id_t> &a2) {
...@@ -399,37 +390,40 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list, ...@@ -399,37 +390,40 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
} }
// Save the sampled vertices and its layer Id. // Save the sampled vertices and its layer Id.
for (size_t i = layer_offsets[layer_id]; i < layer_offsets[layer_id + 1]; i++) { for (size_t i = layer_offsets[layer_id]; i < layer_offsets[layer_id + 1];
i++) {
node_map_data[out_node_idx++] = sub_vers->at(i).first; node_map_data[out_node_idx++] = sub_vers->at(i).first;
layer_ver_maps[layer_id].insert(std::pair<dgl_id_t, dgl_id_t>(sub_vers->at(i).first, layer_ver_maps[layer_id].insert(
ver_id++)); std::pair<dgl_id_t, dgl_id_t>(sub_vers->at(i).first, ver_id++));
CHECK_EQ(sub_vers->at(i).second, layer_id); CHECK_EQ(sub_vers->at(i).second, layer_id);
} }
} }
CHECK(out_node_idx == num_vertices); CHECK(out_node_idx == num_vertices);
// sampling algorithms have to start from the seed nodes, so the seed nodes are // sampling algorithms have to start from the seed nodes, so the seed nodes
// in the first layer and the input nodes are in the last layer. // are in the first layer and the input nodes are in the last layer. When we
// When we expose the sampled graph to a Python user, we say the input nodes // expose the sampled graph to a Python user, we say the input nodes are in
// are in the first layer and the seed nodes are in the last layer. // the first layer and the seed nodes are in the last layer. Thus, when we
// Thus, when we copy sampled results to a CSR, we need to reverse the order of layers. // copy sampled results to a CSR, we need to reverse the order of layers.
std::fill(indptr_out, indptr_out + num_vertices + 1, 0); std::fill(indptr_out, indptr_out + num_vertices + 1, 0);
size_t row_idx = layer_offsets[num_hops] - layer_offsets[num_hops - 1]; size_t row_idx = layer_offsets[num_hops] - layer_offsets[num_hops - 1];
layer_off_data[0] = 0; layer_off_data[0] = 0;
layer_off_data[1] = layer_offsets[num_hops] - layer_offsets[num_hops - 1]; layer_off_data[1] = layer_offsets[num_hops] - layer_offsets[num_hops - 1];
int out_layer_idx = 1; int out_layer_idx = 1;
for (int layer_id = num_hops - 2; layer_id >= 0; layer_id--) { for (int layer_id = num_hops - 2; layer_id >= 0; layer_id--) {
// Because we don't sort the vertices in the first layer above, we can't sort // Because we don't sort the vertices in the first layer above, we can't
// the neighbor positions of the vertices in the first layer either. // sort the neighbor positions of the vertices in the first layer either.
if (layer_id > 0) { if (layer_id > 0) {
std::sort(neigh_pos->begin() + layer_offsets[layer_id], std::sort(
neigh_pos->begin() + layer_offsets[layer_id],
neigh_pos->begin() + layer_offsets[layer_id + 1], neigh_pos->begin() + layer_offsets[layer_id + 1],
[](const neighbor_info &a1, const neighbor_info &a2) { [](const neighbor_info &a1, const neighbor_info &a2) {
return a1.id < a2.id; return a1.id < a2.id;
}); });
} }
for (size_t i = layer_offsets[layer_id]; i < layer_offsets[layer_id + 1]; i++) { for (size_t i = layer_offsets[layer_id]; i < layer_offsets[layer_id + 1];
i++) {
dgl_id_t dst_id = sub_vers->at(i).first; dgl_id_t dst_id = sub_vers->at(i).first;
CHECK_EQ(dst_id, neigh_pos->at(i).id); CHECK_EQ(dst_id, neigh_pos->at(i).id);
size_t pos = neigh_pos->at(i).pos; size_t pos = neigh_pos->at(i).pos;
...@@ -441,18 +435,22 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list, ...@@ -441,18 +435,22 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
auto neigh_it = neighbor_list.begin() + pos; auto neigh_it = neighbor_list.begin() + pos;
for (size_t i = 0; i < nedges; i++) { for (size_t i = 0; i < nedges; i++) {
dgl_id_t neigh = *(neigh_it + i); dgl_id_t neigh = *(neigh_it + i);
CHECK(layer_ver_maps[layer_id + 1].find(neigh) != layer_ver_maps[layer_id + 1].end()); CHECK(
col_list_out[collected_nedges + i] = layer_ver_maps[layer_id + 1][neigh]; layer_ver_maps[layer_id + 1].find(neigh) !=
layer_ver_maps[layer_id + 1].end());
col_list_out[collected_nedges + i] =
layer_ver_maps[layer_id + 1][neigh];
} }
// We can simply copy the edge Ids. // We can simply copy the edge Ids.
std::copy_n(edge_list.begin() + pos, std::copy_n(
nedges, edge_map_data + collected_nedges); edge_list.begin() + pos, nedges, edge_map_data + collected_nedges);
collected_nedges += nedges; collected_nedges += nedges;
indptr_out[row_idx+1] = indptr_out[row_idx] + nedges; indptr_out[row_idx + 1] = indptr_out[row_idx] + nedges;
row_idx++; row_idx++;
} }
layer_off_data[out_layer_idx + 1] = layer_off_data[out_layer_idx] layer_off_data[out_layer_idx + 1] = layer_off_data[out_layer_idx] +
+ layer_offsets[layer_id + 1] - layer_offsets[layer_id]; layer_offsets[layer_id + 1] -
layer_offsets[layer_id];
out_layer_idx++; out_layer_idx++;
} }
CHECK_EQ(row_idx, num_vertices); CHECK_EQ(row_idx, num_vertices);
...@@ -464,7 +462,8 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list, ...@@ -464,7 +462,8 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
flow_off_data[0] = 0; flow_off_data[0] = 0;
int out_flow_idx = 0; int out_flow_idx = 0;
for (size_t i = 0; i < layer_offsets.size() - 2; i++) { for (size_t i = 0; i < layer_offsets.size() - 2; i++) {
size_t num_edges = indptr_out[layer_off_data[i + 2]] - indptr_out[layer_off_data[i + 1]]; size_t num_edges =
indptr_out[layer_off_data[i + 2]] - indptr_out[layer_off_data[i + 1]];
flow_off_data[out_flow_idx + 1] = flow_off_data[out_flow_idx] + num_edges; flow_off_data[out_flow_idx + 1] = flow_off_data[out_flow_idx] + num_edges;
out_flow_idx++; out_flow_idx++;
} }
...@@ -482,23 +481,21 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list, ...@@ -482,23 +481,21 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
return nf; return nf;
} }
template<typename ValueType> template <typename ValueType>
NodeFlow SampleSubgraph(const ImmutableGraph *graph, NodeFlow SampleSubgraph(
const std::vector<dgl_id_t>& seeds, const ImmutableGraph *graph, const std::vector<dgl_id_t> &seeds,
const ValueType* probability, const ValueType *probability, const std::string &edge_type, int num_hops,
const std::string &edge_type, size_t num_neighbor, const bool add_self_loop) {
int num_hops,
size_t num_neighbor,
const bool add_self_loop) {
CHECK_EQ(graph->NumBits(), 64) << "32 bit graph is not supported yet"; CHECK_EQ(graph->NumBits(), 64) << "32 bit graph is not supported yet";
const size_t num_seeds = seeds.size(); const size_t num_seeds = seeds.size();
auto orig_csr = edge_type == "in" ? graph->GetInCSR() : graph->GetOutCSR(); auto orig_csr = edge_type == "in" ? graph->GetInCSR() : graph->GetOutCSR();
const dgl_id_t* val_list = static_cast<dgl_id_t*>(orig_csr->edge_ids()->data); const dgl_id_t *val_list =
const dgl_id_t* col_list = static_cast<dgl_id_t*>(orig_csr->indices()->data); static_cast<dgl_id_t *>(orig_csr->edge_ids()->data);
const dgl_id_t* indptr = static_cast<dgl_id_t*>(orig_csr->indptr()->data); const dgl_id_t *col_list = static_cast<dgl_id_t *>(orig_csr->indices()->data);
const dgl_id_t *indptr = static_cast<dgl_id_t *>(orig_csr->indptr()->data);
std::unordered_set<dgl_id_t> sub_ver_map; // The vertex Ids in a layer. std::unordered_set<dgl_id_t> sub_ver_map; // The vertex Ids in a layer.
std::vector<std::pair<dgl_id_t, int> > sub_vers; std::vector<std::pair<dgl_id_t, int>> sub_vers;
sub_vers.reserve(num_seeds * 10); sub_vers.reserve(num_seeds * 10);
// add seed vertices // add seed vertices
for (size_t i = 0; i < num_seeds; ++i) { for (size_t i = 0; i < num_seeds; ++i) {
...@@ -526,37 +523,38 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph, ...@@ -526,37 +523,38 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
// sampled nodes in a layer, and clear it when entering a new layer. // sampled nodes in a layer, and clear it when entering a new layer.
sub_ver_map.clear(); sub_ver_map.clear();
// Previous iteration collects all nodes in sub_vers, which are collected // Previous iteration collects all nodes in sub_vers, which are collected
// in the previous layer. sub_vers is used both as a node collection and a queue. // in the previous layer. sub_vers is used both as a node collection and a
for (size_t idx = layer_offsets[layer_id - 1]; idx < layer_offsets[layer_id]; idx++) { // queue.
for (size_t idx = layer_offsets[layer_id - 1];
idx < layer_offsets[layer_id]; idx++) {
dgl_id_t dst_id = sub_vers[idx].first; dgl_id_t dst_id = sub_vers[idx].first;
const int cur_node_level = sub_vers[idx].second; const int cur_node_level = sub_vers[idx].second;
tmp_sampled_src_list.clear(); tmp_sampled_src_list.clear();
tmp_sampled_edge_list.clear(); tmp_sampled_edge_list.clear();
dgl_id_t ver_len = *(indptr+dst_id+1) - *(indptr+dst_id); dgl_id_t ver_len = *(indptr + dst_id + 1) - *(indptr + dst_id);
if (probability == nullptr) { // uniform-sample if (probability == nullptr) { // uniform-sample
GetUniformSample(val_list + *(indptr + dst_id), GetUniformSample(
col_list + *(indptr + dst_id), val_list + *(indptr + dst_id), col_list + *(indptr + dst_id),
ver_len, ver_len, num_neighbor, &tmp_sampled_src_list,
num_neighbor,
&tmp_sampled_src_list,
&tmp_sampled_edge_list); &tmp_sampled_edge_list);
} else { // non-uniform-sample } else { // non-uniform-sample
GetNonUniformSample(probability, GetNonUniformSample(
val_list + *(indptr + dst_id), probability, val_list + *(indptr + dst_id),
col_list + *(indptr + dst_id), col_list + *(indptr + dst_id), ver_len, num_neighbor,
ver_len, &tmp_sampled_src_list, &tmp_sampled_edge_list);
num_neighbor, }
&tmp_sampled_src_list, // If we need to add self loop and it doesn't exist in the sampled
&tmp_sampled_edge_list); // neighbor list.
} if (add_self_loop &&
// If we need to add self loop and it doesn't exist in the sampled neighbor list. std::find(
if (add_self_loop && std::find(tmp_sampled_src_list.begin(), tmp_sampled_src_list.end(), tmp_sampled_src_list.begin(), tmp_sampled_src_list.end(),
dst_id) == tmp_sampled_src_list.end()) { dst_id) == tmp_sampled_src_list.end()) {
tmp_sampled_src_list.push_back(dst_id); tmp_sampled_src_list.push_back(dst_id);
const dgl_id_t *src_list = col_list + *(indptr + dst_id); const dgl_id_t *src_list = col_list + *(indptr + dst_id);
const dgl_id_t *eid_list = val_list + *(indptr + dst_id); const dgl_id_t *eid_list = val_list + *(indptr + dst_id);
// TODO(zhengda) this operation has O(N) complexity. It can be pretty slow. // TODO(zhengda) this operation has O(N) complexity. It can be pretty
// slow.
const dgl_id_t *src = std::find(src_list, src_list + ver_len, dst_id); const dgl_id_t *src = std::find(src_list, src_list + ver_len, dst_id);
// If there doesn't exist a self loop in the graph. // If there doesn't exist a self loop in the graph.
// we have to add -1 as the edge id for the self-loop edge. // we have to add -1 as the edge id for the self-loop edge.
...@@ -566,7 +564,8 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph, ...@@ -566,7 +564,8 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
tmp_sampled_edge_list.push_back(eid_list[src - src_list]); tmp_sampled_edge_list.push_back(eid_list[src - src_list]);
} }
CHECK_EQ(tmp_sampled_src_list.size(), tmp_sampled_edge_list.size()); CHECK_EQ(tmp_sampled_src_list.size(), tmp_sampled_edge_list.size());
neigh_pos.emplace_back(dst_id, neighbor_list.size(), tmp_sampled_src_list.size()); neigh_pos.emplace_back(
dst_id, neighbor_list.size(), tmp_sampled_src_list.size());
// Then push the vertices // Then push the vertices
for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) { for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {
neighbor_list.push_back(tmp_sampled_src_list[i]); neighbor_list.push_back(tmp_sampled_src_list[i]);
...@@ -578,8 +577,8 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph, ...@@ -578,8 +577,8 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
num_edges += tmp_sampled_src_list.size(); num_edges += tmp_sampled_src_list.size();
for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) { for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {
// We need to add the neighbor in the hashtable here. This ensures that // We need to add the neighbor in the hashtable here. This ensures that
// the vertex in the queue is unique. If we see a vertex before, we don't // the vertex in the queue is unique. If we see a vertex before, we
// need to add it to the queue again. // don't need to add it to the queue again.
auto ret = sub_ver_map.insert(tmp_sampled_src_list[i]); auto ret = sub_ver_map.insert(tmp_sampled_src_list[i]);
// If the sampled neighbor is inserted to the map successfully. // If the sampled neighbor is inserted to the map successfully.
if (ret.second) { if (ret.second) {
...@@ -591,76 +590,69 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph, ...@@ -591,76 +590,69 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
CHECK_EQ(layer_offsets[layer_id + 1], sub_vers.size()); CHECK_EQ(layer_offsets[layer_id + 1], sub_vers.size());
} }
return ConstructNodeFlow(neighbor_list, edge_list, layer_offsets, &sub_vers, &neigh_pos, return ConstructNodeFlow(
edge_type, num_edges, num_hops); neighbor_list, edge_list, layer_offsets, &sub_vers, &neigh_pos, edge_type,
num_edges, num_hops);
} }
} // namespace } // namespace
DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetGraph") DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
NodeFlow nflow = args[0]; NodeFlow nflow = args[0];
*rv = nflow->graph; *rv = nflow->graph;
}); });
DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetNodeMapping") DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetNodeMapping")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
NodeFlow nflow = args[0]; NodeFlow nflow = args[0];
*rv = nflow->node_mapping; *rv = nflow->node_mapping;
}); });
DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetEdgeMapping") DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetEdgeMapping")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
NodeFlow nflow = args[0]; NodeFlow nflow = args[0];
*rv = nflow->edge_mapping; *rv = nflow->edge_mapping;
}); });
DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetLayerOffsets") DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetLayerOffsets")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
NodeFlow nflow = args[0]; NodeFlow nflow = args[0];
*rv = nflow->layer_offsets; *rv = nflow->layer_offsets;
}); });
DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetBlockOffsets") DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetBlockOffsets")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
NodeFlow nflow = args[0]; NodeFlow nflow = args[0];
*rv = nflow->flow_offsets; *rv = nflow->flow_offsets;
}); });
template<typename ValueType> template <typename ValueType>
NodeFlow SamplerOp::NeighborSample(const ImmutableGraph *graph, NodeFlow SamplerOp::NeighborSample(
const std::vector<dgl_id_t>& seeds, const ImmutableGraph *graph, const std::vector<dgl_id_t> &seeds,
const std::string &edge_type, const std::string &edge_type, int num_hops, int expand_factor,
int num_hops, int expand_factor, const bool add_self_loop, const ValueType *probability) {
const bool add_self_loop, return SampleSubgraph(
const ValueType *probability) { graph, seeds, probability, edge_type, num_hops + 1, expand_factor,
return SampleSubgraph(graph,
seeds,
probability,
edge_type,
num_hops + 1,
expand_factor,
add_self_loop); add_self_loop);
} }
namespace { namespace {
void ConstructLayers(const dgl_id_t *indptr, void ConstructLayers(
const dgl_id_t *indices, const dgl_id_t *indptr, const dgl_id_t *indices,
const std::vector<dgl_id_t>& seed_array, const std::vector<dgl_id_t> &seed_array, IdArray layer_sizes,
IdArray layer_sizes, std::vector<dgl_id_t> *layer_offsets, std::vector<dgl_id_t> *node_mapping,
std::vector<dgl_id_t> *layer_offsets, std::vector<int64_t> *actl_layer_sizes, std::vector<float> *probabilities) {
std::vector<dgl_id_t> *node_mapping,
std::vector<int64_t> *actl_layer_sizes,
std::vector<float> *probabilities) {
/* /*
* Given a graph and a collection of seed nodes, this function constructs NodeFlow * Given a graph and a collection of seed nodes, this function constructs
* layers via uniform layer-wise sampling, and return the resultant layers and their * NodeFlow layers via uniform layer-wise sampling, and return the resultant
* corresponding probabilities. * layers and their corresponding probabilities.
*/ */
std::copy(seed_array.begin(), seed_array.end(), std::back_inserter(*node_mapping)); std::copy(
seed_array.begin(), seed_array.end(), std::back_inserter(*node_mapping));
actl_layer_sizes->push_back(node_mapping->size()); actl_layer_sizes->push_back(node_mapping->size());
probabilities->insert(probabilities->end(), node_mapping->size(), 1); probabilities->insert(probabilities->end(), node_mapping->size(), 1);
const int64_t* layer_sizes_data = static_cast<int64_t*>(layer_sizes->data); const int64_t *layer_sizes_data = static_cast<int64_t *>(layer_sizes->data);
const int64_t num_layers = layer_sizes->shape[0]; const int64_t num_layers = layer_sizes->shape[0];
size_t curr = 0; size_t curr = 0;
...@@ -674,14 +666,15 @@ namespace { ...@@ -674,14 +666,15 @@ namespace {
} }
std::vector<dgl_id_t> candidate_vector; std::vector<dgl_id_t> candidate_vector;
std::copy(candidate_set.begin(), candidate_set.end(), std::copy(
candidate_set.begin(), candidate_set.end(),
std::back_inserter(candidate_vector)); std::back_inserter(candidate_vector));
std::unordered_map<dgl_id_t, size_t> n_occurrences; std::unordered_map<dgl_id_t, size_t> n_occurrences;
auto n_candidates = candidate_vector.size(); auto n_candidates = candidate_vector.size();
for (int64_t j = 0; j != layer_size; ++j) { for (int64_t j = 0; j != layer_size; ++j) {
auto dst = candidate_vector[ auto dst =
RandomEngine::ThreadLocal()->RandInt(n_candidates)]; candidate_vector[RandomEngine::ThreadLocal()->RandInt(n_candidates)];
if (!n_occurrences.insert(std::make_pair(dst, 1)).second) { if (!n_occurrences.insert(std::make_pair(dst, 1)).second) {
++n_occurrences[dst]; ++n_occurrences[dst];
} }
...@@ -703,21 +696,18 @@ namespace { ...@@ -703,21 +696,18 @@ namespace {
for (const auto &size : *actl_layer_sizes) { for (const auto &size : *actl_layer_sizes) {
layer_offsets->push_back(size + layer_offsets->back()); layer_offsets->push_back(size + layer_offsets->back());
} }
} }
void ConstructFlows(const dgl_id_t *indptr, void ConstructFlows(
const dgl_id_t *indices, const dgl_id_t *indptr, const dgl_id_t *indices, const dgl_id_t *eids,
const dgl_id_t *eids,
const std::vector<dgl_id_t> &node_mapping, const std::vector<dgl_id_t> &node_mapping,
const std::vector<int64_t> &actl_layer_sizes, const std::vector<int64_t> &actl_layer_sizes,
std::vector<dgl_id_t> *sub_indptr, std::vector<dgl_id_t> *sub_indptr, std::vector<dgl_id_t> *sub_indices,
std::vector<dgl_id_t> *sub_indices, std::vector<dgl_id_t> *sub_eids, std::vector<dgl_id_t> *flow_offsets,
std::vector<dgl_id_t> *sub_eids,
std::vector<dgl_id_t> *flow_offsets,
std::vector<dgl_id_t> *edge_mapping) { std::vector<dgl_id_t> *edge_mapping) {
/* /*
* Given a graph and a sequence of NodeFlow layers, this function constructs dense * Given a graph and a sequence of NodeFlow layers, this function constructs
* subgraphs (flows) between consecutive layers. * dense subgraphs (flows) between consecutive layers.
*/ */
auto n_flows = actl_layer_sizes.size() - 1; auto n_flows = actl_layer_sizes.size() - 1;
for (int64_t i = 0; i < actl_layer_sizes.front() + 1; i++) for (int64_t i = 0; i < actl_layer_sizes.front() + 1; i++)
...@@ -742,7 +732,9 @@ namespace { ...@@ -742,7 +732,9 @@ namespace {
neighbor_indices.push_back(std::make_pair(ret->second, eids[k])); neighbor_indices.push_back(std::make_pair(ret->second, eids[k]));
} }
} }
auto cmp = [](const id_pair p, const id_pair q)->bool { return p.first < q.first; }; auto cmp = [](const id_pair p, const id_pair q) -> bool {
return p.first < q.first;
};
std::sort(neighbor_indices.begin(), neighbor_indices.end(), cmp); std::sort(neighbor_indices.begin(), neighbor_indices.end(), cmp);
for (const auto &pair : neighbor_indices) { for (const auto &pair : neighbor_indices) {
sub_indices->push_back(pair.first); sub_indices->push_back(pair.first);
...@@ -755,44 +747,32 @@ namespace { ...@@ -755,44 +747,32 @@ namespace {
} }
sub_eids->resize(sub_indices->size()); sub_eids->resize(sub_indices->size());
std::iota(sub_eids->begin(), sub_eids->end(), 0); std::iota(sub_eids->begin(), sub_eids->end(), 0);
} }
} // namespace } // namespace
NodeFlow SamplerOp::LayerUniformSample(const ImmutableGraph *graph, NodeFlow SamplerOp::LayerUniformSample(
const std::vector<dgl_id_t>& seeds, const ImmutableGraph *graph, const std::vector<dgl_id_t> &seeds,
const std::string &neighbor_type, const std::string &neighbor_type, IdArray layer_sizes) {
IdArray layer_sizes) { const auto g_csr =
const auto g_csr = neighbor_type == "in" ? graph->GetInCSR() : graph->GetOutCSR(); neighbor_type == "in" ? graph->GetInCSR() : graph->GetOutCSR();
const dgl_id_t *indptr = static_cast<dgl_id_t*>(g_csr->indptr()->data); const dgl_id_t *indptr = static_cast<dgl_id_t *>(g_csr->indptr()->data);
const dgl_id_t *indices = static_cast<dgl_id_t*>(g_csr->indices()->data); const dgl_id_t *indices = static_cast<dgl_id_t *>(g_csr->indices()->data);
const dgl_id_t *eids = static_cast<dgl_id_t*>(g_csr->edge_ids()->data); const dgl_id_t *eids = static_cast<dgl_id_t *>(g_csr->edge_ids()->data);
std::vector<dgl_id_t> layer_offsets; std::vector<dgl_id_t> layer_offsets;
std::vector<dgl_id_t> node_mapping; std::vector<dgl_id_t> node_mapping;
std::vector<int64_t> actl_layer_sizes; std::vector<int64_t> actl_layer_sizes;
std::vector<float> probabilities; std::vector<float> probabilities;
ConstructLayers(indptr, ConstructLayers(
indices, indptr, indices, seeds, layer_sizes, &layer_offsets, &node_mapping,
seeds, &actl_layer_sizes, &probabilities);
layer_sizes,
&layer_offsets,
&node_mapping,
&actl_layer_sizes,
&probabilities);
std::vector<dgl_id_t> sub_indptr, sub_indices, sub_edge_ids; std::vector<dgl_id_t> sub_indptr, sub_indices, sub_edge_ids;
std::vector<dgl_id_t> flow_offsets; std::vector<dgl_id_t> flow_offsets;
std::vector<dgl_id_t> edge_mapping; std::vector<dgl_id_t> edge_mapping;
ConstructFlows(indptr, ConstructFlows(
indices, indptr, indices, eids, node_mapping, actl_layer_sizes, &sub_indptr,
eids, &sub_indices, &sub_edge_ids, &flow_offsets, &edge_mapping);
node_mapping,
actl_layer_sizes,
&sub_indptr,
&sub_indices,
&sub_edge_ids,
&flow_offsets,
&edge_mapping);
// sanity check // sanity check
CHECK_GT(sub_indptr.size(), 0); CHECK_GT(sub_indptr.size(), 0);
CHECK_EQ(sub_indptr[0], 0); CHECK_EQ(sub_indptr[0], 0);
...@@ -800,8 +780,8 @@ NodeFlow SamplerOp::LayerUniformSample(const ImmutableGraph *graph, ...@@ -800,8 +780,8 @@ NodeFlow SamplerOp::LayerUniformSample(const ImmutableGraph *graph,
CHECK_EQ(sub_indices.size(), sub_edge_ids.size()); CHECK_EQ(sub_indices.size(), sub_edge_ids.size());
NodeFlow nf = NodeFlow::Create(); NodeFlow nf = NodeFlow::Create();
auto sub_csr = CSRPtr(new CSR(aten::VecToIdArray(sub_indptr), auto sub_csr = CSRPtr(new CSR(
aten::VecToIdArray(sub_indices), aten::VecToIdArray(sub_indptr), aten::VecToIdArray(sub_indices),
aten::VecToIdArray(sub_edge_ids))); aten::VecToIdArray(sub_edge_ids)));
if (neighbor_type == std::string("in")) { if (neighbor_type == std::string("in")) {
...@@ -830,24 +810,22 @@ void BuildCsr(const ImmutableGraph &g, const std::string neigh_type) { ...@@ -830,24 +810,22 @@ void BuildCsr(const ImmutableGraph &g, const std::string neigh_type) {
} }
} }
template<typename ValueType> template <typename ValueType>
std::vector<NodeFlow> NeighborSamplingImpl(const ImmutableGraphPtr gptr, std::vector<NodeFlow> NeighborSamplingImpl(
const IdArray seed_nodes, const ImmutableGraphPtr gptr, const IdArray seed_nodes,
const int64_t batch_start_id, const int64_t batch_start_id, const int64_t batch_size,
const int64_t batch_size, const int64_t max_num_workers, const int64_t expand_factor,
const int64_t max_num_workers, const int64_t num_hops, const std::string neigh_type,
const int64_t expand_factor, const bool add_self_loop, const ValueType *probability) {
const int64_t num_hops,
const std::string neigh_type,
const bool add_self_loop,
const ValueType *probability) {
// process args // process args
CHECK(aten::IsValidIdArray(seed_nodes)); CHECK(aten::IsValidIdArray(seed_nodes));
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data); const dgl_id_t *seed_nodes_data = static_cast<dgl_id_t *>(seed_nodes->data);
const int64_t num_seeds = seed_nodes->shape[0]; const int64_t num_seeds = seed_nodes->shape[0];
const int64_t num_workers = std::min(max_num_workers, const int64_t num_workers = std::min(
max_num_workers,
(num_seeds + batch_size - 1) / batch_size - batch_start_id); (num_seeds + batch_size - 1) / batch_size - batch_start_id);
// We need to make sure we have the right CSR before we enter parallel sampling. // We need to make sure we have the right CSR before we enter parallel
// sampling.
BuildCsr(*gptr, neigh_type); BuildCsr(*gptr, neigh_type);
// generate node flows // generate node flows
std::vector<NodeFlow> nflows(num_workers); std::vector<NodeFlow> nflows(num_workers);
...@@ -858,8 +836,8 @@ std::vector<NodeFlow> NeighborSamplingImpl(const ImmutableGraphPtr gptr, ...@@ -858,8 +836,8 @@ std::vector<NodeFlow> NeighborSamplingImpl(const ImmutableGraphPtr gptr,
const int64_t end = std::min(start + batch_size, num_seeds); const int64_t end = std::min(start + batch_size, num_seeds);
// TODO(minjie): the vector allocation/copy is unnecessary // TODO(minjie): the vector allocation/copy is unnecessary
std::vector<dgl_id_t> worker_seeds(end - start); std::vector<dgl_id_t> worker_seeds(end - start);
std::copy(seed_nodes_data + start, seed_nodes_data + end, std::copy(
worker_seeds.begin()); seed_nodes_data + start, seed_nodes_data + end, worker_seeds.begin());
nflows[i] = SamplerOp::NeighborSample( nflows[i] = SamplerOp::NeighborSample(
gptr.get(), worker_seeds, neigh_type, num_hops, expand_factor, gptr.get(), worker_seeds, neigh_type, num_hops, expand_factor,
add_self_loop, probability); add_self_loop, probability);
...@@ -869,7 +847,7 @@ std::vector<NodeFlow> NeighborSamplingImpl(const ImmutableGraphPtr gptr, ...@@ -869,7 +847,7 @@ std::vector<NodeFlow> NeighborSamplingImpl(const ImmutableGraphPtr gptr,
} }
DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling") DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
// arguments // arguments
const GraphRef g = args[0]; const GraphRef g = args[0];
const IdArray seed_nodes = args[1]; const IdArray seed_nodes = args[1];
...@@ -896,7 +874,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling") ...@@ -896,7 +874,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
}); });
DGL_REGISTER_GLOBAL("sampling._CAPI_NeighborSampling") DGL_REGISTER_GLOBAL("sampling._CAPI_NeighborSampling")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
// arguments // arguments
const GraphRef g = args[0]; const GraphRef g = args[0];
const IdArray seed_nodes = args[1]; const IdArray seed_nodes = args[1];
...@@ -926,17 +904,17 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_NeighborSampling") ...@@ -926,17 +904,17 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_NeighborSampling")
<< "NeighborSampling only support CPU sampling"; << "NeighborSampling only support CPU sampling";
ATEN_FLOAT_TYPE_SWITCH( ATEN_FLOAT_TYPE_SWITCH(
probability->dtype, probability->dtype, FloatType, "transition probability", {
FloatType,
"transition probability",
{
const FloatType *prob; const FloatType *prob;
if (aten::IsNullArray(probability)) { if (aten::IsNullArray(probability)) {
prob = nullptr; prob = nullptr;
} else { } else {
CHECK(probability->shape[0] == static_cast<int64_t>(gptr->NumEdges())) CHECK(
<< "transition probability must have same number of elements as edges"; probability->shape[0] ==
static_cast<int64_t>(gptr->NumEdges()))
<< "transition probability must have same number of elements "
"as edges";
CHECK(probability.IsContiguous()) CHECK(probability.IsContiguous())
<< "transition probability must be contiguous tensor"; << "transition probability must be contiguous tensor";
prob = static_cast<const FloatType *>(probability->data); prob = static_cast<const FloatType *>(probability->data);
...@@ -951,7 +929,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_NeighborSampling") ...@@ -951,7 +929,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_NeighborSampling")
}); });
DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling") DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
// arguments // arguments
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray seed_nodes = args[1]; const IdArray seed_nodes = args[1];
...@@ -971,11 +949,14 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling") ...@@ -971,11 +949,14 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
CHECK_EQ(layer_sizes->ctx.device_type, kDGLCPU) CHECK_EQ(layer_sizes->ctx.device_type, kDGLCPU)
<< "LayerSampler only support CPU sampling"; << "LayerSampler only support CPU sampling";
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data); const dgl_id_t *seed_nodes_data =
static_cast<dgl_id_t *>(seed_nodes->data);
const int64_t num_seeds = seed_nodes->shape[0]; const int64_t num_seeds = seed_nodes->shape[0];
const int64_t num_workers = std::min(max_num_workers, const int64_t num_workers = std::min(
max_num_workers,
(num_seeds + batch_size - 1) / batch_size - batch_start_id); (num_seeds + batch_size - 1) / batch_size - batch_start_id);
// We need to make sure we have the right CSR before we enter parallel sampling. // We need to make sure we have the right CSR before we enter parallel
// sampling.
BuildCsr(*gptr, neigh_type); BuildCsr(*gptr, neigh_type);
// generate node flows // generate node flows
std::vector<NodeFlow> nflows(num_workers); std::vector<NodeFlow> nflows(num_workers);
...@@ -986,7 +967,8 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling") ...@@ -986,7 +967,8 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
const int64_t end = std::min(start + batch_size, num_seeds); const int64_t end = std::min(start + batch_size, num_seeds);
// TODO(minjie): the vector allocation/copy is unnecessary // TODO(minjie): the vector allocation/copy is unnecessary
std::vector<dgl_id_t> worker_seeds(end - start); std::vector<dgl_id_t> worker_seeds(end - start);
std::copy(seed_nodes_data + start, seed_nodes_data + end, std::copy(
seed_nodes_data + start, seed_nodes_data + end,
worker_seeds.begin()); worker_seeds.begin());
nflows[i] = SamplerOp::LayerUniformSample( nflows[i] = SamplerOp::LayerUniformSample(
gptr.get(), worker_seeds, neigh_type, layer_sizes); gptr.get(), worker_seeds, neigh_type, layer_sizes);
...@@ -1002,9 +984,8 @@ void BuildCoo(const ImmutableGraph &g) { ...@@ -1002,9 +984,8 @@ void BuildCoo(const ImmutableGraph &g) {
assert(coo); assert(coo);
} }
dgl_id_t global2local_map(
dgl_id_t global2local_map(dgl_id_t global_id, dgl_id_t global_id, std::unordered_map<dgl_id_t, dgl_id_t> *map) {
std::unordered_map<dgl_id_t, dgl_id_t> *map) {
auto it = map->find(global_id); auto it = map->find(global_id);
if (it == map->end()) { if (it == map->end()) {
dgl_id_t local_id = map->size(); dgl_id_t local_id = map->size();
...@@ -1020,7 +1001,8 @@ inline bool IsNegativeHeadMode(const std::string &mode) { ...@@ -1020,7 +1001,8 @@ inline bool IsNegativeHeadMode(const std::string &mode) {
} }
IdArray GetGlobalVid(IdArray induced_nid, IdArray subg_nid) { IdArray GetGlobalVid(IdArray induced_nid, IdArray subg_nid) {
IdArray gnid = IdArray::Empty({subg_nid->shape[0]}, subg_nid->dtype, subg_nid->ctx); IdArray gnid =
IdArray::Empty({subg_nid->shape[0]}, subg_nid->dtype, subg_nid->ctx);
const dgl_id_t *induced_nid_data = static_cast<dgl_id_t *>(induced_nid->data); const dgl_id_t *induced_nid_data = static_cast<dgl_id_t *>(induced_nid->data);
const dgl_id_t *subg_nid_data = static_cast<dgl_id_t *>(subg_nid->data); const dgl_id_t *subg_nid_data = static_cast<dgl_id_t *>(subg_nid->data);
dgl_id_t *gnid_data = static_cast<dgl_id_t *>(gnid->data); dgl_id_t *gnid_data = static_cast<dgl_id_t *>(gnid->data);
...@@ -1030,14 +1012,14 @@ IdArray GetGlobalVid(IdArray induced_nid, IdArray subg_nid) { ...@@ -1030,14 +1012,14 @@ IdArray GetGlobalVid(IdArray induced_nid, IdArray subg_nid) {
return gnid; return gnid;
} }
IdArray CheckExistence(GraphPtr gptr, IdArray neg_src, IdArray neg_dst, IdArray CheckExistence(
IdArray induced_nid) { GraphPtr gptr, IdArray neg_src, IdArray neg_dst, IdArray induced_nid) {
return gptr->HasEdgesBetween(GetGlobalVid(induced_nid, neg_src), return gptr->HasEdgesBetween(
GetGlobalVid(induced_nid, neg_dst)); GetGlobalVid(induced_nid, neg_src), GetGlobalVid(induced_nid, neg_dst));
} }
IdArray CheckExistence(GraphPtr gptr, IdArray relations, IdArray CheckExistence(
IdArray neg_src, IdArray neg_dst, GraphPtr gptr, IdArray relations, IdArray neg_src, IdArray neg_dst,
IdArray induced_nid, IdArray neg_eid) { IdArray induced_nid, IdArray neg_eid) {
neg_src = GetGlobalVid(induced_nid, neg_src); neg_src = GetGlobalVid(induced_nid, neg_src);
neg_dst = GetGlobalVid(induced_nid, neg_dst); neg_dst = GetGlobalVid(induced_nid, neg_dst);
...@@ -1051,8 +1033,7 @@ IdArray CheckExistence(GraphPtr gptr, IdArray relations, ...@@ -1051,8 +1033,7 @@ IdArray CheckExistence(GraphPtr gptr, IdArray relations,
int64_t num_neg_edges = neg_src->shape[0]; int64_t num_neg_edges = neg_src->shape[0];
for (int64_t i = 0; i < num_neg_edges; i++) { for (int64_t i = 0; i < num_neg_edges; i++) {
// If the edge doesn't exist, we don't need to do anything. // If the edge doesn't exist, we don't need to do anything.
if (!exist_data[i]) if (!exist_data[i]) continue;
continue;
// If the edge exists, we need to double check if the relations match. // If the edge exists, we need to double check if the relations match.
// If they match, this negative edge isn't really a negative edge. // If they match, this negative edge isn't really a negative edge.
dgl_id_t eid1 = neg_eid_data[i]; dgl_id_t eid1 = neg_eid_data[i];
...@@ -1073,7 +1054,8 @@ IdArray CheckExistence(GraphPtr gptr, IdArray relations, ...@@ -1073,7 +1054,8 @@ IdArray CheckExistence(GraphPtr gptr, IdArray relations,
return exist; return exist;
} }
std::vector<dgl_id_t> Global2Local(const std::vector<size_t> &ids, std::vector<dgl_id_t> Global2Local(
const std::vector<size_t> &ids,
const std::unordered_map<dgl_id_t, dgl_id_t> &map) { const std::unordered_map<dgl_id_t, dgl_id_t> &map) {
std::vector<dgl_id_t> local_ids(ids.size()); std::vector<dgl_id_t> local_ids(ids.size());
for (size_t i = 0; i < ids.size(); i++) { for (size_t i = 0; i < ids.size(); i++) {
...@@ -1084,35 +1066,36 @@ std::vector<dgl_id_t> Global2Local(const std::vector<size_t> &ids, ...@@ -1084,35 +1066,36 @@ std::vector<dgl_id_t> Global2Local(const std::vector<size_t> &ids,
return local_ids; return local_ids;
} }
NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg, NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(
const std::string &neg_mode, const Subgraph &pos_subg, const std::string &neg_mode,
int64_t neg_sample_size, int64_t neg_sample_size, bool exclude_positive, bool check_false_neg) {
bool exclude_positive,
bool check_false_neg) {
int64_t num_tot_nodes = gptr_->NumVertices(); int64_t num_tot_nodes = gptr_->NumVertices();
if (neg_sample_size > num_tot_nodes) if (neg_sample_size > num_tot_nodes) neg_sample_size = num_tot_nodes;
neg_sample_size = num_tot_nodes;
std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo"); std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo");
IdArray coo = adj[0]; IdArray coo = adj[0];
int64_t num_pos_edges = coo->shape[0] / 2; int64_t num_pos_edges = coo->shape[0] / 2;
int64_t num_neg_edges = num_pos_edges * neg_sample_size; int64_t num_neg_edges = num_pos_edges * neg_sample_size;
IdArray neg_dst = IdArray::Empty({num_neg_edges}, coo->dtype, coo->ctx); IdArray neg_dst = IdArray::Empty({num_neg_edges}, coo->dtype, coo->ctx);
IdArray neg_src = IdArray::Empty({num_neg_edges}, coo->dtype, coo->ctx); IdArray neg_src = IdArray::Empty({num_neg_edges}, coo->dtype, coo->ctx);
IdArray induced_neg_eid = IdArray::Empty({num_neg_edges}, coo->dtype, coo->ctx); IdArray induced_neg_eid =
IdArray::Empty({num_neg_edges}, coo->dtype, coo->ctx);
// These are vids in the positive subgraph. // These are vids in the positive subgraph.
const dgl_id_t *dst_data = static_cast<const dgl_id_t *>(coo->data); const dgl_id_t *dst_data = static_cast<const dgl_id_t *>(coo->data);
const dgl_id_t *src_data = static_cast<const dgl_id_t *>(coo->data) + num_pos_edges; const dgl_id_t *src_data =
static_cast<const dgl_id_t *>(coo->data) + num_pos_edges;
const dgl_id_t *induced_vid_data = const dgl_id_t *induced_vid_data =
static_cast<const dgl_id_t *>(pos_subg.induced_vertices->data); static_cast<const dgl_id_t *>(pos_subg.induced_vertices->data);
const dgl_id_t *induced_eid_data = const dgl_id_t *induced_eid_data =
static_cast<const dgl_id_t *>(pos_subg.induced_edges->data); static_cast<const dgl_id_t *>(pos_subg.induced_edges->data);
size_t num_pos_nodes = pos_subg.graph->NumVertices(); size_t num_pos_nodes = pos_subg.graph->NumVertices();
std::vector<size_t> pos_nodes(induced_vid_data, induced_vid_data + num_pos_nodes); std::vector<size_t> pos_nodes(
induced_vid_data, induced_vid_data + num_pos_nodes);
dgl_id_t *neg_dst_data = static_cast<dgl_id_t *>(neg_dst->data); dgl_id_t *neg_dst_data = static_cast<dgl_id_t *>(neg_dst->data);
dgl_id_t *neg_src_data = static_cast<dgl_id_t *>(neg_src->data); dgl_id_t *neg_src_data = static_cast<dgl_id_t *>(neg_src->data);
dgl_id_t *induced_neg_eid_data = static_cast<dgl_id_t *>(induced_neg_eid->data); dgl_id_t *induced_neg_eid_data =
static_cast<dgl_id_t *>(induced_neg_eid->data);
const dgl_id_t *unchanged; const dgl_id_t *unchanged;
dgl_id_t *neg_unchanged; dgl_id_t *neg_unchanged;
...@@ -1206,7 +1189,8 @@ NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg, ...@@ -1206,7 +1189,8 @@ NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg,
for (int64_t j = 0; j < neg_sample_size; j++) { for (int64_t j = 0; j < neg_sample_size; j++) {
neg_unchanged[neg_idx + j] = local_unchanged; neg_unchanged[neg_idx + j] = local_unchanged;
dgl_id_t local_changed = global2local_map(neg_vids[j + prev_neg_offset], &neg_map); dgl_id_t local_changed =
global2local_map(neg_vids[j + prev_neg_offset], &neg_map);
neg_changed[neg_idx + j] = local_changed; neg_changed[neg_idx + j] = local_changed;
// induced negative eid references to the positive one. // induced negative eid references to the positive one.
induced_neg_eid_data[neg_idx + j] = induced_eid_data[i]; induced_neg_eid_data[neg_idx + j] = induced_eid_data[i];
...@@ -1215,8 +1199,10 @@ NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg, ...@@ -1215,8 +1199,10 @@ NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg,
// Now we know the number of vertices in the negative graph. // Now we know the number of vertices in the negative graph.
int64_t num_neg_nodes = neg_map.size(); int64_t num_neg_nodes = neg_map.size();
IdArray induced_neg_vid = IdArray::Empty({num_neg_nodes}, coo->dtype, coo->ctx); IdArray induced_neg_vid =
dgl_id_t *induced_neg_vid_data = static_cast<dgl_id_t *>(induced_neg_vid->data); IdArray::Empty({num_neg_nodes}, coo->dtype, coo->ctx);
dgl_id_t *induced_neg_vid_data =
static_cast<dgl_id_t *>(induced_neg_vid->data);
for (auto it = neg_map.begin(); it != neg_map.end(); it++) { for (auto it = neg_map.begin(); it != neg_map.end(); it++) {
induced_neg_vid_data[it->second] = it->first; induced_neg_vid_data[it->second] = it->first;
} }
...@@ -1241,24 +1227,22 @@ NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg, ...@@ -1241,24 +1227,22 @@ NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg,
if (aten::IsNullArray(relations_)) { if (aten::IsNullArray(relations_)) {
neg_subg.exist = CheckExistence(gptr_, neg_src, neg_dst, induced_neg_vid); neg_subg.exist = CheckExistence(gptr_, neg_src, neg_dst, induced_neg_vid);
} else { } else {
neg_subg.exist = CheckExistence(gptr_, relations_, neg_src, neg_dst, neg_subg.exist = CheckExistence(
induced_neg_vid, induced_neg_eid); gptr_, relations_, neg_src, neg_dst, induced_neg_vid,
induced_neg_eid);
} }
} }
return neg_subg; return neg_subg;
} }
NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_subg, NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(
const std::string &neg_mode, const Subgraph &pos_subg, const std::string &neg_mode,
int64_t neg_sample_size, int64_t neg_sample_size, bool exclude_positive, bool check_false_neg) {
bool exclude_positive,
bool check_false_neg) {
int64_t num_tot_nodes = gptr_->NumVertices(); int64_t num_tot_nodes = gptr_->NumVertices();
std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo"); std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo");
IdArray coo = adj[0]; IdArray coo = adj[0];
int64_t num_pos_edges = coo->shape[0] / 2; int64_t num_pos_edges = coo->shape[0] / 2;
if (neg_sample_size > num_tot_nodes) if (neg_sample_size > num_tot_nodes) neg_sample_size = num_tot_nodes;
neg_sample_size = num_tot_nodes;
int64_t chunk_size = chunk_size_; int64_t chunk_size = chunk_size_;
CHECK_GT(chunk_size, 0) << "chunk size has to be positive"; CHECK_GT(chunk_size, 0) << "chunk size has to be positive";
...@@ -1275,26 +1259,29 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub ...@@ -1275,26 +1259,29 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub
int64_t num_all_neg_edges = num_neg_edges + num_neg_edges_last_chunk; int64_t num_all_neg_edges = num_neg_edges + num_neg_edges_last_chunk;
// We should include the last chunk. // We should include the last chunk.
if (last_chunk_size > 0) if (last_chunk_size > 0) num_chunks++;
num_chunks++;
IdArray neg_dst = IdArray::Empty({num_all_neg_edges}, coo->dtype, coo->ctx); IdArray neg_dst = IdArray::Empty({num_all_neg_edges}, coo->dtype, coo->ctx);
IdArray neg_src = IdArray::Empty({num_all_neg_edges}, coo->dtype, coo->ctx); IdArray neg_src = IdArray::Empty({num_all_neg_edges}, coo->dtype, coo->ctx);
IdArray induced_neg_eid = IdArray::Empty({num_all_neg_edges}, coo->dtype, coo->ctx); IdArray induced_neg_eid =
IdArray::Empty({num_all_neg_edges}, coo->dtype, coo->ctx);
// These are vids in the positive subgraph. // These are vids in the positive subgraph.
const dgl_id_t *dst_data = static_cast<const dgl_id_t *>(coo->data); const dgl_id_t *dst_data = static_cast<const dgl_id_t *>(coo->data);
const dgl_id_t *src_data = static_cast<const dgl_id_t *>(coo->data) + num_pos_edges; const dgl_id_t *src_data =
static_cast<const dgl_id_t *>(coo->data) + num_pos_edges;
const dgl_id_t *induced_vid_data = const dgl_id_t *induced_vid_data =
static_cast<const dgl_id_t *>(pos_subg.induced_vertices->data); static_cast<const dgl_id_t *>(pos_subg.induced_vertices->data);
const dgl_id_t *induced_eid_data = const dgl_id_t *induced_eid_data =
static_cast<const dgl_id_t *>(pos_subg.induced_edges->data); static_cast<const dgl_id_t *>(pos_subg.induced_edges->data);
int64_t num_pos_nodes = pos_subg.graph->NumVertices(); int64_t num_pos_nodes = pos_subg.graph->NumVertices();
std::vector<dgl_id_t> pos_nodes(induced_vid_data, induced_vid_data + num_pos_nodes); std::vector<dgl_id_t> pos_nodes(
induced_vid_data, induced_vid_data + num_pos_nodes);
dgl_id_t *neg_dst_data = static_cast<dgl_id_t *>(neg_dst->data); dgl_id_t *neg_dst_data = static_cast<dgl_id_t *>(neg_dst->data);
dgl_id_t *neg_src_data = static_cast<dgl_id_t *>(neg_src->data); dgl_id_t *neg_src_data = static_cast<dgl_id_t *>(neg_src->data);
dgl_id_t *induced_neg_eid_data = static_cast<dgl_id_t *>(induced_neg_eid->data); dgl_id_t *induced_neg_eid_data =
static_cast<dgl_id_t *>(induced_neg_eid->data);
const dgl_id_t *unchanged; const dgl_id_t *unchanged;
dgl_id_t *neg_unchanged; dgl_id_t *neg_unchanged;
...@@ -1312,9 +1299,7 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub ...@@ -1312,9 +1299,7 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub
// We first sample all negative edges. // We first sample all negative edges.
std::vector<size_t> global_neg_vids; std::vector<size_t> global_neg_vids;
std::vector<size_t> local_neg_vids; std::vector<size_t> local_neg_vids;
randomSample(num_tot_nodes, randomSample(num_tot_nodes, num_chunks * neg_sample_size, &global_neg_vids);
num_chunks * neg_sample_size,
&global_neg_vids);
CHECK_EQ(num_chunks * neg_sample_size, global_neg_vids.size()); CHECK_EQ(num_chunks * neg_sample_size, global_neg_vids.size());
std::unordered_map<dgl_id_t, dgl_id_t> neg_map; std::unordered_map<dgl_id_t, dgl_id_t> neg_map;
...@@ -1336,7 +1321,7 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub ...@@ -1336,7 +1321,7 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub
// to reduce computation overhead. // to reduce computation overhead.
local_neg_vids.resize(global_neg_vids.size()); local_neg_vids.resize(global_neg_vids.size());
for (size_t i = 0; i < global_neg_vids.size(); i++) { for (size_t i = 0; i < global_neg_vids.size(); i++) {
local_neg_vids[i] = global2local_map(global_neg_vids[i], &neg_map);; local_neg_vids[i] = global2local_map(global_neg_vids[i], &neg_map);
} }
for (int64_t i_chunk = 0; i_chunk < num_chunks; i_chunk++) { for (int64_t i_chunk = 0; i_chunk < num_chunks; i_chunk++) {
...@@ -1353,12 +1338,14 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub ...@@ -1353,12 +1338,14 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub
for (int64_t in_chunk = 0; in_chunk != chunk_size1; ++in_chunk) { for (int64_t in_chunk = 0; in_chunk != chunk_size1; ++in_chunk) {
// For each positive node in a chunk. // For each positive node in a chunk.
dgl_id_t global_unchanged = induced_vid_data[unchanged[pos_edge_idx + in_chunk]]; dgl_id_t global_unchanged =
induced_vid_data[unchanged[pos_edge_idx + in_chunk]];
dgl_id_t local_unchanged = global2local_map(global_unchanged, &neg_map); dgl_id_t local_unchanged = global2local_map(global_unchanged, &neg_map);
for (int64_t j = 0; j < neg_sample_size; ++j) { for (int64_t j = 0; j < neg_sample_size; ++j) {
neg_unchanged[neg_idx] = local_unchanged; neg_unchanged[neg_idx] = local_unchanged;
neg_changed[neg_idx] = local_neg_vids[neg_node_idx + j]; neg_changed[neg_idx] = local_neg_vids[neg_node_idx + j];
induced_neg_eid_data[neg_idx] = induced_eid_data[pos_edge_idx + in_chunk]; induced_neg_eid_data[neg_idx] =
induced_eid_data[pos_edge_idx + in_chunk];
neg_idx++; neg_idx++;
} }
} }
...@@ -1366,8 +1353,10 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub ...@@ -1366,8 +1353,10 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub
// Now we know the number of vertices in the negative graph. // Now we know the number of vertices in the negative graph.
int64_t num_neg_nodes = neg_map.size(); int64_t num_neg_nodes = neg_map.size();
IdArray induced_neg_vid = IdArray::Empty({num_neg_nodes}, coo->dtype, coo->ctx); IdArray induced_neg_vid =
dgl_id_t *induced_neg_vid_data = static_cast<dgl_id_t *>(induced_neg_vid->data); IdArray::Empty({num_neg_nodes}, coo->dtype, coo->ctx);
dgl_id_t *induced_neg_vid_data =
static_cast<dgl_id_t *>(induced_neg_vid->data);
for (auto it = neg_map.begin(); it != neg_map.end(); it++) { for (auto it = neg_map.begin(); it != neg_map.end(); it++) {
induced_neg_vid_data[it->second] = it->first; induced_neg_vid_data[it->second] = it->first;
} }
...@@ -1380,18 +1369,21 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub ...@@ -1380,18 +1369,21 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub
neg_subg.induced_vertices = induced_neg_vid; neg_subg.induced_vertices = induced_neg_vid;
neg_subg.induced_edges = induced_neg_eid; neg_subg.induced_edges = induced_neg_eid;
if (IsNegativeHeadMode(neg_mode)) { if (IsNegativeHeadMode(neg_mode)) {
neg_subg.head_nid = aten::VecToIdArray(Global2Local(global_neg_vids, neg_map)); neg_subg.head_nid =
aten::VecToIdArray(Global2Local(global_neg_vids, neg_map));
neg_subg.tail_nid = aten::VecToIdArray(local_pos_vids); neg_subg.tail_nid = aten::VecToIdArray(local_pos_vids);
} else { } else {
neg_subg.head_nid = aten::VecToIdArray(local_pos_vids); neg_subg.head_nid = aten::VecToIdArray(local_pos_vids);
neg_subg.tail_nid = aten::VecToIdArray(Global2Local(global_neg_vids, neg_map)); neg_subg.tail_nid =
aten::VecToIdArray(Global2Local(global_neg_vids, neg_map));
} }
if (check_false_neg) { if (check_false_neg) {
if (aten::IsNullArray(relations_)) { if (aten::IsNullArray(relations_)) {
neg_subg.exist = CheckExistence(gptr_, neg_src, neg_dst, induced_neg_vid); neg_subg.exist = CheckExistence(gptr_, neg_src, neg_dst, induced_neg_vid);
} else { } else {
neg_subg.exist = CheckExistence(gptr_, relations_, neg_src, neg_dst, neg_subg.exist = CheckExistence(
induced_neg_vid, induced_neg_eid); gptr_, relations_, neg_src, neg_dst, induced_neg_vid,
induced_neg_eid);
} }
} }
return neg_subg; return neg_subg;
...@@ -1408,52 +1400,38 @@ inline SubgraphRef ConvertRef(const NegSubgraph &subg) { ...@@ -1408,52 +1400,38 @@ inline SubgraphRef ConvertRef(const NegSubgraph &subg) {
} // namespace } // namespace
DGL_REGISTER_GLOBAL("sampling._CAPI_GetNegEdgeExistence") DGL_REGISTER_GLOBAL("sampling._CAPI_GetNegEdgeExistence")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
SubgraphRef g = args[0]; SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr()); auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr());
*rv = gptr->exist; *rv = gptr->exist;
}); });
DGL_REGISTER_GLOBAL("sampling._CAPI_GetEdgeSubgraphHead") DGL_REGISTER_GLOBAL("sampling._CAPI_GetEdgeSubgraphHead")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
SubgraphRef g = args[0]; SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr()); auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr());
*rv = gptr->head_nid; *rv = gptr->head_nid;
}); });
DGL_REGISTER_GLOBAL("sampling._CAPI_GetEdgeSubgraphTail") DGL_REGISTER_GLOBAL("sampling._CAPI_GetEdgeSubgraphTail")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
SubgraphRef g = args[0]; SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr()); auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr());
*rv = gptr->tail_nid; *rv = gptr->tail_nid;
}); });
class UniformEdgeSamplerObject: public EdgeSamplerObject { class UniformEdgeSamplerObject : public EdgeSamplerObject {
public: public:
explicit UniformEdgeSamplerObject(const GraphPtr gptr, explicit UniformEdgeSamplerObject(
IdArray seed_edges, const GraphPtr gptr, IdArray seed_edges, const int64_t batch_size,
const int64_t batch_size, const int64_t num_workers, const bool replacement, const bool reset,
const int64_t num_workers, const std::string neg_mode, const int64_t neg_sample_size,
const bool replacement, const int64_t chunk_size, const bool exclude_positive,
const bool reset, const bool check_false_neg, IdArray relations)
const std::string neg_mode, : EdgeSamplerObject(
const int64_t neg_sample_size, gptr, seed_edges, batch_size, num_workers, replacement, reset,
const int64_t chunk_size, neg_mode, neg_sample_size, chunk_size, exclude_positive,
const bool exclude_positive, check_false_neg, relations) {
const bool check_false_neg,
IdArray relations)
: EdgeSamplerObject(gptr,
seed_edges,
batch_size,
num_workers,
replacement,
reset,
neg_mode,
neg_sample_size,
chunk_size,
exclude_positive,
check_false_neg,
relations) {
batch_curr_id_ = 0; batch_curr_id_ = 0;
num_seeds_ = seed_edges->shape[0]; num_seeds_ = seed_edges->shape[0];
max_batch_id_ = (num_seeds_ + batch_size - 1) / batch_size; max_batch_id_ = (num_seeds_ + batch_size - 1) / batch_size;
...@@ -1463,8 +1441,9 @@ public: ...@@ -1463,8 +1441,9 @@ public:
} }
~UniformEdgeSamplerObject() {} ~UniformEdgeSamplerObject() {}
void Fetch(DGLRetValue* rv) { void Fetch(DGLRetValue *rv) {
const int64_t num_workers = std::min(num_workers_, max_batch_id_ - batch_curr_id_); const int64_t num_workers =
std::min(num_workers_, max_batch_id_ - batch_curr_id_);
// generate subgraphs. // generate subgraphs.
std::vector<SubgraphRef> positive_subgs(num_workers); std::vector<SubgraphRef> positive_subgs(num_workers);
std::vector<SubgraphRef> negative_subgs(num_workers); std::vector<SubgraphRef> negative_subgs(num_workers);
...@@ -1477,11 +1456,13 @@ public: ...@@ -1477,11 +1456,13 @@ public:
IdArray worker_seeds; IdArray worker_seeds;
if (replacement_ == false) { if (replacement_ == false) {
worker_seeds = seed_edges_.CreateView({num_edges}, DGLDataType{kDGLInt, 64, 1}, worker_seeds = seed_edges_.CreateView(
{num_edges}, DGLDataType{kDGLInt, 64, 1},
sizeof(dgl_id_t) * start); sizeof(dgl_id_t) * start);
} else { } else {
std::vector<dgl_id_t> seeds; std::vector<dgl_id_t> seeds;
const dgl_id_t *seed_edge_ids = static_cast<const dgl_id_t *>(seed_edges_->data); const dgl_id_t *seed_edge_ids =
static_cast<const dgl_id_t *>(seed_edges_->data);
// sampling of each edge is a standalone event // sampling of each edge is a standalone event
for (int64_t i = 0; i < num_edges; ++i) { for (int64_t i = 0; i < num_edges; ++i) {
int64_t seed = static_cast<const int64_t>( int64_t seed = static_cast<const int64_t>(
...@@ -1497,29 +1478,29 @@ public: ...@@ -1497,29 +1478,29 @@ public:
const dgl_id_t *dst_ids = static_cast<const dgl_id_t *>(arr.dst->data); const dgl_id_t *dst_ids = static_cast<const dgl_id_t *>(arr.dst->data);
std::vector<dgl_id_t> src_vec(src_ids, src_ids + num_edges); std::vector<dgl_id_t> src_vec(src_ids, src_ids + num_edges);
std::vector<dgl_id_t> dst_vec(dst_ids, dst_ids + num_edges); std::vector<dgl_id_t> dst_vec(dst_ids, dst_ids + num_edges);
// TODO(zhengda) what if there are duplicates in the src and dst vectors. // TODO(zhengda) what if there are duplicates in the src and dst
// vectors.
Subgraph subg = gptr_->EdgeSubgraph(worker_seeds, false); Subgraph subg = gptr_->EdgeSubgraph(worker_seeds, false);
positive_subgs[i] = ConvertRef(subg); positive_subgs[i] = ConvertRef(subg);
// For chunked negative sampling, we accept "chunk-head" for corrupting head // For chunked negative sampling, we accept "chunk-head" for corrupting
// nodes and "chunk-tail" for corrupting tail nodes. // head nodes and "chunk-tail" for corrupting tail nodes.
if (neg_mode_.substr(0, 5) == "chunk") { if (neg_mode_.substr(0, 5) == "chunk") {
NegSubgraph neg_subg = genChunkedNegEdgeSubgraph(subg, neg_mode_.substr(6), NegSubgraph neg_subg = genChunkedNegEdgeSubgraph(
neg_sample_size_, subg, neg_mode_.substr(6), neg_sample_size_, exclude_positive_,
exclude_positive_,
check_false_neg_); check_false_neg_);
negative_subgs[i] = ConvertRef(neg_subg); negative_subgs[i] = ConvertRef(neg_subg);
} else if (neg_mode_ == "head" || neg_mode_ == "tail") { } else if (neg_mode_ == "head" || neg_mode_ == "tail") {
NegSubgraph neg_subg = genNegEdgeSubgraph(subg, neg_mode_, NegSubgraph neg_subg = genNegEdgeSubgraph(
neg_sample_size_, subg, neg_mode_, neg_sample_size_, exclude_positive_,
exclude_positive_,
check_false_neg_); check_false_neg_);
negative_subgs[i] = ConvertRef(neg_subg); negative_subgs[i] = ConvertRef(neg_subg);
} }
} }
}); });
if (neg_mode_.size() > 0) { if (neg_mode_.size() > 0) {
positive_subgs.insert(positive_subgs.end(), negative_subgs.begin(), negative_subgs.end()); positive_subgs.insert(
positive_subgs.end(), negative_subgs.begin(), negative_subgs.end());
} }
batch_curr_id_ += num_workers; batch_curr_id_ += num_workers;
...@@ -1535,20 +1516,22 @@ public: ...@@ -1535,20 +1516,22 @@ public:
if (replacement_ == false) { if (replacement_ == false) {
// Now we should shuffle the data and reset the sampler. // Now we should shuffle the data and reset the sampler.
dgl_id_t *seed_ids = static_cast<dgl_id_t *>(seed_edges_->data); dgl_id_t *seed_ids = static_cast<dgl_id_t *>(seed_edges_->data);
std::shuffle(seed_ids, seed_ids + seed_edges_->shape[0], std::shuffle(
seed_ids, seed_ids + seed_edges_->shape[0],
std::default_random_engine()); std::default_random_engine());
} }
} }
DGL_DECLARE_OBJECT_TYPE_INFO(UniformEdgeSamplerObject, Object); DGL_DECLARE_OBJECT_TYPE_INFO(UniformEdgeSamplerObject, Object);
private: private:
void randomSample(size_t set_size, size_t num, std::vector<size_t>* out) { void randomSample(size_t set_size, size_t num, std::vector<size_t> *out) {
RandomSample(set_size, num, out); RandomSample(set_size, num, out);
} }
void randomSample(size_t set_size, size_t num, const std::vector<size_t> &exclude, void randomSample(
std::vector<size_t>* out) { size_t set_size, size_t num, const std::vector<size_t> &exclude,
std::vector<size_t> *out) {
RandomSample(set_size, num, exclude, out); RandomSample(set_size, num, exclude, out);
} }
...@@ -1557,17 +1540,19 @@ private: ...@@ -1557,17 +1540,19 @@ private:
int64_t num_seeds_; int64_t num_seeds_;
}; };
class UniformEdgeSampler: public ObjectRef { class UniformEdgeSampler : public ObjectRef {
public: public:
UniformEdgeSampler() {} UniformEdgeSampler() {}
explicit UniformEdgeSampler(std::shared_ptr<runtime::Object> obj): ObjectRef(obj) {} explicit UniformEdgeSampler(std::shared_ptr<runtime::Object> obj)
: ObjectRef(obj) {}
UniformEdgeSamplerObject* operator->() const { UniformEdgeSamplerObject *operator->() const {
return static_cast<UniformEdgeSamplerObject*>(obj_.get()); return static_cast<UniformEdgeSamplerObject *>(obj_.get());
} }
std::shared_ptr<UniformEdgeSamplerObject> sptr() const { std::shared_ptr<UniformEdgeSamplerObject> sptr() const {
return CHECK_NOTNULL(std::dynamic_pointer_cast<UniformEdgeSamplerObject>(obj_)); return CHECK_NOTNULL(
std::dynamic_pointer_cast<UniformEdgeSamplerObject>(obj_));
} }
operator bool() const { return this->defined(); } operator bool() const { return this->defined(); }
...@@ -1575,7 +1560,7 @@ class UniformEdgeSampler: public ObjectRef { ...@@ -1575,7 +1560,7 @@ class UniformEdgeSampler: public ObjectRef {
}; };
DGL_REGISTER_GLOBAL("sampling._CAPI_CreateUniformEdgeSampler") DGL_REGISTER_GLOBAL("sampling._CAPI_CreateUniformEdgeSampler")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
// arguments // arguments
GraphRef g = args[0]; GraphRef g = args[0];
IdArray seed_edges = args[1]; IdArray seed_edges = args[1];
...@@ -1603,64 +1588,42 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_CreateUniformEdgeSampler") ...@@ -1603,64 +1588,42 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_CreateUniformEdgeSampler")
} }
BuildCoo(*gptr); BuildCoo(*gptr);
auto o = std::make_shared<UniformEdgeSamplerObject>(gptr, auto o = std::make_shared<UniformEdgeSamplerObject>(
seed_edges, gptr, seed_edges, batch_size, max_num_workers, replacement, reset,
batch_size, neg_mode, neg_sample_size, chunk_size, exclude_positive,
max_num_workers, check_false_neg, relations);
replacement,
reset,
neg_mode,
neg_sample_size,
chunk_size,
exclude_positive,
check_false_neg,
relations);
*rv = o; *rv = o;
}); });
DGL_REGISTER_GLOBAL("sampling._CAPI_FetchUniformEdgeSample") DGL_REGISTER_GLOBAL("sampling._CAPI_FetchUniformEdgeSample")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
UniformEdgeSampler sampler = args[0]; UniformEdgeSampler sampler = args[0];
sampler->Fetch(rv); sampler->Fetch(rv);
}); });
DGL_REGISTER_GLOBAL("sampling._CAPI_ResetUniformEdgeSample") DGL_REGISTER_GLOBAL("sampling._CAPI_ResetUniformEdgeSample")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
UniformEdgeSampler sampler = args[0]; UniformEdgeSampler sampler = args[0];
sampler->Reset(); sampler->Reset();
}); });
template<typename ValueType> template <typename ValueType>
class WeightedEdgeSamplerObject: public EdgeSamplerObject { class WeightedEdgeSamplerObject : public EdgeSamplerObject {
public: public:
explicit WeightedEdgeSamplerObject(const GraphPtr gptr, explicit WeightedEdgeSamplerObject(
IdArray seed_edges, const GraphPtr gptr, IdArray seed_edges, NDArray edge_weight,
NDArray edge_weight, NDArray node_weight, const int64_t batch_size, const int64_t num_workers,
NDArray node_weight, const bool replacement, const bool reset, const std::string neg_mode,
const int64_t batch_size, const int64_t neg_sample_size, const int64_t chunk_size,
const int64_t num_workers, const bool exclude_positive, const bool check_false_neg,
const bool replacement,
const bool reset,
const std::string neg_mode,
const int64_t neg_sample_size,
const int64_t chunk_size,
const bool exclude_positive,
const bool check_false_neg,
IdArray relations) IdArray relations)
: EdgeSamplerObject(gptr, : EdgeSamplerObject(
seed_edges, gptr, seed_edges, batch_size, num_workers, replacement, reset,
batch_size, neg_mode, neg_sample_size, chunk_size, exclude_positive,
num_workers, check_false_neg, relations) {
replacement,
reset,
neg_mode,
neg_sample_size,
chunk_size,
exclude_positive,
check_false_neg,
relations) {
const int64_t num_edges = edge_weight->shape[0]; const int64_t num_edges = edge_weight->shape[0];
const ValueType *edge_prob = static_cast<const ValueType*>(edge_weight->data); const ValueType *edge_prob =
static_cast<const ValueType *>(edge_weight->data);
std::vector<ValueType> eprob(num_edges); std::vector<ValueType> eprob(num_edges);
for (int64_t i = 0; i < num_edges; ++i) { for (int64_t i = 0; i < num_edges; ++i) {
eprob[i] = edge_prob[i]; eprob[i] = edge_prob[i];
...@@ -1672,7 +1635,8 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject { ...@@ -1672,7 +1635,8 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
if (num_nodes == 0) { if (num_nodes == 0) {
node_selector_ = nullptr; node_selector_ = nullptr;
} else { } else {
const ValueType *node_prob = static_cast<const ValueType*>(node_weight->data); const ValueType *node_prob =
static_cast<const ValueType *>(node_weight->data);
std::vector<ValueType> nprob(num_nodes); std::vector<ValueType> nprob(num_nodes);
for (size_t i = 0; i < num_nodes; ++i) { for (size_t i = 0; i < num_nodes; ++i) {
nprob[i] = node_prob[i]; nprob[i] = node_prob[i];
...@@ -1687,27 +1651,26 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject { ...@@ -1687,27 +1651,26 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
gptr_->FindEdge(0); gptr_->FindEdge(0);
} }
~WeightedEdgeSamplerObject() { ~WeightedEdgeSamplerObject() {}
}
void Fetch(DGLRetValue* rv) { void Fetch(DGLRetValue *rv) {
const int64_t num_workers = std::min(num_workers_, max_batch_id_ - curr_batch_id_); const int64_t num_workers =
std::min(num_workers_, max_batch_id_ - curr_batch_id_);
// generate subgraphs. // generate subgraphs.
std::vector<SubgraphRef> positive_subgs(num_workers); std::vector<SubgraphRef> positive_subgs(num_workers);
std::vector<SubgraphRef> negative_subgs(num_workers); std::vector<SubgraphRef> negative_subgs(num_workers);
#pragma omp parallel for #pragma omp parallel for
for (int i = 0; i < num_workers; i++) { for (int i = 0; i < num_workers; i++) {
const dgl_id_t *seed_edge_ids = static_cast<const dgl_id_t *>(seed_edges_->data); const dgl_id_t *seed_edge_ids =
static_cast<const dgl_id_t *>(seed_edges_->data);
std::vector<size_t> edge_ids(batch_size_); std::vector<size_t> edge_ids(batch_size_);
if (replacement_ == false) { if (replacement_ == false) {
size_t n = batch_size_; size_t n = batch_size_;
size_t num_ids = 0; size_t num_ids = 0;
#pragma omp critical #pragma omp critical
{ { num_ids = edge_selector_->SampleWithoutReplacement(n, &edge_ids); }
num_ids = edge_selector_->SampleWithoutReplacement(n, &edge_ids);
}
edge_ids.resize(num_ids); edge_ids.resize(num_ids);
for (size_t i = 0; i < num_ids; ++i) { for (size_t i = 0; i < num_ids; ++i) {
edge_ids[i] = seed_edge_ids[edge_ids[i]]; edge_ids[i] = seed_edge_ids[edge_ids[i]];
...@@ -1730,18 +1693,16 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject { ...@@ -1730,18 +1693,16 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
// TODO(zhengda) what if there are duplicates in the src and dst vectors. // TODO(zhengda) what if there are duplicates in the src and dst vectors.
Subgraph subg = gptr_->EdgeSubgraph(worker_seeds, false); Subgraph subg = gptr_->EdgeSubgraph(worker_seeds, false);
positive_subgs[i] = ConvertRef(subg); positive_subgs[i] = ConvertRef(subg);
// For chunked negative sampling, we accept "chunk-head" for corrupting head // For chunked negative sampling, we accept "chunk-head" for corrupting
// nodes and "chunk-tail" for corrupting tail nodes. // head nodes and "chunk-tail" for corrupting tail nodes.
if (neg_mode_.substr(0, 5) == "chunk") { if (neg_mode_.substr(0, 5) == "chunk") {
NegSubgraph neg_subg = genChunkedNegEdgeSubgraph(subg, neg_mode_.substr(6), NegSubgraph neg_subg = genChunkedNegEdgeSubgraph(
neg_sample_size_, subg, neg_mode_.substr(6), neg_sample_size_, exclude_positive_,
exclude_positive_,
check_false_neg_); check_false_neg_);
negative_subgs[i] = ConvertRef(neg_subg); negative_subgs[i] = ConvertRef(neg_subg);
} else if (neg_mode_ == "head" || neg_mode_ == "tail") { } else if (neg_mode_ == "head" || neg_mode_ == "tail") {
NegSubgraph neg_subg = genNegEdgeSubgraph(subg, neg_mode_, NegSubgraph neg_subg = genNegEdgeSubgraph(
neg_sample_size_, subg, neg_mode_, neg_sample_size_, exclude_positive_,
exclude_positive_,
check_false_neg_); check_false_neg_);
negative_subgs[i] = ConvertRef(neg_subg); negative_subgs[i] = ConvertRef(neg_subg);
} }
...@@ -1753,7 +1714,8 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject { ...@@ -1753,7 +1714,8 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
} }
if (neg_mode_.size() > 0) { if (neg_mode_.size() > 0) {
positive_subgs.insert(positive_subgs.end(), negative_subgs.begin(), negative_subgs.end()); positive_subgs.insert(
positive_subgs.end(), negative_subgs.begin(), negative_subgs.end());
} }
*rv = List<SubgraphRef>(positive_subgs); *rv = List<SubgraphRef>(positive_subgs);
} }
...@@ -1762,7 +1724,8 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject { ...@@ -1762,7 +1724,8 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
curr_batch_id_ = 0; curr_batch_id_ = 0;
if (replacement_ == false) { if (replacement_ == false) {
const int64_t num_edges = edge_weight_->shape[0]; const int64_t num_edges = edge_weight_->shape[0];
const ValueType *edge_prob = static_cast<const ValueType*>(edge_weight_->data); const ValueType *edge_prob =
static_cast<const ValueType *>(edge_weight_->data);
std::vector<ValueType> eprob(num_edges); std::vector<ValueType> eprob(num_edges);
for (int64_t i = 0; i < num_edges; ++i) { for (int64_t i = 0; i < num_edges; ++i) {
eprob[i] = edge_prob[i]; eprob[i] = edge_prob[i];
...@@ -1775,8 +1738,8 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject { ...@@ -1775,8 +1738,8 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
DGL_DECLARE_OBJECT_TYPE_INFO(WeightedEdgeSamplerObject<ValueType>, Object); DGL_DECLARE_OBJECT_TYPE_INFO(WeightedEdgeSamplerObject<ValueType>, Object);
private: private:
void randomSample(size_t set_size, size_t num, std::vector<size_t>* out) { void randomSample(size_t set_size, size_t num, std::vector<size_t> *out) {
if (num < set_size) { if (num < set_size) {
std::unordered_set<size_t> sampled_idxs; std::unordered_set<size_t> sampled_idxs;
while (sampled_idxs.size() < num) { while (sampled_idxs.size() < num) {
...@@ -1792,13 +1755,13 @@ private: ...@@ -1792,13 +1755,13 @@ private:
} else { } else {
// If we need to sample all elements in the set, we don't need to // If we need to sample all elements in the set, we don't need to
// generate random numbers. // generate random numbers.
for (size_t i = 0; i < set_size; i++) for (size_t i = 0; i < set_size; i++) out->push_back(i);
out->push_back(i);
} }
} }
void randomSample(size_t set_size, size_t num, const std::vector<size_t> &exclude, void randomSample(
std::vector<size_t>* out) { size_t set_size, size_t num, const std::vector<size_t> &exclude,
std::vector<size_t> *out) {
std::unordered_map<size_t, int> sampled_idxs; std::unordered_map<size_t, int> sampled_idxs;
for (auto v : exclude) { for (auto v : exclude) {
sampled_idxs.insert(std::pair<size_t, int>(v, 0)); sampled_idxs.insert(std::pair<size_t, int>(v, 0));
...@@ -1830,7 +1793,7 @@ private: ...@@ -1830,7 +1793,7 @@ private:
} }
} }
private: private:
std::shared_ptr<ArrayHeap<ValueType>> edge_selector_; std::shared_ptr<ArrayHeap<ValueType>> edge_selector_;
std::shared_ptr<ArrayHeap<ValueType>> node_selector_; std::shared_ptr<ArrayHeap<ValueType>> node_selector_;
...@@ -1841,17 +1804,19 @@ private: ...@@ -1841,17 +1804,19 @@ private:
template class WeightedEdgeSamplerObject<float>; template class WeightedEdgeSamplerObject<float>;
class FloatWeightedEdgeSampler: public ObjectRef { class FloatWeightedEdgeSampler : public ObjectRef {
public: public:
FloatWeightedEdgeSampler() {} FloatWeightedEdgeSampler() {}
explicit FloatWeightedEdgeSampler(std::shared_ptr<runtime::Object> obj): ObjectRef(obj) {} explicit FloatWeightedEdgeSampler(std::shared_ptr<runtime::Object> obj)
: ObjectRef(obj) {}
WeightedEdgeSamplerObject<float>* operator->() const { WeightedEdgeSamplerObject<float> *operator->() const {
return static_cast<WeightedEdgeSamplerObject<float>*>(obj_.get()); return static_cast<WeightedEdgeSamplerObject<float> *>(obj_.get());
} }
std::shared_ptr<WeightedEdgeSamplerObject<float>> sptr() const { std::shared_ptr<WeightedEdgeSamplerObject<float>> sptr() const {
return CHECK_NOTNULL(std::dynamic_pointer_cast<WeightedEdgeSamplerObject<float>>(obj_)); return CHECK_NOTNULL(
std::dynamic_pointer_cast<WeightedEdgeSamplerObject<float>>(obj_));
} }
operator bool() const { return this->defined(); } operator bool() const { return this->defined(); }
...@@ -1859,7 +1824,7 @@ class FloatWeightedEdgeSampler: public ObjectRef { ...@@ -1859,7 +1824,7 @@ class FloatWeightedEdgeSampler: public ObjectRef {
}; };
DGL_REGISTER_GLOBAL("sampling._CAPI_CreateWeightedEdgeSampler") DGL_REGISTER_GLOBAL("sampling._CAPI_CreateWeightedEdgeSampler")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
// arguments // arguments
GraphRef g = args[0]; GraphRef g = args[0];
IdArray seed_edges = args[1]; IdArray seed_edges = args[1];
...@@ -1881,13 +1846,17 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_CreateWeightedEdgeSampler") ...@@ -1881,13 +1846,17 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_CreateWeightedEdgeSampler")
CHECK(aten::IsValidIdArray(seed_edges)); CHECK(aten::IsValidIdArray(seed_edges));
CHECK_EQ(seed_edges->ctx.device_type, kDGLCPU) CHECK_EQ(seed_edges->ctx.device_type, kDGLCPU)
<< "WeightedEdgeSampler only support CPU sampling"; << "WeightedEdgeSampler only support CPU sampling";
CHECK(edge_weight->dtype.code == kDGLFloat) << "edge_weight should be FloatType"; CHECK(edge_weight->dtype.code == kDGLFloat)
CHECK(edge_weight->dtype.bits == 32) << "WeightedEdgeSampler only support float weight"; << "edge_weight should be FloatType";
CHECK(edge_weight->dtype.bits == 32)
<< "WeightedEdgeSampler only support float weight";
CHECK_EQ(edge_weight->ctx.device_type, kDGLCPU) CHECK_EQ(edge_weight->ctx.device_type, kDGLCPU)
<< "WeightedEdgeSampler only support CPU sampling"; << "WeightedEdgeSampler only support CPU sampling";
if (node_weight->shape[0] > 0) { if (node_weight->shape[0] > 0) {
CHECK(node_weight->dtype.code == kDGLFloat) << "node_weight should be FloatType"; CHECK(node_weight->dtype.code == kDGLFloat)
CHECK(node_weight->dtype.bits == 32) << "WeightedEdgeSampler only support float weight"; << "node_weight should be FloatType";
CHECK(node_weight->dtype.bits == 32)
<< "WeightedEdgeSampler only support float weight";
CHECK_EQ(node_weight->ctx.device_type, kDGLCPU) CHECK_EQ(node_weight->ctx.device_type, kDGLCPU)
<< "WeightedEdgeSampler only support CPU sampling"; << "WeightedEdgeSampler only support CPU sampling";
} }
...@@ -1899,36 +1868,26 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_CreateWeightedEdgeSampler") ...@@ -1899,36 +1868,26 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_CreateWeightedEdgeSampler")
BuildCoo(*gptr); BuildCoo(*gptr);
const int64_t num_seeds = seed_edges->shape[0]; const int64_t num_seeds = seed_edges->shape[0];
const int64_t num_workers = std::min(max_num_workers, const int64_t num_workers =
(num_seeds + batch_size - 1) / batch_size); std::min(max_num_workers, (num_seeds + batch_size - 1) / batch_size);
auto o = std::make_shared<WeightedEdgeSamplerObject<float>>(gptr, auto o = std::make_shared<WeightedEdgeSamplerObject<float>>(
seed_edges, gptr, seed_edges, edge_weight, node_weight, batch_size, num_workers,
edge_weight, replacement, reset, neg_mode, neg_sample_size, chunk_size,
node_weight, exclude_positive, check_false_neg, relations);
batch_size,
num_workers,
replacement,
reset,
neg_mode,
neg_sample_size,
chunk_size,
exclude_positive,
check_false_neg,
relations);
*rv = o; *rv = o;
}); });
DGL_REGISTER_GLOBAL("sampling._CAPI_FetchWeightedEdgeSample") DGL_REGISTER_GLOBAL("sampling._CAPI_FetchWeightedEdgeSample")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
FloatWeightedEdgeSampler sampler = args[0]; FloatWeightedEdgeSampler sampler = args[0];
sampler->Fetch(rv); sampler->Fetch(rv);
}); });
DGL_REGISTER_GLOBAL("sampling._CAPI_ResetWeightedEdgeSample") DGL_REGISTER_GLOBAL("sampling._CAPI_ResetWeightedEdgeSample")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
FloatWeightedEdgeSampler sampler = args[0]; FloatWeightedEdgeSampler sampler = args[0];
sampler->Reset(); sampler->Reset();
}); });
} // namespace dgl } // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment