Unverified Commit 81831111 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4805)



* [Misc] clang-format auto fix.

* fix
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 16e771c0
...@@ -4,9 +4,11 @@ ...@@ -4,9 +4,11 @@
* \brief Geometry operator CPU implementation * \brief Geometry operator CPU implementation
*/ */
#include <dgl/random.h> #include <dgl/random.h>
#include <numeric> #include <numeric>
#include <vector>
#include <utility> #include <utility>
#include <vector>
#include "../geometry_op.h" #include "../geometry_op.h"
namespace dgl { namespace dgl {
...@@ -25,73 +27,83 @@ void IndexShuffle(IdType *idxs, int64_t num_elems) { ...@@ -25,73 +27,83 @@ void IndexShuffle(IdType *idxs, int64_t num_elems) {
template void IndexShuffle<int32_t>(int32_t *idxs, int64_t num_elems); template void IndexShuffle<int32_t>(int32_t *idxs, int64_t num_elems);
template void IndexShuffle<int64_t>(int64_t *idxs, int64_t num_elems); template void IndexShuffle<int64_t>(int64_t *idxs, int64_t num_elems);
/*! \brief Groupwise index shuffle algorithm. This function will perform shuffle in subarrays /*! \brief Groupwise index shuffle algorithm. This function will perform shuffle
* indicated by group index. The group index is similar to indptr in CSRMatrix. * in subarrays indicated by group index. The group index is similar to indptr
* * in CSRMatrix.
*
* \param group_idxs group index array. * \param group_idxs group index array.
* \param idxs index array for shuffle. * \param idxs index array for shuffle.
* \param num_groups_idxs length of group_idxs * \param num_groups_idxs length of group_idxs
* \param num_elems length of idxs * \param num_elems length of idxs
*/ */
template <typename IdType> template <typename IdType>
void GroupIndexShuffle(const IdType *group_idxs, IdType *idxs, void GroupIndexShuffle(
int64_t num_groups_idxs, int64_t num_elems) { const IdType *group_idxs, IdType *idxs, int64_t num_groups_idxs,
int64_t num_elems) {
if (num_groups_idxs < 2) return; // empty idxs array if (num_groups_idxs < 2) return; // empty idxs array
CHECK_LE(group_idxs[num_groups_idxs - 1], num_elems) << "group_idxs out of range"; CHECK_LE(group_idxs[num_groups_idxs - 1], num_elems)
<< "group_idxs out of range";
for (int64_t i = 0; i < num_groups_idxs - 1; ++i) { for (int64_t i = 0; i < num_groups_idxs - 1; ++i) {
auto subarray_len = group_idxs[i + 1] - group_idxs[i]; auto subarray_len = group_idxs[i + 1] - group_idxs[i];
IndexShuffle(idxs + group_idxs[i], subarray_len); IndexShuffle(idxs + group_idxs[i], subarray_len);
} }
} }
template void GroupIndexShuffle<int32_t>( template void GroupIndexShuffle<int32_t>(
const int32_t *group_idxs, int32_t *idxs, int64_t num_groups_idxs, int64_t num_elems); const int32_t *group_idxs, int32_t *idxs, int64_t num_groups_idxs,
int64_t num_elems);
template void GroupIndexShuffle<int64_t>( template void GroupIndexShuffle<int64_t>(
const int64_t *group_idxs, int64_t *idxs, int64_t num_groups_idxs, int64_t num_elems); const int64_t *group_idxs, int64_t *idxs, int64_t num_groups_idxs,
int64_t num_elems);
template <typename IdType> template <typename IdType>
IdArray RandomPerm(int64_t num_nodes) { IdArray RandomPerm(int64_t num_nodes) {
IdArray perm = aten::NewIdArray(num_nodes, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8); IdArray perm =
IdType* perm_data = static_cast<IdType*>(perm->data); aten::NewIdArray(num_nodes, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
IdType *perm_data = static_cast<IdType *>(perm->data);
std::iota(perm_data, perm_data + num_nodes, 0); std::iota(perm_data, perm_data + num_nodes, 0);
IndexShuffle(perm_data, num_nodes); IndexShuffle(perm_data, num_nodes);
return perm; return perm;
} }
template <typename IdType> template <typename IdType>
IdArray GroupRandomPerm(const IdType *group_idxs, int64_t num_group_idxs, int64_t num_nodes) { IdArray GroupRandomPerm(
IdArray perm = aten::NewIdArray(num_nodes, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8); const IdType *group_idxs, int64_t num_group_idxs, int64_t num_nodes) {
IdType* perm_data = static_cast<IdType*>(perm->data); IdArray perm =
aten::NewIdArray(num_nodes, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
IdType *perm_data = static_cast<IdType *>(perm->data);
std::iota(perm_data, perm_data + num_nodes, 0); std::iota(perm_data, perm_data + num_nodes, 0);
GroupIndexShuffle(group_idxs, perm_data, num_group_idxs, num_nodes); GroupIndexShuffle(group_idxs, perm_data, num_group_idxs, num_nodes);
return perm; return perm;
} }
/*! /*!
* \brief Farthest Point Sampler without the need to compute all pairs of distance. * \brief Farthest Point Sampler without the need to compute all pairs of
* * distance.
* The input array has shape (N, d), where N is the number of points, and d is the dimension. *
* It consists of a (flatten) batch of point clouds. * The input array has shape (N, d), where N is the number of points, and d is
* the dimension. It consists of a (flatten) batch of point clouds.
* *
* In each batch, the algorithm starts with the sample index specified by ``start_idx``. * In each batch, the algorithm starts with the sample index specified by
* Then for each point, we maintain the minimum to-sample distance. * ``start_idx``. Then for each point, we maintain the minimum to-sample
* Finally, we pick the point with the maximum such distance. * distance. Finally, we pick the point with the maximum such distance. This
* This process will be repeated for ``sample_points`` - 1 times. * process will be repeated for ``sample_points`` - 1 times.
*/ */
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points, void FarthestPointSampler(
NDArray dist, IdArray start_idx, IdArray result) { NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
const FloatType* array_data = static_cast<FloatType*>(array->data); IdArray start_idx, IdArray result) {
const FloatType *array_data = static_cast<FloatType *>(array->data);
const int64_t point_in_batch = array->shape[0] / batch_size; const int64_t point_in_batch = array->shape[0] / batch_size;
const int64_t dim = array->shape[1]; const int64_t dim = array->shape[1];
// distance // distance
FloatType* dist_data = static_cast<FloatType*>(dist->data); FloatType *dist_data = static_cast<FloatType *>(dist->data);
// sample for each cloud in the batch // sample for each cloud in the batch
IdType* start_idx_data = static_cast<IdType*>(start_idx->data); IdType *start_idx_data = static_cast<IdType *>(start_idx->data);
// return value // return value
IdType* ret_data = static_cast<IdType*>(result->data); IdType *ret_data = static_cast<IdType *>(result->data);
int64_t array_start = 0, ret_start = 0; int64_t array_start = 0, ret_start = 0;
// loop for each point cloud sample in this batch // loop for each point cloud sample in this batch
...@@ -112,7 +124,7 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin ...@@ -112,7 +124,7 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin
FloatType one_dist = 0; FloatType one_dist = 0;
for (auto d = 0; d < dim; d++) { for (auto d = 0; d < dim; d++) {
FloatType tmp = array_data[(array_start + j) * dim + d] - FloatType tmp = array_data[(array_start + j) * dim + d] -
array_data[(array_start + sample_idx) * dim + d]; array_data[(array_start + sample_idx) * dim + d];
one_dist += tmp * tmp; one_dist += tmp * tmp;
} }
...@@ -136,29 +148,30 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin ...@@ -136,29 +148,30 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin
} }
} }
template void FarthestPointSampler<kDGLCPU, float, int32_t>( template void FarthestPointSampler<kDGLCPU, float, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
NDArray dist, IdArray start_idx, IdArray result); IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDGLCPU, float, int64_t>( template void FarthestPointSampler<kDGLCPU, float, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
NDArray dist, IdArray start_idx, IdArray result); IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDGLCPU, double, int32_t>( template void FarthestPointSampler<kDGLCPU, double, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
NDArray dist, IdArray start_idx, IdArray result); IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDGLCPU, double, int64_t>( template void FarthestPointSampler<kDGLCPU, double, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
NDArray dist, IdArray start_idx, IdArray result); IdArray start_idx, IdArray result);
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result) { void WeightedNeighborMatching(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result) {
const int64_t num_nodes = result->shape[0]; const int64_t num_nodes = result->shape[0];
const IdType *indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType*>(csr.indices->data); const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
IdType *result_data = static_cast<IdType*>(result->data); IdType *result_data = static_cast<IdType *>(result->data);
FloatType *weight_data = static_cast<FloatType*>(weight->data); FloatType *weight_data = static_cast<FloatType *>(weight->data);
// build node visiting order // build node visiting order
IdArray vis_order = RandomPerm<IdType>(num_nodes); IdArray vis_order = RandomPerm<IdType>(num_nodes);
IdType *vis_order_data = static_cast<IdType*>(vis_order->data); IdType *vis_order_data = static_cast<IdType *>(vis_order->data);
for (int64_t n = 0; n < num_nodes; ++n) { for (int64_t n = 0; n < num_nodes; ++n) {
auto u = vis_order_data[n]; auto u = vis_order_data[n];
...@@ -193,16 +206,16 @@ template void WeightedNeighborMatching<kDGLCPU, double, int64_t>( ...@@ -193,16 +206,16 @@ template void WeightedNeighborMatching<kDGLCPU, double, int64_t>(
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) { void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
const int64_t num_nodes = result->shape[0]; const int64_t num_nodes = result->shape[0];
const IdType *indptr_data = static_cast<IdType*>(csr.indptr->data); const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType*>(csr.indices->data); const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
IdType *result_data = static_cast<IdType*>(result->data); IdType *result_data = static_cast<IdType *>(result->data);
// build vis order // build vis order
IdArray u_vis_order = RandomPerm<IdType>(num_nodes); IdArray u_vis_order = RandomPerm<IdType>(num_nodes);
IdType *u_vis_order_data = static_cast<IdType*>(u_vis_order->data); IdType *u_vis_order_data = static_cast<IdType *>(u_vis_order->data);
IdArray v_vis_order = GroupRandomPerm<IdType>( IdArray v_vis_order = GroupRandomPerm<IdType>(
indptr_data, csr.indptr->shape[0], csr.indices->shape[0]); indptr_data, csr.indptr->shape[0], csr.indices->shape[0]);
IdType *v_vis_order_data = static_cast<IdType*>(v_vis_order->data); IdType *v_vis_order_data = static_cast<IdType *>(v_vis_order->data);
for (int64_t n = 0; n < num_nodes; ++n) { for (int64_t n = 0; n < num_nodes; ++n) {
auto u = u_vis_order_data[n]; auto u = u_vis_order_data[n];
...@@ -221,8 +234,10 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) { ...@@ -221,8 +234,10 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
} }
} }
} }
template void NeighborMatching<kDGLCPU, int32_t>(const aten::CSRMatrix &csr, IdArray result); template void NeighborMatching<kDGLCPU, int32_t>(
template void NeighborMatching<kDGLCPU, int64_t>(const aten::CSRMatrix &csr, IdArray result); const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDGLCPU, int64_t>(
const aten::CSRMatrix &csr, IdArray result);
} // namespace impl } // namespace impl
} // namespace geometry } // namespace geometry
......
...@@ -3,14 +3,16 @@ ...@@ -3,14 +3,16 @@
* \file geometry/cuda/edge_coarsening_impl.cu * \file geometry/cuda/edge_coarsening_impl.cu
* \brief Edge coarsening CUDA implementation * \brief Edge coarsening CUDA implementation
*/ */
#include <curand_kernel.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/random.h> #include <dgl/random.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <curand_kernel.h>
#include <cstdint> #include <cstdint>
#include "../geometry_op.h"
#include "../../runtime/cuda/cuda_common.h"
#include "../../array/cuda/utils.h" #include "../../array/cuda/utils.h"
#include "../../runtime/cuda/cuda_common.h"
#include "../geometry_op.h"
#define BLOCKS(N, T) (N + T - 1) / T #define BLOCKS(N, T) (N + T - 1) / T
...@@ -26,7 +28,8 @@ constexpr int EMPTY_IDX = -1; ...@@ -26,7 +28,8 @@ constexpr int EMPTY_IDX = -1;
__device__ bool done_d; __device__ bool done_d;
__global__ void init_done_kernel() { done_d = true; } __global__ void init_done_kernel() { done_d = true; }
__global__ void generate_uniform_kernel(float *ret_values, size_t num, uint64_t seed) { __global__ void generate_uniform_kernel(
float *ret_values, size_t num, uint64_t seed) {
size_t id = blockIdx.x * blockDim.x + threadIdx.x; size_t id = blockIdx.x * blockDim.x + threadIdx.x;
if (id < num) { if (id < num) {
curandState state; curandState state;
...@@ -36,7 +39,8 @@ __global__ void generate_uniform_kernel(float *ret_values, size_t num, uint64_t ...@@ -36,7 +39,8 @@ __global__ void generate_uniform_kernel(float *ret_values, size_t num, uint64_t
} }
template <typename IdType> template <typename IdType>
__global__ void colorize_kernel(const float *prop, int64_t num_elem, IdType *result) { __global__ void colorize_kernel(
const float *prop, int64_t num_elem, IdType *result) {
const IdType idx = blockIdx.x * blockDim.x + threadIdx.x; const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_elem) { if (idx < num_elem) {
if (result[idx] < 0) { // if unmatched if (result[idx] < 0) { // if unmatched
...@@ -47,9 +51,9 @@ __global__ void colorize_kernel(const float *prop, int64_t num_elem, IdType *res ...@@ -47,9 +51,9 @@ __global__ void colorize_kernel(const float *prop, int64_t num_elem, IdType *res
} }
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
__global__ void weighted_propose_kernel(const IdType *indptr, const IdType *indices, __global__ void weighted_propose_kernel(
const FloatType *weights, int64_t num_elem, const IdType *indptr, const IdType *indices, const FloatType *weights,
IdType *proposal, IdType *result) { int64_t num_elem, IdType *proposal, IdType *result) {
const IdType idx = blockIdx.x * blockDim.x + threadIdx.x; const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_elem) { if (idx < num_elem) {
if (result[idx] != BLUE) return; if (result[idx] != BLUE) return;
...@@ -61,8 +65,7 @@ __global__ void weighted_propose_kernel(const IdType *indptr, const IdType *indi ...@@ -61,8 +65,7 @@ __global__ void weighted_propose_kernel(const IdType *indptr, const IdType *indi
for (IdType i = indptr[idx]; i < indptr[idx + 1]; ++i) { for (IdType i = indptr[idx]; i < indptr[idx + 1]; ++i) {
auto v = indices[i]; auto v = indices[i];
if (result[v] < 0) if (result[v] < 0) has_unmatched_neighbor = true;
has_unmatched_neighbor = true;
if (result[v] == RED && weights[i] >= weight_max) { if (result[v] == RED && weights[i] >= weight_max) {
v_max = v; v_max = v;
weight_max = weights[i]; weight_max = weights[i];
...@@ -70,15 +73,14 @@ __global__ void weighted_propose_kernel(const IdType *indptr, const IdType *indi ...@@ -70,15 +73,14 @@ __global__ void weighted_propose_kernel(const IdType *indptr, const IdType *indi
} }
proposal[idx] = v_max; proposal[idx] = v_max;
if (!has_unmatched_neighbor) if (!has_unmatched_neighbor) result[idx] = idx;
result[idx] = idx;
} }
} }
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
__global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indices, __global__ void weighted_respond_kernel(
const FloatType *weights, int64_t num_elem, const IdType *indptr, const IdType *indices, const FloatType *weights,
IdType *proposal, IdType *result) { int64_t num_elem, IdType *proposal, IdType *result) {
const IdType idx = blockIdx.x * blockDim.x + threadIdx.x; const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_elem) { if (idx < num_elem) {
if (result[idx] != RED) return; if (result[idx] != RED) return;
...@@ -93,9 +95,7 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi ...@@ -93,9 +95,7 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi
if (result[v] < 0) { if (result[v] < 0) {
has_unmatched_neighbors = true; has_unmatched_neighbors = true;
} }
if (result[v] == BLUE if (result[v] == BLUE && proposal[v] == idx && weights[i] >= weight_max) {
&& proposal[v] == idx
&& weights[i] >= weight_max) {
v_max = v; v_max = v;
weight_max = weights[i]; weight_max = weights[i];
} }
...@@ -105,8 +105,7 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi ...@@ -105,8 +105,7 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi
result[idx] = min(idx, v_max); result[idx] = min(idx, v_max);
} }
if (!has_unmatched_neighbors) if (!has_unmatched_neighbors) result[idx] = idx;
result[idx] = idx;
} }
} }
...@@ -114,8 +113,8 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi ...@@ -114,8 +113,8 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi
* nodes with BLUE(-1) and RED(-2) and checks whether the node matching * nodes with BLUE(-1) and RED(-2) and checks whether the node matching
* process has finished. * process has finished.
*/ */
template<typename IdType> template <typename IdType>
bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) { bool Colorize(IdType *result_data, int64_t num_nodes, float *const prop) {
// initial done signal // initial done signal
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
CUDA_KERNEL_CALL(init_done_kernel, 1, 1, 0, stream); CUDA_KERNEL_CALL(init_done_kernel, 1, 1, 0, stream);
...@@ -124,21 +123,24 @@ bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) { ...@@ -124,21 +123,24 @@ bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) {
uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX); uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);
auto num_threads = cuda::FindNumThreads(num_nodes); auto num_threads = cuda::FindNumThreads(num_nodes);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads)); auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));
CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, stream, CUDA_KERNEL_CALL(
prop, num_nodes, seed); generate_uniform_kernel, num_blocks, num_threads, 0, stream, prop,
num_nodes, seed);
// call kernel // call kernel
CUDA_KERNEL_CALL(colorize_kernel, num_blocks, num_threads, 0, stream, CUDA_KERNEL_CALL(
prop, num_nodes, result_data); colorize_kernel, num_blocks, num_threads, 0, stream, prop, num_nodes,
result_data);
bool done_h = false; bool done_h = false;
CUDA_CALL(cudaMemcpyFromSymbol(&done_h, done_d, sizeof(done_h), 0, cudaMemcpyDeviceToHost)); CUDA_CALL(cudaMemcpyFromSymbol(
&done_h, done_d, sizeof(done_h), 0, cudaMemcpyDeviceToHost));
return done_h; return done_h;
} }
/*! \brief Weighted neighbor matching procedure (GPU version). /*! \brief Weighted neighbor matching procedure (GPU version).
* This implementation is from `A GPU Algorithm for Greedy Graph Matching * This implementation is from `A GPU Algorithm for Greedy Graph Matching
* <http://www.staff.science.uu.nl/~bisse101/Articles/match12.pdf>`__ * <http://www.staff.science.uu.nl/~bisse101/Articles/match12.pdf>`__
* *
* This algorithm has three parts: colorize, propose and respond. * This algorithm has three parts: colorize, propose and respond.
* In colorize procedure, each unmarked node will be marked as BLUE or * In colorize procedure, each unmarked node will be marked as BLUE or
* RED randomly. If all nodes are marked, finish and return. * RED randomly. If all nodes are marked, finish and return.
...@@ -151,9 +153,10 @@ bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) { ...@@ -151,9 +153,10 @@ bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) {
* pair and mark them with the smaller id between them. * pair and mark them with the smaller id between them.
*/ */
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result) { void WeightedNeighborMatching(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = result->ctx; const auto &ctx = result->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
device->SetDevice(ctx); device->SetDevice(ctx);
...@@ -162,34 +165,38 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, ...@@ -162,34 +165,38 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight,
IdArray proposal = aten::Full(-1, num_nodes, sizeof(IdType) * 8, ctx); IdArray proposal = aten::Full(-1, num_nodes, sizeof(IdType) * 8, ctx);
// get data ptrs // get data ptrs
IdType *indptr_data = static_cast<IdType*>(csr.indptr->data); IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
IdType *indices_data = static_cast<IdType*>(csr.indices->data); IdType *indices_data = static_cast<IdType *>(csr.indices->data);
IdType *result_data = static_cast<IdType*>(result->data); IdType *result_data = static_cast<IdType *>(result->data);
IdType *proposal_data = static_cast<IdType*>(proposal->data); IdType *proposal_data = static_cast<IdType *>(proposal->data);
FloatType *weight_data = static_cast<FloatType*>(weight->data); FloatType *weight_data = static_cast<FloatType *>(weight->data);
// allocate workspace for prop used in Colorize() // allocate workspace for prop used in Colorize()
float *prop = static_cast<float*>( float *prop = static_cast<float *>(
device->AllocWorkspace(ctx, num_nodes * sizeof(float))); device->AllocWorkspace(ctx, num_nodes * sizeof(float)));
auto num_threads = cuda::FindNumThreads(num_nodes); auto num_threads = cuda::FindNumThreads(num_nodes);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads)); auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));
while (!Colorize<IdType>(result_data, num_nodes, prop)) { while (!Colorize<IdType>(result_data, num_nodes, prop)) {
CUDA_KERNEL_CALL(weighted_propose_kernel, num_blocks, num_threads, 0, stream, CUDA_KERNEL_CALL(
indptr_data, indices_data, weight_data, num_nodes, proposal_data, result_data); weighted_propose_kernel, num_blocks, num_threads, 0, stream,
CUDA_KERNEL_CALL(weighted_respond_kernel, num_blocks, num_threads, 0, stream, indptr_data, indices_data, weight_data, num_nodes, proposal_data,
indptr_data, indices_data, weight_data, num_nodes, proposal_data, result_data); result_data);
CUDA_KERNEL_CALL(
weighted_respond_kernel, num_blocks, num_threads, 0, stream,
indptr_data, indices_data, weight_data, num_nodes, proposal_data,
result_data);
} }
device->FreeWorkspace(ctx, prop); device->FreeWorkspace(ctx, prop);
} }
template void WeightedNeighborMatching<kDGLCUDA, float, int32_t>( template void WeightedNeighborMatching<kDGLCUDA, float, int32_t>(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result); const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
template void WeightedNeighborMatching<kDGLCUDA, float, int64_t>( template void WeightedNeighborMatching<kDGLCUDA, float, int64_t>(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result); const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
template void WeightedNeighborMatching<kDGLCUDA, double, int32_t>( template void WeightedNeighborMatching<kDGLCUDA, double, int32_t>(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result); const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
template void WeightedNeighborMatching<kDGLCUDA, double, int64_t>( template void WeightedNeighborMatching<kDGLCUDA, double, int64_t>(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result); const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
/*! \brief Unweighted neighbor matching procedure (GPU version). /*! \brief Unweighted neighbor matching procedure (GPU version).
* Instead of directly sample neighbors, we assign each neighbor * Instead of directly sample neighbors, we assign each neighbor
...@@ -204,25 +211,28 @@ template void WeightedNeighborMatching<kDGLCUDA, double, int64_t>( ...@@ -204,25 +211,28 @@ template void WeightedNeighborMatching<kDGLCUDA, double, int64_t>(
template <DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) { void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
const int64_t num_edges = csr.indices->shape[0]; const int64_t num_edges = csr.indices->shape[0];
const auto& ctx = result->ctx; const auto &ctx = result->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
device->SetDevice(ctx); device->SetDevice(ctx);
// generate random weights // generate random weights
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
NDArray weight = NDArray::Empty( NDArray weight = NDArray::Empty(
{num_edges}, DGLDataType{kDGLFloat, sizeof(float) * 8, 1}, ctx); {num_edges}, DGLDataType{kDGLFloat, sizeof(float) * 8, 1}, ctx);
float *weight_data = static_cast<float*>(weight->data); float *weight_data = static_cast<float *>(weight->data);
uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX); uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);
auto num_threads = cuda::FindNumThreads(num_edges); auto num_threads = cuda::FindNumThreads(num_edges);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_edges, num_threads)); auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_edges, num_threads));
CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, stream, CUDA_KERNEL_CALL(
weight_data, num_edges, seed); generate_uniform_kernel, num_blocks, num_threads, 0, stream, weight_data,
num_edges, seed);
WeightedNeighborMatching<XPU, float, IdType>(csr, weight, result); WeightedNeighborMatching<XPU, float, IdType>(csr, weight, result);
} }
template void NeighborMatching<kDGLCUDA, int32_t>(const aten::CSRMatrix &csr, IdArray result); template void NeighborMatching<kDGLCUDA, int32_t>(
template void NeighborMatching<kDGLCUDA, int64_t>(const aten::CSRMatrix &csr, IdArray result); const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDGLCUDA, int64_t>(
const aten::CSRMatrix &csr, IdArray result);
} // namespace impl } // namespace impl
} // namespace geometry } // namespace geometry
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "../../c_api_common.h" #include "../../c_api_common.h"
#include "../../runtime/cuda/cuda_common.h"
#include "../geometry_op.h" #include "../geometry_op.h"
#define THREADS 1024 #define THREADS 1024
...@@ -16,21 +16,23 @@ namespace geometry { ...@@ -16,21 +16,23 @@ namespace geometry {
namespace impl { namespace impl {
/*! /*!
* \brief Farthest Point Sampler without the need to compute all pairs of distance. * \brief Farthest Point Sampler without the need to compute all pairs of
* * distance.
* The input array has shape (N, d), where N is the number of points, and d is the dimension. *
* It consists of a (flatten) batch of point clouds. * The input array has shape (N, d), where N is the number of points, and d is
* the dimension. It consists of a (flatten) batch of point clouds.
* *
* In each batch, the algorithm starts with the sample index specified by ``start_idx``. * In each batch, the algorithm starts with the sample index specified by
* Then for each point, we maintain the minimum to-sample distance. * ``start_idx``. Then for each point, we maintain the minimum to-sample
* Finally, we pick the point with the maximum such distance. * distance. Finally, we pick the point with the maximum such distance. This
* This process will be repeated for ``sample_points`` - 1 times. * process will be repeated for ``sample_points`` - 1 times.
*/ */
template <typename FloatType, typename IdType> template <typename FloatType, typename IdType>
__global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size, __global__ void fps_kernel(
const int64_t sample_points, const int64_t point_in_batch, const FloatType* array_data, const int64_t batch_size,
const int64_t dim, const IdType *start_idx, const int64_t sample_points, const int64_t point_in_batch,
FloatType *dist_data, IdType *ret_data) { const int64_t dim, const IdType* start_idx, FloatType* dist_data,
IdType* ret_data) {
const int64_t thread_idx = threadIdx.x; const int64_t thread_idx = threadIdx.x;
const int64_t batch_idx = blockIdx.x; const int64_t batch_idx = blockIdx.x;
...@@ -59,7 +61,7 @@ __global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size ...@@ -59,7 +61,7 @@ __global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size
FloatType one_dist = (FloatType)(0.); FloatType one_dist = (FloatType)(0.);
for (auto d = 0; d < dim; d++) { for (auto d = 0; d < dim; d++) {
FloatType tmp = array_data[(array_start + j) * dim + d] - FloatType tmp = array_data[(array_start + j) * dim + d] -
array_data[(array_start + sample_idx) * dim + d]; array_data[(array_start + sample_idx) * dim + d];
one_dist += tmp * tmp; one_dist += tmp * tmp;
} }
...@@ -79,10 +81,10 @@ __global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size ...@@ -79,10 +81,10 @@ __global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size
FloatType best = dist_max_ht[0]; FloatType best = dist_max_ht[0];
int64_t best_idx = dist_argmax_ht[0]; int64_t best_idx = dist_argmax_ht[0];
for (auto j = 1; j < THREADS; j++) { for (auto j = 1; j < THREADS; j++) {
if (dist_max_ht[j] > best) { if (dist_max_ht[j] > best) {
best = dist_max_ht[j]; best = dist_max_ht[j];
best_idx = dist_argmax_ht[j]; best_idx = dist_argmax_ht[j];
} }
} }
ret_data[ret_start + i + 1] = (IdType)(best_idx); ret_data[ret_start + i + 1] = (IdType)(best_idx);
} }
...@@ -90,8 +92,9 @@ __global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size ...@@ -90,8 +92,9 @@ __global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size
} }
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points, void FarthestPointSampler(
NDArray dist, IdArray start_idx, IdArray result) { NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
IdArray start_idx, IdArray result) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const FloatType* array_data = static_cast<FloatType*>(array->data); const FloatType* array_data = static_cast<FloatType*>(array->data);
...@@ -109,24 +112,23 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin ...@@ -109,24 +112,23 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin
IdType* start_idx_data = static_cast<IdType*>(start_idx->data); IdType* start_idx_data = static_cast<IdType*>(start_idx->data);
CUDA_CALL(cudaSetDevice(array->ctx.device_id)); CUDA_CALL(cudaSetDevice(array->ctx.device_id));
CUDA_KERNEL_CALL(fps_kernel, CUDA_KERNEL_CALL(
batch_size, THREADS, 0, stream, fps_kernel, batch_size, THREADS, 0, stream, array_data, batch_size,
array_data, batch_size, sample_points, sample_points, point_in_batch, dim, start_idx_data, dist_data, ret_data);
point_in_batch, dim, start_idx_data, dist_data, ret_data);
} }
template void FarthestPointSampler<kDGLCUDA, float, int32_t>( template void FarthestPointSampler<kDGLCUDA, float, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
NDArray dist, IdArray start_idx, IdArray result); IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDGLCUDA, float, int64_t>( template void FarthestPointSampler<kDGLCUDA, float, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
NDArray dist, IdArray start_idx, IdArray result); IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDGLCUDA, double, int32_t>( template void FarthestPointSampler<kDGLCUDA, double, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
NDArray dist, IdArray start_idx, IdArray result); IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDGLCUDA, double, int64_t>( template void FarthestPointSampler<kDGLCUDA, double, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
NDArray dist, IdArray start_idx, IdArray result); IdArray start_idx, IdArray result);
} // namespace impl } // namespace impl
} // namespace geometry } // namespace geometry
......
...@@ -4,94 +4,106 @@ ...@@ -4,94 +4,106 @@
* \brief DGL geometry utilities implementation * \brief DGL geometry utilities implementation
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <dgl/runtime/ndarray.h>
#include "../array/check.h"
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./geometry_op.h" #include "./geometry_op.h"
#include "../array/check.h"
using namespace dgl::runtime; using namespace dgl::runtime;
namespace dgl { namespace dgl {
namespace geometry { namespace geometry {
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points, void FarthestPointSampler(
NDArray dist, IdArray start_idx, IdArray result) { NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
IdArray start_idx, IdArray result) {
CHECK_EQ(array->ctx, result->ctx) << "Array and the result should be on the same device."; CHECK_EQ(array->ctx, result->ctx)
CHECK_EQ(array->shape[0], dist->shape[0]) << "Shape of array and dist mismatch"; << "Array and the result should be on the same device.";
CHECK_EQ(start_idx->shape[0], batch_size) << "Shape of start_idx and batch_size mismatch"; CHECK_EQ(array->shape[0], dist->shape[0])
CHECK_EQ(result->shape[0], batch_size * sample_points) << "Invalid shape of result"; << "Shape of array and dist mismatch";
CHECK_EQ(start_idx->shape[0], batch_size)
<< "Shape of start_idx and batch_size mismatch";
CHECK_EQ(result->shape[0], batch_size * sample_points)
<< "Invalid shape of result";
ATEN_FLOAT_TYPE_SWITCH(array->dtype, FloatType, "values", { ATEN_FLOAT_TYPE_SWITCH(array->dtype, FloatType, "values", {
ATEN_ID_TYPE_SWITCH(result->dtype, IdType, { ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "FarthestPointSampler", { ATEN_XPU_SWITCH_CUDA(
impl::FarthestPointSampler<XPU, FloatType, IdType>( array->ctx.device_type, XPU, "FarthestPointSampler", {
array, batch_size, sample_points, dist, start_idx, result); impl::FarthestPointSampler<XPU, FloatType, IdType>(
}); array, batch_size, sample_points, dist, start_idx, result);
});
}); });
}); });
} }
void NeighborMatching(HeteroGraphPtr graph, const NDArray weight, IdArray result) { void NeighborMatching(
HeteroGraphPtr graph, const NDArray weight, IdArray result) {
if (!aten::IsNullArray(weight)) { if (!aten::IsNullArray(weight)) {
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "NeighborMatching", { ATEN_XPU_SWITCH_CUDA(
ATEN_FLOAT_TYPE_SWITCH(weight->dtype, FloatType, "weight", { graph->Context().device_type, XPU, "NeighborMatching", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_FLOAT_TYPE_SWITCH(weight->dtype, FloatType, "weight", {
impl::WeightedNeighborMatching<XPU, FloatType, IdType>( ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
graph->GetCSRMatrix(0), weight, result); impl::WeightedNeighborMatching<XPU, FloatType, IdType>(
graph->GetCSRMatrix(0), weight, result);
});
});
}); });
});
});
} else { } else {
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "NeighborMatching", { ATEN_XPU_SWITCH_CUDA(
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { graph->Context().device_type, XPU, "NeighborMatching", {
impl::NeighborMatching<XPU, IdType>( ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
graph->GetCSRMatrix(0), result); impl::NeighborMatching<XPU, IdType>(graph->GetCSRMatrix(0), result);
}); });
}); });
} }
} }
///////////////////////// C APIs ///////////////////////// ///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("geometry._CAPI_FarthestPointSampler") DGL_REGISTER_GLOBAL("geometry._CAPI_FarthestPointSampler")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const NDArray data = args[0]; const NDArray data = args[0];
const int64_t batch_size = args[1]; const int64_t batch_size = args[1];
const int64_t sample_points = args[2]; const int64_t sample_points = args[2];
NDArray dist = args[3]; NDArray dist = args[3];
IdArray start_idx = args[4]; IdArray start_idx = args[4];
IdArray result = args[5]; IdArray result = args[5];
FarthestPointSampler(data, batch_size, sample_points, dist, start_idx, result); FarthestPointSampler(
}); data, batch_size, sample_points, dist, start_idx, result);
});
DGL_REGISTER_GLOBAL("geometry._CAPI_NeighborMatching") DGL_REGISTER_GLOBAL("geometry._CAPI_NeighborMatching")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0]; HeteroGraphRef graph = args[0];
const NDArray weight = args[1]; const NDArray weight = args[1];
IdArray result = args[2]; IdArray result = args[2];
// sanity check // sanity check
aten::CheckCtx(graph->Context(), {weight, result}, {"edge_weight, result"}); aten::CheckCtx(
aten::CheckContiguous({weight, result}, {"edge_weight", "result"}); graph->Context(), {weight, result}, {"edge_weight, result"});
CHECK_EQ(graph->NumEdgeTypes(), 1) << "homogeneous graph has only one edge type"; aten::CheckContiguous({weight, result}, {"edge_weight", "result"});
CHECK_EQ(result->ndim, 1) << "result should be an 1D tensor."; CHECK_EQ(graph->NumEdgeTypes(), 1)
auto pair = graph->meta_graph()->FindEdge(0); << "homogeneous graph has only one edge type";
const dgl_type_t node_type = pair.first; CHECK_EQ(result->ndim, 1) << "result should be an 1D tensor.";
CHECK_EQ(graph->NumVertices(node_type), result->shape[0]) auto pair = graph->meta_graph()->FindEdge(0);
<< "The number of nodes should be the same as the length of result tensor."; const dgl_type_t node_type = pair.first;
if (!aten::IsNullArray(weight)) { CHECK_EQ(graph->NumVertices(node_type), result->shape[0])
CHECK_EQ(weight->ndim, 1) << "weight should be an 1D tensor."; << "The number of nodes should be the same as the length of result "
CHECK_EQ(graph->NumEdges(0), weight->shape[0]) "tensor.";
<< "number of edges in graph should be the same " if (!aten::IsNullArray(weight)) {
<< "as the length of edge weight tensor."; CHECK_EQ(weight->ndim, 1) << "weight should be an 1D tensor.";
} CHECK_EQ(graph->NumEdges(0), weight->shape[0])
<< "number of edges in graph should be the same "
<< "as the length of edge weight tensor.";
}
// call implementation // call implementation
NeighborMatching(graph.sptr(), weight, result); NeighborMatching(graph.sptr(), weight, result);
}); });
} // namespace geometry } // namespace geometry
} // namespace dgl } // namespace dgl
...@@ -13,16 +13,19 @@ namespace geometry { ...@@ -13,16 +13,19 @@ namespace geometry {
namespace impl { namespace impl {
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points, void FarthestPointSampler(
NDArray dist, IdArray start_idx, IdArray result); NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
IdArray start_idx, IdArray result);
/*! \brief Implementation of weighted neighbor matching process of edge coarsening used /*! \brief Implementation of weighted neighbor matching process of edge
* in Metis and Graclus for homogeneous graph coarsening. This procedure keeps * coarsening used in Metis and Graclus for homogeneous graph coarsening. This
* picking an unmarked vertex and matching it with one its unmarked neighbors * procedure keeps picking an unmarked vertex and matching it with one its
* (that maximizes its edge weight) until no match can be done. * unmarked neighbors (that maximizes its edge weight) until no match can be
* done.
*/ */
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result); void WeightedNeighborMatching(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
/*! \brief Implementation of neighbor matching process of edge coarsening used /*! \brief Implementation of neighbor matching process of edge coarsening used
* in Metis and Graclus for homogeneous graph coarsening. This procedure keeps * in Metis and Graclus for homogeneous graph coarsening. This procedure keeps
......
...@@ -10,60 +10,51 @@ namespace dgl { ...@@ -10,60 +10,51 @@ namespace dgl {
// creator implementation // creator implementation
HeteroGraphPtr CreateHeteroGraph( HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph, GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type) { const std::vector<int64_t>& num_nodes_per_type) {
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs, num_nodes_per_type)); return HeteroGraphPtr(
new HeteroGraph(meta_graph, rel_graphs, num_nodes_per_type));
} }
HeteroGraphPtr CreateFromCOO( HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row,
IdArray row, IdArray col, IdArray col, bool row_sorted, bool col_sorted, dgl_format_code_t formats) {
bool row_sorted, bool col_sorted, dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCOO( auto unit_g = UnitGraph::CreateFromCOO(
num_vtypes, num_src, num_dst, row, col, row_sorted, col_sorted, formats); num_vtypes, num_src, num_dst, row, col, row_sorted, col_sorted, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g})); return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
} }
HeteroGraphPtr CreateFromCOO( HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat, int64_t num_vtypes, const aten::COOMatrix& mat, dgl_format_code_t formats) {
dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCOO(num_vtypes, mat, formats); auto unit_g = UnitGraph::CreateFromCOO(num_vtypes, mat, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g})); return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
} }
HeteroGraphPtr CreateFromCSR( HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCSR( auto unit_g = UnitGraph::CreateFromCSR(
num_vtypes, num_src, num_dst, indptr, indices, edge_ids, formats); num_vtypes, num_src, num_dst, indptr, indices, edge_ids, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g})); return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
} }
HeteroGraphPtr CreateFromCSR( HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {
dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCSR(num_vtypes, mat, formats); auto unit_g = UnitGraph::CreateFromCSR(num_vtypes, mat, formats);
auto ret = HeteroGraphPtr(new HeteroGraph( auto ret = HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
unit_g->meta_graph(), return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
{unit_g}));
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(),
{unit_g}));
} }
HeteroGraphPtr CreateFromCSC( HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst, int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indptr, IdArray indices, IdArray edge_ids, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCSC( auto unit_g = UnitGraph::CreateFromCSC(
num_vtypes, num_src, num_dst, indptr, indices, edge_ids, formats); num_vtypes, num_src, num_dst, indptr, indices, edge_ids, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g})); return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
} }
HeteroGraphPtr CreateFromCSC( HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat, int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {
dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCSC(num_vtypes, mat, formats); auto unit_g = UnitGraph::CreateFromCSC(num_vtypes, mat, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g})); return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
} }
......
...@@ -37,16 +37,18 @@ gk_csr_t *Convert2GKCsr(const aten::CSRMatrix mat, bool is_row) { ...@@ -37,16 +37,18 @@ gk_csr_t *Convert2GKCsr(const aten::CSRMatrix mat, bool is_row) {
size_t num_ptrs; size_t num_ptrs;
if (is_row) { if (is_row) {
num_ptrs = gk_csr->nrows + 1; num_ptrs = gk_csr->nrows + 1;
gk_indptr = gk_csr->rowptr = gk_zmalloc(gk_csr->nrows+1, gk_indptr = gk_csr->rowptr = gk_zmalloc(
const_cast<char*>("gk_csr_ExtractPartition: rowptr")); gk_csr->nrows + 1,
gk_indices = gk_csr->rowind = gk_imalloc(nnz, const_cast<char *>("gk_csr_ExtractPartition: rowptr"));
const_cast<char*>("gk_csr_ExtractPartition: rowind")); gk_indices = gk_csr->rowind =
gk_imalloc(nnz, const_cast<char *>("gk_csr_ExtractPartition: rowind"));
} else { } else {
num_ptrs = gk_csr->ncols + 1; num_ptrs = gk_csr->ncols + 1;
gk_indptr = gk_csr->colptr = gk_zmalloc(gk_csr->ncols+1, gk_indptr = gk_csr->colptr = gk_zmalloc(
const_cast<char*>("gk_csr_ExtractPartition: colptr")); gk_csr->ncols + 1,
gk_indices = gk_csr->colind = gk_imalloc(nnz, const_cast<char *>("gk_csr_ExtractPartition: colptr"));
const_cast<char*>("gk_csr_ExtractPartition: colind")); gk_indices = gk_csr->colind =
gk_imalloc(nnz, const_cast<char *>("gk_csr_ExtractPartition: colind"));
} }
for (size_t i = 0; i < num_ptrs; i++) { for (size_t i = 0; i < num_ptrs; i++) {
...@@ -98,7 +100,8 @@ aten::CSRMatrix Convert2DGLCsr(gk_csr_t *gk_csr, bool is_row) { ...@@ -98,7 +100,8 @@ aten::CSRMatrix Convert2DGLCsr(gk_csr_t *gk_csr, bool is_row) {
eids[i] = i; eids[i] = i;
} }
return aten::CSRMatrix(gk_csr->nrows, gk_csr->ncols, indptr_arr, indices_arr, eids_arr); return aten::CSRMatrix(
gk_csr->nrows, gk_csr->ncols, indptr_arr, indices_arr, eids_arr);
} }
#endif // !defined(_WIN32) #endif // !defined(_WIN32)
......
...@@ -5,11 +5,13 @@ ...@@ -5,11 +5,13 @@
*/ */
#include <dgl/graph.h> #include <dgl/graph.h>
#include <dgl/sampler.h> #include <dgl/sampler.h>
#include <algorithm> #include <algorithm>
#include <unordered_map>
#include <set>
#include <functional> #include <functional>
#include <set>
#include <tuple> #include <tuple>
#include <unordered_map>
#include "../c_api_common.h" #include "../c_api_common.h"
namespace dgl { namespace dgl {
...@@ -20,16 +22,16 @@ Graph::Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes) { ...@@ -20,16 +22,16 @@ Graph::Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes) {
this->AddVertices(num_nodes); this->AddVertices(num_nodes);
num_edges_ = src_ids->shape[0]; num_edges_ = src_ids->shape[0];
CHECK(static_cast<int64_t>(num_edges_) == dst_ids->shape[0]) CHECK(static_cast<int64_t>(num_edges_) == dst_ids->shape[0])
<< "vectors in COO must have the same length"; << "vectors in COO must have the same length";
const dgl_id_t *src_data = static_cast<dgl_id_t*>(src_ids->data); const dgl_id_t* src_data = static_cast<dgl_id_t*>(src_ids->data);
const dgl_id_t *dst_data = static_cast<dgl_id_t*>(dst_ids->data); const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst_ids->data);
all_edges_src_.reserve(num_edges_); all_edges_src_.reserve(num_edges_);
all_edges_dst_.reserve(num_edges_); all_edges_dst_.reserve(num_edges_);
for (uint64_t i = 0; i < num_edges_; i++) { for (uint64_t i = 0; i < num_edges_; i++) {
auto src = src_data[i]; auto src = src_data[i];
auto dst = dst_data[i]; auto dst = dst_data[i];
CHECK(HasVertex(src) && HasVertex(dst)) CHECK(HasVertex(src) && HasVertex(dst))
<< "Invalid vertices: src=" << src << " dst=" << dst; << "Invalid vertices: src=" << src << " dst=" << dst;
adjlist_[src].succ.push_back(dst); adjlist_[src].succ.push_back(dst);
adjlist_[src].edge_id.push_back(i); adjlist_[src].edge_id.push_back(i);
...@@ -53,16 +55,16 @@ bool Graph::IsMultigraph() const { ...@@ -53,16 +55,16 @@ bool Graph::IsMultigraph() const {
pairs.emplace_back(all_edges_src_[eid], all_edges_dst_[eid]); pairs.emplace_back(all_edges_src_[eid], all_edges_dst_[eid]);
} }
// sort according to src and dst ids // sort according to src and dst ids
std::sort(pairs.begin(), pairs.end(), std::sort(pairs.begin(), pairs.end(), [](const Pair& t1, const Pair& t2) {
[] (const Pair& t1, const Pair& t2) { return std::get<0>(t1) < std::get<0>(t2) ||
return std::get<0>(t1) < std::get<0>(t2) (std::get<0>(t1) == std::get<0>(t2) &&
|| (std::get<0>(t1) == std::get<0>(t2) && std::get<1>(t1) < std::get<1>(t2)); std::get<1>(t1) < std::get<1>(t2));
}); });
for (uint64_t eid = 0; eid < num_edges_-1; ++eid) { for (uint64_t eid = 0; eid < num_edges_ - 1; ++eid) {
// As src and dst are all sorted, we only need to compare i and i+1 // As src and dst are all sorted, we only need to compare i and i+1
if (std::get<0>(pairs[eid]) == std::get<0>(pairs[eid+1]) && if (std::get<0>(pairs[eid]) == std::get<0>(pairs[eid + 1]) &&
std::get<1>(pairs[eid]) == std::get<1>(pairs[eid+1])) std::get<1>(pairs[eid]) == std::get<1>(pairs[eid + 1]))
return true; return true;
} }
return false; return false;
...@@ -77,7 +79,7 @@ void Graph::AddVertices(uint64_t num_vertices) { ...@@ -77,7 +79,7 @@ void Graph::AddVertices(uint64_t num_vertices) {
void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) { void Graph::AddEdge(dgl_id_t src, dgl_id_t dst) {
CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed."; CHECK(!read_only_) << "Graph is read-only. Mutations are not allowed.";
CHECK(HasVertex(src) && HasVertex(dst)) CHECK(HasVertex(src) && HasVertex(dst))
<< "Invalid vertices: src=" << src << " dst=" << dst; << "Invalid vertices: src=" << src << " dst=" << dst;
dgl_id_t eid = num_edges_++; dgl_id_t eid = num_edges_++;
...@@ -125,7 +127,7 @@ BoolArray Graph::HasVertices(IdArray vids) const { ...@@ -125,7 +127,7 @@ BoolArray Graph::HasVertices(IdArray vids) const {
int64_t* rst_data = static_cast<int64_t*>(rst->data); int64_t* rst_data = static_cast<int64_t*>(rst->data);
const int64_t nverts = NumVertices(); const int64_t nverts = NumVertices();
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
rst_data[i] = (vid_data[i] < nverts && vid_data[i] >= 0)? 1 : 0; rst_data[i] = (vid_data[i] < nverts && vid_data[i] >= 0) ? 1 : 0;
} }
return rst; return rst;
} }
...@@ -151,18 +153,18 @@ BoolArray Graph::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const { ...@@ -151,18 +153,18 @@ BoolArray Graph::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
if (srclen == 1) { if (srclen == 1) {
// one-many // one-many
for (int64_t i = 0; i < dstlen; ++i) { for (int64_t i = 0; i < dstlen; ++i) {
rst_data[i] = HasEdgeBetween(src_data[0], dst_data[i])? 1 : 0; rst_data[i] = HasEdgeBetween(src_data[0], dst_data[i]) ? 1 : 0;
} }
} else if (dstlen == 1) { } else if (dstlen == 1) {
// many-one // many-one
for (int64_t i = 0; i < srclen; ++i) { for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdgeBetween(src_data[i], dst_data[0])? 1 : 0; rst_data[i] = HasEdgeBetween(src_data[i], dst_data[0]) ? 1 : 0;
} }
} else { } else {
// many-many // many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array."; CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) { for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdgeBetween(src_data[i], dst_data[i])? 1 : 0; rst_data[i] = HasEdgeBetween(src_data[i], dst_data[i]) ? 1 : 0;
} }
} }
return rst; return rst;
...@@ -174,11 +176,11 @@ IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const { ...@@ -174,11 +176,11 @@ IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const {
CHECK(radius >= 1) << "invalid radius: " << radius; CHECK(radius >= 1) << "invalid radius: " << radius;
std::set<dgl_id_t> vset; std::set<dgl_id_t> vset;
for (auto& it : reverse_adjlist_[vid].succ) for (auto& it : reverse_adjlist_[vid].succ) vset.insert(it);
vset.insert(it);
const int64_t len = vset.size(); const int64_t len = vset.size();
IdArray rst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray rst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data); int64_t* rst_data = static_cast<int64_t*>(rst->data);
std::copy(vset.begin(), vset.end(), rst_data); std::copy(vset.begin(), vset.end(), rst_data);
...@@ -191,11 +193,11 @@ IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const { ...@@ -191,11 +193,11 @@ IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const {
CHECK(radius >= 1) << "invalid radius: " << radius; CHECK(radius >= 1) << "invalid radius: " << radius;
std::set<dgl_id_t> vset; std::set<dgl_id_t> vset;
for (auto& it : adjlist_[vid].succ) for (auto& it : adjlist_[vid].succ) vset.insert(it);
vset.insert(it);
const int64_t len = vset.size(); const int64_t len = vset.size();
IdArray rst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray rst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data); int64_t* rst_data = static_cast<int64_t*>(rst->data);
std::copy(vset.begin(), vset.end(), rst_data); std::copy(vset.begin(), vset.end(), rst_data);
...@@ -204,19 +206,20 @@ IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const { ...@@ -204,19 +206,20 @@ IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const {
// O(E) // O(E)
IdArray Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const { IdArray Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
CHECK(HasVertex(src) && HasVertex(dst)) << "invalid edge: " << src << " -> " << dst; CHECK(HasVertex(src) && HasVertex(dst))
<< "invalid edge: " << src << " -> " << dst;
const auto& succ = adjlist_[src].succ; const auto& succ = adjlist_[src].succ;
std::vector<dgl_id_t> edgelist; std::vector<dgl_id_t> edgelist;
for (size_t i = 0; i < succ.size(); ++i) { for (size_t i = 0; i < succ.size(); ++i) {
if (succ[i] == dst) if (succ[i] == dst) edgelist.push_back(adjlist_[src].edge_id[i]);
edgelist.push_back(adjlist_[src].edge_id[i]);
} }
// FIXME: signed? Also it seems that we are using int64_t everywhere... // FIXME: signed? Also it seems that we are using int64_t everywhere...
const int64_t len = edgelist.size(); const int64_t len = edgelist.size();
IdArray rst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray rst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
// FIXME: signed? // FIXME: signed?
int64_t* rst_data = static_cast<int64_t*>(rst->data); int64_t* rst_data = static_cast<int64_t*>(rst->data);
...@@ -234,7 +237,7 @@ EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const { ...@@ -234,7 +237,7 @@ EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
int64_t i, j; int64_t i, j;
CHECK((srclen == dstlen) || (srclen == 1) || (dstlen == 1)) CHECK((srclen == dstlen) || (srclen == 1) || (dstlen == 1))
<< "Invalid src and dst id array."; << "Invalid src and dst id array.";
const int64_t src_stride = (srclen == 1 && dstlen != 1) ? 0 : 1; const int64_t src_stride = (srclen == 1 && dstlen != 1) ? 0 : 1;
const int64_t dst_stride = (dstlen == 1 && srclen != 1) ? 0 : 1; const int64_t dst_stride = (dstlen == 1 && srclen != 1) ? 0 : 1;
...@@ -243,10 +246,11 @@ EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const { ...@@ -243,10 +246,11 @@ EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
std::vector<dgl_id_t> src, dst, eid; std::vector<dgl_id_t> src, dst, eid;
for (i = 0, j = 0; i < srclen && j < dstlen; i += src_stride, j += dst_stride) { for (i = 0, j = 0; i < srclen && j < dstlen;
i += src_stride, j += dst_stride) {
const dgl_id_t src_id = src_data[i], dst_id = dst_data[j]; const dgl_id_t src_id = src_data[i], dst_id = dst_data[j];
CHECK(HasVertex(src_id) && HasVertex(dst_id)) << CHECK(HasVertex(src_id) && HasVertex(dst_id))
"invalid edge: " << src_id << " -> " << dst_id; << "invalid edge: " << src_id << " -> " << dst_id;
const auto& succ = adjlist_[src_id].succ; const auto& succ = adjlist_[src_id].succ;
for (size_t k = 0; k < succ.size(); ++k) { for (size_t k = 0; k < succ.size(); ++k) {
if (succ[k] == dst_id) { if (succ[k] == dst_id) {
...@@ -286,8 +290,7 @@ EdgeArray Graph::FindEdges(IdArray eids) const { ...@@ -286,8 +290,7 @@ EdgeArray Graph::FindEdges(IdArray eids) const {
for (uint64_t i = 0; i < (uint64_t)len; ++i) { for (uint64_t i = 0; i < (uint64_t)len; ++i) {
dgl_id_t eid = eid_data[i]; dgl_id_t eid = eid_data[i];
if (eid >= num_edges_) if (eid >= num_edges_) LOG(FATAL) << "invalid edge id:" << eid;
LOG(FATAL) << "invalid edge id:" << eid;
rst_src_data[i] = all_edges_src_[eid]; rst_src_data[i] = all_edges_src_[eid];
rst_dst_data[i] = all_edges_dst_[eid]; rst_dst_data[i] = all_edges_dst_[eid];
...@@ -301,9 +304,12 @@ EdgeArray Graph::FindEdges(IdArray eids) const { ...@@ -301,9 +304,12 @@ EdgeArray Graph::FindEdges(IdArray eids) const {
EdgeArray Graph::InEdges(dgl_id_t vid) const { EdgeArray Graph::InEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = reverse_adjlist_[vid].succ.size(); const int64_t len = reverse_adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray src = IdArray::Empty(
IdArray dst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray dst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* src_data = static_cast<int64_t*>(src->data); int64_t* src_data = static_cast<int64_t*>(src->data);
int64_t* dst_data = static_cast<int64_t*>(dst->data); int64_t* dst_data = static_cast<int64_t*>(dst->data);
int64_t* eid_data = static_cast<int64_t*>(eid->data); int64_t* eid_data = static_cast<int64_t*>(eid->data);
...@@ -347,9 +353,12 @@ EdgeArray Graph::InEdges(IdArray vids) const { ...@@ -347,9 +353,12 @@ EdgeArray Graph::InEdges(IdArray vids) const {
EdgeArray Graph::OutEdges(dgl_id_t vid) const { EdgeArray Graph::OutEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid; CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = adjlist_[vid].succ.size(); const int64_t len = adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray src = IdArray::Empty(
IdArray dst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray dst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* src_data = static_cast<int64_t*>(src->data); int64_t* src_data = static_cast<int64_t*>(src->data);
int64_t* dst_data = static_cast<int64_t*>(dst->data); int64_t* dst_data = static_cast<int64_t*>(dst->data);
int64_t* eid_data = static_cast<int64_t*>(eid->data); int64_t* eid_data = static_cast<int64_t*>(eid->data);
...@@ -390,11 +399,14 @@ EdgeArray Graph::OutEdges(IdArray vids) const { ...@@ -390,11 +399,14 @@ EdgeArray Graph::OutEdges(IdArray vids) const {
} }
// O(E*log(E)) if sort is required; otherwise, O(E) // O(E*log(E)) if sort is required; otherwise, O(E)
EdgeArray Graph::Edges(const std::string &order) const { EdgeArray Graph::Edges(const std::string& order) const {
const int64_t len = num_edges_; const int64_t len = num_edges_;
IdArray src = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray src = IdArray::Empty(
IdArray dst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); {len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); IdArray dst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
if (order == "srcdst") { if (order == "srcdst") {
typedef std::tuple<int64_t, int64_t, int64_t> Tuple; typedef std::tuple<int64_t, int64_t, int64_t> Tuple;
...@@ -404,10 +416,11 @@ EdgeArray Graph::Edges(const std::string &order) const { ...@@ -404,10 +416,11 @@ EdgeArray Graph::Edges(const std::string &order) const {
tuples.emplace_back(all_edges_src_[eid], all_edges_dst_[eid], eid); tuples.emplace_back(all_edges_src_[eid], all_edges_dst_[eid], eid);
} }
// sort according to src and dst ids // sort according to src and dst ids
std::sort(tuples.begin(), tuples.end(), std::sort(
[] (const Tuple& t1, const Tuple& t2) { tuples.begin(), tuples.end(), [](const Tuple& t1, const Tuple& t2) {
return std::get<0>(t1) < std::get<0>(t2) return std::get<0>(t1) < std::get<0>(t2) ||
|| (std::get<0>(t1) == std::get<0>(t2) && std::get<1>(t1) < std::get<1>(t2)); (std::get<0>(t1) == std::get<0>(t2) &&
std::get<1>(t1) < std::get<1>(t2));
}); });
// make return arrays // make return arrays
...@@ -488,8 +501,11 @@ Subgraph Graph::VertexSubgraph(IdArray vids) const { ...@@ -488,8 +501,11 @@ Subgraph Graph::VertexSubgraph(IdArray vids) const {
} }
} }
} }
rst.induced_edges = IdArray::Empty({static_cast<int64_t>(edges.size())}, vids->dtype, vids->ctx); rst.induced_edges = IdArray::Empty(
std::copy(edges.begin(), edges.end(), static_cast<int64_t*>(rst.induced_edges->data)); {static_cast<int64_t>(edges.size())}, vids->dtype, vids->ctx);
std::copy(
edges.begin(), edges.end(),
static_cast<int64_t*>(rst.induced_edges->data));
return rst; return rst;
} }
...@@ -524,7 +540,9 @@ Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const { ...@@ -524,7 +540,9 @@ Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
rst.induced_vertices = IdArray::Empty( rst.induced_vertices = IdArray::Empty(
{static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx); {static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx);
std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data)); std::copy(
nodes.begin(), nodes.end(),
static_cast<int64_t*>(rst.induced_vertices->data));
} else { } else {
rst.graph = std::make_shared<Graph>(); rst.graph = std::make_shared<Graph>();
rst.induced_edges = eids; rst.induced_edges = eids;
...@@ -536,59 +554,58 @@ Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const { ...@@ -536,59 +554,58 @@ Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
rst.graph->AddEdge(src_id, dst_id); rst.graph->AddEdge(src_id, dst_id);
} }
for (uint64_t i = 0; i < NumVertices(); ++i) for (uint64_t i = 0; i < NumVertices(); ++i) nodes.push_back(i);
nodes.push_back(i);
rst.induced_vertices = IdArray::Empty( rst.induced_vertices = IdArray::Empty(
{static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx); {static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx);
std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data)); std::copy(
nodes.begin(), nodes.end(),
static_cast<int64_t*>(rst.induced_vertices->data));
} }
return rst; return rst;
} }
std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const { std::vector<IdArray> Graph::GetAdj(
bool transpose, const std::string& fmt) const {
uint64_t num_edges = NumEdges(); uint64_t num_edges = NumEdges();
uint64_t num_nodes = NumVertices(); uint64_t num_nodes = NumVertices();
if (fmt == "coo") { if (fmt == "coo") {
IdArray idx = IdArray::Empty( IdArray idx = IdArray::Empty(
{2 * static_cast<int64_t>(num_edges)}, {2 * static_cast<int64_t>(num_edges)}, DGLDataType{kDGLInt, 64, 1},
DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0}); DGLContext{kDGLCPU, 0});
int64_t *idx_data = static_cast<int64_t*>(idx->data); int64_t* idx_data = static_cast<int64_t*>(idx->data);
if (transpose) { if (transpose) {
std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data); std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data);
std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), idx_data + num_edges); std::copy(
all_edges_dst_.begin(), all_edges_dst_.end(), idx_data + num_edges);
} else { } else {
std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), idx_data); std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), idx_data);
std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data + num_edges); std::copy(
all_edges_src_.begin(), all_edges_src_.end(), idx_data + num_edges);
} }
IdArray eid = IdArray::Empty( IdArray eid = IdArray::Empty(
{static_cast<int64_t>(num_edges)}, {static_cast<int64_t>(num_edges)}, DGLDataType{kDGLInt, 64, 1},
DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0}); DGLContext{kDGLCPU, 0});
int64_t *eid_data = static_cast<int64_t*>(eid->data); int64_t* eid_data = static_cast<int64_t*>(eid->data);
for (uint64_t eid = 0; eid < num_edges; ++eid) { for (uint64_t eid = 0; eid < num_edges; ++eid) {
eid_data[eid] = eid; eid_data[eid] = eid;
} }
return std::vector<IdArray>{idx, eid}; return std::vector<IdArray>{idx, eid};
} else if (fmt == "csr") { } else if (fmt == "csr") {
IdArray indptr = IdArray::Empty( IdArray indptr = IdArray::Empty(
{static_cast<int64_t>(num_nodes) + 1}, {static_cast<int64_t>(num_nodes) + 1}, DGLDataType{kDGLInt, 64, 1},
DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0}); DGLContext{kDGLCPU, 0});
IdArray indices = IdArray::Empty( IdArray indices = IdArray::Empty(
{static_cast<int64_t>(num_edges)}, {static_cast<int64_t>(num_edges)}, DGLDataType{kDGLInt, 64, 1},
DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0}); DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty( IdArray eid = IdArray::Empty(
{static_cast<int64_t>(num_edges)}, {static_cast<int64_t>(num_edges)}, DGLDataType{kDGLInt, 64, 1},
DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0}); DGLContext{kDGLCPU, 0});
int64_t *indptr_data = static_cast<int64_t*>(indptr->data); int64_t* indptr_data = static_cast<int64_t*>(indptr->data);
int64_t *indices_data = static_cast<int64_t*>(indices->data); int64_t* indices_data = static_cast<int64_t*>(indices->data);
int64_t *eid_data = static_cast<int64_t*>(eid->data); int64_t* eid_data = static_cast<int64_t*>(eid->data);
const AdjacencyList *adjlist; const AdjacencyList* adjlist;
if (transpose) { if (transpose) {
// Out-edges. // Out-edges.
adjlist = &adjlist_; adjlist = &adjlist_;
...@@ -599,10 +616,12 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const ...@@ -599,10 +616,12 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const
indptr_data[0] = 0; indptr_data[0] = 0;
for (size_t i = 0; i < adjlist->size(); i++) { for (size_t i = 0; i < adjlist->size(); i++) {
indptr_data[i + 1] = indptr_data[i] + adjlist->at(i).succ.size(); indptr_data[i + 1] = indptr_data[i] + adjlist->at(i).succ.size();
std::copy(adjlist->at(i).succ.begin(), adjlist->at(i).succ.end(), std::copy(
indices_data + indptr_data[i]); adjlist->at(i).succ.begin(), adjlist->at(i).succ.end(),
std::copy(adjlist->at(i).edge_id.begin(), adjlist->at(i).edge_id.end(), indices_data + indptr_data[i]);
eid_data + indptr_data[i]); std::copy(
adjlist->at(i).edge_id.begin(), adjlist->at(i).edge_id.end(),
eid_data + indptr_data[i]);
} }
return std::vector<IdArray>{indptr, indices, eid}; return std::vector<IdArray>{indptr, indices, eid};
} else { } else {
......
...@@ -3,323 +3,325 @@ ...@@ -3,323 +3,325 @@
* \file graph/graph.cc * \file graph/graph.cc
* \brief DGL graph index APIs * \brief DGL graph index APIs
*/ */
#include <dgl/packed_func_ext.h>
#include <dgl/graph.h> #include <dgl/graph.h>
#include <dgl/immutable_graph.h>
#include <dgl/graph_op.h> #include <dgl/graph_op.h>
#include <dgl/sampler.h> #include <dgl/immutable_graph.h>
#include <dgl/nodeflow.h> #include <dgl/nodeflow.h>
#include <dgl/packed_func_ext.h>
#include <dgl/sampler.h>
#include "../c_api_common.h" #include "../c_api_common.h"
using dgl::runtime::DGLArgs; using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue; using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue; using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray; using dgl::runtime::NDArray;
using dgl::runtime::PackedFunc;
namespace dgl { namespace dgl {
///////////////////////////// Graph API /////////////////////////////////// ///////////////////////////// Graph API ///////////////////////////////////
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreateMutable") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreateMutable")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = GraphRef(Graph::Create()); *rv = GraphRef(Graph::Create());
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreate") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const IdArray src_ids = args[0]; const IdArray src_ids = args[0];
const IdArray dst_ids = args[1]; const IdArray dst_ids = args[1];
const int64_t num_nodes = args[2]; const int64_t num_nodes = args[2];
const bool readonly = args[3]; const bool readonly = args[3];
if (readonly) { if (readonly) {
*rv = GraphRef(ImmutableGraph::CreateFromCOO(num_nodes, src_ids, dst_ids)); *rv = GraphRef(
} else { ImmutableGraph::CreateFromCOO(num_nodes, src_ids, dst_ids));
*rv = GraphRef(Graph::CreateFromCOO(num_nodes, src_ids, dst_ids)); } else {
} *rv = GraphRef(Graph::CreateFromCOO(num_nodes, src_ids, dst_ids));
}); }
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const IdArray indptr = args[0]; const IdArray indptr = args[0];
const IdArray indices = args[1]; const IdArray indices = args[1];
const std::string edge_dir = args[2]; const std::string edge_dir = args[2];
IdArray edge_ids = IdArray::Empty({indices->shape[0]}, IdArray edge_ids = IdArray::Empty(
DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0}); {indices->shape[0]}, DGLDataType{kDGLInt, 64, 1},
int64_t *edge_data = static_cast<int64_t *>(edge_ids->data); DGLContext{kDGLCPU, 0});
for (int64_t i = 0; i < edge_ids->shape[0]; i++) int64_t* edge_data = static_cast<int64_t*>(edge_ids->data);
edge_data[i] = i; for (int64_t i = 0; i < edge_ids->shape[0]; i++) edge_data[i] = i;
*rv = GraphRef(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, edge_dir)); *rv = GraphRef(
}); ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, edge_dir));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const std::string shared_mem_name = args[0]; const std::string shared_mem_name = args[0];
*rv = GraphRef(ImmutableGraph::CreateFromCSR(shared_mem_name)); *rv = GraphRef(ImmutableGraph::CreateFromCSR(shared_mem_name));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddVertices") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
uint64_t num_vertices = args[1]; uint64_t num_vertices = args[1];
g->AddVertices(num_vertices); g->AddVertices(num_vertices);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdge") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdge")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t src = args[1]; const dgl_id_t src = args[1];
const dgl_id_t dst = args[2]; const dgl_id_t dst = args[2];
g->AddEdge(src, dst); g->AddEdge(src, dst);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdges") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray src = args[1]; const IdArray src = args[1];
const IdArray dst = args[2]; const IdArray dst = args[2];
g->AddEdges(src, dst); g->AddEdges(src, dst);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphClear") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphClear")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
g->Clear(); g->Clear();
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphIsMultigraph") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphIsMultigraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
*rv = g->IsMultigraph(); *rv = g->IsMultigraph();
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphIsReadonly") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphIsReadonly")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
*rv = g->IsReadonly(); *rv = g->IsReadonly();
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumVertices") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
*rv = static_cast<int64_t>(g->NumVertices()); *rv = static_cast<int64_t>(g->NumVertices());
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumEdges") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
*rv = static_cast<int64_t>(g->NumEdges()); *rv = static_cast<int64_t>(g->NumEdges());
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertex") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertex")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
*rv = g->HasVertex(vid); *rv = g->HasVertex(vid);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertices") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray vids = args[1]; const IdArray vids = args[1];
*rv = g->HasVertices(vids); *rv = g->HasVertices(vids);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgeBetween") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgeBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t src = args[1]; const dgl_id_t src = args[1];
const dgl_id_t dst = args[2]; const dgl_id_t dst = args[2];
*rv = g->HasEdgeBetween(src, dst); *rv = g->HasEdgeBetween(src, dst);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgesBetween") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgesBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray src = args[1]; const IdArray src = args[1];
const IdArray dst = args[2]; const IdArray dst = args[2];
*rv = g->HasEdgesBetween(src, dst); *rv = g->HasEdgesBetween(src, dst);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
const uint64_t radius = args[2]; const uint64_t radius = args[2];
*rv = g->Predecessors(vid, radius); *rv = g->Predecessors(vid, radius);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphSuccessors") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphSuccessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
const uint64_t radius = args[2]; const uint64_t radius = args[2];
*rv = g->Successors(vid, radius); *rv = g->Successors(vid, radius);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t src = args[1]; const dgl_id_t src = args[1];
const dgl_id_t dst = args[2]; const dgl_id_t dst = args[2];
*rv = g->EdgeId(src, dst); *rv = g->EdgeId(src, dst);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray src = args[1]; const IdArray src = args[1];
const IdArray dst = args[2]; const IdArray dst = args[2];
*rv = ConvertEdgeArrayToPackedFunc(g->EdgeIds(src, dst)); *rv = ConvertEdgeArrayToPackedFunc(g->EdgeIds(src, dst));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFindEdge") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFindEdge")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t eid = args[1]; const dgl_id_t eid = args[1];
const auto& pair = g->FindEdge(eid); const auto& pair = g->FindEdge(eid);
*rv = PackedFunc([pair] (DGLArgs args, DGLRetValue* rv) { *rv = PackedFunc([pair](DGLArgs args, DGLRetValue* rv) {
const int choice = args[0]; const int choice = args[0];
const int64_t ret = (choice == 0? pair.first : pair.second); const int64_t ret = (choice == 0 ? pair.first : pair.second);
*rv = ret; *rv = ret;
}); });
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFindEdges") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFindEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray eids = args[1]; const IdArray eids = args[1];
*rv = ConvertEdgeArrayToPackedFunc(g->FindEdges(eids)); *rv = ConvertEdgeArrayToPackedFunc(g->FindEdges(eids));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_1") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
*rv = ConvertEdgeArrayToPackedFunc(g->InEdges(vid)); *rv = ConvertEdgeArrayToPackedFunc(g->InEdges(vid));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_2") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray vids = args[1]; const IdArray vids = args[1];
*rv = ConvertEdgeArrayToPackedFunc(g->InEdges(vids)); *rv = ConvertEdgeArrayToPackedFunc(g->InEdges(vids));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_1") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
*rv = ConvertEdgeArrayToPackedFunc(g->OutEdges(vid)); *rv = ConvertEdgeArrayToPackedFunc(g->OutEdges(vid));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_2") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray vids = args[1]; const IdArray vids = args[1];
*rv = ConvertEdgeArrayToPackedFunc(g->OutEdges(vids)); *rv = ConvertEdgeArrayToPackedFunc(g->OutEdges(vids));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdges") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
std::string order = args[1]; std::string order = args[1];
*rv = ConvertEdgeArrayToPackedFunc(g->Edges(order)); *rv = ConvertEdgeArrayToPackedFunc(g->Edges(order));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegree") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(g->InDegree(vid)); *rv = static_cast<int64_t>(g->InDegree(vid));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegrees") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray vids = args[1]; const IdArray vids = args[1];
*rv = g->InDegrees(vids); *rv = g->InDegrees(vids);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegree") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const dgl_id_t vid = args[1]; const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(g->OutDegree(vid)); *rv = static_cast<int64_t>(g->OutDegree(vid));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegrees") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray vids = args[1]; const IdArray vids = args[1];
*rv = g->OutDegrees(vids); *rv = g->OutDegrees(vids);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray vids = args[1]; const IdArray vids = args[1];
std::shared_ptr<Subgraph> subg(new Subgraph(g->VertexSubgraph(vids))); std::shared_ptr<Subgraph> subg(new Subgraph(g->VertexSubgraph(vids)));
*rv = SubgraphRef(subg); *rv = SubgraphRef(subg);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
const IdArray eids = args[1]; const IdArray eids = args[1];
bool preserve_nodes = args[2]; bool preserve_nodes = args[2];
std::shared_ptr<Subgraph> subg( std::shared_ptr<Subgraph> subg(
new Subgraph(g->EdgeSubgraph(eids, preserve_nodes))); new Subgraph(g->EdgeSubgraph(eids, preserve_nodes)));
*rv = SubgraphRef(subg); *rv = SubgraphRef(subg);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphGetAdj") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphGetAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
bool transpose = args[1]; bool transpose = args[1];
std::string format = args[2]; std::string format = args[2];
auto res = g->GetAdj(transpose, format); auto res = g->GetAdj(transpose, format);
*rv = ConvertNDArrayVectorToPackedFunc(res); *rv = ConvertNDArrayVectorToPackedFunc(res);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphContext") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphContext")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
*rv = g->Context(); *rv = g->Context();
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
*rv = g->NumBits(); *rv = g->NumBits();
}); });
// Subgraph C APIs // Subgraph C APIs
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetGraph") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
SubgraphRef subg = args[0]; SubgraphRef subg = args[0];
*rv = GraphRef(subg->graph); *rv = GraphRef(subg->graph);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetInducedVertices") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetInducedVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
SubgraphRef subg = args[0]; SubgraphRef subg = args[0];
*rv = subg->induced_vertices; *rv = subg->induced_vertices;
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetInducedEdges") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetInducedEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
SubgraphRef subg = args[0]; SubgraphRef subg = args[0];
*rv = subg->induced_edges; *rv = subg->induced_edges;
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSortAdj") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSortAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
g->SortCSR(); g->SortCSR();
}); });
} // namespace dgl } // namespace dgl
This diff is collapsed.
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <dgl/graph_traversal.h> #include <dgl/graph_traversal.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include "../c_api_common.h" #include "../c_api_common.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -13,95 +14,92 @@ namespace dgl { ...@@ -13,95 +14,92 @@ namespace dgl {
namespace traverse { namespace traverse {
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes_v2") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
const IdArray src = args[1]; const IdArray src = args[1];
bool reversed = args[2]; bool reversed = args[2];
aten::CSRMatrix csr; aten::CSRMatrix csr;
if (reversed) { if (reversed) {
csr = g.sptr()->GetCSCMatrix(0); csr = g.sptr()->GetCSCMatrix(0);
} else { } else {
csr = g.sptr()->GetCSRMatrix(0); csr = g.sptr()->GetCSRMatrix(0);
} }
const auto& front = aten::BFSNodesFrontiers(csr, src); const auto& front = aten::BFSNodesFrontiers(csr, src);
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections}); *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
}); });
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges_v2") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
const IdArray src = args[1]; const IdArray src = args[1];
bool reversed = args[2]; bool reversed = args[2];
aten::CSRMatrix csr; aten::CSRMatrix csr;
if (reversed) { if (reversed) {
csr = g.sptr()->GetCSCMatrix(0); csr = g.sptr()->GetCSCMatrix(0);
} else { } else {
csr = g.sptr()->GetCSRMatrix(0); csr = g.sptr()->GetCSRMatrix(0);
} }
const auto& front = aten::BFSEdgesFrontiers(csr, src); const auto& front = aten::BFSEdgesFrontiers(csr, src);
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections}); *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
}); });
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes_v2") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
bool reversed = args[1]; bool reversed = args[1];
aten::CSRMatrix csr; aten::CSRMatrix csr;
if (reversed) { if (reversed) {
csr = g.sptr()->GetCSCMatrix(0); csr = g.sptr()->GetCSCMatrix(0);
} else { } else {
csr = g.sptr()->GetCSRMatrix(0); csr = g.sptr()->GetCSRMatrix(0);
} }
const auto& front = aten::TopologicalNodesFrontiers(csr);
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
});
const auto& front = aten::TopologicalNodesFrontiers(csr);
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
});
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges_v2") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
const IdArray source = args[1]; const IdArray source = args[1];
const bool reversed = args[2]; const bool reversed = args[2];
CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array."; CHECK(aten::IsValidIdArray(source)) << "Invalid source node id array.";
aten::CSRMatrix csr; aten::CSRMatrix csr;
if (reversed) { if (reversed) {
csr = g.sptr()->GetCSCMatrix(0); csr = g.sptr()->GetCSCMatrix(0);
} else { } else {
csr = g.sptr()->GetCSRMatrix(0); csr = g.sptr()->GetCSRMatrix(0);
} }
const auto& front = aten::DGLDFSEdges(csr, source); const auto& front = aten::DGLDFSEdges(csr, source);
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections}); *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
}); });
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges_v2") DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0]; HeteroGraphRef g = args[0];
const IdArray source = args[1]; const IdArray source = args[1];
const bool reversed = args[2]; const bool reversed = args[2];
const bool has_reverse_edge = args[3]; const bool has_reverse_edge = args[3];
const bool has_nontree_edge = args[4]; const bool has_nontree_edge = args[4];
const bool return_labels = args[5]; const bool return_labels = args[5];
aten::CSRMatrix csr; aten::CSRMatrix csr;
if (reversed) { if (reversed) {
csr = g.sptr()->GetCSCMatrix(0); csr = g.sptr()->GetCSCMatrix(0);
} else { } else {
csr = g.sptr()->GetCSRMatrix(0); csr = g.sptr()->GetCSRMatrix(0);
} }
const auto& front = aten::DGLDFSLabeledEdges(csr, const auto& front = aten::DGLDFSLabeledEdges(
source, csr, source, has_reverse_edge, has_nontree_edge, return_labels);
has_reverse_edge,
has_nontree_edge,
return_labels);
if (return_labels) { if (return_labels) {
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.tags, front.sections}); *rv = ConvertNDArrayVectorToPackedFunc(
} else { {front.ids, front.tags, front.sections});
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections}); } else {
} *rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
}); }
});
} // namespace traverse } // namespace traverse
} // namespace dgl } // namespace dgl
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
* \brief Call Metis partitioning * \brief Call Metis partitioning
*/ */
#include <metis.h>
#include <dgl/graph_op.h> #include <dgl/graph_op.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <metis.h>
#include "../c_api_common.h" #include "../c_api_common.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -24,13 +25,13 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) { ...@@ -24,13 +25,13 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) {
const auto mat = ig->GetInCSR()->ToCSRMatrix(); const auto mat = ig->GetInCSR()->ToCSRMatrix();
idx_t nvtxs = g->NumVertices(); idx_t nvtxs = g->NumVertices();
idx_t ncon = 1; // # balacing constraints. idx_t ncon = 1; // # balacing constraints.
idx_t *xadj = static_cast<idx_t*>(mat.indptr->data); idx_t *xadj = static_cast<idx_t *>(mat.indptr->data);
idx_t *adjncy = static_cast<idx_t*>(mat.indices->data); idx_t *adjncy = static_cast<idx_t *>(mat.indices->data);
idx_t nparts = k; idx_t nparts = k;
IdArray part_arr = aten::NewIdArray(nvtxs); IdArray part_arr = aten::NewIdArray(nvtxs);
idx_t objval = 0; idx_t objval = 0;
idx_t *part = static_cast<idx_t*>(part_arr->data); idx_t *part = static_cast<idx_t *>(part_arr->data);
int64_t vwgt_len = vwgt_arr->shape[0]; int64_t vwgt_len = vwgt_arr->shape[0];
CHECK_EQ(sizeof(idx_t), vwgt_arr->dtype.bits / 8) CHECK_EQ(sizeof(idx_t), vwgt_arr->dtype.bits / 8)
...@@ -40,7 +41,7 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) { ...@@ -40,7 +41,7 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) {
idx_t *vwgt = NULL; idx_t *vwgt = NULL;
if (vwgt_len > 0) { if (vwgt_len > 0) {
ncon = vwgt_len / g->NumVertices(); ncon = vwgt_len / g->NumVertices();
vwgt = static_cast<idx_t*>(vwgt_arr->data); vwgt = static_cast<idx_t *>(vwgt_arr->data);
} }
idx_t options[METIS_NOPTIONS]; idx_t options[METIS_NOPTIONS];
...@@ -56,21 +57,22 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) { ...@@ -56,21 +57,22 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) {
options[METIS_OPTION_OBJTYPE] = METIS_OBJTYPE_VOL; options[METIS_OPTION_OBJTYPE] = METIS_OBJTYPE_VOL;
} }
int ret = METIS_PartGraphKway(&nvtxs, // The number of vertices int ret = METIS_PartGraphKway(
&ncon, // The number of balancing constraints. &nvtxs, // The number of vertices
xadj, // indptr &ncon, // The number of balancing constraints.
adjncy, // indices xadj, // indptr
vwgt, // the weights of the vertices adjncy, // indices
NULL, // The size of the vertices for computing vwgt, // the weights of the vertices
// the total communication volume NULL, // The size of the vertices for computing
NULL, // The weights of the edges // the total communication volume
&nparts, // The number of partitions. NULL, // The weights of the edges
NULL, // the desired weight for each partition and constraint &nparts, // The number of partitions.
NULL, // the allowed load imbalance tolerance NULL, // the desired weight for each partition and constraint
options, // the array of options NULL, // the allowed load imbalance tolerance
&objval, // the edge-cut or the total communication volume of options, // the array of options
// the partitioning solution &objval, // the edge-cut or the total communication volume of
part); // the partitioning solution
part);
if (obj_cut) { if (obj_cut) {
LOG(INFO) << "Partition a graph with " << g->NumVertices() << " nodes and " LOG(INFO) << "Partition a graph with " << g->NumVertices() << " nodes and "
...@@ -99,16 +101,16 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) { ...@@ -99,16 +101,16 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) {
#endif // !defined(_WIN32) #endif // !defined(_WIN32)
DGL_REGISTER_GLOBAL("transform._CAPI_DGLMetisPartition") DGL_REGISTER_GLOBAL("transform._CAPI_DGLMetisPartition")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef g = args[0]; GraphRef g = args[0];
int k = args[1]; int k = args[1];
NDArray vwgt = args[2]; NDArray vwgt = args[2];
bool obj_cut = args[3]; bool obj_cut = args[3];
#if !defined(_WIN32) #if !defined(_WIN32)
*rv = MetisPartition(g.sptr(), k, vwgt, obj_cut); *rv = MetisPartition(g.sptr(), k, vwgt, obj_cut);
#else #else
LOG(FATAL) << "Metis partition does not support Windows."; LOG(FATAL) << "Metis partition does not support Windows.";
#endif // !defined(_WIN32) #endif // !defined(_WIN32)
}); });
} // namespace dgl } // namespace dgl
This diff is collapsed.
...@@ -6,12 +6,12 @@ ...@@ -6,12 +6,12 @@
#ifndef DGL_GRAPH_NETWORK_H_ #ifndef DGL_GRAPH_NETWORK_H_
#define DGL_GRAPH_NETWORK_H_ #define DGL_GRAPH_NETWORK_H_
#include <dmlc/logging.h>
#include <dgl/runtime/ndarray.h> #include <dgl/runtime/ndarray.h>
#include <dmlc/logging.h>
#include <string.h> #include <string.h>
#include <vector>
#include <string> #include <string>
#include <vector>
#include "../c_api_common.h" #include "../c_api_common.h"
#include "../rpc/network/msg_queue.h" #include "../rpc/network/msg_queue.h"
...@@ -24,10 +24,8 @@ namespace network { ...@@ -24,10 +24,8 @@ namespace network {
/*! /*!
* \brief Create NDArray from raw data * \brief Create NDArray from raw data
*/ */
NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape, NDArray CreateNDArrayFromRaw(
DGLDataType dtype, std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx, void* raw);
DGLContext ctx,
void* raw);
/*! /*!
* \brief Message type for DGL distributed training * \brief Message type for DGL distributed training
...@@ -63,19 +61,18 @@ enum MessageType { ...@@ -63,19 +61,18 @@ enum MessageType {
kBarrierMsg = 6, kBarrierMsg = 6,
/*! /*!
* \brief IP and ID msg for KVStore * \brief IP and ID msg for KVStore
*/ */
kIPIDMsg = 7, kIPIDMsg = 7,
/*! /*!
* \brief Get data shape msg for KVStore * \brief Get data shape msg for KVStore
*/ */
kGetShapeMsg = 8, kGetShapeMsg = 8,
/*! /*!
* \brief Get data shape back msg for KVStore * \brief Get data shape back msg for KVStore
*/ */
kGetShapeBackMsg = 9 kGetShapeBackMsg = 9
}; };
/*! /*!
* \brief Meta data for NDArray message * \brief Meta data for NDArray message
*/ */
...@@ -85,8 +82,7 @@ class ArrayMeta { ...@@ -85,8 +82,7 @@ class ArrayMeta {
* \brief ArrayMeta constructor. * \brief ArrayMeta constructor.
* \param msg_type type of message * \param msg_type type of message
*/ */
explicit ArrayMeta(int msg_type) explicit ArrayMeta(int msg_type) : msg_type_(msg_type), ndarray_count_(0) {}
: msg_type_(msg_type), ndarray_count_(0) {}
/*! /*!
* \brief Construct ArrayMeta from binary data buffer. * \brief Construct ArrayMeta from binary data buffer.
...@@ -101,16 +97,12 @@ class ArrayMeta { ...@@ -101,16 +97,12 @@ class ArrayMeta {
/*! /*!
* \return message type * \return message type
*/ */
inline int msg_type() const { inline int msg_type() const { return msg_type_; }
return msg_type_;
}
/*! /*!
* \return count of ndarray * \return count of ndarray
*/ */
inline int ndarray_count() const { inline int ndarray_count() const { return ndarray_count_; }
return ndarray_count_;
}
/*! /*!
* \brief Add NDArray meta data to ArrayMeta * \brief Add NDArray meta data to ArrayMeta
...@@ -148,8 +140,8 @@ class ArrayMeta { ...@@ -148,8 +140,8 @@ class ArrayMeta {
std::vector<DGLDataType> data_type_; std::vector<DGLDataType> data_type_;
/*! /*!
* \brief We first write the ndim to data_shape_ * \brief We first write the ndim to data_shape_
* and then write the data shape. * and then write the data shape.
*/ */
std::vector<int64_t> data_shape_; std::vector<int64_t> data_shape_;
}; };
...@@ -175,7 +167,7 @@ class KVStoreMsg { ...@@ -175,7 +167,7 @@ class KVStoreMsg {
} }
/*! /*!
* \brief Serialize KVStoreMsg to data buffer * \brief Serialize KVStoreMsg to data buffer
* Note that we don't serialize ID and data here. * Note that we don't serialize ID and data here.
* \param size size of serialized message * \param size size of serialized message
* \return pointer of data buffer * \return pointer of data buffer
*/ */
...@@ -188,29 +180,29 @@ class KVStoreMsg { ...@@ -188,29 +180,29 @@ class KVStoreMsg {
*/ */
void Deserialize(char* buffer, int64_t size); void Deserialize(char* buffer, int64_t size);
/*! /*!
* \brief Message type of kvstore * \brief Message type of kvstore
*/ */
int msg_type; int msg_type;
/*! /*!
* \brief Sender's ID * \brief Sender's ID
*/ */
int rank; int rank;
/*! /*!
* \brief data name * \brief data name
*/ */
std::string name; std::string name;
/*! /*!
* \brief data ID * \brief data ID
*/ */
NDArray id; NDArray id;
/*! /*!
* \brief data matrix * \brief data matrix
*/ */
NDArray data; NDArray data;
/*! /*!
* \brief data shape * \brief data shape
*/ */
NDArray shape; NDArray shape;
}; };
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
*/ */
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/nodeflow.h> #include <dgl/nodeflow.h>
#include <dgl/packed_func_ext.h>
#include <string> #include <string>
...@@ -19,15 +19,16 @@ using dgl::runtime::PackedFunc; ...@@ -19,15 +19,16 @@ using dgl::runtime::PackedFunc;
namespace dgl { namespace dgl {
std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::string &fmt, std::vector<IdArray> GetNodeFlowSlice(
size_t layer0_size, size_t layer1_start, const ImmutableGraph &graph, const std::string &fmt, size_t layer0_size,
size_t layer1_end, bool remap) { size_t layer1_start, size_t layer1_end, bool remap) {
CHECK_GE(layer1_start, layer0_size); CHECK_GE(layer1_start, layer0_size);
if (fmt == std::string("csr")) { if (fmt == std::string("csr")) {
dgl_id_t first_vid = layer1_start - layer0_size; dgl_id_t first_vid = layer1_start - layer0_size;
auto csr = aten::CSRSliceRows(graph.GetInCSR()->ToCSRMatrix(), layer1_start, layer1_end); auto csr = aten::CSRSliceRows(
graph.GetInCSR()->ToCSRMatrix(), layer1_start, layer1_end);
if (remap) { if (remap) {
dgl_id_t *eid_data = static_cast<dgl_id_t*>(csr.data->data); dgl_id_t *eid_data = static_cast<dgl_id_t *>(csr.data->data);
const dgl_id_t first_eid = eid_data[0]; const dgl_id_t first_eid = eid_data[0];
IdArray new_indices = aten::Sub(csr.indices, first_vid); IdArray new_indices = aten::Sub(csr.indices, first_vid);
IdArray new_data = aten::Sub(csr.data, first_eid); IdArray new_data = aten::Sub(csr.data, first_eid);
...@@ -37,14 +38,14 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st ...@@ -37,14 +38,14 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st
} }
} else if (fmt == std::string("coo")) { } else if (fmt == std::string("coo")) {
auto csr = graph.GetInCSR()->ToCSRMatrix(); auto csr = graph.GetInCSR()->ToCSRMatrix();
const dgl_id_t* indptr = static_cast<dgl_id_t*>(csr.indptr->data); const dgl_id_t *indptr = static_cast<dgl_id_t *>(csr.indptr->data);
const dgl_id_t* indices = static_cast<dgl_id_t*>(csr.indices->data); const dgl_id_t *indices = static_cast<dgl_id_t *>(csr.indices->data);
const dgl_id_t* edge_ids = static_cast<dgl_id_t*>(csr.data->data); const dgl_id_t *edge_ids = static_cast<dgl_id_t *>(csr.data->data);
int64_t nnz = indptr[layer1_end] - indptr[layer1_start]; int64_t nnz = indptr[layer1_end] - indptr[layer1_start];
IdArray idx = aten::NewIdArray(2 * nnz); IdArray idx = aten::NewIdArray(2 * nnz);
IdArray eid = aten::NewIdArray(nnz); IdArray eid = aten::NewIdArray(nnz);
int64_t *idx_data = static_cast<int64_t*>(idx->data); int64_t *idx_data = static_cast<int64_t *>(idx->data);
dgl_id_t *eid_data = static_cast<dgl_id_t*>(eid->data); dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data);
size_t num_edges = 0; size_t num_edges = 0;
for (size_t i = layer1_start; i < layer1_end; i++) { for (size_t i = layer1_start; i < layer1_end; i++) {
for (dgl_id_t j = indptr[i]; j < indptr[i + 1]; j++) { for (dgl_id_t j = indptr[i]; j < indptr[i + 1]; j++) {
...@@ -65,10 +66,12 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st ...@@ -65,10 +66,12 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st
eid_data[i] = edge_ids[edge_start + i] - first_eid; eid_data[i] = edge_ids[edge_start + i] - first_eid;
} }
} else { } else {
std::copy(indices + indptr[layer1_start], std::copy(
indices + indptr[layer1_end], idx_data + nnz); indices + indptr[layer1_start], indices + indptr[layer1_end],
std::copy(edge_ids + indptr[layer1_start], idx_data + nnz);
edge_ids + indptr[layer1_end], eid_data); std::copy(
edge_ids + indptr[layer1_start], edge_ids + indptr[layer1_end],
eid_data);
} }
return std::vector<IdArray>{idx, eid}; return std::vector<IdArray>{idx, eid};
} else { } else {
...@@ -78,16 +81,17 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st ...@@ -78,16 +81,17 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st
} }
DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetBlockAdj") DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetBlockAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef g = args[0]; GraphRef g = args[0];
std::string format = args[1]; std::string format = args[1];
int64_t layer0_size = args[2]; int64_t layer0_size = args[2];
int64_t start = args[3]; int64_t start = args[3];
int64_t end = args[4]; int64_t end = args[4];
const bool remap = args[5]; const bool remap = args[5];
auto ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr())); auto ig =
auto res = GetNodeFlowSlice(*ig, format, layer0_size, start, end, remap); CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
*rv = ConvertNDArrayVectorToPackedFunc(res); auto res = GetNodeFlowSlice(*ig, format, layer0_size, start, end, remap);
}); *rv = ConvertNDArrayVectorToPackedFunc(res);
});
} // namespace dgl } // namespace dgl
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment