"...source/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "91b05e2ec78e44856d90f4258f91d56807227bac"
Unverified Commit 6b140f28 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Hetero CPU sampling bug fix. (#7369)

parent 0d9a09df
...@@ -557,7 +557,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -557,7 +557,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
// it equals to `num_seeds`. // it equals to `num_seeds`.
const int64_t num_rows = etype_id_to_num_picked_offset[num_etypes]; const int64_t num_rows = etype_id_to_num_picked_offset[num_etypes];
torch::Tensor num_picked_neighbors_per_node = torch::Tensor num_picked_neighbors_per_node =
torch::empty({num_rows}, indptr_options); // Need to use zeros because all nodes don't have all etypes.
torch::zeros({num_rows}, indptr_options);
AT_DISPATCH_INDEX_TYPES( AT_DISPATCH_INDEX_TYPES(
indptr_.scalar_type(), "SampleNeighborsImplWrappedWithIndptr", ([&] { indptr_.scalar_type(), "SampleNeighborsImplWrappedWithIndptr", ([&] {
...@@ -571,14 +572,6 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -571,14 +572,6 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
num_picked_neighbors_data_ptr[0] = 0; num_picked_neighbors_data_ptr[0] = 0;
const auto seeds_data_ptr = seeds.data_ptr<seeds_t>(); const auto seeds_data_ptr = seeds.data_ptr<seeds_t>();
// Initialize the empty spots in `num_picked_neighbors_per_node`.
if (hetero_with_seed_offsets) {
for (auto i = 0; i < num_etypes; ++i) {
num_picked_neighbors_data_ptr
[etype_id_to_num_picked_offset[i]] = 0;
}
}
// Step 1. Calculate pick number of each node. // Step 1. Calculate pick number of each node.
torch::parallel_for( torch::parallel_for(
0, num_seeds, grain_size, [&](int64_t begin, int64_t end) { 0, num_seeds, grain_size, [&](int64_t begin, int64_t end) {
...@@ -612,40 +605,36 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -612,40 +605,36 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
} }
}); });
// Step 2. Calculate prefix sum to get total length and offsets of
// each node. It's also the indptr of the generated subgraph.
subgraph_indptr = num_picked_neighbors_per_node.cumsum(
0, indptr_.scalar_type());
auto subgraph_indptr_data_ptr =
subgraph_indptr.data_ptr<indptr_t>();
if (hetero_with_seed_offsets) { if (hetero_with_seed_offsets) {
torch::Tensor num_picked_offset_tensor = torch::Tensor num_picked_offset_tensor =
torch::zeros({num_etypes + 1}, indptr_options); torch::empty({num_etypes + 1}, indptr_options);
const auto num_picked_offset_data_ptr =
num_picked_offset_tensor.data_ptr<indptr_t>();
std::copy(
etype_id_to_num_picked_offset.begin(),
etype_id_to_num_picked_offset.end(),
num_picked_offset_data_ptr);
torch::Tensor substract_offset = torch::Tensor substract_offset =
torch::zeros({num_etypes}, indptr_options); torch::empty({num_etypes}, indptr_options);
const auto substract_offset_data_ptr = const auto substract_offset_data_ptr =
substract_offset.data_ptr<indptr_t>(); substract_offset.data_ptr<indptr_t>();
const auto num_picked_offset_data_ptr =
num_picked_offset_tensor.data_ptr<indptr_t>();
for (auto i = 0; i < num_etypes; ++i) { for (auto i = 0; i < num_etypes; ++i) {
num_picked_offset_data_ptr[i + 1] = // Collect the total pick number subtract offsets.
etype_id_to_num_picked_offset[i + 1]; substract_offset_data_ptr[i] = subgraph_indptr_data_ptr
// Collect the total pick number for each edge type. [etype_id_to_num_picked_offset[i]];
if (i + 1 < num_etypes)
substract_offset_data_ptr[i + 1] =
num_picked_neighbors_data_ptr
[etype_id_to_num_picked_offset[i]];
num_picked_neighbors_data_ptr
[etype_id_to_num_picked_offset[i]] = 0;
} }
substract_offset =
substract_offset.cumsum(0, indptr_.scalar_type());
subgraph_indptr_substract = ops::ExpandIndptr( subgraph_indptr_substract = ops::ExpandIndptr(
num_picked_offset_tensor, indptr_.scalar_type(), num_picked_offset_tensor, indptr_.scalar_type(),
substract_offset); substract_offset);
} }
// Step 2. Calculate prefix sum to get total length and offsets of
// each node. It's also the indptr of the generated subgraph.
subgraph_indptr = num_picked_neighbors_per_node.cumsum(
0, indptr_.scalar_type());
auto subgraph_indptr_data_ptr =
subgraph_indptr.data_ptr<indptr_t>();
// When doing non-temporal hetero sampling, we generate an // When doing non-temporal hetero sampling, we generate an
// edge_offsets tensor. // edge_offsets tensor.
if (hetero_with_seed_offsets) { if (hetero_with_seed_offsets) {
...@@ -1277,11 +1266,6 @@ void NumPickByEtype( ...@@ -1277,11 +1266,6 @@ void NumPickByEtype(
NumPick( NumPick(
fanouts[etype], replace, probs_or_mask, etype_begin, fanouts[etype], replace, probs_or_mask, etype_begin,
etype_end - etype_begin, num_picked_ptr + offset); etype_end - etype_begin, num_picked_ptr + offset);
// Use the skipped position of each edge type in the
// num_picked_tensor to sum up the total pick number for each edge
// type.
num_picked_ptr[etype_id_to_num_picked_offset[etype] - 1] +=
num_picked_ptr[offset];
} else { } else {
PickedNumType picked_count = 0; PickedNumType picked_count = 0;
NumPick( NumPick(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment