Unverified Commit b1153db9 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[Graphbolt] Rewrite sampling process to eliminate torch::cat (#6152)

parent 64df37f7
...@@ -357,14 +357,14 @@ int64_t NumPickByEtype( ...@@ -357,14 +357,14 @@ int64_t NumPickByEtype(
* should be put. Enough memory space should be allocated in advance. * should be put. Enough memory space should be allocated in advance.
*/ */
template <typename PickedType> template <typename PickedType>
void Pick( int64_t Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr); SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr);
template <typename PickedType> template <typename PickedType>
void Pick( int64_t Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
...@@ -398,7 +398,7 @@ void Pick( ...@@ -398,7 +398,7 @@ void Pick(
* should be put. Enough memory space should be allocated in advance. * should be put. Enough memory space should be allocated in advance.
*/ */
template <SamplerType S, typename PickedType> template <SamplerType S, typename PickedType>
void PickByEtype( int64_t PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts, int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options, bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge, const torch::Tensor& type_per_edge,
...@@ -408,7 +408,7 @@ void PickByEtype( ...@@ -408,7 +408,7 @@ void PickByEtype(
template < template <
bool NonUniform, bool Replace, typename ProbsType = float, bool NonUniform, bool Replace, typename ProbsType = float,
typename PickedType> typename PickedType>
void LaborPick( int64_t LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
......
...@@ -222,28 +222,30 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl( ...@@ -222,28 +222,30 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn, const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const { PickFn pick_fn) const {
const int64_t num_nodes = nodes.size(0); const int64_t num_nodes = nodes.size(0);
const int64_t num_threads = torch::get_num_threads(); const auto indptr_options = indptr_.options();
std::vector<torch::Tensor> picked_neighbors_per_thread(num_threads);
torch::Tensor num_picked_neighbors_per_node = torch::Tensor num_picked_neighbors_per_node =
torch::zeros({num_nodes + 1}, indptr_.options()); torch::empty({num_nodes + 1}, indptr_options);
// Calculate GrainSize for parallel_for. // Calculate GrainSize for parallel_for.
// Set the default grain size to 64. // Set the default grain size to 64.
const int64_t grain_size = 64; const int64_t grain_size = 64;
torch::Tensor picked_eids;
torch::Tensor subgraph_indptr;
torch::Tensor subgraph_indices;
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
indptr_.scalar_type(), "parallel_for", ([&] { indptr_.scalar_type(), "SampleNeighborsImpl", ([&] {
const scalar_t* indptr_data = indptr_.data_ptr<scalar_t>();
auto num_picked_neighbors_data_ptr =
num_picked_neighbors_per_node.data_ptr<scalar_t>();
num_picked_neighbors_data_ptr[0] = 0;
const auto nodes_data_ptr = nodes.data_ptr<int64_t>();
// Step 1. Calculate pick number of each node.
torch::parallel_for( torch::parallel_for(
0, num_nodes, grain_size, [&](scalar_t begin, scalar_t end) { 0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
const auto indptr_options = indptr_.options(); for (int64_t i = begin; i < end; ++i) {
const scalar_t* indptr_data = indptr_.data_ptr<scalar_t>();
// Get current thread id.
auto thread_id = torch::get_thread_num();
int64_t local_grain_size = end - begin;
std::vector<torch::Tensor> picked_neighbors_cur_thread(
local_grain_size);
const auto nodes_data_ptr = nodes.data_ptr<int64_t>();
for (scalar_t i = begin; i < end; ++i) {
const auto nid = nodes_data_ptr[i]; const auto nid = nodes_data_ptr[i];
TORCH_CHECK( TORCH_CHECK(
nid >= 0 && nid < NumNodes(), nid >= 0 && nid < NumNodes(),
...@@ -252,49 +254,82 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl( ...@@ -252,49 +254,82 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
const auto offset = indptr_data[nid]; const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset; const auto num_neighbors = indptr_data[nid + 1] - offset;
if (num_neighbors == 0) { num_picked_neighbors_data_ptr[i + 1] =
// To avoid crashing during concatenation in the master num_neighbors == 0 ? 0 : num_pick_fn(offset, num_neighbors);
// thread, initializing with empty tensors. }
picked_neighbors_cur_thread[i - begin] = });
torch::tensor({}, indptr_options);
continue; // Step 2. Calculate prefix sum to get total length and offsets of each
} // node. It's also the indptr of the generated subgraph.
subgraph_indptr = torch::cumsum(num_picked_neighbors_per_node, 0);
// Step 3. Allocate the tensor for picked neighbors.
const auto total_length =
subgraph_indptr.data_ptr<scalar_t>()[num_nodes];
picked_eids = torch::empty({total_length}, indptr_options);
subgraph_indices = torch::empty({total_length}, indices_.options());
if (type_per_edge_.has_value()) {
subgraph_type_per_edge =
torch::empty({total_length}, type_per_edge_.value().options());
}
// Pre-allocate tensors for each node. Because the pick // Step 4. Pick neighbors for each node.
// functions are modified, this part of code needed refactoring auto picked_eids_data_ptr = picked_eids.data_ptr<scalar_t>();
// to adapt to the change of APIs. It's temporary since the auto subgraph_indptr_data_ptr = subgraph_indptr.data_ptr<scalar_t>();
// whole process will be rewritten soon. torch::parallel_for(
int64_t allocate_size = num_pick_fn(offset, num_neighbors); 0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
picked_neighbors_cur_thread[i - begin] = for (int64_t i = begin; i < end; ++i) {
torch::empty({allocate_size}, indptr_options); const auto nid = nodes_data_ptr[i];
torch::Tensor& picked_tensor = const auto offset = indptr_data[nid];
picked_neighbors_cur_thread[i - begin]; const auto num_neighbors = indptr_data[nid + 1] - offset;
AT_DISPATCH_INTEGRAL_TYPES( const auto picked_number = num_picked_neighbors_data_ptr[i + 1];
picked_tensor.scalar_type(), "CallPick", ([&] { const auto picked_offset = subgraph_indptr_data_ptr[i];
pick_fn( if (picked_number > 0) {
offset, num_neighbors, auto actual_picked_count = pick_fn(
picked_tensor.data_ptr<scalar_t>()); offset, num_neighbors,
})); picked_eids_data_ptr + picked_offset);
TORCH_CHECK(
num_picked_neighbors_per_node[i + 1] = allocate_size; actual_picked_count == picked_number,
"Actual picked count doesn't match the calculated pick "
"number.");
// Step 5. Calculate other attributes and return the subgraph.
AT_DISPATCH_INTEGRAL_TYPES(
subgraph_indices.scalar_type(),
"IndexSelectSubgraphIndices", ([&] {
auto subgraph_indices_data_ptr =
subgraph_indices.data_ptr<scalar_t>();
auto indices_data_ptr = indices_.data_ptr<scalar_t>();
for (auto i = picked_offset;
i < picked_offset + picked_number; ++i) {
subgraph_indices_data_ptr[i] =
indices_data_ptr[picked_eids_data_ptr[i]];
}
}));
if (type_per_edge_.has_value()) {
AT_DISPATCH_INTEGRAL_TYPES(
subgraph_type_per_edge.value().scalar_type(),
"IndexSelectTypePerEdge", ([&] {
auto subgraph_type_per_edge_data_ptr =
subgraph_type_per_edge.value()
.data_ptr<scalar_t>();
auto type_per_edge_data_ptr =
type_per_edge_.value().data_ptr<scalar_t>();
for (auto i = picked_offset;
i < picked_offset + picked_number; ++i) {
subgraph_type_per_edge_data_ptr[i] =
type_per_edge_data_ptr[picked_eids_data_ptr[i]];
}
}));
}
}
} }
picked_neighbors_per_thread[thread_id] = });
torch::cat(picked_neighbors_cur_thread);
}); // End of parallel_for.
})); }));
torch::Tensor subgraph_indptr =
torch::cumsum(num_picked_neighbors_per_node, 0);
torch::Tensor picked_eids = torch::cat(picked_neighbors_per_thread);
torch::Tensor subgraph_indices =
torch::index_select(indices_, 0, picked_eids);
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;
if (type_per_edge_.has_value()) {
subgraph_type_per_edge =
torch::index_select(type_per_edge_.value(), 0, picked_eids);
}
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt; torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids); if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
return c10::make_intrusive<SampledSubgraph>( return c10::make_intrusive<SampledSubgraph>(
subgraph_indptr, subgraph_indices, nodes, torch::nullopt, subgraph_indptr, subgraph_indices, nodes, torch::nullopt,
subgraph_reverse_edge_ids, subgraph_type_per_edge); subgraph_reverse_edge_ids, subgraph_type_per_edge);
...@@ -383,12 +418,17 @@ int64_t NumPick( ...@@ -383,12 +418,17 @@ int64_t NumPick(
int64_t fanout, bool replace, int64_t fanout, bool replace,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset, const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors) { int64_t num_neighbors) {
int64_t num_valid_neighbors = int64_t num_valid_neighbors = num_neighbors;
probs_or_mask.has_value() if (probs_or_mask.has_value()) {
? *torch::count_nonzero( // Subtract the count of zeros in probs_or_mask.
probs_or_mask.value().slice(0, offset, offset + num_neighbors)) AT_DISPATCH_ALL_TYPES(
.data_ptr<int64_t>() probs_or_mask.value().scalar_type(), "CountZero", ([&] {
: num_neighbors; scalar_t* probs_data_ptr = probs_or_mask.value().data_ptr<scalar_t>();
num_valid_neighbors -= std::count(
probs_data_ptr + offset, probs_data_ptr + offset + num_neighbors,
0);
}));
}
if (num_valid_neighbors == 0 || fanout == -1) return num_valid_neighbors; if (num_valid_neighbors == 0 || fanout == -1) return num_valid_neighbors;
return replace ? fanout : std::min(fanout, num_valid_neighbors); return replace ? fanout : std::min(fanout, num_valid_neighbors);
} }
...@@ -444,17 +484,19 @@ int64_t NumPickByEtype( ...@@ -444,17 +484,19 @@ int64_t NumPickByEtype(
* should be put. Enough memory space should be allocated in advance. * should be put. Enough memory space should be allocated in advance.
*/ */
template <typename PickedType> template <typename PickedType>
inline void UniformPick( inline int64_t UniformPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, PickedType* picked_data_ptr) { const torch::TensorOptions& options, PickedType* picked_data_ptr) {
if ((fanout == -1) || (num_neighbors <= fanout && !replace)) { if ((fanout == -1) || (num_neighbors <= fanout && !replace)) {
std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset); std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset);
return num_neighbors;
} else if (replace) { } else if (replace) {
std::memcpy( std::memcpy(
picked_data_ptr, picked_data_ptr,
torch::randint(offset, offset + num_neighbors, {fanout}, options) torch::randint(offset, offset + num_neighbors, {fanout}, options)
.data_ptr<PickedType>(), .data_ptr<PickedType>(),
fanout * sizeof(PickedType)); fanout * sizeof(PickedType));
return fanout;
} else { } else {
// We use different sampling strategies for different sampling case. // We use different sampling strategies for different sampling case.
if (fanout >= num_neighbors / 10) { if (fanout >= num_neighbors / 10) {
...@@ -490,6 +532,7 @@ inline void UniformPick( ...@@ -490,6 +532,7 @@ inline void UniformPick(
} }
// Save the randomly sampled fanout elements to the output tensor. // Save the randomly sampled fanout elements to the output tensor.
std::copy(seq.begin(), seq.begin() + fanout, picked_data_ptr); std::copy(seq.begin(), seq.begin() + fanout, picked_data_ptr);
return fanout;
} else if (fanout < 64) { } else if (fanout < 64) {
// [Algorithm] // [Algorithm]
// Use linear search to verify uniqueness. // Use linear search to verify uniqueness.
...@@ -510,6 +553,7 @@ inline void UniformPick( ...@@ -510,6 +553,7 @@ inline void UniformPick(
auto it = std::find(picked_data_ptr, begin, *begin); auto it = std::find(picked_data_ptr, begin, *begin);
if (it == begin) ++begin; if (it == begin) ++begin;
} }
return fanout;
} else { } else {
// [Algorithm] // [Algorithm]
// Use hash-set to verify uniqueness. In the best scenario, the // Use hash-set to verify uniqueness. In the best scenario, the
...@@ -533,6 +577,7 @@ inline void UniformPick( ...@@ -533,6 +577,7 @@ inline void UniformPick(
offset, offset + num_neighbors)); offset, offset + num_neighbors));
} }
std::copy(picked_set.begin(), picked_set.end(), picked_data_ptr); std::copy(picked_set.begin(), picked_set.end(), picked_data_ptr);
return picked_set.size();
} }
} }
} }
...@@ -573,7 +618,7 @@ inline void UniformPick( ...@@ -573,7 +618,7 @@ inline void UniformPick(
* should be put. Enough memory space should be allocated in advance. * should be put. Enough memory space should be allocated in advance.
*/ */
template <typename PickedType> template <typename PickedType>
inline void NonUniformPick( inline int64_t NonUniformPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
...@@ -582,12 +627,13 @@ inline void NonUniformPick( ...@@ -582,12 +627,13 @@ inline void NonUniformPick(
probs_or_mask.value().slice(0, offset, offset + num_neighbors); probs_or_mask.value().slice(0, offset, offset + num_neighbors);
auto positive_probs_indices = local_probs.nonzero().squeeze(1); auto positive_probs_indices = local_probs.nonzero().squeeze(1);
auto num_positive_probs = positive_probs_indices.size(0); auto num_positive_probs = positive_probs_indices.size(0);
if (num_positive_probs == 0) return; if (num_positive_probs == 0) return 0;
if ((fanout == -1) || (num_positive_probs <= fanout && !replace)) { if ((fanout == -1) || (num_positive_probs <= fanout && !replace)) {
std::memcpy( std::memcpy(
picked_data_ptr, picked_data_ptr,
(positive_probs_indices + offset).data_ptr<PickedType>(), (positive_probs_indices + offset).data_ptr<PickedType>(),
num_positive_probs * sizeof(PickedType)); num_positive_probs * sizeof(PickedType));
return num_positive_probs;
} else { } else {
if (!replace) fanout = std::min(fanout, num_positive_probs); if (!replace) fanout = std::min(fanout, num_positive_probs);
std::memcpy( std::memcpy(
...@@ -595,27 +641,28 @@ inline void NonUniformPick( ...@@ -595,27 +641,28 @@ inline void NonUniformPick(
(torch::multinomial(local_probs, fanout, replace) + offset) (torch::multinomial(local_probs, fanout, replace) + offset)
.data_ptr<PickedType>(), .data_ptr<PickedType>(),
fanout * sizeof(PickedType)); fanout * sizeof(PickedType));
return fanout;
} }
} }
template <typename PickedType> template <typename PickedType>
void Pick( int64_t Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr) { SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr) {
if (probs_or_mask.has_value()) { if (probs_or_mask.has_value()) {
NonUniformPick( return NonUniformPick(
offset, num_neighbors, fanout, replace, options, probs_or_mask, offset, num_neighbors, fanout, replace, options, probs_or_mask,
picked_data_ptr); picked_data_ptr);
} else { } else {
UniformPick( return UniformPick(
offset, num_neighbors, fanout, replace, options, picked_data_ptr); offset, num_neighbors, fanout, replace, options, picked_data_ptr);
} }
} }
template <SamplerType S, typename PickedType> template <SamplerType S, typename PickedType>
void PickByEtype( int64_t PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts, int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options, bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge, const torch::Tensor& type_per_edge,
...@@ -623,11 +670,11 @@ void PickByEtype( ...@@ -623,11 +670,11 @@ void PickByEtype(
PickedType* picked_data_ptr) { PickedType* picked_data_ptr) {
int64_t etype_begin = offset; int64_t etype_begin = offset;
int64_t etype_end = offset; int64_t etype_end = offset;
int64_t pick_offset = 0;
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
type_per_edge.scalar_type(), "PickByEtype", ([&] { type_per_edge.scalar_type(), "PickByEtype", ([&] {
const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>(); const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>();
const auto end = offset + num_neighbors; const auto end = offset + num_neighbors;
int64_t pick_offset = 0;
while (etype_begin < end) { while (etype_begin < end) {
scalar_t etype = type_per_edge_data[etype_begin]; scalar_t etype = type_per_edge_data[etype_begin];
TORCH_CHECK( TORCH_CHECK(
...@@ -638,12 +685,9 @@ void PickByEtype( ...@@ -638,12 +685,9 @@ void PickByEtype(
type_per_edge_data + etype_begin, type_per_edge_data + end, type_per_edge_data + etype_begin, type_per_edge_data + end,
etype); etype);
etype_end = etype_end_it - type_per_edge_data; etype_end = etype_end_it - type_per_edge_data;
int64_t picked_count = NumPick(
fanout, replace, probs_or_mask, etype_begin,
etype_end - etype_begin);
// Do sampling for one etype. // Do sampling for one etype.
if (fanout != 0) { if (fanout != 0) {
Pick( int64_t picked_count = Pick(
etype_begin, etype_end - etype_begin, fanout, replace, options, etype_begin, etype_end - etype_begin, fanout, replace, options,
probs_or_mask, args, picked_data_ptr + pick_offset); probs_or_mask, args, picked_data_ptr + pick_offset);
pick_offset += picked_count; pick_offset += picked_count;
...@@ -651,43 +695,46 @@ void PickByEtype( ...@@ -651,43 +695,46 @@ void PickByEtype(
etype_begin = etype_end; etype_begin = etype_end;
} }
})); }));
return pick_offset;
} }
template <typename PickedType> template <typename PickedType>
void Pick( int64_t Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) { SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) {
if (fanout == 0) return; if (fanout == 0) return 0;
if (probs_or_mask.has_value()) { if (probs_or_mask.has_value()) {
if (fanout < 0) { if (fanout < 0) {
NonUniformPick( return NonUniformPick(
offset, num_neighbors, fanout, replace, options, probs_or_mask, offset, num_neighbors, fanout, replace, options, probs_or_mask,
picked_data_ptr); picked_data_ptr);
} else { } else {
int64_t picked_count;
AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES(
probs_or_mask.value().scalar_type(), "LaborPickFloatType", ([&] { probs_or_mask.value().scalar_type(), "LaborPickFloatType", ([&] {
if (replace) { if (replace) {
LaborPick<true, true, scalar_t>( picked_count = LaborPick<true, true, scalar_t>(
offset, num_neighbors, fanout, options, probs_or_mask, args, offset, num_neighbors, fanout, options, probs_or_mask, args,
picked_data_ptr); picked_data_ptr);
} else { } else {
LaborPick<true, false, scalar_t>( picked_count = LaborPick<true, false, scalar_t>(
offset, num_neighbors, fanout, options, probs_or_mask, args, offset, num_neighbors, fanout, options, probs_or_mask, args,
picked_data_ptr); picked_data_ptr);
} }
})); }));
return picked_count;
} }
} else if (fanout < 0) { } else if (fanout < 0) {
UniformPick( return UniformPick(
offset, num_neighbors, fanout, replace, options, picked_data_ptr); offset, num_neighbors, fanout, replace, options, picked_data_ptr);
} else if (replace) { } else if (replace) {
LaborPick<false, true>( return LaborPick<false, true>(
offset, num_neighbors, fanout, options, offset, num_neighbors, fanout, options,
/* probs_or_mask= */ torch::nullopt, args, picked_data_ptr); /* probs_or_mask= */ torch::nullopt, args, picked_data_ptr);
} else { // replace = false } else { // replace = false
LaborPick<false, false>( return LaborPick<false, false>(
offset, num_neighbors, fanout, options, offset, num_neighbors, fanout, options,
/* probs_or_mask= */ torch::nullopt, args, picked_data_ptr); /* probs_or_mask= */ torch::nullopt, args, picked_data_ptr);
} }
...@@ -724,7 +771,7 @@ inline void safe_divide(T& a, U b) { ...@@ -724,7 +771,7 @@ inline void safe_divide(T& a, U b) {
*/ */
template < template <
bool NonUniform, bool Replace, typename ProbsType, typename PickedType> bool NonUniform, bool Replace, typename ProbsType, typename PickedType>
inline void LaborPick( inline int64_t LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask, const torch::optional<torch::Tensor>& probs_or_mask,
...@@ -732,7 +779,7 @@ inline void LaborPick( ...@@ -732,7 +779,7 @@ inline void LaborPick(
fanout = Replace ? fanout : std::min(fanout, num_neighbors); fanout = Replace ? fanout : std::min(fanout, num_neighbors);
if (!NonUniform && !Replace && fanout >= num_neighbors) { if (!NonUniform && !Replace && fanout >= num_neighbors) {
std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset); std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset);
return; return num_neighbors;
} }
torch::Tensor heap_tensor = torch::empty({fanout * 2}, torch::kInt32); torch::Tensor heap_tensor = torch::empty({fanout * 2}, torch::kInt32);
// Assuming max_degree of a vertex is <= 4 billion. // Assuming max_degree of a vertex is <= 4 billion.
...@@ -862,6 +909,7 @@ inline void LaborPick( ...@@ -862,6 +909,7 @@ inline void LaborPick(
TORCH_CHECK( TORCH_CHECK(
!Replace || num_sampled == fanout || num_sampled == 0, !Replace || num_sampled == fanout || num_sampled == 0,
"Sampling with replacement should sample exactly fanout neighbors or 0!"); "Sampling with replacement should sample exactly fanout neighbors or 0!");
return num_sampled;
} }
} // namespace sampling } // namespace sampling
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment