Unverified Commit b3224ce8 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Switch to using `AT_DISPATCH_INDEX_TYPES` for graph (#6912)

parent 61504ec5
......@@ -164,16 +164,16 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto max_in_degree = torch::empty(
1,
c10::TensorOptions().dtype(in_degree.scalar_type()).pinned_memory(true));
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsInDegree", ([&] {
size_t tmp_storage_size = 0;
cub::DeviceReduce::Max(
nullptr, tmp_storage_size, in_degree.data_ptr<scalar_t>(),
max_in_degree.data_ptr<scalar_t>(), num_rows, stream);
nullptr, tmp_storage_size, in_degree.data_ptr<index_t>(),
max_in_degree.data_ptr<index_t>(), num_rows, stream);
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
cub::DeviceReduce::Max(
tmp_storage.get(), tmp_storage_size, in_degree.data_ptr<scalar_t>(),
max_in_degree.data_ptr<scalar_t>(), num_rows, stream);
tmp_storage.get(), tmp_storage_size, in_degree.data_ptr<index_t>(),
max_in_degree.data_ptr<index_t>(), num_rows, stream);
}));
auto coo_rows = CSRToCOO(sub_indptr, indices.scalar_type());
const auto num_edges = coo_rows.size(0);
......@@ -184,9 +184,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::Tensor output_indices;
torch::optional<torch::Tensor> output_type_per_edge;
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
using indptr_t = scalar_t;
using indptr_t = index_t;
thrust::counting_iterator<int64_t> iota(0);
auto sampled_degree = thrust::make_transform_iterator(
iota, MinInDegreeFanout<indptr_t>{
......@@ -234,9 +234,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
allocator.AllocateStorage<edge_id_t>(num_edges);
auto sorted_edge_id_segments =
allocator.AllocateStorage<edge_id_t>(num_edges);
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "SampleNeighborsIndices", ([&] {
using indices_t = scalar_t;
using indices_t = index_t;
auto probs_or_mask_scalar_type = torch::kFloat32;
if (probs_or_mask.has_value()) {
probs_or_mask_scalar_type =
......@@ -347,9 +347,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
picked_eids.options().dtype(indices.scalar_type()));
// Compute: output_indices = indices.gather(0, picked_eids);
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "SampleNeighborsOutputIndices", ([&] {
using indices_t = scalar_t;
using indices_t = index_t;
const auto exec_policy =
thrust::cuda::par_nosync(allocator).on(stream);
thrust::gather(
......
......@@ -293,14 +293,14 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
std::vector<torch::Tensor> edge_ids_arr(num_seeds);
std::vector<torch::Tensor> type_per_edge_arr(num_seeds);
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
indptr_.scalar_type(), "InSubgraph", ([&] {
torch::parallel_for(
0, num_seeds, kDefaultGrainSize, [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
const auto node_id = nodes[i].item<scalar_t>();
const auto start_idx = indptr_[node_id].item<scalar_t>();
const auto end_idx = indptr_[node_id + 1].item<scalar_t>();
const auto node_id = nodes[i].item<index_t>();
const auto start_idx = indptr_[node_id].item<index_t>();
const auto end_idx = indptr_[node_id + 1].item<index_t>();
indptr[i + 1] = end_idx - start_idx;
original_column_node_ids[i] = node_id;
indices_arr[i] = indices_.slice(0, start_idx, end_idx);
......@@ -490,12 +490,12 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
torch::Tensor subgraph_indices;
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
indptr_.scalar_type(), "SampleNeighborsImplWrappedWithIndptr", ([&] {
using indptr_t = scalar_t;
AT_DISPATCH_INTEGRAL_TYPES(
using indptr_t = index_t;
AT_DISPATCH_INDEX_TYPES(
nodes.scalar_type(), "SampleNeighborsImplWrappedWithNodes", ([&] {
using nodes_t = scalar_t;
using nodes_t = index_t;
const auto indptr_data = indptr_.data_ptr<indptr_t>();
auto num_picked_neighbors_data_ptr =
num_picked_neighbors_per_node.data_ptr<indptr_t>();
......@@ -563,13 +563,13 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
// Step 5. Calculate other attributes and return the
// subgraph.
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
subgraph_indices.scalar_type(),
"IndexSelectSubgraphIndices", ([&] {
auto subgraph_indices_data_ptr =
subgraph_indices.data_ptr<scalar_t>();
subgraph_indices.data_ptr<index_t>();
auto indices_data_ptr =
indices_.data_ptr<scalar_t>();
indices_.data_ptr<index_t>();
for (auto i = picked_offset;
i < picked_offset + picked_number; ++i) {
subgraph_indices_data_ptr[i] =
......@@ -1394,10 +1394,10 @@ inline int64_t LaborPick(
if (NonUniform && probs_or_mask.value().size(0) <= num_neighbors) {
local_probs_data -= offset;
}
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
args.indices.scalar_type(), "LaborPickMain", ([&] {
const scalar_t* local_indices_data =
args.indices.data_ptr<scalar_t>() + offset;
const index_t* local_indices_data =
args.indices.data_ptr<index_t>() + offset;
if constexpr (Replace) {
// [Algorithm] @mfbalin
// Use a max-heap to get rid of the big random numbers and filter the
......@@ -1431,7 +1431,7 @@ inline int64_t LaborPick(
auto heap_end = heap_data;
const auto init_count = (num_neighbors + fanout - 1) / num_neighbors;
auto sample_neighbor_i_with_index_t_jth_time =
[&](scalar_t t, int64_t j, uint32_t i) {
[&](index_t t, int64_t j, uint32_t i) {
auto rnd = labor::jth_sorted_uniform_random(
args.random_seed, t, args.num_nodes, j, remaining_data[i],
fanout - j); // r_t
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment