Unverified Commit 8a7ab3dd authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] NVFP4 support in TE/JAX (#2254)


Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent e99be1b6
......@@ -33,6 +33,13 @@ def is_mxfp8_supported():
return gpu_arch >= 100
@lru_cache
def is_nvfp4_supported():
"""Return if FP8 has hardware supported"""
gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 100
def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01, print_info=False):
"""Checks whether most params are sharded across sharding axis.
......@@ -98,7 +105,7 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01, print_info=
)
def get_fp8_recipe_from_name_string(name: str):
def get_quantization_recipe_from_name_string(name: str):
"""Query recipe from a given name string"""
match name:
case "DelayedScaling":
......@@ -107,5 +114,7 @@ def get_fp8_recipe_from_name_string(name: str):
return recipe.MXFP8BlockScaling()
case "Float8CurrentScaling":
return recipe.Float8CurrentScaling()
case "NVFP4BlockScaling":
return recipe.NVFP4BlockScaling()
case _:
raise ValueError(f"Invalid fp8_recipe, got {name}")
raise ValueError(f"Invalid quantization_recipe, got {name}")
......@@ -10,9 +10,11 @@ TEST_CASES=(
"test_te_delayed_scaling_fp8"
"test_te_current_scaling_fp8"
"test_te_mxfp8"
"test_te_nvfp4"
"test_te_bf16_shardy"
"test_te_delayed_scaling_fp8_shardy"
"test_te_current_scaling_fp8_shardy"
"test_te_nvfp4_shardy"
)
: ${TE_PATH:=/opt/transformerengine}
......
......@@ -21,13 +21,13 @@ from jax.sharding import PartitionSpec, NamedSharding
from common import (
is_bf16_supported,
get_fp8_recipe_from_name_string,
get_quantization_recipe_from_name_string,
assert_params_sufficiently_sharded,
)
import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
DEVICE_DP_AXIS = "data"
......@@ -36,6 +36,7 @@ NAMED_BROADCAST_AXIS = "my_broadcast_axis"
NAMED_TP_AXIS = "my_tp_axis"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
SR_KEY = "sr_rng"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
......@@ -121,6 +122,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
epoch_accuracy = []
for perm in perms:
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_inputs = train_ds["sentence"][perm, ...]
batch_masks = train_ds["mask"][perm, ...]
batch_labels = train_ds["label"][perm, ...]
......@@ -135,11 +138,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
return state, avg_loss, avg_accuracy, var_collect
def eval_step(state, inputs, masks, labels, var_collect):
def eval_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......@@ -150,7 +153,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
return loss, accuracy
def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
def eval_model(state, test_ds, batch_size, var_collect, eval_fn, rngs):
"""Evaluation loop."""
test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size
......@@ -159,11 +162,13 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
all_accuracy = []
for batch_start in range(0, valid_size, batch_size):
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_end = batch_start + batch_size
batch_inputs = test_ds["sentence"][batch_start:batch_end]
batch_masks = test_ds["mask"][batch_start:batch_end]
batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect)
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect, rngs)
all_loss.append(loss)
all_accuracy.append(accuracy)
......@@ -223,7 +228,7 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)}
func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
......@@ -257,7 +262,7 @@ def train_and_evaluate(args):
), "Test batch size needs to be multiple of 32 for MXFP8"
if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
else:
fp8_recipe = None
......@@ -275,7 +280,8 @@ def train_and_evaluate(args):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
rng, sr_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
......@@ -355,7 +361,14 @@ def train_and_evaluate(args):
train_step, in_shardings=in_shardings, out_shardings=out_shardings
)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
None,
)
out_shardings = (None, None)
jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings
......@@ -367,22 +380,24 @@ def train_and_evaluate(args):
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng_state}
jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
return None
for epoch in range(1, args.epochs + 1):
# Split and reassign to 'rng' to ensure unique rng for each step
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
rng, sr_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
)
test_loss, test_accuracy = eval_model(
state, test_ds, args.test_batch_size, var_collect, jit_eval_step
state, test_ds, args.test_batch_size, var_collect, jit_eval_step, rngs
)
print(
......@@ -402,16 +417,16 @@ def encoder_parser(args):
parser.add_argument(
"--batch-size",
type=int,
default=128,
default=256,
metavar="N",
help="input batch size for training (default: 128)",
help="input batch size for training (default: 256)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=128,
default=256,
metavar="N",
help="input batch size for testing (default: 128)",
help="input batch size for testing (default: 256)",
)
parser.add_argument(
"--max-seq-len",
......@@ -466,8 +481,9 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
def setUp(self):
"""Run 5 epochs for testing"""
......@@ -477,7 +493,7 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
......@@ -485,7 +501,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.361 and actual[1] > 0.84
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
......@@ -493,14 +509,22 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4(self):
"""Test Transformer Engine with NVFP4"""
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.40 and actual[1] > 0.82
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self):
......@@ -509,7 +533,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self):
......@@ -518,14 +542,23 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4_with_sp(self):
"""Test Transformer Engine with NVFP4"""
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.40 and actual[1] > 0.82
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
......@@ -534,7 +567,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp_shardy(self):
......@@ -544,24 +577,27 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.361 and actual[1] > 0.84
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.40 and actual[1] > 0.82
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_with_sp_shardy(self):
"""Test Transformer Engine with MXFP8 + SP"""
self.args.enable_shardy = True
......@@ -569,7 +605,17 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.39 and actual[1] > 0.83
assert actual[0] < 0.36 and actual[1] > 0.84
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4_with_sp_shardy(self):
"""Test Transformer Engine with NVFP4"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.40 and actual[1] > 0.82
if __name__ == "__main__":
......
......@@ -19,17 +19,18 @@ from flax.training import train_state
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec, NamedSharding
from common import is_bf16_supported, get_fp8_recipe_from_name_string
from common import is_bf16_supported, get_quantization_recipe_from_name_string
import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
DEVICE_DP_AXIS = "data"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
DROPOUT_KEY = "dropout"
SR_KEY = "sr_rng"
INPUT_KEY = "input_rng"
......@@ -97,6 +98,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
epoch_accuracy = []
for perm in perms:
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_inputs = train_ds["sentence"][perm, ...]
batch_masks = train_ds["mask"][perm, ...]
batch_labels = train_ds["label"][perm, ...]
......@@ -111,11 +114,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
return state, avg_loss, avg_accuracy, var_collect
def eval_step(state, inputs, masks, labels, var_collect):
def eval_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......@@ -126,7 +129,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
return loss, accuracy
def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
def eval_model(state, test_ds, batch_size, var_collect, eval_fn, rngs):
"""Evaluation loop."""
test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size
......@@ -135,11 +138,13 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
all_accuracy = []
for batch_start in range(0, valid_size, batch_size):
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_end = batch_start + batch_size
batch_inputs = test_ds["sentence"][batch_start:batch_end]
batch_masks = test_ds["mask"][batch_start:batch_end]
batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect)
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect, rngs)
all_loss.append(loss)
all_accuracy.append(accuracy)
......@@ -199,7 +204,7 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)}
func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
......@@ -254,7 +259,7 @@ def train_and_evaluate(args):
), "Test batch size needs to be multiple of 32 for MXFP8"
if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
else:
fp8_recipe = None
......@@ -270,6 +275,7 @@ def train_and_evaluate(args):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rng, sr_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
input_shape = [args.batch_size, args.max_seq_len]
......@@ -322,7 +328,14 @@ def train_and_evaluate(args):
train_step, in_shardings=in_shardings, out_shardings=out_shardings
)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
None,
)
out_shardings = (None, None)
jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings
......@@ -334,22 +347,24 @@ def train_and_evaluate(args):
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
return None
for epoch in range(1, args.epochs + 1):
# Split and reassign to 'rng' to ensure unique rng for each step
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
rng, sr_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect, jit_train_step
)
test_loss, test_accuracy = eval_model(
state, test_ds, args.test_batch_size, var_collect, jit_eval_step
state, test_ds, args.test_batch_size, var_collect, jit_eval_step, rngs
)
print(
......@@ -369,16 +384,16 @@ def encoder_parser(args):
parser.add_argument(
"--batch-size",
type=int,
default=256,
default=512,
metavar="N",
help="input batch size for training (default: 256)",
help="input batch size for training (default: 512)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=256,
default=512,
metavar="N",
help="input batch size for testing (default: 256)",
help="input batch size for testing (default: 512)",
)
parser.add_argument(
"--max-seq-len",
......@@ -430,8 +445,9 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
def setUp(self):
"""Run 5 epochs for testing"""
......@@ -441,7 +457,7 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74
assert actual[0] < 0.51 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
......@@ -449,7 +465,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74
assert actual[0] < 0.51 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self):
......@@ -457,7 +473,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74
assert actual[0] < 0.51 and actual[1] > 0.749
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
......@@ -465,6 +481,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.51 and actual[1] > 0.75
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4(self):
"""Test Transformer Engine with NVFP4"""
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
......@@ -472,7 +496,7 @@ class TestEncoder(unittest.TestCase):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74
assert actual[0] < 0.51 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
......@@ -481,7 +505,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74
assert actual[0] < 0.51 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8_shardy(self):
......@@ -490,18 +514,24 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74
assert actual[0] < 0.51 and actual[1] > 0.749
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.51 and actual[1] > 0.75
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.52 and actual[1] > 0.74
......
......@@ -25,7 +25,8 @@ from common import (
is_bf16_supported,
is_fp8_supported,
is_mxfp8_supported,
get_fp8_recipe_from_name_string,
is_nvfp4_supported,
get_quantization_recipe_from_name_string,
)
import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
......@@ -39,6 +40,7 @@ NAMED_BROADCAST_AXIS = "my_broadcast_axis"
NAMED_TP_AXIS = "my_tp_axis"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
SR_KEY = "sr_rng"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
......@@ -175,6 +177,8 @@ def train_epoch(
epoch_accuracy = []
for perm in perms:
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_input = sentence[perm, ...]
batch_mask = mask[perm, ...]
batch_label = label[perm, ...]
......@@ -200,11 +204,11 @@ def train_epoch(
return state, avg_loss, avg_accuracy, var_collect
def eval_step(state, inputs, masks, labels, var_collect):
def eval_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......@@ -216,7 +220,16 @@ def eval_step(state, inputs, masks, labels, var_collect):
def eval_model(
state, test_ds, batch_size, var_collect, eval_fn, mesh, inputs_pspec, masks_pspec, labels_pspec
state,
test_ds,
batch_size,
var_collect,
eval_fn,
mesh,
inputs_pspec,
masks_pspec,
labels_pspec,
rngs,
):
"""Evaluation loop."""
global_input_shape, input_named_sharding, sentence = shard_array_wrapper(
......@@ -233,7 +246,8 @@ def eval_model(
all_accuracy = []
for batch_input, batch_mask, batch_label in zip(sentence, mask, label):
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
shard_input = jax.make_array_from_single_device_arrays(
global_input_shape, input_named_sharding, [batch_input]
)
......@@ -244,7 +258,7 @@ def eval_model(
global_label_shape, label_named_sharding, [batch_label]
)
loss, accuracy = eval_fn(state, shard_input, shard_mask, shard_label, var_collect)
loss, accuracy = eval_fn(state, shard_input, shard_mask, shard_label, var_collect, rngs)
all_loss.append(loss)
all_accuracy.append(accuracy)
......@@ -303,7 +317,7 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)}
func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
......@@ -372,7 +386,7 @@ def train_and_evaluate(args):
), "Test batch size needs to be multiple of 32 for MXFP8"
if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
else:
fp8_recipe = None
......@@ -390,7 +404,8 @@ def train_and_evaluate(args):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
rng, sr_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
......@@ -444,7 +459,14 @@ def train_and_evaluate(args):
train_step, in_shardings=in_shardings, out_shardings=out_shardings
)
in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None)
in_shardings = (
state_sharding,
inputs_sharding,
masks_sharding,
labels_sharding,
None,
None,
)
out_shardings = (None, None)
jit_eval_step = jax.jit(
eval_step, in_shardings=in_shardings, out_shardings=out_shardings
......@@ -456,14 +478,16 @@ def train_and_evaluate(args):
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng_state}
jit_train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
else:
for epoch in range(1, args.epochs + 1):
# Split and reassign to 'rng' to ensure unique rng for each step
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
rng, sr_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state,
......@@ -488,6 +512,7 @@ def train_and_evaluate(args):
inputs_pspec,
masks_pspec,
labels_sharding.spec,
rngs,
)
if args.process_id == 0:
print(
......@@ -508,16 +533,16 @@ def encoder_parser(args):
parser.add_argument(
"--batch-size",
type=int,
default=128,
default=256,
metavar="N",
help="input batch size for training (default: 128)",
help="input batch size for training (default: 256)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=128,
default=256,
metavar="N",
help="input batch size for testing (default: 128)",
help="input batch size for testing (default: 256)",
)
parser.add_argument(
"--max-seq-len",
......@@ -629,7 +654,7 @@ class TestEncoder(unittest.TestCase):
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling")
assert result[0] < 0.43 and result[1] > 0.80
assert result[0] < 0.432 and result[1] > 0.80
@unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
......@@ -639,6 +664,14 @@ class TestEncoder(unittest.TestCase):
result = self.exec(True, "MXFP8BlockScaling")
assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(
not is_nvfp4_supported(), "Device compute capability 10.0+ is required for NVFP4"
)
def test_te_nvfp4(self):
"""Test Transformer Engine with NVFP4"""
result = self.exec(True, "NVFP4BlockScaling")
assert result[0] < 0.451 and result[1] > 0.79
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
......@@ -659,19 +692,24 @@ class TestEncoder(unittest.TestCase):
def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling", enable_shardy=True)
assert result[0] < 0.43 and result[1] > 0.80
assert result[0] < 0.432 and result[1] > 0.80
@unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True)
assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(
not is_nvfp4_supported(), "Device compute capability 10.0+ is required for NVFP4"
)
def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4"""
result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True)
assert result[0] < 0.451 and result[1] > 0.79
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
......@@ -16,14 +16,15 @@ from datasets import load_dataset
from flax import linen as nn
from flax.training import train_state
from common import is_bf16_supported, get_fp8_recipe_from_name_string
from common import is_bf16_supported, get_quantization_recipe_from_name_string
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
PARAMS_KEY = "params"
DROPOUT_KEY = "dropout"
SR_KEY = "sr_rng"
INPUT_KEY = "input_rng"
......@@ -92,6 +93,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect):
epoch_accuracy = []
for perm in perms:
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_inputs = train_ds["sentence"][perm, ...]
batch_masks = train_ds["mask"][perm, ...]
batch_labels = train_ds["label"][perm, ...]
......@@ -107,11 +110,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect):
@jax.jit
def eval_step(state, inputs, masks, labels, var_collect):
def eval_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes loss and accuracy for a single batch."""
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......@@ -122,7 +125,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
return loss, accuracy
def eval_model(state, test_ds, batch_size, var_collect):
def eval_model(state, test_ds, batch_size, var_collect, rngs):
"""Evaluation loop."""
test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size
......@@ -131,11 +134,15 @@ def eval_model(state, test_ds, batch_size, var_collect):
all_accuracy = []
for batch_start in range(0, valid_size, batch_size):
# Split and reassign to 'rngs' to ensure unique rng for each step
rngs = {key: jax.random.split(rngs[key])[1] for key in rngs}
batch_end = batch_start + batch_size
batch_inputs = test_ds["sentence"][batch_start:batch_end]
batch_masks = test_ds["mask"][batch_start:batch_end]
batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_step(state, batch_inputs, batch_masks, batch_labels, var_collect)
loss, accuracy = eval_step(
state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
)
all_loss.append(loss)
all_accuracy.append(accuracy)
......@@ -195,7 +202,7 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)}
func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
......@@ -208,14 +215,15 @@ def train_and_evaluate(args):
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng}
rng, sr_rng = jax.random.split(rng)
init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
input_shape = [args.batch_size, args.max_seq_len]
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
else:
fp8_recipe = None
......@@ -238,21 +246,25 @@ def train_and_evaluate(args):
if args.dry_run:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
rngs = {DROPOUT_KEY: dropout_rng}
rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
train_step(state, inputs, masks, labels, var_collect, rngs)
print("PASSED")
return None
for epoch in range(1, args.epochs + 1):
# Split and reassign to 'rng' to ensure unique rng for each step
rng, input_rng = jax.random.split(rng)
rng, dropout_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
rng, sr_rng = jax.random.split(rng)
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
state, train_ds, args.batch_size, rngs, var_collect
)
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)
test_loss, test_accuracy = eval_model(
state, test_ds, args.test_batch_size, var_collect, rngs
)
print(
f"Epoch: {epoch:>2} "
......@@ -329,8 +341,9 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
def setUp(self):
"""Run 3 epochs for testing"""
......@@ -340,7 +353,7 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
assert actual[0] < 0.452 and actual[1] > 0.788
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
......@@ -348,7 +361,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79
assert actual[0] < 0.457 and actual[1] > 0.784
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self):
......@@ -356,7 +369,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79
assert actual[0] < 0.461 and actual[1] > 0.784
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
......@@ -364,7 +377,15 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79
assert actual[0] < 0.457 and actual[1] > 0.784
@unittest.skipIf(not is_nvfp4_supported, nvfp4_reason)
def test_te_nvfp4(self):
"""Test Transformer Engine with NVFP4"""
self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.476 and actual[1] > 0.775
if __name__ == "__main__":
......
......@@ -18,11 +18,11 @@ from flax.training import train_state
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode
DIR = str(Path(__file__).resolve().parents[1])
sys.path.append(str(DIR))
from encoder.common import is_bf16_supported, get_fp8_recipe_from_name_string
from encoder.common import is_bf16_supported, get_quantization_recipe_from_name_string
IMAGE_H = 28
IMAGE_W = 28
......@@ -189,7 +189,7 @@ def train_and_evaluate(args):
label_shape = [args.batch_size]
if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe)
else:
fp8_recipe = None
......@@ -308,8 +308,8 @@ def mnist_parser(args):
class TestMNIST(unittest.TestCase):
"""MNIST unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
@classmethod
def setUpClass(cls):
......
......@@ -40,11 +40,13 @@ from transformer_engine.jax.quantize import (
QuantizerFactory,
QuantizeLayout,
noop_quantizer_set,
should_use_rht,
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense, grouped_dense
from transformer_engine.jax.layernorm_dense import layernorm_dense
from transformer_engine.common import recipe
GEMM_CASES = [
(256, 256, 512),
......@@ -56,16 +58,23 @@ GEMM_CASES = [
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
LN_CASES = [(256, 128), (128, 256)]
DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, fp8_unsupported_reason = helper.is_fp8_available()
is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
supported_scaling_modes = []
# TODO(Phuong): remove unneccessary pytest skips
is_fp8_supported, fp8_unsupported_reason = helper.is_scaling_mode_supported(
ScalingMode.DELAYED_TENSOR_SCALING
)
is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_scaling_mode_supported(
ScalingMode.MXFP8_1D_SCALING
)
is_fp4_supported, fp4_unsupported_reason = helper.is_scaling_mode_supported(
ScalingMode.NVFP4_1D_SCALING
)
""" Find supported scaling modes"""
if is_fp8_supported:
supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING)
supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING)
if is_mxfp8_supported:
supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING)
supported_scaling_modes = helper.get_supported_scaling_modes()
non_fp4_supported_scaling_modes = [s for s in supported_scaling_modes if not s.is_nvfp4_scaling]
supported_recipes = helper.get_supported_quantization_recipes()
supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes]
def is_shape_supported_by_mxfp8(input_shape):
......@@ -83,12 +92,13 @@ def assert_bitwise_scaled_tensors(
a: ScaledTensor, b: ScaledTensor, precise_comparison: bool = True
):
if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
if not precise_comparison:
if not precise_comparison and not a.scaling_mode.is_nvfp4_scaling:
assert_allclose(a.dequantize(), b.dequantize(), dtype=a.data.dtype)
return
assert a.scaling_mode == b.scaling_mode
assert a.scale_inv.dtype == b.scale_inv.dtype
assert a.data_layout == b.data_layout
if a.scaling_mode.is_tensor_scaling():
# Assert in dq_dtype as some unfused codepaths have an intermediate cast
# to an input dtype which reduces precision compared to everything in fp32
......@@ -96,6 +106,16 @@ def assert_bitwise_scaled_tensors(
elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
# Compare MXFP8 scales as uint8
assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8))
elif a.scaling_mode.is_nvfp4_scaling:
assert_allclose(a.amax, b.amax)
assert_allclose(a.scale_inv, b.scale_inv)
if not precise_comparison:
mismatch = a.data != b.data
mismatch_fraction = jnp.mean(mismatch.astype(jnp.float32))
assert (
mismatch_fraction < 0.05
), f"Mismatch fraction {mismatch_fraction} is too high"
return
else:
raise ValueError(f"Unsupported scaling mode {a.scaling_mode}")
assert_allclose(a.data, b.data)
......@@ -603,10 +623,24 @@ class TestNorm:
)
QUANTIZE_OUTPUT_DTYPES = {
QUANTIZE_OUTPUT_FP8_DTYPES = {
"L0": [jnp.float8_e4m3fn],
"L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
QUANTIZE_OUTPUT_DTYPES = {
test_level: QUANTIZE_OUTPUT_FP8_DTYPES[test_level] + [jnp.float4_e2m1fn]
for test_level in QUANTIZE_OUTPUT_FP8_DTYPES
}
QUANTIZE_QDTYPE_AND_SCALING_MODES = {
test_level: [
(q_dtype, scaling_mode)
for q_dtype, scaling_mode in zip(
QUANTIZE_OUTPUT_FP8_DTYPES[test_level], supported_scaling_modes
)
if q_dtype in scaling_mode.get_compatible_q_dtypes()
]
for test_level in QUANTIZE_OUTPUT_FP8_DTYPES
}
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [
((32, 64), -1),
......@@ -615,8 +649,7 @@ ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [
((32, 256, 128), -1),
((32, 256, 128), -2),
((64, 32, 32, 256), -1),
((64, 32, 32, 256), -2),
((64, 32, 32, 256), -3),
((8192, 2, 4096), -2),
]
QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = {
......@@ -636,18 +669,38 @@ QUANTIZATION_INPUT_DTYPE = {
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2, jnp.float4_e2m1fn])
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
"q_layout",
[
QuantizeLayout.ROWWISE,
QuantizeLayout.COLWISE,
QuantizeLayout.ROWWISE_COLWISE,
],
)
class TestQuantize:
"""
Purely quantization related tests that will always test on a wider set of types and shapes
"""
def _skip_for_fp4(self, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
"""Temporary hack to skip unsupported FP4 cases until we implement them"""
if q_dtype not in scaling_mode.get_compatible_q_dtypes():
pytest.skip(f"Quantize dtype {q_dtype} is not supported by {scaling_mode}")
return
# HACK: FIXME TODO(jberchtold)
row = reduce(operator.mul, input_shape[flatten_axis:], 1)
col = reduce(operator.mul, input_shape[:flatten_axis], 1)
will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout)
if will_use_rht and (row % 64 != 0 or col % 128 != 0):
pytest.skip("Unfused RHT is not supported currently, skipping")
def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis)
key = jax.random.PRNGKey(0)
# Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
......@@ -657,6 +710,68 @@ class TestQuantize:
q_layout=q_layout,
)
if scaling_mode.is_nvfp4_scaling:
if in_dtype != jnp.bfloat16:
pytest.skip("NVFP4 scaling only supported with bfloat16 input dtype currently")
return
q_func = _jax_quantize
# For NVFP4 scaling, the maximum possible error for a single value can be high between the dequantized and original tensors. To ensure quantization and dequantization is operating correctly without requiring a very high tolerance for all values, we instead test that quantizing the dequantized tensor is bitwise identical to the original quantized tensor.
x = jax.random.uniform(key, input_shape, in_dtype) * 10
q1 = q_func(x, quantizer=quantizer, flatten_axis=flatten_axis)
dq_rowwise = None
dq_colwise = None
if isinstance(q1, ScaledTensor1x):
dq = q1.dequantize()
if q1.is_colwise:
dq_colwise = dq
else:
dq_rowwise = dq
elif isinstance(q1, ScaledTensor2x):
dq_rowwise = q1.rowwise_tensor.dequantize()
dq_colwise = q1.colwise_tensor.dequantize()
else:
raise ValueError(f"Unsupported output type {type(q1)}")
# We only compare Q-DQ for the same quantization layout. If we for example QDQ rowwise, then re-quantize colwise, the error will be larger and may not be bitwise identical to the original colwise quantization.
if dq_rowwise is not None:
assert (
dq_rowwise.shape == x.shape
), f"dq_rowwise shape {dq_rowwise.shape} != x shape {x.shape}"
q2_rowwise = q_func(dq_rowwise, quantizer=quantizer, flatten_axis=flatten_axis)
q2_rowwise = (
q2_rowwise
if isinstance(q2_rowwise, ScaledTensor1x)
else q2_rowwise.rowwise_tensor
)
q1_rowwise = q1 if isinstance(q1, ScaledTensor1x) else q1.rowwise_tensor
assert_bitwise_scaled_tensors(q1_rowwise, q2_rowwise)
if dq_colwise is not None:
# Since this is for NVFP4, we are assuming colwise has T layout and we do a transpose here to get back to original shape
flatten_axis = flatten_axis + len(input_shape) if flatten_axis < 0 else flatten_axis
colwise_flatten_axis = len(input_shape) - flatten_axis
dq_colwise = jnp.transpose(
dq_colwise,
(*range(colwise_flatten_axis, dq_colwise.ndim), *range(colwise_flatten_axis)),
)
assert (
dq_colwise.shape == x.shape
), f"dq_colwise shape {dq_colwise.shape} != x shape {x.shape}"
q2_colwise = q_func(dq_colwise, quantizer=quantizer, flatten_axis=flatten_axis)
q2_colwise = (
q2_colwise
if isinstance(q2_colwise, ScaledTensor1x)
else q2_colwise.colwise_tensor
)
q1_colwise = q1 if isinstance(q1, ScaledTensor1x) else q1.colwise_tensor
assert_bitwise_scaled_tensors(q1_colwise, q2_colwise)
assert (
dq_rowwise is not None or dq_colwise is not None
), "At least one of rowwise or colwise dq must be not None"
return
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations):
x = jax.random.uniform(key, input_shape, in_dtype)
......@@ -664,9 +779,33 @@ class TestQuantize:
scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis)
assert_dequantized_scaled_tensor(scaled_tensor, x)
def _should_use_precise_comparison(
self, in_dtype, scaling_mode, q_layout, input_shape, flatten_axis
):
# TODO(jberchtold): Remove this hack once we have a better solution to ensure bitwise identical results between TE and JAX RHT+quant implementations. Currently for certain shapes the quantized fp4 data differs by a small amount on <0.5% of the values.
RHT_SLIGHT_MISMATCH_SHAPES = [
((32, 256, 128), -1),
((64, 32, 32, 256), -1),
((8192, 2, 4096), -2),
]
if (
should_use_rht(scaling_mode, q_layout=q_layout)
and (input_shape, flatten_axis) in RHT_SLIGHT_MISMATCH_SHAPES
):
# TE fused RHT+quant and JAX RHT+quant have slight implementation differences which can lead to small numerical differences on certain shapes
return False
if scaling_mode.is_nvfp4_scaling and in_dtype != jnp.bfloat16:
# With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation
return False
return True
def test_quantize_bitwise(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
......@@ -677,15 +816,202 @@ class TestQuantize:
jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
try:
te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
assert_bitwise_scaled_tensors(te_output, jax_output)
except AssertionError as e:
if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16:
error_message = e.args[0]
if "RHT requires input to be bfloat16" in error_message:
# Successfully caught the expected error, early return from the test
return
raise e
assert_bitwise_scaled_tensors(
te_output,
jax_output,
precise_comparison=self._should_use_precise_comparison(
in_dtype, scaling_mode, q_layout, input_shape, flatten_axis
),
)
def test_quantize_bitwise_jitted(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
)
jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3))
te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,))
jax_output = jax_impl_func_jit(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
try:
te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
except AssertionError as e:
if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16:
error_message = e.args[0]
if "RHT requires input to be bfloat16" in error_message:
# Successfully caught the expected error, early return from the test
return
raise e
assert_bitwise_scaled_tensors(
te_output,
jax_output,
precise_comparison=self._should_use_precise_comparison(
in_dtype, scaling_mode, q_layout, input_shape, flatten_axis
),
)
@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16])
@pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn])
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper(
"scaling_mode", [s for s in supported_scaling_modes if s.is_nvfp4_scaling]
)
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
)
class TestStochasticRounding:
def _dequantize(self, scaled_tensor) -> list[jnp.ndarray]:
"""Dequantizes a ScaledTensor back to it's original jnp.ndarray form. This always returns an array of jnp.ndarrays, for ScaledTensor2x there will be two tensors, for ScaledTensor1x there will be one tensor."""
if isinstance(scaled_tensor, ScaledTensor1x):
dq = scaled_tensor.dequantize()
if scaled_tensor.data_layout == "T":
dq = jnp.transpose(
dq,
(
*range(scaled_tensor.flatten_axis, dq.ndim),
*range(scaled_tensor.flatten_axis),
),
)
return [dq]
elif isinstance(scaled_tensor, ScaledTensor2x):
[rowwise_dq] = self._dequantize(scaled_tensor.rowwise_tensor)
[colwise_dq] = self._dequantize(scaled_tensor.colwise_tensor)
return [rowwise_dq, colwise_dq]
raise ValueError(
"Unsupported ScaledTensor type, expected ScaledTensor but received"
f" {type(scaled_tensor)}"
)
def _sample_sr_qdq(
self, num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
) -> list[jnp.ndarray]:
"""Samples num_samples quantize-dequantize operations with stochastic rounding enabled and returns the dequantized tensors."""
dq_tensors = []
key = jax.random.PRNGKey(0)
for i in range(num_samples):
iter_key = jax.random.fold_in(key, i)
sr_rng_state = jax.random.randint(
iter_key, (4,), minval=0, maxval=2**30 - 1, dtype=jnp.uint32
)
quantizer = QuantizerFactory.create(
q_dtype=q_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
stochastic_rounding_rng_state=sr_rng_state,
)
q_output = q_func(inputs, quantizer=quantizer, flatten_axis=flatten_axis)
iter_dq = self._dequantize(q_output)
dq_tensors.extend(iter_dq)
avg_sr_tensor = jnp.mean(jnp.stack(dq_tensors), axis=0)
assert avg_sr_tensor.shape == inputs.shape, (
f"Dequantized tensor shape {avg_sr_tensor.shape} does not match input shape"
f" {inputs.shape}"
)
sr_mae = jnp.mean(jnp.abs(avg_sr_tensor - inputs))
dq_var = jnp.var(jnp.stack(dq_tensors))
assert (
dq_var > 0
), "Variance of dequantized tensors is zero, stochastic rounding may not be working"
return dq_tensors
def _round_nearest(
self, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
) -> jnp.ndarray:
"""Quantizes and dequantizes the input tensor with round nearest quantization."""
quantizer = QuantizerFactory.create(
q_dtype=q_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
stochastic_rounding_rng_state=None,
)
q_output = q_func(inputs, quantizer=quantizer, flatten_axis=flatten_axis)
rn_dq = self._dequantize(q_output)[0]
return rn_dq
def _test_sr(
self, num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
) -> float:
"""Tests that the mean absolute error (MAE) of stochastic rounding is smaller than round nearest quantization over multiple samples."""
dq_tensors = self._sample_sr_qdq(
num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
)
avg_sr_tensor = jnp.mean(jnp.stack(dq_tensors).astype(jnp.float32), axis=0)
assert avg_sr_tensor.shape == inputs.shape, (
f"Dequantized tensor shape {avg_sr_tensor.shape} does not match input shape"
f" {inputs.shape}"
)
round_nearest_tensor = self._round_nearest(
q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
)
sr_mae = jnp.mean(jnp.abs(avg_sr_tensor - inputs))
rn_mae = jnp.mean(jnp.abs(round_nearest_tensor - inputs))
assert sr_mae < rn_mae, (
f"Mean absolute error of stochastic rounding ({sr_mae}) is not smaller than"
f" round nearest ({rn_mae})"
)
return sr_mae
def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
"""Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other."""
# HACK: FIXME TODO(jberchtold)
row = reduce(operator.mul, input_shape[flatten_axis:], 1)
col = reduce(operator.mul, input_shape[:flatten_axis], 1)
will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout)
if will_use_rht and (row % 64 != 0 or col % 128 != 0):
pytest.skip("Unfused RHT is not supported currently, skipping")
key = jax.random.PRNGKey(0)
inputs = jax.random.uniform(key, input_shape, in_dtype)
NUM_SAMPLES = 10
te_mean_error = self._test_sr(
NUM_SAMPLES, tex.quantize, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
)
jax_mean_error = self._test_sr(
NUM_SAMPLES, _jax_quantize, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
)
assert_allclose(te_mean_error, jax_mean_error, rtol=0.2, atol=1e-4)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes)
@pytest_parametrize_wrapper("flatten_axis", [-1])
@pytest_parametrize_wrapper("with_group_sizes", [True, False])
@pytest_parametrize_wrapper(
......@@ -724,7 +1050,6 @@ class TestGroupedQuantize:
q_layout=q_layout,
n_groups=n_groups,
)
scaled_tensor = tex.grouped_quantize(
x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer
)
......@@ -736,9 +1061,8 @@ class TestGroupedQuantize:
class TestFusedQuantize:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("out_dtype,scaling_mode", QUANTIZE_QDTYPE_AND_SCALING_MODES)
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
......@@ -860,7 +1184,7 @@ class TestFusedQuantize:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_FP8_DTYPES)
@pytest_parametrize_wrapper("is_dbias", [True, False])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
......@@ -886,7 +1210,7 @@ class TestFusedQuantize:
@pytest_parametrize_wrapper(
"input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)]
)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_FP8_DTYPES)
@pytest_parametrize_wrapper("is_dbias", [True, False])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
......@@ -919,6 +1243,11 @@ valid_fp8_gemm_operand_types = [
(jnp.float8_e4m3fn, jnp.float8_e5m2),
]
supported_nvfp4_scaling_mode_pairs = [
(ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_1D_SCALING),
(ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING),
]
class TestDense:
def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
......@@ -960,7 +1289,7 @@ class TestDense:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes)
@pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm):
......@@ -994,6 +1323,40 @@ class TestDense:
assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
# TODO(Phuong): add bitwise test
@pytest.mark.skipif(not is_fp4_supported, reason=fp4_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("scaling_mode_pair", supported_nvfp4_scaling_mode_pairs)
@pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
@pytest_parametrize_wrapper("with_jax_gemm", [True, False])
def test_gemm_nvfp4(self, m, n, k, scaling_mode_pair, data_layout, with_jax_gemm):
x_uses_rht = scaling_mode_pair[0] == ScalingMode.NVFP4_1D_SCALING and data_layout[0] == "T"
w_uses_rht = scaling_mode_pair[1] == ScalingMode.NVFP4_1D_SCALING and data_layout[1] == "N"
if x_uses_rht != w_uses_rht:
# TODO(jberchtold): Ideally avoid a skip here and rewrite test setup to ensure both or neither use RHT
pytest.skip("RHT must be used for both or neither operand, skipping")
lhs_scaling_mode, rhs_scaling_mode = scaling_mode_pair
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
lhs_quantizer = QuantizerFactory.create(
scaling_mode=lhs_scaling_mode,
q_dtype=jnp.float4_e2m1fn,
)
rhs_quantizer = QuantizerFactory.create(
scaling_mode=rhs_scaling_mode,
q_dtype=jnp.float4_e2m1fn,
)
with use_jax_gemm(enabled=with_jax_gemm):
primitive_out = tex.gemm(
x,
w,
contracting_dims=contracting_dims,
lhs_quantizer=lhs_quantizer,
rhs_quantizer=rhs_quantizer,
)
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=jnp.float4_e2m1fn)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
def test_dense_grad_bf16(self, m, n, k):
data_layout = "NN"
......@@ -1019,11 +1382,10 @@ class TestDense:
assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("m,n,k", [(64, 128, 128)])
@pytest_parametrize_wrapper("recipe", supported_recipes)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm):
def test_dense_grad_fp8_and_fp4(self, m, n, k, recipe, with_jax_gemm):
data_layout = "NN"
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
......@@ -1044,14 +1406,9 @@ class TestDense:
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True,
)
quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe)
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
n_iterations = 3 if recipe.delayed() else 1
with use_jax_gemm(enabled=with_jax_gemm):
for _ in range(n_iterations):
primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
......@@ -1062,10 +1419,10 @@ class TestDense:
x, w, bias, data_layout
)
assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp.float8_e5m2)
assert_allclose(primitive_out, ref_out, dtype=quantizer_set.x.q_dtype)
assert_allclose(primitive_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype)
assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=quantizer_set.dgrad.q_dtype)
@pytest.fixture(name="random_inputs")
......@@ -1087,11 +1444,11 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
class TestFusedDense:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
@pytest.mark.parametrize("m,n,k", [(64, 128, 128)])
@pytest_parametrize_wrapper("recipe", supported_recipes)
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_gemm):
def test_layernorm_dense_grad(self, m, n, k, recipe, norm_type, with_jax_gemm):
"""
Test layernorm_dense VJP Rule
"""
......@@ -1108,12 +1465,7 @@ class TestFusedDense:
gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True,
)
quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe)
if norm_type == "layernorm":
beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
......@@ -1148,7 +1500,7 @@ class TestFusedDense:
x, w, gamma, beta
)
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
n_iterations = 3 if recipe.delayed() else 1
with use_jax_gemm(enabled=with_jax_gemm):
for _ in range(n_iterations):
prim_out, (
......@@ -1158,22 +1510,22 @@ class TestFusedDense:
prim_beta_grad,
) = value_n_grad_prim_func(x, w, gamma, beta)
assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_out, ref_out, dtype=quantizer_set.x.q_dtype)
assert_allclose(prim_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype)
assert_allclose(prim_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=quantizer_set.dgrad.q_dtype)
if beta is not None:
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=quantizer_set.dgrad.q_dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("m,n,k", [(64, 128, 128)])
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("recipe", supported_recipes)
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad(
self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm
self, m, n, k, activation_type, recipe, norm_type, use_bias, with_jax_gemm
):
"""
Test layernorm_mlp VJP Rule
......@@ -1201,10 +1553,7 @@ class TestFusedDense:
quantizer_sets = QuantizerFactory.create_set(
n_quantizer_sets=2,
scaling_mode=scaling_mode,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True,
fp8_recipe=recipe,
)
if norm_type == "layernorm":
......@@ -1251,7 +1600,7 @@ class TestFusedDense:
value_n_grad_prim_func = value_and_grad(prim_func, range(6))
value_n_grad_ref_func = value_and_grad(ref_func, range(6))
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
n_iterations = 3 if recipe.delayed() else 1
with use_jax_gemm(enabled=with_jax_gemm):
for _ in range(n_iterations):
prim_out, (
......@@ -1272,18 +1621,16 @@ class TestFusedDense:
ref_bias_2_grad,
) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)
assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp.float8_e5m2)
if use_bias:
assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp.float8_e5m2)
fwd_dtype = quantizer_sets[0].x.q_dtype
bwd_dtype = quantizer_sets[0].dgrad.q_dtype
assert_allclose(prim_out, ref_out, dtype=fwd_dtype)
assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=bwd_dtype)
assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=bwd_dtype)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=bwd_dtype)
assert_allclose(prim_x_grad, ref_x_grad, dtype=bwd_dtype)
if use_bias:
assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=bwd_dtype)
assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=bwd_dtype)
# E5M2 * E5M2 is not supported
......@@ -1388,7 +1735,7 @@ class TestGroupedDense:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes)
@pytest_parametrize_wrapper("layout", ["NN"])
def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout):
fwd_dtype, bwd_dtype = fwd_bwd_dtype
......@@ -1469,7 +1816,7 @@ class TestGroupedDense:
"fwd_bwd_dtype",
[(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)],
)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes)
def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape):
fwd_dtype, bwd_dtype = fwd_bwd_dtype
dtype = jnp.bfloat16
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import re
from typing import Callable, Sequence, Union, Optional
import pytest
......@@ -17,7 +18,11 @@ from utils import (
)
from transformer_engine.common import recipe
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
from transformer_engine.jax.quantize import (
is_fp8_available,
ScalingMode,
get_quantize_config_with_recipe,
)
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.flax import LayerNormMLP
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
......@@ -33,19 +38,20 @@ from transformer_engine.jax.sharding import (
W_JOINED_AXES,
)
from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.quantize import QuantizerFactory
from transformer_engine.jax.quantize import (
QuantizerFactory,
get_supported_quantization_recipes,
is_scaling_mode_supported,
)
from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability
is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
SUPPORTED_RECIPES = []
if is_fp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling"))
if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
SUPPORTED_RECIPES = get_supported_quantization_recipes()
SUPPORTED_RECIPES = [pytest.param(r, id=r.__class__.__name__) for r in SUPPORTED_RECIPES]
DTYPES = [jnp.bfloat16, jnp.float16]
INPUT_SHAPE = [[4, 128, 256]] # [batch, seqlen, hidden_in]
......@@ -141,6 +147,7 @@ class TestDistributedLayernormMLP:
layernorm_type: str = "rmsnorm",
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
multi_gpus: bool = False,
quantization_recipe: recipe.Recipe = None,
) -> jnp.ndarray:
if multi_gpus:
......@@ -154,7 +161,9 @@ class TestDistributedLayernormMLP:
dot_1_input_axes = dot_2_input_axes = None
kernel_1_axes = kernel_2_axes = None
quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2)
quantizer_sets = QuantizerFactory.create_set(
n_quantizer_sets=2, fp8_recipe=quantization_recipe
)
# out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
return jnp.mean(
......@@ -182,7 +191,7 @@ class TestDistributedLayernormMLP:
use_bias,
input_shape,
dtype,
fp8_recipe,
quantization_recipe,
use_shardy,
with_jax_gemm,
):
......@@ -202,7 +211,9 @@ class TestDistributedLayernormMLP:
# Single GPU
with fp8_autocast(
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()
enabled=quantization_recipe is not None,
fp8_recipe=quantization_recipe,
mesh_resource=MeshResource(),
):
single_jitter = jax.jit(
value_and_grad_func,
......@@ -214,7 +225,9 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(
enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
enabled=quantization_recipe is not None,
fp8_recipe=quantization_recipe,
mesh_resource=mesh_resource,
):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp"))
......@@ -254,9 +267,15 @@ class TestDistributedLayernormMLP:
multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)
fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn
bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2
fwd_test_type = bwd_test_type = dtype
if quantization_recipe is not None:
quantize_config = get_quantize_config_with_recipe(quantization_recipe)
fwd_test_type = quantize_config.FWD_DTYPE
bwd_test_type = quantize_config.BWD_DTYPE
if fwd_test_type == jnp.float16 and use_bias:
assert_allclose(multi_fwd, single_fwd, atol=0.04, rtol=1.5)
else:
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)
for i in range(len(inputs)):
......@@ -278,13 +297,12 @@ class TestDistributedLayernormMLP:
err_msg=f"multi_grads[{i}] is not close",
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad(
self,
......@@ -293,27 +311,28 @@ class TestDistributedLayernormMLP:
use_bias,
input_shape,
dtype,
fp8_recipe,
quantization_recipe,
with_jax_gemm,
):
if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
self._test_layernorm_mlp_grad(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
quantization_recipe,
use_shardy=False,
with_jax_gemm=with_jax_gemm,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad_shardy(
self,
......@@ -322,18 +341,18 @@ class TestDistributedLayernormMLP:
use_bias,
input_shape,
dtype,
fp8_recipe,
quantization_recipe,
with_jax_gemm,
):
if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
self._test_layernorm_mlp_grad(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe=fp8_recipe,
quantization_recipe=quantization_recipe,
use_shardy=True,
with_jax_gemm=with_jax_gemm,
)
......@@ -346,7 +365,7 @@ class TestDistributedLayernormMLP:
input_shape,
dtype,
use_fp8,
fp8_recipe,
quantization_recipe,
use_shardy,
with_jax_gemm,
):
......@@ -355,14 +374,16 @@ class TestDistributedLayernormMLP:
layernorm_type = "rmsnorm"
rng = jax.random.PRNGKey(0)
subkeys = jax.random.split(rng, 2)
subkeys = jax.random.split(rng, 3)
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
init_rngs = {"params": subkeys[1]}
init_rngs = {"params": subkeys[1], "sr_rng": subkeys[2]}
with use_jax_gemm(enabled=with_jax_gemm):
# Single GPUs
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
with fp8_autocast(
enabled=use_fp8, fp8_recipe=quantization_recipe, mesh_resource=MeshResource()
):
ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type,
intermediate_dim=INTERMEDIATE,
......@@ -371,7 +392,7 @@ class TestDistributedLayernormMLP:
)
params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
mlp_out_single, ln_out_single = ln_mlp_single.apply(
params_single, x, deterministic=True
params_single, x, deterministic=True, rngs={"sr_rng": subkeys[2]}
)
# Multi GPUs
......@@ -379,7 +400,7 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(
enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
enabled=use_fp8, fp8_recipe=quantization_recipe, mesh_resource=mesh_resource
):
ln_mlp_sharded = LayerNormMLP(
layernorm_type=layernorm_type,
......@@ -399,7 +420,7 @@ class TestDistributedLayernormMLP:
)
params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
params_sharded, x, deterministic=True
params_sharded, x, deterministic=True, rngs={"sr_rng": subkeys[2]}
)
# Make sure params values are the same
......@@ -411,8 +432,8 @@ class TestDistributedLayernormMLP:
rtol = None
l40_tolerance_update = (
get_min_device_compute_capability() == 89
and fp8_recipe == recipe.DelayedScaling()
and use_fp8
and quantization_recipe.delayed()
and dtype == jnp.float16
and activation_type == ("gelu",)
)
......@@ -430,8 +451,8 @@ class TestDistributedLayernormMLP:
# within tolerance to the float32 ground truth.
jax_triton_gemm_precision_tolerance_update = (
with_jax_gemm
and fp8_recipe is not None
and (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling())
and quantization_recipe is not None
and (quantization_recipe.delayed() or quantization_recipe.float8_current_scaling())
and dtype in (jnp.bfloat16, jnp.float16)
and activation_type == ("gelu", "linear"),
)
......@@ -457,22 +478,30 @@ class TestDistributedLayernormMLP:
input_shape,
dtype,
use_fp8=False,
fp8_recipe=None,
quantization_recipe=None,
use_shardy=False,
with_jax_gemm=with_jax_gemm,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_fp8(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
quantization_recipe,
with_jax_gemm,
):
if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
self._test_layernorm_mlp(
mesh_config,
activation_type,
......@@ -480,7 +509,7 @@ class TestDistributedLayernormMLP:
input_shape,
dtype,
use_fp8=True,
fp8_recipe=fp8_recipe,
quantization_recipe=quantization_recipe,
use_shardy=False,
with_jax_gemm=with_jax_gemm,
)
......@@ -501,24 +530,30 @@ class TestDistributedLayernormMLP:
input_shape,
dtype,
use_fp8=False,
fp8_recipe=None,
quantization_recipe=None,
use_shardy=True,
with_jax_gemm=with_jax_gemm,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_fp8_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
quantization_recipe,
with_jax_gemm,
):
if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4():
pytest.skip("NVFP4 GEMM + Float16 output is unsupported!")
self._test_layernorm_mlp(
mesh_config,
activation_type,
......@@ -526,7 +561,7 @@ class TestDistributedLayernormMLP:
input_shape,
dtype,
use_fp8=True,
fp8_recipe=fp8_recipe,
quantization_recipe=quantization_recipe,
use_shardy=True,
with_jax_gemm=with_jax_gemm,
)
......@@ -10,20 +10,27 @@ import jax.numpy as jnp
import numpy as np
from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling
from transformer_engine.common.recipe import (
DelayedScaling,
MXFP8BlockScaling,
Float8CurrentScaling,
NVFP4BlockScaling,
)
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.quantize import (
get_quantize_config,
is_fp8_available,
is_scaling_mode_supported,
ScalingMode,
update_collections,
TensorSource,
)
from transformer_engine.jax.quantize.helper import _format2dtypes
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING)
is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING)
class TestHelper(unittest.TestCase):
......@@ -52,14 +59,16 @@ class TestFP8Functions(unittest.TestCase):
def _check_default_state(self):
self.assertFalse(get_quantize_config().is_fp8_enabled())
def _compare_delay_scaling(self, ref, test):
self.assertTrue(ref.margin == test.margin)
self.assertTrue(ref.fp8_format == test.fp8_format)
self.assertTrue(ref.amax_history_len == test.amax_history_len)
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
def _compare_delay_scaling(self, test):
self.assertEqual(get_quantize_config().MARGIN, test.margin)
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
self.assertEqual(get_quantize_config().AMAX_HISTORY_LEN, test.amax_history_len)
self.assertEqual(get_quantize_config().AMAX_COMPUTE_ALGO.value, test.amax_compute_algo)
def _compare_current_scaling(self, test):
self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format)
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
for tensor_source in TensorSource:
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source),
......@@ -67,13 +76,26 @@ class TestFP8Functions(unittest.TestCase):
)
def _compare_mxfp8_scaling(self, test):
self.assertEqual(get_quantize_config().MARGIN, test.margin)
self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format)
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0])
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1])
for tensor_source in TensorSource:
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING
)
def _compare_nvfp4_scaling(self, test):
self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp4_format)[0])
self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp4_format)[1])
for tensor_source in TensorSource:
target_scaling_mode = (
ScalingMode.NVFP4_2D_SCALING
if tensor_source == TensorSource.KERNEL
else ScalingMode.NVFP4_1D_SCALING
)
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode
)
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_delayed_scaling(self):
self._check_default_state()
......@@ -86,14 +108,14 @@ class TestFP8Functions(unittest.TestCase):
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._compare_delay_scaling(ds)
self._check_default_state()
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._compare_delay_scaling(ds)
self._check_default_state()
......@@ -133,16 +155,27 @@ class TestFP8Functions(unittest.TestCase):
self._check_default_state()
bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
bs = MXFP8BlockScaling()
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._check_default_state()
bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
@unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason)
def test_fp8_autocast_nvfp4_block_scaling(self):
self._check_default_state()
with fp8_autocast(
enabled=False, fp8_recipe=NVFP4BlockScaling(), mesh_resource=MeshResource()
):
self._check_default_state()
self._check_default_state()
bs = NVFP4BlockScaling()
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_mxfp8_scaling(bs)
self._compare_nvfp4_scaling(bs)
self._check_default_state()
......@@ -1544,6 +1544,12 @@ def dtype_tols(
rtol = eps_relaxed
if atol is None:
atol = max(ulp, eps_relaxed)
# Manually set tols for nvfp4
if dtype == jnp.float4_e2m1fn:
atol = 0.05
rtol = 0.1
return {"rtol": rtol, "atol": atol}
......
......@@ -34,7 +34,7 @@ load_framework_extension("jax")
from . import flax
from . import quantize
from .quantize import fp8_autocast, update_collections, get_delayed_scaling
from .quantize import fp8_autocast, update_collections
from .quantize import NVTE_FP8_COLLECTION_NAME
from .sharding import MeshResource
......@@ -47,7 +47,6 @@ __all__ = [
"NVTE_FP8_COLLECTION_NAME",
"fp8_autocast",
"update_collections",
"get_delayed_scaling",
"MeshResource",
"flax",
"quantize",
......
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
"""Python interface for c++ extensions"""
from .activation import *
from .amax import *
from .attention import *
from .normalization import *
from .quantization import *
......
......@@ -1314,7 +1314,10 @@ def act_lu(
)
return out
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
if (
quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
or quantizer.scaling_mode.is_nvfp4_scaling
):
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out = act_lu(
x=x,
......@@ -1488,7 +1491,10 @@ def quantize_dact_dbias(
if war_output is not None:
return war_output
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
if (
quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
or quantizer.scaling_mode.is_nvfp4_scaling
):
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out = dact_lu(
dz=dz,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom ops for amax calculation"""
from enum import Enum
import jax
import jax.numpy as jnp
from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.sharding import PartitionSpec
from .base import BasePrimitive, register_primitive
from .misc import (
get_padded_spec,
NamedSharding,
)
from ..sharding import (
global_mesh_resource,
lax_paral_op,
)
from ..quantize import (
get_wgrad_sign_vector,
get_sign_from_vector,
)
__all__ = ["AmaxScope", "calculate_amax", "calculate_post_rht_amax"]
class AmaxScope(Enum):
"""
Amax Scope Enum
"""
LOCAL = 1
TPSP = 2
FSDP = 3
def all_reduce_amax_along_TPSP_and_FSDP(self, amax, data_spec, transpose_batch_sequence, mesh):
"""Reduce the amax based on its scope"""
gmesh = global_mesh_resource()
sequence_dim = 0 if transpose_batch_sequence else 1
# Run AR across TPSP only when tensor-sequence is detected in the input spec
if self is AmaxScope.TPSP and data_spec[sequence_dim] == gmesh.tpsp_resource:
return lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh)
# Run AR across FSDP
if self is AmaxScope.FSDP:
return lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh)
return amax
class AmaxCalculationPrimitive(BasePrimitive):
"""
Amax Calculation Primitive with custom_partitioning
"""
name = "jax_local_amax"
multiple_results = False
impl_static_args = (
1,
2,
) # amax_scope, transpose_batch_sequence
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
x_aval,
*,
amax_scope,
transpose_batch_sequence,
):
"""
amax calcuation abstract
"""
del amax_scope, transpose_batch_sequence
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
return out_aval
@staticmethod
def impl(
x,
amax_scope,
transpose_batch_sequence,
):
"""
amax calcuation implementation
"""
del amax_scope, transpose_batch_sequence
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,))
return amax
@staticmethod
def infer_sharding_from_operands(
amax_scope,
transpose_batch_sequence,
mesh,
arg_infos,
result_infos,
):
"""
amax calcuation infer_sharding_from_operands
"""
del (amax_scope, transpose_batch_sequence, arg_infos, result_infos) # Unused.
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="AmaxCalculationPrimitive.out_sharding",
)
return amax_sharding
@staticmethod
def partition(
amax_scope,
transpose_batch_sequence,
mesh,
arg_infos,
result_infos,
):
"""
amax calcuation partition
"""
del result_infos
x_spec = get_padded_spec(arg_infos[0])
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="AmaxCalculation.amax_sharding",
)
def sharded_impl(x):
amax = AmaxCalculationPrimitive.impl(
x,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
amax, x_spec, transpose_batch_sequence, mesh
)
return amax
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
return mesh, sharded_impl, amax_sharding, arg_shardings
@staticmethod
def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types, result_types):
"""
amax calcuation shardy_sharding_rule
"""
del amax_scope, transpose_batch_sequence, mesh, result_types
prefix = "AmaxCal"
input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape)))
output_spec = (f"{prefix}_amax",)
return SdyShardingRule((input_spec,), (output_spec,))
register_primitive(AmaxCalculationPrimitive, outer_only=True)
class RHTAmaxCalculationPrimitive(BasePrimitive):
"""
Amax Calculation Primitive with custom_partitioning for calculating regular and post-Random Hadamard Transform (RHT) amax using TE's fused kernels.
"""
name = "te_rht_amax_ffi"
multiple_results = True
impl_static_args = (
1, # amax_scope
2, # transpose_batch_sequence
3, # rht_matrix_random_sign_mask_t
4, # produce_regular_amax
5, # flatten_axis
)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
x_aval,
*,
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
):
"""
amax calcuation abstract
"""
del (
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
)
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.bfloat16], f"RHT requires input to be bfloat16, but got {dtype}"
amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
post_rht_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
return amax_aval, post_rht_amax_aval
@staticmethod
def lowering(
ctx,
x,
*,
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
):
"""
te_dbias_quantize_p lowering rules
"""
del amax_scope, transpose_batch_sequence
(x_aval,) = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
flatten_axis = flatten_axis if flatten_axis >= 0 else flatten_axis + len(x_aval.shape)
assert 0 < flatten_axis < len(x_aval.shape), "Flatten axis out of bounds!"
return ffi.ffi_lowering(
RHTAmaxCalculationPrimitive.name,
)(
ctx,
x,
rht_matrix_random_sign_mask_t=rht_matrix_random_sign_mask_t,
produce_regular_amax=produce_regular_amax,
flatten_axis=flatten_axis,
)
@staticmethod
def impl(
x,
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
):
"""
amax calcuation implementation
"""
assert RHTAmaxCalculationPrimitive.inner_primitive is not None
(
amax,
post_rht_amax,
) = RHTAmaxCalculationPrimitive.inner_primitive.bind(
x,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
rht_matrix_random_sign_mask_t=rht_matrix_random_sign_mask_t,
produce_regular_amax=produce_regular_amax,
flatten_axis=flatten_axis,
)
return amax, post_rht_amax
@staticmethod
def infer_sharding_from_operands(
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
mesh,
arg_infos,
result_infos,
):
"""
amax calcuation infer_sharding_from_operands
"""
del (
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
arg_infos,
result_infos,
) # Unused.
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="RHTAmaxCalculationPrimitive.out_sharding",
)
return amax_sharding, amax_sharding
@staticmethod
def partition(
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
mesh,
arg_infos,
result_infos,
):
"""
amax calcuation partition
"""
del result_infos
x_spec = get_padded_spec(arg_infos[0])
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="RHTAmaxCalculationPrimitive.amax_sharding",
)
out_shardings = (amax_sharding, amax_sharding)
def sharded_impl(x):
amax, post_rht_amax = RHTAmaxCalculationPrimitive.impl(
x,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
rht_matrix_random_sign_mask_t=rht_matrix_random_sign_mask_t,
produce_regular_amax=produce_regular_amax,
flatten_axis=flatten_axis,
)
amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
amax, x_spec, transpose_batch_sequence, mesh
)
post_rht_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
post_rht_amax, x_spec, transpose_batch_sequence, mesh
)
return amax, post_rht_amax
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
mesh,
value_types,
result_types,
):
"""
amax calcuation shardy_sharding_rule
"""
del (
amax_scope,
transpose_batch_sequence,
rht_matrix_random_sign_mask_t,
produce_regular_amax,
flatten_axis,
mesh,
result_types,
)
prefix = "RHTAmaxCal"
input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape)))
output_amax_spec = (f"{prefix}_amax",)
output_post_rht_amax_spec = (f"{prefix}_post_rht_amax",)
return SdyShardingRule((input_spec,), (output_amax_spec, output_post_rht_amax_spec))
register_primitive(RHTAmaxCalculationPrimitive)
def calculate_amax(x: jnp.ndarray, amax_scope: AmaxScope, transpose_batch_sequence: bool):
"""
Compute the maximum absolute value (amax) of the input tensor.
"""
assert AmaxCalculationPrimitive.outer_primitive is not None
return AmaxCalculationPrimitive.outer_primitive.bind(
x,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
def calculate_post_rht_amax(
x: jnp.ndarray,
amax_scope: AmaxScope,
transpose_batch_sequence: bool,
produce_regular_amax: bool,
flatten_axis: int,
):
"""Compute the post-Random Hadamard Transform (RHT) amax of the input tensor, and optionally the regular amax.
Args:
x: Input tensor.
amax_scope: The scope for amax reduction (local, TPSP, or FSDP).
transpose_batch_sequence: Whether the input tensor has its batch and sequence dimensions transposed.
produce_regular_amax: Whether to compute and return the regular amax alongside the post-RHT amax.
flatten_axis: The axis at which to flatten the input tensor before applying RHT.
Returns:
A tuple containing:
- The regular amax if `produce_regular_amax` is True, otherwise None.
- The post-RHT amax.
"""
amax, post_rht_amax = RHTAmaxCalculationPrimitive.outer_primitive.bind(
x,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
rht_matrix_random_sign_mask_t=get_sign_from_vector(get_wgrad_sign_vector()),
produce_regular_amax=produce_regular_amax,
flatten_axis=flatten_axis,
)
if produce_regular_amax:
return amax, post_rht_amax
return None, post_rht_amax
......@@ -32,6 +32,7 @@ from ..quantize import (
AbstractBaseTensor,
NoScaleTensor,
ScaledTensor,
ScaledTensor1x,
ScaledTensor2x,
GroupedScaledTensor1x,
ScalingMode,
......@@ -43,6 +44,7 @@ from ..quantize import (
noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv,
should_use_rht,
)
from .misc import get_padded_spec, is_all_reduce_in_float32
from ..sharding import (
......@@ -138,6 +140,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_
need_lhs_colwise = lhs_is_transposed and (
lhs_quantizer.scaling_mode.is_1d_block_scaling()
or not is_fp8_gemm_with_all_layouts_supported()
or lhs_quantizer.scaling_mode.is_nvfp4_scaling
)
flatten_axis = max(lhs_cdims) + 1 if lhs_is_transposed else min(lhs_cdims)
lhs_q = lhs_quantizer.quantize(
......@@ -153,6 +156,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_
need_rhs_colwise = not rhs_is_transposed and (
rhs_quantizer.scaling_mode.is_1d_block_scaling()
or not is_fp8_gemm_with_all_layouts_supported()
or rhs_quantizer.scaling_mode.is_nvfp4_scaling
)
flatten_axis = min(rhs_cdims) if rhs_is_transposed else max(rhs_cdims) + 1
rhs_q = rhs_quantizer.quantize(
......@@ -165,9 +169,27 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_
assert not isinstance(lhs_q, ScaledTensor2x)
assert not isinstance(rhs_q, ScaledTensor2x)
def uses_rht(q: AbstractBaseTensor) -> bool:
return isinstance(q, ScaledTensor1x) and should_use_rht(
q.scaling_mode, is_colwise=q.is_colwise
)
# TODO(jberchtold): Move RHT usage check to a bool flag on the ScaledTensor class
assert uses_rht(lhs_q) == uses_rht(rhs_q), (
"With NVFP4_1D_SCALING, if one operand is colwise quantized, the other must be colwise"
" quantized as well. This is to ensure the RHT is applied to both and will cancel out in"
" the GEMM."
)
return lhs_q, rhs_q
def _get_nvfp4_tensor_scale_inv(amax):
DATA_DTYPE_MAX = jnp.finfo(jnp.float4_e2m1fn.dtype).max.astype(jnp.float32)
SCALE_DTYPE_MAX = jnp.finfo(jnp.float8_e4m3fn.dtype).max.astype(jnp.float32)
return amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX)
def collective_gemm_bootstrap(
num_total_devices,
num_devices_per_process,
......@@ -345,7 +367,7 @@ class GemmPrimitive(BasePrimitive):
name = "te_gemm_ffi"
multiple_results = True
impl_static_args = 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)
inner_primitive = None
outer_primitive = None
......@@ -357,6 +379,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_inv,
bias,
gelu_input,
alpha,
beta,
out_dtype,
contracting_dims,
scaling_mode,
......@@ -404,7 +428,9 @@ class GemmPrimitive(BasePrimitive):
lhs_is_transposed, rhs_is_transposed = _get_gemm_layout(operand_ndims, contracting_dims)
if scaling_mode != ScalingMode.NO_SCALING:
assert _compatible_fp8_gemm_dtypes(lhs.dtype, rhs.dtype), (
assert scaling_mode.is_nvfp4_scaling or _compatible_fp8_gemm_dtypes(
lhs.dtype, rhs.dtype
), (
"cuBLAS GEMM quantized operands have incompatible data types: "
f"{lhs.dtype} x {rhs.dtype}."
)
......@@ -484,6 +510,8 @@ class GemmPrimitive(BasePrimitive):
f"expected {pre_gelu_dtype} but found {gelu_input.dtype}."
)
pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype)
assert alpha.size == 1 and alpha.dtype == jnp.float32
assert beta.size == 1 and beta.dtype == jnp.float32
# Declare cuBLAS workspace
workspace_size = get_cublas_workspace_size_bytes()
......@@ -510,6 +538,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_inv,
bias,
gelu_input,
alpha,
beta,
out_dtype,
contracting_dims,
scaling_mode,
......@@ -530,7 +560,7 @@ class GemmPrimitive(BasePrimitive):
(lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims)
)
args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input)
args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta)
kwargs = {
"scaling_mode": int(scaling_mode.value),
"lhs_axis_boundary": max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims),
......@@ -563,6 +593,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_inv,
bias,
gelu_input,
alpha,
beta,
out_dtype,
contracting_dims,
scaling_mode,
......@@ -626,6 +658,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_inv,
bias,
gelu_input,
alpha,
beta,
out_dtype=out_dtype,
contracting_dims=contracting_dims,
scaling_mode=scaling_mode,
......@@ -675,6 +709,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_inv,
bias,
gelu_input,
alpha,
beta,
out_dtype,
contracting_dims,
scaling_mode,
......@@ -694,6 +730,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_inv,
bias,
gelu_input,
alpha,
beta,
out_dtype,
contracting_dims,
scaling_mode,
......@@ -1001,6 +1039,9 @@ class GemmPrimitive(BasePrimitive):
gelu_input_specs = (None,)
arg_shardings += (NamedSharding(mesh, PartitionSpec(*gelu_input_specs)),)
# Alpha, beta
arg_shardings += (none_sharding, none_sharding)
# Assemble output shardings
out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))]
......@@ -1014,7 +1055,7 @@ class GemmPrimitive(BasePrimitive):
pre_gelu_specs = (None,)
out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs)))
def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input):
def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta):
# We should not fuse bias in the output reduction case
sharded_fuse_bias = fuse_bias and reduce_spec is None
outputs = GemmPrimitive.impl(
......@@ -1024,6 +1065,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_inv,
bias,
gelu_input,
alpha,
beta,
out_dtype=out_dtype,
contracting_dims=contracting_dims,
scaling_mode=scaling_mode,
......@@ -1114,8 +1157,10 @@ class GemmPrimitive(BasePrimitive):
rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims)
out_spec = (*lhs_non_cspec, *rhs_non_cspec)
bias_spec = rhs_non_cspec if fuse_bias else ("…4",)
dbias_spec = bias_spec if grad else ("…5")
gelu_spec = out_spec if fuse_gelu else ("…6",)
gelu_spec = out_spec if fuse_gelu else ("…5",)
alpha_spec = ("_6",)
beta_spec = ("_7",)
dbias_spec = bias_spec if grad else ("…8")
return SdyShardingRule(
operand_mappings=(
......@@ -1125,6 +1170,8 @@ class GemmPrimitive(BasePrimitive):
rhs_scale_specs,
bias_spec,
gelu_spec,
alpha_spec,
beta_spec,
),
result_mappings=(
out_spec,
......@@ -1178,6 +1225,7 @@ def _te_gemm(
# Quantize operands (if necessary)
lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)
lhs_amax = rhs_amax = None
# Extract GEMM custom op inputs from quantized operands
if isinstance(lhs_q, ScaledTensor):
assert isinstance(rhs_q, ScaledTensor) or rhs_quantizer is not None, (
......@@ -1192,6 +1240,7 @@ def _te_gemm(
lhs_scale_inv = lhs_q.scale_inv
if lhs_q.data_layout == "T":
lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis)
lhs_amax = lhs_q.amax
if isinstance(rhs_q, ScaledTensor):
assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, (
......@@ -1201,7 +1250,11 @@ def _te_gemm(
if isinstance(rhs_q, ScaledTensor2x):
# Choose the quantization of the contracting dimension(s)
rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor()
assert rhs_q.scaling_mode == lhs_q.scaling_mode, (
assert (
rhs_q.scaling_mode == lhs_q.scaling_mode
or rhs_q.scaling_mode.is_nvfp4_scaling
and lhs_q.scaling_mode.is_nvfp4_scaling
), (
"cuBLAS GEMM quantized operands have mismatched scaling types, "
f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}."
)
......@@ -1209,6 +1262,15 @@ def _te_gemm(
rhs_scale_inv = rhs_q.scale_inv
if rhs_q.data_layout == "T":
rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis)
rhs_amax = rhs_q.amax
alpha = jnp.ones((1,), jnp.float32)
beta = jnp.zeros((1,), jnp.float32)
if scaling_mode.is_nvfp4_scaling:
assert lhs_amax is not None and rhs_amax is not None
lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs_amax)
rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs_amax)
alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv
# Dummy empties for bias and gelu
out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype
......@@ -1224,6 +1286,8 @@ def _te_gemm(
rhs_scale_inv,
bias,
gelu_input,
alpha,
beta,
out_dtype=out_dtype,
contracting_dims=(lhs_cdims, rhs_cdims),
scaling_mode=scaling_mode,
......@@ -1514,15 +1578,17 @@ def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision):
@partial(jax.jit, static_argnums=(2,))
def _jax_gemm_mxfp8_1d(
def _jax_scaled_matmul(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
"""
JAX GEMM for MXFP8 via scaled_matmul
"""
assert (
rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING
), "rhs does not have MXFP8 1D scaling mode"
assert rhs.scaling_mode in (
ScalingMode.MXFP8_1D_SCALING,
ScalingMode.NVFP4_1D_SCALING,
ScalingMode.NVFP4_2D_SCALING,
), f"rhs does not have MXFP8 or NVFP4 scaling mode, got rhs.scaling_mode={rhs.scaling_mode}"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
......@@ -1537,21 +1603,48 @@ def _jax_gemm_mxfp8_1d(
f" {rhs.is_colwise}"
)
if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
out_dtype = lhs.dq_dtype
assert (
lhs.data_layout == "N" and rhs.data_layout == "N"
), f"Got lhs.data_layout={lhs.data_layout}, rhs.data_layout={rhs.data_layout}"
else:
if lhs.data_layout == "T":
lhs_contract = transpose_dims(
lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis
)
if rhs.data_layout == "T":
rhs_contract = transpose_dims(
rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis
)
out_dtype = jnp.float32
# Reshape + Transpose (if needed)
# [..., M, K] -> [1, reduce(..., M), K]
# [..., K, M] -> [1, reduce(..., M), K]
lhs_3d = _shape_normalization(lhs.data, (lhs_contract, lhs_batch))
rhs_3d = _shape_normalization(rhs.data, (rhs_contract, rhs_batch))
lhs_scale_3d = _shape_normalization(lhs.scale_inv, (lhs_contract, lhs_batch))
rhs_scale_3d = _shape_normalization(rhs.scale_inv, (rhs_contract, rhs_batch))
lhs_3d = _shape_normalization(lhs.data, (lhs_contract, lhs_batch), lhs.data_layout == "T")
rhs_3d = _shape_normalization(rhs.data, (rhs_contract, rhs_batch), rhs.data_layout == "T")
lhs_scale_3d = _shape_normalization(
lhs.scale_inv, (lhs_contract, lhs_batch), lhs.data_layout == "T"
)
rhs_scale_3d = _shape_normalization(
rhs.scale_inv, (rhs_contract, rhs_batch), rhs.data_layout == "T"
)
# JAX scaled_matmul only supports NT now (TN-gemm)
# * Expected shape:
# * lhs_data (B, M, K) * rhs_data (B, N, K)
# * lhs_scale (B, M, K_block) * rhs_scale (B, N, K_block)
out_3d = jax.nn.scaled_matmul(
lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=lhs.dq_dtype
lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=out_dtype
)
if lhs.scaling_mode.is_nvfp4_scaling:
assert lhs.amax is not None and rhs.amax is not None
lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs.amax)
rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs.amax)
alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv
out_3d = (out_3d * alpha).astype(lhs.dq_dtype)
# Reshape [1, reduce(..., M), N] -> [..., M, N]
lhs_remain_shape = tuple(
lhs.data.shape[dim] for dim in range(len(lhs.data.shape)) if dim not in lhs_contract
......@@ -1560,6 +1653,7 @@ def _jax_gemm_mxfp8_1d(
rhs.data.shape[dim] for dim in range(len(rhs.data.shape)) if dim not in rhs_contract
)
out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape)
return out
......@@ -1575,7 +1669,7 @@ def _jax_gemm(
"""
dim_nums = (contracting_dims, ((), ()))
def _jax_gemm_fp8_impl(lhs, rhs):
def _jax_gemm_impl(lhs, rhs):
if lhs.scaling_mode.is_tensor_scaling():
assert (
rhs.scaling_mode == lhs.scaling_mode
......@@ -1587,15 +1681,15 @@ def _jax_gemm(
)
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)
if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
if lhs.scaling_mode.is_1d_block_scaling:
return _jax_scaled_matmul(lhs, rhs, dim_nums)
raise NotImplementedError(f"Unsupported ScalingMode: {lhs.scaling_mode}")
lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)
if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor):
return _jax_gemm_fp8_impl(lhs_q, rhs_q)
return _jax_gemm_impl(lhs_q, rhs_q)
if (
isinstance(lhs, jnp.ndarray)
......
......@@ -6,8 +6,6 @@
import os
import functools
from typing import Tuple
from importlib.metadata import version as get_pkg_version
from packaging.version import Version as PkgVersion
import numpy as np
......@@ -75,7 +73,8 @@ def jax_dtype_to_te_dtype(jax_dtype):
jnp.int64.dtype: TEDType.kInt64,
jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3,
jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2,
jnp.uint8.dtype: TEDType.kByte,
jnp.float8_e8m0fnu.dtype: TEDType.kFloat8E8M0,
jnp.float4_e2m1fn.dtype: TEDType.kFloat4E2M1,
}
if jax_dtype not in converter:
......@@ -151,16 +150,6 @@ def get_cudnn_version() -> Tuple[int, int, int]:
return (major, minor, patch)
@functools.lru_cache(maxsize=None)
def jax_version_meet_requirement(version: str):
"""
Helper function checking if required JAX version is available
"""
jax_version = PkgVersion(get_pkg_version("jax"))
jax_version_required = PkgVersion(version)
return jax_version >= jax_version_required
def get_xla_flag(flag: str, default=None, cast=str):
"""
Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value.
......
......@@ -28,7 +28,10 @@ from .misc import (
get_cudnn_version,
)
from .quantization import _quantize_dbias_impl, AmaxScope
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp_tpsp
from ..sharding import (
all_reduce_max_along_all_axes_except_PP,
all_reduce_sum_along_dp_fsdp_tpsp,
)
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import (
Quantizer,
......@@ -1031,7 +1034,10 @@ def layernorm_fwd(
)
return out, mu, rsigma
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
if (
quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
or quantizer.scaling_mode.is_nvfp4_scaling
):
# Current scaling does not support fused operations. Perform norm in higher precision then quantize after.
out, mu, rsigma = layernorm_fwd(
x=x,
......@@ -1276,7 +1282,10 @@ def rmsnorm_fwd(
)
return out, rsigma
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
if (
quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
or quantizer.scaling_mode.is_nvfp4_scaling
):
# Current scaling does not support fused operations. Perform norm in higher precision then quantize after.
out, rsigma = rmsnorm_fwd(
x=x,
......
......@@ -6,7 +6,6 @@ import operator
from functools import reduce
from typing import Tuple, Optional, Union
import math
from enum import Enum
import jax
......@@ -17,6 +16,7 @@ from jax.sharding import PartitionSpec
import transformer_engine_jax
from .amax import AmaxScope, calculate_amax, calculate_post_rht_amax
from .base import BasePrimitive, register_primitive
from .misc import (
get_padded_spec,
......@@ -31,8 +31,7 @@ from .misc import (
from ..sharding import (
all_reduce_max_along_all_axes_except_PP,
all_reduce_sum_along_dp_fsdp,
global_mesh_resource,
lax_paral_op,
num_of_devices,
)
from ..quantize import (
ScaledTensor2x,
......@@ -45,6 +44,8 @@ from ..quantize import (
ScalingMode,
compute_scale_from_amax,
NoScaleTensor,
get_rht_matrix,
should_use_rht,
)
......@@ -59,14 +60,16 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
name = "te_dbias_quantize_ffi"
multiple_results = True
impl_static_args = (
3,
4,
5,
6,
7,
8,
9,
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer
6, # out_dtype
7, # scaling_mode
8, # q_layout
9, # flatten_axis
10, # scale_dtype
11, # is_dbias
12, # is_outer
13, # stochastic_rounding
14, # use_rht
)
inner_primitive = None
outer_primitive = None
......@@ -75,6 +78,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
x_aval,
scale_aval,
amax_aval,
sr_rng_state_aval,
post_rht_amax_aval,
rht_matrix_aval,
*,
out_dtype,
scaling_mode,
......@@ -83,6 +89,8 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
scale_dtype,
is_dbias,
is_outer,
stochastic_rounding,
use_rht,
):
"""
te_dbias_quantize_p abstract
......@@ -91,6 +99,28 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_shape = x_aval.shape
assert scale_aval is None or scale_aval.dtype == jnp.float32
if stochastic_rounding:
assert ScalingMode(
scaling_mode
).is_nvfp4_scaling, "stochastic_rounding can only be used with NVFP4 scaling modes"
# JAX doesn't support 64-bit by default so use 4x uint32 instead of 2x int64
assert sr_rng_state_aval is not None and sr_rng_state_aval.dtype == jnp.uint32, (
"sr_rng_state must be a uint32 array when stochastic_rounding is True but"
f" received {sr_rng_state_aval}"
)
if is_outer:
assert (
sr_rng_state_aval.shape[0] == num_of_devices()
and sr_rng_state_aval.shape[1] == 4
), (
"sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is"
f" True and is_outer is True but received {sr_rng_state_aval.shape}"
)
else:
assert sr_rng_state_aval.shape == (4,), (
"Sharded sr_rng_state must be of shape (4,) per device when"
f" stochastic_rounding is True but received {sr_rng_state_aval.shape}"
)
if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
rowwise_out_shape = out_shape
......@@ -98,14 +128,50 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
rowwise_out_shape = (1,)
rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
assert out_dtype in ScalingMode(scaling_mode).get_compatible_q_dtypes(), (
f"out_dtype {out_dtype} not compatible with scaling_mode {scaling_mode}. out_dtype must"
f" be one of {ScalingMode(scaling_mode).get_compatible_q_dtypes()}"
)
updated_amax_aval = amax_aval
if use_rht:
assert (
x_aval.dtype == jnp.bfloat16
), "x must be of dtype bfloat16 to be eligible for RHT cast fusion."
if flatten_axis < 0:
flatten_axis += len(x_aval.shape)
rows = reduce(operator.mul, x_aval.shape[:flatten_axis], 1)
cols = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
assert rows % 64 == 0 and cols % 128 == 0, (
"Rows must be multiple of 64 and cols multiple of 128 when use_rht is True to be"
f" eligible for RHT cast fusion. Received rows {rows} and cols {cols} of 2D shape"
f" from original shape of {x_aval.shape} with flatten_axis {flatten_axis}."
)
assert (
rht_matrix_aval is not None
and rht_matrix_aval.dtype == jnp.bfloat16
and rht_matrix_aval.shape == (16, 16)
), "rht_matrix must be of shape (16, 16) and dtype bfloat16"
assert (
post_rht_amax_aval is not None
and post_rht_amax_aval.dtype == jnp.float32
and post_rht_amax_aval.size == 1
), "post_rht_amax must be of dtype float32"
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
).get_scale_shape_2x(
x_aval.shape,
is_padded=not is_outer,
flatten_axis=flatten_axis,
broadcast_2d_scale_shape_to_1d=True,
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if ScalingMode(scaling_mode).is_tensor_scaling():
if ScalingMode(scaling_mode).is_colwise_transposed:
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
else:
colwise_out_shape = out_shape
......@@ -126,6 +192,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
gi_hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype),
jax_dtype_to_te_dtype(scale_dtype),
scaling_mode,
QuantizeLayout(
q_layout
......@@ -172,6 +239,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
x,
scale,
amax,
sr_rng_state,
post_rht_amax,
rht_matrix,
*,
out_dtype,
scaling_mode,
......@@ -180,12 +250,14 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
scale_dtype,
is_dbias,
is_outer,
stochastic_rounding,
use_rht,
):
"""
te_dbias_quantize_p lowering rules
"""
del out_dtype, scale_dtype, is_outer
x_aval, scale_aval, amax_aval = ctx.avals_in
x_aval, scale_aval, amax_aval, _, _, _ = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval.dtype == amax_aval.dtype == jnp.float32
return ffi.ffi_lowering(
......@@ -196,10 +268,15 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
x,
scale,
amax,
sr_rng_state,
post_rht_amax,
rht_matrix,
scaling_mode=scaling_mode.value,
q_layout=q_layout,
flatten_axis=flatten_axis,
is_dbias=is_dbias,
stochastic_rounding=stochastic_rounding,
use_rht=use_rht,
)
@staticmethod
......@@ -207,6 +284,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
x,
scale,
amax,
sr_rng_state,
post_rht_amax,
rht_matrix,
out_dtype,
scaling_mode,
q_layout,
......@@ -214,6 +294,8 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
scale_dtype,
is_dbias,
is_outer,
stochastic_rounding,
use_rht,
):
"""
te_dbias_quantize_p implementation
......@@ -232,6 +314,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
x,
scale,
amax,
sr_rng_state,
post_rht_amax,
rht_matrix,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
......@@ -239,10 +324,14 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
scale_dtype=scale_dtype,
is_dbias=is_dbias,
is_outer=False,
stochastic_rounding=stochastic_rounding,
use_rht=use_rht,
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis)
).get_scale_shape_2x(
x.shape, is_padded=False, flatten_axis=flatten_axis, broadcast_2d_scale_shape_to_1d=True
)
scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
)
......@@ -271,6 +360,8 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
scale_dtype,
is_dbias,
is_outer,
stochastic_rounding,
use_rht,
):
"""
to describe batch rules for vmap
......@@ -278,8 +369,8 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
del is_outer
check_valid_batch_dims(batch_dims)
assert BaseDBiasQuantizePrimitive.outer_primitive is not None
x, scale, amax = batched_args
x_bdim, scale_bdim, amax_bdim = batch_dims
x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix = batched_args
x_bdim, scale_bdim, amax_bdim, _, _, _ = batch_dims
out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim
return (
......@@ -287,12 +378,17 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
x,
scale,
amax,
sr_rng_state,
post_rht_amax,
rht_matrix,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype,
is_dbias=is_dbias,
stochastic_rounding=stochastic_rounding,
use_rht=use_rht,
),
out_bdims,
)
......@@ -306,11 +402,20 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
scale_dtype,
is_dbias,
is_outer,
stochastic_rounding,
use_rht,
mesh,
arg_infos,
result_infos,
):
del (out_dtype, result_infos, scale_dtype, is_outer) # Unused.
del (
out_dtype,
result_infos,
scale_dtype,
is_outer,
stochastic_rounding,
use_rht,
) # Unused.
x_spec = get_padded_spec(arg_infos[0])
amax_spec = get_padded_spec(arg_infos[2])
......@@ -320,7 +425,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
desc="BaseDBiasQuantizePrimitive.out_sharding",
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if ScalingMode(scaling_mode).is_tensor_scaling():
if ScalingMode(scaling_mode).is_colwise_transposed:
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else:
colwise_out_spec = x_spec
......@@ -340,10 +445,18 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
)
scale_inv_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
if ScalingMode(scaling_mode).is_block_scaling:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if (
ScalingMode(scaling_mode).is_block_scaling
and ScalingMode(scaling_mode).is_colwise_transposed
):
colwise_scale_inv_spec = multidim_transpose(
scale_inv_spec, transpose_axis=flatten_axis
)
else:
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
......@@ -376,11 +489,13 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
scale_dtype,
is_dbias,
is_outer,
stochastic_rounding,
use_rht,
mesh,
arg_infos,
result_infos,
):
del result_infos, is_outer
del result_infos, is_outer # Unused.
x_spec = get_padded_spec(arg_infos[0])
amax_spec = get_padded_spec(arg_infos[2])
......@@ -389,8 +504,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
PartitionSpec(*x_spec),
desc="BaseDBiasQuantizePrimitive.out_sharding",
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if ScalingMode(scaling_mode).is_tensor_scaling():
if ScalingMode(scaling_mode).is_colwise_transposed:
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else:
colwise_out_spec = x_spec
......@@ -410,10 +526,18 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
)
scale_inv_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
if ScalingMode(scaling_mode).is_block_scaling:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if (
ScalingMode(scaling_mode).is_block_scaling
and ScalingMode(scaling_mode).is_colwise_transposed
):
colwise_scale_inv_spec = multidim_transpose(
scale_inv_spec, transpose_axis=flatten_axis
)
else:
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
......@@ -428,6 +552,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
desc="BaseDBiasQuantizePrimitive.colwise_scale_inv",
)
# TODO(jberchtold): Assert the sr_rng state is sharded along all mesh axes
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (
out_sharding,
......@@ -438,7 +563,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
dbias_sharding,
)
def sharded_impl(x, scale, amax):
def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix):
(
local_x,
local_colwise_x,
......@@ -450,6 +575,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
x,
scale,
amax,
sr_rng_state,
post_rht_amax,
rht_matrix,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
......@@ -457,6 +585,8 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
scale_dtype=scale_dtype,
is_dbias=is_dbias,
is_outer=True,
stochastic_rounding=stochastic_rounding,
use_rht=use_rht,
)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
......@@ -489,35 +619,54 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
scale_dtype,
is_dbias,
is_outer,
stochastic_rounding,
use_rht,
mesh,
value_types,
result_types,
):
del out_dtype, scale_dtype, is_outer, mesh, result_types
del (
out_dtype,
scale_dtype,
is_outer,
stochastic_rounding,
use_rht,
mesh,
result_types,
)
prefix = "DBiasQuantize_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
value_types[0].shape,
unique_var=prefix + "x",
flatten_axis=flatten_axis,
broadcast_2d_scale_shape_to_1d=True,
)
x_axes = scale_rules.input_spec
colwise_scale_inv = scale_rules.colwise_rule
out = x_axes
colwise_out = (prefix + "out_colwise",)
colwise_scale_inv = (prefix + "colwise_scale_inv",)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_scale_inv = scale_rules.colwise_rule
if ScalingMode(scaling_mode).is_colwise_transposed:
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
colwise_scale_inv = tuple(
multidim_transpose(colwise_scale_inv, transpose_axis=flatten_axis)
)
else:
colwise_out = x_axes
dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",)
amax = (prefix + "amax",)
sr_rng_state = (prefix + "sr_rng_state_partition_axis", prefix + "sr_rng_state_data_axis")
post_rht_amax = (prefix + "post_rht_amax",)
rht_matrix = (prefix + "rht_matrix_1", prefix + "rht_matrix_2")
return SdyShardingRule(
(x_axes, ("…1",), amax),
(x_axes, ("…1",), amax, sr_rng_state, post_rht_amax, rht_matrix),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
**scale_rules.factor_sizes,
)
......@@ -534,141 +683,6 @@ class QuantizePrimitive(BaseDBiasQuantizePrimitive):
"""Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
class AmaxScope(Enum):
"""
Amax Scope Enum
"""
LOCAL = 1
TPSP = 2
FSDP = 3
def all_reduce_amax_along_TPSP_and_FSDP(self, amax, data_spec, transpose_batch_sequence, mesh):
"""Reduce the amax based on its scope"""
gmesh = global_mesh_resource()
sequence_dim = 0 if transpose_batch_sequence else 1
# Run AR across TPSP only when tensor-sequence is detected in the input spec
if self is AmaxScope.TPSP and data_spec[sequence_dim] == gmesh.tpsp_resource:
return lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh)
# Run AR across FSDP
if self is AmaxScope.FSDP:
return lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh)
return amax
class AmaxCalculationPrimitive(BasePrimitive):
"""
Amax Calculation Primitive with custom_partitioning
"""
name = "jax_local_amax"
multiple_results = False
impl_static_args = (
1,
2,
) # amax_scope, transpose_batch_sequence
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
x_aval,
*,
amax_scope,
transpose_batch_sequence,
):
"""
amax calcuation abstract
"""
del amax_scope, transpose_batch_sequence
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
return out_aval
@staticmethod
def impl(
x,
amax_scope,
transpose_batch_sequence,
):
"""
amax calcuation implementation
"""
del amax_scope, transpose_batch_sequence
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,))
return amax
@staticmethod
def infer_sharding_from_operands(
amax_scope,
transpose_batch_sequence,
mesh,
arg_infos,
result_infos,
):
"""
amax calcuation infer_sharding_from_operands
"""
del (amax_scope, transpose_batch_sequence, arg_infos, result_infos) # Unused.
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="AmaxCalculationPrimitive.out_sharding",
)
return amax_sharding
@staticmethod
def partition(
amax_scope,
transpose_batch_sequence,
mesh,
arg_infos,
result_infos,
):
"""
amax calcuation partition
"""
del result_infos
x_spec = get_padded_spec(arg_infos[0])
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="AmaxCalculation.amax_sharding",
)
def sharded_impl(x):
amax = AmaxCalculationPrimitive.impl(
x,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP(
amax, x_spec, transpose_batch_sequence, mesh
)
return amax
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
return mesh, sharded_impl, amax_sharding, arg_shardings
@staticmethod
def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types, result_types):
"""
amax calcuation shardy_sharding_rule
"""
del amax_scope, transpose_batch_sequence, mesh, result_types
prefix = "AmaxCal"
input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape)))
output_spec = (f"{prefix}_amax",)
return SdyShardingRule((input_spec,), (output_spec,))
register_primitive(AmaxCalculationPrimitive, outer_only=True)
def _jax_quantize(
x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
......@@ -740,7 +754,11 @@ def _quantize_dbias_impl(
# If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
# fall back on the native-JAX quantize implementation
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
if quantizer.q_layout == QuantizeLayout.COLWISE or not PrimitiveClass.enabled():
is_unsupported = (
quantizer.q_layout == QuantizeLayout.COLWISE
and quantizer.scaling_mode != ScalingMode.NVFP4_1D_SCALING
)
if is_unsupported or not PrimitiveClass.enabled():
if is_dbias:
return _jax_quantize_dbias(
x,
......@@ -767,15 +785,32 @@ def _quantize_dbias_impl(
dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis)
return out, dbias
use_rht = False
scale = jnp.empty((1,), jnp.float32)
amax = None
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Globally reduce amax across all devices for current scaling so we have a single global scale.
# This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
# until the tensor is dequantized (e.g. in the GEMM).
post_rht_amax = None
rht_matrix = jnp.empty((1, 1), jnp.bfloat16)
amax = x.amax
if should_use_rht(quantizer.scaling_mode, q_layout=quantizer.q_layout):
use_rht = True
rht_matrix = get_rht_matrix()
new_amax, post_rht_amax = calculate_post_rht_amax(
x.data,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
produce_regular_amax=amax is None,
flatten_axis=flatten_axis,
)
if amax is None:
# If amax is already calculated in a previous layer, we skip calculating it in the TE kernel
# So here we only calculate and update amax when it is not provided from a previous layer (amax is None)
amax = new_amax
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
if amax is None:
amax = AmaxCalculationPrimitive.outer_primitive.bind(
amax = calculate_amax(
x.data,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
......@@ -783,8 +818,17 @@ def _quantize_dbias_impl(
scale = compute_scale_from_amax(amax, quantizer.q_dtype)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale
# Make sure to reset amax to zeros for DelayedScaling
amax = jnp.zeros((1,), jnp.float32)
elif quantizer.scaling_mode.is_nvfp4_scaling:
if amax is None:
amax = calculate_amax(
x.data,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
)
# Make sure amax is init with zero
# Make sure amax is not None
if amax is None:
amax = jnp.zeros((1,), jnp.float32)
......@@ -796,9 +840,16 @@ def _quantize_dbias_impl(
and is_1x_kernel_supported
)
q_layout = quantizer.q_layout
if force_1x_quantization:
q_layout = QuantizeLayout.ROWWISE
sr_rng_state = None
if quantizer.scaling_mode.is_nvfp4_scaling:
# Only NVFP4 scaling modes support stochastic rounding
if quantizer.stochastic_rounding_rng_state is not None:
sr_rng_state = quantizer.stochastic_rounding_rng_state
(
rowwise_casted_output,
colwise_casted_output,
......@@ -810,13 +861,18 @@ def _quantize_dbias_impl(
x.data,
scale,
amax,
sr_rng_state if sr_rng_state is not None else jnp.empty((num_of_devices(), 1), jnp.uint32),
post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32),
rht_matrix,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
q_layout=q_layout.value,
flatten_axis=flatten_axis,
scale_dtype=quantizer.get_scale_dtype(),
is_dbias=is_dbias,
is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False,
is_outer=True,
stochastic_rounding=sr_rng_state is not None,
use_rht=use_rht,
)
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
......@@ -830,14 +886,17 @@ def _quantize_dbias_impl(
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis))
)
quantizer.update(updated_amax)
if quantizer.scaling_mode.is_nvfp4_scaling and is_dbias:
dbias = _jax_dbias(x, flatten_axis=flatten_axis)
out = ScaledTensorFactory.create(
data=rowwise_casted_output,
scale_inv=rowwise_scale_inv,
colwise_data=colwise_casted_output,
colwise_scale_inv=colwise_scale_inv,
amax=updated_amax,
colwise_amax=post_rht_amax,
scaling_mode=quantizer.scaling_mode,
dq_dtype=dq_dtype,
q_layout=quantizer.q_layout,
......@@ -955,6 +1014,11 @@ class GroupedQuantizePrimitive(BasePrimitive):
# TODO(Phuong): can scale_aval be None?
assert scale_aval is None or scale_aval.dtype == jnp.float32
assert out_dtype in ScalingMode(scaling_mode).get_compatible_q_dtypes(), (
f"out_dtype {out_dtype} not compatible with scaling_mode {scaling_mode}. out_dtype must"
f" be one of {ScalingMode(scaling_mode).get_compatible_q_dtypes()}"
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_grouped_scale_shape_2x(
......
......@@ -85,7 +85,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype,
DType in_dtype, DType out_dtype, DType scale_dtype,
JAXX_Scaling_Mode scaling_mode,
QuantizeLayout q_layout);
......@@ -138,6 +138,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
// Amax
XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler);
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment