Unverified Commit a536772e authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix cuda 11.1 crashing bug (#3265)

parent 2613f7f0
......@@ -413,7 +413,7 @@ void CusparseCsrmm2Hetero(
* \brief Determine whether cusparse SpMM function is applicable.
*/
template <int bits, typename IdType>
inline bool cusparse_available() {
inline bool cusparse_available(bool more_nnz_than_matrix_size) {
#if CUDART_VERSION < 11000
if (std::is_same<IdType, int>::value)
if (bits > 16)
......@@ -422,7 +422,8 @@ inline bool cusparse_available() {
#else
if (bits == 16)
return false; // cusparse's SpMM on fp16 is slow, temporally disabled.
return true;
// If the CSR matrix has more NNZ than matrix size, we should not use cuSPARSE 11.1.
return !more_nnz_than_matrix_size;
#endif
}
......@@ -444,7 +445,9 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
bool use_efeat = op != "copy_lhs";
if (reduce == "sum") {
if (op == "copy_lhs" && cusparse_available<bits, IdType>()) { // cusparse
bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
if (op == "copy_lhs" && cusparse_available<bits, IdType>(more_nnz)) {
// cusparse
int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i)
x_length *= ufeat->shape[i];
......@@ -456,7 +459,8 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
static_cast<DType*>(out->data),
x_length);
});
} else if (op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>()) { // cusparse
} else if (op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>(more_nnz)) {
// cusparse
int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i)
x_length *= ufeat->shape[i];
......@@ -524,8 +528,9 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
bool use_legacy_cusparsemm =
(CUDART_VERSION < 11000) &&
((op == "copy_lhs" && cusparse_available<bits, IdType>()) ||
(op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>()));
// legacy cuSPARSE does not care about NNZ, hence the argument "false".
((op == "copy_lhs" && cusparse_available<bits, IdType>(false)) ||
(op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>(false)));
// Create temporary output buffer to store non-transposed output
if (use_legacy_cusparsemm) {
for (dgl_type_t ntype = 0; ntype < vec_out.size(); ++ntype) {
......@@ -568,8 +573,9 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const dgl_type_t dst_id = out_ntids[etype];
CSRMatrix csr = vec_csr[etype];
if (reduce == "sum") {
bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
/* Call SpMM for each relation type */
if (op == "copy_lhs" && cusparse_available<bits, IdType>()) { // cusparse
if (op == "copy_lhs" && cusparse_available<bits, IdType>(more_nnz)) { // cusparse
/* If CUDA is less than 11.0, put the output in trans_out for later transposition */
DType *out = (CUDART_VERSION < 11000) ? trans_out[dst_id] :
static_cast<DType*>(vec_out[dst_id]->data);
......@@ -580,7 +586,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
out,
x_length, thr_entry->stream);
} else if (op == "mul" && is_scalar_efeat &&
cusparse_available<bits, IdType>()) { // cusparse
cusparse_available<bits, IdType>(more_nnz)) { // cusparse
NDArray efeat = vec_efeat[etype];
if (!IsNullArray(csr.data))
efeat = _IndexSelect<DType, IdType>(vec_efeat[etype], csr.data);
......
......@@ -1521,6 +1521,18 @@ def test_level2(idtype):
g.nodes['game'].data.clear()
@parametrize_dtype
@unittest.skipIf(F._default_context_str == 'cpu', reason="Need gpu for this test")
def test_more_nnz(idtype):
g = dgl.graph(([0, 0, 0, 0, 0], [1, 1, 1, 1, 1]), idtype=idtype, device=F.ctx())
g.ndata['x'] = F.copy_to(F.ones((2, 5)), ctx=F.ctx())
g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'y'))
y = g.ndata['y']
ans = np.zeros((2, 5))
ans[1] = 5
ans = F.copy_to(F.tensor(ans, dtype=F.dtype(y)), ctx=F.ctx())
assert F.array_equal(y, ans)
@parametrize_dtype
def test_updates(idtype):
def msg_func(edges):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment