# TODO(jberchtold): Remove this hack once we have a better solution to ensure bitwise identical results between TE and JAX RHT+quant implementations. Currently for certain shapes the quantized fp4 data differs by a small amount on <0.5% of the values.
# With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation
"""Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other."""
"""Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries."""
# Define a function with a custom VJP (vector-Jacobian product)
"""Asserts that the quantizer matches the expected properties from the reference recipe. The quantizers are created in a small test Flax module TestModule and passed through a VJP boundary to ensure correct reconstruction.
Args:
ref_recipe: The reference quantization recipe.
quantizer: The quantizer to be checked.
tensor_source: The source of the tensor (e.g., KERNEL, X, DGRAD).
"""Tests that the quantizers created in a test module match the expected properties by passing them through a VJP boundary.
Args:
assert_quantizer_func: A function that asserts the properties of the quantizers. The function signature is (quantizer: Quantizer, tensor_source: TensorSource) -> None.
direct_recipe: An optional quantization recipe to be passed directly to the test module. This is an alternative API to using autocast contexts.