Unverified Commit 81831111 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4805)



* [Misc] clang-format auto fix.

* fix
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 16e771c0
...@@ -5,11 +5,13 @@ ...@@ -5,11 +5,13 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/sampling/negative.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/sampling/negative.h>
#include <utility> #include <utility>
#include "../../../c_api_common.h" #include "../../../c_api_common.h"
using namespace dgl::runtime; using namespace dgl::runtime;
...@@ -19,13 +21,8 @@ namespace dgl { ...@@ -19,13 +21,8 @@ namespace dgl {
namespace sampling { namespace sampling {
std::pair<IdArray, IdArray> GlobalUniformNegativeSampling( std::pair<IdArray, IdArray> GlobalUniformNegativeSampling(
HeteroGraphPtr hg, HeteroGraphPtr hg, dgl_type_t etype, int64_t num_samples, int num_trials,
dgl_type_t etype, bool exclude_self_loops, bool replace, double redundancy) {
int64_t num_samples,
int num_trials,
bool exclude_self_loops,
bool replace,
double redundancy) {
auto format = hg->SelectFormat(etype, CSC_CODE | CSR_CODE); auto format = hg->SelectFormat(etype, CSC_CODE | CSR_CODE);
if (format == SparseFormat::kCSC) { if (format == SparseFormat::kCSC) {
CSRMatrix csc = hg->GetCSCMatrix(etype); CSRMatrix csc = hg->GetCSCMatrix(etype);
...@@ -40,28 +37,30 @@ std::pair<IdArray, IdArray> GlobalUniformNegativeSampling( ...@@ -40,28 +37,30 @@ std::pair<IdArray, IdArray> GlobalUniformNegativeSampling(
return CSRGlobalUniformNegativeSampling( return CSRGlobalUniformNegativeSampling(
csr, num_samples, num_trials, exclude_self_loops, replace, redundancy); csr, num_samples, num_trials, exclude_self_loops, replace, redundancy);
} else { } else {
LOG(FATAL) << "COO format is not supported in global uniform negative sampling"; LOG(FATAL)
<< "COO format is not supported in global uniform negative sampling";
return {IdArray(), IdArray()}; return {IdArray(), IdArray()};
} }
} }
DGL_REGISTER_GLOBAL("sampling.negative._CAPI_DGLGlobalUniformNegativeSampling") DGL_REGISTER_GLOBAL("sampling.negative._CAPI_DGLGlobalUniformNegativeSampling")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
dgl_type_t etype = args[1]; dgl_type_t etype = args[1];
CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype; CHECK_LE(etype, hg->NumEdgeTypes()) << "invalid edge type " << etype;
int64_t num_samples = args[2]; int64_t num_samples = args[2];
int num_trials = args[3]; int num_trials = args[3];
bool exclude_self_loops = args[4]; bool exclude_self_loops = args[4];
bool replace = args[5]; bool replace = args[5];
double redundancy = args[6]; double redundancy = args[6];
List<Value> result; List<Value> result;
std::pair<IdArray, IdArray> ret = GlobalUniformNegativeSampling( std::pair<IdArray, IdArray> ret = GlobalUniformNegativeSampling(
hg.sptr(), etype, num_samples, num_trials, exclude_self_loops, replace, redundancy); hg.sptr(), etype, num_samples, num_trials, exclude_self_loops,
result.push_back(Value(MakeValue(ret.first))); replace, redundancy);
result.push_back(Value(MakeValue(ret.second))); result.push_back(Value(MakeValue(ret.first)));
*rv = result; result.push_back(Value(MakeValue(ret.second)));
}); *rv = result;
});
}; // namespace sampling }; // namespace sampling
}; // namespace dgl }; // namespace dgl
...@@ -7,9 +7,10 @@ ...@@ -7,9 +7,10 @@
#include <algorithm> #include <algorithm>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include "../../../runtime/cuda/cuda_common.h"
#include "../../../array/cuda/atomic.cuh" #include "../../../array/cuda/atomic.cuh"
#include "../../../array/cuda/dgl_cub.cuh" #include "../../../array/cuda/dgl_cub.cuh"
#include "../../../runtime/cuda/cuda_common.h"
#include "frequency_hashmap.cuh" #include "frequency_hashmap.cuh"
namespace dgl { namespace dgl {
...@@ -31,11 +32,10 @@ int64_t _table_size(const int64_t num, const int64_t scale) { ...@@ -31,11 +32,10 @@ int64_t _table_size(const int64_t num, const int64_t scale) {
return next_pow2 << scale; return next_pow2 << scale;
} }
template <typename IdxType, int BLOCK_SIZE, int TILE_SIZE>
template<typename IdxType, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _init_edge_table(void *edge_hashmap, int64_t edges_len) { __global__ void _init_edge_table(void *edge_hashmap, int64_t edges_len) {
using EdgeItem = typename DeviceEdgeHashmap<IdxType>::EdgeItem; using EdgeItem = typename DeviceEdgeHashmap<IdxType>::EdgeItem;
auto edge_hashmap_t = static_cast<EdgeItem*>(edge_hashmap); auto edge_hashmap_t = static_cast<EdgeItem *>(edge_hashmap);
int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x; int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;
int64_t last_idx = start_idx + TILE_SIZE; int64_t last_idx = start_idx + TILE_SIZE;
#pragma unroll(4) #pragma unroll(4)
...@@ -48,11 +48,11 @@ __global__ void _init_edge_table(void *edge_hashmap, int64_t edges_len) { ...@@ -48,11 +48,11 @@ __global__ void _init_edge_table(void *edge_hashmap, int64_t edges_len) {
} }
} }
template<typename IdxType, int BLOCK_SIZE, int TILE_SIZE> template <typename IdxType, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _count_frequency(const IdxType *src_data, __global__ void _count_frequency(
const int64_t num_edges, const int64_t num_edges_per_node, const IdxType *src_data, const int64_t num_edges,
IdxType *edge_blocks_prefix, bool *is_first_position, const int64_t num_edges_per_node, IdxType *edge_blocks_prefix,
DeviceEdgeHashmap<IdxType> device_edge_hashmap) { bool *is_first_position, DeviceEdgeHashmap<IdxType> device_edge_hashmap) {
int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x; int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;
int64_t last_idx = start_idx + TILE_SIZE; int64_t last_idx = start_idx + TILE_SIZE;
...@@ -101,11 +101,11 @@ struct BlockPrefixCallbackOp { ...@@ -101,11 +101,11 @@ struct BlockPrefixCallbackOp {
} }
}; };
template<typename IdxType, typename Idx64Type, int BLOCK_SIZE, int TILE_SIZE> template <typename IdxType, typename Idx64Type, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _compact_frequency(const IdxType *src_data, const IdxType *dst_data, __global__ void _compact_frequency(
const int64_t num_edges, const int64_t num_edges_per_node, const IdxType *src_data, const IdxType *dst_data, const int64_t num_edges,
const IdxType *edge_blocks_prefix, const bool *is_first_position, const int64_t num_edges_per_node, const IdxType *edge_blocks_prefix,
IdxType *num_unique_each_node, const bool *is_first_position, IdxType *num_unique_each_node,
IdxType *unique_src_edges, Idx64Type *unique_frequency, IdxType *unique_src_edges, Idx64Type *unique_frequency,
DeviceEdgeHashmap<IdxType> device_edge_hashmap) { DeviceEdgeHashmap<IdxType> device_edge_hashmap) {
int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x; int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;
...@@ -122,7 +122,8 @@ __global__ void _compact_frequency(const IdxType *src_data, const IdxType *dst_d ...@@ -122,7 +122,8 @@ __global__ void _compact_frequency(const IdxType *src_data, const IdxType *dst_d
IdxType src = src_data[idx]; IdxType src = src_data[idx];
IdxType dst_idx = (idx / num_edges_per_node); IdxType dst_idx = (idx / num_edges_per_node);
if (idx % num_edges_per_node == 0) { if (idx % num_edges_per_node == 0) {
num_unique_each_node[dst_idx] = device_edge_hashmap.GetDstCount(dst_idx); num_unique_each_node[dst_idx] =
device_edge_hashmap.GetDstCount(dst_idx);
} }
if (is_first_position[idx] == true) { if (is_first_position[idx] == true) {
flag = 1; flag = 1;
...@@ -132,22 +133,25 @@ __global__ void _compact_frequency(const IdxType *src_data, const IdxType *dst_d ...@@ -132,22 +133,25 @@ __global__ void _compact_frequency(const IdxType *src_data, const IdxType *dst_d
if (is_first_position[idx] == true) { if (is_first_position[idx] == true) {
const IdxType pos = (block_offset + flag); const IdxType pos = (block_offset + flag);
unique_src_edges[pos] = src; unique_src_edges[pos] = src;
if (sizeof(IdxType) != sizeof(Idx64Type) if (sizeof(IdxType) != sizeof(Idx64Type) &&
&& sizeof(IdxType) == 4) { // if IdxType is a 32-bit data sizeof(IdxType) == 4) { // if IdxType is a 32-bit data
unique_frequency[pos] = ( unique_frequency[pos] =
(static_cast<Idx64Type>(num_edges / num_edges_per_node - dst_idx) << 32) ((static_cast<Idx64Type>(num_edges / num_edges_per_node - dst_idx)
| device_edge_hashmap.GetEdgeCount(src, dst_idx)); << 32) |
device_edge_hashmap.GetEdgeCount(src, dst_idx));
} else { } else {
unique_frequency[pos] = device_edge_hashmap.GetEdgeCount(src, dst_idx); unique_frequency[pos] =
device_edge_hashmap.GetEdgeCount(src, dst_idx);
} }
} }
} }
} }
} }
template<typename IdxType, int BLOCK_SIZE, int TILE_SIZE> template <typename IdxType, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _get_pick_num(IdxType *num_unique_each_node, __global__ void _get_pick_num(
const int64_t num_pick, const int64_t num_dst_nodes) { IdxType *num_unique_each_node, const int64_t num_pick,
const int64_t num_dst_nodes) {
int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x; int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;
int64_t last_idx = start_idx + TILE_SIZE; int64_t last_idx = start_idx + TILE_SIZE;
#pragma unroll(4) #pragma unroll(4)
...@@ -159,12 +163,12 @@ __global__ void _get_pick_num(IdxType *num_unique_each_node, ...@@ -159,12 +163,12 @@ __global__ void _get_pick_num(IdxType *num_unique_each_node,
} }
} }
template<typename IdxType, typename Idx64Type, int BLOCK_SIZE, int TILE_SIZE> template <typename IdxType, typename Idx64Type, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _pick_data(const Idx64Type *unique_frequency, const IdxType *unique_src_edges, __global__ void _pick_data(
const Idx64Type *unique_frequency, const IdxType *unique_src_edges,
const IdxType *unique_input_offsets, const IdxType *dst_data, const IdxType *unique_input_offsets, const IdxType *dst_data,
const int64_t num_edges_per_node, const int64_t num_dst_nodes, const int64_t num_edges_per_node, const int64_t num_dst_nodes,
const int64_t num_edges, const int64_t num_edges, const IdxType *unique_output_offsets,
const IdxType *unique_output_offsets,
IdxType *output_src, IdxType *output_dst, IdxType *output_frequency) { IdxType *output_src, IdxType *output_dst, IdxType *output_frequency) {
int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x; int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;
int64_t last_idx = start_idx + TILE_SIZE; int64_t last_idx = start_idx + TILE_SIZE;
...@@ -175,13 +179,16 @@ __global__ void _pick_data(const Idx64Type *unique_frequency, const IdxType *uni ...@@ -175,13 +179,16 @@ __global__ void _pick_data(const Idx64Type *unique_frequency, const IdxType *uni
assert(dst_pos < num_edges); assert(dst_pos < num_edges);
const IdxType dst = dst_data[dst_pos]; const IdxType dst = dst_data[dst_pos];
const IdxType last_output_offset = unique_output_offsets[idx + 1]; const IdxType last_output_offset = unique_output_offsets[idx + 1];
assert((last_output_offset - unique_output_offsets[idx]) <= assert(
(last_output_offset - unique_output_offsets[idx]) <=
(unique_input_offsets[idx + 1] - unique_input_offsets[idx])); (unique_input_offsets[idx + 1] - unique_input_offsets[idx]));
for (IdxType output_idx = unique_output_offsets[idx], input_idx = unique_input_offsets[idx]; for (IdxType output_idx = unique_output_offsets[idx],
output_idx < last_output_offset; ++output_idx, ++input_idx) { input_idx = unique_input_offsets[idx];
output_idx < last_output_offset; ++output_idx, ++input_idx) {
output_src[output_idx] = unique_src_edges[input_idx]; output_src[output_idx] = unique_src_edges[input_idx];
output_dst[output_idx] = dst; output_dst[output_idx] = dst;
output_frequency[output_idx] = static_cast<IdxType>(unique_frequency[input_idx]); output_frequency[output_idx] =
static_cast<IdxType>(unique_frequency[input_idx]);
} }
} }
} }
...@@ -190,7 +197,7 @@ __global__ void _pick_data(const Idx64Type *unique_frequency, const IdxType *uni ...@@ -190,7 +197,7 @@ __global__ void _pick_data(const Idx64Type *unique_frequency, const IdxType *uni
} // namespace } // namespace
// return the old cnt of this edge // return the old cnt of this edge
template<typename IdxType> template <typename IdxType>
inline __device__ IdxType DeviceEdgeHashmap<IdxType>::InsertEdge( inline __device__ IdxType DeviceEdgeHashmap<IdxType>::InsertEdge(
const IdxType &src, const IdxType &dst_idx) { const IdxType &src, const IdxType &dst_idx) {
IdxType start_off = dst_idx * _num_items_each_dst; IdxType start_off = dst_idx * _num_items_each_dst;
...@@ -206,7 +213,8 @@ inline __device__ IdxType DeviceEdgeHashmap<IdxType>::InsertEdge( ...@@ -206,7 +213,8 @@ inline __device__ IdxType DeviceEdgeHashmap<IdxType>::InsertEdge(
&_edge_hashmap[start_off + pos].cnt, static_cast<IdxType>(1)); &_edge_hashmap[start_off + pos].cnt, static_cast<IdxType>(1));
if (old_src == static_cast<IdxType>(-1)) { if (old_src == static_cast<IdxType>(-1)) {
assert(dst_idx < _num_dst); assert(dst_idx < _num_dst);
dgl::aten::cuda::AtomicAdd(&_dst_unique_edges[dst_idx], static_cast<IdxType>(1)); dgl::aten::cuda::AtomicAdd(
&_dst_unique_edges[dst_idx], static_cast<IdxType>(1));
} }
break; break;
} }
...@@ -216,12 +224,13 @@ inline __device__ IdxType DeviceEdgeHashmap<IdxType>::InsertEdge( ...@@ -216,12 +224,13 @@ inline __device__ IdxType DeviceEdgeHashmap<IdxType>::InsertEdge(
return old_cnt; return old_cnt;
} }
template<typename IdxType> template <typename IdxType>
inline __device__ IdxType DeviceEdgeHashmap<IdxType>::GetDstCount(const IdxType &dst_idx) { inline __device__ IdxType
DeviceEdgeHashmap<IdxType>::GetDstCount(const IdxType &dst_idx) {
return _dst_unique_edges[dst_idx]; return _dst_unique_edges[dst_idx];
} }
template<typename IdxType> template <typename IdxType>
inline __device__ IdxType DeviceEdgeHashmap<IdxType>::GetEdgeCount( inline __device__ IdxType DeviceEdgeHashmap<IdxType>::GetEdgeCount(
const IdxType &src, const IdxType &dst_idx) { const IdxType &src, const IdxType &dst_idx) {
IdxType start_off = dst_idx * _num_items_each_dst; IdxType start_off = dst_idx * _num_items_each_dst;
...@@ -242,18 +251,18 @@ FrequencyHashmap<IdxType>::FrequencyHashmap( ...@@ -242,18 +251,18 @@ FrequencyHashmap<IdxType>::FrequencyHashmap(
_stream = stream; _stream = stream;
num_items_each_dst = _table_size(num_items_each_dst, edge_table_scale); num_items_each_dst = _table_size(num_items_each_dst, edge_table_scale);
auto device = dgl::runtime::DeviceAPI::Get(_ctx); auto device = dgl::runtime::DeviceAPI::Get(_ctx);
auto dst_unique_edges = static_cast<IdxType*>( auto dst_unique_edges = static_cast<IdxType *>(
device->AllocWorkspace(_ctx, (num_dst) * sizeof(IdxType))); device->AllocWorkspace(_ctx, (num_dst) * sizeof(IdxType)));
auto edge_hashmap = static_cast<EdgeItem*>( auto edge_hashmap = static_cast<EdgeItem *>(device->AllocWorkspace(
device->AllocWorkspace(_ctx, (num_dst * num_items_each_dst) * sizeof(EdgeItem))); _ctx, (num_dst * num_items_each_dst) * sizeof(EdgeItem)));
constexpr int BLOCK_SIZE = 256; constexpr int BLOCK_SIZE = 256;
constexpr int TILE_SIZE = BLOCK_SIZE * 8; constexpr int TILE_SIZE = BLOCK_SIZE * 8;
dim3 block(BLOCK_SIZE); dim3 block(BLOCK_SIZE);
dim3 grid((num_dst * num_items_each_dst + TILE_SIZE - 1) / TILE_SIZE); dim3 grid((num_dst * num_items_each_dst + TILE_SIZE - 1) / TILE_SIZE);
CUDA_CALL(cudaMemset(dst_unique_edges, 0, (num_dst) * sizeof(IdxType))); CUDA_CALL(cudaMemset(dst_unique_edges, 0, (num_dst) * sizeof(IdxType)));
CUDA_KERNEL_CALL((_init_edge_table<IdxType, BLOCK_SIZE, TILE_SIZE>), CUDA_KERNEL_CALL(
grid, block, 0, _stream, (_init_edge_table<IdxType, BLOCK_SIZE, TILE_SIZE>), grid, block, 0,
edge_hashmap, (num_dst * num_items_each_dst)); _stream, edge_hashmap, (num_dst * num_items_each_dst));
_device_edge_hashmap = new DeviceEdgeHashmap<IdxType>( _device_edge_hashmap = new DeviceEdgeHashmap<IdxType>(
num_dst, num_items_each_dst, dst_unique_edges, edge_hashmap); num_dst, num_items_each_dst, dst_unique_edges, edge_hashmap);
_dst_unique_edges = dst_unique_edges; _dst_unique_edges = dst_unique_edges;
...@@ -271,155 +280,175 @@ FrequencyHashmap<IdxType>::~FrequencyHashmap() { ...@@ -271,155 +280,175 @@ FrequencyHashmap<IdxType>::~FrequencyHashmap() {
_edge_hashmap = nullptr; _edge_hashmap = nullptr;
} }
template <typename IdxType> template <typename IdxType>
std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk( std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk(
const IdxType *src_data, const IdxType *dst_data, DGLDataType dtype, const IdxType *src_data, const IdxType *dst_data, DGLDataType dtype,
const int64_t num_edges, const int64_t num_edges_per_node, const int64_t num_edges, const int64_t num_edges_per_node,
const int64_t num_pick) { const int64_t num_pick) {
using Idx64Type = int64_t; using Idx64Type = int64_t;
const int64_t num_dst_nodes = (num_edges / num_edges_per_node); const int64_t num_dst_nodes = (num_edges / num_edges_per_node);
constexpr int BLOCK_SIZE = 256; constexpr int BLOCK_SIZE = 256;
// XXX: a experienced value, best performance in GV100 // XXX: a experienced value, best performance in GV100
constexpr int TILE_SIZE = BLOCK_SIZE * 32; constexpr int TILE_SIZE = BLOCK_SIZE * 32;
const dim3 block(BLOCK_SIZE); const dim3 block(BLOCK_SIZE);
const dim3 edges_grid((num_edges + TILE_SIZE - 1) / TILE_SIZE); const dim3 edges_grid((num_edges + TILE_SIZE - 1) / TILE_SIZE);
auto device = dgl::runtime::DeviceAPI::Get(_ctx); auto device = dgl::runtime::DeviceAPI::Get(_ctx);
const IdxType num_edge_blocks = static_cast<IdxType>(edges_grid.x); const IdxType num_edge_blocks = static_cast<IdxType>(edges_grid.x);
IdxType num_unique_edges = 0; IdxType num_unique_edges = 0;
// to mark if this position of edges is the first inserting position for _edge_hashmap // to mark if this position of edges is the first inserting position for
bool *is_first_position = static_cast<bool*>( // _edge_hashmap
bool *is_first_position = static_cast<bool *>(
device->AllocWorkspace(_ctx, sizeof(bool) * (num_edges))); device->AllocWorkspace(_ctx, sizeof(bool) * (num_edges)));
CUDA_CALL(cudaMemset(is_first_position, 0, sizeof(bool) * (num_edges))); CUDA_CALL(cudaMemset(is_first_position, 0, sizeof(bool) * (num_edges)));
// double space to use ExclusiveSum // double space to use ExclusiveSum
auto edge_blocks_prefix_data = static_cast<IdxType*>( auto edge_blocks_prefix_data = static_cast<IdxType *>(device->AllocWorkspace(
device->AllocWorkspace(_ctx, 2 * sizeof(IdxType) * (num_edge_blocks + 1))); _ctx, 2 * sizeof(IdxType) * (num_edge_blocks + 1)));
IdxType *edge_blocks_prefix = edge_blocks_prefix_data; IdxType *edge_blocks_prefix = edge_blocks_prefix_data;
IdxType *edge_blocks_prefix_alternate = (edge_blocks_prefix_data + (num_edge_blocks + 1)); IdxType *edge_blocks_prefix_alternate =
(edge_blocks_prefix_data + (num_edge_blocks + 1));
// triple space to use ExclusiveSum and unique_output_offsets // triple space to use ExclusiveSum and unique_output_offsets
auto num_unique_each_node_data = static_cast<IdxType*>( auto num_unique_each_node_data = static_cast<IdxType *>(
device->AllocWorkspace(_ctx, 3 * sizeof(IdxType) * (num_dst_nodes + 1))); device->AllocWorkspace(_ctx, 3 * sizeof(IdxType) * (num_dst_nodes + 1)));
IdxType *num_unique_each_node = num_unique_each_node_data; IdxType *num_unique_each_node = num_unique_each_node_data;
IdxType *num_unique_each_node_alternate = (num_unique_each_node_data + (num_dst_nodes + 1)); IdxType *num_unique_each_node_alternate =
IdxType *unique_output_offsets = (num_unique_each_node_data + 2 * (num_dst_nodes + 1)); (num_unique_each_node_data + (num_dst_nodes + 1));
IdxType *unique_output_offsets =
// 1. Scan the all edges and count the unique edges and unique edges for each dst node (num_unique_each_node_data + 2 * (num_dst_nodes + 1));
CUDA_KERNEL_CALL((_count_frequency<IdxType, BLOCK_SIZE, TILE_SIZE>),
edges_grid, block, 0, _stream, // 1. Scan the all edges and count the unique edges and unique edges for each
src_data, num_edges, num_edges_per_node, // dst node
edge_blocks_prefix, is_first_position, *_device_edge_hashmap); CUDA_KERNEL_CALL(
(_count_frequency<IdxType, BLOCK_SIZE, TILE_SIZE>), edges_grid, block, 0,
_stream, src_data, num_edges, num_edges_per_node, edge_blocks_prefix,
is_first_position, *_device_edge_hashmap);
// 2. Compact the unique edges frequency // 2. Compact the unique edges frequency
// 2.1 ExclusiveSum the edge_blocks_prefix // 2.1 ExclusiveSum the edge_blocks_prefix
void *d_temp_storage = nullptr; void *d_temp_storage = nullptr;
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceScan::ExclusiveSum(
edge_blocks_prefix, edge_blocks_prefix_alternate, num_edge_blocks + 1, _stream)); d_temp_storage, temp_storage_bytes, edge_blocks_prefix,
edge_blocks_prefix_alternate, num_edge_blocks + 1, _stream));
d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes); d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceScan::ExclusiveSum(
edge_blocks_prefix, edge_blocks_prefix_alternate, num_edge_blocks + 1, _stream)); d_temp_storage, temp_storage_bytes, edge_blocks_prefix,
edge_blocks_prefix_alternate, num_edge_blocks + 1, _stream));
device->FreeWorkspace(_ctx, d_temp_storage); device->FreeWorkspace(_ctx, d_temp_storage);
std::swap(edge_blocks_prefix, edge_blocks_prefix_alternate); std::swap(edge_blocks_prefix, edge_blocks_prefix_alternate);
device->CopyDataFromTo(&edge_blocks_prefix[num_edge_blocks], 0, &num_unique_edges, 0, device->CopyDataFromTo(
&edge_blocks_prefix[num_edge_blocks], 0, &num_unique_edges, 0,
sizeof(num_unique_edges), _ctx, DGLContext{kDGLCPU, 0}, dtype); sizeof(num_unique_edges), _ctx, DGLContext{kDGLCPU, 0}, dtype);
device->StreamSync(_ctx, _stream); device->StreamSync(_ctx, _stream);
// 2.2 Allocate the data of unique edges and frequency // 2.2 Allocate the data of unique edges and frequency
// double space to use SegmentedRadixSort // double space to use SegmentedRadixSort
auto unique_src_edges_data = static_cast<IdxType*>( auto unique_src_edges_data = static_cast<IdxType *>(
device->AllocWorkspace(_ctx, 2 * sizeof(IdxType) * (num_unique_edges))); device->AllocWorkspace(_ctx, 2 * sizeof(IdxType) * (num_unique_edges)));
IdxType *unique_src_edges = unique_src_edges_data; IdxType *unique_src_edges = unique_src_edges_data;
IdxType *unique_src_edges_alternate = unique_src_edges_data + num_unique_edges; IdxType *unique_src_edges_alternate =
unique_src_edges_data + num_unique_edges;
// double space to use SegmentedRadixSort // double space to use SegmentedRadixSort
auto unique_frequency_data = static_cast<Idx64Type*>( auto unique_frequency_data = static_cast<Idx64Type *>(
device->AllocWorkspace(_ctx, 2 * sizeof(Idx64Type) * (num_unique_edges))); device->AllocWorkspace(_ctx, 2 * sizeof(Idx64Type) * (num_unique_edges)));
Idx64Type *unique_frequency = unique_frequency_data; Idx64Type *unique_frequency = unique_frequency_data;
Idx64Type *unique_frequency_alternate = unique_frequency_data + num_unique_edges; Idx64Type *unique_frequency_alternate =
unique_frequency_data + num_unique_edges;
// 2.3 Compact the unique edges and their frequency // 2.3 Compact the unique edges and their frequency
CUDA_KERNEL_CALL((_compact_frequency<IdxType, Idx64Type, BLOCK_SIZE, TILE_SIZE>), CUDA_KERNEL_CALL(
edges_grid, block, 0, _stream, (_compact_frequency<IdxType, Idx64Type, BLOCK_SIZE, TILE_SIZE>),
src_data, dst_data, num_edges, num_edges_per_node, edges_grid, block, 0, _stream, src_data, dst_data, num_edges,
edge_blocks_prefix, is_first_position, num_unique_each_node, num_edges_per_node, edge_blocks_prefix, is_first_position,
unique_src_edges, unique_frequency, *_device_edge_hashmap); num_unique_each_node, unique_src_edges, unique_frequency,
*_device_edge_hashmap);
// 3. SegmentedRadixSort the unique edges and unique_frequency // 3. SegmentedRadixSort the unique edges and unique_frequency
// 3.1 ExclusiveSum the num_unique_each_node // 3.1 ExclusiveSum the num_unique_each_node
d_temp_storage = nullptr; d_temp_storage = nullptr;
temp_storage_bytes = 0; temp_storage_bytes = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceScan::ExclusiveSum(
num_unique_each_node, num_unique_each_node_alternate, num_dst_nodes + 1, _stream)); d_temp_storage, temp_storage_bytes, num_unique_each_node,
num_unique_each_node_alternate, num_dst_nodes + 1, _stream));
d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes); d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceScan::ExclusiveSum(
num_unique_each_node, num_unique_each_node_alternate, num_dst_nodes + 1, _stream)); d_temp_storage, temp_storage_bytes, num_unique_each_node,
num_unique_each_node_alternate, num_dst_nodes + 1, _stream));
device->FreeWorkspace(_ctx, d_temp_storage); device->FreeWorkspace(_ctx, d_temp_storage);
// 3.2 SegmentedRadixSort the unique_src_edges and unique_frequency // 3.2 SegmentedRadixSort the unique_src_edges and unique_frequency
// Create a set of DoubleBuffers to wrap pairs of device pointers // Create a set of DoubleBuffers to wrap pairs of device pointers
cub::DoubleBuffer<Idx64Type> d_unique_frequency(unique_frequency, unique_frequency_alternate); cub::DoubleBuffer<Idx64Type> d_unique_frequency(
cub::DoubleBuffer<IdxType> d_unique_src_edges(unique_src_edges, unique_src_edges_alternate); unique_frequency, unique_frequency_alternate);
cub::DoubleBuffer<IdxType> d_unique_src_edges(
unique_src_edges, unique_src_edges_alternate);
// Determine temporary device storage requirements // Determine temporary device storage requirements
d_temp_storage = nullptr; d_temp_storage = nullptr;
temp_storage_bytes = 0; temp_storage_bytes = 0;
// the DeviceRadixSort is faster than DeviceSegmentedRadixSort, // the DeviceRadixSort is faster than DeviceSegmentedRadixSort,
// especially when num_dst_nodes is large (about ~10000) // especially when num_dst_nodes is large (about ~10000)
if (dtype.bits == 32) { if (dtype.bits == 32) {
CUDA_CALL(cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceRadixSort::SortPairsDescending(
d_unique_frequency, d_unique_src_edges, num_unique_edges, d_temp_storage, temp_storage_bytes, d_unique_frequency,
0, sizeof(Idx64Type)*8, _stream)); d_unique_src_edges, num_unique_edges, 0, sizeof(Idx64Type) * 8,
_stream));
} else { } else {
CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairsDescending(
d_unique_frequency, d_unique_src_edges, num_unique_edges, num_dst_nodes, d_temp_storage, temp_storage_bytes, d_unique_frequency,
num_unique_each_node_alternate, num_unique_each_node_alternate + 1, d_unique_src_edges, num_unique_edges, num_dst_nodes,
0, sizeof(Idx64Type)*8, _stream)); num_unique_each_node_alternate, num_unique_each_node_alternate + 1, 0,
sizeof(Idx64Type) * 8, _stream));
} }
d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes); d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);
if (dtype.bits == 32) { if (dtype.bits == 32) {
CUDA_CALL(cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceRadixSort::SortPairsDescending(
d_unique_frequency, d_unique_src_edges, num_unique_edges, d_temp_storage, temp_storage_bytes, d_unique_frequency,
0, sizeof(Idx64Type)*8, _stream)); d_unique_src_edges, num_unique_edges, 0, sizeof(Idx64Type) * 8,
_stream));
} else { } else {
CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairsDescending(
d_unique_frequency, d_unique_src_edges, num_unique_edges, num_dst_nodes, d_temp_storage, temp_storage_bytes, d_unique_frequency,
num_unique_each_node_alternate, num_unique_each_node_alternate + 1, d_unique_src_edges, num_unique_edges, num_dst_nodes,
0, sizeof(Idx64Type)*8, _stream)); num_unique_each_node_alternate, num_unique_each_node_alternate + 1, 0,
sizeof(Idx64Type) * 8, _stream));
} }
device->FreeWorkspace(_ctx, d_temp_storage); device->FreeWorkspace(_ctx, d_temp_storage);
// 4. Get the final pick number for each dst node // 4. Get the final pick number for each dst node
// 4.1 Reset the min(num_pick, num_unique_each_node) to num_unique_each_node // 4.1 Reset the min(num_pick, num_unique_each_node) to num_unique_each_node
constexpr int NODE_TILE_SIZE = BLOCK_SIZE * 2; constexpr int NODE_TILE_SIZE = BLOCK_SIZE * 2;
const dim3 nodes_grid((num_dst_nodes + NODE_TILE_SIZE - 1) / NODE_TILE_SIZE); const dim3 nodes_grid((num_dst_nodes + NODE_TILE_SIZE - 1) / NODE_TILE_SIZE);
CUDA_KERNEL_CALL((_get_pick_num<IdxType, BLOCK_SIZE, NODE_TILE_SIZE>), CUDA_KERNEL_CALL(
nodes_grid, block, 0, _stream, (_get_pick_num<IdxType, BLOCK_SIZE, NODE_TILE_SIZE>), nodes_grid, block,
num_unique_each_node, num_pick, num_dst_nodes); 0, _stream, num_unique_each_node, num_pick, num_dst_nodes);
// 4.2 ExclusiveSum the new num_unique_each_node as unique_output_offsets // 4.2 ExclusiveSum the new num_unique_each_node as unique_output_offsets
// use unique_output_offsets; // use unique_output_offsets;
d_temp_storage = nullptr; d_temp_storage = nullptr;
temp_storage_bytes = 0; temp_storage_bytes = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceScan::ExclusiveSum(
num_unique_each_node, unique_output_offsets, num_dst_nodes + 1, _stream)); d_temp_storage, temp_storage_bytes, num_unique_each_node,
unique_output_offsets, num_dst_nodes + 1, _stream));
d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes); d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, CUDA_CALL(cub::DeviceScan::ExclusiveSum(
num_unique_each_node, unique_output_offsets, num_dst_nodes + 1, _stream)); d_temp_storage, temp_storage_bytes, num_unique_each_node,
unique_output_offsets, num_dst_nodes + 1, _stream));
device->FreeWorkspace(_ctx, d_temp_storage); device->FreeWorkspace(_ctx, d_temp_storage);
// 5. Pick the data to result // 5. Pick the data to result
IdxType num_output = 0; IdxType num_output = 0;
device->CopyDataFromTo(&unique_output_offsets[num_dst_nodes], 0, &num_output, 0, device->CopyDataFromTo(
&unique_output_offsets[num_dst_nodes], 0, &num_output, 0,
sizeof(num_output), _ctx, DGLContext{kDGLCPU, 0}, dtype); sizeof(num_output), _ctx, DGLContext{kDGLCPU, 0}, dtype);
device->StreamSync(_ctx, _stream); device->StreamSync(_ctx, _stream);
IdArray res_src = IdArray::Empty({static_cast<int64_t>(num_output)}, IdArray res_src =
dtype, _ctx); IdArray::Empty({static_cast<int64_t>(num_output)}, dtype, _ctx);
IdArray res_dst = IdArray::Empty({static_cast<int64_t>(num_output)}, IdArray res_dst =
dtype, _ctx); IdArray::Empty({static_cast<int64_t>(num_output)}, dtype, _ctx);
IdArray res_cnt = IdArray::Empty({static_cast<int64_t>(num_output)}, IdArray res_cnt =
dtype, _ctx); IdArray::Empty({static_cast<int64_t>(num_output)}, dtype, _ctx);
CUDA_KERNEL_CALL((_pick_data<IdxType, Idx64Type, BLOCK_SIZE, NODE_TILE_SIZE>), CUDA_KERNEL_CALL(
nodes_grid, block, 0, _stream, (_pick_data<IdxType, Idx64Type, BLOCK_SIZE, NODE_TILE_SIZE>), nodes_grid,
d_unique_frequency.Current(), d_unique_src_edges.Current(), num_unique_each_node_alternate, block, 0, _stream, d_unique_frequency.Current(),
dst_data, num_edges_per_node, num_dst_nodes, num_edges, d_unique_src_edges.Current(), num_unique_each_node_alternate, dst_data,
unique_output_offsets, num_edges_per_node, num_dst_nodes, num_edges, unique_output_offsets,
res_src.Ptr<IdxType>(), res_dst.Ptr<IdxType>(), res_cnt.Ptr<IdxType>()); res_src.Ptr<IdxType>(), res_dst.Ptr<IdxType>(), res_cnt.Ptr<IdxType>());
device->FreeWorkspace(_ctx, is_first_position); device->FreeWorkspace(_ctx, is_first_position);
...@@ -431,15 +460,12 @@ std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk( ...@@ -431,15 +460,12 @@ std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk(
return std::make_tuple(res_src, res_dst, res_cnt); return std::make_tuple(res_src, res_dst, res_cnt);
} }
template template class FrequencyHashmap<int64_t>;
class FrequencyHashmap<int64_t>;
template template class FrequencyHashmap<int32_t>;
class FrequencyHashmap<int32_t>;
}; // namespace impl }; // namespace impl
}; // namespace sampling }; // namespace sampling
}; // namespace dgl }; // namespace dgl
...@@ -9,13 +9,14 @@ ...@@ -9,13 +9,14 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <tuple> #include <tuple>
namespace dgl { namespace dgl {
namespace sampling { namespace sampling {
namespace impl { namespace impl {
template<typename IdxType> template <typename IdxType>
class DeviceEdgeHashmap { class DeviceEdgeHashmap {
public: public:
struct EdgeItem { struct EdgeItem {
...@@ -23,19 +24,24 @@ class DeviceEdgeHashmap { ...@@ -23,19 +24,24 @@ class DeviceEdgeHashmap {
IdxType cnt; IdxType cnt;
}; };
DeviceEdgeHashmap() = delete; DeviceEdgeHashmap() = delete;
DeviceEdgeHashmap(int64_t num_dst, int64_t num_items_each_dst, DeviceEdgeHashmap(
IdxType* dst_unique_edges, EdgeItem *edge_hashmap): int64_t num_dst, int64_t num_items_each_dst, IdxType *dst_unique_edges,
_num_dst(num_dst), _num_items_each_dst(num_items_each_dst), EdgeItem *edge_hashmap)
_dst_unique_edges(dst_unique_edges), _edge_hashmap(edge_hashmap) {} : _num_dst(num_dst),
_num_items_each_dst(num_items_each_dst),
_dst_unique_edges(dst_unique_edges),
_edge_hashmap(edge_hashmap) {}
// return the old cnt of this edge // return the old cnt of this edge
inline __device__ IdxType InsertEdge(const IdxType &src, const IdxType &dst_idx); inline __device__ IdxType
InsertEdge(const IdxType &src, const IdxType &dst_idx);
inline __device__ IdxType GetDstCount(const IdxType &dst_idx); inline __device__ IdxType GetDstCount(const IdxType &dst_idx);
inline __device__ IdxType GetEdgeCount(const IdxType &src, const IdxType &dst_idx); inline __device__ IdxType
GetEdgeCount(const IdxType &src, const IdxType &dst_idx);
private: private:
int64_t _num_dst; int64_t _num_dst;
int64_t _num_items_each_dst; int64_t _num_items_each_dst;
IdxType *_dst_unique_edges; IdxType *_dst_unique_edges;
EdgeItem *_edge_hashmap; EdgeItem *_edge_hashmap;
inline __device__ IdxType EdgeHash(const IdxType &id) const { inline __device__ IdxType EdgeHash(const IdxType &id) const {
...@@ -43,24 +49,26 @@ class DeviceEdgeHashmap { ...@@ -43,24 +49,26 @@ class DeviceEdgeHashmap {
} }
}; };
template<typename IdxType> template <typename IdxType>
class FrequencyHashmap { class FrequencyHashmap {
public: public:
static constexpr int64_t kDefaultEdgeTableScale = 3; static constexpr int64_t kDefaultEdgeTableScale = 3;
FrequencyHashmap() = delete; FrequencyHashmap() = delete;
FrequencyHashmap(int64_t num_dst, int64_t num_items_each_dst, DGLContext ctx, cudaStream_t stream, FrequencyHashmap(
int64_t edge_table_scale = kDefaultEdgeTableScale); int64_t num_dst, int64_t num_items_each_dst, DGLContext ctx,
cudaStream_t stream, int64_t edge_table_scale = kDefaultEdgeTableScale);
~FrequencyHashmap(); ~FrequencyHashmap();
using EdgeItem = typename DeviceEdgeHashmap<IdxType>::EdgeItem; using EdgeItem = typename DeviceEdgeHashmap<IdxType>::EdgeItem;
std::tuple<IdArray, IdArray, IdArray> Topk( std::tuple<IdArray, IdArray, IdArray> Topk(
const IdxType *src_data, const IdxType *dst_data, DGLDataType dtype, const IdxType *src_data, const IdxType *dst_data, DGLDataType dtype,
const int64_t num_edges, const int64_t num_edges_per_node, const int64_t num_edges, const int64_t num_edges_per_node,
const int64_t num_pick); const int64_t num_pick);
private: private:
DGLContext _ctx; DGLContext _ctx;
cudaStream_t _stream; cudaStream_t _stream;
DeviceEdgeHashmap<IdxType> *_device_edge_hashmap; DeviceEdgeHashmap<IdxType> *_device_edge_hashmap;
IdxType *_dst_unique_edges; IdxType *_dst_unique_edges;
EdgeItem *_edge_hashmap; EdgeItem *_edge_hashmap;
}; };
......
...@@ -6,7 +6,9 @@ ...@@ -6,7 +6,9 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <utility> #include <utility>
#include "randomwalks_impl.h" #include "randomwalks_impl.h"
namespace dgl { namespace dgl {
...@@ -18,10 +20,9 @@ namespace sampling { ...@@ -18,10 +20,9 @@ namespace sampling {
namespace impl { namespace impl {
template<DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
TypeArray GetNodeTypesFromMetapath( TypeArray GetNodeTypesFromMetapath(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const TypeArray metapath) {
const TypeArray metapath) {
uint64_t num_etypes = metapath->shape[0]; uint64_t num_etypes = metapath->shape[0];
TypeArray result = TypeArray::Empty( TypeArray result = TypeArray::Empty(
{metapath->shape[0] + 1}, metapath->dtype, metapath->ctx); {metapath->shape[0] + 1}, metapath->dtype, metapath->ctx);
...@@ -38,8 +39,8 @@ TypeArray GetNodeTypesFromMetapath( ...@@ -38,8 +39,8 @@ TypeArray GetNodeTypesFromMetapath(
dgl_type_t dsttype = src_dst_type.second; dgl_type_t dsttype = src_dst_type.second;
if (srctype != curr_type) { if (srctype != curr_type) {
LOG(FATAL) << "source of edge type #" << i << LOG(FATAL) << "source of edge type #" << i
" does not match destination of edge type #" << i - 1; << " does not match destination of edge type #" << i - 1;
return result; return result;
} }
curr_type = dsttype; curr_type = dsttype;
...@@ -48,14 +49,10 @@ TypeArray GetNodeTypesFromMetapath( ...@@ -48,14 +49,10 @@ TypeArray GetNodeTypesFromMetapath(
return result; return result;
} }
template template TypeArray GetNodeTypesFromMetapath<kDGLCPU, int32_t>(
TypeArray GetNodeTypesFromMetapath<kDGLCPU, int32_t>( const HeteroGraphPtr hg, const TypeArray metapath);
const HeteroGraphPtr hg, template TypeArray GetNodeTypesFromMetapath<kDGLCPU, int64_t>(
const TypeArray metapath); const HeteroGraphPtr hg, const TypeArray metapath);
template
TypeArray GetNodeTypesFromMetapath<kDGLCPU, int64_t>(
const HeteroGraphPtr hg,
const TypeArray metapath);
}; // namespace impl }; // namespace impl
......
...@@ -4,11 +4,13 @@ ...@@ -4,11 +4,13 @@
* \brief DGL sampler * \brief DGL sampler
*/ */
#include <cuda_runtime.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <cuda_runtime.h>
#include <utility> #include <utility>
#include "randomwalks_impl.h" #include "randomwalks_impl.h"
namespace dgl { namespace dgl {
...@@ -20,19 +22,17 @@ namespace sampling { ...@@ -20,19 +22,17 @@ namespace sampling {
namespace impl { namespace impl {
template<DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
TypeArray GetNodeTypesFromMetapath( TypeArray GetNodeTypesFromMetapath(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const TypeArray metapath) {
const TypeArray metapath) {
uint64_t num_etypes = metapath->shape[0]; uint64_t num_etypes = metapath->shape[0];
auto cpu_ctx = DGLContext{kDGLCPU, 0}; auto cpu_ctx = DGLContext{kDGLCPU, 0};
auto metapath_ctx = metapath->ctx; auto metapath_ctx = metapath->ctx;
auto stream = DeviceAPI::Get(metapath_ctx)->GetStream(); auto stream = DeviceAPI::Get(metapath_ctx)->GetStream();
TypeArray h_result = TypeArray::Empty( TypeArray h_result =
{metapath->shape[0] + 1}, metapath->dtype, cpu_ctx); TypeArray::Empty({metapath->shape[0] + 1}, metapath->dtype, cpu_ctx);
auto h_result_data = h_result.Ptr<IdxType>(); auto h_result_data = h_result.Ptr<IdxType>();
auto h_metapath = metapath.CopyTo(cpu_ctx); auto h_metapath = metapath.CopyTo(cpu_ctx);
...@@ -48,8 +48,8 @@ TypeArray GetNodeTypesFromMetapath( ...@@ -48,8 +48,8 @@ TypeArray GetNodeTypesFromMetapath(
dgl_type_t dsttype = src_dst_type.second; dgl_type_t dsttype = src_dst_type.second;
if (srctype != curr_type) { if (srctype != curr_type) {
LOG(FATAL) << "source of edge type #" << i << LOG(FATAL) << "source of edge type #" << i
" does not match destination of edge type #" << i - 1; << " does not match destination of edge type #" << i - 1;
} }
curr_type = dsttype; curr_type = dsttype;
h_result_data[i + 1] = dsttype; h_result_data[i + 1] = dsttype;
...@@ -60,14 +60,10 @@ TypeArray GetNodeTypesFromMetapath( ...@@ -60,14 +60,10 @@ TypeArray GetNodeTypesFromMetapath(
return result; return result;
} }
template template TypeArray GetNodeTypesFromMetapath<kDGLCUDA, int32_t>(
TypeArray GetNodeTypesFromMetapath<kDGLCUDA, int32_t>( const HeteroGraphPtr hg, const TypeArray metapath);
const HeteroGraphPtr hg, template TypeArray GetNodeTypesFromMetapath<kDGLCUDA, int64_t>(
const TypeArray metapath); const HeteroGraphPtr hg, const TypeArray metapath);
template
TypeArray GetNodeTypesFromMetapath<kDGLCUDA, int64_t>(
const HeteroGraphPtr hg,
const TypeArray metapath);
}; // namespace impl }; // namespace impl
......
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file graph/sampler/generic_randomwalk_cpu.h * \file graph/sampler/generic_randomwalk_cpu.h
* \brief DGL sampler - templated implementation definition of random walks on CPU * \brief DGL sampler - templated implementation definition of random walks on
* CPU
*/ */
#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_METAPATH_RANDOMWALK_H_ #ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_METAPATH_RANDOMWALK_H_
...@@ -10,11 +11,13 @@ ...@@ -10,11 +11,13 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <dgl/random.h> #include <dgl/random.h>
#include <utility>
#include <tuple> #include <tuple>
#include <utility>
#include <vector> #include <vector>
#include "randomwalks_impl.h"
#include "randomwalks_cpu.h" #include "randomwalks_cpu.h"
#include "randomwalks_impl.h"
namespace dgl { namespace dgl {
...@@ -31,53 +34,49 @@ namespace { ...@@ -31,53 +34,49 @@ namespace {
// IdxType *node_ids_generated_so_far, // IdxType *node_ids_generated_so_far,
// dgl_id_t last_node_id_generated, // dgl_id_t last_node_id_generated,
// int64_t number_of_nodes_generated_so_far) // int64_t number_of_nodes_generated_so_far)
template<typename IdxType> template <typename IdxType>
using TerminatePredicate = std::function<bool(IdxType *, dgl_id_t, int64_t)>; using TerminatePredicate = std::function<bool(IdxType *, dgl_id_t, int64_t)>;
/*! /*!
* \brief Select one successor of metapath-based random walk, given the path generated * \brief Select one successor of metapath-based random walk, given the path
* so far. * generated so far.
* *
* \param data The path generated so far, of type \c IdxType. * \param data The path generated so far, of type \c IdxType.
* \param curr The last node ID generated. * \param curr The last node ID generated.
* \param len The number of nodes generated so far. Note that the seed node is always * \param len The number of nodes generated so far. Note that the seed node is
* included as \c data[0], and the successors start from \c data[1]. * always included as \c data[0], and the successors start from \c data[1].
* *
* \param edges_by_type Vector of results from \c GetAdj() by edge type. * \param edges_by_type Vector of results from \c GetAdj() by edge type.
* \param metapath_data Edge types of given metapath. * \param metapath_data Edge types of given metapath.
* \param prob Transition probability per edge type. * \param prob Transition probability per edge type.
* \param terminate Predicate for terminating the current random walk path. * \param terminate Predicate for terminating the current random walk path.
* *
* \return A tuple of ID of next successor (-1 if not exist), the last traversed edge * \return A tuple of ID of next successor (-1 if not exist), the last traversed
* ID, as well as whether to terminate. * edge ID, as well as whether to terminate.
*/ */
template<DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep( std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
IdxType *data, IdxType *data, dgl_id_t curr, int64_t len,
dgl_id_t curr,
int64_t len,
const std::vector<CSRMatrix> &edges_by_type, const std::vector<CSRMatrix> &edges_by_type,
const std::vector<bool> &csr_has_data, const std::vector<bool> &csr_has_data, const IdxType *metapath_data,
const IdxType *metapath_data,
const std::vector<FloatArray> &prob, const std::vector<FloatArray> &prob,
TerminatePredicate<IdxType> terminate) { TerminatePredicate<IdxType> terminate) {
dgl_type_t etype = metapath_data[len]; dgl_type_t etype = metapath_data[len];
// Note that since the selection of successors is very lightweight (especially in the // Note that since the selection of successors is very lightweight (especially
// uniform case), we want to reduce the overheads (even from object copies or object // in the uniform case), we want to reduce the overheads (even from object
// construction) as much as possible. // copies or object construction) as much as possible. Using Successors()
// Using Successors() slows down by 2x. // slows down by 2x. Using OutEdges() slows down by 10x.
// Using OutEdges() slows down by 10x.
const CSRMatrix &csr = edges_by_type[etype]; const CSRMatrix &csr = edges_by_type[etype];
const IdxType *offsets = csr.indptr.Ptr<IdxType>(); const IdxType *offsets = csr.indptr.Ptr<IdxType>();
const IdxType *all_succ = csr.indices.Ptr<IdxType>(); const IdxType *all_succ = csr.indices.Ptr<IdxType>();
const IdxType *all_eids = csr_has_data[etype] ? csr.data.Ptr<IdxType>() : nullptr; const IdxType *all_eids =
csr_has_data[etype] ? csr.data.Ptr<IdxType>() : nullptr;
const IdxType *succ = all_succ + offsets[curr]; const IdxType *succ = all_succ + offsets[curr];
const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr; const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr;
const int64_t size = offsets[curr + 1] - offsets[curr]; const int64_t size = offsets[curr + 1] - offsets[curr];
if (size == 0) if (size == 0) return std::make_tuple(-1, -1, true);
return std::make_tuple(-1, -1, true);
// Use a reference to the original array instead of copying // Use a reference to the original array instead of copying
// This avoids updating the ref counts atomically from different threads // This avoids updating the ref counts atomically from different threads
...@@ -89,11 +88,13 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep( ...@@ -89,11 +88,13 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
idx = RandomEngine::ThreadLocal()->RandInt(size); idx = RandomEngine::ThreadLocal()->RandInt(size);
} else { } else {
ATEN_FLOAT_TYPE_SWITCH(prob_etype->dtype, DType, "probability", { ATEN_FLOAT_TYPE_SWITCH(prob_etype->dtype, DType, "probability", {
FloatArray prob_selected = FloatArray::Empty({size}, prob_etype->dtype, prob_etype->ctx); FloatArray prob_selected =
FloatArray::Empty({size}, prob_etype->dtype, prob_etype->ctx);
DType *prob_selected_data = prob_selected.Ptr<DType>(); DType *prob_selected_data = prob_selected.Ptr<DType>();
const DType *prob_etype_data = prob_etype.Ptr<DType>(); const DType *prob_etype_data = prob_etype.Ptr<DType>();
for (int64_t j = 0; j < size; ++j) for (int64_t j = 0; j < size; ++j)
prob_selected_data[j] = prob_etype_data[eids ? eids[j] : j + offsets[curr]]; prob_selected_data[j] =
prob_etype_data[eids ? eids[j] : j + offsets[curr]];
idx = RandomEngine::ThreadLocal()->Choice<IdxType>(prob_selected); idx = RandomEngine::ThreadLocal()->Choice<IdxType>(prob_selected);
}); });
} }
...@@ -103,49 +104,47 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep( ...@@ -103,49 +104,47 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStep(
} }
/*! /*!
* \brief Select one successor of metapath-based random walk, given the path generated * \brief Select one successor of metapath-based random walk, given the path
* so far specifically for the uniform probability distribution. * generated so far specifically for the uniform probability distribution.
* *
* \param data The path generated so far, of type \c IdxType. * \param data The path generated so far, of type \c IdxType.
* \param curr The last node ID generated. * \param curr The last node ID generated.
* \param len The number of nodes generated so far. Note that the seed node is always * \param len The number of nodes generated so far. Note that the seed node is
* included as \c data[0], and the successors start from \c data[1]. * always included as \c data[0], and the successors start from \c data[1].
* *
* \param edges_by_type Vector of results from \c GetAdj() by edge type. * \param edges_by_type Vector of results from \c GetAdj() by edge type.
* \param metapath_data Edge types of given metapath. * \param metapath_data Edge types of given metapath.
* \param prob Transition probability per edge type, for this special case this will be a NullArray * \param prob Transition probability per edge type, for this special case this
* \param terminate Predicate for terminating the current random walk path. * will be a NullArray \param terminate Predicate for terminating the current
* random walk path.
* *
* \return A pair of ID of next successor (-1 if not exist), as well as whether to terminate. * \return A pair of ID of next successor (-1 if not exist), as well as whether
* \note This function is called only if all the probability arrays are null. * to terminate. \note This function is called only if all the probability
* arrays are null.
*/ */
template<DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform( std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform(
IdxType *data, IdxType *data, dgl_id_t curr, int64_t len,
dgl_id_t curr,
int64_t len,
const std::vector<CSRMatrix> &edges_by_type, const std::vector<CSRMatrix> &edges_by_type,
const std::vector<bool> &csr_has_data, const std::vector<bool> &csr_has_data, const IdxType *metapath_data,
const IdxType *metapath_data,
const std::vector<FloatArray> &prob, const std::vector<FloatArray> &prob,
TerminatePredicate<IdxType> terminate) { TerminatePredicate<IdxType> terminate) {
dgl_type_t etype = metapath_data[len]; dgl_type_t etype = metapath_data[len];
// Note that since the selection of successors is very lightweight (especially in the // Note that since the selection of successors is very lightweight (especially
// uniform case), we want to reduce the overheads (even from object copies or object // in the uniform case), we want to reduce the overheads (even from object
// construction) as much as possible. // copies or object construction) as much as possible. Using Successors()
// Using Successors() slows down by 2x. // slows down by 2x. Using OutEdges() slows down by 10x.
// Using OutEdges() slows down by 10x.
const CSRMatrix &csr = edges_by_type[etype]; const CSRMatrix &csr = edges_by_type[etype];
const IdxType *offsets = csr.indptr.Ptr<IdxType>(); const IdxType *offsets = csr.indptr.Ptr<IdxType>();
const IdxType *all_succ = csr.indices.Ptr<IdxType>(); const IdxType *all_succ = csr.indices.Ptr<IdxType>();
const IdxType *all_eids = csr_has_data[etype] ? csr.data.Ptr<IdxType>() : nullptr; const IdxType *all_eids =
csr_has_data[etype] ? csr.data.Ptr<IdxType>() : nullptr;
const IdxType *succ = all_succ + offsets[curr]; const IdxType *succ = all_succ + offsets[curr];
const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr; const IdxType *eids = all_eids ? (all_eids + offsets[curr]) : nullptr;
const int64_t size = offsets[curr + 1] - offsets[curr]; const int64_t size = offsets[curr + 1] - offsets[curr];
if (size == 0) if (size == 0) return std::make_tuple(-1, -1, true);
return std::make_tuple(-1, -1, true);
IdxType idx = 0; IdxType idx = 0;
// Guaranteed uniform distribution // Guaranteed uniform distribution
...@@ -158,31 +157,31 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform( ...@@ -158,31 +157,31 @@ std::tuple<dgl_id_t, dgl_id_t, bool> MetapathRandomWalkStepUniform(
/*! /*!
* \brief Metapath-based random walk. * \brief Metapath-based random walk.
* \param hg The heterograph. * \param hg The heterograph.
* \param seeds A 1D array of seed nodes, with the type the source type of the first * \param seeds A 1D array of seed nodes, with the type the source type of the
* edge type in the metapath. * first edge type in the metapath. \param metapath A 1D array of edge types
* \param metapath A 1D array of edge types representing the metapath. * representing the metapath. \param prob A vector of 1D float arrays,
* \param prob A vector of 1D float arrays, indicating the transition probability of * indicating the transition probability of each edge by edge type. An empty
* each edge by edge type. An empty float array assumes uniform transition. * float array assumes uniform transition. \param terminate Predicate for
* \param terminate Predicate for terminating a random walk path. * terminating a random walk path. \return A 2D array of shape (len(seeds),
* \return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs, and * len(metapath) + 1) with node IDs, and A 2D array of shape (len(seeds),
* A 2D array of shape (len(seeds), len(metapath)) with edge IDs. * len(metapath)) with edge IDs.
*/ */
template<DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> MetapathBasedRandomWalk( std::pair<IdArray, IdArray> MetapathBasedRandomWalk(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob, const std::vector<FloatArray> &prob,
TerminatePredicate<IdxType> terminate) { TerminatePredicate<IdxType> terminate) {
int64_t max_num_steps = metapath->shape[0]; int64_t max_num_steps = metapath->shape[0];
const IdxType *metapath_data = static_cast<IdxType *>(metapath->data); const IdxType *metapath_data = static_cast<IdxType *>(metapath->data);
const int64_t begin_ntype = hg->meta_graph()->FindEdge(metapath_data[0]).first; const int64_t begin_ntype =
hg->meta_graph()->FindEdge(metapath_data[0]).first;
const int64_t max_nodes = hg->NumVertices(begin_ntype); const int64_t max_nodes = hg->NumVertices(begin_ntype);
// Prefetch all edges. // Prefetch all edges.
// This forces the heterograph to materialize all OutCSR's before the OpenMP loop; // This forces the heterograph to materialize all OutCSR's before the OpenMP
// otherwise data races will happen. // loop; otherwise data races will happen.
// TODO(BarclayII): should we later on materialize COO/CSR/CSC anyway unless told otherwise? // TODO(BarclayII): should we later on materialize COO/CSR/CSC anyway unless
// told otherwise?
int64_t num_etypes = hg->NumEdgeTypes(); int64_t num_etypes = hg->NumEdgeTypes();
std::vector<CSRMatrix> edges_by_type(num_etypes); std::vector<CSRMatrix> edges_by_type(num_etypes);
std::vector<bool> csr_has_data(num_etypes); std::vector<bool> csr_has_data(num_etypes);
...@@ -202,21 +201,25 @@ std::pair<IdArray, IdArray> MetapathBasedRandomWalk( ...@@ -202,21 +201,25 @@ std::pair<IdArray, IdArray> MetapathBasedRandomWalk(
} }
} }
if (!isUniform) { if (!isUniform) {
StepFunc<IdxType> step = StepFunc<IdxType> step = [&edges_by_type, &csr_has_data, metapath_data,
[&edges_by_type, &csr_has_data, metapath_data, &prob, terminate] &prob, terminate](
(IdxType *data, dgl_id_t curr, int64_t len) { IdxType *data, dgl_id_t curr, int64_t len) {
return MetapathRandomWalkStep<XPU, IdxType>( return MetapathRandomWalkStep<XPU, IdxType>(
data, curr, len, edges_by_type, csr_has_data, metapath_data, prob, terminate); data, curr, len, edges_by_type, csr_has_data, metapath_data, prob,
}; terminate);
return GenericRandomWalk<XPU, IdxType>(seeds, max_num_steps, step, max_nodes); };
return GenericRandomWalk<XPU, IdxType>(
seeds, max_num_steps, step, max_nodes);
} else { } else {
StepFunc<IdxType> step = StepFunc<IdxType> step = [&edges_by_type, &csr_has_data, metapath_data,
[&edges_by_type, &csr_has_data, metapath_data, &prob, terminate] &prob, terminate](
(IdxType *data, dgl_id_t curr, int64_t len) { IdxType *data, dgl_id_t curr, int64_t len) {
return MetapathRandomWalkStepUniform<XPU, IdxType>( return MetapathRandomWalkStepUniform<XPU, IdxType>(
data, curr, len, edges_by_type, csr_has_data, metapath_data, prob, terminate); data, curr, len, edges_by_type, csr_has_data, metapath_data, prob,
}; terminate);
return GenericRandomWalk<XPU, IdxType>(seeds, max_num_steps, step, max_nodes); };
return GenericRandomWalk<XPU, IdxType>(
seeds, max_num_steps, step, max_nodes);
} }
} }
......
...@@ -20,9 +20,9 @@ namespace sampling { ...@@ -20,9 +20,9 @@ namespace sampling {
namespace { namespace {
void CheckNode2vecInputs(const HeteroGraphPtr hg, const IdArray seeds, void CheckNode2vecInputs(
const double p, const double q, const HeteroGraphPtr hg, const IdArray seeds, const double p,
const int64_t walk_length, const FloatArray &prob) { const double q, const int64_t walk_length, const FloatArray &prob) {
CHECK_INT(seeds, "seeds"); CHECK_INT(seeds, "seeds");
CHECK_NDIM(seeds, 1, "seeds"); CHECK_NDIM(seeds, 1, "seeds");
CHECK_FLOAT(prob, "probability"); CHECK_FLOAT(prob, "probability");
...@@ -31,8 +31,7 @@ void CheckNode2vecInputs(const HeteroGraphPtr hg, const IdArray seeds, ...@@ -31,8 +31,7 @@ void CheckNode2vecInputs(const HeteroGraphPtr hg, const IdArray seeds,
std::pair<IdArray, IdArray> Node2vec( std::pair<IdArray, IdArray> Node2vec(
const HeteroGraphPtr hg, const IdArray seeds, const double p, const HeteroGraphPtr hg, const IdArray seeds, const double p,
const double q, const int64_t walk_length, const double q, const int64_t walk_length, const FloatArray &prob) {
const FloatArray &prob) {
CheckNode2vecInputs(hg, seeds, p, q, walk_length, prob); CheckNode2vecInputs(hg, seeds, p, q, walk_length, prob);
std::pair<IdArray, IdArray> result; std::pair<IdArray, IdArray> result;
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <utility> #include <utility>
#include "node2vec_randomwalk.h" #include "node2vec_randomwalk.h"
...@@ -22,27 +23,20 @@ namespace impl { ...@@ -22,27 +23,20 @@ namespace impl {
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> Node2vec( std::pair<IdArray, IdArray> Node2vec(
const HeteroGraphPtr hg, const IdArray seeds, const double p, const HeteroGraphPtr hg, const IdArray seeds, const double p,
const double q, const int64_t walk_length, const double q, const int64_t walk_length, const FloatArray &prob) {
const FloatArray &prob) {
TerminatePredicate<IdxType> terminate = [](IdxType *data, dgl_id_t curr, TerminatePredicate<IdxType> terminate = [](IdxType *data, dgl_id_t curr,
int64_t len) { return false; }; int64_t len) { return false; };
return Node2vecRandomWalk<XPU, IdxType>(hg, seeds, p, q, walk_length, prob, return Node2vecRandomWalk<XPU, IdxType>(
terminate); hg, seeds, p, q, walk_length, prob, terminate);
} }
template std::pair<IdArray, IdArray> Node2vec<kDGLCPU, int32_t>( template std::pair<IdArray, IdArray> Node2vec<kDGLCPU, int32_t>(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const double p,
const IdArray seeds, const double p, const double q, const int64_t walk_length, const FloatArray &prob);
const double q,
const int64_t walk_length,
const FloatArray &prob);
template std::pair<IdArray, IdArray> Node2vec<kDGLCPU, int64_t>( template std::pair<IdArray, IdArray> Node2vec<kDGLCPU, int64_t>(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const double p,
const IdArray seeds, const double p, const double q, const int64_t walk_length, const FloatArray &prob);
const double q,
const int64_t walk_length,
const FloatArray &prob);
}; // namespace impl }; // namespace impl
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <functional> #include <functional>
#include <tuple>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <tuple>
namespace dgl { namespace dgl {
...@@ -30,21 +30,18 @@ namespace impl { ...@@ -30,21 +30,18 @@ namespace impl {
* \param hg The heterograph. * \param hg The heterograph.
* \param seeds A 1D array of seed nodes, with the type the source type of the * \param seeds A 1D array of seed nodes, with the type the source type of the
* first edge type in the metapath. * first edge type in the metapath.
* \param p Float, indicating likelihood of immediately revisiting a node in the walk. * \param p Float, indicating likelihood of immediately revisiting a node in the
* \param q Float, control parameter to interpolate between breadth-first strategy and * walk. \param q Float, control parameter to interpolate between breadth-first
* depth-first strategy. * strategy and depth-first strategy. \param walk_length Int, length of walk.
* \param walk_length Int, length of walk.
* \param prob A vector of 1D float arrays, indicating the transition * \param prob A vector of 1D float arrays, indicating the transition
* probability of each edge by edge type. An empty float array assumes uniform * probability of each edge by edge type. An empty float array assumes
* transition. * uniform transition. \return A 2D array of shape (len(seeds), len(walk_length)
* \return A 2D array of shape (len(seeds), len(walk_length) + 1) * + 1) with node IDs. The paths that terminated early are padded with -1.
* with node IDs. The paths that terminated early are padded with -1.
*/ */
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> Node2vec( std::pair<IdArray, IdArray> Node2vec(
const HeteroGraphPtr hg, const IdArray seeds, const double p, const HeteroGraphPtr hg, const IdArray seeds, const double p,
const double q, const int64_t walk_length, const double q, const int64_t walk_length, const FloatArray &prob);
const FloatArray &prob);
}; // namespace impl }; // namespace impl
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <functional> #include <functional>
#include <tuple>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <tuple>
#include "metapath_randomwalk.h" // for TerminatePredicate
#include "node2vec_impl.h" #include "node2vec_impl.h"
#include "randomwalks_cpu.h" #include "randomwalks_cpu.h"
#include "metapath_randomwalk.h" // for TerminatePredicate
namespace dgl { namespace dgl {
...@@ -34,8 +34,7 @@ namespace impl { ...@@ -34,8 +34,7 @@ namespace impl {
namespace { namespace {
template <typename IdxType> template <typename IdxType>
bool has_edge_between(const CSRMatrix &csr, dgl_id_t u, bool has_edge_between(const CSRMatrix &csr, dgl_id_t u, dgl_id_t v) {
dgl_id_t v) {
const IdxType *offsets = csr.indptr.Ptr<IdxType>(); const IdxType *offsets = csr.indptr.Ptr<IdxType>();
const IdxType *all_succ = csr.indices.Ptr<IdxType>(); const IdxType *all_succ = csr.indices.Ptr<IdxType>();
const IdxType *u_succ = all_succ + offsets[u]; const IdxType *u_succ = all_succ + offsets[u];
...@@ -57,19 +56,18 @@ bool has_edge_between(const CSRMatrix &csr, dgl_id_t u, ...@@ -57,19 +56,18 @@ bool has_edge_between(const CSRMatrix &csr, dgl_id_t u,
* \param q Float, control parameter to interpolate between breadth-first * \param q Float, control parameter to interpolate between breadth-first
* strategy and depth-first strategy. * strategy and depth-first strategy.
* \param len The number of nodes generated so far. Note that the seed node is * \param len The number of nodes generated so far. Note that the seed node is
* always included as \c data[0], and the successors start from \c data[1]. * always included as \c data[0], and the successors start from \c
* \param csr The CSR matrix * data[1]. \param csr The CSR matrix \param prob Transition probability \param
* \param prob Transition probability * terminate Predicate for terminating the current random walk path. \return A
* \param terminate Predicate for terminating the current random walk path. * tuple of ID of next successor (-1 if not exist), the edge ID traversed, as
* \return A tuple of ID of next successor (-1 if not exist), the edge ID traversed, * well as whether to terminate.
* as well as whether to terminate.
*/ */
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::tuple<dgl_id_t, dgl_id_t, bool> Node2vecRandomWalkStep( std::tuple<dgl_id_t, dgl_id_t, bool> Node2vecRandomWalkStep(
IdxType *data, dgl_id_t curr, dgl_id_t pre, const double p, const double q, IdxType *data, dgl_id_t curr, dgl_id_t pre, const double p, const double q,
int64_t len, const CSRMatrix &csr, bool csr_has_data, const FloatArray &probs, int64_t len, const CSRMatrix &csr, bool csr_has_data,
TerminatePredicate<IdxType> terminate) { const FloatArray &probs, TerminatePredicate<IdxType> terminate) {
const IdxType *offsets = csr.indptr.Ptr<IdxType>(); const IdxType *offsets = csr.indptr.Ptr<IdxType>();
const IdxType *all_succ = csr.indices.Ptr<IdxType>(); const IdxType *all_succ = csr.indices.Ptr<IdxType>();
const IdxType *all_eids = csr_has_data ? csr.data.Ptr<IdxType>() : nullptr; const IdxType *all_eids = csr_has_data ? csr.data.Ptr<IdxType>() : nullptr;
...@@ -120,7 +118,8 @@ std::tuple<dgl_id_t, dgl_id_t, bool> Node2vecRandomWalkStep( ...@@ -120,7 +118,8 @@ std::tuple<dgl_id_t, dgl_id_t, bool> Node2vecRandomWalkStep(
DType *prob_selected_data = prob_selected.Ptr<DType>(); DType *prob_selected_data = prob_selected.Ptr<DType>();
const DType *prob_etype_data = probs.Ptr<DType>(); const DType *prob_etype_data = probs.Ptr<DType>();
for (int64_t j = 0; j < size; ++j) for (int64_t j = 0; j < size; ++j)
prob_selected_data[j] = prob_etype_data[eids ? eids[j] : j + offsets[curr]]; prob_selected_data[j] =
prob_etype_data[eids ? eids[j] : j + offsets[curr]];
}); });
if (len == 0) { if (len == 0) {
...@@ -148,22 +147,21 @@ std::tuple<dgl_id_t, dgl_id_t, bool> Node2vecRandomWalkStep( ...@@ -148,22 +147,21 @@ std::tuple<dgl_id_t, dgl_id_t, bool> Node2vecRandomWalkStep(
template <DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> Node2vecRandomWalk( std::pair<IdArray, IdArray> Node2vecRandomWalk(
const HeteroGraphPtr g, const IdArray seeds, const HeteroGraphPtr g, const IdArray seeds, const double p, const double q,
const double p, const double q,
const int64_t max_num_steps, const FloatArray &prob, const int64_t max_num_steps, const FloatArray &prob,
TerminatePredicate<IdxType> terminate) { TerminatePredicate<IdxType> terminate) {
const CSRMatrix &edges = g->GetCSRMatrix(0); // homogeneous graph. const CSRMatrix &edges = g->GetCSRMatrix(0); // homogeneous graph.
bool csr_has_data = CSRHasData(edges); bool csr_has_data = CSRHasData(edges);
StepFunc<IdxType> step = StepFunc<IdxType> step = [&edges, csr_has_data, &prob, p, q, terminate](
[&edges, csr_has_data, &prob, p, q, terminate] IdxType *data, dgl_id_t curr, int64_t len) {
(IdxType *data, dgl_id_t curr, int64_t len) { dgl_id_t pre = (len != 0) ? data[len - 1] : curr;
dgl_id_t pre = (len != 0) ? data[len - 1] : curr; return Node2vecRandomWalkStep<XPU, IdxType>(
return Node2vecRandomWalkStep<XPU, IdxType>(data, curr, pre, p, q, len, data, curr, pre, p, q, len, edges, csr_has_data, prob, terminate);
edges, csr_has_data, prob, terminate); };
};
return GenericRandomWalk<XPU, IdxType>(seeds, max_num_steps, step, g->NumVertices(0)); return GenericRandomWalk<XPU, IdxType>(
seeds, max_num_steps, step, g->NumVertices(0));
} }
}; // namespace }; // namespace
......
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file graph/sampling/randomwalk_cpu.cc * \file graph/sampling/randomwalk_cpu.cc
* \brief DGL sampler - CPU implementation of metapath-based random walk with OpenMP * \brief DGL sampler - CPU implementation of metapath-based random walk with
* OpenMP
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <vector>
#include <utility>
#include <algorithm> #include <algorithm>
#include "randomwalks_impl.h" #include <utility>
#include "randomwalks_cpu.h" #include <vector>
#include "metapath_randomwalk.h" #include "metapath_randomwalk.h"
#include "randomwalks_cpu.h"
#include "randomwalks_impl.h"
namespace dgl { namespace dgl {
...@@ -23,30 +26,25 @@ namespace sampling { ...@@ -23,30 +26,25 @@ namespace sampling {
namespace impl { namespace impl {
template<DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> RandomWalk( std::pair<IdArray, IdArray> RandomWalk(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob) { const std::vector<FloatArray> &prob) {
TerminatePredicate<IdxType> terminate = TerminatePredicate<IdxType> terminate = [](IdxType *data, dgl_id_t curr,
[] (IdxType *data, dgl_id_t curr, int64_t len) { int64_t len) { return false; };
return false;
};
return MetapathBasedRandomWalk<XPU, IdxType>(hg, seeds, metapath, prob, terminate); return MetapathBasedRandomWalk<XPU, IdxType>(
hg, seeds, metapath, prob, terminate);
} }
template<DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors( std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
const IdArray src, const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
const IdArray dst,
const int64_t num_samples_per_node,
const int64_t k) { const int64_t k) {
CHECK(src->ctx.device_type == kDGLCPU) << "IdArray needs be on CPU!"; CHECK(src->ctx.device_type == kDGLCPU) << "IdArray needs be on CPU!";
int64_t len = src->shape[0] / num_samples_per_node; int64_t len = src->shape[0] / num_samples_per_node;
IdxType* src_data = src.Ptr<IdxType>(); IdxType *src_data = src.Ptr<IdxType>();
const IdxType* dst_data = dst.Ptr<IdxType>(); const IdxType *dst_data = dst.Ptr<IdxType>();
std::vector<IdxType> res_src_vec, res_dst_vec, res_cnt_vec; std::vector<IdxType> res_src_vec, res_dst_vec, res_cnt_vec;
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
int64_t start_idx = (i * num_samples_per_node); int64_t start_idx = (i * num_samples_per_node);
...@@ -56,20 +54,20 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors( ...@@ -56,20 +54,20 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
int64_t cnt = 0; int64_t cnt = 0;
std::vector<std::pair<IdxType, IdxType>> vec; std::vector<std::pair<IdxType, IdxType>> vec;
for (int64_t j = start_idx; j < end_idx; ++j) { for (int64_t j = start_idx; j < end_idx; ++j) {
if ((j != start_idx) && (src_data[j] != src_data[j-1])) { if ((j != start_idx) && (src_data[j] != src_data[j - 1])) {
if (src_data[j-1] != -1) { if (src_data[j - 1] != -1) {
vec.emplace_back(std::make_pair(cnt, src_data[j-1])); vec.emplace_back(std::make_pair(cnt, src_data[j - 1]));
} }
cnt = 0; cnt = 0;
} }
++cnt; ++cnt;
} }
// add last count // add last count
if (src_data[end_idx-1] != -1) { if (src_data[end_idx - 1] != -1) {
vec.emplace_back(std::make_pair(cnt, src_data[end_idx-1])); vec.emplace_back(std::make_pair(cnt, src_data[end_idx - 1]));
} }
std::sort(vec.begin(), vec.end(), std::sort(
std::greater<std::pair<IdxType, IdxType>>()); vec.begin(), vec.end(), std::greater<std::pair<IdxType, IdxType>>());
int64_t len = std::min(vec.size(), static_cast<size_t>(k)); int64_t len = std::min(vec.size(), static_cast<size_t>(k));
for (int64_t j = 0; j < len; ++j) { for (int64_t j = 0; j < len; ++j) {
auto pair_item = vec[j]; auto pair_item = vec[j];
...@@ -78,58 +76,45 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors( ...@@ -78,58 +76,45 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
res_cnt_vec.emplace_back(pair_item.first); res_cnt_vec.emplace_back(pair_item.first);
} }
} }
IdArray res_src = IdArray::Empty({static_cast<int64_t>(res_src_vec.size())}, IdArray res_src = IdArray::Empty(
src->dtype, src->ctx); {static_cast<int64_t>(res_src_vec.size())}, src->dtype, src->ctx);
IdArray res_dst = IdArray::Empty({static_cast<int64_t>(res_dst_vec.size())}, IdArray res_dst = IdArray::Empty(
dst->dtype, dst->ctx); {static_cast<int64_t>(res_dst_vec.size())}, dst->dtype, dst->ctx);
IdArray res_cnt = IdArray::Empty({static_cast<int64_t>(res_cnt_vec.size())}, IdArray res_cnt = IdArray::Empty(
src->dtype, src->ctx); {static_cast<int64_t>(res_cnt_vec.size())}, src->dtype, src->ctx);
// copy data from vector to NDArray // copy data from vector to NDArray
auto device = runtime::DeviceAPI::Get(src->ctx); auto device = runtime::DeviceAPI::Get(src->ctx);
device->CopyDataFromTo(static_cast<IdxType*>(res_src_vec.data()), 0, device->CopyDataFromTo(
res_src.Ptr<IdxType>(), 0, static_cast<IdxType *>(res_src_vec.data()), 0, res_src.Ptr<IdxType>(), 0,
sizeof(IdxType) * res_src_vec.size(), sizeof(IdxType) * res_src_vec.size(), DGLContext{kDGLCPU, 0},
DGLContext{kDGLCPU, 0}, res_src->ctx, res_src->ctx, res_src->dtype);
res_src->dtype); device->CopyDataFromTo(
device->CopyDataFromTo(static_cast<IdxType*>(res_dst_vec.data()), 0, static_cast<IdxType *>(res_dst_vec.data()), 0, res_dst.Ptr<IdxType>(), 0,
res_dst.Ptr<IdxType>(), 0, sizeof(IdxType) * res_dst_vec.size(), DGLContext{kDGLCPU, 0},
sizeof(IdxType) * res_dst_vec.size(), res_dst->ctx, res_dst->dtype);
DGLContext{kDGLCPU, 0}, res_dst->ctx, device->CopyDataFromTo(
res_dst->dtype); static_cast<IdxType *>(res_cnt_vec.data()), 0, res_cnt.Ptr<IdxType>(), 0,
device->CopyDataFromTo(static_cast<IdxType*>(res_cnt_vec.data()), 0, sizeof(IdxType) * res_cnt_vec.size(), DGLContext{kDGLCPU, 0},
res_cnt.Ptr<IdxType>(), 0, res_cnt->ctx, res_cnt->dtype);
sizeof(IdxType) * res_cnt_vec.size(),
DGLContext{kDGLCPU, 0}, res_cnt->ctx,
res_cnt->dtype);
return std::make_tuple(res_src, res_dst, res_cnt); return std::make_tuple(res_src, res_dst, res_cnt);
} }
template template std::pair<IdArray, IdArray> RandomWalk<kDGLCPU, int32_t>(
std::pair<IdArray, IdArray> RandomWalk<kDGLCPU, int32_t>( const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob); const std::vector<FloatArray> &prob);
template template std::pair<IdArray, IdArray> RandomWalk<kDGLCPU, int64_t>(
std::pair<IdArray, IdArray> RandomWalk<kDGLCPU, int64_t>( const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob); const std::vector<FloatArray> &prob);
template template std::tuple<IdArray, IdArray, IdArray>
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDGLCPU, int32_t>( SelectPinSageNeighbors<kDGLCPU, int32_t>(
const IdArray src, const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
const IdArray dst,
const int64_t num_samples_per_node,
const int64_t k); const int64_t k);
template template std::tuple<IdArray, IdArray, IdArray>
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDGLCPU, int64_t>( SelectPinSageNeighbors<kDGLCPU, int64_t>(
const IdArray src, const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
const IdArray dst,
const int64_t num_samples_per_node,
const int64_t k); const int64_t k);
}; // namespace impl }; // namespace impl
......
...@@ -4,14 +4,15 @@ ...@@ -4,14 +4,15 @@
* \brief CUDA random walk sampleing * \brief CUDA random walk sampleing
*/ */
#include <curand_kernel.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <dgl/runtime/device_api.h>
#include <dgl/random.h> #include <dgl/random.h>
#include <curand_kernel.h> #include <dgl/runtime/device_api.h>
#include <vector>
#include <utility>
#include <tuple> #include <tuple>
#include <utility>
#include <vector>
#include "../../../array/cuda/dgl_cub.cuh" #include "../../../array/cuda/dgl_cub.cuh"
#include "../../../runtime/cuda/cuda_common.h" #include "../../../runtime/cuda/cuda_common.h"
...@@ -28,26 +29,24 @@ namespace impl { ...@@ -28,26 +29,24 @@ namespace impl {
namespace { namespace {
template<typename IdType> template <typename IdType>
struct GraphKernelData { struct GraphKernelData {
const IdType *in_ptr; const IdType *in_ptr;
const IdType *in_cols; const IdType *in_cols;
const IdType *data; const IdType *data;
}; };
template<typename IdType, typename FloatType, int BLOCK_SIZE, int TILE_SIZE> template <typename IdType, typename FloatType, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _RandomWalkKernel( __global__ void _RandomWalkKernel(
const uint64_t rand_seed, const IdType *seed_data, const int64_t num_seeds, const uint64_t rand_seed, const IdType *seed_data, const int64_t num_seeds,
const IdType* metapath_data, const uint64_t max_num_steps, const IdType *metapath_data, const uint64_t max_num_steps,
const GraphKernelData<IdType>* graphs, const GraphKernelData<IdType> *graphs, const FloatType *restart_prob_data,
const FloatType* restart_prob_data, const int64_t restart_prob_size, const int64_t max_nodes,
const int64_t restart_prob_size, IdType *out_traces_data, IdType *out_eids_data) {
const int64_t max_nodes,
IdType *out_traces_data,
IdType *out_eids_data) {
assert(BLOCK_SIZE == blockDim.x); assert(BLOCK_SIZE == blockDim.x);
int64_t idx = blockIdx.x * TILE_SIZE + threadIdx.x; int64_t idx = blockIdx.x * TILE_SIZE + threadIdx.x;
int64_t last_idx = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_seeds); int64_t last_idx =
min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_seeds);
int64_t trace_length = (max_num_steps + 1); int64_t trace_length = (max_num_steps + 1);
curandState rng; curandState rng;
// reference: // reference:
...@@ -71,15 +70,20 @@ __global__ void _RandomWalkKernel( ...@@ -71,15 +70,20 @@ __global__ void _RandomWalkKernel(
} }
const int64_t num = curand(&rng) % deg; const int64_t num = curand(&rng) % deg;
IdType pick = graph.in_cols[in_row_start + num]; IdType pick = graph.in_cols[in_row_start + num];
IdType eid = (graph.data? graph.data[in_row_start + num] : in_row_start + num); IdType eid =
(graph.data ? graph.data[in_row_start + num] : in_row_start + num);
*traces_data_ptr = pick; *traces_data_ptr = pick;
*eids_data_ptr = eid; *eids_data_ptr = eid;
if ((restart_prob_size > 1) && (curand_uniform(&rng) < restart_prob_data[step_idx])) { if ((restart_prob_size > 1) &&
(curand_uniform(&rng) < restart_prob_data[step_idx])) {
break; break;
} else if ((restart_prob_size == 1) && (curand_uniform(&rng) < restart_prob_data[0])) { } else if (
(restart_prob_size == 1) &&
(curand_uniform(&rng) < restart_prob_data[0])) {
break; break;
} }
++traces_data_ptr; ++eids_data_ptr; ++traces_data_ptr;
++eids_data_ptr;
curr = pick; curr = pick;
} }
for (; step_idx < max_num_steps; ++step_idx) { for (; step_idx < max_num_steps; ++step_idx) {
...@@ -92,22 +96,16 @@ __global__ void _RandomWalkKernel( ...@@ -92,22 +96,16 @@ __global__ void _RandomWalkKernel(
template <typename IdType, typename FloatType, int BLOCK_SIZE, int TILE_SIZE> template <typename IdType, typename FloatType, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _RandomWalkBiasedKernel( __global__ void _RandomWalkBiasedKernel(
const uint64_t rand_seed, const uint64_t rand_seed, const IdType *seed_data, const int64_t num_seeds,
const IdType *seed_data, const IdType *metapath_data, const uint64_t max_num_steps,
const int64_t num_seeds, const GraphKernelData<IdType> *graphs, const FloatType **probs,
const IdType *metapath_data, const FloatType **prob_sums, const FloatType *restart_prob_data,
const uint64_t max_num_steps, const int64_t restart_prob_size, const int64_t max_nodes,
const GraphKernelData<IdType> *graphs, IdType *out_traces_data, IdType *out_eids_data) {
const FloatType **probs,
const FloatType **prob_sums,
const FloatType *restart_prob_data,
const int64_t restart_prob_size,
const int64_t max_nodes,
IdType *out_traces_data,
IdType *out_eids_data) {
assert(BLOCK_SIZE == blockDim.x); assert(BLOCK_SIZE == blockDim.x);
int64_t idx = blockIdx.x * TILE_SIZE + threadIdx.x; int64_t idx = blockIdx.x * TILE_SIZE + threadIdx.x;
int64_t last_idx = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_seeds); int64_t last_idx =
min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_seeds);
int64_t trace_length = (max_num_steps + 1); int64_t trace_length = (max_num_steps + 1);
curandState rng; curandState rng;
// reference: // reference:
...@@ -146,15 +144,20 @@ __global__ void _RandomWalkBiasedKernel( ...@@ -146,15 +144,20 @@ __global__ void _RandomWalkBiasedKernel(
} }
IdType pick = graph.in_cols[in_row_start + num]; IdType pick = graph.in_cols[in_row_start + num];
IdType eid = (graph.data? graph.data[in_row_start + num] : in_row_start + num); IdType eid =
(graph.data ? graph.data[in_row_start + num] : in_row_start + num);
*traces_data_ptr = pick; *traces_data_ptr = pick;
*eids_data_ptr = eid; *eids_data_ptr = eid;
if ((restart_prob_size > 1) && (curand_uniform(&rng) < restart_prob_data[step_idx])) { if ((restart_prob_size > 1) &&
(curand_uniform(&rng) < restart_prob_data[step_idx])) {
break; break;
} else if ((restart_prob_size == 1) && (curand_uniform(&rng) < restart_prob_data[0])) { } else if (
(restart_prob_size == 1) &&
(curand_uniform(&rng) < restart_prob_data[0])) {
break; break;
} }
++traces_data_ptr; ++eids_data_ptr; ++traces_data_ptr;
++eids_data_ptr;
curr = pick; curr = pick;
} }
for (; step_idx < max_num_steps; ++step_idx) { for (; step_idx < max_num_steps; ++step_idx) {
...@@ -168,20 +171,19 @@ __global__ void _RandomWalkBiasedKernel( ...@@ -168,20 +171,19 @@ __global__ void _RandomWalkBiasedKernel(
} // namespace } // namespace
// random walk for uniform choice // random walk for uniform choice
template<DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> RandomWalkUniform( std::pair<IdArray, IdArray> RandomWalkUniform(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds,
const TypeArray metapath,
FloatArray restart_prob) { FloatArray restart_prob) {
const int64_t max_num_steps = metapath->shape[0]; const int64_t max_num_steps = metapath->shape[0];
const IdType *metapath_data = static_cast<IdType *>(metapath->data); const IdType *metapath_data = static_cast<IdType *>(metapath->data);
const int64_t begin_ntype = hg->meta_graph()->FindEdge(metapath_data[0]).first; const int64_t begin_ntype =
hg->meta_graph()->FindEdge(metapath_data[0]).first;
const int64_t max_nodes = hg->NumVertices(begin_ntype); const int64_t max_nodes = hg->NumVertices(begin_ntype);
int64_t num_etypes = hg->NumEdgeTypes(); int64_t num_etypes = hg->NumEdgeTypes();
auto ctx = seeds->ctx; auto ctx = seeds->ctx;
const IdType *seed_data = static_cast<const IdType*>(seeds->data); const IdType *seed_data = static_cast<const IdType *>(seeds->data);
CHECK(seeds->ndim == 1) << "seeds shape is not one dimension."; CHECK(seeds->ndim == 1) << "seeds shape is not one dimension.";
const int64_t num_seeds = seeds->shape[0]; const int64_t num_seeds = seeds->shape[0];
int64_t trace_length = max_num_steps + 1; int64_t trace_length = max_num_steps + 1;
...@@ -193,21 +195,22 @@ std::pair<IdArray, IdArray> RandomWalkUniform( ...@@ -193,21 +195,22 @@ std::pair<IdArray, IdArray> RandomWalkUniform(
std::vector<GraphKernelData<IdType>> h_graphs(num_etypes); std::vector<GraphKernelData<IdType>> h_graphs(num_etypes);
for (int64_t etype = 0; etype < num_etypes; ++etype) { for (int64_t etype = 0; etype < num_etypes; ++etype) {
const CSRMatrix &csr = hg->GetCSRMatrix(etype); const CSRMatrix &csr = hg->GetCSRMatrix(etype);
h_graphs[etype].in_ptr = static_cast<const IdType*>(csr.indptr->data); h_graphs[etype].in_ptr = static_cast<const IdType *>(csr.indptr->data);
h_graphs[etype].in_cols = static_cast<const IdType*>(csr.indices->data); h_graphs[etype].in_cols = static_cast<const IdType *>(csr.indices->data);
h_graphs[etype].data = (CSRHasData(csr) ? static_cast<const IdType*>(csr.data->data) : nullptr); h_graphs[etype].data =
(CSRHasData(csr) ? static_cast<const IdType *>(csr.data->data)
: nullptr);
} }
// use cuda stream from local thread // use cuda stream from local thread
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = DeviceAPI::Get(ctx); auto device = DeviceAPI::Get(ctx);
auto d_graphs = static_cast<GraphKernelData<IdType>*>( auto d_graphs = static_cast<GraphKernelData<IdType> *>(device->AllocWorkspace(
device->AllocWorkspace(ctx, (num_etypes) * sizeof(GraphKernelData<IdType>))); ctx, (num_etypes) * sizeof(GraphKernelData<IdType>)));
// copy graph metadata pointers to GPU // copy graph metadata pointers to GPU
device->CopyDataFromTo(h_graphs.data(), 0, d_graphs, 0, device->CopyDataFromTo(
(num_etypes) * sizeof(GraphKernelData<IdType>), h_graphs.data(), 0, d_graphs, 0,
DGLContext{kDGLCPU, 0}, (num_etypes) * sizeof(GraphKernelData<IdType>), DGLContext{kDGLCPU, 0},
ctx, ctx, hg->GetCSRMatrix(0).indptr->dtype);
hg->GetCSRMatrix(0).indptr->dtype);
// copy metapath to GPU // copy metapath to GPU
auto d_metapath = metapath.CopyTo(ctx); auto d_metapath = metapath.CopyTo(ctx);
const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data); const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data);
...@@ -217,50 +220,41 @@ std::pair<IdArray, IdArray> RandomWalkUniform( ...@@ -217,50 +220,41 @@ std::pair<IdArray, IdArray> RandomWalkUniform(
dim3 block(256); dim3 block(256);
dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE); dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE);
const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000); const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
ATEN_FLOAT_TYPE_SWITCH(restart_prob->dtype, FloatType, "random walk GPU kernel", { ATEN_FLOAT_TYPE_SWITCH(
CHECK(restart_prob->ctx.device_type == kDGLCUDA) << "restart prob should be in GPU."; restart_prob->dtype, FloatType, "random walk GPU kernel", {
CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1."; CHECK(restart_prob->ctx.device_type == kDGLCUDA)
const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>(); << "restart prob should be in GPU.";
const int64_t restart_prob_size = restart_prob->shape[0]; CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1.";
CUDA_KERNEL_CALL( const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>();
(_RandomWalkKernel<IdType, FloatType, BLOCK_SIZE, TILE_SIZE>), const int64_t restart_prob_size = restart_prob->shape[0];
grid, block, 0, stream, CUDA_KERNEL_CALL(
random_seed, (_RandomWalkKernel<IdType, FloatType, BLOCK_SIZE, TILE_SIZE>), grid,
seed_data, block, 0, stream, random_seed, seed_data, num_seeds,
num_seeds, d_metapath_data, max_num_steps, d_graphs, restart_prob_data,
d_metapath_data, restart_prob_size, max_nodes, traces_data, eids_data);
max_num_steps, });
d_graphs,
restart_prob_data,
restart_prob_size,
max_nodes,
traces_data,
eids_data);
});
device->FreeWorkspace(ctx, d_graphs); device->FreeWorkspace(ctx, d_graphs);
return std::make_pair(traces, eids); return std::make_pair(traces, eids);
} }
/** /**
* \brief Random walk for biased choice. We use inverse transform sampling to * \brief Random walk for biased choice. We use inverse transform sampling to
* choose the next step. * choose the next step.
*/ */
template <DGLDeviceType XPU, typename FloatType, typename IdType> template <DGLDeviceType XPU, typename FloatType, typename IdType>
std::pair<IdArray, IdArray> RandomWalkBiased( std::pair<IdArray, IdArray> RandomWalkBiased(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds, const std::vector<FloatArray> &prob, FloatArray restart_prob) {
const TypeArray metapath,
const std::vector<FloatArray> &prob,
FloatArray restart_prob) {
const int64_t max_num_steps = metapath->shape[0]; const int64_t max_num_steps = metapath->shape[0];
const IdType *metapath_data = static_cast<IdType *>(metapath->data); const IdType *metapath_data = static_cast<IdType *>(metapath->data);
const int64_t begin_ntype = hg->meta_graph()->FindEdge(metapath_data[0]).first; const int64_t begin_ntype =
hg->meta_graph()->FindEdge(metapath_data[0]).first;
const int64_t max_nodes = hg->NumVertices(begin_ntype); const int64_t max_nodes = hg->NumVertices(begin_ntype);
int64_t num_etypes = hg->NumEdgeTypes(); int64_t num_etypes = hg->NumEdgeTypes();
auto ctx = seeds->ctx; auto ctx = seeds->ctx;
const IdType *seed_data = static_cast<const IdType*>(seeds->data); const IdType *seed_data = static_cast<const IdType *>(seeds->data);
CHECK(seeds->ndim == 1) << "seeds shape is not one dimension."; CHECK(seeds->ndim == 1) << "seeds shape is not one dimension.";
const int64_t num_seeds = seeds->shape[0]; const int64_t num_seeds = seeds->shape[0];
int64_t trace_length = max_num_steps + 1; int64_t trace_length = max_num_steps + 1;
...@@ -282,9 +276,11 @@ std::pair<IdArray, IdArray> RandomWalkBiased( ...@@ -282,9 +276,11 @@ std::pair<IdArray, IdArray> RandomWalkBiased(
std::vector<GraphKernelData<IdType>> h_graphs(num_etypes); std::vector<GraphKernelData<IdType>> h_graphs(num_etypes);
for (int64_t etype = 0; etype < num_etypes; ++etype) { for (int64_t etype = 0; etype < num_etypes; ++etype) {
const CSRMatrix &csr = hg->GetCSRMatrix(etype); const CSRMatrix &csr = hg->GetCSRMatrix(etype);
h_graphs[etype].in_ptr = static_cast<const IdType*>(csr.indptr->data); h_graphs[etype].in_ptr = static_cast<const IdType *>(csr.indptr->data);
h_graphs[etype].in_cols = static_cast<const IdType*>(csr.indices->data); h_graphs[etype].in_cols = static_cast<const IdType *>(csr.indices->data);
h_graphs[etype].data = (CSRHasData(csr) ? static_cast<const IdType*>(csr.data->data) : nullptr); h_graphs[etype].data =
(CSRHasData(csr) ? static_cast<const IdType *>(csr.data->data)
: nullptr);
int64_t num_segments = csr.indptr->shape[0] - 1; int64_t num_segments = csr.indptr->shape[0] - 1;
// will handle empty probs in the kernel // will handle empty probs in the kernel
...@@ -294,52 +290,42 @@ std::pair<IdArray, IdArray> RandomWalkBiased( ...@@ -294,52 +290,42 @@ std::pair<IdArray, IdArray> RandomWalkBiased(
continue; continue;
} }
probs[etype] = prob[etype].Ptr<FloatType>(); probs[etype] = prob[etype].Ptr<FloatType>();
prob_sums_arr.push_back(FloatArray::Empty({num_segments}, prob[etype]->dtype, ctx)); prob_sums_arr.push_back(
FloatArray::Empty({num_segments}, prob[etype]->dtype, ctx));
prob_sums[etype] = prob_sums_arr[etype].Ptr<FloatType>(); prob_sums[etype] = prob_sums_arr[etype].Ptr<FloatType>();
// calculate the sum of the neighbor weights // calculate the sum of the neighbor weights
const IdType *d_offsets = static_cast<const IdType*>(csr.indptr->data); const IdType *d_offsets = static_cast<const IdType *>(csr.indptr->data);
size_t temp_storage_size = 0; size_t temp_storage_size = 0;
CUDA_CALL(cub::DeviceSegmentedReduce::Sum(nullptr, temp_storage_size, CUDA_CALL(cub::DeviceSegmentedReduce::Sum(
probs[etype], nullptr, temp_storage_size, probs[etype], prob_sums[etype],
prob_sums[etype], num_segments, d_offsets, d_offsets + 1, stream));
num_segments,
d_offsets,
d_offsets + 1, stream));
void *temp_storage = device->AllocWorkspace(ctx, temp_storage_size); void *temp_storage = device->AllocWorkspace(ctx, temp_storage_size);
CUDA_CALL(cub::DeviceSegmentedReduce::Sum(temp_storage, temp_storage_size, CUDA_CALL(cub::DeviceSegmentedReduce::Sum(
probs[etype], temp_storage, temp_storage_size, probs[etype], prob_sums[etype],
prob_sums[etype], num_segments, d_offsets, d_offsets + 1, stream));
num_segments,
d_offsets,
d_offsets + 1, stream));
device->FreeWorkspace(ctx, temp_storage); device->FreeWorkspace(ctx, temp_storage);
} }
// copy graph metadata pointers to GPU // copy graph metadata pointers to GPU
auto d_graphs = static_cast<GraphKernelData<IdType>*>( auto d_graphs = static_cast<GraphKernelData<IdType> *>(device->AllocWorkspace(
device->AllocWorkspace(ctx, (num_etypes) * sizeof(GraphKernelData<IdType>))); ctx, (num_etypes) * sizeof(GraphKernelData<IdType>)));
device->CopyDataFromTo(h_graphs.data(), 0, d_graphs, 0, device->CopyDataFromTo(
(num_etypes) * sizeof(GraphKernelData<IdType>), h_graphs.data(), 0, d_graphs, 0,
DGLContext{kDGLCPU, 0}, (num_etypes) * sizeof(GraphKernelData<IdType>), DGLContext{kDGLCPU, 0},
ctx, ctx, hg->GetCSRMatrix(0).indptr->dtype);
hg->GetCSRMatrix(0).indptr->dtype);
// copy probs pointers to GPU // copy probs pointers to GPU
const FloatType **probs_dev = static_cast<const FloatType **>( const FloatType **probs_dev = static_cast<const FloatType **>(
device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *))); device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *)));
device->CopyDataFromTo(probs.get(), 0, probs_dev, 0, device->CopyDataFromTo(
(num_etypes) * sizeof(FloatType *), probs.get(), 0, probs_dev, 0, (num_etypes) * sizeof(FloatType *),
DGLContext{kDGLCPU, 0}, DGLContext{kDGLCPU, 0}, ctx, prob[0]->dtype);
ctx,
prob[0]->dtype);
// copy probs_sum pointers to GPU // copy probs_sum pointers to GPU
const FloatType **prob_sums_dev = static_cast<const FloatType **>( const FloatType **prob_sums_dev = static_cast<const FloatType **>(
device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *))); device->AllocWorkspace(ctx, num_etypes * sizeof(FloatType *)));
device->CopyDataFromTo(prob_sums.get(), 0, prob_sums_dev, 0, device->CopyDataFromTo(
(num_etypes) * sizeof(FloatType *), prob_sums.get(), 0, prob_sums_dev, 0, (num_etypes) * sizeof(FloatType *),
DGLContext{kDGLCPU, 0}, DGLContext{kDGLCPU, 0}, ctx, prob[0]->dtype);
ctx,
prob[0]->dtype);
// copy metapath to GPU // copy metapath to GPU
auto d_metapath = metapath.CopyTo(ctx); auto d_metapath = metapath.CopyTo(ctx);
const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data); const IdType *d_metapath_data = static_cast<IdType *>(d_metapath->data);
...@@ -349,26 +335,16 @@ std::pair<IdArray, IdArray> RandomWalkBiased( ...@@ -349,26 +335,16 @@ std::pair<IdArray, IdArray> RandomWalkBiased(
dim3 block(256); dim3 block(256);
dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE); dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE);
const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000); const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
CHECK(restart_prob->ctx.device_type == kDGLCUDA) << "restart prob should be in GPU."; CHECK(restart_prob->ctx.device_type == kDGLCUDA)
<< "restart prob should be in GPU.";
CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1."; CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1.";
const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>(); const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>();
const int64_t restart_prob_size = restart_prob->shape[0]; const int64_t restart_prob_size = restart_prob->shape[0];
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
(_RandomWalkBiasedKernel<IdType, FloatType, BLOCK_SIZE, TILE_SIZE>), (_RandomWalkBiasedKernel<IdType, FloatType, BLOCK_SIZE, TILE_SIZE>), grid,
grid, block, 0, stream, block, 0, stream, random_seed, seed_data, num_seeds, d_metapath_data,
random_seed, max_num_steps, d_graphs, probs_dev, prob_sums_dev, restart_prob_data,
seed_data, restart_prob_size, max_nodes, traces_data, eids_data);
num_seeds,
d_metapath_data,
max_num_steps,
d_graphs,
probs_dev,
prob_sums_dev,
restart_prob_data,
restart_prob_size,
max_nodes,
traces_data,
eids_data);
device->FreeWorkspace(ctx, d_graphs); device->FreeWorkspace(ctx, d_graphs);
device->FreeWorkspace(ctx, probs_dev); device->FreeWorkspace(ctx, probs_dev);
...@@ -376,13 +352,10 @@ std::pair<IdArray, IdArray> RandomWalkBiased( ...@@ -376,13 +352,10 @@ std::pair<IdArray, IdArray> RandomWalkBiased(
return std::make_pair(traces, eids); return std::make_pair(traces, eids);
} }
template<DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> RandomWalk( std::pair<IdArray, IdArray> RandomWalk(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob) { const std::vector<FloatArray> &prob) {
bool isUniform = true; bool isUniform = true;
for (const auto &etype_prob : prob) { for (const auto &etype_prob : prob) {
if (!IsNullArray(etype_prob)) { if (!IsNullArray(etype_prob)) {
...@@ -391,12 +364,13 @@ std::pair<IdArray, IdArray> RandomWalk( ...@@ -391,12 +364,13 @@ std::pair<IdArray, IdArray> RandomWalk(
} }
} }
auto restart_prob = NDArray::Empty( auto restart_prob =
{0}, DGLDataType{kDGLFloat, 32, 1}, DGLContext{XPU, 0}); NDArray::Empty({0}, DGLDataType{kDGLFloat, 32, 1}, DGLContext{XPU, 0});
if (!isUniform) { if (!isUniform) {
std::pair<IdArray, IdArray> ret; std::pair<IdArray, IdArray> ret;
ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", { ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
ret = RandomWalkBiased<XPU, FloatType, IdType>(hg, seeds, metapath, prob, restart_prob); ret = RandomWalkBiased<XPU, FloatType, IdType>(
hg, seeds, metapath, prob, restart_prob);
}); });
return ret; return ret;
} else { } else {
...@@ -404,14 +378,10 @@ std::pair<IdArray, IdArray> RandomWalk( ...@@ -404,14 +378,10 @@ std::pair<IdArray, IdArray> RandomWalk(
} }
} }
template<DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> RandomWalkWithRestart( std::pair<IdArray, IdArray> RandomWalkWithRestart(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds, const std::vector<FloatArray> &prob, double restart_prob) {
const TypeArray metapath,
const std::vector<FloatArray> &prob,
double restart_prob) {
bool isUniform = true; bool isUniform = true;
for (const auto &etype_prob : prob) { for (const auto &etype_prob : prob) {
if (!IsNullArray(etype_prob)) { if (!IsNullArray(etype_prob)) {
...@@ -421,17 +391,15 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart( ...@@ -421,17 +391,15 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart(
} }
auto device_ctx = seeds->ctx; auto device_ctx = seeds->ctx;
auto restart_prob_array = NDArray::Empty( auto restart_prob_array =
{1}, DGLDataType{kDGLFloat, 64, 1}, device_ctx); NDArray::Empty({1}, DGLDataType{kDGLFloat, 64, 1}, device_ctx);
auto device = dgl::runtime::DeviceAPI::Get(device_ctx); auto device = dgl::runtime::DeviceAPI::Get(device_ctx);
// use cuda stream from local thread // use cuda stream from local thread
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
device->CopyDataFromTo( device->CopyDataFromTo(
&restart_prob, 0, restart_prob_array.Ptr<double>(), 0, &restart_prob, 0, restart_prob_array.Ptr<double>(), 0, sizeof(double),
sizeof(double), DGLContext{kDGLCPU, 0}, device_ctx, restart_prob_array->dtype);
DGLContext{kDGLCPU, 0}, device_ctx,
restart_prob_array->dtype);
device->StreamSync(device_ctx, stream); device->StreamSync(device_ctx, stream);
if (!isUniform) { if (!isUniform) {
...@@ -442,18 +410,15 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart( ...@@ -442,18 +410,15 @@ std::pair<IdArray, IdArray> RandomWalkWithRestart(
}); });
return ret; return ret;
} else { } else {
return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob_array); return RandomWalkUniform<XPU, IdType>(
hg, seeds, metapath, restart_prob_array);
} }
} }
template<DGLDeviceType XPU, typename IdType> template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart( std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds, const std::vector<FloatArray> &prob, FloatArray restart_prob) {
const TypeArray metapath,
const std::vector<FloatArray> &prob,
FloatArray restart_prob) {
bool isUniform = true; bool isUniform = true;
for (const auto &etype_prob : prob) { for (const auto &etype_prob : prob) {
if (!IsNullArray(etype_prob)) { if (!IsNullArray(etype_prob)) {
...@@ -465,7 +430,8 @@ std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart( ...@@ -465,7 +430,8 @@ std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(
if (!isUniform) { if (!isUniform) {
std::pair<IdArray, IdArray> ret; std::pair<IdArray, IdArray> ret;
ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", { ATEN_FLOAT_TYPE_SWITCH(prob[0]->dtype, FloatType, "probability", {
ret = RandomWalkBiased<XPU, FloatType, IdType>(hg, seeds, metapath, prob, restart_prob); ret = RandomWalkBiased<XPU, FloatType, IdType>(
hg, seeds, metapath, prob, restart_prob);
}); });
return ret; return ret;
} else { } else {
...@@ -473,84 +439,56 @@ std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart( ...@@ -473,84 +439,56 @@ std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(
} }
} }
template<DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors( std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
const IdArray src, const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
const IdArray dst,
const int64_t num_samples_per_node,
const int64_t k) { const int64_t k) {
CHECK(src->ctx.device_type == kDGLCUDA) << CHECK(src->ctx.device_type == kDGLCUDA) << "IdArray needs be on GPU!";
"IdArray needs be on GPU!"; const IdxType *src_data = src.Ptr<IdxType>();
const IdxType* src_data = src.Ptr<IdxType>(); const IdxType *dst_data = dst.Ptr<IdxType>();
const IdxType* dst_data = dst.Ptr<IdxType>();
const int64_t num_dst_nodes = (dst->shape[0] / num_samples_per_node); const int64_t num_dst_nodes = (dst->shape[0] / num_samples_per_node);
auto ctx = src->ctx; auto ctx = src->ctx;
// use cuda stream from local thread // use cuda stream from local thread
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
auto frequency_hashmap = FrequencyHashmap<IdxType>(num_dst_nodes, auto frequency_hashmap = FrequencyHashmap<IdxType>(
num_samples_per_node, ctx, stream); num_dst_nodes, num_samples_per_node, ctx, stream);
auto ret = frequency_hashmap.Topk(src_data, dst_data, src->dtype, auto ret = frequency_hashmap.Topk(
src->shape[0], num_samples_per_node, k); src_data, dst_data, src->dtype, src->shape[0], num_samples_per_node, k);
return ret; return ret;
} }
template template std::pair<IdArray, IdArray> RandomWalk<kDGLCUDA, int32_t>(
std::pair<IdArray, IdArray> RandomWalk<kDGLCUDA, int32_t>( const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob); const std::vector<FloatArray> &prob);
template template std::pair<IdArray, IdArray> RandomWalk<kDGLCUDA, int64_t>(
std::pair<IdArray, IdArray> RandomWalk<kDGLCUDA, int64_t>( const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob); const std::vector<FloatArray> &prob);
template template std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCUDA, int32_t>(
std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCUDA, int32_t>( const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const HeteroGraphPtr hg, const std::vector<FloatArray> &prob, double restart_prob);
const IdArray seeds, template std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCUDA, int64_t>(
const TypeArray metapath, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const std::vector<FloatArray> &prob, const std::vector<FloatArray> &prob, double restart_prob);
double restart_prob);
template template std::pair<IdArray, IdArray>
std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCUDA, int64_t>( RandomWalkWithStepwiseRestart<kDGLCUDA, int32_t>(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds, const std::vector<FloatArray> &prob, FloatArray restart_prob);
const TypeArray metapath, template std::pair<IdArray, IdArray>
const std::vector<FloatArray> &prob, RandomWalkWithStepwiseRestart<kDGLCUDA, int64_t>(
double restart_prob); const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const std::vector<FloatArray> &prob, FloatArray restart_prob);
template
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDGLCUDA, int32_t>( template std::tuple<IdArray, IdArray, IdArray>
const HeteroGraphPtr hg, SelectPinSageNeighbors<kDGLCUDA, int32_t>(
const IdArray seeds, const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
const TypeArray metapath,
const std::vector<FloatArray> &prob,
FloatArray restart_prob);
template
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDGLCUDA, int64_t>(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob,
FloatArray restart_prob);
template
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDGLCUDA, int32_t>(
const IdArray src,
const IdArray dst,
const int64_t num_samples_per_node,
const int64_t k); const int64_t k);
template template std::tuple<IdArray, IdArray, IdArray>
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDGLCUDA, int64_t>( SelectPinSageNeighbors<kDGLCUDA, int64_t>(
const IdArray src, const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
const IdArray dst,
const int64_t num_samples_per_node,
const int64_t k); const int64_t k);
}; // namespace impl }; // namespace impl
}; // namespace sampling }; // namespace sampling
......
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file graph/sampling/randomwalk_with_restart_cpu.cc * \file graph/sampling/randomwalk_with_restart_cpu.cc
* \brief DGL sampler - CPU implementation of metapath-based random walk with restart with OpenMP * \brief DGL sampler - CPU implementation of metapath-based random walk with
* restart with OpenMP
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <dgl/random.h> #include <dgl/random.h>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "randomwalks_impl.h"
#include "randomwalks_cpu.h"
#include "metapath_randomwalk.h" #include "metapath_randomwalk.h"
#include "randomwalks_cpu.h"
#include "randomwalks_impl.h"
namespace dgl { namespace dgl {
...@@ -22,70 +25,53 @@ namespace sampling { ...@@ -22,70 +25,53 @@ namespace sampling {
namespace impl { namespace impl {
template<DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> RandomWalkWithRestart( std::pair<IdArray, IdArray> RandomWalkWithRestart(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds, const std::vector<FloatArray> &prob, double restart_prob) {
const TypeArray metapath,
const std::vector<FloatArray> &prob,
double restart_prob) {
TerminatePredicate<IdxType> terminate = TerminatePredicate<IdxType> terminate =
[restart_prob] (IdxType *data, dgl_id_t curr, int64_t len) { [restart_prob](IdxType *data, dgl_id_t curr, int64_t len) {
return RandomEngine::ThreadLocal()->Uniform<double>() < restart_prob; return RandomEngine::ThreadLocal()->Uniform<double>() < restart_prob;
}; };
return MetapathBasedRandomWalk<XPU, IdxType>(hg, seeds, metapath, prob, terminate); return MetapathBasedRandomWalk<XPU, IdxType>(
hg, seeds, metapath, prob, terminate);
} }
template template std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCPU, int32_t>(
std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCPU, int32_t>( const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const HeteroGraphPtr hg, const std::vector<FloatArray> &prob, double restart_prob);
const IdArray seeds, template std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCPU, int64_t>(
const TypeArray metapath, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const std::vector<FloatArray> &prob, const std::vector<FloatArray> &prob, double restart_prob);
double restart_prob);
template template <DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> RandomWalkWithRestart<kDGLCPU, int64_t>(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob,
double restart_prob);
template<DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart( std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds, const std::vector<FloatArray> &prob, FloatArray restart_prob) {
const TypeArray metapath,
const std::vector<FloatArray> &prob,
FloatArray restart_prob) {
std::pair<IdArray, IdArray> result; std::pair<IdArray, IdArray> result;
ATEN_FLOAT_TYPE_SWITCH(restart_prob->dtype, DType, "restart probability", { ATEN_FLOAT_TYPE_SWITCH(restart_prob->dtype, DType, "restart probability", {
DType *restart_prob_data = static_cast<DType *>(restart_prob->data); DType *restart_prob_data = static_cast<DType *>(restart_prob->data);
TerminatePredicate<IdxType> terminate = TerminatePredicate<IdxType> terminate =
[restart_prob_data] (IdxType *data, dgl_id_t curr, int64_t len) { [restart_prob_data](IdxType *data, dgl_id_t curr, int64_t len) {
return RandomEngine::ThreadLocal()->Uniform<DType>() < restart_prob_data[len]; return RandomEngine::ThreadLocal()->Uniform<DType>() <
}; restart_prob_data[len];
result = MetapathBasedRandomWalk<XPU, IdxType>(hg, seeds, metapath, prob, terminate); };
result = MetapathBasedRandomWalk<XPU, IdxType>(
hg, seeds, metapath, prob, terminate);
}); });
return result; return result;
} }
template template std::pair<IdArray, IdArray>
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDGLCPU, int32_t>( RandomWalkWithStepwiseRestart<kDGLCPU, int32_t>(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds, const std::vector<FloatArray> &prob, FloatArray restart_prob);
const TypeArray metapath, template std::pair<IdArray, IdArray>
const std::vector<FloatArray> &prob, RandomWalkWithStepwiseRestart<kDGLCPU, int64_t>(
FloatArray restart_prob); const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
template const std::vector<FloatArray> &prob, FloatArray restart_prob);
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDGLCPU, int64_t>(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob,
FloatArray restart_prob);
}; // namespace impl }; // namespace impl
......
...@@ -4,13 +4,15 @@ ...@@ -4,13 +4,15 @@
* \brief Dispatcher of different DGL random walks by device type * \brief Dispatcher of different DGL random walks by device type
*/ */
#include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/container.h>
#include <dgl/sampling/randomwalks.h> #include <dgl/sampling/randomwalks.h>
#include <utility>
#include <tuple> #include <tuple>
#include <utility>
#include <vector> #include <vector>
#include "../../../c_api_common.h" #include "../../../c_api_common.h"
#include "randomwalks_impl.h" #include "randomwalks_impl.h"
...@@ -24,9 +26,7 @@ namespace sampling { ...@@ -24,9 +26,7 @@ namespace sampling {
namespace { namespace {
void CheckRandomWalkInputs( void CheckRandomWalkInputs(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob) { const std::vector<FloatArray> &prob) {
CHECK_INT(seeds, "seeds"); CHECK_INT(seeds, "seeds");
CHECK_INT(metapath, "metapath"); CHECK_INT(metapath, "metapath");
...@@ -36,20 +36,24 @@ void CheckRandomWalkInputs( ...@@ -36,20 +36,24 @@ void CheckRandomWalkInputs(
// CHECK_SAME_CONTEXT(seeds, metapath); // CHECK_SAME_CONTEXT(seeds, metapath);
if (hg->IsPinned()) { if (hg->IsPinned()) {
CHECK_EQ(seeds->ctx.device_type, kDGLCUDA) << "Expected seeds (" << seeds->ctx << ")" \ CHECK_EQ(seeds->ctx.device_type, kDGLCUDA)
<< " to be on the GPU when the graph is pinned."; << "Expected seeds (" << seeds->ctx << ")"
<< " to be on the GPU when the graph is pinned.";
} else if (hg->Context() != seeds->ctx) { } else if (hg->Context() != seeds->ctx) {
LOG(FATAL) << "Expected seeds (" << seeds->ctx << ")" << " to have the same " \ LOG(FATAL) << "Expected seeds (" << seeds->ctx << ")"
<< "context as graph (" << hg->Context() << ")."; << " to have the same "
<< "context as graph (" << hg->Context() << ").";
} }
for (uint64_t i = 0; i < prob.size(); ++i) { for (uint64_t i = 0; i < prob.size(); ++i) {
FloatArray p = prob[i]; FloatArray p = prob[i];
CHECK_EQ(hg->Context(), p->ctx) << "Expected prob (" << p->ctx << ")" << " to have the same " \ CHECK_EQ(hg->Context(), p->ctx)
<< "context as graph (" << hg->Context() << ")."; << "Expected prob (" << p->ctx << ")"
<< " to have the same "
<< "context as graph (" << hg->Context() << ").";
CHECK_FLOAT(p, "probability"); CHECK_FLOAT(p, "probability");
if (p.GetSize() != 0) { if (p.GetSize() != 0) {
CHECK_EQ(hg->IsPinned(), p.IsPinned()) CHECK_EQ(hg->IsPinned(), p.IsPinned())
<< "The prob array should have the same pinning status as the graph"; << "The prob array should have the same pinning status as the graph";
CHECK_NDIM(p, 1, "probability"); CHECK_NDIM(p, 1, "probability");
} }
} }
...@@ -58,9 +62,7 @@ void CheckRandomWalkInputs( ...@@ -58,9 +62,7 @@ void CheckRandomWalkInputs(
}; // namespace }; // namespace
std::tuple<IdArray, IdArray, TypeArray> RandomWalk( std::tuple<IdArray, IdArray, TypeArray> RandomWalk(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob) { const std::vector<FloatArray> &prob) {
CheckRandomWalkInputs(hg, seeds, metapath, prob); CheckRandomWalkInputs(hg, seeds, metapath, prob);
...@@ -77,20 +79,19 @@ std::tuple<IdArray, IdArray, TypeArray> RandomWalk( ...@@ -77,20 +79,19 @@ std::tuple<IdArray, IdArray, TypeArray> RandomWalk(
} }
std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithRestart( std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithRestart(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds, const std::vector<FloatArray> &prob, double restart_prob) {
const TypeArray metapath,
const std::vector<FloatArray> &prob,
double restart_prob) {
CheckRandomWalkInputs(hg, seeds, metapath, prob); CheckRandomWalkInputs(hg, seeds, metapath, prob);
CHECK(restart_prob >= 0 && restart_prob < 1) << "restart probability must belong to [0, 1)"; CHECK(restart_prob >= 0 && restart_prob < 1)
<< "restart probability must belong to [0, 1)";
TypeArray vtypes; TypeArray vtypes;
std::pair<IdArray, IdArray> result; std::pair<IdArray, IdArray> result;
ATEN_XPU_SWITCH_CUDA(seeds->ctx.device_type, XPU, "RandomWalkWithRestart", { ATEN_XPU_SWITCH_CUDA(seeds->ctx.device_type, XPU, "RandomWalkWithRestart", {
ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, { ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath); vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
result = impl::RandomWalkWithRestart<XPU, IdxType>(hg, seeds, metapath, prob, restart_prob); result = impl::RandomWalkWithRestart<XPU, IdxType>(
hg, seeds, metapath, prob, restart_prob);
}); });
}); });
...@@ -98,40 +99,38 @@ std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithRestart( ...@@ -98,40 +99,38 @@ std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithRestart(
} }
std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithStepwiseRestart( std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithStepwiseRestart(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds, const std::vector<FloatArray> &prob, FloatArray restart_prob) {
const TypeArray metapath,
const std::vector<FloatArray> &prob,
FloatArray restart_prob) {
CheckRandomWalkInputs(hg, seeds, metapath, prob); CheckRandomWalkInputs(hg, seeds, metapath, prob);
// TODO(BarclayII): check the elements of restart probability // TODO(BarclayII): check the elements of restart probability
TypeArray vtypes; TypeArray vtypes;
std::pair<IdArray, IdArray> result; std::pair<IdArray, IdArray> result;
ATEN_XPU_SWITCH_CUDA(seeds->ctx.device_type, XPU, "RandomWalkWithStepwiseRestart", { ATEN_XPU_SWITCH_CUDA(
ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, { seeds->ctx.device_type, XPU, "RandomWalkWithStepwiseRestart", {
vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath); ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
result = impl::RandomWalkWithStepwiseRestart<XPU, IdxType>( vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
hg, seeds, metapath, prob, restart_prob); result = impl::RandomWalkWithStepwiseRestart<XPU, IdxType>(
}); hg, seeds, metapath, prob, restart_prob);
}); });
});
return std::make_tuple(result.first, result.second, vtypes); return std::make_tuple(result.first, result.second, vtypes);
} }
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors( std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
const IdArray src, const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
const IdArray dst,
const int64_t num_samples_per_node,
const int64_t k) { const int64_t k) {
assert((src->ndim == 1) && (dst->ndim == 1) assert(
&& (src->shape[0] % num_samples_per_node == 0) (src->ndim == 1) && (dst->ndim == 1) &&
&& (src->shape[0] == dst->shape[0])); (src->shape[0] % num_samples_per_node == 0) &&
(src->shape[0] == dst->shape[0]));
std::tuple<IdArray, IdArray, IdArray> result; std::tuple<IdArray, IdArray, IdArray> result;
ATEN_XPU_SWITCH_CUDA((src->ctx).device_type, XPU, "SelectPinSageNeighbors", { ATEN_XPU_SWITCH_CUDA((src->ctx).device_type, XPU, "SelectPinSageNeighbors", {
ATEN_ID_TYPE_SWITCH(src->dtype, IdxType, { ATEN_ID_TYPE_SWITCH(src->dtype, IdxType, {
result = impl::SelectPinSageNeighbors<XPU, IdxType>(src, dst, num_samples_per_node, k); result = impl::SelectPinSageNeighbors<XPU, IdxType>(
src, dst, num_samples_per_node, k);
}); });
}); });
...@@ -141,91 +140,94 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors( ...@@ -141,91 +140,94 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
}; // namespace sampling }; // namespace sampling
DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalk") DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalk")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0]; HeteroGraphRef hg = args[0];
IdArray seeds = args[1]; IdArray seeds = args[1];
TypeArray metapath = args[2]; TypeArray metapath = args[2];
List<Value> prob = args[3]; List<Value> prob = args[3];
const auto& prob_vec = ListValueToVector<FloatArray>(prob); const auto &prob_vec = ListValueToVector<FloatArray>(prob);
auto result = sampling::RandomWalk(hg.sptr(), seeds, metapath, prob_vec); auto result = sampling::RandomWalk(hg.sptr(), seeds, metapath, prob_vec);
List<Value> ret; List<Value> ret;
ret.push_back(Value(MakeValue(std::get<0>(result)))); ret.push_back(Value(MakeValue(std::get<0>(result))));
ret.push_back(Value(MakeValue(std::get<1>(result)))); ret.push_back(Value(MakeValue(std::get<1>(result))));
ret.push_back(Value(MakeValue(std::get<2>(result)))); ret.push_back(Value(MakeValue(std::get<2>(result))));
*rv = ret; *rv = ret;
}); });
DGL_REGISTER_GLOBAL("sampling.pinsage._CAPI_DGLSamplingSelectPinSageNeighbors") DGL_REGISTER_GLOBAL("sampling.pinsage._CAPI_DGLSamplingSelectPinSageNeighbors")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
IdArray src = args[0]; IdArray src = args[0];
IdArray dst = args[1]; IdArray dst = args[1];
int64_t num_travelsals = static_cast<int64_t>(args[2]); int64_t num_travelsals = static_cast<int64_t>(args[2]);
int64_t k = static_cast<int64_t>(args[3]); int64_t k = static_cast<int64_t>(args[3]);
auto result = sampling::SelectPinSageNeighbors(src, dst, num_travelsals, k); auto result =
sampling::SelectPinSageNeighbors(src, dst, num_travelsals, k);
List<Value> ret;
ret.push_back(Value(MakeValue(std::get<0>(result)))); List<Value> ret;
ret.push_back(Value(MakeValue(std::get<1>(result)))); ret.push_back(Value(MakeValue(std::get<0>(result))));
ret.push_back(Value(MakeValue(std::get<2>(result)))); ret.push_back(Value(MakeValue(std::get<1>(result))));
*rv = ret; ret.push_back(Value(MakeValue(std::get<2>(result))));
}); *rv = ret;
});
DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithRestart") DGL_REGISTER_GLOBAL(
.set_body([] (DGLArgs args, DGLRetValue *rv) { "sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithRestart")
HeteroGraphRef hg = args[0]; .set_body([](DGLArgs args, DGLRetValue *rv) {
IdArray seeds = args[1]; HeteroGraphRef hg = args[0];
TypeArray metapath = args[2]; IdArray seeds = args[1];
List<Value> prob = args[3]; TypeArray metapath = args[2];
double restart_prob = args[4]; List<Value> prob = args[3];
double restart_prob = args[4];
const auto& prob_vec = ListValueToVector<FloatArray>(prob);
const auto &prob_vec = ListValueToVector<FloatArray>(prob);
auto result = sampling::RandomWalkWithRestart(
hg.sptr(), seeds, metapath, prob_vec, restart_prob); auto result = sampling::RandomWalkWithRestart(
List<Value> ret; hg.sptr(), seeds, metapath, prob_vec, restart_prob);
ret.push_back(Value(MakeValue(std::get<0>(result)))); List<Value> ret;
ret.push_back(Value(MakeValue(std::get<1>(result)))); ret.push_back(Value(MakeValue(std::get<0>(result))));
ret.push_back(Value(MakeValue(std::get<2>(result)))); ret.push_back(Value(MakeValue(std::get<1>(result))));
*rv = ret; ret.push_back(Value(MakeValue(std::get<2>(result))));
}); *rv = ret;
});
DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithStepwiseRestart") DGL_REGISTER_GLOBAL(
.set_body([] (DGLArgs args, DGLRetValue *rv) { "sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithStepwiseRestart")
HeteroGraphRef hg = args[0]; .set_body([](DGLArgs args, DGLRetValue *rv) {
IdArray seeds = args[1]; HeteroGraphRef hg = args[0];
TypeArray metapath = args[2]; IdArray seeds = args[1];
List<Value> prob = args[3]; TypeArray metapath = args[2];
FloatArray restart_prob = args[4]; List<Value> prob = args[3];
FloatArray restart_prob = args[4];
const auto& prob_vec = ListValueToVector<FloatArray>(prob);
const auto &prob_vec = ListValueToVector<FloatArray>(prob);
auto result = sampling::RandomWalkWithStepwiseRestart(
hg.sptr(), seeds, metapath, prob_vec, restart_prob); auto result = sampling::RandomWalkWithStepwiseRestart(
List<Value> ret; hg.sptr(), seeds, metapath, prob_vec, restart_prob);
ret.push_back(Value(MakeValue(std::get<0>(result)))); List<Value> ret;
ret.push_back(Value(MakeValue(std::get<1>(result)))); ret.push_back(Value(MakeValue(std::get<0>(result))));
ret.push_back(Value(MakeValue(std::get<2>(result)))); ret.push_back(Value(MakeValue(std::get<1>(result))));
*rv = ret; ret.push_back(Value(MakeValue(std::get<2>(result))));
}); *rv = ret;
});
DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingPackTraces") DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingPackTraces")
.set_body([] (DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
IdArray vids = args[0]; IdArray vids = args[0];
TypeArray vtypes = args[1]; TypeArray vtypes = args[1];
IdArray concat_vids, concat_vtypes, lengths, offsets; IdArray concat_vids, concat_vtypes, lengths, offsets;
std::tie(concat_vids, lengths, offsets) = Pack(vids, -1); std::tie(concat_vids, lengths, offsets) = Pack(vids, -1);
std::tie(concat_vtypes, std::ignore) = ConcatSlices(vtypes, lengths); std::tie(concat_vtypes, std::ignore) = ConcatSlices(vtypes, lengths);
List<Value> ret; List<Value> ret;
ret.push_back(Value(MakeValue(concat_vids))); ret.push_back(Value(MakeValue(concat_vids)));
ret.push_back(Value(MakeValue(concat_vtypes))); ret.push_back(Value(MakeValue(concat_vtypes)));
ret.push_back(Value(MakeValue(lengths))); ret.push_back(Value(MakeValue(lengths)));
ret.push_back(Value(MakeValue(offsets))); ret.push_back(Value(MakeValue(offsets)));
*rv = ret; *rv = ret;
}); });
}; // namespace dgl }; // namespace dgl
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2018 by Contributors
* \file graph/sampler/generic_randomwalk_cpu.h * \file graph/sampler/generic_randomwalk_cpu.h
* \brief DGL sampler - templated implementation definition of random walks on CPU * \brief DGL sampler - templated implementation definition of random walks on
* CPU
*/ */
#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_CPU_H_ #ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_CPU_H_
#define DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_CPU_H_ #define DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_CPU_H_
#include <dgl/base_heterograph.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include "randomwalks_impl.h" #include "randomwalks_impl.h"
namespace dgl { namespace dgl {
...@@ -27,24 +30,24 @@ namespace { ...@@ -27,24 +30,24 @@ namespace {
/*! /*!
* \brief Generic Random Walk. * \brief Generic Random Walk.
* \param seeds A 1D array of seed nodes, with the type the source type of the first * \param seeds A 1D array of seed nodes, with the type the source type of the
* edge type in the metapath. * first edge type in the metapath. \param max_num_steps The maximum number of
* \param max_num_steps The maximum number of steps of a random walk path. * steps of a random walk path. \param step The random walk step function with
* \param step The random walk step function with type \c StepFunc. * type \c StepFunc. \param max_nodes Throws an error if one of the values in \c
* \param max_nodes Throws an error if one of the values in \c seeds exceeds this argument. * seeds exceeds this argument. \return A 2D array of shape (len(seeds),
* \return A 2D array of shape (len(seeds), max_num_steps + 1) with node IDs. * max_num_steps + 1) with node IDs. \note The graph itself should be bounded in
* \note The graph itself should be bounded in the closure of \c step. * the closure of \c step.
*/ */
template<DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> GenericRandomWalk( std::pair<IdArray, IdArray> GenericRandomWalk(
const IdArray seeds, const IdArray seeds, int64_t max_num_steps, StepFunc<IdxType> step,
int64_t max_num_steps,
StepFunc<IdxType> step,
int64_t max_nodes) { int64_t max_nodes) {
int64_t num_seeds = seeds->shape[0]; int64_t num_seeds = seeds->shape[0];
int64_t trace_length = max_num_steps + 1; int64_t trace_length = max_num_steps + 1;
IdArray traces = IdArray::Empty({num_seeds, trace_length}, seeds->dtype, seeds->ctx); IdArray traces =
IdArray eids = IdArray::Empty({num_seeds, max_num_steps}, seeds->dtype, seeds->ctx); IdArray::Empty({num_seeds, trace_length}, seeds->dtype, seeds->ctx);
IdArray eids =
IdArray::Empty({num_seeds, max_num_steps}, seeds->dtype, seeds->ctx);
const IdxType *seed_data = seeds.Ptr<IdxType>(); const IdxType *seed_data = seeds.Ptr<IdxType>();
IdxType *traces_data = traces.Ptr<IdxType>(); IdxType *traces_data = traces.Ptr<IdxType>();
...@@ -56,14 +59,14 @@ std::pair<IdArray, IdArray> GenericRandomWalk( ...@@ -56,14 +59,14 @@ std::pair<IdArray, IdArray> GenericRandomWalk(
dgl_id_t curr = seed_data[seed_id]; dgl_id_t curr = seed_data[seed_id];
traces_data[seed_id * trace_length] = curr; traces_data[seed_id * trace_length] = curr;
CHECK_LT(curr, max_nodes) << "Seed node ID exceeds the maximum number of nodes."; CHECK_LT(curr, max_nodes)
<< "Seed node ID exceeds the maximum number of nodes.";
for (i = 0; i < max_num_steps; ++i) { for (i = 0; i < max_num_steps; ++i) {
const auto &succ = step(traces_data + seed_id * trace_length, curr, i); const auto &succ = step(traces_data + seed_id * trace_length, curr, i);
traces_data[seed_id * trace_length + i + 1] = curr = std::get<0>(succ); traces_data[seed_id * trace_length + i + 1] = curr = std::get<0>(succ);
eids_data[seed_id * max_num_steps + i] = std::get<1>(succ); eids_data[seed_id * max_num_steps + i] = std::get<1>(succ);
if (std::get<2>(succ)) if (std::get<2>(succ)) break;
break;
} }
for (; i < max_num_steps; ++i) { for (; i < max_num_steps; ++i) {
......
...@@ -7,12 +7,13 @@ ...@@ -7,12 +7,13 @@
#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_IMPL_H_ #ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_IMPL_H_
#define DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_IMPL_H_ #define DGL_GRAPH_SAMPLING_RANDOMWALKS_RANDOMWALKS_IMPL_H_
#include <dgl/base_heterograph.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <vector> #include <dgl/base_heterograph.h>
#include <utility>
#include <tuple>
#include <functional> #include <functional>
#include <tuple>
#include <utility>
#include <vector>
namespace dgl { namespace dgl {
...@@ -26,100 +27,87 @@ namespace impl { ...@@ -26,100 +27,87 @@ namespace impl {
/*! /*!
* \brief Random walk step function * \brief Random walk step function
*/ */
template<typename IdxType> template <typename IdxType>
using StepFunc = std::function< using StepFunc = std::function<
// ID Edge ID terminate? // ID Edge ID terminate?
std::tuple<dgl_id_t, dgl_id_t, bool>( std::tuple<dgl_id_t, dgl_id_t, bool>(
IdxType *, // node IDs generated so far IdxType *, // node IDs generated so far
dgl_id_t, // last node ID dgl_id_t, // last node ID
int64_t)>; // # of steps int64_t)>; // # of steps
/*! /*!
* \brief Get the node types traversed by the metapath. * \brief Get the node types traversed by the metapath.
* \return A 1D array of shape (len(metapath) + 1,) with node type IDs. * \return A 1D array of shape (len(metapath) + 1,) with node type IDs.
*/ */
template<DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
TypeArray GetNodeTypesFromMetapath( TypeArray GetNodeTypesFromMetapath(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const TypeArray metapath);
const TypeArray metapath);
/*! /*!
* \brief Metapath-based random walk. * \brief Metapath-based random walk.
* \param hg The heterograph. * \param hg The heterograph.
* \param seeds A 1D array of seed nodes, with the type the source type of the first * \param seeds A 1D array of seed nodes, with the type the source type of the
* edge type in the metapath. * first edge type in the metapath. \param metapath A 1D array of edge types
* \param metapath A 1D array of edge types representing the metapath. * representing the metapath. \param prob A vector of 1D float arrays,
* \param prob A vector of 1D float arrays, indicating the transition probability of * indicating the transition probability of each edge by edge type. An empty
* each edge by edge type. An empty float array assumes uniform transition. * float array assumes uniform transition. \return A 2D array of shape
* \return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs. The * (len(seeds), len(metapath) + 1) with node IDs. The paths that terminated
* paths that terminated early are padded with -1. * early are padded with -1. A 2D array of shape (len(seeds), len(metapath))
* A 2D array of shape (len(seeds), len(metapath)) with edge IDs. The * with edge IDs. The paths that terminated early are padded with -1. \note
* paths that terminated early are padded with -1. * This function should be called together with GetNodeTypesFromMetapath to
* \note This function should be called together with GetNodeTypesFromMetapath to
* determine the node type of each node in the random walk traces. * determine the node type of each node in the random walk traces.
*/ */
template<DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> RandomWalk( std::pair<IdArray, IdArray> RandomWalk(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob); const std::vector<FloatArray> &prob);
/*! /*!
* \brief Metapath-based random walk with restart probability. * \brief Metapath-based random walk with restart probability.
* \param hg The heterograph. * \param hg The heterograph.
* \param seeds A 1D array of seed nodes, with the type the source type of the first * \param seeds A 1D array of seed nodes, with the type the source type of the
* edge type in the metapath. * first edge type in the metapath. \param metapath A 1D array of edge types
* \param metapath A 1D array of edge types representing the metapath. * representing the metapath. \param prob A vector of 1D float arrays,
* \param prob A vector of 1D float arrays, indicating the transition probability of * indicating the transition probability of each edge by edge type. An empty
* each edge by edge type. An empty float array assumes uniform transition. * float array assumes uniform transition. \param restart_prob Restart
* \param restart_prob Restart probability * probability \return A 2D array of shape (len(seeds), len(metapath) + 1) with
* \return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs. The * node IDs. The paths that terminated early are padded with -1. A 2D array of
* paths that terminated early are padded with -1. * shape (len(seeds), len(metapath)) with edge IDs. The paths that terminated
* A 2D array of shape (len(seeds), len(metapath)) with edge IDs. The * early are padded with -1. \note This function should be called together with
* paths that terminated early are padded with -1. * GetNodeTypesFromMetapath to determine the node type of each node in the
* \note This function should be called together with GetNodeTypesFromMetapath to * random walk traces.
* determine the node type of each node in the random walk traces.
*/ */
template<DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> RandomWalkWithRestart( std::pair<IdArray, IdArray> RandomWalkWithRestart(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds, const std::vector<FloatArray> &prob, double restart_prob);
const TypeArray metapath,
const std::vector<FloatArray> &prob,
double restart_prob);
/*! /*!
* \brief Metapath-based random walk with stepwise restart probability. Useful * \brief Metapath-based random walk with stepwise restart probability. Useful
* for PinSAGE-like models. * for PinSAGE-like models.
* \param hg The heterograph. * \param hg The heterograph.
* \param seeds A 1D array of seed nodes, with the type the source type of the first * \param seeds A 1D array of seed nodes, with the type the source type of the
* edge type in the metapath. * first edge type in the metapath. \param metapath A 1D array of edge types
* \param metapath A 1D array of edge types representing the metapath. * representing the metapath. \param prob A vector of 1D float arrays,
* \param prob A vector of 1D float arrays, indicating the transition probability of * indicating the transition probability of each edge by edge type. An empty
* each edge by edge type. An empty float array assumes uniform transition. * float array assumes uniform transition. \param restart_prob Restart
* \param restart_prob Restart probability array which has the same number of elements * probability array which has the same number of elements as \c metapath,
* as \c metapath, indicating the probability to terminate after transition. * indicating the probability to terminate after transition. \return A 2D array
* \return A 2D array of shape (len(seeds), len(metapath) + 1) with node IDs. The * of shape (len(seeds), len(metapath) + 1) with node IDs. The paths that
* paths that terminated early are padded with -1. * terminated early are padded with -1. A 2D array of shape (len(seeds),
* A 2D array of shape (len(seeds), len(metapath)) with edge IDs. The * len(metapath)) with edge IDs. The paths that terminated early are padded
* paths that terminated early are padded with -1. * with -1. \note This function should be called together with
* \note This function should be called together with GetNodeTypesFromMetapath to * GetNodeTypesFromMetapath to determine the node type of each node in the
* determine the node type of each node in the random walk traces. * random walk traces.
*/ */
template<DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart( std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(
const HeteroGraphPtr hg, const HeteroGraphPtr hg, const IdArray seeds, const TypeArray metapath,
const IdArray seeds, const std::vector<FloatArray> &prob, FloatArray restart_prob);
const TypeArray metapath,
const std::vector<FloatArray> &prob,
FloatArray restart_prob);
template<DGLDeviceType XPU, typename IdxType> template <DGLDeviceType XPU, typename IdxType>
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors( std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
const IdArray src, const IdArray src, const IdArray dst, const int64_t num_samples_per_node,
const IdArray dst,
const int64_t num_samples_per_node,
const int64_t k); const int64_t k);
}; // namespace impl }; // namespace impl
......
...@@ -6,25 +6,27 @@ ...@@ -6,25 +6,27 @@
#ifndef DGL_GRAPH_SERIALIZE_DGLGRAPH_DATA_H_ #ifndef DGL_GRAPH_SERIALIZE_DGLGRAPH_DATA_H_
#define DGL_GRAPH_SERIALIZE_DGLGRAPH_DATA_H_ #define DGL_GRAPH_SERIALIZE_DGLGRAPH_DATA_H_
#include <dgl/graph.h>
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/graph.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dmlc/io.h> #include <dgl/packed_func_ext.h>
#include <dmlc/type_traits.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h> #include <dgl/runtime/object.h>
#include <dgl/packed_func_ext.h> #include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <algorithm>
#include <iostream> #include <iostream>
#include <memory>
#include <string> #include <string>
#include <vector>
#include <algorithm>
#include <utility> #include <utility>
#include <memory> #include <vector>
#include "../../c_api_common.h" #include "../../c_api_common.h"
using dgl::runtime::NDArray;
using dgl::ImmutableGraph; using dgl::ImmutableGraph;
using dgl::runtime::NDArray;
using namespace dgl::runtime; using namespace dgl::runtime;
namespace dgl { namespace dgl {
...@@ -39,9 +41,9 @@ class GraphDataObject : public runtime::Object { ...@@ -39,9 +41,9 @@ class GraphDataObject : public runtime::Object {
std::vector<NamedTensor> edge_tensors; std::vector<NamedTensor> edge_tensors;
static constexpr const char *_type_key = "graph_serialize.GraphData"; static constexpr const char *_type_key = "graph_serialize.GraphData";
void SetData(ImmutableGraphPtr gptr, void SetData(
Map<std::string, Value> node_tensors, ImmutableGraphPtr gptr, Map<std::string, Value> node_tensors,
Map<std::string, Value> edge_tensors); Map<std::string, Value> edge_tensors);
void Save(dmlc::Stream *fs) const; void Save(dmlc::Stream *fs) const;
...@@ -50,7 +52,6 @@ class GraphDataObject : public runtime::Object { ...@@ -50,7 +52,6 @@ class GraphDataObject : public runtime::Object {
DGL_DECLARE_OBJECT_TYPE_INFO(GraphDataObject, runtime::Object); DGL_DECLARE_OBJECT_TYPE_INFO(GraphDataObject, runtime::Object);
}; };
class GraphData : public runtime::ObjectRef { class GraphData : public runtime::ObjectRef {
public: public:
DGL_DEFINE_OBJECT_REF_METHODS(GraphData, runtime::ObjectRef, GraphDataObject); DGL_DEFINE_OBJECT_REF_METHODS(GraphData, runtime::ObjectRef, GraphDataObject);
...@@ -61,7 +62,6 @@ class GraphData : public runtime::ObjectRef { ...@@ -61,7 +62,6 @@ class GraphData : public runtime::ObjectRef {
} }
}; };
ImmutableGraphPtr ToImmutableGraph(GraphPtr g); ImmutableGraphPtr ToImmutableGraph(GraphPtr g);
} // namespace serialize } // namespace serialize
......
...@@ -53,9 +53,10 @@ class StorageMetaDataObject : public runtime::Object { ...@@ -53,9 +53,10 @@ class StorageMetaDataObject : public runtime::Object {
static constexpr const char *_type_key = "graph_serialize.StorageMetaData"; static constexpr const char *_type_key = "graph_serialize.StorageMetaData";
void SetMetaData(dgl_id_t num_graph, std::vector<int64_t> nodes_num_list, void SetMetaData(
std::vector<int64_t> edges_num_list, dgl_id_t num_graph, std::vector<int64_t> nodes_num_list,
std::vector<NamedTensor> labels_list); std::vector<int64_t> edges_num_list,
std::vector<NamedTensor> labels_list);
void SetGraphData(std::vector<GraphData> gdata); void SetGraphData(std::vector<GraphData> gdata);
...@@ -72,8 +73,8 @@ class StorageMetaDataObject : public runtime::Object { ...@@ -72,8 +73,8 @@ class StorageMetaDataObject : public runtime::Object {
class StorageMetaData : public runtime::ObjectRef { class StorageMetaData : public runtime::ObjectRef {
public: public:
DGL_DEFINE_OBJECT_REF_METHODS(StorageMetaData, runtime::ObjectRef, DGL_DEFINE_OBJECT_REF_METHODS(
StorageMetaDataObject); StorageMetaData, runtime::ObjectRef, StorageMetaDataObject);
/*! \brief create a new StorageMetaData reference */ /*! \brief create a new StorageMetaData reference */
static StorageMetaData Create() { static StorageMetaData Create() {
...@@ -81,18 +82,18 @@ class StorageMetaData : public runtime::ObjectRef { ...@@ -81,18 +82,18 @@ class StorageMetaData : public runtime::ObjectRef {
} }
}; };
StorageMetaData LoadDGLGraphFiles(const std::string &filename, StorageMetaData LoadDGLGraphFiles(
std::vector<dgl_id_t> idx_list, const std::string &filename, std::vector<dgl_id_t> idx_list, bool onlyMeta);
bool onlyMeta);
StorageMetaData LoadDGLGraphs(const std::string &filename, StorageMetaData LoadDGLGraphs(
std::vector<dgl_id_t> idx_list, bool onlyMeta); const std::string &filename, std::vector<dgl_id_t> idx_list, bool onlyMeta);
bool SaveDGLGraphs(std::string filename, List<GraphData> graph_data, bool SaveDGLGraphs(
std::vector<NamedTensor> labels_list); std::string filename, List<GraphData> graph_data,
std::vector<NamedTensor> labels_list);
std::vector<HeteroGraphData> LoadHeteroGraphs(const std::string &filename, std::vector<HeteroGraphData> LoadHeteroGraphs(
std::vector<dgl_id_t> idx_list); const std::string &filename, std::vector<dgl_id_t> idx_list);
ImmutableGraphPtr ToImmutableGraph(GraphPtr g); ImmutableGraphPtr ToImmutableGraph(GraphPtr g);
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <memory>
#include "../../c_api_common.h" #include "../../c_api_common.h"
#include "../heterograph.h" #include "../heterograph.h"
...@@ -31,7 +31,6 @@ using namespace dgl::runtime; ...@@ -31,7 +31,6 @@ using namespace dgl::runtime;
namespace dgl { namespace dgl {
namespace serialize { namespace serialize {
typedef std::pair<std::string, NDArray> NamedTensor; typedef std::pair<std::string, NDArray> NamedTensor;
class HeteroGraphDataObject : public runtime::Object { class HeteroGraphDataObject : public runtime::Object {
public: public:
...@@ -42,14 +41,14 @@ class HeteroGraphDataObject : public runtime::Object { ...@@ -42,14 +41,14 @@ class HeteroGraphDataObject : public runtime::Object {
std::vector<std::string> ntype_names; std::vector<std::string> ntype_names;
static constexpr const char *_type_key = static constexpr const char *_type_key =
"heterograph_serialize.HeteroGraphData"; "heterograph_serialize.HeteroGraphData";
HeteroGraphDataObject() {} HeteroGraphDataObject() {}
HeteroGraphDataObject(HeteroGraphPtr gptr, HeteroGraphDataObject(
List<Map<std::string, Value>> ndata, HeteroGraphPtr gptr, List<Map<std::string, Value>> ndata,
List<Map<std::string, Value>> edata, List<Map<std::string, Value>> edata, List<Value> ntype_names,
List<Value> ntype_names, List<Value> etype_names) { List<Value> etype_names) {
this->gptr = std::dynamic_pointer_cast<HeteroGraph>(gptr); this->gptr = std::dynamic_pointer_cast<HeteroGraph>(gptr);
CHECK_NOTNULL(this->gptr); CHECK_NOTNULL(this->gptr);
for (auto nd_dict : ndata) { for (auto nd_dict : ndata) {
...@@ -95,17 +94,16 @@ class HeteroGraphDataObject : public runtime::Object { ...@@ -95,17 +94,16 @@ class HeteroGraphDataObject : public runtime::Object {
class HeteroGraphData : public runtime::ObjectRef { class HeteroGraphData : public runtime::ObjectRef {
public: public:
DGL_DEFINE_OBJECT_REF_METHODS(HeteroGraphData, runtime::ObjectRef, DGL_DEFINE_OBJECT_REF_METHODS(
HeteroGraphDataObject); HeteroGraphData, runtime::ObjectRef, HeteroGraphDataObject);
/*! \brief create a new GraphData reference */ /*! \brief create a new GraphData reference */
static HeteroGraphData Create(HeteroGraphPtr gptr, static HeteroGraphData Create(
List<Map<std::string, Value>> node_tensors, HeteroGraphPtr gptr, List<Map<std::string, Value>> node_tensors,
List<Map<std::string, Value>> edge_tensors, List<Map<std::string, Value>> edge_tensors, List<Value> ntype_names,
List<Value> ntype_names, List<Value> etype_names) {
List<Value> etype_names) {
return HeteroGraphData(std::make_shared<HeteroGraphDataObject>( return HeteroGraphData(std::make_shared<HeteroGraphDataObject>(
gptr, node_tensors, edge_tensors, ntype_names, etype_names)); gptr, node_tensors, edge_tensors, ntype_names, etype_names));
} }
/*! \brief create an empty GraphData reference */ /*! \brief create an empty GraphData reference */
......
...@@ -22,46 +22,46 @@ typedef std::pair<std::string, NDArray> NamedTensor; ...@@ -22,46 +22,46 @@ typedef std::pair<std::string, NDArray> NamedTensor;
constexpr uint64_t kDGLSerialize_Tensors = 0xDD5A9FBE3FA2443F; constexpr uint64_t kDGLSerialize_Tensors = 0xDD5A9FBE3FA2443F;
DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_SaveNDArrayDict") DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_SaveNDArrayDict")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0]; std::string filename = args[0];
auto fs = std::unique_ptr<dmlc::Stream>( auto fs = std::unique_ptr<dmlc::Stream>(
dmlc::Stream::Create(filename.c_str(), "w")); dmlc::Stream::Create(filename.c_str(), "w"));
CHECK(fs) << "Filename is invalid"; CHECK(fs) << "Filename is invalid";
fs->Write(kDGLSerialize_Tensors); fs->Write(kDGLSerialize_Tensors);
bool empty_dict = args[2]; bool empty_dict = args[2];
Map<std::string, Value> nd_dict; Map<std::string, Value> nd_dict;
if (!empty_dict) { if (!empty_dict) {
nd_dict = args[1]; nd_dict = args[1];
} }
std::vector<NamedTensor> namedTensors; std::vector<NamedTensor> namedTensors;
fs->Write(static_cast<uint64_t>(nd_dict.size())); fs->Write(static_cast<uint64_t>(nd_dict.size()));
for (auto kv : nd_dict) { for (auto kv : nd_dict) {
NDArray ndarray = static_cast<NDArray>(kv.second->data); NDArray ndarray = static_cast<NDArray>(kv.second->data);
namedTensors.emplace_back(kv.first, ndarray); namedTensors.emplace_back(kv.first, ndarray);
} }
fs->Write(namedTensors); fs->Write(namedTensors);
*rv = true; *rv = true;
}); });
DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_LoadNDArrayDict") DGL_REGISTER_GLOBAL("data.tensor_serialize._CAPI_LoadNDArrayDict")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue *rv) {
std::string filename = args[0]; std::string filename = args[0];
auto fs = std::unique_ptr<dmlc::Stream>( auto fs = std::unique_ptr<dmlc::Stream>(
dmlc::Stream::Create(filename.c_str(), "r")); dmlc::Stream::Create(filename.c_str(), "r"));
CHECK(fs) << "Filename is invalid or file doesn't exists"; CHECK(fs) << "Filename is invalid or file doesn't exists";
uint64_t magincNum, num_elements; uint64_t magincNum, num_elements;
CHECK(fs->Read(&magincNum)) << "Invalid file"; CHECK(fs->Read(&magincNum)) << "Invalid file";
CHECK_EQ(magincNum, kDGLSerialize_Tensors) << "Invalid DGL tensor file"; CHECK_EQ(magincNum, kDGLSerialize_Tensors) << "Invalid DGL tensor file";
CHECK(fs->Read(&num_elements)) << "Invalid num of elements"; CHECK(fs->Read(&num_elements)) << "Invalid num of elements";
Map<std::string, Value> nd_dict; Map<std::string, Value> nd_dict;
std::vector<NamedTensor> namedTensors; std::vector<NamedTensor> namedTensors;
fs->Read(&namedTensors); fs->Read(&namedTensors);
for (auto kv : namedTensors) { for (auto kv : namedTensors) {
Value ndarray = Value(MakeValue(kv.second)); Value ndarray = Value(MakeValue(kv.second));
nd_dict.Set(kv.first, ndarray); nd_dict.Set(kv.first, ndarray);
} }
*rv = nd_dict; *rv = nd_dict;
}); });
} // namespace serialize } // namespace serialize
} // namespace dgl } // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment