Unverified Commit 258d0842 authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Do not use normalization forward + amax fusion if cuDNN backend is requested (#2174)



* Do not use norm fwd + amax fusion if cudnn backend is requested
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Read envirornment vairable directly to avoid include error
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 59130cc9
...@@ -66,7 +66,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -66,7 +66,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) {
cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output NVTE_CHECK(!cudnn_backend,
"cuDNN does not currently support amax output for non quantized output");
} }
bool gamma_in_weight_dtype = false; bool gamma_in_weight_dtype = false;
......
...@@ -52,7 +52,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -52,7 +52,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) {
cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output NVTE_CHECK(!cudnn_backend,
"cuDNN does not currently support amax output for non quantized output");
} }
bool training = bool training =
......
...@@ -110,7 +110,8 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -110,7 +110,8 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
TensorWrapper unquantized_out_cu; TensorWrapper unquantized_out_cu;
py::object unquantized_out; py::object unquantized_out;
if (force_unfused_kernel) { if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) &&
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get()); auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
std::tie(unquantized_out_cu, unquantized_out) = std::tie(unquantized_out_cu, unquantized_out) =
my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype);
...@@ -145,7 +146,8 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -145,7 +146,8 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Quantize output if using unfused kernel // Quantize output if using unfused kernel
if (force_unfused_kernel) { if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) &&
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get()); auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu);
} else { } else {
...@@ -290,7 +292,8 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -290,7 +292,8 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
TensorWrapper unquantized_out_cu; TensorWrapper unquantized_out_cu;
py::object unquantized_out; py::object unquantized_out;
if (force_unfused_kernel) { if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) &&
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get()); auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
std::tie(unquantized_out_cu, unquantized_out) = std::tie(unquantized_out_cu, unquantized_out) =
my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype);
...@@ -325,7 +328,8 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -325,7 +328,8 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Quantize output if using unfused kernel // Quantize output if using unfused kernel
if (force_unfused_kernel) { if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) &&
!transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get()); auto my_quantizer_cs = dynamic_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu);
} else { } else {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment