"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "75d53cc83966b4046e5a329ddf7baa6aa24f52e2"
Unverified Commit a2c5472a authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Labor dependent template specialization. (#7220)

parent 74c5e31d
...@@ -92,6 +92,31 @@ class continuous_seed { ...@@ -92,6 +92,31 @@ class continuous_seed {
#endif // __CUDA_ARCH__ #endif // __CUDA_ARCH__
}; };
class single_seed {
uint64_t seed_;
public:
/* implicit */ single_seed(const int64_t seed) : seed_(seed) {} // NOLINT
single_seed(torch::Tensor seed_arr)
: seed_(seed_arr.data_ptr<int64_t>()[0]) {}
#ifdef __CUDACC__
__device__ inline float uniform(const uint64_t id) const {
const uint64_t kCurandSeed = 999961; // Could be any random number.
curandStatePhilox4_32_10_t rng;
curand_init(kCurandSeed, seed_, id, &rng);
return curand_uniform(&rng);
}
#else
inline float uniform(const uint64_t id) const {
pcg32 ng0(seed_, id);
std::uniform_real_distribution<float> uni;
return uni(ng0);
}
#endif // __CUDA_ARCH__
};
} // namespace graphbolt } // namespace graphbolt
#endif // GRAPHBOLT_CONTINUOUS_SEED_H_ #endif // GRAPHBOLT_CONTINUOUS_SEED_H_
...@@ -17,7 +17,11 @@ ...@@ -17,7 +17,11 @@
namespace graphbolt { namespace graphbolt {
namespace sampling { namespace sampling {
enum SamplerType { NEIGHBOR, LABOR }; enum SamplerType { NEIGHBOR, LABOR, LABOR_DEPENDENT };
constexpr bool is_labor(SamplerType S) {
return S == SamplerType::LABOR || S == SamplerType::LABOR_DEPENDENT;
}
template <SamplerType S> template <SamplerType S>
struct SamplerArgs; struct SamplerArgs;
...@@ -27,6 +31,13 @@ struct SamplerArgs<SamplerType::NEIGHBOR> {}; ...@@ -27,6 +31,13 @@ struct SamplerArgs<SamplerType::NEIGHBOR> {};
template <> template <>
struct SamplerArgs<SamplerType::LABOR> { struct SamplerArgs<SamplerType::LABOR> {
const torch::Tensor& indices;
single_seed random_seed;
int64_t num_nodes;
};
template <>
struct SamplerArgs<SamplerType::LABOR_DEPENDENT> {
const torch::Tensor& indices; const torch::Tensor& indices;
continuous_seed random_seed; continuous_seed random_seed;
int64_t num_nodes; int64_t num_nodes;
...@@ -555,12 +566,12 @@ int64_t Pick( ...@@ -555,12 +566,12 @@ int64_t Pick(
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr); SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr);
template <typename PickedType> template <SamplerType S, typename PickedType>
int64_t Pick( std::enable_if_t<is_labor(S), int64_t> Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr); PickedType* picked_data_ptr);
template <typename PickedType> template <typename PickedType>
int64_t TemporalPick( int64_t TemporalPick(
...@@ -619,13 +630,13 @@ int64_t TemporalPickByEtype( ...@@ -619,13 +630,13 @@ int64_t TemporalPickByEtype(
PickedType* picked_data_ptr); PickedType* picked_data_ptr);
template < template <
bool NonUniform, bool Replace, typename ProbsType, typename PickedType, bool NonUniform, bool Replace, typename ProbsType, SamplerType S,
int StackSize = 1024> typename PickedType, int StackSize = 1024>
int64_t LaborPick( std::enable_if_t<is_labor(S), int64_t> LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr); PickedType* picked_data_ptr);
} // namespace sampling } // namespace sampling
} // namespace graphbolt } // namespace graphbolt
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <limits> #include <limits>
#include <numeric> #include <numeric>
#include <tuple> #include <tuple>
#include <type_traits>
#include <vector> #include <vector>
#include "./macro.h" #include "./macro.h"
...@@ -660,26 +661,37 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( ...@@ -660,26 +661,37 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
} }
if (layer) { if (layer) {
SamplerArgs<SamplerType::LABOR> args = [&] { if (random_seed.has_value() && random_seed->numel() >= 2) {
if (random_seed.has_value()) { SamplerArgs<SamplerType::LABOR_DEPENDENT> args{
return SamplerArgs<SamplerType::LABOR>{ indices_,
indices_, {random_seed.value(), static_cast<float>(seed2_contribution)},
{random_seed.value(), static_cast<float>(seed2_contribution)}, NumNodes()};
NumNodes()}; return SampleNeighborsImpl(
} else { nodes.value(), return_eids,
return SamplerArgs<SamplerType::LABOR>{ GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
indices_, GetPickFn(
RandomEngine::ThreadLocal()->RandInt( fanouts, replace, indptr_.options(), type_per_edge_,
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()), probs_or_mask, args));
NumNodes()}; } else {
} auto args = [&] {
}(); if (random_seed.has_value() && random_seed->numel() == 1) {
return SampleNeighborsImpl( return SamplerArgs<SamplerType::LABOR>{
nodes.value(), return_eids, indices_, random_seed.value(), NumNodes()};
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask), } else {
GetPickFn( return SamplerArgs<SamplerType::LABOR>{
fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask, indices_,
args)); RandomEngine::ThreadLocal()->RandInt(
static_cast<int64_t>(0), std::numeric_limits<int64_t>::max()),
NumNodes()};
}
}();
return SampleNeighborsImpl(
nodes.value(), return_eids,
GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask),
GetPickFn(
fanouts, replace, indptr_.options(), type_per_edge_,
probs_or_mask, args));
}
} else { } else {
SamplerArgs<SamplerType::NEIGHBOR> args; SamplerArgs<SamplerType::NEIGHBOR> args;
return SampleNeighborsImpl( return SampleNeighborsImpl(
...@@ -1297,7 +1309,7 @@ int64_t TemporalPick( ...@@ -1297,7 +1309,7 @@ int64_t TemporalPick(
} }
return picked_indices.numel(); return picked_indices.numel();
} }
if constexpr (S == SamplerType::LABOR) { if constexpr (is_labor(S)) {
return Pick( return Pick(
offset, num_neighbors, fanout, replace, options, masked_prob, args, offset, num_neighbors, fanout, replace, options, masked_prob, args,
picked_data_ptr); picked_data_ptr);
...@@ -1383,12 +1395,12 @@ int64_t TemporalPickByEtype( ...@@ -1383,12 +1395,12 @@ int64_t TemporalPickByEtype(
return pick_offset; return pick_offset;
} }
template <typename PickedType> template <SamplerType S, typename PickedType>
int64_t Pick( std::enable_if_t<is_labor(S), int64_t> Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) { PickedType* picked_data_ptr) {
if (fanout == 0) return 0; if (fanout == 0) return 0;
if (probs_or_mask.has_value()) { if (probs_or_mask.has_value()) {
if (fanout < 0) { if (fanout < 0) {
...@@ -1438,9 +1450,9 @@ inline T invcdf(T u, int64_t n, T rem) { ...@@ -1438,9 +1450,9 @@ inline T invcdf(T u, int64_t n, T rem) {
return rem * (one - std::pow(one - u, one / n)); return rem * (one - std::pow(one - u, one / n));
} }
template <typename T> template <typename T, typename seed_t>
inline T jth_sorted_uniform_random( inline T jth_sorted_uniform_random(
continuous_seed seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) { seed_t seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) {
const T u = seed.uniform(t + j * c); const T u = seed.uniform(t + j * c);
// https://mathematica.stackexchange.com/a/256707 // https://mathematica.stackexchange.com/a/256707
rem -= invcdf(u, n, rem); rem -= invcdf(u, n, rem);
...@@ -1474,13 +1486,13 @@ inline T jth_sorted_uniform_random( ...@@ -1474,13 +1486,13 @@ inline T jth_sorted_uniform_random(
* should be put. Enough memory space should be allocated in advance. * should be put. Enough memory space should be allocated in advance.
*/ */
template < template <
bool NonUniform, bool Replace, typename ProbsType, typename PickedType, bool NonUniform, bool Replace, typename ProbsType, SamplerType S,
int StackSize> typename PickedType, int StackSize>
inline int64_t LaborPick( inline std::enable_if_t<is_labor(S), int64_t> LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) { PickedType* picked_data_ptr) {
fanout = Replace ? fanout : std::min(fanout, num_neighbors); fanout = Replace ? fanout : std::min(fanout, num_neighbors);
if (!NonUniform && !Replace && fanout >= num_neighbors) { if (!NonUniform && !Replace && fanout >= num_neighbors) {
std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset); std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset);
...@@ -1504,8 +1516,8 @@ inline int64_t LaborPick( ...@@ -1504,8 +1516,8 @@ inline int64_t LaborPick(
} }
AT_DISPATCH_INDEX_TYPES( AT_DISPATCH_INDEX_TYPES(
args.indices.scalar_type(), "LaborPickMain", ([&] { args.indices.scalar_type(), "LaborPickMain", ([&] {
const index_t* local_indices_data = const auto local_indices_data =
args.indices.data_ptr<index_t>() + offset; reinterpret_cast<index_t*>(args.indices.data_ptr()) + offset;
if constexpr (Replace) { if constexpr (Replace) {
// [Algorithm] @mfbalin // [Algorithm] @mfbalin
// Use a max-heap to get rid of the big random numbers and filter the // Use a max-heap to get rid of the big random numbers and filter the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment