"vscode:/vscode.git/clone" did not exist on "91b754e0cd81e24532684f15d9adb0b871fc1928"
Unverified Commit 887a4fca authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

[JAX] Unset NVTE_FUSED_RING_ATTENTION_USE_SCAN by default (#2503)



* Unset NVTE_FUSED_RING_ATTENTION_USE_SCAN by default
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Add TODO
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Change the warning check in P2P helper to warn against using scan loop
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 50352325
...@@ -2350,7 +2350,8 @@ class _FusedAttnCPWithP2PHelper: ...@@ -2350,7 +2350,8 @@ class _FusedAttnCPWithP2PHelper:
@staticmethod @staticmethod
def use_scanloop(): def use_scanloop():
"""Returns true if the implementation will use a scan loop for iteration.""" """Returns true if the implementation will use a scan loop for iteration."""
use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "1"))) # TODO(KshitijLakhani): Reset default to 1, once the extra kv permute op issue is resolved
use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "0")))
return use_scan return use_scan
def check_supported(self): def check_supported(self):
...@@ -2395,13 +2396,15 @@ class _FusedAttnCPWithP2PHelper: ...@@ -2395,13 +2396,15 @@ class _FusedAttnCPWithP2PHelper:
f"{header} only supports VANILLA_SOFTMAX, got: {self.config.softmax_type}" f"{header} only supports VANILLA_SOFTMAX, got: {self.config.softmax_type}"
) )
# We want to encourage use of scan loop to minimize unrolling and ensure more # TODO(KshitijLakhani): Flip the condition to check for disabled scan loop and warn
# predictable scheduling from XLA. The unrolled flavor will be supported but # against using unrolled loops once the scan issue is resolved.
# not the prefered implementation. # We want to discourage the use of scan loop as additional kv permute op observed.
if not self.use_scanloop(): # The scan loop flavor will be supported but not the prefered implementation until
# a resolution for the additional kv permute op, which degrades perf, is found.
if self.use_scanloop():
warnings.warn( warnings.warn(
"Scan loop is disabled for fused ring attention. To enable set" "Scan loop is enabled for fused ring attention. To disable set"
" NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment" " NVTE_FUSED_RING_ATTENTION_USE_SCAN=0 in your environment"
) )
# If using scanloop, idx in scan_kv_block() will be a traced device value, but # If using scanloop, idx in scan_kv_block() will be a traced device value, but
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment