Unverified Commit 68f60b89 authored by Neta Zmora's avatar Neta Zmora Committed by GitHub
Browse files

Add env. var. for efficient text-generation in inference (#214)



* Dynamically-generated causal attention mask (for ONNX export)

TE's default causal mask is square (seq_len, seq_len) and is
dynamically allocated for different sequence sizes. Dynamic
allocation and dictionary lookups are not supported by ONNX.
GPT generative phase uses rectangular masks.

This commit forces softmax to use `forward_torch_softmax` and
to dynamically generate an attention mask when exporting to ONNX.
The mask is generated w/o using conditional control-flow by generating
a  (k_seq_len, k_seq_len) mask and slicing it to (q_seq_len, k_seq_len)

An alternate implementation is to pre-allocate a mask of shape
(max_seq, max_seq) and to slice that. This solution is more performant
at the expense of space, but the problem is the TE doesn't have a concept
of max_seq.

* Add to test_export_softmax a test for te.softmax.FusedScaleMaskSoftmax.
* Add test_softmax_mask_fn to test that TE's default attention mask and
the new ONNX-compatible mask produce the same behavior.
* Add test_export_gpt_generation to test that the ONNX model can correctly
handle inputs with different shapes and that the attention mask it adjusted
on-the-fly to different sequence lengths.

Misc:
* Add a PRNG seeding fixture for more stability in tests.
* Add dynamic shapes for ONNX input/output tests.
* Allow validate_result to compare ORT output to pre-computed TE outputs.
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Add NVTE_ONNX_KVCACHE_MAX_SEQ_LEN for efficient text-generation in inference

* Introduce an environment variable (NVTE_ONNX_KVCACHE_MAX_SEQ_LEN) to set the maximum sequence length.
In ONNX inference with KV-Cache optimizations for GPT text generation, the attention mask shape can be square (context-phase) or rectangular (generation-phase).
When exporting to ONNX and this variable is set, TE preallocates an upper triangular (k=1) matrix with a size as prescribed by the variable, and dynamically slices the mask for the required shape.
TE models can be exported to ONNX when NVTE_ONNX_KVCACHE_MAX_SEQ_LEN is not configured, but the attention masking is always square and not fit for efficient text generation.

* Work-around torch.onnx.export bug that incorrectly folds
layer_norm(data, scale=add(gamma,1)) to layer_norm(data, scale=gamma)
when we use LN with zero-centered gamma.

* ONNX export tests
  * Add a fixture (seed_default_rng) to seed the PRNG
  * Add a fixture (set_max_seq_len) to set the max sequence length when exporting to ONNX for GPT text generation
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Fix linting errors
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Remove immutable default values from a couple of function signatures
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Add @skip_FP8 to test_export_gpt_generation
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>

* Update transformer_engine/pytorch/softmax.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CI error for softmax export
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarNeta Zmora <nzmora@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 0d2021ef
This diff is collapsed.
...@@ -9,8 +9,8 @@ import torch ...@@ -9,8 +9,8 @@ import torch
from torch import nn from torch import nn
import torch._C._onnx as _C_onnx import torch._C._onnx as _C_onnx
from torch.onnx import _type_utils from torch.onnx import _type_utils
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode
THREADS_PER_WARP = 32 THREADS_PER_WARP = 32
THREADS_PER_BLOCK = 128 THREADS_PER_BLOCK = 128
...@@ -25,6 +25,26 @@ def _get_default_causal_mask(sq: int) -> torch.Tensor: ...@@ -25,6 +25,26 @@ def _get_default_causal_mask(sq: int) -> torch.Tensor:
return _default_causal_mask[sq] return _default_causal_mask[sq]
def _get_onnx_export_causal_mask(
seq_q: int, seq_k: int, onnx_causal_mask: torch.Tensor
) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input, for ONNX export.
ONNX does not support dynamic control-flow and requires non-square masks when
using a KV-cache (seq_k's length len(context)+len(generative) while seq_q's length is 1).
Argument `onnx_causal_mask` is a square triu (k=1) mask that is sliced to the correct
shape for GPT context and generation phases.
In the context phase the derived mask is a square triu of shape (seq_k, seq_k), and in
the generation phase the mask is rectangular with shape (1, seq_k).
"""
assert len(onnx_causal_mask.size()) == 2
assert onnx_causal_mask.size(0) == onnx_causal_mask.size(1)
assert onnx_causal_mask.size(0) >= (seq_k-seq_q) >= 0
derived_mask = onnx_causal_mask[seq_k-seq_q:seq_k, :seq_k]
return derived_mask
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
""" """
Fused operation which performs following three operations in sequence Fused operation which performs following three operations in sequence
...@@ -214,6 +234,17 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -214,6 +234,17 @@ class FusedScaleMaskSoftmax(nn.Module):
self.mask_func = mask_func self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32 self.softmax_in_fp32 = softmax_in_fp32
# Users exporting to ONNX can optimize the attention mask for GPT text generation.
self.kvcache_max_seq = int(os.getenv("NVTE_ONNX_KVCACHE_MAX_SEQ_LEN", "-1"))
if self.kvcache_max_seq > 0:
self.register_buffer(
"onnx_causal_mask",
torch.triu(
torch.ones(self.kvcache_max_seq, self.kvcache_max_seq, device="cuda"),
diagonal=1
).bool(),
persistent=False)
def forward( def forward(
self, self,
inp: torch.Tensor, inp: torch.Tensor,
...@@ -231,7 +262,7 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -231,7 +262,7 @@ class FusedScaleMaskSoftmax(nn.Module):
scale is None or self.softmax_in_fp32 scale is None or self.softmax_in_fp32
), "softmax should be in fp32 when scaled" ), "softmax should be in fp32 when scaled"
if self.is_kernel_available(*inp.size()): if self.is_kernel_available(*inp.size()) and not is_in_onnx_export_mode():
return self.forward_fused_softmax(inp, mask, scale) return self.forward_fused_softmax(inp, mask, scale)
return self.forward_torch_softmax(inp, mask, scale) return self.forward_torch_softmax(inp, mask, scale)
...@@ -287,7 +318,12 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -287,7 +318,12 @@ class FusedScaleMaskSoftmax(nn.Module):
inp = inp * scale inp = inp * scale
if self.attn_mask_type == "causal": if self.attn_mask_type == "causal":
mask = _get_default_causal_mask(inp.size()[2]) if is_in_onnx_export_mode() and self.kvcache_max_seq > 0:
seq_len_q, seq_len_k = inp.size(2), inp.size(3)
assert self.kvcache_max_seq >= seq_len_k
mask = _get_onnx_export_causal_mask(seq_len_q, seq_len_k, self.onnx_causal_mask)
else:
mask = _get_default_causal_mask(inp.size(2))
mask_output = self.mask_func(inp, mask) if mask is not None else inp mask_output = self.mask_func(inp, mask) if mask is not None else inp
probs = torch.nn.Softmax(dim=-1)(mask_output) probs = torch.nn.Softmax(dim=-1)(mask_output)
......
...@@ -234,7 +234,8 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma): ...@@ -234,7 +234,8 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
if zero_centered_gamma: if zero_centered_gamma:
inputs_dtype = inputs.type().dtype() inputs_dtype = inputs.type().dtype()
one = g.op("Constant", value_t=torch.tensor([1.], dtype=inputs_dtype, device="cuda")) shape = g.op("Shape", weight)
one = g.op("ConstantOfShape", shape, value_t=torch.tensor([1], dtype=inputs_dtype))
weight = g.op("Add", weight, one) weight = g.op("Add", weight, one)
axis = -len(normalized_shape) axis = -len(normalized_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment