Unverified Commit e80fbd7e authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Use consistent API for fused norm kernels (#1560)



* Do not suppress MXFP8 norm in Python wrapper func
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support FP8 current scaling in tex norm functions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use single envvar to enable cuDNN MXFP8 norm kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug compilation error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix compilation error
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix full-tile requirement for MXFP8 norm kernels
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Remove unused imports
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add missing imports
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent dd4c17dc
...@@ -26,7 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test ...@@ -26,7 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
......
...@@ -6,12 +6,13 @@ ...@@ -6,12 +6,13 @@
#include "common/util/system.h" #include "common/util/system.h"
#include "extensions.h" #include "extensions.h"
#include "pybind.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
std::pair<TensorWrapper, py::object> createOutputTensor(const NVTEShape &shape, DType dtype, std::pair<TensorWrapper, py::object> createOutputTensor(const NVTEShape &shape, DType dtype,
py::handle quantizer) { py::handle quantizer) {
std::vector<size_t> shape_vec; std::vector<size_t> shape_vec;
for (int i = 0; i < shape.ndim; i++) { for (size_t i = 0; i < shape.ndim; i++) {
size_t t = shape.data[i]; size_t t = shape.data[i];
shape_vec.push_back(t); shape_vec.push_back(t);
} }
...@@ -74,6 +75,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -74,6 +75,7 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
float eps, py::object out, py::handle quantizer, float eps, py::object out, py::handle quantizer,
DType out_dtype, const int sm_margin, DType out_dtype, const int sm_margin,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail;
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
using namespace transformer_engine; using namespace transformer_engine;
...@@ -107,14 +109,17 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -107,14 +109,17 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
} }
// Determine whether to avoid fused kernel // Determine whether to avoid fused kernel
bool force_unfused_kernel = false; bool force_unfused_kernel = true;
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { if (quantizer.is_none()) {
if (!transformer_engine::getenv<bool>("NVTE_CUDNN_MXFP8_NORM", false)) { // No need for separate quantization step if output is unquantized
// TE only supports MXFP8 norm with cuDNN backend force_unfused_kernel = false;
force_unfused_kernel = true; } else if (IsFloat8Quantizers(quantizer.ptr())) {
} else if (N % 128 != 0 || H % 128 != 0) { // Always used fused kernel for FP8 delayed scaling
// cuDNN norm requires full tile for MXFP8 force_unfused_kernel = false;
force_unfused_kernel = true; } else if (IsMXFP8Quantizers(quantizer.ptr())) {
if (transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
// cuDNN MXFP8 kernel requires full tile
force_unfused_kernel = N % 128 != 0 || H % 128 != 0;
} }
} }
TensorWrapper unquantized_out_cu; TensorWrapper unquantized_out_cu;
...@@ -145,6 +150,29 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe ...@@ -145,6 +150,29 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
// Quantize output if using unfused kernel // Quantize output if using unfused kernel
if (force_unfused_kernel) { if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream());
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr =
my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor &amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
}
nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
...@@ -196,6 +224,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -196,6 +224,7 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
py::object out, py::handle quantizer, py::object out, py::handle quantizer,
transformer_engine::DType out_dtype, const int sm_margin, transformer_engine::DType out_dtype, const int sm_margin,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch::detail;
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
using namespace transformer_engine; using namespace transformer_engine;
...@@ -223,14 +252,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -223,14 +252,17 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
} }
// Determine whether to avoid fused kernel // Determine whether to avoid fused kernel
bool force_unfused_kernel = false; bool force_unfused_kernel = true;
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { if (quantizer.is_none()) {
if (!transformer_engine::getenv<bool>("NVTE_CUDNN_MXFP8_NORM", false)) { // No need for separate quantization step if output is unquantized
// TE only supports MXFP8 norm with cuDNN backend force_unfused_kernel = false;
force_unfused_kernel = true; } else if (IsFloat8Quantizers(quantizer.ptr())) {
} else if (N % 128 != 0 || H % 128 != 0) { // Always used fused kernel for FP8 delayed scaling
// cuDNN norm requires full tile for MXFP8 force_unfused_kernel = false;
force_unfused_kernel = true; } else if (IsMXFP8Quantizers(quantizer.ptr())) {
if (transformer_engine::getenv<bool>("NVTE_NORM_FWD_USE_CUDNN")) {
// cuDNN MXFP8 kernel requires full tile
force_unfused_kernel = N % 128 != 0 || H % 128 != 0;
} }
} }
TensorWrapper unquantized_out_cu; TensorWrapper unquantized_out_cu;
...@@ -261,6 +293,29 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w ...@@ -261,6 +293,29 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
// Quantize output if using unfused kernel // Quantize output if using unfused kernel
if (force_unfused_kernel) { if (force_unfused_kernel) {
if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer *>(my_quantizer.get());
nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), at::cuda::getCurrentCUDAStream());
// check if we need to do amax reudction (depending on model parallel configs)
if (my_quantizer_cs->with_amax_reduction) {
c10::intrusive_ptr<dist_group_type> process_group_ptr =
my_quantizer_cs->amax_reduction_group;
// construct torch tesnor from NVTEBasicTensor without reallocating memory
at::Tensor &amax_tensor_torch = my_quantizer_cs->amax;
std::vector<at::Tensor> tensors = {amax_tensor_torch};
// allreduce amax tensor
c10d::AllreduceOptions allreduce_opts;
allreduce_opts.reduceOp = c10d::ReduceOp::MAX;
process_group_ptr->allreduce(tensors, allreduce_opts)->wait();
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);
nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape);
}
nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
"""Internal function used by multiple modules.""" """Internal function used by multiple modules."""
import os
from typing import Any, List, Optional, Tuple, Union, Callable from typing import Any, List, Optional, Tuple, Union, Callable
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce from functools import reduce
...@@ -16,9 +15,6 @@ from .. import cpp_extensions as tex ...@@ -16,9 +15,6 @@ from .. import cpp_extensions as tex
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import get_default_init_method from ..utils import get_default_init_method
from ..tensor.float8_tensor import Float8Tensor from ..tensor.float8_tensor import Float8Tensor
from ..tensor.mxfp8_tensor import MXFP8Quantizer
_use_cudnn_mxfp8_norm = bool(int(os.getenv("NVTE_CUDNN_MXFP8_NORM", "0")))
def _get_normalization_func(normalization: str, forward: bool): def _get_normalization_func(normalization: str, forward: bool):
...@@ -86,26 +82,16 @@ def apply_normalization( ...@@ -86,26 +82,16 @@ def apply_normalization(
inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias) inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias)
split_mxfp8_cast = False return normalization_func(
if not _use_cudnn_mxfp8_norm and isinstance(output_quantizer, MXFP8Quantizer):
split_mxfp8_cast = True
output = normalization_func(
*inputs, *inputs,
eps, eps,
None if split_mxfp8_cast else ln_out, ln_out,
None if split_mxfp8_cast else output_quantizer, output_quantizer,
TE_DType[output_dtype] if output_dtype in TE_DType else output_dtype, TE_DType[output_dtype] if output_dtype in TE_DType else output_dtype,
fwd_ln_sm_margin, fwd_ln_sm_margin,
zero_centered_gamma, zero_centered_gamma,
) )
return (
(output_quantizer.quantize(output[0], out=ln_out), *output[1:])
if split_mxfp8_cast
else output
)
class _NoopCatFunc(torch.autograd.Function): class _NoopCatFunc(torch.autograd.Function):
"""Concatenate tensors, doing a no-op if possible """Concatenate tensors, doing a no-op if possible
......
...@@ -55,9 +55,9 @@ from ..tensor.quantized_tensor import ( ...@@ -55,9 +55,9 @@ from ..tensor.quantized_tensor import (
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..cpp_extensions import ( from ..cpp_extensions import (
general_gemm, general_gemm,
...@@ -160,11 +160,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -160,11 +160,6 @@ class _LayerNormLinear(torch.autograd.Function):
# Configure quantizer for normalization output # Configure quantizer for normalization output
with_quantized_norm = fp8 and not return_layernorm_output with_quantized_norm = fp8 and not return_layernorm_output
# for Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer
# so we need to set with_quantized_norm to False
if isinstance(input_quantizer, Float8CurrentScalingQuantizer):
with_quantized_norm = False
if with_quantized_norm: if with_quantized_norm:
if with_input_all_gather: if with_input_all_gather:
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
......
...@@ -212,8 +212,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -212,8 +212,6 @@ class _LayerNormMLP(torch.autograd.Function):
# for return_layernorm_output: layernorm output = High precision, then cast to FP8 # for return_layernorm_output: layernorm output = High precision, then cast to FP8
# high precision layernorm output and output of the linear are returned # high precision layernorm output and output of the linear are returned
with_quantized_norm = fp8 and not return_layernorm_output with_quantized_norm = fp8 and not return_layernorm_output
if isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer):
with_quantized_norm = False
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment