"""Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
"""Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
"""Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
"""Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
@@ -30,19 +31,77 @@ class BasePrimitive(metaclass=ABCMeta):
...
@@ -30,19 +31,77 @@ class BasePrimitive(metaclass=ABCMeta):
name=None
name=None
_is_enabled=True
# Default list of primitives to disable for all recipes
_default_disable_names=[]
@classmethod
@classmethod
defenabled(cls):
defenabled(cls):
"""
"""
A custom call is marked as disabled if the `cls.__name__` does not fully match the
Determines if a custom call is enabled based on a state variable and environment variables.
`NVTE_JAX_CUSTOM_CALLS_RE` pattern.
Checks `NVTE_JAX_CUSTOM_CALLS` (key/value format) first, then falls back to the deprecated `NVTE_JAX_CUSTOM_CALLS_RE` (regex pattern),
This uses the Python class name of the primitive definitions that inherit from BasePrimitive.
and finally to the internal state `_is_enabled` if neither is set.
By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names.
For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!DBiasQuantizePrimitive$).+$'` to disable `DBiasQuantizePrimitive`.
Environment Variables:
1. `NVTE_JAX_CUSTOM_CALLS`: Preferred key/value format to enable/disable specific primitives or a single value 'true' or 'false' to enable/disable all primitives.
- Example 1 (global enable): 'true' enables all primitives.
- Example 2 (global disable): 'false' disables all primitives.
- Example 3 (specific settings): 'DBiasQuantizePrimitive=false,GemmPrimitive=true' disables DBiasQuantizePrimitive and enables GemmPrimitive, leaving others at their default state.
Note that the default state is set at class level based on _default_disable_names.
2. `NVTE_JAX_CUSTOM_CALLS_RE`: Deprecated regex pattern to match primitive names.
- Example: 'DBiasQuantizePrimitive' or '^(?!DBiasQuantizePrimitive$).+$' to enable/disable DBiasQuantizePrimitive.
- A deprecation warning is raised if used; it will be removed in future releases.
Behavior:
1. Checks if `NVTE_JAX_CUSTOM_CALLS` is set and parses key/value pairs or single true/false value.
2. If not set, checks `NVTE_JAX_CUSTOM_CALLS_RE` (with deprecation warning) for regex matching.
3. If neither is set, falls back to the internal state `_is_enabled`.
"""Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
"""Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
"""Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE."""
"""Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE
bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X
is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X
n_groups:
n_groups:
fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set.
**kwargs: Additional arguments for quantizer initialization
**kwargs: Additional arguments for quantizer initialization
Returns:
Returns:
A single quantizer set or tuple of quantizer sets
A single quantizer set or tuple of quantizer sets
"""
"""
assertscaling_modeisNoneorfp8_recipeisNone,(
"Cannot specify both scaling_mode and fp8_recipe when creating a quantizer set. Scaling"
" mode can be specified directly via the scaling_mode parameter or indirectly via"
" recipe. Recipe is preferred as it will support additional recipes in future where"
" scaling mode differs between x, kernel, and grad in the quantizer set."
)
iffp8_recipeisnotNone:
# TODO(jberchtold): once recipe and scaling mode are decoupled update this logic