Unverified Commit ab4fd3cf authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

Remove xla_ignore_channel_id check and ignore Scan loop warning in un… (#1540)



Remove xla_ignore_channel_id check and ignore Scan loop warning in unit test
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent f3a009da
...@@ -24,3 +24,4 @@ filterwarnings= ...@@ -24,3 +24,4 @@ filterwarnings=
ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning
ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning
ignore:The host_callback APIs are deprecated .*:DeprecationWarning ignore:The host_callback APIs are deprecated .*:DeprecationWarning
ignore:Scan loop is disabled for fused ring attention.*:UserWarning
...@@ -38,7 +38,6 @@ from .misc import ( ...@@ -38,7 +38,6 @@ from .misc import (
get_padded_spec, get_padded_spec,
get_cudnn_version, get_cudnn_version,
is_ffi_enabled, is_ffi_enabled,
get_xla_flag,
) )
from ..sharding import ( from ..sharding import (
global_mesh_resource, global_mesh_resource,
...@@ -1607,14 +1606,7 @@ class _FusedAttnCPWithP2PHelper: ...@@ -1607,14 +1606,7 @@ class _FusedAttnCPWithP2PHelper:
def use_scanloop(): def use_scanloop():
"""Returns true if the implementation will use a scan loop for iteration.""" """Returns true if the implementation will use a scan loop for iteration."""
use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "1"))) use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "1")))
return use_scan
# nvbug(4675071): Disable the HLO verifier for channel ID checks.
# A WAR was added to XLA: https://github.com/openxla/xla/pull/16779
def truthy(val):
return val.lower() in ["1", "true"]
x = use_scan and get_xla_flag("--xla_ignore_channel_id", default=True, cast=truthy)
return x
def check_supported(self): def check_supported(self):
"""Checks if the context parallel implementation is supported by the given arguments.""" """Checks if the context parallel implementation is supported by the given arguments."""
...@@ -1659,8 +1651,7 @@ class _FusedAttnCPWithP2PHelper: ...@@ -1659,8 +1651,7 @@ class _FusedAttnCPWithP2PHelper:
if not self.use_scanloop(): if not self.use_scanloop():
warnings.warn( warnings.warn(
"Scan loop is disabled for fused ring attention. To enable set" "Scan loop is disabled for fused ring attention. To enable set"
" NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment and" " NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment"
" add --xla_experimental_ignore_channel_id=true to XLA_FLAGS."
) )
def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: def get_step_config(self, attn_mask_type) -> _FusedAttnConfig:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment