Commit c1a1c04e authored by wenjh's avatar wenjh
Browse files

Merge nv_main(2.10) to main


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents e698a0a7 66aed3ae
...@@ -45,7 +45,8 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml ...@@ -45,7 +45,8 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
......
...@@ -2,11 +2,44 @@ ...@@ -2,11 +2,44 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
set -xe function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}
RET=0
FAILED_CASES=""
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs} : ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* export NVTE_JAX_UNITTEST_LEVEL="L1"
SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh
# Use --xla_gpu_enable_triton_gemm=false to ensure the reference JAX implementation we are using is accurate.
export XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_triton_gemm=false"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_dense.xml $TE_PATH/tests/jax/test_distributed_dense.py || test_fail "test_distributed_dense.py"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_helper.xml $TE_PATH/tests/jax/test_distributed_helper.py || test_fail "test_distributed_helper.py"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_layernorm.xml $TE_PATH/tests/jax/test_distributed_layernorm.py || test_fail "test_distributed_layernorm.py"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_mlp.xml $TE_PATH/tests/jax/test_distributed_layernorm_mlp.py || test_fail "test_distributed_layernorm_mlp.py"
# XLA_FLAGS to WAR for test_distributed_softmax issue with NCCL
# TODO(Kshitij): remove when NCCL issue is fixed
XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_nccl_comm_splitting=false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_softmax.xml $TE_PATH/tests/jax/test_distributed_softmax.py || test_fail "test_distributed_softmax.py"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_fused_attn.xml $TE_PATH/tests/jax/test_distributed_fused_attn.py || test_fail "test_distributed_fused_attn.py"
# TODO(Phuong): add this test back after it is verified
# SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh || test_fail "test_multi_process_distributed_grouped_gemm.py"
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
...@@ -8,4 +8,5 @@ set -xe ...@@ -8,4 +8,5 @@ set -xe
: ${XML_LOG_DIR:=/logs} : ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* # Use --xla_gpu_enable_triton_gemm=false to ensure the reference JAX implementation we are using is accurate.
XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_triton_gemm=false" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
...@@ -18,10 +18,10 @@ sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); pri ...@@ -18,10 +18,10 @@ sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); pri
export FLASH_ATTN_CUDA_ARCHS=$sm_arch export FLASH_ATTN_CUDA_ARCHS=$sm_arch
if [ $sm_arch -gt 90 ] if [ $sm_arch -gt 90 ]
then then
FA_versions=(2.8.1) FA_versions=(2.8.3)
elif [ $sm_arch -eq 90 ] elif [ $sm_arch -eq 90 ]
then then
FA_versions=(2.7.3 2.8.1 3.0.0b1) FA_versions=(2.7.3 2.8.3 3.0.0b1)
fi fi
for fa_version in "${FA_versions[@]}" for fa_version in "${FA_versions[@]}"
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
from importlib import metadata from importlib import metadata
import os import os
import shutil
import subprocess
import time import time
from pathlib import Path from pathlib import Path
from typing import List, Tuple from typing import List, Tuple
...@@ -149,9 +151,64 @@ def setup_requirements() -> Tuple[List[str], List[str]]: ...@@ -149,9 +151,64 @@ def setup_requirements() -> Tuple[List[str], List[str]]:
return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]] return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]]
def git_check_submodules() -> None:
"""
Attempt to checkout git submodules automatically during setup.
This runs successfully only if the submodules are
either in the correct or uninitialized state.
Note to devs: With this, any updates to the submodules itself, e.g. moving to a newer
commit, must be commited before build. This also ensures that stale submodules aren't
being silently used by developers.
"""
# Provide an option to skip these checks for development.
if bool(int(os.getenv("NVTE_SKIP_SUBMODULE_CHECKS_DURING_BUILD", "0"))):
return
# Require git executable.
if shutil.which("git") is None:
return
# Require a .gitmodules file.
if not (current_file_path / ".gitmodules").exists():
return
try:
submodules = subprocess.check_output(
["git", "submodule", "status", "--recursive"],
cwd=str(current_file_path),
text=True,
).splitlines()
for submodule in submodules:
# '-' start is for an uninitialized submodule.
# ' ' start is for a submodule on the correct commit.
assert submodule[0] in (
" ",
"-",
), (
"Submodules are initialized incorrectly. If this is intended, set the "
"environment variable `NVTE_SKIP_SUBMODULE_CHECKS_DURING_BUILD` to a "
"non-zero value to skip these checks during development. Otherwise, "
"run `git submodule update --init --recursive` to checkout the correct"
" submodule commits."
)
subprocess.check_call(
["git", "submodule", "update", "--init", "--recursive"],
cwd=str(current_file_path),
)
except subprocess.CalledProcessError:
return
if __name__ == "__main__": if __name__ == "__main__":
__version__ = te_version() __version__ = te_version()
git_check_submodules()
with open("README.rst", encoding="utf-8") as f: with open("README.rst", encoding="utf-8") as f:
long_description = f.read() long_description = f.read()
...@@ -163,8 +220,11 @@ if __name__ == "__main__": ...@@ -163,8 +220,11 @@ if __name__ == "__main__":
ext_modules = [] ext_modules = []
package_data = {} package_data = {}
include_package_data = False include_package_data = False
install_requires = ([f"transformer_engine_cu12=={__version__}"],) install_requires = []
extras_require = { extras_require = {
"core": [f"transformer_engine_cu12=={__version__}"],
"core_cu12": [f"transformer_engine_cu12=={__version__}"],
"core_cu13": [f"transformer_engine_cu13=={__version__}"],
"pytorch": [f"transformer_engine_torch=={__version__}"], "pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"],
} }
......
...@@ -66,12 +66,6 @@ else() ...@@ -66,12 +66,6 @@ else()
add_executable(test_operator ${test_hip_sources}) add_executable(test_operator ${test_hip_sources})
endif() endif()
# Add profiling and debug flags for CUDA compilation
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -lineinfo") # Generate line info for device code
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g") # Add debug symbols for host code
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --ptxas-options=-v") # Add info about registers usage
# Note: Using -lineinfo instead of -G to avoid conflicts and get line mapping
# Find required packages # Find required packages
find_package(OpenMP REQUIRED) find_package(OpenMP REQUIRED)
if(USE_CUDA) if(USE_CUDA)
......
...@@ -661,14 +661,9 @@ std::vector<std::vector<size_t>> tensor_dims = { ...@@ -661,14 +661,9 @@ std::vector<std::vector<size_t>> tensor_dims = {
{4096, 13312}, {4096, 13312},
}; };
// Only GeLU activation tests are supported // Only the Identity activation is currently supported.
std::vector<ActivationType> Activation_types = { std::vector<ActivationType> Activation_types = {
ActivationType::Identity, ActivationType::Identity
ActivationType::GeLU,
ActivationType::SiLU,
ActivationType::ReLU,
ActivationType::QGeLU,
ActivationType::SReLU,
}; };
} // namespace } // namespace
......
...@@ -128,8 +128,18 @@ void compute_ref_output(NormType norm_type, ...@@ -128,8 +128,18 @@ void compute_ref_output(NormType norm_type,
tmp = current * rsigma[i] * g; tmp = current * rsigma[i] * g;
} }
// Write output (scaled only for fp8 paths)
output[i * H + j] = static_cast<OutputType>(tmp * scale); output[i * H + j] = static_cast<OutputType>(tmp * scale);
// amax semantics:
// - fp8_out (scale != 1): amax on pre-scale compute value 'tmp'
// - non-fp8_out (scale == 1): amax on value converted to OutputType (e.g., bf16)
if (scale != 1.f) {
current_max = fmaxf(current_max, fabsf(tmp)); current_max = fmaxf(current_max, fabsf(tmp));
} else {
OutputType out_t_val = static_cast<OutputType>(tmp);
current_max = fmaxf(current_max, fabsf(static_cast<compute_t>(out_t_val)));
}
} }
} }
......
...@@ -8,7 +8,7 @@ from itertools import product ...@@ -8,7 +8,7 @@ from itertools import product
import pytest import pytest
import jax import jax
from jax.experimental.pjit import pjit, _UNSPECIFIED from jax._src.sharding_impls import UNSPECIFIED as _UNSPECIFIED
from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.sharding import MeshResource
...@@ -154,13 +154,15 @@ def compare_ops( ...@@ -154,13 +154,15 @@ def compare_ops(
grad_args = tuple(range(len(inputs))) grad_args = tuple(range(len(inputs)))
target_grad_func = jax.value_and_grad(target_func, argnums=grad_args) target_grad_func = jax.value_and_grad(target_func, argnums=grad_args)
target_pjitter = pjit(target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings) target_jitter = jax.jit(
target_fwd, target_grads = target_pjitter(*inputs, **kwargs) target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings
target_hlo = target_pjitter.lower(*inputs, **kwargs).compile().as_text() )
target_fwd, target_grads = target_jitter(*inputs, **kwargs)
target_hlo = target_jitter.lower(*inputs, **kwargs).compile().as_text()
ref_grad_func = jax.value_and_grad(ref_func, argnums=grad_args) ref_grad_func = jax.value_and_grad(ref_func, argnums=grad_args)
ref_pjitter = pjit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings) ref_jitter = jax.jit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings)
ref_fwd, ref_grads = ref_pjitter(*inputs, **kwargs) ref_fwd, ref_grads = ref_jitter(*inputs, **kwargs)
assert_allclose(target_fwd, ref_fwd, dtype=metric_fwd_dtype) assert_allclose(target_fwd, ref_fwd, dtype=metric_fwd_dtype)
......
...@@ -18,6 +18,14 @@ do ...@@ -18,6 +18,14 @@ do
CUDA_VISIBLE_DEVICES=$i python $SCRIPT_NAME 127.0.0.1:12345 $i $NUM_RUNS > /dev/null 2>&1 & CUDA_VISIBLE_DEVICES=$i python $SCRIPT_NAME 127.0.0.1:12345 $i $NUM_RUNS > /dev/null 2>&1 &
done done
CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS | tee stdout_multi_process.txt
wait wait
RET=0
if grep -q "FAILED" stdout_multi_process.txt; then
RET=1
fi
rm -f stdout_multi_process.txt
exit "$RET"
...@@ -40,13 +40,13 @@ from transformer_engine.jax.quantize import ( ...@@ -40,13 +40,13 @@ from transformer_engine.jax.quantize import (
QuantizerFactory, QuantizerFactory,
QuantizeLayout, QuantizeLayout,
noop_quantizer_set, noop_quantizer_set,
should_use_rht, QuantizeMetaSet,
QuantizeMeta,
) )
from transformer_engine.jax.quantize import helper from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.dense import dense, grouped_dense
from transformer_engine.jax.layernorm_dense import layernorm_dense from transformer_engine.jax.layernorm_dense import layernorm_dense
from transformer_engine.common import recipe
GEMM_CASES = [ GEMM_CASES = [
(256, 256, 512), (256, 256, 512),
...@@ -606,7 +606,12 @@ class TestNorm: ...@@ -606,7 +606,12 @@ class TestNorm:
) )
@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest.mark.parametrize("out_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize(
"out_dtype",
[
jnp.float8_e4m3fn,
],
)
def test_norm_forward_with_block_scaling_fp8( def test_norm_forward_with_block_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype
): ):
...@@ -685,21 +690,14 @@ class TestQuantize: ...@@ -685,21 +690,14 @@ class TestQuantize:
Purely quantization related tests that will always test on a wider set of types and shapes Purely quantization related tests that will always test on a wider set of types and shapes
""" """
def _skip_for_fp4(self, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): def _skip_unsupported_dtypes(self, q_dtype, scaling_mode):
"""Temporary hack to skip unsupported FP4 cases until we implement them""" """Skip unsupported dtypes for given scaling mode. For example, NVFP4 only supports the float4_e2m1 dtype not float8 dtypes."""
if q_dtype not in scaling_mode.get_compatible_q_dtypes(): if q_dtype not in scaling_mode.get_compatible_q_dtypes():
pytest.skip(f"Quantize dtype {q_dtype} is not supported by {scaling_mode}") pytest.skip(f"Quantize dtype {q_dtype} is not supported by {scaling_mode}")
return return
# HACK: FIXME TODO(jberchtold)
row = reduce(operator.mul, input_shape[flatten_axis:], 1)
col = reduce(operator.mul, input_shape[:flatten_axis], 1)
will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout)
if will_use_rht and (row % 64 != 0 or col % 128 != 0):
pytest.skip("Unfused RHT is not supported currently, skipping")
def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis) self._skip_unsupported_dtypes(q_dtype, scaling_mode)
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
...@@ -780,22 +778,8 @@ class TestQuantize: ...@@ -780,22 +778,8 @@ class TestQuantize:
assert_dequantized_scaled_tensor(scaled_tensor, x) assert_dequantized_scaled_tensor(scaled_tensor, x)
def _should_use_precise_comparison( def _should_use_precise_comparison(
self, in_dtype, scaling_mode, q_layout, input_shape, flatten_axis self, in_dtype, scaling_mode, quantizer, input_shape, flatten_axis
): ):
# TODO(jberchtold): Remove this hack once we have a better solution to ensure bitwise identical results between TE and JAX RHT+quant implementations. Currently for certain shapes the quantized fp4 data differs by a small amount on <0.5% of the values.
RHT_SLIGHT_MISMATCH_SHAPES = [
((32, 256, 128), -1),
((64, 32, 32, 256), -1),
((8192, 2, 4096), -2),
]
if (
should_use_rht(scaling_mode, q_layout=q_layout)
and (input_shape, flatten_axis) in RHT_SLIGHT_MISMATCH_SHAPES
):
# TE fused RHT+quant and JAX RHT+quant have slight implementation differences which can lead to small numerical differences on certain shapes
return False
if scaling_mode.is_nvfp4_scaling and in_dtype != jnp.bfloat16: if scaling_mode.is_nvfp4_scaling and in_dtype != jnp.bfloat16:
# With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation # With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation
return False return False
...@@ -805,7 +789,7 @@ class TestQuantize: ...@@ -805,7 +789,7 @@ class TestQuantize:
def test_quantize_bitwise( def test_quantize_bitwise(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
): ):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis) self._skip_unsupported_dtypes(q_dtype, scaling_mode)
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype) input = jax.random.uniform(key, input_shape, in_dtype)
...@@ -816,28 +800,20 @@ class TestQuantize: ...@@ -816,28 +800,20 @@ class TestQuantize:
jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
try:
te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis) te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
except AssertionError as e:
if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16:
error_message = e.args[0]
if "RHT requires input to be bfloat16" in error_message:
# Successfully caught the expected error, early return from the test
return
raise e
assert_bitwise_scaled_tensors( assert_bitwise_scaled_tensors(
te_output, te_output,
jax_output, jax_output,
precise_comparison=self._should_use_precise_comparison( precise_comparison=self._should_use_precise_comparison(
in_dtype, scaling_mode, q_layout, input_shape, flatten_axis in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis
), ),
) )
def test_quantize_bitwise_jitted( def test_quantize_bitwise_jitted(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
): ):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis) self._skip_unsupported_dtypes(q_dtype, scaling_mode)
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype) input = jax.random.uniform(key, input_shape, in_dtype)
...@@ -851,21 +827,13 @@ class TestQuantize: ...@@ -851,21 +827,13 @@ class TestQuantize:
jax_output = jax_impl_func_jit(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) jax_output = jax_impl_func_jit(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
try:
te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis) te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
except AssertionError as e:
if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16:
error_message = e.args[0]
if "RHT requires input to be bfloat16" in error_message:
# Successfully caught the expected error, early return from the test
return
raise e
assert_bitwise_scaled_tensors( assert_bitwise_scaled_tensors(
te_output, te_output,
jax_output, jax_output,
precise_comparison=self._should_use_precise_comparison( precise_comparison=self._should_use_precise_comparison(
in_dtype, scaling_mode, q_layout, input_shape, flatten_axis in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis
), ),
) )
...@@ -914,7 +882,7 @@ class TestStochasticRounding: ...@@ -914,7 +882,7 @@ class TestStochasticRounding:
for i in range(num_samples): for i in range(num_samples):
iter_key = jax.random.fold_in(key, i) iter_key = jax.random.fold_in(key, i)
sr_rng_state = jax.random.randint( sr_rng_state = jax.random.randint(
iter_key, (4,), minval=0, maxval=2**30 - 1, dtype=jnp.uint32 iter_key, (1, 4), minval=0, maxval=2**30 - 1, dtype=jnp.uint32
) )
quantizer = QuantizerFactory.create( quantizer = QuantizerFactory.create(
q_dtype=q_dtype, q_dtype=q_dtype,
...@@ -985,12 +953,6 @@ class TestStochasticRounding: ...@@ -985,12 +953,6 @@ class TestStochasticRounding:
def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
"""Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other.""" """Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other."""
# HACK: FIXME TODO(jberchtold)
row = reduce(operator.mul, input_shape[flatten_axis:], 1)
col = reduce(operator.mul, input_shape[:flatten_axis], 1)
will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout)
if will_use_rht and (row % 64 != 0 or col % 128 != 0):
pytest.skip("Unfused RHT is not supported currently, skipping")
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
inputs = jax.random.uniform(key, input_shape, in_dtype) inputs = jax.random.uniform(key, input_shape, in_dtype)
...@@ -1007,6 +969,97 @@ class TestStochasticRounding: ...@@ -1007,6 +969,97 @@ class TestStochasticRounding:
assert_allclose(te_mean_error, jax_mean_error, rtol=0.2, atol=1e-4) assert_allclose(te_mean_error, jax_mean_error, rtol=0.2, atol=1e-4)
@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16])
@pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn])
@pytest_parametrize_wrapper(
"scaling_mode", [s for s in supported_scaling_modes if s == ScalingMode.NVFP4_1D_SCALING]
)
class TestRandomizedHadamardTransform:
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE]
)
@pytest_parametrize_wrapper("input_shape,flatten_axis", [((64, 128), -1)])
def test_rht_quantize_bitwise_jitted(
self, in_dtype, q_dtype, scaling_mode, q_layout, input_shape, flatten_axis
):
key = jax.random.PRNGKey(0)
inputs = jax.random.uniform(key, input_shape, in_dtype)
te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2,
q_dtype=q_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
use_rht=True,
)
jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3))
te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,))
jax_output = jax_impl_func_jit(inputs, quantizer=jax_quantizer, flatten_axis=flatten_axis)
te_output = te_impl_func_jit(inputs, quantizer=te_quantizer, flatten_axis=flatten_axis)
assert_bitwise_scaled_tensors(te_output, jax_output)
def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
if data_layout[0] == "T":
a = jnp.swapaxes(a, -1, -2)
if data_layout[1] == "T":
b = jnp.swapaxes(b, -1, -2)
return jnp.dot(a, b)
def _generate_gemm_input(self, m, n, k, data_layout):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
x = jax.random.uniform(
subkeys[0],
(m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
dtype=jnp.bfloat16,
) / jnp.sqrt(k)
w = jax.random.uniform(
subkeys[1],
(k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
dtype=jnp.bfloat16,
) / jnp.sqrt(n)
lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
return (x, w, contracting_dims)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
# We do not test NN and TT layouts here as they do not have both inputs using RHT due to RHT only supporting the colwise layout currently
@pytest_parametrize_wrapper("data_layout", ["TN", "NT"])
@pytest_parametrize_wrapper("with_jax_gemm", [True, False])
def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, with_jax_gemm):
key = jax.random.PRNGKey(0)
lhs_scaling_mode, rhs_scaling_mode = scaling_mode, scaling_mode
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
lhs_quantizer = QuantizerFactory.create(
scaling_mode=lhs_scaling_mode,
q_dtype=jnp.float4_e2m1fn,
use_rht=True,
)
rhs_quantizer = QuantizerFactory.create(
scaling_mode=rhs_scaling_mode,
q_dtype=jnp.float4_e2m1fn,
use_rht=True,
)
with use_jax_gemm(enabled=with_jax_gemm):
primitive_out = tex.gemm(
x,
w,
contracting_dims=contracting_dims,
lhs_quantizer=lhs_quantizer,
rhs_quantizer=rhs_quantizer,
)
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=jnp.float4_e2m1fn)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)]) @pytest_parametrize_wrapper("input_shape", [(8, 16, 32)])
...@@ -1406,7 +1459,12 @@ class TestDense: ...@@ -1406,7 +1459,12 @@ class TestDense:
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe) quantizer_set = QuantizerFactory.create_set(
fp8_recipe=recipe,
quantize_meta_set=QuantizeMetaSet(
x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
),
)
n_iterations = 3 if recipe.delayed() else 1 n_iterations = 3 if recipe.delayed() else 1
with use_jax_gemm(enabled=with_jax_gemm): with use_jax_gemm(enabled=with_jax_gemm):
...@@ -1465,7 +1523,12 @@ class TestFusedDense: ...@@ -1465,7 +1523,12 @@ class TestFusedDense:
gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16) gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe) quantizer_set = QuantizerFactory.create_set(
fp8_recipe=recipe,
quantize_meta_set=QuantizeMetaSet(
x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
),
)
if norm_type == "layernorm": if norm_type == "layernorm":
beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16) beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
...@@ -1554,6 +1617,9 @@ class TestFusedDense: ...@@ -1554,6 +1617,9 @@ class TestFusedDense:
quantizer_sets = QuantizerFactory.create_set( quantizer_sets = QuantizerFactory.create_set(
n_quantizer_sets=2, n_quantizer_sets=2,
fp8_recipe=recipe, fp8_recipe=recipe,
quantize_meta_set=QuantizeMetaSet(
x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta()
),
) )
if norm_type == "layernorm": if norm_type == "layernorm":
......
...@@ -134,9 +134,12 @@ class TestDistributedLayernorm: ...@@ -134,9 +134,12 @@ class TestDistributedLayernorm:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) x_named_sharding = NamedSharding(mesh, x_pspec)
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) g_named_sharding = NamedSharding(mesh, g_pspec)
beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec)) b_named_sharding = NamedSharding(mesh, b_pspec)
x_ = jax.device_put(x, x_named_sharding)
gamma_ = jax.device_put(gamma, g_named_sharding)
beta_ = jax.device_put(beta, b_named_sharding)
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
...@@ -148,8 +151,11 @@ class TestDistributedLayernorm: ...@@ -148,8 +151,11 @@ class TestDistributedLayernorm:
grad_args=(0, 1, 2), grad_args=(0, 1, 2),
metric_fwd_dtype=q_dtype, metric_fwd_dtype=q_dtype,
metric_bwd_dtype=q_dtype, metric_bwd_dtype=q_dtype,
in_shardings=(x_pspec, g_pspec, b_pspec), in_shardings=(x_named_sharding, g_named_sharding, b_named_sharding),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)), out_shardings=(
None,
(x_named_sharding, g_named_sharding, b_named_sharding),
),
) )
except AssertionError as err: except AssertionError as err:
# Layernorm should still produce the correct numerical result with # Layernorm should still produce the correct numerical result with
...@@ -210,8 +216,10 @@ class TestDistributedLayernorm: ...@@ -210,8 +216,10 @@ class TestDistributedLayernorm:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) x_named_sharding = NamedSharding(mesh, x_pspec)
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) g_named_sharding = NamedSharding(mesh, g_pspec)
x_ = jax.device_put(x, x_named_sharding)
gamma_ = jax.device_put(gamma, g_named_sharding)
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
...@@ -223,8 +231,8 @@ class TestDistributedLayernorm: ...@@ -223,8 +231,8 @@ class TestDistributedLayernorm:
grad_args=(0, 1), grad_args=(0, 1),
metric_fwd_dtype=q_dtype, metric_fwd_dtype=q_dtype,
metric_bwd_dtype=q_dtype, metric_bwd_dtype=q_dtype,
in_shardings=(x_pspec, g_pspec), in_shardings=(x_named_sharding, g_named_sharding),
out_shardings=(None, (x_pspec, g_pspec)), out_shardings=(None, (x_named_sharding, g_named_sharding)),
) )
except AssertionError as err: except AssertionError as err:
# RmsNorm should still produce the correct numerical result with # RmsNorm should still produce the correct numerical result with
......
...@@ -389,6 +389,7 @@ class TestDistributedLayernormMLP: ...@@ -389,6 +389,7 @@ class TestDistributedLayernormMLP:
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
activations=activation_type, activations=activation_type,
use_bias=use_bias, use_bias=use_bias,
return_layernorm_output=True,
) )
params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
mlp_out_single, ln_out_single = ln_mlp_single.apply( mlp_out_single, ln_out_single = ln_mlp_single.apply(
...@@ -417,6 +418,7 @@ class TestDistributedLayernormMLP: ...@@ -417,6 +418,7 @@ class TestDistributedLayernormMLP:
dot_1_input_axes=DOT_1_INPUT_AXES, dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES, dot_2_input_axes=DOT_2_INPUT_AXES,
name="mlp", name="mlp",
return_layernorm_output=True,
) )
params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True) params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply( mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
......
...@@ -103,8 +103,10 @@ class TestDistributedSoftmax: ...@@ -103,8 +103,10 @@ class TestDistributedSoftmax:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, autocast(mesh_resource=mesh_resource): with mesh, autocast(mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) x_named_sharding = NamedSharding(mesh, x_pspec)
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) mask_named_sharding = NamedSharding(mesh, mask_pspec)
x_ = jax.device_put(x, x_named_sharding)
mask_ = jax.device_put(mask, mask_named_sharding)
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
...@@ -116,8 +118,8 @@ class TestDistributedSoftmax: ...@@ -116,8 +118,8 @@ class TestDistributedSoftmax:
grad_args=(0,), grad_args=(0,),
metric_fwd_dtype=dtype, metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype, metric_bwd_dtype=dtype,
in_shardings=(x_pspec, mask_pspec), in_shardings=(x_named_sharding, mask_named_sharding),
out_shardings=(None, (x_pspec,)), out_shardings=(None, x_named_sharding),
) )
except AssertionError as err: except AssertionError as err:
# Softmax should still produce the correct numerical result with # Softmax should still produce the correct numerical result with
......
...@@ -378,14 +378,14 @@ class FusedAttnRunner: ...@@ -378,14 +378,14 @@ class FusedAttnRunner:
pytest.skip( pytest.skip(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
) )
# TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support
if ( if (
get_device_compute_capability(0) == 100 get_device_compute_capability(0) >= 100
and self.dropout_prob == 0.1 and self.dropout_prob == 0.1
and self.attn_bias_type is not AttnBiasType.NO_BIAS and self.attn_bias_type is not AttnBiasType.NO_BIAS
): ):
pytest.skip( pytest.skip(
"For sm100, bprop kernel support for dropout + determinism (bias) is not supported" "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
) )
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import unittest
import flax
import jax
import jax.numpy as jnp
import numpy as np
from utils import assert_allclose
from transformer_engine.common.recipe import (
DelayedScaling,
MXFP8BlockScaling,
Float8CurrentScaling,
NVFP4BlockScaling,
)
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import autocast
from transformer_engine.jax.quantize import (
get_quantize_config,
is_scaling_mode_supported,
ScalingMode,
update_collections,
TensorSource,
)
from transformer_engine.jax.quantize.helper import _format2dtypes
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
class TestHelper(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_update_collections(self):
original_val = 0.0
updated_val = 10.0
original_state = {
"test1": original_val,
"test2": original_val,
}
updated_state = update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val)
original_state = flax.core.frozen_dict.FrozenDict(original_state)
updated_state = update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val)
class TestFP8Functions(unittest.TestCase):
def _check_default_state(self):
self.assertFalse(get_quantize_config().is_fp8_enabled())
def _compare_delay_scaling(self, test):
self.assertEqual(get_quantize_config().MARGIN, test.margin)
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
self.assertEqual(get_quantize_config().AMAX_HISTORY_LEN, test.amax_history_len)
self.assertEqual(get_quantize_config().AMAX_COMPUTE_ALGO.value, test.amax_compute_algo)
def _compare_current_scaling(self, test):
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
for tensor_source in TensorSource:
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source),
ScalingMode.CURRENT_TENSOR_SCALING,
)
def _compare_mxfp8_scaling(self, test):
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
for tensor_source in TensorSource:
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING
)
def _compare_nvfp4_scaling(self, test):
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp4_format)[0])
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp4_format)[1])
for tensor_source in TensorSource:
target_scaling_mode = (
ScalingMode.NVFP4_2D_SCALING
if tensor_source == TensorSource.KERNEL
else ScalingMode.NVFP4_1D_SCALING
)
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode
)
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_autocast_delayed_scaling(self):
self._check_default_state()
with autocast(enabled=False, recipe=DelayedScaling(), mesh_resource=MeshResource()):
self._check_default_state()
self._check_default_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(ds)
self._check_default_state()
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(ds)
self._check_default_state()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_autocast_current_scaling(self):
self._check_default_state()
with autocast(enabled=False, recipe=Float8CurrentScaling(), mesh_resource=MeshResource()):
self._check_default_state()
self._check_default_state()
cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_default_state()
cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_default_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_autocast_mxfp8_block_scaling(self):
self._check_default_state()
with autocast(enabled=False, recipe=MXFP8BlockScaling(), mesh_resource=MeshResource()):
self._check_default_state()
self._check_default_state()
bs = MXFP8BlockScaling()
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_default_state()
@unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason)
def test_autocast_nvfp4_block_scaling(self):
self._check_default_state()
with autocast(enabled=False, recipe=NVFP4BlockScaling(), mesh_resource=MeshResource()):
self._check_default_state()
self._check_default_state()
bs = NVFP4BlockScaling()
with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_nvfp4_scaling(bs)
self._check_default_state()
...@@ -23,7 +23,8 @@ from utils import EncoderLayer as RefEncoderLayer ...@@ -23,7 +23,8 @@ from utils import EncoderLayer as RefEncoderLayer
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.quantize import ( from transformer_engine.jax.quantize import (
get_quantize_config, get_global_quantize_recipe,
get_quantize_config_with_recipe,
ScalingMode, ScalingMode,
is_fp8_available, is_fp8_available,
update_collections, update_collections,
...@@ -358,7 +359,7 @@ class BaseRunner: ...@@ -358,7 +359,7 @@ class BaseRunner:
ref_params, test_params = self._sync_params(ref_params, test_params) ref_params, test_params = self._sync_params(ref_params, test_params)
if get_quantize_config().is_fp8_enabled(): if get_quantize_config_with_recipe(get_global_quantize_recipe()).is_fp8_enabled():
for _ in range(4): for _ in range(4):
_, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)( _, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
inputs, inputs,
...@@ -368,14 +369,24 @@ class BaseRunner: ...@@ -368,14 +369,24 @@ class BaseRunner:
test_layer, test_layer,
) )
if ( if (
get_quantize_config().get_scaling_mode(TensorSource.X) get_quantize_config_with_recipe(get_global_quantize_recipe()).get_scaling_mode(
TensorSource.X
)
== ScalingMode.DELAYED_TENSOR_SCALING == ScalingMode.DELAYED_TENSOR_SCALING
): ):
_, updated_quantize_meta = flax.core.pop( _, updated_quantize_meta = flax.core.pop(
updated_state[0], get_quantize_config().COLLECTION_NAME updated_state[0],
get_quantize_config_with_recipe(
get_global_quantize_recipe()
).COLLECTION_NAME,
) )
test_others = update_collections( test_others = update_collections(
{get_quantize_config().COLLECTION_NAME: updated_quantize_meta}, test_others {
get_quantize_config_with_recipe(
get_global_quantize_recipe()
).COLLECTION_NAME: updated_quantize_meta
},
test_others,
) )
del updated_quantize_meta del updated_quantize_meta
del updated_state del updated_state
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import unittest
from functools import partial
from abc import ABC, abstractmethod
import flax
import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from utils import assert_allclose, pytest_parametrize_wrapper
from transformer_engine.common.recipe import (
Recipe,
DelayedScaling,
MXFP8BlockScaling,
Float8CurrentScaling,
NVFP4BlockScaling,
)
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import autocast
from transformer_engine.jax.quantize import (
get_global_quantize_recipe,
get_quantize_config_with_recipe,
get_supported_quantization_recipes,
is_scaling_mode_supported,
ScalingMode,
update_collections,
TensorSource,
QuantizeLayout,
)
from transformer_engine.jax.quantize.helper import _format2dtypes
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
from transformer_engine.jax.flax.module import TransformerEngineBase
from transformer_engine.jax import flax as te_flax
import transformer_engine.jax as te
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
SUPPORTED_RECIPES = get_supported_quantization_recipes()
def quantizer_check_vjp(outer_quantizer_set, assertion_func, x):
"""Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries."""
# Define a function with a custom VJP (vector-Jacobian product)
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def quantizer_check(inner_quantizer_set, assertion_func, x):
return quantizer_check_fwd(inner_quantizer_set, assertion_func, x)[0]
def quantizer_check_fwd(inner_quantizer_set, assertion_func, x):
assertion_func(inner_quantizer_set.x, TensorSource.X)
assertion_func(inner_quantizer_set.kernel, TensorSource.KERNEL)
assertion_func(inner_quantizer_set.dgrad, TensorSource.DGRAD)
return x, (inner_quantizer_set,)
def quantizer_check_bwd(assertion_func, ctx, g):
(inner_quantizer_set,) = ctx
return (inner_quantizer_set, g)
quantizer_check.defvjp(quantizer_check_fwd, quantizer_check_bwd)
return quantizer_check(outer_quantizer_set, assertion_func, x)
class TestModule(TransformerEngineBase):
"""A simple module to test quantizer creation and reconstruction across VJP boundaries."""
# Signature: (quantizer: Quantizer, tensor_source: TensorSource) -> None
assertion_func: callable
direct_recipe: Recipe
@nn.compact
def __call__(self, x):
quantizer_set = self.generate_quantizer_set(fp8_recipe=self.direct_recipe)
return quantizer_check_vjp(quantizer_set, self.assertion_func, x)
class TestHelper(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_update_collections(self):
original_val = 0.0
updated_val = 10.0
original_state = {
"test1": original_val,
"test2": original_val,
}
updated_state = update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val)
original_state = flax.core.frozen_dict.FrozenDict(original_state)
updated_state = update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val)
def assert_fp8_format(quantizer, tensor_source, fp8_format):
if fp8_format == FP8Format.HYBRID:
if tensor_source == TensorSource.DGRAD:
assert quantizer.q_dtype == jnp.float8_e5m2
else:
assert quantizer.q_dtype == jnp.float8_e4m3fn
elif fp8_format == FP8Format.E4M3:
assert quantizer.q_dtype == jnp.float8_e4m3fn
else:
raise ValueError(f"Unsupported FP8 format: {fp8_format}")
class RecipeAssertionBase(ABC):
"""Base class for defining recipe assertions."""
@abstractmethod
def assert_context(self, ref_recipe, quantize_config):
"""Asserts that the quantize_config matches the expected properties from the reference recipe when the recipe is used with an autocast context.
Args:
ref_recipe: The reference quantization recipe.
quantize_config: The quantization configuration to be checked.
"""
pass
@abstractmethod
def assert_quantizers(self, ref_recipe, quantizer, tensor_source):
"""Asserts that the quantizer matches the expected properties from the reference recipe. The quantizers are created in a small test Flax module TestModule and passed through a VJP boundary to ensure correct reconstruction.
Args:
ref_recipe: The reference quantization recipe.
quantizer: The quantizer to be checked.
tensor_source: The source of the tensor (e.g., KERNEL, X, DGRAD).
"""
pass
class DelayedScalingRecipeAssertion(RecipeAssertionBase):
def assert_context(self, ref_recipe, quantize_config):
assert quantize_config.MARGIN == ref_recipe.margin
assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[0]
assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[1]
assert quantize_config.AMAX_HISTORY_LEN == ref_recipe.amax_history_len
assert quantize_config.AMAX_COMPUTE_ALGO.value == ref_recipe.amax_compute_algo
for tensor_source in TensorSource:
assert (
quantize_config.get_scaling_mode(tensor_source)
== ScalingMode.DELAYED_TENSOR_SCALING
)
def assert_quantizers(self, ref_recipe: DelayedScaling, quantizer, tensor_source):
assert quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING
assert quantizer.margin == ref_recipe.margin
assert quantizer.amax_compute_algo.value == ref_recipe.amax_compute_algo
assert quantizer.amax_history.shape == (ref_recipe.amax_history_len,)
assert_fp8_format(quantizer, tensor_source, ref_recipe.fp8_format)
class CurrentScalingRecipeAssertion(RecipeAssertionBase):
def assert_context(self, ref_recipe, quantize_config):
assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[0]
assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[1]
for tensor_source in TensorSource:
assert (
quantize_config.get_scaling_mode(tensor_source)
== ScalingMode.CURRENT_TENSOR_SCALING
)
def assert_quantizers(self, ref_recipe: Float8CurrentScaling, quantizer, tensor_source):
assert quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
assert_fp8_format(quantizer, tensor_source, ref_recipe.fp8_format)
class MXFP8RecipeAssertion(RecipeAssertionBase):
def assert_context(self, ref_recipe, quantize_config):
assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[0]
assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[1]
for tensor_source in TensorSource:
assert quantize_config.get_scaling_mode(tensor_source) == ScalingMode.MXFP8_1D_SCALING
def assert_quantizers(self, ref_recipe: MXFP8BlockScaling, quantizer, tensor_source):
assert quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
assert_fp8_format(quantizer, tensor_source, ref_recipe.fp8_format)
class NVFP4RecipeAssertion(RecipeAssertionBase):
def assert_context(self, ref_recipe, quantize_config):
assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp4_format)[0]
assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp4_format)[1]
for tensor_source in TensorSource:
target_scaling_mode = (
ScalingMode.NVFP4_2D_SCALING
if (not ref_recipe.disable_2d_quantization) and tensor_source == TensorSource.KERNEL
else ScalingMode.NVFP4_1D_SCALING
)
assert quantize_config.get_scaling_mode(tensor_source) == target_scaling_mode
assert quantize_config.DISABLE_STOCHASTIC_ROUNDING == ref_recipe.disable_stochastic_rounding
assert quantize_config.DISABLE_RHT == ref_recipe.disable_rht
assert quantize_config.DISABLE_2D_QUANTIZATION == ref_recipe.disable_2d_quantization
def assert_quantizers(self, ref_recipe: NVFP4BlockScaling, quantizer, tensor_source):
if tensor_source == TensorSource.KERNEL and not ref_recipe.disable_2d_quantization:
assert quantizer.scaling_mode == ScalingMode.NVFP4_2D_SCALING
else:
assert quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
if ref_recipe.disable_stochastic_rounding or tensor_source != TensorSource.DGRAD:
assert quantizer.stochastic_rounding_rng_state is None
else:
assert quantizer.stochastic_rounding_rng_state is not None
expected_rht = (
quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING
and quantizer.q_layout in {QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE}
and not ref_recipe.disable_rht
)
assert quantizer.use_rht == expected_rht
class TestFP8Functions(unittest.TestCase):
def _check_default_state(self):
self.assertEqual(get_global_quantize_recipe(), None)
def _test_recipe(self, quantization_recipe: Recipe, cls: RecipeAssertionBase):
"""Tests a quantization recipe by verifying its behavior in both autocast and direct application contexts."""
assert_context_func = cls().assert_context
assert_quantizer_func = partial(cls().assert_quantizers, quantization_recipe)
self._test_recipe_autocast(quantization_recipe, assert_context_func, assert_quantizer_func)
self._test_recipe_direct(quantization_recipe, assert_quantizer_func)
def _test_recipe_autocast(
self, quantization_recipe, assert_context_func, assert_quantizer_func
):
"""Tests a quantization recipe within an autocast context by verifying the quantize config and quantizers in a test module."""
self._check_default_state()
with autocast(enabled=False, recipe=quantization_recipe, mesh_resource=MeshResource()):
self._check_default_state()
with autocast(enabled=True, recipe=quantization_recipe, mesh_resource=MeshResource()):
quantize_config = self._get_global_quantize_config()
assert_context_func(quantization_recipe, quantize_config)
self._test_quantizer_in_model(assert_quantizer_func)
self._check_default_state()
def _test_recipe_direct(self, quantization_recipe, assert_quantizer_func):
"""Tests a quantization recipe by directly passing it to a test module and verifying the quantizers."""
self._check_default_state()
self._test_quantizer_in_model(assert_quantizer_func, direct_recipe=quantization_recipe)
self._check_default_state()
def _test_quantizer_in_model(self, assert_quantizer_func, direct_recipe=None):
"""Tests that the quantizers created in a test module match the expected properties by passing them through a VJP boundary.
Args:
assert_quantizer_func: A function that asserts the properties of the quantizers. The function signature is (quantizer: Quantizer, tensor_source: TensorSource) -> None.
direct_recipe: An optional quantization recipe to be passed directly to the test module. This is an alternative API to using autocast contexts.
"""
x = jnp.ones((), dtype=jnp.float32)
test_module = TestModule(assertion_func=assert_quantizer_func, direct_recipe=direct_recipe)
param_key, sr_key = jax.random.split(jax.random.PRNGKey(0))
rngs = {"params": param_key, "sr_rng": sr_key}
variables = test_module.init(rngs, x)
jax.jit(jax.value_and_grad(test_module.apply), static_argnums=(2,))(variables, x, rngs=rngs)
def _get_global_quantize_config(self):
quantization_recipe = get_global_quantize_recipe()
assert quantization_recipe is not None, "No global quantization recipe set"
quantize_config = get_quantize_config_with_recipe(quantization_recipe)
assert (
quantize_config.is_fp8_enabled()
), "Quantization not enabled in global quantize config"
return quantize_config
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_autocast_delayed_scaling(self):
self._test_recipe(
quantization_recipe=DelayedScaling(),
cls=DelayedScalingRecipeAssertion,
)
self._test_recipe(
quantization_recipe=DelayedScaling(
margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1
),
cls=DelayedScalingRecipeAssertion,
)
self._test_recipe(
quantization_recipe=DelayedScaling(
margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1
),
cls=DelayedScalingRecipeAssertion,
)
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_autocast_current_scaling(self):
self._test_recipe(
quantization_recipe=Float8CurrentScaling(),
cls=CurrentScalingRecipeAssertion,
)
self._test_recipe(
quantization_recipe=Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3),
cls=CurrentScalingRecipeAssertion,
)
self._test_recipe(
quantization_recipe=Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID),
cls=CurrentScalingRecipeAssertion,
)
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_autocast_mxfp8_block_scaling(self):
self._test_recipe(
quantization_recipe=MXFP8BlockScaling(),
cls=MXFP8RecipeAssertion,
)
@unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason)
def test_autocast_nvfp4_block_scaling(self):
self._test_recipe(
quantization_recipe=NVFP4BlockScaling(),
cls=NVFP4RecipeAssertion,
)
self._test_recipe(
quantization_recipe=NVFP4BlockScaling(
disable_stochastic_rounding=True,
disable_rht=True,
disable_2d_quantization=True,
),
cls=NVFP4RecipeAssertion,
)
class TestJaxprAndHlo:
"""Tests to verify Jaxpr and/or HLO of compiled modules apply expected recipe functionality and optimizations."""
def _generate_jaxpr_for_layernorm_mlp_fwd_bwd(self, quantization_recipe, ln_mlp_kwargs=None):
"""Generates the jaxpr for a forward and backward pass of LayerNormMLP under the given quantization recipe."""
ln_mlp_kwargs = ln_mlp_kwargs or {}
with te.autocast(enabled=True, recipe=quantization_recipe, mesh_resource=te.MeshResource()):
model = te_flax.LayerNormMLP(
layernorm_type="rmsnorm",
return_layernorm_output=False,
intermediate_dropout_rate=0.0,
dtype=jnp.bfloat16,
**ln_mlp_kwargs,
)
var_collect = model.init(
jax.random.PRNGKey(0),
jnp.ones((128, 128), dtype=jnp.bfloat16),
)
def loss_fn(x, rngs):
return jnp.mean(model.apply(var_collect, x, rngs=rngs)[0])
x = jax.random.normal(jax.random.PRNGKey(0), (128, 128), dtype=jnp.bfloat16)
rngs = {"sr_rng": jax.random.PRNGKey(1), "dropout": jax.random.PRNGKey(2)}
return jax.make_jaxpr(jax.value_and_grad(loss_fn))(x, rngs=rngs)
@pytest_parametrize_wrapper(
"quantization_recipe",
[
quantization_recipe
for quantization_recipe in SUPPORTED_RECIPES
if isinstance(quantization_recipe, NVFP4BlockScaling)
],
)
def test_layernorm_mlp_reuses_amax_nvfp4(self, quantization_recipe):
"""Tests that layernorm_mlp reuses the amax computed in layernorm and the activation and does not recompute it during quantizaton."""
jaxpr = self._generate_jaxpr_for_layernorm_mlp_fwd_bwd(quantization_recipe)
rht_amax_eqns = [
eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == "te_rht_amax_ffi_wrapper"
]
assert len(rht_amax_eqns) == 4, f"Expected 4 rht_amax_eqns, got {len(rht_amax_eqns)}"
def assert_param(index, tensor_name, expected_value: bool):
if expected_value:
assert rht_amax_eqns[index].params["produce_regular_amax"] == True, (
f"Expected produce_regular_amax for {tensor_name} to be True, indicating no"
" reuse of amax as this tensor does not have a previous operation to fuse"
" with"
)
else:
assert rht_amax_eqns[index].params["produce_regular_amax"] == False, (
f"Expected produce_regular_amax for {tensor_name} to be False, indicating"
" reuse of amax"
)
assert_param(0, "fwd ln+q", False)
assert_param(1, "fwd act+q", False)
# No previous op before incoming dgrad in the backward so amax is not reused
assert_param(2, "bwd dgrad", True)
assert_param(3, "bwd dact+q", False)
@pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper(
"quantization_checkpoint_name",
[None, "quantization", "some_arbitrary_user_checkpoint_name"],
)
def test_recipe_supports_quantization_checkpointing(
self, quantization_recipe, quantization_checkpoint_name
):
"""Tests that all supported quantization recipes correctly use checkpoint_name."""
kwargs = {
"quantization_checkpoint_name": quantization_checkpoint_name,
}
jaxpr = self._generate_jaxpr_for_layernorm_mlp_fwd_bwd(quantization_recipe, kwargs)
checkpoint_name_eqns = [
eqn
for eqn in jaxpr.jaxpr.eqns
if eqn.primitive.name == "name" and eqn.params["name"] == quantization_checkpoint_name
]
if quantization_checkpoint_name is None:
assert len(checkpoint_name_eqns) == 0, (
"Expected 0 checkpoint_name eqns when quantization_checkpoint_name is None, got"
f" {len(checkpoint_name_eqns)}"
)
return
# 12 checkpointed values:
# - Fwd pass:
# - Input RMSNorm+Q -> 3 possible output tensors that will be used in the backward
# - Kernel Q -> 3 possible output tensors that will be used in the backward
# - Input Activation+Q -> 3 possible output tensors that will be used in the backward
# - Kernel Q -> 3 possible output tensors that will be used in the backward
expected_checkpoint_eqn_count = 12
assert len(checkpoint_name_eqns) == expected_checkpoint_eqn_count, (
f"Expected {expected_checkpoint_eqn_count} checkpoint_name eqns when"
f" quantization_checkpoint_name is set, got {len(checkpoint_name_eqns)}"
)
...@@ -364,9 +364,9 @@ class MlpBlock(nn.Module): ...@@ -364,9 +364,9 @@ class MlpBlock(nn.Module):
transpose_batch_sequence: bool transpose_batch_sequence: bool
intermediate_dim: int = 2048 intermediate_dim: int = 2048
activations: Sequence[Union[str, Callable]] = ("relu",) activations: Sequence[Union[str, Callable]] = ("gelu",)
kernel_init: Initializer = None kernel_init: Initializer = None
intermediate_dropout_rate: float = 0.1 intermediate_dropout_rate: float = 0.0
intermediate_dropout_dims: Sequence[int] = () intermediate_dropout_dims: Sequence[int] = ()
use_bias: bool = False use_bias: bool = False
dtype: Any = jnp.float32 dtype: Any = jnp.float32
...@@ -1035,14 +1035,14 @@ class EncoderLayer(nn.Module): ...@@ -1035,14 +1035,14 @@ class EncoderLayer(nn.Module):
hidden_dropout: float = 0.1 hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = () hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1 attention_dropout: float = 0.1
intermediate_dropout: float = 0.1 intermediate_dropout: float = 0.0
intermediate_dropout_dims: Sequence[int] = () intermediate_dropout_dims: Sequence[int] = ()
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
float32_attention_logits: bool = False float32_attention_logits: bool = False
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
mlp_dim: int = 2048 mlp_dim: int = 2048
mlp_activations: Sequence[str] = ("relu",) mlp_activations: Sequence[str] = ("gelu",)
use_bias: bool = False use_bias: bool = False
dtype: Any = jnp.float32 dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
...@@ -1199,14 +1199,14 @@ class DecoderLayer(nn.Module): ...@@ -1199,14 +1199,14 @@ class DecoderLayer(nn.Module):
hidden_dropout: float = 0.1 hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = () hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1 attention_dropout: float = 0.1
intermediate_dropout: float = 0.1 intermediate_dropout: float = 0.0
intermediate_dropout_dims: Sequence[int] = () intermediate_dropout_dims: Sequence[int] = ()
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
float32_attention_logits: bool = False float32_attention_logits: bool = False
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
mlp_dim: int = 2048 mlp_dim: int = 2048
mlp_activations: Sequence[str] = ("relu",) mlp_activations: Sequence[str] = ("gelu",)
use_bias: bool = False use_bias: bool = False
dtype: Any = jnp.float32 dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
......
...@@ -248,6 +248,7 @@ def run_dpa_with_cp( ...@@ -248,6 +248,7 @@ def run_dpa_with_cp(
attn_mask_type=config.attn_mask_type, attn_mask_type=config.attn_mask_type,
window_size=config.window_size, window_size=config.window_size,
softmax_type=config.softmax_type, softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit,
).cuda() ).cuda()
if config.softmax_type != "vanilla": if config.softmax_type != "vanilla":
core_attn.softmax_offset.requires_grad = True core_attn.softmax_offset.requires_grad = True
...@@ -308,6 +309,7 @@ def run_dpa_with_cp( ...@@ -308,6 +309,7 @@ def run_dpa_with_cp(
fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group)
else: else:
fp8_context = nullcontext() fp8_context = nullcontext()
max_logit = None
with fp8_context: with fp8_context:
# q, k, v, out in FP8; dout in F16 # q, k, v, out in FP8; dout in F16
out = core_attn( out = core_attn(
...@@ -322,6 +324,8 @@ def run_dpa_with_cp( ...@@ -322,6 +324,8 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded=cu_seqlens_kv_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
fp8_output=fp8_mha, fp8_output=fp8_mha,
) )
if config.return_max_logit:
out, max_logit = out
if fp8_bwd and fp8_mha: if fp8_bwd and fp8_mha:
dout_fp8 = dout_quantizer(dout) dout_fp8 = dout_quantizer(dout)
out.backward(dout_fp8) out.backward(dout_fp8)
...@@ -400,6 +404,7 @@ def run_dpa_with_cp( ...@@ -400,6 +404,7 @@ def run_dpa_with_cp(
fp8_context = nullcontext() fp8_context = nullcontext()
# run attention # run attention
max_logit_ = None
with fp8_context: with fp8_context:
# q, k, v, out in FP8; dout in F16 # q, k, v, out in FP8; dout in F16
out_ = core_attn( out_ = core_attn(
...@@ -414,6 +419,8 @@ def run_dpa_with_cp( ...@@ -414,6 +419,8 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded=cu_seqlens_kv_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded,
fp8_output=fp8_mha, fp8_output=fp8_mha,
) )
if config.return_max_logit:
out_, max_logit_ = out_
if fp8_bwd and fp8_mha: if fp8_bwd and fp8_mha:
dout_fp8_ = dout_quantizer(dout_) dout_fp8_ = dout_quantizer(dout_)
out_.backward(dout_fp8_) out_.backward(dout_fp8_)
...@@ -495,15 +502,15 @@ def run_dpa_with_cp( ...@@ -495,15 +502,15 @@ def run_dpa_with_cp(
) )
atol, rtol, rmse_tol = get_tols(config, dtype) atol, rtol, rmse_tol = get_tols(config, dtype)
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_] tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_logit_]
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset] tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_logit]
names = ["out", "dq", "dk", "dv", "d_softmax_offset"] names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_logit"]
names_cp = [x + "_cp" for x in names] names_cp = [x + "_cp" for x in names]
names_no_cp = [x + "_no_cp" for x in names] names_no_cp = [x + "_no_cp" for x in names]
is_fp8 = dtype == "fp8" is_fp8 = dtype == "fp8"
for i, t in enumerate(tensors_no_cp): for i, t in enumerate(tensors_no_cp):
if t is not None: if t is not None:
if "softmax_offset" not in names[i]: if "softmax_offset" not in names[i] and "max_logit" not in names[i]:
if qkv_format == "bshd": if qkv_format == "bshd":
compare_and_assert( compare_and_assert(
t[:, 0], t[:, 0],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment