Unverified Commit 5978f1d7 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Default to fused attention in JAX DPA (#2363)



* Default to fused attention in JAX DPA
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Consolidate documentation for DPA in JAX
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>

* Correctly update the documentation for defaults in JAX DPA
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>

---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarKshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent 26aad6b0
...@@ -407,10 +407,10 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -407,10 +407,10 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment
variable: variable:
* Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default). * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention.
* Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention (default). If the required cuDNN fused
kernel is not available on the system, a warning will be issued, and the module will attention kernel is not available on the system, a warning will be issued, and the module
automatically fall back to the unfused backend. will automatically fall back to the unfused backend.
.. note:: .. note::
The DotProductAttention default setting enables non-deterministic kernels for reduced The DotProductAttention default setting enables non-deterministic kernels for reduced
...@@ -602,7 +602,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -602,7 +602,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
else: else:
assert bias is not None assert bias is not None
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0")) # Use fused attn (if kernel check below passes) by default
enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1"))
sequence_dim = 0 if self.transpose_batch_sequence else 1 sequence_dim = 0 if self.transpose_batch_sequence else 1
seqlen_q = query.shape[sequence_dim] seqlen_q = query.shape[sequence_dim]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment