Unverified Commit 14198f20 authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Increase number of FP8 tensors per GEMM (#22)



* Increase number of FP8 tensors per GEMM
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Enable FP8 output tensor for fp8_gemm
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [BERT FP8] Initial TE review comments
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Temporary fix for cuda graph non convergence
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Address review comments-2
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Review comments-3
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Change for New API
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Remove unnecessary clone for D_scale, D_amax
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Avoid Roll for AMAX history size = 1
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Update onnx_te_gemm API
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Fix Lint errors
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
parent 58f19082
...@@ -55,7 +55,10 @@ void cublas_gemm(const Tensor *inputA, ...@@ -55,7 +55,10 @@ void cublas_gemm(const Tensor *inputA,
void *A_scale_inverse = inputA->scale_inv.dptr; void *A_scale_inverse = inputA->scale_inv.dptr;
void *B = inputB->data.dptr; void *B = inputB->data.dptr;
void *B_scale_inverse = inputB->scale_inv.dptr; void *B_scale_inverse = inputB->scale_inv.dptr;
void *C = outputD->data.dptr;
void *D = outputD->data.dptr; void *D = outputD->data.dptr;
void *D_scale = outputD->scale.dptr;
void *D_amax = outputD->amax.dptr;
void *bias_ptr = inputBias->data.dptr; void *bias_ptr = inputBias->data.dptr;
const bool bias = bias_ptr != nullptr; const bool bias = bias_ptr != nullptr;
void *pre_gelu_out = outputPreGelu->data.dptr; void *pre_gelu_out = outputPreGelu->data.dptr;
...@@ -78,6 +81,10 @@ void cublas_gemm(const Tensor *inputA, ...@@ -78,6 +81,10 @@ void cublas_gemm(const Tensor *inputA,
if (use_fp8) { if (use_fp8) {
NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!"); NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!");
} }
if (is_fp8_dtype(outputD->data.dtype)) {
NVTE_CHECK(!accumulate,
"Accumulation mode not supported with FP8 GEMM output!");
}
float one = 1.0; float one = 1.0;
float zero = 0.0; float zero = 0.0;
...@@ -87,7 +94,7 @@ void cublas_gemm(const Tensor *inputA, ...@@ -87,7 +94,7 @@ void cublas_gemm(const Tensor *inputA,
NVTE_CHECK_CUBLAS(cublasLtCreate(&handle)); NVTE_CHECK_CUBLAS(cublasLtCreate(&handle));
cublasLtMatmulDesc_t operationDesc = nullptr; cublasLtMatmulDesc_t operationDesc = nullptr;
cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Ddesc = nullptr; cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
cublasLtMatmulPreference_t preference = nullptr; cublasLtMatmulPreference_t preference = nullptr;
int returnedResults = 0; int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {}; cublasLtMatmulHeuristicResult_t heuristicResult = {};
...@@ -135,11 +142,29 @@ void cublas_gemm(const Tensor *inputA, ...@@ -135,11 +142,29 @@ void cublas_gemm(const Tensor *inputA,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, &B_scale_inverse,
sizeof(B_scale_inverse))); sizeof(B_scale_inverse)));
if (is_fp8_dtype(outputD->data.dtype)) {
// Accumulation mode not supported for FP8 output
C = nullptr;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_D_SCALE_POINTER,
&D_scale,
sizeof(D_scale)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_AMAX_D_POINTER,
&D_amax,
sizeof(D_amax)));
// For FP8 output, cuBLAS requires C_type to be same as bias_type
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, bias_type, m, n, ldd));
} else {
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
}
if (bias) { if (bias) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE,
&bias_type, sizeof(bias_type))); &bias_type, sizeof(bias_type)));
} }
} else {
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
} }
if (bias && gelu) { if (bias && gelu) {
...@@ -190,7 +215,7 @@ void cublas_gemm(const Tensor *inputA, ...@@ -190,7 +215,7 @@ void cublas_gemm(const Tensor *inputA,
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspaceSize, sizeof(workspaceSize))); &workspaceSize, sizeof(workspaceSize)));
NVTE_CHECK_CUBLAS(cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Ddesc, NVTE_CHECK_CUBLAS(cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc,
Ddesc, preference, 1, &heuristicResult, Ddesc, preference, 1, &heuristicResult,
&returnedResults)); &returnedResults));
...@@ -205,8 +230,8 @@ void cublas_gemm(const Tensor *inputA, ...@@ -205,8 +230,8 @@ void cublas_gemm(const Tensor *inputA,
B, /* B */ B, /* B */
Bdesc, Bdesc,
static_cast<const void*>(&beta), /* beta */ static_cast<const void*>(&beta), /* beta */
D, /* C */ C, /* C */
Ddesc, Cdesc,
D, /* D */ D, /* D */
Ddesc, Ddesc,
&heuristicResult.algo, /* algo */ &heuristicResult.algo, /* algo */
...@@ -217,6 +242,7 @@ void cublas_gemm(const Tensor *inputA, ...@@ -217,6 +242,7 @@ void cublas_gemm(const Tensor *inputA,
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Bdesc)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Bdesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
......
...@@ -22,14 +22,19 @@ def fp8_gemm( ...@@ -22,14 +22,19 @@ def fp8_gemm(
workspace: torch.Tensor, workspace: torch.Tensor,
accumulate: bool = False, accumulate: bool = False,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
out_index = None,
fp8_meta_tensor: tex.FP8TensorMeta = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
use_bias: bool = False, use_bias: bool = False,
fp32_output: bool = False, fp32_output: bool = False,
use_split_accumulator: bool = False, use_split_accumulator: bool = False,
D_dtype: tex.DType = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""TN layout GEMM with fp8 inputs.""" """TN layout GEMM with fp8 inputs."""
empty_tensor = torch.Tensor() empty_tensor = torch.Tensor()
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_index is not None
return_output = False return_output = False
if out is None: if out is None:
...@@ -42,6 +47,9 @@ def fp8_gemm( ...@@ -42,6 +47,9 @@ def fp8_gemm(
return_output = True return_output = True
out_dtype = tex.DType.kFloat32 if fp32_output else TE_DType[out_dtype] out_dtype = tex.DType.kFloat32 if fp32_output else TE_DType[out_dtype]
# Use bfloat16 as default bias_dtype
bias_dtype = tex.DType.kBFloat16 if bias is None else TE_DType[bias.dtype]
out_dtype = D_dtype if D_dtype is not None else out_dtype
_ = torch.ops.tex_ts.te_gemm_ts( _ = torch.ops.tex_ts.te_gemm_ts(
A, A,
...@@ -55,8 +63,11 @@ def fp8_gemm( ...@@ -55,8 +63,11 @@ def fp8_gemm(
B_dtype, B_dtype,
False, # transb False, # transb
out, out,
empty_tensor if out_index is None else fp8_meta_tensor.scale[out_index],
out_dtype, out_dtype,
empty_tensor if out_index is None else fp8_meta_tensor.amax_history[0][out_index],
bias if use_bias else empty_tensor, bias if use_bias else empty_tensor,
bias_dtype,
empty_tensor, # this is pre_gelu_out empty_tensor, # this is pre_gelu_out
False, # grad False, # grad
workspace, workspace,
...@@ -95,6 +106,7 @@ def gemm( ...@@ -95,6 +106,7 @@ def gemm(
input_dtype = TE_DType[dtype] input_dtype = TE_DType[dtype]
output_dtype = tex.DType.kFloat32 if fp32_output else input_dtype output_dtype = tex.DType.kFloat32 if fp32_output else input_dtype
bias_dtype = output_dtype if bias is None else TE_DType[bias.dtype]
return_output = False return_output = False
if out is None: if out is None:
...@@ -132,8 +144,11 @@ def gemm( ...@@ -132,8 +144,11 @@ def gemm(
input_dtype, input_dtype,
transb, transb,
out, out,
empty_tensor, # out_scale
output_dtype, output_dtype,
empty_tensor, # out_amax
grad_bias if grad else bias, grad_bias if grad else bias,
bias_dtype,
gelu_input, gelu_input,
grad, grad,
workspace, workspace,
......
...@@ -48,15 +48,19 @@ class FP8TensorMeta { ...@@ -48,15 +48,19 @@ class FP8TensorMeta {
enum FP8FwdTensors { enum FP8FwdTensors {
GEMM1_INPUT = 0, GEMM1_INPUT = 0,
GEMM1_WEIGHT = 1, GEMM1_WEIGHT = 1,
GEMM2_INPUT = 2, GEMM1_OUTPUT = 2,
GEMM2_WEIGHT = 3 GEMM2_INPUT = 3,
GEMM2_WEIGHT = 4,
GEMM2_OUTPUT = 5
}; };
// Used as named indices on the `scale`, `scale_inv`, // Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class. // and `amax` tensors in the `FP8TensorMeta` class.
enum FP8BwdTensors { enum FP8BwdTensors {
GRAD_OUTPUT1 = 0, GRAD_OUTPUT1 = 0,
GRAD_OUTPUT2 = 1 GRAD_INPUT1 = 1,
GRAD_OUTPUT2 = 2,
GRAD_INPUT2 = 3
}; };
......
...@@ -16,8 +16,11 @@ void te_gemm(at::Tensor A, ...@@ -16,8 +16,11 @@ void te_gemm(at::Tensor A,
transformer_engine::DType B_type, transformer_engine::DType B_type,
bool transb, bool transb,
at::Tensor D, at::Tensor D,
at::Tensor D_scale,
transformer_engine::DType D_type, transformer_engine::DType D_type,
at::Tensor D_amax,
at::Tensor bias, at::Tensor bias,
transformer_engine::DType bias_type,
at::Tensor pre_gelu_out, at::Tensor pre_gelu_out,
bool grad, bool grad,
at::Tensor workspace, at::Tensor workspace,
...@@ -39,9 +42,10 @@ void te_gemm(at::Tensor A, ...@@ -39,9 +42,10 @@ void te_gemm(at::Tensor A,
auto te_D = makeTransformerEngineTensor(D.data_ptr(), auto te_D = makeTransformerEngineTensor(D.data_ptr(),
{static_cast<size_t>(D.size(0)), {static_cast<size_t>(D.size(0)),
static_cast<size_t>(D.size(1))}, static_cast<size_t>(D.size(1))},
D_type); D_type, D_amax.data_ptr(),
D_scale.data_ptr(), nullptr);
auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), {static_cast<size_t>(bias.size(0))}, auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), {static_cast<size_t>(bias.size(0))},
GetTransformerEngineDType(bias.scalar_type())); bias_type);
const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0))} ? std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0))}
...@@ -869,10 +873,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -869,10 +873,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::enum_<transformer_engine::FP8FwdTensors>(m, "FP8FwdTensors") py::enum_<transformer_engine::FP8FwdTensors>(m, "FP8FwdTensors")
.value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT) .value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT)
.value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT) .value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT)
.value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT)
.value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT) .value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT)
.value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT); .value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT)
.value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT);
py::enum_<transformer_engine::FP8BwdTensors>(m, "FP8BwdTensors") py::enum_<transformer_engine::FP8BwdTensors>(m, "FP8BwdTensors")
.value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1) .value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1)
.value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2); .value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1)
.value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2)
.value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2);
} }
...@@ -16,8 +16,11 @@ void te_gemm(at::Tensor A, ...@@ -16,8 +16,11 @@ void te_gemm(at::Tensor A,
transformer_engine::DType B_type, transformer_engine::DType B_type,
bool transb, bool transb,
at::Tensor D, at::Tensor D,
at::Tensor D_scale,
transformer_engine::DType D_type, transformer_engine::DType D_type,
at::Tensor D_amax,
at::Tensor bias, at::Tensor bias,
transformer_engine::DType bias_type,
at::Tensor pre_gelu_out, at::Tensor pre_gelu_out,
bool grad, bool grad,
at::Tensor workspace, at::Tensor workspace,
......
...@@ -73,8 +73,11 @@ at::Tensor te_gemm_ts(at::Tensor A, ...@@ -73,8 +73,11 @@ at::Tensor te_gemm_ts(at::Tensor A,
int64_t B_type, int64_t B_type,
int64_t transb, int64_t transb,
at::Tensor D, at::Tensor D,
at::Tensor D_scale,
int64_t D_type, int64_t D_type,
at::Tensor D_amax,
at::Tensor bias, at::Tensor bias,
int64_t bias_type,
at::Tensor pre_gelu_out, at::Tensor pre_gelu_out,
int64_t grad, int64_t grad,
at::Tensor workspace, at::Tensor workspace,
...@@ -87,6 +90,7 @@ at::Tensor te_gemm_ts(at::Tensor A, ...@@ -87,6 +90,7 @@ at::Tensor te_gemm_ts(at::Tensor A,
transformer_engine::DType B_type_arg = reverse_map_dtype(B_type); transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
bool transb_arg = static_cast<bool>(transb); bool transb_arg = static_cast<bool>(transb);
transformer_engine::DType D_type_arg = reverse_map_dtype(D_type); transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type);
bool grad_arg = static_cast<bool>(grad); bool grad_arg = static_cast<bool>(grad);
size_t workspaceSize_arg = static_cast<size_t>(workspaceSize); size_t workspaceSize_arg = static_cast<size_t>(workspaceSize);
bool accumulate_arg = static_cast<bool>(accumulate); bool accumulate_arg = static_cast<bool>(accumulate);
...@@ -107,8 +111,11 @@ at::Tensor te_gemm_ts(at::Tensor A, ...@@ -107,8 +111,11 @@ at::Tensor te_gemm_ts(at::Tensor A,
B_type_arg, B_type_arg,
transb_arg, transb_arg,
D, D,
D_scale,
D_type_arg, D_type_arg,
D_amax,
bias, bias,
bias_type_arg,
pre_gelu_out, pre_gelu_out,
grad_arg, grad_arg,
workspace, workspace,
......
...@@ -299,6 +299,7 @@ def get_fp8_group() -> Union[dist_group_type, None]: ...@@ -299,6 +299,7 @@ def get_fp8_group() -> Union[dist_group_type, None]:
def update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: def update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
"""Update amax history and set next amax to zero.""" """Update amax history and set next amax to zero."""
if amax_history.shape[0] > 1:
amax_history = torch.roll(amax_history, -1, 0) amax_history = torch.roll(amax_history, -1, 0)
amax_history[0].fill_(0.0) amax_history[0].fill_(0.0)
return amax_history return amax_history
......
...@@ -158,8 +158,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -158,8 +158,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def set_meta_tensor(self, fwd: bool) -> None: def set_meta_tensor(self, fwd: bool) -> None:
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
num_fp8_tensors = ( num_fp8_tensors = (
self.fp8_meta["num_gemms"] * 2 if fwd else self.fp8_meta["num_gemms"] self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2
) )
self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta() self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta()
......
...@@ -99,7 +99,7 @@ def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): ...@@ -99,7 +99,7 @@ def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
@symbolic_helper.parse_args("v", "fs", "i", "i", "i", @symbolic_helper.parse_args("v", "fs", "i", "i", "i",
"v", "fs", "i", "i", "i", "v", "fs", "i", "i", "i",
"v", "i", "v", "v", "i", "v", "fs", "i", "fs", "v", "i", "v", "i",
"v", "i", "i", "i") "v", "i", "i", "i")
def onnx_te_gemm( def onnx_te_gemm(
g, g,
...@@ -114,8 +114,11 @@ def onnx_te_gemm( ...@@ -114,8 +114,11 @@ def onnx_te_gemm(
input_type, input_type,
trans_input, trans_input,
out, out,
out_scale,
out_type, out_type,
out_amax,
bias, bias,
bias_type,
pre_gelu_out, pre_gelu_out,
grad, grad,
workspace, workspace,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment