Unverified Commit 887a4fca authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Unset NVTE_FUSED_RING_ATTENTION_USE_SCAN by default (#2503)



* Unset NVTE_FUSED_RING_ATTENTION_USE_SCAN by default
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Add TODO
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Change the warning check in P2P helper to warn against using scan loop
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 50352325
......@@ -2350,7 +2350,8 @@ class _FusedAttnCPWithP2PHelper:
@staticmethod
def use_scanloop():
"""Returns true if the implementation will use a scan loop for iteration."""
use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "1")))
# TODO(KshitijLakhani): Reset default to 1, once the extra kv permute op issue is resolved
use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "0")))
return use_scan
def check_supported(self):
......@@ -2395,13 +2396,15 @@ class _FusedAttnCPWithP2PHelper:
f"{header} only supports VANILLA_SOFTMAX, got: {self.config.softmax_type}"
)
# We want to encourage use of scan loop to minimize unrolling and ensure more
# predictable scheduling from XLA. The unrolled flavor will be supported but
# not the prefered implementation.
if not self.use_scanloop():
# TODO(KshitijLakhani): Flip the condition to check for disabled scan loop and warn
# against using unrolled loops once the scan issue is resolved.
# We want to discourage the use of scan loop as additional kv permute op observed.
# The scan loop flavor will be supported but not the prefered implementation until
# a resolution for the additional kv permute op, which degrades perf, is found.
if self.use_scanloop():
warnings.warn(
"Scan loop is disabled for fused ring attention. To enable set"
" NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment"
"Scan loop is enabled for fused ring attention. To disable set"
" NVTE_FUSED_RING_ATTENTION_USE_SCAN=0 in your environment"
)
# If using scanloop, idx in scan_kv_block() will be a traced device value, but
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment