Unverified Commit cf9a7c2f authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Refactor + MXFP8 + GroupedGEMM (#1627)



* refactor + mxfp8

* added grouped gemm

* rename linear to dense

* added cublas init phase for groupedGemm

* relax the tol of test encoder multiprocessing mxfp8 by 0.001
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent be055eb0
......@@ -6,6 +6,7 @@ from functools import lru_cache
import transformer_engine
from transformer_engine_jax import get_device_compute_capability
from transformer_engine.common import recipe
@lru_cache
......@@ -20,3 +21,21 @@ def is_fp8_supported():
"""Return if FP8 has hardware supported"""
gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 90
@lru_cache
def is_mxfp8_supported():
"""Return if FP8 has hardware supported"""
gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 100
def get_fp8_recipe_from_name_string(name: str):
"""Query recipe from a given name string"""
match name:
case "DelayedScaling":
return recipe.DelayedScaling()
case "MXFP8BlockScaling":
return recipe.MXFP8BlockScaling()
case _:
raise ValueError(f"Invalid fp8_recipe, got {name}")
......@@ -12,6 +12,12 @@ wait
for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_fp8 --num-process=$NUM_GPUS --process-id=$i &
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_delayed_scaling_fp8 --num-process=$NUM_GPUS --process-id=$i &
done
wait
for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_mxfp8 --num-process=$NUM_GPUS --process-id=$i &
done
wait
......@@ -19,10 +19,11 @@ from flax.training import train_state
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec, NamedSharding
from common import is_bf16_supported, get_fp8_recipe_from_name_string
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
from common import is_bf16_supported
DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model"
......@@ -217,9 +218,8 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
)
func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
def get_params_sharding(sharding_rules, abs_var_collect, mesh):
......@@ -272,6 +272,19 @@ def train_and_evaluate(args):
args.test_batch_size % num_gpu_dp == 0
), f"Test batch size needs to be multiple of {num_gpu_dp}"
if args.fp8_recipe == "MXFP8BlockScaling":
assert (
args.batch_size / num_gpu_dp % 32 == 0
), "Batch size needs to be multiple of 32 for MXFP8"
assert (
args.test_batch_size / num_gpu_dp % 32 == 0
), "Test batch size needs to be multiple of 32 for MXFP8"
if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
else:
fp8_recipe = None
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
......@@ -287,7 +300,9 @@ def train_and_evaluate(args):
label_shape = [args.batch_size]
with te.fp8_autocast(
args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None)
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
):
encoder = Net(num_embed, args.enable_sp)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
......@@ -371,21 +386,21 @@ def encoder_parser(args):
parser.add_argument(
"--batch-size",
type=int,
default=64,
default=128,
metavar="N",
help="input batch size for training (default: 64)",
help="input batch size for training (default: 128)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=64,
default=128,
metavar="N",
help="input batch size for testing (default: 64)",
help="input batch size for testing (default: 128)",
)
parser.add_argument(
"--max-seq-len",
type=int,
default=32,
default=64,
metavar="N",
help="maximum sequence length (default: 32)",
)
......@@ -416,6 +431,12 @@ def encoder_parser(args):
default=False,
help="Use FP8 for inference and training without recalibration",
)
parser.add_argument(
"--fp8-recipe",
action="store_true",
default="DelayedScaling",
help="Use FP8 recipe (default: DelayedScaling)",
)
parser.add_argument(
"--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
)
......@@ -426,7 +447,8 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
@classmethod
def setUpClass(cls):
......@@ -437,29 +459,48 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
assert actual[0] < 0.50 and actual[1] > 0.76
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.50 and actual[1] > 0.76
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_sp(self):
def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
assert actual[0] < 0.50 and actual[1] > 0.76
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self):
"""Test Transformer Engine with DelayedScaling FP8 + SP"""
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8_sp(self):
"""Test Transformer Engine with FP8 + SP"""
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self):
"""Test Transformer Engine with MXFP8 + SP"""
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.50 and actual[1] > 0.76
if __name__ == "__main__":
......
......@@ -19,10 +19,11 @@ from flax.training import train_state
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec, NamedSharding
from common import is_bf16_supported, get_fp8_recipe_from_name_string
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
from common import is_bf16_supported
DEVICE_DP_AXIS = "data"
PARAMS_KEY = "params"
......@@ -198,9 +199,8 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
)
func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
def get_params_sharding(sharding_rules, abs_var_collect, mesh):
......@@ -243,6 +243,18 @@ def train_and_evaluate(args):
num_gpu = jax.local_device_count()
assert args.batch_size % num_gpu == 0, f"Batch size needs to be multiple of {num_gpu}"
assert args.test_batch_size % num_gpu == 0, f"Test batch size needs to be multiple of {num_gpu}"
if args.fp8_recipe == "MXFP8BlockScaling":
assert (
args.batch_size / num_gpu % 32 == 0
), "Batch size needs to be multiple of 32 for MXFP8"
assert (
args.test_batch_size / num_gpu % 32 == 0
), "Test batch size needs to be multiple of 32 for MXFP8"
if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
else:
fp8_recipe = None
device_mesh = mesh_utils.create_device_mesh((num_gpu,))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)) as mesh:
......@@ -257,7 +269,9 @@ def train_and_evaluate(args):
label_shape = [args.batch_size]
with te.fp8_autocast(
args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None)
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None),
):
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
......@@ -344,16 +358,16 @@ def encoder_parser(args):
parser.add_argument(
"--batch-size",
type=int,
default=128,
default=256,
metavar="N",
help="input batch size for training (default: 128)",
help="input batch size for training (default: 256)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=128,
default=256,
metavar="N",
help="input batch size for testing (default: 128)",
help="input batch size for testing (default: 256)",
)
parser.add_argument(
"--max-seq-len",
......@@ -389,6 +403,12 @@ def encoder_parser(args):
default=False,
help="Use FP8 for inference and training without recalibration",
)
parser.add_argument(
"--fp8-recipe",
action="store_true",
default="DelayedScaling",
help="Use FP8 recipe (default: DelayedScaling)",
)
return parser.parse_args(args)
......@@ -396,7 +416,8 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
@classmethod
def setUpClass(cls):
......@@ -407,14 +428,23 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.535 and actual[1] > 0.73
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.535 and actual[1] > 0.73
if __name__ == "__main__":
......
......@@ -21,9 +21,15 @@ from flax.training import train_state
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec, NamedSharding
from common import is_bf16_supported, is_fp8_supported
from common import (
is_bf16_supported,
is_fp8_supported,
is_mxfp8_supported,
get_fp8_recipe_from_name_string,
)
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
......@@ -298,9 +304,8 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
)
func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
def get_params_sharding(sharding_rules, abs_var_collect, mesh):
......@@ -359,10 +364,16 @@ def train_and_evaluate(args):
num_gpu_dp = 1
num_gpu_tp = 1
assert args.batch_size % num_gpu_dp == 0, f"Batch size needs to be multiple of {num_gpu_dp}"
if args.fp8_recipe == "MXFP8BlockScaling":
assert args.batch_size % 32 == 0, "Batch size needs to be multiple of 32 for MXFP8"
assert (
args.test_batch_size % num_gpu_dp == 0
), f"Test batch size needs to be multiple of {num_gpu_dp}"
args.test_batch_size % 32 == 0
), "Test batch size needs to be multiple of 32 for MXFP8"
if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
else:
fp8_recipe = None
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(
......@@ -379,7 +390,9 @@ def train_and_evaluate(args):
label_shape = [args.batch_size]
with te.fp8_autocast(
args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None)
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
):
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
......@@ -482,23 +495,23 @@ def encoder_parser(args):
parser.add_argument(
"--batch-size",
type=int,
default=64,
default=128,
metavar="N",
help="input batch size for training (default: 64)",
help="input batch size for training (default: 128)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=64,
default=128,
metavar="N",
help="input batch size for testing (default: 64)",
help="input batch size for testing (default: 128)",
)
parser.add_argument(
"--max-seq-len",
type=int,
default=32,
default=64,
metavar="N",
help="maximum sequence length (default: 32)",
help="maximum sequence length (default: 64)",
)
parser.add_argument(
"--epochs",
......@@ -527,6 +540,12 @@ def encoder_parser(args):
default=False,
help="Use FP8 for inference and training without recalibration",
)
parser.add_argument(
"--fp8-recipe",
action="store_true",
default="DelayedScaling",
help="Use FP8 recipe (default: DelayedScaling)",
)
parser.add_argument(
"--coordinator-address",
type=str,
......@@ -554,37 +573,46 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
gpu_has_fp8 = is_fp8_supported()
gpu_has_bf16 = is_bf16_supported()
def exec(self, use_fp8):
def exec(self, use_fp8, fp8_recipe):
"""Run 3 epochs for testing"""
args = encoder_parser([])
num_gpu = self.num_process
tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1
dp_size = num_gpu // tp_size
batch_size = 64 // dp_size
assert args.batch_size % dp_size == 0, f"Batch size needs to be multiple of {dp_size}"
batch_size = args.batch_size // dp_size
args.use_fp8 = use_fp8
args.batch_size = batch_size
args.test_batch_size = batch_size
args.num_process = num_gpu
args.process_id = self.process_id
args.fp8_recipe = fp8_recipe
return train_and_evaluate(args)
@unittest.skipIf(not gpu_has_bf16, "Device compute capability 8.0+ is required for BF16")
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
result = self.exec(False)
assert result[0] < 0.45 and result[1] > 0.79
@unittest.skipIf(not gpu_has_fp8, "Device compute capability 9.0+ is required for FP8")
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
result = self.exec(True)
assert result[0] < 0.455 and result[1] > 0.79
result = self.exec(False, None)
assert result[0] < 0.505 and result[1] > 0.755
@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8"
)
def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling")
assert result[0] < 0.505 and result[1] > 0.755
@unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
)
def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8"""
result = self.exec(True, "MXFP8BlockScaling")
assert result[0] < 0.505 and result[1] > 0.754
if __name__ == "__main__":
......
......@@ -16,10 +16,11 @@ from datasets import load_dataset
from flax import linen as nn
from flax.training import train_state
from common import is_bf16_supported, get_fp8_recipe_from_name_string
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
from common import is_bf16_supported
PARAMS_KEY = "params"
DROPOUT_KEY = "dropout"
......@@ -59,7 +60,7 @@ class Net(nn.Module):
return x
@partial(jax.jit)
@jax.jit
def train_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes gradients, loss and accuracy for a single batch."""
......@@ -195,9 +196,8 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
)
func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
def train_and_evaluate(args):
......@@ -214,7 +214,12 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
with te.fp8_autocast(enabled=args.use_fp8):
if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
else:
fp8_recipe = None
with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe):
encoder = Net(num_embed)
# We use nn.Embed, thus inputs need to be in int
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
......@@ -309,6 +314,12 @@ def encoder_parser(args):
default=False,
help="Use FP8 for inference and training without recalibration",
)
parser.add_argument(
"--fp8-recipe",
action="store_true",
default="DelayedScaling",
help="Use FP8 recipe (default: DelayedScaling)",
)
return parser.parse_args(args)
......@@ -316,7 +327,8 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
@classmethod
def setUpClass(cls):
......@@ -329,10 +341,19 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79
......
......@@ -5,6 +5,8 @@
import argparse
import unittest
from functools import partial
import sys
from pathlib import Path
import jax
import jax.numpy as jnp
......@@ -16,6 +18,11 @@ from flax.training import train_state
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
DIR = str(Path(__file__).resolve().parents[1])
sys.path.append(str(DIR))
from encoder.common import is_bf16_supported, get_fp8_recipe_from_name_string
IMAGE_H = 28
IMAGE_W = 28
......@@ -37,6 +44,7 @@ class Net(nn.Module):
else:
nn_Dense = nn.Dense
# dtype is used for param init in TE but computation in Linen.nn
dtype = jnp.float32 if self.use_te else jnp.bfloat16
x = nn.Conv(features=32, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x)
......@@ -50,8 +58,8 @@ class Net(nn.Module):
x = nn_Dense(features=128, dtype=dtype)(x)
x = nn.relu(x)
x = nn.Dropout(rate=0.5)(x, deterministic=disable_dropout)
x = nn_Dense(features=16, dtype=dtype)(x)
x = nn_Dense(features=10, dtype=dtype)(x)
x = nn_Dense(features=32, dtype=dtype)(x)
x = nn_Dense(features=32, dtype=dtype)(x)
assert x.dtype == jnp.bfloat16
return x
......@@ -62,7 +70,7 @@ def apply_model(state, images, labels, var_collect, rngs=None):
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, images, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 10)
one_hot = jax.nn.one_hot(labels, 32)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......@@ -153,7 +161,7 @@ def get_datasets():
def check_fp8(state, var_collect, input_shape, label_shape):
"Check if model includes FP8."
assert "f8_" in str(
func_jaxpr = str(
jax.make_jaxpr(apply_model)(
state,
jnp.empty(input_shape, dtype=jnp.bfloat16),
......@@ -161,6 +169,7 @@ def check_fp8(state, var_collect, input_shape, label_shape):
var_collect,
)
)
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr
def train_and_evaluate(args):
......@@ -179,7 +188,12 @@ def train_and_evaluate(args):
input_shape = [args.batch_size, IMAGE_H, IMAGE_W, IMAGE_C]
label_shape = [args.batch_size]
with te.fp8_autocast(enabled=args.use_fp8):
if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
else:
fp8_recipe = None
with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe):
cnn = Net(args.use_te)
var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16))
tx = optax.sgd(args.lr, args.momentum)
......@@ -276,6 +290,12 @@ def mnist_parser(args):
"It also enables Transformer Engine implicitly."
),
)
parser.add_argument(
"--fp8-recipe",
action="store_true",
default="DelayedScaling",
help="Use FP8 recipe (default: DelayedScaling)",
)
parser.add_argument(
"--use-te", action="store_true", default=False, help="Use Transformer Engine"
)
......@@ -286,7 +306,8 @@ def mnist_parser(args):
class TestMNIST(unittest.TestCase):
"""MNIST unittests"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
@classmethod
def setUpClass(cls):
......@@ -298,13 +319,14 @@ class TestMNIST(unittest.TestCase):
"""Check If loss and accuracy match target"""
desired_traing_loss = 0.055
desired_traing_accuracy = 0.98
desired_test_loss = 0.04
desired_test_loss = 0.045
desired_test_accuracy = 0.098
assert actual[0] < desired_traing_loss
assert actual[1] > desired_traing_accuracy
assert actual[2] < desired_test_loss
assert actual[3] > desired_test_accuracy
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
self.args.use_te = True
......@@ -312,10 +334,19 @@ class TestMNIST(unittest.TestCase):
actual = train_and_evaluate(self.args)
self.verify(actual)
@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
self.verify(actual)
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
self.verify(actual)
......
......@@ -20,16 +20,15 @@ pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
: ${TE_PATH:=/opt/transformerengine}
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py || test_fail "test_praxis_layers.py"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*"
# Test without custom calls
NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py"
NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py without TE custom calls"
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist || test_fail "test_mnist.py"
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements"
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -xe
pip install "nltk>=3.8.2"
pip install pytest==8.2.1
: ${TE_PATH:=/opt/transformerengine}
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py
# Test without custom calls
NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py
pip install -r $TE_PATH/examples/jax/mnist/requirements.txt
pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py
......@@ -82,7 +82,7 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
'i32[1024]{0}',
'bf16[1024,1024]{0}'
"""
match = re.search(r"(i|f)(\d+).*\[([0-9,]*)\]", t)
match = re.search(r"(i|f|u)(\d+).*\[([0-9,]*)\]", t)
_, bits_of_type, shape = match.groups()
bytes_of_type = int(bits_of_type) // 8
if shape == "":
......
This diff is collapsed.
......@@ -6,7 +6,6 @@ import os
import pytest
import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from distributed_test_base import (
generate_configs,
......@@ -104,7 +103,7 @@ class TestDistributedSelfAttn:
hidden,
None, # no window
):
pytest.skip(f"No FusedAttn backend found")
pytest.skip("No FusedAttn backend found")
col_ref = self.generate_collectives_count_ref(
mesh_shape,
......@@ -176,7 +175,7 @@ class TestDistributedCrossAttn:
hidden,
None, # no window
):
pytest.skip(f"No FusedAttn backend found")
pytest.skip("No FusedAttn backend found")
col_ref = self.generate_collectives_count_ref()
runner = FusedAttnRunner(
......@@ -256,7 +255,6 @@ class TestDistributedContextParallelSelfAttn:
dropout_prob = 0.0
is_training = True
dp_size, cp_size, tp_size = mesh_shape
qkv_format = qkv_layout.get_qkv_format()
batch, seqlen, num_head, hidden = data_shape
......@@ -382,7 +380,7 @@ class TestDistributedContextParallelSelfAttn:
if qkv_layout.is_thd() and not load_balanced:
pytest.skip("THD + ring doesn't support unbalanced context parallelism.")
return self.impl_test_context_parallel_attn(
self.impl_test_context_parallel_attn(
device_count,
mesh_shape,
mesh_axes,
......@@ -396,6 +394,7 @@ class TestDistributedContextParallelSelfAttn:
CPStrategy.RING,
)
del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"]
return
class TestReorderCausalLoadBalancing:
......
......@@ -13,11 +13,30 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from utils import pytest_parametrize_wrapper
from transformer_engine.jax import fp8_autocast
from transformer_engine.common import recipe
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.quantize import QuantizerFactory, ScalingMode, is_fp8_available
DTYPES = [jnp.bfloat16, jnp.float32]
NORM_INPUT_SHAPES = {
"L0": [[64, 64]],
"L2": [[64, 64]],
}
is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
SUPPORTED_RECIPES = []
if is_fp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
class TestDistributedLayernorm:
......@@ -41,25 +60,32 @@ class TestDistributedLayernorm:
return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)
def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype):
def generate_collectives_count_ref(
self, mesh_resource, ln_type, shape, dtype, mesh_axes, fp8_recipe
):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
is_dp_enabled = mesh_resource.dp_resource is not None
assert ln_type in ["layernorm", "rmsnorm"]
all_reduce_loss_bytes = 4 # 1 * FP32
# for loss, dgamma and dbeta
weight_count = 2 if ln_type == "layernorm" else 1
# TODO(Jeremy): debug this check because layernorm should always have 2x weights regardless of dp
weight_count = 2 if (ln_type == "layernorm" and "dp" in mesh_axes) else 1
allreduce_total_bytes = (
all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
)
other_bytes = 0
if fp8_recipe == recipe.MXFP8BlockScaling() and "dp" in mesh_axes:
other_bytes = 384 # required for small scale shapes that require padding
return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=0
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("shard_weights", [False, True])
@pytest_parametrize_wrapper("data_shape", NORM_INPUT_SHAPES)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("zero_centered_gamma", [False, True])
@pytest_parametrize_wrapper("shard_weights", [False, True])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm(
self,
device_count,
......@@ -70,12 +96,19 @@ class TestDistributedLayernorm:
dtype,
zero_centered_gamma,
shard_weights,
fp8_recipe,
):
epsilon = 1e-6
ln_type = "layernorm"
q_dtype = jnp.float8_e4m3fn
def target_func(x, gamma, beta):
return jnp.mean(layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon))
quantizer = QuantizerFactory.create_set().x
return jnp.mean(
layernorm(
x, gamma, beta, ln_type, zero_centered_gamma, epsilon, quantizer=quantizer
)
)
def ref_func(x, gamma, beta):
x_ = jnp.asarray(x, jnp.float32)
......@@ -92,11 +125,11 @@ class TestDistributedLayernorm:
data_shape, mesh_resource, dtype, shard_weights
)
collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype
mesh_resource, ln_type, data_shape, dtype, mesh_axes, fp8_recipe
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec))
......@@ -109,8 +142,8 @@ class TestDistributedLayernorm:
[x_, gamma_, beta_],
collective_count_ref,
grad_args=(0, 1, 2),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
metric_fwd_dtype=q_dtype,
metric_bwd_dtype=q_dtype,
in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)),
)
......@@ -131,17 +164,28 @@ class TestDistributedLayernorm:
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shard_weights", [False, True])
@pytest_parametrize_wrapper("data_shape", NORM_INPUT_SHAPES)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("shard_weights", [False, True])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_rmsnorm(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
dtype,
shard_weights,
fp8_recipe,
):
epsilon = 1e-6
ln_type = "rmsnorm"
q_dtype = jnp.float8_e4m3fn
def target_func(x, gamma):
return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon))
quantizer = QuantizerFactory.create_set().x
return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon, quantizer=quantizer))
def ref_func(x, gamma):
x = jnp.asarray(x, jnp.float32)
......@@ -154,11 +198,11 @@ class TestDistributedLayernorm:
data_shape, mesh_resource, dtype, shard_weights
)
collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype
mesh_resource, ln_type, data_shape, dtype, mesh_axes, fp8_recipe
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
......@@ -170,8 +214,8 @@ class TestDistributedLayernorm:
[x_, gamma_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
metric_fwd_dtype=q_dtype,
metric_bwd_dtype=q_dtype,
in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec)),
)
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Callable, Sequence, Union, Optional
import pytest
from typing import Callable, List, Sequence, Union
import jax
import jax.numpy as jnp
import numpy as np
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from utils import (
assert_allclose,
assert_tree_like_allclose,
is_devices_enough,
pytest_parametrize_wrapper,
)
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.common import recipe
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.flax import LayerNormMLP
from transformer_engine.jax.layernorm_mlp import fused_layernorm_fp8_mlp
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
from transformer_engine.jax.sharding import (
HIDDEN_AXES,
HIDDEN_TP_AXES,
......@@ -26,17 +32,25 @@ from transformer_engine.jax.sharding import (
W_JOINED_AXES,
)
from transformer_engine.jax.sharding import MeshResource
from transformer_engine.jax.quantize import QuantizerFactory
from utils import assert_allclose, assert_tree_like_allclose, is_devices_enough
is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
SUPPORTED_RECIPES = []
if is_fp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
DTYPES = [jnp.bfloat16, jnp.float16]
INPUT_SHAPE = [[64, 128, 32]] # [batch, seqlen, hidden_in]
INPUT_SHAPE = [[2, 64, 64]] # [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
DOT_2_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_TP_AXES)
INTERMEDIATE = 16
INTERMEDIATE = 64
# Only test with FSDP and TP as DP is not used
......@@ -66,13 +80,13 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
k1 = jax.random.normal(
subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
subkeys[1], (hidden_in, len(activation_type) * INTERMEDIATE), dtype
) / jnp.sqrt(hidden_in)
k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
INTERMEDIATE
)
if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
b1 = jax.random.normal(subkeys[3], (len(activation_type) * INTERMEDIATE), dtype)
b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
else:
b1 = None
......@@ -86,35 +100,13 @@ class TestDistributedLayernormMLP:
ln_scale: jnp.ndarray,
kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray,
bias_1: jnp.ndarray,
bias_2: jnp.ndarray,
amax_list_1: List[jnp.ndarray],
amax_list_2: List[jnp.ndarray],
scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray],
bias_1: Optional[jnp.ndarray],
bias_2: Optional[jnp.ndarray],
layernorm_type: str = "rmsnorm",
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
use_bias: bool = True,
multi_gpus: bool = False,
) -> jnp.ndarray:
fp8_meta_pkg1 = FP8MetaPackage(
amax_list_1[0],
scale_list_1[0],
amax_list_1[1],
scale_list_1[1],
amax_list_1[2],
scale_list_1[2],
)
fp8_meta_pkg2 = FP8MetaPackage(
amax_list_2[0],
scale_list_2[0],
amax_list_2[1],
scale_list_2[1],
amax_list_2[2],
scale_list_2[2],
)
if multi_gpus:
layernorm_input_axes = LAYERNORM_INPUT_AXES
dot_1_input_axes = DOT_1_INPUT_AXES
......@@ -124,83 +116,64 @@ class TestDistributedLayernormMLP:
dot_1_input_axes = None
dot_2_input_axes = None
quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2)
# out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
return jnp.mean(
fused_layernorm_fp8_mlp(
layernorm_mlp(
x,
ln_scale,
None,
[kernel_1, kernel_2],
[bias_1, bias_2],
[fp8_meta_pkg1, fp8_meta_pkg2],
layernorm_type,
layernorm_input_axes=layernorm_input_axes,
norm_input_axes=layernorm_input_axes,
dot_1_input_axes=dot_1_input_axes,
dot_2_input_axes=dot_2_input_axes,
activation_type=activation_type,
use_bias=use_bias,
quantizer_sets=quantizer_sets,
)
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
def test_layernorm_fp8_mlp_primitive(
self, mesh_config, activation_type, use_bias, input_shape, dtype
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
):
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
layernorm_type = "rmsnorm"
fp8_amax_list_1 = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
]
fp8_amax_list_2 = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
]
fp8_scale_list_1 = [
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
]
fp8_scale_list_2 = [
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
]
inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs(
input_shape, activation_type, use_bias, dtype
)
inputs = [*inputs, fp8_amax_list_1, fp8_amax_list_2, fp8_scale_list_1, fp8_scale_list_2]
static_inputs = [layernorm_type, activation_type, use_bias]
static_inputs = [layernorm_type, activation_type]
value_and_grad_func = jax.value_and_grad(
self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
)
# Single GPU
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
single_jitter = jax.jit(
value_and_grad_func, static_argnums=range(len(inputs), len(static_inputs) + len(inputs))
value_and_grad_func,
static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
)
with fp8_autocast(enabled=True):
single_fwd, single_grads = single_jitter(*inputs, *static_inputs)
# Multi GPUs
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, mesh_resource=mesh_resource):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding)
if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_sharding = NamedSharding(mesh, PartitionSpec("tp"))
b1_ = jax.device_put(b1, b1_sharding)
else:
b1_sharding = b1_ = None
......@@ -208,7 +181,7 @@ class TestDistributedLayernormMLP:
# Position ref for sharding pspec lists
# x, gamma, k1, k2, b1,
# b2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
# b2
in_shardings = (
None,
None,
......@@ -216,14 +189,10 @@ class TestDistributedLayernormMLP:
k2_sharding,
b1_sharding,
None,
None,
None,
None,
None,
)
out_shardings = (
None,
(None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None, None, None),
(None, None, k1_sharding, k2_sharding, b1_sharding, None),
)
multi_jitter = jax.jit(
......@@ -245,15 +214,42 @@ class TestDistributedLayernormMLP:
m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close"
)
else:
is_gated = len(activation_type) > 1
rtol = None
atol = None
if is_gated:
if dtype == jnp.bfloat16:
if i == 2:
rtol = 800
atol = 9e-2
if i == 4:
atol = 300
rtol = 1e-1
if dtype == jnp.float16:
if i == 1: # gamma
rtol = 200
atol = 1e-2
if i == 2:
rtol = 2000
atol = 7e-2
if i == 4 and fp8_recipe == recipe.MXFP8BlockScaling(): # bias_1
# Accumulating dbias across a large tensor introduces a larger difference
rtol = 200
atol = 4e-2
if i == 4 and fp8_recipe == recipe.DelayedScaling():
rtol = 2200
atol = 9e-2
assert_allclose(
multi_grads[i],
single_grads[i],
dtype=dtype,
rtol=rtol,
atol=atol,
err_msg=f"multi_grads[{i}] is not close",
)
def _test_layernorm_mlp(
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8, fp8_recipe=None
):
batch, seqlen, hidden_in = input_shape
layernorm_type = "rmsnorm"
......@@ -265,7 +261,7 @@ class TestDistributedLayernormMLP:
init_rngs = {"params": subkeys[1]}
# Single GPUs
with fp8_autocast(enabled=use_fp8):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type,
transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
......@@ -282,7 +278,9 @@ class TestDistributedLayernormMLP:
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=use_fp8, mesh_resource=mesh_resource):
with mesh, fp8_autocast(
enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
):
ln_mlp_sharded = LayerNormMLP(
layernorm_type=layernorm_type,
transpose_batch_sequence=False,
......@@ -310,25 +308,30 @@ class TestDistributedLayernormMLP:
assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype)
@pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize("activation_type", [("gelu",), ("silu", "linear"), ("gelu", "gelu")])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype):
self._test_layernorm_mlp(
mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear"), ("gelu", "gelu")])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize("dtype", DTYPES)
def test_layernorm_fp8_mlp_layer(
self, mesh_config, activation_type, use_bias, input_shape, dtype
):
self._test_layernorm_mlp(
mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=True
)
# TODO: debug
# @pytest.mark.skipif(not is_fp8_supported, reason=reason)
# @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
# @pytest_parametrize_wrapper(
# "activation_type", [("gelu",), ("gelu", "linear")]
# )
# @pytest_parametrize_wrapper("use_bias", [True, False])
# @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
# @pytest_parametrize_wrapper("dtype", DTYPES)
# @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
# def test_layernorm_fp8_mlp_layer(
# self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
# ):
# self._test_layernorm_mlp(
# mesh_config, activation_type, use_bias, input_shape, dtype,
# use_fp8=True, fp8_recipe=fp8_recipe
# )
......@@ -3,8 +3,8 @@
# See LICENSE for license information.
import warnings
import pytest
from functools import partial
import pytest
import jax
import jax.numpy as jnp
......
......@@ -13,13 +13,13 @@ from utils import assert_allclose
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available, AmaxComputeAlgo
from transformer_engine.jax.quantize import QuantizeConfig, is_fp8_available, AmaxComputeAlgo
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
is_fp8_supported, reason = is_fp8_available()
class TestFP8Helper(unittest.TestCase):
class TestQuantizeConfig(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_initialize(self):
......@@ -27,30 +27,30 @@ class TestFP8Helper(unittest.TestCase):
fp8_format = FP8Format.E4M3
amax_history_len = 10
FP8Helper.initialize(
QuantizeConfig.initialize(
margin=margin, fp8_format=fp8_format, amax_history_len=amax_history_len
)
self.assertEqual(
FP8Helper.MARGIN,
QuantizeConfig.MARGIN,
margin,
f"FP8Helper.MARGIN initialization failed, should be {margin}"
f" but got {FP8Helper.MARGIN}.",
f"QuantizeConfig.MARGIN initialization failed, should be {margin}"
f" but got {QuantizeConfig.MARGIN}.",
)
self.assertEqual(
FP8Helper.FP8_FORMAT,
QuantizeConfig.FP8_FORMAT,
fp8_format,
f"FP8Helper.FP8_FORMAT initialization failed, should be {fp8_format}"
f" but got {FP8Helper.FP8_FORMAT}.",
f"QuantizeConfig.FP8_FORMAT initialization failed, should be {fp8_format}"
f" but got {QuantizeConfig.FP8_FORMAT}.",
)
self.assertEqual(
FP8Helper.AMAX_HISTORY_LEN,
QuantizeConfig.AMAX_HISTORY_LEN,
amax_history_len,
f"FP8Helper.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
f" but got {FP8Helper.AMAX_HISTORY_LEN}.",
f"QuantizeConfig.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
f" but got {QuantizeConfig.AMAX_HISTORY_LEN}.",
)
FP8Helper.finalize()
QuantizeConfig.finalize()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_update_collections(self):
......@@ -61,12 +61,12 @@ class TestFP8Helper(unittest.TestCase):
"test1": original_val,
"test2": original_val,
}
updated_state = FP8Helper.update_collections({"test1": updated_val}, original_state)
updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val)
original_state = flax.core.frozen_dict.FrozenDict(original_state)
updated_state = FP8Helper.update_collections({"test1": updated_val}, original_state)
updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val)
......@@ -74,7 +74,7 @@ class TestFP8Helper(unittest.TestCase):
class TestFP8Functions(unittest.TestCase):
def _check_defult_state(self):
self.assertFalse(FP8Helper.is_fp8_enabled())
self.assertFalse(QuantizeConfig.is_fp8_enabled())
def _compare_delay_scaling(self, ref, test):
self.assertTrue(ref.margin == test.margin)
......@@ -84,32 +84,32 @@ class TestFP8Functions(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests.
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
self.assertFalse(FP8Helper.is_fp8_enabled())
self.assertFalse(QuantizeConfig.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), DelayedScaling())
self._check_defult_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(FP8Helper.is_fp8_enabled())
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_defult_state()
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(FP8Helper.is_fp8_enabled())
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self._check_defult_state()
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_with_sharding_resource(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests.
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
......@@ -126,7 +126,7 @@ class TestFP8Functions(unittest.TestCase):
with jax.sharding.Mesh(devices, ("dp", "tp")):
for sr in mesh_s:
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr):
self.assertTrue(FP8Helper.is_fp8_enabled())
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)
self.assertEqual(sr, global_mesh_resource())
......
......@@ -20,11 +20,14 @@ from utils import (
from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer
from transformer_engine.common.recipe import Format
from transformer_engine.common import recipe
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
is_fp8_supported, reason = is_fp8_available()
from transformer_engine.jax.quantize import (
QuantizeConfig,
ScalingMode,
is_fp8_available,
update_collections,
)
@pytest.fixture(autouse=True, scope="function")
......@@ -35,12 +38,21 @@ def enable_fused_attn():
del os.environ["NVTE_FUSED_ATTN"]
is_fp8_supported, reason = is_fp8_available()
is_mxfp8_supported, reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
QUANTIZE_RECIPES = []
""" Find supported scaling modes"""
if is_fp8_supported:
QUANTIZE_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
if is_mxfp8_supported:
QUANTIZE_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
DATA_SHAPE = [ # (batch, seqlen, emb_dim)
pytest.param((32, 128, 1024), id="32-128-1024"),
pytest.param((32, 512, 1024), id="32-512-1024"),
]
DTYPE = [jnp.float32, jnp.bfloat16]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]
DTYPE = [jnp.bfloat16]
_KEY_OF_RESIDUAL_POST_LAYERNORM = "apply_residual_connection_post_layernorm"
_KEY_OF_OUTPUT_LAYERNORM = "output_layernorm"
......@@ -80,27 +92,37 @@ BASE_ATTRS = {
}
ATTRS = [
# attrs0
{},
# attrs1
{
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
},
# attrs2
{
_KEY_OF_ZERO_CENTERED_GAMMA: True,
_KEY_OF_LAYERNORM_EPS: 1e-2,
},
# attrs3
{_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_RESIDUAL_POST_LAYERNORM: True},
# attrs4
{_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_OUTPUT_LAYERNORM: True},
# attrs5
{
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_RESIDUAL_POST_LAYERNORM: True,
_KEY_OF_OUTPUT_LAYERNORM: True,
},
# attrs6
{_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_DROP_PATH: 0.1},
# attrs7
{_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_FUSE_QKV_PARAMS: False},
# attrs8
{
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
},
# attrs9
{
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
......@@ -109,12 +131,14 @@ ATTRS = [
_KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
_KEY_OF_USE_BIAS: True,
},
# attrs10
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
},
# attrs11
{
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4,
......@@ -123,33 +147,7 @@ ATTRS = [
_KEY_OF_MLP_ACTIVATIONS: ("gelu",),
_KEY_OF_USE_BIAS: True,
},
{
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")),
},
{
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_HIDDEN_DROPOUT: 0.8,
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
_KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")),
_KEY_OF_USE_BIAS: True,
},
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")),
},
{
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: "layernorm",
_KEY_OF_MLP_ACTIVATIONS: (("silu",)),
_KEY_OF_USE_BIAS: True,
},
# attrs12
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
......@@ -158,12 +156,14 @@ ATTRS = [
_KEY_OF_ROPE_GROUP_METHOD: "consecutive",
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
# attrs13
{
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "consecutive",
_KEY_OF_USE_BIAS: True,
},
# attrs14
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: "layernorm",
......@@ -173,6 +173,7 @@ ATTRS = [
_KEY_OF_USE_BIAS: True,
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
# attrs15
{
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
......@@ -180,26 +181,32 @@ ATTRS = [
_KEY_OF_ROPE_GROUP_METHOD: "alternate",
_KEY_OF_USE_BIAS: True,
},
# attrs16
{
_KEY_OF_HIDDEN_DROPOUT: 0.3,
_KEY_OF_HIDDEN_DROPOUT_DIMS: (0,),
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,),
},
# attrs17
{
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_USE_BIAS: True,
},
# attrs18
{
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias",
},
# attrs19
{
_KEY_OF_ATTENTION_DROPOUT: 0.3,
},
# attrs20
{
_KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")),
},
# attrs21
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
......@@ -207,6 +214,7 @@ ATTRS = [
_KEY_OF_WINDOW_SIZE: (64, 0), # Left size must < DATA_SHAPE seqlen
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
# attrs22
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_RELATIVE_EMBEDDING: False,
......@@ -296,20 +304,24 @@ class BaseRunner:
ref_params, test_params = self._sync_params(ref_params, test_params)
if FP8Helper.is_fp8_enabled():
if QuantizeConfig.is_fp8_enabled():
for _ in range(4):
_, tmp_grad = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
_, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
inputs,
test_masks,
test_params,
test_others,
test_layer,
)
_, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
test_others = FP8Helper.update_collections(
{FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others
if QuantizeConfig.SCALING_MODE == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
_, updated_quantize_meta = flax.core.pop(
updated_state[0], QuantizeConfig.COLLECTION_NAME
)
test_others = update_collections(
{QuantizeConfig.COLLECTION_NAME: updated_quantize_meta}, test_others
)
del tmp_grad, fp8_meta_grad
del updated_quantize_meta
del updated_state
grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False)
......@@ -436,29 +448,29 @@ class BaseTester:
def test_forward(self, data_shape, dtype, attrs):
"""Test normal datatype forward"""
FP8Helper.finalize() # Ensure FP8 disabled.
QuantizeConfig.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_forward(data_shape, dtype)
def test_backward(self, data_shape, dtype, attrs):
"""Test normal datatype backward"""
FP8Helper.finalize() # Ensure FP8 disabled.
QuantizeConfig.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_backward(data_shape, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_format):
@pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test forward with fp8 enabled"""
FP8Helper.initialize(fp8_format=fp8_format)
QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
FP8Helper.finalize()
QuantizeConfig.finalize()
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_format):
@pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test backward with fp8 enabled"""
FP8Helper.initialize(fp8_format=fp8_format)
QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
FP8Helper.finalize()
QuantizeConfig.finalize()
class TestEncoderLayer(BaseTester):
......
This diff is collapsed.
......@@ -18,13 +18,14 @@ from flax.linen.attention import combine_masks
from jax import lax, vmap
from jax import nn as jax_nn
from jax import random as jax_random
import pytest
from transformer_engine.jax.attention import (
AttnMaskType,
canonicalize_attn_mask_type,
make_swa_mask,
)
from transformer_engine.jax.fp8 import DType as TEDType
from transformer_engine.jax.quantize.helper import DType as TEDType
PRNGKey = Any
Shape = Tuple[int, ...]
......@@ -96,6 +97,62 @@ def combine_biases(*masks: Optional[Array]):
return mask
def parameterize_by_test_level(param_dict: dict, id_prefix: str = ""):
"""
Takes an input dictionary of parameters keyed by test type "L0", etc.
Returns a list of pytest parameters to be used in a parameterized test for the current test type
"""
DEFAULT_TEST_LEVEL = "L0"
test_level = os.environ.get("NVTE_JAX_UNITTEST_LEVEL", DEFAULT_TEST_LEVEL)
if test_level not in param_dict:
raise ValueError("Unsupported test level")
return values_to_named_params(param_dict[test_level], id_prefix)
def value_to_test_name_str(value):
"""Converts a value to how it should appear in a test name."""
if isinstance(value, tuple) or isinstance(value, list):
return "_".join([value_to_test_name_str(v) for v in value])
dtype_type = type(jnp.float32)
if isinstance(value, dtype_type):
return value.dtype
return str(value)
def value_to_named_param(value, id_prefix: str = ""):
param_type = type(pytest.param(0))
if isinstance(value, param_type):
return value
x = pytest.param(value, id=f"{id_prefix}_{value_to_test_name_str(value)}")
return x
def values_to_named_params(params, id_prefix: str = ""):
return [value_to_named_param(v, id_prefix=id_prefix) for v in params]
def pytest_parametrize_wrapper(param_name, param_values):
"""
A wrapper for pytest.mark.parametrize to allow for automatic
naming of tests based on the parameter values.
"""
id_prefix = param_name
if isinstance(param_values, dict):
param_values = parameterize_by_test_level(param_values, id_prefix=param_name)
elif "," not in param_name:
param_values = values_to_named_params(param_values, id_prefix=id_prefix)
# Currently comma separated parameters in one parametrize call aren't supported for automatic naming
# and will just be passed through with default pytest names
def decorator(func):
return pytest.mark.parametrize(param_name, param_values)(func)
return decorator
class DotProductAttention(nn.Module):
transpose_batch_sequence: bool = True
scale_attn_logits: bool = True
......@@ -140,6 +197,7 @@ class DotProductAttention(nn.Module):
Returns:
Output of shape `[batch, length, num_heads, v_depth_per_head]`.
"""
input_dtype = query.dtype
assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
batch_dim = 1 if self.transpose_batch_sequence else 0
assert (
......@@ -152,7 +210,7 @@ class DotProductAttention(nn.Module):
if self.scale_attn_logits:
head_dim = query.shape[-1]
depth_scaling = jnp.sqrt(head_dim).astype(self.dtype)
depth_scaling = jnp.sqrt(head_dim).astype(input_dtype)
query = query / depth_scaling
# Casting logits and softmax computation for float32 for model stability.
......@@ -181,7 +239,7 @@ class DotProductAttention(nn.Module):
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
# Normalize the attention weights across `kv_length` dimension.
attn_weights = jax_nn.softmax(attn_weights).astype(self.dtype)
attn_weights = jax_nn.softmax(attn_weights).astype(input_dtype)
# Apply attention dropout.
if not deterministic and self.dropout_rate > 0.0:
......@@ -191,16 +249,20 @@ class DotProductAttention(nn.Module):
dropout_shape = list(attn_weights.shape)
dropout_rng = self.make_rng("dropout")
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype)
multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype)
attn_weights = attn_weights * multiplier
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
attn_weights = attn_weights.astype(value.dtype)
# attn_weights = attn_weights.astype(input_dtype)
# Take the linear combination of `value`.
if self.transpose_batch_sequence:
return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
assert (
attn_weights.dtype == input_dtype
), f"input.dtype={input_dtype}, output.dtype={attn_weights.dtype}"
return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
......@@ -246,7 +308,6 @@ class DenseGeneral(nn.Module):
features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis)
inputs = jnp.asarray(inputs, self.dtype)
axis = _normalize_axes(axis, inputs.ndim)
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
......@@ -268,11 +329,14 @@ class DenseGeneral(nn.Module):
contract_ind = tuple(range(0, len(axis)))
y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
y = y.astype(input_dtype)
y = lax.dot_general(
inputs, kernel, ((axis, contract_ind), ((), ())), preferred_element_type=input_dtype
)
if bias is not None:
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
assert y.dtype == inputs.dtype, f"input.dtype={inputs.dtype}, output.dtype={y.dtype}"
return y
......@@ -352,6 +416,7 @@ class MlpBlock(nn.Module):
)(
x, deterministic=deterministic
) # Broadcast along length.
if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "mlp"))
else:
......@@ -365,6 +430,7 @@ class MlpBlock(nn.Module):
bias_axes="embed",
name="wo",
)(x)
assert (
output.dtype == inputs.dtype
), f"input.dtype={input.dtype}, output.dtype={output.dtype}"
......@@ -391,7 +457,7 @@ def apply_rotary_pos_emb_alternate(
second_part = second_half * cos + first_half * sin
first_part = first_part.astype(inputs.dtype)
second_part = second_part.astype(inputs.dtype)
return jnp.concatenate([first_part, second_part], axis=-1)
return jnp.concatenate([first_part, second_part], axis=-1).astype(inputs.dtype)
def apply_rotary_pos_emb_consecutive(
......@@ -425,7 +491,7 @@ def apply_rotary_pos_emb_consecutive(
sign = jnp.sign(jnp.mod(jnp.arange(embedding_dim, dtype=jnp.int32), 2) - 0.5)
outputs = inputs * cos + inputs_shifted * sin * sign
return outputs
return outputs.astype(inputs.dtype)
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
......@@ -559,6 +625,7 @@ class MultiHeadAttention(nn.Module):
if self.fuse_qkv:
if is_qkvpack:
qkv_proj = DenseGeneral(
axis=-1,
features=self.num_heads * self.head_dim * 3,
......@@ -569,11 +636,13 @@ class MultiHeadAttention(nn.Module):
name="qkv",
dtype=self.dtype,
)(inputs_kv)
query, key, value = jnp.split(
qkv_proj,
[self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2],
axis=-1,
)
else:
query = q_projection(kernel_init=query_init, name="query")(inputs_q)
......@@ -711,6 +780,7 @@ class MultiHeadAttention(nn.Module):
# Convert the boolean attention mask to an attention bias.
if mask is not None:
# attention mask in the form of attention bias
attention_bias = lax.select(
mask > 0,
jnp.full(mask.shape, 0.0).astype(self.dtype),
......@@ -740,6 +810,7 @@ class MultiHeadAttention(nn.Module):
x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "joined_kv"))
# Back to the original inputs dimensions.
out = DenseGeneral(
features=inputs_q.shape[-1], # output dim is set to the input dim.
axis=-1,
......@@ -750,6 +821,7 @@ class MultiHeadAttention(nn.Module):
dtype=self.dtype,
name="out",
)(x)
assert (
inputs_q.dtype == inputs_kv.dtype == out.dtype
), f"q.dtype={inputs_q.dtype}, kv.dtype={inputs_kv.dtype}, out.dtype={out.dtype}"
......@@ -784,12 +856,11 @@ class LayerNorm(nn.Module):
scale = nn_partitioning.param_with_axes(
"scale", self.scale_init, (features,), self.dtype, axes=("embed",)
)
scale = jnp.asarray(scale, input_dtype)
x_ = x.astype(jnp.float32)
if self.layernorm_type == "layernorm":
mean = jnp.mean(x, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True)
y = (x - mean) * lax.rsqrt(var + self.epsilon)
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
y = (x_ - mean) * lax.rsqrt(var + self.epsilon)
bias = nn_partitioning.param_with_axes(
"ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",)
......@@ -803,9 +874,10 @@ class LayerNorm(nn.Module):
else:
assert self.layernorm_type == "rmsnorm"
assert not self.zero_centered_gamma
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
y = x * lax.rsqrt(mean2 + self.epsilon)
mean2 = jnp.mean(lax.square(x_), axis=-1, keepdims=True)
y = x_ * lax.rsqrt(mean2 + self.epsilon)
z = y * scale
z = z.astype(input_dtype)
assert z.dtype == x.dtype, f"output_dtype={z.dtype}, input_dtype={x.dtype}"
return z
......@@ -1085,9 +1157,11 @@ class EncoderLayer(nn.Module):
fuse_wi=self.fuse_mlp_wi,
name="mlp",
)(y, deterministic=deterministic)
y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
y, deterministic=deterministic
)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim)
y = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
......@@ -1103,6 +1177,7 @@ class EncoderLayer(nn.Module):
dtype=self.dtype,
name="output_layernorm",
)(y)
assert y.dtype == inputs.dtype, f"output_dtype={y.dtype}, input_dtype={inputs.dtype}"
return y
......
......@@ -19,9 +19,4 @@ try:
except (ImportError, StopIteration) as e:
pass
try:
import transformer_engine_jax
except ImportError:
pass
__version__ = str(metadata.version("transformer_engine"))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment