Unverified Commit 77fa1e59 authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch] Enabling Per-Tensor Current Scaling Recipe (#1471)



* check in per-tensor current scaling full recipe
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

setup basics of current scaling quantizer in python level
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

add test case for current scaling dequantize
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

finish linear layer fwd bwd test, determined error with bf16
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

achieved zero tolerance for Linear by specify gemm use_split_accumulator config
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

enable layernormlinear with current scaling, pass bitwise test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

refactor test case code
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

make current scaling quantizers distrbuted, pass distributed linear&layernormlinear tests
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

bug fix: use cached fp8 recipe in backward
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

fix layernorm_mlp with current scaling, fix activation_helper with current scaling
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

support detailed numerical settings from recipe to quantization kernel
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

resolving MR comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

recipe naming
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* resolve mr comments, remove IS_CURRENT_SCALING template from kernels
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* resolve mr comments, make current scaling c++ test cases
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* add current scaling to test_numerics.py, skip act recomp and grouped linear
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add benchmark for quantizer
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add benchmarks for linear layer
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* bug fix, typo
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve more mr comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* avoid potential race condition by not using from_blob to construct amax tensor in C++
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve more comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Debug linter warnings and license check
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Debug import error in FP8 tensor test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Debug compilation error with CUDA 12.1 for Turing
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* resolve mr comments, fix activation cast fusion
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve comments, add NVTEQuantizationParams for compute scale
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove is_current_scaling check totally from common folder
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* remove benchmarks, will contribute in another repo
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* adjust cs default recipe config
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* adjust comments in test
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Remove current scaling mode from core lib
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Refactor current-scaling-specific logic in core C++ lib

Move amax and scale update functions out of casting functions, and put into dedicated current-scaling source file. Add general API for accessing quantization config object.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add missing header in C++ tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Disable test config with FP8 transpose on Blackwell
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix compilation error in C++ test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 2a95efd3
...@@ -14,6 +14,7 @@ from torch.nn import init ...@@ -14,6 +14,7 @@ from torch.nn import init
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from .base import ( from .base import (
get_workspace, get_workspace,
get_ub, get_ub,
...@@ -55,8 +56,8 @@ from ..tensor.quantized_tensor import ( ...@@ -55,8 +56,8 @@ from ..tensor.quantized_tensor import (
restore_from_saved, restore_from_saved,
) )
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..cpp_extensions import ( from ..cpp_extensions import (
general_gemm, general_gemm,
) )
...@@ -159,6 +160,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -159,6 +160,11 @@ class _LayerNormLinear(torch.autograd.Function):
# Configure quantizer for normalization output # Configure quantizer for normalization output
with_quantized_norm = fp8 and not return_layernorm_output with_quantized_norm = fp8 and not return_layernorm_output
# for Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer
# so we need to set with_quantized_norm to False
if isinstance(input_quantizer, Float8CurrentScalingQuantizer):
with_quantized_norm = False
if with_quantized_norm: if with_quantized_norm:
if with_input_all_gather: if with_input_all_gather:
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -210,6 +216,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -210,6 +216,10 @@ class _LayerNormLinear(torch.autograd.Function):
with_quantized_all_gather = False with_quantized_all_gather = False
if fp8: if fp8:
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
# ln_out in this has two possibilities:
# 1. in FP8 low precision, the cast was done by fusing quantization into layernorm kernel
# 2. in high precision, then we need to cast it and then gather in FP8
# the output ln_out_total will be in FP8, and it's a full tensor
ln_out_total, _ = gather_along_first_dim( ln_out_total, _ = gather_along_first_dim(
ln_out, ln_out,
tp_group, tp_group,
...@@ -290,6 +300,12 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -290,6 +300,12 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total = ub_obj.get_buffer(input_quantizer) ln_out_total = ub_obj.get_buffer(input_quantizer)
nvtx_range_push(f"{nvtx_label}.gemm") nvtx_range_push(f"{nvtx_label}.gemm")
fprop_gemm_use_split_accumulator = _2X_ACC_FPROP
if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"):
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
out, *_, rs_out = general_gemm( out, *_, rs_out = general_gemm(
weightmat, weightmat,
ln_out_total, ln_out_total,
...@@ -297,7 +313,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -297,7 +313,7 @@ class _LayerNormLinear(torch.autograd.Function):
quantization_params=output_quantizer, quantization_params=output_quantizer,
out_dtype=activation_dtype, out_dtype=activation_dtype,
bias=bias, bias=bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=fprop_gemm_use_split_accumulator,
ub=ub_obj, ub=ub_obj,
ub_type=ub_type, ub_type=ub_type,
extra_output=rs_out, extra_output=rs_out,
...@@ -359,6 +375,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -359,6 +375,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.weight = weight ctx.weight = weight
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8 ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
...@@ -431,11 +448,12 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -431,11 +448,12 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ub_bulk_wgrad, ctx.ub_bulk_wgrad,
] ]
) )
and not FP8GlobalStateManager.get_fp8_recipe().delayed() and (ctx.fp8_recipe is not None)
): ):
raise NotImplementedError( if not ctx.fp8_recipe.delayed():
"Comm+GEMM overlap is only supported with FP8 delayed scaling" raise NotImplementedError(
) "Comm+GEMM overlap is only supported with FP8 delayed scaling"
)
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
...@@ -572,6 +590,12 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -572,6 +590,12 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
nvtx_range_push(f"{nvtx_label}.dgrad_gemm") nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
dgrad, *_ = general_gemm( dgrad, *_ = general_gemm(
weight, weight,
grad_output, grad_output,
...@@ -581,7 +605,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -581,7 +605,7 @@ class _LayerNormLinear(torch.autograd.Function):
quantization_params=ctx.grad_input_quantizer, quantization_params=ctx.grad_input_quantizer,
out=dgrad_bulk, out=dgrad_bulk,
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=dgrad_gemm_use_split_accumulator,
ub=ub_obj_dgrad, ub=ub_obj_dgrad,
ub_type=ub_type_dgrad, ub_type=ub_type_dgrad,
extra_output=rs_out, extra_output=rs_out,
...@@ -643,6 +667,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -643,6 +667,14 @@ class _LayerNormLinear(torch.autograd.Function):
# wgrad GEMM # wgrad GEMM
# Note: Fuse with bgrad computation if needed # Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm") nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
wgrad_gemm_use_split_accumulator = (
recipe.fp8_gemm_wgrad.use_split_accumulator
)
wgrad, grad_bias_, *_, rs_out = general_gemm( wgrad, grad_bias_, *_, rs_out = general_gemm(
ln_out_total, ln_out_total,
grad_output, grad_output,
...@@ -654,7 +686,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -654,7 +686,7 @@ class _LayerNormLinear(torch.autograd.Function):
), ),
bias=(bias if (grad_bias is None and not ctx.fp8) else None), bias=(bias if (grad_bias is None and not ctx.fp8) else None),
out=main_grad if ctx.fuse_wgrad_accumulation else None, out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=wgrad_gemm_use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
ub=ub_obj_wgrad, ub=ub_obj_wgrad,
ub_type=ub_type_wgrad, ub_type=ub_type_wgrad,
...@@ -1139,6 +1171,16 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1139,6 +1171,16 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
# elif other recipes (mxfp8, etc)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
warnings.warn( warnings.warn(
...@@ -1332,3 +1374,44 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1332,3 +1374,44 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer, grad_output_quantizer,
grad_input_quantizer, grad_input_quantizer,
) )
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_linear."""
assert (
recipe.float8_current_scaling()
), "current scaling recipe quantizer customization here"
if fwd:
# set configs about amax epsilon and power_2_scale
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon
# also set weight quantizer with same amax_epsilon & power_2_scale
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_WEIGHT
].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_WEIGHT
].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
# parallel related
if self.sequence_parallel and self.parallel_mode == "column":
# set input_quantizer with amax reduction TP group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_size = self.tp_size
else:
# set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here)
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
...@@ -15,6 +15,7 @@ from torch.nn import init ...@@ -15,6 +15,7 @@ from torch.nn import init
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from .base import ( from .base import (
get_workspace, get_workspace,
_ub_communicators, _ub_communicators,
...@@ -59,7 +60,7 @@ from ..tensor.float8_tensor import Float8Tensor ...@@ -59,7 +60,7 @@ from ..tensor.float8_tensor import Float8Tensor
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ._common import apply_normalization, _fix_gathered_fp8_transpose from ._common import apply_normalization, _fix_gathered_fp8_transpose
from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
Quantizer, Quantizer,
...@@ -73,17 +74,53 @@ from ..cpp_extensions import ( ...@@ -73,17 +74,53 @@ from ..cpp_extensions import (
__all__ = ["LayerNormMLP"] __all__ = ["LayerNormMLP"]
def _act_func(activation: str): def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
funcs = { if recipe is None:
"gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), # bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
"relu": (tex.relu, tex.drelu, tex.dbias_drelu), return {
"gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu),
"relu": (tex.relu, tex.drelu, tex.dbias_drelu),
"geglu": (tex.geglu, tex.dgeglu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu),
}
if recipe.delayed() or recipe.mxfp8():
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
return {
"gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu),
"relu": (tex.relu, tex.drelu, tex.dbias_drelu),
"geglu": (tex.geglu, tex.dgeglu, None),
"reglu": (tex.reglu, tex.dreglu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu),
}
# no activation fusion written yet
# Per-tensor current scaling: []
return {
"gelu": (tex.gelu, tex.dgelu, None),
"relu": (tex.relu, tex.drelu, None),
"geglu": (tex.geglu, tex.dgeglu, None), "geglu": (tex.geglu, tex.dgeglu, None),
"reglu": (tex.reglu, tex.dreglu, None), "reglu": (tex.reglu, tex.dreglu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None), "swiglu": (tex.swiglu, tex.dswiglu, None),
"qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), "qgelu": (tex.qgelu, tex.dqgelu, None),
"qgeglu": (tex.qgeglu, tex.dqgeglu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None),
"srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), "srelu": (tex.srelu, tex.dsrelu, None),
} }
def _act_func(activation: str, recipe: Optional[Recipe] = None):
# based on each quantization mode, we have different kernel fusion supported:
# bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
# Per-tensor current scaling: []
funcs = _get_act_func_supported_list(recipe)
if activation not in funcs: if activation not in funcs:
raise NotImplementedError("Activation type " + activation + " is not supported!") raise NotImplementedError("Activation type " + activation + " is not supported!")
return funcs[activation] return funcs[activation]
...@@ -161,7 +198,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -161,7 +198,9 @@ class _LayerNormMLP(torch.autograd.Function):
"Comm+GEMM overlap is only supported with FP8 delayed scaling" "Comm+GEMM overlap is only supported with FP8 delayed scaling"
) )
activation_func = _act_func(activation)[0] activation_func = _act_func(
activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
)[0]
device = inp.device device = inp.device
# Cast for native AMP # Cast for native AMP
...@@ -175,6 +214,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -175,6 +214,8 @@ class _LayerNormMLP(torch.autograd.Function):
# for return_layernorm_output: layernorm output = High precision, then cast to FP8 # for return_layernorm_output: layernorm output = High precision, then cast to FP8
# high precision layernorm output and output of the linear are returned # high precision layernorm output and output of the linear are returned
with_quantized_norm = fp8 and not return_layernorm_output with_quantized_norm = fp8 and not return_layernorm_output
if isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer):
with_quantized_norm = False
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output
...@@ -220,6 +261,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -220,6 +261,8 @@ class _LayerNormMLP(torch.autograd.Function):
zero_centered_gamma, zero_centered_gamma,
) )
ln_out_return = ln_out if return_layernorm_output else None
# Prepare GEMM input # Prepare GEMM input
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
ln_out_gathered = False ln_out_gathered = False
...@@ -229,6 +272,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -229,6 +272,10 @@ class _LayerNormMLP(torch.autograd.Function):
with_quantized_all_gather = False with_quantized_all_gather = False
if fp8: if fp8:
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
# ln_out in this has two possibilities:
# 1. in FP8 low precision, the cast was done by fusing quantization into layernorm kernel
# 2. in high precision, then we need to cast it and then gather in FP8
# the output ln_out_total will be in FP8, and it's a full tensor
ln_out_total, _ = gather_along_first_dim( ln_out_total, _ = gather_along_first_dim(
ln_out, ln_out,
tp_group, tp_group,
...@@ -240,26 +287,19 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -240,26 +287,19 @@ class _LayerNormMLP(torch.autograd.Function):
if ub_overlap_ag: if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer, False) ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer, False)
else: else:
if fp8:
if not isinstance(ln_out, QuantizedTensor):
fc1_input_quantizer.set_usage(
rowwise=True, columnwise=backwards_needs_fc1_input
)
ln_out = fc1_input_quantizer(ln_out)
elif backwards_needs_fc1_input:
ln_out.update_usage(rowwise_usage=True, columnwise_usage=True)
# here ln_out is in FP8 low precision, the cast was either done by fc1_input_quantizer
# or fused into the layernorm kernel
# ln_out_total represents the full fp8 tensor, in this case, it's the same as ln_out
ln_out_total = ln_out ln_out_total = ln_out
# If residual connection is after LN, we need `ln_out`
# tensor in higher precision, this comes at the cost
# of an extra fp8 cast.
ln_out_return = None
if return_layernorm_output:
ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out
if fp8 and not with_quantized_all_gather:
ln_out_total = fc1_input_quantizer(ln_out_total)
if ln_out_gathered:
rank = torch.distributed.get_rank(tp_group)
slice_start = rank * ln_out.size(0)
slice_end = (rank + 1) * ln_out.size(0)
ln_out = ln_out_total[
slice_start:slice_end, ...
] # TODO(pgadzinski) - check this # pylint: disable=fixme
else:
ln_out = ln_out_total
# Cast weights to expected dtype # Cast weights to expected dtype
fc1_weight_final = fc1_weight fc1_weight_final = fc1_weight
fc2_weight_final = fc2_weight fc2_weight_final = fc2_weight
...@@ -459,6 +499,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -459,6 +499,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.save_for_backward(*tensors_to_save) ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer
ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer
ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_input_quantizer = grad_input_quantizer
...@@ -546,11 +587,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -546,11 +587,12 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.ub_bulk_wgrad, ctx.ub_bulk_wgrad,
] ]
) )
and not FP8GlobalStateManager.get_fp8_recipe().delayed() and (ctx.fp8_recipe is not None)
): ):
raise NotImplementedError( if not ctx.fp8_recipe.delayed():
"Comm+GEMM overlap is only supported with FP8 delayed scaling" raise NotImplementedError(
) "Comm+GEMM overlap is only supported with FP8 delayed scaling"
)
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
...@@ -733,22 +775,36 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -733,22 +775,36 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias)
if ctx.grad_fc1_output_quantizer is not None: if ctx.grad_fc1_output_quantizer is not None:
dact = ctx.grad_fc1_output_quantizer(dact) dact = ctx.grad_fc1_output_quantizer(dact)
elif _act_func(ctx.activation)[2] is not None and ctx.fp8: elif (
_act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None
and ctx.fp8
):
# Fusion: gemm, bias + gelu + quantize # Fusion: gemm, bias + gelu + quantize
dbias_dact_quantize_func = _act_func(ctx.activation)[2] dbias_dact_quantize_func = _act_func(
ctx.activation, ctx.fp8_recipe if ctx.fp8 else None
)[2]
fc1_bias_grad, dact = dbias_dact_quantize_func( fc1_bias_grad, dact = dbias_dact_quantize_func(
fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.grad_fc1_output_quantizer fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.grad_fc1_output_quantizer
) # quantize bgrad gelu fused ) # quantize bgrad gelu fused
else: else:
# Fusion: gemm + gelu, # Fusion: gemm + gelu,
if not fc2_dgrad_gemm_gelu_fusion: if not fc2_dgrad_gemm_gelu_fusion:
activation_func_bwd = _act_func(ctx.activation)[1] activation_func_bwd = _act_func(
ctx.activation, ctx.fp8_recipe if ctx.fp8 else None
)[1]
dact = activation_func_bwd( dact = activation_func_bwd(
fc2_dgrad, fc1_out.to(ctx.activation_dtype), None fc2_dgrad, fc1_out.to(ctx.activation_dtype), None
) # activation in high precision ) # activation in high precision
if ctx.fp8: if ctx.fp8:
fc1_bias_grad, dact = tex.bgrad_quantize(dact, ctx.grad_fc1_output_quantizer) # TODO zhongboz: per-tensor current scaling has no bgrad fusion for now
if isinstance(ctx.grad_fc1_output_quantizer, Float8CurrentScalingQuantizer):
fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0)
dact = ctx.grad_fc1_output_quantizer(dact)
else:
fc1_bias_grad, dact = tex.bgrad_quantize(
dact, ctx.grad_fc1_output_quantizer
)
else: else:
fuse_gemm_and_bias_fc1_wgrad = ( fuse_gemm_and_bias_fc1_wgrad = (
True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1 True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1
...@@ -1286,6 +1342,15 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1286,6 +1342,15 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs
if FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
# elif for other recipes (mxfp8, etc.)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
warnings.warn( warnings.warn(
...@@ -1494,3 +1559,76 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1494,3 +1559,76 @@ class LayerNormMLP(TransformerEngineBaseModule):
grad_fc2_output_quantizer, grad_fc2_output_quantizer,
grad_input_quantizer, grad_input_quantizer,
) )
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + layernorm_mlp."""
assert (
recipe.float8_current_scaling()
), "current scaling recipe quantizer customization here"
if fwd:
# fc1_input_quantizer: set configs about amax epsilon and power_2_scale
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon
# fc2_input_quantizer
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM2_INPUT
].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM2_INPUT
].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon
# fc1_weight_quantizer: also set numerical configs about weight
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_WEIGHT
].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_WEIGHT
].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
# fc2_weight_quantizer
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM2_WEIGHT
].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM2_WEIGHT
].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
# parallel related
if self.sequence_parallel and self.set_parallel_mode:
# fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_size = self.tp_size
else:
# grad_fc2_output_quantizer: set configs about amax epsilon and power_2_scale for grad_fc2_output_quantizer
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
# grad_fc1_output_quantizer: also set numerical configs for grad_fc1_output_quantizer
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
if self.sequence_parallel and self.set_parallel_mode:
# grad_fc2_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_size = self.tp_size
...@@ -11,6 +11,7 @@ import torch ...@@ -11,6 +11,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from .base import ( from .base import (
get_workspace, get_workspace,
get_ub, get_ub,
...@@ -228,6 +229,12 @@ class _Linear(torch.autograd.Function): ...@@ -228,6 +229,12 @@ class _Linear(torch.autograd.Function):
inputmat_total = ub_obj.get_buffer(input_quantizer) inputmat_total = ub_obj.get_buffer(input_quantizer)
nvtx_range_push(f"{nvtx_label}.gemm") nvtx_range_push(f"{nvtx_label}.gemm")
fprop_gemm_use_split_accumulator = _2X_ACC_FPROP
if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"):
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
out, *_, rs_out = general_gemm( out, *_, rs_out = general_gemm(
weightmat, weightmat,
inputmat_total, inputmat_total,
...@@ -235,7 +242,7 @@ class _Linear(torch.autograd.Function): ...@@ -235,7 +242,7 @@ class _Linear(torch.autograd.Function):
quantization_params=output_quantizer, quantization_params=output_quantizer,
out_dtype=out_dtype, out_dtype=out_dtype,
bias=bias, bias=bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=fprop_gemm_use_split_accumulator,
ub=ub_obj, ub=ub_obj,
ub_type=ub_type, ub_type=ub_type,
extra_output=rs_out, extra_output=rs_out,
...@@ -277,6 +284,7 @@ class _Linear(torch.autograd.Function): ...@@ -277,6 +284,7 @@ class _Linear(torch.autograd.Function):
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.fp8 = fp8 ctx.fp8 = fp8
ctx.input_quantizer = input_quantizer ctx.input_quantizer = input_quantizer
ctx.grad_output_quantizer = grad_output_quantizer ctx.grad_output_quantizer = grad_output_quantizer
...@@ -344,11 +352,12 @@ class _Linear(torch.autograd.Function): ...@@ -344,11 +352,12 @@ class _Linear(torch.autograd.Function):
ctx.ub_bulk_wgrad, ctx.ub_bulk_wgrad,
] ]
) )
and not FP8GlobalStateManager.get_fp8_recipe().delayed() and (ctx.fp8_recipe is not None)
): ):
raise NotImplementedError( if not ctx.fp8_recipe.delayed():
"Comm+GEMM overlap is only supported with FP8 delayed scaling" raise NotImplementedError(
) "Comm+GEMM overlap is only supported with FP8 delayed scaling"
)
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking
...@@ -483,6 +492,14 @@ class _Linear(torch.autograd.Function): ...@@ -483,6 +492,14 @@ class _Linear(torch.autograd.Function):
# dgrad GEMM # dgrad GEMM
nvtx_range_push(f"{nvtx_label}.dgrad_gemm") nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = (
recipe.fp8_gemm_dgrad.use_split_accumulator
)
dgrad, *_, rs_out = general_gemm( dgrad, *_, rs_out = general_gemm(
weight_fp8, weight_fp8,
grad_output, grad_output,
...@@ -492,7 +509,7 @@ class _Linear(torch.autograd.Function): ...@@ -492,7 +509,7 @@ class _Linear(torch.autograd.Function):
quantization_params=ctx.grad_input_quantizer, quantization_params=ctx.grad_input_quantizer,
out=dgrad_bulk, out=dgrad_bulk,
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=dgrad_gemm_use_split_accumulator,
ub=ub_obj_dgrad, ub=ub_obj_dgrad,
ub_type=ub_type_dgrad, ub_type=ub_type_dgrad,
extra_output=rs_out, extra_output=rs_out,
...@@ -551,6 +568,14 @@ class _Linear(torch.autograd.Function): ...@@ -551,6 +568,14 @@ class _Linear(torch.autograd.Function):
# wgrad GEMM # wgrad GEMM
# Note: Fuse with bgrad computation if needed # Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm") nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_wgrad"):
wgrad_gemm_use_split_accumulator = (
recipe.fp8_gemm_wgrad.use_split_accumulator
)
wgrad, grad_bias_, _, rs_out = general_gemm( wgrad, grad_bias_, _, rs_out = general_gemm(
inputmat_total, inputmat_total,
grad_output, grad_output,
...@@ -562,7 +587,7 @@ class _Linear(torch.autograd.Function): ...@@ -562,7 +587,7 @@ class _Linear(torch.autograd.Function):
), ),
bias=(bias if (grad_bias is None and not ctx.fp8) else None), bias=(bias if (grad_bias is None and not ctx.fp8) else None),
out=main_grad if ctx.fuse_wgrad_accumulation else None, out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=_2X_ACC_WGRAD, use_split_accumulator=wgrad_gemm_use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
ub=ub_obj_wgrad, ub=ub_obj_wgrad,
ub_type=ub_type_wgrad, ub_type=ub_type_wgrad,
...@@ -955,6 +980,16 @@ class Linear(TransformerEngineBaseModule): ...@@ -955,6 +980,16 @@ class Linear(TransformerEngineBaseModule):
else: else:
self.gemm_bias_unfused_add = False self.gemm_bias_unfused_add = False
def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
"""Init scales and amaxes for fwd | bwd."""
super().set_meta_tensor(fwd, recipe)
# customize quantizers based on each recipe & layer configs
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_current_scaling():
self._customize_quantizers_float8_current_scaling(fwd, recipe)
# elif for other recipes (mxfp8, etc.)
def reset_parameters(self, defer_init=False): def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init) super().reset_parameters(defer_init=defer_init)
...@@ -1118,3 +1153,56 @@ class Linear(TransformerEngineBaseModule): ...@@ -1118,3 +1153,56 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer, grad_output_quantizer,
grad_input_quantizer, grad_input_quantizer,
) )
def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
assert (
recipe.float8_current_scaling()
), "current scaling recipe quantizer customization here"
if fwd:
# set configs about amax epsilon and power_2_scale
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon
# also set weight quantizer with same amax_epsilon & power_2_scale
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_WEIGHT
].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_WEIGHT
].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon
# paralle related
if self.sequence_parallel and self.parallel_mode == "column":
# customize input_quantizer with amax reduction TP group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].with_amax_reduction = True
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_size = self.tp_size
else:
# set grad_output_quantizer with amax epsilon and power_2_scale
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
# parallel related
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_size = self.tp_size
...@@ -14,6 +14,7 @@ from transformer_engine_torch import DType as TE_DType ...@@ -14,6 +14,7 @@ from transformer_engine_torch import DType as TE_DType
from ..utils import devices_match, non_tn_fp8_gemm_supported from ..utils import devices_match, non_tn_fp8_gemm_supported
from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..constants import dist_group_type
aten = torch.ops.aten aten = torch.ops.aten
...@@ -166,6 +167,167 @@ class Float8Quantizer(Quantizer): ...@@ -166,6 +167,167 @@ class Float8Quantizer(Quantizer):
) )
class Float8CurrentScalingQuantizer(Quantizer):
"""Builder class for FP8 tensors with per-tensor current scaling
High-precision tensors (e.g. in FP32 or BF16) are quantized by
multiplying with a scaling factor and casting to FP8. The max-abs
value ("amax") in the tensor is computed directly by scanning the input
high-precision tensor, without the need of any history window.
Unlike delayed scaling, scale and amax tensors are not needed to initialize the
quantizer, becuse they are simply GPU buffers that will be filled by current
scaling quantization kernels, instead of using values taken from delayed scaling
history window. Therefore, device parameter is needed for tensor allocation.
Both Float8CurrentScalingQuantizer and Float8Quantizer produces Float8Tensor,
because they are both per-tensor scaling, ie. one scaling factor per tensor.
"""
"""Scaling factor to multiply when quantizing to FP8"""
scale: torch.Tensor
"""Max-abs value from last FP8 cast"""
amax: torch.Tensor
"""FP8 datatype"""
dtype: TE_DType
"""amax reduction options"""
with_amax_reduction: bool
amax_reduction_group: Optional[dist_group_type]
amax_reduction_size: Optional[int]
"""Options about how to quantize the tensor"""
force_pow_2_scales: bool
amax_epsilon: float
def __init__(
self,
fp8_dtype: TE_DType,
device: torch.device,
*,
rowwise: bool = True,
columnwise: bool = True,
with_amax_reduction: bool = False,
amax_reduction_group: Optional[dist_group_type] = None,
amax_reduction_size: Optional[int] = 1,
force_pow_2_scales: bool = False,
amax_epsilon: float = 0.0,
) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
self.scale = torch.empty(1, dtype=torch.float32, device=device)
self.amax = torch.empty(1, dtype=torch.float32, device=device)
self.dtype = fp8_dtype
self.with_amax_reduction = with_amax_reduction
self.amax_reduction_group = amax_reduction_group
self.amax_reduction_size = amax_reduction_size
self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon
def update_quantized(
self,
src: torch.Tensor,
dst: QuantizedTensor,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> QuantizedTensor:
if not isinstance(dst, Float8Tensor):
raise ValueError("Float8CurrentScalingQuantizer can only update Float8Tensor")
# Make sure input is in expected format
if not devices_match(src.device, dst.device):
src = src.to(device=dst.device)
if not src.is_contiguous():
src = src.contiguous()
# Launch cast kernel
tex.quantize(src, self, dst, noop_flag)
# Update FP8 dtype
dst._fp8_dtype = self.dtype
return dst
def make_empty(
self,
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
) -> Float8Tensor:
# Canonicalize tensor attributes
if device is None:
device = torch.device("cuda")
# Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device)
# Allocate FP8 data transpose if needed
data_transpose = None
if self.columnwise_usage:
inner_dim = data.size(-1)
data_transpose = torch.empty(
inner_dim,
data.numel() // inner_dim,
dtype=torch.uint8,
device=device,
)
# Construct FP8 tensor
return Float8Tensor(
shape=shape,
dtype=dtype,
data=data,
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device),
fp8_dtype=self.dtype,
requires_grad=requires_grad,
data_transpose=data_transpose,
quantizer=self,
)
def calibrate(self, tensor: torch.Tensor) -> None:
# current scaling don't need to calibrate
return
def create_tensor_from_data(
self,
data: torch.Tensor,
fake_dtype=torch.float32,
requires_grad: bool = False,
internal: bool = False,
):
"""
Create Float8Tensor from raw uint8 data, unlike delayed scaling,
self.scale doesn't mean anything, so we are simply creating empty scale_inv
"""
assert data.dtype in [
torch.uint8,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
]
if internal:
return Float8TensorBase(
data=data,
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device),
fp8_dtype=self.dtype,
requires_grad=requires_grad,
data_transpose=None,
quantizer=self,
)
return Float8Tensor(
shape=data.shape,
dtype=fake_dtype,
data=data,
fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device),
fp8_dtype=self.dtype,
requires_grad=requires_grad,
data_transpose=None,
quantizer=self,
)
class Float8Tensor(Float8TensorBase, QuantizedTensor): class Float8Tensor(Float8TensorBase, QuantizedTensor):
"""Experimental tensor class with FP8 data """Experimental tensor class with FP8 data
...@@ -192,7 +354,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -192,7 +354,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
FP8 format. FP8 format.
data_transpose: torch.Tensor, optional data_transpose: torch.Tensor, optional
FP8 transpose data in a uint8 tensor FP8 transpose data in a uint8 tensor
quantizer: Float8Quantizer, optional quantizer: Float8Quantizer, Float8CurrentScalingQuantizer, optional
Builder class for FP8 tensors Builder class for FP8 tensors
""" """
...@@ -229,10 +391,9 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -229,10 +391,9 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
""" """
if self._quantizer is not None: if self._quantizer is not None:
return self._quantizer return self._quantizer
return Float8Quantizer( # Now the quantizer for Float8Tensor can be not just Float8Quantizer (delayed scaling)
scale=torch.reciprocal(self._scale_inv), raise ValueError(
amax=torch.empty(1, dtype=torch.float32, device=self.device), "Float8Tensor's quantizer is None, cannot get a quantizer from Float8Tensor variable"
fp8_dtype=self._fp8_dtype,
) )
def quantize_( def quantize_(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment