Unverified Commit 58d98e03 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[Graphbolt] Utilize pre-allocation in sampling (#6132)

parent f0d8ca1e
...@@ -353,28 +353,22 @@ int64_t NumPickByEtype( ...@@ -353,28 +353,22 @@ int64_t NumPickByEtype(
* probabilities associated with each neighboring edge of a node in the original * probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements * graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph. * equal to the number of edges in the graph.
* * @param picked_data_ptr The destination address where the picked neighbors
* @return A tensor containing the picked neighbors. * should be put. Enough memory space should be allocated in advance.
*/ */
template <SamplerType S> template <typename PickedType>
torch::Tensor Pick( void Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args);
template <>
torch::Tensor Pick<SamplerType::NEIGHBOR>(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::NEIGHBOR> args); SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr);
template <> template <typename PickedType>
torch::Tensor Pick<SamplerType::LABOR>( void Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args); SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr);
/** /**
* @brief Picks a specified number of neighbors for a node per edge type, * @brief Picks a specified number of neighbors for a node per edge type,
...@@ -400,22 +394,25 @@ torch::Tensor Pick<SamplerType::LABOR>( ...@@ -400,22 +394,25 @@ torch::Tensor Pick<SamplerType::LABOR>(
* probabilities associated with each neighboring edge of a node in the original * probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements * graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph. * equal to the number of edges in the graph.
* * @param picked_data_ptr The destination address where the picked neighbors
* @return A tensor containing the picked neighbors. * should be put. Enough memory space should be allocated in advance.
*/ */
template <SamplerType S> template <SamplerType S, typename PickedType>
torch::Tensor PickByEtype( void PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts, int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options, bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge, const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args); const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr);
template <bool NonUniform, bool Replace, typename T = float> template <
torch::Tensor LaborPick( bool NonUniform, bool Replace, typename ProbsType = float,
typename PickedType>
void LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args); SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr);
} // namespace sampling } // namespace sampling
} // namespace graphbolt } // namespace graphbolt
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <cmath> #include <cmath>
#include <limits> #include <limits>
#include <numeric>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
...@@ -187,10 +188,11 @@ auto GetNumPickFn( ...@@ -187,10 +188,11 @@ auto GetNumPickFn(
* equal to the number of edges in the graph. * equal to the number of edges in the graph.
* @param args Contains sampling algorithm specific arguments. * @param args Contains sampling algorithm specific arguments.
* *
* @return A lambda function: (int64_t offset, int64_t num_neighbors) -> * @return A lambda function: (int64_t offset, int64_t num_neighbors,
* torch::Tensor, which takes offset (the starting edge ID of the given node) * PickedType* picked_data_ptr) -> torch::Tensor, which takes offset (the
* and num_neighbors (number of neighbors) as params and returns a tensor of * starting edge ID of the given node) and num_neighbors (number of neighbors)
* picked neighbors. * as params and puts the picked neighbors at the address specified by
* picked_data_ptr.
*/ */
template <SamplerType S> template <SamplerType S>
auto GetPickFn( auto GetPickFn(
...@@ -199,17 +201,18 @@ auto GetPickFn( ...@@ -199,17 +201,18 @@ auto GetPickFn(
const torch::optional<torch::Tensor>& type_per_edge, const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args) { const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args) {
return [&fanouts, replace, &options, &type_per_edge, &probs_or_mask, args]( return [&fanouts, replace, &options, &type_per_edge, &probs_or_mask, args](
int64_t offset, int64_t num_neighbors) { int64_t offset, int64_t num_neighbors, auto picked_data_ptr) {
// If fanouts.size() > 1, perform sampling for each edge type of each node; // If fanouts.size() > 1, perform sampling for each edge type of each
// otherwise just sample once for each node with no regard of edge types. // node; otherwise just sample once for each node with no regard of edge
// types.
if (fanouts.size() > 1) { if (fanouts.size() > 1) {
return PickByEtype( return PickByEtype(
offset, num_neighbors, fanouts, replace, options, offset, num_neighbors, fanouts, replace, options,
type_per_edge.value(), probs_or_mask, args); type_per_edge.value(), probs_or_mask, args, picked_data_ptr);
} else { } else {
return Pick( return Pick(
offset, num_neighbors, fanouts[0], replace, options, probs_or_mask, offset, num_neighbors, fanouts[0], replace, options, probs_or_mask,
args); args, picked_data_ptr);
} }
}; };
} }
...@@ -257,17 +260,23 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl( ...@@ -257,17 +260,23 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
continue; continue;
} }
// Pre-allocate tensors for each node. Because the pick
// functions are modified, this part of code needed refactoring
// to adapt to the change of APIs. It's temporary since the
// whole process will be rewritten soon.
int64_t allocate_size = num_pick_fn(offset, num_neighbors);
picked_neighbors_cur_thread[i - begin] = picked_neighbors_cur_thread[i - begin] =
pick_fn(offset, num_neighbors); torch::empty({allocate_size}, indptr_options);
torch::Tensor& picked_tensor =
picked_neighbors_cur_thread[i - begin];
AT_DISPATCH_INTEGRAL_TYPES(
picked_tensor.scalar_type(), "CallPick", ([&] {
pick_fn(
offset, num_neighbors,
picked_tensor.data_ptr<scalar_t>());
}));
// This number should be the same as the result of num_pick_fn. num_picked_neighbors_per_node[i + 1] = allocate_size;
num_picked_neighbors_per_node[i + 1] =
picked_neighbors_cur_thread[i - begin].size(0);
TORCH_CHECK(
*num_picked_neighbors_per_node[i + 1].data_ptr<int64_t>() ==
num_pick_fn(offset, num_neighbors),
"Return value of num_pick_fn doesn't match the actual "
"picked number.");
} }
picked_neighbors_per_thread[thread_id] = picked_neighbors_per_thread[thread_id] =
torch::cat(picked_neighbors_cur_thread); torch::cat(picked_neighbors_cur_thread);
...@@ -431,24 +440,22 @@ int64_t NumPickByEtype( ...@@ -431,24 +440,22 @@ int64_t NumPickByEtype(
* without replacement. If True, a value can be selected multiple times. * without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once. * Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result. * @param options Tensor options specifying the desired data type of the result.
* * @param picked_data_ptr The destination address where the picked neighbors
* @return A tensor containing the picked neighbors. * should be put. Enough memory space should be allocated in advance.
*/ */
inline torch::Tensor UniformPick( template <typename PickedType>
inline void UniformPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options) { const torch::TensorOptions& options, PickedType* picked_data_ptr) {
torch::Tensor picked_neighbors;
if ((fanout == -1) || (num_neighbors <= fanout && !replace)) { if ((fanout == -1) || (num_neighbors <= fanout && !replace)) {
picked_neighbors = torch::arange(offset, offset + num_neighbors, options); std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset);
} else if (replace) { } else if (replace) {
picked_neighbors = std::memcpy(
torch::randint(offset, offset + num_neighbors, {fanout}, options); picked_data_ptr,
torch::randint(offset, offset + num_neighbors, {fanout}, options)
.data_ptr<PickedType>(),
fanout * sizeof(PickedType));
} else { } else {
picked_neighbors = torch::empty({fanout}, options);
AT_DISPATCH_INTEGRAL_TYPES(
picked_neighbors.scalar_type(), "UniformPick", ([&] {
scalar_t* picked_neighbors_data =
picked_neighbors.data_ptr<scalar_t>();
// We use different sampling strategies for different sampling case. // We use different sampling strategies for different sampling case.
if (fanout >= num_neighbors / 10) { if (fanout >= num_neighbors / 10) {
// [Algorithm] // [Algorithm]
...@@ -474,7 +481,7 @@ inline torch::Tensor UniformPick( ...@@ -474,7 +481,7 @@ inline torch::Tensor UniformPick(
// to the small size of both `fanout` and `num_neighbors`. And it // to the small size of both `fanout` and `num_neighbors`. And it
// is efficient to allocate a small amount of memory. So the // is efficient to allocate a small amount of memory. So the
// algorithm performence is great in this case. // algorithm performence is great in this case.
std::vector<scalar_t> seq(num_neighbors); std::vector<PickedType> seq(num_neighbors);
// Assign the seq with [offset, offset + num_neighbors]. // Assign the seq with [offset, offset + num_neighbors].
std::iota(seq.begin(), seq.end(), offset); std::iota(seq.begin(), seq.end(), offset);
for (int64_t i = 0; i < fanout; ++i) { for (int64_t i = 0; i < fanout; ++i) {
...@@ -482,7 +489,7 @@ inline torch::Tensor UniformPick( ...@@ -482,7 +489,7 @@ inline torch::Tensor UniformPick(
std::swap(seq[i], seq[j]); std::swap(seq[i], seq[j]);
} }
// Save the randomly sampled fanout elements to the output tensor. // Save the randomly sampled fanout elements to the output tensor.
std::copy(seq.begin(), seq.begin() + fanout, picked_neighbors_data); std::copy(seq.begin(), seq.begin() + fanout, picked_data_ptr);
} else if (fanout < 64) { } else if (fanout < 64) {
// [Algorithm] // [Algorithm]
// Use linear search to verify uniqueness. // Use linear search to verify uniqueness.
...@@ -490,17 +497,17 @@ inline torch::Tensor UniformPick( ...@@ -490,17 +497,17 @@ inline torch::Tensor UniformPick(
// [Complexity Analysis] // [Complexity Analysis]
// Since the set of numbers is small (up to 64), so it is more // Since the set of numbers is small (up to 64), so it is more
// cost-effective for the CPU to use this algorithm. // cost-effective for the CPU to use this algorithm.
auto begin = picked_neighbors_data; auto begin = picked_data_ptr;
auto end = picked_neighbors_data + fanout; auto end = picked_data_ptr + fanout;
while (begin != end) { while (begin != end) {
// Put the new random number in the last position. // Put the new random number in the last position.
*begin = RandomEngine::ThreadLocal()->RandInt( *begin = RandomEngine::ThreadLocal()->RandInt(
offset, offset + num_neighbors); offset, offset + num_neighbors);
// Check if a new value doesn't exist in current // Check if a new value doesn't exist in current
// range(picked_neighbors_data, begin). Otherwise get a new // range(picked_data_ptr, begin). Otherwise get a new
// value until we haven't unique range of elements. // value until we haven't unique range of elements.
auto it = std::find(picked_neighbors_data, begin, *begin); auto it = std::find(picked_data_ptr, begin, *begin);
if (it == begin) ++begin; if (it == begin) ++begin;
} }
} else { } else {
...@@ -520,17 +527,14 @@ inline torch::Tensor UniformPick( ...@@ -520,17 +527,14 @@ inline torch::Tensor UniformPick(
// would otherwise increase the sampling cost. By doing so, we // would otherwise increase the sampling cost. By doing so, we
// achieve a balance between theoretical efficiency and practical // achieve a balance between theoretical efficiency and practical
// performance. // performance.
std::unordered_set<scalar_t> picked_set; std::unordered_set<PickedType> picked_set;
while (static_cast<int64_t>(picked_set.size()) < fanout) { while (static_cast<int64_t>(picked_set.size()) < fanout) {
picked_set.insert(RandomEngine::ThreadLocal()->RandInt( picked_set.insert(RandomEngine::ThreadLocal()->RandInt(
offset, offset + num_neighbors)); offset, offset + num_neighbors));
} }
std::copy( std::copy(picked_set.begin(), picked_set.end(), picked_data_ptr);
picked_set.begin(), picked_set.end(), picked_neighbors_data);
} }
}));
} }
return picked_neighbors;
} }
/** /**
...@@ -565,59 +569,65 @@ inline torch::Tensor UniformPick( ...@@ -565,59 +569,65 @@ inline torch::Tensor UniformPick(
* probabilities associated with each neighboring edge of a node in the original * probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements * graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph. * equal to the number of edges in the graph.
* * @param picked_data_ptr The destination address where the picked neighbors
* @return A tensor containing the picked neighbors. * should be put. Enough memory space should be allocated in advance.
*/ */
inline torch::Tensor NonUniformPick( template <typename PickedType>
inline void NonUniformPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask) { const torch::optional<torch::Tensor>& probs_or_mask,
torch::Tensor picked_neighbors; PickedType* picked_data_ptr) {
auto local_probs = auto local_probs =
probs_or_mask.value().slice(0, offset, offset + num_neighbors); probs_or_mask.value().slice(0, offset, offset + num_neighbors);
auto positive_probs_indices = local_probs.nonzero().squeeze(1); auto positive_probs_indices = local_probs.nonzero().squeeze(1);
auto num_positive_probs = positive_probs_indices.size(0); auto num_positive_probs = positive_probs_indices.size(0);
if (num_positive_probs == 0) return torch::tensor({}, options); if (num_positive_probs == 0) return;
if ((fanout == -1) || (num_positive_probs <= fanout && !replace)) { if ((fanout == -1) || (num_positive_probs <= fanout && !replace)) {
picked_neighbors = torch::arange(offset, offset + num_neighbors, options); std::memcpy(
picked_neighbors = picked_data_ptr,
torch::index_select(picked_neighbors, 0, positive_probs_indices); (positive_probs_indices + offset).data_ptr<PickedType>(),
num_positive_probs * sizeof(PickedType));
} else { } else {
if (!replace) fanout = std::min(fanout, num_positive_probs); if (!replace) fanout = std::min(fanout, num_positive_probs);
picked_neighbors = std::memcpy(
torch::multinomial(local_probs, fanout, replace) + offset; picked_data_ptr,
(torch::multinomial(local_probs, fanout, replace) + offset)
.data_ptr<PickedType>(),
fanout * sizeof(PickedType));
} }
return picked_neighbors;
} }
template <> template <typename PickedType>
torch::Tensor Pick<SamplerType::NEIGHBOR>( void Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::NEIGHBOR> args) { SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr) {
if (probs_or_mask.has_value()) { if (probs_or_mask.has_value()) {
return NonUniformPick( NonUniformPick(
offset, num_neighbors, fanout, replace, options, probs_or_mask); offset, num_neighbors, fanout, replace, options, probs_or_mask,
picked_data_ptr);
} else { } else {
return UniformPick(offset, num_neighbors, fanout, replace, options); UniformPick(
offset, num_neighbors, fanout, replace, options, picked_data_ptr);
} }
} }
template <SamplerType S> template <SamplerType S, typename PickedType>
torch::Tensor PickByEtype( void PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts, int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options, bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge, const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args) { const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
std::vector<torch::Tensor> picked_neighbors( PickedType* picked_data_ptr) {
fanouts.size(), torch::tensor({}, options));
int64_t etype_begin = offset; int64_t etype_begin = offset;
int64_t etype_end = offset; int64_t etype_end = offset;
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
type_per_edge.scalar_type(), "PickByEtype", ([&] { type_per_edge.scalar_type(), "PickByEtype", ([&] {
const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>(); const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>();
const auto end = offset + num_neighbors; const auto end = offset + num_neighbors;
int64_t pick_offset = 0;
while (etype_begin < end) { while (etype_begin < end) {
scalar_t etype = type_per_edge_data[etype_begin]; scalar_t etype = type_per_edge_data[etype_begin];
TORCH_CHECK( TORCH_CHECK(
...@@ -628,53 +638,58 @@ torch::Tensor PickByEtype( ...@@ -628,53 +638,58 @@ torch::Tensor PickByEtype(
type_per_edge_data + etype_begin, type_per_edge_data + end, type_per_edge_data + etype_begin, type_per_edge_data + end,
etype); etype);
etype_end = etype_end_it - type_per_edge_data; etype_end = etype_end_it - type_per_edge_data;
int64_t picked_count = NumPick(
fanout, replace, probs_or_mask, etype_begin,
etype_end - etype_begin);
// Do sampling for one etype. // Do sampling for one etype.
if (fanout != 0) { if (fanout != 0) {
picked_neighbors[etype] = Pick<S>( Pick(
etype_begin, etype_end - etype_begin, fanout, replace, options, etype_begin, etype_end - etype_begin, fanout, replace, options,
probs_or_mask, args); probs_or_mask, args, picked_data_ptr + pick_offset);
pick_offset += picked_count;
} }
etype_begin = etype_end; etype_begin = etype_end;
} }
})); }));
return torch::cat(picked_neighbors, 0);
} }
template <> template <typename PickedType>
torch::Tensor Pick<SamplerType::LABOR>( void Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args) { SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) {
if (fanout == 0) return torch::tensor({}, options); if (fanout == 0) return;
if (probs_or_mask.has_value()) { if (probs_or_mask.has_value()) {
if (fanout < 0) { if (fanout < 0) {
return NonUniformPick( NonUniformPick(
offset, num_neighbors, fanout, replace, options, probs_or_mask); offset, num_neighbors, fanout, replace, options, probs_or_mask,
} picked_data_ptr);
torch::Tensor picked_neighbors; } else {
AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES(
probs_or_mask.value().scalar_type(), "LaborPickFloatType", ([&] { probs_or_mask.value().scalar_type(), "LaborPickFloatType", ([&] {
if (replace) { if (replace) {
picked_neighbors = LaborPick<true, true, scalar_t>( LaborPick<true, true, scalar_t>(
offset, num_neighbors, fanout, options, probs_or_mask, args); offset, num_neighbors, fanout, options, probs_or_mask, args,
picked_data_ptr);
} else { } else {
picked_neighbors = LaborPick<true, false, scalar_t>( LaborPick<true, false, scalar_t>(
offset, num_neighbors, fanout, options, probs_or_mask, args); offset, num_neighbors, fanout, options, probs_or_mask, args,
picked_data_ptr);
} }
})); }));
return picked_neighbors; }
} else if (fanout < 0) { } else if (fanout < 0) {
return UniformPick(offset, num_neighbors, fanout, replace, options); UniformPick(
offset, num_neighbors, fanout, replace, options, picked_data_ptr);
} else if (replace) { } else if (replace) {
return LaborPick<false, true>( LaborPick<false, true>(
offset, num_neighbors, fanout, options, offset, num_neighbors, fanout, options,
/* probs_or_mask= */ torch::nullopt, args); /* probs_or_mask= */ torch::nullopt, args, picked_data_ptr);
} else { // replace = false } else { // replace = false
return LaborPick<false, false>( LaborPick<false, false>(
offset, num_neighbors, fanout, options, offset, num_neighbors, fanout, options,
/* probs_or_mask= */ torch::nullopt, args); /* probs_or_mask= */ torch::nullopt, args, picked_data_ptr);
} }
} }
...@@ -704,25 +719,28 @@ inline void safe_divide(T& a, U b) { ...@@ -704,25 +719,28 @@ inline void safe_divide(T& a, U b) {
* graph. It must be a 1D floating-point tensor with the number of elements * graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph. * equal to the number of edges in the graph.
* @param args Contains labor specific arguments. * @param args Contains labor specific arguments.
* * @param picked_data_ptr The destination address where the picked neighbors
* @return A tensor containing the picked neighbors. * should be put. Enough memory space should be allocated in advance.
*/ */
template <bool NonUniform, bool Replace, typename T> template <
inline torch::Tensor LaborPick( bool NonUniform, bool Replace, typename ProbsType, typename PickedType>
inline void LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args) { SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) {
fanout = Replace ? fanout : std::min(fanout, num_neighbors); fanout = Replace ? fanout : std::min(fanout, num_neighbors);
if (!NonUniform && !Replace && fanout >= num_neighbors) { if (!NonUniform && !Replace && fanout >= num_neighbors) {
return torch::arange(offset, offset + num_neighbors, options); std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset);
return;
} }
torch::Tensor heap_tensor = torch::empty({fanout * 2}, torch::kInt32); torch::Tensor heap_tensor = torch::empty({fanout * 2}, torch::kInt32);
// Assuming max_degree of a vertex is <= 4 billion. // Assuming max_degree of a vertex is <= 4 billion.
auto heap_data = reinterpret_cast<std::pair<float, uint32_t>*>( auto heap_data = reinterpret_cast<std::pair<float, uint32_t>*>(
heap_tensor.data_ptr<int32_t>()); heap_tensor.data_ptr<int32_t>());
const T* local_probs_data = const ProbsType* local_probs_data =
NonUniform ? probs_or_mask.value().data_ptr<T>() + offset : nullptr; NonUniform ? probs_or_mask.value().data_ptr<ProbsType>() + offset
: nullptr;
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
args.indices.scalar_type(), "LaborPickMain", ([&] { args.indices.scalar_type(), "LaborPickMain", ([&] {
const scalar_t* local_indices_data = const scalar_t* local_indices_data =
...@@ -835,21 +853,15 @@ inline torch::Tensor LaborPick( ...@@ -835,21 +853,15 @@ inline torch::Tensor LaborPick(
} }
})); }));
int64_t num_sampled = 0; int64_t num_sampled = 0;
torch::Tensor picked_neighbors = torch::empty({fanout}, options);
AT_DISPATCH_INTEGRAL_TYPES(
picked_neighbors.scalar_type(), "LaborPickOutput", ([&] {
scalar_t* picked_neighbors_data = picked_neighbors.data_ptr<scalar_t>();
for (int64_t i = 0; i < fanout; ++i) { for (int64_t i = 0; i < fanout; ++i) {
const auto [rnd, j] = heap_data[i]; const auto [rnd, j] = heap_data[i];
if (!NonUniform || rnd < std::numeric_limits<float>::infinity()) { if (!NonUniform || rnd < std::numeric_limits<float>::infinity()) {
picked_neighbors_data[num_sampled++] = offset + j; picked_data_ptr[num_sampled++] = offset + j;
} }
} }
}));
TORCH_CHECK( TORCH_CHECK(
!Replace || num_sampled == fanout || num_sampled == 0, !Replace || num_sampled == fanout || num_sampled == 0,
"Sampling with replacement should sample exactly fanout neighbors or 0!"); "Sampling with replacement should sample exactly fanout neighbors or 0!");
return picked_neighbors.narrow(0, 0, num_sampled);
} }
} // namespace sampling } // namespace sampling
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment