Unverified Commit 811e35a6 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Sparse] Change all CHECK to TORCH_CHECK (#5082)



* Update

* CI

* Update dgl_sparse/src/utils.h
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-36-188.ap-northeast-1.compute.internal>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 14429df6
......@@ -40,14 +40,4 @@
#undef DLOG
#undef LOG_IF
// For Pytorch version later than 1.12, redefine CHECK_* to TORCH_CHECK_*.
#if !(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR <= 12)
#define CHECK_EQ(val1, val2) TORCH_CHECK_EQ(val1, val2)
#define CHECK_NE(val1, val2) TORCH_CHECK_NE(val1, val2)
#define CHECK_LE(val1, val2) TORCH_CHECK_LE(val1, val2)
#define CHECK_LT(val1, val2) TORCH_CHECK_LT(val1, val2)
#define CHECK_GE(val1, val2) TORCH_CHECK_GE(val1, val2)
#define CHECK_GT(val1, val2) TORCH_CHECK_GT(val1, val2)
#endif
#endif // SPARSE_DGL_HEADERS_H_
......@@ -17,7 +17,7 @@ namespace sparse {
std::shared_ptr<COO> COOFromOldDGLCOO(const aten::COOMatrix& dgl_coo) {
auto row = DGLArrayToTorchTensor(dgl_coo.row);
auto col = DGLArrayToTorchTensor(dgl_coo.col);
CHECK(aten::IsNullArray(dgl_coo.data));
TORCH_CHECK(aten::IsNullArray(dgl_coo.data));
return std::make_shared<COO>(
COO{dgl_coo.num_rows, dgl_coo.num_cols, row, col, dgl_coo.row_sorted,
dgl_coo.col_sorted});
......
......@@ -20,37 +20,38 @@ SparseMatrix::SparseMatrix(
const std::shared_ptr<CSR>& csc, torch::Tensor value,
const std::vector<int64_t>& shape)
: coo_(coo), csr_(csr), csc_(csc), value_(value), shape_(shape) {
CHECK(coo != nullptr || csr != nullptr || csc != nullptr)
<< "At least one of CSR/COO/CSC is provided to construct a "
"SparseMatrix";
CHECK_EQ(shape.size(), 2)
<< "The shape of a sparse matrix should be 2-dimensional";
TORCH_CHECK(
coo != nullptr || csr != nullptr || csc != nullptr, "At least ",
"one of CSR/COO/CSC is required to construct a SparseMatrix.")
TORCH_CHECK(
shape.size() == 2, "The shape of a sparse matrix should be ",
"2-dimensional.");
// NOTE: Currently all the tensors of a SparseMatrix should on the same
// device. Do we allow the graph structure and values are on different
// devices?
if (coo != nullptr) {
CHECK_EQ(coo->row.dim(), 1);
CHECK_EQ(coo->col.dim(), 1);
CHECK_EQ(coo->row.size(0), coo->col.size(0));
CHECK_EQ(coo->row.size(0), value.size(0));
CHECK_EQ(coo->row.device(), value.device());
CHECK_EQ(coo->col.device(), value.device());
TORCH_CHECK(coo->row.dim() == 1);
TORCH_CHECK(coo->col.dim() == 1);
TORCH_CHECK(coo->row.size(0) == coo->col.size(0));
TORCH_CHECK(coo->row.size(0) == value.size(0));
TORCH_CHECK(coo->row.device() == value.device());
TORCH_CHECK(coo->col.device() == value.device());
}
if (csr != nullptr) {
CHECK_EQ(csr->indptr.dim(), 1);
CHECK_EQ(csr->indices.dim(), 1);
CHECK_EQ(csr->indptr.size(0), shape[0] + 1);
CHECK_EQ(csr->indices.size(0), value.size(0));
CHECK_EQ(csr->indptr.device(), value.device());
CHECK_EQ(csr->indices.device(), value.device());
TORCH_CHECK(csr->indptr.dim() == 1);
TORCH_CHECK(csr->indices.dim() == 1);
TORCH_CHECK(csr->indptr.size(0) == shape[0] + 1);
TORCH_CHECK(csr->indices.size(0) == value.size(0));
TORCH_CHECK(csr->indptr.device() == value.device());
TORCH_CHECK(csr->indices.device() == value.device());
}
if (csc != nullptr) {
CHECK_EQ(csc->indptr.dim(), 1);
CHECK_EQ(csc->indices.dim(), 1);
CHECK_EQ(csc->indptr.size(0), shape[1] + 1);
CHECK_EQ(csc->indices.size(0), value.size(0));
CHECK_EQ(csc->indptr.device(), value.device());
CHECK_EQ(csc->indices.device(), value.device());
TORCH_CHECK(csc->indptr.dim() == 1);
TORCH_CHECK(csc->indices.dim() == 1);
TORCH_CHECK(csc->indptr.size(0) == shape[1] + 1);
TORCH_CHECK(csc->indices.size(0) == value.size(0));
TORCH_CHECK(csc->indptr.device() == value.device());
TORCH_CHECK(csc->indices.device() == value.device());
}
}
......@@ -187,11 +188,12 @@ c10::intrusive_ptr<SparseMatrix> CreateFromCSC(
c10::intrusive_ptr<SparseMatrix> CreateValLike(
const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value) {
CHECK_EQ(mat->value().size(0), value.size(0))
<< "The first dimension of the old values and the new values must be the "
"same.";
CHECK_EQ(mat->value().device(), value.device())
<< "The device of the old values and the new values must be the same.";
TORCH_CHECK(
mat->value().size(0) == value.size(0), "The first dimension of ",
"the old values and the new values must be the same.");
TORCH_CHECK(
mat->value().device() == value.device(), "The device of the ",
"old values and the new values must be the same.");
auto shape = mat->shape();
if (mat->HasCOO()) {
return SparseMatrix::FromCOO(mat->COOPtr(), value, shape);
......
......@@ -89,7 +89,7 @@ variable_list SpSpMMAutoGrad::forward(
auto csr = ret_mat->CSRPtr();
auto val = ret_mat->value();
CHECK(!csr->value_indices.has_value());
TORCH_CHECK(!csr->value_indices.has_value());
return {csr->indptr, csr->indices, val};
}
......
......@@ -40,15 +40,14 @@ inline static SparseFormat FindAnyExistingFormat(
inline static void ElementwiseOpSanityCheck(
const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B) {
CHECK(A->value().dtype() == B->value().dtype())
<< "Elementwise operators do not support two sparse matrices with "
"different dtypes. ("
<< A->value().dtype() << " vs " << B->value().dtype() << ")";
CHECK(A->shape()[0] == B->shape()[0] && A->shape()[1] == B->shape()[1])
<< "Elementwise operator do not support two sparse matrices with "
"different shapes. (["
<< A->shape()[0] << ", " << A->shape()[1] << "] vs [" << B->shape()[0]
<< ", " << B->shape()[1] << "])";
TORCH_CHECK(
A->value().dtype() == B->value().dtype(),
"Elementwise operators"
" do not support two sparse matrices with different dtypes.");
TORCH_CHECK(
A->shape()[0] == B->shape()[0] && A->shape()[1] == B->shape()[1],
"Elementwise operators do not support two sparse matrices with different"
" shapes.");
}
/** @brief Convert a Torch tensor to a DGL array. */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment