Unverified Commit 2bc4df22 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] replace item<> with raw pointer (#6601)

parent 33e80452
......@@ -330,96 +330,112 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;
AT_DISPATCH_INTEGRAL_TYPES(
indptr_.scalar_type(), "SampleNeighborsImpl", ([&] {
const scalar_t* indptr_data = indptr_.data_ptr<scalar_t>();
auto num_picked_neighbors_data_ptr =
num_picked_neighbors_per_node.data_ptr<scalar_t>();
num_picked_neighbors_data_ptr[0] = 0;
indptr_.scalar_type(), "SampleNeighborsImplWrappedWithIndptr", ([&] {
using indptr_t = scalar_t;
AT_DISPATCH_INTEGRAL_TYPES(
nodes.scalar_type(), "SampleNeighborsImplWrappedWithNodes", ([&] {
using nodes_t = scalar_t;
const auto indptr_data = indptr_.data_ptr<indptr_t>();
auto num_picked_neighbors_data_ptr =
num_picked_neighbors_per_node.data_ptr<indptr_t>();
num_picked_neighbors_data_ptr[0] = 0;
const auto nodes_data_ptr = nodes.data_ptr<nodes_t>();
// Step 1. Calculate pick number of each node.
torch::parallel_for(
0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
const auto nid = nodes[i].item<int64_t>();
TORCH_CHECK(
nid >= 0 && nid < NumNodes(),
"The seed nodes' IDs should fall within the range of the "
"graph's node IDs.");
const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;
// Step 1. Calculate pick number of each node.
torch::parallel_for(
0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
const auto nid = nodes_data_ptr[i];
TORCH_CHECK(
nid >= 0 && nid < NumNodes(),
"The seed nodes' IDs should fall within the range of "
"the "
"graph's node IDs.");
const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;
num_picked_neighbors_data_ptr[i + 1] =
num_neighbors == 0 ? 0 : num_pick_fn(offset, num_neighbors);
}
});
num_picked_neighbors_data_ptr[i + 1] =
num_neighbors == 0
? 0
: num_pick_fn(offset, num_neighbors);
}
});
// Step 2. Calculate prefix sum to get total length and offsets of each
// node. It's also the indptr of the generated subgraph.
subgraph_indptr =
num_picked_neighbors_per_node.cumsum(0, indptr_.scalar_type());
// Step 2. Calculate prefix sum to get total length and offsets of
// each node. It's also the indptr of the generated subgraph.
subgraph_indptr = num_picked_neighbors_per_node.cumsum(
0, indptr_.scalar_type());
// Step 3. Allocate the tensor for picked neighbors.
const auto total_length =
subgraph_indptr.data_ptr<scalar_t>()[num_nodes];
picked_eids = torch::empty({total_length}, indptr_options);
subgraph_indices = torch::empty({total_length}, indices_.options());
if (type_per_edge_.has_value()) {
subgraph_type_per_edge =
torch::empty({total_length}, type_per_edge_.value().options());
}
// Step 3. Allocate the tensor for picked neighbors.
const auto total_length =
subgraph_indptr.data_ptr<indptr_t>()[num_nodes];
picked_eids = torch::empty({total_length}, indptr_options);
subgraph_indices =
torch::empty({total_length}, indices_.options());
if (type_per_edge_.has_value()) {
subgraph_type_per_edge = torch::empty(
{total_length}, type_per_edge_.value().options());
}
// Step 4. Pick neighbors for each node.
auto picked_eids_data_ptr = picked_eids.data_ptr<scalar_t>();
auto subgraph_indptr_data_ptr = subgraph_indptr.data_ptr<scalar_t>();
torch::parallel_for(
0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
const auto nid = nodes[i].item<int64_t>();
const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;
const auto picked_number = num_picked_neighbors_data_ptr[i + 1];
const auto picked_offset = subgraph_indptr_data_ptr[i];
if (picked_number > 0) {
auto actual_picked_count = pick_fn(
offset, num_neighbors,
picked_eids_data_ptr + picked_offset);
TORCH_CHECK(
actual_picked_count == picked_number,
"Actual picked count doesn't match the calculated pick "
"number.");
// Step 4. Pick neighbors for each node.
auto picked_eids_data_ptr = picked_eids.data_ptr<indptr_t>();
auto subgraph_indptr_data_ptr =
subgraph_indptr.data_ptr<indptr_t>();
torch::parallel_for(
0, num_nodes, grain_size, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
const auto nid = nodes_data_ptr[i];
const auto offset = indptr_data[nid];
const auto num_neighbors = indptr_data[nid + 1] - offset;
const auto picked_number =
num_picked_neighbors_data_ptr[i + 1];
const auto picked_offset = subgraph_indptr_data_ptr[i];
if (picked_number > 0) {
auto actual_picked_count = pick_fn(
offset, num_neighbors,
picked_eids_data_ptr + picked_offset);
TORCH_CHECK(
actual_picked_count == picked_number,
"Actual picked count doesn't match the calculated "
"pick "
"number.");
// Step 5. Calculate other attributes and return the subgraph.
AT_DISPATCH_INTEGRAL_TYPES(
subgraph_indices.scalar_type(),
"IndexSelectSubgraphIndices", ([&] {
auto subgraph_indices_data_ptr =
subgraph_indices.data_ptr<scalar_t>();
auto indices_data_ptr = indices_.data_ptr<scalar_t>();
for (auto i = picked_offset;
i < picked_offset + picked_number; ++i) {
subgraph_indices_data_ptr[i] =
indices_data_ptr[picked_eids_data_ptr[i]];
// Step 5. Calculate other attributes and return the
// subgraph.
AT_DISPATCH_INTEGRAL_TYPES(
subgraph_indices.scalar_type(),
"IndexSelectSubgraphIndices", ([&] {
auto subgraph_indices_data_ptr =
subgraph_indices.data_ptr<scalar_t>();
auto indices_data_ptr =
indices_.data_ptr<scalar_t>();
for (auto i = picked_offset;
i < picked_offset + picked_number; ++i) {
subgraph_indices_data_ptr[i] =
indices_data_ptr[picked_eids_data_ptr[i]];
}
}));
if (type_per_edge_.has_value()) {
AT_DISPATCH_INTEGRAL_TYPES(
subgraph_type_per_edge.value().scalar_type(),
"IndexSelectTypePerEdge", ([&] {
auto subgraph_type_per_edge_data_ptr =
subgraph_type_per_edge.value()
.data_ptr<scalar_t>();
auto type_per_edge_data_ptr =
type_per_edge_.value().data_ptr<scalar_t>();
for (auto i = picked_offset;
i < picked_offset + picked_number; ++i) {
subgraph_type_per_edge_data_ptr[i] =
type_per_edge_data_ptr
[picked_eids_data_ptr[i]];
}
}));
}
}));
if (type_per_edge_.has_value()) {
AT_DISPATCH_INTEGRAL_TYPES(
subgraph_type_per_edge.value().scalar_type(),
"IndexSelectTypePerEdge", ([&] {
auto subgraph_type_per_edge_data_ptr =
subgraph_type_per_edge.value()
.data_ptr<scalar_t>();
auto type_per_edge_data_ptr =
type_per_edge_.value().data_ptr<scalar_t>();
for (auto i = picked_offset;
i < picked_offset + picked_number; ++i) {
subgraph_type_per_edge_data_ptr[i] =
type_per_edge_data_ptr[picked_eids_data_ptr[i]];
}
}));
}
}
}
});
}
}
});
}));
}));
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment