expand_indptr.cc 1.37 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
3
4
5
6
7
8
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file expand_indptr.cc
 * @brief ExpandIndptr operators.
 */
#include <graphbolt/cuda_ops.h>
9
#include <torch/autograd.h>
10

sangwzh's avatar
sangwzh committed
11
12
#include "macro.h"
#include "utils.h"
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

namespace graphbolt {
namespace ops {

torch::Tensor ExpandIndptr(
    torch::Tensor indptr, torch::ScalarType dtype,
    torch::optional<torch::Tensor> node_ids,
    torch::optional<int64_t> output_size) {
  if (utils::is_on_gpu(indptr) &&
      (!node_ids.has_value() || utils::is_on_gpu(node_ids.value()))) {
    GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(c10::DeviceType::CUDA, "ExpandIndptr", {
      return ExpandIndptrImpl(indptr, dtype, node_ids, output_size);
    });
  }
  if (!node_ids.has_value()) {
28
    return torch::repeat_interleave(indptr.diff(), output_size).to(dtype);
29
30
31
32
33
  }
  return node_ids.value().to(dtype).repeat_interleave(
      indptr.diff(), 0, output_size);
}

34
35
36
37
38
39
40
41
42
43
44
45
46
47
TORCH_LIBRARY_IMPL(graphbolt, CPU, m) {
  m.impl("expand_indptr", &ExpandIndptr);
}

#ifdef GRAPHBOLT_USE_CUDA
TORCH_LIBRARY_IMPL(graphbolt, CUDA, m) {
  m.impl("expand_indptr", &ExpandIndptrImpl);
}
#endif

TORCH_LIBRARY_IMPL(graphbolt, Autograd, m) {
  m.impl("expand_indptr", torch::autograd::autogradNotImplementedFallback());
}

48
49
}  // namespace ops
}  // namespace graphbolt