"tests/vscode:/vscode.git/clone" did not exist on "9abe041351fc0c10d5cffd3ce90a5d6c90229045"
expand_indptr.cc 1.32 KB
Newer Older
1
2
3
4
5
6
7
/**
 *  Copyright (c) 2023 by Contributors
 *  Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
 * @file expand_indptr.cc
 * @brief ExpandIndptr operators.
 */
#include <graphbolt/cuda_ops.h>
8
#include <torch/autograd.h>
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

#include "./macro.h"
#include "./utils.h"

namespace graphbolt {
namespace ops {

torch::Tensor ExpandIndptr(
    torch::Tensor indptr, torch::ScalarType dtype,
    torch::optional<torch::Tensor> node_ids,
    torch::optional<int64_t> output_size) {
  if (utils::is_on_gpu(indptr) &&
      (!node_ids.has_value() || utils::is_on_gpu(node_ids.value()))) {
    GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(c10::DeviceType::CUDA, "ExpandIndptr", {
      return ExpandIndptrImpl(indptr, dtype, node_ids, output_size);
    });
  }
  if (!node_ids.has_value()) {
27
    return torch::repeat_interleave(indptr.diff(), output_size).to(dtype);
28
29
30
31
32
  }
  return node_ids.value().to(dtype).repeat_interleave(
      indptr.diff(), 0, output_size);
}

33
34
35
36
37
38
39
40
41
42
43
44
45
46
TORCH_LIBRARY_IMPL(graphbolt, CPU, m) {
  m.impl("expand_indptr", &ExpandIndptr);
}

#ifdef GRAPHBOLT_USE_CUDA
TORCH_LIBRARY_IMPL(graphbolt, CUDA, m) {
  m.impl("expand_indptr", &ExpandIndptrImpl);
}
#endif

TORCH_LIBRARY_IMPL(graphbolt, Autograd, m) {
  m.impl("expand_indptr", torch::autograd::autogradNotImplementedFallback());
}

47
48
}  // namespace ops
}  // namespace graphbolt