Commit 9df0c4a3 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents 0d874a4e f122b07d
...@@ -12,7 +12,6 @@ import torch ...@@ -12,7 +12,6 @@ import torch
from transformer_engine.pytorch.torch_version import torch_version from transformer_engine.pytorch.torch_version import torch_version
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.jit import ( from transformer_engine.pytorch.jit import (
...@@ -35,7 +34,7 @@ from transformer_engine.pytorch.constants import ( ...@@ -35,7 +34,7 @@ from transformer_engine.pytorch.constants import (
from transformer_engine.pytorch.distributed import get_distributed_world_size from transformer_engine.pytorch.distributed import get_distributed_world_size
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
...@@ -149,11 +148,21 @@ class TransformerLayer(torch.nn.Module): ...@@ -149,11 +148,21 @@ class TransformerLayer(torch.nn.Module):
distinguishes them based on :attr:`self_attn_mask_type` or :attr:`enc_dec_attn_mask_type`. distinguishes them based on :attr:`self_attn_mask_type` or :attr:`enc_dec_attn_mask_type`.
Similar to :attr:`self_attn_mask_type`, :attr:`window_size` can be overridden by Similar to :attr:`self_attn_mask_type`, :attr:`window_size` can be overridden by
:attr:`window_size` in :meth:`forward` as well. :attr:`window_size` in :meth:`forward` as well.
bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `self_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
enc_dec_attn_mask_type : {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, enc_dec_attn_mask_type : {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = "no_mask" default = "no_mask"
type of attention mask passed into softmax operation for decoder. type of attention mask passed into softmax operation for decoder.
enc_dec_window_size : Optional[Tuple[int, int]], default = None enc_dec_window_size : Optional[Tuple[int, int]], default = None
sliding window size for local attention in decoder. sliding window size for local attention in decoder.
enc_dec_bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the decoder.
If `None`, it will be set to `False` for `enc_dec_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
zero_centered_gamma : bool, default = False zero_centered_gamma : bool, default = False
if set to ``True``, gamma parameter in LayerNorm is initialized to 0 and if set to ``True``, gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to the LayerNorm formula changes to
...@@ -175,7 +184,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -175,7 +184,7 @@ class TransformerLayer(torch.nn.Module):
if set to ``False``, the transformer layer will not learn any additive biases. if set to ``False``, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu' activation : str, default = 'gelu'
Type of activation used in MLP block. Type of activation used in MLP block.
Options are: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``, Options are: ``'gelu'``, ``'geglu'``, ``'glu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``. ``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``.
activation_params : Optional[dict], default = None activation_params : Optional[dict], default = None
Additional parameters for the activation function. Additional parameters for the activation function.
...@@ -302,7 +311,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -302,7 +311,9 @@ class TransformerLayer(torch.nn.Module):
kv_channels: Optional[int] = None, kv_channels: Optional[int] = None,
self_attn_mask_type: str = "causal", self_attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
enc_dec_attn_mask_type: str = "no_mask", enc_dec_attn_mask_type: str = "no_mask",
enc_dec_bottom_right_diagonal: Optional[bool] = None,
enc_dec_window_size: Optional[Tuple[int, int]] = None, enc_dec_window_size: Optional[Tuple[int, int]] = None,
tp_group: Optional[dist_group_type] = None, tp_group: Optional[dist_group_type] = None,
tp_size: int = 1, tp_size: int = 1,
...@@ -344,8 +355,10 @@ class TransformerLayer(torch.nn.Module): ...@@ -344,8 +355,10 @@ class TransformerLayer(torch.nn.Module):
self.self_attn_mask_type = self_attn_mask_type self.self_attn_mask_type = self_attn_mask_type
self.window_size = window_size self.window_size = window_size
self.bottom_right_diagonal = bottom_right_diagonal
self.enc_dec_attn_mask_type = enc_dec_attn_mask_type self.enc_dec_attn_mask_type = enc_dec_attn_mask_type
self.enc_dec_window_size = enc_dec_window_size self.enc_dec_window_size = enc_dec_window_size
self.enc_dec_bottom_right_diagonal = enc_dec_bottom_right_diagonal
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
...@@ -398,6 +411,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -398,6 +411,7 @@ class TransformerLayer(torch.nn.Module):
self.softmax_type = softmax_type self.softmax_type = softmax_type
self.name = name self.name = name
TransformerEngineBaseModule._validate_name(self)
attention_args = ( attention_args = (
hidden_size, hidden_size,
...@@ -446,7 +460,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -446,7 +460,7 @@ class TransformerLayer(torch.nn.Module):
qk_norm_type=qk_norm_type, qk_norm_type=qk_norm_type,
qk_norm_eps=qk_norm_eps, qk_norm_eps=qk_norm_eps,
qk_norm_before_rope=qk_norm_before_rope, qk_norm_before_rope=qk_norm_before_rope,
name=name + ".self_attention" if name is not None else None, name=self.name + ".self_attention" if self.name is not None else None,
) )
if layer_type == "decoder": if layer_type == "decoder":
...@@ -463,7 +477,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -463,7 +477,7 @@ class TransformerLayer(torch.nn.Module):
qk_norm_type=qk_norm_type, qk_norm_type=qk_norm_type,
qk_norm_eps=qk_norm_eps, qk_norm_eps=qk_norm_eps,
qk_norm_before_rope=qk_norm_before_rope, qk_norm_before_rope=qk_norm_before_rope,
name=name + ".inter_attention" if name is not None else None, name=self.name + ".inter_attention" if self.name is not None else None,
) )
# LayerNorm -> activation(Linear + Bias) -> Linear # LayerNorm -> activation(Linear + Bias) -> Linear
...@@ -499,7 +513,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -499,7 +513,7 @@ class TransformerLayer(torch.nn.Module):
activation_params=activation_params, activation_params=activation_params,
normalization=normalization, normalization=normalization,
device=device, device=device,
name=name + ".layernorm_mlp" if name is not None else None, name=self.name + ".layernorm_mlp" if self.name is not None else None,
) )
self.hidden_dropout = hidden_dropout self.hidden_dropout = hidden_dropout
...@@ -606,10 +620,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -606,10 +620,12 @@ class TransformerLayer(torch.nn.Module):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
self_attn_mask_type: Optional[str] = None, self_attn_mask_type: Optional[str] = None,
window_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
encoder_output: Optional[torch.Tensor] = None, encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, enc_dec_attn_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
enc_dec_attn_mask_type: Optional[str] = None, enc_dec_attn_mask_type: Optional[str] = None,
enc_dec_window_size: Optional[Tuple[int, int]] = None, enc_dec_window_size: Optional[Tuple[int, int]] = None,
enc_dec_bottom_right_diagonal: Optional[bool] = None,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
checkpoint_core_attention: bool = False, checkpoint_core_attention: bool = False,
inference_params: Optional[InferenceParams] = None, inference_params: Optional[InferenceParams] = None,
...@@ -654,6 +670,11 @@ class TransformerLayer(torch.nn.Module): ...@@ -654,6 +670,11 @@ class TransformerLayer(torch.nn.Module):
causal masks are aligned to the bottom right corner. causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = None window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention in encoder. Sliding window size for local attention in encoder.
bottom_right_diagonal: Optional[bool] = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `self_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
encoder_output : Optional[torch.Tensor], default = None encoder_output : Optional[torch.Tensor], default = None
Output of the encoder block to be fed into the decoder block if using Output of the encoder block to be fed into the decoder block if using
:attr:`layer_type` = ``"decoder"``. :attr:`layer_type` = ``"decoder"``.
...@@ -670,6 +691,11 @@ class TransformerLayer(torch.nn.Module): ...@@ -670,6 +691,11 @@ class TransformerLayer(torch.nn.Module):
Type of attention mask passed into softmax operation for decoder. Type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default = None enc_dec_window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention in decoder. Sliding window size for local attention in decoder.
enc_dec_bottom_right_diagonal: Optional[bool] = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the decoder.
If `None`, it will be set to `False` for `enc_dec_attn_mask_type` =
{`causal`, `padding_causal`} and `True` for other mask types.
is_first_microbatch : {True, False, None}, default = None is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split pipeline parallelism a minibatch of data is further split
...@@ -736,10 +762,35 @@ class TransformerLayer(torch.nn.Module): ...@@ -736,10 +762,35 @@ class TransformerLayer(torch.nn.Module):
self_attn_mask_type = self.self_attn_mask_type self_attn_mask_type = self.self_attn_mask_type
if window_size is None: if window_size is None:
window_size = self.window_size window_size = self.window_size
window_size = dpa_utils.check_set_window_size(self_attn_mask_type, window_size)
if enc_dec_attn_mask_type is None: if enc_dec_attn_mask_type is None:
enc_dec_attn_mask_type = self.enc_dec_attn_mask_type enc_dec_attn_mask_type = self.enc_dec_attn_mask_type
if enc_dec_window_size is None: if enc_dec_window_size is None:
enc_dec_window_size = self.enc_dec_window_size enc_dec_window_size = self.enc_dec_window_size
enc_dec_window_size = dpa_utils.check_set_window_size(
enc_dec_attn_mask_type, enc_dec_window_size
)
if bottom_right_diagonal is None:
bottom_right_diagonal = self.bottom_right_diagonal
if self_attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or self_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True
if enc_dec_bottom_right_diagonal is None:
enc_dec_bottom_right_diagonal = self.enc_dec_bottom_right_diagonal
if enc_dec_attn_mask_type in {"causal", "padding_causal"}:
enc_dec_bottom_right_diagonal = False
if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
enc_dec_bottom_right_diagonal = True
assert ( assert (
self_attn_mask_type in AttnMaskTypes self_attn_mask_type in AttnMaskTypes
...@@ -768,9 +819,6 @@ class TransformerLayer(torch.nn.Module): ...@@ -768,9 +819,6 @@ class TransformerLayer(torch.nn.Module):
enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask)) enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask))
), "Encoder-decoder attention mask must be boolean tensor(s)" ), "Encoder-decoder attention mask must be boolean tensor(s)"
if TEDebugState.debug_enabled:
TransformerEngineBaseModule._validate_name(self)
# For AMP # For AMP
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype()) hidden_states = cast_if_needed(hidden_states, torch_get_autocast_gpu_dtype())
...@@ -781,6 +829,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -781,6 +829,7 @@ class TransformerLayer(torch.nn.Module):
attention_mask=attention_mask, attention_mask=attention_mask,
attn_mask_type=self_attn_mask_type, attn_mask_type=self_attn_mask_type,
window_size=window_size, window_size=window_size,
bottom_right_diagonal=bottom_right_diagonal,
inference_params=inference_params, inference_params=inference_params,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
...@@ -816,6 +865,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -816,6 +865,7 @@ class TransformerLayer(torch.nn.Module):
attention_mask=enc_dec_attn_mask, attention_mask=enc_dec_attn_mask,
attn_mask_type=enc_dec_attn_mask_type, attn_mask_type=enc_dec_attn_mask_type,
window_size=enc_dec_window_size, window_size=enc_dec_window_size,
bottom_right_diagonal=enc_dec_bottom_right_diagonal,
encoder_output=encoder_output, encoder_output=encoder_output,
inference_params=inference_params, inference_params=inference_params,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment