"vscode:/vscode.git/clone" did not exist on "44a581c1fbb05225e9a3edff91224d198d23c0a5"
Unverified Commit 214e2a4a authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[JAX] GEMM custom op (#1855)



* added XLA FFI custom op for TE/common nvte_cublas_gemm
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

started GemmPrimitive, abstract done
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

gemm custom op working with BF16, needs testing for FP8/MXFP8
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

converted TE GEMM API to use ScaledTensor and added os ENV flag to use TE GEMM under general gemm() call
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

BF16 tests passing, FP8 tests should be passing but contracting_dims has a scoping issue
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fp8 tests passing for E4M3, getting CUBLAS_STATUS_NOT_SUPPORTED for E5M2
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

updated GEMM API to use separate LHS and RHS quantizers instead of a QuantizerSet
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

new GemmPrimitive passing all Dense tests
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

import cleanup and reverted code chunk movement
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

removed unused .transpose() implementations from ScaledTensors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

all custom call tests passing on Hopper, GEMM-related tests cover both GemmPrimitive and native JAX impl
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

removed direct calls to GemmPrimitive.enabled() from outside of cpp_extensions
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

removed unused changes to ScaledTensor classes and debug prints
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* minor unit test cleanup
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* FP8 tests passing on Blackwell but MXFP8 outputs NaN
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* reverted dense and fuseddense changes, FP8 test passing on Hopper and Blackwell, MXFP8 has issues with E5M2
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* MXFP8 issue traced to scale factor padding with NaNs instead of zeros
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* padding scale with 2^-127 instead of nans
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix bug on rhs_scale_inv usage
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* cleanup E8M0 type converter use it in gemm.cpp
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* segfault fixed, passing all unittests on Blackwell
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fix for fuseddense tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix workspace alignment
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed GemmPrimitive custom partitioning to match jax.nn.scaled_matmul
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

all unit tests passing on H100x8 node
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



linting fixes
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fixed batch dimension numbers
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fixed FP8 scale sharding rule when there are no FP8 scales
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

added error message for unsupported Shardy partitioner
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fixed test tolerances for FP8 cases
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

fixed shardy test skip cases
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* moved reshape of encoder output in encoder examples to make custom partitioning rules work correctly
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added helper functions for padding and unpadding block scales, changed GemmPrimitive to accept unpadded scales and pad them after sharding
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* updated shardy rules for all custom ops to decouple block scale rules from their tensors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed linting errors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* changed unit test use_jax_gemm option to be a context to preserve external custom op settings, tightened multi-GPU encoder test tolerances, changed gemm() API to use contracting_dims and batched_dims separately instead of dimension_numbers
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed typo in test utils
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added sequence-first input warnings
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixed datasets version for JAX examples
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* reverting modification to force_1x_quantization decision
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* corrected gemm function syntax in unit tests
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 397c4be6
...@@ -30,7 +30,7 @@ for TEST_CASE in "${TEST_CASES[@]}"; do ...@@ -30,7 +30,7 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
LOG_FILE="${TEST_CASE}_gpu_${i}.log" LOG_FILE="${TEST_CASE}_gpu_${i}.log"
# Run pytest and redirect stdout and stderr to the log file # Run pytest and redirect stdout and stderr to the log file
pytest -c "$TE_PATH/tests/jax/pytest.ini" \ pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \
--num-process=$NUM_GPUS \ --num-process=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 & --process-id=$i > "$LOG_FILE" 2>&1 &
......
...@@ -25,6 +25,7 @@ from common import ( ...@@ -25,6 +25,7 @@ from common import (
assert_params_sufficiently_sharded, assert_params_sufficiently_sharded,
) )
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
...@@ -465,14 +466,14 @@ class TestEncoder(unittest.TestCase): ...@@ -465,14 +466,14 @@ class TestEncoder(unittest.TestCase):
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
def setUp(self): def setUp(self):
"""Run 3 epochs for testing""" """Run 5 epochs for testing"""
self.args = encoder_parser(["--epochs", "3"]) self.args = encoder_parser(["--epochs", "5"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
...@@ -480,7 +481,7 @@ class TestEncoder(unittest.TestCase): ...@@ -480,7 +481,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self): def test_te_mxfp8(self):
...@@ -488,14 +489,14 @@ class TestEncoder(unittest.TestCase): ...@@ -488,14 +489,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_with_sp(self): def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP""" """Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True self.args.enable_sp = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self): def test_te_delayed_scaling_fp8_with_sp(self):
...@@ -504,7 +505,7 @@ class TestEncoder(unittest.TestCase): ...@@ -504,7 +505,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self): def test_te_mxfp8_with_sp(self):
...@@ -513,14 +514,14 @@ class TestEncoder(unittest.TestCase): ...@@ -513,14 +514,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self): def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
self.args.enable_shardy = True self.args.enable_shardy = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self): def test_te_delayed_scaling_fp8_shardy(self):
...@@ -529,7 +530,7 @@ class TestEncoder(unittest.TestCase): ...@@ -529,7 +530,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp_shardy(self): def test_te_delayed_scaling_fp8_with_sp_shardy(self):
...@@ -539,9 +540,32 @@ class TestEncoder(unittest.TestCase): ...@@ -539,9 +540,32 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785 assert actual[0] < 0.43 and actual[1] > 0.80
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_with_sp_shardy(self):
"""Test Transformer Engine with MXFP8 + SP"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.43 and actual[1] > 0.80
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -21,6 +21,7 @@ from jax.sharding import PartitionSpec, NamedSharding ...@@ -21,6 +21,7 @@ from jax.sharding import PartitionSpec, NamedSharding
from common import is_bf16_supported, get_fp8_recipe_from_name_string from common import is_bf16_supported, get_fp8_recipe_from_name_string
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
...@@ -430,14 +431,14 @@ class TestEncoder(unittest.TestCase): ...@@ -430,14 +431,14 @@ class TestEncoder(unittest.TestCase):
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
def setUp(self): def setUp(self):
"""Run 3 epochs for testing""" """Run 5 epochs for testing"""
self.args = encoder_parser(["--epochs", "3"]) self.args = encoder_parser(["--epochs", "6"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
...@@ -445,7 +446,7 @@ class TestEncoder(unittest.TestCase): ...@@ -445,7 +446,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self): def test_te_current_scaling_fp8(self):
...@@ -453,7 +454,7 @@ class TestEncoder(unittest.TestCase): ...@@ -453,7 +454,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling" self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self): def test_te_mxfp8(self):
...@@ -461,14 +462,14 @@ class TestEncoder(unittest.TestCase): ...@@ -461,14 +462,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self): def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
self.args.enable_shardy = True self.args.enable_shardy = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self): def test_te_delayed_scaling_fp8_shardy(self):
...@@ -477,9 +478,7 @@ class TestEncoder(unittest.TestCase): ...@@ -477,9 +478,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.50 and actual[1] > 0.75
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8_shardy(self): def test_te_current_scaling_fp8_shardy(self):
...@@ -488,7 +487,19 @@ class TestEncoder(unittest.TestCase): ...@@ -488,7 +487,19 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling" self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.50 and actual[1] > 0.75
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.75
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -28,8 +28,8 @@ from common import ( ...@@ -28,8 +28,8 @@ from common import (
get_fp8_recipe_from_name_string, get_fp8_recipe_from_name_string,
) )
import transformer_engine.jax as te import transformer_engine.jax as te
import transformer_engine.jax.cpp_extensions as tex
import transformer_engine.jax.flax as te_flax import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
...@@ -584,8 +584,8 @@ class TestEncoder(unittest.TestCase): ...@@ -584,8 +584,8 @@ class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False):
"""Run 3 epochs for testing""" """Run 5 epochs for testing"""
args = encoder_parser([]) args = encoder_parser(["--epochs", "5"])
num_gpu = self.num_process num_gpu = self.num_process
tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1 tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1
...@@ -607,7 +607,7 @@ class TestEncoder(unittest.TestCase): ...@@ -607,7 +607,7 @@ class TestEncoder(unittest.TestCase):
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
result = self.exec(False, None) result = self.exec(False, None)
assert result[0] < 0.505 and result[1] > 0.755 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf( @unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8"
...@@ -615,7 +615,7 @@ class TestEncoder(unittest.TestCase): ...@@ -615,7 +615,7 @@ class TestEncoder(unittest.TestCase):
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8""" """Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling") result = self.exec(True, "DelayedScaling")
assert result[0] < 0.506 and result[1] > 0.753 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf( @unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
...@@ -623,7 +623,7 @@ class TestEncoder(unittest.TestCase): ...@@ -623,7 +623,7 @@ class TestEncoder(unittest.TestCase):
def test_te_current_scaling_fp8(self): def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8""" """Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling") result = self.exec(True, "Float8CurrentScaling")
assert result[0] < 0.507 and result[1] > 0.753 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf( @unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
...@@ -631,13 +631,13 @@ class TestEncoder(unittest.TestCase): ...@@ -631,13 +631,13 @@ class TestEncoder(unittest.TestCase):
def test_te_mxfp8(self): def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8""" """Test Transformer Engine with MXFP8"""
result = self.exec(True, "MXFP8BlockScaling") result = self.exec(True, "MXFP8BlockScaling")
assert result[0] < 0.505 and result[1] > 0.754 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self): def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
result = self.exec(False, None, enable_shardy=True) result = self.exec(False, None, enable_shardy=True)
assert result[0] < 0.505 and result[1] > 0.755 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf( @unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8"
...@@ -645,9 +645,7 @@ class TestEncoder(unittest.TestCase): ...@@ -645,9 +645,7 @@ class TestEncoder(unittest.TestCase):
def test_te_delayed_scaling_fp8_shardy(self): def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8""" """Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling", enable_shardy=True) result = self.exec(True, "DelayedScaling", enable_shardy=True)
assert result[0] < 0.506 and result[1] > 0.753 assert result[0] < 0.43 and result[1] > 0.80
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
@unittest.skipIf( @unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
...@@ -655,7 +653,18 @@ class TestEncoder(unittest.TestCase): ...@@ -655,7 +653,18 @@ class TestEncoder(unittest.TestCase):
def test_te_current_scaling_fp8_shardy(self): def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8""" """Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) result = self.exec(True, "Float8CurrentScaling", enable_shardy=True)
assert result[0] < 0.507 and result[1] > 0.753 assert result[0] < 0.43 and result[1] > 0.80
@unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
)
@unittest.skipIf(
tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
def test_te_mxfp8_shardy(self):
"""Test Transformer Engine with MXFP8"""
result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True)
assert result[0] < 0.43 and result[1] > 0.80
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -13,6 +13,7 @@ import operator ...@@ -13,6 +13,7 @@ import operator
from utils import ( from utils import (
assert_allclose, assert_allclose,
pytest_parametrize_wrapper, pytest_parametrize_wrapper,
use_jax_gemm,
) )
from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm_mlp import layernorm_mlp from transformer_engine.jax.layernorm_mlp import layernorm_mlp
...@@ -30,7 +31,6 @@ from transformer_engine.jax.cpp_extensions.quantization import ( ...@@ -30,7 +31,6 @@ from transformer_engine.jax.cpp_extensions.quantization import (
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax import cpp_extensions as tex
from transformer_engine.jax.quantize import ( from transformer_engine.jax.quantize import (
DelayedScaleQuantizer,
ScaledTensor, ScaledTensor,
ScaledTensor1x, ScaledTensor1x,
ScaledTensor2x, ScaledTensor2x,
...@@ -851,6 +851,22 @@ class TestFusedQuantize: ...@@ -851,6 +851,22 @@ class TestFusedQuantize:
) )
valid_fp8_gemm_operand_types = [
(jnp.float8_e4m3fn, jnp.float8_e4m3fn),
(jnp.float8_e5m2, jnp.float8_e4m3fn),
(jnp.float8_e4m3fn, jnp.float8_e5m2),
]
def _use_jax_fp8_gemm(enabled=False):
import os
if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$"
elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ:
os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE")
class TestDense: class TestDense:
def _ref_gemm_with_jnp_dot(self, a, b, data_layout): def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
if data_layout[0] == "T": if data_layout[0] == "T":
...@@ -883,27 +899,47 @@ class TestDense: ...@@ -883,27 +899,47 @@ class TestDense:
def test_gemm_bf16(self, m, n, k, data_layout): def test_gemm_bf16(self, m, n, k, data_layout):
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
primitive_out = tex.gemm(x, w, contracting_dims) primitive_out = tex.gemm(x, w, contracting_dims=contracting_dims)
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"])
def test_gemm_fp8(self, m, n, k, q_dtype, scaling_mode, data_layout): @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm):
if (
not with_jax_gemm
and scaling_mode.is_1d_block_scaling()
and jnp.float8_e5m2 in (x_qtype, w_qtype)
):
pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.")
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
quantizer_set = QuantizerFactory.create_set( quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=False scaling_mode=scaling_mode,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2,
is_2x2x=False,
) )
with use_jax_gemm(enabled=with_jax_gemm):
primitive_out = tex.gemm( primitive_out = tex.gemm(
x, w, contracting_dims=contracting_dims, quantizer_set=quantizer_set x,
w,
contracting_dims=contracting_dims,
lhs_quantizer=(
quantizer_set.x if x_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad
),
rhs_quantizer=(
quantizer_set.kernel if w_qtype == jnp.float8_e4m3fn else quantizer_set.dgrad
),
) )
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=q_dtype) assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
def test_dense_grad_bf16(self, m, n, k): def test_dense_grad_bf16(self, m, n, k):
...@@ -932,9 +968,9 @@ class TestDense: ...@@ -932,9 +968,9 @@ class TestDense:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
def test_dense_grad_fp8(self, m, n, k, q_dtype, scaling_mode): @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm):
data_layout = "NN" data_layout = "NN"
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
...@@ -956,10 +992,14 @@ class TestDense: ...@@ -956,10 +992,14 @@ class TestDense:
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2))
quantizer_set = QuantizerFactory.create_set( quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, fwd_dtype=q_dtype, bwd_dtype=q_dtype, is_2x2x=True scaling_mode=scaling_mode,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True,
) )
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
with use_jax_gemm(enabled=with_jax_gemm):
for _ in range(n_iterations): for _ in range(n_iterations):
primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = (
value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set) value_n_grad_primitive_func(x, w, bias, contracting_dims, quantizer_set)
...@@ -969,10 +1009,10 @@ class TestDense: ...@@ -969,10 +1009,10 @@ class TestDense:
x, w, bias, data_layout x, w, bias, data_layout
) )
assert_allclose(primitive_out, ref_out, dtype=q_dtype) assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(primitive_x_grad, ref_x_grad, dtype=q_dtype) assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=q_dtype) assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=q_dtype) assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp.float8_e5m2)
@pytest.fixture(name="random_inputs") @pytest.fixture(name="random_inputs")
...@@ -996,20 +1036,13 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan ...@@ -996,20 +1036,13 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
class TestFusedDense: class TestFusedDense:
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
def test_layernorm_dense_grad(self, m, n, k, q_dtype, scaling_mode, norm_type): @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_gemm):
""" """
Test layernorm_dense VJP Rule Test layernorm_dense VJP Rule
""" """
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm # zero_centered_gamma is already tested in TestNorm
zero_centered_gamma = False zero_centered_gamma = False
eps = 1e-6 eps = 1e-6
...@@ -1025,8 +1058,8 @@ class TestFusedDense: ...@@ -1025,8 +1058,8 @@ class TestFusedDense:
quantizer_set = QuantizerFactory.create_set( quantizer_set = QuantizerFactory.create_set(
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
fwd_dtype=q_dtype, fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=q_dtype, bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True, is_2x2x=True,
) )
...@@ -1064,6 +1097,7 @@ class TestFusedDense: ...@@ -1064,6 +1097,7 @@ class TestFusedDense:
) )
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
with use_jax_gemm(enabled=with_jax_gemm):
for _ in range(n_iterations): for _ in range(n_iterations):
prim_out, ( prim_out, (
prim_x_grad, prim_x_grad,
...@@ -1072,33 +1106,26 @@ class TestFusedDense: ...@@ -1072,33 +1106,26 @@ class TestFusedDense:
prim_beta_grad, prim_beta_grad,
) = value_n_grad_prim_func(x, w, gamma, beta) ) = value_n_grad_prim_func(x, w, gamma, beta)
assert_allclose(prim_out, ref_out, dtype=q_dtype) assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_w_grad, ref_w_grad, dtype=q_dtype) assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype) assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2)
if beta is not None: if beta is not None:
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype) assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp.float8_e5m2)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
@pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"])
@pytest.mark.parametrize("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad( def test_layernorm_mlp_grad(
self, m, n, k, activation_type, q_dtype, scaling_mode, norm_type, use_bias self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm
): ):
""" """
Test layernorm_mlp VJP Rule Test layernorm_mlp VJP Rule
""" """
# No Norm FWD E5M2 in TE backend
if q_dtype == jnp.float8_e5m2 and scaling_mode in (
ScalingMode.DELAYED_TENSOR_SCALING,
ScalingMode.CURRENT_TENSOR_SCALING,
):
pytest.skip("E5M2 is not supported in normalization with TE Backend!")
# zero_centered_gamma is already tested in TestNorm # zero_centered_gamma is already tested in TestNorm
zero_centered_gamma = False zero_centered_gamma = False
eps = 1e-6 eps = 1e-6
...@@ -1123,8 +1150,8 @@ class TestFusedDense: ...@@ -1123,8 +1150,8 @@ class TestFusedDense:
quantizer_sets = QuantizerFactory.create_set( quantizer_sets = QuantizerFactory.create_set(
n_quantizer_sets=2, n_quantizer_sets=2,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
fwd_dtype=q_dtype, fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=q_dtype, bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn,
is_2x2x=True, is_2x2x=True,
) )
...@@ -1153,14 +1180,13 @@ class TestFusedDense: ...@@ -1153,14 +1180,13 @@ class TestFusedDense:
ln_out = _ref_jax_norm_impl( ln_out = _ref_jax_norm_impl(
x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer=None
) )
# TODO: replace gemm with jnp.dot linear_1_out = jax.lax.dot_general(ln_out, kernel_1, (((1,), (0,)), ((), ())))
linear_1_out = tex.gemm(ln_out, kernel_1, ((1,), (0,)))
if use_bias: if use_bias:
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape) linear_1_out += jnp.reshape(bias_1, bias_1_shape)
x = _jax_act_lu(linear_1_out, activation_type) x = _jax_act_lu(linear_1_out, activation_type)
linear_2_out = tex.gemm(x, kernel_2, ((1,), (0,))) linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ())))
if use_bias: if use_bias:
bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape
linear_2_out += jnp.reshape(bias_2, bias_2_shape) linear_2_out += jnp.reshape(bias_2, bias_2_shape)
...@@ -1174,6 +1200,7 @@ class TestFusedDense: ...@@ -1174,6 +1200,7 @@ class TestFusedDense:
value_n_grad_ref_func = value_and_grad(ref_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, range(6))
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
with use_jax_gemm(enabled=with_jax_gemm):
for _ in range(n_iterations): for _ in range(n_iterations):
prim_out, ( prim_out, (
prim_x_grad, prim_x_grad,
...@@ -1193,18 +1220,18 @@ class TestFusedDense: ...@@ -1193,18 +1220,18 @@ class TestFusedDense:
ref_bias_2_grad, ref_bias_2_grad,
) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2)
assert_allclose(prim_out, ref_out, dtype=q_dtype) assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=q_dtype) assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp.float8_e5m2)
if use_bias: if use_bias:
assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=q_dtype) assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=q_dtype) assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp.float8_e5m2)
if use_bias: if use_bias:
assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=q_dtype) assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=q_dtype) assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2)
assert_allclose(prim_x_grad, ref_x_grad, dtype=q_dtype) assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2)
# E5M2 * E5M2 is not supported # E5M2 * E5M2 is not supported
......
...@@ -13,6 +13,7 @@ from utils import ( ...@@ -13,6 +13,7 @@ from utils import (
assert_tree_like_allclose, assert_tree_like_allclose,
is_devices_enough, is_devices_enough,
pytest_parametrize_wrapper, pytest_parametrize_wrapper,
use_jax_gemm,
) )
from transformer_engine.common import recipe from transformer_engine.common import recipe
...@@ -147,7 +148,15 @@ class TestDistributedLayernormMLP: ...@@ -147,7 +148,15 @@ class TestDistributedLayernormMLP:
) )
def _test_layernorm_mlp_grad( def _test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, use_shardy self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
use_shardy,
with_jax_gemm,
): ):
jax.config.update("jax_use_shardy_partitioner", use_shardy) jax.config.update("jax_use_shardy_partitioner", use_shardy)
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
...@@ -157,6 +166,8 @@ class TestDistributedLayernormMLP: ...@@ -157,6 +166,8 @@ class TestDistributedLayernormMLP:
input_shape, activation_type, use_bias, dtype input_shape, activation_type, use_bias, dtype
) )
static_inputs = [layernorm_type, activation_type] static_inputs = [layernorm_type, activation_type]
with use_jax_gemm(enabled=with_jax_gemm):
value_and_grad_func = jax.value_and_grad( value_and_grad_func = jax.value_and_grad(
self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs)) self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
) )
...@@ -172,7 +183,9 @@ class TestDistributedLayernormMLP: ...@@ -172,7 +183,9 @@ class TestDistributedLayernormMLP:
# Multi GPUs # Multi GPUs
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource): with mesh, fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource
):
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp")) k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp")) k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding) k1_ = jax.device_put(k1, k1_sharding)
...@@ -204,25 +217,32 @@ class TestDistributedLayernormMLP: ...@@ -204,25 +217,32 @@ class TestDistributedLayernormMLP:
value_and_grad_func, value_and_grad_func,
in_shardings=in_shardings, in_shardings=in_shardings,
out_shardings=out_shardings, out_shardings=out_shardings,
static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1), static_argnums=range(
len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1
),
) # +1 for multi_gpus ) # +1 for multi_gpus
multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)
assert_allclose(multi_fwd, single_fwd, dtype=dtype) fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn
bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)
for i in range(len(inputs)): for i in range(len(inputs)):
if multi_grads[i] is not None: if multi_grads[i] is not None:
if isinstance(multi_grads[i], list): if isinstance(multi_grads[i], list):
assert isinstance(single_grads[i], list) assert isinstance(single_grads[i], list)
for m_grad, s_grad in zip(multi_grads[i], single_grads[i]): for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
assert_allclose( assert_allclose(
m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close" m_grad,
s_grad,
dtype=bwd_test_type,
err_msg=f"multi_grads[{i}] is not close",
) )
else: else:
assert_allclose( assert_allclose(
multi_grads[i], multi_grads[i],
single_grads[i], single_grads[i],
dtype=dtype, dtype=bwd_test_type,
err_msg=f"multi_grads[{i}] is not close", err_msg=f"multi_grads[{i}] is not close",
) )
...@@ -233,8 +253,16 @@ class TestDistributedLayernormMLP: ...@@ -233,8 +253,16 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad( def test_layernorm_mlp_grad(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
with_jax_gemm,
): ):
self._test_layernorm_mlp_grad( self._test_layernorm_mlp_grad(
mesh_config, mesh_config,
...@@ -244,6 +272,7 @@ class TestDistributedLayernormMLP: ...@@ -244,6 +272,7 @@ class TestDistributedLayernormMLP:
dtype, dtype,
fp8_recipe, fp8_recipe,
use_shardy=False, use_shardy=False,
with_jax_gemm=with_jax_gemm,
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
...@@ -252,19 +281,29 @@ class TestDistributedLayernormMLP: ...@@ -252,19 +281,29 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_grad_shardy( def test_layernorm_mlp_grad_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype self,
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
fp8_recipe,
with_jax_gemm,
): ):
# We don't test block scaling with Shardy because at the time of writing, if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
# it is not supported in JAX's scaled_matmul_stablehlo. pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
self._test_layernorm_mlp_grad( self._test_layernorm_mlp_grad(
mesh_config, mesh_config,
activation_type, activation_type,
use_bias, use_bias,
input_shape, input_shape,
dtype, dtype,
fp8_recipe=recipe.DelayedScaling(), fp8_recipe=fp8_recipe,
use_shardy=True, use_shardy=True,
with_jax_gemm=with_jax_gemm,
) )
def _test_layernorm_mlp( def _test_layernorm_mlp(
...@@ -277,6 +316,7 @@ class TestDistributedLayernormMLP: ...@@ -277,6 +316,7 @@ class TestDistributedLayernormMLP:
use_fp8, use_fp8,
fp8_recipe, fp8_recipe,
use_shardy, use_shardy,
with_jax_gemm,
): ):
jax.config.update("jax_use_shardy_partitioner", use_shardy) jax.config.update("jax_use_shardy_partitioner", use_shardy)
batch, seqlen, hidden_in = input_shape batch, seqlen, hidden_in = input_shape
...@@ -288,6 +328,7 @@ class TestDistributedLayernormMLP: ...@@ -288,6 +328,7 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
init_rngs = {"params": subkeys[1]} init_rngs = {"params": subkeys[1]}
with use_jax_gemm(enabled=with_jax_gemm):
# Single GPUs # Single GPUs
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
ln_mlp_single = LayerNormMLP( ln_mlp_single = LayerNormMLP(
...@@ -355,9 +396,9 @@ class TestDistributedLayernormMLP: ...@@ -355,9 +396,9 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("use_shardy", [False, True]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer( def test_layernorm_mlp_layer(
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_shardy self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
): ):
self._test_layernorm_mlp( self._test_layernorm_mlp(
mesh_config, mesh_config,
...@@ -367,7 +408,8 @@ class TestDistributedLayernormMLP: ...@@ -367,7 +408,8 @@ class TestDistributedLayernormMLP:
dtype, dtype,
use_fp8=False, use_fp8=False,
fp8_recipe=None, fp8_recipe=None,
use_shardy=use_shardy, use_shardy=False,
with_jax_gemm=with_jax_gemm,
) )
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
...@@ -377,8 +419,9 @@ class TestDistributedLayernormMLP: ...@@ -377,8 +419,9 @@ class TestDistributedLayernormMLP:
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_fp8( def test_layernorm_mlp_layer_fp8(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
): ):
self._test_layernorm_mlp( self._test_layernorm_mlp(
mesh_config, mesh_config,
...@@ -389,4 +432,51 @@ class TestDistributedLayernormMLP: ...@@ -389,4 +432,51 @@ class TestDistributedLayernormMLP:
use_fp8=True, use_fp8=True,
fp8_recipe=fp8_recipe, fp8_recipe=fp8_recipe,
use_shardy=False, use_shardy=False,
with_jax_gemm=with_jax_gemm,
)
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")])
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm
):
self._test_layernorm_mlp(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=False,
fp8_recipe=None,
use_shardy=True,
with_jax_gemm=with_jax_gemm,
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")])
@pytest_parametrize_wrapper("use_bias", [True, False])
@pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
@pytest_parametrize_wrapper("with_jax_gemm", [False, True])
def test_layernorm_mlp_layer_fp8_shardy(
self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm
):
if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.")
self._test_layernorm_mlp(
mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=True,
fp8_recipe=fp8_recipe,
use_shardy=True,
with_jax_gemm=with_jax_gemm,
) )
...@@ -3,11 +3,12 @@ ...@@ -3,11 +3,12 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Utility for the TE layer tests""" """Utility for the TE layer tests"""
import os
import functools import functools
import math import math
import operator import operator
from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional, NewType
import os from contextlib import contextmanager
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -20,7 +21,6 @@ from jax import random as jax_random ...@@ -20,7 +21,6 @@ from jax import random as jax_random
import pytest import pytest
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
AttnMaskType,
canonicalize_attn_mask_type, canonicalize_attn_mask_type,
make_swa_mask, make_swa_mask,
) )
...@@ -28,8 +28,8 @@ from transformer_engine.jax.quantize.helper import DType as TEDType ...@@ -28,8 +28,8 @@ from transformer_engine.jax.quantize.helper import DType as TEDType
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
DType = jnp.dtype DType = NewType("DType", jnp.dtype)
Array = Any Array = NewType("Array", jnp.ndarray)
PrecisionLike = Union[ PrecisionLike = Union[
None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
] ]
...@@ -1519,7 +1519,7 @@ def dtype_tols( ...@@ -1519,7 +1519,7 @@ def dtype_tols(
TEDType.kFloat8E5M2: jnp.float8_e5m2, TEDType.kFloat8E5M2: jnp.float8_e5m2,
}[dtype] }[dtype]
elif isinstance(dtype, np.dtype): elif isinstance(dtype, np.dtype):
dtype = jnp.dtype(dtype) dtype = DType(dtype)
# Expect bit-wise accuracy for integer dtypes # Expect bit-wise accuracy for integer dtypes
if not jnp.issubdtype(dtype, jnp.floating): if not jnp.issubdtype(dtype, jnp.floating):
...@@ -1600,3 +1600,20 @@ def print_debug_tensor_stats(prefix, tensor, hist=False): ...@@ -1600,3 +1600,20 @@ def print_debug_tensor_stats(prefix, tensor, hist=False):
fmt = fmt + "\n {}\n {}" fmt = fmt + "\n {}\n {}"
jax.debug.print(fmt, *args) jax.debug.print(fmt, *args)
@contextmanager
def use_jax_gemm(enabled=False):
orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None)
try:
if enabled:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$"
yield
finally:
if enabled:
if orig_custom_calls_filter is None:
os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE")
else:
os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls_filter
...@@ -415,37 +415,35 @@ class ActLuPrimitive(BasePrimitive): ...@@ -415,37 +415,35 @@ class ActLuPrimitive(BasePrimitive):
result_types, result_types,
): ):
del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types
prefix = "ActLuPrimitive_"
x_rank = len(value_types[0].shape) x_rank = len(value_types[0].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank - 1, unique_var="ActLuPrimitive_i", flatten_axis=-2 x_rank - 1, unique_var=prefix + "x", flatten_axis=-2
) )
x_axes = scale_rules.input_spec + (f"x{x_rank-1}",) x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}",)
out = (*x_axes[:-2], x_axes[-1]) out = (*x_axes[:-2], x_axes[-1])
scale_inv = scale_rules.rowwise_rule scale_inv = scale_rules.rowwise_rule
colwise_scale_inv = scale_rules.colwise_rule
colwise_out = (prefix + "out_colwise",)
colwise_scale_inv = (prefix + "scale_inv_colwise",)
if is_2x: if is_2x:
colwise_scale_inv = scale_rules.colwise_rule
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple( colwise_out = tuple(
multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2) multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2)
) )
else: else:
colwise_out = out colwise_out = out
else:
colwise_out = ("j",)
colwise_scale_inv = ("k",)
# amax is always a unit tensor. # amax is always a unit tensor.
amax = ("l",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
( (
x_axes, x_axes,
"…1", ("…1",),
), ),
(out, colwise_out, scale_inv, colwise_scale_inv, amax), (out, colwise_out, scale_inv, colwise_scale_inv, amax),
**scale_rules.factor_sizes,
) )
...@@ -890,28 +888,26 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -890,28 +888,26 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
result_types, result_types,
): ):
del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types
prefix = "BaseDActLuDBiasQuantizePrimitive_"
x_rank = len(value_types[1].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank, unique_var="BaseDActLuDBiasQuantizePrimitive_i", flatten_axis=-2 len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2
) )
x_axes = scale_rules.input_spec x_axes = scale_rules.input_spec
dz_axes = (*x_axes[:-2], x_axes[-1])
out = x_axes out = x_axes
colwise_out = (prefix + "out_colwise",)
if is_2x: if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
else: else:
colwise_out = tuple(x_axes) colwise_out = out
else:
colwise_out = ("j",)
dbias = x_axes[-2:] if is_dbias else ("k",) dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",)
amax = ("…4",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
(("…0",), tuple(x_axes), ("…2",)), (dz_axes, x_axes, ("…2",)),
(out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias),
**scale_rules.factor_sizes,
) )
...@@ -985,6 +981,7 @@ def act_lu( ...@@ -985,6 +981,7 @@ def act_lu(
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
noop_scaled_tensor: bool = False,
) -> Union[jnp.ndarray, ScaledTensor]: ) -> Union[jnp.ndarray, ScaledTensor]:
"""Activation with optional quantization. """Activation with optional quantization.
...@@ -993,6 +990,7 @@ def act_lu( ...@@ -993,6 +990,7 @@ def act_lu(
Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function to apply. activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
If quantizer is None: If quantizer is None:
...@@ -1037,6 +1035,10 @@ def act_lu( ...@@ -1037,6 +1035,10 @@ def act_lu(
is_outer=True, is_outer=True,
) )
out = out.reshape(output_shape) out = out.reshape(output_shape)
if noop_scaled_tensor:
return ScaledTensorFactory.create_2x(
out, None, out, None, ScalingMode.NO_SCALING, dq_dtype=out.dtype
)
return out return out
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
...@@ -1090,6 +1092,7 @@ def quantize_dact_dbias( ...@@ -1090,6 +1092,7 @@ def quantize_dact_dbias(
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
is_dbias: bool = True, is_dbias: bool = True,
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor, jnp.ndarray]: ) -> Tuple[ScaledTensor, jnp.ndarray]:
"""Compute gradients of activation and bias with optional quantization. """Compute gradients of activation and bias with optional quantization.
...@@ -1100,6 +1103,7 @@ def quantize_dact_dbias( ...@@ -1100,6 +1103,7 @@ def quantize_dact_dbias(
activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",). activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
is_dbias: If True, compute bias gradient. Defaults to True. is_dbias: If True, compute bias gradient. Defaults to True.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
Tuple[ScaledTensor, jnp.ndarray]: A tuple containing: Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
...@@ -1113,13 +1117,49 @@ def quantize_dact_dbias( ...@@ -1113,13 +1117,49 @@ def quantize_dact_dbias(
f" {x.shape} and act_len {act_len}" f" {x.shape} and act_len {act_len}"
) )
scale = jnp.empty((), jnp.float32)
act_type_id = ActivationEnum[activation_type]
PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive
if not PrimitiveClass.enabled(): if not PrimitiveClass.enabled() or (
quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE
):
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
# TE/common does not support colwise-only quantization yet if quantizer is None:
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) dz,
x,
scale,
# outputs float32 for dbias accumulation
out_dtype=(jnp.float32 if is_dbias else x.dtype),
# default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, # unused
scale_dtype=jnp.float32, # unused
is_dbias=False,
act_enum=act_type_id,
act_len=act_len,
is_outer=True,
)
output = output.astype(x.dtype)
dbias = None
if is_dbias:
dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
if noop_scaled_tensor:
return (
ScaledTensorFactory.create_2x(
output,
None,
output,
None,
ScalingMode.NO_SCALING,
dq_dtype=output.dtype,
),
dbias,
)
return output, dbias
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
...@@ -1145,31 +1185,6 @@ def quantize_dact_dbias( ...@@ -1145,31 +1185,6 @@ def quantize_dact_dbias(
if war_output is not None: if war_output is not None:
return war_output return war_output
scale = jnp.empty((), jnp.float32)
act_type_id = ActivationEnum[activation_type]
if quantizer is None:
output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind(
dz,
x,
scale,
# outputs float32 for dbias accumulation
out_dtype=(jnp.float32 if is_dbias else x.dtype),
# default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, # unused
scale_dtype=jnp.float32, # unused
is_dbias=False,
act_enum=act_type_id,
act_len=act_len,
is_outer=True,
)
dbias = None
if is_dbias:
dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
return output.astype(x.dtype), dbias
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after. # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out = dact_lu( out = dact_lu(
...@@ -1183,7 +1198,7 @@ def quantize_dact_dbias( ...@@ -1183,7 +1198,7 @@ def quantize_dact_dbias(
) )
return out, dbias return out, dbias
if isinstance(quantizer, DelayedScaleQuantizer): if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale scale = quantizer.scale
# TE/common dact_dbias_quantize does not support gated act yet # TE/common dact_dbias_quantize does not support gated act yet
...@@ -1243,6 +1258,7 @@ def dact_lu( ...@@ -1243,6 +1258,7 @@ def dact_lu(
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
noop_scale_tensor: bool = False,
) -> Union[jnp.ndarray, ScaledTensor]: ) -> Union[jnp.ndarray, ScaledTensor]:
""" """
Backward pass for activation with optional quantization. Backward pass for activation with optional quantization.
...@@ -1252,6 +1268,7 @@ def dact_lu( ...@@ -1252,6 +1268,7 @@ def dact_lu(
x: Input tensor that was used in forward pass. x: Input tensor that was used in forward pass.
activation_type: Type of activation function that was applied. activation_type: Type of activation function that was applied.
quantizer: Optional quantizer for FP8 quantization of the output gradient. quantizer: Optional quantizer for FP8 quantization of the output gradient.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
The gradient of the activation with respect to the input. The gradient of the activation with respect to the input.
...@@ -1262,5 +1279,6 @@ def dact_lu( ...@@ -1262,5 +1279,6 @@ def dact_lu(
activation_type=activation_type, activation_type=activation_type,
is_dbias=False, is_dbias=False,
quantizer=quantizer, quantizer=quantizer,
noop_scaled_tensor=noop_scale_tensor,
) )
return output return output
...@@ -3,19 +3,26 @@ ...@@ -3,19 +3,26 @@
# See LICENSE for license information. # See LICENSE for license information.
"""JAX te modules""" """JAX te modules"""
from typing import Tuple, Sequence, Union, Dict
from functools import partial, reduce
import operator
import math import math
import operator
from collections.abc import Iterable
from typing import Tuple, Sequence, Union
from functools import partial, reduce
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from transformer_engine_jax import get_device_compute_capability, get_num_compute_streams from jax import dtypes
from jax.sharding import NamedSharding, PartitionSpec
from jax.experimental.custom_partitioning import SdyShardingRule
import transformer_engine_jax as tex
from transformer_engine_jax import get_num_compute_streams
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .quantization import grouped_quantize from .quantization import grouped_quantize
from ..quantize import ( from ..quantize import (
ScaledTensor, ScaledTensor,
ScaledTensor2x,
GroupedScaledTensor1x, GroupedScaledTensor1x,
ScalingMode, ScalingMode,
Quantizer, Quantizer,
...@@ -25,10 +32,20 @@ from ..quantize import ( ...@@ -25,10 +32,20 @@ from ..quantize import (
QuantizeLayout, QuantizeLayout,
noop_quantizer_set, noop_quantizer_set,
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv,
remove_padding_from_scale_inv,
) )
from .misc import get_padded_spec
__all__ = ["gemm", "grouped_gemm"] __all__ = [
"gemm",
"grouped_gemm",
"gemm_uses_jax_dot",
"sanitize_dims",
"get_non_contracting_dims",
"transpose_dims",
]
num_cublas_streams = get_num_compute_streams() num_cublas_streams = get_num_compute_streams()
...@@ -36,11 +53,936 @@ num_cublas_streams = get_num_compute_streams() ...@@ -36,11 +53,936 @@ num_cublas_streams = get_num_compute_streams()
def get_cublas_workspace_size_bytes() -> None: def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures.""" """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if get_device_compute_capability(0) >= 90: if tex.get_device_compute_capability(0) >= 90:
return 33_554_432 return 33_554_432
return 4_194_304 return 4_194_304
def sanitize_dims(ndim: int, dims: Union[int, Sequence[int]]) -> Sequence[int]:
"""Convert relative (negative) indexes to absolute dimension numbers."""
dims_ = dims if isinstance(dims, Iterable) else (dims,)
if len(dims_) == 0:
return dims_
return tuple(ndim + dim if dim < 0 else dim for dim in dims_ if dim is not None)
def get_non_contracting_dims(ndim, contracting_dims):
"""Return a tuple of dimensions not included in the contracting dimensions."""
contracting_dims = sanitize_dims(ndim, contracting_dims)
return tuple(dim for dim in range(ndim) if dim not in contracting_dims)
def transpose_dims(ndim, dims_to_transpose, flatten_axis=-1):
"""Compute the new dimension numbers after transpose."""
if len(dims_to_transpose) == 0:
return dims_to_transpose
flatten_axis = ndim - flatten_axis if flatten_axis > 0 else flatten_axis
transposed_dims = (*range(flatten_axis, ndim), *range(flatten_axis))
return tuple(transposed_dims.index(dim) for dim in dims_to_transpose)
def _compatible_fp8_gemm_dtypes(lhs_dtype, rhs_dtype) -> bool:
lhs, rhs, e4m3, e5m2 = map(
dtypes.canonicalize_dtype,
(
lhs_dtype,
rhs_dtype,
jnp.float8_e4m3fn,
jnp.float8_e5m2,
),
)
# FP8 GEMM supports (e4m3 x e4m3), (e4m3 x e5m2) and (e5m2 x e4m3)
if (lhs is e4m3 and rhs in (e4m3, e5m2)) or (lhs in (e4m3, e5m2) and rhs is e4m3):
return True
# Any other combination of data types is not supported
return False
def _get_gemm_layout(
operand_ndims: Tuple[int, int], contracting_dims: Tuple[Sequence[int], Sequence[int]]
) -> Tuple[bool, bool]:
lhs_contracting, rhs_contracting = map(sanitize_dims, operand_ndims, contracting_dims)
lhs_is_transposed = operand_ndims[0] - 1 not in lhs_contracting
rhs_is_transposed = operand_ndims[1] - 1 in rhs_contracting
return lhs_is_transposed, rhs_is_transposed
def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims):
lhs_q = lhs
rhs_q = rhs
if not isinstance(lhs, ScaledTensor) and lhs_quantizer is not None:
lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims[0])
lhs_is_transposed = lhs.ndim - 1 not in lhs_cdims
need_lhs_colwise = lhs_is_transposed and (
lhs_quantizer.scaling_mode.is_1d_block_scaling()
or not is_fp8_gemm_with_all_layouts_supported()
)
flatten_axis = max(lhs_cdims) + 1 if lhs_is_transposed else min(lhs_cdims)
lhs_q = lhs_quantizer.quantize(
lhs,
is_rowwise=not need_lhs_colwise,
is_colwise=need_lhs_colwise,
flatten_axis=flatten_axis,
)
if not isinstance(rhs, ScaledTensor) and rhs_quantizer is not None:
rhs_cdims = sanitize_dims(rhs.ndim, contracting_dims[1])
rhs_is_transposed = rhs.ndim - 1 in rhs_cdims
need_rhs_colwise = not rhs_is_transposed and (
rhs_quantizer.scaling_mode.is_1d_block_scaling()
or not is_fp8_gemm_with_all_layouts_supported()
)
flatten_axis = min(rhs_cdims) if rhs_is_transposed else max(rhs_cdims) + 1
rhs_q = rhs_quantizer.quantize(
rhs,
is_rowwise=not need_rhs_colwise,
is_colwise=need_rhs_colwise,
flatten_axis=flatten_axis,
)
assert not isinstance(lhs_q, ScaledTensor2x)
assert not isinstance(rhs_q, ScaledTensor2x)
return lhs_q, rhs_q
class GemmPrimitive(BasePrimitive):
"""
Primitive for cuBLAS GEMM
"""
name = "te_gemm_ffi"
multiple_results = True
impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
gelu_input,
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
):
del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator
def _dims_are_consecutive(dims):
if len(dims) <= 1:
return True
return sorted(dims) == list(range(min(dims), max(dims) + 1))
# Sanity-check operand layouts and types
operand_ndims = (lhs.ndim, rhs.ndim)
(
lhs_contracting_dims,
rhs_contracting_dims,
) = map(sanitize_dims, operand_ndims, contracting_dims)
assert _dims_are_consecutive(lhs_contracting_dims), (
"cuBLAS GEMM expected consecutive contracting dimensions for LHS operand, but got "
f"{lhs_contracting_dims}."
)
assert _dims_are_consecutive(rhs_contracting_dims), (
"cuBLAS GEMM expected consecutive contracting dimensions for RHS operand, but got "
f"{rhs_contracting_dims}."
)
(
lhs_batch_dims,
rhs_batch_dims,
) = map(sanitize_dims, operand_ndims, batched_dims)
assert _dims_are_consecutive(lhs_batch_dims), (
"cuBLAS GEMM expected consecutive batch dimensions for LHS operand, but got "
f"{lhs_batch_dims}."
)
assert _dims_are_consecutive(rhs_batch_dims), (
"cuBLAS GEMM expected consecutive batch dimensions for RHS operand, but got "
f"{rhs_batch_dims}."
)
if len(lhs_batch_dims) == 0:
assert (
len(rhs_batch_dims) == 0
), "cuBLAS GEMM RHS operand cannot be batched if LHS operand is not batched."
elif len(rhs_batch_dims) != 0:
assert all(bdim in lhs_contracting_dims for bdim in lhs_batch_dims) and all(
bdim in rhs_contracting_dims for bdim in rhs_batch_dims
), "cuBLAS GEMM batched dimensions must be contracting when both operands are batched."
lhs_contracting_size, rhs_contracting_size = map(
lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]),
(lhs.shape, rhs.shape),
(lhs_contracting_dims, rhs_contracting_dims),
)
assert lhs_contracting_size == rhs_contracting_size, (
"cuBLAS GEMM operands have incompatible contracting dimensions: "
f"{lhs.shape} @ idx {lhs_contracting_dims} X {rhs.shape} @ idx {rhs_contracting_dims}."
)
lhs_is_transposed, rhs_is_transposed = _get_gemm_layout(operand_ndims, contracting_dims)
if scaling_mode != ScalingMode.NO_SCALING:
assert _compatible_fp8_gemm_dtypes(lhs.dtype, rhs.dtype), (
"cuBLAS GEMM quantized operands have incompatible data types: "
f"{lhs.dtype} x {rhs.dtype}."
)
assert (
lhs_scale_inv.size > 0 and rhs_scale_inv.size > 0
), "Quantized cuBLAS GEMM requires inverse scaling factors for both operands."
if (
scaling_mode != ScalingMode.MXFP8_1D_SCALING
and not tex.is_non_nt_fp8_gemm_supported()
):
assert not lhs_is_transposed and rhs_is_transposed, (
"cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) "
"require non-transposed LHS and transposed RHS operands "
"(`contracting_dims=((-1, ), (-1, ))`)."
)
# Determine output shape and dtype
assert (
dtypes.canonicalize_dtype(out_dtype).itemsize > 1
), "cuBLAS GEMM custom op does not support 8-bit quantized output types."
lhs_non_contracting_shape, rhs_non_contracting_shape = map(
lambda shape, dims: [shape[dim] for dim in range(len(shape)) if dim not in dims],
(lhs.shape, rhs.shape),
(lhs_contracting_dims, rhs_contracting_dims),
)
out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape)
output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
# Validate bias
bias_shape = (0,)
bias_dtype = out_dtype
if fuse_bias:
expected_bias_size = reduce(operator.mul, rhs_non_contracting_shape)
if not grad:
assert bias.size == expected_bias_size, (
"cuBLAS GEMM bias tensor has incorrect shape, "
f"expected ({expected_bias_size}, ) but found {bias.shape}."
)
assert bias.dtype == out_dtype, (
"cuBLAS GEMM bias tensor has incorrect data type, "
f"expected {bias_dtype} but found {bias.dtype}."
)
bias_shape = bias.shape
else:
bias_shape = rhs_non_contracting_shape
bias_grad = jax.core.ShapedArray(shape=bias_shape, dtype=bias_dtype)
# Validate pre-GeLU
pre_gelu_shape = (0,)
pre_gelu_dtype = out_dtype
if fuse_gelu:
pre_gelu_shape = out_shape
if grad:
pre_gelu_ndim = len(pre_gelu_shape)
assert gelu_input.ndim == pre_gelu_shape and all(
gelu_input.shape[i] == pre_gelu_shape[i] for i in range(pre_gelu_ndim)
), (
"cuBLAS GEMM pre-GeLU tensor has incorrect shape, "
f"expected {pre_gelu_shape} but found {gelu_input.shape}."
)
assert gelu_input.dtype == out_dtype, (
"cuBLAS GEMM pre-GeLU tensor has incorrect data type, "
f"expected {pre_gelu_dtype} but found {gelu_input.dtype}."
)
pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype)
# Need extra workspace for swizzled scale factors
lhs_swizzle_size = 0
rhs_swizzle_size = 0
swizzle_dtype = jnp.uint8
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_swizzle_size = lhs_scale_inv.size
rhs_swizzle_size = rhs_scale_inv.size
lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size,), dtype=swizzle_dtype)
rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype)
# Declare cuBLAS workspace
# cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
# necessarily 256 bytes aligned, we add some padding to ensure alignment.
workspace_size = get_cublas_workspace_size_bytes() + 256
workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace
@staticmethod
def outer_abstract(*args, **kwargs):
outputs = GemmPrimitive.abstract(*args, **kwargs)
return outputs[:-3] # discard workspace arrays
@staticmethod
def lowering(
ctx,
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
gelu_input,
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
):
del batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, out_dtype
lhs_aval, _, rhs_aval, *_ = ctx.avals_in
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims)
lhs_transposed, rhs_transposed = _get_gemm_layout(
(lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims)
)
args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input)
kwargs = {
"scaling_mode": int(scaling_mode.value),
"lhs_axis_boundary": max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims),
"rhs_axis_boundary": min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1,
"lhs_transposed": lhs_transposed,
"rhs_transposed": rhs_transposed,
"fuse_bias": fuse_bias,
"fuse_gelu": fuse_gelu,
"grad": grad,
"use_split_accumulator": use_split_accumulator,
}
operand_output_aliases = {}
if fuse_bias and not grad:
operand_output_aliases.update({4: 1}) # bias <-> bias_grad
if fuse_gelu and grad:
operand_output_aliases.update({5: 2}) # gelu_input <-> pre_gelu_out
return jax.ffi.ffi_lowering(
GemmPrimitive.name,
operand_output_aliases=operand_output_aliases,
)(ctx, *args, **kwargs)
@staticmethod
def impl(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
gelu_input,
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
):
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
lhs_transposed, rhs_transposed = _get_gemm_layout(
(lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims)
)
lhs_scale_inv = apply_padding_to_scale_inv(
lhs_scale_inv,
scaling_mode,
lhs.shape,
is_colwise=lhs_quantized_colwise,
flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims),
)
rhs_scale_inv = apply_padding_to_scale_inv(
rhs_scale_inv,
scaling_mode,
rhs.shape,
is_colwise=rhs_quantized_colwise,
flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1,
)
outputs = GemmPrimitive.inner_primitive.bind(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
gelu_input,
out_dtype=out_dtype,
contracting_dims=contracting_dims,
batched_dims=batched_dims,
lhs_quantized_colwise=lhs_quantized_colwise,
rhs_quantized_colwise=rhs_quantized_colwise,
scaling_mode=scaling_mode,
fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
)
return outputs[:-3] # discard workspace arrays
@staticmethod
def batcher(
batched_args,
jax_batch_dims,
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
):
assert GemmPrimitive.outer_primitive is not None
lhs, _, rhs, *_ = batched_args
lhs_bdims, _, rhs_bdims, *_ = jax_batch_dims
arg_lhs_bdims, arg_rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims)
arg_lhs_bdims = (None,) if len(arg_lhs_bdims) == 0 else arg_lhs_bdims
assert all(bdim == arg_bdim for bdim, arg_bdim in zip(lhs_bdims, arg_lhs_bdims)), (
"User-specified batch dimension(s) for cuBLAS GEMM LHS operand does not match batch "
f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}."
)
arg_rhs_bdims = (None,) if len(arg_rhs_bdims) == 0 else arg_rhs_bdims
assert all(bdim == arg_bdim for bdim, arg_bdim in zip(rhs_bdims, arg_rhs_bdims)), (
"User-specified batch dimension(s) for cuBLAS GEMM RHS operand does not match batch "
f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}."
)
# Output is batched like the non-contracting batch dimensions of the LHS operand
lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims)
lhs_non_contracting_bdims = tuple(dim for dim in lhs_bdims if dim not in lhs_cdims)
out_bdims = (None,) if len(lhs_non_contracting_bdims) == 0 else lhs_non_contracting_bdims
# Bias gradient is never batched
bias_bdims = (None,)
# Pre-GeLU output, if exists, is batched like GEMM output
pre_gelu_bdims = (None,)
if fuse_gelu and not grad:
pre_gelu_bdims = out_bdims
return (
GemmPrimitive.outer_primitive.bind(
*batched_args,
out_dtype=out_dtype,
contracting_dims=contracting_dims,
batched_dims=batched_dims,
lhs_quantized_colwise=lhs_quantized_colwise,
rhs_quantized_colwise=rhs_quantized_colwise,
scaling_mode=scaling_mode,
fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
),
(out_bdims, bias_bdims, pre_gelu_bdims),
)
@staticmethod
def _decompose_operand_specs(specs, contracting_dims, batch_dims):
ndims = len(specs)
cdims, bdims = map(sanitize_dims, (ndims, ndims), (contracting_dims, batch_dims))
# Batch specs
bspecs = tuple(specs[i] for i in bdims)
# Non-batch leading dimension specs
lspecs = tuple(specs[i] for i in range(ndims) if i not in cdims + bdims)
# Non-batch contracting dimension specs
cspecs = tuple(specs[i] for i in range(ndims) if i in cdims and i not in bdims)
return bspecs, lspecs, cspecs
@staticmethod
def _parse_operand_output_specs(arg_infos, contracting_dims, batched_dims):
lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos)
lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs))
lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map(
sanitize_dims, 2 * [lhs_ndim, rhs_ndim], contracting_dims + batched_dims
)
(
(lhs_bspecs, lhs_lspecs, lhs_cspecs),
(rhs_bspecs, rhs_lspecs, rhs_cspecs),
) = map(
GemmPrimitive._decompose_operand_specs,
(lhs_specs, rhs_specs),
(lhs_cdims, rhs_cdims),
(lhs_bdims, rhs_bdims),
)
# Batched dimensions must have the same sharding
if len(lhs_bdims) > 0 and len(rhs_bdims) > 0:
assert all(
lhs_bspec == rhs_bspec for lhs_bspec, rhs_bspec in zip(lhs_bspecs, rhs_bspecs)
), (
"cuBLAS GEMM operand batch dimensions must have the same sharding: "
f"{lhs_specs} @ idx {lhs_bdims} x {rhs_specs} @ idx {rhs_bdims}."
)
# Only one each of the non-batched leading dimensions and non-batched contracting
# dimensions can be sharded
lhs_ldims, rhs_ldims = map(
lambda ndim, exclude: tuple(dim for dim in range(ndim) if dim not in exclude),
(lhs_ndim, rhs_ndim),
(lhs_bdims + lhs_cdims, rhs_bdims + rhs_cdims),
)
(lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none) = map(
lambda specs: tuple(spec for spec in specs if spec is not None),
(lhs_lspecs, rhs_lspecs, lhs_cspecs, rhs_cspecs),
)
assert len(lhs_lspec_not_none) <= 1 and len(rhs_lspec_not_none) <= 1, (
"cuBLAS GEMM operands can have only one sharded non-batched leading dimension: "
f"{lhs_specs} @ idx {lhs_ldims} x {rhs_specs} @ idx {rhs_ldims}."
)
assert len(lhs_cspec_not_none) <= 1 and len(rhs_cspec_not_none) <= 1, (
"cuBLAS GEMM operands can have only one sharded non-batched contracting dimension: "
f"{lhs_specs} @ idx {lhs_cdims} x {rhs_specs} @ idx {rhs_cdims}."
)
# Extract single leading and contracting dimension specs
(lhs_lspec, rhs_lspec, lhs_cspec, rhs_cspec) = map(
lambda specs: None if len(specs) == 0 else specs[0],
(lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none),
)
# Reproducing jax.nn.scaled_matmul() custom partitioning for arbitrary GEMM layouts
# with row-wise LHS:(B, M, K1) and row-wise RHS:(B, N, K2) operands.
# 1. K1 == K2 != None and N == None
# LHS: (B, M, K)
# RHS: (B, None, K)
# OUT: (B, M, None) --(AR)-> (B, M, None)
# 2. K1 == K2 != None and M == N != None
# LHS: (B, M, K)
# RHS: (B, N, K)--(AG)->(B, None, K)
# OUT: (B, M, None) --(RS)--> (B, M, N)
# 3. M == N
# LHS: (B, M, K)--(AG)->(B, M, None)
# RHS: (B, M, K)--(AG)->(B, None, None)
# OUT: (B, M, None)
# 4. M != N
# LHS: (B, M, K)--(AG)->(B, M, None)
# RHS: (B, N, K)--(AG)->(B, N, None)
# OUT: (B, M, N)
reduce_flag = lhs_cspec is not None and lhs_cspec == rhs_cspec
all_reduce_output = reduce_flag and rhs_lspec is None
reduce_scatter_output = reduce_flag and lhs_lspec is not None and lhs_lspec == rhs_lspec
all_reduce_spec = reduce_scatter_spec = scatter_dim = None
lhs_non_contracting_specs, rhs_non_contracting_specs = map(
lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims),
(lhs_specs, rhs_specs),
(lhs_cdims, rhs_cdims),
)
out_specs = (*lhs_non_contracting_specs, *rhs_non_contracting_specs)
if reduce_scatter_output:
# All-gather (if necessary) the non-batch non-contracting dimension of RHS
# (B, N, K) --(AG)-> (B, None, K)
# (B, M, K) x (B, None, K)^T = (B, M, None) --(RS)-> (B, M, N)
rhs_spec = tuple(
rhs_spec[i] if i in set(rhs_bdims + rhs_cdims) else None for i in range(rhs_ndim)
)
reduce_scatter_spec = lhs_cspec
scatter_dim = out_specs.index(rhs_lspec)
elif all_reduce_output:
# Set all output trailing dimensions to zero
out_specs = (
*lhs_non_contracting_specs,
*[None for _ in range(len(rhs_non_contracting_specs))],
)
all_reduce_spec = lhs_cspec
else:
# All-gather (if necessary) the non-batch contracting dimensions
# (B, M, K) --(AG)-> (B, M, None)
# (B, N, K) --(AG)-> (B, N, None)
# (B, M, None) x (B, N, None)^T = (B, M, N)
lhs_specs = tuple(
None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i]
for i in range(lhs_ndim)
)
rhs_specs = tuple(
None if i in rhs_cdims and i not in rhs_bdims else rhs_specs[i]
for i in range(rhs_ndim)
)
# Check if RHS non-contracting spec also appears in the LHS non-contracting specs
if rhs_lspec is not None and rhs_lspec in tuple(
lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_cdims
):
# All-gather (if necessary) the non-batch non-contracting dimensions of RHS
# (B, N, None) --(AG)-> (B, None, None)
# (B, M, None) x (B, None, None)^T = (B, M, None)
rhs_specs = tuple(
None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i]
for i in range(rhs_ndim)
)
# Set all output trailing dimensions to zero
out_specs = (
*lhs_non_contracting_specs,
*[None for _ in range(len(rhs_non_contracting_specs))],
)
# Bias and Pre-GeLU sharding is based on GEMM output
bias_specs = out_specs[len(lhs_non_contracting_specs) :]
gelu_specs = out_specs
return (
(lhs_specs, rhs_specs, bias_specs, gelu_specs),
(out_specs, bias_specs, gelu_specs),
all_reduce_spec,
reduce_scatter_spec,
scatter_dim,
)
@staticmethod
def infer_sharding_from_operands(
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
mesh,
arg_infos,
result_infos,
):
del (
out_dtype,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
grad,
)
del use_split_accumulator, result_infos
(_, (out_specs, dbias_specs, pre_gelu_specs), *_) = (
GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims)
)
out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs))
# Discard bias gradient spec if there is no bias fusion
if not fuse_bias:
dbias_specs = (None,)
dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_specs))
# Discard pre-GeLU output spec if there is no GeLU fusion
if not fuse_gelu:
pre_gelu_specs = (None,)
pre_gelu_sharding = NamedSharding(mesh, PartitionSpec(*pre_gelu_specs))
return [out_sharding, dbias_sharding, pre_gelu_sharding]
@staticmethod
def partition(
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
mesh,
arg_infos,
result_infos,
):
del result_infos
(
(lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs),
(out_specs, dbias_specs, pre_gelu_specs),
all_reduce_spec,
reduce_scatter_spec,
scatter_dim,
) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims)
# Assemble argument shardings
# NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded.
none_sharding = NamedSharding(mesh, PartitionSpec(None))
lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs))
rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs))
arg_shardings = (
lhs_sharding,
lhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding,
rhs_sharding,
rhs_sharding if scaling_mode.is_1d_block_scaling() else none_sharding,
)
# Discard bias input spec if there is no bias fusion
if not fuse_bias:
bias_input_specs = (None,)
arg_shardings += (NamedSharding(mesh, PartitionSpec(*bias_input_specs)),)
# Discard pre-GeLU input spec if there is no GeLU fusion
if not fuse_gelu:
gelu_input_specs = (None,)
arg_shardings += (NamedSharding(mesh, PartitionSpec(*gelu_input_specs)),)
# Assemble output shardings
out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))]
# Discard bias gradient spec if there is no bias fusion
if not fuse_bias:
dbias_specs = (None,)
out_shardings.append(NamedSharding(mesh, PartitionSpec(*dbias_specs)))
# Discard pre-GeLU output spec if there is no GeLU fusion
if not fuse_gelu:
pre_gelu_specs = (None,)
out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs)))
def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input):
outputs = GemmPrimitive.impl(
lhs,
lhs_scale_inv,
rhs,
rhs_scale_inv,
bias,
gelu_input,
out_dtype=out_dtype,
contracting_dims=contracting_dims,
batched_dims=batched_dims,
lhs_quantized_colwise=lhs_quantized_colwise,
rhs_quantized_colwise=rhs_quantized_colwise,
scaling_mode=scaling_mode,
fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
)
# All-Reduce/Reduce-Scatter GEMM output
if all_reduce_spec is not None:
outputs[0] = jax.lax.psum(outputs[0], all_reduce_spec)
if fuse_gelu and not grad:
outputs[2] = jax.lax.psum(outputs[2], all_reduce_spec)
elif reduce_scatter_spec is not None:
outputs[0] = jax.lax.psum_scatter(
outputs[0], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True
)
if fuse_gelu and not grad:
outputs[2] = jax.lax.psum_scatter(
outputs[2], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True
)
return outputs
return mesh, _sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
out_dtype,
contracting_dims,
batched_dims,
lhs_quantized_colwise,
rhs_quantized_colwise,
scaling_mode,
fuse_bias,
fuse_gelu,
grad,
use_split_accumulator,
mesh,
operand_types,
result_types,
):
del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype, grad, use_split_accumulator
del mesh, result_types
prefix = "GemmPrimitive_"
def _generate_operand_rules(name, ndim, cdims, bdims):
specs = []
ldims = tuple(i for i in range(ndim) if i not in bdims + cdims)
for i in range(ndim):
dim_name = None
if i in bdims:
dim_idx = bdims.index(i) if len(bdims) > 1 else ""
dim_name = f"b{dim_idx}"
elif i in cdims:
dim_idx = cdims.index(i) if len(cdims) > 1 else ""
dim_name = f"k{dim_idx}"
else:
dim_idx = ldims.index(i) if len(ldims) > 1 else ""
dim_name = f"{name}_l{dim_idx}"
specs.append(prefix + dim_name)
return specs
lhs, _, rhs, *_ = operand_types
operand_ndims = (len(lhs.shape), len(rhs.shape))
(lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims) = map(
lambda dims: map(sanitize_dims, operand_ndims, dims),
(contracting_dims, batched_dims),
)
lhs_specs, rhs_specs = map(
_generate_operand_rules,
("lhs", "rhs"),
operand_ndims,
(lhs_cdims, rhs_cdims),
(lhs_bdims, rhs_bdims),
)
lhs_scale_specs = ("…1",)
rhs_scale_specs = ("…2",)
if scaling_mode.is_1d_block_scaling():
# Shardy rules for MXFP8 scales cannot be related to the operands because of the
# global-unpadding and local-padding workflow. This can potentially insert expensive
# re-shards in the partition call later if the scales are not already sharded correctly.
lhs_scale_specs, rhs_scale_specs = map(
lambda specs: tuple(spec.replace(prefix, prefix + "scale_inv_") for spec in specs),
(lhs_specs, rhs_specs),
)
lhs_non_cspec = tuple(lhs_specs[i] for i in range(operand_ndims[0]) if i not in lhs_cdims)
rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims)
out_spec = (*lhs_non_cspec, *rhs_non_cspec)
bias_spec = rhs_non_cspec if fuse_bias else ("…4",)
gelu_spec = out_spec if fuse_gelu else ("…5",)
return SdyShardingRule(
operand_mappings=(
lhs_specs,
lhs_scale_specs,
rhs_specs,
rhs_scale_specs,
bias_spec,
gelu_spec,
),
result_mappings=(
out_spec,
bias_spec,
gelu_spec,
),
)
register_primitive(GemmPrimitive)
def gemm_uses_jax_dot() -> bool:
"""Check if the GEMM call directs to the TE custom cuBLAS call or native JAX dot."""
return not GemmPrimitive.enabled()
def _get_scale_inv_without_padding(scaled_tensor):
return remove_padding_from_scale_inv(
scaled_tensor.scale_inv,
scaled_tensor.scaling_mode,
scaled_tensor.data.shape,
is_colwise=scaled_tensor.is_colwise,
flatten_axis=scaled_tensor.flatten_axis,
)
def _te_gemm(
lhs: Union[jax.Array, ScaledTensor],
rhs: Union[jax.Array, ScaledTensor],
bias: jax.Array = None,
gelu_input: jax.Array = None,
lhs_quantizer: Quantizer = None,
rhs_quantizer: Quantizer = None,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()),
fuse_bias: bool = False,
fuse_gelu: bool = False,
grad: bool = False,
use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP,
) -> Tuple[jax.Array, ...]:
# Prepare non-quantized GEMM operands
lhs_data = lhs
rhs_data = rhs
lhs_scale_inv = jnp.empty(0, dtype=jnp.float32)
rhs_scale_inv = jnp.empty(0, dtype=jnp.float32)
scaling_mode = ScalingMode.NO_SCALING
lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims)
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims)
# Quantize operands (if necessary)
lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)
# Extract GEMM custom op inputs from quantized operands
if isinstance(lhs_q, ScaledTensor):
assert isinstance(rhs_q, ScaledTensor) or rhs_quantizer is not None, (
"cuBLAS GEMM with quantized LHS and non-quantized RHS operands requires a valid "
"`Quantizer` object to quantize the RHS operand."
)
if isinstance(lhs_q, ScaledTensor2x):
# Choose the quantization of the contracting dimension(s)
lhs_q = lhs_q.get_colwise_tensor() if lhs_is_transposed else lhs_q.get_rowwise_tensor()
scaling_mode = lhs_q.scaling_mode
lhs_data = lhs_q.data
lhs_scale_inv = _get_scale_inv_without_padding(lhs_q)
if lhs_q.data_layout == "T":
lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis)
lhs_bdims = transpose_dims(lhs_q.ndim, lhs_bdims, flatten_axis=lhs_q.flatten_axis)
if isinstance(rhs_q, ScaledTensor):
assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, (
"cuBLAS GEMM with non-quantized LHS and quantized RHS operands requires a valid "
"`Quantizer` object to quantize the LHS operand."
)
if isinstance(rhs_q, ScaledTensor2x):
# Choose the quantization of the contracting dimension(s)
rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor()
assert rhs_q.scaling_mode == lhs_q.scaling_mode, (
"cuBLAS GEMM quantized operands have mismatched scaling types, "
f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}."
)
rhs_data = rhs_q.data
rhs_scale_inv = _get_scale_inv_without_padding(rhs_q)
if rhs_q.data_layout == "T":
rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis)
rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis)
# Dummy empties for bias and gelu
out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype
if bias is None or not (fuse_bias and not grad):
bias = jnp.empty(0, dtype=out_dtype)
if gelu_input is None or not (fuse_gelu and grad):
gelu_input = jnp.empty(0, dtype=out_dtype)
return GemmPrimitive.outer_primitive.bind(
lhs_data,
lhs_scale_inv,
rhs_data,
rhs_scale_inv,
bias,
gelu_input,
out_dtype=out_dtype,
contracting_dims=(lhs_cdims, rhs_cdims),
batched_dims=(lhs_bdims, rhs_bdims),
lhs_quantized_colwise=lhs_q.is_colwise if isinstance(lhs_q, ScaledTensor) else False,
rhs_quantized_colwise=rhs_q.is_colwise if isinstance(rhs_q, ScaledTensor) else False,
scaling_mode=scaling_mode,
fuse_bias=fuse_bias,
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
)
class GroupedGemmPrimitive(BasePrimitive): class GroupedGemmPrimitive(BasePrimitive):
""" """
Primitive for grouped GEMM Primitive for grouped GEMM
...@@ -230,11 +1172,8 @@ def _shape_normalization(x, dimension_numbers, already_transposed: bool = False) ...@@ -230,11 +1172,8 @@ def _shape_normalization(x, dimension_numbers, already_transposed: bool = False)
def _calculate_remaining_shape(shape, contracting_dims): def _calculate_remaining_shape(shape, contracting_dims):
return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims) contracting_dims_ = sanitize_dims(len(shape), contracting_dims)
return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims_)
def _transpose_contract_dims(ndim, contracting_dims):
return tuple(ndim - i - 1 for i in contracting_dims)[::-1]
# Apply jit to guarantee correctness of FP8 GEMM. # Apply jit to guarantee correctness of FP8 GEMM.
...@@ -242,9 +1181,11 @@ def _transpose_contract_dims(ndim, contracting_dims): ...@@ -242,9 +1181,11 @@ def _transpose_contract_dims(ndim, contracting_dims):
def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision): def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.data_layout == "T": if lhs.data_layout == "T":
lhs_contract = _transpose_contract_dims(lhs.data.ndim, lhs_contract) lhs_contract = transpose_dims(lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis)
lhs_batch = transpose_dims(lhs.data.ndim, lhs_batch, flatten_axis=lhs.flatten_axis)
if rhs.data_layout == "T": if rhs.data_layout == "T":
rhs_contract = _transpose_contract_dims(rhs.data.ndim, rhs_contract) rhs_contract = transpose_dims(rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis)
rhs_batch = transpose_dims(rhs.data.ndim, rhs_batch, flatten_axis=rhs.flatten_axis)
dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch)
...@@ -315,12 +1256,12 @@ def _jax_gemm( ...@@ -315,12 +1256,12 @@ def _jax_gemm(
lhs: Union[jnp.ndarray, ScaledTensor], lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
quantizer_set: Dict["str", Quantizer] = noop_quantizer_set, lhs_quantizer: Quantizer = None,
rhs_quantizer: Quantizer = None,
) -> jnp.ndarray: ) -> jnp.ndarray:
""" """
FP8 GEMM via JAX FP8 GEMM via JAX
""" """
dim_nums = (contracting_dims, ((), ())) dim_nums = (contracting_dims, ((), ()))
def _jax_gemm_fp8_impl(lhs, rhs): def _jax_gemm_fp8_impl(lhs, rhs):
...@@ -340,37 +1281,16 @@ def _jax_gemm( ...@@ -340,37 +1281,16 @@ def _jax_gemm(
raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}")
if isinstance(lhs, ScaledTensor) and isinstance(rhs, ScaledTensor): lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims)
return _jax_gemm_fp8_impl(lhs, rhs)
if not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor): if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor):
if quantizer_set != noop_quantizer_set:
assert type(quantizer_set.x) is type(quantizer_set.kernel)
if (
quantizer_set.x.scaling_mode.is_tensor_scaling()
and is_fp8_gemm_with_all_layouts_supported()
):
lhs_is_rowwise = rhs_is_rowwise = True
else:
(((lhs_contract_dim,), (rhs_contract_dim,)), _) = dim_nums
lhs_is_rowwise = lhs_contract_dim == lhs.ndim - 1
rhs_is_rowwise = rhs_contract_dim == rhs.ndim - 1
lhs_q = quantizer_set.x.quantize(
lhs,
is_rowwise=lhs_is_rowwise,
is_colwise=not lhs_is_rowwise,
)
rhs_q = quantizer_set.kernel.quantize(
rhs,
is_rowwise=rhs_is_rowwise,
is_colwise=not rhs_is_rowwise,
)
return _jax_gemm_fp8_impl(lhs_q, rhs_q) return _jax_gemm_fp8_impl(lhs_q, rhs_q)
if ( if (
isinstance(lhs, jnp.ndarray) isinstance(lhs, jnp.ndarray)
and isinstance(rhs, jnp.ndarray) and isinstance(rhs, jnp.ndarray)
and quantizer_set == noop_quantizer_set and lhs_quantizer is None
and rhs_quantizer is None
): ):
return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype) return jax.lax.dot_general(lhs, rhs, dim_nums, preferred_element_type=lhs.dtype)
...@@ -380,30 +1300,109 @@ def _jax_gemm( ...@@ -380,30 +1300,109 @@ def _jax_gemm(
def gemm( def gemm(
lhs: Union[jnp.ndarray, ScaledTensor], lhs: Union[jnp.ndarray, ScaledTensor],
rhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor],
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
quantizer_set: QuantizerSet = noop_quantizer_set, batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()),
) -> jnp.ndarray: lhs_quantizer: Quantizer = None,
"""General matrix multiplication with optional quantization. rhs_quantizer: Quantizer = None,
**kwargs,
Args: ) -> Tuple[jnp.ndarray, ...]:
lhs: First input matrix. r"""General matrix multiplication with optional quantization.
rhs: Second input matrix.
contracting_dims: Tuple of two sequences representing the contracting dimensions. Parameters
The first sequence represents the contracting dimensions of the first matrix, ----------
and the second sequence represents the contracting dimensions of the second matrix. lhs: Union[jax.Array, ScaledTensor]
quantizer_set: Set of quantizers for FP8 quantization of the output. Left-hand side operand in the matrix multiplication.
If None, no quantization is applied and the output has the same dtype as the inputs. rhs: Union[jax.Array, ScaledTensor]
Right-hand side operand in the matrix multiplication.
Returns: lhs_quantizer: Quantizer, default = None
If quantizer_set is None: Object for down-casting the LHS operand for quantized GEMM.
The matrix multiplication result. rhs_quantizer: Quantizer, default = None
Shape: (M, N) Object for down-casting the RHS operand for quantized GEMM.
Dtype: Same as input dtype contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (0, ))
If quantizer_set is provided: Tuple of sequences representing the contracting dimensions of the operands.
A ScaledTensor containing the quantized matrix multiplication result. batched_dims: Tuple[Sequence[int], Sequence[int]], default = ((), ()),
Tuple of sequences representing the batched dimensions of the operands. This is *not* used
to perform a batched matrix multiplication, but it is required to avoid a potentially
undesirable reduction in any batched contracting dimensions when invoked with sharded
operands (e.g. when computing weight gradients in a Flax module).
bias: jax.Array, default = None
Optional additive bias term, required for forward GEMM with bias fusion. Only supported
with TE's custom call to cuBLAS GEMM.
gelu_input: jax.Array, default = None
Pre-GeLU output from forward GEMM, required for backward/grad GEMM with dGeLU fusion. Only
supported with TE's custom call to cuBLAS GEMM.
fuse_bias: bool, default = False
Enable bias addition in forward GEMM or bias gradient in backward GEMM. Only supported with
TE's custom call to cuBLAS GEMM.
fuse_gelu: bool, default = False
Enable GeLU activation in forward GEMM or GeLU gradient in backward GEMM. Only supported
with TE's custom call to cuBLAS GEMM.
grad: bool, default = False
Flag for switching bias and GeLU fusions from forward to backward mode. Only supported with
TE's custom call to cuBLAS GEMM.
use_split_accumulator: bool, default = True
Enable promoting some intermediate sums to higher precision when accumulating the result in
the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed.
Returns
-------
jax.Array:
Result of the operation. For TE's custom call to cuBLAS GEMM, this result can include the
GeLU application when `fuse_gelu=True` and `grad=False`, the GeLU gradient contribution
when `fuse_gelu=True` and `grad=True`, and the additive bias when `fuse_bias=True` and
`grad=False`.
Optional[jax.Array]:
Bias gradient when `fuse_bias=True` and `grad=True`. Only supported with TE's custom call
to cuBLAS GEMM.
Optional[jax.Array]:
Pre-GeLU GEMM output when `fuse_gelu=True` and `grad=False`. This is required as an input
to `_te_gemm()` with `fuse_gelu=True` and `grad=True` in the backward pass in order to
compute the GeLU contribution to the gradient. Only supported with TE's custom call to
cuBLAS GEMM.
""" """
# Try to get LHS and RHS quantizers from a quantizer set for backward compatibility
if lhs_quantizer is None or rhs_quantizer is None:
quantizer_set = kwargs.get("quantizer_set", None)
if quantizer_set is not None:
lhs_quantizer = quantizer_set.x
rhs_quantizer = quantizer_set.kernel
# Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled
fuse_bias = kwargs.get("fuse_bias", False)
fuse_gelu = kwargs.get("fuse_gelu", False)
if not GemmPrimitive.enabled():
assert kwargs.get("bias", None) is None and not fuse_gelu, (
"TE GEMM was invoked with bias fusion options that are not supported by the "
"`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS "
"GEMM primitive is disabled."
)
assert kwargs.get("gelu_input", None) is None and not fuse_bias, (
"TE GEMM was invoked with GeLU fusion options that are not supported by the "
"`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS "
"GEMM primitive is disabled."
)
return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer)
outputs = _te_gemm(
lhs,
rhs,
lhs_quantizer=lhs_quantizer,
rhs_quantizer=rhs_quantizer,
contracting_dims=contracting_dims,
batched_dims=batched_dims,
**kwargs,
)
return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set) # Discard empty outputs
grad = kwargs.get("grad", False)
clean_outputs = outputs[0] # first output is the final result and is never empty
if (fuse_bias and grad) or (fuse_gelu and not grad):
clean_outputs = (outputs[0],)
if fuse_bias and grad: # only return bias gradient if it exists
clean_outputs += (outputs[1],)
if fuse_gelu and not grad: # only return pre-GeLU output if it exists
clean_outputs += (outputs[2],)
return clean_outputs
def grouped_gemm( def grouped_gemm(
......
...@@ -198,14 +198,19 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant ...@@ -198,14 +198,19 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant
Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to
calculate dbias separately. This function checks if the workaround should be applied. calculate dbias separately. This function checks if the workaround should be applied.
""" """
if quantizer is None:
return False
arch_l_100 = False arch_l_100 = False
for local_gpu_id in range(len(jax.local_devices())): for local_gpu_id in range(len(jax.local_devices())):
if transformer_engine_jax.get_device_compute_capability(local_gpu_id) < 100: if transformer_engine_jax.get_device_compute_capability(local_gpu_id) < 100:
arch_l_100 = True arch_l_100 = True
break break
# _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE,
# but this fails when bias fusion is turned on with arch < 100.
force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
return ( return (
quantizer is not None (force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE)
and quantizer.q_layout == QuantizeLayout.ROWWISE
and arch_l_100 and arch_l_100
and is_dbias and is_dbias
) )
......
...@@ -587,16 +587,17 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -587,16 +587,17 @@ class NormFwdPrimitive(BasePrimitive):
result_types, result_types,
) )
prefix = "NormFwdPrimitive_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var="NormFwdPrimitive_i", flatten_axis=-1 len(value_types[0].shape), unique_var=prefix + "x", flatten_axis=-1
) )
x_axes = scale_rules.input_spec x_axes = scale_rules.input_spec
out = x_axes[:-1] + ("k",) out = x_axes
colwise_out = out if is_2x else ("…4",) colwise_out = out if is_2x else (prefix + "out_colwise",)
rsigma = x_axes[:-1] rsigma = x_axes[:-1]
mu = ("…5",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma mu = (prefix + "mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma
amax = ("…6",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
(x_axes, ("…1",), ("…2",), ("…3",)), (x_axes, ("…1",), ("…2",), ("…3",)),
...@@ -609,7 +610,6 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -609,7 +610,6 @@ class NormFwdPrimitive(BasePrimitive):
mu, mu,
rsigma, rsigma,
), ),
**scale_rules.factor_sizes,
) )
...@@ -1276,6 +1276,7 @@ def normalization_fwd( ...@@ -1276,6 +1276,7 @@ def normalization_fwd(
epsilon: float, epsilon: float,
norm_type: str, norm_type: str,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
noop_scaled_tensor: bool = False,
): ):
"""Common wrapper for normalization forward pass. """Common wrapper for normalization forward pass.
...@@ -1292,6 +1293,7 @@ def normalization_fwd( ...@@ -1292,6 +1293,7 @@ def normalization_fwd(
- 'layernorm': Layer normalization - 'layernorm': Layer normalization
- 'rmsnorm': Root mean square normalization - 'rmsnorm': Root mean square normalization
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -1319,6 +1321,15 @@ def normalization_fwd( ...@@ -1319,6 +1321,15 @@ def normalization_fwd(
else: else:
raise ValueError(f"{norm_type=} is not supported.") raise ValueError(f"{norm_type=} is not supported.")
if quantizer is None and noop_scaled_tensor:
return (
ScaledTensorFactory.create_2x(
output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype
),
mu,
rsigma,
)
return output, mu, rsigma return output, mu, rsigma
......
...@@ -36,7 +36,6 @@ from ..quantize import ( ...@@ -36,7 +36,6 @@ from ..quantize import (
Quantizer, Quantizer,
GroupedQuantizer, GroupedQuantizer,
QuantizeLayout, QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode, ScalingMode,
compute_scale_from_amax, compute_scale_from_amax,
) )
...@@ -489,9 +488,10 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -489,9 +488,10 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
): ):
del out_dtype, scale_dtype, is_outer, mesh, result_types del out_dtype, scale_dtype, is_outer, mesh, result_types
prefix = "BaseDBiasQuantizePrimitive_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), len(value_types[0].shape),
unique_var="BaseDBiasQuantizePrimitive_i", unique_var=prefix + "x",
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
) )
...@@ -499,22 +499,19 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -499,22 +499,19 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
colwise_scale_inv = scale_rules.colwise_rule colwise_scale_inv = scale_rules.colwise_rule
out = x_axes out = x_axes
colwise_out = (prefix + "out_colwise",)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if ScalingMode(scaling_mode).is_tensor_scaling(): if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis)) colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
else: else:
colwise_out = x_axes colwise_out = x_axes
else:
colwise_out = ("j",)
colwise_scale_inv = ("k",)
dbias = x_axes[flatten_axis:] if is_dbias else ("l",) dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",)
amax = ("m",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
(x_axes, ("…1",)), (x_axes, ("…1",)),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
**scale_rules.factor_sizes,
) )
...@@ -538,11 +535,12 @@ def _jax_quantize( ...@@ -538,11 +535,12 @@ def _jax_quantize(
def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1): def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
assert flatten_axis < 0 sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
assert sum_axis < dx.ndim, "Flatten axis out of bounds!"
dtype = dtype or dx.dtype dtype = dtype or dx.dtype
dbias = jnp.sum( dbias = jnp.sum(
dx.astype(jnp.float32), dx.astype(jnp.float32),
axis=tuple(range(dx.ndim + flatten_axis)), axis=tuple(range(sum_axis)),
keepdims=False, keepdims=False,
) )
return dbias.astype(dtype) return dbias.astype(dtype)
...@@ -568,6 +566,7 @@ def _quantize_dbias_impl( ...@@ -568,6 +566,7 @@ def _quantize_dbias_impl(
is_dbias: bool = False, is_dbias: bool = False,
dq_dtype: Optional[jnp.dtype] = None, dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1, flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
""" """
Cast wrapper Cast wrapper
...@@ -577,24 +576,34 @@ def _quantize_dbias_impl( ...@@ -577,24 +576,34 @@ def _quantize_dbias_impl(
quantizer is not None quantizer is not None
), "quantizer must be provided if dq_dtype is provided" ), "quantizer must be provided if dq_dtype is provided"
# Early-exit for non-quantized call
dq_dtype = dq_dtype or x.dtype dq_dtype = dq_dtype or x.dtype
if quantizer is None:
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive dbias = None
if not PrimitiveClass.enabled():
if is_dbias: if is_dbias:
return _jax_quantize_dbias( dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
x, if noop_scaled_tensor:
quantizer=quantizer, # Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor()
dq_dtype=dq_dtype, # always works.
flatten_axis=flatten_axis,
)
return ( return (
_jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), ScaledTensorFactory.create_2x(
x,
None, None,
x,
None,
ScalingMode.NO_SCALING,
dq_dtype=x.dtype,
data_layout="NN",
flatten_axis=flatten_axis,
),
dbias,
) )
return x, dbias
# TE/common doesn't support colwise only quantization yet # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: # fall back on the native-JAX quantize implementation
PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive
if quantizer.q_layout == QuantizeLayout.COLWISE or not PrimitiveClass.enabled():
if is_dbias: if is_dbias:
return _jax_quantize_dbias( return _jax_quantize_dbias(
x, x,
...@@ -606,9 +615,8 @@ def _quantize_dbias_impl( ...@@ -606,9 +615,8 @@ def _quantize_dbias_impl(
_jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis), _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
None, None,
) )
scale = jnp.empty((), jnp.float32)
# TE/common dbias_quantize does not support 1x on arch < 100 # TE/common custom quantize op does not support dbias fusion with 1x quantization on arch < 100
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out, _ = _quantize_dbias_impl( out, _ = _quantize_dbias_impl(
x=x, x=x,
...@@ -620,29 +628,23 @@ def _quantize_dbias_impl( ...@@ -620,29 +628,23 @@ def _quantize_dbias_impl(
dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return out, dbias return out, dbias
if quantizer is None: scale = jnp.empty((), jnp.float32)
if is_dbias:
return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return x, None
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Globally reduce amax across all devices for current scaling so we have a single global scale. # Globally reduce amax across all devices for current scaling so we have a single global scale.
# This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
# until the tensor is dequantized (e.g. in the GEMM). # until the tensor is dequantized (e.g. in the GEMM).
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32) amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32)
scale = compute_scale_from_amax(amax, quantizer.q_dtype) scale = compute_scale_from_amax(amax, quantizer.q_dtype)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale scale = quantizer.scale
is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
# It is faster to use 1x quantization for tensor scaling # It is faster to use 1x quantization for tensor scaling
is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100)
force_1x_quantization = ( force_1x_quantization = (
quantizer.scaling_mode.is_tensor_scaling() quantizer.scaling_mode.is_tensor_scaling()
and quantizer.is_2x2x() and quantizer.is_2x2x()
and is_1x_kernel_supported and is_1x_kernel_supported
) )
q_layout = quantizer.q_layout q_layout = quantizer.q_layout
if force_1x_quantization: if force_1x_quantization:
q_layout = QuantizeLayout.ROWWISE q_layout = QuantizeLayout.ROWWISE
...@@ -698,6 +700,7 @@ def quantize( ...@@ -698,6 +700,7 @@ def quantize(
x: jnp.ndarray, x: jnp.ndarray,
quantizer: Quantizer, quantizer: Quantizer,
flatten_axis: int = -1, flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor]: ) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer. """Quantize input tensor according to the quantizer.
...@@ -707,6 +710,8 @@ def quantize( ...@@ -707,6 +710,8 @@ def quantize(
quantizer: Quantizer for FP8 quantization of the output. quantizer: Quantizer for FP8 quantization of the output.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1. Defaults to -1.
noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer
is None.
Returns: Returns:
A ScaledTensor containing the quantized input tensor. A ScaledTensor containing the quantized input tensor.
...@@ -715,6 +720,7 @@ def quantize( ...@@ -715,6 +720,7 @@ def quantize(
x, x,
quantizer=quantizer, quantizer=quantizer,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
noop_scaled_tensor=noop_scaled_tensor,
) )
return out return out
...@@ -724,6 +730,7 @@ def quantize_dbias( ...@@ -724,6 +730,7 @@ def quantize_dbias(
quantizer: Quantizer, quantizer: Quantizer,
is_dbias: bool = True, is_dbias: bool = True,
flatten_axis: int = -1, flatten_axis: int = -1,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient. """Quantize input tensor and compute bias gradient.
...@@ -734,6 +741,8 @@ def quantize_dbias( ...@@ -734,6 +741,8 @@ def quantize_dbias(
is_dbias: If True, compute bias gradient. Defaults to True. is_dbias: If True, compute bias gradient. Defaults to True.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1. Defaults to -1.
noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when
quantizer is None.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -743,7 +752,11 @@ def quantize_dbias( ...@@ -743,7 +752,11 @@ def quantize_dbias(
Shape: (K,) or empty if is_dbias is False. Shape: (K,) or empty if is_dbias is False.
""" """
return _quantize_dbias_impl( return _quantize_dbias_impl(
dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis dz,
quantizer=quantizer,
is_dbias=is_dbias,
flatten_axis=flatten_axis,
noop_scaled_tensor=noop_scaled_tensor,
) )
......
...@@ -119,6 +119,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -119,6 +119,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
int64_t window_size_right); int64_t window_size_right);
// GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
// Grouped GEMM // Grouped GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
......
...@@ -38,12 +38,11 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { ...@@ -38,12 +38,11 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
case xla::ffi::DataType::F8E4M3FN: case xla::ffi::DataType::F8E4M3FN:
return DType::kFloat8E4M3; return DType::kFloat8E4M3;
break; break;
// case xla::ffi::DataType::F8E8M0FNU: case xla::ffi::DataType::F8E8M0FNU:
// return DType::kFloat8E8M0; return DType::kFloat8E8M0;
// break; break;
default: default:
auto type_num = static_cast<XLA_FFI_DataType>(type); auto type_num = static_cast<XLA_FFI_DataType>(type);
if (type_num == 33) return DType::kFloat8E8M0;
NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d", NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d",
static_cast<int>(type_num)); static_cast<int>(type_num));
break; break;
......
...@@ -6,11 +6,13 @@ ...@@ -6,11 +6,13 @@
#include "transformer_engine/gemm.h" #include "transformer_engine/gemm.h"
#include <memory> #include <memory>
#include <string_view>
#include <tuple>
#include "../extensions.h" #include "../extensions.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#include "common/util/string.h"
#include "common/util/system.h" #include "common/util/system.h"
#include "transformer_engine/multi_stream.h"
#include "transformer_engine/swizzle.h" #include "transformer_engine/swizzle.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
...@@ -25,6 +27,181 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { ...@@ -25,6 +27,181 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) {
~static_cast<uintptr_t>(255)); ~static_cast<uintptr_t>(255));
} }
std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv,
JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) {
// Set tensor data with collapsed 2D shape
auto buffer_dims = buffer.dimensions();
std::vector<size_t> input_shape = {product(buffer_dims, 0, axis_boundary),
product(buffer_dims, axis_boundary, buffer_dims.size())};
auto input_dtype = convert_ffi_datatype_to_te_dtype(buffer.element_type());
TensorWrapper input(get_nvte_scaling_mode(scaling_mode));
if (rowwise) {
input.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape);
} else {
input.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape);
}
// Set scaling factor for quantized tensors
if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) {
NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands.");
NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM.");
std::vector<size_t> scale_shape = {1};
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
// Block scaling also needs to be collapsed to match 2D data
scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary),
product(scale_inv.dimensions(), axis_boundary, scale_inv.dimensions().size())};
}
auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type());
if (rowwise) {
input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
} else {
input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape);
}
// Swizzle scaling factors for MXFP8
if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) {
// Get the swizzle buffer
NVTE_CHECK(swizzled_scale_inv->element_count() > 0,
"Missing swizzled inverse scale buffer in the JAX primitive.");
auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type());
auto swizzled_scale_inv_dtype =
convert_ffi_datatype_to_te_dtype(swizzled_scale_inv->element_type());
NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1,
"Inverse scale factors need to have an 8-bit data type.");
// Create tensor to hold swizzled scale factor
TensorWrapper output(get_nvte_scaling_mode(scaling_mode));
if (rowwise) {
output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape);
output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape);
} else {
output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape);
output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype,
scale_shape);
}
// Launch swizzle kernel
nvte_swizzle_scaling_factors(input.data(), output.data(), stream);
// Set swizzled scales into the input tensor
if (rowwise) {
input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape);
} else {
input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype,
scale_shape);
}
}
}
return std::make_tuple(std::move(input), input_shape);
}
Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input,
Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out,
Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace,
JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed,
bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) {
// Operands (this includes swizzling MXFP8 scaling factors)
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
(is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported()));
bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed;
bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed;
auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(
stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise);
auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(
stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise);
// Output tensor
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
(rhs_transposed) ? rhs_shape[0] : rhs_shape[1]};
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, "
"expected ",
out_.numel(), " elements ", to_string_like(out_shape), " but got ",
output->element_count(), " elements ", to_string_like(output->dimensions()));
// Bias input to forward pass or bias gradient output from backward pass
void *bias_ptr = nullptr;
std::vector<size_t> bias_shape = {0};
DType bias_dtype = out_dtype;
if (fuse_bias) {
if (!grad) {
NVTE_CHECK(bias_grad->untyped_data() == bias.untyped_data(),
"Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad");
}
bias_ptr = bias_grad->untyped_data();
bias_shape.at(0) = bias_grad->dimensions().front();
bias_dtype = convert_ffi_datatype_to_te_dtype(bias_grad->element_type());
}
auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype);
// Pre-GeLU output from forward pass or input to backward pass
void *pre_gelu_ptr = nullptr;
std::vector<size_t> pre_gelu_shape = {0};
DType pre_gelu_dtype = out_dtype;
if (gelu_input.element_count() > 0) {
if (grad) {
NVTE_CHECK(pre_gelu_out->untyped_data() == gelu_input.untyped_data(),
"Missing operand-output aliasing in GemmPrimitive: gelu_input <-> pre_gelu_out");
}
pre_gelu_ptr = pre_gelu_out->untyped_data();
pre_gelu_shape = {product(pre_gelu_out->dimensions(), 0, pre_gelu_out->dimensions().size() - 1),
static_cast<size_t>(pre_gelu_out->dimensions().back())};
pre_gelu_dtype = convert_ffi_datatype_to_te_dtype(pre_gelu_out->element_type());
}
auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype);
// cuBLAS workspace + 256 alignment enforcement
auto workspace_ptr = reinterpret_cast<uint8_t *>(workspace->untyped_data());
workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr);
std::vector<size_t> workspace_shape = {static_cast<size_t>(workspace->element_count()) - 256};
auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte);
// Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(),
rhs_transposed, lhs_transposed, grad, workspace_.data(), false,
use_split_accumulator, num_math_sm, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // lhs
.Arg<Buffer_Type>() // lhs_scale_inv
.Arg<Buffer_Type>() // rhs
.Arg<Buffer_Type>() // rhs_scale_inv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // gelu_input
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out
.Ret<Buffer_Type>() // lhs_swizzled
.Ret<Buffer_Type>() // rhs_swizzled
.Ret<Buffer_Type>() // workspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("lhs_axis_boundary")
.Attr<int64_t>("rhs_axis_boundary")
.Attr<bool>("lhs_transposed")
.Attr<bool>("rhs_transposed")
.Attr<bool>("fuse_bias")
.Attr<bool>("fuse_gelu")
.Attr<bool>("grad")
.Attr<bool>("use_split_accumulator"),
FFI_CudaGraph_Traits);
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv,
Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias,
Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output,
......
...@@ -47,6 +47,15 @@ enum class JAXX_Scaling_Mode : int64_t { ...@@ -47,6 +47,15 @@ enum class JAXX_Scaling_Mode : int64_t {
CURRENT_TENSOR_SCALING = 3, CURRENT_TENSOR_SCALING = 3,
}; };
inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) {
return (mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING ||
mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING);
}
inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) {
return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING);
}
static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
switch (mode) { switch (mode) {
case JAXX_Scaling_Mode::NO_SCALING: case JAXX_Scaling_Mode::NO_SCALING:
......
...@@ -55,6 +55,11 @@ pybind11::dict Registrations() { ...@@ -55,6 +55,11 @@ pybind11::dict Registrations() {
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler)); pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler));
// GEMM
dict["te_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GemmHandler));
// Grouped GEMM // Grouped GEMM
dict["te_grouped_gemm_ffi"] = dict["te_grouped_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
...@@ -78,6 +83,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { ...@@ -78,6 +83,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("nvte_get_qkv_format", &nvte_get_qkv_format);
m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported);
pybind11::enum_<DType>(m, "DType", pybind11::module_local()) pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte) .value("kByte", DType::kByte)
......
...@@ -8,7 +8,7 @@ architectures, including support for quantization and automatic differentiation. ...@@ -8,7 +8,7 @@ architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations. customizable contracting dimensions for flexible tensor operations.
""" """
import warnings
from typing import Tuple, Sequence from typing import Tuple, Sequence
from functools import partial from functools import partial
import jax import jax
...@@ -23,6 +23,16 @@ from .quantize import ( ...@@ -23,6 +23,16 @@ from .quantize import (
) )
DENSE_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global DENSE_BATCH_FIRST_WARNING_ISSUED
if not DENSE_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
DENSE_BATCH_FIRST_WARNING_ISSUED = True
def dense( def dense(
x: jnp.ndarray, x: jnp.ndarray,
kernel: jnp.ndarray, kernel: jnp.ndarray,
...@@ -30,6 +40,7 @@ def dense( ...@@ -30,6 +40,7 @@ def dense(
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)),
input_axes: Tuple[str, ...] = None, input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
): ):
"""Perform dense layer transformation with optional quantization. """Perform dense layer transformation with optional quantization.
...@@ -43,25 +54,28 @@ def dense( ...@@ -43,25 +54,28 @@ def dense(
kernel: Weight matrix for the dense layer transformation kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract contracting_dims: Tuple of sequences specifying which dimensions to contract
batch_first: Assume that X is batched in the first dimension.
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns: Returns:
Transformed output tensor Transformed output tensor
""" """
# Remove when tex.quantize() can handle quantizer=None # Remove when tex.quantize() can handle quantizer=None
if quantizer_set == noop_quantizer_set: if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot():
x = with_sharding_constraint_by_logical_axes(x, input_axes) x = with_sharding_constraint_by_logical_axes(x, input_axes)
output = tex.gemm(x, kernel, contracting_dims) output = tex.gemm(x, kernel, contracting_dims=contracting_dims)
if bias is not None: if bias is not None:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape) output += jnp.reshape(bias, bias_new_shape)
else: else:
output = _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set) output = _dense(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
)
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6))
def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set):
"""Internal implementation of dense layer transformation with custom VJP. """Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support This function implements the core dense layer transformation logic with support
...@@ -75,44 +89,81 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer ...@@ -75,44 +89,81 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer
input_axes: Logical axes for sharding the activation input input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types quantizer_set: QuantizerSet which contains quantizers for different tensor types
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
Returns: Returns:
Transformed output tensor Transformed output tensor
""" """
output, _ = _dense_fwd_rule( output, _ = _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
) )
return output return output
def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, quantizer_set): def _dense_fwd_rule(
x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set
):
"""Forward pass rule for dense layer transformation. """Forward pass rule for dense layer transformation.
Returns: Returns:
Tuple of (output, context) for backward pass Tuple of (output, context) for backward pass
""" """
x_contracting_dims, k_contracting_dims = contracting_dims x_contracting_dims, k_contracting_dims = map(
tex.sanitize_dims, (x.ndim, kernel.ndim), contracting_dims
)
# Check supported input layout
x_is_transposed = x.ndim - 1 not in x_contracting_dims
k_is_transposed = kernel.ndim - 1 in k_contracting_dims
assert (
not x_is_transposed and not k_is_transposed
), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel."
# Determine X batch dimension
# - If `batch_first=True` -> (batch, leading..., contracting...)
# - Otherwise -> (leading..., batch, contracting...)
# NOTE: Always assume a single batch dimension
x_bdim = None
num_cdims = len(x_contracting_dims)
if x.ndim >= num_cdims + 2:
# Assume X is batched if it has at least +2 dimensions more than the number of contracting
# dimensions.
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `dense()` layer implementation does not officially support sequence-first "
"inputs and may produce incorrect results when `batch_first=False`. Use "
"sequence-first inputs at your own discretion.",
)
x_bdim = 0 if batch_first else x.ndim - num_cdims - 1
flatten_axis_x = -len(x_contracting_dims) flatten_axis_x = -len(x_contracting_dims)
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape)
casted_x = tex.quantize(x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x) casted_x = tex.quantize(
x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, noop_scaled_tensor=True
)
casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes)
casted_kernel = tex.quantize( casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel kernel,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.kernel,
noop_scaled_tensor=True,
) )
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# GEMM NN # GEMM NN
use_bias = bias is not None
output = tex.gemm( output = tex.gemm(
casted_x.get_tensor(usage=TensorUsage.LHS), casted_x.get_tensor(usage=TensorUsage.LHS),
casted_kernel.get_tensor(usage=TensorUsage.RHS), casted_kernel.get_tensor(usage=TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
) )
use_bias = bias is not None if use_bias and tex.gemm_uses_jax_dot():
if use_bias:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape) output += jnp.reshape(bias, bias_new_shape)
...@@ -124,20 +175,19 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes, ...@@ -124,20 +175,19 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, input_axes, kernel_axes,
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis_k, flatten_axis_k,
x_bdim,
) )
return output, ctx return output, ctx
def _dense_bwd_rule( def _dense_bwd_rule(
contracting_dims, input_axes, kernel_axes, ctx, grad contracting_dims, input_axes, kernel_axes, batch_first, ctx, grad
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation. """Backward pass rule for dense layer transformation.
Returns: Returns:
Tuple of gradients with respect to inputs Tuple of gradients with respect to inputs
""" """
fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims
( (
casted_x_lhs, casted_x_lhs,
casted_kernel_rhs, casted_kernel_rhs,
...@@ -146,10 +196,19 @@ def _dense_bwd_rule( ...@@ -146,10 +196,19 @@ def _dense_bwd_rule(
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis_k, flatten_axis_k,
x_bdim,
) = ctx ) = ctx
fwd_x_contracting_dims, fwd_k_contracting_dims = map(
tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims
)
casted_grad, dbias = tex.quantize_dbias( casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad grad,
is_dbias=use_bias,
flatten_axis=flatten_axis_k,
quantizer=quantizer_set.dgrad,
noop_scaled_tensor=True,
) )
# GEMM NT # GEMM NT
...@@ -164,7 +223,8 @@ def _dense_bwd_rule( ...@@ -164,7 +223,8 @@ def _dense_bwd_rule(
dgrad = tex.gemm( dgrad = tex.gemm(
casted_grad.get_tensor(usage=TensorUsage.LHS), casted_grad.get_tensor(usage=TensorUsage.LHS),
casted_kernel_rhs, casted_kernel_rhs,
(g_contracting_dim, k_contracting_dim), contracting_dims=(g_contracting_dim, k_contracting_dim),
batched_dims=((x_bdim,), ()),
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
...@@ -177,7 +237,8 @@ def _dense_bwd_rule( ...@@ -177,7 +237,8 @@ def _dense_bwd_rule(
wgrad = tex.gemm( wgrad = tex.gemm(
casted_x_lhs, casted_x_lhs,
casted_grad.get_tensor(usage=TensorUsage.RHS), casted_grad.get_tensor(usage=TensorUsage.RHS),
(x_contracting_dim, g_contracting_dim), contracting_dims=(x_contracting_dim, g_contracting_dim),
batched_dims=((x_bdim,), (x_bdim,)),
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
...@@ -6,7 +6,7 @@ Wrapper module for Transformer related layers with FP8 support. ...@@ -6,7 +6,7 @@ Wrapper module for Transformer related layers with FP8 support.
""" """
from functools import reduce from functools import reduce
import operator import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
import numpy as np import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
...@@ -15,12 +15,12 @@ from jax import lax ...@@ -15,12 +15,12 @@ from jax import lax
from jax import random as jax_random from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
from ..dense import dense from ..dense import dense, _issue_batch_first_warning as _dense_warning
from ..layernorm import canonicalize_norm_type from ..layernorm import canonicalize_norm_type
from ..layernorm import layernorm from ..layernorm import layernorm
from ..layernorm_dense import layernorm_dense from ..layernorm_dense import layernorm_dense, _issue_batch_first_warning as _ln_dense_warning
from ..layernorm_mlp import layernorm_mlp from ..layernorm_mlp import layernorm_mlp, _issue_batch_first_warning as _ln_mlp_warning
from ..activation import activation from ..activation import activation
from ..softmax import softmax, SoftmaxType from ..softmax import softmax, SoftmaxType
from ..sharding import with_sharding_constraint_by_logical_axes from ..sharding import with_sharding_constraint_by_logical_axes
...@@ -35,8 +35,8 @@ from ..sharding import get_non_contracting_logical_axes ...@@ -35,8 +35,8 @@ from ..sharding import get_non_contracting_logical_axes
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
DType = jnp.dtype DType = NewType("DType", jnp.dtype)
Array = jnp.ndarray Array = NewType("Array", jnp.ndarray)
PrecisionLike = Union[ PrecisionLike = Union[
None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
] ]
...@@ -441,6 +441,12 @@ class DenseGeneral(TransformerEngineBase): ...@@ -441,6 +441,12 @@ class DenseGeneral(TransformerEngineBase):
input_axes: Tuple[str, ...] = () input_axes: Tuple[str, ...] = ()
def __post_init__(self): def __post_init__(self):
if self.transpose_batch_sequence:
_dense_warning(
"TE/JAX DenseGeneral() module does not officially support sequence-first inputs "
"and may produce incorrect results when `transpose_batch_sequence=True`. Use "
"sequence-first inputs at your own discretion."
)
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype 1.0, "fan_in", "truncated_normal", dtype=self.dtype
...@@ -657,6 +663,12 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -657,6 +663,12 @@ class LayerNormDenseGeneral(TransformerEngineBase):
depth_scaling: float = None depth_scaling: float = None
def __post_init__(self): def __post_init__(self):
if self.transpose_batch_sequence:
_ln_dense_warning(
"TE/JAX LayerNormDenseGeneral() module does not officially support sequence-first "
"inputs and may produce incorrect results when `transpose_batch_sequence=True`. "
"Use sequence-first inputs at your own discretion."
)
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, 1.0,
...@@ -967,6 +979,12 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -967,6 +979,12 @@ class LayerNormMLP(TransformerEngineBase):
dot_2_input_axes: Tuple[str, ...] = None dot_2_input_axes: Tuple[str, ...] = None
def __post_init__(self): def __post_init__(self):
if self.transpose_batch_sequence:
_ln_mlp_warning(
"TE/JAX LayerNormMLP() module does not officially support sequence-first inputs "
"and may produce incorrect results when `transpose_batch_sequence=True`. Use "
"sequence-first inputs at your own discretion."
)
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype 1.0, "fan_in", "truncated_normal", dtype=self.dtype
......
...@@ -9,6 +9,7 @@ architectures. It supports various normalization types, quantization, and ...@@ -9,6 +9,7 @@ architectures. It supports various normalization types, quantization, and
distributed training through sharding constraints. distributed training through sharding constraints.
""" """
import warnings
from functools import partial from functools import partial
from typing import Tuple from typing import Tuple
...@@ -25,6 +26,16 @@ from .quantize import ( ...@@ -25,6 +26,16 @@ from .quantize import (
) )
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False
def _issue_batch_first_warning(msg):
global LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED
if not LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED:
warnings.warn(msg, UserWarning)
LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = True
def layernorm_dense( def layernorm_dense(
x: jnp.ndarray, x: jnp.ndarray,
kernel: jnp.ndarray, kernel: jnp.ndarray,
...@@ -37,6 +48,7 @@ def layernorm_dense( ...@@ -37,6 +48,7 @@ def layernorm_dense(
layernorm_input_axes: Tuple[str, ...] = None, layernorm_input_axes: Tuple[str, ...] = None,
dot_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None,
kernel_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None,
batch_first: bool = True,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
) -> jnp.ndarray: ) -> jnp.ndarray:
"""Apply layer normalization followed by dense layer transformation. """Apply layer normalization followed by dense layer transformation.
...@@ -57,6 +69,7 @@ def layernorm_dense( ...@@ -57,6 +69,7 @@ def layernorm_dense(
layernorm_input_axes: Logical axes for sharding the layernorm input layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix kernel_axes: Logical axes for sharding the weight matrix
batch_first: Assume that X is batched in the first dimension if it has more than 2 dims.
quantizer_set: Set of quantizers for different tensor types quantizer_set: Set of quantizers for different tensor types
Returns: Returns:
...@@ -80,6 +93,7 @@ def layernorm_dense( ...@@ -80,6 +93,7 @@ def layernorm_dense(
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
batch_first,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -94,6 +108,7 @@ def layernorm_dense( ...@@ -94,6 +108,7 @@ def layernorm_dense(
8, 8,
9, 9,
10, 10,
11,
), ),
) )
def _layernorm_dense( def _layernorm_dense(
...@@ -108,6 +123,7 @@ def _layernorm_dense( ...@@ -108,6 +123,7 @@ def _layernorm_dense(
layernorm_input_axes: Tuple[str, ...], layernorm_input_axes: Tuple[str, ...],
dot_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...],
kernel_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...],
batch_first: bool,
quantizer_set, quantizer_set,
): ):
"""Internal implementation of layernorm_dense with custom VJP. """Internal implementation of layernorm_dense with custom VJP.
...@@ -127,6 +143,7 @@ def _layernorm_dense( ...@@ -127,6 +143,7 @@ def _layernorm_dense(
epsilon: Small constant for numerical stability epsilon: Small constant for numerical stability
layernorm_input_axes: Logical axes for layernorm sharding layernorm_input_axes: Logical axes for layernorm sharding
dot_input_axes: Logical axes for matrix multiplication sharding dot_input_axes: Logical axes for matrix multiplication sharding
batch_first: Assume that X is batched in the first dimension.
quantizer_set: Set of quantizers quantizer_set: Set of quantizers
Returns: Returns:
...@@ -144,6 +161,7 @@ def _layernorm_dense( ...@@ -144,6 +161,7 @@ def _layernorm_dense(
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
batch_first,
quantizer_set, quantizer_set,
) )
return output return output
...@@ -161,6 +179,7 @@ def _layernorm_dense_fwd_rule( ...@@ -161,6 +179,7 @@ def _layernorm_dense_fwd_rule(
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, dot_input_axes,
kernel_axes, kernel_axes,
batch_first,
quantizer_set, quantizer_set,
): ):
"""Forward pass rule for layernorm_dense. """Forward pass rule for layernorm_dense.
...@@ -178,6 +197,17 @@ def _layernorm_dense_fwd_rule( ...@@ -178,6 +197,17 @@ def _layernorm_dense_fwd_rule(
k_contracting_dims = (0,) k_contracting_dims = (0,)
assert x.shape[-1] == kernel.shape[0] assert x.shape[-1] == kernel.shape[0]
x_bdim = None
if x.ndim > 2:
if not batch_first:
_issue_batch_first_warning(
"TE/JAX `layernorm_dense()` fused-layer implementation does not officially "
"support sequence-first inputs and may produce incorrect results when "
"`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first "
"inputs at your own discretion."
)
x_bdim = 0 if batch_first else x.ndim - 2
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
casted_ln_out, mu, rsigma = tex.normalization_fwd( casted_ln_out, mu, rsigma = tex.normalization_fwd(
...@@ -187,25 +217,31 @@ def _layernorm_dense_fwd_rule( ...@@ -187,25 +217,31 @@ def _layernorm_dense_fwd_rule(
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
norm_type, norm_type,
quantizer_set.x, quantizer=quantizer_set.x,
noop_scaled_tensor=True,
) )
casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes)
# Kernel in (hidden_in, hidden_out...) # Kernel in (hidden_in, hidden_out...)
flatten_axis = 1 - len(kernel.shape) flatten_axis = 1 - len(kernel.shape)
casted_kernel = tex.quantize(kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel) casted_kernel = tex.quantize(
kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True
)
casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes)
# NN GEMM # NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out...) # (batch..., hidden_in) x (hidden_in, hidden_out...)
use_bias = bias is not None
output = tex.gemm( output = tex.gemm(
casted_ln_out.get_tensor(TensorUsage.LHS), casted_ln_out.get_tensor(TensorUsage.LHS),
casted_kernel.get_tensor(TensorUsage.RHS), casted_kernel.get_tensor(TensorUsage.RHS),
(x_contracting_dims, k_contracting_dims), contracting_dims=(x_contracting_dims, k_contracting_dims),
batched_dims=((x_bdim,), ()),
bias=bias if not tex.gemm_uses_jax_dot() else None,
fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False,
) )
use_bias = bias is not None if use_bias and tex.gemm_uses_jax_dot():
if use_bias:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
output += jnp.reshape(bias, bias_new_shape) output += jnp.reshape(bias, bias_new_shape)
...@@ -224,6 +260,7 @@ def _layernorm_dense_fwd_rule( ...@@ -224,6 +260,7 @@ def _layernorm_dense_fwd_rule(
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis, flatten_axis,
x_bdim,
) )
return output, ctx return output, ctx
...@@ -236,6 +273,7 @@ def _layernorm_dense_bwd_rule( ...@@ -236,6 +273,7 @@ def _layernorm_dense_bwd_rule(
layernorm_input_axes, layernorm_input_axes,
dot_input_axes, # pylint: disable=unused-argument dot_input_axes, # pylint: disable=unused-argument
kernel_axes, kernel_axes,
batch_first, # pylint: disable=unused-argument
ctx, ctx,
grad, grad,
): ):
...@@ -265,10 +303,15 @@ def _layernorm_dense_bwd_rule( ...@@ -265,10 +303,15 @@ def _layernorm_dense_bwd_rule(
use_bias, use_bias,
quantizer_set, quantizer_set,
flatten_axis, flatten_axis,
x_bdim,
) = ctx ) = ctx
casted_grad, dbias = tex.quantize_dbias( casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad grad,
is_dbias=use_bias,
flatten_axis=flatten_axis,
quantizer=quantizer_set.dgrad,
noop_scaled_tensor=True,
) )
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
...@@ -284,7 +327,8 @@ def _layernorm_dense_bwd_rule( ...@@ -284,7 +327,8 @@ def _layernorm_dense_bwd_rule(
dgrad = tex.gemm( dgrad = tex.gemm(
casted_grad.get_tensor(TensorUsage.LHS), casted_grad.get_tensor(TensorUsage.LHS),
casted_kernel, casted_kernel,
(g_constracting_dim, k_constracting_dim), contracting_dims=(g_constracting_dim, k_constracting_dim),
batched_dims=((x_bdim,), ()),
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
...@@ -297,7 +341,8 @@ def _layernorm_dense_bwd_rule( ...@@ -297,7 +341,8 @@ def _layernorm_dense_bwd_rule(
wgrad = tex.gemm( wgrad = tex.gemm(
casted_ln_out, casted_ln_out,
casted_grad.get_tensor(TensorUsage.RHS), casted_grad.get_tensor(TensorUsage.RHS),
(x_constracting_dim, g_constracting_dim), contracting_dims=(x_constracting_dim, g_constracting_dim),
batched_dims=((x_bdim,), (x_bdim,)),
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment