"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "c07ece5ca490a90b2b19c33ab7da2d21e015d7bd"
Unverified Commit 2d72c11f authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Add support for fp8 GEMM BIAS AUX GELU fusion (#116)



* Add support for fp8 GEMM BIAS AUX GELU fusion
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Fix Lint error
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Fix Lint error
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
parent 5992e03d
......@@ -77,9 +77,10 @@ void cublas_gemm(const Tensor *inputA,
// check consistency of arguments:
// if fp8 is desired, context cannot be null
// fp8 + gelu fusion is unavailable right now.
if (use_fp8) {
NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!");
// fp8 + gelu fusion + fp8 aux is unavailable right now.
if (use_fp8 && gelu) {
NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
"fp8 Aux output for gemm + gelu fusion not supported!");
}
if (is_fp8_dtype(outputD->data.dtype)) {
NVTE_CHECK(!accumulate,
......@@ -182,6 +183,10 @@ void cublas_gemm(const Tensor *inputA,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&ld_gelumat, sizeof(ld_gelumat)));
const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE,
&aux_type, sizeof(aux_type)));
} else if (bias) {
if (grad) {
// grad output is always input B
......
......@@ -20,6 +20,7 @@ def fp8_gemm(
B_dtype: tex.DType,
out_dtype: torch.dtype,
workspace: torch.Tensor,
gelu: bool = False,
accumulate: bool = False,
out: Optional[torch.Tensor] = None,
out_index = None,
......@@ -44,10 +45,15 @@ def fp8_gemm(
device="cuda",
)
return_output = True
# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias.dtype
if gelu:
gelu_input = torch.empty_like(out, dtype=bias_dtype)
else:
gelu_input = empty_tensor
bias_dtype = TE_DType[bias_dtype]
out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype
# Use bfloat16 as default bias_dtype
bias_dtype = tex.DType.kBFloat16 if bias is None else TE_DType[bias.dtype]
_ = torch.ops.tex_ts.te_gemm_ts(
A,
......@@ -66,7 +72,7 @@ def fp8_gemm(
empty_tensor if out_index is None else fp8_meta_tensor.amax_history[0][out_index],
bias if use_bias else empty_tensor,
bias_dtype,
empty_tensor, # this is pre_gelu_out
gelu_input, # this is pre_gelu_out
False, # grad
workspace,
workspace.shape[0],
......@@ -75,7 +81,11 @@ def fp8_gemm(
)
if return_output:
if gelu:
return out, gelu_input
return out
if gelu:
return gelu_input
return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment