Unverified Commit b1153db9 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[Graphbolt] Rewrite sampling process to eliminate torch::cat (#6152)

parent 64df37f7
......@@ -357,14 +357,14 @@ int64_t NumPickByEtype(
* should be put. Enough memory space should be allocated in advance.
*/
template <typename PickedType>
void Pick(
int64_t Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr);
template <typename PickedType>
void Pick(
int64_t Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
......@@ -398,7 +398,7 @@ void Pick(
* should be put. Enough memory space should be allocated in advance.
*/
template <SamplerType S, typename PickedType>
void PickByEtype(
int64_t PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge,
......@@ -408,7 +408,7 @@ void PickByEtype(
template <
bool NonUniform, bool Replace, typename ProbsType = float,
typename PickedType>
void LaborPick(
int64_t LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
......
......@@ -222,28 +222,30 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
const torch::Tensor& nodes, bool return_eids, NumPickFn num_pick_fn,
PickFn pick_fn) const {
const int64_t num_nodes = nodes.size(0);
const int64_t num_threads = torch::get_num_threads();
std::vector<torch::Tensor> picked_neighbors_per_thread(num_threads);
const auto indptr_options = indptr_.options();
torch::Tensor num_picked_neighbors_per_node =
torch::zeros({num_nodes + 1}, indptr_.options());
torch::empty({num_nodes + 1}, indptr_options);
// Calculate GrainSize for parallel_for.
// Set the default grain size to 64.
const int64_t grain_size = 64;
torch::Tensor picked_eids;
torch::Tensor subgraph_indptr;
torch::Tensor subgraph_indices;
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;
AT_DISPATCH_INTEGRAL_TYPES(
indptr_.scalar_type(), "parallel_for", ([&] {
indptr_.scalar_type(), "SampleNeighborsImpl", ([&] {
const scalar_t* indptr_data = indptr_.data_ptr<scalar_t>();
auto num_picked_neighbors_data_ptr =
num_picked_neighbors_per_node.data_ptr<scalar_t>();
num_picked_neighbors_data_ptr[0] = 0;
const auto nodes_data_ptr = nodes.data_ptr<int64_t>();
// Step 1. Calculate pick number of each node.
torch::parallel_for(
0, num_nodes, grain_size, [&](scalar_t begin, scalar_t end) {
const auto indptr_options = indptr_.options();
const scalar_t* indptr_data = indptr_.data_ptr<scalar_t>();
// Get current thread id.
auto thread_id = torch::get_thread_num();
int64_t local_grain_size = end - begin;
std::vector<torch::Tensor> picked_neighbors_cur_thread(
local_grain_size);
const auto nodes_data_ptr = nodes.data_ptr<int64_t>();
for (scalar_t i = begin; i < end; ++i) {
0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
const auto nid = nodes_data_ptr[i];
TORCH_CHECK(
nid >= 0 && nid < NumNodes(),
......@@ -252,49 +254,82 @@ c10::intrusive_ptr<SampledSubgraph> CSCSamplingGraph::SampleNeighborsImpl(
const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;
if (num_neighbors == 0) {
// To avoid crashing during concatenation in the master
// thread, initializing with empty tensors.
picked_neighbors_cur_thread[i - begin] =
torch::tensor({}, indptr_options);
continue;
}
num_picked_neighbors_data_ptr[i + 1] =
num_neighbors == 0 ? 0 : num_pick_fn(offset, num_neighbors);
}
});
// Step 2. Calculate prefix sum to get total length and offsets of each
// node. It's also the indptr of the generated subgraph.
subgraph_indptr = torch::cumsum(num_picked_neighbors_per_node, 0);
// Step 3. Allocate the tensor for picked neighbors.
const auto total_length =
subgraph_indptr.data_ptr<scalar_t>()[num_nodes];
picked_eids = torch::empty({total_length}, indptr_options);
subgraph_indices = torch::empty({total_length}, indices_.options());
if (type_per_edge_.has_value()) {
subgraph_type_per_edge =
torch::empty({total_length}, type_per_edge_.value().options());
}
// Pre-allocate tensors for each node. Because the pick
// functions are modified, this part of code needed refactoring
// to adapt to the change of APIs. It's temporary since the
// whole process will be rewritten soon.
int64_t allocate_size = num_pick_fn(offset, num_neighbors);
picked_neighbors_cur_thread[i - begin] =
torch::empty({allocate_size}, indptr_options);
torch::Tensor& picked_tensor =
picked_neighbors_cur_thread[i - begin];
AT_DISPATCH_INTEGRAL_TYPES(
picked_tensor.scalar_type(), "CallPick", ([&] {
pick_fn(
offset, num_neighbors,
picked_tensor.data_ptr<scalar_t>());
}));
num_picked_neighbors_per_node[i + 1] = allocate_size;
// Step 4. Pick neighbors for each node.
auto picked_eids_data_ptr = picked_eids.data_ptr<scalar_t>();
auto subgraph_indptr_data_ptr = subgraph_indptr.data_ptr<scalar_t>();
torch::parallel_for(
0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
const auto nid = nodes_data_ptr[i];
const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;
const auto picked_number = num_picked_neighbors_data_ptr[i + 1];
const auto picked_offset = subgraph_indptr_data_ptr[i];
if (picked_number > 0) {
auto actual_picked_count = pick_fn(
offset, num_neighbors,
picked_eids_data_ptr + picked_offset);
TORCH_CHECK(
actual_picked_count == picked_number,
"Actual picked count doesn't match the calculated pick "
"number.");
// Step 5. Calculate other attributes and return the subgraph.
AT_DISPATCH_INTEGRAL_TYPES(
subgraph_indices.scalar_type(),
"IndexSelectSubgraphIndices", ([&] {
auto subgraph_indices_data_ptr =
subgraph_indices.data_ptr<scalar_t>();
auto indices_data_ptr = indices_.data_ptr<scalar_t>();
for (auto i = picked_offset;
i < picked_offset + picked_number; ++i) {
subgraph_indices_data_ptr[i] =
indices_data_ptr[picked_eids_data_ptr[i]];
}
}));
if (type_per_edge_.has_value()) {
AT_DISPATCH_INTEGRAL_TYPES(
subgraph_type_per_edge.value().scalar_type(),
"IndexSelectTypePerEdge", ([&] {
auto subgraph_type_per_edge_data_ptr =
subgraph_type_per_edge.value()
.data_ptr<scalar_t>();
auto type_per_edge_data_ptr =
type_per_edge_.value().data_ptr<scalar_t>();
for (auto i = picked_offset;
i < picked_offset + picked_number; ++i) {
subgraph_type_per_edge_data_ptr[i] =
type_per_edge_data_ptr[picked_eids_data_ptr[i]];
}
}));
}
}
}
picked_neighbors_per_thread[thread_id] =
torch::cat(picked_neighbors_cur_thread);
}); // End of parallel_for.
});
}));
torch::Tensor subgraph_indptr =
torch::cumsum(num_picked_neighbors_per_node, 0);
torch::Tensor picked_eids = torch::cat(picked_neighbors_per_thread);
torch::Tensor subgraph_indices =
torch::index_select(indices_, 0, picked_eids);
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;
if (type_per_edge_.has_value()) {
subgraph_type_per_edge =
torch::index_select(type_per_edge_.value(), 0, picked_eids);
}
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
return c10::make_intrusive<SampledSubgraph>(
subgraph_indptr, subgraph_indices, nodes, torch::nullopt,
subgraph_reverse_edge_ids, subgraph_type_per_edge);
......@@ -383,12 +418,17 @@ int64_t NumPick(
int64_t fanout, bool replace,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors) {
int64_t num_valid_neighbors =
probs_or_mask.has_value()
? *torch::count_nonzero(
probs_or_mask.value().slice(0, offset, offset + num_neighbors))
.data_ptr<int64_t>()
: num_neighbors;
int64_t num_valid_neighbors = num_neighbors;
if (probs_or_mask.has_value()) {
// Subtract the count of zeros in probs_or_mask.
AT_DISPATCH_ALL_TYPES(
probs_or_mask.value().scalar_type(), "CountZero", ([&] {
scalar_t* probs_data_ptr = probs_or_mask.value().data_ptr<scalar_t>();
num_valid_neighbors -= std::count(
probs_data_ptr + offset, probs_data_ptr + offset + num_neighbors,
0);
}));
}
if (num_valid_neighbors == 0 || fanout == -1) return num_valid_neighbors;
return replace ? fanout : std::min(fanout, num_valid_neighbors);
}
......@@ -444,17 +484,19 @@ int64_t NumPickByEtype(
* should be put. Enough memory space should be allocated in advance.
*/
template <typename PickedType>
inline void UniformPick(
inline int64_t UniformPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options, PickedType* picked_data_ptr) {
if ((fanout == -1) || (num_neighbors <= fanout && !replace)) {
std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset);
return num_neighbors;
} else if (replace) {
std::memcpy(
picked_data_ptr,
torch::randint(offset, offset + num_neighbors, {fanout}, options)
.data_ptr<PickedType>(),
fanout * sizeof(PickedType));
return fanout;
} else {
// We use different sampling strategies for different sampling case.
if (fanout >= num_neighbors / 10) {
......@@ -490,6 +532,7 @@ inline void UniformPick(
}
// Save the randomly sampled fanout elements to the output tensor.
std::copy(seq.begin(), seq.begin() + fanout, picked_data_ptr);
return fanout;
} else if (fanout < 64) {
// [Algorithm]
// Use linear search to verify uniqueness.
......@@ -510,6 +553,7 @@ inline void UniformPick(
auto it = std::find(picked_data_ptr, begin, *begin);
if (it == begin) ++begin;
}
return fanout;
} else {
// [Algorithm]
// Use hash-set to verify uniqueness. In the best scenario, the
......@@ -533,6 +577,7 @@ inline void UniformPick(
offset, offset + num_neighbors));
}
std::copy(picked_set.begin(), picked_set.end(), picked_data_ptr);
return picked_set.size();
}
}
}
......@@ -573,7 +618,7 @@ inline void UniformPick(
* should be put. Enough memory space should be allocated in advance.
*/
template <typename PickedType>
inline void NonUniformPick(
inline int64_t NonUniformPick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
......@@ -582,12 +627,13 @@ inline void NonUniformPick(
probs_or_mask.value().slice(0, offset, offset + num_neighbors);
auto positive_probs_indices = local_probs.nonzero().squeeze(1);
auto num_positive_probs = positive_probs_indices.size(0);
if (num_positive_probs == 0) return;
if (num_positive_probs == 0) return 0;
if ((fanout == -1) || (num_positive_probs <= fanout && !replace)) {
std::memcpy(
picked_data_ptr,
(positive_probs_indices + offset).data_ptr<PickedType>(),
num_positive_probs * sizeof(PickedType));
return num_positive_probs;
} else {
if (!replace) fanout = std::min(fanout, num_positive_probs);
std::memcpy(
......@@ -595,27 +641,28 @@ inline void NonUniformPick(
(torch::multinomial(local_probs, fanout, replace) + offset)
.data_ptr<PickedType>(),
fanout * sizeof(PickedType));
return fanout;
}
}
template <typename PickedType>
void Pick(
int64_t Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::NEIGHBOR> args, PickedType* picked_data_ptr) {
if (probs_or_mask.has_value()) {
NonUniformPick(
return NonUniformPick(
offset, num_neighbors, fanout, replace, options, probs_or_mask,
picked_data_ptr);
} else {
UniformPick(
return UniformPick(
offset, num_neighbors, fanout, replace, options, picked_data_ptr);
}
}
template <SamplerType S, typename PickedType>
void PickByEtype(
int64_t PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
bool replace, const torch::TensorOptions& options,
const torch::Tensor& type_per_edge,
......@@ -623,11 +670,11 @@ void PickByEtype(
PickedType* picked_data_ptr) {
int64_t etype_begin = offset;
int64_t etype_end = offset;
int64_t pick_offset = 0;
AT_DISPATCH_INTEGRAL_TYPES(
type_per_edge.scalar_type(), "PickByEtype", ([&] {
const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>();
const auto end = offset + num_neighbors;
int64_t pick_offset = 0;
while (etype_begin < end) {
scalar_t etype = type_per_edge_data[etype_begin];
TORCH_CHECK(
......@@ -638,12 +685,9 @@ void PickByEtype(
type_per_edge_data + etype_begin, type_per_edge_data + end,
etype);
etype_end = etype_end_it - type_per_edge_data;
int64_t picked_count = NumPick(
fanout, replace, probs_or_mask, etype_begin,
etype_end - etype_begin);
// Do sampling for one etype.
if (fanout != 0) {
Pick(
int64_t picked_count = Pick(
etype_begin, etype_end - etype_begin, fanout, replace, options,
probs_or_mask, args, picked_data_ptr + pick_offset);
pick_offset += picked_count;
......@@ -651,43 +695,46 @@ void PickByEtype(
etype_begin = etype_end;
}
}));
return pick_offset;
}
template <typename PickedType>
void Pick(
int64_t Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr) {
if (fanout == 0) return;
if (fanout == 0) return 0;
if (probs_or_mask.has_value()) {
if (fanout < 0) {
NonUniformPick(
return NonUniformPick(
offset, num_neighbors, fanout, replace, options, probs_or_mask,
picked_data_ptr);
} else {
int64_t picked_count;
AT_DISPATCH_FLOATING_TYPES(
probs_or_mask.value().scalar_type(), "LaborPickFloatType", ([&] {
if (replace) {
LaborPick<true, true, scalar_t>(
picked_count = LaborPick<true, true, scalar_t>(
offset, num_neighbors, fanout, options, probs_or_mask, args,
picked_data_ptr);
} else {
LaborPick<true, false, scalar_t>(
picked_count = LaborPick<true, false, scalar_t>(
offset, num_neighbors, fanout, options, probs_or_mask, args,
picked_data_ptr);
}
}));
return picked_count;
}
} else if (fanout < 0) {
UniformPick(
return UniformPick(
offset, num_neighbors, fanout, replace, options, picked_data_ptr);
} else if (replace) {
LaborPick<false, true>(
return LaborPick<false, true>(
offset, num_neighbors, fanout, options,
/* probs_or_mask= */ torch::nullopt, args, picked_data_ptr);
} else { // replace = false
LaborPick<false, false>(
return LaborPick<false, false>(
offset, num_neighbors, fanout, options,
/* probs_or_mask= */ torch::nullopt, args, picked_data_ptr);
}
......@@ -724,7 +771,7 @@ inline void safe_divide(T& a, U b) {
*/
template <
bool NonUniform, bool Replace, typename ProbsType, typename PickedType>
inline void LaborPick(
inline int64_t LaborPick(
int64_t offset, int64_t num_neighbors, int64_t fanout,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
......@@ -732,7 +779,7 @@ inline void LaborPick(
fanout = Replace ? fanout : std::min(fanout, num_neighbors);
if (!NonUniform && !Replace && fanout >= num_neighbors) {
std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset);
return;
return num_neighbors;
}
torch::Tensor heap_tensor = torch::empty({fanout * 2}, torch::kInt32);
// Assuming max_degree of a vertex is <= 4 billion.
......@@ -862,6 +909,7 @@ inline void LaborPick(
TORCH_CHECK(
!Replace || num_sampled == fanout || num_sampled == 0,
"Sampling with replacement should sample exactly fanout neighbors or 0!");
return num_sampled;
}
} // namespace sampling
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment