Unverified Commit 903e1f4f authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Fix ONNX exports (#437)



* Fix ONNX exports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 2da34d41
......@@ -763,156 +763,6 @@ def test_export_rmsnorm(
validate_result(
fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs)
@skip_FP8
@pytest.mark.parametrize("softmax_fn", [
softmax_defs.ScaledUpperTriangMaskedSoftmax,
softmax_defs.ScaledMaskedSoftmax,
softmax_defs.ScaledSoftmax,
te.softmax.FusedScaleMaskSoftmax,
])
# Softmax kernel only supports FP16 or BF16!
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"])
def test_export_softmax(seed_default_rng, set_max_seq_len, softmax_fn, precision):
class Test_Softmax(nn.Module):
def __init__(self, softmax_fn, fake_bf16_io, mask_inp=False):
super().__init__()
self.softmax_fn = softmax_fn
self.scale = 8 # arbitrary value
self.mask_inp = mask_inp
self.fused_scaled_softmax = None
self.fake_bf16_io = fake_bf16_io
if self.softmax_fn == te.softmax.FusedScaleMaskSoftmax:
self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax(
mask_func=te.utils.attention_mask_func,
softmax_in_fp32=True,
)
def forward(self, inp, mask):
if self.fake_bf16_io:
inp = inp.type(torch.bfloat16)
if self.fused_scaled_softmax:
ret = self.fused_scaled_softmax(inp, mask, "causal", self.scale)
else:
if self.mask_inp:
ret = self.softmax_fn.apply(inp, mask, self.scale)
else:
ret = self.softmax_fn.apply(inp, self.scale)
if self.fake_bf16_io:
ret = ret.type(torch.float32)
return ret
fake_bf16_io = precision == "fake-torch.bfloat16"
precision = torch.bfloat16 if fake_bf16_io else precision
# Set dimensions (these are arbitrary).
batch_size, n_heads, seq_len_q, seq_len_k = 64, 96, 32, 32
mask = None
input_names = ["input", "mask"]
inp_shape = [batch_size, n_heads, seq_len_q, seq_len_k]
if softmax_fn == softmax_defs.ScaledUpperTriangMaskedSoftmax:
inp_shape = [batch_size, seq_len_q, seq_len_k]
kernel_str = "ScaledUpperTriangMaskedSoftmax"
model = Test_Softmax(softmax_fn, fake_bf16_io)
elif softmax_fn == softmax_defs.ScaledMaskedSoftmax:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(1, 1, seq_len_q, seq_len_k, device="cuda", dtype=precision)
mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
kernel_str = "ScaledMaskedSoftmax"
model = Test_Softmax(softmax_fn, fake_bf16_io, mask_inp=True)
elif softmax_fn == softmax_defs.ScaledSoftmax:
kernel_str = "ScaledSoftmax"
model = Test_Softmax(softmax_fn, fake_bf16_io)
elif softmax_fn == te.softmax.FusedScaleMaskSoftmax:
kernel_str = "TorchSoftmax"
model = Test_Softmax(softmax_fn, fake_bf16_io)
input_tensor = torch.randn(*inp_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision)
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
fname = f"{kernel_str}{high_prec_str}.onnx"
inp = (input_tensor, mask)
dynamic_axes = {}
if mask is not None:
dynamic_axes = {"mask": {2:"seq_len_q", 3:"seq_len_k"}}
do_export(model, inp, fname, input_names=input_names, dynamic_axes=dynamic_axes)
te_outputs = te_infer(model, inp, is_fp8=False)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
if fake_bf16_io or precision != torch.bfloat16:
atol = 5e-2 if fake_bf16_io else 1e-3
validate_result(fname, inp, model, atol=atol, input_names=input_names, te_outputs=te_outputs)
# Test dynamically generated softmax mask.
# Softmax kernel only supports FP16 or BF16!
@skip_FP8
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16, "fake-torch.bfloat16"])
def test_softmax_mask_fn(seed_default_rng, precision):
fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode
precision = torch.bfloat16 if fake_bf16_io else precision
class Test_Softmax(nn.Module):
def __init__(self, use_default_te_mask_fn: bool, fake_bf16_io: bool):
super().__init__()
self.scale = 1 # arbitrary value
self.fake_bf16_io = fake_bf16_io
if use_default_te_mask_fn:
os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = "0"
else:
os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{seq_len_q}"
# Use NVTE_MASKED_SOFTMAX_FUSION to force TE to use forward_torch_softmax
# even when is_in_onnx_export_mode()==False.
os.environ["NVTE_MASKED_SOFTMAX_FUSION"] = "0"
self.fused_scaled_softmax = te.softmax.FusedScaleMaskSoftmax(
mask_func=te.utils.attention_mask_func,
softmax_in_fp32=True,
)
def forward(self, inp, mask):
if self.fake_bf16_io:
inp = inp.type(torch.bfloat16)
ret = self.fused_scaled_softmax(inp, mask, "causal", scale=self.scale)
if self.fake_bf16_io:
ret = ret.type(torch.float)
return ret
# Set dimensions (these are arbitrary).
mask = None
batch_size, n_heads, seq_len_q, seq_len_k = 64, 96, 32, 32
assert seq_len_q == seq_len_k # This is a causal (TRILU) mask
inp_shape = [batch_size, n_heads, seq_len_q, seq_len_k]
input_tensor = torch.randn(
*inp_shape, device="cuda", dtype=torch.float if fake_bf16_io else precision)
inp = (input_tensor, mask)
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
# Compare the outputs of TE when using the default softmax mask
# to the TE outputs produced when using the ONNX-compatible causal mask.
# This verifies that _get_onnx_export_causal_mask generates a correct mask.
model = Test_Softmax(use_default_te_mask_fn=True, fake_bf16_io=fake_bf16_io)
te_outputs_default_mask = te_infer(model, inp, is_fp8=True)
with te.onnx_export(True):
# ONNX export mode forces use of the ONNX-compatible causal mask.
model_onnx_mask = Test_Softmax(use_default_te_mask_fn=False, fake_bf16_io=fake_bf16_io)
te_outputs_onnx_mask = te_infer(model_onnx_mask, inp, is_fp8=True)
compare_outputs(te_outputs_default_mask, te_outputs_onnx_mask,
atol=0, rtol=0, max_errors_printed=10, allow_cnt_errors=0, fname="softmax masking")
# Compare the outputs of TE when using the default softmax mask
# to the ORT ONNX outputs produced when using the ONNX-compatible causal mask.
input_names = ["input", "mask"]
kernel_str = "FusedScaleMaskSoftmax"
fname = f"{kernel_str}{high_prec_str}.onnx"
do_export(model, inp, fname, input_names=input_names)
serialize_inputs_outputs(fname, inp, te_outputs=te_outputs_default_mask, input_names=input_names)
if fake_bf16_io or precision != torch.bfloat16:
atol = 1e-2 if fake_bf16_io else 1e-3
validate_result(
fname, inp, model_onnx_mask, atol=atol,
input_names=input_names, te_outputs=te_outputs_default_mask)
@pytest.mark.parametrize("scale_factor", [1])
@pytest.mark.parametrize("use_fp8", [False, True])
......@@ -1159,13 +1009,13 @@ def test_export_core_attention(
query_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
key_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
value_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
input_names = ["query", "key", "value", "attention_mask", "attn_mask_type"]
input_names = ["query", "key", "value", "attention_mask"]
attention_mask = None
if use_mask:
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(qkv_size[1], qkv_size[2], qkv_size[0], qkv_size[0], device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (query_layer, key_layer, value_layer, attention_mask, attn_mask_type)
inp = (query_layer, key_layer, value_layer, attention_mask)
mask_str = get_attn_mask_str(use_mask, attn_mask_type)
high_prec_str = dtype2str(precision)
......@@ -1175,6 +1025,7 @@ def test_export_core_attention(
num_attention_heads=num_attention_heads,
kv_channels=kv_channels,
attention_dropout=0.5,
attn_mask_type=attn_mask_type,
).to(device='cuda')
do_export(model,
inp,
......@@ -1190,8 +1041,9 @@ def test_export_core_attention(
test_configs_multihead_attention = [
#"use_mask, attn_mask_type"
(False, "no_mask"), # calls ScaledUpperTriangMaskedSoftmax
(False, "causal"), # calls ScaledUpperTriangMaskedSoftmax
(True, "padding"), # calls ScaledMaskedSoftmax
(False, "padding"), # calls ScaledSoftmax
]
test_configs_attention_type = [
#"input_layernorm, attention_type, fuse_qkv_params"
......@@ -1265,6 +1117,7 @@ def test_export_multihead_attention(
model = te.MultiheadAttention(
*attention_args,
attn_mask_type=attn_mask_type,
params_dtype=precision,
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
......@@ -1273,8 +1126,8 @@ def test_export_multihead_attention(
return_bias=True,
).to(device='cuda')
inp_context = (hidden_states_context, attention_mask, encoder_output, attn_mask_type)
input_names = ["hidden_states", "attention_mask", "encoder_output", "attn_mask_type"]
inp_context = (hidden_states_context, attention_mask, encoder_output)
input_names = ["hidden_states", "attention_mask", "encoder_output"]
output_names=["attention_output", "attention_bias"]
do_export(model, inp_context, fname, use_fp8, input_names=input_names, output_names=output_names,
dynamic_axes={"hidden_states": {0: "seq", 1:"bs"},
......@@ -1342,13 +1195,13 @@ def test_export_transformer_layer(
num_attention_heads = 4
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
input_names = ["input", "attention_mask", "self_attn_mask_type"]
input_names = ["input", "attention_mask"]
attention_mask = None
if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1.
probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision)
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (input_tensor, attention_mask, attn_mask_type)
inp = (input_tensor, attention_mask)
fp8_str = "_fp8" if use_fp8 else ""
fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else ""
......@@ -1360,6 +1213,7 @@ def test_export_transformer_layer(
hidden_size,
ffn_hidden_size,
num_attention_heads,
self_attn_mask_type=attn_mask_type,
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
......@@ -1541,16 +1395,17 @@ def test_export_gpt_generation(
hidden_size,
ffn_hidden_size,
num_attention_heads,
self_attn_mask_type=attn_mask_type,
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma).to(device='cuda')
# "Context phase": use full input sequence length
input_names = ["input", "attention_mask", "self_attn_mask_type"]
input_names = ["input"]
output_names = ["output"]
input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
inp = (input_tensor, None, attn_mask_type)
inp = (input_tensor,)
do_export(model, inp, fname, use_fp8,
input_names=input_names, output_names=output_names,
dynamic_axes={"input": {0: "seq", 1:"bs"},
......
......@@ -610,6 +610,7 @@ class _CombineKV(torch.autograd.Function):
tensors = split_tensor_along_dim(grad_outputs[0], ctx.dim, 2)
return tensors[0], tensors[1], None
class UnfusedDotProductAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms
BMM1 -> softmax + dropout -> BMM2
......@@ -1324,11 +1325,6 @@ class DotProductAttention(torch.nn.Module):
and set the environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0`. In order
to disable`flash-attn` entirely, set :attr:`NVTE_FLASH_ATTN=0`.
.. warning::
Argument :attr:`attn_mask_type` has been moved to the `forward` method and
is deprecated. It will be fully removed in future releases.
Parameters
----------
num_attention_heads : int
......@@ -1348,6 +1344,12 @@ class DotProductAttention(torch.nn.Module):
layer_number: int, default = `None`
layer number of the current `DotProductAttention` when multiple such modules
are concatenated, for instance in consecutive transformer blocks.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation. Overridden by
:attr:`attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different
mask for training and inference. The init arg is useful for cases
involving compilation/tracing, e.g. ONNX export.
Parallelism parameters
----------------------
......@@ -1374,7 +1376,7 @@ class DotProductAttention(torch.nn.Module):
kv_channels: int,
num_gqa_groups: Optional[int] = None,
attention_dropout: float = 0.0,
attn_mask_type: Optional[str] = None,
attn_mask_type: str = "causal",
sequence_parallel: bool = False,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
......@@ -1387,13 +1389,6 @@ class DotProductAttention(torch.nn.Module):
) -> None:
super().__init__()
if attn_mask_type is not None:
warnings.warn(
"Argument :attr:`attn_mask_type` has been moved to the `forward` method and"
"is deprecated. It will be fully removed in future releases.",
category=DeprecationWarning,
)
self.attn_mask_type = attn_mask_type
self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_group = tp_group
......@@ -1487,7 +1482,7 @@ class DotProductAttention(torch.nn.Module):
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attn_mask_type: str = "causal",
attn_mask_type: Optional[str] = None,
checkpoint_core_attention: bool = False,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
......@@ -1543,7 +1538,7 @@ class DotProductAttention(torch.nn.Module):
Value tensor.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out softmax input when not using flash-attn.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `None`
type of attention mask passed into softmax operation.
checkpoint_core_attention : bool, default = `False`
If true, forward activations for attention are recomputed
......@@ -1558,13 +1553,7 @@ class DotProductAttention(torch.nn.Module):
Whether to use the fast path to set output tensors to 0 or not.
"""
if self.attn_mask_type is not None:
warnings.warn(
"Argument :attr:`attn_mask_type` has been moved to the `forward` method and"
"is deprecated. It will be fully removed in future releases.",
category=DeprecationWarning,
)
# Keep previous functionality for current users.
if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
......@@ -1697,11 +1686,6 @@ class MultiheadAttention(torch.nn.Module):
Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`attn_mask_type` is set to `"causal"`.
.. warning::
Argument :attr:`attn_mask_type` has been moved to the `forward` method and
is deprecated. It will be fully removed in future releases.
Parameters
----------
hidden_size : int
......@@ -1727,6 +1711,12 @@ class MultiheadAttention(torch.nn.Module):
layer_number: int, default = `None`
layer number of the current `TransformerLayer` when multiple such modules are
concatenated to form a transformer block.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation. Overridden by
:attr:`attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different
mask for training and inference. The init arg is useful for cases
involving compilation/tracing, e.g. ONNX export.
num_gqa_groups : int, default = `None`
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
......@@ -1817,7 +1807,7 @@ class MultiheadAttention(torch.nn.Module):
init_method: Optional[Callable] = None,
output_layer_init_method: Optional[Callable] = None,
layer_number: Optional[int] = None,
attn_mask_type: Optional[str] = None,
attn_mask_type: str = "causal",
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
num_gqa_groups: Optional[int] = None,
......@@ -1843,13 +1833,6 @@ class MultiheadAttention(torch.nn.Module):
) -> None:
super().__init__()
if attn_mask_type is not None:
warnings.warn(
"Argument :attr:`attn_mask_type` has been moved to the `forward` method and"
"is deprecated. It will be fully removed in future releases.",
category=DeprecationWarning,
)
self.attn_mask_type = attn_mask_type
self.layer_number = layer_number
self.input_layernorm = input_layernorm
......@@ -2034,7 +2017,7 @@ class MultiheadAttention(torch.nn.Module):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_output: Optional[torch.Tensor] = None,
attn_mask_type: str = "causal",
attn_mask_type: Optional[str] = None,
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False,
inference_params: Optional[Any] = None,
......@@ -2057,7 +2040,7 @@ class MultiheadAttention(torch.nn.Module):
Input tensor.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input.
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `None`
type of attention mask passed into softmax operation.
encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using
......@@ -2092,13 +2075,7 @@ class MultiheadAttention(torch.nn.Module):
"""
# hidden_states: [sq, b, h]
if self.attn_mask_type is not None:
warnings.warn(
"Argument :attr:`attn_mask_type` has been moved to the `forward` method and"
"is deprecated. It will be fully removed in future releases.",
category=DeprecationWarning,
)
# Keep previous functionality for current users.
if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
if attn_mask_type == "padding" and attention_mask is not None:
......
......@@ -73,10 +73,9 @@ class TransformerLayer(torch.nn.Module):
Arguments :attr:`attention_softmax_in_fp32` and :attr:`apply_query_key_layer_scaling`
are deprecated and will be fully removed in future releases.
.. warning::
Argument :attr:`self_attn_mask_type` has been moved to the `forward` method and
is deprecated. It will be fully removed in future releases.
.. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when
:attr:`self_attn_mask_type` is set to `"causal"`.
Parameters
----------
......@@ -127,6 +126,12 @@ class TransformerLayer(torch.nn.Module):
kv_channels: int, default = `None`
number of key-value channels. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`.
self_attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation. Overridden by
:attr:`self_attn_mask_type` in the `forward` method. The forward
arg is useful for dynamically changing mask types, e.g. a different
mask for training and inference. The init arg is useful for cases
involving compilation/tracing, e.g. ONNX export.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
......@@ -212,7 +217,7 @@ class TransformerLayer(torch.nn.Module):
output_layer_init_method: Optional[Callable] = None,
layer_number: Optional[int] = None,
kv_channels: Optional[int] = None,
self_attn_mask_type: Optional[str] = None,
self_attn_mask_type: str = "causal",
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
params_dtype: Optional[torch.dtype] = None,
......@@ -239,13 +244,6 @@ class TransformerLayer(torch.nn.Module):
) -> None:
super().__init__()
if self_attn_mask_type is not None:
warnings.warn(
"Argument :attr:`self_attn_mask_type` has been moved to the `forward` method and"
"is deprecated. It will be fully removed in future releases.",
category=DeprecationWarning,
)
warnings.warn(
"Arguments `attention_softmax_in_fp32` and `apply_query_key_layer_scaling`"
"are deprecated and will be fully removed in future releases.",
......@@ -445,7 +443,7 @@ class TransformerLayer(torch.nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
self_attn_mask_type: str = "causal",
self_attn_mask_type: Optional[str] = None,
encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[torch.Tensor] = None,
is_first_microbatch: Optional[bool] = None,
......@@ -470,7 +468,7 @@ class TransformerLayer(torch.nn.Module):
Input tensor.
attention_mask : Optional[torch.Tensor], default = `None`
Boolean tensor used to mask out self-attention softmax input.
self_attn_mask_type: {'causal', 'padding'}, default = `causal`
self_attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal`
type of attention mask passed into softmax operation.
encoder_output : Optional[torch.Tensor], default = `None`
Output of the encoder block to be fed into the decoder block if using
......@@ -507,13 +505,7 @@ class TransformerLayer(torch.nn.Module):
Whether to set output tensors to 0 or not before use.
"""
if self.self_attn_mask_type is not None:
warnings.warn(
"Argument :attr:`self_attn_mask_type` has been moved to the `forward` method and"
"is deprecated. It will be fully removed in future releases.",
category=DeprecationWarning,
)
# Keep previous functionality for current users.
if self_attn_mask_type is None:
self_attn_mask_type = self.self_attn_mask_type
assert (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment