Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a536772e
Unverified
Commit
a536772e
authored
Aug 18, 2021
by
Quan (Andy) Gan
Committed by
GitHub
Aug 18, 2021
Browse files
fix cuda 11.1 crashing bug (#3265)
parent
2613f7f0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
8 deletions
+26
-8
src/array/cuda/spmm.cu
src/array/cuda/spmm.cu
+14
-8
tests/compute/test_heterograph.py
tests/compute/test_heterograph.py
+12
-0
No files found.
src/array/cuda/spmm.cu
View file @
a536772e
...
@@ -413,7 +413,7 @@ void CusparseCsrmm2Hetero(
...
@@ -413,7 +413,7 @@ void CusparseCsrmm2Hetero(
* \brief Determine whether cusparse SpMM function is applicable.
* \brief Determine whether cusparse SpMM function is applicable.
*/
*/
template
<
int
bits
,
typename
IdType
>
template
<
int
bits
,
typename
IdType
>
inline
bool
cusparse_available
()
{
inline
bool
cusparse_available
(
bool
more_nnz_than_matrix_size
)
{
#if CUDART_VERSION < 11000
#if CUDART_VERSION < 11000
if
(
std
::
is_same
<
IdType
,
int
>::
value
)
if
(
std
::
is_same
<
IdType
,
int
>::
value
)
if
(
bits
>
16
)
if
(
bits
>
16
)
...
@@ -422,7 +422,8 @@ inline bool cusparse_available() {
...
@@ -422,7 +422,8 @@ inline bool cusparse_available() {
#else
#else
if
(
bits
==
16
)
if
(
bits
==
16
)
return
false
;
// cusparse's SpMM on fp16 is slow, temporally disabled.
return
false
;
// cusparse's SpMM on fp16 is slow, temporally disabled.
return
true
;
// If the CSR matrix has more NNZ than matrix size, we should not use cuSPARSE 11.1.
return
!
more_nnz_than_matrix_size
;
#endif
#endif
}
}
...
@@ -444,7 +445,9 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
...
@@ -444,7 +445,9 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
bool
use_efeat
=
op
!=
"copy_lhs"
;
bool
use_efeat
=
op
!=
"copy_lhs"
;
if
(
reduce
==
"sum"
)
{
if
(
reduce
==
"sum"
)
{
if
(
op
==
"copy_lhs"
&&
cusparse_available
<
bits
,
IdType
>
())
{
// cusparse
bool
more_nnz
=
(
csr
.
indices
->
shape
[
0
]
>
csr
.
num_rows
*
csr
.
num_cols
);
if
(
op
==
"copy_lhs"
&&
cusparse_available
<
bits
,
IdType
>
(
more_nnz
))
{
// cusparse
int64_t
x_length
=
1
;
int64_t
x_length
=
1
;
for
(
int
i
=
1
;
i
<
ufeat
->
ndim
;
++
i
)
for
(
int
i
=
1
;
i
<
ufeat
->
ndim
;
++
i
)
x_length
*=
ufeat
->
shape
[
i
];
x_length
*=
ufeat
->
shape
[
i
];
...
@@ -456,7 +459,8 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
...
@@ -456,7 +459,8 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
static_cast
<
DType
*>
(
out
->
data
),
static_cast
<
DType
*>
(
out
->
data
),
x_length
);
x_length
);
});
});
}
else
if
(
op
==
"mul"
&&
is_scalar_efeat
&&
cusparse_available
<
bits
,
IdType
>
())
{
// cusparse
}
else
if
(
op
==
"mul"
&&
is_scalar_efeat
&&
cusparse_available
<
bits
,
IdType
>
(
more_nnz
))
{
// cusparse
int64_t
x_length
=
1
;
int64_t
x_length
=
1
;
for
(
int
i
=
1
;
i
<
ufeat
->
ndim
;
++
i
)
for
(
int
i
=
1
;
i
<
ufeat
->
ndim
;
++
i
)
x_length
*=
ufeat
->
shape
[
i
];
x_length
*=
ufeat
->
shape
[
i
];
...
@@ -524,8 +528,9 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
...
@@ -524,8 +528,9 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
bool
use_legacy_cusparsemm
=
bool
use_legacy_cusparsemm
=
(
CUDART_VERSION
<
11000
)
&&
(
CUDART_VERSION
<
11000
)
&&
((
op
==
"copy_lhs"
&&
cusparse_available
<
bits
,
IdType
>
())
||
// legacy cuSPARSE does not care about NNZ, hence the argument "false".
(
op
==
"mul"
&&
is_scalar_efeat
&&
cusparse_available
<
bits
,
IdType
>
()));
((
op
==
"copy_lhs"
&&
cusparse_available
<
bits
,
IdType
>
(
false
))
||
(
op
==
"mul"
&&
is_scalar_efeat
&&
cusparse_available
<
bits
,
IdType
>
(
false
)));
// Create temporary output buffer to store non-transposed output
// Create temporary output buffer to store non-transposed output
if
(
use_legacy_cusparsemm
)
{
if
(
use_legacy_cusparsemm
)
{
for
(
dgl_type_t
ntype
=
0
;
ntype
<
vec_out
.
size
();
++
ntype
)
{
for
(
dgl_type_t
ntype
=
0
;
ntype
<
vec_out
.
size
();
++
ntype
)
{
...
@@ -568,8 +573,9 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
...
@@ -568,8 +573,9 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const
dgl_type_t
dst_id
=
out_ntids
[
etype
];
const
dgl_type_t
dst_id
=
out_ntids
[
etype
];
CSRMatrix
csr
=
vec_csr
[
etype
];
CSRMatrix
csr
=
vec_csr
[
etype
];
if
(
reduce
==
"sum"
)
{
if
(
reduce
==
"sum"
)
{
bool
more_nnz
=
(
csr
.
indices
->
shape
[
0
]
>
csr
.
num_rows
*
csr
.
num_cols
);
/* Call SpMM for each relation type */
/* Call SpMM for each relation type */
if
(
op
==
"copy_lhs"
&&
cusparse_available
<
bits
,
IdType
>
())
{
// cusparse
if
(
op
==
"copy_lhs"
&&
cusparse_available
<
bits
,
IdType
>
(
more_nnz
))
{
// cusparse
/* If CUDA is less than 11.0, put the output in trans_out for later transposition */
/* If CUDA is less than 11.0, put the output in trans_out for later transposition */
DType
*
out
=
(
CUDART_VERSION
<
11000
)
?
trans_out
[
dst_id
]
:
DType
*
out
=
(
CUDART_VERSION
<
11000
)
?
trans_out
[
dst_id
]
:
static_cast
<
DType
*>
(
vec_out
[
dst_id
]
->
data
);
static_cast
<
DType
*>
(
vec_out
[
dst_id
]
->
data
);
...
@@ -580,7 +586,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
...
@@ -580,7 +586,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
out
,
out
,
x_length
,
thr_entry
->
stream
);
x_length
,
thr_entry
->
stream
);
}
else
if
(
op
==
"mul"
&&
is_scalar_efeat
&&
}
else
if
(
op
==
"mul"
&&
is_scalar_efeat
&&
cusparse_available
<
bits
,
IdType
>
())
{
// cusparse
cusparse_available
<
bits
,
IdType
>
(
more_nnz
))
{
// cusparse
NDArray
efeat
=
vec_efeat
[
etype
];
NDArray
efeat
=
vec_efeat
[
etype
];
if
(
!
IsNullArray
(
csr
.
data
))
if
(
!
IsNullArray
(
csr
.
data
))
efeat
=
_IndexSelect
<
DType
,
IdType
>
(
vec_efeat
[
etype
],
csr
.
data
);
efeat
=
_IndexSelect
<
DType
,
IdType
>
(
vec_efeat
[
etype
],
csr
.
data
);
...
...
tests/compute/test_heterograph.py
View file @
a536772e
...
@@ -1521,6 +1521,18 @@ def test_level2(idtype):
...
@@ -1521,6 +1521,18 @@ def test_level2(idtype):
g
.
nodes
[
'game'
].
data
.
clear
()
g
.
nodes
[
'game'
].
data
.
clear
()
@
parametrize_dtype
@
unittest
.
skipIf
(
F
.
_default_context_str
==
'cpu'
,
reason
=
"Need gpu for this test"
)
def
test_more_nnz
(
idtype
):
g
=
dgl
.
graph
(([
0
,
0
,
0
,
0
,
0
],
[
1
,
1
,
1
,
1
,
1
]),
idtype
=
idtype
,
device
=
F
.
ctx
())
g
.
ndata
[
'x'
]
=
F
.
copy_to
(
F
.
ones
((
2
,
5
)),
ctx
=
F
.
ctx
())
g
.
update_all
(
fn
.
copy_u
(
'x'
,
'm'
),
fn
.
sum
(
'm'
,
'y'
))
y
=
g
.
ndata
[
'y'
]
ans
=
np
.
zeros
((
2
,
5
))
ans
[
1
]
=
5
ans
=
F
.
copy_to
(
F
.
tensor
(
ans
,
dtype
=
F
.
dtype
(
y
)),
ctx
=
F
.
ctx
())
assert
F
.
array_equal
(
y
,
ans
)
@
parametrize_dtype
@
parametrize_dtype
def
test_updates
(
idtype
):
def
test_updates
(
idtype
):
def
msg_func
(
edges
):
def
msg_func
(
edges
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment