@@ -532,6 +534,126 @@ class QuantizePrimitive(BaseDBiasQuantizePrimitive):
...
@@ -532,6 +534,126 @@ class QuantizePrimitive(BaseDBiasQuantizePrimitive):
"""Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
"""Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
classAmaxScope(Enum):
"""
Amax Scope Enum
"""
LOCAL=1
TPSP=2
FSDP=3
classAmaxCalculationPrimitive(BasePrimitive):
"""
Amax Calculation Primitive with custom_partitioning