Unverified Commit 214e2a4a authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[JAX] GEMM custom op (#1855)



* added XLA FFI custom op for TE/common nvte_cublas_gemm
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

started GemmPrimitive, abstract done
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

gemm custom op working with BF16, needs testing for FP8/MXFP8
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

converted TE GEMM API to use ScaledTensor and added os ENV flag to use TE GEMM under general gemm() call
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

BF16 tests passing, FP8 tests should be passing but contracting_dims has a scoping issue
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fp8 tests passing for E4M3, getting CUBLAS_STATUS_NOT_SUPPORTED for E5M2
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

updated GEMM API to use separate LHS and RHS quantizers instead of a QuantizerSet
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

new GemmPrimitive passing all Dense tests
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

import cleanup and reverted code chunk movement
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

removed unused .transpose() implementations from ScaledTensors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

all custom call tests passing on Hopper, GEMM-related tests cover both GemmPrimitive and native JAX impl
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

removed direct calls to GemmPrimitive.enabled() from outside of cpp_extensions
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

removed unused changes to ScaledTensor classes and debug prints
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* minor unit test cleanup
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* FP8 tests passing on Blackwell but MXFP8 outputs NaN
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* reverted dense and fuseddense changes, FP8 test passing on Hopper and Blackwell, MXFP8 has issues with E5M2
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* MXFP8 issue traced to scale factor padding with NaNs instead of zeros
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* padding scale with 2^-127 instead of nans
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix bug on rhs_scale_inv usage
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* cleanup E8M0 type converter use it in gemm.cpp
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* segfault fixed, passing all unittests on Blackwell
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fix for fuseddense tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix workspace alignment
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed GemmPrimitive custom partitioning to match jax.nn.scaled_matmul
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

all unit tests passing on H100x8 node
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



linting fixes
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fixed batch dimension numbers
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fixed FP8 scale sharding rule when there are no FP8 scales
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

added error message for unsupported Shardy partitioner
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fixed test tolerances for FP8 cases
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fixed shardy test skip cases
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* moved reshape of encoder output in encoder examples to make custom partitioning rules work correctly
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added helper functions for padding and unpadding block scales, changed GemmPrimitive to accept unpadded scales and pad them after sharding
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* updated shardy rules for all custom ops to decouple block scale rules from their tensors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed linting errors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* changed unit test use_jax_gemm option to be a context to preserve external custom op settings, tightened multi-GPU encoder test tolerances, changed gemm() API to use contracting_dims and batched_dims separately instead of dimension_numbers
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed typo in test utils
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added sequence-first input warnings
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed datasets version for JAX examples
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* reverting modification to force_1x_quantization decision
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected gemm function syntax in unit tests
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 397c4be6
...@@ -30,7 +30,7 @@ for TEST_CASE in "${TEST_CASES[@]}"; do ...@@ -30,7 +30,7 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
LOG_FILE="${TEST_CASE}_gpu_${i}.log" LOG_FILE="${TEST_CASE}_gpu_${i}.log"
# Run pytest and redirect stdout and stderr to the log file # Run pytest and redirect stdout and stderr to the log file
pytest -c "$TE_PATH/tests/jax/pytest.ini" \ pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \
--num-process=$NUM_GPUS \ --num-process=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 & --process-id=$i > "$LOG_FILE" 2>&1 &
......
...@@ -25,6 +25,7 @@ from common import ( ...@@ -25,6 +25,7 @@ from common import (
assert_params_sufficiently_sharded, assert_params_sufficiently_sharded,
) )
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
...@@ -465,14 +466,14 @@ class TestEncoder(unittest.TestCase): ...@@ -465,14 +466,14 @@ class TestEncoder(unittest.TestCase):
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
def setUp(self): def setUp(self):
"""Run 3 epochs for testing""" """Run 5 epochs for testing"""
self.args = encoder_parser(["--epochs", "3"]) self.args = encoder_parser(["--epochs", "5"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
...@@ -480,7 +481,7 @@ class TestEncoder(unittest.TestCase): ...@@ -480,7 +481,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self): def test_te_mxfp8(self):
...@@ -488,14 +489,14 @@ class TestEncoder(unittest.TestCase): ...@@ -488,14 +489,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_with_sp(self): def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP""" """Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True self.args.enable_sp = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self): def test_te_delayed_scaling_fp8_with_sp(self):
...@@ -504,7 +505,7 @@ class TestEncoder(unittest.TestCase): ...@@ -504,7 +505,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self): def test_te_mxfp8_with_sp(self):
...@@ -513,14 +514,14 @@ class TestEncoder(unittest.TestCase): ...@@ -513,14 +514,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self): def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
self.args.enable_shardy = True self.args.enable_shardy = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self): def test_te_delayed_scaling_fp8_shardy(self):
...@@ -529,7 +530,7 @@ class TestEncoder(unittest.TestCase): ...@@ -529,7 +530,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp_shardy(self): def test_te_delayed_scaling_fp8_with_sp_shardy(self):
...@@ -539,9 +540,32 @@ class TestEncoder(unittest.TestCase): ...@@ -539,9 +540,32 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.43 and actual[1] > 0.80
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_with_sp_shardy(self):
"""Test Transformer Engine with MXFP8 + SP"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -21,6 +21,7 @@ from jax.sharding import PartitionSpec, NamedSharding ...@@ -21,6 +21,7 @@ from jax.sharding import PartitionSpec, NamedSharding
from common import is_bf16_supported, get_fp8_recipe_from_name_string from common import is_bf16_supported, get_fp8_recipe_from_name_string
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
...@@ -430,14 +431,14 @@ class TestEncoder(unittest.TestCase): ...@@ -430,14 +431,14 @@ class TestEncoder(unittest.TestCase):
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
def setUp(self): def setUp(self):
"""Run 3 epochs for testing""" """Run 5 epochs for testing"""
self.args = encoder_parser(["--epochs", "3"]) self.args = encoder_parser(["--epochs", "6"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
...@@ -445,7 +446,7 @@ class TestEncoder(unittest.TestCase): ...@@ -445,7 +446,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self): def test_te_current_scaling_fp8(self):
...@@ -453,7 +454,7 @@ class TestEncoder(unittest.TestCase): ...@@ -453,7 +454,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling" self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self): def test_te_mxfp8(self):
...@@ -461,14 +462,14 @@ class TestEncoder(unittest.TestCase): ...@@ -461,14 +462,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self): def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
self.args.enable_shardy = True self.args.enable_shardy = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self): def test_te_delayed_scaling_fp8_shardy(self):
...@@ -477,9 +478,7 @@ class TestEncoder(unittest.TestCase): ...@@ -477,9 +478,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.50 and actual[1] > 0.75
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8_shardy(self): def test_te_current_scaling_fp8_shardy(self):
...@@ -488,7 +487,19 @@ class TestEncoder(unittest.TestCase): ...@@ -488,7 +487,19 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling" self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.75
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -28,8 +28,8 @@ from common import ( ...@@ -28,8 +28,8 @@ from common import (
get_fp8_recipe_from_name_string, get_fp8_recipe_from_name_string,
) )
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
...@@ -584,8 +584,8 @@ class TestEncoder(unittest.TestCase): ...@@ -584,8 +584,8 @@ class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False):
"""Run 3 epochs for testing""" """Run 5 epochs for testing"""
args = encoder_parser([]) args = encoder_parser(["--epochs", "5"])
num_gpu = self.num_process num_gpu = self.num_process
tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1 tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1
...@@ -607,7 +607,7 @@ class TestEncoder(unittest.TestCase): ...@@ -607,7 +607,7 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
result = self.exec(False, None) result = self.exec(False, None)
assert result[0] < 0.505 and result[1] > 0.755 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf( @unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8"
...@@ -615,7 +615,7 @@ class TestEncoder(unittest.TestCase): ...@@ -615,7 +615,7 @@ class TestEncoder(unittest.TestCase):
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8""" """Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling") result = self.exec(True, "DelayedScaling")
assert result[0] < 0.506 and result[1] > 0.753 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf( @unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
...@@ -623,7 +623,7 @@ class TestEncoder(unittest.TestCase): ...@@ -623,7 +623,7 @@ class TestEncoder(unittest.TestCase):
def test_te_current_scaling_fp8(self): def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8""" """Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling") result = self.exec(True, "Float8CurrentScaling")
assert result[0] < 0.507 and result[1] > 0.753 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf( @unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
...@@ -631,13 +631,13 @@ class TestEncoder(unittest.TestCase): ...@@ -631,13 +631,13 @@ class TestEncoder(unittest.TestCase):
def test_te_mxfp8(self): def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8""" """Test Transformer Engine with MXFP8"""
result = self.exec(True, "MXFP8BlockScaling") result = self.exec(True, "MXFP8BlockScaling")
assert result[0] < 0.505 and result[1] > 0.754 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self): def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
result = self.exec(False, None, enable_shardy=True) result = self.exec(False, None, enable_shardy=True)
assert result[0] < 0.505 and result[1] > 0.755 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf( @unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8"
...@@ -645,9 +645,7 @@ class TestEncoder(unittest.TestCase): ...@@ -645,9 +645,7 @@ class TestEncoder(unittest.TestCase):
def test_te_delayed_scaling_fp8_shardy(self): def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8""" """Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling", enable_shardy=True) result = self.exec(True, "DelayedScaling", enable_shardy=True)
assert result[0] < 0.506 and result[1] > 0.753 assert result[0] < 0.43 and result[1] > 0.80
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
@unittest.skipIf( @unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
...@@ -655,7 +653,18 @@ class TestEncoder(unittest.TestCase): ...@@ -655,7 +653,18 @@ class TestEncoder(unittest.TestCase):
def test_te_current_scaling_fp8_shardy(self): def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8""" """Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) result = self.exec(True, "Float8CurrentScaling", enable_shardy=True)
assert result[0] < 0.507 and result[1] > 0.753 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True)
assert result[0] < 0.43 and result[1] > 0.80
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -13,6 +13,7 @@ import operator ...@@ -13,6 +13,7 @@ import operator
from utils import ( from utils import (
assert_allclose, assert_allclose,
pytest_parametrize_wrapper, pytest_parametrize_wrapper,
use_jax_gemm,
) )
from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm_mlp import layernorm_mlp from transformer_engine.jax.layernorm_mlp import layernorm_mlp
...@@ -30,7 +31,6 @@ from transformer_engine.jax.cpp_extensions.quantization import ( ...@@ -30,7 +31,6 @@ from transformer_engine.jax.cpp_extensions.quantization import (
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax import cpp_extensions as tex
from transformer_engine.jax.quantize import ( from transformer_engine.jax.quantize import (
DelayedScaleQuantizer,
ScaledTensor, ScaledTensor,
ScaledTensor1x, ScaledTensor1x,
ScaledTensor2x, ScaledTensor2x,
...@@ -851,6 +851,22 @@ class TestFusedQuantize: ...@@ -851,6 +851,22 @@ class TestFusedQuantize:
) )
valid_fp8_gemm_operand_types = [
(jnp.float8_e4m3fn, jnp.float8_e4m3fn),
(jnp.float8_e5m2, jnp.float8_e4m3fn),
(jnp.float8_e4m3fn, jnp.float8_e5m2),
]
def _use_jax_fp8_gemm(enabled=False):
import os
if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$"
elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ:
os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE")
class TestDense: class TestDense:
def _ref_gemm_with_jnp_dot(self, a, b, data_layout): def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
if data_layout[0] == "T": if data_layout[0] == "T":
...@@ -883,27 +899,47 @@ class TestDense: ...@@ -883,27 +899,47 @@ class TestDense:
def test_gemm_bf16(self, m, n, k, data_layout): def test_gemm_bf16(self, m, n, k, data_layout):
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
primitive_out = tex.gemm(x, w, contracting_dims) primitive_out = tex.gemm(x, w, contracting_dims=contracting_dims)
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout): @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm):
if (
not with_jax_gemm
and scaling_mode.is_1d_block_scaling()
and jnp.float8_e5m2 in (x_qtype, w_qtype)
):
pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.")
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
quantizer_set = QuantizerFactory.create_set( quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False scaling_mode=scaling_mode,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2,
is_2x2x=False,
) )
with use_jax_gemm(enabled=with_jax_gemm):
primitive_out = tex.gemm( primitive_out = tex.gemm(
x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set x,
w,
contracting_dims=contracting_dims,
lhs_quantizer=(
quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad
),
rhs_quantizer=(
quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad
),
) )
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=q_dtype) assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
def test_dense_grad_bf16(self, m, n, k): def test_dense_grad_bf16(self, m, n, k):
...@@ -932,9 +968,9 @@ class TestDense: ...@@ -932,9 +968,9 @@ class TestDense:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode): @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm):
data_layout = "NN" data_layout = "NN"
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
...@@ -956,10 +992,14 @@ class TestDense: ...@@ -956,10 +992,14 @@ class TestDense:
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
quantizer_set = QuantizerFactory.create_set( quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True scaling_mode=scaling_mode,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True,
) )
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
with use_jax_gemm(enabled=with_jax_gemm):
for _ in range(n_iterations): for _ in range(n_iterations):
primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
...@@ -969,10 +1009,10 @@ class TestDense: ...@@ -969,10 +1009,10 @@ class TestDense:
x, w, bias, data_layout x, w, bias, data_layout
) )
assert_allclose(primitive_out, ref_out, dtype=q_dtype) assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype) assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=q_dtype) assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=q_dtype) assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp.float8_e5m2)
@pytest.fixture(name="random_inputs") @pytest.fixture(name="random_inputs")
...@@ -996,20 +1036,13 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan ...@@ -996,20 +1036,13 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
class TestFusedDense: class TestFusedDense:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_gemm):
""" """
Test layernorm_dense VJP Rule Test layernorm_dense VJP Rule
""" """
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm # zero_centered_gamma is already tested in TestNorm
zero_centered_gamma = False zero_centered_gamma = False
eps = 1e-6 eps = 1e-6
...@@ -1025,8 +1058,8 @@ class TestFusedDense: ...@@ -1025,8 +1058,8 @@ class TestFusedDense:
quantizer_set = QuantizerFactory.create_set( quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
fwd_dtype=q_dtype, fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=q_dtype, bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True, is_2x2x=True,
) )
...@@ -1064,6 +1097,7 @@ class TestFusedDense: ...@@ -1064,6 +1097,7 @@ class TestFusedDense:
) )
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
with use_jax_gemm(enabled=with_jax_gemm):
for _ in range(n_iterations): for _ in range(n_iterations):
prim_out, ( prim_out, (
prim_x_grad, prim_x_grad,
...@@ -1072,33 +1106,26 @@ class TestFusedDense: ...@@ -1072,33 +1106,26 @@ class TestFusedDense:
prim_beta_grad, prim_beta_grad,
) = value_n_grad_prim_func(x, w, gamma, beta) ) = value_n_grad_prim_func(x, w, gamma, beta)
assert_allclose(prim_out, ref_out, dtype=q_dtype) assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_w_grad, ref_w_grad, dtype=q_dtype) assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype) assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2)
if beta is not None: if beta is not None:
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype) assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp.float8_e5m2)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
@pytest.mark.parametrize("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad( def test_layernorm_mlp_grad(
self, m, n, k, activation_type, q_dtype, scaling_mode, norm_type, use_bias self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm
): ):
""" """
Test layernorm_mlp VJP Rule Test layernorm_mlp VJP Rule
""" """
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm # zero_centered_gamma is already tested in TestNorm
zero_centered_gamma = False zero_centered_gamma = False
eps = 1e-6 eps = 1e-6
...@@ -1123,8 +1150,8 @@ class TestFusedDense: ...@@ -1123,8 +1150,8 @@ class TestFusedDense:
quantizer_sets = QuantizerFactory.create_set( quantizer_sets = QuantizerFactory.create_set(
n_quantizer_sets=2, n_quantizer_sets=2,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
fwd_dtype=q_dtype, fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=q_dtype, bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True, is_2x2x=True,
) )
...@@ -1153,14 +1180,13 @@ class TestFusedDense: ...@@ -1153,14 +1180,13 @@ class TestFusedDense:
ln_out = _ref_jax_norm_impl( ln_out = _ref_jax_norm_impl(
x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
) )
# TODO: replace gemm with jnp.dot linear_1_out = jax.lax.dot_general(ln_out, kernel_1, (((1,), (0,)), ((), ())))
linear_1_out = tex.gemm(ln_out, kernel_1, ((1,), (0,)))
if use_bias: if use_bias:
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape) linear_1_out += jnp.reshape(bias_1, bias_1_shape)
x = _jax_act_lu(linear_1_out, activation_type) x = _jax_act_lu(linear_1_out, activation_type)
linear_2_out = tex.gemm(x, kernel_2, ((1,), (0,))) linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ())))
if use_bias: if use_bias:
bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape
linear_2_out += jnp.reshape(bias_2, bias_2_shape) linear_2_out += jnp.reshape(bias_2, bias_2_shape)
...@@ -1174,6 +1200,7 @@ class TestFusedDense: ...@@ -1174,6 +1200,7 @@ class TestFusedDense:
value_n_grad_ref_func = value_and_grad(ref_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, range(6))
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
with use_jax_gemm(enabled=with_jax_gemm):
for _ in range(n_iterations): for _ in range(n_iterations):
prim_out, ( prim_out, (
prim_x_grad, prim_x_grad,
...@@ -1193,18 +1220,18 @@ class TestFusedDense: ...@@ -1193,18 +1220,18 @@ class TestFusedDense:
ref_bias_2_grad, ref_bias_2_grad,
) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)
assert_allclose(prim_out, ref_out, dtype=q_dtype) assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=q_dtype) assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp.float8_e5m2)
if use_bias: if use_bias:
assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=q_dtype) assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=q_dtype) assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp.float8_e5m2)
if use_bias: if use_bias:
assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=q_dtype) assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype) assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
# E5M2 * E5M2 is not supported # E5M2 * E5M2 is not supported
......
...@@ -13,6 +13,7 @@ from utils import ( ...@@ -13,6 +13,7 @@ from utils import (
assert_tree_like_allclose, assert_tree_like_allclose,
is_devices_enough, is_devices_enough,
pytest_parametrize_wrapper, pytest_parametrize_wrapper,
use_jax_gemm,
) )
from transformer_engine.common import recipe from transformer_engine.common import recipe
...@@ -147,7 +148,15 @@ class TestDistributedLayernormMLP: ...@@ -147,7 +148,15 @@ class TestDistributedLayernormMLP:
) )
def _test_layernorm_mlp_grad( def _test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, use_shardy self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
use_shardy,
with_jax_gemm,
): ):
jax.config.update("jax_use_shardy_partitioner", use_shardy) jax.config.update("jax_use_shardy_partitioner", use_shardy)
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
...@@ -157,6 +166,8 @@ class TestDistributedLayernormMLP: ...@@ -157,6 +166,8 @@ class TestDistributedLayernormMLP:
input_shape, activation_type, use_bias, dtype input_shape, activation_type, use_bias, dtype
) )
static_inputs = [layernorm_type, activation_type] static_inputs = [layernorm_type, activation_type]
with use_jax_gemm(enabled=with_jax_gemm):
value_and_grad_func = jax.value_and_grad( value_and_grad_func = jax.value_and_grad(
self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs)) self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
) )
...@@ -172,7 +183,9 @@ class TestDistributedLayernormMLP: ...@@ -172,7 +183,9 @@ class TestDistributedLayernormMLP:
# Multi GPUs # Multi GPUs
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource): with mesh, fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp")) k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp")) k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding) k1_ = jax.device_put(k1, k1_sharding)
...@@ -204,25 +217,32 @@ class TestDistributedLayernormMLP: ...@@ -204,25 +217,32 @@ class TestDistributedLayernormMLP:
value_and_grad_func, value_and_grad_func,
in_shardings=in_shardings, in_shardings=in_shardings,
out_shardings=out_shardings, out_shardings=out_shardings,
static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1), static_argnums=range(
len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1
),
) # +1 for multi_gpus ) # +1 for multi_gpus
multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)
assert_allclose(multi_fwd, single_fwd, dtype=dtype) fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn
bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)
for i in range(len(inputs)): for i in range(len(inputs)):
if multi_grads[i] is not None: if multi_grads[i] is not None:
if isinstance(multi_grads[i], list): if isinstance(multi_grads[i], list):
assert isinstance(single_grads[i], list) assert isinstance(single_grads[i], list)
for m_grad, s_grad in zip(multi_grads[i], single_grads[i]): for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
assert_allclose( assert_allclose(
m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close" m_grad,
s_grad,
dtype=bwd_test_type,
err_msg=f"multi_grads[{i}] is not close",
) )
else: else:
assert_allclose( assert_allclose(
multi_grads[i], multi_grads[i],
single_grads[i], single_grads[i],
dtype=dtype, dtype=bwd_test_type,
err_msg=f"multi_grads[{i}] is not close", err_msg=f"multi_grads[{i}] is not close",
) )
...@@ -233,8 +253,16 @@ class TestDistributedLayernormMLP: ...@@ -233,8 +253,16 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad( def test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
with_jax_gemm,
): ):
self._test_layernorm_mlp_grad( self._test_layernorm_mlp_grad(
mesh_config, mesh_config,
...@@ -244,6 +272,7 @@ class TestDistributedLayernormMLP: ...@@ -244,6 +272,7 @@ class TestDistributedLayernormMLP:
dtype, dtype,
fp8_recipe, fp8_recipe,
use_shardy=False, use_shardy=False,
with_jax_gemm=with_jax_gemm,
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
...@@ -252,19 +281,29 @@ class TestDistributedLayernormMLP: ...@@ -252,19 +281,29 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad_shardy( def test_layernorm_mlp_grad_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
with_jax_gemm,
): ):
# We don't test block scaling with Shardy because at the time of writing, if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
# it is not supported in JAX's scaled_matmul_stablehlo. pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
self._test_layernorm_mlp_grad( self._test_layernorm_mlp_grad(
mesh_config, mesh_config,
activation_type, activation_type,
use_bias, use_bias,
input_shape, input_shape,
dtype, dtype,
fp8_recipe=recipe.DelayedScaling(), fp8_recipe=fp8_recipe,
use_shardy=True, use_shardy=True,
with_jax_gemm=with_jax_gemm,
) )
def _test_layernorm_mlp( def _test_layernorm_mlp(
...@@ -277,6 +316,7 @@ class TestDistributedLayernormMLP: ...@@ -277,6 +316,7 @@ class TestDistributedLayernormMLP:
use_fp8, use_fp8,
fp8_recipe, fp8_recipe,
use_shardy, use_shardy,
with_jax_gemm,
): ):
jax.config.update("jax_use_shardy_partitioner", use_shardy) jax.config.update("jax_use_shardy_partitioner", use_shardy)
batch, seqlen, hidden_in = input_shape batch, seqlen, hidden_in = input_shape
...@@ -288,6 +328,7 @@ class TestDistributedLayernormMLP: ...@@ -288,6 +328,7 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
init_rngs = {"params": subkeys[1]} init_rngs = {"params": subkeys[1]}
with use_jax_gemm(enabled=with_jax_gemm):
# Single GPUs # Single GPUs
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
ln_mlp_single = LayerNormMLP( ln_mlp_single = LayerNormMLP(
...@@ -355,9 +396,9 @@ class TestDistributedLayernormMLP: ...@@ -355,9 +396,9 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("use_shardy", [False, True]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer( def test_layernorm_mlp_layer(
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_shardy self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
): ):
self._test_layernorm_mlp( self._test_layernorm_mlp(
mesh_config, mesh_config,
...@@ -367,7 +408,8 @@ class TestDistributedLayernormMLP: ...@@ -367,7 +408,8 @@ class TestDistributedLayernormMLP:
dtype, dtype,
use_fp8=False, use_fp8=False,
fp8_recipe=None, fp8_recipe=None,
use_shardy=use_shardy, use_shardy=False,
with_jax_gemm=with_jax_gemm,
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
...@@ -377,8 +419,9 @@ class TestDistributedLayernormMLP: ...@@ -377,8 +419,9 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_fp8( def test_layernorm_mlp_layer_fp8(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
): ):
self._test_layernorm_mlp( self._test_layernorm_mlp(
mesh_config, mesh_config,
...@@ -389,4 +432,51 @@ class TestDistributedLayernormMLP: ...@@ -389,4 +432,51 @@ class TestDistributedLayernormMLP:
use_fp8=True, use_fp8=True,
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
use_shardy=False, use_shardy=False,
with_jax_gemm=with_jax_gemm,
)
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
):
self._test_layernorm_mlp(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=False,
fp8_recipe=None,
use_shardy=True,
with_jax_gemm=with_jax_gemm,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_fp8_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
):
if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
self._test_layernorm_mlp(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=True,
fp8_recipe=fp8_recipe,
use_shardy=True,
with_jax_gemm=with_jax_gemm,
) )
...@@ -3,11 +3,12 @@ ...@@ -3,11 +3,12 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Utility for the TE layer tests""" """Utility for the TE layer tests"""
import os
import functools import functools
import math import math
import operator import operator
from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional, NewType
import os from contextlib import contextmanager
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -20,7 +21,6 @@ from jax import random as jax_random ...@@ -20,7 +21,6 @@ from jax import random as jax_random
import pytest import pytest
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
AttnMaskType,
canonicalize_attn_mask_type, canonicalize_attn_mask_type,
make_swa_mask, make_swa_mask,
) )
...@@ -28,8 +28,8 @@ from transformer_engine.jax.quantize.helper import DType as TEDType ...@@ -28,8 +28,8 @@ from transformer_engine.jax.quantize.helper import DType as TEDType
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
DType = jnp.dtype DType = NewType("DType", jnp.dtype)
Array = Any Array = NewType("Array", jnp.ndarray)
PrecisionLike = Union[ PrecisionLike = Union[
None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
] ]
...@@ -1519,7 +1519,7 @@ def dtype_tols( ...@@ -1519,7 +1519,7 @@ def dtype_tols(
TEDType.kFloat8E5M2: jnp.float8_e5m2, TEDType.kFloat8E5M2: jnp.float8_e5m2,
}[dtype] }[dtype]
elif isinstance(dtype, np.dtype): elif isinstance(dtype, np.dtype):
dtype = jnp.dtype(dtype) dtype = DType(dtype)
# Expect bit-wise accuracy for integer dtypes # Expect bit-wise accuracy for integer dtypes
if not jnp.issubdtype(dtype, jnp.floating): if not jnp.issubdtype(dtype, jnp.floating):
...@@ -1600,3 +1600,20 @@ def print_debug_tensor_stats(prefix, tensor, hist=False): ...@@ -1600,3 +1600,20 @@ def print_debug_tensor_stats(prefix, tensor, hist=False):
fmt = fmt + "\n {}\n {}" fmt = fmt + "\n {}\n {}"
jax.debug.print(fmt, *args) jax.debug.print(fmt, *args)
@contextmanager
def use_jax_gemm(enabled=False):
orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None)
try:
if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$"
yield
finally:
if enabled:
if orig_custom_calls_filter is None:
os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE")
else:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls_filter
...@@ -415,37 +415,35 @@ class ActLuPrimitive(BasePrimitive): ...@@ -415,37 +415,35 @@ class ActLuPrimitive(BasePrimitive):
result_types, result_types,
): ):
del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types
prefix = "ActLuPrimitive_"
x_rank = len(value_types[0].shape) x_rank = len(value_types[0].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank - 1, unique_var="ActLuPrimitive_i", flatten_axis=-2 x_rank - 1, unique_var=prefix + "x", flatten_axis=-2
) )
x_axes = scale_rules.input_spec + (f"x{x_rank-1}",) x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}",)
out = (*x_axes[:-2], x_axes[-1]) out = (*x_axes[:-2], x_axes[-1])
scale_inv = scale_rules.rowwise_rule scale_inv = scale_rules.rowwise_rule
colwise_scale_inv = scale_rules.colwise_rule
colwise_out = (prefix + "out_colwise",)
colwise_scale_inv = (prefix + "scale_inv_colwise",)
if is_2x: if is_2x:
colwise_scale_inv = scale_rules.colwise_rule
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple( colwise_out = tuple(
multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2) multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2)
) )
else: else:
colwise_out = out colwise_out = out
else:
colwise_out = ("j",)
colwise_scale_inv = ("k",)
# amax is always a unit tensor. # amax is always a unit tensor.
amax = ("l",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
( (
x_axes, x_axes,
"…1", ("…1",),
), ),
(out, colwise_out, scale_inv, colwise_scale_inv, amax), (out, colwise_out, scale_inv, colwise_scale_inv, amax),
**scale_rules.factor_sizes,
) )
...@@ -890,28 +888,26 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -890,28 +888,26 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
result_types, result_types,
): ):
del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types
prefix = "BaseDActLuDBiasQuantizePrimitive_"
x_rank = len(value_types[1].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank, unique_var="BaseDActLuDBiasQuantizePrimitive_i", flatten_axis=-2 len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2
) )
x_axes = scale_rules.input_spec x_axes = scale_rules.input_spec
dz_axes = (*x_axes[:-2], x_axes[-1])
out = x_axes out = x_axes
colwise_out = (prefix + "out_colwise",)
if is_2x: if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
else: else:
colwise_out = tuple(x_axes) colwise_out = out
else:
colwise_out = ("j",)
dbias = x_axes[-2:] if is_dbias else ("k",) dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",)
amax = ("…4",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
(("…0",), tuple(x_axes), ("…2",)), (dz_axes, x_axes, ("…2",)),
(out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias),
**scale_rules.factor_sizes,
) )
...@@ -985,6 +981,7 @@ def act_lu( ...@@ -985,6 +981,7 @@ def act_lu(
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
noop_scaled_tensor: bool = False,
) -> Union[jnp.ndarray, ScaledTensor]: ) -> Union[jnp.ndarray, ScaledTensor]:
"""Activation with optional quantization. """Activation with optional quantization.
...@@ -993,6 +990,7 @@ def act_lu( ...@@ -993,6 +990,7 @@ def act_lu(
Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function to apply. activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
If quantizer is None: If quantizer is None:
...@@ -1037,6 +1035,10 @@ def act_lu( ...@@ -1037,6 +1035,10 @@ def act_lu(
is_outer=True, is_outer=True,
) )
out = out.reshape(output_shape) out = out.reshape(output_shape)
if noop_scaled_tensor:
return ScaledTensorFactory.create_2x(
out, None, out, None, ScalingMode.NO_SCALING, dq_dtype=out.dtype
)
return out return out
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
...@@ -1090,6 +1092,7 @@ def quantize_dact_dbias( ...@@ -1090,6 +1092,7 @@ def quantize_dact_dbias(
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
is_dbias: bool = True, is_dbias: bool = True,
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor, jnp.ndarray]: ) -> Tuple[ScaledTensor, jnp.ndarray]:
"""Compute gradients of activation and bias with optional quantization. """Compute gradients of activation and bias with optional quantization.
...@@ -1100,6 +1103,7 @@ def quantize_dact_dbias( ...@@ -1100,6 +1103,7 @@ def quantize_dact_dbias(
activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",). activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
is_dbias: If True, compute bias gradient. Defaults to True. is_dbias: If True, compute bias gradient. Defaults to True.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
Tuple[ScaledTensor, jnp.ndarray]: A tuple containing: Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
...@@ -1113,13 +1117,49 @@ def quantize_dact_dbias( ...@@ -1113,13 +1117,49 @@ def quantize_dact_dbias(
f" {x.shape} and act_len {act_len}" f" {x.shape} and act_len {act_len}"
) )
scale = jnp.empty((), jnp.float32)
act_type_id = ActivationEnum[activation_type]
PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive
if not PrimitiveClass.enabled(): if not PrimitiveClass.enabled() or (
quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE
):
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
# TE/common does not support colwise-only quantization yet if quantizer is None:
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) dz,
x,
scale,
# outputs float32 for dbias accumulation
out_dtype=(jnp.float32 if is_dbias else x.dtype),
# default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, # unused
scale_dtype=jnp.float32, # unused
is_dbias=False,
act_enum=act_type_id,
act_len=act_len,
is_outer=True,
)
output = output.astype(x.dtype)
dbias = None
if is_dbias:
dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
if noop_scaled_tensor:
return (
ScaledTensorFactory.create_2x(
output,
None,
output,
None,
ScalingMode.NO_SCALING,
dq_dtype=output.dtype,
),
dbias,
)
return output, dbias
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
...@@ -1145,31 +1185,6 @@ def quantize_dact_dbias( ...@@ -1145,31 +1185,6 @@ def quantize_dact_dbias(
if war_output is not None: if war_output is not None:
return war_output return war_output
scale = jnp.empty((), jnp.float32)
act_type_id = ActivationEnum[activation_type]
if quantizer is None:
output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
dz,
x,
scale,
# outputs float32 for dbias accumulation
out_dtype=(jnp.float32 if is_dbias else x.dtype),
# default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, # unused
scale_dtype=jnp.float32, # unused
is_dbias=False,
act_enum=act_type_id,
act_len=act_len,
is_outer=True,
)
dbias = None
if is_dbias:
dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
return output.astype(x.dtype), dbias
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after. # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out = dact_lu( out = dact_lu(
...@@ -1183,7 +1198,7 @@ def quantize_dact_dbias( ...@@ -1183,7 +1198,7 @@ def quantize_dact_dbias(
) )
return out, dbias return out, dbias
if isinstance(quantizer, DelayedScaleQuantizer): if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale scale = quantizer.scale
# TE/common dact_dbias_quantize does not support gated act yet # TE/common dact_dbias_quantize does not support gated act yet
...@@ -1243,6 +1258,7 @@ def dact_lu( ...@@ -1243,6 +1258,7 @@ def dact_lu(
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
noop_scale_tensor: bool = False,
) -> Union[jnp.ndarray, ScaledTensor]: ) -> Union[jnp.ndarray, ScaledTensor]:
""" """
Backward pass for activation with optional quantization. Backward pass for activation with optional quantization.
...@@ -1252,6 +1268,7 @@ def dact_lu( ...@@ -1252,6 +1268,7 @@ def dact_lu(
x: Input tensor that was used in forward pass. x: Input tensor that was used in forward pass.
activation_type: Type of activation function that was applied. activation_type: Type of activation function that was applied.
quantizer: Optional quantizer for FP8 quantization of the output gradient. quantizer: Optional quantizer for FP8 quantization of the output gradient.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
The gradient of the activation with respect to the input. The gradient of the activation with respect to the input.
...@@ -1262,5 +1279,6 @@ def dact_lu( ...@@ -1262,5 +1279,6 @@ def dact_lu(
activation_type=activation_type, activation_type=activation_type,
is_dbias=False, is_dbias=False,
quantizer=quantizer, quantizer=quantizer,
noop_scaled_tensor=noop_scale_tensor,
) )
return output return output
...@@ -198,14 +198,19 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant ...@@ -198,14 +198,19 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant
Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to
calculate dbias separately. This function checks if the workaround should be applied. calculate dbias separately. This function checks if the workaround should be applied.
""" """
if quantizer is None:
return False
arch_l_100 = False arch_l_100 = False
for local_gpu_id in range(len(jax.local_devices())): for local_gpu_id in range(len(jax.local_devices())):
if transformer_engine_jax.get_device_compute_capability(local_gpu_id) < 100: if transformer_engine_jax.get_device_compute_capability(local_gpu_id) < 100:
arch_l_100 = True arch_l_100 = True
break break
# _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE,
# but this fails when bias fusion is turned on with arch < 100.
force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
return ( return (
quantizer is not None (force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE)
and quantizer.q_layout == QuantizeLayout.ROWWISE
and arch_l_100 and arch_l_100
and is_dbias and is_dbias
) )
......
...@@ -587,16 +587,17 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -587,16 +587,17 @@ class NormFwdPrimitive(BasePrimitive):
result_types, result_types,
) )
prefix = "NormFwdPrimitive_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var="NormFwdPrimitive_i", flatten_axis=-1 len(value_types[0].shape), unique_var=prefix + "x", flatten_axis=-1
) )
x_axes = scale_rules.input_spec x_axes = scale_rules.input_spec
out = x_axes[:-1] + ("k",) out = x_axes
colwise_out = out if is_2x else ("…4",) colwise_out = out if is_2x else (prefix + "out_colwise",)
rsigma = x_axes[:-1] rsigma = x_axes[:-1]
mu = ("…5",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma mu = (prefix + "mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma
amax = ("…6",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
(x_axes, ("…1",), ("…2",), ("…3",)), (x_axes, ("…1",), ("…2",), ("…3",)),
...@@ -609,7 +610,6 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -609,7 +610,6 @@ class NormFwdPrimitive(BasePrimitive):
mu, mu,
rsigma, rsigma,
), ),
**scale_rules.factor_sizes,
) )
...@@ -1276,6 +1276,7 @@ def normalization_fwd( ...@@ -1276,6 +1276,7 @@ def normalization_fwd(
epsilon: float, epsilon: float,
norm_type: str, norm_type: str,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
noop_scaled_tensor: bool = False,
): ):
"""Common wrapper for normalization forward pass. """Common wrapper for normalization forward pass.
...@@ -1292,6 +1293,7 @@ def normalization_fwd( ...@@ -1292,6 +1293,7 @@ def normalization_fwd(
- 'layernorm': Layer normalization - 'layernorm': Layer normalization
- 'rmsnorm': Root mean square normalization - 'rmsnorm': Root mean square normalization
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -1319,6 +1321,15 @@ def normalization_fwd( ...@@ -1319,6 +1321,15 @@ def normalization_fwd(
else: else:
raise ValueError(f"{norm_type=} is not supported.") raise ValueError(f"{norm_type=} is not supported.")
if quantizer is None and noop_scaled_tensor:
return (
ScaledTensorFactory.create_2x(
output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype
),
mu,
rsigma,
)
return output, mu, rsigma return output, mu, rsigma
......
...@@ -36,7 +36,6 @@ from ..quantize import ( ...@@ -36,7 +36,6 @@ from ..quantize import (
Quantizer, Quantizer,
GroupedQuantizer, GroupedQuantizer,
QuantizeLayout, QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode, ScalingMode,
compute_scale_from_amax, compute_scale_from_amax,
) )
...@@ -489,9 +488,10 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -489,9 +488,10 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
): ):
del out_dtype, scale_dtype, is_outer, mesh, result_types del out_dtype, scale_dtype, is_outer, mesh, result_types
prefix = "BaseDBiasQuantizePrimitive_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), len(value_types[0].shape),
unique_var="BaseDBiasQuantizePrimitive_i", unique_var=prefix + "x",
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
) )
...@@ -499,22 +499,19 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -499,22 +499,19 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
colwise_scale_inv = scale_rules.colwise_rule colwise_scale_inv = scale_rules.colwise_rule
out = x_axes out = x_axes
colwise_out = (prefix + "out_colwise",)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if ScalingMode(scaling_mode).is_tensor_scaling(): if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis)) colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
else: else:
colwise_out = x_axes colwise_out = x_axes
else:
colwise_out = ("j",)
colwise_scale_inv = ("k",)
dbias = x_axes[flatten_axis:] if is_dbias else ("l",) dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",)
amax = ("m",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
(x_axes, ("…1",)), (x_axes, ("…1",)),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
**scale_rules.factor_sizes,
) )
...@@ -538,11 +535,12 @@ def _jax_quantize( ...@@ -538,11 +535,12 @@ def _jax_quantize(
def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1): def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
assert flatten_axis < 0 sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
assert sum_axis < dx.ndim, "Flatten axis out of bounds!"
dtype = dtype or dx.dtype dtype = dtype or dx.dtype
dbias = jnp.sum( dbias = jnp.sum(
dx.astype(jnp.float32), dx.astype(jnp.float32),
axis=tuple(range(dx.ndim + flatten_axis)), axis=tuple(range(sum_axis)),
keepdims=False, keepdims=False,
) )
return dbias.astype(dtype) return dbias.astype(dtype)
...@@ -568,6 +566,7 @@ def _quantize_dbias_impl( ...@@ -568,6 +566,7 @@ def _quantize_dbias_impl(
is_dbias: bool = False, is_dbias: bool = False,
dq_dtype: Optional[jnp.dtype] = None, dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1, flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
""" """
Cast wrapper Cast wrapper
...@@ -577,24 +576,34 @@ def _quantize_dbias_impl( ...@@ -577,24 +576,34 @@ def _quantize_dbias_impl(
quantizer is not None quantizer is not None
), "quantizer must be provided if dq_dtype is provided" ), "quantizer must be provided if dq_dtype is provided"
# Early-exit for non-quantized call
dq_dtype = dq_dtype or x.dtype dq_dtype = dq_dtype or x.dtype
if quantizer is None:
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive dbias = None
if not PrimitiveClass.enabled():
if is_dbias: if is_dbias:
return _jax_quantize_dbias( dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
x, if noop_scaled_tensor:
quantizer=quantizer, # Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor()
dq_dtype=dq_dtype, # always works.
flatten_axis=flatten_axis,
)
return ( return (
_jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), ScaledTensorFactory.create_2x(
x,
None, None,
x,
None,
ScalingMode.NO_SCALING,
dq_dtype=x.dtype,
data_layout="NN",
flatten_axis=flatten_axis,
),
dbias,
) )
return x, dbias
# TE/common doesn't support colwise only quantization yet # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: # fall back on the native-JAX quantize implementation
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
if quantizer.q_layout == QuantizeLayout.COLWISE or not PrimitiveClass.enabled():
if is_dbias: if is_dbias:
return _jax_quantize_dbias( return _jax_quantize_dbias(
x, x,
...@@ -606,9 +615,8 @@ def _quantize_dbias_impl( ...@@ -606,9 +615,8 @@ def _quantize_dbias_impl(
_jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
None, None,
) )
scale = jnp.empty((), jnp.float32)
# TE/common dbias_quantize does not support 1x on arch < 100 # TE/common custom quantize op does not support dbias fusion with 1x quantization on arch < 100
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out, _ = _quantize_dbias_impl( out, _ = _quantize_dbias_impl(
x=x, x=x,
...@@ -620,29 +628,23 @@ def _quantize_dbias_impl( ...@@ -620,29 +628,23 @@ def _quantize_dbias_impl(
dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return out, dbias return out, dbias
if quantizer is None: scale = jnp.empty((), jnp.float32)
if is_dbias:
return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return x, None
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Globally reduce amax across all devices for current scaling so we have a single global scale. # Globally reduce amax across all devices for current scaling so we have a single global scale.
# This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
# until the tensor is dequantized (e.g. in the GEMM). # until the tensor is dequantized (e.g. in the GEMM).
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32) amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32)
scale = compute_scale_from_amax(amax, quantizer.q_dtype) scale = compute_scale_from_amax(amax, quantizer.q_dtype)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale scale = quantizer.scale
is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
# It is faster to use 1x quantization for tensor scaling # It is faster to use 1x quantization for tensor scaling
is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
force_1x_quantization = ( force_1x_quantization = (
quantizer.scaling_mode.is_tensor_scaling() quantizer.scaling_mode.is_tensor_scaling()
and quantizer.is_2x2x() and quantizer.is_2x2x()
and is_1x_kernel_supported and is_1x_kernel_supported
) )
q_layout = quantizer.q_layout q_layout = quantizer.q_layout
if force_1x_quantization: if force_1x_quantization:
q_layout = QuantizeLayout.ROWWISE q_layout = QuantizeLayout.ROWWISE
...@@ -698,6 +700,7 @@ def quantize( ...@@ -698,6 +700,7 @@ def quantize(
x: jnp.ndarray, x: jnp.ndarray,
quantizer: Quantizer, quantizer: Quantizer,
flatten_axis: int = -1, flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor]: ) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer. """Quantize input tensor according to the quantizer.
...@@ -707,6 +710,8 @@ def quantize( ...@@ -707,6 +710,8 @@ def quantize(
quantizer: Quantizer for FP8 quantization of the output. quantizer: Quantizer for FP8 quantization of the output.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1. Defaults to -1.
noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer
is None.
Returns: Returns:
A ScaledTensor containing the quantized input tensor. A ScaledTensor containing the quantized input tensor.
...@@ -715,6 +720,7 @@ def quantize( ...@@ -715,6 +720,7 @@ def quantize(
x, x,
quantizer=quantizer, quantizer=quantizer,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
noop_scaled_tensor=noop_scaled_tensor,
) )
return out return out
...@@ -724,6 +730,7 @@ def quantize_dbias( ...@@ -724,6 +730,7 @@ def quantize_dbias(
quantizer: Quantizer, quantizer: Quantizer,
is_dbias: bool = True, is_dbias: bool = True,
flatten_axis: int = -1, flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient. """Quantize input tensor and compute bias gradient.
...@@ -734,6 +741,8 @@ def quantize_dbias( ...@@ -734,6 +741,8 @@ def quantize_dbias(
is_dbias: If True, compute bias gradient. Defaults to True. is_dbias: If True, compute bias gradient. Defaults to True.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1. Defaults to -1.
noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when
quantizer is None.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -743,7 +752,11 @@ def quantize_dbias( ...@@ -743,7 +752,11 @@ def quantize_dbias(
Shape: (K,) or empty if is_dbias is False. Shape: (K,) or empty if is_dbias is False.
""" """
return _quantize_dbias_impl( return _quantize_dbias_impl(
dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis dz,
quantizer=quantizer,
is_dbias=is_dbias,
flatten_axis=flatten_axis,
noop_scaled_tensor=noop_scaled_tensor,
) )
......
...@@ -119,6 +119,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -119,6 +119,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right); int64_t window_size_right);
// GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
// Grouped GEMM // Grouped GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
......
...@@ -38,12 +38,11 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { ...@@ -38,12 +38,11 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
case xla::ffi::DataType::F8E4M3FN: case xla::ffi::DataType::F8E4M3FN:
return DType::kFloat8E4M3; return DType::kFloat8E4M3;
break; break;
// case xla::ffi::DataType::F8E8M0FNU: case xla::ffi::DataType::F8E8M0FNU:
// return DType::kFloat8E8M0; return DType::kFloat8E8M0;
// break; break;
default: default:
auto type_num = static_cast<XLA_FFI_DataType>(type); auto type_num = static_cast<XLA_FFI_DataType>(type);
if (type_num == 33) return DType::kFloat8E8M0;
NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d", NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d",
static_cast<int>(type_num)); static_cast<int>(type_num));
break; break;
......
...@@ -6,11 +6,13 @@ ...@@ -6,11 +6,13 @@
#include "transformer_engine/gemm.h" #include "transformer_engine/gemm.h"
#include <memory> #include <memory>
#include <string_view>
#include <tuple>
#include "../extensions.h" #include "../extensions.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#include "common/util/string.h"
#include "common/util/system.h" #include "common/util/system.h"
#include "transformer_engine/multi_stream.h"
#include "transformer_engine/swizzle.h" #include "transformer_engine/swizzle.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
...@@ -25,6 +27,181 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { ...@@ -25,6 +27,181 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) {
~static_cast<uintptr_t>(255)); ~static_cast<uintptr_t>(255));
} }
std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv,
JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) {
// Set tensor data with collapsed 2D shape
auto buffer_dims = buffer.dimensions();
std::vector<size_t> input_shape = {product(buffer_dims, 0, axis_boundary),
product(buffer_dims, axis_boundary, buffer_dims.size())};
auto input_dtype = convert_ffi_datatype_to_te_dtype(buffer.element_type());
TensorWrapper input(get_nvte_scaling_mode(scaling_mode));
if (rowwise) {
input.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape);
} else {
input.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape);
}
// Set scaling factor for quantized tensors
if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands.");
NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM.");
std::vector<size_t> scale_shape = {1};
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
// Block scaling also needs to be collapsed to match 2D data
scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary),
product(scale_inv.dimensions(), axis_boundary, scale_inv.dimensions().size())};
}
auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type());
if (rowwise) {
input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
} else {
input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
}
// Swizzle scaling factors for MXFP8
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
// Get the swizzle buffer
NVTE_CHECK(swizzled_scale_inv->element_count() > 0,
"Missing swizzled inverse scale buffer in the JAX primitive.");
auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type());
auto swizzled_scale_inv_dtype =
convert_ffi_datatype_to_te_dtype(swizzled_scale_inv->element_type());
NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1,
"Inverse scale factors need to have an 8-bit data type.");
// Create tensor to hold swizzled scale factor
TensorWrapper output(get_nvte_scaling_mode(scaling_mode));
if (rowwise) {
output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape);
output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape);
} else {
output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape);
output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype,
scale_shape);
}
// Launch swizzle kernel
nvte_swizzle_scaling_factors(input.data(), output.data(), stream);
// Set swizzled scales into the input tensor
if (rowwise) {
input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape);
} else {
input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype,
scale_shape);
}
}
}
return std::make_tuple(std::move(input), input_shape);
}
Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input,
Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out,
Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace,
JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed,
bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) {
// Operands (this includes swizzling MXFP8 scaling factors)
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
(is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported()));
bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed;
bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed;
auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(
stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise);
auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(
stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise);
// Output tensor
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
(rhs_transposed) ? rhs_shape[0] : rhs_shape[1]};
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, "
"expected ",
out_.numel(), " elements ", to_string_like(out_shape), " but got ",
output->element_count(), " elements ", to_string_like(output->dimensions()));
// Bias input to forward pass or bias gradient output from backward pass
void *bias_ptr = nullptr;
std::vector<size_t> bias_shape = {0};
DType bias_dtype = out_dtype;
if (fuse_bias) {
if (!grad) {
NVTE_CHECK(bias_grad->untyped_data() == bias.untyped_data(),
"Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad");
}
bias_ptr = bias_grad->untyped_data();
bias_shape.at(0) = bias_grad->dimensions().front();
bias_dtype = convert_ffi_datatype_to_te_dtype(bias_grad->element_type());
}
auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype);
// Pre-GeLU output from forward pass or input to backward pass
void *pre_gelu_ptr = nullptr;
std::vector<size_t> pre_gelu_shape = {0};
DType pre_gelu_dtype = out_dtype;
if (gelu_input.element_count() > 0) {
if (grad) {
NVTE_CHECK(pre_gelu_out->untyped_data() == gelu_input.untyped_data(),
"Missing operand-output aliasing in GemmPrimitive: gelu_input <-> pre_gelu_out");
}
pre_gelu_ptr = pre_gelu_out->untyped_data();
pre_gelu_shape = {product(pre_gelu_out->dimensions(), 0, pre_gelu_out->dimensions().size() - 1),
static_cast<size_t>(pre_gelu_out->dimensions().back())};
pre_gelu_dtype = convert_ffi_datatype_to_te_dtype(pre_gelu_out->element_type());
}
auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype);
// cuBLAS workspace + 256 alignment enforcement
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr);
std::vector<size_t> workspace_shape = {static_cast<size_t>(workspace->element_count()) - 256};
auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte);
// Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(),
rhs_transposed, lhs_transposed, grad, workspace_.data(), false,
use_split_accumulator, num_math_sm, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // lhs
.Arg<Buffer_Type>() // lhs_scale_inv
.Arg<Buffer_Type>() // rhs
.Arg<Buffer_Type>() // rhs_scale_inv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // gelu_input
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out
.Ret<Buffer_Type>() // lhs_swizzled
.Ret<Buffer_Type>() // rhs_swizzled
.Ret<Buffer_Type>() // workspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("lhs_axis_boundary")
.Attr<int64_t>("rhs_axis_boundary")
.Attr<bool>("lhs_transposed")
.Attr<bool>("rhs_transposed")
.Attr<bool>("fuse_bias")
.Attr<bool>("fuse_gelu")
.Attr<bool>("grad")
.Attr<bool>("use_split_accumulator"),
FFI_CudaGraph_Traits);
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv,
Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias,
Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output,
......
...@@ -47,6 +47,15 @@ enum class JAXX_Scaling_Mode : int64_t { ...@@ -47,6 +47,15 @@ enum class JAXX_Scaling_Mode : int64_t {
CURRENT_TENSOR_SCALING = 3, CURRENT_TENSOR_SCALING = 3,
}; };
inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) {
return (mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING ||
mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING);
}
inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) {
return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING);
}
static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
switch (mode) { switch (mode) {
case JAXX_Scaling_Mode::NO_SCALING: case JAXX_Scaling_Mode::NO_SCALING:
......
...@@ -55,6 +55,11 @@ pybind11::dict Registrations() { ...@@ -55,6 +55,11 @@ pybind11::dict Registrations() {
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler)); pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler));
// GEMM
dict["te_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GemmHandler));
// Grouped GEMM // Grouped GEMM
dict["te_grouped_gemm_ffi"] = dict["te_grouped_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
...@@ -78,6 +83,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -78,6 +83,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("nvte_get_qkv_format", &nvte_get_qkv_format);
m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported);
pybind11::enum_<DType>(m, "DType", pybind11::module_local()) pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
......
...@@ -8,7 +8,7 @@ architectures, including support for quantization and automatic differentiation. ...@@ -8,7 +8,7 @@ architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations. customizable contracting dimensions for flexible tensor operations.
""" """
import warnings
from typing import Tuple, Sequence from typing import Tuple, Sequence
from functools import partial from functools import partial
import jax import jax
...@@ -23,6 +23,16 @@ from .quantize import ( ...@@ -23,6 +23,16 @@ from .quantize import (
) )
DENSE_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global DENSE_BATCH_FIRST_WARNING_ISSUED
if not DENSE_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
DENSE_BATCH_FIRST_WARNING_ISSUED = True
def dense( def dense(
x: jnp.ndarray, x: jnp.ndarray,
kernel: jnp.ndarray, kernel: jnp.ndarray,
...@@ -30,6 +40,7 @@ def dense( ...@@ -30,6 +40,7 @@ def dense(
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
input_axes: Tuple[str, ...] = None, input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
): ):
"""Perform dense layer transformation with optional quantization. """Perform dense layer transformation with optional quantization.
...@@ -43,25 +54,28 @@ def dense( ...@@ -43,25 +54,28 @@ def dense(
kernel: Weight matrix for the dense layer transformation kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract contracting_dims: Tuple of sequences specifying which dimensions to contract
batch_first: Assume that X is batched in the first dimension.
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
Transformed output tensor Transformed output tensor
""" """
# Remove when tex.quantize() can handle quantizer=None # Remove when tex.quantize() can handle quantizer=None
if quantizer_set == noop_quantizer_set: if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot():
x = with_sharding_constraint_by_logical_axes(x, input_axes) x = with_sharding_constraint_by_logical_axes(x, input_axes)
output = tex.gemm(x, kernel, contracting_dims) output = tex.gemm(x, kernel, contracting_dims=contracting_dims)
if bias is not None: if bias is not None:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape) output += jnp.reshape(bias, bias_new_shape)
else: else:
output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set) output = _dense(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6))
def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set):
"""Internal implementation of dense layer transformation with custom VJP. """Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support This function implements the core dense layer transformation logic with support
...@@ -75,44 +89,81 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer ...@@ -75,44 +89,81 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer
input_axes: Logical axes for sharding the activation input input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
Returns: Returns:
Transformed output tensor Transformed output tensor
""" """
output, _ = _dense_fwd_rule( output, _ = _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
) )
return output return output
def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): def _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
):
"""Forward pass rule for dense layer transformation. """Forward pass rule for dense layer transformation.
Returns: Returns:
Tuple of (output, context) for backward pass Tuple of (output, context) for backward pass
""" """
x_contracting_dims, k_contracting_dims = contracting_dims x_contracting_dims, k_contracting_dims = map(
tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims
)
# Check supported input layout
x_is_transposed = x.ndim - 1 not in x_contracting_dims
k_is_transposed = kernel.ndim - 1 in k_contracting_dims
assert (
not x_is_transposed and not k_is_transposed
), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel."
# Determine X batch dimension
# - If `batch_first=True` -> (batch, leading..., contracting...)
# - Otherwise -> (leading..., batch, contracting...)
# NOTE: Always assume a single batch dimension
x_bdim = None
num_cdims = len(x_contracting_dims)
if x.ndim >= num_cdims + 2:
# Assume X is batched if it has at least +2 dimensions more than the number of contracting
# dimensions.
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `dense()` layer implementation does not officially support sequence-first "
"inputs and may produce incorrect results when `batch_first=False`. Use "
"sequence-first inputs at your own discretion.",
)
x_bdim = 0 if batch_first else x.ndim - num_cdims - 1
flatten_axis_x = -len(x_contracting_dims) flatten_axis_x = -len(x_contracting_dims)
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)
casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x) casted_x = tex.quantize(
x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, noop_scaled_tensor=True
)
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
casted_kernel = tex.quantize( casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel kernel,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.kernel,
noop_scaled_tensor=True,
) )
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# GEMM NN # GEMM NN
use_bias = bias is not None
output = tex.gemm( output = tex.gemm(
casted_x.get_tensor(usage=TensorUsage.LHS), casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_tensor(usage=TensorUsage.RHS), casted_kernel.get_tensor(usage=TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
) )
use_bias = bias is not None if use_bias and tex.gemm_uses_jax_dot():
if use_bias:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape) output += jnp.reshape(bias, bias_new_shape)
...@@ -124,20 +175,19 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, ...@@ -124,20 +175,19 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes,
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis_k, flatten_axis_k,
x_bdim,
) )
return output, ctx return output, ctx
def _dense_bwd_rule( def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, ctx, grad contracting_dims, input_axes, kernel_axes, batch_first, ctx, grad
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation. """Backward pass rule for dense layer transformation.
Returns: Returns:
Tuple of gradients with respect to inputs Tuple of gradients with respect to inputs
""" """
fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims
( (
casted_x_lhs, casted_x_lhs,
casted_kernel_rhs, casted_kernel_rhs,
...@@ -146,10 +196,19 @@ def _dense_bwd_rule( ...@@ -146,10 +196,19 @@ def _dense_bwd_rule(
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis_k, flatten_axis_k,
x_bdim,
) = ctx ) = ctx
fwd_x_contracting_dims, fwd_k_contracting_dims = map(
tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims
)
casted_grad, dbias = tex.quantize_dbias( casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad grad,
is_dbias=use_bias,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad,
noop_scaled_tensor=True,
) )
# GEMM NT # GEMM NT
...@@ -164,7 +223,8 @@ def _dense_bwd_rule( ...@@ -164,7 +223,8 @@ def _dense_bwd_rule(
dgrad = tex.gemm( dgrad = tex.gemm(
casted_grad.get_tensor(usage=TensorUsage.LHS), casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs, casted_kernel_rhs,
(g_contracting_dim, k_contracting_dim), contracting_dims=(g_contracting_dim, k_contracting_dim),
batched_dims=((x_bdim,), ()),
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
...@@ -177,7 +237,8 @@ def _dense_bwd_rule( ...@@ -177,7 +237,8 @@ def _dense_bwd_rule(
wgrad = tex.gemm( wgrad = tex.gemm(
casted_x_lhs, casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS), casted_grad.get_tensor(usage=TensorUsage.RHS),
(x_contracting_dim, g_contracting_dim), contracting_dims=(x_contracting_dim, g_contracting_dim),
batched_dims=((x_bdim,), (x_bdim,)),
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
...@@ -6,7 +6,7 @@ Wrapper module for Transformer related layers with FP8 support. ...@@ -6,7 +6,7 @@ Wrapper module for Transformer related layers with FP8 support.
""" """
from functools import reduce from functools import reduce
import operator import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
import numpy as np import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
...@@ -15,12 +15,12 @@ from jax import lax ...@@ -15,12 +15,12 @@ from jax import lax
from jax import random as jax_random from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
from ..dense import dense from ..dense import dense, _issue_batch_first_warning as _dense_warning
from ..layernorm import canonicalize_norm_type from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm from ..layernorm import layernorm
from ..layernorm_dense import layernorm_dense from ..layernorm_dense import layernorm_dense, _issue_batch_first_warning as _ln_dense_warning
from ..layernorm_mlp import layernorm_mlp from ..layernorm_mlp import layernorm_mlp, _issue_batch_first_warning as _ln_mlp_warning
from ..activation import activation from ..activation import activation
from ..softmax import softmax, SoftmaxType from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes from ..sharding import with_sharding_constraint_by_logical_axes
...@@ -35,8 +35,8 @@ from ..sharding import get_non_contracting_logical_axes ...@@ -35,8 +35,8 @@ from ..sharding import get_non_contracting_logical_axes
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
DType = jnp.dtype DType = NewType("DType", jnp.dtype)
Array = jnp.ndarray Array = NewType("Array", jnp.ndarray)
PrecisionLike = Union[ PrecisionLike = Union[
None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
] ]
...@@ -441,6 +441,12 @@ class DenseGeneral(TransformerEngineBase): ...@@ -441,6 +441,12 @@ class DenseGeneral(TransformerEngineBase):
input_axes: Tuple[str, ...] = () input_axes: Tuple[str, ...] = ()
def __post_init__(self): def __post_init__(self):
if self.transpose_batch_sequence:
_dense_warning(
"TE/JAX DenseGeneral() module does not officially support sequence-first inputs "
"and may produce incorrect results when `transpose_batch_sequence=True`. Use "
"sequence-first inputs at your own discretion."
)
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype 1.0, "fan_in", "truncated_normal", dtype=self.dtype
...@@ -657,6 +663,12 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -657,6 +663,12 @@ class LayerNormDenseGeneral(TransformerEngineBase):
depth_scaling: float = None depth_scaling: float = None
def __post_init__(self): def __post_init__(self):
if self.transpose_batch_sequence:
_ln_dense_warning(
"TE/JAX LayerNormDenseGeneral() module does not officially support sequence-first "
"inputs and may produce incorrect results when `transpose_batch_sequence=True`. "
"Use sequence-first inputs at your own discretion."
)
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, 1.0,
...@@ -967,6 +979,12 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -967,6 +979,12 @@ class LayerNormMLP(TransformerEngineBase):
dot_2_input_axes: Tuple[str, ...] = None dot_2_input_axes: Tuple[str, ...] = None
def __post_init__(self): def __post_init__(self):
if self.transpose_batch_sequence:
_ln_mlp_warning(
"TE/JAX LayerNormMLP() module does not officially support sequence-first inputs "
"and may produce incorrect results when `transpose_batch_sequence=True`. Use "
"sequence-first inputs at your own discretion."
)
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype 1.0, "fan_in", "truncated_normal", dtype=self.dtype
......
...@@ -9,6 +9,7 @@ architectures. It supports various normalization types, quantization, and ...@@ -9,6 +9,7 @@ architectures. It supports various normalization types, quantization, and
distributed training through sharding constraints. distributed training through sharding constraints.
""" """
import warnings
from functools import partial from functools import partial
from typing import Tuple from typing import Tuple
...@@ -25,6 +26,16 @@ from .quantize import ( ...@@ -25,6 +26,16 @@ from .quantize import (
) )
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED
if not LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = True
def layernorm_dense( def layernorm_dense(
x: jnp.ndarray, x: jnp.ndarray,
kernel: jnp.ndarray, kernel: jnp.ndarray,
...@@ -37,6 +48,7 @@ def layernorm_dense( ...@@ -37,6 +48,7 @@ def layernorm_dense(
layernorm_input_axes: Tuple[str, ...] = None, layernorm_input_axes: Tuple[str, ...] = None,
dot_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Apply layer normalization followed by dense layer transformation. """Apply layer normalization followed by dense layer transformation.
...@@ -57,6 +69,7 @@ def layernorm_dense( ...@@ -57,6 +69,7 @@ def layernorm_dense(
layernorm_input_axes: Logical axes for sharding the layernorm input layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
quantizer_set: Set of quantizers for different tensor types quantizer_set: Set of quantizers for different tensor types
Returns: Returns:
...@@ -80,6 +93,7 @@ def layernorm_dense( ...@@ -80,6 +93,7 @@ def layernorm_dense(
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
batch_first,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -94,6 +108,7 @@ def layernorm_dense( ...@@ -94,6 +108,7 @@ def layernorm_dense(
8, 8,
9, 9,
10, 10,
11,
), ),
) )
def _layernorm_dense( def _layernorm_dense(
...@@ -108,6 +123,7 @@ def _layernorm_dense( ...@@ -108,6 +123,7 @@ def _layernorm_dense(
layernorm_input_axes: Tuple[str, ...], layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...],
batch_first: bool,
quantizer_set, quantizer_set,
): ):
"""Internal implementation of layernorm_dense with custom VJP. """Internal implementation of layernorm_dense with custom VJP.
...@@ -127,6 +143,7 @@ def _layernorm_dense( ...@@ -127,6 +143,7 @@ def _layernorm_dense(
epsilon: Small constant for numerical stability epsilon: Small constant for numerical stability
layernorm_input_axes: Logical axes for layernorm sharding layernorm_input_axes: Logical axes for layernorm sharding
dot_input_axes: Logical axes for matrix multiplication sharding dot_input_axes: Logical axes for matrix multiplication sharding
batch_first: Assume that X is batched in the first dimension.
quantizer_set: Set of quantizers quantizer_set: Set of quantizers
Returns: Returns:
...@@ -144,6 +161,7 @@ def _layernorm_dense( ...@@ -144,6 +161,7 @@ def _layernorm_dense(
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
batch_first,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -161,6 +179,7 @@ def _layernorm_dense_fwd_rule( ...@@ -161,6 +179,7 @@ def _layernorm_dense_fwd_rule(
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
batch_first,
quantizer_set, quantizer_set,
): ):
"""Forward pass rule for layernorm_dense. """Forward pass rule for layernorm_dense.
...@@ -178,6 +197,17 @@ def _layernorm_dense_fwd_rule( ...@@ -178,6 +197,17 @@ def _layernorm_dense_fwd_rule(
k_contracting_dims = (0,) k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0] assert x.shape[-1] == kernel.shape[0]
x_bdim = None
if x.ndim > 2:
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `layernorm_dense()` fused-layer implementation does not officially "
"support sequence-first inputs and may produce incorrect results when "
"`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
"inputs at your own discretion."
)
x_bdim = 0 if batch_first else x.ndim - 2
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
casted_ln_out, mu, rsigma = tex.normalization_fwd( casted_ln_out, mu, rsigma = tex.normalization_fwd(
...@@ -187,25 +217,31 @@ def _layernorm_dense_fwd_rule( ...@@ -187,25 +217,31 @@ def _layernorm_dense_fwd_rule(
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
norm_type, norm_type,
quantizer_set.x, quantizer=quantizer_set.x,
noop_scaled_tensor=True,
) )
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
# Kernel in (hidden_in, hidden_out...) # Kernel in (hidden_in, hidden_out...)
flatten_axis = 1 - len(kernel.shape) flatten_axis = 1 - len(kernel.shape)
casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel) casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# NN GEMM # NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out...) # (batch..., hidden_in) x (hidden_in, hidden_out...)
use_bias = bias is not None
output = tex.gemm( output = tex.gemm(
casted_ln_out.get_tensor(TensorUsage.LHS), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel.get_tensor(TensorUsage.RHS), casted_kernel.get_tensor(TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
) )
use_bias = bias is not None if use_bias and tex.gemm_uses_jax_dot():
if use_bias:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape) output += jnp.reshape(bias, bias_new_shape)
...@@ -224,6 +260,7 @@ def _layernorm_dense_fwd_rule( ...@@ -224,6 +260,7 @@ def _layernorm_dense_fwd_rule(
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis, flatten_axis,
x_bdim,
) )
return output, ctx return output, ctx
...@@ -236,6 +273,7 @@ def _layernorm_dense_bwd_rule( ...@@ -236,6 +273,7 @@ def _layernorm_dense_bwd_rule(
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument dot_input_axes, # pylint: disable=unused-argument
kernel_axes, kernel_axes,
batch_first, # pylint: disable=unused-argument
ctx, ctx,
grad, grad,
): ):
...@@ -265,10 +303,15 @@ def _layernorm_dense_bwd_rule( ...@@ -265,10 +303,15 @@ def _layernorm_dense_bwd_rule(
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis, flatten_axis,
x_bdim,
) = ctx ) = ctx
casted_grad, dbias = tex.quantize_dbias( casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad grad,
is_dbias=use_bias,
flatten_axis=flatten_axis,
quantizer=quantizer_set.dgrad,
noop_scaled_tensor=True,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
...@@ -284,7 +327,8 @@ def _layernorm_dense_bwd_rule( ...@@ -284,7 +327,8 @@ def _layernorm_dense_bwd_rule(
dgrad = tex.gemm( dgrad = tex.gemm(
casted_grad.get_tensor(TensorUsage.LHS), casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel, casted_kernel,
(g_constracting_dim, k_constracting_dim), contracting_dims=(g_constracting_dim, k_constracting_dim),
batched_dims=((x_bdim,), ()),
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
...@@ -297,7 +341,8 @@ def _layernorm_dense_bwd_rule( ...@@ -297,7 +341,8 @@ def _layernorm_dense_bwd_rule(
wgrad = tex.gemm( wgrad = tex.gemm(
casted_ln_out, casted_ln_out,
casted_grad.get_tensor(TensorUsage.RHS), casted_grad.get_tensor(TensorUsage.RHS),
(x_constracting_dim, g_constracting_dim), contracting_dims=(x_constracting_dim, g_constracting_dim),
batched_dims=((x_bdim,), (x_bdim,)),
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment