Unverified Commit beed55b9 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Set BSHD as default in Unfused DPA, DPA and MHA API calls (#2392)



* Make BSHD default for Unfused DPA, DPA and MHA in TE JAX
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Remove explicit transpose_batch set for BSHD for DPA in JAX quickstart
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Add warnings in DPA and MHA to warn users of change defaults to BSHD instead of SBHD
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Minimize the scope of when to trigger warnings for changed defaults for transpose_batch_sequence
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent b14f417a
...@@ -368,7 +368,6 @@ ...@@ -368,7 +368,6 @@
" num_gqa_groups=self.num_attention_heads, # No GQA\n", " num_gqa_groups=self.num_attention_heads, # No GQA\n",
" attention_dropout=self.attention_dropout,\n", " attention_dropout=self.attention_dropout,\n",
" attn_mask_type='causal',\n", " attn_mask_type='causal',\n",
" transpose_batch_sequence=False, # Input format is [batch, seq_len, ...]\n",
" )\n", " )\n",
" x = attention(q, k, v, sequence_descriptor=None, deterministic=deterministic)\n", " x = attention(q, k, v, sequence_descriptor=None, deterministic=deterministic)\n",
" # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n", " # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n",
...@@ -628,7 +627,6 @@ ...@@ -628,7 +627,6 @@
" num_gqa_groups=self.num_attention_heads, \n", " num_gqa_groups=self.num_attention_heads, \n",
" attention_dropout=self.attention_dropout,\n", " attention_dropout=self.attention_dropout,\n",
" attn_mask_type='causal',\n", " attn_mask_type='causal',\n",
" transpose_batch_sequence=False, # Input format is [batch, seq_len, ...]\n",
" )\n", " )\n",
" x = attention(q, k, v, sequence_descriptor=None, deterministic=deterministic)\n", " x = attention(q, k, v, sequence_descriptor=None, deterministic=deterministic)\n",
" # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n", " # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n",
......
...@@ -124,7 +124,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -124,7 +124,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
dtype: DType = jnp.float32 dtype: DType = jnp.float32
float32_logits: bool = False float32_logits: bool = False
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = False
window_size: Optional[Tuple[int, int]] = None window_size: Optional[Tuple[int, int]] = None
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
...@@ -544,9 +544,10 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -544,9 +544,10 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal
to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't
need to apply scale on query, which is to set :attr:`scale_factor=1.`. need to apply scale on query, which is to set :attr:`scale_factor=1.`.
transpose_batch_sequence: bool, default = True TODO(KshitijLakhani): Reset this to bool only with default False arg in TransformerEngine v2.12
transpose_batch_sequence: bool | None, default = None (however, default is forced to False in post_init)
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...). should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
window_size: Optional[Tuple[int, int]], default = None window_size: Optional[Tuple[int, int]], default = None
Sliding window size. The default value is no sliding window. Sliding window size. The default value is no sliding window.
...@@ -586,7 +587,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -586,7 +587,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
float32_logits: bool = False float32_logits: bool = False
qkv_layout: str = "bshd_bshd_bshd" qkv_layout: str = "bshd_bshd_bshd"
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True transpose_batch_sequence: bool | None = None
window_size: Optional[Tuple[int, int]] = None window_size: Optional[Tuple[int, int]] = None
max_segments_per_seq: Optional[int] = 1 max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False context_parallel_causal_load_balanced: bool = False
...@@ -595,6 +596,17 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -595,6 +596,17 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
context_checkpoint_name: str = "context" context_checkpoint_name: str = "context"
softmax_type: str = "vanilla" softmax_type: str = "vanilla"
def __post_init__(self):
# TODO(KshitijLakhani): Remove warning in TransformerEngine v2.12
# None implies that the user is relying on defaults, hence warn the user and set the new defaults
if self.transpose_batch_sequence is None:
warnings.warn(
"transpose_batch_sequence defaults to False in DotProductAttention starting"
" TransformerEngine v2.10"
)
self.transpose_batch_sequence = False
super().__post_init__()
@nn.compact @nn.compact
def __call__( def __call__(
self, self,
...@@ -1047,7 +1059,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1047,7 +1059,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
If set to True, this module exposes a single fused If set to True, this module exposes a single fused
parameter for query-key-value for self-attention and key-value for parameter for query-key-value for self-attention and key-value for
cross-attention. cross-attention.
transpose_batch_sequence: bool, default = True TODO(KshitijLakhani): Reset this to bool only with default False arg in TransformerEngine v2.12
transpose_batch_sequence: bool | None, default = None (however, default is forced to False in post_init)
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
...@@ -1100,7 +1113,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1100,7 +1113,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32 dtype: DType = jnp.float32
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True transpose_batch_sequence: bool | None = None
enable_sequence_parallel: bool = False enable_sequence_parallel: bool = False
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
...@@ -1116,6 +1129,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1116,6 +1129,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
fuse_qkv: Optional[bool] = None fuse_qkv: Optional[bool] = None
def __post_init__(self): def __post_init__(self):
# Deal with changed defaults in API
# TODO(KshitijLakhani): Remove warning in TransformerEngine v2.12
# None implies that the user is relying on defaults, hence warn the user and set the new defaults
if self.transpose_batch_sequence is None:
warnings.warn(
"transpose_batch_sequence defaults to False in MultiHeadAttention starting"
" TransformerEngine v2.10"
)
self.transpose_batch_sequence = False
# Deal with the deprecated parameters # Deal with the deprecated parameters
if self.num_heads is not None: if self.num_heads is not None:
self.num_attention_heads = self.num_heads self.num_attention_heads = self.num_heads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment