Unverified Commit 06a38cc0 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] ONNX export of FP8 Current Scaling (#2068)



* Compute amax in normalization forward in current scaling in untuned kernels
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* apply tims suggestions
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent a5c79876
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
"\n", "\n",
"<b>Note:</b>\n", "<b>Note:</b>\n",
"\n", "\n",
"Currently, export to ONNX is supported only for high precision, FP8 delayed scaling and MXFP8.\n", "Currently, export to ONNX is supported only for high precision, FP8 delayed scaling, FP8 current scaling and MXFP8.\n",
"\n", "\n",
"</div>\n", "</div>\n",
"\n", "\n",
......
...@@ -65,6 +65,7 @@ if mxfp8_available: ...@@ -65,6 +65,7 @@ if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling()) fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_available: if fp8_available:
fp8_recipes.append(recipe.DelayedScaling()) fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(None) fp8_recipes.append(None)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
...@@ -81,11 +82,11 @@ all_normalizations = ["LayerNorm", "RMSNorm"] ...@@ -81,11 +82,11 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
], ],
outputs=[PyCustomOpDef.dt_uint8], outputs=[PyCustomOpDef.dt_uint8],
) )
def trt_fp8_quantize(t, scale): def trt_fp8_quantize(t, scale_inv):
"""FP8 quantization extension for ONNX Runtime.""" """FP8 quantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda() x = torch.from_numpy(t).cuda()
q = te.tensor.float8_tensor.Float8Quantizer( q = te.tensor.float8_tensor.Float8Quantizer(
scale=1 / torch.from_numpy(scale).cuda(), scale=1 / torch.from_numpy(scale_inv).cuda(),
amax=torch.zeros([1]).cuda(), amax=torch.zeros([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3, fp8_dtype=tex.DType.kFloat8E4M3,
) )
...@@ -101,11 +102,11 @@ def trt_fp8_quantize(t, scale): ...@@ -101,11 +102,11 @@ def trt_fp8_quantize(t, scale):
], ],
outputs=[PyCustomOpDef.dt_float], outputs=[PyCustomOpDef.dt_float],
) )
def trt_fp8_dequantize(t, scale): def trt_fp8_dequantize(t, scale_inv):
"""FP8 dequantization extension for ONNX Runtime.""" """FP8 dequantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda() x = torch.from_numpy(t).cuda()
q = te.tensor.float8_tensor.Float8Quantizer( q = te.tensor.float8_tensor.Float8Quantizer(
scale=1 / torch.from_numpy(scale).cuda(), scale=1 / torch.from_numpy(scale_inv).cuda(),
amax=torch.zeros([1]).cuda(), amax=torch.zeros([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3, fp8_dtype=tex.DType.kFloat8E4M3,
) )
...@@ -593,7 +594,9 @@ def _test_export_layernorm_linear( ...@@ -593,7 +594,9 @@ def _test_export_layernorm_linear(
fname, fname,
inp, inp,
model, model,
atol=1e-3, # For current scaling we use Float8Quantizer in tests + amax computed by hand,
# which has slightly different numerics than Float8CurrentScalingQuantizer.
atol=1e-3 if fp8_recipe.__class__ is not recipe.Float8CurrentScaling else 2e-2,
is_fp8=fp8_recipe is not None, is_fp8=fp8_recipe is not None,
te_outputs=te_outputs, te_outputs=te_outputs,
) )
...@@ -1150,6 +1153,11 @@ def test_trt_integration(fp8_recipe: recipe.Recipe): ...@@ -1150,6 +1153,11 @@ def test_trt_integration(fp8_recipe: recipe.Recipe):
ffn_hidden_size=128, ffn_hidden_size=128,
num_attention_heads=4, num_attention_heads=4,
).eval() ).eval()
if type(fp8_recipe) == recipe.Float8CurrentScaling:
# TODO(pgadzinski): Attention does not work with TRT for FP8CurrentScaling
model = te.LayerNormMLP(128, 128)
inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),) inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),)
with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
......
...@@ -112,7 +112,9 @@ schema = defs.OpSchema( ...@@ -112,7 +112,9 @@ schema = defs.OpSchema(
doc="TRT FP8 Quantize Linear used for inference.", doc="TRT FP8 Quantize Linear used for inference.",
inputs=[ inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"), defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"),
defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for quantization"), defs.OpSchema.FormalParameter(
"scale_inv", "tensor(float)", "Inverse scale factor for quantization"
),
], ],
outputs=[defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor")], outputs=[defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor")],
) )
...@@ -126,11 +128,10 @@ TRT_FP8QuantizeLinear = onnxscript.values.Op( ...@@ -126,11 +128,10 @@ TRT_FP8QuantizeLinear = onnxscript.values.Op(
@torch.library.custom_op("tex::fp8_dequantize", mutates_args=[]) @torch.library.custom_op("tex::fp8_dequantize", mutates_args=[])
def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor: def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor:
"""Dequantize from Float8Tensor used for inference.""" """Dequantize from Float8Tensor used for inference."""
scale_tensor = torch.tensor(scale, dtype=torch.float32, device=tensor.device)
quantizer = Float8Quantizer( quantizer = Float8Quantizer(
scale_tensor, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3 1 / scale_inv, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3
) )
quantizer_tensor = quantizer.create_tensor_from_data(tensor, fake_dtype=torch.float32) quantizer_tensor = quantizer.create_tensor_from_data(tensor, fake_dtype=torch.float32)
return quantizer_tensor.dequantize() return quantizer_tensor.dequantize()
...@@ -143,10 +144,9 @@ def _(tensor: torch.Tensor, _) -> torch.Tensor: ...@@ -143,10 +144,9 @@ def _(tensor: torch.Tensor, _) -> torch.Tensor:
def onnx_dequantize_fp8_symbolic( def onnx_dequantize_fp8_symbolic(
tensor: onnxscript.onnx_types.TensorType, scale: float tensor: onnxscript.onnx_types.TensorType, scale_inv: onnxscript.onnx_types.TensorType
) -> onnxscript.onnx_types.TensorType: ) -> onnxscript.onnx_types.TensorType:
"""Symbolic dequantize from Float8Tensor used for inference.""" """Symbolic dequantize from Float8Tensor used for inference."""
scale_inv = op.Constant(value_float=1 / scale)
return TRT_FP8DequantizeLinear(tensor, scale_inv) return TRT_FP8DequantizeLinear(tensor, scale_inv)
...@@ -157,7 +157,9 @@ schema = defs.OpSchema( ...@@ -157,7 +157,9 @@ schema = defs.OpSchema(
doc="TRT FP8 Dequantize Linear from Float8Tensor used for inference.", doc="TRT FP8 Dequantize Linear from Float8Tensor used for inference.",
inputs=[ inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"), defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"),
defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for dequantization"), defs.OpSchema.FormalParameter(
"scale_inv", "tensor(float)", "Inverse scale factor for dequantization"
),
], ],
outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")], outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")],
) )
...@@ -166,6 +168,43 @@ TRT_FP8DequantizeLinear = onnxscript.values.Op( ...@@ -166,6 +168,43 @@ TRT_FP8DequantizeLinear = onnxscript.values.Op(
opset=trt_opset, name="TRT_FP8DequantizeLinear", op_schema=schema opset=trt_opset, name="TRT_FP8DequantizeLinear", op_schema=schema
) )
# ONNX FP8 Current Scaling Quantization
@torch.library.custom_op("tex::fp8_cs_quantize", mutates_args=[])
def onnx_cs_quantize_fp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Quantize to FP8 with current scaling; returns (uint8, scale_inv)."""
if tensor.dtype != torch.float32:
tensor = tensor.to(torch.float32)
amax = tensor.abs().max()
eps = torch.tensor(1e-12, dtype=torch.float32, device=tensor.device)
amax = torch.maximum(amax, eps)
fp8_max = torch.tensor(448, dtype=torch.float32, device=tensor.device)
scale = fp8_max / amax
q = torch.ops.tex.fp8_quantize(tensor, scale)
scale_inv = 1 / scale
return q, scale_inv
@onnx_cs_quantize_fp8_op.register_fake
def _(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device), torch.ones(
1, dtype=torch.float32, device=tensor.device
)
def onnx_quantize_fp8_cs_symbolic(
tensor: onnxscript.onnx_types.TensorType,
):
"""Symbolic quantize with current scaling; computes scale_inv from tensor."""
# scale_inv = 1 / max(abs(tensor))
amax = op.ReduceMax(op.Abs(tensor), keepdims=0)
eps = op.Constant(value_float=1.0e-12)
amax = op.Max(amax, eps)
scale_inv = op.Div(amax, op.Constant(value_float=448.0))
q = TRT_FP8QuantizeLinear(tensor, scale_inv)
return q, scale_inv
# ONNX MXFP8 Quantization # ONNX MXFP8 Quantization
...@@ -356,6 +395,7 @@ te_translation_table = { ...@@ -356,6 +395,7 @@ te_translation_table = {
torch.ops.tex.gemm_inf.default: onnx_gemm_inf_symbolic, torch.ops.tex.gemm_inf.default: onnx_gemm_inf_symbolic,
torch.ops.tex.fp8_quantize.default: onnx_quantize_fp8_symbolic, torch.ops.tex.fp8_quantize.default: onnx_quantize_fp8_symbolic,
torch.ops.tex.fp8_dequantize.default: onnx_dequantize_fp8_symbolic, torch.ops.tex.fp8_dequantize.default: onnx_dequantize_fp8_symbolic,
torch.ops.tex.fp8_cs_quantize.default: onnx_quantize_fp8_cs_symbolic,
torch.ops.tex.mxfp8_quantize.default: onnx_quantize_mxfp8_symbolic, torch.ops.tex.mxfp8_quantize.default: onnx_quantize_mxfp8_symbolic,
torch.ops.tex.mxfp8_dequantize.default: onnx_dequantize_mxfp8_symbolic, torch.ops.tex.mxfp8_dequantize.default: onnx_dequantize_mxfp8_symbolic,
torch.ops.tex.layernorm.default: onnx_layernorm_symbolic, torch.ops.tex.layernorm.default: onnx_layernorm_symbolic,
......
...@@ -177,7 +177,7 @@ class Float8Quantizer(Quantizer): ...@@ -177,7 +177,7 @@ class Float8Quantizer(Quantizer):
def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor: def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor:
"""Function using primitives with ONNX defined translations.""" """Function using primitives with ONNX defined translations."""
out = torch.ops.tex.fp8_dequantize(tensor._data, self.scale.item()) out = torch.ops.tex.fp8_dequantize(tensor._data, tensor._scale_inv)
out = out.to(tensor.dtype) out = out.to(tensor.dtype)
return out return out
...@@ -350,15 +350,25 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -350,15 +350,25 @@ class Float8CurrentScalingQuantizer(Quantizer):
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Function using primitives with ONNX defined translations.""" """Function using primitives with ONNX defined translations."""
raise NotImplementedError( if tensor.dtype != torch.float32:
"Float8CurrentScalingQuantizer does not support ONNX quantization yet." tensor = tensor.to(torch.float32)
data, scale_inv = torch.ops.tex.fp8_cs_quantize(tensor)
return Float8Tensor(
shape=data.shape,
dtype=torch.float32,
data=data,
fp8_scale_inv=scale_inv,
fp8_dtype=self.dtype,
requires_grad=False,
data_transpose=None,
quantizer=self,
) )
def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor: def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor:
"""Function using primitives with ONNX defined translations.""" """Function using primitives with ONNX defined translations."""
raise NotImplementedError( out = torch.ops.tex.fp8_dequantize(tensor._data, tensor._scale_inv)
"Float8CurrentScalingQuantizer does not support ONNX dequantization yet." out = out.to(tensor.dtype)
) return out
def _canonicalized_amax_reduction_group(self) -> dist_group_type: def _canonicalized_amax_reduction_group(self) -> dist_group_type:
"""Get process group for amax reduction""" """Get process group for amax reduction"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment