utils.h 2.54 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
/**
2
3
4
5
6
7
8
 *  Copyright (c) 2022 by Contributors
 * @file utils.h
 * @brief DGL C++ sparse API utilities
 */
#ifndef DGL_SPARSE_UTILS_H_
#define DGL_SPARSE_UTILS_H_

czkkkkkk's avatar
czkkkkkk committed
9
10
// clang-format off
#include <sparse/dgl_headers.h>
11
#include <sparse/torch_headers.h>
czkkkkkk's avatar
czkkkkkk committed
12
13
// clang-format on

14
15
16
17
18
#include <sparse/sparse_matrix.h>

namespace dgl {
namespace sparse {

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
19
/** @brief Find a proper sparse format for two sparse matrices. It chooses
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
 * COO if anyone of the sparse matrices has COO format. If none of them has
 * COO, it tries CSR and CSC in the same manner. */
inline static SparseFormat FindAnyExistingFormat(
    const c10::intrusive_ptr<SparseMatrix>& A,
    const c10::intrusive_ptr<SparseMatrix>& B) {
  SparseFormat fmt;
  if (A->HasCOO() || B->HasCOO()) {
    fmt = SparseFormat::kCOO;
  } else if (A->HasCSR() || B->HasCSR()) {
    fmt = SparseFormat::kCSR;
  } else {
    fmt = SparseFormat::kCSC;
  }
  return fmt;
}

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
36
/** @brief Check whether two matrices has the same dtype and shape for
37
38
39
40
 * elementwise operators. */
inline static void ElementwiseOpSanityCheck(
    const c10::intrusive_ptr<SparseMatrix>& A,
    const c10::intrusive_ptr<SparseMatrix>& B) {
41
42
43
44
45
46
47
48
  TORCH_CHECK(
      A->value().dtype() == B->value().dtype(),
      "Elementwise operators"
      " do not support two sparse matrices with different dtypes.");
  TORCH_CHECK(
      A->shape()[0] == B->shape()[0] && A->shape()[1] == B->shape()[1],
      "Elementwise operators do not support two sparse matrices with different"
      " shapes.");
49
50
}

czkkkkkk's avatar
czkkkkkk committed
51
52
/** @brief Convert a Torch tensor to a DGL array. */
inline static runtime::NDArray TorchTensorToDGLArray(torch::Tensor tensor) {
53
  return runtime::DLPackConvert::FromDLPack(at::toDLPack(tensor.contiguous()));
czkkkkkk's avatar
czkkkkkk committed
54
55
56
57
58
59
60
}

/** @brief Convert a DGL array to a Torch tensor. */
inline static torch::Tensor DGLArrayToTorchTensor(runtime::NDArray array) {
  return at::fromDLPack(runtime::DLPackConvert::ToDLPack(array));
}

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
/** @brief Convert an optional Torch tensor to a DGL array. */
inline static runtime::NDArray OptionalTorchTensorToDGLArray(
    torch::optional<torch::Tensor> tensor) {
  if (!tensor.has_value()) {
    return aten::NullArray();
  }
  return TorchTensorToDGLArray(tensor.value());
}

/** @brief Convert a DGL array to an optional Torch tensor. */
inline static torch::optional<torch::Tensor> DGLArrayToOptionalTorchTensor(
    runtime::NDArray array) {
  if (aten::IsNullArray(array)) {
    return torch::optional<torch::Tensor>();
  }
  return torch::make_optional<torch::Tensor>(DGLArrayToTorchTensor(array));
}

79
80
81
82
}  // namespace sparse
}  // namespace dgl

#endif  // DGL_SPARSE_UTILS_H_