Unverified Commit beed55b9 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Set BSHD as default in Unfused DPA, DPA and MHA API calls (#2392)



* Make BSHD default for Unfused DPA, DPA and MHA in TE JAX
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Remove explicit transpose_batch set for BSHD for DPA in JAX quickstart
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Add warnings in DPA and MHA to warn users of change defaults to BSHD instead of SBHD
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Minimize the scope of when to trigger warnings for changed defaults for transpose_batch_sequence
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent b14f417a
......@@ -368,7 +368,6 @@
" num_gqa_groups=self.num_attention_heads, # No GQA\n",
" attention_dropout=self.attention_dropout,\n",
" attn_mask_type='causal',\n",
" transpose_batch_sequence=False, # Input format is [batch, seq_len, ...]\n",
" )\n",
" x = attention(q, k, v, sequence_descriptor=None, deterministic=deterministic)\n",
" # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n",
......@@ -628,7 +627,6 @@
" num_gqa_groups=self.num_attention_heads, \n",
" attention_dropout=self.attention_dropout,\n",
" attn_mask_type='causal',\n",
" transpose_batch_sequence=False, # Input format is [batch, seq_len, ...]\n",
" )\n",
" x = attention(q, k, v, sequence_descriptor=None, deterministic=deterministic)\n",
" # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n",
......
......@@ -124,7 +124,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
dtype: DType = jnp.float32
float32_logits: bool = False
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
transpose_batch_sequence: bool = False
window_size: Optional[Tuple[int, int]] = None
softmax_type: AttnSoftmaxType = AttnSoftmaxType.VANILLA_SOFTMAX
......@@ -544,9 +544,10 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal
to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't
need to apply scale on query, which is to set :attr:`scale_factor=1.`.
transpose_batch_sequence: bool, default = True
TODO(KshitijLakhani): Reset this to bool only with default False arg in TransformerEngine v2.12
transpose_batch_sequence: bool | None, default = None (however, default is forced to False in post_init)
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
and sequence length dimension. If set to True, the input tensors
should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...).
window_size: Optional[Tuple[int, int]], default = None
Sliding window size. The default value is no sliding window.
......@@ -586,7 +587,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
float32_logits: bool = False
qkv_layout: str = "bshd_bshd_bshd"
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
transpose_batch_sequence: bool | None = None
window_size: Optional[Tuple[int, int]] = None
max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False
......@@ -595,6 +596,17 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
context_checkpoint_name: str = "context"
softmax_type: str = "vanilla"
def __post_init__(self):
# TODO(KshitijLakhani): Remove warning in TransformerEngine v2.12
# None implies that the user is relying on defaults, hence warn the user and set the new defaults
if self.transpose_batch_sequence is None:
warnings.warn(
"transpose_batch_sequence defaults to False in DotProductAttention starting"
" TransformerEngine v2.10"
)
self.transpose_batch_sequence = False
super().__post_init__()
@nn.compact
def __call__(
self,
......@@ -1047,7 +1059,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
If set to True, this module exposes a single fused
parameter for query-key-value for self-attention and key-value for
cross-attention.
transpose_batch_sequence: bool, default = True
TODO(KshitijLakhani): Reset this to bool only with default False arg in TransformerEngine v2.12
transpose_batch_sequence: bool | None, default = None (however, default is forced to False in post_init)
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. if set to True, the input tensors
should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden).
......@@ -1100,7 +1113,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True
transpose_batch_sequence: bool | None = None
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
......@@ -1116,6 +1129,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
fuse_qkv: Optional[bool] = None
def __post_init__(self):
# Deal with changed defaults in API
# TODO(KshitijLakhani): Remove warning in TransformerEngine v2.12
# None implies that the user is relying on defaults, hence warn the user and set the new defaults
if self.transpose_batch_sequence is None:
warnings.warn(
"transpose_batch_sequence defaults to False in MultiHeadAttention starting"
" TransformerEngine v2.10"
)
self.transpose_batch_sequence = False
# Deal with the deprecated parameters
if self.num_heads is not None:
self.num_attention_heads = self.num_heads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment