Unverified Commit dd762a1e authored by lixiaobai's avatar lixiaobai Committed by GitHub
Browse files

[PinSAGESampler] support PinSAGE sampler on GPU (#3567)



* Feat: support API "randomwalk_topk" in library

* Feat: use the new API "randomwalk_topk" for PinSAGESampler

* Minor

* Minor

* Refactor: modified codes as checker required

* Minor

* Minor

* Minor

* Minor

* Fix: checking errors in RandomWalkTopk

* Refactor: modified the docstring for randomwalk_topk

* change randomwalk_topk to internal

* fix

* rename

* Minor for pinsage.py

* Feat: support randomwalk and SelectPinSageNeighbors on GPU

Port RandomWalk algorithm on GPU,
and port SelectPinSageNeighbors on GPU.

* Feat: support GPU on python APIs

* Feat: remove perf print information in FrequenchHashmap

* Fix: modified the code format

Modified the code format as task_lint.sh suggested

* Feat: let test script support PinSAGESampler on GPU

Let test script support PinSAGESampler on GPU,
minor of "restart_prob".

* Minor

* Minor

* Minor

* Refactor: use the atomic operations from the array module

* Minor: change the long lines

* Refactor: modified the get_node_types for gpu

* Feat: update the contributor date

* Perf: remove unnecessary stream sync

* Feat: support other random walk

But the non-uniform choice is still not supported.

* Fix: add CUDA switch for random walk
Co-authored-by: default avatarQuan Gan <coin2028@hotmail.com>
parent 78e0dae6
......@@ -247,6 +247,7 @@ macro(dgl_config_cuda out_variable)
src/runtime/cuda/*.cu
src/geometry/cuda/*.cu
src/graph/transform/cuda/*.cu
src/graph/sampling/randomwalks/*.cu
)
# NVCC flags
......
......@@ -41,7 +41,7 @@ class RandomWalkNeighborSampler(object):
Parameters
----------
G : DGLGraph
The graph. It must be on CPU.
The graph.
num_traversals : int
The maximum number of metapath-based traversals for a single random walk.
......@@ -71,7 +71,6 @@ class RandomWalkNeighborSampler(object):
"""
def __init__(self, G, num_traversals, termination_prob,
num_random_walks, num_neighbors, metapath=None, weight_column='weights'):
assert G.device == F.cpu(), "Graph must be on CPU."
self.G = G
self.weight_column = weight_column
self.num_random_walks = num_random_walks
......@@ -93,7 +92,8 @@ class RandomWalkNeighborSampler(object):
self.full_metapath = metapath * num_traversals
restart_prob = np.zeros(self.metapath_hops * num_traversals)
restart_prob[self.metapath_hops::self.metapath_hops] = termination_prob
self.restart_prob = F.zerocopy_from_numpy(restart_prob)
restart_prob = F.tensor(restart_prob, dtype=F.float32)
self.restart_prob = F.copy_to(restart_prob, G.device)
# pylint: disable=no-member
def __call__(self, seed_nodes):
......
......@@ -33,11 +33,11 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
Parameters
----------
g : DGLGraph
The graph. Must be on CPU.
The graph.
nodes : Tensor
Node ID tensor from which the random walk traces starts.
The tensor must be on CPU, and must have the same dtype as the ID type
The tensor must have the same dtype as the ID type
of the graph.
metapath : list[str or tuple of str], optional
Metapath, specified as a list of edge types.
......@@ -85,10 +85,6 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
``(length + 1)``.
The type IDs match the ones in the original graph ``g``.
Notes
-----
The returned tensors are on CPU.
Examples
--------
The following creates a homogeneous graph:
......@@ -160,7 +156,6 @@ def random_walk(g, nodes, *, metapath=None, length=None, prob=None, restart_prob
[ 2, 0, 1, 1, 3, 2, 2],
[ 0, 1, 1, 3, 0, 0, 0]]), tensor([0, 0, 1, 0, 0, 1, 0]))
"""
assert g.device == F.cpu(), "Graph must be on CPU."
n_etypes = len(g.canonical_etypes)
n_ntypes = len(g.ntypes)
......
/*!
* Copyright (c) 2021 by Contributors
* \file graph/sampling/frequency_hashmap.cu
* \brief frequency hashmap - used to select top-k frequency edges of each node
*/
#include <cub/cub.cuh>
#include <algorithm>
#include <tuple>
#include <utility>
#include "../../../runtime/cuda/cuda_common.h"
#include "../../../array/cuda/atomic.cuh"
#include "frequency_hashmap.cuh"
namespace dgl {
namespace sampling {
namespace impl {
namespace {
int64_t _table_size(const int64_t num, const int64_t scale) {
/**
* Calculate the number of buckets in the hashtable. To guarantee we can
* fill the hashtable in the worst case, we must use a number of buckets which
* is a power of two.
* https://en.wikipedia.org/wiki/Quadratic_probing#Limitations
*/
const int64_t next_pow2 = 1 << static_cast<int64_t>(1 + std::log2(num >> 1));
return next_pow2 << scale;
}
template<typename IdxType, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _init_edge_table(void *edge_hashmap, int64_t edges_len) {
using EdgeItem = typename DeviceEdgeHashmap<IdxType>::EdgeItem;
auto edge_hashmap_t = static_cast<EdgeItem*>(edge_hashmap);
int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;
int64_t last_idx = start_idx + TILE_SIZE;
#pragma unroll(4)
for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) {
if (idx < edges_len) {
EdgeItem *edge = (edge_hashmap_t + idx);
edge->src = static_cast<IdxType>(-1);
edge->cnt = static_cast<IdxType>(0);
}
}
}
template<typename IdxType, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _count_frequency(const IdxType *src_data,
const int64_t num_edges, const int64_t num_edges_per_node,
IdxType *edge_blocks_prefix, bool *is_first_position,
DeviceEdgeHashmap<IdxType> device_edge_hashmap) {
int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;
int64_t last_idx = start_idx + TILE_SIZE;
IdxType count = 0;
for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) {
if (idx < num_edges) {
IdxType src = src_data[idx];
if (src == static_cast<IdxType>(-1)) {
continue;
}
IdxType dst_idx = (idx / num_edges_per_node);
if (device_edge_hashmap.InsertEdge(src, dst_idx) == 0) {
is_first_position[idx] = true;
++count;
}
}
}
using BlockReduce = typename cub::BlockReduce<IdxType, BLOCK_SIZE>;
__shared__ typename BlockReduce::TempStorage temp_space;
count = BlockReduce(temp_space).Sum(count);
if (threadIdx.x == 0) {
edge_blocks_prefix[blockIdx.x] = count;
if (blockIdx.x == 0) {
edge_blocks_prefix[gridDim.x] = 0;
}
}
}
/**
* This structure is used with cub's block-level prefixscan in order to
* keep a running sum as items are iteratively processed.
*/
template <typename T>
struct BlockPrefixCallbackOp {
T _running_total;
__device__ BlockPrefixCallbackOp(const T running_total)
: _running_total(running_total) {}
__device__ T operator()(const T block_aggregate) {
const T old_prefix = _running_total;
_running_total += block_aggregate;
return old_prefix;
}
};
template<typename IdxType, typename Idx64Type, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _compact_frequency(const IdxType *src_data, const IdxType *dst_data,
const int64_t num_edges, const int64_t num_edges_per_node,
const IdxType *edge_blocks_prefix, const bool *is_first_position,
IdxType *num_unique_each_node,
IdxType *unique_src_edges, Idx64Type *unique_frequency,
DeviceEdgeHashmap<IdxType> device_edge_hashmap) {
int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;
int64_t last_idx = start_idx + TILE_SIZE;
const IdxType block_offset = edge_blocks_prefix[blockIdx.x];
using BlockScan = typename cub::BlockScan<IdxType, BLOCK_SIZE>;
__shared__ typename BlockScan::TempStorage temp_space;
BlockPrefixCallbackOp<IdxType> prefix_op(0);
for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) {
IdxType flag = 0;
if (idx < num_edges) {
IdxType src = src_data[idx];
IdxType dst_idx = (idx / num_edges_per_node);
if (idx % num_edges_per_node == 0) {
num_unique_each_node[dst_idx] = device_edge_hashmap.GetDstCount(dst_idx);
}
if (is_first_position[idx] == true) {
flag = 1;
}
BlockScan(temp_space).ExclusiveSum(flag, flag, prefix_op);
__syncthreads();
if (is_first_position[idx] == true) {
const IdxType pos = (block_offset + flag);
unique_src_edges[pos] = src;
if (sizeof(IdxType) != sizeof(Idx64Type)
&& sizeof(IdxType) == 4) { // if IdxType is a 32-bit data
unique_frequency[pos] = (
(static_cast<Idx64Type>(num_edges / num_edges_per_node - dst_idx) << 32)
| device_edge_hashmap.GetEdgeCount(src, dst_idx));
} else {
unique_frequency[pos] = device_edge_hashmap.GetEdgeCount(src, dst_idx);
}
}
}
}
}
template<typename IdxType, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _get_pick_num(IdxType *num_unique_each_node,
const int64_t num_pick, const int64_t num_dst_nodes) {
int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;
int64_t last_idx = start_idx + TILE_SIZE;
#pragma unroll(4)
for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) {
if (idx < num_dst_nodes) {
IdxType &num_unique = num_unique_each_node[idx];
num_unique = min(num_unique, static_cast<IdxType>(num_pick));
}
}
}
template<typename IdxType, typename Idx64Type, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _pick_data(const Idx64Type *unique_frequency, const IdxType *unique_src_edges,
const IdxType *unique_input_offsets, const IdxType *dst_data,
const int64_t num_edges_per_node, const int64_t num_dst_nodes,
const int64_t num_edges,
const IdxType *unique_output_offsets,
IdxType *output_src, IdxType *output_dst, IdxType *output_frequency) {
int64_t start_idx = (blockIdx.x * TILE_SIZE) + threadIdx.x;
int64_t last_idx = start_idx + TILE_SIZE;
for (int64_t idx = start_idx; idx < last_idx; idx += BLOCK_SIZE) {
if (idx < num_dst_nodes) {
const int64_t dst_pos = (idx * num_edges_per_node);
assert(dst_pos < num_edges);
const IdxType dst = dst_data[dst_pos];
const IdxType last_output_offset = unique_output_offsets[idx + 1];
assert((last_output_offset - unique_output_offsets[idx]) <=
(unique_input_offsets[idx + 1] - unique_input_offsets[idx]));
for (IdxType output_idx = unique_output_offsets[idx], input_idx = unique_input_offsets[idx];
output_idx < last_output_offset; ++output_idx, ++input_idx) {
output_src[output_idx] = unique_src_edges[input_idx];
output_dst[output_idx] = dst;
output_frequency[output_idx] = static_cast<IdxType>(unique_frequency[input_idx]);
}
}
}
}
} // namespace
// return the old cnt of this edge
template<typename IdxType>
inline __device__ IdxType DeviceEdgeHashmap<IdxType>::InsertEdge(
const IdxType &src, const IdxType &dst_idx) {
IdxType start_off = dst_idx * _num_items_each_dst;
IdxType pos = EdgeHash(src);
IdxType delta = 1;
IdxType old_cnt = static_cast<IdxType>(-1);
while (true) {
IdxType old_src = dgl::aten::cuda::AtomicCAS(
&_edge_hashmap[start_off + pos].src, static_cast<IdxType>(-1), src);
if (old_src == static_cast<IdxType>(-1) || old_src == src) {
// first insert
old_cnt = dgl::aten::cuda::AtomicAdd(
&_edge_hashmap[start_off + pos].cnt, static_cast<IdxType>(1));
if (old_src == static_cast<IdxType>(-1)) {
assert(dst_idx < _num_dst);
dgl::aten::cuda::AtomicAdd(&_dst_unique_edges[dst_idx], static_cast<IdxType>(1));
}
break;
}
pos = EdgeHash(pos + delta);
delta += 1;
}
return old_cnt;
}
template<typename IdxType>
inline __device__ IdxType DeviceEdgeHashmap<IdxType>::GetDstCount(const IdxType &dst_idx) {
return _dst_unique_edges[dst_idx];
}
template<typename IdxType>
inline __device__ IdxType DeviceEdgeHashmap<IdxType>::GetEdgeCount(
const IdxType &src, const IdxType &dst_idx) {
IdxType start_off = dst_idx * _num_items_each_dst;
IdxType pos = EdgeHash(src);
IdxType delta = 1;
while (_edge_hashmap[start_off + pos].src != src) {
pos = EdgeHash(pos + delta);
delta += 1;
}
return _edge_hashmap[start_off + pos].cnt;
}
template <typename IdxType>
FrequencyHashmap<IdxType>::FrequencyHashmap(
int64_t num_dst, int64_t num_items_each_dst, DGLContext ctx,
cudaStream_t stream, int64_t edge_table_scale) {
_ctx = ctx;
_stream = stream;
num_items_each_dst = _table_size(num_items_each_dst, edge_table_scale);
auto device = dgl::runtime::DeviceAPI::Get(_ctx);
auto dst_unique_edges = static_cast<IdxType*>(
device->AllocWorkspace(_ctx, (num_dst) * sizeof(IdxType)));
auto edge_hashmap = static_cast<EdgeItem*>(
device->AllocWorkspace(_ctx, (num_dst * num_items_each_dst) * sizeof(EdgeItem)));
constexpr int BLOCK_SIZE = 256;
constexpr int TILE_SIZE = BLOCK_SIZE * 8;
dim3 block(BLOCK_SIZE);
dim3 grid((num_dst * num_items_each_dst + TILE_SIZE - 1) / TILE_SIZE);
cudaMemset(dst_unique_edges, 0, (num_dst) * sizeof(IdxType));
_init_edge_table<IdxType, BLOCK_SIZE, TILE_SIZE><<<grid, block, 0, _stream>>>(
edge_hashmap, (num_dst * num_items_each_dst));
_device_edge_hashmap = new DeviceEdgeHashmap<IdxType>(
num_dst, num_items_each_dst, dst_unique_edges, edge_hashmap);
_dst_unique_edges = dst_unique_edges;
_edge_hashmap = edge_hashmap;
}
template <typename IdxType>
FrequencyHashmap<IdxType>::~FrequencyHashmap() {
auto device = dgl::runtime::DeviceAPI::Get(_ctx);
delete _device_edge_hashmap;
_device_edge_hashmap = nullptr;
device->FreeWorkspace(_ctx, _dst_unique_edges);
_dst_unique_edges = nullptr;
device->FreeWorkspace(_ctx, _edge_hashmap);
_edge_hashmap = nullptr;
}
template <typename IdxType>
std::tuple<IdArray, IdArray, IdArray> FrequencyHashmap<IdxType>::Topk(
const IdxType *src_data, const IdxType *dst_data, DLDataType dtype,
const int64_t num_edges, const int64_t num_edges_per_node,
const int64_t num_pick) {
using Idx64Type = int64_t;
const int64_t num_dst_nodes = (num_edges / num_edges_per_node);
constexpr int BLOCK_SIZE = 256;
// XXX: a experienced value, best performance in GV100
constexpr int TILE_SIZE = BLOCK_SIZE * 32;
const dim3 block(BLOCK_SIZE);
const dim3 edges_grid((num_edges + TILE_SIZE - 1) / TILE_SIZE);
auto device = dgl::runtime::DeviceAPI::Get(_ctx);
const IdxType num_edge_blocks = static_cast<IdxType>(edges_grid.x);
IdxType num_unique_edges = 0;
// to mark if this position of edges is the first inserting position for _edge_hashmap
bool *is_first_position = static_cast<bool*>(
device->AllocWorkspace(_ctx, sizeof(bool) * (num_edges)));
CUDA_CALL(cudaMemset(is_first_position, 0, sizeof(bool) * (num_edges)));
// double space to use ExclusiveSum
auto edge_blocks_prefix_data = static_cast<IdxType*>(
device->AllocWorkspace(_ctx, 2 * sizeof(IdxType) * (num_edge_blocks + 1)));
IdxType *edge_blocks_prefix = edge_blocks_prefix_data;
IdxType *edge_blocks_prefix_alternate = (edge_blocks_prefix_data + (num_edge_blocks + 1));
// triple space to use ExclusiveSum and unique_output_offsets
auto num_unique_each_node_data = static_cast<IdxType*>(
device->AllocWorkspace(_ctx, 3 * sizeof(IdxType) * (num_dst_nodes + 1)));
IdxType *num_unique_each_node = num_unique_each_node_data;
IdxType *num_unique_each_node_alternate = (num_unique_each_node_data + (num_dst_nodes + 1));
IdxType *unique_output_offsets = (num_unique_each_node_data + 2 * (num_dst_nodes + 1));
// 1. Scan the all edges and count the unique edges and unique edges for each dst node
_count_frequency<IdxType, BLOCK_SIZE, TILE_SIZE><<<edges_grid, block, 0, _stream>>>(
src_data, num_edges, num_edges_per_node,
edge_blocks_prefix, is_first_position, *_device_edge_hashmap);
// 2. Compact the unique edges frequency
// 2.1 ExclusiveSum the edge_blocks_prefix
void *d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes,
edge_blocks_prefix, edge_blocks_prefix_alternate, num_edge_blocks + 1));
d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes,
edge_blocks_prefix, edge_blocks_prefix_alternate, num_edge_blocks + 1));
device->FreeWorkspace(_ctx, d_temp_storage);
std::swap(edge_blocks_prefix, edge_blocks_prefix_alternate);
device->CopyDataFromTo(&edge_blocks_prefix[num_edge_blocks], 0, &num_unique_edges, 0,
sizeof(num_unique_edges),
_ctx, DGLContext{kDLCPU, 0},
dtype, _stream);
device->StreamSync(_ctx, _stream);
// 2.2 Allocate the data of unique edges and frequency
// double space to use SegmentedRadixSort
auto unique_src_edges_data = static_cast<IdxType*>(
device->AllocWorkspace(_ctx, 2 * sizeof(IdxType) * (num_unique_edges)));
IdxType *unique_src_edges = unique_src_edges_data;
IdxType *unique_src_edges_alternate = unique_src_edges_data + num_unique_edges;
// double space to use SegmentedRadixSort
auto unique_frequency_data = static_cast<Idx64Type*>(
device->AllocWorkspace(_ctx, 2 * sizeof(Idx64Type) * (num_unique_edges)));
Idx64Type *unique_frequency = unique_frequency_data;
Idx64Type *unique_frequency_alternate = unique_frequency_data + num_unique_edges;
// 2.3 Compact the unique edges and their frequency
_compact_frequency<IdxType, Idx64Type, BLOCK_SIZE, TILE_SIZE><<<edges_grid, block, 0, _stream>>>(
src_data, dst_data, num_edges, num_edges_per_node,
edge_blocks_prefix, is_first_position, num_unique_each_node,
unique_src_edges, unique_frequency, *_device_edge_hashmap);
// 3. SegmentedRadixSort the unique edges and unique_frequency
// 3.1 ExclusiveSum the num_unique_each_node
d_temp_storage = nullptr;
temp_storage_bytes = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes,
num_unique_each_node, num_unique_each_node_alternate, num_dst_nodes + 1));
d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes,
num_unique_each_node, num_unique_each_node_alternate, num_dst_nodes + 1));
device->FreeWorkspace(_ctx, d_temp_storage);
// 3.2 SegmentedRadixSort the unique_src_edges and unique_frequency
// Create a set of DoubleBuffers to wrap pairs of device pointers
cub::DoubleBuffer<Idx64Type> d_unique_frequency(unique_frequency, unique_frequency_alternate);
cub::DoubleBuffer<IdxType> d_unique_src_edges(unique_src_edges, unique_src_edges_alternate);
// Determine temporary device storage requirements
d_temp_storage = nullptr;
temp_storage_bytes = 0;
// the DeviceRadixSort is faster than DeviceSegmentedRadixSort,
// especially when num_dst_nodes is large (about ~10000)
if (dtype.bits == 32) {
CUDA_CALL(cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes,
d_unique_frequency, d_unique_src_edges, num_unique_edges));
} else {
CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes,
d_unique_frequency, d_unique_src_edges, num_unique_edges, num_dst_nodes,
num_unique_each_node_alternate, num_unique_each_node_alternate + 1));
}
d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);
if (dtype.bits == 32) {
CUDA_CALL(cub::DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes,
d_unique_frequency, d_unique_src_edges, num_unique_edges));
} else {
CUDA_CALL(cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes,
d_unique_frequency, d_unique_src_edges, num_unique_edges, num_dst_nodes,
num_unique_each_node_alternate, num_unique_each_node_alternate + 1));
}
device->FreeWorkspace(_ctx, d_temp_storage);
// 4. Get the final pick number for each dst node
// 4.1 Reset the min(num_pick, num_unique_each_node) to num_unique_each_node
constexpr int NODE_TILE_SIZE = BLOCK_SIZE * 2;
const dim3 nodes_grid((num_dst_nodes + NODE_TILE_SIZE - 1) / NODE_TILE_SIZE);
_get_pick_num<IdxType, BLOCK_SIZE, NODE_TILE_SIZE><<<nodes_grid, block, 0, _stream>>>(
num_unique_each_node, num_pick, num_dst_nodes);
// 4.2 ExclusiveSum the new num_unique_each_node as unique_output_offsets
// use unique_output_offsets;
d_temp_storage = nullptr;
temp_storage_bytes = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes,
num_unique_each_node, unique_output_offsets, num_dst_nodes + 1));
d_temp_storage = device->AllocWorkspace(_ctx, temp_storage_bytes);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes,
num_unique_each_node, unique_output_offsets, num_dst_nodes + 1));
device->FreeWorkspace(_ctx, d_temp_storage);
// 5. Pick the data to result
IdxType num_output = 0;
device->CopyDataFromTo(&unique_output_offsets[num_dst_nodes], 0, &num_output, 0,
sizeof(num_output),
_ctx, DGLContext{kDLCPU, 0},
dtype, _stream);
device->StreamSync(_ctx, _stream);
IdArray res_src = IdArray::Empty({static_cast<int64_t>(num_output)},
dtype, _ctx);
IdArray res_dst = IdArray::Empty({static_cast<int64_t>(num_output)},
dtype, _ctx);
IdArray res_cnt = IdArray::Empty({static_cast<int64_t>(num_output)},
dtype, _ctx);
_pick_data<IdxType, Idx64Type, BLOCK_SIZE, NODE_TILE_SIZE><<<nodes_grid, block, 0, _stream>>>(
d_unique_frequency.Current(), d_unique_src_edges.Current(), num_unique_each_node_alternate,
dst_data, num_edges_per_node, num_dst_nodes, num_edges,
unique_output_offsets,
res_src.Ptr<IdxType>(), res_dst.Ptr<IdxType>(), res_cnt.Ptr<IdxType>());
device->FreeWorkspace(_ctx, is_first_position);
device->FreeWorkspace(_ctx, edge_blocks_prefix_data);
device->FreeWorkspace(_ctx, num_unique_each_node_data);
device->FreeWorkspace(_ctx, unique_src_edges_data);
device->FreeWorkspace(_ctx, unique_frequency_data);
return std::make_tuple(res_src, res_dst, res_cnt);
}
template
class FrequencyHashmap<int64_t>;
template
class FrequencyHashmap<int32_t>;
}; // namespace impl
}; // namespace sampling
}; // namespace dgl
/*!
* Copyright (c) 2021 by Contributors
* \file graph/sampling/frequency_hashmap.cuh
* \brief frequency hashmap - used to select top-k frequency edges of each node
*/
#ifndef DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_CUH_
#define DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_CUH_
#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
namespace dgl {
namespace sampling {
namespace impl {
template<typename IdxType>
class DeviceEdgeHashmap {
public:
struct EdgeItem {
IdxType src;
IdxType cnt;
};
DeviceEdgeHashmap() = delete;
DeviceEdgeHashmap(int64_t num_dst, int64_t num_items_each_dst,
IdxType* dst_unique_edges, EdgeItem *edge_hashmap):
_num_dst(num_dst), _num_items_each_dst(num_items_each_dst),
_dst_unique_edges(dst_unique_edges), _edge_hashmap(edge_hashmap) {};
// return the old cnt of this edge
inline __device__ IdxType InsertEdge(const IdxType &src, const IdxType &dst_idx);
inline __device__ IdxType GetDstCount(const IdxType &dst_idx);
inline __device__ IdxType GetEdgeCount(const IdxType &src, const IdxType &dst_idx);
private:
int64_t _num_dst;
int64_t _num_items_each_dst;
IdxType *_dst_unique_edges;
EdgeItem *_edge_hashmap;
inline __device__ IdxType EdgeHash(const IdxType &id) const {
return id % _num_items_each_dst;
};
};
template<typename IdxType>
class FrequencyHashmap {
public:
static constexpr int64_t kDefaultEdgeTableScale = 3;
FrequencyHashmap() = delete;
FrequencyHashmap(int64_t num_dst, int64_t num_items_each_dst, DGLContext ctx, cudaStream_t stream,
int64_t edge_table_scale = kDefaultEdgeTableScale);
~FrequencyHashmap();
using EdgeItem = typename DeviceEdgeHashmap<IdxType>::EdgeItem;
std::tuple<IdArray, IdArray, IdArray> Topk(
const IdxType *src_data, const IdxType *dst_data, DLDataType dtype,
const int64_t num_edges, const int64_t num_edges_per_node,
const int64_t num_pick);
private:
DGLContext _ctx;
cudaStream_t _stream;
DeviceEdgeHashmap<IdxType> *_device_edge_hashmap;
IdxType *_dst_unique_edges;
EdgeItem *_edge_hashmap;
};
}; // namespace impl
}; // namespace sampling
}; // namespace dgl
#endif // DGL_GRAPH_SAMPLING_RANDOMWALKS_FREQUENCY_CUH_
/*!
* Copyright (c) 2021 by Contributors
* \file graph/sampling/get_node_types_gpu.cu
* \brief DGL sampler
*/
#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/runtime/device_api.h>
#include <cuda_runtime.h>
#include <utility>
#include "randomwalks_impl.h"
namespace dgl {
using namespace dgl::runtime;
using namespace dgl::aten;
namespace sampling {
namespace impl {
template<DLDeviceType XPU, typename IdxType>
TypeArray GetNodeTypesFromMetapath(
const HeteroGraphPtr hg,
const TypeArray metapath) {
uint64_t num_etypes = metapath->shape[0];
auto cpu_ctx = DGLContext{kDLCPU, 0};
auto metapath_ctx = metapath->ctx;
// use default stream
cudaStream_t stream = 0;
TypeArray h_result = TypeArray::Empty(
{metapath->shape[0] + 1}, metapath->dtype, cpu_ctx);
auto h_result_data = h_result.Ptr<IdxType>();
auto h_metapath = metapath.CopyTo(cpu_ctx, stream);
DeviceAPI::Get(metapath_ctx)->StreamSync(metapath_ctx, stream);
const IdxType *h_metapath_data = h_metapath.Ptr<IdxType>();
dgl_type_t curr_type = hg->GetEndpointTypes(h_metapath_data[0]).first;
h_result_data[0] = curr_type;
for (uint64_t i = 0; i < num_etypes; ++i) {
auto src_dst_type = hg->GetEndpointTypes(h_metapath_data[i]);
dgl_type_t srctype = src_dst_type.first;
dgl_type_t dsttype = src_dst_type.second;
if (srctype != curr_type) {
LOG(FATAL) << "source of edge type #" << i <<
" does not match destination of edge type #" << i - 1;
}
curr_type = dsttype;
h_result_data[i + 1] = dsttype;
}
auto result = h_result.CopyTo(metapath->ctx, stream);
DeviceAPI::Get(metapath_ctx)->StreamSync(metapath_ctx, stream);
return result;
}
template
TypeArray GetNodeTypesFromMetapath<kDLGPU, int32_t>(
const HeteroGraphPtr hg,
const TypeArray metapath);
template
TypeArray GetNodeTypesFromMetapath<kDLGPU, int64_t>(
const HeteroGraphPtr hg,
const TypeArray metapath);
}; // namespace impl
}; // namespace sampling
}; // namespace dgl
/*!
* Copyright (c) 2021 by Contributors
* \file graph/sampling/randomwalk_gpu.cu
* \brief DGL sampler
*/
#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/runtime/device_api.h>
#include <dgl/random.h>
#include <curand_kernel.h>
#include <vector>
#include <utility>
#include <tuple>
#include "frequency_hashmap.cuh"
namespace dgl {
using namespace dgl::runtime;
using namespace dgl::aten;
namespace sampling {
namespace impl {
namespace {
template<typename IdType>
struct GraphKernelData {
const IdType *in_ptr;
const IdType *in_cols;
const IdType *data;
};
template<typename IdType, typename FloatType, int BLOCK_SIZE, int TILE_SIZE>
__global__ void _RandomWalkKernel(
const uint64_t rand_seed, const IdType *seed_data, const int64_t num_seeds,
const IdType* metapath_data, const uint64_t max_num_steps,
const GraphKernelData<IdType>* graphs,
const FloatType* restart_prob_data,
const int64_t restart_prob_size,
IdType *out_traces_data,
IdType *out_eids_data) {
assert(BLOCK_SIZE == blockDim.x);
int64_t idx = blockIdx.x * TILE_SIZE + threadIdx.x;
int64_t last_idx = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_seeds);
int64_t trace_length = (max_num_steps + 1);
curandState rng;
// reference:
// https://docs.nvidia.com/cuda/curand/device-api-overview.html#performance-notes
curand_init(rand_seed + idx, 0, 0, &rng);
while (idx < last_idx) {
IdType curr = seed_data[idx];
IdType *traces_data_ptr = &out_traces_data[idx * trace_length];
IdType *eids_data_ptr = &out_eids_data[idx * max_num_steps];
*(traces_data_ptr++) = curr;
int64_t step_idx;
for (step_idx = 0; step_idx < max_num_steps; ++step_idx) {
IdType metapath_id = metapath_data[step_idx];
const GraphKernelData<IdType> &graph = graphs[metapath_id];
const int64_t in_row_start = graph.in_ptr[curr];
const int64_t deg = graph.in_ptr[curr + 1] - graph.in_ptr[curr];
if (deg == 0) { // the degree is zero
break;
}
const int64_t num = curand(&rng) % deg;
IdType pick = graph.in_cols[in_row_start + num];
IdType eid = (graph.data? graph.data[in_row_start + num] : in_row_start + num);
*traces_data_ptr = pick;
*eids_data_ptr = eid;
if ((restart_prob_size > 1) && (curand_uniform(&rng) < restart_prob_data[step_idx])) {
break;
} else if ((restart_prob_size == 1) && (curand_uniform(&rng) < restart_prob_data[0])) {
break;
}
++traces_data_ptr; ++eids_data_ptr;
curr = pick;
}
for (; step_idx < max_num_steps; ++step_idx) {
*(traces_data_ptr++) = -1;
*(eids_data_ptr++) = -1;
}
idx += BLOCK_SIZE;
}
}
} // namespace
// random walk for uniform choice
template<DLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> RandomWalkUniform(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
FloatArray restart_prob) {
const int64_t max_num_steps = metapath->shape[0];
const IdType *metapath_data = static_cast<IdType *>(metapath->data);
int64_t num_etypes = hg->NumEdgeTypes();
CHECK(seeds->ctx.device_type == kDLGPU) << "seeds should be in GPU.";
CHECK(metapath->ctx.device_type == kDLGPU) << "metapath should be in GPU.";
const IdType *seed_data = static_cast<const IdType*>(seeds->data);
CHECK(seeds->ndim == 1) << "seeds shape is not one dimension.";
const int64_t num_seeds = seeds->shape[0];
int64_t trace_length = max_num_steps + 1;
IdArray traces = IdArray::Empty({num_seeds, trace_length}, seeds->dtype, seeds->ctx);
IdArray eids = IdArray::Empty({num_seeds, max_num_steps}, seeds->dtype, seeds->ctx);
IdType *traces_data = traces.Ptr<IdType>();
IdType *eids_data = eids.Ptr<IdType>();
GraphKernelData<IdType> h_graphs[num_etypes];
DGLContext ctx;
for (int64_t etype = 0; etype < num_etypes; ++etype) {
const CSRMatrix &csr = hg->GetCSRMatrix(etype);
ctx = csr.indptr->ctx;
CHECK(ctx.device_type == kDLGPU) << "graph should be in GPU.";
h_graphs[etype].in_ptr = static_cast<const IdType*>(csr.indptr->data);
h_graphs[etype].in_cols = static_cast<const IdType*>(csr.indices->data);
h_graphs[etype].data = (CSRHasData(csr) ? static_cast<const IdType*>(csr.data->data) : nullptr);
}
// use default stream
cudaStream_t stream = 0;
auto device = DeviceAPI::Get(ctx);
auto d_graphs = static_cast<GraphKernelData<IdType>*>(
device->AllocWorkspace(ctx, (num_etypes) * sizeof(GraphKernelData<IdType>)));
auto d_metapath_data = metapath_data;
// copy graph metadata pointers to GPU
device->CopyDataFromTo(h_graphs, 0, d_graphs, 0,
(num_etypes) * sizeof(GraphKernelData<IdType>),
DGLContext{kDLCPU, 0},
ctx,
hg->GetCSRMatrix(0).indptr->dtype,
stream);
constexpr int BLOCK_SIZE = 256;
constexpr int TILE_SIZE = BLOCK_SIZE * 4;
dim3 block(256);
dim3 grid((num_seeds + TILE_SIZE - 1) / TILE_SIZE);
const uint64_t random_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
ATEN_FLOAT_TYPE_SWITCH(restart_prob->dtype, FloatType, "random walk GPU kernel", {
CHECK(restart_prob->ctx.device_type == kDLGPU) << "restart prob should be in GPU.";
CHECK(restart_prob->ndim == 1) << "restart prob dimension should be 1.";
const FloatType *restart_prob_data = restart_prob.Ptr<FloatType>();
const int64_t restart_prob_size = restart_prob->shape[0];
_RandomWalkKernel<IdType, FloatType, BLOCK_SIZE, TILE_SIZE> <<<grid, block, 0, stream>>>(
random_seed,
seed_data,
num_seeds,
d_metapath_data,
max_num_steps,
d_graphs,
restart_prob_data,
restart_prob_size,
traces_data,
eids_data);
});
device->FreeWorkspace(ctx, d_graphs);
return std::make_pair(traces, eids);
}
template<DLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> RandomWalk(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob) {
// not support no-uniform choice now
for (const auto &etype_prob : prob) {
if (!IsNullArray(etype_prob)) {
LOG(FATAL) << "Non-uniform choice is not supported in GPU.";
}
}
auto restart_prob = NDArray::Empty(
{0}, DLDataType{kDLFloat, 32, 1}, DGLContext{XPU, 0});
return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob);
}
template<DLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> RandomWalkWithRestart(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob,
double restart_prob) {
// not support no-uniform choice now
for (const auto &etype_prob : prob) {
if (!IsNullArray(etype_prob)) {
LOG(FATAL) << "Non-uniform choice is not supported in GPU.";
}
}
auto restart_prob_array = NDArray::Empty(
{1}, DLDataType{kDLFloat, 64, 1}, seeds->ctx);
auto device_ctx = restart_prob_array->ctx;
auto device = dgl::runtime::DeviceAPI::Get(device_ctx);
// use default stream
cudaStream_t stream = 0;
device->CopyDataFromTo(
&restart_prob, 0, restart_prob_array.Ptr<double>(), 0,
sizeof(double),
DGLContext{kDLCPU, 0}, device_ctx,
restart_prob_array->dtype, stream);
device->StreamSync(device_ctx, stream);
return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob_array);
}
template<DLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob,
FloatArray restart_prob) {
// not support no-uniform choice now
for (const auto &etype_prob : prob) {
if (!IsNullArray(etype_prob)) {
LOG(FATAL) << "Non-uniform choice is not supported in GPU.";
}
}
return RandomWalkUniform<XPU, IdType>(hg, seeds, metapath, restart_prob);
}
template<DLDeviceType XPU, typename IdxType>
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
const IdArray src,
const IdArray dst,
const int64_t num_samples_per_node,
const int64_t k) {
CHECK(src->ctx.device_type == kDLGPU) <<
"IdArray needs be on GPU!";
const IdxType* src_data = src.Ptr<IdxType>();
const IdxType* dst_data = dst.Ptr<IdxType>();
const int64_t num_dst_nodes = (dst->shape[0] / num_samples_per_node);
auto ctx = src->ctx;
// use default stream
cudaStream_t stream = 0;
auto frequency_hashmap = FrequencyHashmap<IdxType>(num_dst_nodes,
num_samples_per_node, ctx, stream);
auto ret = frequency_hashmap.Topk(src_data, dst_data, src->dtype,
src->shape[0], num_samples_per_node, k);
return ret;
}
template
std::pair<IdArray, IdArray> RandomWalk<kDLGPU, int32_t>(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob);
template
std::pair<IdArray, IdArray> RandomWalk<kDLGPU, int64_t>(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob);
template
std::pair<IdArray, IdArray> RandomWalkWithRestart<kDLGPU, int32_t>(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob,
double restart_prob);
template
std::pair<IdArray, IdArray> RandomWalkWithRestart<kDLGPU, int64_t>(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob,
double restart_prob);
template
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDLGPU, int32_t>(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob,
FloatArray restart_prob);
template
std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart<kDLGPU, int64_t>(
const HeteroGraphPtr hg,
const IdArray seeds,
const TypeArray metapath,
const std::vector<FloatArray> &prob,
FloatArray restart_prob);
template
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDLGPU, int32_t>(
const IdArray src,
const IdArray dst,
const int64_t num_samples_per_node,
const int64_t k);
template
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDLGPU, int64_t>(
const IdArray src,
const IdArray dst,
const int64_t num_samples_per_node,
const int64_t k);
}; // namespace impl
}; // namespace sampling
}; // namespace dgl
......@@ -51,7 +51,7 @@ std::tuple<IdArray, IdArray, TypeArray> RandomWalk(
TypeArray vtypes;
std::pair<IdArray, IdArray> result;
ATEN_XPU_SWITCH(hg->Context().device_type, XPU, "RandomWalk", {
ATEN_XPU_SWITCH_CUDA(hg->Context().device_type, XPU, "RandomWalk", {
ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
result = impl::RandomWalk<XPU, IdxType>(hg, seeds, metapath, prob);
......@@ -72,7 +72,7 @@ std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithRestart(
TypeArray vtypes;
std::pair<IdArray, IdArray> result;
ATEN_XPU_SWITCH(hg->Context().device_type, XPU, "RandomWalkWithRestart", {
ATEN_XPU_SWITCH_CUDA(hg->Context().device_type, XPU, "RandomWalkWithRestart", {
ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
result = impl::RandomWalkWithRestart<XPU, IdxType>(hg, seeds, metapath, prob, restart_prob);
......@@ -93,7 +93,7 @@ std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithStepwiseRestart(
TypeArray vtypes;
std::pair<IdArray, IdArray> result;
ATEN_XPU_SWITCH(hg->Context().device_type, XPU, "RandomWalkWithStepwiseRestart", {
ATEN_XPU_SWITCH_CUDA(hg->Context().device_type, XPU, "RandomWalkWithStepwiseRestart", {
ATEN_ID_TYPE_SWITCH(seeds->dtype, IdxType, {
vtypes = impl::GetNodeTypesFromMetapath<XPU, IdxType>(hg, metapath);
result = impl::RandomWalkWithStepwiseRestart<XPU, IdxType>(
......@@ -114,7 +114,7 @@ std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
&& (src->shape[0] == dst->shape[0]));
std::tuple<IdArray, IdArray, IdArray> result;
ATEN_XPU_SWITCH((src->ctx).device_type, XPU, "SelectPinSageNeighbors", {
ATEN_XPU_SWITCH_CUDA((src->ctx).device_type, XPU, "SelectPinSageNeighbors", {
ATEN_ID_TYPE_SWITCH(src->dtype, IdxType, {
result = impl::SelectPinSageNeighbors<XPU, IdxType>(src, dst, num_samples_per_node, k);
});
......
......@@ -146,10 +146,10 @@ def test_pack_traces():
assert F.array_equal(result[2], F.tensor([2, 7], dtype=F.int64))
assert F.array_equal(result[3], F.tensor([0, 2], dtype=F.int64))
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU not implemented")
def test_pinsage_sampling():
def _test_sampler(g, sampler, ntype):
neighbor_g = sampler(F.tensor([0, 2], dtype=F.int64))
seeds = F.copy_to(F.tensor([0, 2], dtype=F.int64), F.ctx())
neighbor_g = sampler(seeds)
assert neighbor_g.ntypes == [ntype]
u, v = neighbor_g.all_edges(form='uv', order='eid')
uv = list(zip(F.asnumpy(u).tolist(), F.asnumpy(v).tolist()))
......@@ -159,6 +159,7 @@ def test_pinsage_sampling():
g = dgl.heterograph({
('item', 'bought-by', 'user'): ([0, 0, 1, 1, 2, 2, 3, 3], [0, 1, 0, 1, 2, 3, 2, 3]),
('user', 'bought', 'item'): ([0, 1, 0, 1, 2, 3, 2, 3], [0, 0, 1, 1, 2, 2, 3, 3])})
g = g.to(F.ctx())
sampler = dgl.sampling.PinSAGESampler(g, 'item', 'user', 4, 0.5, 3, 2)
_test_sampler(g, sampler, 'item')
sampler = dgl.sampling.RandomWalkNeighborSampler(g, 4, 0.5, 3, 2, ['bought-by', 'bought'])
......@@ -168,12 +169,14 @@ def test_pinsage_sampling():
_test_sampler(g, sampler, 'item')
g = dgl.graph(([0, 0, 1, 1, 2, 2, 3, 3],
[0, 1, 0, 1, 2, 3, 2, 3]))
g = g.to(F.ctx())
sampler = dgl.sampling.RandomWalkNeighborSampler(g, 4, 0.5, 3, 2)
_test_sampler(g, sampler, g.ntypes[0])
g = dgl.heterograph({
('A', 'AB', 'B'): ([0, 2], [1, 3]),
('B', 'BC', 'C'): ([1, 3], [2, 1]),
('C', 'CA', 'A'): ([2, 1], [0, 2])})
g = g.to(F.ctx())
sampler = dgl.sampling.RandomWalkNeighborSampler(g, 4, 0.5, 3, 2, ['AB', 'BC', 'CA'])
_test_sampler(g, sampler, 'A')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment