Commit d81f8119 authored by wenjh's avatar wenjh
Browse files

Adapt to changes of hipblaslt


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 3f800f01
......@@ -1076,17 +1076,6 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
hipblasLtMatmulPreference_t preference = nullptr;
hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT;
hipblasLtMatmulFlags_t matmul_flag = HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_BF16;
if (tensorwise_int8) {
if (D_type == HIP_R_16BF) {
matmul_flag = HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_BF16;
} else if (D_type == HIP_R_32F) {
matmul_flag = HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_FP32;
} else {
NVTE_CHECK(false, "tensorwise_int8 only surpport D_type bf16 or fp32!");
}
}
int64_t ld_gelumat = (int64_t)ldd;
// default to tf32 except for e5m2 inputs where the config is not supported
......@@ -1099,11 +1088,7 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
transb == HIPBLAS_OP_N ? n : k, ldb));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
if (tensorwise_int8) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F, matmul_flag));
} else {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F));
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA,
&transa, sizeof(transa)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB,
......@@ -1450,16 +1435,6 @@ void hipblaslt_batchgemm_tensorwise_int8(const Tensor *inputA,
hipblasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
hipblasLtMatmulPreference_t preference = nullptr;
hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT;
hipblasLtMatmulFlags_t matmul_flag = HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_BF16;
if (tensorwise_int8) {
if (D_type == HIP_R_16BF) {
matmul_flag = HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_BF16;
} else if (D_type == HIP_R_32F) {
matmul_flag = HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_FP32;
} else {
NVTE_CHECK(false, "tensorwise_int8 only surpport D_type bf16 or fp32!");
}
}
int64_t ld_gelumat = (int64_t) ldd;
......@@ -1491,7 +1466,7 @@ void hipblaslt_batchgemm_tensorwise_int8(const Tensor *inputA,
hipblasLtMatrixLayoutSetAttribute(Ddesc, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t));
hipblasLtMatrixLayoutSetAttribute(Ddesc, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideD, sizeof(int64_t));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F, matmul_flag));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F));
} else {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment