"tests/vscode:/vscode.git/clone" did not exist on "9c29c93114e0b57367258538da8005de27db5b8f"
Unverified Commit f04b094c authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] ONNX test fix + export for FP8 attention (#2598)



* jjit bug fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix'
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* lint fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 2104e4c1
...@@ -6,4 +6,5 @@ ...@@ -6,4 +6,5 @@
: ${XML_LOG_DIR:=/logs} : ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py # NVTE_UnfusedDPA_Emulate_FP8=1 enables FP8 attention emulation when no native backend is available
NVTE_UnfusedDPA_Emulate_FP8=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py
...@@ -713,6 +713,14 @@ def test_export_layernorm_mlp_activation(seed_default_rng, activation): ...@@ -713,6 +713,14 @@ def test_export_layernorm_mlp_activation(seed_default_rng, activation):
_test_export_layernorm_mlp(activation=activation) _test_export_layernorm_mlp(activation=activation)
# Quantization recipes with fp8_dpa=True for attention emulation export test
dpa_quantization_recipes = [None] # None = no quantization
if fp8_available:
dpa_quantization_recipes.append(recipe.DelayedScaling(fp8_dpa=True))
dpa_quantization_recipes.append(recipe.Float8CurrentScaling(fp8_dpa=True))
@pytest.mark.parametrize("fp8_recipe", dpa_quantization_recipes)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"precision, use_mask, attn_mask_type", "precision, use_mask, attn_mask_type",
[ [
...@@ -730,6 +738,7 @@ def test_export_core_attention( ...@@ -730,6 +738,7 @@ def test_export_core_attention(
precision: torch.dtype, precision: torch.dtype,
use_mask: bool, use_mask: bool,
attn_mask_type: str, attn_mask_type: str,
fp8_recipe: recipe.Recipe,
): ):
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64) seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64)
...@@ -749,22 +758,25 @@ def test_export_core_attention( ...@@ -749,22 +758,25 @@ def test_export_core_attention(
mask_str = get_attn_mask_str(use_mask, attn_mask_type) mask_str = get_attn_mask_str(use_mask, attn_mask_type)
high_prec_str = dtype2str(precision) high_prec_str = dtype2str(precision)
fname = f"te.core_attention{mask_str}{high_prec_str}.onnx" fp8_str = "_fp8_dpa" if fp8_recipe is not None else ""
fname = f"te.core_attention{fp8_str}{mask_str}{high_prec_str}.onnx"
is_fp8 = fp8_recipe is not None
model = te.attention.DotProductAttention( model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
kv_channels=kv_channels, kv_channels=kv_channels,
attention_dropout=0.5,
qkv_format=qkv_format, qkv_format=qkv_format,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
).to(device="cuda") ).to(device="cuda")
do_export(model, inp, fname, input_names=input_names, fp8_recipe=None) do_export(model, inp, fname, input_names=input_names, fp8_recipe=fp8_recipe)
te_outputs = te_infer(model, inp, is_fp8=False, fp8_recipe=None) te_outputs = te_infer(model, inp, is_fp8=is_fp8, fp8_recipe=fp8_recipe)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if precision in (torch.bfloat16,): if precision in (torch.bfloat16,):
return return
atol = 5e-1 if is_fp8 else 1e-2
validate_result( validate_result(
fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs fname, inp, model, is_fp8=True, atol=atol, input_names=input_names, te_outputs=te_outputs
) )
......
...@@ -164,6 +164,11 @@ class FP8EmulationFunc(torch.autograd.Function): ...@@ -164,6 +164,11 @@ class FP8EmulationFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout): def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
if is_in_onnx_export_mode():
return FP8EmulationFunc.onnx_forward(
tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout
)
if quantizer_name == "QKV_quantizer": if quantizer_name == "QKV_quantizer":
query_layer, key_layer, value_layer = [ query_layer, key_layer, value_layer = [
x.contiguous() for x in [tensor1, tensor2, tensor3] x.contiguous() for x in [tensor1, tensor2, tensor3]
...@@ -202,6 +207,47 @@ class FP8EmulationFunc(torch.autograd.Function): ...@@ -202,6 +207,47 @@ class FP8EmulationFunc(torch.autograd.Function):
tensors = grad1, grad2, grad3 tensors = grad1, grad2, grad3
return tensors[0], tensors[1], tensors[2], None, None, None return tensors[0], tensors[1], tensors[2], None, None, None
@staticmethod
def onnx_forward(tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout=None):
"""
ONNX-compatible forward for FP8 emulation using operations with defined ONNX translations.
"""
# pylint: disable=unused-argument
is_qkv_quantizer = quantizer_name == "QKV_quantizer"
assert isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
), "ONNX FP8 emulation path supports only Float8 quantizers."
if is_qkv_quantizer:
# Flatten + concatenate + quantize + split. Equivalent to combine_and_quantize Case 3.
orig_dtype = tensor1.dtype
shapes = [tensor1.shape, tensor2.shape, tensor3.shape]
numels = [tensor1.numel(), tensor2.numel(), tensor3.numel()]
# Flatten and concatenate
combined = torch.cat(
[tensor1.reshape(-1), tensor2.reshape(-1), tensor3.reshape(-1)], dim=0
)
# Quantize + dequantize combined tensor using quantizer's ONNX methods
combined_fp8 = quantizer.onnx_quantize(combined)
out = quantizer.onnx_dequantize(combined_fp8).to(orig_dtype)
# Split back
out1 = out[: numels[0]].reshape(shapes[0])
out2 = out[numels[0] : numels[0] + numels[1]].reshape(shapes[1])
out3 = out[numels[0] + numels[1] :].reshape(shapes[2])
return out1, out2, out3
if quantizer_name in ["S_quantizer", "O_quantizer"]:
# Emulate FP8 on single tensor using quantizer's ONNX methods
orig_dtype = tensor1.dtype
t_fp8 = quantizer.onnx_quantize(tensor1)
out = quantizer.onnx_dequantize(t_fp8).to(orig_dtype)
return out, tensor2, tensor3
# Pass-through
return tensor1, tensor2, tensor3
class UnfusedDotProductAttention(torch.nn.Module): class UnfusedDotProductAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms """Parallel attention w/o QKV and Proj Gemms
......
...@@ -1552,7 +1552,9 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -1552,7 +1552,9 @@ class DotProductAttention(TransformerEngineBaseModule):
) )
if use_unfused_attention: if use_unfused_attention:
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" allow_emulation = (
os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode()
)
if checkpoint_core_attention: if checkpoint_core_attention:
return self._checkpointed_attention_forward( return self._checkpointed_attention_forward(
self.unfused_attention, self.unfused_attention,
......
...@@ -479,7 +479,9 @@ def get_attention_backend( ...@@ -479,7 +479,9 @@ def get_attention_backend(
logger.debug("Disabling FlashAttention 3 for FP8 training") logger.debug("Disabling FlashAttention 3 for FP8 training")
use_flash_attention_3 = False use_flash_attention_3 = False
if use_unfused_attention: if use_unfused_attention:
allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" allow_emulation = (
os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" or is_in_onnx_export_mode()
)
if not allow_emulation: if not allow_emulation:
logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") logger.debug("Disabling UnfusedDotProductAttention for FP8 attention")
use_unfused_attention = False use_unfused_attention = False
......
...@@ -46,17 +46,35 @@ if torch_version() >= (2, 2, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1" ...@@ -46,17 +46,35 @@ if torch_version() >= (2, 2, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"
# Decorator to disable Torch Dynamo # Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308 # See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2": if torch.__version__ >= "2":
import torch._dynamo import torch._dynamo
if torch.__version__ >= "2.1": def no_torch_dynamo(recursive=True):
no_torch_dynamo = lambda recursive=True: lambda f: ( """Decorator to disable Torch Dynamo, except during ONNX export."""
f if is_in_onnx_export_mode() else torch._dynamo.disable(f, recursive=recursive)
) def decorator(f):
else:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable disabled_f = (
torch._dynamo.disable(f, recursive=recursive)
if torch.__version__ >= "2.1"
else torch._dynamo.disable(f)
)
@wraps(f)
def wrapper(*args, **kwargs):
if is_in_onnx_export_mode():
return f(*args, **kwargs)
return disabled_f(*args, **kwargs)
return wrapper
return decorator
else:
# Fallback for PyTorch < 2.0: no-op decorator
def no_torch_dynamo(recursive=True): # pylint: disable=unused-argument
"""No-op decorator for PyTorch < 2.0."""
return lambda func: func
def set_jit_fusion_options() -> None: def set_jit_fusion_options() -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment