Unverified Commit e19b8281 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Fixes for CI failures with the latest JAX (#1469)



* fixes L1 test

* fix test_multigpu_encoder

* fixes for other multi-encoder tests

* jax.extend.ffi to jax.ffi

* initialization with float32

* add init_dtype as an optional arg to all modules

* update use_scan query from xla flags

* relax threshold for test_encoder fp8

* relax the tols

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 24e4f955
......@@ -239,7 +239,7 @@ def get_params_sharding(sharding_rules, abs_var_collect, mesh):
)
params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
params_sharding = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY]
lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY]
)
params_sharding = {**params_sharding, **params_axes_sharding}
return params_sharding
......@@ -447,7 +447,7 @@ class TestEncoder(unittest.TestCase):
"""Test Transformer Engine with FP8"""
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_sp(self):
......@@ -462,7 +462,7 @@ class TestEncoder(unittest.TestCase):
self.args.enable_sp = True
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
assert actual[0] < 0.455 and actual[1] > 0.785
if __name__ == "__main__":
......
......@@ -218,7 +218,7 @@ def get_params_sharding(sharding_rules, abs_var_collect, mesh):
)
params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
params_sharding = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY]
lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY]
)
params_sharding = {**params_sharding, **params_axes_sharding}
return params_sharding
......
......@@ -320,7 +320,7 @@ def get_params_sharding(sharding_rules, abs_var_collect, mesh):
)
params_axes_sharding = flax.core.unfreeze(params_axes_sharding)
params_sharding = jax.tree_util.tree_map(
lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY]
lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY]
)
params_sharding = {**params_sharding, **params_axes_sharding}
return params_sharding
......@@ -587,7 +587,7 @@ class TestEncoder(unittest.TestCase):
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
result = self.exec(True)
assert result[0] < 0.45 and result[1] > 0.79
assert result[0] < 0.455 and result[1] > 0.79
if __name__ == "__main__":
......
......@@ -334,7 +334,7 @@ class TestEncoder(unittest.TestCase):
"""Test Transformer Engine with FP8"""
self.args.use_fp8 = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
assert actual[0] < 0.455 and actual[1] > 0.79
if __name__ == "__main__":
......
......@@ -6,10 +6,4 @@ set -xe
: ${TE_PATH:=/opt/transformerengine}
# Skip ring attention tests since they need fixed environment vars
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* -k 'not test_context_parallel_ring_attn'
# Test ring attention with and without scan loop
NVTE_FUSED_RING_ATTENTION_USE_SCAN=0 pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn
NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 XLA_FLAGS="--xla_experimental_ignore_channel_id" \
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*
......@@ -2,6 +2,8 @@
#
# See LICENSE for license information.
import os
import pytest
import jax
import jax.numpy as jnp
import numpy as np
......@@ -11,7 +13,7 @@ from distributed_test_base import (
generate_context_parallel_configs,
generate_collectives_count,
)
from transformer_engine.jax import fp8_autocast
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
AttnBiasType,
......@@ -22,10 +24,7 @@ from transformer_engine.jax.attention import (
inverse_reorder_causal_load_balancing,
CPStrategy,
)
from transformer_engine.jax.sharding import MeshResource
import pytest
from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat
DTYPES = [jnp.bfloat16]
......@@ -355,6 +354,10 @@ class TestDistributedContextParallelSelfAttn:
CPStrategy.ALL_GATHER,
)
@pytest.mark.parametrize(
"use_scan",
[pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")],
)
def test_context_parallel_ring_attn(
self,
device_count,
......@@ -367,8 +370,14 @@ class TestDistributedContextParallelSelfAttn:
dtype,
qkv_layout,
load_balanced,
use_scan,
):
return self.impl_test_context_parallel_attn(
if use_scan:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1"
else:
os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0"
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
......@@ -381,6 +390,7 @@ class TestDistributedContextParallelSelfAttn:
load_balanced,
CPStrategy.RING,
)
del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]
class TestReorderCausalLoadBalancing:
......
......@@ -11,7 +11,7 @@ import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from jax import ffi
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import NVTE_Activation_Type
......
......@@ -15,7 +15,7 @@ from jax import dtypes, lax
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from jax import ffi
from transformer_engine.jax.attention import CPStrategy, SequenceDescriptor
......@@ -1602,9 +1602,7 @@ class _FusedAttnCPWithP2PHelper:
def truthy(val):
return val.lower() in ["1", "true"]
x = use_scan and get_xla_flag(
"--xla_experimental_ignore_channel_id", default=False, cast=truthy
)
x = use_scan and get_xla_flag("--xla_ignore_channel_id", default=True, cast=truthy)
return x
def check_supported(self):
......
......@@ -5,9 +5,8 @@
from dataclasses import dataclass
from enum import IntEnum
import jax
from jax.interpreters import mlir
import jax.extend as jex
from transformer_engine import transformer_engine_jax
from .misc import is_ffi_enabled
......@@ -30,11 +29,11 @@ class CustomCallAPIVersion(IntEnum):
for _name, _value in transformer_engine_jax.registrations().items():
if _name.endswith("_ffi"):
if is_ffi_enabled():
jex.ffi.register_ffi_target(
jax.ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value
)
else:
jex.ffi.register_ffi_target(
jax.ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value
)
......
......@@ -13,7 +13,7 @@ from jax import dtypes
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from jax import ffi
from transformer_engine import transformer_engine_jax
......
......@@ -9,7 +9,7 @@ import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from jax import ffi
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType
......
......@@ -12,7 +12,7 @@ import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from jax import ffi
from transformer_engine import transformer_engine_jax
......
......@@ -11,7 +11,7 @@ import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from jax import ffi
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType
......
......@@ -150,8 +150,8 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
del self.scale_factor
if self.float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
query = query.astype(self.dtype)
key = key.astype(self.dtype)
h_q, h_kv = query.shape[-2], key.shape[-2]
# The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
# Therefore, we have to maintain two code paths.
......@@ -989,6 +989,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", self.weight_dtype
)
self.kernel_init = _kernel_init.astype(self.dtype)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
......@@ -1281,7 +1282,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
f"expected query shape {expected_shape} instead got {query.shape}."
)
cur_index = cache_index.value
cur_index = cache_index.value.astype(jnp.int32)
one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype)
one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape)
key = cached_key.value + key * one_hot_indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment