Unverified Commit 14198f20 authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Increase number of FP8 tensors per GEMM (#22)



* Increase number of FP8 tensors per GEMM
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Enable FP8 output tensor for fp8_gemm
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [BERT FP8] Initial TE review comments
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Temporary fix for cuda graph non convergence
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Address review comments-2
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Review comments-3
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Change for New API
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Remove unnecessary clone for D_scale, D_amax
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Avoid Roll for AMAX history size = 1
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Update onnx_te_gemm API
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Fix Lint errors
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
parent 58f19082
......@@ -55,7 +55,10 @@ void cublas_gemm(const Tensor *inputA,
void *A_scale_inverse = inputA->scale_inv.dptr;
void *B = inputB->data.dptr;
void *B_scale_inverse = inputB->scale_inv.dptr;
void *C = outputD->data.dptr;
void *D = outputD->data.dptr;
void *D_scale = outputD->scale.dptr;
void *D_amax = outputD->amax.dptr;
void *bias_ptr = inputBias->data.dptr;
const bool bias = bias_ptr != nullptr;
void *pre_gelu_out = outputPreGelu->data.dptr;
......@@ -78,6 +81,10 @@ void cublas_gemm(const Tensor *inputA,
if (use_fp8) {
NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!");
}
if (is_fp8_dtype(outputD->data.dtype)) {
NVTE_CHECK(!accumulate,
"Accumulation mode not supported with FP8 GEMM output!");
}
float one = 1.0;
float zero = 0.0;
......@@ -87,7 +94,7 @@ void cublas_gemm(const Tensor *inputA,
NVTE_CHECK_CUBLAS(cublasLtCreate(&handle));
cublasLtMatmulDesc_t operationDesc = nullptr;
cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Ddesc = nullptr;
cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
cublasLtMatmulPreference_t preference = nullptr;
int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
......@@ -135,11 +142,29 @@ void cublas_gemm(const Tensor *inputA,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse,
sizeof(B_scale_inverse)));
if (is_fp8_dtype(outputD->data.dtype)) {
// Accumulation mode not supported for FP8 output
C = nullptr;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_D_SCALE_POINTER,
&D_scale,
sizeof(D_scale)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_AMAX_D_POINTER,
&D_amax,
sizeof(D_amax)));
// For FP8 output, cuBLAS requires C_type to be same as bias_type
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, bias_type, m, n, ldd));
} else {
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
}
if (bias) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE,
&bias_type, sizeof(bias_type)));
}
} else {
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd));
}
if (bias && gelu) {
......@@ -190,7 +215,7 @@ void cublas_gemm(const Tensor *inputA,
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspaceSize, sizeof(workspaceSize)));
NVTE_CHECK_CUBLAS(cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Ddesc,
NVTE_CHECK_CUBLAS(cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc,
Ddesc, preference, 1, &heuristicResult,
&returnedResults));
......@@ -205,8 +230,8 @@ void cublas_gemm(const Tensor *inputA,
B, /* B */
Bdesc,
static_cast<const void*>(&beta), /* beta */
D, /* C */
Ddesc,
C, /* C */
Cdesc,
D, /* D */
Ddesc,
&heuristicResult.algo, /* algo */
......@@ -217,6 +242,7 @@ void cublas_gemm(const Tensor *inputA,
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceDestroy(preference));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Cdesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Bdesc));
NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutDestroy(Adesc));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescDestroy(operationDesc));
......
......@@ -22,14 +22,19 @@ def fp8_gemm(
workspace: torch.Tensor,
accumulate: bool = False,
out: Optional[torch.Tensor] = None,
out_index = None,
fp8_meta_tensor: tex.FP8TensorMeta = None,
bias: Optional[torch.Tensor] = None,
use_bias: bool = False,
fp32_output: bool = False,
use_split_accumulator: bool = False,
D_dtype: tex.DType = None,
) -> torch.Tensor:
"""TN layout GEMM with fp8 inputs."""
empty_tensor = torch.Tensor()
if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]:
assert fp8_meta_tensor is not None and out_index is not None
return_output = False
if out is None:
......@@ -42,6 +47,9 @@ def fp8_gemm(
return_output = True
out_dtype = tex.DType.kFloat32 if fp32_output else TE_DType[out_dtype]
# Use bfloat16 as default bias_dtype
bias_dtype = tex.DType.kBFloat16 if bias is None else TE_DType[bias.dtype]
out_dtype = D_dtype if D_dtype is not None else out_dtype
_ = torch.ops.tex_ts.te_gemm_ts(
A,
......@@ -55,8 +63,11 @@ def fp8_gemm(
B_dtype,
False, # transb
out,
empty_tensor if out_index is None else fp8_meta_tensor.scale[out_index],
out_dtype,
empty_tensor if out_index is None else fp8_meta_tensor.amax_history[0][out_index],
bias if use_bias else empty_tensor,
bias_dtype,
empty_tensor, # this is pre_gelu_out
False, # grad
workspace,
......@@ -95,6 +106,7 @@ def gemm(
input_dtype = TE_DType[dtype]
output_dtype = tex.DType.kFloat32 if fp32_output else input_dtype
bias_dtype = output_dtype if bias is None else TE_DType[bias.dtype]
return_output = False
if out is None:
......@@ -132,8 +144,11 @@ def gemm(
input_dtype,
transb,
out,
empty_tensor, # out_scale
output_dtype,
empty_tensor, # out_amax
grad_bias if grad else bias,
bias_dtype,
gelu_input,
grad,
workspace,
......
......@@ -48,15 +48,19 @@ class FP8TensorMeta {
enum FP8FwdTensors {
GEMM1_INPUT = 0,
GEMM1_WEIGHT = 1,
GEMM2_INPUT = 2,
GEMM2_WEIGHT = 3
GEMM1_OUTPUT = 2,
GEMM2_INPUT = 3,
GEMM2_WEIGHT = 4,
GEMM2_OUTPUT = 5
};
// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8BwdTensors {
GRAD_OUTPUT1 = 0,
GRAD_OUTPUT2 = 1
GRAD_INPUT1 = 1,
GRAD_OUTPUT2 = 2,
GRAD_INPUT2 = 3
};
......
......@@ -16,8 +16,11 @@ void te_gemm(at::Tensor A,
transformer_engine::DType B_type,
bool transb,
at::Tensor D,
at::Tensor D_scale,
transformer_engine::DType D_type,
at::Tensor D_amax,
at::Tensor bias,
transformer_engine::DType bias_type,
at::Tensor pre_gelu_out,
bool grad,
at::Tensor workspace,
......@@ -39,9 +42,10 @@ void te_gemm(at::Tensor A,
auto te_D = makeTransformerEngineTensor(D.data_ptr(),
{static_cast<size_t>(D.size(0)),
static_cast<size_t>(D.size(1))},
D_type);
D_type, D_amax.data_ptr(),
D_scale.data_ptr(), nullptr);
auto te_bias = makeTransformerEngineTensor(bias.data_ptr(), {static_cast<size_t>(bias.size(0))},
GetTransformerEngineDType(bias.scalar_type()));
bias_type);
const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(pre_gelu_out.size(0))}
......@@ -869,10 +873,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::enum_<transformer_engine::FP8FwdTensors>(m, "FP8FwdTensors")
.value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT)
.value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT)
.value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT)
.value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT)
.value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT);
.value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT)
.value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT);
py::enum_<transformer_engine::FP8BwdTensors>(m, "FP8BwdTensors")
.value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1)
.value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2);
.value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1)
.value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2)
.value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2);
}
......@@ -16,8 +16,11 @@ void te_gemm(at::Tensor A,
transformer_engine::DType B_type,
bool transb,
at::Tensor D,
at::Tensor D_scale,
transformer_engine::DType D_type,
at::Tensor D_amax,
at::Tensor bias,
transformer_engine::DType bias_type,
at::Tensor pre_gelu_out,
bool grad,
at::Tensor workspace,
......
......@@ -73,8 +73,11 @@ at::Tensor te_gemm_ts(at::Tensor A,
int64_t B_type,
int64_t transb,
at::Tensor D,
at::Tensor D_scale,
int64_t D_type,
at::Tensor D_amax,
at::Tensor bias,
int64_t bias_type,
at::Tensor pre_gelu_out,
int64_t grad,
at::Tensor workspace,
......@@ -87,6 +90,7 @@ at::Tensor te_gemm_ts(at::Tensor A,
transformer_engine::DType B_type_arg = reverse_map_dtype(B_type);
bool transb_arg = static_cast<bool>(transb);
transformer_engine::DType D_type_arg = reverse_map_dtype(D_type);
transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type);
bool grad_arg = static_cast<bool>(grad);
size_t workspaceSize_arg = static_cast<size_t>(workspaceSize);
bool accumulate_arg = static_cast<bool>(accumulate);
......@@ -107,8 +111,11 @@ at::Tensor te_gemm_ts(at::Tensor A,
B_type_arg,
transb_arg,
D,
D_scale,
D_type_arg,
D_amax,
bias,
bias_type_arg,
pre_gelu_out,
grad_arg,
workspace,
......
......@@ -299,6 +299,7 @@ def get_fp8_group() -> Union[dist_group_type, None]:
def update_amax_history(amax_history: torch.Tensor) -> torch.Tensor:
"""Update amax history and set next amax to zero."""
if amax_history.shape[0] > 1:
amax_history = torch.roll(amax_history, -1, 0)
amax_history[0].fill_(0.0)
return amax_history
......
......@@ -158,8 +158,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
def set_meta_tensor(self, fwd: bool) -> None:
"""Init scales and amaxes for fwd | bwd."""
fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
# 2 (grad_output and grad_input) for bwd
num_fp8_tensors = (
self.fp8_meta["num_gemms"] * 2 if fwd else self.fp8_meta["num_gemms"]
self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2
)
self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta()
......
......@@ -99,7 +99,7 @@ def onnx_fp8_gelu(g, inputs, scale, amax, scale_inv, fp8_tensor, otype):
@symbolic_helper.parse_args("v", "fs", "i", "i", "i",
"v", "fs", "i", "i", "i",
"v", "i", "v", "v", "i",
"v", "fs", "i", "fs", "v", "i", "v", "i",
"v", "i", "i", "i")
def onnx_te_gemm(
g,
......@@ -114,8 +114,11 @@ def onnx_te_gemm(
input_type,
trans_input,
out,
out_scale,
out_type,
out_amax,
bias,
bias_type,
pre_gelu_out,
grad,
workspace,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment