Unverified Commit 2bc4df22 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] replace item<> with raw pointer (#6601)

parent 33e80452
...@@ -330,96 +330,112 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -330,96 +330,112 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt; torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INTEGRAL_TYPES(
indptr_.scalar_type(), "SampleNeighborsImpl", ([&] { indptr_.scalar_type(), "SampleNeighborsImplWrappedWithIndptr", ([&] {
const scalar_t* indptr_data = indptr_.data_ptr<scalar_t>(); using indptr_t = scalar_t;
auto num_picked_neighbors_data_ptr = AT_DISPATCH_INTEGRAL_TYPES(
num_picked_neighbors_per_node.data_ptr<scalar_t>(); nodes.scalar_type(), "SampleNeighborsImplWrappedWithNodes", ([&] {
num_picked_neighbors_data_ptr[0] = 0; using nodes_t = scalar_t;
const auto indptr_data = indptr_.data_ptr<indptr_t>();
auto num_picked_neighbors_data_ptr =
num_picked_neighbors_per_node.data_ptr<indptr_t>();
num_picked_neighbors_data_ptr[0] = 0;
const auto nodes_data_ptr = nodes.data_ptr<nodes_t>();
// Step 1. Calculate pick number of each node. // Step 1. Calculate pick number of each node.
torch::parallel_for( torch::parallel_for(
0, num_nodes, grain_size, [&](int64_t begin, int64_t end) { 0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) { for (int64_t i = begin; i < end; ++i) {
const auto nid = nodes[i].item<int64_t>(); const auto nid = nodes_data_ptr[i];
TORCH_CHECK( TORCH_CHECK(
nid >= 0 && nid < NumNodes(), nid >= 0 && nid < NumNodes(),
"The seed nodes' IDs should fall within the range of the " "The seed nodes' IDs should fall within the range of "
"graph's node IDs."); "the "
const auto offset = indptr_data[nid]; "graph's node IDs.");
const auto num_neighbors = indptr_data[nid + 1] - offset; const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;
num_picked_neighbors_data_ptr[i + 1] = num_picked_neighbors_data_ptr[i + 1] =
num_neighbors == 0 ? 0 : num_pick_fn(offset, num_neighbors); num_neighbors == 0
} ? 0
}); : num_pick_fn(offset, num_neighbors);
}
});
// Step 2. Calculate prefix sum to get total length and offsets of each // Step 2. Calculate prefix sum to get total length and offsets of
// node. It's also the indptr of the generated subgraph. // each node. It's also the indptr of the generated subgraph.
subgraph_indptr = subgraph_indptr = num_picked_neighbors_per_node.cumsum(
num_picked_neighbors_per_node.cumsum(0, indptr_.scalar_type()); 0, indptr_.scalar_type());
// Step 3. Allocate the tensor for picked neighbors. // Step 3. Allocate the tensor for picked neighbors.
const auto total_length = const auto total_length =
subgraph_indptr.data_ptr<scalar_t>()[num_nodes]; subgraph_indptr.data_ptr<indptr_t>()[num_nodes];
picked_eids = torch::empty({total_length}, indptr_options); picked_eids = torch::empty({total_length}, indptr_options);
subgraph_indices = torch::empty({total_length}, indices_.options()); subgraph_indices =
if (type_per_edge_.has_value()) { torch::empty({total_length}, indices_.options());
subgraph_type_per_edge = if (type_per_edge_.has_value()) {
torch::empty({total_length}, type_per_edge_.value().options()); subgraph_type_per_edge = torch::empty(
} {total_length}, type_per_edge_.value().options());
}
// Step 4. Pick neighbors for each node. // Step 4. Pick neighbors for each node.
auto picked_eids_data_ptr = picked_eids.data_ptr<scalar_t>(); auto picked_eids_data_ptr = picked_eids.data_ptr<indptr_t>();
auto subgraph_indptr_data_ptr = subgraph_indptr.data_ptr<scalar_t>(); auto subgraph_indptr_data_ptr =
torch::parallel_for( subgraph_indptr.data_ptr<indptr_t>();
0, num_nodes, grain_size, [&](int64_t begin, int64_t end) { torch::parallel_for(
for (int64_t i = begin; i < end; ++i) { 0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
const auto nid = nodes[i].item<int64_t>(); for (int64_t i = begin; i < end; ++i) {
const auto offset = indptr_data[nid]; const auto nid = nodes_data_ptr[i];
const auto num_neighbors = indptr_data[nid + 1] - offset; const auto offset = indptr_data[nid];
const auto picked_number = num_picked_neighbors_data_ptr[i + 1]; const auto num_neighbors = indptr_data[nid + 1] - offset;
const auto picked_offset = subgraph_indptr_data_ptr[i]; const auto picked_number =
if (picked_number > 0) { num_picked_neighbors_data_ptr[i + 1];
auto actual_picked_count = pick_fn( const auto picked_offset = subgraph_indptr_data_ptr[i];
offset, num_neighbors, if (picked_number > 0) {
picked_eids_data_ptr + picked_offset); auto actual_picked_count = pick_fn(
TORCH_CHECK( offset, num_neighbors,
actual_picked_count == picked_number, picked_eids_data_ptr + picked_offset);
"Actual picked count doesn't match the calculated pick " TORCH_CHECK(
"number."); actual_picked_count == picked_number,
"Actual picked count doesn't match the calculated "
"pick "
"number.");
// Step 5. Calculate other attributes and return the subgraph. // Step 5. Calculate other attributes and return the
AT_DISPATCH_INTEGRAL_TYPES( // subgraph.
subgraph_indices.scalar_type(), AT_DISPATCH_INTEGRAL_TYPES(
"IndexSelectSubgraphIndices", ([&] { subgraph_indices.scalar_type(),
auto subgraph_indices_data_ptr = "IndexSelectSubgraphIndices", ([&] {
subgraph_indices.data_ptr<scalar_t>(); auto subgraph_indices_data_ptr =
auto indices_data_ptr = indices_.data_ptr<scalar_t>(); subgraph_indices.data_ptr<scalar_t>();
for (auto i = picked_offset; auto indices_data_ptr =
i < picked_offset + picked_number; ++i) { indices_.data_ptr<scalar_t>();
subgraph_indices_data_ptr[i] = for (auto i = picked_offset;
indices_data_ptr[picked_eids_data_ptr[i]]; i < picked_offset + picked_number; ++i) {
subgraph_indices_data_ptr[i] =
indices_data_ptr[picked_eids_data_ptr[i]];
}
}));
if (type_per_edge_.has_value()) {
AT_DISPATCH_INTEGRAL_TYPES(
subgraph_type_per_edge.value().scalar_type(),
"IndexSelectTypePerEdge", ([&] {
auto subgraph_type_per_edge_data_ptr =
subgraph_type_per_edge.value()
.data_ptr<scalar_t>();
auto type_per_edge_data_ptr =
type_per_edge_.value().data_ptr<scalar_t>();
for (auto i = picked_offset;
i < picked_offset + picked_number; ++i) {
subgraph_type_per_edge_data_ptr[i] =
type_per_edge_data_ptr
[picked_eids_data_ptr[i]];
}
}));
} }
})); }
if (type_per_edge_.has_value()) { }
AT_DISPATCH_INTEGRAL_TYPES( });
subgraph_type_per_edge.value().scalar_type(), }));
"IndexSelectTypePerEdge", ([&] {
auto subgraph_type_per_edge_data_ptr =
subgraph_type_per_edge.value()
.data_ptr<scalar_t>();
auto type_per_edge_data_ptr =
type_per_edge_.value().data_ptr<scalar_t>();
for (auto i = picked_offset;
i < picked_offset + picked_number; ++i) {
subgraph_type_per_edge_data_ptr[i] =
type_per_edge_data_ptr[picked_eids_data_ptr[i]];
}
}));
}
}
}
});
})); }));
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt; torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment