Commit 712d526a authored by yuguo's avatar yuguo
Browse files

Merge branch 'develop_v2.8' of...

Merge branch 'develop_v2.8' of http://10.16.6.30/dcutoolkit/deeplearing/TransformerEngine into develop_v2.8
parents 47077129 a26a0c30
...@@ -1053,7 +1053,7 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, ...@@ -1053,7 +1053,7 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
NVTE_ERROR("MOE nvte_grouped_gemm not surpport bias or gelu."); NVTE_ERROR("MOE nvte_grouped_gemm not surpport bias or gelu.");
} }
hipblaslt_goupedgemm(inputA, inputB, outputD, m, n, k, b, hipblaslt_groupedgemm(inputA, inputB, outputD, m, n, k, b,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N, (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N, (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
wspace->data.dptr, wspace->data.shape[0], wspace->data.dptr, wspace->data.shape[0],
......
...@@ -1429,7 +1429,7 @@ class d_userArgsManager { ...@@ -1429,7 +1429,7 @@ class d_userArgsManager {
static userArgsManager UAManager; static userArgsManager UAManager;
static d_userArgsManager d_UAManager; static d_userArgsManager d_UAManager;
void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB, void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB,
std::vector<Tensor*>& outputD, std::vector<int64_t>& m, std::vector<Tensor*>& outputD, std::vector<int64_t>& m,
std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b, std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b,
hipblasOperation_t transa, hipblasOperation_t transb, void* workspace, hipblasOperation_t transa, hipblasOperation_t transb, void* workspace,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment