Unverified Commit 2a293456 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Helper to disable TE custom calls + disable GemmPrimitive for non-MXFP8 recipes. (#1962)



* add manage_primitives() helper

* disable GEMM primitives for non-MXFP8 recipes

* implement the NVTE_JAX_CUSTOM_CALLS + deprecate NVTE_JAX_CUSTOM_CALLS_RE

* replace NVTE_JAX_CUSTOM_CALLS_RE with NVTE_JAX_CUSTOM_CALLS in TE tests and examples

* fix use_jax_gemm contextmanager
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 992ba01d
......@@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls"
NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls"
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
......
......@@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS_RE="" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
NVTE_JAX_CUSTOM_CALLS="false" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
......
......@@ -863,15 +863,6 @@ valid_fp8_gemm_operand_types = [
]
def _use_jax_fp8_gemm(enabled=False):
import os
if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$"
elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ:
os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE")
class TestDense:
def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
if data_layout[0] == "T":
......
......@@ -1604,16 +1604,18 @@ def print_debug_tensor_stats(prefix, tensor, hist=False):
@contextmanager
def use_jax_gemm(enabled=False):
orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None)
orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS", None)
try:
if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$"
os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=false"
else:
os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=true"
yield
finally:
if enabled:
if orig_custom_calls_filter is None:
os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE")
os.environ.pop("NVTE_JAX_CUSTOM_CALLS")
else:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls_filter
os.environ["NVTE_JAX_CUSTOM_CALLS"] = orig_custom_calls_filter
......@@ -915,11 +915,11 @@ register_primitive(BaseDActLuDBiasQuantizePrimitive)
class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
"""Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
"""Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
"""Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
"""Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]:
......
......@@ -4,6 +4,7 @@
"""JAX/TE base custom ops"""
import os
import re
import warnings
from abc import ABCMeta, abstractmethod
from functools import partial
from packaging import version
......@@ -30,19 +31,77 @@ class BasePrimitive(metaclass=ABCMeta):
name = None
_is_enabled = True
# Default list of primitives to disable for all recipes
_default_disable_names = ["GemmPrimitive"]
@classmethod
def enabled(cls):
"""
A custom call is marked as disabled if the `cls.__name__` does not fully match the
`NVTE_JAX_CUSTOM_CALLS_RE` pattern.
This uses the Python class name of the primitive definitions that inherit from BasePrimitive.
By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names.
For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!DBiasQuantizePrimitive$).+$'` to disable `DBiasQuantizePrimitive`.
Determines if a custom call is enabled based on a state variable and environment variables.
Checks `NVTE_JAX_CUSTOM_CALLS` (key/value format) first, then falls back to the deprecated `NVTE_JAX_CUSTOM_CALLS_RE` (regex pattern),
and finally to the internal state `_is_enabled` if neither is set.
Environment Variables:
1. `NVTE_JAX_CUSTOM_CALLS`: Preferred key/value format to enable/disable specific primitives or a single value 'true' or 'false' to enable/disable all primitives.
- Example 1 (global enable): 'true' enables all primitives.
- Example 2 (global disable): 'false' disables all primitives.
- Example 3 (specific settings): 'DBiasQuantizePrimitive=false,GemmPrimitive=true' disables DBiasQuantizePrimitive and enables GemmPrimitive, leaving others at their default state.
Note that the default state is set at class level based on _default_disable_names.
2. `NVTE_JAX_CUSTOM_CALLS_RE`: Deprecated regex pattern to match primitive names.
- Example: 'DBiasQuantizePrimitive' or '^(?!DBiasQuantizePrimitive$).+$' to enable/disable DBiasQuantizePrimitive.
- A deprecation warning is raised if used; it will be removed in future releases.
Behavior:
1. Checks if `NVTE_JAX_CUSTOM_CALLS` is set and parses key/value pairs or single true/false value.
2. If not set, checks `NVTE_JAX_CUSTOM_CALLS_RE` (with deprecation warning) for regex matching.
3. If neither is set, falls back to the internal state `_is_enabled`.
"""
# Check new key/value environment variable first
custom_calls_str = os.getenv("NVTE_JAX_CUSTOM_CALLS")
if custom_calls_str is not None:
custom_calls_str = custom_calls_str.strip()
if custom_calls_str.lower() == "true":
return True
if custom_calls_str.lower() == "false":
return False
# Parse key=value pairs
settings = {}
for pair in custom_calls_str.split(","):
pair = pair.strip()
if "=" in pair:
key, value = pair.split("=", 1)
key = key.strip()
value = value.strip().lower()
settings[key] = value == "true"
if cls.__name__ in settings:
return settings[cls.__name__]
# Check old regex environment variable (deprecated)
pattern_str = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE")
if pattern_str is not None:
warnings.warn(
"NVTE_JAX_CUSTOM_CALLS_RE is deprecated and will be removed in future releases. Use"
" NVTE_JAX_CUSTOM_CALLS with key=value format instead (e.g.,"
" 'DBiasQuantizePrimitive=false').",
DeprecationWarning,
)
pattern = re.compile(pattern_str)
env_enabled = pattern.fullmatch(cls.__name__) is not None
return env_enabled
# If no environment variable is set, fall back to the internal state
return cls._is_enabled
@classmethod
def set_enabled(cls, enabled: bool):
"""
Sets the enabled state for this primitive.
"""
pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*")
pattern = re.compile(pattern)
is_enabled = pattern.fullmatch(cls.__name__) is not None
return is_enabled
cls._is_enabled = enabled
@staticmethod
@abstractmethod
......@@ -109,10 +168,19 @@ class BasePrimitive(metaclass=ABCMeta):
return "... -> ..."
# Registry to store all registered primitive classes
_primitive_registry = {}
def register_primitive(cls):
"""
register jax primitive
Register a JAX primitive and add it to the internal registry.
"""
_primitive_registry[cls.__name__] = cls
# Set default disabled state at class level based on _default_disable_names
if cls.__name__ in BasePrimitive._default_disable_names:
cls.set_enabled(False)
def name_of_wrapper_p():
return cls.name + "_wrapper"
......@@ -145,3 +213,48 @@ def register_primitive(cls):
for _name, _value in transformer_engine_jax.registrations().items():
ffi.register_ffi_target(_name, _value, platform="CUDA")
def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False):
"""
Helper function to manage primitive states by name without modifying environment variables.
Allows enabling specific primitives, disabling specific primitives, or disabling all primitives.
This helper is used in the QuantizeConfig.initialize() methods.
Args:
enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None.
disable_names: List of strings, each representing the name of a primitive class to disable. Defaults to None.
disable_all_first: Boolean, if True, disables all primitives before applying enable/disable lists. Defaults to False.
Note:
1. If `disable_all_first` is True, all primitives are disabled first, then `enable_names` is applied.
2. Conflicts (a primitive in both enable and disable lists) are resolved by applying disable last.
"""
enable_set = set(enable_names or [])
disable_set = set(disable_names or [])
if disable_all_first:
for name, cls in _primitive_registry.items():
if (
isinstance(cls, type)
and issubclass(cls, BasePrimitive)
and cls is not BasePrimitive
):
cls.set_enabled(False)
# Apply enables
for name in enable_set:
cls = _primitive_registry.get(name)
if cls and isinstance(cls, type) and issubclass(cls, BasePrimitive):
cls.set_enabled(True)
else:
raise ValueError(f"Primitive not found in registry: {name}")
# Apply disables (overrides enables if there's a conflict)
for name in disable_set:
cls = _primitive_registry.get(name)
if cls and isinstance(cls, type) and issubclass(cls, BasePrimitive):
cls.set_enabled(False)
else:
raise ValueError(f"Primitive not found in registry: {name}")
......@@ -519,11 +519,11 @@ register_primitive(BaseDBiasQuantizePrimitive)
class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive):
"""Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
"""Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
class QuantizePrimitive(BaseDBiasQuantizePrimitive):
"""Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
"""Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
def _jax_quantize(
......
......@@ -352,6 +352,9 @@ class BlockScalingQuantizeConfig:
cls.initialize(fp8_recipe)
cls.AMAX_HISTORY_LEN = 0
# Use TE GEMM instead of JAX GEMM for better performance
tex.base.manage_primitives(enable_names=["GemmPrimitive"])
@staticmethod
def finalize() -> None:
"""Reset the block scaling configuration."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment