"cacheflow/vscode:/vscode.git/clone" did not exist on "3363c27d1954fd63741020dbf5e48bc9b9735cd7"
Unverified Commit 2d72c11f authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Add support for fp8 GEMM BIAS AUX GELU fusion (#116)



* Add support for fp8 GEMM BIAS AUX GELU fusion
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Fix Lint error
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Fix Lint error
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
parent 5992e03d
...@@ -77,9 +77,10 @@ void cublas_gemm(const Tensor *inputA, ...@@ -77,9 +77,10 @@ void cublas_gemm(const Tensor *inputA,
// check consistency of arguments: // check consistency of arguments:
// if fp8 is desired, context cannot be null // if fp8 is desired, context cannot be null
// fp8 + gelu fusion is unavailable right now. // fp8 + gelu fusion + fp8 aux is unavailable right now.
if (use_fp8) { if (use_fp8 && gelu) {
NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!"); NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
"fp8 Aux output for gemm + gelu fusion not supported!");
} }
if (is_fp8_dtype(outputD->data.dtype)) { if (is_fp8_dtype(outputD->data.dtype)) {
NVTE_CHECK(!accumulate, NVTE_CHECK(!accumulate,
...@@ -182,6 +183,10 @@ void cublas_gemm(const Tensor *inputA, ...@@ -182,6 +183,10 @@ void cublas_gemm(const Tensor *inputA,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&ld_gelumat, sizeof(ld_gelumat))); &ld_gelumat, sizeof(ld_gelumat)));
const cudaDataType_t aux_type = get_cuda_dtype(outputPreGelu->data.dtype);
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE,
&aux_type, sizeof(aux_type)));
} else if (bias) { } else if (bias) {
if (grad) { if (grad) {
// grad output is always input B // grad output is always input B
......
...@@ -20,6 +20,7 @@ def fp8_gemm( ...@@ -20,6 +20,7 @@ def fp8_gemm(
B_dtype: tex.DType, B_dtype: tex.DType,
out_dtype: torch.dtype, out_dtype: torch.dtype,
workspace: torch.Tensor, workspace: torch.Tensor,
gelu: bool = False,
accumulate: bool = False, accumulate: bool = False,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
out_index = None, out_index = None,
...@@ -44,10 +45,15 @@ def fp8_gemm( ...@@ -44,10 +45,15 @@ def fp8_gemm(
device="cuda", device="cuda",
) )
return_output = True return_output = True
# Use bfloat16 as default bias_dtype
bias_dtype = torch.bfloat16 if bias is None else bias.dtype
if gelu:
gelu_input = torch.empty_like(out, dtype=bias_dtype)
else:
gelu_input = empty_tensor
bias_dtype = TE_DType[bias_dtype]
out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype
# Use bfloat16 as default bias_dtype
bias_dtype = tex.DType.kBFloat16 if bias is None else TE_DType[bias.dtype]
_ = torch.ops.tex_ts.te_gemm_ts( _ = torch.ops.tex_ts.te_gemm_ts(
A, A,
...@@ -66,7 +72,7 @@ def fp8_gemm( ...@@ -66,7 +72,7 @@ def fp8_gemm(
empty_tensor if out_index is None else fp8_meta_tensor.amax_history[0][out_index], empty_tensor if out_index is None else fp8_meta_tensor.amax_history[0][out_index],
bias if use_bias else empty_tensor, bias if use_bias else empty_tensor,
bias_dtype, bias_dtype,
empty_tensor, # this is pre_gelu_out gelu_input, # this is pre_gelu_out
False, # grad False, # grad
workspace, workspace,
workspace.shape[0], workspace.shape[0],
...@@ -75,7 +81,11 @@ def fp8_gemm( ...@@ -75,7 +81,11 @@ def fp8_gemm(
) )
if return_output: if return_output:
if gelu:
return out, gelu_input
return out return out
if gelu:
return gelu_input
return None return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment