Unverified Commit 8a7ab3dd authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] NVFP4 support in TE/JAX (#2254)


Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent e99be1b6
...@@ -33,6 +33,13 @@ def is_mxfp8_supported(): ...@@ -33,6 +33,13 @@ def is_mxfp8_supported():
return gpu_arch >= 100 return gpu_arch >= 100
@lru_cache
def is_nvfp4_supported():
"""Return if FP8 has hardware supported"""
gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 100
def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01, print_info=False): def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01, print_info=False):
"""Checks whether most params are sharded across sharding axis. """Checks whether most params are sharded across sharding axis.
...@@ -98,7 +105,7 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01, print_info= ...@@ -98,7 +105,7 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01, print_info=
) )
def get_fp8_recipe_from_name_string(name: str): def get_quantization_recipe_from_name_string(name: str):
"""Query recipe from a given name string""" """Query recipe from a given name string"""
match name: match name:
case "DelayedScaling": case "DelayedScaling":
...@@ -107,5 +114,7 @@ def get_fp8_recipe_from_name_string(name: str): ...@@ -107,5 +114,7 @@ def get_fp8_recipe_from_name_string(name: str):
return recipe.MXFP8BlockScaling() return recipe.MXFP8BlockScaling()
case "Float8CurrentScaling": case "Float8CurrentScaling":
return recipe.Float8CurrentScaling() return recipe.Float8CurrentScaling()
case "NVFP4BlockScaling":
return recipe.NVFP4BlockScaling()
case _: case _:
raise ValueError(f"Invalid fp8_recipe, got {name}") raise ValueError(f"Invalid quantization_recipe, got {name}")
...@@ -10,9 +10,11 @@ TEST_CASES=( ...@@ -10,9 +10,11 @@ TEST_CASES=(
"test_te_delayed_scaling_fp8" "test_te_delayed_scaling_fp8"
"test_te_current_scaling_fp8" "test_te_current_scaling_fp8"
"test_te_mxfp8" "test_te_mxfp8"
"test_te_nvfp4"
"test_te_bf16_shardy" "test_te_bf16_shardy"
"test_te_delayed_scaling_fp8_shardy" "test_te_delayed_scaling_fp8_shardy"
"test_te_current_scaling_fp8_shardy" "test_te_current_scaling_fp8_shardy"
"test_te_nvfp4_shardy"
) )
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
......
...@@ -21,13 +21,13 @@ from jax.sharding import PartitionSpec, NamedSharding ...@@ -21,13 +21,13 @@ from jax.sharding import PartitionSpec, NamedSharding
from common import ( from common import (
is_bf16_supported, is_bf16_supported,
get_fp8_recipe_from_name_string, get_quantization_recipe_from_name_string,
assert_params_sufficiently_sharded, assert_params_sufficiently_sharded,
) )
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
DEVICE_DP_AXIS = "data" DEVICE_DP_AXIS = "data"
...@@ -36,6 +36,7 @@ NAMED_BROADCAST_AXIS = "my_broadcast_axis" ...@@ -36,6 +36,7 @@ NAMED_BROADCAST_AXIS = "my_broadcast_axis"
NAMED_TP_AXIS = "my_tp_axis" NAMED_TP_AXIS = "my_tp_axis"
PARAMS_KEY = "params" PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes" PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
SR_KEY = "sr_rng"
DROPOUT_KEY = "dropout" DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng" INPUT_KEY = "input_rng"
...@@ -121,6 +122,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): ...@@ -121,6 +122,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
epoch_accuracy = [] epoch_accuracy = []
for perm in perms: for perm in perms:
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_inputs = train_ds["sentence"][perm, ...] batch_inputs = train_ds["sentence"][perm, ...]
batch_masks = train_ds["mask"][perm, ...] batch_masks = train_ds["mask"][perm, ...]
batch_labels = train_ds["label"][perm, ...] batch_labels = train_ds["label"][perm, ...]
...@@ -135,11 +138,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): ...@@ -135,11 +138,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
return state, avg_loss, avg_accuracy, var_collect return state, avg_loss, avg_accuracy, var_collect
def eval_step(state, inputs, masks, labels, var_collect): def eval_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes loss and accuracy for a single batch.""" """Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False): def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
...@@ -150,7 +153,7 @@ def eval_step(state, inputs, masks, labels, var_collect): ...@@ -150,7 +153,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
return loss, accuracy return loss, accuracy
def eval_model(state, test_ds, batch_size, var_collect, eval_fn): def eval_model(state, test_ds, batch_size, var_collect, eval_fn, rngs):
"""Evaluation loop.""" """Evaluation loop."""
test_ds_size = len(test_ds["sentence"]) test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size num_steps = test_ds_size // batch_size
...@@ -159,11 +162,13 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): ...@@ -159,11 +162,13 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
all_accuracy = [] all_accuracy = []
for batch_start in range(0, valid_size, batch_size): for batch_start in range(0, valid_size, batch_size):
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_end = batch_start + batch_size batch_end = batch_start + batch_size
batch_inputs = test_ds["sentence"][batch_start:batch_end] batch_inputs = test_ds["sentence"][batch_start:batch_end]
batch_masks = test_ds["mask"][batch_start:batch_end] batch_masks = test_ds["mask"][batch_start:batch_end]
batch_labels = test_ds["label"][batch_start:batch_end] batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect) loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect, rngs)
all_loss.append(loss) all_loss.append(loss)
all_accuracy.append(accuracy) all_accuracy.append(accuracy)
...@@ -223,7 +228,7 @@ def get_datasets(max_seq_len): ...@@ -223,7 +228,7 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels): def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8." "Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)}
func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
...@@ -257,7 +262,7 @@ def train_and_evaluate(args): ...@@ -257,7 +262,7 @@ def train_and_evaluate(args):
), "Test batch size needs to be multiple of 32 for MXFP8" ), "Test batch size needs to be multiple of 32 for MXFP8"
if args.use_fp8: if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
else: else:
fp8_recipe = None fp8_recipe = None
...@@ -275,7 +280,8 @@ def train_and_evaluate(args): ...@@ -275,7 +280,8 @@ def train_and_evaluate(args):
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng) rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} rng, sr_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
input_shape = [args.batch_size, args.max_seq_len] input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
...@@ -355,7 +361,14 @@ def train_and_evaluate(args): ...@@ -355,7 +361,14 @@ def train_and_evaluate(args):
train_step, in_shardings=in_shardings, out_shardings=out_shardings train_step, in_shardings=in_shardings, out_shardings=out_shardings
) )
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
None,
)
out_shardings = (None, None) out_shardings = (None, None)
jit_eval_step = jax.jit( jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings eval_step, in_shardings=in_shardings, out_shardings=out_shardings
...@@ -367,22 +380,24 @@ def train_and_evaluate(args): ...@@ -367,22 +380,24 @@ def train_and_evaluate(args):
if args.dry_run: if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng} rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng_state}
jit_train_step(state, inputs, masks, labels, var_collect, rngs) jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED") print("PASSED")
return None return None
for epoch in range(1, args.epochs + 1): for epoch in range(1, args.epochs + 1):
# Split and reassign to 'rng' to ensure unique rng for each step
rng, input_rng = jax.random.split(rng) rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} rng, sr_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
state, train_loss, train_accuracy, var_collect = train_epoch( state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, jit_train_step state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
) )
test_loss, test_accuracy = eval_model( test_loss, test_accuracy = eval_model(
state, test_ds, args.test_batch_size, var_collect, jit_eval_step state, test_ds, args.test_batch_size, var_collect, jit_eval_step, rngs
) )
print( print(
...@@ -402,16 +417,16 @@ def encoder_parser(args): ...@@ -402,16 +417,16 @@ def encoder_parser(args):
parser.add_argument( parser.add_argument(
"--batch-size", "--batch-size",
type=int, type=int,
default=128, default=256,
metavar="N", metavar="N",
help="input batch size for training (default: 128)", help="input batch size for training (default: 256)",
) )
parser.add_argument( parser.add_argument(
"--test-batch-size", "--test-batch-size",
type=int, type=int,
default=128, default=256,
metavar="N", metavar="N",
help="input batch size for testing (default: 128)", help="input batch size for testing (default: 256)",
) )
parser.add_argument( parser.add_argument(
"--max-seq-len", "--max-seq-len",
...@@ -466,8 +481,9 @@ def encoder_parser(args): ...@@ -466,8 +481,9 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase): class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
def setUp(self): def setUp(self):
"""Run 5 epochs for testing""" """Run 5 epochs for testing"""
...@@ -477,7 +493,7 @@ class TestEncoder(unittest.TestCase): ...@@ -477,7 +493,7 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83 assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
...@@ -485,7 +501,7 @@ class TestEncoder(unittest.TestCase): ...@@ -485,7 +501,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83 assert actual[0] < 0.361 and actual[1] > 0.84
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self): def test_te_mxfp8(self):
...@@ -493,14 +509,22 @@ class TestEncoder(unittest.TestCase): ...@@ -493,14 +509,22 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83 assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4(self):
"""Test Transformer Engine with NVFP4"""
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.40 and actual[1] > 0.82
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_with_sp(self): def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP""" """Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True self.args.enable_sp = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83 assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self): def test_te_delayed_scaling_fp8_with_sp(self):
...@@ -509,7 +533,7 @@ class TestEncoder(unittest.TestCase): ...@@ -509,7 +533,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83 assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self): def test_te_mxfp8_with_sp(self):
...@@ -518,14 +542,23 @@ class TestEncoder(unittest.TestCase): ...@@ -518,14 +542,23 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83 assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4_with_sp(self):
"""Test Transformer Engine with NVFP4"""
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.40 and actual[1] > 0.82
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self): def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
self.args.enable_shardy = True self.args.enable_shardy = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83 assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self): def test_te_delayed_scaling_fp8_shardy(self):
...@@ -534,7 +567,7 @@ class TestEncoder(unittest.TestCase): ...@@ -534,7 +567,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83 assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp_shardy(self): def test_te_delayed_scaling_fp8_with_sp_shardy(self):
...@@ -544,24 +577,27 @@ class TestEncoder(unittest.TestCase): ...@@ -544,24 +577,27 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83 assert actual[0] < 0.361 and actual[1] > 0.84
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self): def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8""" """Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True self.args.enable_shardy = True
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83 assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.40 and actual[1] > 0.82
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_with_sp_shardy(self): def test_te_mxfp8_with_sp_shardy(self):
"""Test Transformer Engine with MXFP8 + SP""" """Test Transformer Engine with MXFP8 + SP"""
self.args.enable_shardy = True self.args.enable_shardy = True
...@@ -569,7 +605,17 @@ class TestEncoder(unittest.TestCase): ...@@ -569,7 +605,17 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83 assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4_with_sp_shardy(self):
"""Test Transformer Engine with NVFP4"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.40 and actual[1] > 0.82
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -19,17 +19,18 @@ from flax.training import train_state ...@@ -19,17 +19,18 @@ from flax.training import train_state
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from common import is_bf16_supported, get_fp8_recipe_from_name_string from common import is_bf16_supported, get_quantization_recipe_from_name_string
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
DEVICE_DP_AXIS = "data" DEVICE_DP_AXIS = "data"
PARAMS_KEY = "params" PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes" PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
DROPOUT_KEY = "dropout" DROPOUT_KEY = "dropout"
SR_KEY = "sr_rng"
INPUT_KEY = "input_rng" INPUT_KEY = "input_rng"
...@@ -97,6 +98,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): ...@@ -97,6 +98,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
epoch_accuracy = [] epoch_accuracy = []
for perm in perms: for perm in perms:
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_inputs = train_ds["sentence"][perm, ...] batch_inputs = train_ds["sentence"][perm, ...]
batch_masks = train_ds["mask"][perm, ...] batch_masks = train_ds["mask"][perm, ...]
batch_labels = train_ds["label"][perm, ...] batch_labels = train_ds["label"][perm, ...]
...@@ -111,11 +114,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): ...@@ -111,11 +114,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
return state, avg_loss, avg_accuracy, var_collect return state, avg_loss, avg_accuracy, var_collect
def eval_step(state, inputs, masks, labels, var_collect): def eval_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes loss and accuracy for a single batch.""" """Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False): def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
...@@ -126,7 +129,7 @@ def eval_step(state, inputs, masks, labels, var_collect): ...@@ -126,7 +129,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
return loss, accuracy return loss, accuracy
def eval_model(state, test_ds, batch_size, var_collect, eval_fn): def eval_model(state, test_ds, batch_size, var_collect, eval_fn, rngs):
"""Evaluation loop.""" """Evaluation loop."""
test_ds_size = len(test_ds["sentence"]) test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size num_steps = test_ds_size // batch_size
...@@ -135,11 +138,13 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): ...@@ -135,11 +138,13 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
all_accuracy = [] all_accuracy = []
for batch_start in range(0, valid_size, batch_size): for batch_start in range(0, valid_size, batch_size):
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_end = batch_start + batch_size batch_end = batch_start + batch_size
batch_inputs = test_ds["sentence"][batch_start:batch_end] batch_inputs = test_ds["sentence"][batch_start:batch_end]
batch_masks = test_ds["mask"][batch_start:batch_end] batch_masks = test_ds["mask"][batch_start:batch_end]
batch_labels = test_ds["label"][batch_start:batch_end] batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect) loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect, rngs)
all_loss.append(loss) all_loss.append(loss)
all_accuracy.append(accuracy) all_accuracy.append(accuracy)
...@@ -199,7 +204,7 @@ def get_datasets(max_seq_len): ...@@ -199,7 +204,7 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels): def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8." "Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)}
func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
...@@ -254,7 +259,7 @@ def train_and_evaluate(args): ...@@ -254,7 +259,7 @@ def train_and_evaluate(args):
), "Test batch size needs to be multiple of 32 for MXFP8" ), "Test batch size needs to be multiple of 32 for MXFP8"
if args.use_fp8: if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
else: else:
fp8_recipe = None fp8_recipe = None
...@@ -270,6 +275,7 @@ def train_and_evaluate(args): ...@@ -270,6 +275,7 @@ def train_and_evaluate(args):
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng) rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng)
rng, sr_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
input_shape = [args.batch_size, args.max_seq_len] input_shape = [args.batch_size, args.max_seq_len]
...@@ -322,7 +328,14 @@ def train_and_evaluate(args): ...@@ -322,7 +328,14 @@ def train_and_evaluate(args):
train_step, in_shardings=in_shardings, out_shardings=out_shardings train_step, in_shardings=in_shardings, out_shardings=out_shardings
) )
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
None,
)
out_shardings = (None, None) out_shardings = (None, None)
jit_eval_step = jax.jit( jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings eval_step, in_shardings=in_shardings, out_shardings=out_shardings
...@@ -334,22 +347,24 @@ def train_and_evaluate(args): ...@@ -334,22 +347,24 @@ def train_and_evaluate(args):
if args.dry_run: if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng} rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
jit_train_step(state, inputs, masks, labels, var_collect, rngs) jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED") print("PASSED")
return None return None
for epoch in range(1, args.epochs + 1): for epoch in range(1, args.epochs + 1):
# Split and reassign to 'rng' to ensure unique rng for each step
rng, input_rng = jax.random.split(rng) rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} rng, sr_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
state, train_loss, train_accuracy, var_collect = train_epoch( state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, jit_train_step state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
) )
test_loss, test_accuracy = eval_model( test_loss, test_accuracy = eval_model(
state, test_ds, args.test_batch_size, var_collect, jit_eval_step state, test_ds, args.test_batch_size, var_collect, jit_eval_step, rngs
) )
print( print(
...@@ -369,16 +384,16 @@ def encoder_parser(args): ...@@ -369,16 +384,16 @@ def encoder_parser(args):
parser.add_argument( parser.add_argument(
"--batch-size", "--batch-size",
type=int, type=int,
default=256, default=512,
metavar="N", metavar="N",
help="input batch size for training (default: 256)", help="input batch size for training (default: 512)",
) )
parser.add_argument( parser.add_argument(
"--test-batch-size", "--test-batch-size",
type=int, type=int,
default=256, default=512,
metavar="N", metavar="N",
help="input batch size for testing (default: 256)", help="input batch size for testing (default: 512)",
) )
parser.add_argument( parser.add_argument(
"--max-seq-len", "--max-seq-len",
...@@ -430,8 +445,9 @@ def encoder_parser(args): ...@@ -430,8 +445,9 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase): class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
def setUp(self): def setUp(self):
"""Run 5 epochs for testing""" """Run 5 epochs for testing"""
...@@ -441,7 +457,7 @@ class TestEncoder(unittest.TestCase): ...@@ -441,7 +457,7 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74 assert actual[0] < 0.51 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
...@@ -449,7 +465,7 @@ class TestEncoder(unittest.TestCase): ...@@ -449,7 +465,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74 assert actual[0] < 0.51 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self): def test_te_current_scaling_fp8(self):
...@@ -457,7 +473,7 @@ class TestEncoder(unittest.TestCase): ...@@ -457,7 +473,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling" self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74 assert actual[0] < 0.51 and actual[1] > 0.749
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self): def test_te_mxfp8(self):
...@@ -465,6 +481,14 @@ class TestEncoder(unittest.TestCase): ...@@ -465,6 +481,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.51 and actual[1] > 0.75
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4(self):
"""Test Transformer Engine with NVFP4"""
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74 assert actual[0] < 0.52 and actual[1] > 0.74
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
...@@ -472,7 +496,7 @@ class TestEncoder(unittest.TestCase): ...@@ -472,7 +496,7 @@ class TestEncoder(unittest.TestCase):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
self.args.enable_shardy = True self.args.enable_shardy = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74 assert actual[0] < 0.51 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self): def test_te_delayed_scaling_fp8_shardy(self):
...@@ -481,7 +505,7 @@ class TestEncoder(unittest.TestCase): ...@@ -481,7 +505,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74 assert actual[0] < 0.51 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8_shardy(self): def test_te_current_scaling_fp8_shardy(self):
...@@ -490,18 +514,24 @@ class TestEncoder(unittest.TestCase): ...@@ -490,18 +514,24 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling" self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74 assert actual[0] < 0.51 and actual[1] > 0.749
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self): def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8""" """Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True self.args.enable_shardy = True
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.51 and actual[1] > 0.75
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74 assert actual[0] < 0.52 and actual[1] > 0.74
......
...@@ -25,7 +25,8 @@ from common import ( ...@@ -25,7 +25,8 @@ from common import (
is_bf16_supported, is_bf16_supported,
is_fp8_supported, is_fp8_supported,
is_mxfp8_supported, is_mxfp8_supported,
get_fp8_recipe_from_name_string, is_nvfp4_supported,
get_quantization_recipe_from_name_string,
) )
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.cpp_extensions as tex
...@@ -39,6 +40,7 @@ NAMED_BROADCAST_AXIS = "my_broadcast_axis" ...@@ -39,6 +40,7 @@ NAMED_BROADCAST_AXIS = "my_broadcast_axis"
NAMED_TP_AXIS = "my_tp_axis" NAMED_TP_AXIS = "my_tp_axis"
PARAMS_KEY = "params" PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes" PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
SR_KEY = "sr_rng"
DROPOUT_KEY = "dropout" DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng" INPUT_KEY = "input_rng"
...@@ -175,6 +177,8 @@ def train_epoch( ...@@ -175,6 +177,8 @@ def train_epoch(
epoch_accuracy = [] epoch_accuracy = []
for perm in perms: for perm in perms:
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_input = sentence[perm, ...] batch_input = sentence[perm, ...]
batch_mask = mask[perm, ...] batch_mask = mask[perm, ...]
batch_label = label[perm, ...] batch_label = label[perm, ...]
...@@ -200,11 +204,11 @@ def train_epoch( ...@@ -200,11 +204,11 @@ def train_epoch(
return state, avg_loss, avg_accuracy, var_collect return state, avg_loss, avg_accuracy, var_collect
def eval_step(state, inputs, masks, labels, var_collect): def eval_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes loss and accuracy for a single batch.""" """Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False): def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2) one_hot = jax.nn.one_hot(labels, 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
...@@ -216,7 +220,16 @@ def eval_step(state, inputs, masks, labels, var_collect): ...@@ -216,7 +220,16 @@ def eval_step(state, inputs, masks, labels, var_collect):
def eval_model( def eval_model(
state, test_ds, batch_size, var_collect, eval_fn, mesh, inputs_pspec, masks_pspec, labels_pspec state,
test_ds,
batch_size,
var_collect,
eval_fn,
mesh,
inputs_pspec,
masks_pspec,
labels_pspec,
rngs,
): ):
"""Evaluation loop.""" """Evaluation loop."""
global_input_shape, input_named_sharding, sentence = shard_array_wrapper( global_input_shape, input_named_sharding, sentence = shard_array_wrapper(
...@@ -233,7 +246,8 @@ def eval_model( ...@@ -233,7 +246,8 @@ def eval_model(
all_accuracy = [] all_accuracy = []
for batch_input, batch_mask, batch_label in zip(sentence, mask, label): for batch_input, batch_mask, batch_label in zip(sentence, mask, label):
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
shard_input = jax.make_array_from_single_device_arrays( shard_input = jax.make_array_from_single_device_arrays(
global_input_shape, input_named_sharding, [batch_input] global_input_shape, input_named_sharding, [batch_input]
) )
...@@ -244,7 +258,7 @@ def eval_model( ...@@ -244,7 +258,7 @@ def eval_model(
global_label_shape, label_named_sharding, [batch_label] global_label_shape, label_named_sharding, [batch_label]
) )
loss, accuracy = eval_fn(state, shard_input, shard_mask, shard_label, var_collect) loss, accuracy = eval_fn(state, shard_input, shard_mask, shard_label, var_collect, rngs)
all_loss.append(loss) all_loss.append(loss)
all_accuracy.append(accuracy) all_accuracy.append(accuracy)
...@@ -303,7 +317,7 @@ def get_datasets(max_seq_len): ...@@ -303,7 +317,7 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels): def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8." "Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)}
func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
...@@ -372,7 +386,7 @@ def train_and_evaluate(args): ...@@ -372,7 +386,7 @@ def train_and_evaluate(args):
), "Test batch size needs to be multiple of 32 for MXFP8" ), "Test batch size needs to be multiple of 32 for MXFP8"
if args.use_fp8: if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
else: else:
fp8_recipe = None fp8_recipe = None
...@@ -390,7 +404,8 @@ def train_and_evaluate(args): ...@@ -390,7 +404,8 @@ def train_and_evaluate(args):
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng) rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} rng, sr_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
input_shape = [args.batch_size, args.max_seq_len] input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
...@@ -444,7 +459,14 @@ def train_and_evaluate(args): ...@@ -444,7 +459,14 @@ def train_and_evaluate(args):
train_step, in_shardings=in_shardings, out_shardings=out_shardings train_step, in_shardings=in_shardings, out_shardings=out_shardings
) )
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
None,
)
out_shardings = (None, None) out_shardings = (None, None)
jit_eval_step = jax.jit( jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings eval_step, in_shardings=in_shardings, out_shardings=out_shardings
...@@ -456,14 +478,16 @@ def train_and_evaluate(args): ...@@ -456,14 +478,16 @@ def train_and_evaluate(args):
if args.dry_run: if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng} rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng_state}
jit_train_step(state, inputs, masks, labels, var_collect, rngs) jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED") print("PASSED")
else: else:
for epoch in range(1, args.epochs + 1): for epoch in range(1, args.epochs + 1):
# Split and reassign to 'rng' to ensure unique rng for each step
rng, input_rng = jax.random.split(rng) rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} rng, sr_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
state, train_loss, train_accuracy, var_collect = train_epoch( state, train_loss, train_accuracy, var_collect = train_epoch(
state, state,
...@@ -488,6 +512,7 @@ def train_and_evaluate(args): ...@@ -488,6 +512,7 @@ def train_and_evaluate(args):
inputs_pspec, inputs_pspec,
masks_pspec, masks_pspec,
labels_sharding.spec, labels_sharding.spec,
rngs,
) )
if args.process_id == 0: if args.process_id == 0:
print( print(
...@@ -508,16 +533,16 @@ def encoder_parser(args): ...@@ -508,16 +533,16 @@ def encoder_parser(args):
parser.add_argument( parser.add_argument(
"--batch-size", "--batch-size",
type=int, type=int,
default=128, default=256,
metavar="N", metavar="N",
help="input batch size for training (default: 128)", help="input batch size for training (default: 256)",
) )
parser.add_argument( parser.add_argument(
"--test-batch-size", "--test-batch-size",
type=int, type=int,
default=128, default=256,
metavar="N", metavar="N",
help="input batch size for testing (default: 128)", help="input batch size for testing (default: 256)",
) )
parser.add_argument( parser.add_argument(
"--max-seq-len", "--max-seq-len",
...@@ -629,7 +654,7 @@ class TestEncoder(unittest.TestCase): ...@@ -629,7 +654,7 @@ class TestEncoder(unittest.TestCase):
def test_te_current_scaling_fp8(self): def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8""" """Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling") result = self.exec(True, "Float8CurrentScaling")
assert result[0] < 0.43 and result[1] > 0.80 assert result[0] < 0.432 and result[1] > 0.80
@unittest.skipIf( @unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
...@@ -639,6 +664,14 @@ class TestEncoder(unittest.TestCase): ...@@ -639,6 +664,14 @@ class TestEncoder(unittest.TestCase):
result = self.exec(True, "MXFP8BlockScaling") result = self.exec(True, "MXFP8BlockScaling")
assert result[0] < 0.43 and result[1] > 0.80 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(
not is_nvfp4_supported(), "Device compute capability 10.0+ is required for NVFP4"
)
def test_te_nvfp4(self):
"""Test Transformer Engine with NVFP4"""
result = self.exec(True, "NVFP4BlockScaling")
assert result[0] < 0.451 and result[1] > 0.79
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self): def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
...@@ -659,19 +692,24 @@ class TestEncoder(unittest.TestCase): ...@@ -659,19 +692,24 @@ class TestEncoder(unittest.TestCase):
def test_te_current_scaling_fp8_shardy(self): def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8""" """Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) result = self.exec(True, "Float8CurrentScaling", enable_shardy=True)
assert result[0] < 0.43 and result[1] > 0.80 assert result[0] < 0.432 and result[1] > 0.80
@unittest.skipIf( @unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
) )
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self): def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8""" """Test Transformer Engine with MXFP8"""
result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True) result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True)
assert result[0] < 0.43 and result[1] > 0.80 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(
not is_nvfp4_supported(), "Device compute capability 10.0+ is required for NVFP4"
)
def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4"""
result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True)
assert result[0] < 0.451 and result[1] > 0.79
if __name__ == "__main__": if __name__ == "__main__":
train_and_evaluate(encoder_parser(None)) train_and_evaluate(encoder_parser(None))
...@@ -16,14 +16,15 @@ from datasets import load_dataset ...@@ -16,14 +16,15 @@ from datasets import load_dataset
from flax import linen as nn from flax import linen as nn
from flax.training import train_state from flax.training import train_state
from common import is_bf16_supported, get_fp8_recipe_from_name_string from common import is_bf16_supported, get_quantization_recipe_from_name_string
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
PARAMS_KEY = "params" PARAMS_KEY = "params"
DROPOUT_KEY = "dropout" DROPOUT_KEY = "dropout"
SR_KEY = "sr_rng"
INPUT_KEY = "input_rng" INPUT_KEY = "input_rng"
...@@ -92,6 +93,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect): ...@@ -92,6 +93,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect):
epoch_accuracy = [] epoch_accuracy = []
for perm in perms: for perm in perms:
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_inputs = train_ds["sentence"][perm, ...] batch_inputs = train_ds["sentence"][perm, ...]
batch_masks = train_ds["mask"][perm, ...] batch_masks = train_ds["mask"][perm, ...]
batch_labels = train_ds["label"][perm, ...] batch_labels = train_ds["label"][perm, ...]
...@@ -107,11 +110,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect): ...@@ -107,11 +110,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect):
@jax.jit @jax.jit
def eval_step(state, inputs, masks, labels, var_collect): def eval_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes loss and accuracy for a single batch.""" """Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False): def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
...@@ -122,7 +125,7 @@ def eval_step(state, inputs, masks, labels, var_collect): ...@@ -122,7 +125,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
return loss, accuracy return loss, accuracy
def eval_model(state, test_ds, batch_size, var_collect): def eval_model(state, test_ds, batch_size, var_collect, rngs):
"""Evaluation loop.""" """Evaluation loop."""
test_ds_size = len(test_ds["sentence"]) test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size num_steps = test_ds_size // batch_size
...@@ -131,11 +134,15 @@ def eval_model(state, test_ds, batch_size, var_collect): ...@@ -131,11 +134,15 @@ def eval_model(state, test_ds, batch_size, var_collect):
all_accuracy = [] all_accuracy = []
for batch_start in range(0, valid_size, batch_size): for batch_start in range(0, valid_size, batch_size):
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_end = batch_start + batch_size batch_end = batch_start + batch_size
batch_inputs = test_ds["sentence"][batch_start:batch_end] batch_inputs = test_ds["sentence"][batch_start:batch_end]
batch_masks = test_ds["mask"][batch_start:batch_end] batch_masks = test_ds["mask"][batch_start:batch_end]
batch_labels = test_ds["label"][batch_start:batch_end] batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_step(state, batch_inputs, batch_masks, batch_labels, var_collect) loss, accuracy = eval_step(
state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
)
all_loss.append(loss) all_loss.append(loss)
all_accuracy.append(accuracy) all_accuracy.append(accuracy)
...@@ -195,7 +202,7 @@ def get_datasets(max_seq_len): ...@@ -195,7 +202,7 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels): def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8." "Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)}
func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
...@@ -208,14 +215,15 @@ def train_and_evaluate(args): ...@@ -208,14 +215,15 @@ def train_and_evaluate(args):
rng = jax.random.PRNGKey(args.seed) rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng) rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} rng, sr_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
input_shape = [args.batch_size, args.max_seq_len] input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size] label_shape = [args.batch_size]
if args.use_fp8: if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
else: else:
fp8_recipe = None fp8_recipe = None
...@@ -238,21 +246,25 @@ def train_and_evaluate(args): ...@@ -238,21 +246,25 @@ def train_and_evaluate(args):
if args.dry_run: if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng} rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
train_step(state, inputs, masks, labels, var_collect, rngs) train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED") print("PASSED")
return None return None
for epoch in range(1, args.epochs + 1): for epoch in range(1, args.epochs + 1):
# Split and reassign to 'rng' to ensure unique rng for each step
rng, input_rng = jax.random.split(rng) rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} rng, sr_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
state, train_loss, train_accuracy, var_collect = train_epoch( state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect state, train_ds, args.batch_size, rngs, var_collect
) )
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect) test_loss, test_accuracy = eval_model(
state, test_ds, args.test_batch_size, var_collect, rngs
)
print( print(
f"Epoch: {epoch:>2} " f"Epoch: {epoch:>2} "
...@@ -329,8 +341,9 @@ def encoder_parser(args): ...@@ -329,8 +341,9 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase): class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
def setUp(self): def setUp(self):
"""Run 3 epochs for testing""" """Run 3 epochs for testing"""
...@@ -340,7 +353,7 @@ class TestEncoder(unittest.TestCase): ...@@ -340,7 +353,7 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79 assert actual[0] < 0.452 and actual[1] > 0.788
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
...@@ -348,7 +361,7 @@ class TestEncoder(unittest.TestCase): ...@@ -348,7 +361,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79 assert actual[0] < 0.457 and actual[1] > 0.784
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self): def test_te_current_scaling_fp8(self):
...@@ -356,7 +369,7 @@ class TestEncoder(unittest.TestCase): ...@@ -356,7 +369,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling" self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79 assert actual[0] < 0.461 and actual[1] > 0.784
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self): def test_te_mxfp8(self):
...@@ -364,7 +377,15 @@ class TestEncoder(unittest.TestCase): ...@@ -364,7 +377,15 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79 assert actual[0] < 0.457 and actual[1] > 0.784
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4(self):
"""Test Transformer Engine with NVFP4"""
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.476 and actual[1] > 0.775
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -18,11 +18,11 @@ from flax.training import train_state ...@@ -18,11 +18,11 @@ from flax.training import train_state
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
DIR = str(Path(__file__).resolve().parents[1]) DIR = str(Path(__file__).resolve().parents[1])
sys.path.append(str(DIR)) sys.path.append(str(DIR))
from encoder.common import is_bf16_supported, get_fp8_recipe_from_name_string from encoder.common import is_bf16_supported, get_quantization_recipe_from_name_string
IMAGE_H = 28 IMAGE_H = 28
IMAGE_W = 28 IMAGE_W = 28
...@@ -189,7 +189,7 @@ def train_and_evaluate(args): ...@@ -189,7 +189,7 @@ def train_and_evaluate(args):
label_shape = [args.batch_size] label_shape = [args.batch_size]
if args.use_fp8: if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
else: else:
fp8_recipe = None fp8_recipe = None
...@@ -308,8 +308,8 @@ def mnist_parser(args): ...@@ -308,8 +308,8 @@ def mnist_parser(args):
class TestMNIST(unittest.TestCase): class TestMNIST(unittest.TestCase):
"""MNIST unittests""" """MNIST unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
This diff is collapsed.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
import re
from typing import Callable, Sequence, Union, Optional from typing import Callable, Sequence, Union, Optional
import pytest import pytest
...@@ -17,7 +18,11 @@ from utils import ( ...@@ -17,7 +18,11 @@ from utils import (
) )
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode from transformer_engine.jax.quantize import (
is_fp8_available,
ScalingMode,
get_quantize_config_with_recipe,
)
from transformer_engine.jax import fp8_autocast from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.flax import LayerNormMLP from transformer_engine.jax.flax import LayerNormMLP
from transformer_engine.jax.layernorm_mlp import layernorm_mlp from transformer_engine.jax.layernorm_mlp import layernorm_mlp
...@@ -33,19 +38,20 @@ from transformer_engine.jax.sharding import ( ...@@ -33,19 +38,20 @@ from transformer_engine.jax.sharding import (
W_JOINED_AXES, W_JOINED_AXES,
) )
from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.quantize import QuantizerFactory from transformer_engine.jax.quantize import (
QuantizerFactory,
get_supported_quantization_recipes,
is_scaling_mode_supported,
)
from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
SUPPORTED_RECIPES = [] SUPPORTED_RECIPES = get_supported_quantization_recipes()
if is_fp8_supported: SUPPORTED_RECIPES = [pytest.param(r, id=r.__class__.__name__) for r in SUPPORTED_RECIPES]
SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling"))
if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
DTYPES = [jnp.bfloat16, jnp.float16] DTYPES = [jnp.bfloat16, jnp.float16]
INPUT_SHAPE = [[4, 128, 256]] # [batch, seqlen, hidden_in] INPUT_SHAPE = [[4, 128, 256]] # [batch, seqlen, hidden_in]
...@@ -141,6 +147,7 @@ class TestDistributedLayernormMLP: ...@@ -141,6 +147,7 @@ class TestDistributedLayernormMLP:
layernorm_type: str = "rmsnorm", layernorm_type: str = "rmsnorm",
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
multi_gpus: bool = False, multi_gpus: bool = False,
quantization_recipe: recipe.Recipe = None,
) -> jnp.ndarray: ) -> jnp.ndarray:
if multi_gpus: if multi_gpus:
...@@ -154,7 +161,9 @@ class TestDistributedLayernormMLP: ...@@ -154,7 +161,9 @@ class TestDistributedLayernormMLP:
dot_1_input_axes = dot_2_input_axes = None dot_1_input_axes = dot_2_input_axes = None
kernel_1_axes = kernel_2_axes = None kernel_1_axes = kernel_2_axes = None
quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2) quantizer_sets = QuantizerFactory.create_set(
n_quantizer_sets=2, fp8_recipe=quantization_recipe
)
# out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2 # out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
return jnp.mean( return jnp.mean(
...@@ -182,7 +191,7 @@ class TestDistributedLayernormMLP: ...@@ -182,7 +191,7 @@ class TestDistributedLayernormMLP:
use_bias, use_bias,
input_shape, input_shape,
dtype, dtype,
fp8_recipe, quantization_recipe,
use_shardy, use_shardy,
with_jax_gemm, with_jax_gemm,
): ):
...@@ -202,7 +211,9 @@ class TestDistributedLayernormMLP: ...@@ -202,7 +211,9 @@ class TestDistributedLayernormMLP:
# Single GPU # Single GPU
with fp8_autocast( with fp8_autocast(
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=MeshResource() enabled=quantization_recipe is not None,
fp8_recipe=quantization_recipe,
mesh_resource=MeshResource(),
): ):
single_jitter = jax.jit( single_jitter = jax.jit(
value_and_grad_func, value_and_grad_func,
...@@ -214,7 +225,9 @@ class TestDistributedLayernormMLP: ...@@ -214,7 +225,9 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast( with mesh, fp8_autocast(
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource enabled=quantization_recipe is not None,
fp8_recipe=quantization_recipe,
mesh_resource=mesh_resource,
): ):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp")) k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp")) k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp"))
...@@ -254,9 +267,15 @@ class TestDistributedLayernormMLP: ...@@ -254,9 +267,15 @@ class TestDistributedLayernormMLP:
multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)
fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn fwd_test_type = bwd_test_type = dtype
bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2 if quantization_recipe is not None:
quantize_config = get_quantize_config_with_recipe(quantization_recipe)
fwd_test_type = quantize_config.FWD_DTYPE
bwd_test_type = quantize_config.BWD_DTYPE
if fwd_test_type == jnp.float16 and use_bias:
assert_allclose(multi_fwd, single_fwd, atol=0.04, rtol=1.5)
else:
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)
for i in range(len(inputs)): for i in range(len(inputs)):
...@@ -278,13 +297,12 @@ class TestDistributedLayernormMLP: ...@@ -278,13 +297,12 @@ class TestDistributedLayernormMLP:
err_msg=f"multi_grads[{i}] is not close", err_msg=f"multi_grads[{i}] is not close",
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES) @pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad( def test_layernorm_mlp_grad(
self, self,
...@@ -293,27 +311,28 @@ class TestDistributedLayernormMLP: ...@@ -293,27 +311,28 @@ class TestDistributedLayernormMLP:
use_bias, use_bias,
input_shape, input_shape,
dtype, dtype,
fp8_recipe, quantization_recipe,
with_jax_gemm, with_jax_gemm,
): ):
if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
self._test_layernorm_mlp_grad( self._test_layernorm_mlp_grad(
mesh_config, mesh_config,
activation_type, activation_type,
use_bias, use_bias,
input_shape, input_shape,
dtype, dtype,
fp8_recipe, quantization_recipe,
use_shardy=False, use_shardy=False,
with_jax_gemm=with_jax_gemm, with_jax_gemm=with_jax_gemm,
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES) @pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad_shardy( def test_layernorm_mlp_grad_shardy(
self, self,
...@@ -322,18 +341,18 @@ class TestDistributedLayernormMLP: ...@@ -322,18 +341,18 @@ class TestDistributedLayernormMLP:
use_bias, use_bias,
input_shape, input_shape,
dtype, dtype,
fp8_recipe, quantization_recipe,
with_jax_gemm, with_jax_gemm,
): ):
if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling): if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
self._test_layernorm_mlp_grad( self._test_layernorm_mlp_grad(
mesh_config, mesh_config,
activation_type, activation_type,
use_bias, use_bias,
input_shape, input_shape,
dtype, dtype,
fp8_recipe=fp8_recipe, quantization_recipe=quantization_recipe,
use_shardy=True, use_shardy=True,
with_jax_gemm=with_jax_gemm, with_jax_gemm=with_jax_gemm,
) )
...@@ -346,7 +365,7 @@ class TestDistributedLayernormMLP: ...@@ -346,7 +365,7 @@ class TestDistributedLayernormMLP:
input_shape, input_shape,
dtype, dtype,
use_fp8, use_fp8,
fp8_recipe, quantization_recipe,
use_shardy, use_shardy,
with_jax_gemm, with_jax_gemm,
): ):
...@@ -355,14 +374,16 @@ class TestDistributedLayernormMLP: ...@@ -355,14 +374,16 @@ class TestDistributedLayernormMLP:
layernorm_type = "rmsnorm" layernorm_type = "rmsnorm"
rng = jax.random.PRNGKey(0) rng = jax.random.PRNGKey(0)
subkeys = jax.random.split(rng, 2) subkeys = jax.random.split(rng, 3)
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
init_rngs = {"params": subkeys[1]} init_rngs = {"params": subkeys[1], "sr_rng": subkeys[2]}
with use_jax_gemm(enabled=with_jax_gemm): with use_jax_gemm(enabled=with_jax_gemm):
# Single GPUs # Single GPUs
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): with fp8_autocast(
enabled=use_fp8, fp8_recipe=quantization_recipe, mesh_resource=MeshResource()
):
ln_mlp_single = LayerNormMLP( ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
...@@ -371,7 +392,7 @@ class TestDistributedLayernormMLP: ...@@ -371,7 +392,7 @@ class TestDistributedLayernormMLP:
) )
params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
mlp_out_single, ln_out_single = ln_mlp_single.apply( mlp_out_single, ln_out_single = ln_mlp_single.apply(
params_single, x, deterministic=True params_single, x, deterministic=True, rngs={"sr_rng": subkeys[2]}
) )
# Multi GPUs # Multi GPUs
...@@ -379,7 +400,7 @@ class TestDistributedLayernormMLP: ...@@ -379,7 +400,7 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast( with mesh, fp8_autocast(
enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource enabled=use_fp8, fp8_recipe=quantization_recipe, mesh_resource=mesh_resource
): ):
ln_mlp_sharded = LayerNormMLP( ln_mlp_sharded = LayerNormMLP(
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
...@@ -399,7 +420,7 @@ class TestDistributedLayernormMLP: ...@@ -399,7 +420,7 @@ class TestDistributedLayernormMLP:
) )
params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True) params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply( mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
params_sharded, x, deterministic=True params_sharded, x, deterministic=True, rngs={"sr_rng": subkeys[2]}
) )
# Make sure params values are the same # Make sure params values are the same
...@@ -411,8 +432,8 @@ class TestDistributedLayernormMLP: ...@@ -411,8 +432,8 @@ class TestDistributedLayernormMLP:
rtol = None rtol = None
l40_tolerance_update = ( l40_tolerance_update = (
get_min_device_compute_capability() == 89 get_min_device_compute_capability() == 89
and fp8_recipe == recipe.DelayedScaling()
and use_fp8 and use_fp8
and quantization_recipe.delayed()
and dtype == jnp.float16 and dtype == jnp.float16
and activation_type == ("gelu",) and activation_type == ("gelu",)
) )
...@@ -430,8 +451,8 @@ class TestDistributedLayernormMLP: ...@@ -430,8 +451,8 @@ class TestDistributedLayernormMLP:
# within tolerance to the float32 ground truth. # within tolerance to the float32 ground truth.
jax_triton_gemm_precision_tolerance_update = ( jax_triton_gemm_precision_tolerance_update = (
with_jax_gemm with_jax_gemm
and fp8_recipe is not None and quantization_recipe is not None
and (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()) and (quantization_recipe.delayed() or quantization_recipe.float8_current_scaling())
and dtype in (jnp.bfloat16, jnp.float16) and dtype in (jnp.bfloat16, jnp.float16)
and activation_type == ("gelu", "linear"), and activation_type == ("gelu", "linear"),
) )
...@@ -457,22 +478,30 @@ class TestDistributedLayernormMLP: ...@@ -457,22 +478,30 @@ class TestDistributedLayernormMLP:
input_shape, input_shape,
dtype, dtype,
use_fp8=False, use_fp8=False,
fp8_recipe=None, quantization_recipe=None,
use_shardy=False, use_shardy=False,
with_jax_gemm=with_jax_gemm, with_jax_gemm=with_jax_gemm,
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_fp8( def test_layernorm_mlp_layer_fp8(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
quantization_recipe,
with_jax_gemm,
): ):
if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
self._test_layernorm_mlp( self._test_layernorm_mlp(
mesh_config, mesh_config,
activation_type, activation_type,
...@@ -480,7 +509,7 @@ class TestDistributedLayernormMLP: ...@@ -480,7 +509,7 @@ class TestDistributedLayernormMLP:
input_shape, input_shape,
dtype, dtype,
use_fp8=True, use_fp8=True,
fp8_recipe=fp8_recipe, quantization_recipe=quantization_recipe,
use_shardy=False, use_shardy=False,
with_jax_gemm=with_jax_gemm, with_jax_gemm=with_jax_gemm,
) )
...@@ -501,24 +530,30 @@ class TestDistributedLayernormMLP: ...@@ -501,24 +530,30 @@ class TestDistributedLayernormMLP:
input_shape, input_shape,
dtype, dtype,
use_fp8=False, use_fp8=False,
fp8_recipe=None, quantization_recipe=None,
use_shardy=True, use_shardy=True,
with_jax_gemm=with_jax_gemm, with_jax_gemm=with_jax_gemm,
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_fp8_shardy( def test_layernorm_mlp_layer_fp8_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
quantization_recipe,
with_jax_gemm,
): ):
if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling): if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
self._test_layernorm_mlp( self._test_layernorm_mlp(
mesh_config, mesh_config,
activation_type, activation_type,
...@@ -526,7 +561,7 @@ class TestDistributedLayernormMLP: ...@@ -526,7 +561,7 @@ class TestDistributedLayernormMLP:
input_shape, input_shape,
dtype, dtype,
use_fp8=True, use_fp8=True,
fp8_recipe=fp8_recipe, quantization_recipe=quantization_recipe,
use_shardy=True, use_shardy=True,
with_jax_gemm=with_jax_gemm, with_jax_gemm=with_jax_gemm,
) )
...@@ -10,20 +10,27 @@ import jax.numpy as jnp ...@@ -10,20 +10,27 @@ import jax.numpy as jnp
import numpy as np import numpy as np
from utils import assert_allclose from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling from transformer_engine.common.recipe import (
DelayedScaling,
MXFP8BlockScaling,
Float8CurrentScaling,
NVFP4BlockScaling,
)
from transformer_engine.common.recipe import Format as FP8Format from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast, get_delayed_scaling from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.quantize import ( from transformer_engine.jax.quantize import (
get_quantize_config, get_quantize_config,
is_fp8_available, is_scaling_mode_supported,
ScalingMode, ScalingMode,
update_collections, update_collections,
TensorSource, TensorSource,
) )
from transformer_engine.jax.quantize.helper import _format2dtypes
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
class TestHelper(unittest.TestCase): class TestHelper(unittest.TestCase):
...@@ -52,14 +59,16 @@ class TestFP8Functions(unittest.TestCase): ...@@ -52,14 +59,16 @@ class TestFP8Functions(unittest.TestCase):
def _check_default_state(self): def _check_default_state(self):
self.assertFalse(get_quantize_config().is_fp8_enabled()) self.assertFalse(get_quantize_config().is_fp8_enabled())
def _compare_delay_scaling(self, ref, test): def _compare_delay_scaling(self, test):
self.assertTrue(ref.margin == test.margin) self.assertEqual(get_quantize_config().MARGIN, test.margin)
self.assertTrue(ref.fp8_format == test.fp8_format) self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
self.assertTrue(ref.amax_history_len == test.amax_history_len) self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo) self.assertEqual(get_quantize_config().AMAX_HISTORY_LEN, test.amax_history_len)
self.assertEqual(get_quantize_config().AMAX_COMPUTE_ALGO.value, test.amax_compute_algo)
def _compare_current_scaling(self, test): def _compare_current_scaling(self, test):
self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format) self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
for tensor_source in TensorSource: for tensor_source in TensorSource:
self.assertEqual( self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), get_quantize_config().get_scaling_mode(tensor_source),
...@@ -67,13 +76,26 @@ class TestFP8Functions(unittest.TestCase): ...@@ -67,13 +76,26 @@ class TestFP8Functions(unittest.TestCase):
) )
def _compare_mxfp8_scaling(self, test): def _compare_mxfp8_scaling(self, test):
self.assertEqual(get_quantize_config().MARGIN, test.margin) self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format) self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
for tensor_source in TensorSource: for tensor_source in TensorSource:
self.assertEqual( self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING
) )
def _compare_nvfp4_scaling(self, test):
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp4_format)[0])
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp4_format)[1])
for tensor_source in TensorSource:
target_scaling_mode = (
ScalingMode.NVFP4_2D_SCALING
if tensor_source == TensorSource.KERNEL
else ScalingMode.NVFP4_1D_SCALING
)
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode
)
@unittest.skipIf(not is_fp8_supported, reason=reason) @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_delayed_scaling(self): def test_fp8_autocast_delayed_scaling(self):
self._check_default_state() self._check_default_state()
...@@ -86,14 +108,14 @@ class TestFP8Functions(unittest.TestCase): ...@@ -86,14 +108,14 @@ class TestFP8Functions(unittest.TestCase):
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1) ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled()) self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds) self._compare_delay_scaling(ds)
self._check_default_state() self._check_default_state()
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1) ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled()) self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds) self._compare_delay_scaling(ds)
self._check_default_state() self._check_default_state()
...@@ -133,16 +155,27 @@ class TestFP8Functions(unittest.TestCase): ...@@ -133,16 +155,27 @@ class TestFP8Functions(unittest.TestCase):
self._check_default_state() self._check_default_state()
bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3) bs = MXFP8BlockScaling()
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled()) self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_mxfp8_scaling(bs) self._compare_mxfp8_scaling(bs)
self._check_default_state() self._check_default_state()
bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID) @unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason)
def test_fp8_autocast_nvfp4_block_scaling(self):
self._check_default_state()
with fp8_autocast(
enabled=False, fp8_recipe=NVFP4BlockScaling(), mesh_resource=MeshResource()
):
self._check_default_state()
self._check_default_state()
bs = NVFP4BlockScaling()
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled()) self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_mxfp8_scaling(bs) self._compare_nvfp4_scaling(bs)
self._check_default_state() self._check_default_state()
...@@ -1544,6 +1544,12 @@ def dtype_tols( ...@@ -1544,6 +1544,12 @@ def dtype_tols(
rtol = eps_relaxed rtol = eps_relaxed
if atol is None: if atol is None:
atol = max(ulp, eps_relaxed) atol = max(ulp, eps_relaxed)
# Manually set tols for nvfp4
if dtype == jnp.float4_e2m1fn:
atol = 0.05
rtol = 0.1
return {"rtol": rtol, "atol": atol} return {"rtol": rtol, "atol": atol}
......
...@@ -34,7 +34,7 @@ load_framework_extension("jax") ...@@ -34,7 +34,7 @@ load_framework_extension("jax")
from . import flax from . import flax
from . import quantize from . import quantize
from .quantize import fp8_autocast, update_collections, get_delayed_scaling from .quantize import fp8_autocast, update_collections
from .quantize import NVTE_FP8_COLLECTION_NAME from .quantize import NVTE_FP8_COLLECTION_NAME
from .sharding import MeshResource from .sharding import MeshResource
...@@ -47,7 +47,6 @@ __all__ = [ ...@@ -47,7 +47,6 @@ __all__ = [
"NVTE_FP8_COLLECTION_NAME", "NVTE_FP8_COLLECTION_NAME",
"fp8_autocast", "fp8_autocast",
"update_collections", "update_collections",
"get_delayed_scaling",
"MeshResource", "MeshResource",
"flax", "flax",
"quantize", "quantize",
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Python interface for c++ extensions""" """Python interface for c++ extensions"""
from .activation import * from .activation import *
from .amax import *
from .attention import * from .attention import *
from .normalization import * from .normalization import *
from .quantization import * from .quantization import *
......
...@@ -1314,7 +1314,10 @@ def act_lu( ...@@ -1314,7 +1314,10 @@ def act_lu(
) )
return out return out
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if (
quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
or quantizer.scaling_mode.is_nvfp4_scaling
):
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after. # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out = act_lu( out = act_lu(
x=x, x=x,
...@@ -1488,7 +1491,10 @@ def quantize_dact_dbias( ...@@ -1488,7 +1491,10 @@ def quantize_dact_dbias(
if war_output is not None: if war_output is not None:
return war_output return war_output
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if (
quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
or quantizer.scaling_mode.is_nvfp4_scaling
):
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after. # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out = dact_lu( out = dact_lu(
dz=dz, dz=dz,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom ops for amax calculation"""
from enum import Enum
import jax
import jax.numpy as jnp
from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.sharding import PartitionSpec
from .base import BasePrimitive, register_primitive
from .misc import (
get_padded_spec,
NamedSharding,
)
from ..sharding import (
global_mesh_resource,
lax_paral_op,
)
from ..quantize import (
get_wgrad_sign_vector,
get_sign_from_vector,
)
__all__ = ["AmaxScope", "calculate_amax", "calculate_post_rht_amax"]
class AmaxScope(Enum):
"""
Amax Scope Enum
"""
LOCAL = 1
TPSP = 2
FSDP = 3
def all_reduce_amax_along_TPSP_and_FSDP(self, amax, data_spec, transpose_batch_sequence, mesh):
"""Reduce the amax based on its scope"""
gmesh = global_mesh_resource()
sequence_dim = 0 if transpose_batch_sequence else 1
# Run AR across TPSP only when tensor-sequence is detected in the input spec
if self is AmaxScope.TPSP and data_spec[sequence_dim] == gmesh.tpsp_resource:
return lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh)
# Run AR across FSDP
if self is AmaxScope.FSDP:
return lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh)
return amax
class AmaxCalculationPrimitive(BasePrimitive):
"""
Amax Calculation Primitive with custom_partitioning
"""
name = "jax_local_amax"
multiple_results = False
impl_static_args = (
1,
2,
) # amax_scope, transpose_batch_sequence
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
x_aval,
*,
amax_scope,
transpose_batch_sequence,
):
"""
amax calcuation abstract
"""
del amax_scope, transpose_batch_sequence
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
return out_aval
@staticmethod
def impl(
x,
amax_scope,
transpose_batch_sequence,
):
"""
amax calcuation implementation
"""
del amax_scope, transpose_batch_sequence
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,))
return amax
@staticmethod
def infer_sharding_from_operands(
amax_scope,
transpose_batch_sequence,
mesh,
arg_infos,
result_infos,
):
"""
amax calcuation infer_sharding_from_operands
"""
del (amax_scope, transpose_batch_sequence, arg_infos, result_infos) # Unused.
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="AmaxCalculationPrimitive.out_sharding",
)
return amax_sharding
@staticmethod
def partition(
amax_scope,
transpose_batch_sequence,
mesh,
arg_infos,
result_infos,
):
"""
amax calcuation partition
"""
del result_infos
x_spec = get_padded_spec(arg_infos[0])
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="AmaxCalculation.amax_sharding",
)
def sharded_impl(x):
amax = AmaxCalculationPrimitive.impl(
x,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
amax, x_spec, transpose_batch_sequence, mesh
)
return amax
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
return mesh, sharded_impl, amax_sharding, arg_shardings
@staticmethod
def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types, result_types):
"""
amax calcuation shardy_sharding_rule
"""
del amax_scope, transpose_batch_sequence, mesh, result_types
prefix = "AmaxCal"
input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape)))
output_spec = (f"{prefix}_amax",)
return SdyShardingRule((input_spec,), (output_spec,))
register_primitive(AmaxCalculationPrimitive, outer_only=True)
class RHTAmaxCalculationPrimitive(BasePrimitive):
"""
Amax Calculation Primitive with custom_partitioning for calculating regular and post-Random Hadamard Transform (RHT) amax using TE's fused kernels.
"""
name = "te_rht_amax_ffi"
multiple_results = True
impl_static_args = (
1, # amax_scope
2, # transpose_batch_sequence
3, # rht_matrix_random_sign_mask_t
4, # produce_regular_amax
5, # flatten_axis
)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
x_aval,
*,
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
):
"""
amax calcuation abstract
"""
del (
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
)
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.bfloat16], f"RHT requires input to be bfloat16, but got {dtype}"
amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
post_rht_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
return amax_aval, post_rht_amax_aval
@staticmethod
def lowering(
ctx,
x,
*,
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
):
"""
te_dbias_quantize_p lowering rules
"""
del amax_scope, transpose_batch_sequence
(x_aval,) = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
flatten_axis = flatten_axis if flatten_axis >= 0 else flatten_axis + len(x_aval.shape)
assert 0 < flatten_axis < len(x_aval.shape), "Flatten axis out of bounds!"
return ffi.ffi_lowering(
RHTAmaxCalculationPrimitive.name,
)(
ctx,
x,
rht_matrix_random_sign_mask_t=rht_matrix_random_sign_mask_t,
produce_regular_amax=produce_regular_amax,
flatten_axis=flatten_axis,
)
@staticmethod
def impl(
x,
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
):
"""
amax calcuation implementation
"""
assert RHTAmaxCalculationPrimitive.inner_primitive is not None
(
amax,
post_rht_amax,
) = RHTAmaxCalculationPrimitive.inner_primitive.bind(
x,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
rht_matrix_random_sign_mask_t=rht_matrix_random_sign_mask_t,
produce_regular_amax=produce_regular_amax,
flatten_axis=flatten_axis,
)
return amax, post_rht_amax
@staticmethod
def infer_sharding_from_operands(
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
mesh,
arg_infos,
result_infos,
):
"""
amax calcuation infer_sharding_from_operands
"""
del (
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
arg_infos,
result_infos,
) # Unused.
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="RHTAmaxCalculationPrimitive.out_sharding",
)
return amax_sharding, amax_sharding
@staticmethod
def partition(
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
mesh,
arg_infos,
result_infos,
):
"""
amax calcuation partition
"""
del result_infos
x_spec = get_padded_spec(arg_infos[0])
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="RHTAmaxCalculationPrimitive.amax_sharding",
)
out_shardings = (amax_sharding, amax_sharding)
def sharded_impl(x):
amax, post_rht_amax = RHTAmaxCalculationPrimitive.impl(
x,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
rht_matrix_random_sign_mask_t=rht_matrix_random_sign_mask_t,
produce_regular_amax=produce_regular_amax,
flatten_axis=flatten_axis,
)
amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
amax, x_spec, transpose_batch_sequence, mesh
)
post_rht_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
post_rht_amax, x_spec, transpose_batch_sequence, mesh
)
return amax, post_rht_amax
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
mesh,
value_types,
result_types,
):
"""
amax calcuation shardy_sharding_rule
"""
del (
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
mesh,
result_types,
)
prefix = "RHTAmaxCal"
input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape)))
output_amax_spec = (f"{prefix}_amax",)
output_post_rht_amax_spec = (f"{prefix}_post_rht_amax",)
return SdyShardingRule((input_spec,), (output_amax_spec, output_post_rht_amax_spec))
register_primitive(RHTAmaxCalculationPrimitive)
def calculate_amax(x: jnp.ndarray, amax_scope: AmaxScope, transpose_batch_sequence: bool):
"""
Compute the maximum absolute value (amax) of the input tensor.
"""
assert AmaxCalculationPrimitive.outer_primitive is not None
return AmaxCalculationPrimitive.outer_primitive.bind(
x,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
def calculate_post_rht_amax(
x: jnp.ndarray,
amax_scope: AmaxScope,
transpose_batch_sequence: bool,
produce_regular_amax: bool,
flatten_axis: int,
):
"""Compute the post-Random Hadamard Transform (RHT) amax of the input tensor, and optionally the regular amax.
Args:
x: Input tensor.
amax_scope: The scope for amax reduction (local, TPSP, or FSDP).
transpose_batch_sequence: Whether the input tensor has its batch and sequence dimensions transposed.
produce_regular_amax: Whether to compute and return the regular amax alongside the post-RHT amax.
flatten_axis: The axis at which to flatten the input tensor before applying RHT.
Returns:
A tuple containing:
- The regular amax if `produce_regular_amax` is True, otherwise None.
- The post-RHT amax.
"""
amax, post_rht_amax = RHTAmaxCalculationPrimitive.outer_primitive.bind(
x,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
rht_matrix_random_sign_mask_t=get_sign_from_vector(get_wgrad_sign_vector()),
produce_regular_amax=produce_regular_amax,
flatten_axis=flatten_axis,
)
if produce_regular_amax:
return amax, post_rht_amax
return None, post_rht_amax
...@@ -32,6 +32,7 @@ from ..quantize import ( ...@@ -32,6 +32,7 @@ from ..quantize import (
AbstractBaseTensor, AbstractBaseTensor,
NoScaleTensor, NoScaleTensor,
ScaledTensor, ScaledTensor,
ScaledTensor1x,
ScaledTensor2x, ScaledTensor2x,
GroupedScaledTensor1x, GroupedScaledTensor1x,
ScalingMode, ScalingMode,
...@@ -43,6 +44,7 @@ from ..quantize import ( ...@@ -43,6 +44,7 @@ from ..quantize import (
noop_quantizer_set, noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv, apply_padding_to_scale_inv,
should_use_rht,
) )
from .misc import get_padded_spec, is_all_reduce_in_float32 from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import ( from ..sharding import (
...@@ -138,6 +140,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ ...@@ -138,6 +140,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_
need_lhs_colwise = lhs_is_transposed and ( need_lhs_colwise = lhs_is_transposed and (
lhs_quantizer.scaling_mode.is_1d_block_scaling() lhs_quantizer.scaling_mode.is_1d_block_scaling()
or not is_fp8_gemm_with_all_layouts_supported() or not is_fp8_gemm_with_all_layouts_supported()
or lhs_quantizer.scaling_mode.is_nvfp4_scaling
) )
flatten_axis = max(lhs_cdims) + 1 if lhs_is_transposed else min(lhs_cdims) flatten_axis = max(lhs_cdims) + 1 if lhs_is_transposed else min(lhs_cdims)
lhs_q = lhs_quantizer.quantize( lhs_q = lhs_quantizer.quantize(
...@@ -153,6 +156,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ ...@@ -153,6 +156,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_
need_rhs_colwise = not rhs_is_transposed and ( need_rhs_colwise = not rhs_is_transposed and (
rhs_quantizer.scaling_mode.is_1d_block_scaling() rhs_quantizer.scaling_mode.is_1d_block_scaling()
or not is_fp8_gemm_with_all_layouts_supported() or not is_fp8_gemm_with_all_layouts_supported()
or rhs_quantizer.scaling_mode.is_nvfp4_scaling
) )
flatten_axis = min(rhs_cdims) if rhs_is_transposed else max(rhs_cdims) + 1 flatten_axis = min(rhs_cdims) if rhs_is_transposed else max(rhs_cdims) + 1
rhs_q = rhs_quantizer.quantize( rhs_q = rhs_quantizer.quantize(
...@@ -165,9 +169,27 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ ...@@ -165,9 +169,27 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_
assert not isinstance(lhs_q, ScaledTensor2x) assert not isinstance(lhs_q, ScaledTensor2x)
assert not isinstance(rhs_q, ScaledTensor2x) assert not isinstance(rhs_q, ScaledTensor2x)
def uses_rht(q: AbstractBaseTensor) -> bool:
return isinstance(q, ScaledTensor1x) and should_use_rht(
q.scaling_mode, is_colwise=q.is_colwise
)
# TODO(jberchtold): Move RHT usage check to a bool flag on the ScaledTensor class
assert uses_rht(lhs_q) == uses_rht(rhs_q), (
"With NVFP4_1D_SCALING, if one operand is colwise quantized, the other must be colwise"
" quantized as well. This is to ensure the RHT is applied to both and will cancel out in"
" the GEMM."
)
return lhs_q, rhs_q return lhs_q, rhs_q
def _get_nvfp4_tensor_scale_inv(amax):
DATA_DTYPE_MAX = jnp.finfo(jnp.float4_e2m1fn.dtype).max.astype(jnp.float32)
SCALE_DTYPE_MAX = jnp.finfo(jnp.float8_e4m3fn.dtype).max.astype(jnp.float32)
return amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX)
def collective_gemm_bootstrap( def collective_gemm_bootstrap(
num_total_devices, num_total_devices,
num_devices_per_process, num_devices_per_process,
...@@ -345,7 +367,7 @@ class GemmPrimitive(BasePrimitive): ...@@ -345,7 +367,7 @@ class GemmPrimitive(BasePrimitive):
name = "te_gemm_ffi" name = "te_gemm_ffi"
multiple_results = True multiple_results = True
impl_static_args = 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -357,6 +379,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -357,6 +379,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_inv, rhs_scale_inv,
bias, bias,
gelu_input, gelu_input,
alpha,
beta,
out_dtype, out_dtype,
contracting_dims, contracting_dims,
scaling_mode, scaling_mode,
...@@ -404,7 +428,9 @@ class GemmPrimitive(BasePrimitive): ...@@ -404,7 +428,9 @@ class GemmPrimitive(BasePrimitive):
lhs_is_transposed, rhs_is_transposed = _get_gemm_layout(operand_ndims, contracting_dims) lhs_is_transposed, rhs_is_transposed = _get_gemm_layout(operand_ndims, contracting_dims)
if scaling_mode != ScalingMode.NO_SCALING: if scaling_mode != ScalingMode.NO_SCALING:
assert _compatible_fp8_gemm_dtypes(lhs.dtype, rhs.dtype), ( assert scaling_mode.is_nvfp4_scaling or _compatible_fp8_gemm_dtypes(
lhs.dtype, rhs.dtype
), (
"cuBLAS GEMM quantized operands have incompatible data types: " "cuBLAS GEMM quantized operands have incompatible data types: "
f"{lhs.dtype} x {rhs.dtype}." f"{lhs.dtype} x {rhs.dtype}."
) )
...@@ -484,6 +510,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -484,6 +510,8 @@ class GemmPrimitive(BasePrimitive):
f"expected {pre_gelu_dtype} but found {gelu_input.dtype}." f"expected {pre_gelu_dtype} but found {gelu_input.dtype}."
) )
pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype)
assert alpha.size == 1 and alpha.dtype == jnp.float32
assert beta.size == 1 and beta.dtype == jnp.float32
# Declare cuBLAS workspace # Declare cuBLAS workspace
workspace_size = get_cublas_workspace_size_bytes() workspace_size = get_cublas_workspace_size_bytes()
...@@ -510,6 +538,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -510,6 +538,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_inv, rhs_scale_inv,
bias, bias,
gelu_input, gelu_input,
alpha,
beta,
out_dtype, out_dtype,
contracting_dims, contracting_dims,
scaling_mode, scaling_mode,
...@@ -530,7 +560,7 @@ class GemmPrimitive(BasePrimitive): ...@@ -530,7 +560,7 @@ class GemmPrimitive(BasePrimitive):
(lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims) (lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims)
) )
args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input) args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta)
kwargs = { kwargs = {
"scaling_mode": int(scaling_mode.value), "scaling_mode": int(scaling_mode.value),
"lhs_axis_boundary": max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), "lhs_axis_boundary": max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims),
...@@ -563,6 +593,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -563,6 +593,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_inv, rhs_scale_inv,
bias, bias,
gelu_input, gelu_input,
alpha,
beta,
out_dtype, out_dtype,
contracting_dims, contracting_dims,
scaling_mode, scaling_mode,
...@@ -626,6 +658,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -626,6 +658,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_inv, rhs_scale_inv,
bias, bias,
gelu_input, gelu_input,
alpha,
beta,
out_dtype=out_dtype, out_dtype=out_dtype,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
...@@ -675,6 +709,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -675,6 +709,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_inv, rhs_scale_inv,
bias, bias,
gelu_input, gelu_input,
alpha,
beta,
out_dtype, out_dtype,
contracting_dims, contracting_dims,
scaling_mode, scaling_mode,
...@@ -694,6 +730,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -694,6 +730,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_inv, rhs_scale_inv,
bias, bias,
gelu_input, gelu_input,
alpha,
beta,
out_dtype, out_dtype,
contracting_dims, contracting_dims,
scaling_mode, scaling_mode,
...@@ -1001,6 +1039,9 @@ class GemmPrimitive(BasePrimitive): ...@@ -1001,6 +1039,9 @@ class GemmPrimitive(BasePrimitive):
gelu_input_specs = (None,) gelu_input_specs = (None,)
arg_shardings += (NamedSharding(mesh, PartitionSpec(*gelu_input_specs)),) arg_shardings += (NamedSharding(mesh, PartitionSpec(*gelu_input_specs)),)
# Alpha, beta
arg_shardings += (none_sharding, none_sharding)
# Assemble output shardings # Assemble output shardings
out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))] out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))]
...@@ -1014,7 +1055,7 @@ class GemmPrimitive(BasePrimitive): ...@@ -1014,7 +1055,7 @@ class GemmPrimitive(BasePrimitive):
pre_gelu_specs = (None,) pre_gelu_specs = (None,)
out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs))) out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs)))
def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta):
# We should not fuse bias in the output reduction case # We should not fuse bias in the output reduction case
sharded_fuse_bias = fuse_bias and reduce_spec is None sharded_fuse_bias = fuse_bias and reduce_spec is None
outputs = GemmPrimitive.impl( outputs = GemmPrimitive.impl(
...@@ -1024,6 +1065,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -1024,6 +1065,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_inv, rhs_scale_inv,
bias, bias,
gelu_input, gelu_input,
alpha,
beta,
out_dtype=out_dtype, out_dtype=out_dtype,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
...@@ -1114,8 +1157,10 @@ class GemmPrimitive(BasePrimitive): ...@@ -1114,8 +1157,10 @@ class GemmPrimitive(BasePrimitive):
rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims) rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims)
out_spec = (*lhs_non_cspec, *rhs_non_cspec) out_spec = (*lhs_non_cspec, *rhs_non_cspec)
bias_spec = rhs_non_cspec if fuse_bias else ("…4",) bias_spec = rhs_non_cspec if fuse_bias else ("…4",)
dbias_spec = bias_spec if grad else ("…5") gelu_spec = out_spec if fuse_gelu else ("…5",)
gelu_spec = out_spec if fuse_gelu else ("…6",) alpha_spec = ("_6",)
beta_spec = ("_7",)
dbias_spec = bias_spec if grad else ("…8")
return SdyShardingRule( return SdyShardingRule(
operand_mappings=( operand_mappings=(
...@@ -1125,6 +1170,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -1125,6 +1170,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_specs, rhs_scale_specs,
bias_spec, bias_spec,
gelu_spec, gelu_spec,
alpha_spec,
beta_spec,
), ),
result_mappings=( result_mappings=(
out_spec, out_spec,
...@@ -1178,6 +1225,7 @@ def _te_gemm( ...@@ -1178,6 +1225,7 @@ def _te_gemm(
# Quantize operands (if necessary) # Quantize operands (if necessary)
lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)
lhs_amax = rhs_amax = None
# Extract GEMM custom op inputs from quantized operands # Extract GEMM custom op inputs from quantized operands
if isinstance(lhs_q, ScaledTensor): if isinstance(lhs_q, ScaledTensor):
assert isinstance(rhs_q, ScaledTensor) or rhs_quantizer is not None, ( assert isinstance(rhs_q, ScaledTensor) or rhs_quantizer is not None, (
...@@ -1192,6 +1240,7 @@ def _te_gemm( ...@@ -1192,6 +1240,7 @@ def _te_gemm(
lhs_scale_inv = lhs_q.scale_inv lhs_scale_inv = lhs_q.scale_inv
if lhs_q.data_layout == "T": if lhs_q.data_layout == "T":
lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis) lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis)
lhs_amax = lhs_q.amax
if isinstance(rhs_q, ScaledTensor): if isinstance(rhs_q, ScaledTensor):
assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, (
...@@ -1201,7 +1250,11 @@ def _te_gemm( ...@@ -1201,7 +1250,11 @@ def _te_gemm(
if isinstance(rhs_q, ScaledTensor2x): if isinstance(rhs_q, ScaledTensor2x):
# Choose the quantization of the contracting dimension(s) # Choose the quantization of the contracting dimension(s)
rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor()
assert rhs_q.scaling_mode == lhs_q.scaling_mode, ( assert (
rhs_q.scaling_mode == lhs_q.scaling_mode
or rhs_q.scaling_mode.is_nvfp4_scaling
and lhs_q.scaling_mode.is_nvfp4_scaling
), (
"cuBLAS GEMM quantized operands have mismatched scaling types, " "cuBLAS GEMM quantized operands have mismatched scaling types, "
f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}." f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}."
) )
...@@ -1209,6 +1262,15 @@ def _te_gemm( ...@@ -1209,6 +1262,15 @@ def _te_gemm(
rhs_scale_inv = rhs_q.scale_inv rhs_scale_inv = rhs_q.scale_inv
if rhs_q.data_layout == "T": if rhs_q.data_layout == "T":
rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis) rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis)
rhs_amax = rhs_q.amax
alpha = jnp.ones((1,), jnp.float32)
beta = jnp.zeros((1,), jnp.float32)
if scaling_mode.is_nvfp4_scaling:
assert lhs_amax is not None and rhs_amax is not None
lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs_amax)
rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs_amax)
alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv
# Dummy empties for bias and gelu # Dummy empties for bias and gelu
out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype
...@@ -1224,6 +1286,8 @@ def _te_gemm( ...@@ -1224,6 +1286,8 @@ def _te_gemm(
rhs_scale_inv, rhs_scale_inv,
bias, bias,
gelu_input, gelu_input,
alpha,
beta,
out_dtype=out_dtype, out_dtype=out_dtype,
contracting_dims=(lhs_cdims, rhs_cdims), contracting_dims=(lhs_cdims, rhs_cdims),
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
...@@ -1514,15 +1578,17 @@ def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision): ...@@ -1514,15 +1578,17 @@ def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision):
@partial(jax.jit, static_argnums=(2,)) @partial(jax.jit, static_argnums=(2,))
def _jax_gemm_mxfp8_1d( def _jax_scaled_matmul(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]] lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
): ):
""" """
JAX GEMM for MXFP8 via scaled_matmul JAX GEMM for MXFP8 via scaled_matmul
""" """
assert ( assert rhs.scaling_mode in (
rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING ScalingMode.MXFP8_1D_SCALING,
), "rhs does not have MXFP8 1D scaling mode" ScalingMode.NVFP4_1D_SCALING,
ScalingMode.NVFP4_2D_SCALING,
), f"rhs does not have MXFP8 or NVFP4 scaling mode, got rhs.scaling_mode={rhs.scaling_mode}"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
...@@ -1537,21 +1603,48 @@ def _jax_gemm_mxfp8_1d( ...@@ -1537,21 +1603,48 @@ def _jax_gemm_mxfp8_1d(
f" {rhs.is_colwise}" f" {rhs.is_colwise}"
) )
if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
out_dtype = lhs.dq_dtype
assert (
lhs.data_layout == "N" and rhs.data_layout == "N"
), f"Got lhs.data_layout={lhs.data_layout}, rhs.data_layout={rhs.data_layout}"
else:
if lhs.data_layout == "T":
lhs_contract = transpose_dims(
lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis
)
if rhs.data_layout == "T":
rhs_contract = transpose_dims(
rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis
)
out_dtype = jnp.float32
# Reshape + Transpose (if needed) # Reshape + Transpose (if needed)
# [..., M, K] -> [1, reduce(..., M), K] # [..., M, K] -> [1, reduce(..., M), K]
# [..., K, M] -> [1, reduce(..., M), K] # [..., K, M] -> [1, reduce(..., M), K]
lhs_3d = _shape_normalization(lhs.data, (lhs_contract, lhs_batch)) lhs_3d = _shape_normalization(lhs.data, (lhs_contract, lhs_batch), lhs.data_layout == "T")
rhs_3d = _shape_normalization(rhs.data, (rhs_contract, rhs_batch)) rhs_3d = _shape_normalization(rhs.data, (rhs_contract, rhs_batch), rhs.data_layout == "T")
lhs_scale_3d = _shape_normalization(lhs.scale_inv, (lhs_contract, lhs_batch)) lhs_scale_3d = _shape_normalization(
rhs_scale_3d = _shape_normalization(rhs.scale_inv, (rhs_contract, rhs_batch)) lhs.scale_inv, (lhs_contract, lhs_batch), lhs.data_layout == "T"
)
rhs_scale_3d = _shape_normalization(
rhs.scale_inv, (rhs_contract, rhs_batch), rhs.data_layout == "T"
)
# JAX scaled_matmul only supports NT now (TN-gemm) # JAX scaled_matmul only supports NT now (TN-gemm)
# * Expected shape: # * Expected shape:
# * lhs_data (B, M, K) * rhs_data (B, N, K) # * lhs_data (B, M, K) * rhs_data (B, N, K)
# * lhs_scale (B, M, K_block) * rhs_scale (B, N, K_block) # * lhs_scale (B, M, K_block) * rhs_scale (B, N, K_block)
out_3d = jax.nn.scaled_matmul( out_3d = jax.nn.scaled_matmul(
lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=lhs.dq_dtype lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=out_dtype
) )
if lhs.scaling_mode.is_nvfp4_scaling:
assert lhs.amax is not None and rhs.amax is not None
lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs.amax)
rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs.amax)
alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv
out_3d = (out_3d * alpha).astype(lhs.dq_dtype)
# Reshape [1, reduce(..., M), N] -> [..., M, N] # Reshape [1, reduce(..., M), N] -> [..., M, N]
lhs_remain_shape = tuple( lhs_remain_shape = tuple(
lhs.data.shape[dim] for dim in range(len(lhs.data.shape)) if dim not in lhs_contract lhs.data.shape[dim] for dim in range(len(lhs.data.shape)) if dim not in lhs_contract
...@@ -1560,6 +1653,7 @@ def _jax_gemm_mxfp8_1d( ...@@ -1560,6 +1653,7 @@ def _jax_gemm_mxfp8_1d(
rhs.data.shape[dim] for dim in range(len(rhs.data.shape)) if dim not in rhs_contract rhs.data.shape[dim] for dim in range(len(rhs.data.shape)) if dim not in rhs_contract
) )
out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape) out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape)
return out return out
...@@ -1575,7 +1669,7 @@ def _jax_gemm( ...@@ -1575,7 +1669,7 @@ def _jax_gemm(
""" """
dim_nums = (contracting_dims, ((), ())) dim_nums = (contracting_dims, ((), ()))
def _jax_gemm_fp8_impl(lhs, rhs): def _jax_gemm_impl(lhs, rhs):
if lhs.scaling_mode.is_tensor_scaling(): if lhs.scaling_mode.is_tensor_scaling():
assert ( assert (
rhs.scaling_mode == lhs.scaling_mode rhs.scaling_mode == lhs.scaling_mode
...@@ -1587,15 +1681,15 @@ def _jax_gemm( ...@@ -1587,15 +1681,15 @@ def _jax_gemm(
) )
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision) return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)
if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: if lhs.scaling_mode.is_1d_block_scaling:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums) return _jax_scaled_matmul(lhs, rhs, dim_nums)
raise NotImplementedError(f"Unsupported ScalingMode: {lhs.scaling_mode}") raise NotImplementedError(f"Unsupported ScalingMode: {lhs.scaling_mode}")
lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)
if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor): if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor):
return _jax_gemm_fp8_impl(lhs_q, rhs_q) return _jax_gemm_impl(lhs_q, rhs_q)
if ( if (
isinstance(lhs, jnp.ndarray) isinstance(lhs, jnp.ndarray)
......
...@@ -6,8 +6,6 @@ ...@@ -6,8 +6,6 @@
import os import os
import functools import functools
from typing import Tuple from typing import Tuple
from importlib.metadata import version as get_pkg_version
from packaging.version import Version as PkgVersion
import numpy as np import numpy as np
...@@ -75,7 +73,8 @@ def jax_dtype_to_te_dtype(jax_dtype): ...@@ -75,7 +73,8 @@ def jax_dtype_to_te_dtype(jax_dtype):
jnp.int64.dtype: TEDType.kInt64, jnp.int64.dtype: TEDType.kInt64,
jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3, jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3,
jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2, jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2,
jnp.uint8.dtype: TEDType.kByte, jnp.float8_e8m0fnu.dtype: TEDType.kFloat8E8M0,
jnp.float4_e2m1fn.dtype: TEDType.kFloat4E2M1,
} }
if jax_dtype not in converter: if jax_dtype not in converter:
...@@ -151,16 +150,6 @@ def get_cudnn_version() -> Tuple[int, int, int]: ...@@ -151,16 +150,6 @@ def get_cudnn_version() -> Tuple[int, int, int]:
return (major, minor, patch) return (major, minor, patch)
@functools.lru_cache(maxsize=None)
def jax_version_meet_requirement(version: str):
"""
Helper function checking if required JAX version is available
"""
jax_version = PkgVersion(get_pkg_version("jax"))
jax_version_required = PkgVersion(version)
return jax_version >= jax_version_required
def get_xla_flag(flag: str, default=None, cast=str): def get_xla_flag(flag: str, default=None, cast=str):
""" """
Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value. Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value.
......
...@@ -28,7 +28,10 @@ from .misc import ( ...@@ -28,7 +28,10 @@ from .misc import (
get_cudnn_version, get_cudnn_version,
) )
from .quantization import _quantize_dbias_impl, AmaxScope from .quantization import _quantize_dbias_impl, AmaxScope
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp_tpsp from ..sharding import (
all_reduce_max_along_all_axes_except_PP,
all_reduce_sum_along_dp_fsdp_tpsp,
)
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import ( from ..quantize import (
Quantizer, Quantizer,
...@@ -1031,7 +1034,10 @@ def layernorm_fwd( ...@@ -1031,7 +1034,10 @@ def layernorm_fwd(
) )
return out, mu, rsigma return out, mu, rsigma
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if (
quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
or quantizer.scaling_mode.is_nvfp4_scaling
):
# Current scaling does not support fused operations. Perform norm in higher precision then quantize after. # Current scaling does not support fused operations. Perform norm in higher precision then quantize after.
out, mu, rsigma = layernorm_fwd( out, mu, rsigma = layernorm_fwd(
x=x, x=x,
...@@ -1276,7 +1282,10 @@ def rmsnorm_fwd( ...@@ -1276,7 +1282,10 @@ def rmsnorm_fwd(
) )
return out, rsigma return out, rsigma
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if (
quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
or quantizer.scaling_mode.is_nvfp4_scaling
):
# Current scaling does not support fused operations. Perform norm in higher precision then quantize after. # Current scaling does not support fused operations. Perform norm in higher precision then quantize after.
out, rsigma = rmsnorm_fwd( out, rsigma = rmsnorm_fwd(
x=x, x=x,
......
...@@ -85,7 +85,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeHandler); ...@@ -85,7 +85,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType in_dtype, DType out_dtype, DType scale_dtype,
JAXX_Scaling_Mode scaling_mode, JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout); QuantizeLayout q_layout);
...@@ -138,6 +138,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler); ...@@ -138,6 +138,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
// Amax
XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler);
// Cudnn helpers // Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment