Unverified Commit 0cd1cd8e authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Fix incorrectly skipped test_quantize_dbias tests (#1808)



Fix incorrectly skipped test_quantize_dbias tests
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 9c436d53
......@@ -532,7 +532,7 @@ QUANTIZE_OUTPUT_DTYPES = {
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [
((32, 64), -1),
((2, 64, 32), -1),
((2, 64, 32), -2),
((64, 2, 32), -2),
((32, 256, 128), -1),
((32, 256, 128), -2),
((64, 32, 32, 256), -1),
......@@ -544,7 +544,7 @@ QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = {
"L0": [
((32, 64), -1),
((2, 64, 32), -1),
((2, 64, 32), -2),
((64, 2, 32), -2),
],
"L2": ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES,
}
......@@ -577,9 +577,6 @@ class TestQuantize:
q_dtype=q_dtype,
q_layout=q_layout,
)
# Adding dimension to test if padding is done correctly when flatten 3D to 2D
if flatten_axis == -2:
input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations):
......@@ -593,8 +590,6 @@ class TestQuantize:
):
key = jax.random.PRNGKey(0)
if flatten_axis == -2:
input_shape = input_shape[:-1] + (2,) + input_shape[-1:]
input = jax.random.uniform(key, input_shape, in_dtype)
te_quantizer, jax_quantizer = QuantizerFactory.create(
......@@ -625,12 +620,6 @@ class TestFusedQuantize:
):
pytest.skip(f"Input shape {input_shape} is not supported by MXFP8")
if (flatten_axis < 0 and flatten_axis + len(input_shape) <= 0) or flatten_axis <= 0:
pytest.skip(
f"Flatten axis {flatten_axis} is not supported for input shape {input_shape}. There"
" must be at least one axis on either side of the flatten_axis split."
)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment