"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "3f875fb57fcf2872d238f8c7cb199b171c424536"
Unverified Commit 214e2a4a authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[JAX] GEMM custom op (#1855)



* added XLA FFI custom op for TE/common nvte_cublas_gemm
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

started GemmPrimitive, abstract done
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

gemm custom op working with BF16, needs testing for FP8/MXFP8
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

converted TE GEMM API to use ScaledTensor and added os ENV flag to use TE GEMM under general gemm() call
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

BF16 tests passing, FP8 tests should be passing but contracting_dims has a scoping issue
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fp8 tests passing for E4M3, getting CUBLAS_STATUS_NOT_SUPPORTED for E5M2
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

updated GEMM API to use separate LHS and RHS quantizers instead of a QuantizerSet
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

new GemmPrimitive passing all Dense tests
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

import cleanup and reverted code chunk movement
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

removed unused .transpose() implementations from ScaledTensors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

all custom call tests passing on Hopper, GEMM-related tests cover both GemmPrimitive and native JAX impl
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

removed direct calls to GemmPrimitive.enabled() from outside of cpp_extensions
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

removed unused changes to ScaledTensor classes and debug prints
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* minor unit test cleanup
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* FP8 tests passing on Blackwell but MXFP8 outputs NaN
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* reverted dense and fuseddense changes, FP8 test passing on Hopper and Blackwell, MXFP8 has issues with E5M2
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* MXFP8 issue traced to scale factor padding with NaNs instead of zeros
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* padding scale with 2^-127 instead of nans
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix bug on rhs_scale_inv usage
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* cleanup E8M0 type converter use it in gemm.cpp
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* segfault fixed, passing all unittests on Blackwell
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fix for fuseddense tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix workspace alignment
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed GemmPrimitive custom partitioning to match jax.nn.scaled_matmul
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

all unit tests passing on H100x8 node
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



linting fixes
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fixed batch dimension numbers
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fixed FP8 scale sharding rule when there are no FP8 scales
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

added error message for unsupported Shardy partitioner
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fixed test tolerances for FP8 cases
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fixed shardy test skip cases
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* moved reshape of encoder output in encoder examples to make custom partitioning rules work correctly
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added helper functions for padding and unpadding block scales, changed GemmPrimitive to accept unpadded scales and pad them after sharding
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* updated shardy rules for all custom ops to decouple block scale rules from their tensors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed linting errors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* changed unit test use_jax_gemm option to be a context to preserve external custom op settings, tightened multi-GPU encoder test tolerances, changed gemm() API to use contracting_dims and batched_dims separately instead of dimension_numbers
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed typo in test utils
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added sequence-first input warnings
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed datasets version for JAX examples
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* reverting modification to force_1x_quantization decision
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected gemm function syntax in unit tests
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 397c4be6
......@@ -30,7 +30,7 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
LOG_FILE="${TEST_CASE}_gpu_${i}.log"
# Run pytest and redirect stdout and stderr to the log file
pytest -c "$TE_PATH/tests/jax/pytest.ini" \
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \
--num-process=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 &
......
......@@ -25,6 +25,7 @@ from common import (
assert_params_sufficiently_sharded,
)
import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
......@@ -465,14 +466,14 @@ class TestEncoder(unittest.TestCase):
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
def setUp(self):
"""Run 3 epochs for testing"""
self.args = encoder_parser(["--epochs", "3"])
"""Run 5 epochs for testing"""
self.args = encoder_parser(["--epochs", "5"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
......@@ -480,7 +481,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
......@@ -488,14 +489,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self):
......@@ -504,7 +505,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self):
......@@ -513,14 +514,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
......@@ -529,7 +530,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp_shardy(self):
......@@ -539,9 +540,32 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.43 and actual[1] > 0.80
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_with_sp_shardy(self):
"""Test Transformer Engine with MXFP8 + SP"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80
if __name__ == "__main__":
......
......@@ -21,6 +21,7 @@ from jax.sharding import PartitionSpec, NamedSharding
from common import is_bf16_supported, get_fp8_recipe_from_name_string
import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
......@@ -430,14 +431,14 @@ class TestEncoder(unittest.TestCase):
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
def setUp(self):
"""Run 3 epochs for testing"""
self.args = encoder_parser(["--epochs", "3"])
"""Run 5 epochs for testing"""
self.args = encoder_parser(["--epochs", "6"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
......@@ -445,7 +446,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self):
......@@ -453,7 +454,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
......@@ -461,14 +462,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
......@@ -477,9 +478,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8_shardy(self):
......@@ -488,7 +487,19 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.75
if __name__ == "__main__":
......
......@@ -28,8 +28,8 @@ from common import (
get_fp8_recipe_from_name_string,
)
import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
......@@ -584,8 +584,8 @@ class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False):
"""Run 3 epochs for testing"""
args = encoder_parser([])
"""Run 5 epochs for testing"""
args = encoder_parser(["--epochs", "5"])
num_gpu = self.num_process
tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1
......@@ -607,7 +607,7 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
result = self.exec(False, None)
assert result[0] < 0.505 and result[1] > 0.755
assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8"
......@@ -615,7 +615,7 @@ class TestEncoder(unittest.TestCase):
def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling")
assert result[0] < 0.506 and result[1] > 0.753
assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
......@@ -623,7 +623,7 @@ class TestEncoder(unittest.TestCase):
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling")
assert result[0] < 0.507 and result[1] > 0.753
assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
......@@ -631,13 +631,13 @@ class TestEncoder(unittest.TestCase):
def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8"""
result = self.exec(True, "MXFP8BlockScaling")
assert result[0] < 0.505 and result[1] > 0.754
assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
result = self.exec(False, None, enable_shardy=True)
assert result[0] < 0.505 and result[1] > 0.755
assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8"
......@@ -645,9 +645,7 @@ class TestEncoder(unittest.TestCase):
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling", enable_shardy=True)
assert result[0] < 0.506 and result[1] > 0.753
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
......@@ -655,7 +653,18 @@ class TestEncoder(unittest.TestCase):
def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling", enable_shardy=True)
assert result[0] < 0.507 and result[1] > 0.753
assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True)
assert result[0] < 0.43 and result[1] > 0.80
if __name__ == "__main__":
......
......@@ -13,6 +13,7 @@ import operator
from utils import (
assert_allclose,
pytest_parametrize_wrapper,
use_jax_gemm,
)
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
......@@ -30,7 +31,6 @@ from transformer_engine.jax.cpp_extensions.quantization import (
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
from transformer_engine.jax import cpp_extensions as tex
from transformer_engine.jax.quantize import (
DelayedScaleQuantizer,
ScaledTensor,
ScaledTensor1x,
ScaledTensor2x,
......@@ -851,6 +851,22 @@ class TestFusedQuantize:
)
valid_fp8_gemm_operand_types = [
(jnp.float8_e4m3fn, jnp.float8_e4m3fn),
(jnp.float8_e5m2, jnp.float8_e4m3fn),
(jnp.float8_e4m3fn, jnp.float8_e5m2),
]
def _use_jax_fp8_gemm(enabled=False):
import os
if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$"
elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ:
os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE")
class TestDense:
def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
if data_layout[0] == "T":
......@@ -883,27 +899,47 @@ class TestDense:
def test_gemm_bf16(self, m, n, k, data_layout):
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
primitive_out = tex.gemm(x, w, contracting_dims)
primitive_out = tex.gemm(x, w, contracting_dims=contracting_dims)
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout):
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm):
if (
not with_jax_gemm
and scaling_mode.is_1d_block_scaling()
and jnp.float8_e5m2 in (x_qtype, w_qtype)
):
pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.")
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False
)
primitive_out = tex.gemm(
x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set
scaling_mode=scaling_mode,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2,
is_2x2x=False,
)
with use_jax_gemm(enabled=with_jax_gemm):
primitive_out = tex.gemm(
x,
w,
contracting_dims=contracting_dims,
lhs_quantizer=(
quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad
),
rhs_quantizer=(
quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad
),
)
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=q_dtype)
assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
def test_dense_grad_bf16(self, m, n, k):
......@@ -932,9 +968,9 @@ class TestDense:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode):
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm):
data_layout = "NN"
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
......@@ -956,23 +992,27 @@ class TestDense:
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True
scaling_mode=scaling_mode,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True,
)
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations):
primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
)
with use_jax_gemm(enabled=with_jax_gemm):
for _ in range(n_iterations):
primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
)
ref_out, (ref_x_grad, ref_w_grad, ref_bias_grad) = value_n_grad_ref_func(
x, w, bias, data_layout
)
assert_allclose(primitive_out, ref_out, dtype=q_dtype)
assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=q_dtype)
assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=q_dtype)
assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp.float8_e5m2)
@pytest.fixture(name="random_inputs")
......@@ -996,20 +1036,13 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
class TestFusedDense:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type):
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_gemm):
"""
Test layernorm_dense VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm
zero_centered_gamma = False
eps = 1e-6
......@@ -1025,8 +1058,8 @@ class TestFusedDense:
quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode,
fwd_dtype=q_dtype,
bwd_dtype=q_dtype,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True,
)
......@@ -1064,41 +1097,35 @@ class TestFusedDense:
)
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations):
prim_out, (
prim_x_grad,
prim_w_grad,
prim_gamma_grad,
prim_beta_grad,
) = value_n_grad_prim_func(x, w, gamma, beta)
assert_allclose(prim_out, ref_out, dtype=q_dtype)
assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype)
assert_allclose(prim_w_grad, ref_w_grad, dtype=q_dtype)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype)
with use_jax_gemm(enabled=with_jax_gemm):
for _ in range(n_iterations):
prim_out, (
prim_x_grad,
prim_w_grad,
prim_gamma_grad,
prim_beta_grad,
) = value_n_grad_prim_func(x, w, gamma, beta)
assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2)
if beta is not None:
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype)
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp.float8_e5m2)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad(
self, m, n, k, activation_type, q_dtype, scaling_mode, norm_type, use_bias
self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm
):
"""
Test layernorm_mlp VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm
zero_centered_gamma = False
eps = 1e-6
......@@ -1123,8 +1150,8 @@ class TestFusedDense:
quantizer_sets = QuantizerFactory.create_set(
n_quantizer_sets=2,
scaling_mode=scaling_mode,
fwd_dtype=q_dtype,
bwd_dtype=q_dtype,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True,
)
......@@ -1153,14 +1180,13 @@ class TestFusedDense:
ln_out = _ref_jax_norm_impl(
x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
)
# TODO: replace gemm with jnp.dot
linear_1_out = tex.gemm(ln_out, kernel_1, ((1,), (0,)))
linear_1_out = jax.lax.dot_general(ln_out, kernel_1, (((1,), (0,)), ((), ())))
if use_bias:
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape)
x = _jax_act_lu(linear_1_out, activation_type)
linear_2_out = tex.gemm(x, kernel_2, ((1,), (0,)))
linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ())))
if use_bias:
bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape
linear_2_out += jnp.reshape(bias_2, bias_2_shape)
......@@ -1174,15 +1200,16 @@ class TestFusedDense:
value_n_grad_ref_func = value_and_grad(ref_func, range(6))
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations):
prim_out, (
prim_x_grad,
prim_gamma_grad,
prim_kernel_1_grad,
prim_kernel_2_grad,
prim_bias_1_grad,
prim_bias_2_grad,
) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)
with use_jax_gemm(enabled=with_jax_gemm):
for _ in range(n_iterations):
prim_out, (
prim_x_grad,
prim_gamma_grad,
prim_kernel_1_grad,
prim_kernel_2_grad,
prim_bias_1_grad,
prim_bias_2_grad,
) = value_n_grad_prim_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)
ref_out, (
ref_x_grad,
......@@ -1193,18 +1220,18 @@ class TestFusedDense:
ref_bias_2_grad,
) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)
assert_allclose(prim_out, ref_out, dtype=q_dtype)
assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=q_dtype)
assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp.float8_e5m2)
if use_bias:
assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=q_dtype)
assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=q_dtype)
assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp.float8_e5m2)
if use_bias:
assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=q_dtype)
assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype)
assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
# E5M2 * E5M2 is not supported
......
......@@ -13,6 +13,7 @@ from utils import (
assert_tree_like_allclose,
is_devices_enough,
pytest_parametrize_wrapper,
use_jax_gemm,
)
from transformer_engine.common import recipe
......@@ -147,7 +148,15 @@ class TestDistributedLayernormMLP:
)
def _test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, use_shardy
self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
use_shardy,
with_jax_gemm,
):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
......@@ -157,72 +166,83 @@ class TestDistributedLayernormMLP:
input_shape, activation_type, use_bias, dtype
)
static_inputs = [layernorm_type, activation_type]
value_and_grad_func = jax.value_and_grad(
self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
)
# Single GPU
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
single_jitter = jax.jit(
value_and_grad_func,
static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
)
single_fwd, single_grads = single_jitter(*inputs, *static_inputs)
# Multi GPUs
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding)
if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_ = jax.device_put(b1, b1_sharding)
else:
b1_sharding = b1_ = None
multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]]
# Position ref for sharding pspec lists
# x, gamma, k1, k2, b1,
# b2
in_shardings = (
None,
None,
k1_sharding,
k2_sharding,
b1_sharding,
None,
)
out_shardings = (
None,
(None, None, k1_sharding, k2_sharding, b1_sharding, None),
with use_jax_gemm(enabled=with_jax_gemm):
value_and_grad_func = jax.value_and_grad(
self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
)
multi_jitter = jax.jit(
value_and_grad_func,
in_shardings=in_shardings,
out_shardings=out_shardings,
static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1),
) # +1 for multi_gpus
multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)
assert_allclose(multi_fwd, single_fwd, dtype=dtype)
# Single GPU
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
single_jitter = jax.jit(
value_and_grad_func,
static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
)
single_fwd, single_grads = single_jitter(*inputs, *static_inputs)
# Multi GPUs
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding)
if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_ = jax.device_put(b1, b1_sharding)
else:
b1_sharding = b1_ = None
multi_inputs = [*inputs[:2], k1_, k2_, b1_, *inputs[5:]]
# Position ref for sharding pspec lists
# x, gamma, k1, k2, b1,
# b2
in_shardings = (
None,
None,
k1_sharding,
k2_sharding,
b1_sharding,
None,
)
out_shardings = (
None,
(None, None, k1_sharding, k2_sharding, b1_sharding, None),
)
multi_jitter = jax.jit(
value_and_grad_func,
in_shardings=in_shardings,
out_shardings=out_shardings,
static_argnums=range(
len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1
),
) # +1 for multi_gpus
multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)
fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn
bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)
for i in range(len(inputs)):
if multi_grads[i] is not None:
if isinstance(multi_grads[i], list):
assert isinstance(single_grads[i], list)
for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
assert_allclose(
m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close"
m_grad,
s_grad,
dtype=bwd_test_type,
err_msg=f"multi_grads[{i}] is not close",
)
else:
assert_allclose(
multi_grads[i],
single_grads[i],
dtype=dtype,
dtype=bwd_test_type,
err_msg=f"multi_grads[{i}] is not close",
)
......@@ -233,8 +253,16 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
with_jax_gemm,
):
self._test_layernorm_mlp_grad(
mesh_config,
......@@ -244,6 +272,7 @@ class TestDistributedLayernormMLP:
dtype,
fp8_recipe,
use_shardy=False,
with_jax_gemm=with_jax_gemm,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
......@@ -252,19 +281,29 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype
self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
with_jax_gemm,
):
# We don't test block scaling with Shardy because at the time of writing,
# it is not supported in JAX's scaled_matmul_stablehlo.
if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
self._test_layernorm_mlp_grad(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe=recipe.DelayedScaling(),
fp8_recipe=fp8_recipe,
use_shardy=True,
with_jax_gemm=with_jax_gemm,
)
def _test_layernorm_mlp(
......@@ -277,6 +316,7 @@ class TestDistributedLayernormMLP:
use_fp8,
fp8_recipe,
use_shardy,
with_jax_gemm,
):
jax.config.update("jax_use_shardy_partitioner", use_shardy)
batch, seqlen, hidden_in = input_shape
......@@ -288,48 +328,49 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
init_rngs = {"params": subkeys[1]}
# Single GPUs
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type,
transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
intermediate_dim=INTERMEDIATE,
activations=activation_type,
use_bias=use_bias,
)
params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
mlp_out_single, ln_out_single = ln_mlp_single.apply(
params_single, x, deterministic=True
)
# Multi GPUs
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(
enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
):
ln_mlp_sharded = LayerNormMLP(
layernorm_type=layernorm_type,
transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE,
activations=activation_type,
scale_axes=LN_SCALE_AXES,
ln_bias_axes=LN_BIAS_AXES,
kernel_axes_1=KERNEL_1_AXES,
kernel_axes_2=KERNEL_2_AXES,
use_bias=use_bias,
bias_axes_1=BIAS_1_AXES,
bias_axes_2=BIAS_2_AXES,
layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES,
name="mlp",
)
params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
params_sharded, x, deterministic=True
)
with use_jax_gemm(enabled=with_jax_gemm):
# Single GPUs
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type,
transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
intermediate_dim=INTERMEDIATE,
activations=activation_type,
use_bias=use_bias,
)
params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
mlp_out_single, ln_out_single = ln_mlp_single.apply(
params_single, x, deterministic=True
)
# Multi GPUs
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(
enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
):
ln_mlp_sharded = LayerNormMLP(
layernorm_type=layernorm_type,
transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE,
activations=activation_type,
scale_axes=LN_SCALE_AXES,
ln_bias_axes=LN_BIAS_AXES,
kernel_axes_1=KERNEL_1_AXES,
kernel_axes_2=KERNEL_2_AXES,
use_bias=use_bias,
bias_axes_1=BIAS_1_AXES,
bias_axes_2=BIAS_2_AXES,
layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES,
name="mlp",
)
params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
params_sharded, x, deterministic=True
)
# Make sure params values are the same
assert_tree_like_allclose(params_sharded["params"], params_single["params"])
......@@ -355,9 +396,9 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("use_shardy", [False, True])
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer(
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_shardy
self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
):
self._test_layernorm_mlp(
mesh_config,
......@@ -367,7 +408,8 @@ class TestDistributedLayernormMLP:
dtype,
use_fp8=False,
fp8_recipe=None,
use_shardy=use_shardy,
use_shardy=False,
with_jax_gemm=with_jax_gemm,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
......@@ -377,8 +419,9 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_fp8(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
):
self._test_layernorm_mlp(
mesh_config,
......@@ -389,4 +432,51 @@ class TestDistributedLayernormMLP:
use_fp8=True,
fp8_recipe=fp8_recipe,
use_shardy=False,
with_jax_gemm=with_jax_gemm,
)
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
):
self._test_layernorm_mlp(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=False,
fp8_recipe=None,
use_shardy=True,
with_jax_gemm=with_jax_gemm,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_fp8_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
):
if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
self._test_layernorm_mlp(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=True,
fp8_recipe=fp8_recipe,
use_shardy=True,
with_jax_gemm=with_jax_gemm,
)
......@@ -3,11 +3,12 @@
# See LICENSE for license information.
"""Utility for the TE layer tests"""
import os
import functools
import math
import operator
from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional
import os
from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional, NewType
from contextlib import contextmanager
import jax
import jax.numpy as jnp
......@@ -20,7 +21,6 @@ from jax import random as jax_random
import pytest
from transformer_engine.jax.attention import (
AttnMaskType,
canonicalize_attn_mask_type,
make_swa_mask,
)
......@@ -28,8 +28,8 @@ from transformer_engine.jax.quantize.helper import DType as TEDType
PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = Any
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
PrecisionLike = Union[
None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
......@@ -1519,7 +1519,7 @@ def dtype_tols(
TEDType.kFloat8E5M2: jnp.float8_e5m2,
}[dtype]
elif isinstance(dtype, np.dtype):
dtype = jnp.dtype(dtype)
dtype = DType(dtype)
# Expect bit-wise accuracy for integer dtypes
if not jnp.issubdtype(dtype, jnp.floating):
......@@ -1600,3 +1600,20 @@ def print_debug_tensor_stats(prefix, tensor, hist=False):
fmt = fmt + "\n {}\n {}"
jax.debug.print(fmt, *args)
@contextmanager
def use_jax_gemm(enabled=False):
orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None)
try:
if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$"
yield
finally:
if enabled:
if orig_custom_calls_filter is None:
os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE")
else:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls_filter
......@@ -415,37 +415,35 @@ class ActLuPrimitive(BasePrimitive):
result_types,
):
del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types
prefix = "ActLuPrimitive_"
x_rank = len(value_types[0].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank - 1, unique_var="ActLuPrimitive_i", flatten_axis=-2
x_rank - 1, unique_var=prefix + "x", flatten_axis=-2
)
x_axes = scale_rules.input_spec + (f"x{x_rank-1}",)
x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}",)
out = (*x_axes[:-2], x_axes[-1])
scale_inv = scale_rules.rowwise_rule
colwise_scale_inv = scale_rules.colwise_rule
colwise_out = (prefix + "out_colwise",)
colwise_scale_inv = (prefix + "scale_inv_colwise",)
if is_2x:
colwise_scale_inv = scale_rules.colwise_rule
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(
multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2)
)
else:
colwise_out = out
else:
colwise_out = ("j",)
colwise_scale_inv = ("k",)
# amax is always a unit tensor.
amax = ("l",)
amax = (prefix + "amax",)
return SdyShardingRule(
(
x_axes,
"…1",
("…1",),
),
(out, colwise_out, scale_inv, colwise_scale_inv, amax),
**scale_rules.factor_sizes,
)
......@@ -890,28 +888,26 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
result_types,
):
del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types
x_rank = len(value_types[1].shape)
prefix = "BaseDActLuDBiasQuantizePrimitive_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank, unique_var="BaseDActLuDBiasQuantizePrimitive_i", flatten_axis=-2
len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2
)
x_axes = scale_rules.input_spec
dz_axes = (*x_axes[:-2], x_axes[-1])
out = x_axes
colwise_out = (prefix + "out_colwise",)
if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
else:
colwise_out = tuple(x_axes)
else:
colwise_out = ("j",)
colwise_out = out
dbias = x_axes[-2:] if is_dbias else ("k",)
amax = ("…4",)
dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",)
amax = (prefix + "amax",)
return SdyShardingRule(
(("…0",), tuple(x_axes), ("…2",)),
(dz_axes, x_axes, ("…2",)),
(out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias),
**scale_rules.factor_sizes,
)
......@@ -985,6 +981,7 @@ def act_lu(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
noop_scaled_tensor: bool = False,
) -> Union[jnp.ndarray, ScaledTensor]:
"""Activation with optional quantization.
......@@ -993,6 +990,7 @@ def act_lu(
Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns:
If quantizer is None:
......@@ -1037,6 +1035,10 @@ def act_lu(
is_outer=True,
)
out = out.reshape(output_shape)
if noop_scaled_tensor:
return ScaledTensorFactory.create_2x(
out, None, out, None, ScalingMode.NO_SCALING, dq_dtype=out.dtype
)
return out
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
......@@ -1090,6 +1092,7 @@ def quantize_dact_dbias(
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
is_dbias: bool = True,
quantizer: Optional[Quantizer] = None,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor, jnp.ndarray]:
"""Compute gradients of activation and bias with optional quantization.
......@@ -1100,6 +1103,7 @@ def quantize_dact_dbias(
activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
is_dbias: If True, compute bias gradient. Defaults to True.
quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns:
Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
......@@ -1113,13 +1117,49 @@ def quantize_dact_dbias(
f" {x.shape} and act_len {act_len}"
)
scale = jnp.empty((), jnp.float32)
act_type_id = ActivationEnum[activation_type]
PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive
if not PrimitiveClass.enabled():
if not PrimitiveClass.enabled() or (
quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE
):
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
# TE/common does not support colwise-only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
if quantizer is None:
output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
dz,
x,
scale,
# outputs float32 for dbias accumulation
out_dtype=(jnp.float32 if is_dbias else x.dtype),
# default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, # unused
scale_dtype=jnp.float32, # unused
is_dbias=False,
act_enum=act_type_id,
act_len=act_len,
is_outer=True,
)
output = output.astype(x.dtype)
dbias = None
if is_dbias:
dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
if noop_scaled_tensor:
return (
ScaledTensorFactory.create_2x(
output,
None,
output,
None,
ScalingMode.NO_SCALING,
dq_dtype=output.dtype,
),
dbias,
)
return output, dbias
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
......@@ -1145,31 +1185,6 @@ def quantize_dact_dbias(
if war_output is not None:
return war_output
scale = jnp.empty((), jnp.float32)
act_type_id = ActivationEnum[activation_type]
if quantizer is None:
output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
dz,
x,
scale,
# outputs float32 for dbias accumulation
out_dtype=(jnp.float32 if is_dbias else x.dtype),
# default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, # unused
scale_dtype=jnp.float32, # unused
is_dbias=False,
act_enum=act_type_id,
act_len=act_len,
is_outer=True,
)
dbias = None
if is_dbias:
dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
return output.astype(x.dtype), dbias
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out = dact_lu(
......@@ -1183,7 +1198,7 @@ def quantize_dact_dbias(
)
return out, dbias
if isinstance(quantizer, DelayedScaleQuantizer):
if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale
# TE/common dact_dbias_quantize does not support gated act yet
......@@ -1243,6 +1258,7 @@ def dact_lu(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
noop_scale_tensor: bool = False,
) -> Union[jnp.ndarray, ScaledTensor]:
"""
Backward pass for activation with optional quantization.
......@@ -1252,6 +1268,7 @@ def dact_lu(
x: Input tensor that was used in forward pass.
activation_type: Type of activation function that was applied.
quantizer: Optional quantizer for FP8 quantization of the output gradient.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns:
The gradient of the activation with respect to the input.
......@@ -1262,5 +1279,6 @@ def dact_lu(
activation_type=activation_type,
is_dbias=False,
quantizer=quantizer,
noop_scaled_tensor=noop_scale_tensor,
)
return output
......@@ -198,14 +198,19 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant
Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to
calculate dbias separately. This function checks if the workaround should be applied.
"""
if quantizer is None:
return False
arch_l_100 = False
for local_gpu_id in range(len(jax.local_devices())):
if transformer_engine_jax.get_device_compute_capability(local_gpu_id) < 100:
arch_l_100 = True
break
# _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE,
# but this fails when bias fusion is turned on with arch < 100.
force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
return (
quantizer is not None
and quantizer.q_layout == QuantizeLayout.ROWWISE
(force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE)
and arch_l_100
and is_dbias
)
......
......@@ -587,16 +587,17 @@ class NormFwdPrimitive(BasePrimitive):
result_types,
)
prefix = "NormFwdPrimitive_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var="NormFwdPrimitive_i", flatten_axis=-1
len(value_types[0].shape), unique_var=prefix + "x", flatten_axis=-1
)
x_axes = scale_rules.input_spec
out = x_axes[:-1] + ("k",)
colwise_out = out if is_2x else ("…4",)
out = x_axes
colwise_out = out if is_2x else (prefix + "out_colwise",)
rsigma = x_axes[:-1]
mu = ("…5",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma
amax = ("…6",)
mu = (prefix + "mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma
amax = (prefix + "amax",)
return SdyShardingRule(
(x_axes, ("…1",), ("…2",), ("…3",)),
......@@ -609,7 +610,6 @@ class NormFwdPrimitive(BasePrimitive):
mu,
rsigma,
),
**scale_rules.factor_sizes,
)
......@@ -1276,6 +1276,7 @@ def normalization_fwd(
epsilon: float,
norm_type: str,
quantizer: Optional[Quantizer],
noop_scaled_tensor: bool = False,
):
"""Common wrapper for normalization forward pass.
......@@ -1292,6 +1293,7 @@ def normalization_fwd(
- 'layernorm': Layer normalization
- 'rmsnorm': Root mean square normalization
quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns:
A tuple containing:
......@@ -1319,6 +1321,15 @@ def normalization_fwd(
else:
raise ValueError(f"{norm_type=} is not supported.")
if quantizer is None and noop_scaled_tensor:
return (
ScaledTensorFactory.create_2x(
output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype
),
mu,
rsigma,
)
return output, mu, rsigma
......
......@@ -36,7 +36,6 @@ from ..quantize import (
Quantizer,
GroupedQuantizer,
QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode,
compute_scale_from_amax,
)
......@@ -489,9 +488,10 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
):
del out_dtype, scale_dtype, is_outer, mesh, result_types
prefix = "BaseDBiasQuantizePrimitive_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape),
unique_var="BaseDBiasQuantizePrimitive_i",
unique_var=prefix + "x",
flatten_axis=flatten_axis,
)
......@@ -499,22 +499,19 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
colwise_scale_inv = scale_rules.colwise_rule
out = x_axes
colwise_out = (prefix + "out_colwise",)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
else:
colwise_out = x_axes
else:
colwise_out = ("j",)
colwise_scale_inv = ("k",)
dbias = x_axes[flatten_axis:] if is_dbias else ("l",)
amax = ("m",)
dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",)
amax = (prefix + "amax",)
return SdyShardingRule(
(x_axes, ("…1",)),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
**scale_rules.factor_sizes,
)
......@@ -538,11 +535,12 @@ def _jax_quantize(
def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
assert flatten_axis < 0
sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
assert sum_axis < dx.ndim, "Flatten axis out of bounds!"
dtype = dtype or dx.dtype
dbias = jnp.sum(
dx.astype(jnp.float32),
axis=tuple(range(dx.ndim + flatten_axis)),
axis=tuple(range(sum_axis)),
keepdims=False,
)
return dbias.astype(dtype)
......@@ -568,6 +566,7 @@ def _quantize_dbias_impl(
is_dbias: bool = False,
dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""
Cast wrapper
......@@ -577,24 +576,34 @@ def _quantize_dbias_impl(
quantizer is not None
), "quantizer must be provided if dq_dtype is provided"
# Early-exit for non-quantized call
dq_dtype = dq_dtype or x.dtype
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
if not PrimitiveClass.enabled():
if quantizer is None:
dbias = None
if is_dbias:
return _jax_quantize_dbias(
x,
quantizer=quantizer,
dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
if noop_scaled_tensor:
# Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor()
# always works.
return (
ScaledTensorFactory.create_2x(
x,
None,
x,
None,
ScalingMode.NO_SCALING,
dq_dtype=x.dtype,
data_layout="NN",
flatten_axis=flatten_axis,
),
dbias,
)
return (
_jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
None,
)
return x, dbias
# TE/common doesn't support colwise only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
# If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
# fall back on the native-JAX quantize implementation
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
if quantizer.q_layout == QuantizeLayout.COLWISE or not PrimitiveClass.enabled():
if is_dbias:
return _jax_quantize_dbias(
x,
......@@ -606,9 +615,8 @@ def _quantize_dbias_impl(
_jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
None,
)
scale = jnp.empty((), jnp.float32)
# TE/common dbias_quantize does not support 1x on arch < 100
# TE/common custom quantize op does not support dbias fusion with 1x quantization on arch < 100
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out, _ = _quantize_dbias_impl(
x=x,
......@@ -620,29 +628,23 @@ def _quantize_dbias_impl(
dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return out, dbias
if quantizer is None:
if is_dbias:
return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return x, None
scale = jnp.empty((), jnp.float32)
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Globally reduce amax across all devices for current scaling so we have a single global scale.
# This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
# until the tensor is dequantized (e.g. in the GEMM).
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32)
scale = compute_scale_from_amax(amax, quantizer.q_dtype)
if isinstance(quantizer, DelayedScaleQuantizer):
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale
is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
# It is faster to use 1x quantization for tensor scaling
is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
force_1x_quantization = (
quantizer.scaling_mode.is_tensor_scaling()
and quantizer.is_2x2x()
and is_1x_kernel_supported
)
q_layout = quantizer.q_layout
if force_1x_quantization:
q_layout = QuantizeLayout.ROWWISE
......@@ -698,6 +700,7 @@ def quantize(
x: jnp.ndarray,
quantizer: Quantizer,
flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer.
......@@ -707,6 +710,8 @@ def quantize(
quantizer: Quantizer for FP8 quantization of the output.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1.
noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer
is None.
Returns:
A ScaledTensor containing the quantized input tensor.
......@@ -715,6 +720,7 @@ def quantize(
x,
quantizer=quantizer,
flatten_axis=flatten_axis,
noop_scaled_tensor=noop_scaled_tensor,
)
return out
......@@ -724,6 +730,7 @@ def quantize_dbias(
quantizer: Quantizer,
is_dbias: bool = True,
flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient.
......@@ -734,6 +741,8 @@ def quantize_dbias(
is_dbias: If True, compute bias gradient. Defaults to True.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1.
noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when
quantizer is None.
Returns:
A tuple containing:
......@@ -743,7 +752,11 @@ def quantize_dbias(
Shape: (K,) or empty if is_dbias is False.
"""
return _quantize_dbias_impl(
dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis
dz,
quantizer=quantizer,
is_dbias=is_dbias,
flatten_axis=flatten_axis,
noop_scaled_tensor=noop_scaled_tensor,
)
......
......@@ -119,6 +119,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right);
// GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
// Grouped GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
......
......@@ -38,12 +38,11 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
case xla::ffi::DataType::F8E4M3FN:
return DType::kFloat8E4M3;
break;
// case xla::ffi::DataType::F8E8M0FNU:
// return DType::kFloat8E8M0;
// break;
case xla::ffi::DataType::F8E8M0FNU:
return DType::kFloat8E8M0;
break;
default:
auto type_num = static_cast<XLA_FFI_DataType>(type);
if (type_num == 33) return DType::kFloat8E8M0;
NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d",
static_cast<int>(type_num));
break;
......
......@@ -6,11 +6,13 @@
#include "transformer_engine/gemm.h"
#include <memory>
#include <string_view>
#include <tuple>
#include "../extensions.h"
#include "common/util/cuda_runtime.h"
#include "common/util/string.h"
#include "common/util/system.h"
#include "transformer_engine/multi_stream.h"
#include "transformer_engine/swizzle.h"
#include "xla/ffi/api/c_api.h"
......@@ -25,6 +27,181 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) {
~static_cast<uintptr_t>(255));
}
std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv,
JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) {
// Set tensor data with collapsed 2D shape
auto buffer_dims = buffer.dimensions();
std::vector<size_t> input_shape = {product(buffer_dims, 0, axis_boundary),
product(buffer_dims, axis_boundary, buffer_dims.size())};
auto input_dtype = convert_ffi_datatype_to_te_dtype(buffer.element_type());
TensorWrapper input(get_nvte_scaling_mode(scaling_mode));
if (rowwise) {
input.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape);
} else {
input.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape);
}
// Set scaling factor for quantized tensors
if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands.");
NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM.");
std::vector<size_t> scale_shape = {1};
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
// Block scaling also needs to be collapsed to match 2D data
scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary),
product(scale_inv.dimensions(), axis_boundary, scale_inv.dimensions().size())};
}
auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type());
if (rowwise) {
input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
} else {
input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
}
// Swizzle scaling factors for MXFP8
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
// Get the swizzle buffer
NVTE_CHECK(swizzled_scale_inv->element_count() > 0,
"Missing swizzled inverse scale buffer in the JAX primitive.");
auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type());
auto swizzled_scale_inv_dtype =
convert_ffi_datatype_to_te_dtype(swizzled_scale_inv->element_type());
NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1,
"Inverse scale factors need to have an 8-bit data type.");
// Create tensor to hold swizzled scale factor
TensorWrapper output(get_nvte_scaling_mode(scaling_mode));
if (rowwise) {
output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape);
output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape);
} else {
output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape);
output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype,
scale_shape);
}
// Launch swizzle kernel
nvte_swizzle_scaling_factors(input.data(), output.data(), stream);
// Set swizzled scales into the input tensor
if (rowwise) {
input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape);
} else {
input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype,
scale_shape);
}
}
}
return std::make_tuple(std::move(input), input_shape);
}
Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input,
Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out,
Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace,
JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed,
bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) {
// Operands (this includes swizzling MXFP8 scaling factors)
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
(is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported()));
bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed;
bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed;
auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(
stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise);
auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(
stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise);
// Output tensor
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
(rhs_transposed) ? rhs_shape[0] : rhs_shape[1]};
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, "
"expected ",
out_.numel(), " elements ", to_string_like(out_shape), " but got ",
output->element_count(), " elements ", to_string_like(output->dimensions()));
// Bias input to forward pass or bias gradient output from backward pass
void *bias_ptr = nullptr;
std::vector<size_t> bias_shape = {0};
DType bias_dtype = out_dtype;
if (fuse_bias) {
if (!grad) {
NVTE_CHECK(bias_grad->untyped_data() == bias.untyped_data(),
"Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad");
}
bias_ptr = bias_grad->untyped_data();
bias_shape.at(0) = bias_grad->dimensions().front();
bias_dtype = convert_ffi_datatype_to_te_dtype(bias_grad->element_type());
}
auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype);
// Pre-GeLU output from forward pass or input to backward pass
void *pre_gelu_ptr = nullptr;
std::vector<size_t> pre_gelu_shape = {0};
DType pre_gelu_dtype = out_dtype;
if (gelu_input.element_count() > 0) {
if (grad) {
NVTE_CHECK(pre_gelu_out->untyped_data() == gelu_input.untyped_data(),
"Missing operand-output aliasing in GemmPrimitive: gelu_input <-> pre_gelu_out");
}
pre_gelu_ptr = pre_gelu_out->untyped_data();
pre_gelu_shape = {product(pre_gelu_out->dimensions(), 0, pre_gelu_out->dimensions().size() - 1),
static_cast<size_t>(pre_gelu_out->dimensions().back())};
pre_gelu_dtype = convert_ffi_datatype_to_te_dtype(pre_gelu_out->element_type());
}
auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype);
// cuBLAS workspace + 256 alignment enforcement
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr);
std::vector<size_t> workspace_shape = {static_cast<size_t>(workspace->element_count()) - 256};
auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte);
// Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(),
rhs_transposed, lhs_transposed, grad, workspace_.data(), false,
use_split_accumulator, num_math_sm, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // lhs
.Arg<Buffer_Type>() // lhs_scale_inv
.Arg<Buffer_Type>() // rhs
.Arg<Buffer_Type>() // rhs_scale_inv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // gelu_input
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out
.Ret<Buffer_Type>() // lhs_swizzled
.Ret<Buffer_Type>() // rhs_swizzled
.Ret<Buffer_Type>() // workspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("lhs_axis_boundary")
.Attr<int64_t>("rhs_axis_boundary")
.Attr<bool>("lhs_transposed")
.Attr<bool>("rhs_transposed")
.Attr<bool>("fuse_bias")
.Attr<bool>("fuse_gelu")
.Attr<bool>("grad")
.Attr<bool>("use_split_accumulator"),
FFI_CudaGraph_Traits);
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv,
Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias,
Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output,
......
......@@ -47,6 +47,15 @@ enum class JAXX_Scaling_Mode : int64_t {
CURRENT_TENSOR_SCALING = 3,
};
inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) {
return (mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING ||
mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING);
}
inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) {
return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING);
}
static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
switch (mode) {
case JAXX_Scaling_Mode::NO_SCALING:
......
......@@ -55,6 +55,11 @@ pybind11::dict Registrations() {
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler));
// GEMM
dict["te_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GemmHandler));
// Grouped GEMM
dict["te_grouped_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
......@@ -78,6 +83,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
m.def("nvte_get_qkv_format", &nvte_get_qkv_format);
m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported);
pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte)
......
......@@ -8,7 +8,7 @@ architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations.
"""
import warnings
from typing import Tuple, Sequence
from functools import partial
import jax
......@@ -23,6 +23,16 @@ from .quantize import (
)
DENSE_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global DENSE_BATCH_FIRST_WARNING_ISSUED
if not DENSE_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
DENSE_BATCH_FIRST_WARNING_ISSUED = True
def dense(
x: jnp.ndarray,
kernel: jnp.ndarray,
......@@ -30,6 +40,7 @@ def dense(
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
quantizer_set: QuantizerSet = noop_quantizer_set,
):
"""Perform dense layer transformation with optional quantization.
......@@ -43,25 +54,28 @@ def dense(
kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract
batch_first: Assume that X is batched in the first dimension.
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Transformed output tensor
"""
# Remove when tex.quantize() can handle quantizer=None
if quantizer_set == noop_quantizer_set:
if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot():
x = with_sharding_constraint_by_logical_axes(x, input_axes)
output = tex.gemm(x, kernel, contracting_dims)
output = tex.gemm(x, kernel, contracting_dims=contracting_dims)
if bias is not None:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape)
else:
output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set)
output = _dense(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
)
return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6))
def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set):
"""Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support
......@@ -75,44 +89,81 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
Returns:
Transformed output tensor
"""
output, _ = _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
)
return output
def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set):
def _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
):
"""Forward pass rule for dense layer transformation.
Returns:
Tuple of (output, context) for backward pass
"""
x_contracting_dims, k_contracting_dims = contracting_dims
x_contracting_dims, k_contracting_dims = map(
tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims
)
# Check supported input layout
x_is_transposed = x.ndim - 1 not in x_contracting_dims
k_is_transposed = kernel.ndim - 1 in k_contracting_dims
assert (
not x_is_transposed and not k_is_transposed
), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel."
# Determine X batch dimension
# - If `batch_first=True` -> (batch, leading..., contracting...)
# - Otherwise -> (leading..., batch, contracting...)
# NOTE: Always assume a single batch dimension
x_bdim = None
num_cdims = len(x_contracting_dims)
if x.ndim >= num_cdims + 2:
# Assume X is batched if it has at least +2 dimensions more than the number of contracting
# dimensions.
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `dense()` layer implementation does not officially support sequence-first "
"inputs and may produce incorrect results when `batch_first=False`. Use "
"sequence-first inputs at your own discretion.",
)
x_bdim = 0 if batch_first else x.ndim - num_cdims - 1
flatten_axis_x = -len(x_contracting_dims)
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)
casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x)
casted_x = tex.quantize(
x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, noop_scaled_tensor=True
)
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel
kernel,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.kernel,
noop_scaled_tensor=True,
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# GEMM NN
use_bias = bias is not None
output = tex.gemm(
casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_tensor(usage=TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims),
contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
)
use_bias = bias is not None
if use_bias:
if use_bias and tex.gemm_uses_jax_dot():
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape)
......@@ -124,20 +175,19 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes,
use_bias,
quantizer_set,
flatten_axis_k,
x_bdim,
)
return output, ctx
def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, ctx, grad
contracting_dims, input_axes, kernel_axes, batch_first, ctx, grad
): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation.
Returns:
Tuple of gradients with respect to inputs
"""
fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims
(
casted_x_lhs,
casted_kernel_rhs,
......@@ -146,10 +196,19 @@ def _dense_bwd_rule(
use_bias,
quantizer_set,
flatten_axis_k,
x_bdim,
) = ctx
fwd_x_contracting_dims, fwd_k_contracting_dims = map(
tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims
)
casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad
grad,
is_dbias=use_bias,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad,
noop_scaled_tensor=True,
)
# GEMM NT
......@@ -164,7 +223,8 @@ def _dense_bwd_rule(
dgrad = tex.gemm(
casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs,
(g_contracting_dim, k_contracting_dim),
contracting_dims=(g_contracting_dim, k_contracting_dim),
batched_dims=((x_bdim,), ()),
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
......@@ -177,7 +237,8 @@ def _dense_bwd_rule(
wgrad = tex.gemm(
casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS),
(x_contracting_dim, g_contracting_dim),
contracting_dims=(x_contracting_dim, g_contracting_dim),
batched_dims=((x_bdim,), (x_bdim,)),
)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
......@@ -6,7 +6,7 @@ Wrapper module for Transformer related layers with FP8 support.
"""
from functools import reduce
import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
import numpy as np
import jax.numpy as jnp
......@@ -15,12 +15,12 @@ from jax import lax
from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name
from ..dense import dense
from ..dense import dense, _issue_batch_first_warning as _dense_warning
from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm
from ..layernorm_dense import layernorm_dense
from ..layernorm_mlp import layernorm_mlp
from ..layernorm_dense import layernorm_dense, _issue_batch_first_warning as _ln_dense_warning
from ..layernorm_mlp import layernorm_mlp, _issue_batch_first_warning as _ln_mlp_warning
from ..activation import activation
from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes
......@@ -35,8 +35,8 @@ from ..sharding import get_non_contracting_logical_axes
PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = jnp.ndarray
DType = NewType("DType", jnp.dtype)
Array = NewType("Array", jnp.ndarray)
PrecisionLike = Union[
None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
......@@ -441,6 +441,12 @@ class DenseGeneral(TransformerEngineBase):
input_axes: Tuple[str, ...] = ()
def __post_init__(self):
if self.transpose_batch_sequence:
_dense_warning(
"TE/JAX DenseGeneral() module does not officially support sequence-first inputs "
"and may produce incorrect results when `transpose_batch_sequence=True`. Use "
"sequence-first inputs at your own discretion."
)
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
......@@ -657,6 +663,12 @@ class LayerNormDenseGeneral(TransformerEngineBase):
depth_scaling: float = None
def __post_init__(self):
if self.transpose_batch_sequence:
_ln_dense_warning(
"TE/JAX LayerNormDenseGeneral() module does not officially support sequence-first "
"inputs and may produce incorrect results when `transpose_batch_sequence=True`. "
"Use sequence-first inputs at your own discretion."
)
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0,
......@@ -967,6 +979,12 @@ class LayerNormMLP(TransformerEngineBase):
dot_2_input_axes: Tuple[str, ...] = None
def __post_init__(self):
if self.transpose_batch_sequence:
_ln_mlp_warning(
"TE/JAX LayerNormMLP() module does not officially support sequence-first inputs "
"and may produce incorrect results when `transpose_batch_sequence=True`. Use "
"sequence-first inputs at your own discretion."
)
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
......
......@@ -9,6 +9,7 @@ architectures. It supports various normalization types, quantization, and
distributed training through sharding constraints.
"""
import warnings
from functools import partial
from typing import Tuple
......@@ -25,6 +26,16 @@ from .quantize import (
)
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED
if not LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = True
def layernorm_dense(
x: jnp.ndarray,
kernel: jnp.ndarray,
......@@ -37,6 +48,7 @@ def layernorm_dense(
layernorm_input_axes: Tuple[str, ...] = None,
dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray:
"""Apply layer normalization followed by dense layer transformation.
......@@ -57,6 +69,7 @@ def layernorm_dense(
layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
quantizer_set: Set of quantizers for different tensor types
Returns:
......@@ -80,6 +93,7 @@ def layernorm_dense(
layernorm_input_axes,
dot_input_axes,
kernel_axes,
batch_first,
quantizer_set,
)
return output
......@@ -94,6 +108,7 @@ def layernorm_dense(
8,
9,
10,
11,
),
)
def _layernorm_dense(
......@@ -108,6 +123,7 @@ def _layernorm_dense(
layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...],
batch_first: bool,
quantizer_set,
):
"""Internal implementation of layernorm_dense with custom VJP.
......@@ -127,6 +143,7 @@ def _layernorm_dense(
epsilon: Small constant for numerical stability
layernorm_input_axes: Logical axes for layernorm sharding
dot_input_axes: Logical axes for matrix multiplication sharding
batch_first: Assume that X is batched in the first dimension.
quantizer_set: Set of quantizers
Returns:
......@@ -144,6 +161,7 @@ def _layernorm_dense(
layernorm_input_axes,
dot_input_axes,
kernel_axes,
batch_first,
quantizer_set,
)
return output
......@@ -161,6 +179,7 @@ def _layernorm_dense_fwd_rule(
layernorm_input_axes,
dot_input_axes,
kernel_axes,
batch_first,
quantizer_set,
):
"""Forward pass rule for layernorm_dense.
......@@ -178,6 +197,17 @@ def _layernorm_dense_fwd_rule(
k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0]
x_bdim = None
if x.ndim > 2:
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `layernorm_dense()` fused-layer implementation does not officially "
"support sequence-first inputs and may produce incorrect results when "
"`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
"inputs at your own discretion."
)
x_bdim = 0 if batch_first else x.ndim - 2
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
casted_ln_out, mu, rsigma = tex.normalization_fwd(
......@@ -187,25 +217,31 @@ def _layernorm_dense_fwd_rule(
zero_centered_gamma,
epsilon,
norm_type,
quantizer_set.x,
quantizer=quantizer_set.x,
noop_scaled_tensor=True,
)
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
# Kernel in (hidden_in, hidden_out...)
flatten_axis = 1 - len(kernel.shape)
casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel)
casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out...)
use_bias = bias is not None
output = tex.gemm(
casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel.get_tensor(TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims),
contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
)
use_bias = bias is not None
if use_bias:
if use_bias and tex.gemm_uses_jax_dot():
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape)
......@@ -224,6 +260,7 @@ def _layernorm_dense_fwd_rule(
use_bias,
quantizer_set,
flatten_axis,
x_bdim,
)
return output, ctx
......@@ -236,6 +273,7 @@ def _layernorm_dense_bwd_rule(
layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument
kernel_axes,
batch_first, # pylint: disable=unused-argument
ctx,
grad,
):
......@@ -265,10 +303,15 @@ def _layernorm_dense_bwd_rule(
use_bias,
quantizer_set,
flatten_axis,
x_bdim,
) = ctx
casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad
grad,
is_dbias=use_bias,
flatten_axis=flatten_axis,
quantizer=quantizer_set.dgrad,
noop_scaled_tensor=True,
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
......@@ -284,7 +327,8 @@ def _layernorm_dense_bwd_rule(
dgrad = tex.gemm(
casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel,
(g_constracting_dim, k_constracting_dim),
contracting_dims=(g_constracting_dim, k_constracting_dim),
batched_dims=((x_bdim,), ()),
)
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
......@@ -297,7 +341,8 @@ def _layernorm_dense_bwd_rule(
wgrad = tex.gemm(
casted_ln_out,
casted_grad.get_tensor(TensorUsage.RHS),
(x_constracting_dim, g_constracting_dim),
contracting_dims=(x_constracting_dim, g_constracting_dim),
batched_dims=((x_bdim,), (x_bdim,)),
)
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment