Commit 9df0c4a3 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents 0d874a4e f122b07d
......@@ -12,7 +12,6 @@ import torch
from transformer_engine.pytorch.torch_version import torch_version
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.jit import (
......@@ -35,7 +34,7 @@ from transformer_engine.pytorch.constants import (
from transformer_engine.pytorch.distributed import get_distributed_world_size
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
......@@ -149,11 +148,21 @@ class TransformerLayer(torch.nn.Module):
distinguishes them based on :attr:`self_attn_mask_type` or :attr:`enc_dec_attn_mask_type`.
Similar to :attr:`self_attn_mask_type`, :attr:`window_size` can be overridden by
:attr:`window_size` in :meth:`forward` as well.
bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `self_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
enc_dec_attn_mask_type : {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = "no_mask"
type of attention mask passed into softmax operation for decoder.
enc_dec_window_size : Optional[Tuple[int, int]], default = None
sliding window size for local attention in decoder.
enc_dec_bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the decoder.
If `None`, it will be set to `False` for `enc_dec_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
zero_centered_gamma : bool, default = False
if set to ``True``, gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
......@@ -175,7 +184,7 @@ class TransformerLayer(torch.nn.Module):
if set to ``False``, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu'
Type of activation used in MLP block.
Options are: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
Options are: ``'gelu'``, ``'geglu'``, ``'glu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``.
activation_params : Optional[dict], default = None
Additional parameters for the activation function.
......@@ -302,7 +311,9 @@ class TransformerLayer(torch.nn.Module):
kv_channels: Optional[int] = None,
self_attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
enc_dec_attn_mask_type: str = "no_mask",
enc_dec_bottom_right_diagonal: Optional[bool] = None,
enc_dec_window_size: Optional[Tuple[int, int]] = None,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
......@@ -344,8 +355,10 @@ class TransformerLayer(torch.nn.Module):
self.self_attn_mask_type = self_attn_mask_type
self.window_size = window_size
self.bottom_right_diagonal = bottom_right_diagonal
self.enc_dec_attn_mask_type = enc_dec_attn_mask_type
self.enc_dec_window_size = enc_dec_window_size
self.enc_dec_bottom_right_diagonal = enc_dec_bottom_right_diagonal
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
......@@ -398,6 +411,7 @@ class TransformerLayer(torch.nn.Module):
self.softmax_type = softmax_type
self.name = name
TransformerEngineBaseModule._validate_name(self)
attention_args = (
hidden_size,
......@@ -446,7 +460,7 @@ class TransformerLayer(torch.nn.Module):
qk_norm_type=qk_norm_type,
qk_norm_eps=qk_norm_eps,
qk_norm_before_rope=qk_norm_before_rope,
name=name + ".self_attention" if name is not None else None,
name=self.name + ".self_attention" if self.name is not None else None,
)
if layer_type == "decoder":
......@@ -463,7 +477,7 @@ class TransformerLayer(torch.nn.Module):
qk_norm_type=qk_norm_type,
qk_norm_eps=qk_norm_eps,
qk_norm_before_rope=qk_norm_before_rope,
name=name + ".inter_attention" if name is not None else None,
name=self.name + ".inter_attention" if self.name is not None else None,
)
# LayerNorm -> activation(Linear + Bias) -> Linear
......@@ -499,7 +513,7 @@ class TransformerLayer(torch.nn.Module):
activation_params=activation_params,
normalization=normalization,
device=device,
name=name + ".layernorm_mlp" if name is not None else None,
name=self.name + ".layernorm_mlp" if self.name is not None else None,
)
self.hidden_dropout = hidden_dropout
......@@ -606,10 +620,12 @@ class TransformerLayer(torch.nn.Module):
attention_mask: Optional[torch.Tensor] = None,
self_attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
enc_dec_attn_mask_type: Optional[str] = None,
enc_dec_window_size: Optional[Tuple[int, int]] = None,
enc_dec_bottom_right_diagonal: Optional[bool] = None,
is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False,
inference_params: Optional[InferenceParams] = None,
......@@ -654,6 +670,11 @@ class TransformerLayer(torch.nn.Module):
causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention in encoder.
bottom_right_diagonal: Optional[bool] = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `self_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
encoder_output : Optional[torch.Tensor], default = None
Output of the encoder block to be fed into the decoder block if using
:attr:`layer_type` = ``"decoder"``.
......@@ -670,6 +691,11 @@ class TransformerLayer(torch.nn.Module):
Type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention in decoder.
enc_dec_bottom_right_diagonal: Optional[bool] = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the decoder.
If `None`, it will be set to `False` for `enc_dec_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
......@@ -736,10 +762,35 @@ class TransformerLayer(torch.nn.Module):
self_attn_mask_type = self.self_attn_mask_type
if window_size is None:
window_size = self.window_size
window_size = dpa_utils.check_set_window_size(self_attn_mask_type, window_size)
if enc_dec_attn_mask_type is None:
enc_dec_attn_mask_type = self.enc_dec_attn_mask_type
if enc_dec_window_size is None:
enc_dec_window_size = self.enc_dec_window_size
enc_dec_window_size = dpa_utils.check_set_window_size(
enc_dec_attn_mask_type, enc_dec_window_size
)
if bottom_right_diagonal is None:
bottom_right_diagonal = self.bottom_right_diagonal
if self_attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or self_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True
if enc_dec_bottom_right_diagonal is None:
enc_dec_bottom_right_diagonal = self.enc_dec_bottom_right_diagonal
if enc_dec_attn_mask_type in {"causal", "padding_causal"}:
enc_dec_bottom_right_diagonal = False
if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
enc_dec_bottom_right_diagonal = True
assert (
self_attn_mask_type in AttnMaskTypes
......@@ -768,9 +819,6 @@ class TransformerLayer(torch.nn.Module):
enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask))
), "Encoder-decoder attention mask must be boolean tensor(s)"
if TEDebugState.debug_enabled:
TransformerEngineBaseModule._validate_name(self)
# For AMP
if torch.is_autocast_enabled():
hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype())
......@@ -781,6 +829,7 @@ class TransformerLayer(torch.nn.Module):
attention_mask=attention_mask,
attn_mask_type=self_attn_mask_type,
window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
inference_params=inference_params,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
......@@ -816,6 +865,7 @@ class TransformerLayer(torch.nn.Module):
attention_mask=enc_dec_attn_mask,
attn_mask_type=enc_dec_attn_mask_type,
window_size=enc_dec_window_size,
bottom_right_diagonal=enc_dec_bottom_right_diagonal,
encoder_output=encoder_output,
inference_params=inference_params,
is_first_microbatch=is_first_microbatch,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment