Unverified Commit a536772e authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix cuda 11.1 crashing bug (#3265)

parent 2613f7f0
...@@ -413,7 +413,7 @@ void CusparseCsrmm2Hetero( ...@@ -413,7 +413,7 @@ void CusparseCsrmm2Hetero(
* \brief Determine whether cusparse SpMM function is applicable. * \brief Determine whether cusparse SpMM function is applicable.
*/ */
template <int bits, typename IdType> template <int bits, typename IdType>
inline bool cusparse_available() { inline bool cusparse_available(bool more_nnz_than_matrix_size) {
#if CUDART_VERSION < 11000 #if CUDART_VERSION < 11000
if (std::is_same<IdType, int>::value) if (std::is_same<IdType, int>::value)
if (bits > 16) if (bits > 16)
...@@ -422,7 +422,8 @@ inline bool cusparse_available() { ...@@ -422,7 +422,8 @@ inline bool cusparse_available() {
#else #else
if (bits == 16) if (bits == 16)
return false; // cusparse's SpMM on fp16 is slow, temporally disabled. return false; // cusparse's SpMM on fp16 is slow, temporally disabled.
return true; // If the CSR matrix has more NNZ than matrix size, we should not use cuSPARSE 11.1.
return !more_nnz_than_matrix_size;
#endif #endif
} }
...@@ -444,7 +445,9 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -444,7 +445,9 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
bool use_efeat = op != "copy_lhs"; bool use_efeat = op != "copy_lhs";
if (reduce == "sum") { if (reduce == "sum") {
if (op == "copy_lhs" && cusparse_available<bits, IdType>()) { // cusparse bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
if (op == "copy_lhs" && cusparse_available<bits, IdType>(more_nnz)) {
// cusparse
int64_t x_length = 1; int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i) for (int i = 1; i < ufeat->ndim; ++i)
x_length *= ufeat->shape[i]; x_length *= ufeat->shape[i];
...@@ -456,7 +459,8 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -456,7 +459,8 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
static_cast<DType*>(out->data), static_cast<DType*>(out->data),
x_length); x_length);
}); });
} else if (op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>()) { // cusparse } else if (op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>(more_nnz)) {
// cusparse
int64_t x_length = 1; int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i) for (int i = 1; i < ufeat->ndim; ++i)
x_length *= ufeat->shape[i]; x_length *= ufeat->shape[i];
...@@ -524,8 +528,9 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -524,8 +528,9 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
bool use_legacy_cusparsemm = bool use_legacy_cusparsemm =
(CUDART_VERSION < 11000) && (CUDART_VERSION < 11000) &&
((op == "copy_lhs" && cusparse_available<bits, IdType>()) || // legacy cuSPARSE does not care about NNZ, hence the argument "false".
(op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>())); ((op == "copy_lhs" && cusparse_available<bits, IdType>(false)) ||
(op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>(false)));
// Create temporary output buffer to store non-transposed output // Create temporary output buffer to store non-transposed output
if (use_legacy_cusparsemm) { if (use_legacy_cusparsemm) {
for (dgl_type_t ntype = 0; ntype < vec_out.size(); ++ntype) { for (dgl_type_t ntype = 0; ntype < vec_out.size(); ++ntype) {
...@@ -568,8 +573,9 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -568,8 +573,9 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const dgl_type_t dst_id = out_ntids[etype]; const dgl_type_t dst_id = out_ntids[etype];
CSRMatrix csr = vec_csr[etype]; CSRMatrix csr = vec_csr[etype];
if (reduce == "sum") { if (reduce == "sum") {
bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
/* Call SpMM for each relation type */ /* Call SpMM for each relation type */
if (op == "copy_lhs" && cusparse_available<bits, IdType>()) { // cusparse if (op == "copy_lhs" && cusparse_available<bits, IdType>(more_nnz)) { // cusparse
/* If CUDA is less than 11.0, put the output in trans_out for later transposition */ /* If CUDA is less than 11.0, put the output in trans_out for later transposition */
DType *out = (CUDART_VERSION < 11000) ? trans_out[dst_id] : DType *out = (CUDART_VERSION < 11000) ? trans_out[dst_id] :
static_cast<DType*>(vec_out[dst_id]->data); static_cast<DType*>(vec_out[dst_id]->data);
...@@ -580,7 +586,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -580,7 +586,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
out, out,
x_length, thr_entry->stream); x_length, thr_entry->stream);
} else if (op == "mul" && is_scalar_efeat && } else if (op == "mul" && is_scalar_efeat &&
cusparse_available<bits, IdType>()) { // cusparse cusparse_available<bits, IdType>(more_nnz)) { // cusparse
NDArray efeat = vec_efeat[etype]; NDArray efeat = vec_efeat[etype];
if (!IsNullArray(csr.data)) if (!IsNullArray(csr.data))
efeat = _IndexSelect<DType, IdType>(vec_efeat[etype], csr.data); efeat = _IndexSelect<DType, IdType>(vec_efeat[etype], csr.data);
......
...@@ -1521,6 +1521,18 @@ def test_level2(idtype): ...@@ -1521,6 +1521,18 @@ def test_level2(idtype):
g.nodes['game'].data.clear() g.nodes['game'].data.clear()
@parametrize_dtype
@unittest.skipIf(F._default_context_str == 'cpu', reason="Need gpu for this test")
def test_more_nnz(idtype):
g = dgl.graph(([0, 0, 0, 0, 0], [1, 1, 1, 1, 1]), idtype=idtype, device=F.ctx())
g.ndata['x'] = F.copy_to(F.ones((2, 5)), ctx=F.ctx())
g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'y'))
y = g.ndata['y']
ans = np.zeros((2, 5))
ans[1] = 5
ans = F.copy_to(F.tensor(ans, dtype=F.dtype(y)), ctx=F.ctx())
assert F.array_equal(y, ans)
@parametrize_dtype @parametrize_dtype
def test_updates(idtype): def test_updates(idtype):
def msg_func(edges): def msg_func(edges):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment