Unverified Commit 81831111 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4805)



* [Misc] clang-format auto fix.

* fix
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 16e771c0
......@@ -4,9 +4,11 @@
* \brief Geometry operator CPU implementation
*/
#include <dgl/random.h>
#include <numeric>
#include <vector>
#include <utility>
#include <vector>
#include "../geometry_op.h"
namespace dgl {
......@@ -25,8 +27,9 @@ void IndexShuffle(IdType *idxs, int64_t num_elems) {
template void IndexShuffle<int32_t>(int32_t *idxs, int64_t num_elems);
template void IndexShuffle<int64_t>(int64_t *idxs, int64_t num_elems);
/*! \brief Groupwise index shuffle algorithm. This function will perform shuffle in subarrays
* indicated by group index. The group index is similar to indptr in CSRMatrix.
/*! \brief Groupwise index shuffle algorithm. This function will perform shuffle
* in subarrays indicated by group index. The group index is similar to indptr
* in CSRMatrix.
*
* \param group_idxs group index array.
* \param idxs index array for shuffle.
......@@ -34,64 +37,73 @@ template void IndexShuffle<int64_t>(int64_t *idxs, int64_t num_elems);
* \param num_elems length of idxs
*/
template <typename IdType>
void GroupIndexShuffle(const IdType *group_idxs, IdType *idxs,
int64_t num_groups_idxs, int64_t num_elems) {
void GroupIndexShuffle(
const IdType *group_idxs, IdType *idxs, int64_t num_groups_idxs,
int64_t num_elems) {
if (num_groups_idxs < 2) return; // empty idxs array
CHECK_LE(group_idxs[num_groups_idxs - 1], num_elems) << "group_idxs out of range";
CHECK_LE(group_idxs[num_groups_idxs - 1], num_elems)
<< "group_idxs out of range";
for (int64_t i = 0; i < num_groups_idxs - 1; ++i) {
auto subarray_len = group_idxs[i + 1] - group_idxs[i];
IndexShuffle(idxs + group_idxs[i], subarray_len);
}
}
template void GroupIndexShuffle<int32_t>(
const int32_t *group_idxs, int32_t *idxs, int64_t num_groups_idxs, int64_t num_elems);
const int32_t *group_idxs, int32_t *idxs, int64_t num_groups_idxs,
int64_t num_elems);
template void GroupIndexShuffle<int64_t>(
const int64_t *group_idxs, int64_t *idxs, int64_t num_groups_idxs, int64_t num_elems);
const int64_t *group_idxs, int64_t *idxs, int64_t num_groups_idxs,
int64_t num_elems);
template <typename IdType>
IdArray RandomPerm(int64_t num_nodes) {
IdArray perm = aten::NewIdArray(num_nodes, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
IdType* perm_data = static_cast<IdType*>(perm->data);
IdArray perm =
aten::NewIdArray(num_nodes, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
IdType *perm_data = static_cast<IdType *>(perm->data);
std::iota(perm_data, perm_data + num_nodes, 0);
IndexShuffle(perm_data, num_nodes);
return perm;
}
template <typename IdType>
IdArray GroupRandomPerm(const IdType *group_idxs, int64_t num_group_idxs, int64_t num_nodes) {
IdArray perm = aten::NewIdArray(num_nodes, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
IdType* perm_data = static_cast<IdType*>(perm->data);
IdArray GroupRandomPerm(
const IdType *group_idxs, int64_t num_group_idxs, int64_t num_nodes) {
IdArray perm =
aten::NewIdArray(num_nodes, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
IdType *perm_data = static_cast<IdType *>(perm->data);
std::iota(perm_data, perm_data + num_nodes, 0);
GroupIndexShuffle(group_idxs, perm_data, num_group_idxs, num_nodes);
return perm;
}
/*!
* \brief Farthest Point Sampler without the need to compute all pairs of distance.
* \brief Farthest Point Sampler without the need to compute all pairs of
* distance.
*
* The input array has shape (N, d), where N is the number of points, and d is the dimension.
* It consists of a (flatten) batch of point clouds.
* The input array has shape (N, d), where N is the number of points, and d is
* the dimension. It consists of a (flatten) batch of point clouds.
*
* In each batch, the algorithm starts with the sample index specified by ``start_idx``.
* Then for each point, we maintain the minimum to-sample distance.
* Finally, we pick the point with the maximum such distance.
* This process will be repeated for ``sample_points`` - 1 times.
* In each batch, the algorithm starts with the sample index specified by
* ``start_idx``. Then for each point, we maintain the minimum to-sample
* distance. Finally, we pick the point with the maximum such distance. This
* process will be repeated for ``sample_points`` - 1 times.
*/
template <DGLDeviceType XPU, typename FloatType, typename IdType>
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result) {
const FloatType* array_data = static_cast<FloatType*>(array->data);
void FarthestPointSampler(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
IdArray start_idx, IdArray result) {
const FloatType *array_data = static_cast<FloatType *>(array->data);
const int64_t point_in_batch = array->shape[0] / batch_size;
const int64_t dim = array->shape[1];
// distance
FloatType* dist_data = static_cast<FloatType*>(dist->data);
FloatType *dist_data = static_cast<FloatType *>(dist->data);
// sample for each cloud in the batch
IdType* start_idx_data = static_cast<IdType*>(start_idx->data);
IdType *start_idx_data = static_cast<IdType *>(start_idx->data);
// return value
IdType* ret_data = static_cast<IdType*>(result->data);
IdType *ret_data = static_cast<IdType *>(result->data);
int64_t array_start = 0, ret_start = 0;
// loop for each point cloud sample in this batch
......@@ -136,29 +148,30 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin
}
}
template void FarthestPointSampler<kDGLCPU, float, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDGLCPU, float, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDGLCPU, double, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDGLCPU, double, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
IdArray start_idx, IdArray result);
template <DGLDeviceType XPU, typename FloatType, typename IdType>
void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result) {
void WeightedNeighborMatching(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result) {
const int64_t num_nodes = result->shape[0];
const IdType *indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType*>(csr.indices->data);
IdType *result_data = static_cast<IdType*>(result->data);
FloatType *weight_data = static_cast<FloatType*>(weight->data);
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
IdType *result_data = static_cast<IdType *>(result->data);
FloatType *weight_data = static_cast<FloatType *>(weight->data);
// build node visiting order
IdArray vis_order = RandomPerm<IdType>(num_nodes);
IdType *vis_order_data = static_cast<IdType*>(vis_order->data);
IdType *vis_order_data = static_cast<IdType *>(vis_order->data);
for (int64_t n = 0; n < num_nodes; ++n) {
auto u = vis_order_data[n];
......@@ -193,16 +206,16 @@ template void WeightedNeighborMatching<kDGLCPU, double, int64_t>(
template <DGLDeviceType XPU, typename IdType>
void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
const int64_t num_nodes = result->shape[0];
const IdType *indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType*>(csr.indices->data);
IdType *result_data = static_cast<IdType*>(result->data);
const IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
const IdType *indices_data = static_cast<IdType *>(csr.indices->data);
IdType *result_data = static_cast<IdType *>(result->data);
// build vis order
IdArray u_vis_order = RandomPerm<IdType>(num_nodes);
IdType *u_vis_order_data = static_cast<IdType*>(u_vis_order->data);
IdType *u_vis_order_data = static_cast<IdType *>(u_vis_order->data);
IdArray v_vis_order = GroupRandomPerm<IdType>(
indptr_data, csr.indptr->shape[0], csr.indices->shape[0]);
IdType *v_vis_order_data = static_cast<IdType*>(v_vis_order->data);
IdType *v_vis_order_data = static_cast<IdType *>(v_vis_order->data);
for (int64_t n = 0; n < num_nodes; ++n) {
auto u = u_vis_order_data[n];
......@@ -221,8 +234,10 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
}
}
}
template void NeighborMatching<kDGLCPU, int32_t>(const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDGLCPU, int64_t>(const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDGLCPU, int32_t>(
const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDGLCPU, int64_t>(
const aten::CSRMatrix &csr, IdArray result);
} // namespace impl
} // namespace geometry
......
......@@ -3,14 +3,16 @@
* \file geometry/cuda/edge_coarsening_impl.cu
* \brief Edge coarsening CUDA implementation
*/
#include <curand_kernel.h>
#include <dgl/array.h>
#include <dgl/random.h>
#include <dmlc/thread_local.h>
#include <curand_kernel.h>
#include <cstdint>
#include "../geometry_op.h"
#include "../../runtime/cuda/cuda_common.h"
#include "../../array/cuda/utils.h"
#include "../../runtime/cuda/cuda_common.h"
#include "../geometry_op.h"
#define BLOCKS(N, T) (N + T - 1) / T
......@@ -26,7 +28,8 @@ constexpr int EMPTY_IDX = -1;
__device__ bool done_d;
__global__ void init_done_kernel() { done_d = true; }
__global__ void generate_uniform_kernel(float *ret_values, size_t num, uint64_t seed) {
__global__ void generate_uniform_kernel(
float *ret_values, size_t num, uint64_t seed) {
size_t id = blockIdx.x * blockDim.x + threadIdx.x;
if (id < num) {
curandState state;
......@@ -36,7 +39,8 @@ __global__ void generate_uniform_kernel(float *ret_values, size_t num, uint64_t
}
template <typename IdType>
__global__ void colorize_kernel(const float *prop, int64_t num_elem, IdType *result) {
__global__ void colorize_kernel(
const float *prop, int64_t num_elem, IdType *result) {
const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_elem) {
if (result[idx] < 0) { // if unmatched
......@@ -47,9 +51,9 @@ __global__ void colorize_kernel(const float *prop, int64_t num_elem, IdType *res
}
template <typename FloatType, typename IdType>
__global__ void weighted_propose_kernel(const IdType *indptr, const IdType *indices,
const FloatType *weights, int64_t num_elem,
IdType *proposal, IdType *result) {
__global__ void weighted_propose_kernel(
const IdType *indptr, const IdType *indices, const FloatType *weights,
int64_t num_elem, IdType *proposal, IdType *result) {
const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_elem) {
if (result[idx] != BLUE) return;
......@@ -61,8 +65,7 @@ __global__ void weighted_propose_kernel(const IdType *indptr, const IdType *indi
for (IdType i = indptr[idx]; i < indptr[idx + 1]; ++i) {
auto v = indices[i];
if (result[v] < 0)
has_unmatched_neighbor = true;
if (result[v] < 0) has_unmatched_neighbor = true;
if (result[v] == RED && weights[i] >= weight_max) {
v_max = v;
weight_max = weights[i];
......@@ -70,15 +73,14 @@ __global__ void weighted_propose_kernel(const IdType *indptr, const IdType *indi
}
proposal[idx] = v_max;
if (!has_unmatched_neighbor)
result[idx] = idx;
if (!has_unmatched_neighbor) result[idx] = idx;
}
}
template <typename FloatType, typename IdType>
__global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indices,
const FloatType *weights, int64_t num_elem,
IdType *proposal, IdType *result) {
__global__ void weighted_respond_kernel(
const IdType *indptr, const IdType *indices, const FloatType *weights,
int64_t num_elem, IdType *proposal, IdType *result) {
const IdType idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_elem) {
if (result[idx] != RED) return;
......@@ -93,9 +95,7 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi
if (result[v] < 0) {
has_unmatched_neighbors = true;
}
if (result[v] == BLUE
&& proposal[v] == idx
&& weights[i] >= weight_max) {
if (result[v] == BLUE && proposal[v] == idx && weights[i] >= weight_max) {
v_max = v;
weight_max = weights[i];
}
......@@ -105,8 +105,7 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi
result[idx] = min(idx, v_max);
}
if (!has_unmatched_neighbors)
result[idx] = idx;
if (!has_unmatched_neighbors) result[idx] = idx;
}
}
......@@ -114,8 +113,8 @@ __global__ void weighted_respond_kernel(const IdType *indptr, const IdType *indi
* nodes with BLUE(-1) and RED(-2) and checks whether the node matching
* process has finished.
*/
template<typename IdType>
bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) {
template <typename IdType>
bool Colorize(IdType *result_data, int64_t num_nodes, float *const prop) {
// initial done signal
cudaStream_t stream = runtime::getCurrentCUDAStream();
CUDA_KERNEL_CALL(init_done_kernel, 1, 1, 0, stream);
......@@ -124,14 +123,17 @@ bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) {
uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);
auto num_threads = cuda::FindNumThreads(num_nodes);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));
CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, stream,
prop, num_nodes, seed);
CUDA_KERNEL_CALL(
generate_uniform_kernel, num_blocks, num_threads, 0, stream, prop,
num_nodes, seed);
// call kernel
CUDA_KERNEL_CALL(colorize_kernel, num_blocks, num_threads, 0, stream,
prop, num_nodes, result_data);
CUDA_KERNEL_CALL(
colorize_kernel, num_blocks, num_threads, 0, stream, prop, num_nodes,
result_data);
bool done_h = false;
CUDA_CALL(cudaMemcpyFromSymbol(&done_h, done_d, sizeof(done_h), 0, cudaMemcpyDeviceToHost));
CUDA_CALL(cudaMemcpyFromSymbol(
&done_h, done_d, sizeof(done_h), 0, cudaMemcpyDeviceToHost));
return done_h;
}
......@@ -151,9 +153,10 @@ bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) {
* pair and mark them with the smaller id between them.
*/
template <DGLDeviceType XPU, typename FloatType, typename IdType>
void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result) {
void WeightedNeighborMatching(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = result->ctx;
const auto &ctx = result->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
device->SetDevice(ctx);
......@@ -162,23 +165,27 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight,
IdArray proposal = aten::Full(-1, num_nodes, sizeof(IdType) * 8, ctx);
// get data ptrs
IdType *indptr_data = static_cast<IdType*>(csr.indptr->data);
IdType *indices_data = static_cast<IdType*>(csr.indices->data);
IdType *result_data = static_cast<IdType*>(result->data);
IdType *proposal_data = static_cast<IdType*>(proposal->data);
FloatType *weight_data = static_cast<FloatType*>(weight->data);
IdType *indptr_data = static_cast<IdType *>(csr.indptr->data);
IdType *indices_data = static_cast<IdType *>(csr.indices->data);
IdType *result_data = static_cast<IdType *>(result->data);
IdType *proposal_data = static_cast<IdType *>(proposal->data);
FloatType *weight_data = static_cast<FloatType *>(weight->data);
// allocate workspace for prop used in Colorize()
float *prop = static_cast<float*>(
float *prop = static_cast<float *>(
device->AllocWorkspace(ctx, num_nodes * sizeof(float)));
auto num_threads = cuda::FindNumThreads(num_nodes);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_nodes, num_threads));
while (!Colorize<IdType>(result_data, num_nodes, prop)) {
CUDA_KERNEL_CALL(weighted_propose_kernel, num_blocks, num_threads, 0, stream,
indptr_data, indices_data, weight_data, num_nodes, proposal_data, result_data);
CUDA_KERNEL_CALL(weighted_respond_kernel, num_blocks, num_threads, 0, stream,
indptr_data, indices_data, weight_data, num_nodes, proposal_data, result_data);
CUDA_KERNEL_CALL(
weighted_propose_kernel, num_blocks, num_threads, 0, stream,
indptr_data, indices_data, weight_data, num_nodes, proposal_data,
result_data);
CUDA_KERNEL_CALL(
weighted_respond_kernel, num_blocks, num_threads, 0, stream,
indptr_data, indices_data, weight_data, num_nodes, proposal_data,
result_data);
}
device->FreeWorkspace(ctx, prop);
}
......@@ -204,7 +211,7 @@ template void WeightedNeighborMatching<kDGLCUDA, double, int64_t>(
template <DGLDeviceType XPU, typename IdType>
void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
const int64_t num_edges = csr.indices->shape[0];
const auto& ctx = result->ctx;
const auto &ctx = result->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
device->SetDevice(ctx);
......@@ -212,17 +219,20 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
NDArray weight = NDArray::Empty(
{num_edges}, DGLDataType{kDGLFloat, sizeof(float) * 8, 1}, ctx);
float *weight_data = static_cast<float*>(weight->data);
float *weight_data = static_cast<float *>(weight->data);
uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);
auto num_threads = cuda::FindNumThreads(num_edges);
auto num_blocks = cuda::FindNumBlocks<'x'>(BLOCKS(num_edges, num_threads));
CUDA_KERNEL_CALL(generate_uniform_kernel, num_blocks, num_threads, 0, stream,
weight_data, num_edges, seed);
CUDA_KERNEL_CALL(
generate_uniform_kernel, num_blocks, num_threads, 0, stream, weight_data,
num_edges, seed);
WeightedNeighborMatching<XPU, float, IdType>(csr, weight, result);
}
template void NeighborMatching<kDGLCUDA, int32_t>(const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDGLCUDA, int64_t>(const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDGLCUDA, int32_t>(
const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDGLCUDA, int64_t>(
const aten::CSRMatrix &csr, IdArray result);
} // namespace impl
} // namespace geometry
......
......@@ -5,8 +5,8 @@
*/
#include <dgl/array.h>
#include "../../runtime/cuda/cuda_common.h"
#include "../../c_api_common.h"
#include "../../runtime/cuda/cuda_common.h"
#include "../geometry_op.h"
#define THREADS 1024
......@@ -16,21 +16,23 @@ namespace geometry {
namespace impl {
/*!
* \brief Farthest Point Sampler without the need to compute all pairs of distance.
* \brief Farthest Point Sampler without the need to compute all pairs of
* distance.
*
* The input array has shape (N, d), where N is the number of points, and d is the dimension.
* It consists of a (flatten) batch of point clouds.
* The input array has shape (N, d), where N is the number of points, and d is
* the dimension. It consists of a (flatten) batch of point clouds.
*
* In each batch, the algorithm starts with the sample index specified by ``start_idx``.
* Then for each point, we maintain the minimum to-sample distance.
* Finally, we pick the point with the maximum such distance.
* This process will be repeated for ``sample_points`` - 1 times.
* In each batch, the algorithm starts with the sample index specified by
* ``start_idx``. Then for each point, we maintain the minimum to-sample
* distance. Finally, we pick the point with the maximum such distance. This
* process will be repeated for ``sample_points`` - 1 times.
*/
template <typename FloatType, typename IdType>
__global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size,
__global__ void fps_kernel(
const FloatType* array_data, const int64_t batch_size,
const int64_t sample_points, const int64_t point_in_batch,
const int64_t dim, const IdType *start_idx,
FloatType *dist_data, IdType *ret_data) {
const int64_t dim, const IdType* start_idx, FloatType* dist_data,
IdType* ret_data) {
const int64_t thread_idx = threadIdx.x;
const int64_t batch_idx = blockIdx.x;
......@@ -90,8 +92,9 @@ __global__ void fps_kernel(const FloatType *array_data, const int64_t batch_size
}
template <DGLDeviceType XPU, typename FloatType, typename IdType>
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result) {
void FarthestPointSampler(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
IdArray start_idx, IdArray result) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const FloatType* array_data = static_cast<FloatType*>(array->data);
......@@ -109,24 +112,23 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin
IdType* start_idx_data = static_cast<IdType*>(start_idx->data);
CUDA_CALL(cudaSetDevice(array->ctx.device_id));
CUDA_KERNEL_CALL(fps_kernel,
batch_size, THREADS, 0, stream,
array_data, batch_size, sample_points,
point_in_batch, dim, start_idx_data, dist_data, ret_data);
CUDA_KERNEL_CALL(
fps_kernel, batch_size, THREADS, 0, stream, array_data, batch_size,
sample_points, point_in_batch, dim, start_idx_data, dist_data, ret_data);
}
template void FarthestPointSampler<kDGLCUDA, float, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDGLCUDA, float, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDGLCUDA, double, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDGLCUDA, double, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
IdArray start_idx, IdArray result);
} // namespace impl
} // namespace geometry
......
......@@ -4,28 +4,34 @@
* \brief DGL geometry utilities implementation
*/
#include <dgl/array.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/base_heterograph.h>
#include <dgl/runtime/ndarray.h>
#include "../array/check.h"
#include "../c_api_common.h"
#include "./geometry_op.h"
#include "../array/check.h"
using namespace dgl::runtime;
namespace dgl {
namespace geometry {
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result) {
CHECK_EQ(array->ctx, result->ctx) << "Array and the result should be on the same device.";
CHECK_EQ(array->shape[0], dist->shape[0]) << "Shape of array and dist mismatch";
CHECK_EQ(start_idx->shape[0], batch_size) << "Shape of start_idx and batch_size mismatch";
CHECK_EQ(result->shape[0], batch_size * sample_points) << "Invalid shape of result";
void FarthestPointSampler(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
IdArray start_idx, IdArray result) {
CHECK_EQ(array->ctx, result->ctx)
<< "Array and the result should be on the same device.";
CHECK_EQ(array->shape[0], dist->shape[0])
<< "Shape of array and dist mismatch";
CHECK_EQ(start_idx->shape[0], batch_size)
<< "Shape of start_idx and batch_size mismatch";
CHECK_EQ(result->shape[0], batch_size * sample_points)
<< "Invalid shape of result";
ATEN_FLOAT_TYPE_SWITCH(array->dtype, FloatType, "values", {
ATEN_ID_TYPE_SWITCH(result->dtype, IdType, {
ATEN_XPU_SWITCH_CUDA(array->ctx.device_type, XPU, "FarthestPointSampler", {
ATEN_XPU_SWITCH_CUDA(
array->ctx.device_type, XPU, "FarthestPointSampler", {
impl::FarthestPointSampler<XPU, FloatType, IdType>(
array, batch_size, sample_points, dist, start_idx, result);
});
......@@ -33,9 +39,11 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin
});
}
void NeighborMatching(HeteroGraphPtr graph, const NDArray weight, IdArray result) {
void NeighborMatching(
HeteroGraphPtr graph, const NDArray weight, IdArray result) {
if (!aten::IsNullArray(weight)) {
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "NeighborMatching", {
ATEN_XPU_SWITCH_CUDA(
graph->Context().device_type, XPU, "NeighborMatching", {
ATEN_FLOAT_TYPE_SWITCH(weight->dtype, FloatType, "weight", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
impl::WeightedNeighborMatching<XPU, FloatType, IdType>(
......@@ -44,10 +52,10 @@ void NeighborMatching(HeteroGraphPtr graph, const NDArray weight, IdArray result
});
});
} else {
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "NeighborMatching", {
ATEN_XPU_SWITCH_CUDA(
graph->Context().device_type, XPU, "NeighborMatching", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
impl::NeighborMatching<XPU, IdType>(
graph->GetCSRMatrix(0), result);
impl::NeighborMatching<XPU, IdType>(graph->GetCSRMatrix(0), result);
});
});
}
......@@ -56,7 +64,7 @@ void NeighborMatching(HeteroGraphPtr graph, const NDArray weight, IdArray result
///////////////////////// C APIs /////////////////////////
DGL_REGISTER_GLOBAL("geometry._CAPI_FarthestPointSampler")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const NDArray data = args[0];
const int64_t batch_size = args[1];
const int64_t sample_points = args[2];
......@@ -64,24 +72,28 @@ DGL_REGISTER_GLOBAL("geometry._CAPI_FarthestPointSampler")
IdArray start_idx = args[4];
IdArray result = args[5];
FarthestPointSampler(data, batch_size, sample_points, dist, start_idx, result);
FarthestPointSampler(
data, batch_size, sample_points, dist, start_idx, result);
});
DGL_REGISTER_GLOBAL("geometry._CAPI_NeighborMatching")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const NDArray weight = args[1];
IdArray result = args[2];
// sanity check
aten::CheckCtx(graph->Context(), {weight, result}, {"edge_weight, result"});
aten::CheckCtx(
graph->Context(), {weight, result}, {"edge_weight, result"});
aten::CheckContiguous({weight, result}, {"edge_weight", "result"});
CHECK_EQ(graph->NumEdgeTypes(), 1) << "homogeneous graph has only one edge type";
CHECK_EQ(graph->NumEdgeTypes(), 1)
<< "homogeneous graph has only one edge type";
CHECK_EQ(result->ndim, 1) << "result should be an 1D tensor.";
auto pair = graph->meta_graph()->FindEdge(0);
const dgl_type_t node_type = pair.first;
CHECK_EQ(graph->NumVertices(node_type), result->shape[0])
<< "The number of nodes should be the same as the length of result tensor.";
<< "The number of nodes should be the same as the length of result "
"tensor.";
if (!aten::IsNullArray(weight)) {
CHECK_EQ(weight->ndim, 1) << "weight should be an 1D tensor.";
CHECK_EQ(graph->NumEdges(0), weight->shape[0])
......
......@@ -13,16 +13,19 @@ namespace geometry {
namespace impl {
template <DGLDeviceType XPU, typename FloatType, typename IdType>
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
void FarthestPointSampler(
NDArray array, int64_t batch_size, int64_t sample_points, NDArray dist,
IdArray start_idx, IdArray result);
/*! \brief Implementation of weighted neighbor matching process of edge coarsening used
* in Metis and Graclus for homogeneous graph coarsening. This procedure keeps
* picking an unmarked vertex and matching it with one its unmarked neighbors
* (that maximizes its edge weight) until no match can be done.
/*! \brief Implementation of weighted neighbor matching process of edge
* coarsening used in Metis and Graclus for homogeneous graph coarsening. This
* procedure keeps picking an unmarked vertex and matching it with one its
* unmarked neighbors (that maximizes its edge weight) until no match can be
* done.
*/
template <DGLDeviceType XPU, typename FloatType, typename IdType>
void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
void WeightedNeighborMatching(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
/*! \brief Implementation of neighbor matching process of edge coarsening used
* in Metis and Graclus for homogeneous graph coarsening. This procedure keeps
......
......@@ -10,60 +10,51 @@ namespace dgl {
// creator implementation
HeteroGraphPtr CreateHeteroGraph(
GraphPtr meta_graph,
const std::vector<HeteroGraphPtr>& rel_graphs,
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type) {
return HeteroGraphPtr(new HeteroGraph(meta_graph, rel_graphs, num_nodes_per_type));
return HeteroGraphPtr(
new HeteroGraph(meta_graph, rel_graphs, num_nodes_per_type));
}
HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray row, IdArray col,
bool row_sorted, bool col_sorted, dgl_format_code_t formats) {
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray row,
IdArray col, bool row_sorted, bool col_sorted, dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCOO(
num_vtypes, num_src, num_dst, row, col, row_sorted, col_sorted, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
HeteroGraphPtr CreateFromCOO(
int64_t num_vtypes, const aten::COOMatrix& mat,
dgl_format_code_t formats) {
int64_t num_vtypes, const aten::COOMatrix& mat, dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCOO(num_vtypes, mat, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
dgl_format_code_t formats) {
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCSR(
num_vtypes, num_src, num_dst, indptr, indices, edge_ids, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
HeteroGraphPtr CreateFromCSR(
int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats) {
int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCSR(num_vtypes, mat, formats);
auto ret = HeteroGraphPtr(new HeteroGraph(
unit_g->meta_graph(),
{unit_g}));
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(),
{unit_g}));
auto ret = HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids,
dgl_format_code_t formats) {
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCSC(
num_vtypes, num_src, num_dst, indptr, indices, edge_ids, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
HeteroGraphPtr CreateFromCSC(
int64_t num_vtypes, const aten::CSRMatrix& mat,
dgl_format_code_t formats) {
int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats) {
auto unit_g = UnitGraph::CreateFromCSC(num_vtypes, mat, formats);
return HeteroGraphPtr(new HeteroGraph(unit_g->meta_graph(), {unit_g}));
}
......
......@@ -37,16 +37,18 @@ gk_csr_t *Convert2GKCsr(const aten::CSRMatrix mat, bool is_row) {
size_t num_ptrs;
if (is_row) {
num_ptrs = gk_csr->nrows + 1;
gk_indptr = gk_csr->rowptr = gk_zmalloc(gk_csr->nrows+1,
const_cast<char*>("gk_csr_ExtractPartition: rowptr"));
gk_indices = gk_csr->rowind = gk_imalloc(nnz,
const_cast<char*>("gk_csr_ExtractPartition: rowind"));
gk_indptr = gk_csr->rowptr = gk_zmalloc(
gk_csr->nrows + 1,
const_cast<char *>("gk_csr_ExtractPartition: rowptr"));
gk_indices = gk_csr->rowind =
gk_imalloc(nnz, const_cast<char *>("gk_csr_ExtractPartition: rowind"));
} else {
num_ptrs = gk_csr->ncols + 1;
gk_indptr = gk_csr->colptr = gk_zmalloc(gk_csr->ncols+1,
const_cast<char*>("gk_csr_ExtractPartition: colptr"));
gk_indices = gk_csr->colind = gk_imalloc(nnz,
const_cast<char*>("gk_csr_ExtractPartition: colind"));
gk_indptr = gk_csr->colptr = gk_zmalloc(
gk_csr->ncols + 1,
const_cast<char *>("gk_csr_ExtractPartition: colptr"));
gk_indices = gk_csr->colind =
gk_imalloc(nnz, const_cast<char *>("gk_csr_ExtractPartition: colind"));
}
for (size_t i = 0; i < num_ptrs; i++) {
......@@ -98,7 +100,8 @@ aten::CSRMatrix Convert2DGLCsr(gk_csr_t *gk_csr, bool is_row) {
eids[i] = i;
}
return aten::CSRMatrix(gk_csr->nrows, gk_csr->ncols, indptr_arr, indices_arr, eids_arr);
return aten::CSRMatrix(
gk_csr->nrows, gk_csr->ncols, indptr_arr, indices_arr, eids_arr);
}
#endif // !defined(_WIN32)
......
......@@ -5,11 +5,13 @@
*/
#include <dgl/graph.h>
#include <dgl/sampler.h>
#include <algorithm>
#include <unordered_map>
#include <set>
#include <functional>
#include <set>
#include <tuple>
#include <unordered_map>
#include "../c_api_common.h"
namespace dgl {
......@@ -21,8 +23,8 @@ Graph::Graph(IdArray src_ids, IdArray dst_ids, size_t num_nodes) {
num_edges_ = src_ids->shape[0];
CHECK(static_cast<int64_t>(num_edges_) == dst_ids->shape[0])
<< "vectors in COO must have the same length";
const dgl_id_t *src_data = static_cast<dgl_id_t*>(src_ids->data);
const dgl_id_t *dst_data = static_cast<dgl_id_t*>(dst_ids->data);
const dgl_id_t* src_data = static_cast<dgl_id_t*>(src_ids->data);
const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst_ids->data);
all_edges_src_.reserve(num_edges_);
all_edges_dst_.reserve(num_edges_);
for (uint64_t i = 0; i < num_edges_; i++) {
......@@ -53,15 +55,15 @@ bool Graph::IsMultigraph() const {
pairs.emplace_back(all_edges_src_[eid], all_edges_dst_[eid]);
}
// sort according to src and dst ids
std::sort(pairs.begin(), pairs.end(),
[] (const Pair& t1, const Pair& t2) {
return std::get<0>(t1) < std::get<0>(t2)
|| (std::get<0>(t1) == std::get<0>(t2) && std::get<1>(t1) < std::get<1>(t2));
std::sort(pairs.begin(), pairs.end(), [](const Pair& t1, const Pair& t2) {
return std::get<0>(t1) < std::get<0>(t2) ||
(std::get<0>(t1) == std::get<0>(t2) &&
std::get<1>(t1) < std::get<1>(t2));
});
for (uint64_t eid = 0; eid < num_edges_-1; ++eid) {
for (uint64_t eid = 0; eid < num_edges_ - 1; ++eid) {
// As src and dst are all sorted, we only need to compare i and i+1
if (std::get<0>(pairs[eid]) == std::get<0>(pairs[eid+1]) &&
std::get<1>(pairs[eid]) == std::get<1>(pairs[eid+1]))
if (std::get<0>(pairs[eid]) == std::get<0>(pairs[eid + 1]) &&
std::get<1>(pairs[eid]) == std::get<1>(pairs[eid + 1]))
return true;
}
......@@ -125,7 +127,7 @@ BoolArray Graph::HasVertices(IdArray vids) const {
int64_t* rst_data = static_cast<int64_t*>(rst->data);
const int64_t nverts = NumVertices();
for (int64_t i = 0; i < len; ++i) {
rst_data[i] = (vid_data[i] < nverts && vid_data[i] >= 0)? 1 : 0;
rst_data[i] = (vid_data[i] < nverts && vid_data[i] >= 0) ? 1 : 0;
}
return rst;
}
......@@ -151,18 +153,18 @@ BoolArray Graph::HasEdgesBetween(IdArray src_ids, IdArray dst_ids) const {
if (srclen == 1) {
// one-many
for (int64_t i = 0; i < dstlen; ++i) {
rst_data[i] = HasEdgeBetween(src_data[0], dst_data[i])? 1 : 0;
rst_data[i] = HasEdgeBetween(src_data[0], dst_data[i]) ? 1 : 0;
}
} else if (dstlen == 1) {
// many-one
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdgeBetween(src_data[i], dst_data[0])? 1 : 0;
rst_data[i] = HasEdgeBetween(src_data[i], dst_data[0]) ? 1 : 0;
}
} else {
// many-many
CHECK(srclen == dstlen) << "Invalid src and dst id array.";
for (int64_t i = 0; i < srclen; ++i) {
rst_data[i] = HasEdgeBetween(src_data[i], dst_data[i])? 1 : 0;
rst_data[i] = HasEdgeBetween(src_data[i], dst_data[i]) ? 1 : 0;
}
}
return rst;
......@@ -174,11 +176,11 @@ IdArray Graph::Predecessors(dgl_id_t vid, uint64_t radius) const {
CHECK(radius >= 1) << "invalid radius: " << radius;
std::set<dgl_id_t> vset;
for (auto& it : reverse_adjlist_[vid].succ)
vset.insert(it);
for (auto& it : reverse_adjlist_[vid].succ) vset.insert(it);
const int64_t len = vset.size();
IdArray rst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray rst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data);
std::copy(vset.begin(), vset.end(), rst_data);
......@@ -191,11 +193,11 @@ IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const {
CHECK(radius >= 1) << "invalid radius: " << radius;
std::set<dgl_id_t> vset;
for (auto& it : adjlist_[vid].succ)
vset.insert(it);
for (auto& it : adjlist_[vid].succ) vset.insert(it);
const int64_t len = vset.size();
IdArray rst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray rst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* rst_data = static_cast<int64_t*>(rst->data);
std::copy(vset.begin(), vset.end(), rst_data);
......@@ -204,19 +206,20 @@ IdArray Graph::Successors(dgl_id_t vid, uint64_t radius) const {
// O(E)
IdArray Graph::EdgeId(dgl_id_t src, dgl_id_t dst) const {
CHECK(HasVertex(src) && HasVertex(dst)) << "invalid edge: " << src << " -> " << dst;
CHECK(HasVertex(src) && HasVertex(dst))
<< "invalid edge: " << src << " -> " << dst;
const auto& succ = adjlist_[src].succ;
std::vector<dgl_id_t> edgelist;
for (size_t i = 0; i < succ.size(); ++i) {
if (succ[i] == dst)
edgelist.push_back(adjlist_[src].edge_id[i]);
if (succ[i] == dst) edgelist.push_back(adjlist_[src].edge_id[i]);
}
// FIXME: signed? Also it seems that we are using int64_t everywhere...
const int64_t len = edgelist.size();
IdArray rst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray rst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
// FIXME: signed?
int64_t* rst_data = static_cast<int64_t*>(rst->data);
......@@ -243,10 +246,11 @@ EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
std::vector<dgl_id_t> src, dst, eid;
for (i = 0, j = 0; i < srclen && j < dstlen; i += src_stride, j += dst_stride) {
for (i = 0, j = 0; i < srclen && j < dstlen;
i += src_stride, j += dst_stride) {
const dgl_id_t src_id = src_data[i], dst_id = dst_data[j];
CHECK(HasVertex(src_id) && HasVertex(dst_id)) <<
"invalid edge: " << src_id << " -> " << dst_id;
CHECK(HasVertex(src_id) && HasVertex(dst_id))
<< "invalid edge: " << src_id << " -> " << dst_id;
const auto& succ = adjlist_[src_id].succ;
for (size_t k = 0; k < succ.size(); ++k) {
if (succ[k] == dst_id) {
......@@ -286,8 +290,7 @@ EdgeArray Graph::FindEdges(IdArray eids) const {
for (uint64_t i = 0; i < (uint64_t)len; ++i) {
dgl_id_t eid = eid_data[i];
if (eid >= num_edges_)
LOG(FATAL) << "invalid edge id:" << eid;
if (eid >= num_edges_) LOG(FATAL) << "invalid edge id:" << eid;
rst_src_data[i] = all_edges_src_[eid];
rst_dst_data[i] = all_edges_dst_[eid];
......@@ -301,9 +304,12 @@ EdgeArray Graph::FindEdges(IdArray eids) const {
EdgeArray Graph::InEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = reverse_adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray dst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray src = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray dst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* src_data = static_cast<int64_t*>(src->data);
int64_t* dst_data = static_cast<int64_t*>(dst->data);
int64_t* eid_data = static_cast<int64_t*>(eid->data);
......@@ -347,9 +353,12 @@ EdgeArray Graph::InEdges(IdArray vids) const {
EdgeArray Graph::OutEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
const int64_t len = adjlist_[vid].succ.size();
IdArray src = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray dst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray src = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray dst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* src_data = static_cast<int64_t*>(src->data);
int64_t* dst_data = static_cast<int64_t*>(dst->data);
int64_t* eid_data = static_cast<int64_t*>(eid->data);
......@@ -390,11 +399,14 @@ EdgeArray Graph::OutEdges(IdArray vids) const {
}
// O(E*log(E)) if sort is required; otherwise, O(E)
EdgeArray Graph::Edges(const std::string &order) const {
EdgeArray Graph::Edges(const std::string& order) const {
const int64_t len = num_edges_;
IdArray src = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray dst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray src = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray dst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
if (order == "srcdst") {
typedef std::tuple<int64_t, int64_t, int64_t> Tuple;
......@@ -404,10 +416,11 @@ EdgeArray Graph::Edges(const std::string &order) const {
tuples.emplace_back(all_edges_src_[eid], all_edges_dst_[eid], eid);
}
// sort according to src and dst ids
std::sort(tuples.begin(), tuples.end(),
[] (const Tuple& t1, const Tuple& t2) {
return std::get<0>(t1) < std::get<0>(t2)
|| (std::get<0>(t1) == std::get<0>(t2) && std::get<1>(t1) < std::get<1>(t2));
std::sort(
tuples.begin(), tuples.end(), [](const Tuple& t1, const Tuple& t2) {
return std::get<0>(t1) < std::get<0>(t2) ||
(std::get<0>(t1) == std::get<0>(t2) &&
std::get<1>(t1) < std::get<1>(t2));
});
// make return arrays
......@@ -488,8 +501,11 @@ Subgraph Graph::VertexSubgraph(IdArray vids) const {
}
}
}
rst.induced_edges = IdArray::Empty({static_cast<int64_t>(edges.size())}, vids->dtype, vids->ctx);
std::copy(edges.begin(), edges.end(), static_cast<int64_t*>(rst.induced_edges->data));
rst.induced_edges = IdArray::Empty(
{static_cast<int64_t>(edges.size())}, vids->dtype, vids->ctx);
std::copy(
edges.begin(), edges.end(),
static_cast<int64_t*>(rst.induced_edges->data));
return rst;
}
......@@ -524,7 +540,9 @@ Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
rst.induced_vertices = IdArray::Empty(
{static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx);
std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data));
std::copy(
nodes.begin(), nodes.end(),
static_cast<int64_t*>(rst.induced_vertices->data));
} else {
rst.graph = std::make_shared<Graph>();
rst.induced_edges = eids;
......@@ -536,59 +554,58 @@ Subgraph Graph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
rst.graph->AddEdge(src_id, dst_id);
}
for (uint64_t i = 0; i < NumVertices(); ++i)
nodes.push_back(i);
for (uint64_t i = 0; i < NumVertices(); ++i) nodes.push_back(i);
rst.induced_vertices = IdArray::Empty(
{static_cast<int64_t>(nodes.size())}, eids->dtype, eids->ctx);
std::copy(nodes.begin(), nodes.end(), static_cast<int64_t*>(rst.induced_vertices->data));
std::copy(
nodes.begin(), nodes.end(),
static_cast<int64_t*>(rst.induced_vertices->data));
}
return rst;
}
std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const {
std::vector<IdArray> Graph::GetAdj(
bool transpose, const std::string& fmt) const {
uint64_t num_edges = NumEdges();
uint64_t num_nodes = NumVertices();
if (fmt == "coo") {
IdArray idx = IdArray::Empty(
{2 * static_cast<int64_t>(num_edges)},
DGLDataType{kDGLInt, 64, 1},
{2 * static_cast<int64_t>(num_edges)}, DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0});
int64_t *idx_data = static_cast<int64_t*>(idx->data);
int64_t* idx_data = static_cast<int64_t*>(idx->data);
if (transpose) {
std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data);
std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), idx_data + num_edges);
std::copy(
all_edges_dst_.begin(), all_edges_dst_.end(), idx_data + num_edges);
} else {
std::copy(all_edges_dst_.begin(), all_edges_dst_.end(), idx_data);
std::copy(all_edges_src_.begin(), all_edges_src_.end(), idx_data + num_edges);
std::copy(
all_edges_src_.begin(), all_edges_src_.end(), idx_data + num_edges);
}
IdArray eid = IdArray::Empty(
{static_cast<int64_t>(num_edges)},
DGLDataType{kDGLInt, 64, 1},
{static_cast<int64_t>(num_edges)}, DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0});
int64_t *eid_data = static_cast<int64_t*>(eid->data);
int64_t* eid_data = static_cast<int64_t*>(eid->data);
for (uint64_t eid = 0; eid < num_edges; ++eid) {
eid_data[eid] = eid;
}
return std::vector<IdArray>{idx, eid};
} else if (fmt == "csr") {
IdArray indptr = IdArray::Empty(
{static_cast<int64_t>(num_nodes) + 1},
DGLDataType{kDGLInt, 64, 1},
{static_cast<int64_t>(num_nodes) + 1}, DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0});
IdArray indices = IdArray::Empty(
{static_cast<int64_t>(num_edges)},
DGLDataType{kDGLInt, 64, 1},
{static_cast<int64_t>(num_edges)}, DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0});
IdArray eid = IdArray::Empty(
{static_cast<int64_t>(num_edges)},
DGLDataType{kDGLInt, 64, 1},
{static_cast<int64_t>(num_edges)}, DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0});
int64_t *indptr_data = static_cast<int64_t*>(indptr->data);
int64_t *indices_data = static_cast<int64_t*>(indices->data);
int64_t *eid_data = static_cast<int64_t*>(eid->data);
const AdjacencyList *adjlist;
int64_t* indptr_data = static_cast<int64_t*>(indptr->data);
int64_t* indices_data = static_cast<int64_t*>(indices->data);
int64_t* eid_data = static_cast<int64_t*>(eid->data);
const AdjacencyList* adjlist;
if (transpose) {
// Out-edges.
adjlist = &adjlist_;
......@@ -599,9 +616,11 @@ std::vector<IdArray> Graph::GetAdj(bool transpose, const std::string &fmt) const
indptr_data[0] = 0;
for (size_t i = 0; i < adjlist->size(); i++) {
indptr_data[i + 1] = indptr_data[i] + adjlist->at(i).succ.size();
std::copy(adjlist->at(i).succ.begin(), adjlist->at(i).succ.end(),
std::copy(
adjlist->at(i).succ.begin(), adjlist->at(i).succ.end(),
indices_data + indptr_data[i]);
std::copy(adjlist->at(i).edge_id.begin(), adjlist->at(i).edge_id.end(),
std::copy(
adjlist->at(i).edge_id.begin(), adjlist->at(i).edge_id.end(),
eid_data + indptr_data[i]);
}
return std::vector<IdArray>{indptr, indices, eid};
......
......@@ -3,72 +3,74 @@
* \file graph/graph.cc
* \brief DGL graph index APIs
*/
#include <dgl/packed_func_ext.h>
#include <dgl/graph.h>
#include <dgl/immutable_graph.h>
#include <dgl/graph_op.h>
#include <dgl/sampler.h>
#include <dgl/immutable_graph.h>
#include <dgl/nodeflow.h>
#include <dgl/packed_func_ext.h>
#include <dgl/sampler.h>
#include "../c_api_common.h"
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
using dgl::runtime::PackedFunc;
namespace dgl {
///////////////////////////// Graph API ///////////////////////////////////
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreateMutable")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = GraphRef(Graph::Create());
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const IdArray src_ids = args[0];
const IdArray dst_ids = args[1];
const int64_t num_nodes = args[2];
const bool readonly = args[3];
if (readonly) {
*rv = GraphRef(ImmutableGraph::CreateFromCOO(num_nodes, src_ids, dst_ids));
*rv = GraphRef(
ImmutableGraph::CreateFromCOO(num_nodes, src_ids, dst_ids));
} else {
*rv = GraphRef(Graph::CreateFromCOO(num_nodes, src_ids, dst_ids));
}
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const IdArray indptr = args[0];
const IdArray indices = args[1];
const std::string edge_dir = args[2];
IdArray edge_ids = IdArray::Empty({indices->shape[0]},
DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t *edge_data = static_cast<int64_t *>(edge_ids->data);
for (int64_t i = 0; i < edge_ids->shape[0]; i++)
edge_data[i] = i;
*rv = GraphRef(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, edge_dir));
IdArray edge_ids = IdArray::Empty(
{indices->shape[0]}, DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0});
int64_t* edge_data = static_cast<int64_t*>(edge_ids->data);
for (int64_t i = 0; i < edge_ids->shape[0]; i++) edge_data[i] = i;
*rv = GraphRef(
ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, edge_dir));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const std::string shared_mem_name = args[0];
*rv = GraphRef(ImmutableGraph::CreateFromCSR(shared_mem_name));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
uint64_t num_vertices = args[1];
g->AddVertices(num_vertices);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdge")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
......@@ -76,7 +78,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdge")
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray src = args[1];
const IdArray dst = args[2];
......@@ -84,51 +86,51 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddEdges")
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphClear")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
g->Clear();
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphIsMultigraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
*rv = g->IsMultigraph();
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphIsReadonly")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
*rv = g->IsReadonly();
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
*rv = static_cast<int64_t>(g->NumVertices());
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
*rv = static_cast<int64_t>(g->NumEdges());
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertex")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const dgl_id_t vid = args[1];
*rv = g->HasVertex(vid);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray vids = args[1];
*rv = g->HasVertices(vids);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgeBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
......@@ -136,7 +138,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgeBetween")
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgesBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray src = args[1];
const IdArray dst = args[2];
......@@ -144,7 +146,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphHasEdgesBetween")
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const dgl_id_t vid = args[1];
const uint64_t radius = args[2];
......@@ -152,7 +154,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphPredecessors")
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphSuccessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const dgl_id_t vid = args[1];
const uint64_t radius = args[2];
......@@ -160,7 +162,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphSuccessors")
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const dgl_id_t src = args[1];
const dgl_id_t dst = args[2];
......@@ -168,7 +170,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeId")
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray src = args[1];
const IdArray dst = args[2];
......@@ -176,89 +178,89 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeIds")
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFindEdge")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const dgl_id_t eid = args[1];
const auto& pair = g->FindEdge(eid);
*rv = PackedFunc([pair] (DGLArgs args, DGLRetValue* rv) {
*rv = PackedFunc([pair](DGLArgs args, DGLRetValue* rv) {
const int choice = args[0];
const int64_t ret = (choice == 0? pair.first : pair.second);
const int64_t ret = (choice == 0 ? pair.first : pair.second);
*rv = ret;
});
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphFindEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray eids = args[1];
*rv = ConvertEdgeArrayToPackedFunc(g->FindEdges(eids));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const dgl_id_t vid = args[1];
*rv = ConvertEdgeArrayToPackedFunc(g->InEdges(vid));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray vids = args[1];
*rv = ConvertEdgeArrayToPackedFunc(g->InEdges(vids));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const dgl_id_t vid = args[1];
*rv = ConvertEdgeArrayToPackedFunc(g->OutEdges(vid));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray vids = args[1];
*rv = ConvertEdgeArrayToPackedFunc(g->OutEdges(vids));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
std::string order = args[1];
*rv = ConvertEdgeArrayToPackedFunc(g->Edges(order));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(g->InDegree(vid));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphInDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray vids = args[1];
*rv = g->InDegrees(vids);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const dgl_id_t vid = args[1];
*rv = static_cast<int64_t>(g->OutDegree(vid));
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphOutDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray vids = args[1];
*rv = g->OutDegrees(vids);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray vids = args[1];
std::shared_ptr<Subgraph> subg(new Subgraph(g->VertexSubgraph(vids)));
......@@ -266,7 +268,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphVertexSubgraph")
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray eids = args[1];
bool preserve_nodes = args[2];
......@@ -276,7 +278,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphEdgeSubgraph")
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphGetAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
bool transpose = args[1];
std::string format = args[2];
......@@ -285,13 +287,13 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphGetAdj")
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphContext")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
*rv = g->Context();
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
*rv = g->NumBits();
});
......@@ -299,25 +301,25 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits")
// Subgraph C APIs
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
SubgraphRef subg = args[0];
*rv = GraphRef(subg->graph);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetInducedVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
SubgraphRef subg = args[0];
*rv = subg->induced_vertices;
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSubgraphGetInducedEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
SubgraphRef subg = args[0];
*rv = subg->induced_edges;
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLSortAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
g->SortCSR();
});
......
......@@ -3,13 +3,15 @@
* \file graph/graph.cc
* \brief Graph operation implementation
*/
#include <dgl/graph_op.h>
#include <dgl/array.h>
#include <dgl/graph_op.h>
#include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h>
#include <algorithm>
#include "../c_api_common.h"
using namespace dgl::runtime;
......@@ -19,7 +21,7 @@ namespace {
// generate consecutive dgl ids
class RangeIter : public std::iterator<std::input_iterator_tag, dgl_id_t> {
public:
explicit RangeIter(dgl_id_t from): cur_(from) {}
explicit RangeIter(dgl_id_t from) : cur_(from) {}
RangeIter& operator++() {
++cur_;
......@@ -31,15 +33,9 @@ class RangeIter : public std::iterator<std::input_iterator_tag, dgl_id_t> {
++cur_;
return retval;
}
bool operator==(RangeIter other) const {
return cur_ == other.cur_;
}
bool operator!=(RangeIter other) const {
return cur_ != other.cur_;
}
dgl_id_t operator*() const {
return cur_;
}
bool operator==(RangeIter other) const { return cur_ == other.cur_; }
bool operator!=(RangeIter other) const { return cur_ != other.cur_; }
dgl_id_t operator*() const { return cur_; }
private:
dgl_id_t cur_;
......@@ -88,7 +84,8 @@ GraphPtr GraphOp::DisjointUnion(std::vector<GraphPtr> graphs) {
rst->AddVertices(gr->NumVertices());
for (uint64_t i = 0; i < gr->NumEdges(); ++i) {
// TODO(minjie): quite ugly to expose internal members
rst->AddEdge(mg->all_edges_src_[i] + cumsum, mg->all_edges_dst_[i] + cumsum);
rst->AddEdge(
mg->all_edges_src_[i] + cumsum, mg->all_edges_dst_[i] + cumsum);
}
cumsum += gr->NumVertices();
}
......@@ -136,14 +133,17 @@ GraphPtr GraphOp::DisjointUnion(std::vector<GraphPtr> graphs) {
cum_num_edges += g_num_edges;
}
return ImmutableGraph::CreateFromCSR(indptr_arr, indices_arr, edge_ids_arr, "in");
return ImmutableGraph::CreateFromCSR(
indptr_arr, indices_arr, edge_ids_arr, "in");
}
}
std::vector<GraphPtr> GraphOp::DisjointPartitionByNum(GraphPtr graph, int64_t num) {
std::vector<GraphPtr> GraphOp::DisjointPartitionByNum(
GraphPtr graph, int64_t num) {
CHECK(num != 0 && graph->NumVertices() % num == 0)
<< "Number of partitions must evenly divide the number of nodes.";
IdArray sizes = IdArray::Empty({num}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray sizes = IdArray::Empty(
{num}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* sizes_data = static_cast<int64_t*>(sizes->data);
std::fill(sizes_data, sizes_data + num, graph->NumVertices() / num);
return DisjointPartitionBySizes(graph, sizes);
......@@ -170,10 +170,11 @@ std::vector<GraphPtr> GraphOp::DisjointPartitionBySizes(
MutableGraphPtr mg = Graph::Create();
// TODO(minjie): quite ugly to expose internal members
// copy adj
mg->adjlist_.insert(mg->adjlist_.end(),
graph->adjlist_.begin() + node_offset,
mg->adjlist_.insert(
mg->adjlist_.end(), graph->adjlist_.begin() + node_offset,
graph->adjlist_.begin() + node_offset + sizes_data[i]);
mg->reverse_adjlist_.insert(mg->reverse_adjlist_.end(),
mg->reverse_adjlist_.insert(
mg->reverse_adjlist_.end(),
graph->reverse_adjlist_.begin() + node_offset,
graph->reverse_adjlist_.begin() + node_offset + sizes_data[i]);
// relabel adjs
......@@ -209,12 +210,15 @@ std::vector<GraphPtr> GraphOp::DisjointPartitionBySizes(
}
} else {
// Input is an immutable graph. Partition it into several multiple graphs.
ImmutableGraphPtr graph = std::dynamic_pointer_cast<ImmutableGraph>(batched_graph);
ImmutableGraphPtr graph =
std::dynamic_pointer_cast<ImmutableGraph>(batched_graph);
// TODO(minjie): why in csr?
CSRPtr in_csr_ptr = graph->GetInCSR();
const dgl_id_t* indptr = static_cast<dgl_id_t*>(in_csr_ptr->indptr()->data);
const dgl_id_t* indices = static_cast<dgl_id_t*>(in_csr_ptr->indices()->data);
const dgl_id_t* edge_ids = static_cast<dgl_id_t*>(in_csr_ptr->edge_ids()->data);
const dgl_id_t* indices =
static_cast<dgl_id_t*>(in_csr_ptr->indices()->data);
const dgl_id_t* edge_ids =
static_cast<dgl_id_t*>(in_csr_ptr->edge_ids()->data);
dgl_id_t cum_sum_edges = 0;
for (int64_t i = 0; i < len; ++i) {
const int64_t start_pos = cumsum[i];
......@@ -257,7 +261,8 @@ IdArray GraphOp::MapParentIdToSubgraphId(IdArray parent_vids, IdArray query) {
const auto query_len = query->shape[0];
const dgl_id_t* parent_data = static_cast<dgl_id_t*>(parent_vids->data);
const dgl_id_t* query_data = static_cast<dgl_id_t*>(query->data);
IdArray rst = IdArray::Empty({query_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray rst = IdArray::Empty(
{query_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data);
const bool is_sorted = std::is_sorted(parent_data, parent_data + parent_len);
......@@ -300,11 +305,12 @@ IdArray GraphOp::ExpandIds(IdArray ids, IdArray offset) {
const auto id_len = ids->shape[0];
const auto off_len = offset->shape[0];
CHECK_EQ(id_len + 1, off_len);
const dgl_id_t *id_data = static_cast<dgl_id_t*>(ids->data);
const dgl_id_t *off_data = static_cast<dgl_id_t*>(offset->data);
const dgl_id_t* id_data = static_cast<dgl_id_t*>(ids->data);
const dgl_id_t* off_data = static_cast<dgl_id_t*>(offset->data);
const int64_t len = off_data[off_len - 1];
IdArray rst = IdArray::Empty({len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
dgl_id_t *rst_data = static_cast<dgl_id_t*>(rst->data);
IdArray rst = IdArray::Empty(
{len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
dgl_id_t* rst_data = static_cast<dgl_id_t*>(rst->data);
for (int64_t i = 0; i < id_len; i++) {
const int64_t local_len = off_data[i + 1] - off_data[i];
for (int64_t j = 0; j < local_len; j++) {
......@@ -325,10 +331,11 @@ GraphPtr GraphOp::ToSimpleGraph(GraphPtr graph) {
hashmap.insert(dst);
}
}
indptr[src+1] = indices.size();
indptr[src + 1] = indices.size();
}
CSRPtr csr(new CSR(graph->NumVertices(), indices.size(),
indptr.begin(), indices.begin(), RangeIter(0)));
CSRPtr csr(new CSR(
graph->NumVertices(), indices.size(), indptr.begin(), indices.begin(),
RangeIter(0)));
return std::make_shared<ImmutableGraph>(csr);
}
......@@ -403,8 +410,9 @@ GraphPtr GraphOp::ToBidirectedImmutableGraph(GraphPtr g) {
g->NumVertices(), srcs_array, dsts_array);
}
HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hops) {
const dgl_id_t *nid = static_cast<dgl_id_t *>(nodes->data);
HaloSubgraph GraphOp::GetSubgraphWithHalo(
GraphPtr g, IdArray nodes, int num_hops) {
const dgl_id_t* nid = static_cast<dgl_id_t*>(nodes->data);
const auto id_len = nodes->shape[0];
// A map contains all nodes in the subgraph.
// The key is the old node Ids, the value indicates whether a node is a inner
......@@ -414,8 +422,7 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
// vector. The first few nodes are the inner nodes in the subgraph.
std::vector<dgl_id_t> old_node_ids(nid, nid + id_len);
std::vector<std::vector<dgl_id_t>> outer_nodes(num_hops);
for (int64_t i = 0; i < id_len; i++)
all_nodes[nid[i]] = true;
for (int64_t i = 0; i < id_len; i++) all_nodes[nid[i]] = true;
auto orig_nodes = all_nodes;
std::vector<dgl_id_t> edge_src, edge_dst, edge_eid;
......@@ -428,9 +435,9 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
auto dst = in_edges.dst;
auto eid = in_edges.id;
auto num_edges = eid->shape[0];
const dgl_id_t *src_data = static_cast<dgl_id_t *>(src->data);
const dgl_id_t *dst_data = static_cast<dgl_id_t *>(dst->data);
const dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data);
const dgl_id_t* src_data = static_cast<dgl_id_t*>(src->data);
const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst->data);
const dgl_id_t* eid_data = static_cast<dgl_id_t*>(eid->data);
for (int64_t i = 0; i < num_edges; i++) {
// We check if the source node is in the original node.
auto it1 = orig_nodes.find(src_data[i]);
......@@ -451,15 +458,15 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
// Now we need to traverse the graph with the in-edges to access nodes
// and edges more hops away.
for (int k = 1; k < num_hops; k++) {
const std::vector<dgl_id_t> &nodes = outer_nodes[k-1];
const std::vector<dgl_id_t>& nodes = outer_nodes[k - 1];
EdgeArray in_edges = g->InEdges(aten::VecToIdArray(nodes));
auto src = in_edges.src;
auto dst = in_edges.dst;
auto eid = in_edges.id;
auto num_edges = eid->shape[0];
const dgl_id_t *src_data = static_cast<dgl_id_t *>(src->data);
const dgl_id_t *dst_data = static_cast<dgl_id_t *>(dst->data);
const dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data);
const dgl_id_t* src_data = static_cast<dgl_id_t*>(src->data);
const dgl_id_t* dst_data = static_cast<dgl_id_t*>(dst->data);
const dgl_id_t* eid_data = static_cast<dgl_id_t*>(eid->data);
for (int64_t i = 0; i < num_edges; i++) {
edge_src.push_back(src_data[i]);
edge_dst.push_back(dst_data[i]);
......@@ -482,12 +489,12 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
}
num_edges = edge_src.size();
IdArray new_src = IdArray::Empty({num_edges}, DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0});
IdArray new_dst = IdArray::Empty({num_edges}, DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0});
dgl_id_t *new_src_data = static_cast<dgl_id_t *>(new_src->data);
dgl_id_t *new_dst_data = static_cast<dgl_id_t *>(new_dst->data);
IdArray new_src = IdArray::Empty(
{num_edges}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
IdArray new_dst = IdArray::Empty(
{num_edges}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
dgl_id_t* new_src_data = static_cast<dgl_id_t*>(new_src->data);
dgl_id_t* new_dst_data = static_cast<dgl_id_t*>(new_dst->data);
for (size_t i = 0; i < edge_src.size(); i++) {
new_src_data[i] = old2new[edge_src[i]];
new_dst_data[i] = old2new[edge_dst[i]];
......@@ -499,7 +506,8 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
inner_nodes[i] = all_nodes[old_nid];
}
GraphPtr subg = ImmutableGraph::CreateFromCOO(old_node_ids.size(), new_src, new_dst);
GraphPtr subg =
ImmutableGraph::CreateFromCOO(old_node_ids.size(), new_src, new_dst);
HaloSubgraph halo_subg;
halo_subg.graph = subg;
halo_subg.induced_vertices = aten::VecToIdArray(old_node_ids);
......@@ -509,7 +517,8 @@ HaloSubgraph GraphOp::GetSubgraphWithHalo(GraphPtr g, IdArray nodes, int num_hop
return halo_subg;
}
GraphPtr GraphOp::ReorderImmutableGraph(ImmutableGraphPtr ig, IdArray new_order) {
GraphPtr GraphOp::ReorderImmutableGraph(
ImmutableGraphPtr ig, IdArray new_order) {
CSRPtr in_csr, out_csr;
COOPtr coo;
// We only need to reorder one of the graph structure.
......@@ -517,12 +526,14 @@ GraphPtr GraphOp::ReorderImmutableGraph(ImmutableGraphPtr ig, IdArray new_order)
in_csr = ig->GetInCSR();
auto csrmat = in_csr->ToCSRMatrix();
auto new_csrmat = aten::CSRReorder(csrmat, new_order, new_order);
in_csr = CSRPtr(new CSR(new_csrmat.indptr, new_csrmat.indices, new_csrmat.data));
in_csr =
CSRPtr(new CSR(new_csrmat.indptr, new_csrmat.indices, new_csrmat.data));
} else if (ig->HasOutCSR()) {
out_csr = ig->GetOutCSR();
auto csrmat = out_csr->ToCSRMatrix();
auto new_csrmat = aten::CSRReorder(csrmat, new_order, new_order);
out_csr = CSRPtr(new CSR(new_csrmat.indptr, new_csrmat.indices, new_csrmat.data));
out_csr =
CSRPtr(new CSR(new_csrmat.indptr, new_csrmat.indices, new_csrmat.data));
} else {
coo = ig->GetCOO();
auto coomat = coo->ToCOOMatrix();
......@@ -536,14 +547,14 @@ GraphPtr GraphOp::ReorderImmutableGraph(ImmutableGraphPtr ig, IdArray new_order)
}
DGL_REGISTER_GLOBAL("transform._CAPI_DGLPartitionWithHalo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef graph = args[0];
IdArray node_parts = args[1];
int num_hops = args[2];
const dgl_id_t *part_data = static_cast<dgl_id_t *>(node_parts->data);
const dgl_id_t* part_data = static_cast<dgl_id_t*>(node_parts->data);
int64_t num_nodes = node_parts->shape[0];
std::unordered_map<int, std::vector<dgl_id_t> > part_map;
std::unordered_map<int, std::vector<dgl_id_t>> part_map;
for (int64_t i = 0; i < num_nodes; i++) {
dgl_id_t part_id = part_data[i];
auto it = part_map.find(part_id);
......@@ -556,7 +567,7 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLPartitionWithHalo")
}
}
std::vector<int> part_ids;
std::vector<std::vector<dgl_id_t> > part_nodes;
std::vector<std::vector<dgl_id_t>> part_nodes;
int max_part_id = 0;
for (auto it = part_map.begin(); it != part_map.end(); it++) {
max_part_id = std::max(it->first, max_part_id);
......@@ -570,12 +581,13 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLPartitionWithHalo")
// try to construct in-CSR in openmp for loop, which will lead
// to some unexpected results.
graph_ptr->GetInCSR();
std::vector<std::shared_ptr<HaloSubgraph> > subgs(max_part_id + 1);
std::vector<std::shared_ptr<HaloSubgraph>> subgs(max_part_id + 1);
int num_partitions = part_nodes.size();
runtime::parallel_for(0, num_partitions, [&](size_t b, size_t e) {
for (auto i = b; i < e; ++i) {
auto nodes = aten::VecToIdArray(part_nodes[i]);
HaloSubgraph subg = GraphOp::GetSubgraphWithHalo(graph_ptr, nodes, num_hops);
HaloSubgraph subg =
GraphOp::GetSubgraphWithHalo(graph_ptr, nodes, num_hops);
std::shared_ptr<HaloSubgraph> subg_ptr(new HaloSubgraph(subg));
int part_id = part_ids[i];
subgs[part_id] = subg_ptr;
......@@ -589,25 +601,26 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLPartitionWithHalo")
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGetSubgraphWithHalo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef graph = args[0];
IdArray nodes = args[1];
int num_hops = args[2];
HaloSubgraph subg = GraphOp::GetSubgraphWithHalo(graph.sptr(), nodes, num_hops);
HaloSubgraph subg =
GraphOp::GetSubgraphWithHalo(graph.sptr(), nodes, num_hops);
std::shared_ptr<HaloSubgraph> subg_ptr(new HaloSubgraph(subg));
*rv = SubgraphRef(subg_ptr);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_GetHaloSubgraphInnerNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<HaloSubgraph>(g.sptr());
CHECK(gptr) << "The input graph has to be immutable graph";
*rv = gptr->inner_nodes;
});
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
List<GraphRef> graphs = args[0];
std::vector<GraphPtr> ptrs(graphs.size());
for (size_t i = 0; i < graphs.size(); ++i) {
......@@ -617,7 +630,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointUnion")
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
int64_t num = args[1];
const auto& ret = GraphOp::DisjointPartitionByNum(g.sptr(), num);
......@@ -629,7 +642,7 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionByNum")
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray sizes = args[1];
const auto& ret = GraphOp::DisjointPartitionBySizes(g.sptr(), sizes);
......@@ -638,35 +651,35 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLDisjointPartitionBySizes")
ret_list.push_back(GraphRef(gp));
}
*rv = ret_list;
});
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphLineGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
bool backtracking = args[1];
*rv = GraphOp::LineGraph(g.sptr(), backtracking);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLToImmutable")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
*rv = ImmutableGraph::ToImmutable(g.sptr());
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToSimpleGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
*rv = GraphOp::ToSimpleGraph(g.sptr());
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedMutableGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
*rv = GraphOp::ToBidirectedMutableGraph(g.sptr());
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLReorderGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
const IdArray new_order = args[1];
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
......@@ -675,7 +688,7 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLReorderGraph")
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLReassignEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef graph = args[0];
bool is_incsr = args[1];
auto gptr = std::dynamic_pointer_cast<ImmutableGraph>(graph.sptr());
......@@ -683,7 +696,8 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLReassignEdges")
CSRPtr csr = is_incsr ? gptr->GetInCSR() : gptr->GetOutCSR();
auto csrmat = csr->ToCSRMatrix();
int64_t num_edges = csrmat.data->shape[0];
IdArray new_data = IdArray::Empty({num_edges}, csrmat.data->dtype, csrmat.data->ctx);
IdArray new_data =
IdArray::Empty({num_edges}, csrmat.data->dtype, csrmat.data->ctx);
// Return the original edge Ids.
*rv = new_data;
// TODO(zhengda) I need to invalidate out-CSR and COO.
......@@ -692,8 +706,8 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLReassignEdges")
// TODO(zhengda) after assignment, we actually don't need to store them
// physically.
ATEN_ID_TYPE_SWITCH(new_data->dtype, IdType, {
IdType *typed_new_data = static_cast<IdType*>(new_data->data);
IdType *typed_data = static_cast<IdType*>(csrmat.data->data);
IdType* typed_new_data = static_cast<IdType*>(new_data->data);
IdType* typed_data = static_cast<IdType*>(csrmat.data->data);
for (int64_t i = 0; i < num_edges; i++) {
typed_new_data[i] = typed_data[i];
typed_data[i] = i;
......@@ -702,7 +716,7 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLReassignEdges")
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedImmutableGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0];
auto gptr = g.sptr();
auto immutable_g = std::dynamic_pointer_cast<ImmutableGraph>(gptr);
......@@ -719,29 +733,31 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBidirectedImmutableGraph")
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLMapSubgraphNID")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const IdArray parent_vids = args[0];
const IdArray query = args[1];
*rv = GraphOp::MapParentIdToSubgraphId(parent_vids, query);
});
template<class IdType>
IdArray MapIds(IdArray ids, IdArray range_starts, IdArray range_ends, IdArray typed_map,
template <class IdType>
IdArray MapIds(
IdArray ids, IdArray range_starts, IdArray range_ends, IdArray typed_map,
int num_parts, int num_types) {
int64_t num_ids = ids->shape[0];
int64_t num_ranges = range_starts->shape[0];
IdArray ret = IdArray::Empty({num_ids * 2}, ids->dtype, ids->ctx);
const IdType *range_start_data = static_cast<IdType *>(range_starts->data);
const IdType *range_end_data = static_cast<IdType *>(range_ends->data);
const IdType *ids_data = static_cast<IdType *>(ids->data);
const IdType *typed_map_data = static_cast<IdType *>(typed_map->data);
IdType *types_data = static_cast<IdType *>(ret->data);
IdType *per_type_ids_data = static_cast<IdType *>(ret->data) + num_ids;
const IdType* range_start_data = static_cast<IdType*>(range_starts->data);
const IdType* range_end_data = static_cast<IdType*>(range_ends->data);
const IdType* ids_data = static_cast<IdType*>(ids->data);
const IdType* typed_map_data = static_cast<IdType*>(typed_map->data);
IdType* types_data = static_cast<IdType*>(ret->data);
IdType* per_type_ids_data = static_cast<IdType*>(ret->data) + num_ids;
runtime::parallel_for(0, ids->shape[0], [&](size_t b, size_t e) {
for (auto i = b; i < e; ++i) {
IdType id = ids_data[i];
auto it = std::lower_bound(range_end_data, range_end_data + num_ranges, id);
auto it =
std::lower_bound(range_end_data, range_end_data + num_ranges, id);
// The range must exist.
BUG_IF_FAIL(it != range_end_data + num_ranges);
size_t range_id = it - range_end_data;
......@@ -752,8 +768,9 @@ IdArray MapIds(IdArray ids, IdArray range_starts, IdArray range_ends, IdArray ty
if (part_id == 0) {
per_type_ids_data[i] = id - range_start_data[range_id];
} else {
per_type_ids_data[i] = id - range_start_data[range_id]
+ typed_map_data[num_parts * type_id + part_id - 1];
per_type_ids_data[i] =
id - range_start_data[range_id] +
typed_map_data[num_parts * type_id + part_id - 1];
}
}
});
......@@ -761,7 +778,7 @@ IdArray MapIds(IdArray ids, IdArray range_starts, IdArray range_ends, IdArray ty
}
DGL_REGISTER_GLOBAL("distributed.id_map._CAPI_DGLHeteroMapIds")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
const IdArray ids = args[0];
const IdArray range_starts = args[1];
const IdArray range_ends = args[2];
......@@ -778,7 +795,8 @@ DGL_REGISTER_GLOBAL("distributed.id_map._CAPI_DGLHeteroMapIds")
IdArray ret;
ATEN_ID_TYPE_SWITCH(ids->dtype, IdType, {
ret = MapIds<IdType>(ids, range_starts, range_ends, typed_map, num_parts, num_types);
ret = MapIds<IdType>(
ids, range_starts, range_ends, typed_map, num_parts, num_types);
});
*rv = ret;
});
......
......@@ -5,6 +5,7 @@
*/
#include <dgl/graph_traversal.h>
#include <dgl/packed_func_ext.h>
#include "../c_api_common.h"
using namespace dgl::runtime;
......@@ -13,7 +14,7 @@ namespace dgl {
namespace traverse {
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0];
const IdArray src = args[1];
bool reversed = args[2];
......@@ -28,7 +29,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSNodes_v2")
});
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0];
const IdArray src = args[1];
bool reversed = args[2];
......@@ -44,7 +45,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLBFSEdges_v2")
});
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0];
bool reversed = args[1];
aten::CSRMatrix csr;
......@@ -58,9 +59,8 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLTopologicalNodes_v2")
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
});
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0];
const IdArray source = args[1];
const bool reversed = args[2];
......@@ -76,7 +76,7 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSEdges_v2")
});
DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef g = args[0];
const IdArray source = args[1];
const bool reversed = args[2];
......@@ -90,14 +90,12 @@ DGL_REGISTER_GLOBAL("traversal._CAPI_DGLDFSLabeledEdges_v2")
csr = g.sptr()->GetCSRMatrix(0);
}
const auto& front = aten::DGLDFSLabeledEdges(csr,
source,
has_reverse_edge,
has_nontree_edge,
return_labels);
const auto& front = aten::DGLDFSLabeledEdges(
csr, source, has_reverse_edge, has_nontree_edge, return_labels);
if (return_labels) {
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.tags, front.sections});
*rv = ConvertNDArrayVectorToPackedFunc(
{front.ids, front.tags, front.sections});
} else {
*rv = ConvertNDArrayVectorToPackedFunc({front.ids, front.sections});
}
......
......@@ -4,14 +4,16 @@
* \brief Heterograph implementation
*/
#include "./heterograph.h"
#include <dgl/array.h>
#include <dgl/immutable_graph.h>
#include <dgl/graph_serializer.h>
#include <dgl/immutable_graph.h>
#include <dmlc/memory_io.h>
#include <memory>
#include <vector>
#include <tuple>
#include <utility>
#include <vector>
using namespace dgl::runtime;
......@@ -23,7 +25,8 @@ using dgl::ImmutableGraph;
HeteroSubgraph EdgeSubgraphPreserveNodes(
const HeteroGraph* hg, const std::vector<IdArray>& eids) {
CHECK_EQ(eids.size(), hg->NumEdgeTypes())
<< "Invalid input: the input list size must be the same as the number of edge type.";
<< "Invalid input: the input list size must be the same as the number of "
"edge type.";
HeteroSubgraph ret;
ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = eids;
......@@ -33,14 +36,14 @@ HeteroSubgraph EdgeSubgraphPreserveNodes(
auto pair = hg->meta_graph()->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
const auto& rel_vsg = hg->GetRelationGraph(etype)->EdgeSubgraph(
{eids[etype]}, true);
const auto& rel_vsg =
hg->GetRelationGraph(etype)->EdgeSubgraph({eids[etype]}, true);
subrels[etype] = rel_vsg.graph;
ret.induced_vertices[src_vtype] = rel_vsg.induced_vertices[0];
ret.induced_vertices[dst_vtype] = rel_vsg.induced_vertices[1];
}
ret.graph = HeteroGraphPtr(new HeteroGraph(
hg->meta_graph(), subrels, hg->NumVerticesPerType()));
ret.graph = HeteroGraphPtr(
new HeteroGraph(hg->meta_graph(), subrels, hg->NumVerticesPerType()));
return ret;
}
......@@ -49,34 +52,36 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
// TODO(minjie): In general, all relabeling should be separated with subgraph
// operations.
CHECK_EQ(eids.size(), hg->NumEdgeTypes())
<< "Invalid input: the input list size must be the same as the number of edge type.";
<< "Invalid input: the input list size must be the same as the number of "
"edge type.";
HeteroSubgraph ret;
ret.induced_vertices.resize(hg->NumVertexTypes());
ret.induced_edges = eids;
// NOTE(minjie): EdgeSubgraph when preserve_nodes is false is quite complicated in
// heterograph. This is because we need to make sure bipartite graphs that incident
// on the same vertex type must have the same ID space. For example, suppose we have
// following heterograph:
// NOTE(minjie): EdgeSubgraph when preserve_nodes is false is quite
// complicated in heterograph. This is because we need to make sure bipartite
// graphs that incident on the same vertex type must have the same ID space.
// For example, suppose we have following heterograph:
//
// Meta graph: A -> B -> C
// UnitGraph graphs:
// * A -> B: (0, 0), (0, 1)
// * B -> C: (1, 0), (1, 1)
//
// Suppose for A->B, we only keep edge (0, 0), while for B->C we only keep (1, 0). We need
// to make sure that in the result subgraph, node type B still has two nodes. This means
// we cannot simply compute EdgeSubgraph for B->C which will relabel node#1 of type B to be
// node #0.
// Suppose for A->B, we only keep edge (0, 0), while for B->C we only keep (1,
// 0). We need to make sure that in the result subgraph, node type B still has
// two nodes. This means we cannot simply compute EdgeSubgraph for B->C which
// will relabel node#1 of type B to be node #0.
//
// One implementation is as follows:
// (1) For each bipartite graph, slice out the edges using the given eids.
// (2) Make a dictionary map<vtype, vector<IdArray>>, where the key is the vertex type
// and the value is the incident nodes from the bipartite graphs that has the vertex
// type as either srctype or dsttype.
// (2) Make a dictionary map<vtype, vector<IdArray>>, where the key is the
// vertex type
// and the value is the incident nodes from the bipartite graphs that has
// the vertex type as either srctype or dsttype.
// (3) Then for each vertex type, use aten::Relabel_ on its vector<IdArray>.
// aten::Relabel_ computes the union of the vertex sets and relabel
// the unique elements from zero. The returned mapping array is the final induced
// vertex set for that vertex type.
// the unique elements from zero. The returned mapping array is the final
// induced vertex set for that vertex type.
// (4) Use the relabeled edges to construct the bipartite graph.
// step (1) & (2)
std::vector<EdgeArray> subedges(hg->NumEdgeTypes());
......@@ -103,10 +108,9 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
subrels[etype] = UnitGraph::CreateFromCOO(
(src_vtype == dst_vtype)? 1 : 2,
(src_vtype == dst_vtype) ? 1 : 2,
ret.induced_vertices[src_vtype]->shape[0],
ret.induced_vertices[dst_vtype]->shape[0],
subedges[etype].src,
ret.induced_vertices[dst_vtype]->shape[0], subedges[etype].src,
subedges[etype].dst);
}
ret.graph = HeteroGraphPtr(new HeteroGraph(
......@@ -114,36 +118,39 @@ HeteroSubgraph EdgeSubgraphNoPreserveNodes(
return ret;
}
void HeteroGraphSanityCheck(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
void HeteroGraphSanityCheck(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
// Sanity check
CHECK_EQ(meta_graph->NumEdges(), rel_graphs.size());
CHECK(!rel_graphs.empty()) << "Empty heterograph is not allowed.";
// all relation graphs must have only one edge type
for (const auto &rg : rel_graphs) {
CHECK_EQ(rg->NumEdgeTypes(), 1) << "Each relation graph must have only one edge type.";
for (const auto& rg : rel_graphs) {
CHECK_EQ(rg->NumEdgeTypes(), 1)
<< "Each relation graph must have only one edge type.";
}
auto ctx = rel_graphs[0]->Context();
for (const auto &rg : rel_graphs) {
CHECK_EQ(rg->Context(), ctx) << "Each relation graph must have the same context.";
for (const auto& rg : rel_graphs) {
CHECK_EQ(rg->Context(), ctx)
<< "Each relation graph must have the same context.";
}
}
std::vector<int64_t>
InferNumVerticesPerType(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
std::vector<int64_t> InferNumVerticesPerType(
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs) {
// create num verts per type
std::vector<int64_t> num_verts_per_type(meta_graph->NumVertices(), -1);
EdgeArray etype_array = meta_graph->Edges();
dgl_type_t *srctypes = static_cast<dgl_type_t *>(etype_array.src->data);
dgl_type_t *dsttypes = static_cast<dgl_type_t *>(etype_array.dst->data);
dgl_type_t *etypes = static_cast<dgl_type_t *>(etype_array.id->data);
dgl_type_t* srctypes = static_cast<dgl_type_t*>(etype_array.src->data);
dgl_type_t* dsttypes = static_cast<dgl_type_t*>(etype_array.dst->data);
dgl_type_t* etypes = static_cast<dgl_type_t*>(etype_array.id->data);
for (size_t i = 0; i < meta_graph->NumEdges(); ++i) {
dgl_type_t srctype = srctypes[i];
dgl_type_t dsttype = dsttypes[i];
dgl_type_t etype = etypes[i];
const auto& rg = rel_graphs[etype];
const auto sty = 0;
const auto dty = rg->NumVertexTypes() == 1? 0 : 1;
const auto dty = rg->NumVertexTypes() == 1 ? 0 : 1;
size_t nv;
// # nodes of source type
......@@ -164,7 +171,8 @@ InferNumVerticesPerType(GraphPtr meta_graph, const std::vector<HeteroGraphPtr>&
return num_verts_per_type;
}
std::vector<UnitGraphPtr> CastToUnitGraphs(const std::vector<HeteroGraphPtr>& rel_graphs) {
std::vector<UnitGraphPtr> CastToUnitGraphs(
const std::vector<HeteroGraphPtr>& rel_graphs) {
std::vector<UnitGraphPtr> relation_graphs(rel_graphs.size());
for (size_t i = 0; i < rel_graphs.size(); ++i) {
HeteroGraphPtr relg = rel_graphs[i];
......@@ -181,9 +189,9 @@ std::vector<UnitGraphPtr> CastToUnitGraphs(const std::vector<HeteroGraphPtr>& re
} // namespace
HeteroGraph::HeteroGraph(
GraphPtr meta_graph,
const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type) : BaseHeteroGraph(meta_graph) {
GraphPtr meta_graph, const std::vector<HeteroGraphPtr>& rel_graphs,
const std::vector<int64_t>& num_nodes_per_type)
: BaseHeteroGraph(meta_graph) {
if (num_nodes_per_type.size() == 0)
num_verts_per_type_ = InferNumVerticesPerType(meta_graph, rel_graphs);
else
......@@ -193,7 +201,7 @@ HeteroGraph::HeteroGraph(
}
bool HeteroGraph::IsMultigraph() const {
for (const auto &hg : relation_graphs_) {
for (const auto& hg : relation_graphs_) {
if (hg->IsMultigraph()) {
return true;
}
......@@ -206,9 +214,11 @@ BoolArray HeteroGraph::HasVertices(dgl_type_t vtype, IdArray vids) const {
return aten::LT(vids, NumVertices(vtype));
}
HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) const {
HeteroSubgraph HeteroGraph::VertexSubgraph(
const std::vector<IdArray>& vids) const {
CHECK_EQ(vids.size(), NumVertexTypes())
<< "Invalid input: the input list size must be the same as the number of vertex types.";
<< "Invalid input: the input list size must be the same as the number of "
"vertex types.";
HeteroSubgraph ret;
ret.induced_vertices = vids;
std::vector<int64_t> num_vertices_per_type(NumVertexTypes());
......@@ -220,15 +230,16 @@ HeteroSubgraph HeteroGraph::VertexSubgraph(const std::vector<IdArray>& vids) con
auto pair = meta_graph_->FindEdge(etype);
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
const std::vector<IdArray> rel_vids = (src_vtype == dst_vtype) ?
std::vector<IdArray>({vids[src_vtype]}) :
std::vector<IdArray>({vids[src_vtype], vids[dst_vtype]});
const std::vector<IdArray> rel_vids =
(src_vtype == dst_vtype)
? std::vector<IdArray>({vids[src_vtype]})
: std::vector<IdArray>({vids[src_vtype], vids[dst_vtype]});
const auto& rel_vsg = GetRelationGraph(etype)->VertexSubgraph(rel_vids);
subrels[etype] = rel_vsg.graph;
ret.induced_edges[etype] = rel_vsg.induced_edges[0];
}
ret.graph = HeteroGraphPtr(new HeteroGraph(
meta_graph_, subrels, std::move(num_vertices_per_type)));
ret.graph = HeteroGraphPtr(
new HeteroGraph(meta_graph_, subrels, std::move(num_vertices_per_type)));
return ret;
}
......@@ -248,11 +259,11 @@ HeteroGraphPtr HeteroGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
for (auto g : hgindex->relation_graphs_) {
rel_graphs.push_back(UnitGraph::AsNumBits(g, bits));
}
return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs,
hgindex->num_verts_per_type_));
return HeteroGraphPtr(new HeteroGraph(
hgindex->meta_graph_, rel_graphs, hgindex->num_verts_per_type_));
}
HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DGLContext &ctx) {
HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DGLContext& ctx) {
if (ctx == g->Context()) {
return g;
}
......@@ -262,23 +273,20 @@ HeteroGraphPtr HeteroGraph::CopyTo(HeteroGraphPtr g, const DGLContext &ctx) {
for (auto g : hgindex->relation_graphs_) {
rel_graphs.push_back(UnitGraph::CopyTo(g, ctx));
}
return HeteroGraphPtr(new HeteroGraph(hgindex->meta_graph_, rel_graphs,
hgindex->num_verts_per_type_));
return HeteroGraphPtr(new HeteroGraph(
hgindex->meta_graph_, rel_graphs, hgindex->num_verts_per_type_));
}
void HeteroGraph::PinMemory_() {
for (auto g : relation_graphs_)
g->PinMemory_();
for (auto g : relation_graphs_) g->PinMemory_();
}
void HeteroGraph::UnpinMemory_() {
for (auto g : relation_graphs_)
g->UnpinMemory_();
for (auto g : relation_graphs_) g->UnpinMemory_();
}
void HeteroGraph::RecordStream(DGLStreamHandle stream) {
for (auto g : relation_graphs_)
g->RecordStream(stream);
for (auto g : relation_graphs_) g->RecordStream(stream);
}
std::string HeteroGraph::SharedMemName() const {
......@@ -286,13 +294,13 @@ std::string HeteroGraph::SharedMemName() const {
}
HeteroGraphPtr HeteroGraph::CopyToSharedMem(
HeteroGraphPtr g, const std::string& name, const std::vector<std::string>& ntypes,
HeteroGraphPtr g, const std::string& name,
const std::vector<std::string>& ntypes,
const std::vector<std::string>& etypes, const std::set<std::string>& fmts) {
// TODO(JJ): Raise error when calling shared_memory if graph index is on gpu
auto hg = std::dynamic_pointer_cast<HeteroGraph>(g);
CHECK_NOTNULL(hg);
if (hg->SharedMemName() == name)
return g;
if (hg->SharedMemName() == name) return g;
// Copy buffer to share memory
auto mem = std::make_shared<SharedMemory>(name);
......@@ -312,7 +320,7 @@ HeteroGraphPtr HeteroGraph::CopyToSharedMem(
std::vector<HeteroGraphPtr> relgraphs(g->NumEdgeTypes());
for (dgl_type_t etype = 0 ; etype < g->NumEdgeTypes() ; ++etype) {
for (dgl_type_t etype = 0; etype < g->NumEdgeTypes(); ++etype) {
auto src_dst_type = g->GetEndpointTypes(etype);
int num_vtypes = (src_dst_type.first == src_dst_type.second ? 1 : 2);
aten::COOMatrix coo;
......@@ -341,10 +349,11 @@ HeteroGraphPtr HeteroGraph::CopyToSharedMem(
}
std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
HeteroGraph::CreateFromSharedMem(const std::string &name) {
HeteroGraph::CreateFromSharedMem(const std::string& name) {
bool exist = SharedMemory::Exist(name);
if (!exist) {
return std::make_tuple(nullptr, std::vector<std::string>(), std::vector<std::string>());
return std::make_tuple(
nullptr, std::vector<std::string>(), std::vector<std::string>());
}
auto mem = std::make_shared<SharedMemory>(name);
auto mem_buf = mem->Open(SHARED_MEM_METAINFO_SIZE_MAX);
......@@ -367,7 +376,7 @@ std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
CHECK(shm.Read(&num_verts_per_type)) << "Invalid number of vertices per type";
std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
for (dgl_type_t etype = 0 ; etype < metagraph->NumEdges() ; ++etype) {
for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
auto src_dst = metagraph->FindEdge(etype);
int num_vtypes = (src_dst.first == src_dst.second) ? 1 : 2;
aten::COOMatrix coo;
......@@ -387,7 +396,8 @@ std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
num_vtypes, csc, csr, coo, has_csc, has_csr, has_coo);
}
auto ret = std::make_shared<HeteroGraph>(metagraph, relgraphs, num_verts_per_type);
auto ret =
std::make_shared<HeteroGraph>(metagraph, relgraphs, num_verts_per_type);
ret->shared_mem_ = mem;
std::vector<std::string> ntypes;
......@@ -400,11 +410,12 @@ std::tuple<HeteroGraphPtr, std::vector<std::string>, std::vector<std::string>>
HeteroGraphPtr HeteroGraph::GetGraphInFormat(dgl_format_code_t formats) const {
std::vector<HeteroGraphPtr> format_rels(NumEdgeTypes());
for (dgl_type_t etype = 0; etype < NumEdgeTypes(); ++etype) {
auto relgraph = std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(etype));
auto relgraph =
std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(etype));
format_rels[etype] = relgraph->GetGraphInFormat(formats);
}
return HeteroGraphPtr(new HeteroGraph(
meta_graph_, format_rels, NumVerticesPerType()));
return HeteroGraphPtr(
new HeteroGraph(meta_graph_, format_rels, NumVerticesPerType()));
}
FlattenedHeteroGraphPtr HeteroGraph::Flatten(
......@@ -418,15 +429,16 @@ FlattenedHeteroGraphPtr HeteroGraph::Flatten(
}
template <class IdType>
FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>& etypes) const {
FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(
const std::vector<dgl_type_t>& etypes) const {
std::unordered_map<dgl_type_t, size_t> srctype_offsets, dsttype_offsets;
size_t src_nodes = 0, dst_nodes = 0;
std::vector<dgl_type_t> induced_srctype, induced_dsttype;
std::vector<IdType> induced_srcid, induced_dstid;
std::vector<dgl_type_t> srctype_set, dsttype_set;
// XXXtype_offsets contain the mapping from node type and number of nodes after this
// loop.
// XXXtype_offsets contain the mapping from node type and number of nodes
// after this loop.
for (dgl_type_t etype : etypes) {
auto src_dsttype = meta_graph_->FindEdge(etype);
dgl_type_t srctype = src_dsttype.first;
......@@ -443,15 +455,16 @@ FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>&
dsttype_set.push_back(dsttype);
}
}
// Sort the node types so that we can compare the sets and decide whether a homogeneous graph
// should be returned.
// Sort the node types so that we can compare the sets and decide whether a
// homogeneous graph should be returned.
std::sort(srctype_set.begin(), srctype_set.end());
std::sort(dsttype_set.begin(), dsttype_set.end());
bool homograph = (srctype_set.size() == dsttype_set.size()) &&
bool homograph =
(srctype_set.size() == dsttype_set.size()) &&
std::equal(srctype_set.begin(), srctype_set.end(), dsttype_set.begin());
// XXXtype_offsets contain the mapping from node type to node ID offsets after these
// two loops.
// XXXtype_offsets contain the mapping from node type to node ID offsets after
// these two loops.
for (size_t i = 0; i < srctype_set.size(); ++i) {
dgl_type_t ntype = srctype_set[i];
size_t num_nodes = srctype_offsets[ntype];
......@@ -492,14 +505,12 @@ FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>&
src_arrs.push_back(edges.src + srctype_offset);
dst_arrs.push_back(edges.dst + dsttype_offset);
eid_arrs.push_back(edges.id);
induced_etypes.push_back(aten::Full(etype, num_edges, NumBits(), Context()));
induced_etypes.push_back(
aten::Full(etype, num_edges, NumBits(), Context()));
}
HeteroGraphPtr gptr = UnitGraph::CreateFromCOO(
homograph ? 1 : 2,
src_nodes,
dst_nodes,
aten::Concat(src_arrs),
homograph ? 1 : 2, src_nodes, dst_nodes, aten::Concat(src_arrs),
aten::Concat(dst_arrs));
// Sanity check
......@@ -507,15 +518,20 @@ FlattenedHeteroGraphPtr HeteroGraph::FlattenImpl(const std::vector<dgl_type_t>&
CHECK_EQ(gptr->NumBits(), NumBits());
FlattenedHeteroGraph* result = new FlattenedHeteroGraph;
result->graph = HeteroGraphRef(HeteroGraphPtr(new HeteroGraph(gptr->meta_graph(), {gptr})));
result->induced_srctype = aten::VecToIdArray(induced_srctype).CopyTo(Context());
result->induced_srctype_set = aten::VecToIdArray(srctype_set).CopyTo(Context());
result->graph = HeteroGraphRef(
HeteroGraphPtr(new HeteroGraph(gptr->meta_graph(), {gptr})));
result->induced_srctype =
aten::VecToIdArray(induced_srctype).CopyTo(Context());
result->induced_srctype_set =
aten::VecToIdArray(srctype_set).CopyTo(Context());
result->induced_srcid = aten::VecToIdArray(induced_srcid).CopyTo(Context());
result->induced_etype = aten::Concat(induced_etypes);
result->induced_etype_set = aten::VecToIdArray(etypes).CopyTo(Context());
result->induced_eid = aten::Concat(eid_arrs);
result->induced_dsttype = aten::VecToIdArray(induced_dsttype).CopyTo(Context());
result->induced_dsttype_set = aten::VecToIdArray(dsttype_set).CopyTo(Context());
result->induced_dsttype =
aten::VecToIdArray(induced_dsttype).CopyTo(Context());
result->induced_dsttype_set =
aten::VecToIdArray(dsttype_set).CopyTo(Context());
result->induced_dstid = aten::VecToIdArray(induced_dstid).CopyTo(Context());
return FlattenedHeteroGraphPtr(result);
}
......@@ -545,21 +561,25 @@ void HeteroGraph::Save(dmlc::Stream* fs) const {
GraphPtr HeteroGraph::AsImmutableGraph() const {
CHECK(NumVertexTypes() == 1) << "graph has more than one node types";
CHECK(NumEdgeTypes() == 1) << "graph has more than one edge types";
auto unit_graph = CHECK_NOTNULL(
std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(0)));
auto unit_graph =
CHECK_NOTNULL(std::dynamic_pointer_cast<UnitGraph>(GetRelationGraph(0)));
return unit_graph->AsImmutableGraph();
}
HeteroGraphPtr HeteroGraph::LineGraph(bool backtracking) const {
CHECK_EQ(1, meta_graph_->NumEdges()) << "Only support Homogeneous graph now (one edge type)";
CHECK_EQ(1, meta_graph_->NumVertices()) << "Only support Homogeneous graph now (one node type)";
CHECK_EQ(1, meta_graph_->NumEdges())
<< "Only support Homogeneous graph now (one edge type)";
CHECK_EQ(1, meta_graph_->NumVertices())
<< "Only support Homogeneous graph now (one node type)";
CHECK_EQ(1, relation_graphs_.size()) << "Only support Homogeneous graph now";
UnitGraphPtr ug = relation_graphs_[0];
const auto &ulg = ug->LineGraph(backtracking);
const auto& ulg = ug->LineGraph(backtracking);
std::vector<HeteroGraphPtr> rel_graph = {ulg};
std::vector<int64_t> num_nodes_per_type = {static_cast<int64_t>(ulg->NumVertices(0))};
return HeteroGraphPtr(new HeteroGraph(meta_graph_, rel_graph, std::move(num_nodes_per_type)));
std::vector<int64_t> num_nodes_per_type = {
static_cast<int64_t>(ulg->NumVertices(0))};
return HeteroGraphPtr(
new HeteroGraph(meta_graph_, rel_graph, std::move(num_nodes_per_type)));
}
} // namespace dgl
......@@ -5,11 +5,12 @@
*/
#include <dgl/array.h>
#include <dgl/aten/coo.h>
#include <dgl/packed_func_ext.h>
#include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h>
#include <dgl/runtime/c_runtime_api.h>
#include <set>
#include "../c_api_common.h"
......@@ -25,7 +26,7 @@ namespace dgl {
// XXX(minjie): Ideally, Unitgraph should be invisible to python side
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
int64_t nvtypes = args[0];
int64_t num_src = args[1];
int64_t num_dst = args[2];
......@@ -40,13 +41,13 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCOO")
formats_vec.push_back(ParseSparseFormat(fmt));
}
const auto code = SparseFormatsToCode(formats_vec);
auto hgptr = CreateFromCOO(nvtypes, num_src, num_dst, row, col,
row_sorted, col_sorted, code);
auto hgptr = CreateFromCOO(
nvtypes, num_src, num_dst, row, col, row_sorted, col_sorted, code);
*rv = HeteroGraphRef(hgptr);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCSR")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
int64_t nvtypes = args[0];
int64_t num_src = args[1];
int64_t num_dst = args[2];
......@@ -62,16 +63,18 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateUnitGraphFromCSR")
}
const auto code = SparseFormatsToCode(formats_vec);
if (!transpose) {
auto hgptr = CreateFromCSR(nvtypes, num_src, num_dst, indptr, indices, edge_ids, code);
auto hgptr = CreateFromCSR(
nvtypes, num_src, num_dst, indptr, indices, edge_ids, code);
*rv = HeteroGraphRef(hgptr);
} else {
auto hgptr = CreateFromCSC(nvtypes, num_src, num_dst, indptr, indices, edge_ids, code);
auto hgptr = CreateFromCSC(
nvtypes, num_src, num_dst, indptr, indices, edge_ids, code);
*rv = HeteroGraphRef(hgptr);
}
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0];
List<HeteroGraphRef> rel_graphs = args[1];
std::vector<HeteroGraphPtr> rel_ptrs;
......@@ -83,8 +86,9 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraph")
*rv = HeteroGraphRef(hgptr);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraphWithNumNodes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
DGL_REGISTER_GLOBAL(
"heterograph_index._CAPI_DGLHeteroCreateHeteroGraphWithNumNodes")
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0];
List<HeteroGraphRef> rel_graphs = args[1];
IdArray num_nodes_per_type = args[2];
......@@ -101,20 +105,20 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateHeteroGraphWithNumNo
///////////////////////// HeteroGraph member functions /////////////////////////
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetMetaGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = hg->meta_graph();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsMetaGraphUniBipartite")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
GraphPtr mg = hg->meta_graph();
*rv = mg->IsUniBipartite();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
......@@ -126,12 +130,13 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetRelationGraph")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
List<Value> etypes = args[1];
std::vector<dgl_id_t> etypes_vec;
for (Value val : etypes) {
// (gq) have to decompose it into two statements because of a weird MSVC internal error
// (gq) have to decompose it into two statements because of a weird MSVC
// internal error
dgl_id_t id = val->data;
etypes_vec.push_back(id);
}
......@@ -140,7 +145,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFlattenedGraph")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1];
int64_t num = args[2];
......@@ -148,7 +153,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddVertices")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdge")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
dgl_id_t src = args[2];
......@@ -157,7 +162,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdge")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray src = args[2];
......@@ -166,63 +171,63 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAddEdges")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroClear")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
hg->Clear();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDataType")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = hg->DataType();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroContext")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = hg->Context();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsPinned")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = hg->IsPinned();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = hg->NumBits();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsMultigraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = hg->IsMultigraph();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroIsReadonly")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = hg->IsReadonly();
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1];
*rv = static_cast<int64_t>(hg->NumVertices(vtype));
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroNumEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
*rv = static_cast<int64_t>(hg->NumEdges(etype));
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertex")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1];
dgl_id_t vid = args[2];
......@@ -230,7 +235,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertex")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t vtype = args[1];
IdArray vids = args[2];
......@@ -238,7 +243,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasVertices")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgeBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
dgl_id_t src = args[2];
......@@ -247,7 +252,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgeBetween")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgesBetween")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray src = args[2];
......@@ -256,7 +261,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroHasEdgesBetween")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPredecessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
dgl_id_t dst = args[2];
......@@ -264,7 +269,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPredecessors")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSuccessors")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
dgl_id_t src = args[2];
......@@ -272,7 +277,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSuccessors")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeId")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
dgl_id_t src = args[2];
......@@ -281,7 +286,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeId")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIdsAll")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray src = args[2];
......@@ -290,9 +295,8 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIdsAll")
*rv = ConvertEdgeArrayToPackedFunc(ret);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIdsOne")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray src = args[2];
......@@ -301,7 +305,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeIdsOne")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroFindEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray eids = args[2];
......@@ -310,7 +314,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroFindEdges")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
dgl_id_t vid = args[2];
......@@ -319,7 +323,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_1")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray vids = args[2];
......@@ -328,7 +332,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInEdges_2")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_1")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
dgl_id_t vid = args[2];
......@@ -337,7 +341,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_1")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray vids = args[2];
......@@ -346,7 +350,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutEdges_2")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
std::string order = args[2];
......@@ -355,7 +359,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdges")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
dgl_id_t vid = args[2];
......@@ -363,7 +367,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegree")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray vids = args[2];
......@@ -371,7 +375,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroInDegrees")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegree")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
dgl_id_t vid = args[2];
......@@ -379,7 +383,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegree")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegrees")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
IdArray vids = args[2];
......@@ -387,17 +391,16 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroOutDegrees")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1];
bool transpose = args[2];
std::string fmt = args[3];
*rv = ConvertNDArrayVectorToPackedFunc(
hg->GetAdj(etype, transpose, fmt));
*rv = ConvertNDArrayVectorToPackedFunc(hg->GetAdj(etype, transpose, fmt));
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroVertexSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
List<Value> vids = args[1];
bool relabel_nodes = args[2];
......@@ -413,7 +416,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroVertexSubgraph")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
List<Value> eids = args[1];
bool preserve_nodes = args[2];
......@@ -430,13 +433,14 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroEdgeSubgraph")
///////////////////////// HeteroSubgraph members /////////////////////////
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroSubgraphRef subg = args[0];
*rv = HeteroGraphRef(subg->graph);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedVertices")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
DGL_REGISTER_GLOBAL(
"heterograph_index._CAPI_DGLHeteroSubgraphGetInducedVertices")
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroSubgraphRef subg = args[0];
List<Value> induced_verts;
for (IdArray arr : subg->induced_vertices) {
......@@ -446,7 +450,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedVertices
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroSubgraphRef subg = args[0];
List<Value> induced_edges;
for (IdArray arr : subg->induced_edges) {
......@@ -455,10 +459,11 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSubgraphGetInducedEdges")
*rv = induced_edges;
});
///////////////////////// Global functions and algorithms /////////////////////////
///////////////////////// Global functions and algorithms
////////////////////////////
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
int bits = args[1];
HeteroGraphPtr bhg_ptr = hg.sptr();
......@@ -473,7 +478,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroAsNumBits")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
int device_type = args[1];
int device_id = args[2];
......@@ -485,7 +490,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyTo")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPinMemory_")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr());
hgindex->PinMemory_();
......@@ -493,7 +498,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPinMemory_")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpinMemory_")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr());
hgindex->UnpinMemory_();
......@@ -501,7 +506,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpinMemory_")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroRecordStream")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
DGLStreamHandle stream = args[1];
auto hgindex = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr());
......@@ -510,7 +515,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroRecordStream")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
std::string name = args[1];
List<Value> ntypes = args[2];
......@@ -519,7 +524,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyToSharedMem")
auto ntypes_vec = ListValueToVector<std::string>(ntypes);
auto etypes_vec = ListValueToVector<std::string>(etypes);
std::set<std::string> fmts_set;
for (const auto &fmt : fmts) {
for (const auto& fmt : fmts) {
std::string fmt_data = fmt->data;
fmts_set.insert(fmt_data);
}
......@@ -529,7 +534,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCopyToSharedMem")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateFromSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
std::string name = args[0];
HeteroGraphPtr hg;
std::vector<std::string> ntypes;
......@@ -537,9 +542,9 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateFromSharedMem")
std::tie(hg, ntypes, etypes) = HeteroGraph::CreateFromSharedMem(name);
List<Value> ntypes_list;
List<Value> etypes_list;
for (const auto &ntype : ntypes)
for (const auto& ntype : ntypes)
ntypes_list.push_back(Value(MakeValue(ntype)));
for (const auto &etype : etypes)
for (const auto& etype : etypes)
etypes_list.push_back(Value(MakeValue(etype)));
List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(hg));
......@@ -549,7 +554,7 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateFromSharedMem")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroJointUnion")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0];
List<HeteroGraphRef> component_graphs = args[1];
CHECK(component_graphs.size() > 1)
......@@ -561,8 +566,8 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroJointUnion")
for (const auto& component : component_graphs) {
component_ptrs.push_back(component.sptr());
CHECK_EQ(component->NumBits(), bits)
<< "Expect graphs to joint union have the same index dtype(int" << bits
<< "), but got int" << component->NumBits();
<< "Expect graphs to joint union have the same index dtype(int"
<< bits << "), but got int" << component->NumBits();
CHECK_EQ(component->Context(), ctx)
<< "Expect graphs to joint union have the same context" << ctx
<< "), but got " << component->Context();
......@@ -570,10 +575,10 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroJointUnion")
auto hgptr = JointUnionHeteroGraph(meta_graph.sptr(), component_ptrs);
*rv = HeteroGraphRef(hgptr);
});
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef meta_graph = args[0];
List<HeteroGraphRef> component_graphs = args[1];
CHECK(component_graphs.size() > 0)
......@@ -594,86 +599,86 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointUnion_v2")
auto hgptr = DisjointUnionHeteroGraph2(meta_graph.sptr(), component_ptrs);
*rv = HeteroGraphRef(hgptr);
});
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes_v2")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
DGL_REGISTER_GLOBAL(
"heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes_v2")
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
const IdArray vertex_sizes = args[1];
const IdArray edge_sizes = args[2];
std::vector<HeteroGraphPtr> ret;
ret = DisjointPartitionHeteroBySizes2(hg->meta_graph(), hg.sptr(),
vertex_sizes, edge_sizes);
ret = DisjointPartitionHeteroBySizes2(
hg->meta_graph(), hg.sptr(), vertex_sizes, edge_sizes);
List<HeteroGraphRef> ret_list;
for (HeteroGraphPtr hgptr : ret) {
ret_list.push_back(HeteroGraphRef(hgptr));
}
*rv = ret_list;
});
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroDisjointPartitionBySizes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
const IdArray vertex_sizes = args[1];
const IdArray edge_sizes = args[2];
const int64_t bits = hg->NumBits();
std::vector<HeteroGraphPtr> ret;
ATEN_ID_BITS_SWITCH(bits, IdType, {
ret = DisjointPartitionHeteroBySizes<IdType>(hg->meta_graph(), hg.sptr(),
vertex_sizes, edge_sizes);
ret = DisjointPartitionHeteroBySizes<IdType>(
hg->meta_graph(), hg.sptr(), vertex_sizes, edge_sizes);
});
List<HeteroGraphRef> ret_list;
for (HeteroGraphPtr hgptr : ret) {
ret_list.push_back(HeteroGraphRef(hgptr));
}
*rv = ret_list;
});
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroSlice")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
const IdArray num_nodes_per_type = args[1];
const IdArray start_nid_per_type = args[2];
const IdArray num_edges_per_type = args[3];
const IdArray start_eid_per_type = args[4];
auto hgptr = SliceHeteroGraph(hg->meta_graph(), hg.sptr(), num_nodes_per_type,
start_nid_per_type, num_edges_per_type, start_eid_per_type);
auto hgptr = SliceHeteroGraph(
hg->meta_graph(), hg.sptr(), num_nodes_per_type, start_nid_per_type,
num_edges_per_type, start_eid_per_type);
*rv = HeteroGraphRef(hgptr);
});
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetCreatedFormats")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
List<Value> format_list;
dgl_format_code_t code = hg->GetRelationGraph(0)->GetCreatedFormats();
for (auto format : CodeToSparseFormats(code)) {
format_list.push_back(
Value(MakeValue(ToStringSparseFormat(format))));
format_list.push_back(Value(MakeValue(ToStringSparseFormat(format))));
}
*rv = format_list;
});
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetAllowedFormats")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
List<Value> format_list;
dgl_format_code_t code = hg->GetRelationGraph(0)->GetAllowedFormats();
for (auto format : CodeToSparseFormats(code)) {
format_list.push_back(
Value(MakeValue(ToStringSparseFormat(format))));
format_list.push_back(Value(MakeValue(ToStringSparseFormat(format))));
}
*rv = format_list;
});
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateFormat")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
dgl_format_code_t code = hg->GetRelationGraph(0)->GetAllowedFormats();
auto get_format_f = [&](size_t etype_b, size_t etype_e) {
for (auto etype = etype_b; etype < etype_e; ++etype) {
auto bg = std::dynamic_pointer_cast<UnitGraph>(hg->GetRelationGraph(etype));
for (auto format : CodeToSparseFormats(code))
bg->GetFormat(format);
auto bg =
std::dynamic_pointer_cast<UnitGraph>(hg->GetRelationGraph(etype));
for (auto format : CodeToSparseFormats(code)) bg->GetFormat(format);
}
};
......@@ -682,10 +687,10 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroCreateFormat")
#else
get_format_f(0, hg->NumEdgeTypes());
#endif
});
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
List<Value> formats = args[1];
std::vector<SparseFormat> formats_vec;
......@@ -693,13 +698,12 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroGetFormatGraph")
std::string fmt = val->data;
formats_vec.push_back(ParseSparseFormat(fmt));
}
auto hgptr = hg->GetGraphInFormat(
SparseFormatsToCode(formats_vec));
auto hgptr = hg->GetGraphInFormat(SparseFormatsToCode(formats_vec));
*rv = HeteroGraphRef(hgptr);
});
});
DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLInSubgraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]);
bool relabel_nodes = args[2];
......@@ -709,7 +713,7 @@ DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLInSubgraph")
});
DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLOutSubgraph")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
const auto& nodes = ListValueToVector<IdArray>(args[1]);
bool relabel_nodes = args[2];
......@@ -719,27 +723,30 @@ DGL_REGISTER_GLOBAL("subgraph._CAPI_DGLOutSubgraph")
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLAsImmutableGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
*rv = GraphRef(hg->AsImmutableGraph());
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortOutEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
NDArray tag = args[1];
int64_t num_tag = args[2];
CHECK_EQ(hg->Context().device_type, kDGLCPU) << "Only support sorting by tag on cpu";
CHECK_EQ(hg->Context().device_type, kDGLCPU)
<< "Only support sorting by tag on cpu";
CHECK(aten::IsValidIdArray(tag));
CHECK_EQ(tag->ctx.device_type, kDGLCPU) << "Only support sorting by tag on cpu";
CHECK_EQ(tag->ctx.device_type, kDGLCPU)
<< "Only support sorting by tag on cpu";
const auto csr = hg->GetCSRMatrix(0);
NDArray tag_pos = aten::NullArray();
aten::CSRMatrix output;
std::tie(output, tag_pos) = aten::CSRSortByTag(csr, tag, num_tag);
HeteroGraphPtr output_hg = CreateFromCSR(hg->NumVertexTypes(), output, ALL_CODE);
HeteroGraphPtr output_hg =
CreateFromCSR(hg->NumVertexTypes(), output, ALL_CODE);
List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(output_hg));
ret.push_back(Value(MakeValue(tag_pos)));
......@@ -747,14 +754,16 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortOutEdges")
});
DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortInEdges")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
NDArray tag = args[1];
int64_t num_tag = args[2];
CHECK_EQ(hg->Context().device_type, kDGLCPU) << "Only support sorting by tag on cpu";
CHECK_EQ(hg->Context().device_type, kDGLCPU)
<< "Only support sorting by tag on cpu";
CHECK(aten::IsValidIdArray(tag));
CHECK_EQ(tag->ctx.device_type, kDGLCPU) << "Only support sorting by tag on cpu";
CHECK_EQ(tag->ctx.device_type, kDGLCPU)
<< "Only support sorting by tag on cpu";
const auto csc = hg->GetCSCMatrix(0);
......@@ -762,7 +771,8 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortInEdges")
aten::CSRMatrix output;
std::tie(output, tag_pos) = aten::CSRSortByTag(csc, tag, num_tag);
HeteroGraphPtr output_hg = CreateFromCSC(hg->NumVertexTypes(), output, ALL_CODE);
HeteroGraphPtr output_hg =
CreateFromCSC(hg->NumVertexTypes(), output, ALL_CODE);
List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(output_hg));
ret.push_back(Value(MakeValue(tag_pos)));
......@@ -770,7 +780,7 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLHeteroSortInEdges")
});
DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
GraphRef metagraph = args[0];
std::unordered_set<uint64_t> dst_set;
std::unordered_set<uint64_t> src_set;
......@@ -793,7 +803,8 @@ DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes")
else if (is_dst)
dstlist.push_back(Value(MakeValue(static_cast<int64_t>(nid))));
else
// If a node type is isolated, put it in srctype as defined in the Python docstring.
// If a node type is isolated, put it in srctype as defined in the
// Python docstring.
srclist.push_back(Value(MakeValue(static_cast<int64_t>(nid))));
}
ret_list.push_back(srclist);
......@@ -802,24 +813,24 @@ DGL_REGISTER_GLOBAL("heterograph._CAPI_DGLFindSrcDstNtypes")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroReverse")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0];
CHECK_GT(hg->NumEdgeTypes(), 0);
auto g = std::dynamic_pointer_cast<HeteroGraph>(hg.sptr());
std::vector<HeteroGraphPtr> rev_ugs;
const auto &ugs = g->relation_graphs();
const auto& ugs = g->relation_graphs();
rev_ugs.resize(ugs.size());
for (size_t i = 0; i < ugs.size(); ++i) {
const auto &rev_ug = ugs[i]->Reverse();
const auto& rev_ug = ugs[i]->Reverse();
rev_ugs[i] = rev_ug;
}
// node types are not changed
const auto& num_nodes = g->NumVerticesPerType();
const auto& meta_edges = hg->meta_graph()->Edges("eid");
// reverse the metagraph
const auto& rev_meta = ImmutableGraph::CreateFromCOO(hg->meta_graph()->NumVertices(),
meta_edges.dst, meta_edges.src);
const auto& rev_meta = ImmutableGraph::CreateFromCOO(
hg->meta_graph()->NumVertices(), meta_edges.dst, meta_edges.src);
*rv = CreateHeteroGraph(rev_meta, rev_ugs, num_nodes);
});
} // namespace dgl
......@@ -4,13 +4,14 @@
* \brief DGL immutable graph index implementation
*/
#include <dgl/base_heterograph.h>
#include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/smart_ptr_serializer.h>
#include <dgl/base_heterograph.h>
#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <string.h>
#include <bitset>
#include <numeric>
#include <tuple>
......@@ -23,7 +24,8 @@ using namespace dgl::runtime;
namespace dgl {
namespace {
inline std::string GetSharedMemName(const std::string &name, const std::string &edge_dir) {
inline std::string GetSharedMemName(
const std::string &name, const std::string &edge_dir) {
return name + "_" + edge_dir;
}
......@@ -39,8 +41,9 @@ struct GraphIndexMetadata {
};
/*
* Serialize the metadata of a graph index and place it in a shared-memory tensor.
* In this way, another process can reconstruct a GraphIndex from a shared-memory tensor.
* Serialize the metadata of a graph index and place it in a shared-memory
* tensor. In this way, another process can reconstruct a GraphIndex from a
* shared-memory tensor.
*/
NDArray SerializeMetadata(ImmutableGraphPtr gidx, const std::string &name) {
#ifndef _WIN32
......@@ -51,8 +54,9 @@ NDArray SerializeMetadata(ImmutableGraphPtr gidx, const std::string &name) {
meta.has_out_csr = gidx->HasOutCSR();
meta.has_coo = false;
NDArray meta_arr = NDArray::EmptyShared(name, {sizeof(meta)}, DGLDataType{kDGLInt, 8, 1},
DGLContext{kDGLCPU, 0}, true);
NDArray meta_arr = NDArray::EmptyShared(
name, {sizeof(meta)}, DGLDataType{kDGLInt, 8, 1}, DGLContext{kDGLCPU, 0},
true);
memcpy(meta_arr->data, &meta, sizeof(meta));
return meta_arr;
#else
......@@ -67,8 +71,9 @@ NDArray SerializeMetadata(ImmutableGraphPtr gidx, const std::string &name) {
GraphIndexMetadata DeserializeMetadata(const std::string &name) {
GraphIndexMetadata meta;
#ifndef _WIN32
NDArray meta_arr = NDArray::EmptyShared(name, {sizeof(meta)}, DGLDataType{kDGLInt, 8, 1},
DGLContext{kDGLCPU, 0}, false);
NDArray meta_arr = NDArray::EmptyShared(
name, {sizeof(meta)}, DGLDataType{kDGLInt, 8, 1}, DGLContext{kDGLCPU, 0},
false);
memcpy(&meta, meta_arr->data, sizeof(meta));
#else
LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet";
......@@ -77,18 +82,23 @@ GraphIndexMetadata DeserializeMetadata(const std::string &name) {
}
std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges, bool is_create) {
const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges,
bool is_create) {
#ifndef _WIN32
const int64_t file_size = (num_verts + 1 + num_edges * 2) * sizeof(dgl_id_t);
IdArray sm_array = IdArray::EmptyShared(
shared_mem_name, {file_size}, DGLDataType{kDGLInt, 8, 1}, DGLContext{kDGLCPU, 0}, is_create);
shared_mem_name, {file_size}, DGLDataType{kDGLInt, 8, 1},
DGLContext{kDGLCPU, 0}, is_create);
// Create views from the shared memory array. Note that we don't need to save
// the sm_array because the refcount is maintained by the view arrays.
IdArray indptr = sm_array.CreateView({num_verts + 1}, DGLDataType{kDGLInt, 64, 1});
IdArray indices = sm_array.CreateView({num_edges}, DGLDataType{kDGLInt, 64, 1},
IdArray indptr =
sm_array.CreateView({num_verts + 1}, DGLDataType{kDGLInt, 64, 1});
IdArray indices = sm_array.CreateView(
{num_edges}, DGLDataType{kDGLInt, 64, 1},
(num_verts + 1) * sizeof(dgl_id_t));
IdArray edge_ids = sm_array.CreateView({num_edges}, DGLDataType{kDGLInt, 64, 1},
IdArray edge_ids = sm_array.CreateView(
{num_edges}, DGLDataType{kDGLInt, 64, 1},
(num_verts + 1 + num_edges) * sizeof(dgl_id_t));
return std::make_tuple(indptr, indices, edge_ids);
#else
......@@ -106,10 +116,9 @@ std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
CSR::CSR(int64_t num_vertices, int64_t num_edges) {
CHECK(!(num_vertices == 0 && num_edges != 0));
adj_ = aten::CSRMatrix{num_vertices, num_vertices,
aten::NewIdArray(num_vertices + 1),
aten::NewIdArray(num_edges),
aten::NewIdArray(num_edges)};
adj_ = aten::CSRMatrix{
num_vertices, num_vertices, aten::NewIdArray(num_vertices + 1),
aten::NewIdArray(num_edges), aten::NewIdArray(num_edges)};
adj_.sorted = false;
}
......@@ -123,8 +132,10 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) {
adj_.sorted = false;
}
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &shared_mem_name): shared_mem_name_(shared_mem_name) {
CSR::CSR(
IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &shared_mem_name)
: shared_mem_name_(shared_mem_name) {
CHECK(aten::IsValidIdArray(indptr));
CHECK(aten::IsValidIdArray(indices));
CHECK(aten::IsValidIdArray(edge_ids));
......@@ -133,8 +144,8 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
const int64_t num_edges = indices->shape[0];
adj_.num_rows = num_verts;
adj_.num_cols = num_verts;
std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
shared_mem_name, num_verts, num_edges, true);
std::tie(adj_.indptr, adj_.indices, adj_.data) =
MapFromSharedMemory(shared_mem_name, num_verts, num_edges, true);
// copy the given data into the shared memory arrays
adj_.indptr.CopyFrom(indptr);
adj_.indices.CopyFrom(indices);
......@@ -142,19 +153,18 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
adj_.sorted = false;
}
CSR::CSR(const std::string &shared_mem_name,
int64_t num_verts, int64_t num_edges): shared_mem_name_(shared_mem_name) {
CSR::CSR(
const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges)
: shared_mem_name_(shared_mem_name) {
CHECK(!(num_verts == 0 && num_edges != 0));
adj_.num_rows = num_verts;
adj_.num_cols = num_verts;
std::tie(adj_.indptr, adj_.indices, adj_.data) = MapFromSharedMemory(
shared_mem_name, num_verts, num_edges, false);
std::tie(adj_.indptr, adj_.indices, adj_.data) =
MapFromSharedMemory(shared_mem_name, num_verts, num_edges, false);
adj_.sorted = false;
}
bool CSR::IsMultigraph() const {
return aten::CSRHasDuplicate(adj_);
}
bool CSR::IsMultigraph() const { return aten::CSRHasDuplicate(adj_); }
EdgeArray CSR::OutEdges(dgl_id_t vid) const {
CHECK(HasVertex(vid)) << "invalid vertex: " << vid;
......@@ -204,7 +214,7 @@ IdArray CSR::EdgeId(dgl_id_t src, dgl_id_t dst) const {
}
EdgeArray CSR::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
const auto& arrs = aten::CSRGetDataAndIndices(adj_, src_ids, dst_ids);
const auto &arrs = aten::CSRGetDataAndIndices(adj_, src_ids, dst_ids);
return EdgeArray{arrs[0], arrs[1], arrs[2]};
}
......@@ -212,14 +222,15 @@ EdgeArray CSR::Edges(const std::string &order) const {
CHECK(order.empty() || order == std::string("srcdst"))
<< "CSR only support Edges of order \"srcdst\","
<< " but got \"" << order << "\".";
const auto& coo = aten::CSRToCOO(adj_, false);
const auto &coo = aten::CSRToCOO(adj_, false);
return EdgeArray{coo.row, coo.col, coo.data};
}
Subgraph CSR::VertexSubgraph(IdArray vids) const {
CHECK(aten::IsValidIdArray(vids)) << "Invalid vertex id array.";
const auto& submat = aten::CSRSliceMatrix(adj_, vids, vids);
IdArray sub_eids = aten::Range(0, submat.data->shape[0], NumBits(), Context());
const auto &submat = aten::CSRSliceMatrix(adj_, vids, vids);
IdArray sub_eids =
aten::Range(0, submat.data->shape[0], NumBits(), Context());
CSRPtr subcsr(new CSR(submat.indptr, submat.indices, sub_eids));
subcsr->adj_.sorted = this->adj_.sorted;
Subgraph subg;
......@@ -230,21 +241,21 @@ Subgraph CSR::VertexSubgraph(IdArray vids) const {
}
CSRPtr CSR::Transpose() const {
const auto& trans = aten::CSRTranspose(adj_);
const auto &trans = aten::CSRTranspose(adj_);
return CSRPtr(new CSR(trans.indptr, trans.indices, trans.data));
}
COOPtr CSR::ToCOO() const {
const auto& coo = aten::CSRToCOO(adj_, true);
const auto &coo = aten::CSRToCOO(adj_, true);
return COOPtr(new COO(NumVertices(), coo.row, coo.col));
}
CSR CSR::CopyTo(const DGLContext& ctx) const {
CSR CSR::CopyTo(const DGLContext &ctx) const {
if (Context() == ctx) {
return *this;
} else {
CSR ret(adj_.indptr.CopyTo(ctx),
adj_.indices.CopyTo(ctx),
CSR ret(
adj_.indptr.CopyTo(ctx), adj_.indices.CopyTo(ctx),
adj_.data.CopyTo(ctx));
return ret;
}
......@@ -264,8 +275,8 @@ CSR CSR::AsNumBits(uint8_t bits) const {
if (NumBits() == bits) {
return *this;
} else {
CSR ret(aten::AsNumBits(adj_.indptr, bits),
aten::AsNumBits(adj_.indices, bits),
CSR ret(
aten::AsNumBits(adj_.indptr, bits), aten::AsNumBits(adj_.indices, bits),
aten::AsNumBits(adj_.data, bits));
return ret;
}
......@@ -274,8 +285,8 @@ CSR CSR::AsNumBits(uint8_t bits) const {
DGLIdIters CSR::SuccVec(dgl_id_t vid) const {
// TODO(minjie): This still assumes the data type and device context
// of this graph. Should fix later.
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
const dgl_id_t* indices_data = static_cast<dgl_id_t*>(adj_.indices->data);
const dgl_id_t *indptr_data = static_cast<dgl_id_t *>(adj_.indptr->data);
const dgl_id_t *indices_data = static_cast<dgl_id_t *>(adj_.indices->data);
const dgl_id_t start = indptr_data[vid];
const dgl_id_t end = indptr_data[vid + 1];
return DGLIdIters(indices_data + start, indices_data + end);
......@@ -284,29 +295,28 @@ DGLIdIters CSR::SuccVec(dgl_id_t vid) const {
DGLIdIters CSR::OutEdgeVec(dgl_id_t vid) const {
// TODO(minjie): This still assumes the data type and device context
// of this graph. Should fix later.
const dgl_id_t* indptr_data = static_cast<dgl_id_t*>(adj_.indptr->data);
const dgl_id_t* eid_data = static_cast<dgl_id_t*>(adj_.data->data);
const dgl_id_t *indptr_data = static_cast<dgl_id_t *>(adj_.indptr->data);
const dgl_id_t *eid_data = static_cast<dgl_id_t *>(adj_.data->data);
const dgl_id_t start = indptr_data[vid];
const dgl_id_t end = indptr_data[vid + 1];
return DGLIdIters(eid_data + start, eid_data + end);
}
bool CSR::Load(dmlc::Stream *fs) {
fs->Read(const_cast<dgl::aten::CSRMatrix*>(&adj_));
fs->Read(const_cast<dgl::aten::CSRMatrix *>(&adj_));
return true;
}
void CSR::Save(dmlc::Stream *fs) const {
fs->Write(adj_);
}
void CSR::Save(dmlc::Stream *fs) const { fs->Write(adj_); }
//////////////////////////////////////////////////////////
//
// COO graph implementation
//
//////////////////////////////////////////////////////////
COO::COO(int64_t num_vertices, IdArray src, IdArray dst,
bool row_sorted, bool col_sorted) {
COO::COO(
int64_t num_vertices, IdArray src, IdArray dst, bool row_sorted,
bool col_sorted) {
CHECK(aten::IsValidIdArray(src));
CHECK(aten::IsValidIdArray(dst));
CHECK_EQ(src->shape[0], dst->shape[0]);
......@@ -314,9 +324,7 @@ COO::COO(int64_t num_vertices, IdArray src, IdArray dst,
aten::NullArray(), row_sorted, col_sorted};
}
bool COO::IsMultigraph() const {
return aten::COOHasDuplicate(adj_);
}
bool COO::IsMultigraph() const { return aten::COOHasDuplicate(adj_); }
std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
CHECK(eid < NumEdges()) << "Invalid edge id: " << eid;
......@@ -327,17 +335,17 @@ std::pair<dgl_id_t, dgl_id_t> COO::FindEdge(dgl_id_t eid) const {
EdgeArray COO::FindEdges(IdArray eids) const {
CHECK(aten::IsValidIdArray(eids)) << "Invalid edge id array";
BUG_IF_FAIL(aten::IsNullArray(adj_.data)) <<
"FindEdges requires the internal COO matrix not having EIDs.";
return EdgeArray{aten::IndexSelect(adj_.row, eids),
aten::IndexSelect(adj_.col, eids),
BUG_IF_FAIL(aten::IsNullArray(adj_.data))
<< "FindEdges requires the internal COO matrix not having EIDs.";
return EdgeArray{
aten::IndexSelect(adj_.row, eids), aten::IndexSelect(adj_.col, eids),
eids};
}
EdgeArray COO::Edges(const std::string &order) const {
CHECK(order.empty() || order == std::string("eid"))
<< "COO only support Edges of order \"eid\", but got \""
<< order << "\".";
<< "COO only support Edges of order \"eid\", but got \"" << order
<< "\".";
IdArray rst_eid = aten::Range(0, NumEdges(), NumBits(), Context());
return EdgeArray{adj_.row, adj_.col, rst_eid};
}
......@@ -366,17 +374,15 @@ Subgraph COO::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
}
CSRPtr COO::ToCSR() const {
const auto& csr = aten::COOToCSR(adj_);
const auto &csr = aten::COOToCSR(adj_);
return CSRPtr(new CSR(csr.indptr, csr.indices, csr.data));
}
COO COO::CopyTo(const DGLContext& ctx) const {
COO COO::CopyTo(const DGLContext &ctx) const {
if (Context() == ctx) {
return *this;
} else {
COO ret(NumVertices(),
adj_.row.CopyTo(ctx),
adj_.col.CopyTo(ctx));
COO ret(NumVertices(), adj_.row.CopyTo(ctx), adj_.col.CopyTo(ctx));
return ret;
}
}
......@@ -390,8 +396,8 @@ COO COO::AsNumBits(uint8_t bits) const {
if (NumBits() == bits) {
return *this;
} else {
COO ret(NumVertices(),
aten::AsNumBits(adj_.row, bits),
COO ret(
NumVertices(), aten::AsNumBits(adj_.row, bits),
aten::AsNumBits(adj_.col, bits));
return ret;
}
......@@ -411,13 +417,14 @@ BoolArray ImmutableGraph::HasVertices(IdArray vids) const {
CSRPtr ImmutableGraph::GetInCSR() const {
if (!in_csr_) {
if (out_csr_) {
const_cast<ImmutableGraph*>(this)->in_csr_ = out_csr_->Transpose();
const_cast<ImmutableGraph *>(this)->in_csr_ = out_csr_->Transpose();
if (out_csr_->IsSharedMem())
LOG(WARNING) << "We just construct an in-CSR from a shared-memory out CSR. "
LOG(WARNING)
<< "We just construct an in-CSR from a shared-memory out CSR. "
<< "It may dramatically increase memory consumption.";
} else {
CHECK(coo_) << "None of CSR, COO exist";
const_cast<ImmutableGraph*>(this)->in_csr_ = coo_->Transpose()->ToCSR();
const_cast<ImmutableGraph *>(this)->in_csr_ = coo_->Transpose()->ToCSR();
}
}
return in_csr_;
......@@ -427,13 +434,14 @@ CSRPtr ImmutableGraph::GetInCSR() const {
CSRPtr ImmutableGraph::GetOutCSR() const {
if (!out_csr_) {
if (in_csr_) {
const_cast<ImmutableGraph*>(this)->out_csr_ = in_csr_->Transpose();
const_cast<ImmutableGraph *>(this)->out_csr_ = in_csr_->Transpose();
if (in_csr_->IsSharedMem())
LOG(WARNING) << "We just construct an out-CSR from a shared-memory in CSR. "
LOG(WARNING)
<< "We just construct an out-CSR from a shared-memory in CSR. "
<< "It may dramatically increase memory consumption.";
} else {
CHECK(coo_) << "None of CSR, COO exist";
const_cast<ImmutableGraph*>(this)->out_csr_ = coo_->ToCSR();
const_cast<ImmutableGraph *>(this)->out_csr_ = coo_->ToCSR();
}
}
return out_csr_;
......@@ -443,10 +451,10 @@ CSRPtr ImmutableGraph::GetOutCSR() const {
COOPtr ImmutableGraph::GetCOO() const {
if (!coo_) {
if (in_csr_) {
const_cast<ImmutableGraph*>(this)->coo_ = in_csr_->ToCOO()->Transpose();
const_cast<ImmutableGraph *>(this)->coo_ = in_csr_->ToCOO()->Transpose();
} else {
CHECK(out_csr_) << "Both CSR are missing.";
const_cast<ImmutableGraph*>(this)->coo_ = out_csr_->ToCOO();
const_cast<ImmutableGraph *>(this)->coo_ = out_csr_->ToCOO();
}
}
return coo_;
......@@ -457,7 +465,7 @@ EdgeArray ImmutableGraph::Edges(const std::string &order) const {
// arbitrary order
if (in_csr_) {
// transpose
const auto& edges = in_csr_->Edges(order);
const auto &edges = in_csr_->Edges(order);
return EdgeArray{edges.dst, edges.src, edges.id};
} else {
return AnyGraph()->Edges(order);
......@@ -489,16 +497,19 @@ Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids, bool preserve_nodes) const {
return sg;
}
std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &fmt) const {
// TODO(minjie): Our current semantics of adjacency matrix is row for dst nodes and col for
// src nodes. Therefore, we need to flip the transpose flag. For example, transpose=False
// is equal to in edge CSR.
// We have this behavior because previously we use framework's SPMM and we don't cache
// reverse adj. This is not intuitive and also not consistent with networkx's
// to_scipy_sparse_matrix. With the upcoming custom kernel change, we should change the
// behavior and make row for src and col for dst.
std::vector<IdArray> ImmutableGraph::GetAdj(
bool transpose, const std::string &fmt) const {
// TODO(minjie): Our current semantics of adjacency matrix is row for dst
// nodes and col for
// src nodes. Therefore, we need to flip the transpose flag. For example,
// transpose=False is equal to in edge CSR. We have this behavior because
// previously we use framework's SPMM and we don't cache reverse adj. This
// is not intuitive and also not consistent with networkx's
// to_scipy_sparse_matrix. With the upcoming custom kernel change, we should
// change the behavior and make row for src and col for dst.
if (fmt == std::string("csr")) {
return transpose? GetOutCSR()->GetAdj(false, "csr") : GetInCSR()->GetAdj(false, "csr");
return transpose ? GetOutCSR()->GetAdj(false, "csr")
: GetInCSR()->GetAdj(false, "csr");
} else if (fmt == std::string("coo")) {
return GetCOO()->GetAdj(!transpose, fmt);
} else {
......@@ -508,7 +519,8 @@ std::vector<IdArray> ImmutableGraph::GetAdj(bool transpose, const std::string &f
}
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir) {
IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir) {
CSRPtr csr(new CSR(indptr, indices, edge_ids));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr));
......@@ -530,17 +542,19 @@ ImmutableGraphPtr ImmutableGraph::CreateFromCSR(const std::string &name) {
GraphIndexMetadata meta = DeserializeMetadata(GetSharedMemName(name, "meta"));
CSRPtr in_csr, out_csr;
if (meta.has_in_csr) {
in_csr = CSRPtr(new CSR(GetSharedMemName(name, "in"), meta.num_nodes, meta.num_edges));
in_csr = CSRPtr(
new CSR(GetSharedMemName(name, "in"), meta.num_nodes, meta.num_edges));
}
if (meta.has_out_csr) {
out_csr = CSRPtr(new CSR(GetSharedMemName(name, "out"), meta.num_nodes, meta.num_edges));
out_csr = CSRPtr(
new CSR(GetSharedMemName(name, "out"), meta.num_nodes, meta.num_edges));
}
return ImmutableGraphPtr(new ImmutableGraph(in_csr, out_csr, name));
}
ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
int64_t num_vertices, IdArray src, IdArray dst,
bool row_sorted, bool col_sorted) {
int64_t num_vertices, IdArray src, IdArray dst, bool row_sorted,
bool col_sorted) {
COOPtr coo(new COO(num_vertices, src, dst, row_sorted, col_sorted));
return std::make_shared<ImmutableGraph>(coo);
}
......@@ -550,13 +564,14 @@ ImmutableGraphPtr ImmutableGraph::ToImmutable(GraphPtr graph) {
if (ig) {
return ig;
} else {
const auto& adj = graph->GetAdj(true, "csr");
const auto &adj = graph->GetAdj(true, "csr");
CSRPtr csr(new CSR(adj[0], adj[1], adj[2]));
return ImmutableGraph::CreateFromCSR(adj[0], adj[1], adj[2], "out");
}
}
ImmutableGraphPtr ImmutableGraph::CopyTo(ImmutableGraphPtr g, const DGLContext& ctx) {
ImmutableGraphPtr ImmutableGraph::CopyTo(
ImmutableGraphPtr g, const DGLContext &ctx) {
if (ctx == g->Context()) {
return g;
}
......@@ -569,16 +584,20 @@ ImmutableGraphPtr ImmutableGraph::CopyTo(ImmutableGraphPtr g, const DGLContext&
return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
}
ImmutableGraphPtr ImmutableGraph::CopyToSharedMem(ImmutableGraphPtr g, const std::string &name) {
ImmutableGraphPtr ImmutableGraph::CopyToSharedMem(
ImmutableGraphPtr g, const std::string &name) {
CSRPtr new_incsr, new_outcsr;
std::string shared_mem_name = GetSharedMemName(name, "in");
new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyToSharedMem(shared_mem_name)));
shared_mem_name = GetSharedMemName(name, "out");
new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyToSharedMem(shared_mem_name)));
new_outcsr =
CSRPtr(new CSR(g->GetOutCSR()->CopyToSharedMem(shared_mem_name)));
auto new_g = ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr, name));
new_g->serialized_shared_meta_ = SerializeMetadata(new_g, GetSharedMemName(name, "meta"));
auto new_g =
ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr, name));
new_g->serialized_shared_meta_ =
SerializeMetadata(new_g, GetSharedMemName(name, "meta"));
return new_g;
}
......@@ -598,8 +617,8 @@ ImmutableGraphPtr ImmutableGraph::AsNumBits(ImmutableGraphPtr g, uint8_t bits) {
ImmutableGraphPtr ImmutableGraph::Reverse() const {
if (coo_) {
return ImmutableGraphPtr(new ImmutableGraph(
out_csr_, in_csr_, coo_->Transpose()));
return ImmutableGraphPtr(
new ImmutableGraph(out_csr_, in_csr_, coo_->Transpose()));
} else {
return ImmutableGraphPtr(new ImmutableGraph(out_csr_, in_csr_));
}
......@@ -628,54 +647,53 @@ HeteroGraphPtr ImmutableGraph::AsHeteroGraph() const {
aten::CSRMatrix in_csr, out_csr;
aten::COOMatrix coo;
if (in_csr_)
in_csr = GetInCSR()->ToCSRMatrix();
if (out_csr_)
out_csr = GetOutCSR()->ToCSRMatrix();
if (coo_)
coo = GetCOO()->ToCOOMatrix();
if (in_csr_) in_csr = GetInCSR()->ToCSRMatrix();
if (out_csr_) out_csr = GetOutCSR()->ToCSRMatrix();
if (coo_) coo = GetCOO()->ToCOOMatrix();
auto g = UnitGraph::CreateUnitGraphFrom(
1, in_csr, out_csr, coo,
in_csr_ != nullptr,
out_csr_ != nullptr,
1, in_csr, out_csr, coo, in_csr_ != nullptr, out_csr_ != nullptr,
coo_ != nullptr);
return HeteroGraphPtr(new HeteroGraph(g->meta_graph(), {g}));
}
DGL_REGISTER_GLOBAL("transform._CAPI_DGLAsHeteroGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef g = args[0];
ImmutableGraphPtr ig = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
ImmutableGraphPtr ig =
std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(ig) << "graph is not readonly";
*rv = HeteroGraphRef(ig->AsHeteroGraph());
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef g = args[0];
const int device_type = args[1];
const int device_id = args[2];
DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id;
ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
ImmutableGraphPtr ig =
CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
*rv = ImmutableGraph::CopyTo(ig, ctx);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef g = args[0];
std::string name = args[1];
ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
ImmutableGraphPtr ig =
CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
*rv = ImmutableGraph::CopyToSharedMem(ig, name);
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphAsNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef g = args[0];
int bits = args[1];
ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
ImmutableGraphPtr ig =
CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
*rv = ImmutableGraph::AsNumBits(ig, bits);
});
......
......@@ -4,9 +4,10 @@
* \brief Call Metis partitioning
*/
#include <metis.h>
#include <dgl/graph_op.h>
#include <dgl/packed_func_ext.h>
#include <metis.h>
#include "../c_api_common.h"
using namespace dgl::runtime;
......@@ -25,12 +26,12 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) {
idx_t nvtxs = g->NumVertices();
idx_t ncon = 1; // # balacing constraints.
idx_t *xadj = static_cast<idx_t*>(mat.indptr->data);
idx_t *adjncy = static_cast<idx_t*>(mat.indices->data);
idx_t *xadj = static_cast<idx_t *>(mat.indptr->data);
idx_t *adjncy = static_cast<idx_t *>(mat.indices->data);
idx_t nparts = k;
IdArray part_arr = aten::NewIdArray(nvtxs);
idx_t objval = 0;
idx_t *part = static_cast<idx_t*>(part_arr->data);
idx_t *part = static_cast<idx_t *>(part_arr->data);
int64_t vwgt_len = vwgt_arr->shape[0];
CHECK_EQ(sizeof(idx_t), vwgt_arr->dtype.bits / 8)
......@@ -40,7 +41,7 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) {
idx_t *vwgt = NULL;
if (vwgt_len > 0) {
ncon = vwgt_len / g->NumVertices();
vwgt = static_cast<idx_t*>(vwgt_arr->data);
vwgt = static_cast<idx_t *>(vwgt_arr->data);
}
idx_t options[METIS_NOPTIONS];
......@@ -56,7 +57,8 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) {
options[METIS_OPTION_OBJTYPE] = METIS_OBJTYPE_VOL;
}
int ret = METIS_PartGraphKway(&nvtxs, // The number of vertices
int ret = METIS_PartGraphKway(
&nvtxs, // The number of vertices
&ncon, // The number of balancing constraints.
xadj, // indptr
adjncy, // indices
......@@ -99,7 +101,7 @@ IdArray MetisPartition(GraphPtr g, int k, NDArray vwgt_arr, bool obj_cut) {
#endif // !defined(_WIN32)
DGL_REGISTER_GLOBAL("transform._CAPI_DGLMetisPartition")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef g = args[0];
int k = args[1];
NDArray vwgt = args[2];
......
......@@ -5,21 +5,20 @@
*/
#include "./network.h"
#include <stdlib.h>
#include <dgl/immutable_graph.h>
#include <dgl/nodeflow.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/parallel_for.h>
#include <dgl/packed_func_ext.h>
#include <dgl/immutable_graph.h>
#include <dgl/nodeflow.h>
#include <stdlib.h>
#include <unordered_map>
#include "../rpc/network/common.h"
#include "../rpc/network/communicator.h"
#include "../rpc/network/socket_communicator.h"
#include "../rpc/network/msg_queue.h"
#include "../rpc/network/common.h"
#include "../rpc/network/socket_communicator.h"
using dgl::network::StringPrintf;
using namespace dgl::runtime;
......@@ -29,11 +28,8 @@ const bool AUTO_FREE = true;
namespace dgl {
namespace network {
NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape,
DGLDataType dtype,
DGLContext ctx,
void* raw,
NDArray CreateNDArrayFromRaw(
std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx, void* raw,
bool auto_free) {
return NDArray::CreateFromRaw(shape, dtype, ctx, raw, auto_free);
}
......@@ -74,16 +70,16 @@ char* ArrayMeta::Serialize(int64_t* size) {
*(reinterpret_cast<int*>(pointer)) = ndarray_count_;
pointer += sizeof(ndarray_count_);
// Write data type
memcpy(pointer,
reinterpret_cast<DGLDataType*>(data_type_.data()),
memcpy(
pointer, reinterpret_cast<DGLDataType*>(data_type_.data()),
sizeof(DGLDataType) * data_type_.size());
pointer += (sizeof(DGLDataType) * data_type_.size());
// Write size of data_shape_
*(reinterpret_cast<size_t*>(pointer)) = data_shape_.size();
pointer += sizeof(data_shape_.size());
// Write data of data_shape_
memcpy(pointer,
reinterpret_cast<char*>(data_shape_.data()),
memcpy(
pointer, reinterpret_cast<char*>(data_shape_.data()),
sizeof(int64_t) * data_shape_.size());
}
*size = buffer_size;
......@@ -103,8 +99,7 @@ void ArrayMeta::Deserialize(char* buffer, int64_t size) {
data_size += sizeof(int);
// Read data type
data_type_.resize(ndarray_count_);
memcpy(data_type_.data(), buffer,
ndarray_count_ * sizeof(DGLDataType));
memcpy(data_type_.data(), buffer, ndarray_count_ * sizeof(DGLDataType));
buffer += ndarray_count_ * sizeof(DGLDataType);
data_size += ndarray_count_ * sizeof(DGLDataType);
// Read size of data_shape_
......@@ -113,8 +108,7 @@ void ArrayMeta::Deserialize(char* buffer, int64_t size) {
data_size += sizeof(size_t);
data_shape_.resize(count);
// Read data of data_shape_
memcpy(data_shape_.data(), buffer,
count * sizeof(int64_t));
memcpy(data_shape_.data(), buffer, count * sizeof(int64_t));
data_size += count * sizeof(int64_t);
}
CHECK_EQ(data_size, size);
......@@ -170,11 +164,11 @@ void KVStoreMsg::Deserialize(char* buffer, int64_t size) {
CHECK_EQ(data_size, size);
}
////////////////////////////////// Basic Networking Components ////////////////////////////////
////////////////////////////////// Basic Networking Components
///////////////////////////////////
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
std::string type = args[0];
int64_t msg_queue_size = args[1];
network::Sender* sender = nullptr;
......@@ -188,7 +182,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
std::string type = args[0];
int64_t msg_queue_size = args[1];
network::Receiver* receiver = nullptr;
......@@ -202,21 +196,22 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeSender")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
network::Sender* sender = static_cast<network::Sender*>(chandle);
sender->Finalize();
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeReceiver")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
network::Receiver* receiver =
static_cast<network::SocketReceiver*>(chandle);
receiver->Finalize();
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
std::string ip = args[1];
int port = args[2];
......@@ -232,7 +227,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver")
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
network::Sender* sender = static_cast<network::Sender*>(chandle);
const int max_try_times = 1024;
......@@ -242,12 +237,13 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect")
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverWait")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
std::string ip = args[1];
int port = args[2];
int num_sender = args[3];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
network::Receiver* receiver =
static_cast<network::SocketReceiver*>(chandle);
std::string addr;
if (receiver->NetType() == "socket") {
addr = StringPrintf("socket://%s:%d", ip.c_str(), port);
......@@ -259,12 +255,11 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverWait")
}
});
////////////////////////// Distributed Sampler Components ////////////////////////////////
////////////////////////// Distributed Sampler Components
///////////////////////////////////
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendNodeFlow")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
int recv_id = args[1];
GraphRef g = args[2];
......@@ -348,7 +343,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendNodeFlow")
});
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSamplerEndSignal")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
int recv_id = args[1];
ArrayMeta meta(kFinalMsg);
......@@ -361,9 +356,10 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSamplerEndSignal")
});
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
network::Receiver* receiver =
static_cast<network::SocketReceiver*>(chandle);
int send_id = 0;
Message recv_msg;
CHECK_EQ(receiver->Recv(&recv_msg, &send_id), REMOVE_SUCCESS);
......@@ -377,71 +373,50 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
CHECK_EQ(receiver->RecvFrom(&array_0, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[0], 1);
nf->node_mapping = CreateNDArrayFromRaw(
{meta.data_shape_[1]},
DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0},
array_0.data,
AUTO_FREE);
{meta.data_shape_[1]}, DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0}, array_0.data, AUTO_FREE);
// edge_mapping
Message array_1;
CHECK_EQ(receiver->RecvFrom(&array_1, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[2], 1);
nf->edge_mapping = CreateNDArrayFromRaw(
{meta.data_shape_[3]},
DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0},
array_1.data,
AUTO_FREE);
{meta.data_shape_[3]}, DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0}, array_1.data, AUTO_FREE);
// layer_offset
Message array_2;
CHECK_EQ(receiver->RecvFrom(&array_2, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[4], 1);
nf->layer_offsets = CreateNDArrayFromRaw(
{meta.data_shape_[5]},
DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0},
array_2.data,
AUTO_FREE);
{meta.data_shape_[5]}, DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0}, array_2.data, AUTO_FREE);
// flow_offset
Message array_3;
CHECK_EQ(receiver->RecvFrom(&array_3, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[6], 1);
nf->flow_offsets = CreateNDArrayFromRaw(
{meta.data_shape_[7]},
DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0},
array_3.data,
AUTO_FREE);
{meta.data_shape_[7]}, DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0}, array_3.data, AUTO_FREE);
// CSR indptr
Message array_4;
CHECK_EQ(receiver->RecvFrom(&array_4, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[8], 1);
NDArray indptr = CreateNDArrayFromRaw(
{meta.data_shape_[9]},
DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0},
array_4.data,
AUTO_FREE);
{meta.data_shape_[9]}, DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0}, array_4.data, AUTO_FREE);
// CSR indice
Message array_5;
CHECK_EQ(receiver->RecvFrom(&array_5, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[10], 1);
NDArray indice = CreateNDArrayFromRaw(
{meta.data_shape_[11]},
DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0},
array_5.data,
AUTO_FREE);
{meta.data_shape_[11]}, DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0}, array_5.data, AUTO_FREE);
// CSR edge_ids
Message array_6;
CHECK_EQ(receiver->RecvFrom(&array_6, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[12], 1);
NDArray edge_ids = CreateNDArrayFromRaw(
{meta.data_shape_[13]},
DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0},
array_6.data,
AUTO_FREE);
{meta.data_shape_[13]}, DGLDataType{kDGLInt, 64, 1},
DGLContext{kDGLCPU, 0}, array_6.data, AUTO_FREE);
// Create CSR
CSRPtr csr(new CSR(indptr, indice, edge_ids));
nf->graph = GraphPtr(new ImmutableGraph(csr, nullptr));
......@@ -453,14 +428,11 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
}
});
////////////////////////// Distributed KVStore Components
///////////////////////////////////
////////////////////////// Distributed KVStore Components ////////////////////////////////
static void send_kv_message(network::Sender* sender,
KVStoreMsg* kv_msg,
int recv_id,
bool auto_free) {
static void send_kv_message(
network::Sender* sender, KVStoreMsg* kv_msg, int recv_id, bool auto_free) {
int64_t kv_size = 0;
char* kv_data = kv_msg->Serialize(&kv_size);
// Send kv_data
......@@ -471,23 +443,18 @@ static void send_kv_message(network::Sender* sender,
send_kv_msg.deallocator = DefaultMessageDeleter;
}
CHECK_EQ(sender->Send(send_kv_msg, recv_id), ADD_SUCCESS);
if (kv_msg->msg_type != kFinalMsg &&
kv_msg->msg_type != kBarrierMsg &&
kv_msg->msg_type != kIPIDMsg &&
kv_msg->msg_type != kGetShapeMsg) {
if (kv_msg->msg_type != kFinalMsg && kv_msg->msg_type != kBarrierMsg &&
kv_msg->msg_type != kIPIDMsg && kv_msg->msg_type != kGetShapeMsg) {
// Send ArrayMeta
ArrayMeta meta(kv_msg->msg_type);
if (kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeBackMsg) {
if (kv_msg->msg_type != kInitMsg && kv_msg->msg_type != kGetShapeBackMsg) {
meta.AddArray(kv_msg->id);
}
if (kv_msg->msg_type != kPullMsg &&
kv_msg->msg_type != kInitMsg &&
if (kv_msg->msg_type != kPullMsg && kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeBackMsg) {
meta.AddArray(kv_msg->data);
}
if (kv_msg->msg_type != kPullMsg &&
kv_msg->msg_type != kPushMsg &&
if (kv_msg->msg_type != kPullMsg && kv_msg->msg_type != kPushMsg &&
kv_msg->msg_type != kPullBackMsg) {
meta.AddArray(kv_msg->shape);
}
......@@ -501,8 +468,7 @@ static void send_kv_message(network::Sender* sender,
}
CHECK_EQ(sender->Send(send_meta_msg, recv_id), ADD_SUCCESS);
// Send ID NDArray
if (kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeMsg &&
if (kv_msg->msg_type != kInitMsg && kv_msg->msg_type != kGetShapeMsg &&
kv_msg->msg_type != kGetShapeBackMsg) {
Message send_id_msg;
send_id_msg.data = static_cast<char*>(kv_msg->id->data);
......@@ -514,8 +480,7 @@ static void send_kv_message(network::Sender* sender,
CHECK_EQ(sender->Send(send_id_msg, recv_id), ADD_SUCCESS);
}
// Send data NDArray
if (kv_msg->msg_type != kPullMsg &&
kv_msg->msg_type != kInitMsg &&
if (kv_msg->msg_type != kPullMsg && kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeMsg &&
kv_msg->msg_type != kGetShapeBackMsg) {
Message send_data_msg;
......@@ -528,8 +493,7 @@ static void send_kv_message(network::Sender* sender,
CHECK_EQ(sender->Send(send_data_msg, recv_id), ADD_SUCCESS);
}
// Send shape NDArray
if (kv_msg->msg_type != kPullMsg &&
kv_msg->msg_type != kPushMsg &&
if (kv_msg->msg_type != kPullMsg && kv_msg->msg_type != kPushMsg &&
kv_msg->msg_type != kPullBackMsg) {
Message send_shape_msg;
send_shape_msg.data = static_cast<char*>(kv_msg->shape->data);
......@@ -544,17 +508,15 @@ static void send_kv_message(network::Sender* sender,
}
static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
KVStoreMsg *kv_msg = new KVStoreMsg();
KVStoreMsg* kv_msg = new KVStoreMsg();
// Recv kv_Msg
Message recv_kv_msg;
int send_id;
CHECK_EQ(receiver->Recv(&recv_kv_msg, &send_id), REMOVE_SUCCESS);
kv_msg->Deserialize(recv_kv_msg.data, recv_kv_msg.size);
recv_kv_msg.deallocator(&recv_kv_msg);
if (kv_msg->msg_type == kFinalMsg ||
kv_msg->msg_type == kBarrierMsg ||
kv_msg->msg_type == kIPIDMsg ||
kv_msg->msg_type == kGetShapeMsg) {
if (kv_msg->msg_type == kFinalMsg || kv_msg->msg_type == kBarrierMsg ||
kv_msg->msg_type == kIPIDMsg || kv_msg->msg_type == kGetShapeMsg) {
return kv_msg;
}
// Recv ArrayMeta
......@@ -563,21 +525,16 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
ArrayMeta meta(recv_meta_msg.data, recv_meta_msg.size);
recv_meta_msg.deallocator(&recv_meta_msg);
// Recv ID NDArray
if (kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeBackMsg) {
if (kv_msg->msg_type != kInitMsg && kv_msg->msg_type != kGetShapeBackMsg) {
Message recv_id_msg;
CHECK_EQ(receiver->RecvFrom(&recv_id_msg, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[0], 1);
kv_msg->id = CreateNDArrayFromRaw(
{meta.data_shape_[1]},
meta.data_type_[0],
DGLContext{kDGLCPU, 0},
recv_id_msg.data,
AUTO_FREE);
{meta.data_shape_[1]}, meta.data_type_[0], DGLContext{kDGLCPU, 0},
recv_id_msg.data, AUTO_FREE);
}
// Recv Data NDArray
if (kv_msg->msg_type != kPullMsg &&
kv_msg->msg_type != kInitMsg &&
if (kv_msg->msg_type != kPullMsg && kv_msg->msg_type != kInitMsg &&
kv_msg->msg_type != kGetShapeBackMsg) {
Message recv_data_msg;
CHECK_EQ(receiver->RecvFrom(&recv_data_msg, send_id), REMOVE_SUCCESS);
......@@ -585,18 +542,14 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
CHECK_GE(ndim, 1);
std::vector<int64_t> vec_shape;
for (int i = 0; i < ndim; ++i) {
vec_shape.push_back(meta.data_shape_[3+i]);
vec_shape.push_back(meta.data_shape_[3 + i]);
}
kv_msg->data = CreateNDArrayFromRaw(
vec_shape,
meta.data_type_[1],
DGLContext{kDGLCPU, 0},
recv_data_msg.data,
AUTO_FREE);
vec_shape, meta.data_type_[1], DGLContext{kDGLCPU, 0},
recv_data_msg.data, AUTO_FREE);
}
// Recv Shape
if (kv_msg->msg_type != kPullMsg &&
kv_msg->msg_type != kPushMsg &&
if (kv_msg->msg_type != kPullMsg && kv_msg->msg_type != kPushMsg &&
kv_msg->msg_type != kPullBackMsg) {
Message recv_shape_msg;
CHECK_EQ(receiver->RecvFrom(&recv_shape_msg, send_id), REMOVE_SUCCESS);
......@@ -604,20 +557,17 @@ static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
CHECK_GE(ndim, 1);
std::vector<int64_t> vec_shape;
for (int i = 0; i < ndim; ++i) {
vec_shape.push_back(meta.data_shape_[1+i]);
vec_shape.push_back(meta.data_shape_[1 + i]);
}
kv_msg->shape = CreateNDArrayFromRaw(
vec_shape,
meta.data_type_[0],
DGLContext{kDGLCPU, 0},
recv_shape_msg.data,
AUTO_FREE);
vec_shape, meta.data_type_[0], DGLContext{kDGLCPU, 0},
recv_shape_msg.data, AUTO_FREE);
}
return kv_msg;
}
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
int args_count = 0;
CommunicatorHandle chandle = args[args_count++];
int recv_id = args[args_count++];
......@@ -628,23 +578,18 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg")
if (kv_msg.msg_type != kFinalMsg && kv_msg.msg_type != kBarrierMsg) {
std::string name = args[args_count++];
kv_msg.name = name;
if (kv_msg.msg_type != kIPIDMsg &&
kv_msg.msg_type != kInitMsg &&
if (kv_msg.msg_type != kIPIDMsg && kv_msg.msg_type != kInitMsg &&
kv_msg.msg_type != kGetShapeMsg &&
kv_msg.msg_type != kGetShapeBackMsg) {
kv_msg.id = args[args_count++];
}
if (kv_msg.msg_type != kPullMsg &&
kv_msg.msg_type != kIPIDMsg &&
kv_msg.msg_type != kInitMsg &&
kv_msg.msg_type != kGetShapeMsg &&
if (kv_msg.msg_type != kPullMsg && kv_msg.msg_type != kIPIDMsg &&
kv_msg.msg_type != kInitMsg && kv_msg.msg_type != kGetShapeMsg &&
kv_msg.msg_type != kGetShapeBackMsg) {
kv_msg.data = args[args_count++];
}
if (kv_msg.msg_type != kIPIDMsg &&
kv_msg.msg_type != kPullMsg &&
kv_msg.msg_type != kPushMsg &&
kv_msg.msg_type != kPullBackMsg &&
if (kv_msg.msg_type != kIPIDMsg && kv_msg.msg_type != kPullMsg &&
kv_msg.msg_type != kPushMsg && kv_msg.msg_type != kPullBackMsg &&
kv_msg.msg_type != kGetShapeMsg) {
kv_msg.shape = args[args_count++];
}
......@@ -653,63 +598,64 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg")
});
DGL_REGISTER_GLOBAL("network.CAPI_ReceiverRecvKVMsg")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
network::Receiver* receiver =
static_cast<network::SocketReceiver*>(chandle);
*rv = recv_kv_message(receiver);
});
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgType")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
KVMsgHandle chandle = args[0];
network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle);
*rv = msg->msg_type;
});
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgRank")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
KVMsgHandle chandle = args[0];
network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle);
*rv = msg->rank;
});
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgName")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
KVMsgHandle chandle = args[0];
network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle);
*rv = msg->name;
});
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgID")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
KVMsgHandle chandle = args[0];
network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle);
*rv = msg->id;
});
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgData")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
KVMsgHandle chandle = args[0];
network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle);
*rv = msg->data;
});
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgShape")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
KVMsgHandle chandle = args[0];
network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle);
*rv = msg->shape;
});
DGL_REGISTER_GLOBAL("network._CAPI_DeleteKVMsg")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
KVMsgHandle chandle = args[0];
network::KVStoreMsg* msg = static_cast<KVStoreMsg*>(chandle);
delete msg;
});
DGL_REGISTER_GLOBAL("network._CAPI_FastPull")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue* rv) {
std::string name = args[0];
int local_machine_id = args[1];
int machine_count = args[2];
......@@ -722,7 +668,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_FastPull")
CommunicatorHandle chandle_receiver = args[9];
std::string str_flag = args[10];
network::Sender* sender = static_cast<network::Sender*>(chandle_sender);
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle_receiver);
network::Receiver* receiver =
static_cast<network::SocketReceiver*>(chandle_receiver);
int64_t ID_size = ID.GetSize() / sizeof(int64_t);
int64_t* ID_data = static_cast<int64_t*>(ID->data);
int64_t* pb_data = static_cast<int64_t*>(pb->data);
......@@ -785,15 +732,13 @@ DGL_REGISTER_GLOBAL("network._CAPI_FastPull")
kv_msg.msg_type = MessageType::kPullMsg;
kv_msg.rank = client_id;
kv_msg.name = name;
kv_msg.id = CreateNDArrayFromRaw({static_cast<int64_t>(remote_ids[i].size())},
ID->dtype,
DGLContext{kDGLCPU, 0},
remote_ids[i].data(),
!AUTO_FREE);
int lower = i*group_count;
int higher = (i+1)*group_count-1;
kv_msg.id = CreateNDArrayFromRaw(
{static_cast<int64_t>(remote_ids[i].size())}, ID->dtype,
DGLContext{kDGLCPU, 0}, remote_ids[i].data(), !AUTO_FREE);
int lower = i * group_count;
int higher = (i + 1) * group_count - 1;
#ifndef _WIN32 // windows does not support rand_r()
int s_id = (rand_r(&seed) % (higher-lower+1))+lower;
int s_id = (rand_r(&seed) % (higher - lower + 1)) + lower;
send_kv_message(sender, &kv_msg, s_id, !AUTO_FREE);
#else
LOG(FATAL) << "KVStore does not support Windows yet.";
......@@ -801,40 +746,38 @@ DGL_REGISTER_GLOBAL("network._CAPI_FastPull")
msg_count++;
}
}
char *return_data = new char[ID_size*row_size];
char* return_data = new char[ID_size * row_size];
const int64_t local_ids_size = local_ids.size();
// Copy local data
runtime::parallel_for(0, local_ids_size, [&](size_t b, size_t e) {
for (auto i = b; i < e; ++i) {
CHECK_GE(ID_size*row_size, local_ids_orginal[i] * row_size + row_size);
CHECK_GE(
ID_size * row_size, local_ids_orginal[i] * row_size + row_size);
CHECK_GE(data_size, local_ids[i] * row_size + row_size);
CHECK_GE(local_ids[i], 0);
memcpy(return_data + local_ids_orginal[i] * row_size,
local_data_char + local_ids[i] * row_size,
row_size);
memcpy(
return_data + local_ids_orginal[i] * row_size,
local_data_char + local_ids[i] * row_size, row_size);
}
});
// Recv remote message
for (int i = 0; i < msg_count; ++i) {
KVStoreMsg *kv_msg = recv_kv_message(receiver);
KVStoreMsg* kv_msg = recv_kv_message(receiver);
int64_t id_size = kv_msg->id.GetSize() / sizeof(int64_t);
int part_id = kv_msg->rank / group_count;
char* data_char = static_cast<char*>(kv_msg->data->data);
for (int64_t n = 0; n < id_size; ++n) {
memcpy(return_data + remote_ids_original[part_id][n] * row_size,
data_char + n * row_size,
row_size);
memcpy(
return_data + remote_ids_original[part_id][n] * row_size,
data_char + n * row_size, row_size);
}
delete kv_msg;
}
// Get final tensor
local_data_shape[0] = ID_size;
NDArray res_tensor = CreateNDArrayFromRaw(
local_data_shape,
local_data->dtype,
DGLContext{kDGLCPU, 0},
return_data,
AUTO_FREE);
local_data_shape, local_data->dtype, DGLContext{kDGLCPU, 0},
return_data, AUTO_FREE);
*rv = res_tensor;
});
......
......@@ -6,12 +6,12 @@
#ifndef DGL_GRAPH_NETWORK_H_
#define DGL_GRAPH_NETWORK_H_
#include <dmlc/logging.h>
#include <dgl/runtime/ndarray.h>
#include <dmlc/logging.h>
#include <string.h>
#include <vector>
#include <string>
#include <vector>
#include "../c_api_common.h"
#include "../rpc/network/msg_queue.h"
......@@ -24,10 +24,8 @@ namespace network {
/*!
* \brief Create NDArray from raw data
*/
NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape,
DGLDataType dtype,
DGLContext ctx,
void* raw);
NDArray CreateNDArrayFromRaw(
std::vector<int64_t> shape, DGLDataType dtype, DGLContext ctx, void* raw);
/*!
* \brief Message type for DGL distributed training
......@@ -75,7 +73,6 @@ enum MessageType {
kGetShapeBackMsg = 9
};
/*!
* \brief Meta data for NDArray message
*/
......@@ -85,8 +82,7 @@ class ArrayMeta {
* \brief ArrayMeta constructor.
* \param msg_type type of message
*/
explicit ArrayMeta(int msg_type)
: msg_type_(msg_type), ndarray_count_(0) {}
explicit ArrayMeta(int msg_type) : msg_type_(msg_type), ndarray_count_(0) {}
/*!
* \brief Construct ArrayMeta from binary data buffer.
......@@ -101,16 +97,12 @@ class ArrayMeta {
/*!
* \return message type
*/
inline int msg_type() const {
return msg_type_;
}
inline int msg_type() const { return msg_type_; }
/*!
* \return count of ndarray
*/
inline int ndarray_count() const {
return ndarray_count_;
}
inline int ndarray_count() const { return ndarray_count_; }
/*!
* \brief Add NDArray meta data to ArrayMeta
......
......@@ -5,8 +5,8 @@
*/
#include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/nodeflow.h>
#include <dgl/packed_func_ext.h>
#include <string>
......@@ -19,15 +19,16 @@ using dgl::runtime::PackedFunc;
namespace dgl {
std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::string &fmt,
size_t layer0_size, size_t layer1_start,
size_t layer1_end, bool remap) {
std::vector<IdArray> GetNodeFlowSlice(
const ImmutableGraph &graph, const std::string &fmt, size_t layer0_size,
size_t layer1_start, size_t layer1_end, bool remap) {
CHECK_GE(layer1_start, layer0_size);
if (fmt == std::string("csr")) {
dgl_id_t first_vid = layer1_start - layer0_size;
auto csr = aten::CSRSliceRows(graph.GetInCSR()->ToCSRMatrix(), layer1_start, layer1_end);
auto csr = aten::CSRSliceRows(
graph.GetInCSR()->ToCSRMatrix(), layer1_start, layer1_end);
if (remap) {
dgl_id_t *eid_data = static_cast<dgl_id_t*>(csr.data->data);
dgl_id_t *eid_data = static_cast<dgl_id_t *>(csr.data->data);
const dgl_id_t first_eid = eid_data[0];
IdArray new_indices = aten::Sub(csr.indices, first_vid);
IdArray new_data = aten::Sub(csr.data, first_eid);
......@@ -37,14 +38,14 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st
}
} else if (fmt == std::string("coo")) {
auto csr = graph.GetInCSR()->ToCSRMatrix();
const dgl_id_t* indptr = static_cast<dgl_id_t*>(csr.indptr->data);
const dgl_id_t* indices = static_cast<dgl_id_t*>(csr.indices->data);
const dgl_id_t* edge_ids = static_cast<dgl_id_t*>(csr.data->data);
const dgl_id_t *indptr = static_cast<dgl_id_t *>(csr.indptr->data);
const dgl_id_t *indices = static_cast<dgl_id_t *>(csr.indices->data);
const dgl_id_t *edge_ids = static_cast<dgl_id_t *>(csr.data->data);
int64_t nnz = indptr[layer1_end] - indptr[layer1_start];
IdArray idx = aten::NewIdArray(2 * nnz);
IdArray eid = aten::NewIdArray(nnz);
int64_t *idx_data = static_cast<int64_t*>(idx->data);
dgl_id_t *eid_data = static_cast<dgl_id_t*>(eid->data);
int64_t *idx_data = static_cast<int64_t *>(idx->data);
dgl_id_t *eid_data = static_cast<dgl_id_t *>(eid->data);
size_t num_edges = 0;
for (size_t i = layer1_start; i < layer1_end; i++) {
for (dgl_id_t j = indptr[i]; j < indptr[i + 1]; j++) {
......@@ -65,10 +66,12 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st
eid_data[i] = edge_ids[edge_start + i] - first_eid;
}
} else {
std::copy(indices + indptr[layer1_start],
indices + indptr[layer1_end], idx_data + nnz);
std::copy(edge_ids + indptr[layer1_start],
edge_ids + indptr[layer1_end], eid_data);
std::copy(
indices + indptr[layer1_start], indices + indptr[layer1_end],
idx_data + nnz);
std::copy(
edge_ids + indptr[layer1_start], edge_ids + indptr[layer1_end],
eid_data);
}
return std::vector<IdArray>{idx, eid};
} else {
......@@ -78,14 +81,15 @@ std::vector<IdArray> GetNodeFlowSlice(const ImmutableGraph &graph, const std::st
}
DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetBlockAdj")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef g = args[0];
std::string format = args[1];
int64_t layer0_size = args[2];
int64_t start = args[3];
int64_t end = args[4];
const bool remap = args[5];
auto ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
auto ig =
CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
auto res = GetNodeFlowSlice(*ig, format, layer0_size, start, end, remap);
*rv = ConvertNDArrayVectorToPackedFunc(res);
});
......
......@@ -3,13 +3,14 @@
* \file graph/pickle.cc
* \brief Functions for pickle and unpickle a graph
*/
#include <dgl/graph_serializer.h>
#include <dgl/immutable_graph.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <dgl/immutable_graph.h>
#include <dgl/graph_serializer.h>
#include <dmlc/memory_io.h>
#include "./heterograph.h"
#include "../c_api_common.h"
#include "./heterograph.h"
#include "unit_graph.h"
using namespace dgl::runtime;
......@@ -91,7 +92,7 @@ HeteroPickleStates HeteroForkingPickle(HeteroGraphPtr graph) {
return states;
}
HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates &states) {
char *buf = const_cast<char *>(states.meta.c_str()); // a readonly stream?
dmlc::MemoryFixedSizeStream ifs(buf, states.meta.size());
dmlc::Stream *strm = &ifs;
......@@ -108,10 +109,10 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
auto array_itr = states.arrays.begin();
for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
const auto& pair = metagraph->FindEdge(etype);
const auto &pair = metagraph->FindEdge(etype);
const dgl_type_t srctype = pair.first;
const dgl_type_t dsttype = pair.second;
const int64_t num_vtypes = (srctype == dsttype)? 1 : 2;
const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2;
int64_t num_src = num_nodes_per_type[srctype];
int64_t num_dst = num_nodes_per_type[dsttype];
SparseFormat fmt;
......@@ -126,7 +127,8 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
bool csorted;
CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'";
CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'";
auto coo = aten::COOMatrix(num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
auto coo = aten::COOMatrix(
num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
// TODO(zihao) fix
relgraph = CreateFromCOO(num_vtypes, coo, ALL_CODE);
break;
......@@ -138,7 +140,8 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
const auto &edge_id = *(array_itr++);
bool sorted;
CHECK(strm->Read(&sorted)) << "Invalid flag 'sorted'";
auto csr = aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);
auto csr =
aten::CSRMatrix(num_src, num_dst, indptr, indices, edge_id, sorted);
// TODO(zihao) fix
relgraph = CreateFromCSR(num_vtypes, csr, ALL_CODE);
break;
......@@ -157,17 +160,18 @@ HeteroGraphPtr HeteroUnpickle(const HeteroPickleStates& states) {
}
// For backward compatibility
HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates& states) {
HeteroGraphPtr HeteroUnpickleOld(const HeteroPickleStates &states) {
const auto metagraph = states.metagraph;
const auto &num_nodes_per_type = states.num_nodes_per_type;
CHECK_EQ(states.adjs.size(), metagraph->NumEdges());
std::vector<HeteroGraphPtr> relgraphs(metagraph->NumEdges());
for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
const auto& pair = metagraph->FindEdge(etype);
const auto &pair = metagraph->FindEdge(etype);
const dgl_type_t srctype = pair.first;
const dgl_type_t dsttype = pair.second;
const int64_t num_vtypes = (srctype == dsttype)? 1 : 2;
const SparseFormat fmt = static_cast<SparseFormat>(states.adjs[etype]->format);
const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2;
const SparseFormat fmt =
static_cast<SparseFormat>(states.adjs[etype]->format);
switch (fmt) {
case SparseFormat::kCOO:
relgraphs[etype] = UnitGraph::CreateFromCOO(
......@@ -202,7 +206,7 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
auto array_itr = states.arrays.begin();
for (dgl_type_t etype = 0; etype < metagraph->NumEdges(); ++etype) {
const auto& pair = metagraph->FindEdge(etype);
const auto &pair = metagraph->FindEdge(etype);
const dgl_type_t srctype = pair.first;
const dgl_type_t dsttype = pair.second;
const int64_t num_vtypes = (srctype == dsttype) ? 1 : 2;
......@@ -227,7 +231,8 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
bool csorted;
CHECK(strm->Read(&rsorted)) << "Invalid flag 'rsorted'";
CHECK(strm->Read(&csorted)) << "Invalid flag 'csorted'";
coo = aten::COOMatrix(num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
coo = aten::COOMatrix(
num_src, num_dst, row, col, aten::NullArray(), rsorted, csorted);
}
if (created_formats & CSR_CODE) {
CHECK_GE(states.arrays.end() - array_itr, 3);
......@@ -258,13 +263,13 @@ HeteroGraphPtr HeteroForkingUnpickle(const HeteroPickleStates &states) {
}
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetVersion")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroPickleStatesRef st = args[0];
*rv = st->version;
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMeta")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroPickleStatesRef st = args[0];
DGLByteArray buf;
buf.data = st->meta.c_str();
......@@ -273,50 +278,50 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetMeta")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetArrays")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroPickleStatesRef st = args[0];
*rv = ConvertNDArrayVectorToPackedFunc(st->arrays);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickleStatesGetArraysNum")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroPickleStatesRef st = args[0];
*rv = static_cast<int64_t>(st->arrays.size());
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStates")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
const int version = args[0];
std::string meta = args[1];
const List<Value> arrays = args[2];
std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);
st->version = version == 0 ? 1 : version;
st->meta = meta;
st->arrays.reserve(arrays.size());
for (const auto& ref : arrays) {
for (const auto &ref : arrays) {
st->arrays.push_back(ref->data);
}
*rv = HeteroPickleStatesRef(st);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroPickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef ref = args[0];
std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);
*st = HeteroPickle(ref.sptr());
*rv = HeteroPickleStatesRef(st);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroForkingPickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef ref = args[0];
std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);
*st = HeteroForkingPickle(ref.sptr());
*rv = HeteroPickleStatesRef(st);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroPickleStatesRef ref = args[0];
HeteroGraphPtr graph;
switch (ref->version) {
......@@ -334,24 +339,23 @@ DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroUnpickle")
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLHeteroForkingUnpickle")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroPickleStatesRef ref = args[0];
HeteroGraphPtr graph = HeteroForkingUnpickle(*ref.sptr());
*rv = HeteroGraphRef(graph);
});
DGL_REGISTER_GLOBAL("heterograph_index._CAPI_DGLCreateHeteroPickleStatesOld")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
GraphRef metagraph = args[0];
IdArray num_nodes_per_type = args[1];
List<SparseMatrixRef> adjs = args[2];
std::shared_ptr<HeteroPickleStates> st( new HeteroPickleStates );
std::shared_ptr<HeteroPickleStates> st(new HeteroPickleStates);
st->version = 0;
st->metagraph = metagraph.sptr();
st->num_nodes_per_type = num_nodes_per_type.ToVector<int64_t>();
st->adjs.reserve(adjs.size());
for (const auto& ref : adjs)
st->adjs.push_back(ref.sptr());
for (const auto &ref : adjs) st->adjs.push_back(ref.sptr());
*rv = HeteroPickleStatesRef(st);
});
} // namespace dgl
......@@ -3,18 +3,20 @@
* \file graph/sampler.cc
* \brief DGL sampler implementation
*/
#include <dgl/sampler.h>
#include <dgl/array.h>
#include <dgl/immutable_graph.h>
#include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h>
#include <dgl/random.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/parallel_for.h>
#include <dgl/sampler.h>
#include <dmlc/omp.h>
#include <algorithm>
#include <cstdlib>
#include <cmath>
#include <cstdlib>
#include <numeric>
#include "../c_api_common.h"
using namespace dgl::runtime;
......@@ -25,21 +27,21 @@ namespace {
/*
* ArrayHeap is used to sample elements from vector
*/
template<typename ValueType>
template <typename ValueType>
class ArrayHeap {
public:
explicit ArrayHeap(const std::vector<ValueType>& prob) {
explicit ArrayHeap(const std::vector<ValueType> &prob) {
vec_size_ = prob.size();
bit_len_ = ceil(log2(vec_size_));
limit_ = 1UL << bit_len_;
// allocate twice the size
heap_.resize(limit_ << 1, 0);
// allocate the leaves
for (size_t i = limit_; i < vec_size_+limit_; ++i) {
heap_[i] = prob[i-limit_];
for (size_t i = limit_; i < vec_size_ + limit_; ++i) {
heap_[i] = prob[i - limit_];
}
// iterate up the tree (this is O(m))
for (int i = bit_len_-1; i >= 0; --i) {
for (int i = bit_len_ - 1; i >= 0; --i) {
for (size_t j = (1UL << i); j < (1UL << (i + 1)); ++j) {
heap_[j] = heap_[j << 1] + heap_[(j << 1) + 1];
}
......@@ -54,7 +56,7 @@ class ArrayHeap {
size_t i = index + limit_;
heap_[i] = 0;
i /= 2;
for (int j = bit_len_-1; j >= 0; --j) {
for (int j = bit_len_ - 1; j >= 0; --j) {
// Using heap_[i] = heap_[i] - w will loss some precision in float.
// Using addition to re-calculate the weight layer by layer.
heap_[i] = heap_[i << 1] + heap_[(i << 1) + 1];
......@@ -93,7 +95,7 @@ class ArrayHeap {
/*
* Sample a vector by given the size n
*/
size_t SampleWithoutReplacement(size_t n, std::vector<size_t>* samples) {
size_t SampleWithoutReplacement(size_t n, std::vector<size_t> *samples) {
// sample n elements
size_t i = 0;
for (; i < n; ++i) {
......@@ -116,20 +118,14 @@ class ArrayHeap {
};
///////////////////////// Samplers //////////////////////////
class EdgeSamplerObject: public Object {
class EdgeSamplerObject : public Object {
public:
EdgeSamplerObject(const GraphPtr gptr,
IdArray seed_edges,
const int64_t batch_size,
const int64_t num_workers,
const bool replacement,
const bool reset,
const std::string neg_mode,
const int64_t neg_sample_size,
const int64_t chunk_size,
const bool exclude_positive,
const bool check_false_neg,
IdArray relations) {
EdgeSamplerObject(
const GraphPtr gptr, IdArray seed_edges, const int64_t batch_size,
const int64_t num_workers, const bool replacement, const bool reset,
const std::string neg_mode, const int64_t neg_sample_size,
const int64_t chunk_size, const bool exclude_positive,
const bool check_false_neg, IdArray relations) {
gptr_ = gptr;
seed_edges_ = seed_edges;
relations_ = relations;
......@@ -147,24 +143,22 @@ class EdgeSamplerObject: public Object {
~EdgeSamplerObject() {}
virtual void Fetch(DGLRetValue* rv) = 0;
virtual void Fetch(DGLRetValue *rv) = 0;
virtual void Reset() = 0;
protected:
virtual void randomSample(size_t set_size, size_t num, std::vector<size_t>* out) = 0;
virtual void randomSample(size_t set_size, size_t num, const std::vector<size_t> &exclude,
std::vector<size_t>* out) = 0;
NegSubgraph genNegEdgeSubgraph(const Subgraph &pos_subg,
const std::string &neg_mode,
int64_t neg_sample_size,
bool exclude_positive,
bool check_false_neg);
NegSubgraph genChunkedNegEdgeSubgraph(const Subgraph &pos_subg,
const std::string &neg_mode,
int64_t neg_sample_size,
bool exclude_positive,
bool check_false_neg);
virtual void randomSample(
size_t set_size, size_t num, std::vector<size_t> *out) = 0;
virtual void randomSample(
size_t set_size, size_t num, const std::vector<size_t> &exclude,
std::vector<size_t> *out) = 0;
NegSubgraph genNegEdgeSubgraph(
const Subgraph &pos_subg, const std::string &neg_mode,
int64_t neg_sample_size, bool exclude_positive, bool check_false_neg);
NegSubgraph genChunkedNegEdgeSubgraph(
const Subgraph &pos_subg, const std::string &neg_mode,
int64_t neg_sample_size, bool exclude_positive, bool check_false_neg);
GraphPtr gptr_;
IdArray seed_edges_;
......@@ -184,7 +178,7 @@ class EdgeSamplerObject: public Object {
/*
* Uniformly sample integers from [0, set_size) without replacement.
*/
void RandomSample(size_t set_size, size_t num, std::vector<size_t>* out) {
void RandomSample(size_t set_size, size_t num, std::vector<size_t> *out) {
if (num < set_size) {
std::unordered_set<size_t> sampled_idxs;
while (sampled_idxs.size() < num) {
......@@ -194,13 +188,13 @@ void RandomSample(size_t set_size, size_t num, std::vector<size_t>* out) {
} else {
// If we need to sample all elements in the set, we don't need to
// generate random numbers.
for (size_t i = 0; i < set_size; i++)
out->push_back(i);
for (size_t i = 0; i < set_size; i++) out->push_back(i);
}
}
void RandomSample(size_t set_size, size_t num, const std::vector<size_t> &exclude,
std::vector<size_t>* out) {
void RandomSample(
size_t set_size, size_t num, const std::vector<size_t> &exclude,
std::vector<size_t> *out) {
std::unordered_map<size_t, int> sampled_idxs;
for (auto v : exclude) {
sampled_idxs.insert(std::pair<size_t, int>(v, 0));
......@@ -231,9 +225,9 @@ void RandomSample(size_t set_size, size_t num, const std::vector<size_t> &exclud
* For a sparse array whose non-zeros are represented by nz_idxs,
* negate the sparse array and outputs the non-zeros in the negated array.
*/
void NegateArray(const std::vector<size_t> &nz_idxs,
size_t arr_size,
std::vector<size_t>* out) {
void NegateArray(
const std::vector<size_t> &nz_idxs, size_t arr_size,
std::vector<size_t> *out) {
// nz_idxs must have been sorted.
auto it = nz_idxs.begin();
size_t i = 0;
......@@ -253,12 +247,10 @@ void NegateArray(const std::vector<size_t> &nz_idxs,
/*
* Uniform sample vertices from a list of vertices.
*/
void GetUniformSample(const dgl_id_t* edge_id_list,
const dgl_id_t* vid_list,
const size_t ver_len,
const size_t max_num_neighbor,
std::vector<dgl_id_t>* out_ver,
std::vector<dgl_id_t>* out_edge) {
void GetUniformSample(
const dgl_id_t *edge_id_list, const dgl_id_t *vid_list,
const size_t ver_len, const size_t max_num_neighbor,
std::vector<dgl_id_t> *out_ver, std::vector<dgl_id_t> *out_edge) {
// Copy vid_list to output
if (ver_len <= max_num_neighbor) {
out_ver->insert(out_ver->end(), vid_list, vid_list + ver_len);
......@@ -292,16 +284,15 @@ void GetUniformSample(const dgl_id_t* edge_id_list,
/*
* Non-uniform sample via ArrayHeap
*
* \param probability Transition probability on the entire graph, indexed by edge ID
* \param probability Transition probability on the entire graph, indexed by
* edge ID
*/
template<typename ValueType>
void GetNonUniformSample(const ValueType* probability,
const dgl_id_t* edge_id_list,
const dgl_id_t* vid_list,
const size_t ver_len,
const size_t max_num_neighbor,
std::vector<dgl_id_t>* out_ver,
std::vector<dgl_id_t>* out_edge) {
template <typename ValueType>
void GetNonUniformSample(
const ValueType *probability, const dgl_id_t *edge_id_list,
const dgl_id_t *vid_list, const size_t ver_len,
const size_t max_num_neighbor, std::vector<dgl_id_t> *out_ver,
std::vector<dgl_id_t> *out_edge) {
// Copy vid_list to output
if (ver_len <= max_num_neighbor) {
out_ver->insert(out_ver->end(), vid_list, vid_list + ver_len);
......@@ -333,8 +324,8 @@ void GetNonUniformSample(const ValueType* probability,
struct neigh_list {
std::vector<dgl_id_t> neighs;
std::vector<dgl_id_t> edges;
neigh_list(const std::vector<dgl_id_t> &_neighs,
const std::vector<dgl_id_t> &_edges)
neigh_list(
const std::vector<dgl_id_t> &_neighs, const std::vector<dgl_id_t> &_edges)
: neighs(_neighs), edges(_edges) {}
};
......@@ -350,12 +341,11 @@ struct neighbor_info {
}
};
NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
std::vector<dgl_id_t> edge_list,
NodeFlow ConstructNodeFlow(
std::vector<dgl_id_t> neighbor_list, std::vector<dgl_id_t> edge_list,
std::vector<size_t> layer_offsets,
std::vector<std::pair<dgl_id_t, int> > *sub_vers,
std::vector<neighbor_info> *neigh_pos,
const std::string &edge_type,
std::vector<std::pair<dgl_id_t, int>> *sub_vers,
std::vector<neighbor_info> *neigh_pos, const std::string &edge_type,
int64_t num_edges, int num_hops) {
NodeFlow nf = NodeFlow::Create();
uint64_t num_vertices = sub_vers->size();
......@@ -371,9 +361,9 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
// Construct sub_csr_graph, we treat nodeflow as multigraph by default
auto subg_csr = CSRPtr(new CSR(num_vertices, num_edges));
dgl_id_t* indptr_out = static_cast<dgl_id_t*>(subg_csr->indptr()->data);
dgl_id_t* col_list_out = static_cast<dgl_id_t*>(subg_csr->indices()->data);
dgl_id_t* eid_out = static_cast<dgl_id_t*>(subg_csr->edge_ids()->data);
dgl_id_t *indptr_out = static_cast<dgl_id_t *>(subg_csr->indptr()->data);
dgl_id_t *col_list_out = static_cast<dgl_id_t *>(subg_csr->indices()->data);
dgl_id_t *eid_out = static_cast<dgl_id_t *>(subg_csr->edge_ids()->data);
size_t collected_nedges = 0;
// The data from the previous steps:
......@@ -385,12 +375,13 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
layer_ver_maps.resize(num_hops);
size_t out_node_idx = 0;
for (int layer_id = num_hops - 1; layer_id >= 0; layer_id--) {
// We sort the vertices in a layer so that we don't need to sort the neighbor Ids
// after remap to a subgraph. However, we don't need to sort the first layer
// because we want the order of the nodes in the first layer is the same as
// the input seed nodes.
// We sort the vertices in a layer so that we don't need to sort the
// neighbor Ids after remap to a subgraph. However, we don't need to sort
// the first layer because we want the order of the nodes in the first layer
// is the same as the input seed nodes.
if (layer_id > 0) {
std::sort(sub_vers->begin() + layer_offsets[layer_id],
std::sort(
sub_vers->begin() + layer_offsets[layer_id],
sub_vers->begin() + layer_offsets[layer_id + 1],
[](const std::pair<dgl_id_t, dgl_id_t> &a1,
const std::pair<dgl_id_t, dgl_id_t> &a2) {
......@@ -399,37 +390,40 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
}
// Save the sampled vertices and its layer Id.
for (size_t i = layer_offsets[layer_id]; i < layer_offsets[layer_id + 1]; i++) {
for (size_t i = layer_offsets[layer_id]; i < layer_offsets[layer_id + 1];
i++) {
node_map_data[out_node_idx++] = sub_vers->at(i).first;
layer_ver_maps[layer_id].insert(std::pair<dgl_id_t, dgl_id_t>(sub_vers->at(i).first,
ver_id++));
layer_ver_maps[layer_id].insert(
std::pair<dgl_id_t, dgl_id_t>(sub_vers->at(i).first, ver_id++));
CHECK_EQ(sub_vers->at(i).second, layer_id);
}
}
CHECK(out_node_idx == num_vertices);
// sampling algorithms have to start from the seed nodes, so the seed nodes are
// in the first layer and the input nodes are in the last layer.
// When we expose the sampled graph to a Python user, we say the input nodes
// are in the first layer and the seed nodes are in the last layer.
// Thus, when we copy sampled results to a CSR, we need to reverse the order of layers.
// sampling algorithms have to start from the seed nodes, so the seed nodes
// are in the first layer and the input nodes are in the last layer. When we
// expose the sampled graph to a Python user, we say the input nodes are in
// the first layer and the seed nodes are in the last layer. Thus, when we
// copy sampled results to a CSR, we need to reverse the order of layers.
std::fill(indptr_out, indptr_out + num_vertices + 1, 0);
size_t row_idx = layer_offsets[num_hops] - layer_offsets[num_hops - 1];
layer_off_data[0] = 0;
layer_off_data[1] = layer_offsets[num_hops] - layer_offsets[num_hops - 1];
int out_layer_idx = 1;
for (int layer_id = num_hops - 2; layer_id >= 0; layer_id--) {
// Because we don't sort the vertices in the first layer above, we can't sort
// the neighbor positions of the vertices in the first layer either.
// Because we don't sort the vertices in the first layer above, we can't
// sort the neighbor positions of the vertices in the first layer either.
if (layer_id > 0) {
std::sort(neigh_pos->begin() + layer_offsets[layer_id],
std::sort(
neigh_pos->begin() + layer_offsets[layer_id],
neigh_pos->begin() + layer_offsets[layer_id + 1],
[](const neighbor_info &a1, const neighbor_info &a2) {
return a1.id < a2.id;
});
}
for (size_t i = layer_offsets[layer_id]; i < layer_offsets[layer_id + 1]; i++) {
for (size_t i = layer_offsets[layer_id]; i < layer_offsets[layer_id + 1];
i++) {
dgl_id_t dst_id = sub_vers->at(i).first;
CHECK_EQ(dst_id, neigh_pos->at(i).id);
size_t pos = neigh_pos->at(i).pos;
......@@ -441,18 +435,22 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
auto neigh_it = neighbor_list.begin() + pos;
for (size_t i = 0; i < nedges; i++) {
dgl_id_t neigh = *(neigh_it + i);
CHECK(layer_ver_maps[layer_id + 1].find(neigh) != layer_ver_maps[layer_id + 1].end());
col_list_out[collected_nedges + i] = layer_ver_maps[layer_id + 1][neigh];
CHECK(
layer_ver_maps[layer_id + 1].find(neigh) !=
layer_ver_maps[layer_id + 1].end());
col_list_out[collected_nedges + i] =
layer_ver_maps[layer_id + 1][neigh];
}
// We can simply copy the edge Ids.
std::copy_n(edge_list.begin() + pos,
nedges, edge_map_data + collected_nedges);
std::copy_n(
edge_list.begin() + pos, nedges, edge_map_data + collected_nedges);
collected_nedges += nedges;
indptr_out[row_idx+1] = indptr_out[row_idx] + nedges;
indptr_out[row_idx + 1] = indptr_out[row_idx] + nedges;
row_idx++;
}
layer_off_data[out_layer_idx + 1] = layer_off_data[out_layer_idx]
+ layer_offsets[layer_id + 1] - layer_offsets[layer_id];
layer_off_data[out_layer_idx + 1] = layer_off_data[out_layer_idx] +
layer_offsets[layer_id + 1] -
layer_offsets[layer_id];
out_layer_idx++;
}
CHECK_EQ(row_idx, num_vertices);
......@@ -464,7 +462,8 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
flow_off_data[0] = 0;
int out_flow_idx = 0;
for (size_t i = 0; i < layer_offsets.size() - 2; i++) {
size_t num_edges = indptr_out[layer_off_data[i + 2]] - indptr_out[layer_off_data[i + 1]];
size_t num_edges =
indptr_out[layer_off_data[i + 2]] - indptr_out[layer_off_data[i + 1]];
flow_off_data[out_flow_idx + 1] = flow_off_data[out_flow_idx] + num_edges;
out_flow_idx++;
}
......@@ -482,23 +481,21 @@ NodeFlow ConstructNodeFlow(std::vector<dgl_id_t> neighbor_list,
return nf;
}
template<typename ValueType>
NodeFlow SampleSubgraph(const ImmutableGraph *graph,
const std::vector<dgl_id_t>& seeds,
const ValueType* probability,
const std::string &edge_type,
int num_hops,
size_t num_neighbor,
const bool add_self_loop) {
template <typename ValueType>
NodeFlow SampleSubgraph(
const ImmutableGraph *graph, const std::vector<dgl_id_t> &seeds,
const ValueType *probability, const std::string &edge_type, int num_hops,
size_t num_neighbor, const bool add_self_loop) {
CHECK_EQ(graph->NumBits(), 64) << "32 bit graph is not supported yet";
const size_t num_seeds = seeds.size();
auto orig_csr = edge_type == "in" ? graph->GetInCSR() : graph->GetOutCSR();
const dgl_id_t* val_list = static_cast<dgl_id_t*>(orig_csr->edge_ids()->data);
const dgl_id_t* col_list = static_cast<dgl_id_t*>(orig_csr->indices()->data);
const dgl_id_t* indptr = static_cast<dgl_id_t*>(orig_csr->indptr()->data);
const dgl_id_t *val_list =
static_cast<dgl_id_t *>(orig_csr->edge_ids()->data);
const dgl_id_t *col_list = static_cast<dgl_id_t *>(orig_csr->indices()->data);
const dgl_id_t *indptr = static_cast<dgl_id_t *>(orig_csr->indptr()->data);
std::unordered_set<dgl_id_t> sub_ver_map; // The vertex Ids in a layer.
std::vector<std::pair<dgl_id_t, int> > sub_vers;
std::vector<std::pair<dgl_id_t, int>> sub_vers;
sub_vers.reserve(num_seeds * 10);
// add seed vertices
for (size_t i = 0; i < num_seeds; ++i) {
......@@ -526,37 +523,38 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
// sampled nodes in a layer, and clear it when entering a new layer.
sub_ver_map.clear();
// Previous iteration collects all nodes in sub_vers, which are collected
// in the previous layer. sub_vers is used both as a node collection and a queue.
for (size_t idx = layer_offsets[layer_id - 1]; idx < layer_offsets[layer_id]; idx++) {
// in the previous layer. sub_vers is used both as a node collection and a
// queue.
for (size_t idx = layer_offsets[layer_id - 1];
idx < layer_offsets[layer_id]; idx++) {
dgl_id_t dst_id = sub_vers[idx].first;
const int cur_node_level = sub_vers[idx].second;
tmp_sampled_src_list.clear();
tmp_sampled_edge_list.clear();
dgl_id_t ver_len = *(indptr+dst_id+1) - *(indptr+dst_id);
dgl_id_t ver_len = *(indptr + dst_id + 1) - *(indptr + dst_id);
if (probability == nullptr) { // uniform-sample
GetUniformSample(val_list + *(indptr + dst_id),
col_list + *(indptr + dst_id),
ver_len,
num_neighbor,
&tmp_sampled_src_list,
GetUniformSample(
val_list + *(indptr + dst_id), col_list + *(indptr + dst_id),
ver_len, num_neighbor, &tmp_sampled_src_list,
&tmp_sampled_edge_list);
} else { // non-uniform-sample
GetNonUniformSample(probability,
val_list + *(indptr + dst_id),
col_list + *(indptr + dst_id),
ver_len,
num_neighbor,
&tmp_sampled_src_list,
&tmp_sampled_edge_list);
}
// If we need to add self loop and it doesn't exist in the sampled neighbor list.
if (add_self_loop && std::find(tmp_sampled_src_list.begin(), tmp_sampled_src_list.end(),
GetNonUniformSample(
probability, val_list + *(indptr + dst_id),
col_list + *(indptr + dst_id), ver_len, num_neighbor,
&tmp_sampled_src_list, &tmp_sampled_edge_list);
}
// If we need to add self loop and it doesn't exist in the sampled
// neighbor list.
if (add_self_loop &&
std::find(
tmp_sampled_src_list.begin(), tmp_sampled_src_list.end(),
dst_id) == tmp_sampled_src_list.end()) {
tmp_sampled_src_list.push_back(dst_id);
const dgl_id_t *src_list = col_list + *(indptr + dst_id);
const dgl_id_t *eid_list = val_list + *(indptr + dst_id);
// TODO(zhengda) this operation has O(N) complexity. It can be pretty slow.
// TODO(zhengda) this operation has O(N) complexity. It can be pretty
// slow.
const dgl_id_t *src = std::find(src_list, src_list + ver_len, dst_id);
// If there doesn't exist a self loop in the graph.
// we have to add -1 as the edge id for the self-loop edge.
......@@ -566,7 +564,8 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
tmp_sampled_edge_list.push_back(eid_list[src - src_list]);
}
CHECK_EQ(tmp_sampled_src_list.size(), tmp_sampled_edge_list.size());
neigh_pos.emplace_back(dst_id, neighbor_list.size(), tmp_sampled_src_list.size());
neigh_pos.emplace_back(
dst_id, neighbor_list.size(), tmp_sampled_src_list.size());
// Then push the vertices
for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {
neighbor_list.push_back(tmp_sampled_src_list[i]);
......@@ -578,8 +577,8 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
num_edges += tmp_sampled_src_list.size();
for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {
// We need to add the neighbor in the hashtable here. This ensures that
// the vertex in the queue is unique. If we see a vertex before, we don't
// need to add it to the queue again.
// the vertex in the queue is unique. If we see a vertex before, we
// don't need to add it to the queue again.
auto ret = sub_ver_map.insert(tmp_sampled_src_list[i]);
// If the sampled neighbor is inserted to the map successfully.
if (ret.second) {
......@@ -591,76 +590,69 @@ NodeFlow SampleSubgraph(const ImmutableGraph *graph,
CHECK_EQ(layer_offsets[layer_id + 1], sub_vers.size());
}
return ConstructNodeFlow(neighbor_list, edge_list, layer_offsets, &sub_vers, &neigh_pos,
edge_type, num_edges, num_hops);
return ConstructNodeFlow(
neighbor_list, edge_list, layer_offsets, &sub_vers, &neigh_pos, edge_type,
num_edges, num_hops);
}
} // namespace
DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetGraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
NodeFlow nflow = args[0];
*rv = nflow->graph;
});
DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetNodeMapping")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
NodeFlow nflow = args[0];
*rv = nflow->node_mapping;
});
DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetEdgeMapping")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
NodeFlow nflow = args[0];
*rv = nflow->edge_mapping;
});
DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetLayerOffsets")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
NodeFlow nflow = args[0];
*rv = nflow->layer_offsets;
});
DGL_REGISTER_GLOBAL("_deprecate.nodeflow._CAPI_NodeFlowGetBlockOffsets")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
NodeFlow nflow = args[0];
*rv = nflow->flow_offsets;
});
template<typename ValueType>
NodeFlow SamplerOp::NeighborSample(const ImmutableGraph *graph,
const std::vector<dgl_id_t>& seeds,
const std::string &edge_type,
int num_hops, int expand_factor,
const bool add_self_loop,
const ValueType *probability) {
return SampleSubgraph(graph,
seeds,
probability,
edge_type,
num_hops + 1,
expand_factor,
template <typename ValueType>
NodeFlow SamplerOp::NeighborSample(
const ImmutableGraph *graph, const std::vector<dgl_id_t> &seeds,
const std::string &edge_type, int num_hops, int expand_factor,
const bool add_self_loop, const ValueType *probability) {
return SampleSubgraph(
graph, seeds, probability, edge_type, num_hops + 1, expand_factor,
add_self_loop);
}
namespace {
void ConstructLayers(const dgl_id_t *indptr,
const dgl_id_t *indices,
const std::vector<dgl_id_t>& seed_array,
IdArray layer_sizes,
std::vector<dgl_id_t> *layer_offsets,
std::vector<dgl_id_t> *node_mapping,
std::vector<int64_t> *actl_layer_sizes,
std::vector<float> *probabilities) {
void ConstructLayers(
const dgl_id_t *indptr, const dgl_id_t *indices,
const std::vector<dgl_id_t> &seed_array, IdArray layer_sizes,
std::vector<dgl_id_t> *layer_offsets, std::vector<dgl_id_t> *node_mapping,
std::vector<int64_t> *actl_layer_sizes, std::vector<float> *probabilities) {
/*
* Given a graph and a collection of seed nodes, this function constructs NodeFlow
* layers via uniform layer-wise sampling, and return the resultant layers and their
* corresponding probabilities.
* Given a graph and a collection of seed nodes, this function constructs
* NodeFlow layers via uniform layer-wise sampling, and return the resultant
* layers and their corresponding probabilities.
*/
std::copy(seed_array.begin(), seed_array.end(), std::back_inserter(*node_mapping));
std::copy(
seed_array.begin(), seed_array.end(), std::back_inserter(*node_mapping));
actl_layer_sizes->push_back(node_mapping->size());
probabilities->insert(probabilities->end(), node_mapping->size(), 1);
const int64_t* layer_sizes_data = static_cast<int64_t*>(layer_sizes->data);
const int64_t *layer_sizes_data = static_cast<int64_t *>(layer_sizes->data);
const int64_t num_layers = layer_sizes->shape[0];
size_t curr = 0;
......@@ -674,14 +666,15 @@ namespace {
}
std::vector<dgl_id_t> candidate_vector;
std::copy(candidate_set.begin(), candidate_set.end(),
std::copy(
candidate_set.begin(), candidate_set.end(),
std::back_inserter(candidate_vector));
std::unordered_map<dgl_id_t, size_t> n_occurrences;
auto n_candidates = candidate_vector.size();
for (int64_t j = 0; j != layer_size; ++j) {
auto dst = candidate_vector[
RandomEngine::ThreadLocal()->RandInt(n_candidates)];
auto dst =
candidate_vector[RandomEngine::ThreadLocal()->RandInt(n_candidates)];
if (!n_occurrences.insert(std::make_pair(dst, 1)).second) {
++n_occurrences[dst];
}
......@@ -703,21 +696,18 @@ namespace {
for (const auto &size : *actl_layer_sizes) {
layer_offsets->push_back(size + layer_offsets->back());
}
}
}
void ConstructFlows(const dgl_id_t *indptr,
const dgl_id_t *indices,
const dgl_id_t *eids,
void ConstructFlows(
const dgl_id_t *indptr, const dgl_id_t *indices, const dgl_id_t *eids,
const std::vector<dgl_id_t> &node_mapping,
const std::vector<int64_t> &actl_layer_sizes,
std::vector<dgl_id_t> *sub_indptr,
std::vector<dgl_id_t> *sub_indices,
std::vector<dgl_id_t> *sub_eids,
std::vector<dgl_id_t> *flow_offsets,
std::vector<dgl_id_t> *sub_indptr, std::vector<dgl_id_t> *sub_indices,
std::vector<dgl_id_t> *sub_eids, std::vector<dgl_id_t> *flow_offsets,
std::vector<dgl_id_t> *edge_mapping) {
/*
* Given a graph and a sequence of NodeFlow layers, this function constructs dense
* subgraphs (flows) between consecutive layers.
* Given a graph and a sequence of NodeFlow layers, this function constructs
* dense subgraphs (flows) between consecutive layers.
*/
auto n_flows = actl_layer_sizes.size() - 1;
for (int64_t i = 0; i < actl_layer_sizes.front() + 1; i++)
......@@ -742,7 +732,9 @@ namespace {
neighbor_indices.push_back(std::make_pair(ret->second, eids[k]));
}
}
auto cmp = [](const id_pair p, const id_pair q)->bool { return p.first < q.first; };
auto cmp = [](const id_pair p, const id_pair q) -> bool {
return p.first < q.first;
};
std::sort(neighbor_indices.begin(), neighbor_indices.end(), cmp);
for (const auto &pair : neighbor_indices) {
sub_indices->push_back(pair.first);
......@@ -755,44 +747,32 @@ namespace {
}
sub_eids->resize(sub_indices->size());
std::iota(sub_eids->begin(), sub_eids->end(), 0);
}
}
} // namespace
NodeFlow SamplerOp::LayerUniformSample(const ImmutableGraph *graph,
const std::vector<dgl_id_t>& seeds,
const std::string &neighbor_type,
IdArray layer_sizes) {
const auto g_csr = neighbor_type == "in" ? graph->GetInCSR() : graph->GetOutCSR();
const dgl_id_t *indptr = static_cast<dgl_id_t*>(g_csr->indptr()->data);
const dgl_id_t *indices = static_cast<dgl_id_t*>(g_csr->indices()->data);
const dgl_id_t *eids = static_cast<dgl_id_t*>(g_csr->edge_ids()->data);
NodeFlow SamplerOp::LayerUniformSample(
const ImmutableGraph *graph, const std::vector<dgl_id_t> &seeds,
const std::string &neighbor_type, IdArray layer_sizes) {
const auto g_csr =
neighbor_type == "in" ? graph->GetInCSR() : graph->GetOutCSR();
const dgl_id_t *indptr = static_cast<dgl_id_t *>(g_csr->indptr()->data);
const dgl_id_t *indices = static_cast<dgl_id_t *>(g_csr->indices()->data);
const dgl_id_t *eids = static_cast<dgl_id_t *>(g_csr->edge_ids()->data);
std::vector<dgl_id_t> layer_offsets;
std::vector<dgl_id_t> node_mapping;
std::vector<int64_t> actl_layer_sizes;
std::vector<float> probabilities;
ConstructLayers(indptr,
indices,
seeds,
layer_sizes,
&layer_offsets,
&node_mapping,
&actl_layer_sizes,
&probabilities);
ConstructLayers(
indptr, indices, seeds, layer_sizes, &layer_offsets, &node_mapping,
&actl_layer_sizes, &probabilities);
std::vector<dgl_id_t> sub_indptr, sub_indices, sub_edge_ids;
std::vector<dgl_id_t> flow_offsets;
std::vector<dgl_id_t> edge_mapping;
ConstructFlows(indptr,
indices,
eids,
node_mapping,
actl_layer_sizes,
&sub_indptr,
&sub_indices,
&sub_edge_ids,
&flow_offsets,
&edge_mapping);
ConstructFlows(
indptr, indices, eids, node_mapping, actl_layer_sizes, &sub_indptr,
&sub_indices, &sub_edge_ids, &flow_offsets, &edge_mapping);
// sanity check
CHECK_GT(sub_indptr.size(), 0);
CHECK_EQ(sub_indptr[0], 0);
......@@ -800,8 +780,8 @@ NodeFlow SamplerOp::LayerUniformSample(const ImmutableGraph *graph,
CHECK_EQ(sub_indices.size(), sub_edge_ids.size());
NodeFlow nf = NodeFlow::Create();
auto sub_csr = CSRPtr(new CSR(aten::VecToIdArray(sub_indptr),
aten::VecToIdArray(sub_indices),
auto sub_csr = CSRPtr(new CSR(
aten::VecToIdArray(sub_indptr), aten::VecToIdArray(sub_indices),
aten::VecToIdArray(sub_edge_ids)));
if (neighbor_type == std::string("in")) {
......@@ -830,24 +810,22 @@ void BuildCsr(const ImmutableGraph &g, const std::string neigh_type) {
}
}
template<typename ValueType>
std::vector<NodeFlow> NeighborSamplingImpl(const ImmutableGraphPtr gptr,
const IdArray seed_nodes,
const int64_t batch_start_id,
const int64_t batch_size,
const int64_t max_num_workers,
const int64_t expand_factor,
const int64_t num_hops,
const std::string neigh_type,
const bool add_self_loop,
const ValueType *probability) {
template <typename ValueType>
std::vector<NodeFlow> NeighborSamplingImpl(
const ImmutableGraphPtr gptr, const IdArray seed_nodes,
const int64_t batch_start_id, const int64_t batch_size,
const int64_t max_num_workers, const int64_t expand_factor,
const int64_t num_hops, const std::string neigh_type,
const bool add_self_loop, const ValueType *probability) {
// process args
CHECK(aten::IsValidIdArray(seed_nodes));
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data);
const dgl_id_t *seed_nodes_data = static_cast<dgl_id_t *>(seed_nodes->data);
const int64_t num_seeds = seed_nodes->shape[0];
const int64_t num_workers = std::min(max_num_workers,
const int64_t num_workers = std::min(
max_num_workers,
(num_seeds + batch_size - 1) / batch_size - batch_start_id);
// We need to make sure we have the right CSR before we enter parallel sampling.
// We need to make sure we have the right CSR before we enter parallel
// sampling.
BuildCsr(*gptr, neigh_type);
// generate node flows
std::vector<NodeFlow> nflows(num_workers);
......@@ -858,8 +836,8 @@ std::vector<NodeFlow> NeighborSamplingImpl(const ImmutableGraphPtr gptr,
const int64_t end = std::min(start + batch_size, num_seeds);
// TODO(minjie): the vector allocation/copy is unnecessary
std::vector<dgl_id_t> worker_seeds(end - start);
std::copy(seed_nodes_data + start, seed_nodes_data + end,
worker_seeds.begin());
std::copy(
seed_nodes_data + start, seed_nodes_data + end, worker_seeds.begin());
nflows[i] = SamplerOp::NeighborSample(
gptr.get(), worker_seeds, neigh_type, num_hops, expand_factor,
add_self_loop, probability);
......@@ -869,7 +847,7 @@ std::vector<NodeFlow> NeighborSamplingImpl(const ImmutableGraphPtr gptr,
}
DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
// arguments
const GraphRef g = args[0];
const IdArray seed_nodes = args[1];
......@@ -896,7 +874,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_UniformSampling")
});
DGL_REGISTER_GLOBAL("sampling._CAPI_NeighborSampling")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
// arguments
const GraphRef g = args[0];
const IdArray seed_nodes = args[1];
......@@ -926,17 +904,17 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_NeighborSampling")
<< "NeighborSampling only support CPU sampling";
ATEN_FLOAT_TYPE_SWITCH(
probability->dtype,
FloatType,
"transition probability",
{
probability->dtype, FloatType, "transition probability", {
const FloatType *prob;
if (aten::IsNullArray(probability)) {
prob = nullptr;
} else {
CHECK(probability->shape[0] == static_cast<int64_t>(gptr->NumEdges()))
<< "transition probability must have same number of elements as edges";
CHECK(
probability->shape[0] ==
static_cast<int64_t>(gptr->NumEdges()))
<< "transition probability must have same number of elements "
"as edges";
CHECK(probability.IsContiguous())
<< "transition probability must be contiguous tensor";
prob = static_cast<const FloatType *>(probability->data);
......@@ -951,7 +929,7 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_NeighborSampling")
});
DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
// arguments
GraphRef g = args[0];
const IdArray seed_nodes = args[1];
......@@ -971,11 +949,14 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
CHECK_EQ(layer_sizes->ctx.device_type, kDGLCPU)
<< "LayerSampler only support CPU sampling";
const dgl_id_t* seed_nodes_data = static_cast<dgl_id_t*>(seed_nodes->data);
const dgl_id_t *seed_nodes_data =
static_cast<dgl_id_t *>(seed_nodes->data);
const int64_t num_seeds = seed_nodes->shape[0];
const int64_t num_workers = std::min(max_num_workers,
const int64_t num_workers = std::min(
max_num_workers,
(num_seeds + batch_size - 1) / batch_size - batch_start_id);
// We need to make sure we have the right CSR before we enter parallel sampling.
// We need to make sure we have the right CSR before we enter parallel
// sampling.
BuildCsr(*gptr, neigh_type);
// generate node flows
std::vector<NodeFlow> nflows(num_workers);
......@@ -986,7 +967,8 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_LayerSampling")
const int64_t end = std::min(start + batch_size, num_seeds);
// TODO(minjie): the vector allocation/copy is unnecessary
std::vector<dgl_id_t> worker_seeds(end - start);
std::copy(seed_nodes_data + start, seed_nodes_data + end,
std::copy(
seed_nodes_data + start, seed_nodes_data + end,
worker_seeds.begin());
nflows[i] = SamplerOp::LayerUniformSample(
gptr.get(), worker_seeds, neigh_type, layer_sizes);
......@@ -1002,9 +984,8 @@ void BuildCoo(const ImmutableGraph &g) {
assert(coo);
}
dgl_id_t global2local_map(dgl_id_t global_id,
std::unordered_map<dgl_id_t, dgl_id_t> *map) {
dgl_id_t global2local_map(
dgl_id_t global_id, std::unordered_map<dgl_id_t, dgl_id_t> *map) {
auto it = map->find(global_id);
if (it == map->end()) {
dgl_id_t local_id = map->size();
......@@ -1020,7 +1001,8 @@ inline bool IsNegativeHeadMode(const std::string &mode) {
}
IdArray GetGlobalVid(IdArray induced_nid, IdArray subg_nid) {
IdArray gnid = IdArray::Empty({subg_nid->shape[0]}, subg_nid->dtype, subg_nid->ctx);
IdArray gnid =
IdArray::Empty({subg_nid->shape[0]}, subg_nid->dtype, subg_nid->ctx);
const dgl_id_t *induced_nid_data = static_cast<dgl_id_t *>(induced_nid->data);
const dgl_id_t *subg_nid_data = static_cast<dgl_id_t *>(subg_nid->data);
dgl_id_t *gnid_data = static_cast<dgl_id_t *>(gnid->data);
......@@ -1030,14 +1012,14 @@ IdArray GetGlobalVid(IdArray induced_nid, IdArray subg_nid) {
return gnid;
}
IdArray CheckExistence(GraphPtr gptr, IdArray neg_src, IdArray neg_dst,
IdArray induced_nid) {
return gptr->HasEdgesBetween(GetGlobalVid(induced_nid, neg_src),
GetGlobalVid(induced_nid, neg_dst));
IdArray CheckExistence(
GraphPtr gptr, IdArray neg_src, IdArray neg_dst, IdArray induced_nid) {
return gptr->HasEdgesBetween(
GetGlobalVid(induced_nid, neg_src), GetGlobalVid(induced_nid, neg_dst));
}
IdArray CheckExistence(GraphPtr gptr, IdArray relations,
IdArray neg_src, IdArray neg_dst,
IdArray CheckExistence(
GraphPtr gptr, IdArray relations, IdArray neg_src, IdArray neg_dst,
IdArray induced_nid, IdArray neg_eid) {
neg_src = GetGlobalVid(induced_nid, neg_src);
neg_dst = GetGlobalVid(induced_nid, neg_dst);
......@@ -1051,8 +1033,7 @@ IdArray CheckExistence(GraphPtr gptr, IdArray relations,
int64_t num_neg_edges = neg_src->shape[0];
for (int64_t i = 0; i < num_neg_edges; i++) {
// If the edge doesn't exist, we don't need to do anything.
if (!exist_data[i])
continue;
if (!exist_data[i]) continue;
// If the edge exists, we need to double check if the relations match.
// If they match, this negative edge isn't really a negative edge.
dgl_id_t eid1 = neg_eid_data[i];
......@@ -1073,7 +1054,8 @@ IdArray CheckExistence(GraphPtr gptr, IdArray relations,
return exist;
}
std::vector<dgl_id_t> Global2Local(const std::vector<size_t> &ids,
std::vector<dgl_id_t> Global2Local(
const std::vector<size_t> &ids,
const std::unordered_map<dgl_id_t, dgl_id_t> &map) {
std::vector<dgl_id_t> local_ids(ids.size());
for (size_t i = 0; i < ids.size(); i++) {
......@@ -1084,35 +1066,36 @@ std::vector<dgl_id_t> Global2Local(const std::vector<size_t> &ids,
return local_ids;
}
NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg,
const std::string &neg_mode,
int64_t neg_sample_size,
bool exclude_positive,
bool check_false_neg) {
NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(
const Subgraph &pos_subg, const std::string &neg_mode,
int64_t neg_sample_size, bool exclude_positive, bool check_false_neg) {
int64_t num_tot_nodes = gptr_->NumVertices();
if (neg_sample_size > num_tot_nodes)
neg_sample_size = num_tot_nodes;
if (neg_sample_size > num_tot_nodes) neg_sample_size = num_tot_nodes;
std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo");
IdArray coo = adj[0];
int64_t num_pos_edges = coo->shape[0] / 2;
int64_t num_neg_edges = num_pos_edges * neg_sample_size;
IdArray neg_dst = IdArray::Empty({num_neg_edges}, coo->dtype, coo->ctx);
IdArray neg_src = IdArray::Empty({num_neg_edges}, coo->dtype, coo->ctx);
IdArray induced_neg_eid = IdArray::Empty({num_neg_edges}, coo->dtype, coo->ctx);
IdArray induced_neg_eid =
IdArray::Empty({num_neg_edges}, coo->dtype, coo->ctx);
// These are vids in the positive subgraph.
const dgl_id_t *dst_data = static_cast<const dgl_id_t *>(coo->data);
const dgl_id_t *src_data = static_cast<const dgl_id_t *>(coo->data) + num_pos_edges;
const dgl_id_t *src_data =
static_cast<const dgl_id_t *>(coo->data) + num_pos_edges;
const dgl_id_t *induced_vid_data =
static_cast<const dgl_id_t *>(pos_subg.induced_vertices->data);
const dgl_id_t *induced_eid_data =
static_cast<const dgl_id_t *>(pos_subg.induced_edges->data);
size_t num_pos_nodes = pos_subg.graph->NumVertices();
std::vector<size_t> pos_nodes(induced_vid_data, induced_vid_data + num_pos_nodes);
std::vector<size_t> pos_nodes(
induced_vid_data, induced_vid_data + num_pos_nodes);
dgl_id_t *neg_dst_data = static_cast<dgl_id_t *>(neg_dst->data);
dgl_id_t *neg_src_data = static_cast<dgl_id_t *>(neg_src->data);
dgl_id_t *induced_neg_eid_data = static_cast<dgl_id_t *>(induced_neg_eid->data);
dgl_id_t *induced_neg_eid_data =
static_cast<dgl_id_t *>(induced_neg_eid->data);
const dgl_id_t *unchanged;
dgl_id_t *neg_unchanged;
......@@ -1206,7 +1189,8 @@ NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg,
for (int64_t j = 0; j < neg_sample_size; j++) {
neg_unchanged[neg_idx + j] = local_unchanged;
dgl_id_t local_changed = global2local_map(neg_vids[j + prev_neg_offset], &neg_map);
dgl_id_t local_changed =
global2local_map(neg_vids[j + prev_neg_offset], &neg_map);
neg_changed[neg_idx + j] = local_changed;
// induced negative eid references to the positive one.
induced_neg_eid_data[neg_idx + j] = induced_eid_data[i];
......@@ -1215,8 +1199,10 @@ NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg,
// Now we know the number of vertices in the negative graph.
int64_t num_neg_nodes = neg_map.size();
IdArray induced_neg_vid = IdArray::Empty({num_neg_nodes}, coo->dtype, coo->ctx);
dgl_id_t *induced_neg_vid_data = static_cast<dgl_id_t *>(induced_neg_vid->data);
IdArray induced_neg_vid =
IdArray::Empty({num_neg_nodes}, coo->dtype, coo->ctx);
dgl_id_t *induced_neg_vid_data =
static_cast<dgl_id_t *>(induced_neg_vid->data);
for (auto it = neg_map.begin(); it != neg_map.end(); it++) {
induced_neg_vid_data[it->second] = it->first;
}
......@@ -1241,24 +1227,22 @@ NegSubgraph EdgeSamplerObject::genNegEdgeSubgraph(const Subgraph &pos_subg,
if (aten::IsNullArray(relations_)) {
neg_subg.exist = CheckExistence(gptr_, neg_src, neg_dst, induced_neg_vid);
} else {
neg_subg.exist = CheckExistence(gptr_, relations_, neg_src, neg_dst,
induced_neg_vid, induced_neg_eid);
neg_subg.exist = CheckExistence(
gptr_, relations_, neg_src, neg_dst, induced_neg_vid,
induced_neg_eid);
}
}
return neg_subg;
}
NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_subg,
const std::string &neg_mode,
int64_t neg_sample_size,
bool exclude_positive,
bool check_false_neg) {
NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(
const Subgraph &pos_subg, const std::string &neg_mode,
int64_t neg_sample_size, bool exclude_positive, bool check_false_neg) {
int64_t num_tot_nodes = gptr_->NumVertices();
std::vector<IdArray> adj = pos_subg.graph->GetAdj(false, "coo");
IdArray coo = adj[0];
int64_t num_pos_edges = coo->shape[0] / 2;
if (neg_sample_size > num_tot_nodes)
neg_sample_size = num_tot_nodes;
if (neg_sample_size > num_tot_nodes) neg_sample_size = num_tot_nodes;
int64_t chunk_size = chunk_size_;
CHECK_GT(chunk_size, 0) << "chunk size has to be positive";
......@@ -1275,26 +1259,29 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub
int64_t num_all_neg_edges = num_neg_edges + num_neg_edges_last_chunk;
// We should include the last chunk.
if (last_chunk_size > 0)
num_chunks++;
if (last_chunk_size > 0) num_chunks++;
IdArray neg_dst = IdArray::Empty({num_all_neg_edges}, coo->dtype, coo->ctx);
IdArray neg_src = IdArray::Empty({num_all_neg_edges}, coo->dtype, coo->ctx);
IdArray induced_neg_eid = IdArray::Empty({num_all_neg_edges}, coo->dtype, coo->ctx);
IdArray induced_neg_eid =
IdArray::Empty({num_all_neg_edges}, coo->dtype, coo->ctx);
// These are vids in the positive subgraph.
const dgl_id_t *dst_data = static_cast<const dgl_id_t *>(coo->data);
const dgl_id_t *src_data = static_cast<const dgl_id_t *>(coo->data) + num_pos_edges;
const dgl_id_t *src_data =
static_cast<const dgl_id_t *>(coo->data) + num_pos_edges;
const dgl_id_t *induced_vid_data =
static_cast<const dgl_id_t *>(pos_subg.induced_vertices->data);
const dgl_id_t *induced_eid_data =
static_cast<const dgl_id_t *>(pos_subg.induced_edges->data);
int64_t num_pos_nodes = pos_subg.graph->NumVertices();
std::vector<dgl_id_t> pos_nodes(induced_vid_data, induced_vid_data + num_pos_nodes);
std::vector<dgl_id_t> pos_nodes(
induced_vid_data, induced_vid_data + num_pos_nodes);
dgl_id_t *neg_dst_data = static_cast<dgl_id_t *>(neg_dst->data);
dgl_id_t *neg_src_data = static_cast<dgl_id_t *>(neg_src->data);
dgl_id_t *induced_neg_eid_data = static_cast<dgl_id_t *>(induced_neg_eid->data);
dgl_id_t *induced_neg_eid_data =
static_cast<dgl_id_t *>(induced_neg_eid->data);
const dgl_id_t *unchanged;
dgl_id_t *neg_unchanged;
......@@ -1312,9 +1299,7 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub
// We first sample all negative edges.
std::vector<size_t> global_neg_vids;
std::vector<size_t> local_neg_vids;
randomSample(num_tot_nodes,
num_chunks * neg_sample_size,
&global_neg_vids);
randomSample(num_tot_nodes, num_chunks * neg_sample_size, &global_neg_vids);
CHECK_EQ(num_chunks * neg_sample_size, global_neg_vids.size());
std::unordered_map<dgl_id_t, dgl_id_t> neg_map;
......@@ -1336,7 +1321,7 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub
// to reduce computation overhead.
local_neg_vids.resize(global_neg_vids.size());
for (size_t i = 0; i < global_neg_vids.size(); i++) {
local_neg_vids[i] = global2local_map(global_neg_vids[i], &neg_map);;
local_neg_vids[i] = global2local_map(global_neg_vids[i], &neg_map);
}
for (int64_t i_chunk = 0; i_chunk < num_chunks; i_chunk++) {
......@@ -1353,12 +1338,14 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub
for (int64_t in_chunk = 0; in_chunk != chunk_size1; ++in_chunk) {
// For each positive node in a chunk.
dgl_id_t global_unchanged = induced_vid_data[unchanged[pos_edge_idx + in_chunk]];
dgl_id_t global_unchanged =
induced_vid_data[unchanged[pos_edge_idx + in_chunk]];
dgl_id_t local_unchanged = global2local_map(global_unchanged, &neg_map);
for (int64_t j = 0; j < neg_sample_size; ++j) {
neg_unchanged[neg_idx] = local_unchanged;
neg_changed[neg_idx] = local_neg_vids[neg_node_idx + j];
induced_neg_eid_data[neg_idx] = induced_eid_data[pos_edge_idx + in_chunk];
induced_neg_eid_data[neg_idx] =
induced_eid_data[pos_edge_idx + in_chunk];
neg_idx++;
}
}
......@@ -1366,8 +1353,10 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub
// Now we know the number of vertices in the negative graph.
int64_t num_neg_nodes = neg_map.size();
IdArray induced_neg_vid = IdArray::Empty({num_neg_nodes}, coo->dtype, coo->ctx);
dgl_id_t *induced_neg_vid_data = static_cast<dgl_id_t *>(induced_neg_vid->data);
IdArray induced_neg_vid =
IdArray::Empty({num_neg_nodes}, coo->dtype, coo->ctx);
dgl_id_t *induced_neg_vid_data =
static_cast<dgl_id_t *>(induced_neg_vid->data);
for (auto it = neg_map.begin(); it != neg_map.end(); it++) {
induced_neg_vid_data[it->second] = it->first;
}
......@@ -1380,18 +1369,21 @@ NegSubgraph EdgeSamplerObject::genChunkedNegEdgeSubgraph(const Subgraph &pos_sub
neg_subg.induced_vertices = induced_neg_vid;
neg_subg.induced_edges = induced_neg_eid;
if (IsNegativeHeadMode(neg_mode)) {
neg_subg.head_nid = aten::VecToIdArray(Global2Local(global_neg_vids, neg_map));
neg_subg.head_nid =
aten::VecToIdArray(Global2Local(global_neg_vids, neg_map));
neg_subg.tail_nid = aten::VecToIdArray(local_pos_vids);
} else {
neg_subg.head_nid = aten::VecToIdArray(local_pos_vids);
neg_subg.tail_nid = aten::VecToIdArray(Global2Local(global_neg_vids, neg_map));
neg_subg.tail_nid =
aten::VecToIdArray(Global2Local(global_neg_vids, neg_map));
}
if (check_false_neg) {
if (aten::IsNullArray(relations_)) {
neg_subg.exist = CheckExistence(gptr_, neg_src, neg_dst, induced_neg_vid);
} else {
neg_subg.exist = CheckExistence(gptr_, relations_, neg_src, neg_dst,
induced_neg_vid, induced_neg_eid);
neg_subg.exist = CheckExistence(
gptr_, relations_, neg_src, neg_dst, induced_neg_vid,
induced_neg_eid);
}
}
return neg_subg;
......@@ -1408,52 +1400,38 @@ inline SubgraphRef ConvertRef(const NegSubgraph &subg) {
} // namespace
DGL_REGISTER_GLOBAL("sampling._CAPI_GetNegEdgeExistence")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr());
*rv = gptr->exist;
});
});
DGL_REGISTER_GLOBAL("sampling._CAPI_GetEdgeSubgraphHead")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr());
*rv = gptr->head_nid;
});
});
DGL_REGISTER_GLOBAL("sampling._CAPI_GetEdgeSubgraphTail")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
SubgraphRef g = args[0];
auto gptr = std::dynamic_pointer_cast<NegSubgraph>(g.sptr());
*rv = gptr->tail_nid;
});
class UniformEdgeSamplerObject: public EdgeSamplerObject {
public:
explicit UniformEdgeSamplerObject(const GraphPtr gptr,
IdArray seed_edges,
const int64_t batch_size,
const int64_t num_workers,
const bool replacement,
const bool reset,
const std::string neg_mode,
const int64_t neg_sample_size,
const int64_t chunk_size,
const bool exclude_positive,
const bool check_false_neg,
IdArray relations)
: EdgeSamplerObject(gptr,
seed_edges,
batch_size,
num_workers,
replacement,
reset,
neg_mode,
neg_sample_size,
chunk_size,
exclude_positive,
check_false_neg,
relations) {
});
class UniformEdgeSamplerObject : public EdgeSamplerObject {
public:
explicit UniformEdgeSamplerObject(
const GraphPtr gptr, IdArray seed_edges, const int64_t batch_size,
const int64_t num_workers, const bool replacement, const bool reset,
const std::string neg_mode, const int64_t neg_sample_size,
const int64_t chunk_size, const bool exclude_positive,
const bool check_false_neg, IdArray relations)
: EdgeSamplerObject(
gptr, seed_edges, batch_size, num_workers, replacement, reset,
neg_mode, neg_sample_size, chunk_size, exclude_positive,
check_false_neg, relations) {
batch_curr_id_ = 0;
num_seeds_ = seed_edges->shape[0];
max_batch_id_ = (num_seeds_ + batch_size - 1) / batch_size;
......@@ -1463,8 +1441,9 @@ public:
}
~UniformEdgeSamplerObject() {}
void Fetch(DGLRetValue* rv) {
const int64_t num_workers = std::min(num_workers_, max_batch_id_ - batch_curr_id_);
void Fetch(DGLRetValue *rv) {
const int64_t num_workers =
std::min(num_workers_, max_batch_id_ - batch_curr_id_);
// generate subgraphs.
std::vector<SubgraphRef> positive_subgs(num_workers);
std::vector<SubgraphRef> negative_subgs(num_workers);
......@@ -1477,11 +1456,13 @@ public:
IdArray worker_seeds;
if (replacement_ == false) {
worker_seeds = seed_edges_.CreateView({num_edges}, DGLDataType{kDGLInt, 64, 1},
worker_seeds = seed_edges_.CreateView(
{num_edges}, DGLDataType{kDGLInt, 64, 1},
sizeof(dgl_id_t) * start);
} else {
std::vector<dgl_id_t> seeds;
const dgl_id_t *seed_edge_ids = static_cast<const dgl_id_t *>(seed_edges_->data);
const dgl_id_t *seed_edge_ids =
static_cast<const dgl_id_t *>(seed_edges_->data);
// sampling of each edge is a standalone event
for (int64_t i = 0; i < num_edges; ++i) {
int64_t seed = static_cast<const int64_t>(
......@@ -1497,29 +1478,29 @@ public:
const dgl_id_t *dst_ids = static_cast<const dgl_id_t *>(arr.dst->data);
std::vector<dgl_id_t> src_vec(src_ids, src_ids + num_edges);
std::vector<dgl_id_t> dst_vec(dst_ids, dst_ids + num_edges);
// TODO(zhengda) what if there are duplicates in the src and dst vectors.
// TODO(zhengda) what if there are duplicates in the src and dst
// vectors.
Subgraph subg = gptr_->EdgeSubgraph(worker_seeds, false);
positive_subgs[i] = ConvertRef(subg);
// For chunked negative sampling, we accept "chunk-head" for corrupting head
// nodes and "chunk-tail" for corrupting tail nodes.
// For chunked negative sampling, we accept "chunk-head" for corrupting
// head nodes and "chunk-tail" for corrupting tail nodes.
if (neg_mode_.substr(0, 5) == "chunk") {
NegSubgraph neg_subg = genChunkedNegEdgeSubgraph(subg, neg_mode_.substr(6),
neg_sample_size_,
exclude_positive_,
NegSubgraph neg_subg = genChunkedNegEdgeSubgraph(
subg, neg_mode_.substr(6), neg_sample_size_, exclude_positive_,
check_false_neg_);
negative_subgs[i] = ConvertRef(neg_subg);
} else if (neg_mode_ == "head" || neg_mode_ == "tail") {
NegSubgraph neg_subg = genNegEdgeSubgraph(subg, neg_mode_,
neg_sample_size_,
exclude_positive_,
NegSubgraph neg_subg = genNegEdgeSubgraph(
subg, neg_mode_, neg_sample_size_, exclude_positive_,
check_false_neg_);
negative_subgs[i] = ConvertRef(neg_subg);
}
}
});
if (neg_mode_.size() > 0) {
positive_subgs.insert(positive_subgs.end(), negative_subgs.begin(), negative_subgs.end());
positive_subgs.insert(
positive_subgs.end(), negative_subgs.begin(), negative_subgs.end());
}
batch_curr_id_ += num_workers;
......@@ -1535,20 +1516,22 @@ public:
if (replacement_ == false) {
// Now we should shuffle the data and reset the sampler.
dgl_id_t *seed_ids = static_cast<dgl_id_t *>(seed_edges_->data);
std::shuffle(seed_ids, seed_ids + seed_edges_->shape[0],
std::shuffle(
seed_ids, seed_ids + seed_edges_->shape[0],
std::default_random_engine());
}
}
DGL_DECLARE_OBJECT_TYPE_INFO(UniformEdgeSamplerObject, Object);
private:
void randomSample(size_t set_size, size_t num, std::vector<size_t>* out) {
private:
void randomSample(size_t set_size, size_t num, std::vector<size_t> *out) {
RandomSample(set_size, num, out);
}
void randomSample(size_t set_size, size_t num, const std::vector<size_t> &exclude,
std::vector<size_t>* out) {
void randomSample(
size_t set_size, size_t num, const std::vector<size_t> &exclude,
std::vector<size_t> *out) {
RandomSample(set_size, num, exclude, out);
}
......@@ -1557,17 +1540,19 @@ private:
int64_t num_seeds_;
};
class UniformEdgeSampler: public ObjectRef {
class UniformEdgeSampler : public ObjectRef {
public:
UniformEdgeSampler() {}
explicit UniformEdgeSampler(std::shared_ptr<runtime::Object> obj): ObjectRef(obj) {}
explicit UniformEdgeSampler(std::shared_ptr<runtime::Object> obj)
: ObjectRef(obj) {}
UniformEdgeSamplerObject* operator->() const {
return static_cast<UniformEdgeSamplerObject*>(obj_.get());
UniformEdgeSamplerObject *operator->() const {
return static_cast<UniformEdgeSamplerObject *>(obj_.get());
}
std::shared_ptr<UniformEdgeSamplerObject> sptr() const {
return CHECK_NOTNULL(std::dynamic_pointer_cast<UniformEdgeSamplerObject>(obj_));
return CHECK_NOTNULL(
std::dynamic_pointer_cast<UniformEdgeSamplerObject>(obj_));
}
operator bool() const { return this->defined(); }
......@@ -1575,7 +1560,7 @@ class UniformEdgeSampler: public ObjectRef {
};
DGL_REGISTER_GLOBAL("sampling._CAPI_CreateUniformEdgeSampler")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
// arguments
GraphRef g = args[0];
IdArray seed_edges = args[1];
......@@ -1603,64 +1588,42 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_CreateUniformEdgeSampler")
}
BuildCoo(*gptr);
auto o = std::make_shared<UniformEdgeSamplerObject>(gptr,
seed_edges,
batch_size,
max_num_workers,
replacement,
reset,
neg_mode,
neg_sample_size,
chunk_size,
exclude_positive,
check_false_neg,
relations);
auto o = std::make_shared<UniformEdgeSamplerObject>(
gptr, seed_edges, batch_size, max_num_workers, replacement, reset,
neg_mode, neg_sample_size, chunk_size, exclude_positive,
check_false_neg, relations);
*rv = o;
});
});
DGL_REGISTER_GLOBAL("sampling._CAPI_FetchUniformEdgeSample")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
UniformEdgeSampler sampler = args[0];
sampler->Fetch(rv);
});
});
DGL_REGISTER_GLOBAL("sampling._CAPI_ResetUniformEdgeSample")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
UniformEdgeSampler sampler = args[0];
sampler->Reset();
});
});
template<typename ValueType>
class WeightedEdgeSamplerObject: public EdgeSamplerObject {
template <typename ValueType>
class WeightedEdgeSamplerObject : public EdgeSamplerObject {
public:
explicit WeightedEdgeSamplerObject(const GraphPtr gptr,
IdArray seed_edges,
NDArray edge_weight,
NDArray node_weight,
const int64_t batch_size,
const int64_t num_workers,
const bool replacement,
const bool reset,
const std::string neg_mode,
const int64_t neg_sample_size,
const int64_t chunk_size,
const bool exclude_positive,
const bool check_false_neg,
explicit WeightedEdgeSamplerObject(
const GraphPtr gptr, IdArray seed_edges, NDArray edge_weight,
NDArray node_weight, const int64_t batch_size, const int64_t num_workers,
const bool replacement, const bool reset, const std::string neg_mode,
const int64_t neg_sample_size, const int64_t chunk_size,
const bool exclude_positive, const bool check_false_neg,
IdArray relations)
: EdgeSamplerObject(gptr,
seed_edges,
batch_size,
num_workers,
replacement,
reset,
neg_mode,
neg_sample_size,
chunk_size,
exclude_positive,
check_false_neg,
relations) {
: EdgeSamplerObject(
gptr, seed_edges, batch_size, num_workers, replacement, reset,
neg_mode, neg_sample_size, chunk_size, exclude_positive,
check_false_neg, relations) {
const int64_t num_edges = edge_weight->shape[0];
const ValueType *edge_prob = static_cast<const ValueType*>(edge_weight->data);
const ValueType *edge_prob =
static_cast<const ValueType *>(edge_weight->data);
std::vector<ValueType> eprob(num_edges);
for (int64_t i = 0; i < num_edges; ++i) {
eprob[i] = edge_prob[i];
......@@ -1672,7 +1635,8 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
if (num_nodes == 0) {
node_selector_ = nullptr;
} else {
const ValueType *node_prob = static_cast<const ValueType*>(node_weight->data);
const ValueType *node_prob =
static_cast<const ValueType *>(node_weight->data);
std::vector<ValueType> nprob(num_nodes);
for (size_t i = 0; i < num_nodes; ++i) {
nprob[i] = node_prob[i];
......@@ -1687,27 +1651,26 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
gptr_->FindEdge(0);
}
~WeightedEdgeSamplerObject() {
}
~WeightedEdgeSamplerObject() {}
void Fetch(DGLRetValue* rv) {
const int64_t num_workers = std::min(num_workers_, max_batch_id_ - curr_batch_id_);
void Fetch(DGLRetValue *rv) {
const int64_t num_workers =
std::min(num_workers_, max_batch_id_ - curr_batch_id_);
// generate subgraphs.
std::vector<SubgraphRef> positive_subgs(num_workers);
std::vector<SubgraphRef> negative_subgs(num_workers);
#pragma omp parallel for
for (int i = 0; i < num_workers; i++) {
const dgl_id_t *seed_edge_ids = static_cast<const dgl_id_t *>(seed_edges_->data);
const dgl_id_t *seed_edge_ids =
static_cast<const dgl_id_t *>(seed_edges_->data);
std::vector<size_t> edge_ids(batch_size_);
if (replacement_ == false) {
size_t n = batch_size_;
size_t num_ids = 0;
#pragma omp critical
{
num_ids = edge_selector_->SampleWithoutReplacement(n, &edge_ids);
}
{ num_ids = edge_selector_->SampleWithoutReplacement(n, &edge_ids); }
edge_ids.resize(num_ids);
for (size_t i = 0; i < num_ids; ++i) {
edge_ids[i] = seed_edge_ids[edge_ids[i]];
......@@ -1730,18 +1693,16 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
// TODO(zhengda) what if there are duplicates in the src and dst vectors.
Subgraph subg = gptr_->EdgeSubgraph(worker_seeds, false);
positive_subgs[i] = ConvertRef(subg);
// For chunked negative sampling, we accept "chunk-head" for corrupting head
// nodes and "chunk-tail" for corrupting tail nodes.
// For chunked negative sampling, we accept "chunk-head" for corrupting
// head nodes and "chunk-tail" for corrupting tail nodes.
if (neg_mode_.substr(0, 5) == "chunk") {
NegSubgraph neg_subg = genChunkedNegEdgeSubgraph(subg, neg_mode_.substr(6),
neg_sample_size_,
exclude_positive_,
NegSubgraph neg_subg = genChunkedNegEdgeSubgraph(
subg, neg_mode_.substr(6), neg_sample_size_, exclude_positive_,
check_false_neg_);
negative_subgs[i] = ConvertRef(neg_subg);
} else if (neg_mode_ == "head" || neg_mode_ == "tail") {
NegSubgraph neg_subg = genNegEdgeSubgraph(subg, neg_mode_,
neg_sample_size_,
exclude_positive_,
NegSubgraph neg_subg = genNegEdgeSubgraph(
subg, neg_mode_, neg_sample_size_, exclude_positive_,
check_false_neg_);
negative_subgs[i] = ConvertRef(neg_subg);
}
......@@ -1753,7 +1714,8 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
}
if (neg_mode_.size() > 0) {
positive_subgs.insert(positive_subgs.end(), negative_subgs.begin(), negative_subgs.end());
positive_subgs.insert(
positive_subgs.end(), negative_subgs.begin(), negative_subgs.end());
}
*rv = List<SubgraphRef>(positive_subgs);
}
......@@ -1762,7 +1724,8 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
curr_batch_id_ = 0;
if (replacement_ == false) {
const int64_t num_edges = edge_weight_->shape[0];
const ValueType *edge_prob = static_cast<const ValueType*>(edge_weight_->data);
const ValueType *edge_prob =
static_cast<const ValueType *>(edge_weight_->data);
std::vector<ValueType> eprob(num_edges);
for (int64_t i = 0; i < num_edges; ++i) {
eprob[i] = edge_prob[i];
......@@ -1775,8 +1738,8 @@ class WeightedEdgeSamplerObject: public EdgeSamplerObject {
DGL_DECLARE_OBJECT_TYPE_INFO(WeightedEdgeSamplerObject<ValueType>, Object);
private:
void randomSample(size_t set_size, size_t num, std::vector<size_t>* out) {
private:
void randomSample(size_t set_size, size_t num, std::vector<size_t> *out) {
if (num < set_size) {
std::unordered_set<size_t> sampled_idxs;
while (sampled_idxs.size() < num) {
......@@ -1792,13 +1755,13 @@ private:
} else {
// If we need to sample all elements in the set, we don't need to
// generate random numbers.
for (size_t i = 0; i < set_size; i++)
out->push_back(i);
for (size_t i = 0; i < set_size; i++) out->push_back(i);
}
}
void randomSample(size_t set_size, size_t num, const std::vector<size_t> &exclude,
std::vector<size_t>* out) {
void randomSample(
size_t set_size, size_t num, const std::vector<size_t> &exclude,
std::vector<size_t> *out) {
std::unordered_map<size_t, int> sampled_idxs;
for (auto v : exclude) {
sampled_idxs.insert(std::pair<size_t, int>(v, 0));
......@@ -1830,7 +1793,7 @@ private:
}
}
private:
private:
std::shared_ptr<ArrayHeap<ValueType>> edge_selector_;
std::shared_ptr<ArrayHeap<ValueType>> node_selector_;
......@@ -1841,17 +1804,19 @@ private:
template class WeightedEdgeSamplerObject<float>;
class FloatWeightedEdgeSampler: public ObjectRef {
class FloatWeightedEdgeSampler : public ObjectRef {
public:
FloatWeightedEdgeSampler() {}
explicit FloatWeightedEdgeSampler(std::shared_ptr<runtime::Object> obj): ObjectRef(obj) {}
explicit FloatWeightedEdgeSampler(std::shared_ptr<runtime::Object> obj)
: ObjectRef(obj) {}
WeightedEdgeSamplerObject<float>* operator->() const {
return static_cast<WeightedEdgeSamplerObject<float>*>(obj_.get());
WeightedEdgeSamplerObject<float> *operator->() const {
return static_cast<WeightedEdgeSamplerObject<float> *>(obj_.get());
}
std::shared_ptr<WeightedEdgeSamplerObject<float>> sptr() const {
return CHECK_NOTNULL(std::dynamic_pointer_cast<WeightedEdgeSamplerObject<float>>(obj_));
return CHECK_NOTNULL(
std::dynamic_pointer_cast<WeightedEdgeSamplerObject<float>>(obj_));
}
operator bool() const { return this->defined(); }
......@@ -1859,7 +1824,7 @@ class FloatWeightedEdgeSampler: public ObjectRef {
};
DGL_REGISTER_GLOBAL("sampling._CAPI_CreateWeightedEdgeSampler")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
// arguments
GraphRef g = args[0];
IdArray seed_edges = args[1];
......@@ -1881,13 +1846,17 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_CreateWeightedEdgeSampler")
CHECK(aten::IsValidIdArray(seed_edges));
CHECK_EQ(seed_edges->ctx.device_type, kDGLCPU)
<< "WeightedEdgeSampler only support CPU sampling";
CHECK(edge_weight->dtype.code == kDGLFloat) << "edge_weight should be FloatType";
CHECK(edge_weight->dtype.bits == 32) << "WeightedEdgeSampler only support float weight";
CHECK(edge_weight->dtype.code == kDGLFloat)
<< "edge_weight should be FloatType";
CHECK(edge_weight->dtype.bits == 32)
<< "WeightedEdgeSampler only support float weight";
CHECK_EQ(edge_weight->ctx.device_type, kDGLCPU)
<< "WeightedEdgeSampler only support CPU sampling";
if (node_weight->shape[0] > 0) {
CHECK(node_weight->dtype.code == kDGLFloat) << "node_weight should be FloatType";
CHECK(node_weight->dtype.bits == 32) << "WeightedEdgeSampler only support float weight";
CHECK(node_weight->dtype.code == kDGLFloat)
<< "node_weight should be FloatType";
CHECK(node_weight->dtype.bits == 32)
<< "WeightedEdgeSampler only support float weight";
CHECK_EQ(node_weight->ctx.device_type, kDGLCPU)
<< "WeightedEdgeSampler only support CPU sampling";
}
......@@ -1899,36 +1868,26 @@ DGL_REGISTER_GLOBAL("sampling._CAPI_CreateWeightedEdgeSampler")
BuildCoo(*gptr);
const int64_t num_seeds = seed_edges->shape[0];
const int64_t num_workers = std::min(max_num_workers,
(num_seeds + batch_size - 1) / batch_size);
auto o = std::make_shared<WeightedEdgeSamplerObject<float>>(gptr,
seed_edges,
edge_weight,
node_weight,
batch_size,
num_workers,
replacement,
reset,
neg_mode,
neg_sample_size,
chunk_size,
exclude_positive,
check_false_neg,
relations);
const int64_t num_workers =
std::min(max_num_workers, (num_seeds + batch_size - 1) / batch_size);
auto o = std::make_shared<WeightedEdgeSamplerObject<float>>(
gptr, seed_edges, edge_weight, node_weight, batch_size, num_workers,
replacement, reset, neg_mode, neg_sample_size, chunk_size,
exclude_positive, check_false_neg, relations);
*rv = o;
});
});
DGL_REGISTER_GLOBAL("sampling._CAPI_FetchWeightedEdgeSample")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
FloatWeightedEdgeSampler sampler = args[0];
sampler->Fetch(rv);
});
});
DGL_REGISTER_GLOBAL("sampling._CAPI_ResetWeightedEdgeSample")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
.set_body([](DGLArgs args, DGLRetValue *rv) {
FloatWeightedEdgeSampler sampler = args[0];
sampler->Reset();
});
});
} // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment