Unverified Commit b3224ce8 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt] Switch to using `AT_DISPATCH_INDEX_TYPES` for graph (#6912)

parent 61504ec5
...@@ -164,16 +164,16 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -164,16 +164,16 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto max_in_degree = torch::empty( auto max_in_degree = torch::empty(
1, 1,
c10::TensorOptions().dtype(in_degree.scalar_type()).pinned_memory(true)); c10::TensorOptions().dtype(in_degree.scalar_type()).pinned_memory(true));
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsInDegree", ([&] { indptr.scalar_type(), "SampleNeighborsInDegree", ([&] {
size_t tmp_storage_size = 0; size_t tmp_storage_size = 0;
cub::DeviceReduce::Max( cub::DeviceReduce::Max(
nullptr, tmp_storage_size, in_degree.data_ptr<scalar_t>(), nullptr, tmp_storage_size, in_degree.data_ptr<index_t>(),
max_in_degree.data_ptr<scalar_t>(), num_rows, stream); max_in_degree.data_ptr<index_t>(), num_rows, stream);
auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size); auto tmp_storage = allocator.AllocateStorage<char>(tmp_storage_size);
cub::DeviceReduce::Max( cub::DeviceReduce::Max(
tmp_storage.get(), tmp_storage_size, in_degree.data_ptr<scalar_t>(), tmp_storage.get(), tmp_storage_size, in_degree.data_ptr<index_t>(),
max_in_degree.data_ptr<scalar_t>(), num_rows, stream); max_in_degree.data_ptr<index_t>(), num_rows, stream);
})); }));
auto coo_rows = CSRToCOO(sub_indptr, indices.scalar_type()); auto coo_rows = CSRToCOO(sub_indptr, indices.scalar_type());
const auto num_edges = coo_rows.size(0); const auto num_edges = coo_rows.size(0);
...@@ -184,9 +184,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -184,9 +184,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::Tensor output_indices; torch::Tensor output_indices;
torch::optional<torch::Tensor> output_type_per_edge; torch::optional<torch::Tensor> output_type_per_edge;
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsIndptr", ([&] { indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
using indptr_t = scalar_t; using indptr_t = index_t;
thrust::counting_iterator<int64_t> iota(0); thrust::counting_iterator<int64_t> iota(0);
auto sampled_degree = thrust::make_transform_iterator( auto sampled_degree = thrust::make_transform_iterator(
iota, MinInDegreeFanout<indptr_t>{ iota, MinInDegreeFanout<indptr_t>{
...@@ -234,9 +234,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -234,9 +234,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
allocator.AllocateStorage<edge_id_t>(num_edges); allocator.AllocateStorage<edge_id_t>(num_edges);
auto sorted_edge_id_segments = auto sorted_edge_id_segments =
allocator.AllocateStorage<edge_id_t>(num_edges); allocator.AllocateStorage<edge_id_t>(num_edges);
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "SampleNeighborsIndices", ([&] { indices.scalar_type(), "SampleNeighborsIndices", ([&] {
using indices_t = scalar_t; using indices_t = index_t;
auto probs_or_mask_scalar_type = torch::kFloat32; auto probs_or_mask_scalar_type = torch::kFloat32;
if (probs_or_mask.has_value()) { if (probs_or_mask.has_value()) {
probs_or_mask_scalar_type = probs_or_mask_scalar_type =
...@@ -347,9 +347,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -347,9 +347,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
picked_eids.options().dtype(indices.scalar_type())); picked_eids.options().dtype(indices.scalar_type()));
// Compute: output_indices = indices.gather(0, picked_eids); // Compute: output_indices = indices.gather(0, picked_eids);
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "SampleNeighborsOutputIndices", ([&] { indices.scalar_type(), "SampleNeighborsOutputIndices", ([&] {
using indices_t = scalar_t; using indices_t = index_t;
const auto exec_policy = const auto exec_policy =
thrust::cuda::par_nosync(allocator).on(stream); thrust::cuda::par_nosync(allocator).on(stream);
thrust::gather( thrust::gather(
......
...@@ -293,14 +293,14 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph( ...@@ -293,14 +293,14 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
std::vector<torch::Tensor> edge_ids_arr(num_seeds); std::vector<torch::Tensor> edge_ids_arr(num_seeds);
std::vector<torch::Tensor> type_per_edge_arr(num_seeds); std::vector<torch::Tensor> type_per_edge_arr(num_seeds);
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INDEX_TYPES(
indptr_.scalar_type(), "InSubgraph", ([&] { indptr_.scalar_type(), "InSubgraph", ([&] {
torch::parallel_for( torch::parallel_for(
0, num_seeds, kDefaultGrainSize, [&](size_t start, size_t end) { 0, num_seeds, kDefaultGrainSize, [&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) { for (size_t i = start; i < end; ++i) {
const auto node_id = nodes[i].item<scalar_t>(); const auto node_id = nodes[i].item<index_t>();
const auto start_idx = indptr_[node_id].item<scalar_t>(); const auto start_idx = indptr_[node_id].item<index_t>();
const auto end_idx = indptr_[node_id + 1].item<scalar_t>(); const auto end_idx = indptr_[node_id + 1].item<index_t>();
indptr[i + 1] = end_idx - start_idx; indptr[i + 1] = end_idx - start_idx;
original_column_node_ids[i] = node_id; original_column_node_ids[i] = node_id;
indices_arr[i] = indices_.slice(0, start_idx, end_idx); indices_arr[i] = indices_.slice(0, start_idx, end_idx);
...@@ -490,12 +490,12 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -490,12 +490,12 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
torch::Tensor subgraph_indices; torch::Tensor subgraph_indices;
torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt; torch::optional<torch::Tensor> subgraph_type_per_edge = torch::nullopt;
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INDEX_TYPES(
indptr_.scalar_type(), "SampleNeighborsImplWrappedWithIndptr", ([&] { indptr_.scalar_type(), "SampleNeighborsImplWrappedWithIndptr", ([&] {
using indptr_t = scalar_t; using indptr_t = index_t;
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INDEX_TYPES(
nodes.scalar_type(), "SampleNeighborsImplWrappedWithNodes", ([&] { nodes.scalar_type(), "SampleNeighborsImplWrappedWithNodes", ([&] {
using nodes_t = scalar_t; using nodes_t = index_t;
const auto indptr_data = indptr_.data_ptr<indptr_t>(); const auto indptr_data = indptr_.data_ptr<indptr_t>();
auto num_picked_neighbors_data_ptr = auto num_picked_neighbors_data_ptr =
num_picked_neighbors_per_node.data_ptr<indptr_t>(); num_picked_neighbors_per_node.data_ptr<indptr_t>();
...@@ -563,13 +563,13 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( ...@@ -563,13 +563,13 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
// Step 5. Calculate other attributes and return the // Step 5. Calculate other attributes and return the
// subgraph. // subgraph.
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INDEX_TYPES(
subgraph_indices.scalar_type(), subgraph_indices.scalar_type(),
"IndexSelectSubgraphIndices", ([&] { "IndexSelectSubgraphIndices", ([&] {
auto subgraph_indices_data_ptr = auto subgraph_indices_data_ptr =
subgraph_indices.data_ptr<scalar_t>(); subgraph_indices.data_ptr<index_t>();
auto indices_data_ptr = auto indices_data_ptr =
indices_.data_ptr<scalar_t>(); indices_.data_ptr<index_t>();
for (auto i = picked_offset; for (auto i = picked_offset;
i < picked_offset + picked_number; ++i) { i < picked_offset + picked_number; ++i) {
subgraph_indices_data_ptr[i] = subgraph_indices_data_ptr[i] =
...@@ -1394,10 +1394,10 @@ inline int64_t LaborPick( ...@@ -1394,10 +1394,10 @@ inline int64_t LaborPick(
if (NonUniform && probs_or_mask.value().size(0) <= num_neighbors) { if (NonUniform && probs_or_mask.value().size(0) <= num_neighbors) {
local_probs_data -= offset; local_probs_data -= offset;
} }
AT_DISPATCH_INTEGRAL_TYPES( AT_DISPATCH_INDEX_TYPES(
args.indices.scalar_type(), "LaborPickMain", ([&] { args.indices.scalar_type(), "LaborPickMain", ([&] {
const scalar_t* local_indices_data = const index_t* local_indices_data =
args.indices.data_ptr<scalar_t>() + offset; args.indices.data_ptr<index_t>() + offset;
if constexpr (Replace) { if constexpr (Replace) {
// [Algorithm] @mfbalin // [Algorithm] @mfbalin
// Use a max-heap to get rid of the big random numbers and filter the // Use a max-heap to get rid of the big random numbers and filter the
...@@ -1431,7 +1431,7 @@ inline int64_t LaborPick( ...@@ -1431,7 +1431,7 @@ inline int64_t LaborPick(
auto heap_end = heap_data; auto heap_end = heap_data;
const auto init_count = (num_neighbors + fanout - 1) / num_neighbors; const auto init_count = (num_neighbors + fanout - 1) / num_neighbors;
auto sample_neighbor_i_with_index_t_jth_time = auto sample_neighbor_i_with_index_t_jth_time =
[&](scalar_t t, int64_t j, uint32_t i) { [&](index_t t, int64_t j, uint32_t i) {
auto rnd = labor::jth_sorted_uniform_random( auto rnd = labor::jth_sorted_uniform_random(
args.random_seed, t, args.num_nodes, j, remaining_data[i], args.random_seed, t, args.num_nodes, j, remaining_data[i],
fanout - j); // r_t fanout - j); // r_t
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment