Unverified Commit 58d98e03 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[Graphbolt] Utilize pre-allocation in sampling (#6132)

parent f0d8ca1e
...@@ -353,28 +353,22 @@ int64_t NumPickByEtype( ...@@ -353,28 +353,22 @@ int64_t NumPickByEtype(
* probabilities associated with each neighboring edge of a node in the original * probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements * graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph. * equal to the number of edges in the graph.
* * @param picked_data_ptr The destination address where the picked neighbors
* @return A tensor containing the picked neighbors. * should be put. Enough memory space should be allocated in advance.
*/ */
template <SamplerType S> template <typename PickedType>
torch::Tensor Pick( void Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args);
template <>
torch::Tensor Pick<SamplerType::NEIGHBOR>(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::NEIGHBOR> args); SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr);
template <> template <typename PickedType>
torch::Tensor Pick<SamplerType::LABOR>( void Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args); SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr);
/** /**
* @brief Picks a specified number of neighbors for a node per edge type, * @brief Picks a specified number of neighbors for a node per edge type,
...@@ -400,22 +394,25 @@ torch::Tensor Pick<SamplerType::LABOR>( ...@@ -400,22 +394,25 @@ torch::Tensor Pick<SamplerType::LABOR>(
* probabilities associated with each neighboring edge of a node in the original * probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements * graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph. * equal to the number of edges in the graph.
* * @param picked_data_ptr The destination address where the picked neighbors
* @return A tensor containing the picked neighbors. * should be put. Enough memory space should be allocated in advance.
*/ */
template <SamplerType S> template <SamplerType S, typename PickedType>
torch::Tensor PickByEtype( void PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts, int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options, bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge, const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args); const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr);
template <bool NonUniform, bool Replace, typename T = float> template <
torch::Tensor LaborPick( bool NonUniform, bool Replace, typename ProbsType = float,
typename PickedType>
void LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args); SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr);
} // namespace sampling } // namespace sampling
} // namespace graphbolt } // namespace graphbolt
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <cmath> #include <cmath>
#include <limits> #include <limits>
#include <numeric>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
...@@ -187,10 +188,11 @@ auto GetNumPickFn( ...@@ -187,10 +188,11 @@ auto GetNumPickFn(
* equal to the number of edges in the graph. * equal to the number of edges in the graph.
* @param args Contains sampling algorithm specific arguments. * @param args Contains sampling algorithm specific arguments.
* *
* @return A lambda function: (int64_t offset, int64_t num_neighbors) -> * @return A lambda function: (int64_t offset, int64_t num_neighbors,
* torch::Tensor, which takes offset (the starting edge ID of the given node) * PickedType* picked_data_ptr) -> torch::Tensor, which takes offset (the
* and num_neighbors (number of neighbors) as params and returns a tensor of * starting edge ID of the given node) and num_neighbors (number of neighbors)
* picked neighbors. * as params and puts the picked neighbors at the address specified by
* picked_data_ptr.
*/ */
template <SamplerType S> template <SamplerType S>
auto GetPickFn( auto GetPickFn(
...@@ -199,17 +201,18 @@ auto GetPickFn( ...@@ -199,17 +201,18 @@ auto GetPickFn(
const torch::optional<torch::Tensor>& type_per_edge, const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args) { const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args) {
return [&fanouts, replace, &options, &type_per_edge, &probs_or_mask, args]( return [&fanouts, replace, &options, &type_per_edge, &probs_or_mask, args](
int64_t offset, int64_t num_neighbors) { int64_t offset, int64_t num_neighbors, auto picked_data_ptr) {
// If fanouts.size() > 1, perform sampling for each edge type of each node; // If fanouts.size() > 1, perform sampling for each edge type of each
// otherwise just sample once for each node with no regard of edge types. // node; otherwise just sample once for each node with no regard of edge
// types.
if (fanouts.size() > 1) { if (fanouts.size() > 1) {
return PickByEtype( return PickByEtype(
offset, num_neighbors, fanouts, replace, options, offset, num_neighbors, fanouts, replace, options,
type_per_edge.value(), probs_or_mask, args); type_per_edge.value(), probs_or_mask, args, picked_data_ptr);
} else { } else {
return Pick( return Pick(
offset, num_neighbors, fanouts[0], replace, options, probs_or_mask, offset, num_neighbors, fanouts[0], replace, options, probs_or_mask,
args); args, picked_data_ptr);
} }
}; };
} }
...@@ -257,17 +260,23 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl( ...@@ -257,17 +260,23 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
continue; continue;
} }
// Pre-allocate tensors for each node. Because the pick
// functions are modified, this part of code needed refactoring
// to adapt to the change of APIs. It's temporary since the
// whole process will be rewritten soon.
int64_t allocate_size = num_pick_fn(offset, num_neighbors);
picked_neighbors_cur_thread[i - begin] = picked_neighbors_cur_thread[i - begin] =
pick_fn(offset, num_neighbors); torch::empty({allocate_size}, indptr_options);
torch::Tensor& picked_tensor =
// This number should be the same as the result of num_pick_fn. picked_neighbors_cur_thread[i - begin];
num_picked_neighbors_per_node[i + 1] = AT_DISPATCH_INTEGRAL_TYPES(
picked_neighbors_cur_thread[i - begin].size(0); picked_tensor.scalar_type(), "CallPick", ([&] {
TORCH_CHECK( pick_fn(
*num_picked_neighbors_per_node[i + 1].data_ptr<int64_t>() == offset, num_neighbors,
num_pick_fn(offset, num_neighbors), picked_tensor.data_ptr<scalar_t>());
"Return value of num_pick_fn doesn't match the actual " }));
"picked number.");
num_picked_neighbors_per_node[i + 1] = allocate_size;
} }
picked_neighbors_per_thread[thread_id] = picked_neighbors_per_thread[thread_id] =
torch::cat(picked_neighbors_cur_thread); torch::cat(picked_neighbors_cur_thread);
...@@ -431,106 +440,101 @@ int64_t NumPickByEtype( ...@@ -431,106 +440,101 @@ int64_t NumPickByEtype(
* without replacement. If True, a value can be selected multiple times. * without replacement. If True, a value can be selected multiple times.
* Otherwise, each value can be selected only once. * Otherwise, each value can be selected only once.
* @param options Tensor options specifying the desired data type of the result. * @param options Tensor options specifying the desired data type of the result.
* * @param picked_data_ptr The destination address where the picked neighbors
* @return A tensor containing the picked neighbors. * should be put. Enough memory space should be allocated in advance.
*/ */
inline torch::Tensor UniformPick( template <typename PickedType>
inline void UniformPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options) { const torch::TensorOptions& options, PickedType* picked_data_ptr) {
torch::Tensor picked_neighbors;
if ((fanout == -1) || (num_neighbors <= fanout && !replace)) { if ((fanout == -1) || (num_neighbors <= fanout && !replace)) {
picked_neighbors = torch::arange(offset, offset + num_neighbors, options); std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset);
} else if (replace) { } else if (replace) {
picked_neighbors = std::memcpy(
torch::randint(offset, offset + num_neighbors, {fanout}, options); picked_data_ptr,
torch::randint(offset, offset + num_neighbors, {fanout}, options)
.data_ptr<PickedType>(),
fanout * sizeof(PickedType));
} else { } else {
picked_neighbors = torch::empty({fanout}, options); // We use different sampling strategies for different sampling case.
AT_DISPATCH_INTEGRAL_TYPES( if (fanout >= num_neighbors / 10) {
picked_neighbors.scalar_type(), "UniformPick", ([&] { // [Algorithm]
scalar_t* picked_neighbors_data = // This algorithm is conceptually related to the Fisher-Yates
picked_neighbors.data_ptr<scalar_t>(); // shuffle.
// We use different sampling strategies for different sampling case. //
if (fanout >= num_neighbors / 10) { // [Complexity Analysis]
// [Algorithm] // This algorithm's memory complexity is O(num_neighbors), but
// This algorithm is conceptually related to the Fisher-Yates // it generates fewer random numbers (O(fanout)).
// shuffle. //
// // (Compare) Reservoir algorithm is one of the most classical
// [Complexity Analysis] // sampling algorithms. Both the reservoir algorithm and our
// This algorithm's memory complexity is O(num_neighbors), but // algorithm offer distinct advantages, we need to compare to
// it generates fewer random numbers (O(fanout)). // illustrate our trade-offs.
// // The reservoir algorithm is memory-efficient (O(fanout)) but
// (Compare) Reservoir algorithm is one of the most classical // creates many random numbers (O(num_neighbors)), which is
// sampling algorithms. Both the reservoir algorithm and our // costly.
// algorithm offer distinct advantages, we need to compare to //
// illustrate our trade-offs. // [Practical Consideration]
// The reservoir algorithm is memory-efficient (O(fanout)) but // Use this algorithm when `fanout >= num_neighbors / 10` to
// creates many random numbers (O(num_neighbors)), which is // reduce computation.
// costly. // In this scenarios above, memory complexity is not a concern due
// // to the small size of both `fanout` and `num_neighbors`. And it
// [Practical Consideration] // is efficient to allocate a small amount of memory. So the
// Use this algorithm when `fanout >= num_neighbors / 10` to // algorithm performence is great in this case.
// reduce computation. std::vector<PickedType> seq(num_neighbors);
// In this scenarios above, memory complexity is not a concern due // Assign the seq with [offset, offset + num_neighbors].
// to the small size of both `fanout` and `num_neighbors`. And it std::iota(seq.begin(), seq.end(), offset);
// is efficient to allocate a small amount of memory. So the for (int64_t i = 0; i < fanout; ++i) {
// algorithm performence is great in this case. auto j = RandomEngine::ThreadLocal()->RandInt(i, num_neighbors);
std::vector<scalar_t> seq(num_neighbors); std::swap(seq[i], seq[j]);
// Assign the seq with [offset, offset + num_neighbors]. }
std::iota(seq.begin(), seq.end(), offset); // Save the randomly sampled fanout elements to the output tensor.
for (int64_t i = 0; i < fanout; ++i) { std::copy(seq.begin(), seq.begin() + fanout, picked_data_ptr);
auto j = RandomEngine::ThreadLocal()->RandInt(i, num_neighbors); } else if (fanout < 64) {
std::swap(seq[i], seq[j]); // [Algorithm]
} // Use linear search to verify uniqueness.
// Save the randomly sampled fanout elements to the output tensor. //
std::copy(seq.begin(), seq.begin() + fanout, picked_neighbors_data); // [Complexity Analysis]
} else if (fanout < 64) { // Since the set of numbers is small (up to 64), so it is more
// [Algorithm] // cost-effective for the CPU to use this algorithm.
// Use linear search to verify uniqueness. auto begin = picked_data_ptr;
// auto end = picked_data_ptr + fanout;
// [Complexity Analysis]
// Since the set of numbers is small (up to 64), so it is more while (begin != end) {
// cost-effective for the CPU to use this algorithm. // Put the new random number in the last position.
auto begin = picked_neighbors_data; *begin = RandomEngine::ThreadLocal()->RandInt(
auto end = picked_neighbors_data + fanout; offset, offset + num_neighbors);
// Check if a new value doesn't exist in current
while (begin != end) { // range(picked_data_ptr, begin). Otherwise get a new
// Put the new random number in the last position. // value until we haven't unique range of elements.
*begin = RandomEngine::ThreadLocal()->RandInt( auto it = std::find(picked_data_ptr, begin, *begin);
offset, offset + num_neighbors); if (it == begin) ++begin;
// Check if a new value doesn't exist in current }
// range(picked_neighbors_data, begin). Otherwise get a new } else {
// value until we haven't unique range of elements. // [Algorithm]
auto it = std::find(picked_neighbors_data, begin, *begin); // Use hash-set to verify uniqueness. In the best scenario, the
if (it == begin) ++begin; // time complexity is O(fanout), assuming no conflicts occur.
} //
} else { // [Complexity Analysis]
// [Algorithm] // Let K = (fanout / num_neighbors), the expected number of extra
// Use hash-set to verify uniqueness. In the best scenario, the // sampling steps is roughly K^2 / (1-K) * num_neighbors, which
// time complexity is O(fanout), assuming no conflicts occur. // means in the worst case scenario, the time complexity is
// // O(num_neighbors^2).
// [Complexity Analysis] //
// Let K = (fanout / num_neighbors), the expected number of extra // [Practical Consideration]
// sampling steps is roughly K^2 / (1-K) * num_neighbors, which // In practice, we set the threshold K to 1/10. This trade-off is
// means in the worst case scenario, the time complexity is // due to the slower performance of std::unordered_set, which
// O(num_neighbors^2). // would otherwise increase the sampling cost. By doing so, we
// // achieve a balance between theoretical efficiency and practical
// [Practical Consideration] // performance.
// In practice, we set the threshold K to 1/10. This trade-off is std::unordered_set<PickedType> picked_set;
// due to the slower performance of std::unordered_set, which while (static_cast<int64_t>(picked_set.size()) < fanout) {
// would otherwise increase the sampling cost. By doing so, we picked_set.insert(RandomEngine::ThreadLocal()->RandInt(
// achieve a balance between theoretical efficiency and practical offset, offset + num_neighbors));
// performance. }
std::unordered_set<scalar_t> picked_set; std::copy(picked_set.begin(), picked_set.end(), picked_data_ptr);
while (static_cast<int64_t>(picked_set.size()) < fanout) { }
picked_set.insert(RandomEngine::ThreadLocal()->RandInt(
offset, offset + num_neighbors));
}
std::copy(
picked_set.begin(), picked_set.end(), picked_neighbors_data);
}
}));
} }
return picked_neighbors;
} }
/** /**
...@@ -565,59 +569,65 @@ inline torch::Tensor UniformPick( ...@@ -565,59 +569,65 @@ inline torch::Tensor UniformPick(
* probabilities associated with each neighboring edge of a node in the original * probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements * graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph. * equal to the number of edges in the graph.
* * @param picked_data_ptr The destination address where the picked neighbors
* @return A tensor containing the picked neighbors. * should be put. Enough memory space should be allocated in advance.
*/ */
inline torch::Tensor NonUniformPick( template <typename PickedType>
inline void NonUniformPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask) { const torch::optional<torch::Tensor>& probs_or_mask,
torch::Tensor picked_neighbors; PickedType* picked_data_ptr) {
auto local_probs = auto local_probs =
probs_or_mask.value().slice(0, offset, offset + num_neighbors); probs_or_mask.value().slice(0, offset, offset + num_neighbors);
auto positive_probs_indices = local_probs.nonzero().squeeze(1); auto positive_probs_indices = local_probs.nonzero().squeeze(1);
auto num_positive_probs = positive_probs_indices.size(0); auto num_positive_probs = positive_probs_indices.size(0);
if (num_positive_probs == 0) return torch::tensor({}, options); if (num_positive_probs == 0) return;
if ((fanout == -1) || (num_positive_probs <= fanout && !replace)) { if ((fanout == -1) || (num_positive_probs <= fanout && !replace)) {
picked_neighbors = torch::arange(offset, offset + num_neighbors, options); std::memcpy(
picked_neighbors = picked_data_ptr,
torch::index_select(picked_neighbors, 0, positive_probs_indices); (positive_probs_indices + offset).data_ptr<PickedType>(),
num_positive_probs * sizeof(PickedType));
} else { } else {
if (!replace) fanout = std::min(fanout, num_positive_probs); if (!replace) fanout = std::min(fanout, num_positive_probs);
picked_neighbors = std::memcpy(
torch::multinomial(local_probs, fanout, replace) + offset; picked_data_ptr,
(torch::multinomial(local_probs, fanout, replace) + offset)
.data_ptr<PickedType>(),
fanout * sizeof(PickedType));
} }
return picked_neighbors;
} }
template <> template <typename PickedType>
torch::Tensor Pick<SamplerType::NEIGHBOR>( void Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::NEIGHBOR> args) { SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr) {
if (probs_or_mask.has_value()) { if (probs_or_mask.has_value()) {
return NonUniformPick( NonUniformPick(
offset, num_neighbors, fanout, replace, options, probs_or_mask); offset, num_neighbors, fanout, replace, options, probs_or_mask,
picked_data_ptr);
} else { } else {
return UniformPick(offset, num_neighbors, fanout, replace, options); UniformPick(
offset, num_neighbors, fanout, replace, options, picked_data_ptr);
} }
} }
template <SamplerType S> template <SamplerType S, typename PickedType>
torch::Tensor PickByEtype( void PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts, int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options, bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge, const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args) { const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
std::vector<torch::Tensor> picked_neighbors( PickedType* picked_data_ptr) {
fanouts.size(), torch::tensor({}, options));
int64_t etype_begin = offset; int64_t etype_begin = offset;
int64_t etype_end = offset; int64_t etype_end = offset;
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
type_per_edge.scalar_type(), "PickByEtype", ([&] { type_per_edge.scalar_type(), "PickByEtype", ([&] {
const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>(); const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>();
const auto end = offset + num_neighbors; const auto end = offset + num_neighbors;
int64_t pick_offset = 0;
while (etype_begin < end) { while (etype_begin < end) {
scalar_t etype = type_per_edge_data[etype_begin]; scalar_t etype = type_per_edge_data[etype_begin];
TORCH_CHECK( TORCH_CHECK(
...@@ -628,53 +638,58 @@ torch::Tensor PickByEtype( ...@@ -628,53 +638,58 @@ torch::Tensor PickByEtype(
type_per_edge_data + etype_begin, type_per_edge_data + end, type_per_edge_data + etype_begin, type_per_edge_data + end,
etype); etype);
etype_end = etype_end_it - type_per_edge_data; etype_end = etype_end_it - type_per_edge_data;
int64_t picked_count = NumPick(
fanout, replace, probs_or_mask, etype_begin,
etype_end - etype_begin);
// Do sampling for one etype. // Do sampling for one etype.
if (fanout != 0) { if (fanout != 0) {
picked_neighbors[etype] = Pick<S>( Pick(
etype_begin, etype_end - etype_begin, fanout, replace, options, etype_begin, etype_end - etype_begin, fanout, replace, options,
probs_or_mask, args); probs_or_mask, args, picked_data_ptr + pick_offset);
pick_offset += picked_count;
} }
etype_begin = etype_end; etype_begin = etype_end;
} }
})); }));
return torch::cat(picked_neighbors, 0);
} }
template <> template <typename PickedType>
torch::Tensor Pick<SamplerType::LABOR>( void Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args) { SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) {
if (fanout == 0) return torch::tensor({}, options); if (fanout == 0) return;
if (probs_or_mask.has_value()) { if (probs_or_mask.has_value()) {
if (fanout < 0) { if (fanout < 0) {
return NonUniformPick( NonUniformPick(
offset, num_neighbors, fanout, replace, options, probs_or_mask); offset, num_neighbors, fanout, replace, options, probs_or_mask,
picked_data_ptr);
} else {
AT_DISPATCH_FLOATING_TYPES(
probs_or_mask.value().scalar_type(), "LaborPickFloatType", ([&] {
if (replace) {
LaborPick<true, true, scalar_t>(
offset, num_neighbors, fanout, options, probs_or_mask, args,
picked_data_ptr);
} else {
LaborPick<true, false, scalar_t>(
offset, num_neighbors, fanout, options, probs_or_mask, args,
picked_data_ptr);
}
}));
} }
torch::Tensor picked_neighbors;
AT_DISPATCH_FLOATING_TYPES(
probs_or_mask.value().scalar_type(), "LaborPickFloatType", ([&] {
if (replace) {
picked_neighbors = LaborPick<true, true, scalar_t>(
offset, num_neighbors, fanout, options, probs_or_mask, args);
} else {
picked_neighbors = LaborPick<true, false, scalar_t>(
offset, num_neighbors, fanout, options, probs_or_mask, args);
}
}));
return picked_neighbors;
} else if (fanout < 0) { } else if (fanout < 0) {
return UniformPick(offset, num_neighbors, fanout, replace, options); UniformPick(
offset, num_neighbors, fanout, replace, options, picked_data_ptr);
} else if (replace) { } else if (replace) {
return LaborPick<false, true>( LaborPick<false, true>(
offset, num_neighbors, fanout, options, offset, num_neighbors, fanout, options,
/* probs_or_mask= */ torch::nullopt, args); /* probs_or_mask= */ torch::nullopt, args, picked_data_ptr);
} else { // replace = false } else { // replace = false
return LaborPick<false, false>( LaborPick<false, false>(
offset, num_neighbors, fanout, options, offset, num_neighbors, fanout, options,
/* probs_or_mask= */ torch::nullopt, args); /* probs_or_mask= */ torch::nullopt, args, picked_data_ptr);
} }
} }
...@@ -704,25 +719,28 @@ inline void safe_divide(T& a, U b) { ...@@ -704,25 +719,28 @@ inline void safe_divide(T& a, U b) {
* graph. It must be a 1D floating-point tensor with the number of elements * graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph. * equal to the number of edges in the graph.
* @param args Contains labor specific arguments. * @param args Contains labor specific arguments.
* * @param picked_data_ptr The destination address where the picked neighbors
* @return A tensor containing the picked neighbors. * should be put. Enough memory space should be allocated in advance.
*/ */
template <bool NonUniform, bool Replace, typename T> template <
inline torch::Tensor LaborPick( bool NonUniform, bool Replace, typename ProbsType, typename PickedType>
inline void LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args) { SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) {
fanout = Replace ? fanout : std::min(fanout, num_neighbors); fanout = Replace ? fanout : std::min(fanout, num_neighbors);
if (!NonUniform && !Replace && fanout >= num_neighbors) { if (!NonUniform && !Replace && fanout >= num_neighbors) {
return torch::arange(offset, offset + num_neighbors, options); std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset);
return;
} }
torch::Tensor heap_tensor = torch::empty({fanout * 2}, torch::kInt32); torch::Tensor heap_tensor = torch::empty({fanout * 2}, torch::kInt32);
// Assuming max_degree of a vertex is <= 4 billion. // Assuming max_degree of a vertex is <= 4 billion.
auto heap_data = reinterpret_cast<std::pair<float, uint32_t>*>( auto heap_data = reinterpret_cast<std::pair<float, uint32_t>*>(
heap_tensor.data_ptr<int32_t>()); heap_tensor.data_ptr<int32_t>());
const T* local_probs_data = const ProbsType* local_probs_data =
NonUniform ? probs_or_mask.value().data_ptr<T>() + offset : nullptr; NonUniform ? probs_or_mask.value().data_ptr<ProbsType>() + offset
: nullptr;
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
args.indices.scalar_type(), "LaborPickMain", ([&] { args.indices.scalar_type(), "LaborPickMain", ([&] {
const scalar_t* local_indices_data = const scalar_t* local_indices_data =
...@@ -835,21 +853,15 @@ inline torch::Tensor LaborPick( ...@@ -835,21 +853,15 @@ inline torch::Tensor LaborPick(
} }
})); }));
int64_t num_sampled = 0; int64_t num_sampled = 0;
torch::Tensor picked_neighbors = torch::empty({fanout}, options); for (int64_t i = 0; i < fanout; ++i) {
AT_DISPATCH_INTEGRAL_TYPES( const auto [rnd, j] = heap_data[i];
picked_neighbors.scalar_type(), "LaborPickOutput", ([&] { if (!NonUniform || rnd < std::numeric_limits<float>::infinity()) {
scalar_t* picked_neighbors_data = picked_neighbors.data_ptr<scalar_t>(); picked_data_ptr[num_sampled++] = offset + j;
for (int64_t i = 0; i < fanout; ++i) { }
const auto [rnd, j] = heap_data[i]; }
if (!NonUniform || rnd < std::numeric_limits<float>::infinity()) {
picked_neighbors_data[num_sampled++] = offset + j;
}
}
}));
TORCH_CHECK( TORCH_CHECK(
!Replace || num_sampled == fanout || num_sampled == 0, !Replace || num_sampled == fanout || num_sampled == 0,
"Sampling with replacement should sample exactly fanout neighbors or 0!"); "Sampling with replacement should sample exactly fanout neighbors or 0!");
return picked_neighbors.narrow(0, 0, num_sampled);
} }
} // namespace sampling } // namespace sampling
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment