"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "b9290e8b2d54a08fce612b1c320bd5873761789f"
Unverified Commit d4a6f8a0 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Refactor `Gather` operation. (#7269)

parent 62aca92d
...@@ -149,6 +149,20 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> SliceCSCIndptrHetero( ...@@ -149,6 +149,20 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> SliceCSCIndptrHetero(
*/ */
torch::Tensor ExclusiveCumSum(torch::Tensor input); torch::Tensor ExclusiveCumSum(torch::Tensor input);
/**
* @brief Computes the gather operation on a given input and index tensor.
*
* @param input The input tensor.
* @param index The index tensor.
* @param dtype The optional output dtype. If not given, inferred from the input
* tensor.
*
* @return The result of the input.gather(0, index).to(dtype) operation.
*/
torch::Tensor Gather(
torch::Tensor input, torch::Tensor index,
torch::optional<torch::ScalarType> dtype = torch::nullopt);
/** /**
* @brief Select rows from input tensor according to index tensor. * @brief Select rows from input tensor according to index tensor.
* *
......
/**
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/gather.cu
* @brief Gather operators implementation on CUDA.
*/
#include <thrust/gather.h>
#include "./common.h"
namespace graphbolt {
namespace ops {
torch::Tensor Gather(
torch::Tensor input, torch::Tensor index,
torch::optional<torch::ScalarType> dtype) {
if (!dtype.has_value()) dtype = input.scalar_type();
auto output = torch::empty(index.sizes(), index.options().dtype(*dtype));
AT_DISPATCH_INDEX_TYPES(
index.scalar_type(), "GatherIndexType", ([&] {
AT_DISPATCH_INTEGRAL_TYPES(
input.scalar_type(), "GatherInputType", ([&] {
using input_t = scalar_t;
AT_DISPATCH_INTEGRAL_TYPES(*dtype, "GatherOutputType", ([&] {
using output_t = scalar_t;
THRUST_CALL(
gather, index.data_ptr<index_t>(),
index.data_ptr<index_t>() + index.size(0),
input.data_ptr<input_t>(), output.data_ptr<output_t>());
}));
}));
}));
return output;
}
} // namespace ops
} // namespace graphbolt
...@@ -500,44 +500,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -500,44 +500,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
} }
} }
output_indices = torch::empty( output_indices = Gather(indices, picked_eids);
picked_eids.size(0),
picked_eids.options().dtype(indices.scalar_type()));
// Compute: output_indices = indices.gather(0, picked_eids);
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(), "SampleNeighborsOutputIndices", ([&] {
using indices_t = index_t;
THRUST_CALL(
gather, picked_eids.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
indices.data_ptr<indices_t>(),
output_indices.data_ptr<indices_t>());
}));
})); }));
auto index_type_per_edge_for_sampled_edges = [&] {
// The code behaves same as:
// output_type_per_edge = type_per_edge.gather(0, picked_eids);
// The reimplementation is required due to the torch equivalent does
// not work when type_per_edge is on pinned memory
auto types = type_per_edge.value();
auto output = torch::empty(
picked_eids.size(0), picked_eids.options().dtype(types.scalar_type()));
AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "SampleNeighborsIndptr", ([&] {
using indptr_t = index_t;
AT_DISPATCH_INTEGRAL_TYPES(
types.scalar_type(), "SampleNeighborsOutputTypePerEdge", ([&] {
THRUST_CALL(
gather, picked_eids.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
types.data_ptr<scalar_t>(), output.data_ptr<scalar_t>());
}));
}));
return output;
};
torch::optional<torch::Tensor> output_type_per_edge; torch::optional<torch::Tensor> output_type_per_edge;
torch::optional<torch::Tensor> edge_offsets; torch::optional<torch::Tensor> edge_offsets;
if (type_per_edge && seed_offsets) { if (type_per_edge && seed_offsets) {
...@@ -547,7 +512,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -547,7 +512,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
// type_per_edge of sampled edges and determine the offsets of different // type_per_edge of sampled edges and determine the offsets of different
// sampled etypes and convert to fused hetero indptr representation. // sampled etypes and convert to fused hetero indptr representation.
if (fanouts.size() == 1) { if (fanouts.size() == 1) {
output_type_per_edge = index_type_per_edge_for_sampled_edges(); output_type_per_edge = Gather(*type_per_edge, picked_eids);
torch::Tensor output_in_degree, sliced_output_indptr; torch::Tensor output_in_degree, sliced_output_indptr;
sliced_output_indptr = sliced_output_indptr =
output_indptr.slice(0, 0, output_indptr.size(0) - 1); output_indptr.slice(0, 0, output_indptr.size(0) - 1);
...@@ -652,7 +617,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -652,7 +617,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
output_indptr = output_indptr =
output_indptr.slice(0, 0, output_indptr.size(0), fanouts.size()); output_indptr.slice(0, 0, output_indptr.size(0), fanouts.size());
if (type_per_edge) if (type_per_edge)
output_type_per_edge = index_type_per_edge_for_sampled_edges(); output_type_per_edge = Gather(*type_per_edge, picked_eids);
} }
torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt; torch::optional<torch::Tensor> subgraph_reverse_edge_ids = torch::nullopt;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment