Unverified Commit 4ceb3d4c authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Distributed Current Scaling (#1699)



* Update test_helper.py and add QuantizeConfig class for CurrentScaling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* WIP distributed current scaling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Distributed Current Scaling (debugging). Distributed implementation with replicated scale_inv works for layernorm_mlp but feels like a hack

Has different per-device scale_inv values, but jax.debug.print only shows one of them. Since we're telling JAX/XLA that this scale is replicated, I think it assumes all the values are equal. However, it doesn't actually check this, so it seems we are able to get away with per-device scales for current scaling but I am not sure how stable this will be and may randomly fail if us or the user changes partitioning at all or if XLA decides to actually act on the assumption that all these scale_invs are the same.
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Implement distributed current scaling by computing a global amax and scale before quantization
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add encoder and mnist tests for current scaling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add primitive prefix to shardy unique_vars to prevent factor conflicts when performing unfused primitives for current scaling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Remove scale_shape primitive arg that is no longer used
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Format
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix expected result on multiprocessing encoder test
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Lint fix
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update multiprocessing current scaling tolerances
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Uncomment test case that was disabled for testing
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Remove commented out debug line
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 643fb0a0
......@@ -37,5 +37,7 @@ def get_fp8_recipe_from_name_string(name: str):
return recipe.DelayedScaling()
case "MXFP8BlockScaling":
return recipe.MXFP8BlockScaling()
case "Float8CurrentScaling":
return recipe.Float8CurrentScaling()
case _:
raise ValueError(f"Invalid fp8_recipe, got {name}")
......@@ -8,9 +8,11 @@ NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
TEST_CASES=(
"test_te_bf16"
"test_te_delayed_scaling_fp8"
"test_te_current_scaling_fp8"
"test_te_mxfp8"
"test_te_bf16_shardy"
"test_te_delayed_scaling_fp8_shardy"
"test_te_current_scaling_fp8_shardy"
)
echo
......
......@@ -441,6 +441,14 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8"""
......@@ -467,6 +475,15 @@ class TestEncoder(unittest.TestCase):
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
......@@ -611,6 +611,14 @@ class TestEncoder(unittest.TestCase):
result = self.exec(True, "DelayedScaling")
assert result[0] < 0.505 and result[1] > 0.754
@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
)
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling")
assert result[0] < 0.507 and result[1] > 0.753
@unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
)
......@@ -631,10 +639,18 @@ class TestEncoder(unittest.TestCase):
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling", enable_shardy=True)
assert result[0] < 0.505 and result[1] > 0.754
assert result[0] < 0.505 and result[1] > 0.753
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
)
def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling", enable_shardy=True)
assert result[0] < 0.507 and result[1] > 0.753
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
......@@ -348,6 +348,14 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8"""
......
......@@ -350,6 +350,14 @@ class TestMNIST(unittest.TestCase):
actual = train_and_evaluate(self.args)
self.verify(actual)
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
self.verify(actual)
if __name__ == "__main__":
train_and_evaluate(mnist_parser(None))
......@@ -34,6 +34,7 @@ is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES = []
if is_fp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling"))
if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
......@@ -76,6 +77,8 @@ class TestDistributedLayernorm:
other_bytes = 0
if fp8_recipe == recipe.MXFP8BlockScaling() and "dp" in mesh_axes:
other_bytes = 384 # required for small scale shapes that require padding
if fp8_recipe == recipe.Float8CurrentScaling():
allreduce_total_bytes += 4 # 1 * FP32 for the amax reduction
return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes
)
......
......@@ -41,6 +41,7 @@ is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES = []
if is_fp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling"))
if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
......@@ -217,37 +218,10 @@ class TestDistributedLayernormMLP:
m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close"
)
else:
is_gated = len(activation_type) > 1
rtol = None
atol = None
if is_gated:
if dtype == jnp.bfloat16:
if i == 2:
rtol = 800
atol = 9e-2
if i == 4:
atol = 300
rtol = 1e-1
if dtype == jnp.float16:
if i == 1: # gamma
rtol = 200
atol = 1e-2
if i == 2:
rtol = 2000
atol = 7e-2
if i == 4 and fp8_recipe == recipe.MXFP8BlockScaling(): # bias_1
# Accumulating dbias across a large tensor introduces a larger difference
rtol = 200
atol = 4e-2
if i == 4 and fp8_recipe == recipe.DelayedScaling():
rtol = 2200
atol = 9e-2
assert_allclose(
multi_grads[i],
single_grads[i],
dtype=dtype,
rtol=rtol,
atol=atol,
err_msg=f"multi_grads[{i}] is not close",
)
......
......@@ -10,47 +10,22 @@ import jax.numpy as jnp
import numpy as np
from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax.quantize import QuantizeConfig, is_fp8_available, AmaxComputeAlgo
from transformer_engine.jax.quantize import (
QuantizeConfig,
is_fp8_available,
ScalingMode,
update_collections,
)
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
class TestQuantizeConfig(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_initialize(self):
margin = 5.0
fp8_format = FP8Format.E4M3
amax_history_len = 10
QuantizeConfig.initialize(
margin=margin, fp8_format=fp8_format, amax_history_len=amax_history_len
)
self.assertEqual(
QuantizeConfig.MARGIN,
margin,
f"QuantizeConfig.MARGIN initialization failed, should be {margin}"
f" but got {QuantizeConfig.MARGIN}.",
)
self.assertEqual(
QuantizeConfig.FP8_FORMAT,
fp8_format,
f"QuantizeConfig.FP8_FORMAT initialization failed, should be {fp8_format}"
f" but got {QuantizeConfig.FP8_FORMAT}.",
)
self.assertEqual(
QuantizeConfig.AMAX_HISTORY_LEN,
amax_history_len,
f"QuantizeConfig.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
f" but got {QuantizeConfig.AMAX_HISTORY_LEN}.",
)
QuantizeConfig.finalize()
class TestHelper(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_update_collections(self):
......@@ -61,12 +36,12 @@ class TestQuantizeConfig(unittest.TestCase):
"test1": original_val,
"test2": original_val,
}
updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state)
updated_state = update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val)
original_state = flax.core.frozen_dict.FrozenDict(original_state)
updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state)
updated_state = update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val)
......@@ -82,8 +57,18 @@ class TestFP8Functions(unittest.TestCase):
self.assertTrue(ref.amax_history_len == test.amax_history_len)
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
def _compare_current_scaling(self, test):
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING)
def _compare_mxfp8_scaling(self, test):
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING)
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast(self):
def test_fp8_autocast_delayed_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
......@@ -107,6 +92,56 @@ class TestFP8Functions(unittest.TestCase):
self._check_defult_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
with fp8_autocast(enabled=False, fp8_recipe=Float8CurrentScaling()):
self.assertFalse(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(Float8CurrentScaling())
self._check_defult_state()
cs = Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=cs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_defult_state()
cs = Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=cs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_defult_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
with fp8_autocast(enabled=False, fp8_recipe=MXFP8BlockScaling()):
self.assertFalse(QuantizeConfig.is_fp8_enabled())
self._compare_mxfp8_scaling(MXFP8BlockScaling())
self._check_defult_state()
bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=bs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_defult_state()
bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=bs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_defult_state()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_with_sharding_resource(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
......
......@@ -89,8 +89,7 @@ class ActLuPrimitive(BasePrimitive):
6,
7,
8,
9,
) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, scale_shapes, is_outer
) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer
inner_primitive = None
outer_primitive = None
......@@ -105,13 +104,12 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
te_act_lu_p abstract
"""
del act_enum, scale_shapes
del act_enum
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
......@@ -121,8 +119,8 @@ class ActLuPrimitive(BasePrimitive):
)
assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
"Current tensor scaling is not supported for fused activation and quantization. Please"
" do activation in higher-precision then quantize with current tensor scaling."
"Current tensor scaling is not yet supported for fused activation and quantization."
" Please do activation in higher-precision then quantize with current tensor scaling."
)
out_shape = (*x_aval.shape[:-2], x_aval.shape[-1]) # Exclude act dim
......@@ -156,13 +154,12 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
te_gated_act_lu_p lowering rules
"""
del out_dtype, scale_dtype, scale_shapes, act_len, is_outer
del out_dtype, scale_dtype, act_len, is_outer
x_aval, scale_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
......@@ -182,7 +179,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
......@@ -201,7 +197,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=False,
)
)
......@@ -230,7 +225,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
......@@ -253,7 +247,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
),
out_bdims,
)
......@@ -266,7 +259,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
arg_infos,
......@@ -277,7 +269,6 @@ class ActLuPrimitive(BasePrimitive):
result_infos,
act_enum,
scale_dtype,
scale_shapes,
act_len,
is_outer,
) # Unused.
......@@ -331,7 +322,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
arg_infos,
......@@ -392,7 +382,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=True,
)
)
......@@ -420,17 +409,16 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
value_types,
result_types,
):
del out_dtype, act_enum, act_len, scale_dtype, scale_shapes, is_outer, mesh, result_types
del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types
x_rank = len(value_types[0].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank - 1, unique_var="i", flatten_axis=-2
x_rank - 1, unique_var="ActLuPrimitive_i", flatten_axis=-2
)
x_axes = scale_rules.input_spec + (f"x{x_rank-1}",)
out = (*x_axes[:-2], x_axes[-1])
......@@ -472,8 +460,8 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
name = "te_dact_dbias_quantize_ffi"
multiple_results = True
# out_dtype, scaling_mode, is_2x, scale_dtype, scale_shapes, is_dbias, act_enum, act_len, is_outer
impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11)
# out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer
impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10)
inner_primitive = None
outer_primitive = None
......@@ -487,7 +475,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
......@@ -496,7 +483,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
"""
te_dact_dbias_quantize_p abstract
"""
del act_enum, scale_shapes
del act_enum
dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_dtype
......@@ -523,10 +510,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
if is_2x:
if scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING.value,
ScalingMode.CURRENT_TENSOR_SCALING.value,
):
if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
else:
colwise_out_shape = out_shape
......@@ -589,7 +573,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
......@@ -598,7 +581,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
"""
te_dact_dbias_quantize_p lowering rules
"""
del out_dtype, scale_dtype, scale_shapes, act_len, is_outer
del out_dtype, scale_dtype, act_len, is_outer
dz_aval, x_aval, scale_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
......@@ -623,7 +606,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
......@@ -643,7 +625,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
act_enum=act_enum,
act_len=act_len,
......@@ -672,7 +653,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
......@@ -704,7 +684,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
act_enum=act_enum,
act_len=act_len,
......@@ -718,7 +697,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
......@@ -728,7 +706,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
result_infos,
):
del out_dtype, result_infos, act_enum
del scale_dtype, scale_shapes, act_len, is_outer
del scale_dtype, act_len, is_outer
x_spec = get_padded_spec(arg_infos[1])
scale_spec = get_padded_spec(arg_infos[2])
......@@ -792,7 +770,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
......@@ -865,7 +842,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
act_enum=act_enum,
act_len=act_len,
......@@ -892,7 +868,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
......@@ -901,11 +876,11 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
value_types,
result_types,
):
del out_dtype, scale_dtype, scale_shapes, act_enum, act_len, is_outer, mesh, result_types
del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types
x_rank = len(value_types[1].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank, unique_var="i", flatten_axis=-2
x_rank, unique_var="DActLuDbiasQuantizePrimitive_i", flatten_axis=-2
)
x_axes = scale_rules.input_spec
out = x_axes
......@@ -1038,7 +1013,6 @@ def act_lu(
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
scale_dtype=jnp.float32,
scale_shapes=((), ()),
is_outer=True,
)
out = out.reshape(output_shape)
......@@ -1072,8 +1046,6 @@ def act_lu(
scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(),
# output does not have act axis
scale_shapes=quantizer.get_scale_shapes(output_shape, flatten_axis=-1),
is_outer=True,
)
......@@ -1166,7 +1138,6 @@ def quantize_dact_dbias(
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, # unused
scale_dtype=jnp.float32, # unused
scale_shapes=((), ()), # unused
is_dbias=False,
act_enum=act_type_id,
act_len=act_len,
......@@ -1203,8 +1174,6 @@ def quantize_dact_dbias(
)
return out, dbias
out_shape = x.shape
(
rowwise_casted_output,
colwise_casted_output,
......@@ -1220,8 +1189,6 @@ def quantize_dact_dbias(
scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(),
# output has act axis
scale_shapes=quantizer.get_scale_shapes(out_shape, flatten_axis=-2),
is_dbias=is_dbias,
act_enum=act_type_id,
act_len=act_len,
......
......@@ -177,10 +177,7 @@ def _jax_gemm_tensor_scaling_fp8(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
"""FP8 GEMM for XLA pattern match"""
assert rhs.scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
), "rhs does not have delayed tensor scaling mode"
assert rhs.scaling_mode.is_tensor_scaling(), "rhs does not have tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.data_layout == "T":
......@@ -272,10 +269,7 @@ def _jax_gemm(
def _jax_gemm_fp8_impl(lhs, rhs):
if lhs.scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
if lhs.scaling_mode.is_tensor_scaling():
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums)
if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
......
......@@ -98,7 +98,7 @@ class NormFwdPrimitive(BasePrimitive):
name = "te_norm_forward_ffi"
multiple_results = True
impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11, 12)
impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11)
inner_primitive = None
outer_primitive = None
......@@ -116,13 +116,11 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
LayerNorm fwd inner primitive abstract
"""
del scale_shapes
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
......@@ -238,13 +236,12 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
LayerNorm fwd lowering rules
"""
del out_dtype, scale_dtype, scale_shapes, is_outer
del out_dtype, scale_dtype, is_outer
x_aval, scale_aval, gamma_aval, beta_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
......@@ -287,7 +284,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
......@@ -316,7 +312,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=False,
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
......@@ -352,7 +347,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
......@@ -386,7 +380,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
),
out_bdims,
)
......@@ -400,14 +393,13 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
arg_infos,
result_infos,
):
del zero_centered_gamma, epsilon, out_dtype, result_infos
del scale_dtype, scale_shapes, is_outer
del scale_dtype, is_outer
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_spec = (*x_spec[:-1], None)
......@@ -459,7 +451,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
arg_infos,
......@@ -544,7 +535,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=True,
)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
......@@ -573,7 +563,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
value_types,
......@@ -584,14 +573,13 @@ class NormFwdPrimitive(BasePrimitive):
epsilon,
out_dtype,
scale_dtype,
scale_shapes,
is_outer,
mesh,
result_types,
)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var="i", flatten_axis=-1
len(value_types[0].shape), unique_var="NormFwdPrimitive_i", flatten_axis=-1
)
x_axes = scale_rules.input_spec
......@@ -931,7 +919,6 @@ def layernorm_fwd(
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
scale_dtype=jnp.float32,
scale_shapes=((1,), (1,)),
is_outer=True,
)
return output, mu, rsigma
......@@ -983,16 +970,12 @@ def layernorm_fwd(
scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape),
is_outer=True,
)
quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
......@@ -1139,7 +1122,6 @@ def rmsnorm_fwd(
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
scale_dtype=jnp.float32,
scale_shapes=((), ()),
is_outer=True,
)
return output, rsigma
......@@ -1188,16 +1170,12 @@ def rmsnorm_fwd(
scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape),
is_outer=True,
)
quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
......
......@@ -27,7 +27,13 @@ from .misc import (
)
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory
from ..quantize import Quantizer, QuantizeLayout, DelayedScaleQuantizer, ScalingMode
from ..quantize import (
Quantizer,
QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode,
compute_scale_from_amax,
)
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
......@@ -53,8 +59,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
6,
7,
8,
9,
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, scale_shapes, is_dbias, is_outer
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer
inner_primitive = None
outer_primitive = None
......@@ -68,14 +73,12 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
):
"""
te_dbias_quantize_p abstract
"""
del scale_shapes
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_shape = x_aval.shape
......@@ -94,10 +97,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING.value,
ScalingMode.CURRENT_TENSOR_SCALING.value,
):
if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
else:
colwise_out_shape = out_shape
......@@ -169,14 +169,13 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
):
"""
te_dbias_quantize_p lowering rules
"""
del out_dtype, scale_dtype, scale_shapes, is_outer
del out_dtype, scale_dtype, is_outer
x_aval, scale_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval.dtype == jnp.float32
......@@ -199,7 +198,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
):
......@@ -224,7 +222,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
is_outer=False,
)
......@@ -257,7 +254,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
):
......@@ -281,7 +277,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
),
out_bdims,
......@@ -294,18 +289,13 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
mesh,
arg_infos,
result_infos,
):
del (out_dtype, result_infos, scale_dtype, scale_shapes, is_outer) # Unused.
assert (
scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
), "Current tensor scaling is not yet supported for multi-GPU partitioning."
del (out_dtype, result_infos, scale_dtype, is_outer) # Unused.
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
......@@ -315,7 +305,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
desc="DBiasQuantizePrimitive.out_sharding",
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else:
colwise_out_spec = x_spec
......@@ -371,7 +361,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
mesh,
......@@ -380,10 +369,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
):
del result_infos, is_outer
assert (
scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
), "Current tensor scaling is not yet supported for multi-GPU partitioning."
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(
......@@ -392,7 +377,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
desc="DBiasQuantizePrimitive.out_sharding",
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else:
colwise_out_spec = x_spec
......@@ -458,7 +443,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
is_outer=True,
)
......@@ -491,17 +475,18 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
mesh,
value_types,
result_types,
):
del out_dtype, scale_dtype, scale_shapes, is_outer, mesh, result_types
del out_dtype, scale_dtype, is_outer, mesh, result_types
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var="i", flatten_axis=flatten_axis
len(value_types[0].shape),
unique_var="DBiasQuantizePrimitive_i",
flatten_axis=flatten_axis,
)
x_axes = scale_rules.input_spec
......@@ -509,7 +494,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
out = x_axes
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
else:
colwise_out = x_axes
......@@ -625,6 +610,13 @@ def _quantize_dbias_impl(
return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return x, None
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Globally reduce amax across all devices for current scaling so we have a single global scale.
# This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
# until the tensor is dequantized (e.g. in the GEMM).
amax = jnp.amax(jnp.abs(x), keepdims=True)
scale = compute_scale_from_amax(amax, quantizer.q_dtype)
if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale
......@@ -643,7 +635,6 @@ def _quantize_dbias_impl(
q_layout=quantizer.q_layout.value,
flatten_axis=flatten_axis,
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape, flatten_axis=flatten_axis),
is_dbias=is_dbias,
is_outer=True,
)
......
......@@ -162,17 +162,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
}
if (scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
nvte_compute_amax(input_tensor.data(), // input data
output_tensor.data(), // output data (for amax)
stream);
QuantizationConfigWrapper quant_config;
/** defaults for now, TODO(Jeremy) move to parameter */
bool force_pow_2_scales = false;
float amax_epsilon = 0.0;
quant_config.set_force_pow_2_scales(force_pow_2_scales);
quant_config.set_amax_epsilon(amax_epsilon);
nvte_compute_scale_from_amax(output_tensor.data(), quant_config, stream);
output_tensor.set_amax(nullptr, DType::kFloat32, std::vector<size_t>{1});
}
......
......@@ -182,6 +182,8 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode:
return ScalingMode.DELAYED_TENSOR_SCALING
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
return ScalingMode.MXFP8_1D_SCALING
if isinstance(fp8_recipe, recipe.Float8CurrentScaling):
return ScalingMode.CURRENT_TENSOR_SCALING
raise ValueError("Invalid fp8_recipe!")
......@@ -240,7 +242,7 @@ class QuantizeConfig:
fp8_recipe: The FP8 recipe to use for initialization
"""
cls.INITIALIZED = True
cls.MARGIN = fp8_recipe.margin
cls.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0
cls.FP8_FORMAT = fp8_recipe.fp8_format
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = _get_scaling_mode(fp8_recipe)
......@@ -309,6 +311,30 @@ class DelayedScalingQuantizeConfig:
QuantizeConfig.finalize()
class CurrentScalingQuantizeConfig:
"""Configuration class for current scaling FP8 recipe.
This class provides specific initialization and finalization for current scaling
FP8 quantization mode.
"""
@staticmethod
def initialize(fp8_recipe: recipe.Recipe) -> None:
"""Initialize current scaling FP8 configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
cls = QuantizeConfig
cls.initialize(fp8_recipe)
cls.AMAX_HISTORY_LEN = 0
@staticmethod
def finalize() -> None:
"""Reset the current scaling configuration."""
QuantizeConfig.finalize()
class BlockScalingQuantizeConfig:
"""Configuration class for block scaling FP8 recipe.
......@@ -385,6 +411,8 @@ def fp8_autocast(
Config = DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
Config = BlockScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.Float8CurrentScaling):
Config = CurrentScalingQuantizeConfig
try:
with global_shard_guard(mesh_resource):
......
......@@ -32,9 +32,31 @@ __all__ = [
"BlockScaleQuantizer",
"QuantizerFactory",
"noop_quantizer_set",
"compute_scale_from_amax",
]
def compute_scale_from_amax(
amax: jnp.ndarray, q_dtype: jnp.dtype, scale: Optional[jnp.ndarray] = None
) -> jnp.ndarray:
"""Compute scale from amax value.
Args:
amax: Maximum absolute value of the tensor
q_dtype: Quantization data type
Returns:
Scale value
"""
fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32)
if scale is None:
scale = jnp.ones((1,))
sf = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
return sf
@register_pytree_node_class
@dataclass
class Quantizer(ABC):
......@@ -377,18 +399,12 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Updated scale value
"""
# 2. Calculate the current scale
fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32)
if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(amax_history, axis=-1, keepdims=True)
else:
amax = amax_history[0:1]
sf = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
scale = scale.at[0].set(sf[0])
return scale
return compute_scale_from_amax(amax, q_dtype, scale=scale)
@staticmethod
@jax.jit
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment