Unverified Commit 81429b80 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

deprecate qk layer scaling and fp32 softmax args (#90)



* deprecate qk layer scaling and fp32 softmax args
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* apply QK layer scaling for fp16 training
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 39631f76
......@@ -756,16 +756,10 @@ def test_export_layernorm_mlp(
(torch.float16, True, "padding"), # calls ScaledMaskedSoftmax
(torch.float16, False, "padding"), # calls ScaledSoftmax
])
@pytest.mark.parametrize("attention_softmax_in_fp32",
[True, False])
@pytest.mark.parametrize("apply_query_key_layer_scaling",
[True, False])
def test_export_core_attention(
precision: torch.dtype,
use_mask: bool,
attn_mask_type: str,
attention_softmax_in_fp32: bool,
apply_query_key_layer_scaling: bool,
):
# Set dimensions (these are arbitrary).
kv_channels = 64
......@@ -784,11 +778,9 @@ def test_export_core_attention(
input_names.append("attention_mask")
inp = (query_layer, key_layer, value_layer, attention_mask)
sm_prec_str = "_sm-fp32" if attention_softmax_in_fp32 else "_sm-fp16"
qk_scaling_str = "_qk-scaling" if apply_query_key_layer_scaling else ""
mask_str = get_attn_mask_str(use_mask, attn_mask_type)
high_prec_str = dtype2str(precision)
fname = f"te.core_attention{mask_str}{qk_scaling_str}{sm_prec_str}{high_prec_str}.onnx"
fname = f"te.core_attention{mask_str}{high_prec_str}.onnx"
if attn_mask_type is None:
attn_mask_type = 'causal'
......@@ -798,8 +790,6 @@ def test_export_core_attention(
kv_channels=kv_channels,
attention_dropout=0.5,
attn_mask_type=attn_mask_type,
attention_softmax_in_fp32=attention_softmax_in_fp32,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
).to(device='cuda')
do_export(model,
inp,
......@@ -911,7 +901,6 @@ def test_export_multihead_attention(
])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16])
@pytest.mark.parametrize("fuse_qkv_params", [False, True])
@pytest.mark.parametrize("apply_query_key_layer_scaling", [True, False])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_transformer_layer(
use_fp8: bool,
......@@ -920,7 +909,6 @@ def test_export_transformer_layer(
output_layernorm: bool,
precision: torch.dtype,
fuse_qkv_params: bool,
apply_query_key_layer_scaling: bool,
zero_centered_gamma: bool
):
# Skip FP8 tests on non-hopper devices
......@@ -946,10 +934,9 @@ def test_export_transformer_layer(
fp8_str = "_fp8" if use_fp8 else ""
fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else ""
qk_scaling_str = "_qk-scaling" if apply_query_key_layer_scaling else ""
high_prec_str = dtype2str(precision)
attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
fname = f"te.transformer_layer{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{qk_scaling_str}{high_prec_str}.onnx"
fname = f"te.transformer_layer{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}.onnx"
model = te.TransformerLayer(
hidden_size,
......@@ -959,7 +946,6 @@ def test_export_transformer_layer(
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
zero_centered_gamma=zero_centered_gamma).to(device='cuda')
do_export(model, inp, fname, use_fp8)
if not use_fp8:
......
......@@ -4,7 +4,7 @@
"""Fused scaled masked softmax functions"""
import os
from typing import Callable, Tuple, Union
from typing import Callable, Tuple, Union, Optional
import torch
from torch import nn
import torch._C._onnx as _C_onnx
......@@ -198,15 +198,13 @@ class FusedScaleMaskSoftmax(nn.Module):
attn_mask_type: attention mask type (pad or causal)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def __init__(
self,
attn_mask_type: str,
mask_func: Callable,
softmax_in_fp32: bool,
scale: float,
softmax_in_fp32: bool = True,
) -> None:
super().__init__()
self.attn_mask_type = attn_mask_type
......@@ -215,13 +213,13 @@ class FusedScaleMaskSoftmax(nn.Module):
)
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert (
self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled"
def forward(self, inp: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
def forward(
self,
inp: torch.Tensor,
mask: torch.Tensor,
scale: Optional[float] = None,
) -> torch.Tensor:
"""FusedScaleMaskSoftmax fprop"""
# [b, np, sq, sk]
assert inp.dim() == 4
......@@ -229,9 +227,13 @@ class FusedScaleMaskSoftmax(nn.Module):
self.input_in_bf16 = inp.dtype == torch.bfloat16
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
assert (
scale is None or self.softmax_in_fp32
), "softmax should be in fp32 when scaled"
if self.is_kernel_available(*inp.size()):
return self.forward_fused_softmax(inp, mask)
return self.forward_torch_softmax(inp, mask)
return self.forward_fused_softmax(inp, mask, scale)
return self.forward_torch_softmax(inp, mask, scale)
def is_kernel_available(self, b: int, np: int, sq: int, sk: int) -> bool:
"""Check FusedScaleMaskSoftmax kernel availability based on size"""
......@@ -256,11 +258,11 @@ class FusedScaleMaskSoftmax(nn.Module):
return False
def forward_fused_softmax(
self, inp: torch.Tensor, mask: torch.Tensor
self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
) -> torch.Tensor:
"""Fused masked softmax kernel"""
b, np, sq, sk = inp.size()
scale = self.scale if self.scale is not None else 1.0
scale = 1.0 if scale is None else scale
if self.attn_mask_type == "causal":
assert sq == sk, "causal mask is only for self attention"
......@@ -275,14 +277,14 @@ class FusedScaleMaskSoftmax(nn.Module):
return ScaledSoftmax.apply(inp, scale)
def forward_torch_softmax(
self, inp: torch.Tensor, mask: torch.Tensor
self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
) -> torch.Tensor:
"""Framework softmax"""
if self.input_in_float16 and self.softmax_in_fp32:
inp = inp.float()
if self.scale is not None:
inp = inp * self.scale
if scale is not None:
inp = inp * scale
if self.attn_mask_type == "causal":
mask = _get_default_causal_mask(inp.size()[2])
......
......@@ -6,6 +6,7 @@
import os
import re
import math
import warnings
from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union
......@@ -42,6 +43,7 @@ from transformer_engine.pytorch.distributed import (
)
_flash_attn_version = re.search("Version: (.*)", os.popen("pip show flash_attn").read()).group(1)
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
__all__ = ["DotProductAttention", "TransformerLayer"]
......@@ -82,10 +84,8 @@ class UnfusedDotProductAttention(torch.nn.Module):
norm_factor: float,
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None,
apply_query_key_layer_scaling: bool = False,
attention_softmax_in_fp32: bool = True,
attn_mask_type: str = "causal",
layer_number: Optional[int] = None,
) -> None:
super().__init__()
......@@ -95,12 +95,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number
self.scale_mask_softmax = FusedScaleMaskSoftmax(
attn_mask_type,
attention_mask_func,
attention_softmax_in_fp32,
layer_number if apply_query_key_layer_scaling else None,
)
# Dropout. Note that for a single iteration, this layer will generate
......@@ -117,6 +116,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
) -> torch.Tensor:
"""core attention fprop"""
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
apply_qk_layer_scaling = self.layer_number is not None and key_layer.dtype == torch.float16
# [b, np, sq, sk]
output_size = (
......@@ -142,20 +142,25 @@ class UnfusedDotProductAttention(torch.nn.Module):
device=torch.cuda.current_device(),
)
scale = self.norm_factor
if apply_qk_layer_scaling:
scale *= self.layer_number
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
alpha=(1.0 / scale),
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
softmax_scale = self.layer_number if apply_qk_layer_scaling else None
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, softmax_scale)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
......@@ -205,9 +210,6 @@ class FlashAttention(torch.nn.Module):
norm_factor: float,
attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None,
apply_query_key_layer_scaling: bool = False,
attention_softmax_in_fp32: bool = True,
attn_mask_type: str = "causal",
) -> None:
super().__init__()
......@@ -226,9 +228,6 @@ class FlashAttention(torch.nn.Module):
self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx
self.attention_dropout = attention_dropout
self.layer_number = layer_number
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
def forward(
self,
......@@ -309,6 +308,9 @@ class DotProductAttention(torch.nn.Module):
dropout probability for the dropout op during multi-head attention.
attn_mask_type: {'causal', 'padding'}, default = `causal`
type of attention mask passed into softmax operation.
layer_number: int, default = `None`
layer number of the current `DotProductAttention` when multiple such modules
are concatenated, for instance in consecutive transformer blocks.
Parallelism parameters
----------------------
......@@ -325,25 +327,15 @@ class DotProductAttention(torch.nn.Module):
num_attention_heads: int,
kv_channels: int,
attention_dropout: float = 0.0,
layer_number: Optional[int] = None,
apply_query_key_layer_scaling: bool = False,
attention_softmax_in_fp32: bool = True,
attn_mask_type: str = "causal",
sequence_parallel: bool = False,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
tp_group: Optional[dist_group_type] = None,
layer_number: Optional[int] = None,
) -> None:
super().__init__()
if layer_number is None:
apply_query_key_layer_scaling = False
else:
layer_number = max(1, layer_number)
if apply_query_key_layer_scaling:
attention_softmax_in_fp32 = True
tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_group = tp_group
self.get_rng_state_tracker = get_rng_state_tracker
......@@ -360,9 +352,6 @@ class DotProductAttention(torch.nn.Module):
attention_dropout_ctx = get_rng_state_tracker().fork
norm_factor = math.sqrt(self.hidden_size_per_attention_head)
norm_factor_flash_attn = norm_factor
if apply_query_key_layer_scaling:
norm_factor *= layer_number
self.use_flash_attention = (
int(os.getenv("NVTE_FLASH_ATTN", "1"))
......@@ -373,17 +362,15 @@ class DotProductAttention(torch.nn.Module):
attn_kwargs = {
"attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx,
"layer_number": layer_number,
"apply_query_key_layer_scaling": apply_query_key_layer_scaling,
"attention_softmax_in_fp32": attention_softmax_in_fp32,
"attn_mask_type": attn_mask_type,
}
if self.use_flash_attention:
self.flash_attention = FlashAttention(norm_factor_flash_attn, **attn_kwargs)
self.flash_attention = FlashAttention(norm_factor, **attn_kwargs)
# Instantiating both types since use of flash-attn
# might be ruled out due to forward inputs.
self.unfused_attention = UnfusedDotProductAttention(norm_factor, **attn_kwargs)
self.unfused_attention = UnfusedDotProductAttention(
norm_factor, **attn_kwargs, layer_number=layer_number)
def _checkpointed_attention_forward(
self,
......@@ -487,8 +474,6 @@ class MultiHeadAttention(torch.nn.Module):
init_method: Callable,
output_layer_init_method: Callable,
layer_number: Optional[int] = None,
apply_query_key_layer_scaling: bool = False,
attention_softmax_in_fp32: bool = True,
attn_mask_type: str = "causal",
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
......@@ -607,14 +592,12 @@ class MultiHeadAttention(torch.nn.Module):
num_attention_heads,
kv_channels,
attention_dropout,
layer_number=layer_number,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32,
tp_size=tp_size,
get_rng_state_tracker=get_rng_state_tracker,
attn_mask_type=attn_mask_type,
sequence_parallel=sequence_parallel,
tp_group=tp_group,
layer_number=layer_number,
)
# Linear
......@@ -835,6 +818,11 @@ class TransformerLayer(torch.nn.Module):
TransformerLayer is made up of an attention block and a feedforward network (MLP).
This standard layer is based on the paper "Attention Is All You Need".
.. warning::
Arguments :attr:`attention_softmax_in_fp32` and :attr:`apply_query_key_layer_scaling`
are deprecated and will be fully removed in future releases.
.. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when
......@@ -870,16 +858,10 @@ class TransformerLayer(torch.nn.Module):
layer_number: int, default = `None`
layer number of the current `TransformerLayer` when multiple such modules are
concatenated to form a transformer block.
apply_query_key_layer_scaling: bool, default = `False`
apply query-key layer scaling during BMM1
by a factor of `layer_number`
output_layernorm: bool, default = `False`
if set to `True`, layer normalization is applied on the output side,
after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation.
attention_softmax_in_fp32: bool, default = `True`
if set to `False`, softmax is executed in
the dtype of activation tensors.
layer_type: {'encoder', 'decoder'}, default = `encoder`
if set to `decoder`, an additional cross-attn block is added after self-attn.
This can be used for structures like `T5` Transformer in conjunction with the
......@@ -964,8 +946,8 @@ class TransformerLayer(torch.nn.Module):
params_dtype: torch.dtype = torch.float32,
get_rng_state_tracker: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False,
apply_query_key_layer_scaling: bool = False,
attention_softmax_in_fp32: bool = True,
apply_query_key_layer_scaling: bool = False, # pylint: disable=unused-argument
attention_softmax_in_fp32: bool = True, # pylint: disable=unused-argument
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
sequence_parallel: bool = False,
......@@ -980,6 +962,12 @@ class TransformerLayer(torch.nn.Module):
) -> None:
super().__init__()
warnings.warn(
"Arguments `attention_softmax_in_fp32` and `apply_query_key_layer_scaling`"
"are deprecated and will be fully removed in future releases.",
category=DeprecationWarning,
)
bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
self.layer_number = layer_number
self.output_layernorm = output_layernorm
......@@ -1026,8 +1014,6 @@ class TransformerLayer(torch.nn.Module):
)
common_attention_kwargs = {
"layer_number": layer_number,
"apply_query_key_layer_scaling": apply_query_key_layer_scaling,
"attention_softmax_in_fp32": attention_softmax_in_fp32,
"tp_group": tp_group,
"tp_size": tp_size,
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment