"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "7444f56819f679f68eee8bf915bbb74be3da0e40"
Unverified Commit 2bc4df22 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] replace item<> with raw pointer (#6601)

parent 33e80452
...@@ -330,54 +330,65 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -330,54 +330,65 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt; torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
indptr_.scalar_type(), "SampleNeighborsImpl", ([&] { indptr_.scalar_type(), "SampleNeighborsImplWrappedWithIndptr", ([&] {
const scalar_t* indptr_data = indptr_.data_ptr<scalar_t>(); using indptr_t = scalar_t;
AT_DISPATCH_INTEGRAL_TYPES(
nodes.scalar_type(), "SampleNeighborsImplWrappedWithNodes", ([&] {
using nodes_t = scalar_t;
const auto indptr_data = indptr_.data_ptr<indptr_t>();
auto num_picked_neighbors_data_ptr = auto num_picked_neighbors_data_ptr =
num_picked_neighbors_per_node.data_ptr<scalar_t>(); num_picked_neighbors_per_node.data_ptr<indptr_t>();
num_picked_neighbors_data_ptr[0] = 0; num_picked_neighbors_data_ptr[0] = 0;
const auto nodes_data_ptr = nodes.data_ptr<nodes_t>();
// Step 1. Calculate pick number of each node. // Step 1. Calculate pick number of each node.
torch::parallel_for( torch::parallel_for(
0, num_nodes, grain_size, [&](int64_t begin, int64_t end) { 0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) { for (int64_t i = begin; i < end; ++i) {
const auto nid = nodes[i].item<int64_t>(); const auto nid = nodes_data_ptr[i];
TORCH_CHECK( TORCH_CHECK(
nid >= 0 && nid < NumNodes(), nid >= 0 && nid < NumNodes(),
"The seed nodes' IDs should fall within the range of the " "The seed nodes' IDs should fall within the range of "
"the "
"graph's node IDs."); "graph's node IDs.");
const auto offset = indptr_data[nid]; const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset; const auto num_neighbors = indptr_data[nid + 1] - offset;
num_picked_neighbors_data_ptr[i + 1] = num_picked_neighbors_data_ptr[i + 1] =
num_neighbors == 0 ? 0 : num_pick_fn(offset, num_neighbors); num_neighbors == 0
? 0
: num_pick_fn(offset, num_neighbors);
} }
}); });
// Step 2. Calculate prefix sum to get total length and offsets of each // Step 2. Calculate prefix sum to get total length and offsets of
// node. It's also the indptr of the generated subgraph. // each node. It's also the indptr of the generated subgraph.
subgraph_indptr = subgraph_indptr = num_picked_neighbors_per_node.cumsum(
num_picked_neighbors_per_node.cumsum(0, indptr_.scalar_type()); 0, indptr_.scalar_type());
// Step 3. Allocate the tensor for picked neighbors. // Step 3. Allocate the tensor for picked neighbors.
const auto total_length = const auto total_length =
subgraph_indptr.data_ptr<scalar_t>()[num_nodes]; subgraph_indptr.data_ptr<indptr_t>()[num_nodes];
picked_eids = torch::empty({total_length}, indptr_options); picked_eids = torch::empty({total_length}, indptr_options);
subgraph_indices = torch::empty({total_length}, indices_.options()); subgraph_indices =
torch::empty({total_length}, indices_.options());
if (type_per_edge_.has_value()) { if (type_per_edge_.has_value()) {
subgraph_type_per_edge = subgraph_type_per_edge = torch::empty(
torch::empty({total_length}, type_per_edge_.value().options()); {total_length}, type_per_edge_.value().options());
} }
// Step 4. Pick neighbors for each node. // Step 4. Pick neighbors for each node.
auto picked_eids_data_ptr = picked_eids.data_ptr<scalar_t>(); auto picked_eids_data_ptr = picked_eids.data_ptr<indptr_t>();
auto subgraph_indptr_data_ptr = subgraph_indptr.data_ptr<scalar_t>(); auto subgraph_indptr_data_ptr =
subgraph_indptr.data_ptr<indptr_t>();
torch::parallel_for( torch::parallel_for(
0, num_nodes, grain_size, [&](int64_t begin, int64_t end) { 0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) { for (int64_t i = begin; i < end; ++i) {
const auto nid = nodes[i].item<int64_t>(); const auto nid = nodes_data_ptr[i];
const auto offset = indptr_data[nid]; const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset; const auto num_neighbors = indptr_data[nid + 1] - offset;
const auto picked_number = num_picked_neighbors_data_ptr[i + 1]; const auto picked_number =
num_picked_neighbors_data_ptr[i + 1];
const auto picked_offset = subgraph_indptr_data_ptr[i]; const auto picked_offset = subgraph_indptr_data_ptr[i];
if (picked_number > 0) { if (picked_number > 0) {
auto actual_picked_count = pick_fn( auto actual_picked_count = pick_fn(
...@@ -385,16 +396,19 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -385,16 +396,19 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
picked_eids_data_ptr + picked_offset); picked_eids_data_ptr + picked_offset);
TORCH_CHECK( TORCH_CHECK(
actual_picked_count == picked_number, actual_picked_count == picked_number,
"Actual picked count doesn't match the calculated pick " "Actual picked count doesn't match the calculated "
"pick "
"number."); "number.");
// Step 5. Calculate other attributes and return the subgraph. // Step 5. Calculate other attributes and return the
// subgraph.
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
subgraph_indices.scalar_type(), subgraph_indices.scalar_type(),
"IndexSelectSubgraphIndices", ([&] { "IndexSelectSubgraphIndices", ([&] {
auto subgraph_indices_data_ptr = auto subgraph_indices_data_ptr =
subgraph_indices.data_ptr<scalar_t>(); subgraph_indices.data_ptr<scalar_t>();
auto indices_data_ptr = indices_.data_ptr<scalar_t>(); auto indices_data_ptr =
indices_.data_ptr<scalar_t>();
for (auto i = picked_offset; for (auto i = picked_offset;
i < picked_offset + picked_number; ++i) { i < picked_offset + picked_number; ++i) {
subgraph_indices_data_ptr[i] = subgraph_indices_data_ptr[i] =
...@@ -413,7 +427,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -413,7 +427,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
for (auto i = picked_offset; for (auto i = picked_offset;
i < picked_offset + picked_number; ++i) { i < picked_offset + picked_number; ++i) {
subgraph_type_per_edge_data_ptr[i] = subgraph_type_per_edge_data_ptr[i] =
type_per_edge_data_ptr[picked_eids_data_ptr[i]]; type_per_edge_data_ptr
[picked_eids_data_ptr[i]];
} }
})); }));
} }
...@@ -421,6 +436,7 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -421,6 +436,7 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
} }
}); });
})); }));
}));
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt; torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids); if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment