@@ -922,7 +922,7 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
...
@@ -922,7 +922,7 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
"""Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
"""Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
# VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards.