Unverified Commit e80fbd7e authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Use consistent API for fused norm kernels (#1560)



* Do not suppress MXFP8 norm in Python wrapper func
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support FP8 current scaling in tex norm functions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use single envvar to enable cuDNN MXFP8 norm kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug compilation error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix compilation error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix full-tile requirement for MXFP8 norm kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove unused imports
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add missing imports
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent dd4c17dc
......@@ -26,7 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
......
......@@ -6,12 +6,13 @@
#include "common/util/system.h"
#include "extensions.h"
#include "pybind.h"
namespace transformer_engine::pytorch {
std::pair<TensorWrapper, py::object> createOutputTensor(const NVTEShape &shape, DType dtype,
py::handle quantizer) {
std::vector<size_t> shape_vec;
for (int i = 0; i < shape.ndim; i++) {
for (size_t i = 0; i < shape.ndim; i++) {
size_t t = shape.data[i];
shape_vec.push_back(t);
}
......@@ -74,6 +75,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
float eps, py::object out, py::handle quantizer,
DType out_dtype, const int sm_margin,
const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail;
using namespace transformer_engine::pytorch;
using namespace transformer_engine;
......@@ -107,14 +109,17 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
}
// Determine whether to avoid fused kernel
bool force_unfused_kernel = false;
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) {
if (!transformer_engine::getenv<bool>("NVTE_CUDNN_MXFP8_NORM", false)) {
// TE only supports MXFP8 norm with cuDNN backend
force_unfused_kernel = true;
} else if (N % 128 != 0 || H % 128 != 0) {
// cuDNN norm requires full tile for MXFP8
force_unfused_kernel = true;
bool force_unfused_kernel = true;
if (quantizer.is_none()) {
// No need for separate quantization step if output is unquantized
force_unfused_kernel = false;
} else if (IsFloat8Quantizers(quantizer.ptr())) {
// Always used fused kernel for FP8 delayed scaling
force_unfused_kernel = false;
} else if (IsMXFP8Quantizers(quantizer.ptr())) {
if (transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
// cuDNN MXFP8 kernel requires full tile
force_unfused_kernel = N % 128 != 0 || H % 128 != 0;
}
}
TensorWrapper unquantized_out_cu;
......@@ -145,6 +150,29 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Quantize output if using unfused kernel
if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream());
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr =
my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor &amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
}
nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr,
at::cuda::getCurrentCUDAStream());
}
......@@ -196,6 +224,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
py::object out, py::handle quantizer,
transformer_engine::DType out_dtype, const int sm_margin,
const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail;
using namespace transformer_engine::pytorch;
using namespace transformer_engine;
......@@ -223,14 +252,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
}
// Determine whether to avoid fused kernel
bool force_unfused_kernel = false;
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) {
if (!transformer_engine::getenv<bool>("NVTE_CUDNN_MXFP8_NORM", false)) {
// TE only supports MXFP8 norm with cuDNN backend
force_unfused_kernel = true;
} else if (N % 128 != 0 || H % 128 != 0) {
// cuDNN norm requires full tile for MXFP8
force_unfused_kernel = true;
bool force_unfused_kernel = true;
if (quantizer.is_none()) {
// No need for separate quantization step if output is unquantized
force_unfused_kernel = false;
} else if (IsFloat8Quantizers(quantizer.ptr())) {
// Always used fused kernel for FP8 delayed scaling
force_unfused_kernel = false;
} else if (IsMXFP8Quantizers(quantizer.ptr())) {
if (transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
// cuDNN MXFP8 kernel requires full tile
force_unfused_kernel = N % 128 != 0 || H % 128 != 0;
}
}
TensorWrapper unquantized_out_cu;
......@@ -261,6 +293,29 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Quantize output if using unfused kernel
if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream());
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr =
my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor &amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
}
nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr,
at::cuda::getCurrentCUDAStream());
}
......
......@@ -4,7 +4,6 @@
"""Internal function used by multiple modules."""
import os
from typing import Any, List, Optional, Tuple, Union, Callable
from dataclasses import dataclass
from functools import reduce
......@@ -16,9 +15,6 @@ from .. import cpp_extensions as tex
from ..constants import TE_DType
from ..utils import get_default_init_method
from ..tensor.float8_tensor import Float8Tensor
from ..tensor.mxfp8_tensor import MXFP8Quantizer
_use_cudnn_mxfp8_norm = bool(int(os.getenv("NVTE_CUDNN_MXFP8_NORM", "0")))
def _get_normalization_func(normalization: str, forward: bool):
......@@ -86,26 +82,16 @@ def apply_normalization(
inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias)
split_mxfp8_cast = False
if not _use_cudnn_mxfp8_norm and isinstance(output_quantizer, MXFP8Quantizer):
split_mxfp8_cast = True
output = normalization_func(
return normalization_func(
*inputs,
eps,
None if split_mxfp8_cast else ln_out,
None if split_mxfp8_cast else output_quantizer,
ln_out,
output_quantizer,
TE_DType[output_dtype] if output_dtype in TE_DType else output_dtype,
fwd_ln_sm_margin,
zero_centered_gamma,
)
return (
(output_quantizer.quantize(output[0], out=ln_out), *output[1:])
if split_mxfp8_cast
else output
)
class _NoopCatFunc(torch.autograd.Function):
"""Concatenate tensors, doing a no-op if possible
......
......@@ -55,9 +55,9 @@ from ..tensor.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..cpp_extensions import (
general_gemm,
......@@ -160,11 +160,6 @@ class _LayerNormLinear(torch.autograd.Function):
# Configure quantizer for normalization output
with_quantized_norm = fp8 and not return_layernorm_output
# for Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer
# so we need to set with_quantized_norm to False
if isinstance(input_quantizer, Float8CurrentScalingQuantizer):
with_quantized_norm = False
if with_quantized_norm:
if with_input_all_gather:
input_quantizer.set_usage(rowwise=True, columnwise=False)
......
......@@ -212,8 +212,6 @@ class _LayerNormMLP(torch.autograd.Function):
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
# high precision layernorm output and output of the linear are returned
with_quantized_norm = fp8 and not return_layernorm_output
if isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer):
with_quantized_norm = False
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment