Unverified Commit a0f44354 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Improve softmax ONNX export tests (#370)



* Add dynamically shaped input mask in test_export_softmax
* Fix test_softmax_mask_fn - use env. var `NVTE_ONNX_KVCACHE_MAX_SEQ_LEN` to control whether the test uses the default mask generation function or dynamic TRILU mask slicing.
* Change core_attention ONNX export test: use "no_mask" as attn mask type when testing `te.attention.DotProductAttention` w/o masking.
* Use ORT CUDA backend by default.
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>
parent ecd4f808
...@@ -256,7 +256,7 @@ def validate_result( ...@@ -256,7 +256,7 @@ def validate_result(
print("registered custom FP8 Q/DQ ops!") print("registered custom FP8 Q/DQ ops!")
"""Create an ONNX Runtime session for validation.""" """Create an ONNX Runtime session for validation."""
kwargs = {} kwargs = {"providers": ['CUDAExecutionProvider', 'CPUExecutionProvider']}
if is_fp8: if is_fp8:
sess_options = ort.SessionOptions() sess_options = ort.SessionOptions()
load_custom_ops(sess_options) load_custom_ops(sess_options)
...@@ -807,17 +807,17 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision ...@@ -807,17 +807,17 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision
precision = torch.bfloat16 if fake_bf16_io else precision precision = torch.bfloat16 if fake_bf16_io else precision
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
in_features, hidden_size = 64, 256 batch_size, n_heads, seq_len_q, seq_len_k = 64, 96, 32, 32
mask = None mask = None
input_names = ["input", "mask"] input_names = ["input", "mask"]
inp_shape = [hidden_size, in_features, in_features, in_features] inp_shape = [batch_size, n_heads, seq_len_q, seq_len_k]
if softmax_fn == softmax_defs.ScaledUpperTriangMaskedSoftmax: if softmax_fn == softmax_defs.ScaledUpperTriangMaskedSoftmax:
inp_shape = [hidden_size, in_features, in_features] inp_shape = [batch_size, seq_len_q, seq_len_k]
kernel_str = "ScaledUpperTriangMaskedSoftmax" kernel_str = "ScaledUpperTriangMaskedSoftmax"
model = Test_Softmax(softmax_fn, fake_bf16_io) model = Test_Softmax(softmax_fn, fake_bf16_io)
elif softmax_fn == softmax_defs.ScaledMaskedSoftmax: elif softmax_fn == softmax_defs.ScaledMaskedSoftmax:
# Generate a random mask with 50% probability for 0 or 1. # Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(hidden_size, 1, in_features, in_features, device="cuda", dtype=precision) probs = 0.5 * torch.ones(1, 1, seq_len_q, seq_len_k, device="cuda", dtype=precision)
mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
kernel_str = "ScaledMaskedSoftmax" kernel_str = "ScaledMaskedSoftmax"
model = Test_Softmax(softmax_fn, fake_bf16_io, mask_inp=True) model = Test_Softmax(softmax_fn, fake_bf16_io, mask_inp=True)
...@@ -832,8 +832,10 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision ...@@ -832,8 +832,10 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
fname = f"{kernel_str}{high_prec_str}.onnx" fname = f"{kernel_str}{high_prec_str}.onnx"
inp = (input_tensor, mask) inp = (input_tensor, mask)
dynamic_axes = {}
do_export(model, inp, fname, input_names=input_names) if mask is not None:
dynamic_axes = {"mask": {2:"seq_len_q", 3:"seq_len_k"}}
do_export(model, inp, fname, input_names=input_names, dynamic_axes=dynamic_axes)
te_outputs = te_infer(model, inp, is_fp8=False) te_outputs = te_infer(model, inp, is_fp8=False)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if fake_bf16_io or precision != torch.bfloat16: if fake_bf16_io or precision != torch.bfloat16:
...@@ -845,16 +847,22 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision ...@@ -845,16 +847,22 @@ def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision
# Softmax kernel only supports FP16 or BF16! # Softmax kernel only supports FP16 or BF16!
@skip_FP8 @skip_FP8
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"]) @pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"])
def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision): def test_softmax_mask_fn(seed_default_rng, precision):
fake_bf16_io = precision == "fake-torch.bfloat16" fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode # reset precision to torch.bfloat16 after capturing fake BF16 mode
precision = torch.bfloat16 if fake_bf16_io else precision precision = torch.bfloat16 if fake_bf16_io else precision
class Test_Softmax(nn.Module): class Test_Softmax(nn.Module):
def __init__(self, use_onnx_mask_fn: bool, fake_bf16_io: bool): def __init__(self, use_default_te_mask_fn: bool, fake_bf16_io: bool):
super().__init__() super().__init__()
self.scale=1 # arbitrary value self.scale = 1 # arbitrary value
self.fake_bf16_io=fake_bf16_io self.fake_bf16_io = fake_bf16_io
if use_default_te_mask_fn:
os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = "0"
else:
os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{seq_len_q}"
# Use NVTE_MASKED_SOFTMAX_FUSION to force TE to use forward_torch_softmax # Use NVTE_MASKED_SOFTMAX_FUSION to force TE to use forward_torch_softmax
# even when is_in_onnx_export_mode()==False. # even when is_in_onnx_export_mode()==False.
os.environ["NVTE_MASKED_SOFTMAX_FUSION"] = "0" os.environ["NVTE_MASKED_SOFTMAX_FUSION"] = "0"
...@@ -873,10 +881,10 @@ def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision): ...@@ -873,10 +881,10 @@ def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision):
return ret return ret
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
in_features = 64
hidden_size = 256
mask = None mask = None
inp_shape = [hidden_size, in_features, in_features, in_features] batch_size, n_heads, seq_len_q, seq_len_k = 64, 96, 32, 32
assert seq_len_q == seq_len_k # This is a causal (TRILU) mask
inp_shape = [batch_size, n_heads, seq_len_q, seq_len_k]
input_tensor = torch.randn( input_tensor = torch.randn(
*inp_shape, device="cuda", dtype=torch.float if fake_bf16_io else precision) *inp_shape, device="cuda", dtype=torch.float if fake_bf16_io else precision)
inp = (input_tensor, mask) inp = (input_tensor, mask)
...@@ -884,11 +892,12 @@ def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision): ...@@ -884,11 +892,12 @@ def test_softmax_mask_fn(seed_default_rng, set_max_seq_len, precision):
# Compare the outputs of TE when using the default softmax mask # Compare the outputs of TE when using the default softmax mask
# to the TE outputs produced when using the ONNX-compatible causal mask. # to the TE outputs produced when using the ONNX-compatible causal mask.
model = Test_Softmax(use_onnx_mask_fn=False, fake_bf16_io=fake_bf16_io) # This verifies that _get_onnx_export_causal_mask generates a correct mask.
model = Test_Softmax(use_default_te_mask_fn=True, fake_bf16_io=fake_bf16_io)
te_outputs_default_mask = te_infer(model, inp, is_fp8=True) te_outputs_default_mask = te_infer(model, inp, is_fp8=True)
with te.onnx_export(True): with te.onnx_export(True):
# ONNX export mode forces use of the ONNX-compatible causal mask. # ONNX export mode forces use of the ONNX-compatible causal mask.
model_onnx_mask = Test_Softmax(use_onnx_mask_fn=True, fake_bf16_io=fake_bf16_io) model_onnx_mask = Test_Softmax(use_default_te_mask_fn=False, fake_bf16_io=fake_bf16_io)
te_outputs_onnx_mask = te_infer(model_onnx_mask, inp, is_fp8=True) te_outputs_onnx_mask = te_infer(model_onnx_mask, inp, is_fp8=True)
compare_outputs(te_outputs_default_mask, te_outputs_onnx_mask, compare_outputs(te_outputs_default_mask, te_outputs_onnx_mask,
atol=0, rtol=0, max_errors_printed=10, allow_cnt_errors=0, fname="softmax masking") atol=0, rtol=0, max_errors_printed=10, allow_cnt_errors=0, fname="softmax masking")
...@@ -1129,14 +1138,14 @@ def test_export_layernorm_mlp( ...@@ -1129,14 +1138,14 @@ def test_export_layernorm_mlp(
@skip_FP8 @skip_FP8
@pytest.mark.parametrize( @pytest.mark.parametrize(
"precision, use_mask, attn_mask_type", [ "precision, use_mask, attn_mask_type", [
(torch.float32, False, None), # calls forward_torch_softmax (torch.float32, True, "padding"), # calls forward_torch_softmax (apply user mask)
(torch.float32, True, None), # calls forward_torch_softmax (torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.float16, False, "causal"), # calls ScaledUpperTriangMaskedSoftmax (torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.float16, True, "padding"), # calls ScaledMaskedSoftmax (torch.float16, True, "padding"), # calls forward_torch_softmax (apply user mask)
(torch.float16, False, "padding"), # calls ScaledSoftmax (torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
(torch.bfloat16, False, "causal"), # calls ScaledUpperTriangMaskedSoftmax (torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
(torch.bfloat16, True, "padding"), # calls ScaledMaskedSoftmax (torch.bfloat16, True, "padding"), # calls forward_torch_softmax (apply user mask)
(torch.bfloat16, False, "padding"), # calls ScaledSoftmax (torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
]) ])
def test_export_core_attention( def test_export_core_attention(
seed_default_rng, seed_default_rng,
...@@ -1164,10 +1173,6 @@ def test_export_core_attention( ...@@ -1164,10 +1173,6 @@ def test_export_core_attention(
high_prec_str = dtype2str(precision) high_prec_str = dtype2str(precision)
fname = f"te.core_attention{mask_str}{high_prec_str}.onnx" fname = f"te.core_attention{mask_str}{high_prec_str}.onnx"
if attn_mask_type is None:
attn_mask_type = 'causal'
input_names = ["query", "key", "value"]
inp = (query_layer, key_layer, value_layer)
model = te.attention.DotProductAttention( model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
kv_channels=kv_channels, kv_channels=kv_channels,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment