Unverified Commit e65fc9ff authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Avoid using macros in DGL headers (#5087)

parent 255ad1b3
...@@ -67,9 +67,15 @@ torch::Tensor _CSRMask( ...@@ -67,9 +67,15 @@ torch::Tensor _CSRMask(
auto row = TorchTensorToDGLArray(sub_mat->COOPtr()->row); auto row = TorchTensorToDGLArray(sub_mat->COOPtr()->row);
auto col = TorchTensorToDGLArray(sub_mat->COOPtr()->col); auto col = TorchTensorToDGLArray(sub_mat->COOPtr()->col);
runtime::NDArray ret; runtime::NDArray ret;
ATEN_FLOAT_TYPE_SWITCH(val->dtype, DType, "Value Type", { if (val->dtype.bits == 32) {
ret = aten::CSRGetData<DType>(csr, row, col, val, 0.); ret = aten::CSRGetData<float>(csr, row, col, val, 0.);
}); } else if (val->dtype.bits == 64) {
ret = aten::CSRGetData<double>(csr, row, col, val, 0.);
} else {
TORCH_CHECK(
false, "Dtype of value for SpSpMM should be 32 or 64 bits but got: " +
std::to_string(val->dtype.bits));
}
return DGLArrayToTorchTensor(ret); return DGLArrayToTorchTensor(ret);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment