Unverified Commit 4ceb3d4c authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Distributed Current Scaling (#1699)



* Update test_helper.py and add QuantizeConfig class for CurrentScaling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* WIP distributed current scaling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Distributed Current Scaling (debugging). Distributed implementation with replicated scale_inv works for layernorm_mlp but feels like a hack

Has different per-device scale_inv values, but jax.debug.print only shows one of them. Since we're telling JAX/XLA that this scale is replicated, I think it assumes all the values are equal. However, it doesn't actually check this, so it seems we are able to get away with per-device scales for current scaling but I am not sure how stable this will be and may randomly fail if us or the user changes partitioning at all or if XLA decides to actually act on the assumption that all these scale_invs are the same.
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Implement distributed current scaling by computing a global amax and scale before quantization
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add encoder and mnist tests for current scaling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add primitive prefix to shardy unique_vars to prevent factor conflicts when performing unfused primitives for current scaling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Remove scale_shape primitive arg that is no longer used
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Format
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix expected result on multiprocessing encoder test
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Lint fix
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update multiprocessing current scaling tolerances
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Uncomment test case that was disabled for testing
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Remove commented out debug line
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 643fb0a0
...@@ -37,5 +37,7 @@ def get_fp8_recipe_from_name_string(name: str): ...@@ -37,5 +37,7 @@ def get_fp8_recipe_from_name_string(name: str):
return recipe.DelayedScaling() return recipe.DelayedScaling()
case "MXFP8BlockScaling": case "MXFP8BlockScaling":
return recipe.MXFP8BlockScaling() return recipe.MXFP8BlockScaling()
case "Float8CurrentScaling":
return recipe.Float8CurrentScaling()
case _: case _:
raise ValueError(f"Invalid fp8_recipe, got {name}") raise ValueError(f"Invalid fp8_recipe, got {name}")
...@@ -8,9 +8,11 @@ NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} ...@@ -8,9 +8,11 @@ NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
TEST_CASES=( TEST_CASES=(
"test_te_bf16" "test_te_bf16"
"test_te_delayed_scaling_fp8" "test_te_delayed_scaling_fp8"
"test_te_current_scaling_fp8"
"test_te_mxfp8" "test_te_mxfp8"
"test_te_bf16_shardy" "test_te_bf16_shardy"
"test_te_delayed_scaling_fp8_shardy" "test_te_delayed_scaling_fp8_shardy"
"test_te_current_scaling_fp8_shardy"
) )
echo echo
......
...@@ -441,6 +441,14 @@ class TestEncoder(unittest.TestCase): ...@@ -441,6 +441,14 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.535 and actual[1] > 0.73
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self): def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8""" """Test Transformer Engine with MXFP8"""
...@@ -467,6 +475,15 @@ class TestEncoder(unittest.TestCase): ...@@ -467,6 +475,15 @@ class TestEncoder(unittest.TestCase):
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
if __name__ == "__main__": if __name__ == "__main__":
train_and_evaluate(encoder_parser(None)) train_and_evaluate(encoder_parser(None))
...@@ -611,6 +611,14 @@ class TestEncoder(unittest.TestCase): ...@@ -611,6 +611,14 @@ class TestEncoder(unittest.TestCase):
result = self.exec(True, "DelayedScaling") result = self.exec(True, "DelayedScaling")
assert result[0] < 0.505 and result[1] > 0.754 assert result[0] < 0.505 and result[1] > 0.754
@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
)
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling")
assert result[0] < 0.507 and result[1] > 0.753
@unittest.skipIf( @unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
) )
...@@ -631,10 +639,18 @@ class TestEncoder(unittest.TestCase): ...@@ -631,10 +639,18 @@ class TestEncoder(unittest.TestCase):
def test_te_delayed_scaling_fp8_shardy(self): def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8""" """Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling", enable_shardy=True) result = self.exec(True, "DelayedScaling", enable_shardy=True)
assert result[0] < 0.505 and result[1] > 0.754 assert result[0] < 0.505 and result[1] > 0.753
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
)
def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling", enable_shardy=True)
assert result[0] < 0.507 and result[1] > 0.753
if __name__ == "__main__": if __name__ == "__main__":
train_and_evaluate(encoder_parser(None)) train_and_evaluate(encoder_parser(None))
...@@ -348,6 +348,14 @@ class TestEncoder(unittest.TestCase): ...@@ -348,6 +348,14 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79 assert actual[0] < 0.455 and actual[1] > 0.79
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self): def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8""" """Test Transformer Engine with MXFP8"""
......
...@@ -350,6 +350,14 @@ class TestMNIST(unittest.TestCase): ...@@ -350,6 +350,14 @@ class TestMNIST(unittest.TestCase):
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
self.verify(actual) self.verify(actual)
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
self.verify(actual)
if __name__ == "__main__": if __name__ == "__main__":
train_and_evaluate(mnist_parser(None)) train_and_evaluate(mnist_parser(None))
...@@ -34,6 +34,7 @@ is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) ...@@ -34,6 +34,7 @@ is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES = [] SUPPORTED_RECIPES = []
if is_fp8_supported: if is_fp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling")) SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling"))
if is_mxfp8_supported: if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
...@@ -76,6 +77,8 @@ class TestDistributedLayernorm: ...@@ -76,6 +77,8 @@ class TestDistributedLayernorm:
other_bytes = 0 other_bytes = 0
if fp8_recipe == recipe.MXFP8BlockScaling() and "dp" in mesh_axes: if fp8_recipe == recipe.MXFP8BlockScaling() and "dp" in mesh_axes:
other_bytes = 384 # required for small scale shapes that require padding other_bytes = 384 # required for small scale shapes that require padding
if fp8_recipe == recipe.Float8CurrentScaling():
allreduce_total_bytes += 4 # 1 * FP32 for the amax reduction
return generate_collectives_count( return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes
) )
......
...@@ -41,6 +41,7 @@ is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) ...@@ -41,6 +41,7 @@ is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES = [] SUPPORTED_RECIPES = []
if is_fp8_supported: if is_fp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling")) SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling"))
if is_mxfp8_supported: if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
...@@ -217,37 +218,10 @@ class TestDistributedLayernormMLP: ...@@ -217,37 +218,10 @@ class TestDistributedLayernormMLP:
m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close" m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close"
) )
else: else:
is_gated = len(activation_type) > 1
rtol = None
atol = None
if is_gated:
if dtype == jnp.bfloat16:
if i == 2:
rtol = 800
atol = 9e-2
if i == 4:
atol = 300
rtol = 1e-1
if dtype == jnp.float16:
if i == 1: # gamma
rtol = 200
atol = 1e-2
if i == 2:
rtol = 2000
atol = 7e-2
if i == 4 and fp8_recipe == recipe.MXFP8BlockScaling(): # bias_1
# Accumulating dbias across a large tensor introduces a larger difference
rtol = 200
atol = 4e-2
if i == 4 and fp8_recipe == recipe.DelayedScaling():
rtol = 2200
atol = 9e-2
assert_allclose( assert_allclose(
multi_grads[i], multi_grads[i],
single_grads[i], single_grads[i],
dtype=dtype, dtype=dtype,
rtol=rtol,
atol=atol,
err_msg=f"multi_grads[{i}] is not close", err_msg=f"multi_grads[{i}] is not close",
) )
......
...@@ -10,47 +10,22 @@ import jax.numpy as jnp ...@@ -10,47 +10,22 @@ import jax.numpy as jnp
import numpy as np import numpy as np
from utils import assert_allclose from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling
from transformer_engine.common.recipe import Format as FP8Format from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast, get_delayed_scaling from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax.quantize import QuantizeConfig, is_fp8_available, AmaxComputeAlgo from transformer_engine.jax.quantize import (
QuantizeConfig,
is_fp8_available,
ScalingMode,
update_collections,
)
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
class TestQuantizeConfig(unittest.TestCase): class TestHelper(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_initialize(self):
margin = 5.0
fp8_format = FP8Format.E4M3
amax_history_len = 10
QuantizeConfig.initialize(
margin=margin, fp8_format=fp8_format, amax_history_len=amax_history_len
)
self.assertEqual(
QuantizeConfig.MARGIN,
margin,
f"QuantizeConfig.MARGIN initialization failed, should be {margin}"
f" but got {QuantizeConfig.MARGIN}.",
)
self.assertEqual(
QuantizeConfig.FP8_FORMAT,
fp8_format,
f"QuantizeConfig.FP8_FORMAT initialization failed, should be {fp8_format}"
f" but got {QuantizeConfig.FP8_FORMAT}.",
)
self.assertEqual(
QuantizeConfig.AMAX_HISTORY_LEN,
amax_history_len,
f"QuantizeConfig.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
f" but got {QuantizeConfig.AMAX_HISTORY_LEN}.",
)
QuantizeConfig.finalize()
@unittest.skipIf(not is_fp8_supported, reason=reason) @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_update_collections(self): def test_update_collections(self):
...@@ -61,12 +36,12 @@ class TestQuantizeConfig(unittest.TestCase): ...@@ -61,12 +36,12 @@ class TestQuantizeConfig(unittest.TestCase):
"test1": original_val, "test1": original_val,
"test2": original_val, "test2": original_val,
} }
updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state) updated_state = update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val) self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val) self.assertEqual(updated_state["test2"], original_val)
original_state = flax.core.frozen_dict.FrozenDict(original_state) original_state = flax.core.frozen_dict.FrozenDict(original_state)
updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state) updated_state = update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val) self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val) self.assertEqual(updated_state["test2"], original_val)
...@@ -82,8 +57,18 @@ class TestFP8Functions(unittest.TestCase): ...@@ -82,8 +57,18 @@ class TestFP8Functions(unittest.TestCase):
self.assertTrue(ref.amax_history_len == test.amax_history_len) self.assertTrue(ref.amax_history_len == test.amax_history_len)
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo) self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
def _compare_current_scaling(self, test):
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING)
def _compare_mxfp8_scaling(self, test):
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING)
@unittest.skipIf(not is_fp8_supported, reason=reason) @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast(self): def test_fp8_autocast_delayed_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state() self._check_defult_state()
...@@ -107,6 +92,56 @@ class TestFP8Functions(unittest.TestCase): ...@@ -107,6 +92,56 @@ class TestFP8Functions(unittest.TestCase):
self._check_defult_state() self._check_defult_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
with fp8_autocast(enabled=False, fp8_recipe=Float8CurrentScaling()):
self.assertFalse(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(Float8CurrentScaling())
self._check_defult_state()
cs = Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=cs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_defult_state()
cs = Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=cs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_current_scaling(cs)
self._check_defult_state()
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
with fp8_autocast(enabled=False, fp8_recipe=MXFP8BlockScaling()):
self.assertFalse(QuantizeConfig.is_fp8_enabled())
self._compare_mxfp8_scaling(MXFP8BlockScaling())
self._check_defult_state()
bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=bs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_defult_state()
bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=bs):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_defult_state()
@unittest.skipIf(not is_fp8_supported, reason=reason) @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_with_sharding_resource(self): def test_fp8_autocast_with_sharding_resource(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
......
...@@ -89,8 +89,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -89,8 +89,7 @@ class ActLuPrimitive(BasePrimitive):
6, 6,
7, 7,
8, 8,
9, ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer
) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, scale_shapes, is_outer
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -105,13 +104,12 @@ class ActLuPrimitive(BasePrimitive): ...@@ -105,13 +104,12 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_outer, is_outer,
): ):
""" """
te_act_lu_p abstract te_act_lu_p abstract
""" """
del act_enum, scale_shapes del act_enum
dtype = dtypes.canonicalize_dtype(x_aval.dtype) dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32 assert scale_aval is None or scale_aval.dtype == jnp.float32
...@@ -121,8 +119,8 @@ class ActLuPrimitive(BasePrimitive): ...@@ -121,8 +119,8 @@ class ActLuPrimitive(BasePrimitive):
) )
assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, ( assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
"Current tensor scaling is not supported for fused activation and quantization. Please" "Current tensor scaling is not yet supported for fused activation and quantization."
" do activation in higher-precision then quantize with current tensor scaling." " Please do activation in higher-precision then quantize with current tensor scaling."
) )
out_shape = (*x_aval.shape[:-2], x_aval.shape[-1]) # Exclude act dim out_shape = (*x_aval.shape[:-2], x_aval.shape[-1]) # Exclude act dim
...@@ -156,13 +154,12 @@ class ActLuPrimitive(BasePrimitive): ...@@ -156,13 +154,12 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_outer, is_outer,
): ):
""" """
te_gated_act_lu_p lowering rules te_gated_act_lu_p lowering rules
""" """
del out_dtype, scale_dtype, scale_shapes, act_len, is_outer del out_dtype, scale_dtype, act_len, is_outer
x_aval, scale_aval = ctx.avals_in x_aval, scale_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32 assert scale_aval is None or scale_aval.dtype == jnp.float32
...@@ -182,7 +179,6 @@ class ActLuPrimitive(BasePrimitive): ...@@ -182,7 +179,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_outer, is_outer,
): ):
""" """
...@@ -201,7 +197,6 @@ class ActLuPrimitive(BasePrimitive): ...@@ -201,7 +197,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, is_2x=is_2x,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=False, is_outer=False,
) )
) )
...@@ -230,7 +225,6 @@ class ActLuPrimitive(BasePrimitive): ...@@ -230,7 +225,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_outer, is_outer,
): ):
""" """
...@@ -253,7 +247,6 @@ class ActLuPrimitive(BasePrimitive): ...@@ -253,7 +247,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, is_2x=is_2x,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
), ),
out_bdims, out_bdims,
) )
...@@ -266,7 +259,6 @@ class ActLuPrimitive(BasePrimitive): ...@@ -266,7 +259,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_outer, is_outer,
mesh, mesh,
arg_infos, arg_infos,
...@@ -277,7 +269,6 @@ class ActLuPrimitive(BasePrimitive): ...@@ -277,7 +269,6 @@ class ActLuPrimitive(BasePrimitive):
result_infos, result_infos,
act_enum, act_enum,
scale_dtype, scale_dtype,
scale_shapes,
act_len, act_len,
is_outer, is_outer,
) # Unused. ) # Unused.
...@@ -331,7 +322,6 @@ class ActLuPrimitive(BasePrimitive): ...@@ -331,7 +322,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_outer, is_outer,
mesh, mesh,
arg_infos, arg_infos,
...@@ -392,7 +382,6 @@ class ActLuPrimitive(BasePrimitive): ...@@ -392,7 +382,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, is_2x=is_2x,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=True, is_outer=True,
) )
) )
...@@ -420,17 +409,16 @@ class ActLuPrimitive(BasePrimitive): ...@@ -420,17 +409,16 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_outer, is_outer,
mesh, mesh,
value_types, value_types,
result_types, result_types,
): ):
del out_dtype, act_enum, act_len, scale_dtype, scale_shapes, is_outer, mesh, result_types del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types
x_rank = len(value_types[0].shape) x_rank = len(value_types[0].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank - 1, unique_var="i", flatten_axis=-2 x_rank - 1, unique_var="ActLuPrimitive_i", flatten_axis=-2
) )
x_axes = scale_rules.input_spec + (f"x{x_rank-1}",) x_axes = scale_rules.input_spec + (f"x{x_rank-1}",)
out = (*x_axes[:-2], x_axes[-1]) out = (*x_axes[:-2], x_axes[-1])
...@@ -472,8 +460,8 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -472,8 +460,8 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
name = "te_dact_dbias_quantize_ffi" name = "te_dact_dbias_quantize_ffi"
multiple_results = True multiple_results = True
# out_dtype, scaling_mode, is_2x, scale_dtype, scale_shapes, is_dbias, act_enum, act_len, is_outer # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer
impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11) impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -487,7 +475,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -487,7 +475,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_dbias, is_dbias,
act_enum, act_enum,
act_len, act_len,
...@@ -496,7 +483,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -496,7 +483,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
""" """
te_dact_dbias_quantize_p abstract te_dact_dbias_quantize_p abstract
""" """
del act_enum, scale_shapes del act_enum
dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype) dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_dtype assert x_aval.dtype == dz_dtype
...@@ -523,10 +510,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -523,10 +510,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2) ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
if is_2x: if is_2x:
if scaling_mode in ( if ScalingMode(scaling_mode).is_tensor_scaling():
ScalingMode.DELAYED_TENSOR_SCALING.value,
ScalingMode.CURRENT_TENSOR_SCALING.value,
):
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2) colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
else: else:
colwise_out_shape = out_shape colwise_out_shape = out_shape
...@@ -589,7 +573,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -589,7 +573,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_dbias, is_dbias,
act_enum, act_enum,
act_len, act_len,
...@@ -598,7 +581,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -598,7 +581,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
""" """
te_dact_dbias_quantize_p lowering rules te_dact_dbias_quantize_p lowering rules
""" """
del out_dtype, scale_dtype, scale_shapes, act_len, is_outer del out_dtype, scale_dtype, act_len, is_outer
dz_aval, x_aval, scale_aval = ctx.avals_in dz_aval, x_aval, scale_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype assert x_aval.dtype == dz_aval.dtype
...@@ -623,7 +606,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -623,7 +606,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_dbias, is_dbias,
act_enum, act_enum,
act_len, act_len,
...@@ -643,7 +625,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -643,7 +625,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, is_2x=is_2x,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=act_enum, act_enum=act_enum,
act_len=act_len, act_len=act_len,
...@@ -672,7 +653,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -672,7 +653,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_dbias, is_dbias,
act_enum, act_enum,
act_len, act_len,
...@@ -704,7 +684,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -704,7 +684,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, is_2x=is_2x,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=act_enum, act_enum=act_enum,
act_len=act_len, act_len=act_len,
...@@ -718,7 +697,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -718,7 +697,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_dbias, is_dbias,
act_enum, act_enum,
act_len, act_len,
...@@ -728,7 +706,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -728,7 +706,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
result_infos, result_infos,
): ):
del out_dtype, result_infos, act_enum del out_dtype, result_infos, act_enum
del scale_dtype, scale_shapes, act_len, is_outer del scale_dtype, act_len, is_outer
x_spec = get_padded_spec(arg_infos[1]) x_spec = get_padded_spec(arg_infos[1])
scale_spec = get_padded_spec(arg_infos[2]) scale_spec = get_padded_spec(arg_infos[2])
...@@ -792,7 +770,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -792,7 +770,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_dbias, is_dbias,
act_enum, act_enum,
act_len, act_len,
...@@ -865,7 +842,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -865,7 +842,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, is_2x=is_2x,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=act_enum, act_enum=act_enum,
act_len=act_len, act_len=act_len,
...@@ -892,7 +868,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -892,7 +868,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_dbias, is_dbias,
act_enum, act_enum,
act_len, act_len,
...@@ -901,11 +876,11 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -901,11 +876,11 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
value_types, value_types,
result_types, result_types,
): ):
del out_dtype, scale_dtype, scale_shapes, act_enum, act_len, is_outer, mesh, result_types del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types
x_rank = len(value_types[1].shape) x_rank = len(value_types[1].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank, unique_var="i", flatten_axis=-2 x_rank, unique_var="DActLuDbiasQuantizePrimitive_i", flatten_axis=-2
) )
x_axes = scale_rules.input_spec x_axes = scale_rules.input_spec
out = x_axes out = x_axes
...@@ -1038,7 +1013,6 @@ def act_lu( ...@@ -1038,7 +1013,6 @@ def act_lu(
scaling_mode=ScalingMode.NO_SCALING.value, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, is_2x=False,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
scale_shapes=((), ()),
is_outer=True, is_outer=True,
) )
out = out.reshape(output_shape) out = out.reshape(output_shape)
...@@ -1072,8 +1046,6 @@ def act_lu( ...@@ -1072,8 +1046,6 @@ def act_lu(
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(), is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
# output does not have act axis
scale_shapes=quantizer.get_scale_shapes(output_shape, flatten_axis=-1),
is_outer=True, is_outer=True,
) )
...@@ -1166,7 +1138,6 @@ def quantize_dact_dbias( ...@@ -1166,7 +1138,6 @@ def quantize_dact_dbias(
scaling_mode=ScalingMode.NO_SCALING.value, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, # unused is_2x=False, # unused
scale_dtype=jnp.float32, # unused scale_dtype=jnp.float32, # unused
scale_shapes=((), ()), # unused
is_dbias=False, is_dbias=False,
act_enum=act_type_id, act_enum=act_type_id,
act_len=act_len, act_len=act_len,
...@@ -1203,8 +1174,6 @@ def quantize_dact_dbias( ...@@ -1203,8 +1174,6 @@ def quantize_dact_dbias(
) )
return out, dbias return out, dbias
out_shape = x.shape
( (
rowwise_casted_output, rowwise_casted_output,
colwise_casted_output, colwise_casted_output,
...@@ -1220,8 +1189,6 @@ def quantize_dact_dbias( ...@@ -1220,8 +1189,6 @@ def quantize_dact_dbias(
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(), is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
# output has act axis
scale_shapes=quantizer.get_scale_shapes(out_shape, flatten_axis=-2),
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=act_type_id, act_enum=act_type_id,
act_len=act_len, act_len=act_len,
......
...@@ -177,10 +177,7 @@ def _jax_gemm_tensor_scaling_fp8( ...@@ -177,10 +177,7 @@ def _jax_gemm_tensor_scaling_fp8(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]] lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
): ):
"""FP8 GEMM for XLA pattern match""" """FP8 GEMM for XLA pattern match"""
assert rhs.scaling_mode in ( assert rhs.scaling_mode.is_tensor_scaling(), "rhs does not have tensor scaling mode"
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
), "rhs does not have delayed tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.data_layout == "T": if lhs.data_layout == "T":
...@@ -272,10 +269,7 @@ def _jax_gemm( ...@@ -272,10 +269,7 @@ def _jax_gemm(
def _jax_gemm_fp8_impl(lhs, rhs): def _jax_gemm_fp8_impl(lhs, rhs):
if lhs.scaling_mode in ( if lhs.scaling_mode.is_tensor_scaling():
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums) return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums)
if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
......
...@@ -98,7 +98,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -98,7 +98,7 @@ class NormFwdPrimitive(BasePrimitive):
name = "te_norm_forward_ffi" name = "te_norm_forward_ffi"
multiple_results = True multiple_results = True
impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11, 12) impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -116,13 +116,11 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -116,13 +116,11 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_outer, is_outer,
): ):
""" """
LayerNorm fwd inner primitive abstract LayerNorm fwd inner primitive abstract
""" """
del scale_shapes
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
...@@ -238,13 +236,12 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -238,13 +236,12 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_outer, is_outer,
): ):
""" """
LayerNorm fwd lowering rules LayerNorm fwd lowering rules
""" """
del out_dtype, scale_dtype, scale_shapes, is_outer del out_dtype, scale_dtype, is_outer
x_aval, scale_aval, gamma_aval, beta_aval = ctx.avals_in x_aval, scale_aval, gamma_aval, beta_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
...@@ -287,7 +284,6 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -287,7 +284,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_outer, is_outer,
): ):
""" """
...@@ -316,7 +312,6 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -316,7 +312,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, is_2x=is_2x,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=False, is_outer=False,
) )
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
...@@ -352,7 +347,6 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -352,7 +347,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_outer, is_outer,
): ):
""" """
...@@ -386,7 +380,6 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -386,7 +380,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, is_2x=is_2x,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
), ),
out_bdims, out_bdims,
) )
...@@ -400,14 +393,13 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -400,14 +393,13 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_outer, is_outer,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
): ):
del zero_centered_gamma, epsilon, out_dtype, result_infos del zero_centered_gamma, epsilon, out_dtype, result_infos
del scale_dtype, scale_shapes, is_outer del scale_dtype, is_outer
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1]) scale_spec = get_padded_spec(arg_infos[1])
out_spec = (*x_spec[:-1], None) out_spec = (*x_spec[:-1], None)
...@@ -459,7 +451,6 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -459,7 +451,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_outer, is_outer,
mesh, mesh,
arg_infos, arg_infos,
...@@ -544,7 +535,6 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -544,7 +535,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
is_2x=is_2x, is_2x=is_2x,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=True, is_outer=True,
) )
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
...@@ -573,7 +563,6 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -573,7 +563,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
scale_dtype, scale_dtype,
scale_shapes,
is_outer, is_outer,
mesh, mesh,
value_types, value_types,
...@@ -584,14 +573,13 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -584,14 +573,13 @@ class NormFwdPrimitive(BasePrimitive):
epsilon, epsilon,
out_dtype, out_dtype,
scale_dtype, scale_dtype,
scale_shapes,
is_outer, is_outer,
mesh, mesh,
result_types, result_types,
) )
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var="i", flatten_axis=-1 len(value_types[0].shape), unique_var="NormFwdPrimitive_i", flatten_axis=-1
) )
x_axes = scale_rules.input_spec x_axes = scale_rules.input_spec
...@@ -931,7 +919,6 @@ def layernorm_fwd( ...@@ -931,7 +919,6 @@ def layernorm_fwd(
scaling_mode=ScalingMode.NO_SCALING.value, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, is_2x=False,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
scale_shapes=((1,), (1,)),
is_outer=True, is_outer=True,
) )
return output, mu, rsigma return output, mu, rsigma
...@@ -983,16 +970,12 @@ def layernorm_fwd( ...@@ -983,16 +970,12 @@ def layernorm_fwd(
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x, is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape),
is_outer=True, is_outer=True,
) )
quantizer.update(updated_amax) quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode in ( if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
colwise_casted_output = jnp.transpose( colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
) )
...@@ -1139,7 +1122,6 @@ def rmsnorm_fwd( ...@@ -1139,7 +1122,6 @@ def rmsnorm_fwd(
scaling_mode=ScalingMode.NO_SCALING.value, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, is_2x=False,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
scale_shapes=((), ()),
is_outer=True, is_outer=True,
) )
return output, rsigma return output, rsigma
...@@ -1188,16 +1170,12 @@ def rmsnorm_fwd( ...@@ -1188,16 +1170,12 @@ def rmsnorm_fwd(
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x, is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape),
is_outer=True, is_outer=True,
) )
quantizer.update(updated_amax) quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode in ( if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
colwise_casted_output = jnp.transpose( colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
) )
......
...@@ -27,7 +27,13 @@ from .misc import ( ...@@ -27,7 +27,13 @@ from .misc import (
) )
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory
from ..quantize import Quantizer, QuantizeLayout, DelayedScaleQuantizer, ScalingMode from ..quantize import (
Quantizer,
QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode,
compute_scale_from_amax,
)
if version.parse(jax.__version__) >= version.parse("0.5.0"): if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports from jax import ffi # pylint: disable=ungrouped-imports
...@@ -53,8 +59,7 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -53,8 +59,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
6, 6,
7, 7,
8, 8,
9, ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, scale_shapes, is_dbias, is_outer
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -68,14 +73,12 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -68,14 +73,12 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout, q_layout,
flatten_axis, flatten_axis,
scale_dtype, scale_dtype,
scale_shapes,
is_dbias, is_dbias,
is_outer, is_outer,
): ):
""" """
te_dbias_quantize_p abstract te_dbias_quantize_p abstract
""" """
del scale_shapes
dtype = dtypes.canonicalize_dtype(x_aval.dtype) dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_shape = x_aval.shape out_shape = x_aval.shape
...@@ -94,10 +97,7 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -94,10 +97,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis) ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode in ( if ScalingMode(scaling_mode).is_tensor_scaling():
ScalingMode.DELAYED_TENSOR_SCALING.value,
ScalingMode.CURRENT_TENSOR_SCALING.value,
):
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis) colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
else: else:
colwise_out_shape = out_shape colwise_out_shape = out_shape
...@@ -169,14 +169,13 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -169,14 +169,13 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout, q_layout,
flatten_axis, flatten_axis,
scale_dtype, scale_dtype,
scale_shapes,
is_dbias, is_dbias,
is_outer, is_outer,
): ):
""" """
te_dbias_quantize_p lowering rules te_dbias_quantize_p lowering rules
""" """
del out_dtype, scale_dtype, scale_shapes, is_outer del out_dtype, scale_dtype, is_outer
x_aval, scale_aval = ctx.avals_in x_aval, scale_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32
...@@ -199,7 +198,6 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -199,7 +198,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout, q_layout,
flatten_axis, flatten_axis,
scale_dtype, scale_dtype,
scale_shapes,
is_dbias, is_dbias,
is_outer, is_outer,
): ):
...@@ -224,7 +222,6 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -224,7 +222,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout=q_layout, q_layout=q_layout,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias, is_dbias=is_dbias,
is_outer=False, is_outer=False,
) )
...@@ -257,7 +254,6 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -257,7 +254,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout, q_layout,
flatten_axis, flatten_axis,
scale_dtype, scale_dtype,
scale_shapes,
is_dbias, is_dbias,
is_outer, is_outer,
): ):
...@@ -281,7 +277,6 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -281,7 +277,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout=q_layout, q_layout=q_layout,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias, is_dbias=is_dbias,
), ),
out_bdims, out_bdims,
...@@ -294,18 +289,13 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -294,18 +289,13 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout, q_layout,
flatten_axis, flatten_axis,
scale_dtype, scale_dtype,
scale_shapes,
is_dbias, is_dbias,
is_outer, is_outer,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
): ):
del (out_dtype, result_infos, scale_dtype, scale_shapes, is_outer) # Unused. del (out_dtype, result_infos, scale_dtype, is_outer) # Unused.
assert (
scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
), "Current tensor scaling is not yet supported for multi-GPU partitioning."
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1]) scale_spec = get_padded_spec(arg_infos[1])
...@@ -315,7 +305,7 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -315,7 +305,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
desc="DBiasQuantizePrimitive.out_sharding", desc="DBiasQuantizePrimitive.out_sharding",
) )
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else: else:
colwise_out_spec = x_spec colwise_out_spec = x_spec
...@@ -371,7 +361,6 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -371,7 +361,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout, q_layout,
flatten_axis, flatten_axis,
scale_dtype, scale_dtype,
scale_shapes,
is_dbias, is_dbias,
is_outer, is_outer,
mesh, mesh,
...@@ -380,10 +369,6 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -380,10 +369,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
): ):
del result_infos, is_outer del result_infos, is_outer
assert (
scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
), "Current tensor scaling is not yet supported for multi-GPU partitioning."
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1]) scale_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding( out_sharding = NamedSharding(
...@@ -392,7 +377,7 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -392,7 +377,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
desc="DBiasQuantizePrimitive.out_sharding", desc="DBiasQuantizePrimitive.out_sharding",
) )
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else: else:
colwise_out_spec = x_spec colwise_out_spec = x_spec
...@@ -458,7 +443,6 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -458,7 +443,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout=q_layout, q_layout=q_layout,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias, is_dbias=is_dbias,
is_outer=True, is_outer=True,
) )
...@@ -491,17 +475,18 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -491,17 +475,18 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout, q_layout,
flatten_axis, flatten_axis,
scale_dtype, scale_dtype,
scale_shapes,
is_dbias, is_dbias,
is_outer, is_outer,
mesh, mesh,
value_types, value_types,
result_types, result_types,
): ):
del out_dtype, scale_dtype, scale_shapes, is_outer, mesh, result_types del out_dtype, scale_dtype, is_outer, mesh, result_types
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var="i", flatten_axis=flatten_axis len(value_types[0].shape),
unique_var="DBiasQuantizePrimitive_i",
flatten_axis=flatten_axis,
) )
x_axes = scale_rules.input_spec x_axes = scale_rules.input_spec
...@@ -509,7 +494,7 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -509,7 +494,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
out = x_axes out = x_axes
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis)) colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
else: else:
colwise_out = x_axes colwise_out = x_axes
...@@ -625,6 +610,13 @@ def _quantize_dbias_impl( ...@@ -625,6 +610,13 @@ def _quantize_dbias_impl(
return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return x, None return x, None
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Globally reduce amax across all devices for current scaling so we have a single global scale.
# This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
# until the tensor is dequantized (e.g. in the GEMM).
amax = jnp.amax(jnp.abs(x), keepdims=True)
scale = compute_scale_from_amax(amax, quantizer.q_dtype)
if isinstance(quantizer, DelayedScaleQuantizer): if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale scale = quantizer.scale
...@@ -643,7 +635,6 @@ def _quantize_dbias_impl( ...@@ -643,7 +635,6 @@ def _quantize_dbias_impl(
q_layout=quantizer.q_layout.value, q_layout=quantizer.q_layout.value,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape, flatten_axis=flatten_axis),
is_dbias=is_dbias, is_dbias=is_dbias,
is_outer=True, is_outer=True,
) )
......
...@@ -162,17 +162,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T ...@@ -162,17 +162,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
} }
if (scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) { if (scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
nvte_compute_amax(input_tensor.data(), // input data
output_tensor.data(), // output data (for amax)
stream);
QuantizationConfigWrapper quant_config;
/** defaults for now, TODO(Jeremy) move to parameter */
bool force_pow_2_scales = false;
float amax_epsilon = 0.0;
quant_config.set_force_pow_2_scales(force_pow_2_scales);
quant_config.set_amax_epsilon(amax_epsilon);
nvte_compute_scale_from_amax(output_tensor.data(), quant_config, stream);
output_tensor.set_amax(nullptr, DType::kFloat32, std::vector<size_t>{1}); output_tensor.set_amax(nullptr, DType::kFloat32, std::vector<size_t>{1});
} }
......
...@@ -182,6 +182,8 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode: ...@@ -182,6 +182,8 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode:
return ScalingMode.DELAYED_TENSOR_SCALING return ScalingMode.DELAYED_TENSOR_SCALING
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
return ScalingMode.MXFP8_1D_SCALING return ScalingMode.MXFP8_1D_SCALING
if isinstance(fp8_recipe, recipe.Float8CurrentScaling):
return ScalingMode.CURRENT_TENSOR_SCALING
raise ValueError("Invalid fp8_recipe!") raise ValueError("Invalid fp8_recipe!")
...@@ -240,7 +242,7 @@ class QuantizeConfig: ...@@ -240,7 +242,7 @@ class QuantizeConfig:
fp8_recipe: The FP8 recipe to use for initialization fp8_recipe: The FP8 recipe to use for initialization
""" """
cls.INITIALIZED = True cls.INITIALIZED = True
cls.MARGIN = fp8_recipe.margin cls.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0
cls.FP8_FORMAT = fp8_recipe.fp8_format cls.FP8_FORMAT = fp8_recipe.fp8_format
cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT)
cls.SCALING_MODE = _get_scaling_mode(fp8_recipe) cls.SCALING_MODE = _get_scaling_mode(fp8_recipe)
...@@ -309,6 +311,30 @@ class DelayedScalingQuantizeConfig: ...@@ -309,6 +311,30 @@ class DelayedScalingQuantizeConfig:
QuantizeConfig.finalize() QuantizeConfig.finalize()
class CurrentScalingQuantizeConfig:
"""Configuration class for current scaling FP8 recipe.
This class provides specific initialization and finalization for current scaling
FP8 quantization mode.
"""
@staticmethod
def initialize(fp8_recipe: recipe.Recipe) -> None:
"""Initialize current scaling FP8 configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
cls = QuantizeConfig
cls.initialize(fp8_recipe)
cls.AMAX_HISTORY_LEN = 0
@staticmethod
def finalize() -> None:
"""Reset the current scaling configuration."""
QuantizeConfig.finalize()
class BlockScalingQuantizeConfig: class BlockScalingQuantizeConfig:
"""Configuration class for block scaling FP8 recipe. """Configuration class for block scaling FP8 recipe.
...@@ -385,6 +411,8 @@ def fp8_autocast( ...@@ -385,6 +411,8 @@ def fp8_autocast(
Config = DelayedScalingQuantizeConfig Config = DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
Config = BlockScalingQuantizeConfig Config = BlockScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.Float8CurrentScaling):
Config = CurrentScalingQuantizeConfig
try: try:
with global_shard_guard(mesh_resource): with global_shard_guard(mesh_resource):
......
...@@ -32,9 +32,31 @@ __all__ = [ ...@@ -32,9 +32,31 @@ __all__ = [
"BlockScaleQuantizer", "BlockScaleQuantizer",
"QuantizerFactory", "QuantizerFactory",
"noop_quantizer_set", "noop_quantizer_set",
"compute_scale_from_amax",
] ]
def compute_scale_from_amax(
amax: jnp.ndarray, q_dtype: jnp.dtype, scale: Optional[jnp.ndarray] = None
) -> jnp.ndarray:
"""Compute scale from amax value.
Args:
amax: Maximum absolute value of the tensor
q_dtype: Quantization data type
Returns:
Scale value
"""
fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32)
if scale is None:
scale = jnp.ones((1,))
sf = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
return sf
@register_pytree_node_class @register_pytree_node_class
@dataclass @dataclass
class Quantizer(ABC): class Quantizer(ABC):
...@@ -377,18 +399,12 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): ...@@ -377,18 +399,12 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Updated scale value Updated scale value
""" """
# 2. Calculate the current scale # 2. Calculate the current scale
fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32)
if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX: if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX:
amax = jnp.max(amax_history, axis=-1, keepdims=True) amax = jnp.max(amax_history, axis=-1, keepdims=True)
else: else:
amax = amax_history[0:1] amax = amax_history[0:1]
sf = (fp8_max / amax) / (2**QuantizeConfig.MARGIN) return compute_scale_from_amax(amax, q_dtype, scale=scale)
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
scale = scale.at[0].set(sf[0])
return scale
@staticmethod @staticmethod
@jax.jit @jax.jit
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment