Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
88964a82
Unverified
Commit
88964a82
authored
Aug 10, 2023
by
Chang Liu
Committed by
GitHub
Aug 10, 2023
Browse files
[Bugfix] Fix cusparseCreateCsr format for cuda12 (#6121)
parent
1e16e4ca
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
57 deletions
+39
-57
src/array/cuda/csr_mm.cu
src/array/cuda/csr_mm.cu
+39
-57
No files found.
src/array/cuda/csr_mm.cu
View file @
88964a82
...
...
@@ -57,17 +57,17 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
&
matA
,
A
.
num_rows
,
A
.
num_cols
,
nnzA
,
A
.
indptr
.
Ptr
<
IdType
>
(),
A
.
indices
.
Ptr
<
IdType
>
(),
// cusparseCreateCsr only accepts non-const pointers.
const_cast
<
DType
*>
(
A_weights
),
idtype
,
idtype
,
CUSPARSE_INDEX_BASE_ZERO
,
dtype
));
const_cast
<
DType
*>
(
A_weights
),
idtype
,
idtype
,
CUSPARSE_INDEX_BASE_ZERO
,
dtype
));
CUSPARSE_CALL
(
cusparseCreateCsr
(
&
matB
,
B
.
num_rows
,
B
.
num_cols
,
nnzB
,
B
.
indptr
.
Ptr
<
IdType
>
(),
B
.
indices
.
Ptr
<
IdType
>
(),
// cusparseCreateCsr only accepts non-const pointers.
const_cast
<
DType
*>
(
B_weights
),
idtype
,
idtype
,
CUSPARSE_INDEX_BASE_ZERO
,
dtype
));
const_cast
<
DType
*>
(
B_weights
),
idtype
,
idtype
,
CUSPARSE_INDEX_BASE_ZERO
,
dtype
));
CUSPARSE_CALL
(
cusparseCreateCsr
(
&
matC
,
A
.
num_rows
,
B
.
num_cols
,
0
,
nullptr
,
nullptr
,
nullptr
,
idtype
,
idtype
,
CUSPARSE_INDEX_BASE_ZERO
,
dtype
));
&
matC
,
A
.
num_rows
,
B
.
num_cols
,
0
,
dC_csrOffsets_data
,
nullptr
,
nullptr
,
idtype
,
idtype
,
CUSPARSE_INDEX_BASE_ZERO
,
dtype
));
// SpGEMM Computation
cusparseSpGEMMDescr_t
spgemmDesc
;
cusparseSpGEMMAlg_t
alg
=
CUSPARSE_SPGEMM_DEFAULT
;
...
...
@@ -77,15 +77,12 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
// ask bufferSize1 bytes for external memory
CUSPARSE_CALL
(
cusparseSpGEMM_workEstimation
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
alg
,
spgemmDesc
,
&
workspace_size1
,
NULL
));
matC
,
dtype
,
alg
,
spgemmDesc
,
&
workspace_size1
,
NULL
));
void
*
workspace1
=
(
device
->
AllocWorkspace
(
ctx
,
workspace_size1
));
// inspect the matrices A and B to understand the memory requiremnent
cusparseStatus_t
e
=
cusparseSpGEMM_workEstimation
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
alg
,
spgemmDesc
,
&
workspace_size1
,
workspace1
);
cusparseStatus_t
e
=
cusparseSpGEMM_workEstimation
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
alg
,
spgemmDesc
,
&
workspace_size1
,
workspace1
);
// CUSPARSE_SPGEMM_DEFAULT not support getting num_prods > 2^31 -1
// and throws insufficient memory error within workEstimation call
if
(
e
==
CUSPARSE_STATUS_INSUFFICIENT_RESOURCES
)
{
...
...
@@ -93,17 +90,13 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
alg
=
CUSPARSE_SPGEMM_ALG2
;
device
->
FreeWorkspace
(
ctx
,
workspace1
);
// rerun cusparseSpGEMM_workEstimation
CUSPARSE_CALL
(
cusparseSpGEMM_workEstimation
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
alg
,
spgemmDesc
,
&
workspace_size1
,
NULL
));
CUSPARSE_CALL
(
cusparseSpGEMM_workEstimation
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
alg
,
spgemmDesc
,
&
workspace_size1
,
NULL
));
workspace1
=
(
device
->
AllocWorkspace
(
ctx
,
workspace_size1
));
CUSPARSE_CALL
(
cusparseSpGEMM_workEstimation
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
alg
,
spgemmDesc
,
&
workspace_size1
,
workspace1
));
CUSPARSE_CALL
(
cusparseSpGEMM_workEstimation
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
alg
,
spgemmDesc
,
&
workspace_size1
,
workspace1
));
}
else
{
CHECK
(
e
==
CUSPARSE_STATUS_SUCCESS
)
<<
"CUSPARSE ERROR in SpGEMM: "
<<
e
;
}
...
...
@@ -122,22 +115,18 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
// switch to ALG2/ALG3 for medium & large problem size
if
(
alg
==
CUSPARSE_SPGEMM_DEFAULT
&&
num_prods
>
MEDIUM_NUM_PRODUCTS
)
{
// use ALG3 for very large problem
alg
=
num_prods
>
LARGE_NUM_PRODUCTS
?
CUSPARSE_SPGEMM_ALG3
:
CUSPARSE_SPGEMM_ALG2
;
alg
=
num_prods
>
LARGE_NUM_PRODUCTS
?
CUSPARSE_SPGEMM_ALG3
:
CUSPARSE_SPGEMM_ALG2
;
device
->
FreeWorkspace
(
ctx
,
workspace1
);
// rerun cusparseSpGEMM_workEstimation
CUSPARSE_CALL
(
cusparseSpGEMM_workEstimation
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
alg
,
spgemmDesc
,
&
workspace_size1
,
NULL
));
CUSPARSE_CALL
(
cusparseSpGEMM_workEstimation
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
alg
,
spgemmDesc
,
&
workspace_size1
,
NULL
));
workspace1
=
(
device
->
AllocWorkspace
(
ctx
,
workspace_size1
));
CUSPARSE_CALL
(
cusparseSpGEMM_workEstimation
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
alg
,
spgemmDesc
,
&
workspace_size1
,
workspace1
));
CUSPARSE_CALL
(
cusparseSpGEMM_workEstimation
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
alg
,
spgemmDesc
,
&
workspace_size1
,
workspace1
));
}
else
if
(
alg
==
CUSPARSE_SPGEMM_ALG2
&&
num_prods
>
LARGE_NUM_PRODUCTS
)
{
// no need to rerun cusparseSpGEMM_workEstimation between ALG2 and ALG3
alg
=
CUSPARSE_SPGEMM_ALG3
;
...
...
@@ -147,41 +136,34 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
// estimate memory for ALG2/ALG3; note chunk_fraction is only used by ALG3
// reduce chunk_fraction if crash due to mem., but it trades off speed
float
chunk_fraction
=
num_prods
<
4
*
LARGE_NUM_PRODUCTS
?
0.15
:
0.05
;
CUSPARSE_CALL
(
cusparseSpGEMM_estimateMemory
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
alg
,
spgemmDesc
,
chunk_fraction
,
&
workspace_size3
,
NULL
,
NULL
));
CUSPARSE_CALL
(
cusparseSpGEMM_estimateMemory
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
alg
,
spgemmDesc
,
chunk_fraction
,
&
workspace_size3
,
NULL
,
NULL
));
void
*
workspace3
=
(
device
->
AllocWorkspace
(
ctx
,
workspace_size3
));
CUSPARSE_CALL
(
cusparseSpGEMM_estimateMemory
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
alg
,
spgemmDesc
,
chunk_fraction
,
&
workspace_size3
,
CUSPARSE_CALL
(
cusparseSpGEMM_estimateMemory
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
alg
,
spgemmDesc
,
chunk_fraction
,
&
workspace_size3
,
workspace3
,
&
workspace_size2
));
device
->
FreeWorkspace
(
ctx
,
workspace3
);
}
else
{
CUSPARSE_CALL
(
cusparseSpGEMM_compute
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
alg
,
spgemmDesc
,
&
workspace_size2
,
NULL
));
CUSPARSE_CALL
(
cusparseSpGEMM_compute
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
alg
,
spgemmDesc
,
&
workspace_size2
,
NULL
));
}
// ask bufferSize2 bytes for external memory
void
*
workspace2
=
device
->
AllocWorkspace
(
ctx
,
workspace_size2
);
// compute the intermediate product of A * B
CUSPARSE_CALL
(
cusparseSpGEMM_compute
(
thr_entry
->
cusparse_handle
,
transA
,
transB
,
&
alpha
,
matA
,
matB
,
&
beta
,
matC
,
dtype
,
alg
,
spgemmDesc
,
&
workspace_size2
,
workspace2
));
matC
,
dtype
,
alg
,
spgemmDesc
,
&
workspace_size2
,
workspace2
));
// get matrix C non-zero entries C_nnz1
int64_t
C_num_rows1
,
C_num_cols1
,
C_nnz1
;
CUSPARSE_CALL
(
cusparseSpMatGetSize
(
matC
,
&
C_num_rows1
,
&
C_num_cols1
,
&
C_nnz1
));
IdArray
dC_columns
=
IdArray
::
Empty
({
C_nnz1
},
A
.
indptr
->
dtype
,
A
.
indptr
->
ctx
);
NDArray
dC_weights
=
NDArray
::
Empty
(
{
C_nnz1
},
A_weights_array
->
dtype
,
A
.
indptr
->
ctx
);
NDArray
dC_weights
=
NDArray
::
Empty
(
{
C_nnz1
},
A_weights_array
->
dtype
,
A
.
indptr
->
ctx
);
IdType
*
dC_columns_data
=
dC_columns
.
Ptr
<
IdType
>
();
DType
*
dC_weights_data
=
dC_weights
.
Ptr
<
DType
>
();
// update matC with the new pointers
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment