Unverified Commit 81429b80 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

deprecate qk layer scaling and fp32 softmax args (#90)



* deprecate qk layer scaling and fp32 softmax args
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* apply QK layer scaling for fp16 training
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 39631f76
...@@ -756,16 +756,10 @@ def test_export_layernorm_mlp( ...@@ -756,16 +756,10 @@ def test_export_layernorm_mlp(
(torch.float16, True, "padding"), # calls ScaledMaskedSoftmax (torch.float16, True, "padding"), # calls ScaledMaskedSoftmax
(torch.float16, False, "padding"), # calls ScaledSoftmax (torch.float16, False, "padding"), # calls ScaledSoftmax
]) ])
@pytest.mark.parametrize("attention_softmax_in_fp32",
[True, False])
@pytest.mark.parametrize("apply_query_key_layer_scaling",
[True, False])
def test_export_core_attention( def test_export_core_attention(
precision: torch.dtype, precision: torch.dtype,
use_mask: bool, use_mask: bool,
attn_mask_type: str, attn_mask_type: str,
attention_softmax_in_fp32: bool,
apply_query_key_layer_scaling: bool,
): ):
# Set dimensions (these are arbitrary). # Set dimensions (these are arbitrary).
kv_channels = 64 kv_channels = 64
...@@ -784,11 +778,9 @@ def test_export_core_attention( ...@@ -784,11 +778,9 @@ def test_export_core_attention(
input_names.append("attention_mask") input_names.append("attention_mask")
inp = (query_layer, key_layer, value_layer, attention_mask) inp = (query_layer, key_layer, value_layer, attention_mask)
sm_prec_str = "_sm-fp32" if attention_softmax_in_fp32 else "_sm-fp16"
qk_scaling_str = "_qk-scaling" if apply_query_key_layer_scaling else ""
mask_str = get_attn_mask_str(use_mask, attn_mask_type) mask_str = get_attn_mask_str(use_mask, attn_mask_type)
high_prec_str = dtype2str(precision) high_prec_str = dtype2str(precision)
fname = f"te.core_attention{mask_str}{qk_scaling_str}{sm_prec_str}{high_prec_str}.onnx" fname = f"te.core_attention{mask_str}{high_prec_str}.onnx"
if attn_mask_type is None: if attn_mask_type is None:
attn_mask_type = 'causal' attn_mask_type = 'causal'
...@@ -798,8 +790,6 @@ def test_export_core_attention( ...@@ -798,8 +790,6 @@ def test_export_core_attention(
kv_channels=kv_channels, kv_channels=kv_channels,
attention_dropout=0.5, attention_dropout=0.5,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attention_softmax_in_fp32=attention_softmax_in_fp32,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
).to(device='cuda') ).to(device='cuda')
do_export(model, do_export(model,
inp, inp,
...@@ -911,7 +901,6 @@ def test_export_multihead_attention( ...@@ -911,7 +901,6 @@ def test_export_multihead_attention(
]) ])
@pytest.mark.parametrize("precision", [torch.float32, torch.float16]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16])
@pytest.mark.parametrize("fuse_qkv_params", [False, True]) @pytest.mark.parametrize("fuse_qkv_params", [False, True])
@pytest.mark.parametrize("apply_query_key_layer_scaling", [True, False])
@pytest.mark.parametrize("zero_centered_gamma", [False, True]) @pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_transformer_layer( def test_export_transformer_layer(
use_fp8: bool, use_fp8: bool,
...@@ -920,7 +909,6 @@ def test_export_transformer_layer( ...@@ -920,7 +909,6 @@ def test_export_transformer_layer(
output_layernorm: bool, output_layernorm: bool,
precision: torch.dtype, precision: torch.dtype,
fuse_qkv_params: bool, fuse_qkv_params: bool,
apply_query_key_layer_scaling: bool,
zero_centered_gamma: bool zero_centered_gamma: bool
): ):
# Skip FP8 tests on non-hopper devices # Skip FP8 tests on non-hopper devices
...@@ -946,10 +934,9 @@ def test_export_transformer_layer( ...@@ -946,10 +934,9 @@ def test_export_transformer_layer(
fp8_str = "_fp8" if use_fp8 else "" fp8_str = "_fp8" if use_fp8 else ""
fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else "" fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else ""
qk_scaling_str = "_qk-scaling" if apply_query_key_layer_scaling else ""
high_prec_str = dtype2str(precision) high_prec_str = dtype2str(precision)
attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
fname = f"te.transformer_layer{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{qk_scaling_str}{high_prec_str}.onnx" fname = f"te.transformer_layer{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}.onnx"
model = te.TransformerLayer( model = te.TransformerLayer(
hidden_size, hidden_size,
...@@ -959,7 +946,6 @@ def test_export_transformer_layer( ...@@ -959,7 +946,6 @@ def test_export_transformer_layer(
output_layernorm=output_layernorm, output_layernorm=output_layernorm,
params_dtype=precision, params_dtype=precision,
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
zero_centered_gamma=zero_centered_gamma).to(device='cuda') zero_centered_gamma=zero_centered_gamma).to(device='cuda')
do_export(model, inp, fname, use_fp8) do_export(model, inp, fname, use_fp8)
if not use_fp8: if not use_fp8:
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Fused scaled masked softmax functions""" """Fused scaled masked softmax functions"""
import os import os
from typing import Callable, Tuple, Union from typing import Callable, Tuple, Union, Optional
import torch import torch
from torch import nn from torch import nn
import torch._C._onnx as _C_onnx import torch._C._onnx as _C_onnx
...@@ -198,15 +198,13 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -198,15 +198,13 @@ class FusedScaleMaskSoftmax(nn.Module):
attn_mask_type: attention mask type (pad or causal) attn_mask_type: attention mask type (pad or causal)
mask_func: mask function to be applied. mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision. softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
""" """
def __init__( def __init__(
self, self,
attn_mask_type: str, attn_mask_type: str,
mask_func: Callable, mask_func: Callable,
softmax_in_fp32: bool, softmax_in_fp32: bool = True,
scale: float,
) -> None: ) -> None:
super().__init__() super().__init__()
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
...@@ -215,13 +213,13 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -215,13 +213,13 @@ class FusedScaleMaskSoftmax(nn.Module):
) )
self.mask_func = mask_func self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32 self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert ( def forward(
self.scale is None or softmax_in_fp32 self,
), "softmax should be in fp32 when scaled" inp: torch.Tensor,
mask: torch.Tensor,
def forward(self, inp: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: scale: Optional[float] = None,
) -> torch.Tensor:
"""FusedScaleMaskSoftmax fprop""" """FusedScaleMaskSoftmax fprop"""
# [b, np, sq, sk] # [b, np, sq, sk]
assert inp.dim() == 4 assert inp.dim() == 4
...@@ -229,9 +227,13 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -229,9 +227,13 @@ class FusedScaleMaskSoftmax(nn.Module):
self.input_in_bf16 = inp.dtype == torch.bfloat16 self.input_in_bf16 = inp.dtype == torch.bfloat16
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
assert (
scale is None or self.softmax_in_fp32
), "softmax should be in fp32 when scaled"
if self.is_kernel_available(*inp.size()): if self.is_kernel_available(*inp.size()):
return self.forward_fused_softmax(inp, mask) return self.forward_fused_softmax(inp, mask, scale)
return self.forward_torch_softmax(inp, mask) return self.forward_torch_softmax(inp, mask, scale)
def is_kernel_available(self, b: int, np: int, sq: int, sk: int) -> bool: def is_kernel_available(self, b: int, np: int, sq: int, sk: int) -> bool:
"""Check FusedScaleMaskSoftmax kernel availability based on size""" """Check FusedScaleMaskSoftmax kernel availability based on size"""
...@@ -256,11 +258,11 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -256,11 +258,11 @@ class FusedScaleMaskSoftmax(nn.Module):
return False return False
def forward_fused_softmax( def forward_fused_softmax(
self, inp: torch.Tensor, mask: torch.Tensor self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
) -> torch.Tensor: ) -> torch.Tensor:
"""Fused masked softmax kernel""" """Fused masked softmax kernel"""
b, np, sq, sk = inp.size() b, np, sq, sk = inp.size()
scale = self.scale if self.scale is not None else 1.0 scale = 1.0 if scale is None else scale
if self.attn_mask_type == "causal": if self.attn_mask_type == "causal":
assert sq == sk, "causal mask is only for self attention" assert sq == sk, "causal mask is only for self attention"
...@@ -275,14 +277,14 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -275,14 +277,14 @@ class FusedScaleMaskSoftmax(nn.Module):
return ScaledSoftmax.apply(inp, scale) return ScaledSoftmax.apply(inp, scale)
def forward_torch_softmax( def forward_torch_softmax(
self, inp: torch.Tensor, mask: torch.Tensor self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
) -> torch.Tensor: ) -> torch.Tensor:
"""Framework softmax""" """Framework softmax"""
if self.input_in_float16 and self.softmax_in_fp32: if self.input_in_float16 and self.softmax_in_fp32:
inp = inp.float() inp = inp.float()
if self.scale is not None: if scale is not None:
inp = inp * self.scale inp = inp * scale
if self.attn_mask_type == "causal": if self.attn_mask_type == "causal":
mask = _get_default_causal_mask(inp.size()[2]) mask = _get_default_causal_mask(inp.size()[2])
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import os import os
import re import re
import math import math
import warnings
from contextlib import nullcontext from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union from typing import Any, Callable, Optional, Tuple, Union
...@@ -42,6 +43,7 @@ from transformer_engine.pytorch.distributed import ( ...@@ -42,6 +43,7 @@ from transformer_engine.pytorch.distributed import (
) )
_flash_attn_version = re.search("Version: (.*)", os.popen("pip show flash_attn").read()).group(1) _flash_attn_version = re.search("Version: (.*)", os.popen("pip show flash_attn").read()).group(1)
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
__all__ = ["DotProductAttention", "TransformerLayer"] __all__ = ["DotProductAttention", "TransformerLayer"]
...@@ -82,10 +84,8 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -82,10 +84,8 @@ class UnfusedDotProductAttention(torch.nn.Module):
norm_factor: float, norm_factor: float,
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext, attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None,
apply_query_key_layer_scaling: bool = False,
attention_softmax_in_fp32: bool = True,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
layer_number: Optional[int] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -95,12 +95,11 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -95,12 +95,11 @@ class UnfusedDotProductAttention(torch.nn.Module):
self.norm_factor = norm_factor self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
self.layer_number = layer_number
self.scale_mask_softmax = FusedScaleMaskSoftmax( self.scale_mask_softmax = FusedScaleMaskSoftmax(
attn_mask_type, attn_mask_type,
attention_mask_func, attention_mask_func,
attention_softmax_in_fp32,
layer_number if apply_query_key_layer_scaling else None,
) )
# Dropout. Note that for a single iteration, this layer will generate # Dropout. Note that for a single iteration, this layer will generate
...@@ -117,6 +116,7 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -117,6 +116,7 @@ class UnfusedDotProductAttention(torch.nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
"""core attention fprop""" """core attention fprop"""
batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
apply_qk_layer_scaling = self.layer_number is not None and key_layer.dtype == torch.float16
# [b, np, sq, sk] # [b, np, sq, sk]
output_size = ( output_size = (
...@@ -142,20 +142,25 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -142,20 +142,25 @@ class UnfusedDotProductAttention(torch.nn.Module):
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
) )
scale = self.norm_factor
if apply_qk_layer_scaling:
scale *= self.layer_number
# Raw attention scores. [b * np, sq, sk] # Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm( matmul_result = torch.baddbmm(
matmul_result, matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn] query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, beta=0.0,
alpha=(1.0 / self.norm_factor), alpha=(1.0 / scale),
) )
# change view to [b, np, sq, sk] # change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size) attention_scores = matmul_result.view(*output_size)
# attention scores and attention mask [b, np, sq, sk] # attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) softmax_scale = self.layer_number if apply_qk_layer_scaling else None
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, softmax_scale)
# This is actually dropping out entire tokens to attend to, which might # This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
...@@ -205,9 +210,6 @@ class FlashAttention(torch.nn.Module): ...@@ -205,9 +210,6 @@ class FlashAttention(torch.nn.Module):
norm_factor: float, norm_factor: float,
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
attention_dropout_ctx: Optional[Callable] = nullcontext, attention_dropout_ctx: Optional[Callable] = nullcontext,
layer_number: Optional[int] = None,
apply_query_key_layer_scaling: bool = False,
attention_softmax_in_fp32: bool = True,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -226,9 +228,6 @@ class FlashAttention(torch.nn.Module): ...@@ -226,9 +228,6 @@ class FlashAttention(torch.nn.Module):
self.norm_factor = norm_factor self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx self.attention_dropout_ctx = attention_dropout_ctx
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.layer_number = layer_number
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
def forward( def forward(
self, self,
...@@ -309,6 +308,9 @@ class DotProductAttention(torch.nn.Module): ...@@ -309,6 +308,9 @@ class DotProductAttention(torch.nn.Module):
dropout probability for the dropout op during multi-head attention. dropout probability for the dropout op during multi-head attention.
attn_mask_type: {'causal', 'padding'}, default = `causal` attn_mask_type: {'causal', 'padding'}, default = `causal`
type of attention mask passed into softmax operation. type of attention mask passed into softmax operation.
layer_number: int, default = `None`
layer number of the current `DotProductAttention` when multiple such modules
are concatenated, for instance in consecutive transformer blocks.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -325,25 +327,15 @@ class DotProductAttention(torch.nn.Module): ...@@ -325,25 +327,15 @@ class DotProductAttention(torch.nn.Module):
num_attention_heads: int, num_attention_heads: int,
kv_channels: int, kv_channels: int,
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
layer_number: Optional[int] = None,
apply_query_key_layer_scaling: bool = False,
attention_softmax_in_fp32: bool = True,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
sequence_parallel: bool = False, sequence_parallel: bool = False,
tp_size: int = 1, tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None, get_rng_state_tracker: Optional[Callable] = None,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
layer_number: Optional[int] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if layer_number is None:
apply_query_key_layer_scaling = False
else:
layer_number = max(1, layer_number)
if apply_query_key_layer_scaling:
attention_softmax_in_fp32 = True
tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group) tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.tp_group = tp_group self.tp_group = tp_group
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
...@@ -360,9 +352,6 @@ class DotProductAttention(torch.nn.Module): ...@@ -360,9 +352,6 @@ class DotProductAttention(torch.nn.Module):
attention_dropout_ctx = get_rng_state_tracker().fork attention_dropout_ctx = get_rng_state_tracker().fork
norm_factor = math.sqrt(self.hidden_size_per_attention_head) norm_factor = math.sqrt(self.hidden_size_per_attention_head)
norm_factor_flash_attn = norm_factor
if apply_query_key_layer_scaling:
norm_factor *= layer_number
self.use_flash_attention = ( self.use_flash_attention = (
int(os.getenv("NVTE_FLASH_ATTN", "1")) int(os.getenv("NVTE_FLASH_ATTN", "1"))
...@@ -373,17 +362,15 @@ class DotProductAttention(torch.nn.Module): ...@@ -373,17 +362,15 @@ class DotProductAttention(torch.nn.Module):
attn_kwargs = { attn_kwargs = {
"attention_dropout": attention_dropout, "attention_dropout": attention_dropout,
"attention_dropout_ctx": attention_dropout_ctx, "attention_dropout_ctx": attention_dropout_ctx,
"layer_number": layer_number,
"apply_query_key_layer_scaling": apply_query_key_layer_scaling,
"attention_softmax_in_fp32": attention_softmax_in_fp32,
"attn_mask_type": attn_mask_type, "attn_mask_type": attn_mask_type,
} }
if self.use_flash_attention: if self.use_flash_attention:
self.flash_attention = FlashAttention(norm_factor_flash_attn, **attn_kwargs) self.flash_attention = FlashAttention(norm_factor, **attn_kwargs)
# Instantiating both types since use of flash-attn # Instantiating both types since use of flash-attn
# might be ruled out due to forward inputs. # might be ruled out due to forward inputs.
self.unfused_attention = UnfusedDotProductAttention(norm_factor, **attn_kwargs) self.unfused_attention = UnfusedDotProductAttention(
norm_factor, **attn_kwargs, layer_number=layer_number)
def _checkpointed_attention_forward( def _checkpointed_attention_forward(
self, self,
...@@ -487,8 +474,6 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -487,8 +474,6 @@ class MultiHeadAttention(torch.nn.Module):
init_method: Callable, init_method: Callable,
output_layer_init_method: Callable, output_layer_init_method: Callable,
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
apply_query_key_layer_scaling: bool = False,
attention_softmax_in_fp32: bool = True,
attn_mask_type: str = "causal", attn_mask_type: str = "causal",
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
tp_size: int = 1, tp_size: int = 1,
...@@ -607,14 +592,12 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -607,14 +592,12 @@ class MultiHeadAttention(torch.nn.Module):
num_attention_heads, num_attention_heads,
kv_channels, kv_channels,
attention_dropout, attention_dropout,
layer_number=layer_number,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32,
tp_size=tp_size, tp_size=tp_size,
get_rng_state_tracker=get_rng_state_tracker, get_rng_state_tracker=get_rng_state_tracker,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
sequence_parallel=sequence_parallel, sequence_parallel=sequence_parallel,
tp_group=tp_group, tp_group=tp_group,
layer_number=layer_number,
) )
# Linear # Linear
...@@ -835,6 +818,11 @@ class TransformerLayer(torch.nn.Module): ...@@ -835,6 +818,11 @@ class TransformerLayer(torch.nn.Module):
TransformerLayer is made up of an attention block and a feedforward network (MLP). TransformerLayer is made up of an attention block and a feedforward network (MLP).
This standard layer is based on the paper "Attention Is All You Need". This standard layer is based on the paper "Attention Is All You Need".
.. warning::
Arguments :attr:`attention_softmax_in_fp32` and :attr:`apply_query_key_layer_scaling`
are deprecated and will be fully removed in future releases.
.. note:: .. note::
Argument :attr:`attention_mask` will be ignored in the `forward` call when Argument :attr:`attention_mask` will be ignored in the `forward` call when
...@@ -870,16 +858,10 @@ class TransformerLayer(torch.nn.Module): ...@@ -870,16 +858,10 @@ class TransformerLayer(torch.nn.Module):
layer_number: int, default = `None` layer_number: int, default = `None`
layer number of the current `TransformerLayer` when multiple such modules are layer number of the current `TransformerLayer` when multiple such modules are
concatenated to form a transformer block. concatenated to form a transformer block.
apply_query_key_layer_scaling: bool, default = `False`
apply query-key layer scaling during BMM1
by a factor of `layer_number`
output_layernorm: bool, default = `False` output_layernorm: bool, default = `False`
if set to `True`, layer normalization is applied on the output side, if set to `True`, layer normalization is applied on the output side,
after the final dropout-add. default behavior is to apply layer after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation. normalization on the input side, before the QKV transformation.
attention_softmax_in_fp32: bool, default = `True`
if set to `False`, softmax is executed in
the dtype of activation tensors.
layer_type: {'encoder', 'decoder'}, default = `encoder` layer_type: {'encoder', 'decoder'}, default = `encoder`
if set to `decoder`, an additional cross-attn block is added after self-attn. if set to `decoder`, an additional cross-attn block is added after self-attn.
This can be used for structures like `T5` Transformer in conjunction with the This can be used for structures like `T5` Transformer in conjunction with the
...@@ -964,8 +946,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -964,8 +946,8 @@ class TransformerLayer(torch.nn.Module):
params_dtype: torch.dtype = torch.float32, params_dtype: torch.dtype = torch.float32,
get_rng_state_tracker: Optional[Callable] = None, get_rng_state_tracker: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False, fuse_wgrad_accumulation: bool = False,
apply_query_key_layer_scaling: bool = False, apply_query_key_layer_scaling: bool = False, # pylint: disable=unused-argument
attention_softmax_in_fp32: bool = True, attention_softmax_in_fp32: bool = True, # pylint: disable=unused-argument
seq_length: Optional[int] = None, seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None, micro_batch_size: Optional[int] = None,
sequence_parallel: bool = False, sequence_parallel: bool = False,
...@@ -980,6 +962,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -980,6 +962,12 @@ class TransformerLayer(torch.nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
warnings.warn(
"Arguments `attention_softmax_in_fp32` and `apply_query_key_layer_scaling`"
"are deprecated and will be fully removed in future releases.",
category=DeprecationWarning,
)
bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1"))) bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
self.layer_number = layer_number self.layer_number = layer_number
self.output_layernorm = output_layernorm self.output_layernorm = output_layernorm
...@@ -1026,8 +1014,6 @@ class TransformerLayer(torch.nn.Module): ...@@ -1026,8 +1014,6 @@ class TransformerLayer(torch.nn.Module):
) )
common_attention_kwargs = { common_attention_kwargs = {
"layer_number": layer_number, "layer_number": layer_number,
"apply_query_key_layer_scaling": apply_query_key_layer_scaling,
"attention_softmax_in_fp32": attention_softmax_in_fp32,
"tp_group": tp_group, "tp_group": tp_group,
"tp_size": tp_size, "tp_size": tp_size,
"fuse_wgrad_accumulation": fuse_wgrad_accumulation, "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment