Unverified Commit a2c5472a authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Labor dependent template specialization. (#7220)

parent 74c5e31d
......@@ -92,6 +92,31 @@ class continuous_seed {
#endif // __CUDA_ARCH__
};
class single_seed {
uint64_t seed_;
public:
/* implicit */ single_seed(const int64_t seed) : seed_(seed) {} // NOLINT
single_seed(torch::Tensor seed_arr)
: seed_(seed_arr.data_ptr<int64_t>()[0]) {}
#ifdef __CUDACC__
__device__ inline float uniform(const uint64_t id) const {
const uint64_t kCurandSeed = 999961; // Could be any random number.
curandStatePhilox4_32_10_t rng;
curand_init(kCurandSeed, seed_, id, &rng);
return curand_uniform(&rng);
}
#else
inline float uniform(const uint64_t id) const {
pcg32 ng0(seed_, id);
std::uniform_real_distribution<float> uni;
return uni(ng0);
}
#endif // __CUDA_ARCH__
};
} // namespace graphbolt
#endif // GRAPHBOLT_CONTINUOUS_SEED_H_
......@@ -17,7 +17,11 @@
namespace graphbolt {
namespace sampling {
enum SamplerType { NEIGHBOR, LABOR };
enum SamplerType { NEIGHBOR, LABOR, LABOR_DEPENDENT };
constexpr bool is_labor(SamplerType S) {
return S == SamplerType::LABOR || S == SamplerType::LABOR_DEPENDENT;
}
template <SamplerType S>
struct SamplerArgs;
......@@ -27,6 +31,13 @@ struct SamplerArgs<SamplerType::NEIGHBOR> {};
template <>
struct SamplerArgs<SamplerType::LABOR> {
const torch::Tensor& indices;
single_seed random_seed;
int64_t num_nodes;
};
template <>
struct SamplerArgs<SamplerType::LABOR_DEPENDENT> {
const torch::Tensor& indices;
continuous_seed random_seed;
int64_t num_nodes;
......@@ -555,12 +566,12 @@ int64_t Pick(
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr);
template <typename PickedType>
int64_t Pick(
template <SamplerType S, typename PickedType>
std::enable_if_t<is_labor(S), int64_t> Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr);
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr);
template <typename PickedType>
int64_t TemporalPick(
......@@ -619,13 +630,13 @@ int64_t TemporalPickByEtype(
PickedType* picked_data_ptr);
template <
bool NonUniform, bool Replace, typename ProbsType, typename PickedType,
int StackSize = 1024>
int64_t LaborPick(
bool NonUniform, bool Replace, typename ProbsType, SamplerType S,
typename PickedType, int StackSize = 1024>
std::enable_if_t<is_labor(S), int64_t> LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr);
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr);
} // namespace sampling
} // namespace graphbolt
......
......@@ -15,6 +15,7 @@
#include <limits>
#include <numeric>
#include <tuple>
#include <type_traits>
#include <vector>
#include "./macro.h"
......@@ -660,12 +661,22 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}
if (layer) {
SamplerArgs<SamplerType::LABOR> args = [&] {
if (random_seed.has_value()) {
return SamplerArgs<SamplerType::LABOR>{
if (random_seed.has_value() && random_seed->numel() >= 2) {
SamplerArgs<SamplerType::LABOR_DEPENDENT> args{
indices_,
{random_seed.value(), static_cast<float>(seed2_contribution)},
NumNodes()};
return SampleNeighborsImpl(
nodes.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_,
probs_or_mask, args));
} else {
auto args = [&] {
if (random_seed.has_value() && random_seed->numel() == 1) {
return SamplerArgs<SamplerType::LABOR>{
indices_, random_seed.value(), NumNodes()};
} else {
return SamplerArgs<SamplerType::LABOR>{
indices_,
......@@ -678,8 +689,9 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
nodes.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask,
args));
fanouts, replace, indptr_.options(), type_per_edge_,
probs_or_mask, args));
}
} else {
SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl(
......@@ -1297,7 +1309,7 @@ int64_t TemporalPick(
}
return picked_indices.numel();
}
if constexpr (S == SamplerType::LABOR) {
if constexpr (is_labor(S)) {
return Pick(
offset, num_neighbors, fanout, replace, options, masked_prob, args,
picked_data_ptr);
......@@ -1383,12 +1395,12 @@ int64_t TemporalPickByEtype(
return pick_offset;
}
template <typename PickedType>
int64_t Pick(
template <SamplerType S, typename PickedType>
std::enable_if_t<is_labor(S), int64_t> Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) {
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr) {
if (fanout == 0) return 0;
if (probs_or_mask.has_value()) {
if (fanout < 0) {
......@@ -1438,9 +1450,9 @@ inline T invcdf(T u, int64_t n, T rem) {
return rem * (one - std::pow(one - u, one / n));
}
template <typename T>
template <typename T, typename seed_t>
inline T jth_sorted_uniform_random(
continuous_seed seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) {
seed_t seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) {
const T u = seed.uniform(t + j * c);
// https://mathematica.stackexchange.com/a/256707
rem -= invcdf(u, n, rem);
......@@ -1474,13 +1486,13 @@ inline T jth_sorted_uniform_random(
* should be put. Enough memory space should be allocated in advance.
*/
template <
bool NonUniform, bool Replace, typename ProbsType, typename PickedType,
int StackSize>
inline int64_t LaborPick(
bool NonUniform, bool Replace, typename ProbsType, SamplerType S,
typename PickedType, int StackSize>
inline std::enable_if_t<is_labor(S), int64_t> LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) {
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr) {
fanout = Replace ? fanout : std::min(fanout, num_neighbors);
if (!NonUniform && !Replace && fanout >= num_neighbors) {
std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset);
......@@ -1504,8 +1516,8 @@ inline int64_t LaborPick(
}
AT_DISPATCH_INDEX_TYPES(
args.indices.scalar_type(), "LaborPickMain", ([&] {
const index_t* local_indices_data =
args.indices.data_ptr<index_t>() + offset;
const auto local_indices_data =
reinterpret_cast<index_t*>(args.indices.data_ptr()) + offset;
if constexpr (Replace) {
// [Algorithm] @mfbalin
// Use a max-heap to get rid of the big random numbers and filter the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment