utils.h 2.59 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
/**
2
3
4
5
6
7
8
 *  Copyright (c) 2022 by Contributors
 * @file utils.h
 * @brief DGL C++ sparse API utilities
 */
#ifndef DGL_SPARSE_UTILS_H_
#define DGL_SPARSE_UTILS_H_

czkkkkkk's avatar
czkkkkkk committed
9
10
11
12
// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

13
#include <ATen/DLConvertor.h>
14
#include <sparse/sparse_matrix.h>
15
16
#include <torch/custom_class.h>
#include <torch/script.h>
17
18
19
20

namespace dgl {
namespace sparse {

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
21
/** @brief Find a proper sparse format for two sparse matrices. It chooses
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
 * COO if anyone of the sparse matrices has COO format. If none of them has
 * COO, it tries CSR and CSC in the same manner. */
inline static SparseFormat FindAnyExistingFormat(
    const c10::intrusive_ptr<SparseMatrix>& A,
    const c10::intrusive_ptr<SparseMatrix>& B) {
  SparseFormat fmt;
  if (A->HasCOO() || B->HasCOO()) {
    fmt = SparseFormat::kCOO;
  } else if (A->HasCSR() || B->HasCSR()) {
    fmt = SparseFormat::kCSR;
  } else {
    fmt = SparseFormat::kCSC;
  }
  return fmt;
}

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
38
/** @brief Check whether two matrices has the same dtype and shape for
39
40
41
42
 * elementwise operators. */
inline static void ElementwiseOpSanityCheck(
    const c10::intrusive_ptr<SparseMatrix>& A,
    const c10::intrusive_ptr<SparseMatrix>& B) {
43
44
45
46
47
48
49
50
  TORCH_CHECK(
      A->value().dtype() == B->value().dtype(),
      "Elementwise operators"
      " do not support two sparse matrices with different dtypes.");
  TORCH_CHECK(
      A->shape()[0] == B->shape()[0] && A->shape()[1] == B->shape()[1],
      "Elementwise operators do not support two sparse matrices with different"
      " shapes.");
51
52
}

czkkkkkk's avatar
czkkkkkk committed
53
54
/** @brief Convert a Torch tensor to a DGL array. */
inline static runtime::NDArray TorchTensorToDGLArray(torch::Tensor tensor) {
55
  return runtime::DLPackConvert::FromDLPack(at::toDLPack(tensor.contiguous()));
czkkkkkk's avatar
czkkkkkk committed
56
57
58
59
60
61
62
}

/** @brief Convert a DGL array to a Torch tensor. */
inline static torch::Tensor DGLArrayToTorchTensor(runtime::NDArray array) {
  return at::fromDLPack(runtime::DLPackConvert::ToDLPack(array));
}

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
/** @brief Convert an optional Torch tensor to a DGL array. */
inline static runtime::NDArray OptionalTorchTensorToDGLArray(
    torch::optional<torch::Tensor> tensor) {
  if (!tensor.has_value()) {
    return aten::NullArray();
  }
  return TorchTensorToDGLArray(tensor.value());
}

/** @brief Convert a DGL array to an optional Torch tensor. */
inline static torch::optional<torch::Tensor> DGLArrayToOptionalTorchTensor(
    runtime::NDArray array) {
  if (aten::IsNullArray(array)) {
    return torch::optional<torch::Tensor>();
  }
  return torch::make_optional<torch::Tensor>(DGLArrayToTorchTensor(array));
}

81
82
83
84
}  // namespace sparse
}  // namespace dgl

#endif  // DGL_SPARSE_UTILS_H_