insubgraph.cu 1.28 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file cuda/insubgraph.cu
 * @brief InSubgraph operator implementation on CUDA.
 */

#include <graphbolt/cuda_ops.h>
#include <graphbolt/cuda_sampling_ops.h>

#include "./common.h"

namespace graphbolt {
namespace ops {

c10::intrusive_ptr<sampling::FusedSampledSubgraph> InSubgraph(
    torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes,
    torch::optional<torch::Tensor> type_per_edge) {
  auto [output_indptr, output_indices] =
      IndexSelectCSCImpl(indptr, indices, nodes);
  torch::optional<torch::Tensor> output_type_per_edge;
  if (type_per_edge) {
    output_type_per_edge =
        std::get<1>(IndexSelectCSCImpl(indptr, type_per_edge.value(), nodes));
  }
  auto rows = CSRToCOO(output_indptr, indices.scalar_type());
  auto [in_degree, sliced_indptr] = SliceCSCIndptr(indptr, nodes);
  auto i = torch::arange(output_indices.size(0), output_indptr.options());
  auto edge_ids =
      i - output_indptr.gather(0, rows) + sliced_indptr.gather(0, rows);

  return c10::make_intrusive<sampling::FusedSampledSubgraph>(
      output_indptr, output_indices, nodes, torch::nullopt, edge_ids,
      output_type_per_edge);
}

}  // namespace ops
}  // namespace graphbolt