"src/vscode:/vscode.git/clone" did not exist on "37a44bb2839c1af18940b6cf38f5639c9c279caf"
expand_indptr.cc 1.02 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
3
4
5
6
7
8
9
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file expand_indptr.cc
 * @brief ExpandIndptr operators.
 */
#include <graphbolt/cuda_ops.h>

sangwzh's avatar
sangwzh committed
10
11
#include "macro.h"
#include "utils.h"
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

namespace graphbolt {
namespace ops {

torch::Tensor ExpandIndptr(
    torch::Tensor indptr, torch::ScalarType dtype,
    torch::optional<torch::Tensor> node_ids,
    torch::optional<int64_t> output_size) {
  if (utils::is_on_gpu(indptr) &&
      (!node_ids.has_value() || utils::is_on_gpu(node_ids.value()))) {
    GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(c10::DeviceType::CUDA, "ExpandIndptr", {
      return ExpandIndptrImpl(indptr, dtype, node_ids, output_size);
    });
  }
  if (!node_ids.has_value()) {
27
    return torch::repeat_interleave(indptr.diff(), output_size).to(dtype);
28
29
30
31
32
33
34
  }
  return node_ids.value().to(dtype).repeat_interleave(
      indptr.diff(), 0, output_size);
}

}  // namespace ops
}  // namespace graphbolt