Commit 4c0c31dd authored by zhouxiang's avatar zhouxiang
Browse files

支持24.04dtk编译

parent 28569504
...@@ -802,10 +802,10 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand ...@@ -802,10 +802,10 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand
cublasLtMatrixLayout_t Ddesc, cublasLtMatrixLayout_t Ddesc,
cudaStream_t stream) cudaStream_t stream)
{ {
#if (CUBLAS_VERSION) <= 11601 //#if (CUBLAS_VERSION) <= 11601
FT_CHECK_WITH_INFO(false, "CUBLAS version too low."); FT_CHECK_WITH_INFO(false, "CUBLAS version too low.");
return {false, cublasLtMatmulAlgo_t{}}; return {false, cublasLtMatmulAlgo_t{}};
#else /*#else
size_t returnSize; size_t returnSize;
int32_t pointer_mode; int32_t pointer_mode;
cublasLtMatmulDescGetAttribute( cublasLtMatmulDescGetAttribute(
...@@ -893,7 +893,7 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand ...@@ -893,7 +893,7 @@ std::pair<bool, cublasLtMatmulAlgo_t> cublasMMWrapper::findBestAlgo(cublasLtHand
} }
return {best_time != INFINITY, result.algo}; return {best_time != INFINITY, result.algo};
#endif #endif*/
} }
cublasMMWrapper::MatrixLayout cublasMMWrapper::createMatrixLayout(cublasLtMatrixLayout_t Mdesc) cublasMMWrapper::MatrixLayout cublasMMWrapper::createMatrixLayout(cublasLtMatrixLayout_t Mdesc)
...@@ -901,6 +901,9 @@ cublasMMWrapper::MatrixLayout cublasMMWrapper::createMatrixLayout(cublasLtMatrix ...@@ -901,6 +901,9 @@ cublasMMWrapper::MatrixLayout cublasMMWrapper::createMatrixLayout(cublasLtMatrix
size_t returnSize; size_t returnSize;
MatrixLayout m_layout; MatrixLayout m_layout;
FT_CHECK_WITH_INFO(false, "cublasLtMatrixLayoutGetAttribute is not support.");
/*
cublasLtMatrixLayoutGetAttribute( cublasLtMatrixLayoutGetAttribute(
Mdesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &std::get<0>(m_layout), sizeof(std::get<0>(m_layout)), &returnSize); Mdesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &std::get<0>(m_layout), sizeof(std::get<0>(m_layout)), &returnSize);
cublasLtMatrixLayoutGetAttribute( cublasLtMatrixLayoutGetAttribute(
...@@ -909,10 +912,10 @@ cublasMMWrapper::MatrixLayout cublasMMWrapper::createMatrixLayout(cublasLtMatrix ...@@ -909,10 +912,10 @@ cublasMMWrapper::MatrixLayout cublasMMWrapper::createMatrixLayout(cublasLtMatrix
Mdesc, CUBLASLT_MATRIX_LAYOUT_ROWS, &std::get<2>(m_layout), sizeof(std::get<2>(m_layout)), &returnSize); Mdesc, CUBLASLT_MATRIX_LAYOUT_ROWS, &std::get<2>(m_layout), sizeof(std::get<2>(m_layout)), &returnSize);
cublasLtMatrixLayoutGetAttribute( cublasLtMatrixLayoutGetAttribute(
Mdesc, CUBLASLT_MATRIX_LAYOUT_COLS, &std::get<3>(m_layout), sizeof(std::get<3>(m_layout)), &returnSize); Mdesc, CUBLASLT_MATRIX_LAYOUT_COLS, &std::get<3>(m_layout), sizeof(std::get<3>(m_layout)), &returnSize);
*/
return m_layout; return m_layout;
} }
/*
cublasStatus_t cublasMMWrapper::cublasLtMatmulWrapper(cublasLtHandle_t lightHandle, cublasStatus_t cublasMMWrapper::cublasLtMatmulWrapper(cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t computeDesc, cublasLtMatmulDesc_t computeDesc,
const void* alpha, const void* alpha,
...@@ -969,7 +972,7 @@ cublasStatus_t cublasMMWrapper::cublasLtMatmulWrapper(cublasLtHandle_t ...@@ -969,7 +972,7 @@ cublasStatus_t cublasMMWrapper::cublasLtMatmulWrapper(cublasLtHandle_t
workspaceSizeInBytes, workspaceSizeInBytes,
stream); stream);
} }
*/
void cublasMMWrapper::_Int8Gemm(const int m, void cublasMMWrapper::_Int8Gemm(const int m,
const int n, const int n,
const int k, const int k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment