Unverified Commit 44f0b5fe authored by lixiaobai's avatar lixiaobai Committed by GitHub
Browse files

[PinSAGE samper] Adjust the APIs for PinSAGESamper (#3529)



* Feat: support API "randomwalk_topk" in library

* Feat: use the new API "randomwalk_topk" for PinSAGESampler

* Minor

* Minor

* Refactor: modified codes as checker required

* Minor

* Minor

* Minor

* Minor

* Fix: checking errors in RandomWalkTopk

* Refactor: modified the docstring for randomwalk_topk

* change randomwalk_topk to internal

* fix

* rename

* Minor for pinsage.py
Co-authored-by: default avatarQuan Gan <coin2028@hotmail.com>
parent da53275a
"""PinSAGE sampler & related functions and classes"""
import numpy as np
from .._ffi.function import _init_api
from .. import backend as F
from .. import convert
from .. import transform
from .randomwalks import random_walk
from .neighbor import select_topk
from ..base import EID
from .. import utils
def _select_pinsage_neighbors(src, dst, num_samples_per_node, k):
"""Determine the neighbors for PinSAGE algorithm from the given random walk traces.
This is fusing ``to_simple()``, ``select_topk()``, and counting the number of occurrences
together.
"""
src = F.to_dgl_nd(src)
dst = F.to_dgl_nd(dst)
src, dst, counts = _CAPI_DGLSamplingSelectPinSageNeighbors(src, dst, num_samples_per_node, k)
src = F.from_dgl_nd(src)
dst = F.from_dgl_nd(dst)
counts = F.from_dgl_nd(counts)
return (src, dst, counts)
class RandomWalkNeighborSampler(object):
"""PinSage-like neighbor sampler extended to any heterogeneous graphs.
......@@ -109,20 +120,13 @@ class RandomWalkNeighborSampler(object):
src = F.reshape(paths[:, self.metapath_hops::self.metapath_hops], (-1,))
dst = F.repeat(paths[:, 0], self.num_traversals, 0)
src_mask = (src != -1)
src = F.boolean_mask(src, src_mask)
dst = F.boolean_mask(dst, src_mask)
# count the number of visits and pick the K-most frequent neighbors for each node
src, dst, counts = _select_pinsage_neighbors(
src, dst, (self.num_random_walks * self.num_traversals), self.num_neighbors)
neighbor_graph = convert.heterograph(
{(self.ntype, '_E', self.ntype): (src, dst)},
{self.ntype: self.G.number_of_nodes(self.ntype)}
)
neighbor_graph = transform.to_simple(neighbor_graph, return_counts=self.weight_column)
counts = neighbor_graph.edata[self.weight_column]
neighbor_graph = select_topk(neighbor_graph, self.num_neighbors, self.weight_column)
selected_counts = F.gather_row(counts, neighbor_graph.edata[EID])
neighbor_graph.edata[self.weight_column] = selected_counts
neighbor_graph.edata[self.weight_column] = counts
return neighbor_graph
......@@ -215,3 +219,5 @@ class PinSAGESampler(RandomWalkNeighborSampler):
super().__init__(G, num_traversals,
termination_prob, num_random_walks, num_neighbors,
metapath=[fw_etype, bw_etype], weight_column=weight_column)
_init_api('dgl.sampling.pinsage', __name__)
......@@ -6,8 +6,10 @@
#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include <dgl/runtime/device_api.h>
#include <vector>
#include <utility>
#include <algorithm>
#include "randomwalks_impl.h"
#include "randomwalks_cpu.h"
#include "metapath_randomwalk.h"
......@@ -35,6 +37,75 @@ std::pair<IdArray, IdArray> RandomWalk(
return MetapathBasedRandomWalk<XPU, IdxType>(hg, seeds, metapath, prob, terminate);
}
template<DLDeviceType XPU, typename IdxType>
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
const IdArray src,
const IdArray dst,
const int64_t num_samples_per_node,
const int64_t k) {
CHECK(src->ctx.device_type == kDLCPU) << "IdArray needs be on CPU!";
int64_t len = src->shape[0] / num_samples_per_node;
IdxType* src_data = src.Ptr<IdxType>();
const IdxType* dst_data = dst.Ptr<IdxType>();
std::vector<IdxType> res_src_vec, res_dst_vec, res_cnt_vec;
for (int64_t i = 0; i < len; ++i) {
int64_t start_idx = (i * num_samples_per_node);
int64_t end_idx = (start_idx + num_samples_per_node);
IdxType dst_node = dst_data[start_idx];
std::sort(src_data + start_idx, src_data + end_idx);
int64_t cnt = 0;
std::vector<std::pair<IdxType, IdxType>> vec;
for (int64_t j = start_idx; j < end_idx; ++j) {
if ((j != start_idx) && (src_data[j] != src_data[j-1])) {
if (src_data[j-1] != -1) {
vec.emplace_back(std::make_pair(cnt, src_data[j-1]));
}
cnt = 0;
}
++cnt;
}
// add last count
if (src_data[end_idx-1] != -1) {
vec.emplace_back(std::make_pair(cnt, src_data[end_idx-1]));
}
std::sort(vec.begin(), vec.end(),
std::greater<std::pair<IdxType, IdxType>>());
int64_t len = std::min(vec.size(), static_cast<size_t>(k));
for (int64_t j = 0; j < len; ++j) {
auto pair_item = vec[j];
res_src_vec.emplace_back(pair_item.second);
res_dst_vec.emplace_back(dst_node);
res_cnt_vec.emplace_back(pair_item.first);
}
}
IdArray res_src = IdArray::Empty({static_cast<int64_t>(res_src_vec.size())},
src->dtype, src->ctx);
IdArray res_dst = IdArray::Empty({static_cast<int64_t>(res_dst_vec.size())},
dst->dtype, dst->ctx);
IdArray res_cnt = IdArray::Empty({static_cast<int64_t>(res_cnt_vec.size())},
src->dtype, src->ctx);
// copy data from vector to NDArray
auto device = runtime::DeviceAPI::Get(src->ctx);
device->CopyDataFromTo(static_cast<IdxType*>(res_src_vec.data()), 0,
res_src.Ptr<IdxType>(), 0,
sizeof(IdxType) * res_src_vec.size(),
DGLContext{kDLCPU, 0}, res_src->ctx,
res_src->dtype, 0);
device->CopyDataFromTo(static_cast<IdxType*>(res_dst_vec.data()), 0,
res_dst.Ptr<IdxType>(), 0,
sizeof(IdxType) * res_dst_vec.size(),
DGLContext{kDLCPU, 0}, res_dst->ctx,
res_dst->dtype, 0);
device->CopyDataFromTo(static_cast<IdxType*>(res_cnt_vec.data()), 0,
res_cnt.Ptr<IdxType>(), 0,
sizeof(IdxType) * res_cnt_vec.size(),
DGLContext{kDLCPU, 0}, res_cnt->ctx,
res_cnt->dtype, 0);
return std::make_tuple(res_src, res_dst, res_cnt);
}
template
std::pair<IdArray, IdArray> RandomWalk<kDLCPU, int32_t>(
const HeteroGraphPtr hg,
......@@ -48,6 +119,19 @@ std::pair<IdArray, IdArray> RandomWalk<kDLCPU, int64_t>(
const TypeArray metapath,
const std::vector<FloatArray> &prob);
template
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDLCPU, int32_t>(
const IdArray src,
const IdArray dst,
const int64_t num_samples_per_node,
const int64_t k);
template
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors<kDLCPU, int64_t>(
const IdArray src,
const IdArray dst,
const int64_t num_samples_per_node,
const int64_t k);
}; // namespace impl
}; // namespace sampling
......
......@@ -104,6 +104,25 @@ std::tuple<IdArray, IdArray, TypeArray> RandomWalkWithStepwiseRestart(
return std::make_tuple(result.first, result.second, vtypes);
}
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
const IdArray src,
const IdArray dst,
const int64_t num_samples_per_node,
const int64_t k) {
assert((src->ndim == 1) && (dst->ndim == 1)
&& (src->shape[0] % num_samples_per_node == 0)
&& (src->shape[0] == dst->shape[0]));
std::tuple<IdArray, IdArray, IdArray> result;
ATEN_XPU_SWITCH((src->ctx).device_type, XPU, "SelectPinSageNeighbors", {
ATEN_ID_TYPE_SWITCH(src->dtype, IdxType, {
result = impl::SelectPinSageNeighbors<XPU, IdxType>(src, dst, num_samples_per_node, k);
});
});
return result;
}
}; // namespace sampling
DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalk")
......@@ -123,6 +142,22 @@ DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalk")
*rv = ret;
});
DGL_REGISTER_GLOBAL("sampling.pinsage._CAPI_DGLSamplingSelectPinSageNeighbors")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
IdArray src = args[0];
IdArray dst = args[1];
int64_t num_travelsals = static_cast<int64_t>(args[2]);
int64_t k = static_cast<int64_t>(args[3]);
auto result = sampling::SelectPinSageNeighbors(src, dst, num_travelsals, k);
List<Value> ret;
ret.push_back(Value(MakeValue(std::get<0>(result))));
ret.push_back(Value(MakeValue(std::get<1>(result))));
ret.push_back(Value(MakeValue(std::get<2>(result))));
*rv = ret;
});
DGL_REGISTER_GLOBAL("sampling.randomwalks._CAPI_DGLSamplingRandomWalkWithRestart")
.set_body([] (DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef hg = args[0];
......
......@@ -115,6 +115,13 @@ std::pair<IdArray, IdArray> RandomWalkWithStepwiseRestart(
const std::vector<FloatArray> &prob,
FloatArray restart_prob);
template<DLDeviceType XPU, typename IdxType>
std::tuple<IdArray, IdArray, IdArray> SelectPinSageNeighbors(
const IdArray src,
const IdArray dst,
const int64_t num_samples_per_node,
const int64_t k);
}; // namespace impl
}; // namespace sampling
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment