Unverified Commit 6b140f28 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Hetero CPU sampling bug fix. (#7369)

parent 0d9a09df
......@@ -557,7 +557,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
// it equals to `num_seeds`.
const int64_t num_rows = etype_id_to_num_picked_offset[num_etypes];
torch::Tensor num_picked_neighbors_per_node =
torch::empty({num_rows}, indptr_options);
// Need to use zeros because all nodes don't have all etypes.
torch::zeros({num_rows}, indptr_options);
AT_DISPATCH_INDEX_TYPES(
indptr_.scalar_type(), "SampleNeighborsImplWrappedWithIndptr", ([&] {
......@@ -571,14 +572,6 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
num_picked_neighbors_data_ptr[0] = 0;
const auto seeds_data_ptr = seeds.data_ptr<seeds_t>();
// Initialize the empty spots in `num_picked_neighbors_per_node`.
if (hetero_with_seed_offsets) {
for (auto i = 0; i < num_etypes; ++i) {
num_picked_neighbors_data_ptr
[etype_id_to_num_picked_offset[i]] = 0;
}
}
// Step 1. Calculate pick number of each node.
torch::parallel_for(
0, num_seeds, grain_size, [&](int64_t begin, int64_t end) {
......@@ -612,40 +605,36 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
}
});
// Step 2. Calculate prefix sum to get total length and offsets of
// each node. It's also the indptr of the generated subgraph.
subgraph_indptr = num_picked_neighbors_per_node.cumsum(
0, indptr_.scalar_type());
auto subgraph_indptr_data_ptr =
subgraph_indptr.data_ptr<indptr_t>();
if (hetero_with_seed_offsets) {
torch::Tensor num_picked_offset_tensor =
torch::zeros({num_etypes + 1}, indptr_options);
torch::empty({num_etypes + 1}, indptr_options);
const auto num_picked_offset_data_ptr =
num_picked_offset_tensor.data_ptr<indptr_t>();
std::copy(
etype_id_to_num_picked_offset.begin(),
etype_id_to_num_picked_offset.end(),
num_picked_offset_data_ptr);
torch::Tensor substract_offset =
torch::zeros({num_etypes}, indptr_options);
torch::empty({num_etypes}, indptr_options);
const auto substract_offset_data_ptr =
substract_offset.data_ptr<indptr_t>();
const auto num_picked_offset_data_ptr =
num_picked_offset_tensor.data_ptr<indptr_t>();
for (auto i = 0; i < num_etypes; ++i) {
num_picked_offset_data_ptr[i + 1] =
etype_id_to_num_picked_offset[i + 1];
// Collect the total pick number for each edge type.
if (i + 1 < num_etypes)
substract_offset_data_ptr[i + 1] =
num_picked_neighbors_data_ptr
[etype_id_to_num_picked_offset[i]];
num_picked_neighbors_data_ptr
[etype_id_to_num_picked_offset[i]] = 0;
// Collect the total pick number subtract offsets.
substract_offset_data_ptr[i] = subgraph_indptr_data_ptr
[etype_id_to_num_picked_offset[i]];
}
substract_offset =
substract_offset.cumsum(0, indptr_.scalar_type());
subgraph_indptr_substract = ops::ExpandIndptr(
num_picked_offset_tensor, indptr_.scalar_type(),
substract_offset);
}
// Step 2. Calculate prefix sum to get total length and offsets of
// each node. It's also the indptr of the generated subgraph.
subgraph_indptr = num_picked_neighbors_per_node.cumsum(
0, indptr_.scalar_type());
auto subgraph_indptr_data_ptr =
subgraph_indptr.data_ptr<indptr_t>();
// When doing non-temporal hetero sampling, we generate an
// edge_offsets tensor.
if (hetero_with_seed_offsets) {
......@@ -1277,11 +1266,6 @@ void NumPickByEtype(
NumPick(
fanouts[etype], replace, probs_or_mask, etype_begin,
etype_end - etype_begin, num_picked_ptr + offset);
// Use the skipped position of each edge type in the
// num_picked_tensor to sum up the total pick number for each edge
// type.
num_picked_ptr[etype_id_to_num_picked_offset[etype] - 1] +=
num_picked_ptr[offset];
} else {
PickedNumType picked_count = 0;
NumPick(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment