Unverified Commit a3ec6a54 authored by Jeng Bai-Cheng's avatar Jeng Bai-Cheng Committed by GitHub
Browse files

add building workflow for TE/Jax (#53)



* add building workflow for jax modules
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* replace bit_cast with reinterpret_cast
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add nvtx to cmake check list
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor layernorm fwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor rmsnorm fwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor layernorm_bwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* set pytorch as default in setup.py
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* rename extension from *.cc to *.cpp

cpplint cannot recognize *.cc file, so rename the extension
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor style, to align TE/PyTorch
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add pybinding, unittest and qa
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix license
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* disable c-extension-no-member and no-name-in-module
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add dataclass avoid pylint error
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update transformer_engine/__init__.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update tests/jax/test_custom_call_shape.py

fix typo
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update tests/jax/test_custom_call_shape.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* add building workflow for jax modules
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* replace bit_cast with reinterpret_cast
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add nvtx to cmake check list
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor layernorm fwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor rmsnorm fwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor layernorm_bwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* set pytorch as default in setup.py
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* rename extension from *.cc to *.cpp

cpplint cannot recognize *.cc file, so rename the extension
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor style, to align TE/PyTorch
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add pybinding, unittest and qa
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix license
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* disable c-extension-no-member and no-name-in-module
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add dataclass avoid pylint error
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update transformer_engine/__init__.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update tests/jax/test_custom_call_shape.py

fix typo
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update tests/jax/test_custom_call_shape.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* fix conflict due to PR62
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix c-extension-no-member and no-name-in-module

1. add transformer_engine_jax into extension-pkg-whitelist
2. convert pylintrc from CRLF to LF format
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update setup.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* remove pylint:disable and refactor import order
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

---------
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d8a2f352
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -xe
: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax
[MASTER] [MASTER]
extension-pkg-whitelist=torch, extension-pkg-whitelist=torch,
transformer_engine_extensions transformer_engine_extensions,
transformer_engine_jax
disable=too-many-locals,
invalid-name, disable=too-many-locals,
too-many-arguments, invalid-name,
abstract-method, too-many-arguments,
arguments-differ, abstract-method,
too-many-instance-attributes, arguments-differ,
unsubscriptable-object, too-many-instance-attributes,
import-outside-toplevel, unsubscriptable-object,
too-many-statements, import-outside-toplevel,
import-error, too-many-statements,
too-many-lines, import-error,
use-maxsplit-arg, too-many-lines,
protected-access, use-maxsplit-arg,
pointless-string-statement, protected-access,
cyclic-import, pointless-string-statement,
duplicate-code, cyclic-import,
no-member, duplicate-code,
attribute-defined-outside-init, no-member,
global-statement, attribute-defined-outside-init,
too-many-branches, global-statement,
global-variable-not-assigned, too-many-branches,
redefined-argument-from-local global-variable-not-assigned,
redefined-argument-from-local
[TYPECHECK]
ignored-modules=torch [TYPECHECK]
ignored-classes=torch ignored-modules=torch
ignored-classes=torch
...@@ -14,12 +14,12 @@ from setuptools import setup, find_packages, Extension ...@@ -14,12 +14,12 @@ from setuptools import setup, find_packages, Extension
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
from distutils.version import LooseVersion from distutils.version import LooseVersion
from distutils.file_util import copy_file from distutils.file_util import copy_file
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.dirname(os.path.realpath(__file__))
with open(path + "/VERSION", "r") as f: with open(path + "/VERSION", "r") as f:
te_version = f.readline() te_version = f.readline()
CUDA_HOME = os.environ.get("CUDA_HOME", "/usr/local/cuda")
def get_cuda_bare_metal_version(cuda_dir): def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output( raw_output = subprocess.check_output(
...@@ -94,9 +94,10 @@ all_sources = pytorch_sources ...@@ -94,9 +94,10 @@ all_sources = pytorch_sources
supported_frameworks = { supported_frameworks = {
"all": all_sources, "all": all_sources,
"pytorch": pytorch_sources, "pytorch": pytorch_sources,
"jax": None, # JAX use transformer_engine/CMakeLists.txt
} }
framework = "all" framework = os.environ.get("NVTE_FRAMEWORK", "pytorch")
args = sys.argv.copy() args = sys.argv.copy()
for s in args: for s in args:
...@@ -113,19 +114,72 @@ class CMakeExtension(Extension): ...@@ -113,19 +114,72 @@ class CMakeExtension(Extension):
super(CMakeExtension, self).__init__(name, sources=sources, **kwargs) super(CMakeExtension, self).__init__(name, sources=sources, **kwargs)
self.cmake_path = cmake_path self.cmake_path = cmake_path
class FrameworkBuilderBase:
def __init__(self, *args, **kwargs) -> None:
pass
def cmake_flags(self):
return []
def initialize_options(self):
pass
def finalize_options(self):
pass
def run(self, extensions):
pass
@staticmethod
def install_requires():
return []
class PyTorchBuilder(FrameworkBuilderBase):
def __init__(self, *args, **kwargs) -> None:
pytorch_args = copy.deepcopy(args)
pytorch_kwargs = copy.deepcopy(kwargs)
from torch.utils.cpp_extension import BuildExtension
self.pytorch_build_extensions = BuildExtension(*pytorch_args, **pytorch_kwargs)
def initialize_options(self):
self.pytorch_build_extensions.initialize_options()
def finalize_options(self):
self.pytorch_build_extensions.finalize_options()
def run(self, extensions):
other_ext = [
ext for ext in extensions if not isinstance(ext, CMakeExtension)
]
self.pytorch_build_extensions.extensions = other_ext
print("Building pyTorch extensions!")
self.pytorch_build_extensions.run()
@staticmethod
def install_requires():
return ["flash-attn @ git+https://github.com/ksivaman/flash-attention.git@hopper",]
class JaxBuilder(FrameworkBuilderBase):
def cmake_flags(self):
return ["-DENABLE_JAX=ON"]
def run(self, extensions):
print("Building jax extensions!")
ext_modules = [] ext_modules = []
dlfw_builder_funcs = []
ext_modules.append( ext_modules.append(
CMakeExtension( CMakeExtension(
name="transformer_engine", name="transformer_engine",
cmake_path=os.path.join(path, "transformer_engine/common"), cmake_path=os.path.join(path, "transformer_engine"),
sources=[], sources=[],
include_dirs=include_dirs, include_dirs=include_dirs,
) )
) )
if framework in ("all", "pytorch"): if framework in ("all", "pytorch"):
from torch.utils.cpp_extension import CUDAExtension
ext_modules.append( ext_modules.append(
CUDAExtension( CUDAExtension(
name="transformer_engine_extensions", name="transformer_engine_extensions",
...@@ -137,6 +191,14 @@ if framework in ("all", "pytorch"): ...@@ -137,6 +191,14 @@ if framework in ("all", "pytorch"):
include_dirs=include_dirs, include_dirs=include_dirs,
) )
) )
dlfw_builder_funcs.append(PyTorchBuilder)
if framework in ("all", "jax"):
dlfw_builder_funcs.append(JaxBuilder)
dlfw_install_requires = []
for builder in dlfw_builder_funcs:
dlfw_install_requires = dlfw_install_requires + builder.install_requires()
def get_cmake_bin(): def get_cmake_bin():
...@@ -179,6 +241,7 @@ def get_cmake_bin(): ...@@ -179,6 +241,7 @@ def get_cmake_bin():
class CMakeBuildExtension(build_ext, object): class CMakeBuildExtension(build_ext, object):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
self.dlfw_flags = kwargs["dlfw_flags"]
super(CMakeBuildExtension, self).__init__(*args, **kwargs) super(CMakeBuildExtension, self).__init__(*args, **kwargs)
def build_extensions(self) -> None: def build_extensions(self) -> None:
...@@ -198,6 +261,7 @@ class CMakeBuildExtension(build_ext, object): ...@@ -198,6 +261,7 @@ class CMakeBuildExtension(build_ext, object):
"-DCMAKE_BUILD_TYPE=" + config, "-DCMAKE_BUILD_TYPE=" + config,
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(config.upper(), build_dir), "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(config.upper(), build_dir),
] ]
cmake_args = cmake_args + self.dlfw_flags
cmake_build_args = ["--config", config] cmake_build_args = ["--config", config]
...@@ -223,26 +287,35 @@ class CMakeBuildExtension(build_ext, object): ...@@ -223,26 +287,35 @@ class CMakeBuildExtension(build_ext, object):
except OSError as e: except OSError as e:
raise RuntimeError("CMake failed: {}".format(str(e))) raise RuntimeError("CMake failed: {}".format(str(e)))
class TEBuildExtension(build_ext, object): class TEBuildExtension(build_ext, object):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
self.dlfw_builder = []
for functor in dlfw_builder_funcs:
self.dlfw_builder.append(functor(*args, **kwargs))
flags = []
for builder in self.dlfw_builder:
flags = flags + builder.cmake_flags()
cmake_args = copy.deepcopy(args) cmake_args = copy.deepcopy(args)
cmake_kwargs = copy.deepcopy(kwargs) cmake_kwargs = copy.deepcopy(kwargs)
pytorch_args = copy.deepcopy(args) cmake_kwargs["dlfw_flags"] = flags
pytorch_kwargs = copy.deepcopy(kwargs)
self.cmake_build_extensions = CMakeBuildExtension(*cmake_args, **cmake_kwargs) self.cmake_build_extensions = CMakeBuildExtension(*cmake_args, **cmake_kwargs)
self.pytorch_build_extensions = BuildExtension(*pytorch_args, **pytorch_kwargs)
self.all_outputs = None self.all_outputs = None
super(TEBuildExtension, self).__init__(*args, **kwargs) super(TEBuildExtension, self).__init__(*args, **kwargs)
def initialize_options(self): def initialize_options(self):
self.cmake_build_extensions.initialize_options() self.cmake_build_extensions.initialize_options()
self.pytorch_build_extensions.initialize_options() for builder in self.dlfw_builder:
builder.initialize_options()
super(TEBuildExtension, self).initialize_options() super(TEBuildExtension, self).initialize_options()
def finalize_options(self): def finalize_options(self):
self.cmake_build_extensions.finalize_options() self.cmake_build_extensions.finalize_options()
self.pytorch_build_extensions.finalize_options() for builder in self.dlfw_builder:
builder.finalize_options()
super(TEBuildExtension, self).finalize_options() super(TEBuildExtension, self).finalize_options()
def run(self) -> None: def run(self) -> None:
...@@ -250,12 +323,9 @@ class TEBuildExtension(build_ext, object): ...@@ -250,12 +323,9 @@ class TEBuildExtension(build_ext, object):
cmake_ext = [ext for ext in self.extensions if isinstance(ext, CMakeExtension)] cmake_ext = [ext for ext in self.extensions if isinstance(ext, CMakeExtension)]
self.cmake_build_extensions.extensions = cmake_ext self.cmake_build_extensions.extensions = cmake_ext
self.cmake_build_extensions.run() self.cmake_build_extensions.run()
other_ext = [
ext for ext in self.extensions if not isinstance(ext, CMakeExtension) for builder in self.dlfw_builder:
] builder.run(self.extensions)
self.pytorch_build_extensions.extensions = other_ext
print("Building pyTorch extensions!")
self.pytorch_build_extensions.run()
self.all_outputs = [] self.all_outputs = []
for f in os.scandir(self.build_lib): for f in os.scandir(self.build_lib):
...@@ -313,8 +383,6 @@ setup( ...@@ -313,8 +383,6 @@ setup(
description="Transformer acceleration library", description="Transformer acceleration library",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": TEBuildExtension}, cmdclass={"build_ext": TEBuildExtension},
install_requires = [ install_requires=dlfw_install_requires,
"flash-attn @ git+https://github.com/ksivaman/flash-attention.git@hopper",
],
license_files=("LICENSE",), license_files=("LICENSE",),
) )
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import functools
import operator
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import lax
from jax import jit, value_and_grad
from flax import linen as nn
from utils import assert_allclose, is_fp8_supported
from transformer_engine.common.recipe import Format
from transformer_engine.jax.cpp_extensions import dgated_gelu, gated_gelu
from transformer_engine.jax.cpp_extensions import dgated_gelu_cast_transpose, gated_gelu_fp8
from transformer_engine.jax.cpp_extensions import dequantize, quantize
from transformer_engine.jax.dot import fp8_dot
from transformer_engine.jax.fp8 import DType, FP8GemmPackage, FP8Helper, _format2dtypes
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.mlp import fp8_ln_mlp
GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)]
FP8_COMPUTE_TYPE = [_format2dtypes(Format.E4M3), _format2dtypes(Format.HYBRID)]
LN_CASES = [(512, 1024)]
DTYPES = [jnp.bfloat16, jnp.float32]
class TestFP8Dot:
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_qdq(self):
FP8_E4M3_MAX = 448
x = jnp.asarray([[-1, 0.1], [2, 3]], jnp.float32)
amax = jnp.max(jnp.abs(x)).reshape(1)
scale = jnp.asarray(FP8_E4M3_MAX / amax, jnp.float32).reshape(1)
scale_inv = (1 / scale).reshape(1)
y, new_amax = quantize(x, amax, scale, scale_inv, out_dtype=DType.kFloat8E4M3)
assert_allclose(new_amax, 3.0)
no_use = jnp.zeros(1, jnp.float32)
z = dequantize(y,
no_use,
no_use,
scale_inv,
fp8_dtype=DType.kFloat8E4M3,
out_dtype=DType.kFloat32)
assert_allclose(z, x, rtol=5e-2, atol=5e-2)
def test_compile_bf16(self):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
a = jax.random.normal(subkeys[0], (256, 512), jnp.bfloat16)
b = jax.random.normal(subkeys[1], (512, 256), jnp.bfloat16)
def func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE),
jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
# x = input, matrix 2d
# y = input, matrix 2d (weight)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.sum(fp8_dot(fp8_gemm_pkg, 0, *_format2dtypes(None)))
value_n_grad_func = value_and_grad(func, (0, 1))
value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile()
value_n_grad_func_compiled(a, b)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.parametrize('compute_type', FP8_COMPUTE_TYPE)
def test_compile_fp8(self, compute_type):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
a = jax.random.normal(subkeys[0], (256, 512), jnp.bfloat16)
b = jax.random.normal(subkeys[1], (512, 256), jnp.bfloat16)
def func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE),
jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.sum(fp8_dot(fp8_gemm_pkg, 0, *compute_type))
value_n_grad_func = value_and_grad(func, (0, 1))
value_n_grad_func_compiled = jit(value_n_grad_func).lower(a, b).compile()
value_n_grad_func_compiled(a, b)
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
def test_forward_bf16(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE),
jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, a, [b], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
primitive_out = fp8_dot(fp8_gemm_pkg, 0, *_format2dtypes(None))
ref_out = jnp.dot(a, b)
assert_allclose(primitive_out, ref_out)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
@pytest.mark.parametrize('compute_type', FP8_COMPUTE_TYPE)
def test_forward_fp8_randint(self, m, n, k, compute_type):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
# TODO(rewang): add float random test
min_val, max_val = -8, 8
a = jax.random.randint(subkeys[0], (m, k), min_val, max_val).astype(jnp.bfloat16)
b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16)
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE),
jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_meta = [fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv]
# calculate amax
fp8_gemm_pkg = FP8GemmPackage(1, a, [b], *fp8_meta)
primitive_out = fp8_dot(fp8_gemm_pkg, 0, *compute_type)
# calculate scale by amax
fp8_meta = FP8Helper._update_fp8_metas_impl(fp8_meta)
fp8_gemm_pkg = FP8GemmPackage(1, a, [b], *fp8_meta)
primitive_out = fp8_dot(fp8_gemm_pkg, 0, *compute_type)
ref_out = jnp.dot(a, b)
ref_out = ref_out.astype(jnp.float32)
primitive_out = primitive_out.astype(jnp.float32)
assert_allclose(primitive_out, ref_out)
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
def test_grad_bf16(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16)
b = jax.random.normal(subkeys[1], (k, n), jnp.bfloat16)
def primitive_func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE),
jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.mean(fp8_dot(fp8_gemm_pkg, 0, *_format2dtypes(None)))
def ref_func(x, y):
return jnp.mean(jnp.dot(x, y))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))
primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)
assert_allclose(primitive_out, ref_out)
assert_allclose(primitive_a_grad, ref_a_grad)
assert_allclose(primitive_b_grad, ref_b_grad, atol=1e-5)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
@pytest.mark.parametrize('compute_type', FP8_COMPUTE_TYPE)
def test_grad_fp8_randint(self, m, n, k, compute_type):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
# TODO(rewang): add float random test
min_val, max_val = -8, 8
a = jax.random.randint(subkeys[0], (m, k), min_val, max_val).astype(jnp.bfloat16)
b = jax.random.randint(subkeys[1], (k, n), min_val, max_val).astype(jnp.bfloat16)
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE),
jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_meta = [fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv]
def primitive_func(x, y, metas):
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], *metas)
return jnp.sum(fp8_dot(fp8_gemm_pkg, 0, *compute_type))
def ref_func(x, y):
return jnp.sum(jnp.dot(x, y))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))
ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)
# calculate amax
primitive_out, (primitive_a_grad,
primitive_b_grad) = value_n_grad_primitive_func(a, b, fp8_meta)
# calculate scale by amax
fp8_meta = FP8Helper._update_fp8_metas_impl(fp8_meta)
primitive_out, (primitive_a_grad,
primitive_b_grad) = value_n_grad_primitive_func(a, b, fp8_meta)
assert_allclose(primitive_out, ref_out)
assert_allclose(primitive_a_grad, ref_a_grad)
assert_allclose(primitive_b_grad, ref_b_grad)
def test_contracting_dims_bf16(self):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
a = jax.random.normal(subkeys[0], (32, 8, 16, 64), jnp.bfloat16)
b = jax.random.normal(subkeys[1], (16, 64, 128), jnp.bfloat16)
def primitive_func(x, y):
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_SIZE),
jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_gemm_pkg = FP8GemmPackage(1, x, [y], fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
return jnp.sum(fp8_dot(fp8_gemm_pkg, 0, *_format2dtypes(None), ((2, 3), (0, 1))))
def ref_func(x, y):
return jnp.sum(lax.dot_general(x, y, dimension_numbers=(((2, 3), (0, 1)), ((), ()))))
value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1))
primitive_out, (primitive_a_grad, primitive_b_grad) = value_n_grad_primitive_func(a, b)
ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)
assert_allclose(primitive_out, ref_out)
assert_allclose(primitive_a_grad, ref_a_grad)
assert_allclose(primitive_b_grad, ref_b_grad)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)])
def test_grad_fp8_mlp_randint(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 4)
activations = ('gelu', 'linear')
a = jax.random.uniform(subkeys[0], (m, k), jnp.bfloat16, 5, 8)
k1 = jax.random.uniform(subkeys[1], (k, n * len(activations)), jnp.bfloat16, 5, 8)
k2 = jax.random.uniform(subkeys[2], (n, k), jnp.bfloat16, 5, 8)
s = jax.random.uniform(subkeys[3], (k,), jnp.bfloat16, 5, 8)
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_SIZE),
jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32)
fp8_meta = [fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv]
compute_type = _format2dtypes(Format.HYBRID)
def primitive_func(x, ln_s, y, z, metas):
# x is input tensor, matrix 2d
# y, z are weights, matrix 2d
# out = (x * y) * z
fp8_gemm_pkg = FP8GemmPackage(2, x, [y, z], *metas)
return jnp.mean(
fp8_ln_mlp(fp8_gemm_pkg,
ln_s,
None,
"rmsnorm",
0,
*compute_type,
activations=activations))
def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function."""
if fn_or_string == 'linear':
return lambda x: x
if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string)
if callable(fn_or_string):
return fn_or_string
raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")
def fp8_ln_mlp_py(inputs: jnp.ndarray,
ln_scale: jnp.ndarray,
kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray,
fp8_maxs: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
amax_history_idx: int,
fwd_dtype,
bwd_dtype,
epsilon=1e-6,
contracting_dims=((-1,), (0,)),
dp_dim_index=0,
activations=('gelu', 'linear')) -> jnp.ndarray:
x = jnp.asarray(inputs, jnp.float32)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * jax.lax.rsqrt(mean2 + epsilon), jnp.bfloat16)
ln_out = y * ln_scale
ln_out = jnp.asarray(ln_out, jnp.bfloat16)
fp8_gemm_1_pkg = FP8GemmPackage(1, ln_out, [kernel_1],
fp8_maxs[:FP8Helper.NUM_META_PER_GEMM],
amax[:FP8Helper.NUM_META_PER_GEMM],
scale[:FP8Helper.NUM_META_PER_GEMM],
scale_inv[:FP8Helper.NUM_META_PER_GEMM])
linear_1_out = fp8_dot(fp8_gemm_1_pkg,
amax_history_idx,
fwd_dtype,
bwd_dtype,
contracting_dims,
dp_dim_index=dp_dim_index)
x = jnp.split(linear_1_out, len(activations), axis=-1)
acts = []
for idx, act_fn in enumerate(activations):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = functools.reduce(operator.mul, acts)
x = jnp.asarray(x, jnp.bfloat16)
fp8_gemm_2_pkg = FP8GemmPackage(1, x, [kernel_2],
fp8_maxs[FP8Helper.NUM_META_PER_GEMM:],
amax[FP8Helper.NUM_META_PER_GEMM:],
scale[FP8Helper.NUM_META_PER_GEMM:],
scale_inv[FP8Helper.NUM_META_PER_GEMM:])
output = fp8_dot(fp8_gemm_2_pkg,
amax_history_idx,
fwd_dtype,
bwd_dtype,
contracting_dims,
dp_dim_index=dp_dim_index)
return output
def ref_func(x, ln_s, y, z, metas):
return jnp.mean(
fp8_ln_mlp_py(x, ln_s, y, z, *metas, 0, *compute_type, activations=activations))
value_n_grad_primitive_func = jit(value_and_grad(primitive_func, (0, 1, 2, 3)))
value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3)))
ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad,
ref_k2_grad) = value_n_grad_ref_func(a, s, k1, k2, fp8_meta)
# calculate amax
primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
primitive_k2_grad) = value_n_grad_primitive_func(a, s, k1, k2, fp8_meta)
# calculate scale by amax
fp8_meta = FP8Helper._update_fp8_metas_impl(fp8_meta)
primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
primitive_k2_grad) = value_n_grad_primitive_func(a, s, k1, k2, fp8_meta)
assert_allclose(primitive_out, ref_out, rtol=1e-2)
assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
jnp.asarray(ref_a_grad, np.float32),
rtol=1e-2)
assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_k1_grad, np.float32),
rtol=1e-2)
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
rtol=1e-2)
assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32),
rtol=1e-2)
@pytest.fixture(name="random_inputs")
def random_inputs_fixture(shape):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 4)
out = jax.random.uniform(subkeys[0], shape, jnp.bfloat16, 5, 8)
return out
class TestGatedGeLu:
def ref_func(self, inputs):
def jax_gated_gelu(x):
x = jnp.split(x, 2, axis=-1)
acts = [jax.nn.gelu(x[0]), x[1]]
x = functools.reduce(operator.mul, acts)
x = jnp.asarray(x, jnp.bfloat16)
return x
func = jit(value_and_grad(lambda x: jnp.mean(jax_gated_gelu(x))))
return func(inputs)
def prim_func(self, inputs):
@jax.custom_vjp
def primitive(x):
out, _ = primitive_fwd(x)
return out
def primitive_fwd(x):
out = gated_gelu(x)
ctx = x
return out, ctx
def primitive_bwd(ctx, g):
x = ctx
out = dgated_gelu(g, x)
return (out,)
primitive.defvjp(primitive_fwd, primitive_bwd)
func = jit(value_and_grad(lambda x: jnp.mean(primitive(x))))
return func(inputs)
@pytest.mark.parametrize('shape', [(32, 64), (64, 256)])
def test_gated_gelu(self, random_inputs):
x = random_inputs
prim_out, prim_grad = self.prim_func(x)
ref_out, ref_grad = self.ref_func(x)
assert_allclose(prim_out, ref_out, rtol=1e-2)
assert_allclose(prim_grad, ref_grad, rtol=1e-1, atol=1e-3)
class TestGatedGeLuFP8(TestGatedGeLu):
def prim_func(self, inputs):
amax = self.amax
scale = self.scale
scale_inv = self.scale_inv
no_use = jnp.zeros(1, jnp.float32)
@jax.custom_vjp
def primitive(x, y, z):
out = primitive_fwd(x, y, z)
return out
def primitive_fwd(x, y, z): # pylint: disable=unused-argument
out, _ = gated_gelu_fp8(x, amax, scale, scale_inv, DType.kFloat8E5M2)
out = dequantize(out, no_use, no_use, scale_inv, DType.kFloat8E5M2, DType.kBFloat16)
ctx = x
return out, ctx
def primitive_bwd(ctx, g):
x = ctx
dgelu, dgelu_trans, amax_out = dgated_gelu_cast_transpose(g, x, amax, scale, scale_inv,
DType.kFloat8E5M2)
dgelu = dequantize(dgelu, no_use, no_use, scale_inv, DType.kFloat8E5M2, DType.kFloat32)
dgelu_trans = dequantize(dgelu_trans, no_use, no_use, scale_inv, DType.kFloat8E5M2,
DType.kFloat32)
return dgelu, dgelu_trans, amax_out
primitive.defvjp(primitive_fwd, primitive_bwd)
func = jit(value_and_grad(lambda x, y, z: jnp.mean(primitive(x, y, z)), (0, 1, 2)))
return func(inputs, no_use, no_use)
@pytest.mark.skipif(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
@pytest.mark.parametrize('shape', [(32, 64), (64, 256)])
def test_gated_gelu(self, random_inputs):
self.amax = jnp.zeros(1, jnp.float32)
self.scale = jnp.ones(1, jnp.float32)
self.scale_inv = jnp.ones(1, jnp.float32)
x = random_inputs
prim_out, (prim_grad, prim_grad_trans, amax) = self.prim_func(x)
ref_out, ref_grad = self.ref_func(x)
assert_allclose(prim_out, ref_out, rtol=1e-2)
assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
assert_allclose(prim_grad, ref_grad, rtol=1e-1, atol=1e-3)
assert_allclose(prim_grad_trans, jnp.transpose(ref_grad), rtol=1e-1, atol=1e-3)
class TestRMSNorm:
@pytest.mark.parametrize('n, hidden', LN_CASES)
@pytest.mark.parametrize('dtype', DTYPES)
def test_forward_backward(self, n, hidden, dtype):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -2, 1)
scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, -2, 1)
scale = jnp.asarray(scale, dtype)
epsilon = 1e-6
def reference_rmsnorm(x, scale):
x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * lax.rsqrt(mean2 + epsilon), dtype)
return y * scale
jitted_primitive = jit(
value_and_grad(lambda x, scale: jnp.mean(layernorm(x, scale, None, "rmsnorm")), (0, 1)))
jitted_reference = jit(
value_and_grad(lambda x, scale: jnp.mean(reference_rmsnorm(x, scale)), (0, 1)))
primitive_out, (primitive_dx, primitive_dgamma) = jitted_primitive(x, scale)
reference_out, (reference_dx, reference_dgamma) = jitted_reference(x, scale)
if dtype == jnp.float32:
assert_allclose(primitive_out, reference_out, rtol=1e-7)
assert_allclose(primitive_dx, reference_dx, rtol=1e-7)
assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-7)
else:
assert_allclose(primitive_out, reference_out, rtol=1e-3)
assert_allclose(primitive_dx, reference_dx, rtol=1e-4, atol=5e-8)
assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-4, atol=5e-8)
class TestLayerNorm:
@pytest.mark.parametrize('n, hidden', LN_CASES)
@pytest.mark.parametrize('dtype', DTYPES)
def test_forward_backward(self, n, hidden, dtype):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3)
x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -2, 1)
scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, -2, 1)
scale = jnp.asarray(scale, dtype)
bias = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -2, 1)
bias = jnp.asarray(bias, dtype)
epsilon = 1e-6
def reference_layernorm(x, scale, bias):
x = jnp.asarray(x, jnp.float32)
mean = jnp.mean(x, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True)
normed_input = (x - mean) * jax.lax.rsqrt(var + epsilon)
# Align TE implementation
return jnp.asarray(normed_input * scale + bias)
jitted_primitive = jit(
value_and_grad(lambda x, scale, bias: jnp.mean(layernorm(x, scale, bias, "layernorm")),
(0, 1, 2)))
jitted_reference = jit(
value_and_grad(lambda x, scale, bias: jnp.mean(reference_layernorm(x, scale, bias)),
(0, 1, 2)))
primitive_out, (primitive_dx, primitive_dgamma,
primitive_dbeta) = jitted_primitive(x, scale, bias)
reference_out, (reference_dx, reference_dgamma,
reference_dbeta) = jitted_reference(x, scale, bias)
if dtype == jnp.float32:
assert_allclose(primitive_out, reference_out, rtol=1e-7)
assert_allclose(primitive_dx, reference_dx, rtol=1e-7)
assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-7)
assert_allclose(primitive_dbeta, reference_dbeta, rtol=1e-7)
else:
assert_allclose(primitive_out, reference_out, rtol=1e-3)
assert_allclose(primitive_dx, reference_dx, rtol=1e-4, atol=5e-8)
assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-4, atol=5e-8)
assert_allclose(primitive_dbeta, reference_dbeta, rtol=1e-4, atol=5e-8)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import jax.numpy as jnp
from jax.abstract_arrays import ShapedArray
from transformer_engine_jax import DType
from transformer_engine.jax.cpp_extensions import te_dtype_to_jax_dtype
from transformer_engine.jax.cpp_extensions import GemmPrimitive
SHAPES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)]
NAMED_SHAPES = [{}, {
"data": 4
}, {
"data": 2
}, {
"model": 4
}, {
"model": 2
}, {
"data": 4,
"model": 2
}, {
"model": 4,
"data": 2
}]
DTYPE = [DType.kFloat32, DType.kFloat16, DType.kBFloat16]
TRANSPOSE = [True, False]
class TestGEMMShapeInfer:
@staticmethod
def _joint_named_shape(ns1, ns2):
output_named_shape = {**ns1}
need_assert = False
for key in ns2:
if key in output_named_shape and output_named_shape[key] != ns2[key]:
need_assert = True
else:
output_named_shape[key] = ns2[key]
return output_named_shape, need_assert
@staticmethod
def _get_shapes(m, n, k, transa, transb):
# te_gemm only support TN and col-major, then we have to reorder a, b shape
# to compute row-major matrices calculate in col-major algos.
a = (m, k) if transa else (k, m)
b = (k, n) if transb else (n, k)
out = (n, m)
return a, b, out
@pytest.mark.parametrize('shapes', SHAPES)
@pytest.mark.parametrize('named_shape1', NAMED_SHAPES)
@pytest.mark.parametrize('named_shape2', NAMED_SHAPES)
@pytest.mark.parametrize('te_dtype', DTYPE)
@pytest.mark.parametrize('transa', TRANSPOSE)
@pytest.mark.parametrize('transb', TRANSPOSE)
def test_shape_infer(self, shapes, named_shape1, named_shape2, te_dtype, transa, transb):
a_shape, b_shape, out_shape = TestGEMMShapeInfer._get_shapes(*shapes, transa, transb)
dtype = te_dtype_to_jax_dtype(te_dtype)
mat_a = ShapedArray(a_shape, dtype, named_shape=named_shape1)
mat_b = ShapedArray(b_shape, dtype, named_shape=named_shape2)
scale_inv_a = ShapedArray((3, 1), jnp.float32)
scale_inv_b = ShapedArray((3, 1), jnp.float32)
ref_out_named_shape, need_assert = TestGEMMShapeInfer._joint_named_shape(
named_shape1, named_shape2)
ref_out = ShapedArray(out_shape, dtype, named_shape=ref_out_named_shape)
try:
test_out = GemmPrimitive.abstract(mat_a,
mat_b,
scale_inv_a,
scale_inv_b,
A_dtype=te_dtype,
B_dtype=te_dtype,
D_dtype=te_dtype,
transa=transa,
transb=transb,
use_split_accumulator=False)
assert not need_assert
assert ref_out == test_out
except AssertionError as ae:
assert need_assert, f"{ae.args}"
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import unittest
import flax
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import maps
from utils import assert_allclose, is_fp8_supported
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.fp8 import FP8Helper
from transformer_engine.jax.sharding import infer_major_sharding_type
from transformer_engine.jax.sharding import MajorShardingType
from transformer_engine.jax.sharding import ShardingResource
class TestFP8Helper(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_initialize(self):
margin = 5.0
fp8_format = FP8Format.E4M3
update_fp8meta_interval = 10
amax_history_size = 10
FP8Helper.initialize(margin=margin,
fp8_format=fp8_format,
update_fp8meta_interval=update_fp8meta_interval,
amax_history_size=amax_history_size)
self.assertEqual(
FP8Helper.MARGIN, margin, f"FP8Helper.MARGIN initialization failed, should be {margin}"
f" but got {FP8Helper.MARGIN}.")
self.assertEqual(
FP8Helper.FP8_FORMAT, fp8_format,
f"FP8Helper.FP8_FORMAT initialization failed, should be {fp8_format}"
f" but got {FP8Helper.FP8_FORMAT}.")
self.assertEqual(
FP8Helper.UPDATE_FP8META_INTERVAL, update_fp8meta_interval,
"FP8Helper.UPDATE_FP8META_INTERVAL initialization failed, should be"
f"{update_fp8meta_interval} but got {FP8Helper.UPDATE_FP8META_INTERVAL}.")
self.assertEqual(
FP8Helper.AMAX_HISTORY_SIZE, amax_history_size,
f"FP8Helper.AMAX_HISTORY_SIZE initialization failed, should be {amax_history_size}"
f" but got {FP8Helper.AMAX_HISTORY_SIZE}.")
FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_update_fp8_metas(self):
FP8Helper.initialize(margin=3.0, amax_history_size=5)
seed = 0
key1, key2 = jax.random.split(jax.random.PRNGKey(seed))
num_of_gemm = 10
num_of_meta = FP8Helper.NUM_META_PER_GEMM * num_of_gemm
def get_fp8_scale(fp8_max, amax, scale):
fp8_max = np.array(fp8_max)
amax = np.array(amax)
scale = np.array(scale)
exp = np.floor(np.log2(fp8_max / amax)) - FP8Helper.MARGIN
sf = np.round(np.power(2, np.abs(exp)))
sf = np.where(amax > 0.0, sf, scale)
sf = np.where(np.isfinite(amax), sf, scale)
return np.where(exp < 0, 1 / sf, sf)
meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_SIZE)
fp8_max_array = FP8Helper.generate_fp8_max_array(num_of_meta)
fp8_amax_array1 = jax.random.uniform(key1, shape=meta_shape)
fp8_scale_array1 = get_fp8_scale(fp8_max_array, fp8_amax_array1, jnp.ones(meta_shape))
fp8_scale_inv_array1 = 1 / fp8_scale_array1
fp8_amax_array2 = jax.random.uniform(key2, shape=meta_shape)
fp8_scale_array2 = get_fp8_scale(fp8_max_array, fp8_amax_array2, jnp.ones(meta_shape))
fp8_scale_inv_array2 = 1 / fp8_scale_array2
state = flax.core.frozen_dict.FrozenDict({
FP8Helper.FP8_COLLECTION_NAME: {
"test_update_fp8_metas1": {
FP8Helper.FP8_MAX_NAME: fp8_max_array,
FP8Helper.FP8_AMAX_NAME: fp8_amax_array1,
FP8Helper.FP8_SCALE_NAME: jnp.ones(meta_shape),
FP8Helper.FP8_SCALE_INV_NAME: jnp.ones(meta_shape)
},
"test_update_fp8_metas2": {
FP8Helper.FP8_MAX_NAME: fp8_max_array,
FP8Helper.FP8_AMAX_NAME: fp8_amax_array2,
FP8Helper.FP8_SCALE_NAME: jnp.ones(meta_shape),
FP8Helper.FP8_SCALE_INV_NAME: jnp.ones(meta_shape)
}
}
})
updated_state = FP8Helper.update_fp8_metas(state)
state_array, _ = jax.tree_util.tree_flatten(updated_state)
meta_per_gemm = FP8Helper.NUM_META_PER_GEMM + 1
scale_shift = 2
scale_inv_shift = 3
assert_allclose(state_array[0 * meta_per_gemm + scale_shift], fp8_scale_array1)
assert_allclose(state_array[0 * meta_per_gemm + scale_inv_shift], fp8_scale_inv_array1)
assert_allclose(state_array[1 * meta_per_gemm + scale_shift], fp8_scale_array2)
assert_allclose(state_array[1 * meta_per_gemm + scale_inv_shift], fp8_scale_inv_array2)
FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_generate_fp8_max_array(self):
num_of_meta = FP8Helper.NUM_META_PER_GEMM * 2
def get_ref(format_for_test):
refer_list = []
for i in range(num_of_meta):
val = format_for_test.value.max_bwd \
if i % FP8Helper.NUM_META_PER_GEMM == FP8Helper.GRAD_META_IDX_PER_GEMM \
else format_for_test.value.max_fwd
refer_list.append([val])
return jnp.asarray(refer_list)
for fp8_format in FP8Format:
FP8Helper.initialize(fp8_format=fp8_format)
assert_allclose(get_ref(fp8_format), FP8Helper.generate_fp8_max_array(num_of_meta))
FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_update_collections(self):
original_val = 0.0
updated_val = 10.0
original_state = {
"test1": original_val,
"test2": original_val,
}
updated_state = FP8Helper.update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val)
original_state = flax.core.frozen_dict.FrozenDict(original_state)
updated_state = FP8Helper.update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val)
class TestFP8Functions(unittest.TestCase):
def _check_defult_state(self):
self.assertFalse(FP8Helper.enable_fp8())
self.assertEqual(infer_major_sharding_type(), MajorShardingType.SINGLE)
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_fp8_autocast(self):
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
self.assertFalse(FP8Helper.enable_fp8())
self._check_defult_state()
ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(FP8Helper.enable_fp8())
self.assertEqual(FP8Helper.MARGIN, ds.margin)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_SIZE, ds.amax_history_len)
self._check_defult_state()
ds = DelayedScaling(margin=3.0, interval=1, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(FP8Helper.enable_fp8())
self.assertEqual(FP8Helper.MARGIN, ds.margin)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_SIZE, ds.amax_history_len)
self._check_defult_state()
ds = DelayedScaling(amax_history_len=2)
with self.assertRaises(AssertionError):
with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling(amax_history_len=2)):
pass
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_fp8_autocast_with_sharding_resource(self):
self._check_defult_state()
ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1)
# TODO (Ming Huang): Suport multi-GPUs testing. # pylint: disable=fixme
# srs = (
# (ShardingResource(None, None), MajorShardingType.SINGLE),
# (ShardingResource('dp', None), MajorShardingType.DP),
# (ShardingResource(None, 'tp'), MajorShardingType.TP),
# (ShardingResource('dp', 'tp'), MajorShardingType.DPTP),
# )
srs = (
(ShardingResource(None, None), MajorShardingType.SINGLE),
(ShardingResource('dp', None), MajorShardingType.SINGLE),
(ShardingResource(None, 'tp'), MajorShardingType.SINGLE),
(ShardingResource('dp', 'tp'), MajorShardingType.SINGLE),
)
# TODO (Ming Huang): Suport multi-GPUs testing. # pylint: disable=fixme
mesh_shape = (1, 1)
devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
with maps.Mesh(devices, ('dp', 'tp')):
for sr, mst in srs:
with fp8_autocast(enabled=True, fp8_recipe=ds, sharding_resource=sr):
self.assertTrue(FP8Helper.enable_fp8())
self.assertEqual(FP8Helper.MARGIN, ds.margin)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_SIZE, ds.amax_history_len)
self.assertEqual(infer_major_sharding_type(), mst)
self._check_defult_state()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import jax
import numpy as np
import pytest
from jax.experimental import maps
from transformer_engine.jax.sharding import get_dot_sharding_meta
from transformer_engine.jax.sharding import get_elementwise_sharding_meta
from transformer_engine.jax.sharding import get_fp8_meta_sharding_meta
from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import infer_major_sharding_type
from transformer_engine.jax.sharding import is_dp_enabled, is_tp_enabled
from transformer_engine.jax.sharding import ShardingMeta, ShardingResource, ShardingType
def _get_sharding_resource(mesh_names, sharding_type):
dp_r = None
tp_r = None
if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
dp_r = mesh_names[0]
if sharding_type in (ShardingType.TP_COL, ShardingType.TP_ROW):
tp_r = mesh_names[0]
if sharding_type in (ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
tp_r = mesh_names[1]
return ShardingResource(dp_r, tp_r)
DEVICE_COUNT = 4
MESH_CONFIG = [((4,), ("dp",), ShardingType.DP), ((4,), ("tp",), ShardingType.TP_COL),
((4,), ("tp",), ShardingType.TP_ROW), ((2, 2), ("dp", "tp"), ShardingType.DP_TP_COL),
((2, 2), ("dp", "tp"), ShardingType.DP_TP_ROW)]
LOGICAL_RULES = [[(('a1', None), ('a2', 'ma2')), False],
[(('a1', None), ('a2', 'ma2'), ('a3', ('ma31', 'ma32'))), True],
[(('a1', None), ('a2', 'ma2'), ('batch', 'batch_1200234')), True]]
SRS = [
ShardingResource(),
ShardingResource('data', None),
ShardingResource(None, 'model'),
ShardingResource('data', 'model')
]
def is_devices_enough():
return len(jax.devices()) >= DEVICE_COUNT
class TestGeneralFunc:
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough')
def test_infer_major_sharding_type(
self,
mesh_shape, # pylint: disable=unused-argument
mesh_names,
sharding_type):
devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
with maps.Mesh(devices, mesh_names):
assert infer_major_sharding_type() is sharding_type.value[0]
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
def test_is_dp_enabled(
self,
mesh_shape, # pylint: disable=unused-argument
mesh_names, # pylint: disable=unused-argument
sharding_type):
if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
assert is_dp_enabled(sharding_type.value[0])
else:
assert not is_dp_enabled(sharding_type.value[0])
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
def test_is_tp_enabled(
self,
mesh_shape, # pylint: disable=unused-argument
mesh_names, # pylint: disable=unused-argument
sharding_type):
if sharding_type is ShardingType.DP:
assert not is_tp_enabled(sharding_type.value[0])
else:
assert is_tp_enabled(sharding_type.value[0])
class TestShardingMetaGenerator:
BATCH_AXIS_NAME = 'batch'
MODEL_AXIS_NAME = 'model'
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough')
def test_fp8_meta(self, mesh_shape, mesh_names, sharding_type, num_of_fp8_meta=4):
def stack_axes_meta(mapping):
return tuple(mapping for _ in range(num_of_fp8_meta))
def get_ref_sm():
if sharding_type == ShardingType.DP:
return ShardingMeta(stack_axes_meta({}), stack_axes_meta({}),
{TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0]}, (),
())
if sharding_type == ShardingType.TP_COL:
return ShardingMeta(stack_axes_meta({}), stack_axes_meta({}),
{TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]}, (),
())
if sharding_type == ShardingType.TP_ROW:
return ShardingMeta(stack_axes_meta({}), stack_axes_meta({}),
{TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]}, (),
())
if sharding_type == ShardingType.DP_TP_COL:
return ShardingMeta(
stack_axes_meta({}), stack_axes_meta({}), {
TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
}, (), ())
if sharding_type == ShardingType.DP_TP_ROW:
return ShardingMeta(
stack_axes_meta({}), stack_axes_meta({}), {
TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
}, (), ())
return None
devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
with maps.Mesh(devices, mesh_names):
test_sm = get_fp8_meta_sharding_meta(
sharding_type,
num_of_fp8_meta,
dp_axis_name=TestShardingMetaGenerator.BATCH_AXIS_NAME,
tp_axis_name=TestShardingMetaGenerator.MODEL_AXIS_NAME)
assert test_sm == get_ref_sm()
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
@pytest.mark.parametrize('a_shape, b_shape', [((64, 128, 256), (256, 512)),
((128, 64, 512), (512, 256))])
@pytest.mark.parametrize('batch_dim_of_a', [0, 1])
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough')
def test_dot(self, mesh_shape, mesh_names, sharding_type, a_shape, b_shape, batch_dim_of_a):
model_dim_of_a = len(a_shape) - 1
model_dim_of_b = 0 if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW) else 1
contracting_dims = ((-1,), (0,))
def get_ref_sm():
out_shape = (*a_shape[:min(contracting_dims[0])],
*b_shape[max(contracting_dims[1]) + 1:])
if sharding_type == ShardingType.DP:
a_new_shape = (*a_shape[:batch_dim_of_a], mesh_shape[0], -1,
*a_shape[batch_dim_of_a + 1:])
return ShardingMeta(({
batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME
}, {}), ({
batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME
}), {TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0]},
[a_new_shape, b_shape], [out_shape])
if sharding_type == ShardingType.TP_COL:
b_new_shape = (b_shape[0], mesh_shape[0], b_shape[1] // mesh_shape[0])
return ShardingMeta(({}, {
1: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), ({
len(out_shape) - 1: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]},
[a_shape, b_new_shape], [out_shape])
if sharding_type == ShardingType.TP_ROW:
a_new_shape = (*a_shape[:-1], mesh_shape[0], a_shape[-1] // mesh_shape[0])
b_new_shape = (mesh_shape[0], b_shape[0] // mesh_shape[0], b_shape[1])
return ShardingMeta(({
len(a_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
}, {
0: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), ({}), {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]},
[a_new_shape, b_new_shape], [out_shape])
if sharding_type == ShardingType.DP_TP_COL:
a_new_shape = (*a_shape[:batch_dim_of_a], mesh_shape[0],
a_shape[batch_dim_of_a] // mesh_shape[0],
*a_shape[batch_dim_of_a + 1:])
b_new_shape = (b_shape[0], mesh_shape[1], b_shape[1] // mesh_shape[1])
return ShardingMeta(
({
batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME
}, {
1: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), ({
batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME,
len(out_shape): TestShardingMetaGenerator.MODEL_AXIS_NAME
}), {
TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
}, [a_new_shape, b_new_shape], [out_shape])
if sharding_type == ShardingType.DP_TP_ROW:
a_new_shape = (*a_shape[:batch_dim_of_a], mesh_shape[0],
a_shape[batch_dim_of_a] // mesh_shape[0],
*a_shape[batch_dim_of_a + 1:-1], mesh_shape[1],
a_shape[-1] // mesh_shape[1])
b_new_shape = (mesh_shape[1], b_shape[0] // mesh_shape[1], b_shape[1])
return ShardingMeta(
({
batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME,
len(a_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
}, {
0: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), ({
batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME
}), {
TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
}, [a_new_shape, b_new_shape], [out_shape])
return None
devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
with maps.Mesh(devices, mesh_names):
test_sm = get_dot_sharding_meta(
sharding_type,
a_shape,
b_shape,
batch_dim_of_a,
model_dim_of_a,
model_dim_of_b,
contracting_dims,
dp_axis_name=TestShardingMetaGenerator.BATCH_AXIS_NAME,
tp_axis_name=TestShardingMetaGenerator.MODEL_AXIS_NAME)
assert test_sm == get_ref_sm()
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
@pytest.mark.parametrize('input_shape', [(64, 128, 256), (128, 64, 512)])
@pytest.mark.parametrize('other_shape', [(256,), (512,)])
@pytest.mark.parametrize('batch_dim', [0, 1])
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough')
def test_elementwise(self, mesh_shape, mesh_names, sharding_type, input_shape, other_shape,
batch_dim):
def get_ref_sm():
need_assert = True
ref_sharding_meta = None
if input_shape[-1] != other_shape[0]:
need_assert = True
ref_sharding_meta = None
elif sharding_type is (ShardingType.DP_TP_COL, ShardingType.DP):
need_assert = False
input_new_shape = (*input_shape[:batch_dim], mesh_shape[0], -1,
*input_shape[batch_dim + 1:])
ref_sharding_meta = ShardingMeta(({
batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME
}, {}), ({
batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME
}), {TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0]},
[input_new_shape, other_shape], [input_shape])
elif sharding_type is ShardingType.TP_COL:
need_assert = False
ref_sharding_meta = ShardingMeta(({}, {}), ({}), {}, [input_shape, other_shape],
[input_shape])
elif sharding_type is ShardingType.TP_ROW:
need_assert = False
input_new_shape = (*input_shape[:-1], mesh_shape[0], -1)
other_new_shape = (mesh_shape[0], -1)
ref_sharding_meta = ShardingMeta(({
len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
}, {
0: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), ({
len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]},
[input_new_shape, other_new_shape], [input_shape])
elif sharding_type is ShardingType.DP_TP_ROW:
need_assert = False
input_new_shape = (*input_shape[:batch_dim], mesh_shape[0], -1,
*input_shape[batch_dim + 1:-1], mesh_shape[1],
input_shape[-1] // mesh_shape[1])
other_new_shape = (mesh_shape[0], -1)
ref_sharding_meta = ShardingMeta(
({
batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME,
len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
}, {
0: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), ({
batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME,
len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), {
TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
}, [input_new_shape, other_new_shape], [input_shape])
return ref_sharding_meta, need_assert
devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
with maps.Mesh(devices, mesh_names):
ref_sm, need_assert = get_ref_sm()
try:
test_sm = get_elementwise_sharding_meta(
sharding_type,
input_shape,
other_shape,
batch_dim,
dp_axis_name=TestShardingMetaGenerator.BATCH_AXIS_NAME,
tp_axis_name=TestShardingMetaGenerator.MODEL_AXIS_NAME)
assert not need_assert
assert test_sm == ref_sm
except (NotImplementedError, AssertionError) as e:
assert need_assert, f"{e.args}"
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Any, Callable, Tuple, Union
import jax.numpy as jnp
import numpy as np
from cuda import cudart
from jax import lax
PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = Any
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
lax.Precision]]
Initializer = Callable[[PRNGKey, Shape, DType], Array]
def is_fp8_supported():
"""
Thus JAX doesn't have API to query capability
Use cuda-python for get the compute capability
"""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
ret, sm_major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
assert ret == cudaSuccess
return sm_major >= 9
def assert_allclose(actual,
desired,
rtol=1e-05,
atol=1e-08,
equal_nan=True,
err_msg='',
verbose=True):
if not isinstance(actual, float):
actual = actual.astype(jnp.float32)
if not isinstance(desired, float):
desired = desired.astype(jnp.float32)
np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
cmake_minimum_required(VERSION 3.18)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 70 80 90)
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
project(transformer_engine LANGUAGES CUDA CXX)
list(APPEND CMAKE_CUDA_FLAGS "--threads 4")
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G")
endif()
find_package(CUDAToolkit REQUIRED cublas nvToolsExt)
find_package(Python COMPONENTS Interpreter Development REQUIRED)
include_directories(${PROJECT_SOURCE_DIR})
add_subdirectory(common)
option(ENABLE_JAX "Enable JAX in the building workflow." OFF)
if(ENABLE_JAX)
find_package(pybind11 CONFIG REQUIRED)
add_subdirectory(jax)
endif()
...@@ -4,4 +4,14 @@ ...@@ -4,4 +4,14 @@
"""Top level package""" """Top level package"""
from . import common from . import common
from . import pytorch
try:
from . import pytorch
except ImportError as e:
pass
try:
from . import jax
except ImportError as e:
pass
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
cmake_minimum_required(VERSION 3.18)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 70 80 90)
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
project(transformer_engine LANGUAGES CUDA CXX)
list(APPEND CMAKE_CUDA_FLAGS "--threads 4")
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
list(APPEND CMAKE_CUDA_FLAGS "-G")
endif()
add_library(transformer_engine SHARED add_library(transformer_engine SHARED
transformer_engine.cpp transformer_engine.cpp
transpose/cast_transpose.cu transpose/cast_transpose.cu
...@@ -38,9 +19,7 @@ add_library(transformer_engine SHARED ...@@ -38,9 +19,7 @@ add_library(transformer_engine SHARED
fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu) fused_softmax/scaled_upper_triang_masked_softmax.cu)
target_include_directories(transformer_engine PUBLIC "${PROJECT_SOURCE_DIR}/include") target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include")
find_package(CUDAToolkit REQUIRED cublas nvToolsExt)
list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart CUDA::nvToolsExt) list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart CUDA::nvToolsExt)
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS}) target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
pybind11_add_module(
transformer_engine_jax
${CMAKE_CURRENT_SOURCE_DIR}/csrc/extensions.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/modules.cpp
)
target_link_libraries(transformer_engine_jax PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt transformer_engine)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
from .fp8 import fp8_autocast
from .sharding import ShardingResource
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX te custom call"""
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Tuple
from functools import partial, reduce
import operator
import numpy as np
from jaxlib.hlo_helpers import custom_call
import jax.numpy as jnp
from jax.lib import xla_client
from jax import core, dtypes
from jax.abstract_arrays import ShapedArray
from jax.interpreters import xla, mlir
from jax.interpreters.mlir import ir, dtype_to_ir_type
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
for _name, _value in transformer_engine_jax.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
def te_dtype_to_jax_dtype(te_dtype):
"""
convert TE dtype to jax dtype
"""
assert isinstance(te_dtype, TEDType)
if te_dtype == TEDType.kFloat32:
return jnp.float32
if te_dtype == TEDType.kFloat16:
return jnp.float16
if te_dtype == TEDType.kBFloat16:
return jnp.bfloat16
if te_dtype == TEDType.kInt32:
return jnp.int32
return jnp.int8
def te_dtype_to_ir_dtype(te_dtype):
"""
convert TE dtype to MLIR dtype
"""
return dtype_to_ir_type(np.dtype(te_dtype_to_jax_dtype(te_dtype)))
def jax_dtype_to_te_dtype(jax_dtype):
"""
convert jax dtype to TE dtype
"""
if jax_dtype == jnp.float32:
return TEDType.kFloat32
if jax_dtype == jnp.float16:
return TEDType.kFloat16
if jax_dtype == jnp.bfloat16:
return TEDType.kBFloat16
raise ValueError(f"Not support the {jax_dtype=}")
def merge_named_shape(base, new):
"""
merge named shape(ie, dict), no key conflict
"""
output_named_shape = {**base}
for key in new:
if key in output_named_shape:
assert output_named_shape[key] == new[key], \
f"The value of named shape with a same name should be equal between" \
f" base and new in merge_named_shape, but got base[{key}]=" \
f"{output_named_shape[key]} and {new[key]=}"
else:
output_named_shape[key] = new[key]
return output_named_shape
class BasePrimitive(metaclass=ABCMeta):
"""
jax premitive
"""
@staticmethod
@abstractmethod
def abstract():
"""
to describe computing graph
"""
return NotImplemented
@staticmethod
@abstractmethod
def lowering():
"""
to describe MLIR
"""
return NotImplemented
def register_primitive(cls):
"""
register jax primitive
"""
p = core.Primitive(cls.name)
p.multiple_results = cls.multiple_results
p.def_impl(partial(xla.apply_primitive, p))
p.def_abstract_eval(cls.abstract)
mlir.register_lowering(p, cls.lowering, platform='cuda')
return p
@dataclass
class CustomCallArgsWrapper:
"""
wrapper of XLA custom call args
"""
def __init__(self,
output_types,
operands,
operand_shapes,
operand_specific_layouts=None,
output_specific_layouts=None):
self.output_types = output_types
self.operands = operands
self.operand_layouts = CustomCallArgsWrapper.generate_layouts(operand_shapes,
operand_specific_layouts)
output_shapes = [x.shape for x in output_types]
self.output_layouts = CustomCallArgsWrapper.generate_layouts(output_shapes,
output_specific_layouts)
@staticmethod
def generate_layouts(shapes, specific_layouts):
"""
setup layouts for XLA custom call
"""
def default_layout(shape):
return range(len(shape) - 1, -1, -1)
if specific_layouts is None:
specific_layouts = {}
layouts = []
for idx, shape in enumerate(shapes):
if idx in specific_layouts:
layouts.append(specific_layouts[idx])
else:
layouts.append(default_layout(shape))
return layouts
def custom_caller(name, args, opaque, has_side_effect, **kwargs):
"""
XLA custom call warpper
"""
out = custom_call(name,
args.output_types,
args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs)
return out
class TransposePrimitive(BasePrimitive):
"""
Transpose Primitive
"""
name = "te_transpose"
multiple_results = False
@staticmethod
def abstract(inputs, *, dtype):
"""
_transpose abstract
"""
in_dtype = dtypes.canonicalize_dtype(inputs.dtype)
out_dtype = te_dtype_to_jax_dtype(dtype)
assert len(inputs.shape) == 2
assert isinstance(dtype, TEDType)
assert in_dtype == out_dtype
return ShapedArray((inputs.shape[1], inputs.shape[0]),
in_dtype,
named_shape=inputs.named_shape)
@staticmethod
def lowering(ctx, inputs, *, dtype):
"""
_transpose cuda lowering
"""
in_aval = ctx.avals_in[0]
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16, jnp.int8]
ir_in_type = ir.RankedTensorType(inputs.type)
ir_in_shape = ir_in_type.shape
ir_out_dtype = te_dtype_to_ir_dtype(dtype)
out_types = [ir.RankedTensorType.get([ir_in_shape[1], ir_in_shape[0]], ir_out_dtype)]
operands = [inputs]
operand_shapes = [ir_in_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
assert len(ir_in_shape) == 2
opaque = transformer_engine_jax.pack_common_descriptor(ir_in_shape, dtype, dtype)
out = custom_caller(TransposePrimitive.name, args, opaque, False)
return [out]
_transpose_p = register_primitive(TransposePrimitive)
def transpose(inputs: jnp.ndarray, dtype: TEDType) -> jnp.ndarray:
"""
transpose wrapper
Assume input has two dimension shape
"""
return _transpose_p.bind(inputs, dtype=dtype)
class CastTransposePrimitive(BasePrimitive):
"""
Cast Transpose Primitive
"""
name = "te_cast_transpose"
multiple_results = True
@staticmethod
def abstract(inputs, amax, scale, scale_inv, *, out_dtype):
"""
te_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(inputs.dtype)
assert len(inputs.shape) == 2
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax.dtype == jnp.float32
assert scale.dtype == jnp.float32
assert scale_inv.dtype == jnp.float32
out_dtype = te_dtype_to_jax_dtype(out_dtype)
# input_cast, input_cast_trans, amax
return (ShapedArray((inputs.shape[0], inputs.shape[1]),
out_dtype,
named_shape=inputs.named_shape),
ShapedArray((inputs.shape[1], inputs.shape[0]),
out_dtype,
named_shape=inputs.named_shape),
ShapedArray((1,), amax.dtype, named_shape=amax.named_shape))
@staticmethod
def lowering(ctx, inputs, amax, scale, scale_inv, *, out_dtype):
"""
te_cast_transpose_p lowering rules
"""
in_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_in_type = ir.RankedTensorType(inputs.type)
ir_in_shape = ir_in_type.shape
ir_out_dtype = te_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
out_types = [
ir.RankedTensorType.get([ir_in_shape[0], ir_in_shape[1]], ir_out_dtype),
ir.RankedTensorType.get([ir_in_shape[1], ir_in_shape[0]], ir_out_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype),
]
operands = [inputs, amax, scale, scale_inv]
operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
assert len(ir_in_shape) == 2
opaque = transformer_engine_jax.pack_common_descriptor(ir_in_shape,
jax_dtype_to_te_dtype(in_aval.dtype),
out_dtype)
out = custom_caller(CastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={1: 2})
return out
_cast_transpose_p = register_primitive(CastTransposePrimitive)
def cast_transpose(inputs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: TEDType) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose wrapper
Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale`
"""
return _cast_transpose_p.bind(inputs, amax, scale, scale_inv, out_dtype=out_dtype)
class GatedGeluPrimitive(BasePrimitive):
"""
Gated Gelu Primitive
"""
name = "te_gated_gelu"
multiple_results = False
@staticmethod
def abstract(inputs):
"""
te_gated_gelu_p abstract
"""
dtype = dtypes.canonicalize_dtype(inputs.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
inputs_shape = inputs.shape
hidden_size = inputs_shape[-1]
# In Transformer, batch_shape = (batch, seqlen, )
batch_shapes = inputs_shape[:-1]
assert hidden_size % 2 == 0
inputs_shape = inputs.shape
out_shape = (batch_shapes) + (hidden_size // 2,)
return ShapedArray(out_shape, dtype, named_shape=inputs.named_shape)
@staticmethod
def lowering(ctx, inputs):
"""
te_gated_gelu_p lowering rules
"""
(in_aval,) = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
ir_in_type = ir.RankedTensorType(inputs.type)
ir_in_shape = ir_in_type.shape
out_shape = ir_in_shape[:-1] + [ir_in_shape[-1] // 2]
out_types = [
ir.RankedTensorType.get(out_shape, ir_in_type.element_type),
]
operands = [inputs]
operand_shapes = [ir_in_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
hidden_size = ir_in_shape[-1]
# In Transformer, batch_size = batch x seqlen
batch_size = reduce(operator.mul, ir_in_shape[:-1])
in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((batch_size, hidden_size // 2),
in_dtype, in_dtype)
out = custom_caller(GatedGeluPrimitive.name, args, opaque, False)
return [out]
_gated_gelu_p = register_primitive(GatedGeluPrimitive)
def gated_gelu(inputs: jnp.ndarray) -> jnp.ndarray:
"""
gated gelu wrapper
Return FP8(geglu(inputs))
Assume inputs has two dimensions shape and the memory layout is (N, 2, H)
"""
return _gated_gelu_p.bind(inputs)
class GatedGeluFp8Primitive(BasePrimitive):
"""
Gated Gelu FP8 Primitive
"""
name = "te_gated_gelu_fp8"
multiple_results = True
@staticmethod
def abstract(inputs, amax, scale, scale_inv, *, out_dtype):
"""
te_gated_gelu_p abstract
"""
dtype = dtypes.canonicalize_dtype(inputs.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax.dtype == jnp.float32
assert scale.dtype == jnp.float32
assert scale_inv.dtype == jnp.float32
out_dtype = te_dtype_to_jax_dtype(out_dtype)
assert len(inputs.shape) == 2
hidden_size = inputs.shape[1]
batch_size = inputs.shape[0] # In Transformer, batch_size = batch x seqlen
# input_cast, input_cast_trans, amax
return (ShapedArray((batch_size, hidden_size // 2),
out_dtype,
named_shape=inputs.named_shape),
ShapedArray((1,), amax.dtype, named_shape=amax.named_shape))
@staticmethod
def lowering(ctx, inputs, amax, scale, scale_inv, *, out_dtype):
"""
te_gated_gelu_p lowering rules
"""
in_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_in_type = ir.RankedTensorType(inputs.type)
ir_in_shape = ir_in_type.shape
ir_out_dtype = te_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
hidden_size = ir_in_shape[1]
batch_size = ir_in_shape[0] # In Transformer, batch_size = batch x seqlen
out_types = [
ir.RankedTensorType.get([batch_size, hidden_size // 2], ir_out_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype),
]
operands = [inputs, amax, scale, scale_inv]
operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor(
(ir_in_shape[0], ir_in_shape[1] // 2), jax_dtype_to_te_dtype(in_aval.dtype), out_dtype)
out = custom_caller(GatedGeluFp8Primitive.name,
args,
opaque,
False,
operand_output_aliases={1: 1})
return out
_gated_gelu_fp8_p = register_primitive(GatedGeluFp8Primitive)
def gated_gelu_fp8(inputs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray,
out_dtype: TEDType) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast gated gelu wrapper
Return FP8(geglu(inputs))
Assume inputs has two dimensions shape and the memory layout is (N, 2, H)
"""
return _gated_gelu_fp8_p.bind(inputs, amax, scale, scale_inv, out_dtype=out_dtype)
class DgatedGeluPrimitive(BasePrimitive):
"""
Dgated Gelu Primitive
"""
name = "te_dgated_gelu"
multiple_results = False
@staticmethod
def abstract(inputs, gelu_inputs):
"""
te_dgated_gelu_p abstract
"""
dtype = dtypes.canonicalize_dtype(inputs.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert gelu_inputs.dtype == dtype
for axis in range(len(inputs.shape) - 1):
assert inputs.shape[axis] == gelu_inputs.shape[axis]
i_hidden_size = inputs.shape[-1]
g_hidden_szie = gelu_inputs.shape[-1]
assert i_hidden_size * 2 == g_hidden_szie
return ShapedArray(gelu_inputs.shape, dtype, named_shape=inputs.named_shape)
@staticmethod
def lowering(ctx, inputs, gelu_inputs):
"""
te_dgated_gelu_p lowering rules
"""
in_aval, gi_aval = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert gi_aval.dtype == in_aval.dtype
ir_in_type = ir.RankedTensorType(inputs.type)
ir_in_shape = ir_in_type.shape
gi_type = ir.RankedTensorType(gelu_inputs.type)
gi_shape = gi_type.shape
for axis in range(len(ir_in_shape) - 1):
assert ir_in_shape[axis] == gi_shape[axis]
# In Transformer, batch_size = batch x seqlen
ir_batch_szie = reduce(operator.mul, ir_in_shape[:-1])
i_hidden_size = ir_in_shape[-1]
g_hidden_szie = gi_shape[-1]
assert i_hidden_size * 2 == g_hidden_szie
out_dtype = ir_in_type.element_type
out_shape = gi_shape
out_types = [
ir.RankedTensorType.get(out_shape, out_dtype),
]
operands = [inputs, gelu_inputs]
operand_shapes = [ir_in_shape, gi_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
in_dtype = jax_dtype_to_te_dtype(in_aval.dtype)
opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_szie, i_hidden_size),
in_dtype, in_dtype)
out = custom_caller(DgatedGeluPrimitive.name, args, opaque, False)
return [out]
_dgated_gelu_p = register_primitive(DgatedGeluPrimitive)
def dgated_gelu(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray) -> jnp.ndarray:
"""
dgated_gelu fusion wrapper
Return dgeglu(inputs)
"""
return _dgated_gelu_p.bind(inputs, gelu_inputs)
class DgatedGeluCastTransposePrimitive(BasePrimitive):
"""
Dgated Gelu Cast Transpose Primitive
"""
name = "te_dgated_gelu_cast_transpose"
multiple_results = True
@staticmethod
def abstract(inputs, gelu_inputs, amax, scale, scale_inv, *, out_dtype):
"""
te_dgated_gelu_cast_transpose_p abstract
"""
dtype = dtypes.canonicalize_dtype(inputs.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert gelu_inputs.dtype == dtype
assert len(inputs.shape) == 2
assert len(gelu_inputs.shape) == 2
ir_batch_szie = inputs.shape[0]
gi_batch_size = gelu_inputs.shape[0]
assert ir_batch_szie == gi_batch_size
ir_hidden_szie = inputs.shape[1]
gi_hidden_size = gelu_inputs.shape[1]
assert ir_hidden_szie * 2 == gi_hidden_size
assert amax.dtype == jnp.float32
assert scale.dtype == jnp.float32
assert scale_inv.dtype == jnp.float32
out_dtype = te_dtype_to_jax_dtype(out_dtype)
# input_cast, input_cast_trans, amax
return (ShapedArray((gi_batch_size, gi_hidden_size),
out_dtype,
named_shape=inputs.named_shape),
ShapedArray((gi_hidden_size, gi_batch_size),
out_dtype,
named_shape=inputs.named_shape),
ShapedArray((1,), amax.dtype, named_shape=amax.named_shape))
@staticmethod
def lowering(ctx, inputs, gelu_inputs, amax, scale, scale_inv, *, out_dtype):
"""
te_dgated_gelu_cast_transpose_p lowering rules
"""
in_aval, gi_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert gi_aval.dtype == in_aval.dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_in_type = ir.RankedTensorType(inputs.type)
ir_in_shape = ir_in_type.shape
gi_type = ir.RankedTensorType(gelu_inputs.type)
gi_shape = gi_type.shape
ir_batch_szie = ir_in_shape[0]
gi_batch_size = gi_shape[0]
assert ir_batch_szie == gi_batch_size
ir_hidden_szie = ir_in_shape[1]
gi_hidden_size = gi_shape[1]
assert ir_hidden_szie * 2 == gi_hidden_size
ir_out_dtype = te_dtype_to_ir_dtype(out_dtype)
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
out_types = [
ir.RankedTensorType.get([gi_batch_size, gi_hidden_size], ir_out_dtype),
ir.RankedTensorType.get([gi_hidden_size, gi_batch_size], ir_out_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype),
]
operands = [inputs, gelu_inputs, amax, scale, scale_inv]
operand_shapes = [ir_in_shape, gi_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor((ir_batch_szie, ir_hidden_szie),
jax_dtype_to_te_dtype(in_aval.dtype),
out_dtype)
out = custom_caller(DgatedGeluCastTransposePrimitive.name,
args,
opaque,
False,
operand_output_aliases={2: 2})
return out
_dgated_gelu_cast_transpose_p = register_primitive(DgatedGeluCastTransposePrimitive)
def dgated_gelu_cast_transpose(inputs: jnp.ndarray, gelu_inputs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray,
out_dtype: TEDType) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""
cast transpose d_gated_gelu fusion wrapper
Return FP8(dgeglu(inputs))
"""
return _dgated_gelu_cast_transpose_p.bind(inputs,
gelu_inputs,
amax,
scale,
scale_inv,
out_dtype=out_dtype)
class GemmPrimitive(BasePrimitive):
"""
Gemm Primitive
"""
name = "te_gemm"
multiple_results = False
@staticmethod
def abstract(A, B, A_scale_inv, B_scale_inv, *, A_dtype, B_dtype, D_dtype, transa, transb,
use_split_accumulator): # pylint: disable=unused-argument
"""
te_gemm_p abstract
"""
atype = dtypes.canonicalize_dtype(A.dtype)
btype = dtypes.canonicalize_dtype(B.dtype)
assert atype == te_dtype_to_jax_dtype(A_dtype)
assert btype == te_dtype_to_jax_dtype(B_dtype)
assert A_scale_inv.dtype == jnp.float32
assert B_scale_inv.dtype == jnp.float32
m = A.shape[0] if transa else A.shape[1]
k = A.shape[1] if transa else A.shape[0]
n = B.shape[1] if transb else B.shape[0]
assert (transb and k == B.shape[0]) or k == B.shape[1]
out_dtype = te_dtype_to_jax_dtype(D_dtype)
return ShapedArray((n, m),
out_dtype,
named_shape=merge_named_shape(A.named_shape, B.named_shape))
@staticmethod
def lowering(ctx, A, B, A_scale_inv, B_scale_inv, *, A_dtype, B_dtype, D_dtype, transa, transb,
use_split_accumulator):
"""
te_gemm_p lowering rules
"""
A_aval, B_aval, A_scale_inv_aval, B_scale_inv_aval = ctx.avals_in
assert A_aval.dtype == te_dtype_to_jax_dtype(A_dtype)
assert B_aval.dtype == te_dtype_to_jax_dtype(B_dtype)
assert A_scale_inv_aval.dtype == jnp.float32
assert B_scale_inv_aval.dtype == jnp.float32
A_type = ir.RankedTensorType(A.type)
B_type = ir.RankedTensorType(B.type)
A_shape = A_type.shape
B_shape = B_type.shape
A_scale_inv_shape = ir.RankedTensorType(A_scale_inv.type).shape
B_scale_inv_shape = ir.RankedTensorType(B_scale_inv.type).shape
m = A_shape[0] if transa else A_shape[1]
k = A_shape[1] if transa else A_shape[0]
n = B_shape[1] if transb else B_shape[0]
assert (transb and k == B_shape[0]) or k == B_shape[1]
ir_out_dtype = dtype_to_ir_type(np.dtype(te_dtype_to_jax_dtype(D_dtype)))
out_types = [
ir.RankedTensorType.get([n, m], ir_out_dtype),
]
operands = [A, B, A_scale_inv, B_scale_inv]
operand_shapes = [A_shape, B_shape, A_scale_inv_shape, B_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
# m, n, k here should be equal to transa=False and transb=False,
# due to te_gemm's implementation.
# Therefore, m=A_shape[1], n=B_shape[0], k=A_shape[0]
opaque = transformer_engine_jax.pack_gemm_descriptor(A_shape[1], B_shape[0], A_shape[0],
A_dtype, B_dtype, D_dtype, transa,
transb, use_split_accumulator)
out = custom_caller(GemmPrimitive.name, args, opaque, False)
return [out]
_gemm_p = register_primitive(GemmPrimitive)
def gemm(A: jnp.ndarray,
A_scale_inv: jnp.ndarray,
A_type: TEDType,
transa: bool,
B: jnp.ndarray,
B_scale_inv: jnp.ndarray,
B_type: TEDType,
transb: bool,
D_type: TEDType,
use_split_accumulator: bool = False) -> jnp.ndarray:
"""
gemm wrapper
"""
return _gemm_p.bind(A,
B,
A_scale_inv,
B_scale_inv,
A_dtype=A_type,
B_dtype=B_type,
D_dtype=D_type,
transa=transa,
transb=transb,
use_split_accumulator=use_split_accumulator)
class LayerNormFwdPrimitive(BasePrimitive):
"""
Layer Normalization Forward Primitive
"""
name = "te_layernorm_forward"
multiple_results = True
@staticmethod
def abstract(
x,
gamma,
beta,
*,
epsilon # pylint: disable=unused-argument
):
"""
LayerNorm fwd abstract
"""
x_dtype = dtypes.canonicalize_dtype(x.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
mu_dtype = jnp.float32
rsigma_dtype = jnp.float32
assert gamma.size == beta.size
hidden_size = gamma.size
assert x.size % hidden_size == 0
# In Transformer, batch_size = batch x seqlen
batch_size = x.size // hidden_size
return (
ShapedArray(x.shape, x_dtype, named_shape=x.named_shape), # output
ShapedArray((batch_size,), mu_dtype, named_shape=x.named_shape), # mu
ShapedArray((batch_size,), rsigma_dtype, named_shape=x.named_shape), # rsigma
)
@staticmethod
def lowering(ctx, x, gamma, beta, *, epsilon):
"""
LayerNorm fwd lowering rules
"""
x_aval, gamma_aval, beta_aval = ctx.avals_in
assert gamma_aval.dtype == beta_aval.dtype
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
w_type = ir.RankedTensorType(gamma.type)
w_shape = w_type.shape
b_type = ir.RankedTensorType(beta.type)
b_shape = b_type.shape
assert w_type == b_type
assert w_shape == b_shape
# Output shape is same as the input shape, but the output type is same as the weight type.
# See ln_api.cpp
out_shape = x_shape
output_type = w_type.element_type
ir_mu_dtype = ir.F32Type.get()
ir_rsigma_dtype = ir.F32Type.get()
hidden_size = reduce(operator.mul, w_shape)
# In Transformer, batch_size = batch x seqlen
batch_size = reduce(operator.mul, x_shape) // hidden_size
out_types = [
ir.RankedTensorType.get(out_shape, output_type),
ir.RankedTensorType.get((batch_size,), ir_mu_dtype),
ir.RankedTensorType.get((batch_size,), ir_rsigma_dtype),
]
operands = [x, gamma, beta]
operand_shapes = [x_shape, w_shape, b_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
epsilon,
)
out = custom_caller(LayerNormFwdPrimitive.name, args, opaque, False)
return out
_layernorm_fwd_p = register_primitive(LayerNormFwdPrimitive)
def layernorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, epsilon: float):
"""
Wrapper for TE layernorm fwd
"""
return _layernorm_fwd_p.bind(x, gamma, beta, epsilon=epsilon)
class LayerNormFwdFp8Primitive(BasePrimitive):
"""
Layer Normalization Forward FP8 Primitive
"""
name = "te_layernorm_forward_fp8"
multiple_results = True
@staticmethod
def abstract(
x,
gamma,
beta,
amax,
scale,
scale_inv,
*,
epsilon # pylint: disable=unused-argument
):
"""
LayerNorm fwd (fp8 out) abstract
"""
x_dtype = dtypes.canonicalize_dtype(x.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax.dtype == jnp.float32
assert scale.dtype == jnp.float32
assert scale_inv.dtype == jnp.float32
out_dtype = jnp.int8
mu_dtype = jnp.float32
rsigma_dtype = jnp.float32
assert gamma.size == beta.size
hidden_szie = gamma.size
# In Transformer, batch_size = batch x seqlen
batch_size = x.size // hidden_szie
return (
ShapedArray(x.shape, out_dtype, named_shape=x.named_shape), # output
ShapedArray((batch_size,), mu_dtype, named_shape=x.named_shape), # mu
ShapedArray((batch_size,), rsigma_dtype, named_shape=x.named_shape), # rsigma
ShapedArray((1,), amax.dtype, named_shape=amax.named_shape), # amax
)
@staticmethod
def lowering(ctx, x, gamma, beta, amax, scale, scale_inv, *, epsilon):
"""
LayerNorm fwd (fp8 out) lowering rules
"""
x_aval, gamma_aval, beta_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert gamma_aval.dtype == beta_aval.dtype
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
w_type = ir.RankedTensorType(gamma.type)
w_shape = w_type.shape
b_type = ir.RankedTensorType(beta.type)
b_shape = b_type.shape
ir_out_dtype = dtype_to_ir_type(np.dtype(np.int8))
ir_mu_dtype = ir.F32Type.get()
ir_rsigma_dtype = ir.F32Type.get()
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
hidden_size = reduce(operator.mul, w_shape)
# In Transformer, batch_size = batch x seqlen
batch_size = reduce(operator.mul, x_shape) // hidden_size
out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get((batch_size,), ir_mu_dtype),
ir.RankedTensorType.get((batch_size,), ir_rsigma_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype),
]
operands = [x, gamma, beta, amax, scale, scale_inv]
operand_shapes = [
x_shape, w_shape, b_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
epsilon,
)
out = custom_caller(LayerNormFwdFp8Primitive.name,
args,
opaque,
False,
operand_output_aliases={3: 3})
return out
_layernorm_fwd_fp8_p = register_primitive(LayerNormFwdFp8Primitive)
def layernorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, epsilon: float):
"""
Wrapper for TE layernorm fwd (fp8 out)
"""
return _layernorm_fwd_fp8_p.bind(x, gamma, beta, amax, scale, scale_inv, epsilon=epsilon)
class LayerNormBwdPrimitive(BasePrimitive):
"""
Layer Normalization Backward Primitive
"""
name = "te_layernorm_backward"
multiple_results = True
@staticmethod
def abstract(
grad_output,
mu,
rsigma,
x,
gamma,
*,
epsilon # pylint: disable=unused-argument
):
"""
Layernorm bwd abstract
"""
x_dtype = dtypes.canonicalize_dtype(x.dtype)
w_dtype = dtypes.canonicalize_dtype(gamma.dtype)
mu_dtype = dtypes.canonicalize_dtype(mu.dtype)
rsigma_dtype = dtypes.canonicalize_dtype(rsigma.dtype)
hidden_size = gamma.size
# In Transformer, batch_size = batch x seqlen
batch_size = x.size // hidden_size
assert dtypes.canonicalize_dtype(grad_output.dtype) == w_dtype
assert grad_output.shape == x.shape
assert mu.shape == rsigma.shape == (batch_size,)
assert mu_dtype == rsigma_dtype == jnp.float32
assert grad_output.named_shape == x.named_shape
return (
ShapedArray(x.shape, x_dtype, named_shape=grad_output.named_shape), # grad input
ShapedArray(gamma.shape, w_dtype, named_shape=gamma.named_shape), # grad gamma
ShapedArray(gamma.shape, w_dtype, named_shape=gamma.named_shape), # grad beta
)
@staticmethod
def lowering(ctx, grad_output, mu, rsigma, x, gamma, *, epsilon):
"""
Layernorm bwd lowering rules
"""
_, _, _, x_aval, gamma_aval = ctx.avals_in
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
w_type = ir.RankedTensorType(gamma.type)
w_shape = w_type.shape
b_type = ir.RankedTensorType(gamma.type)
b_shape = b_type.shape
assert w_type == b_type
assert w_shape == b_shape
go_shape = ir.RankedTensorType(grad_output.type).shape
mu_shape = ir.RankedTensorType(mu.type).shape
rsigma_shape = ir.RankedTensorType(rsigma.type).shape
hidden_size = reduce(operator.mul, w_shape)
# In Transformer, batch_size = batch x seqlen
batch_size = reduce(operator.mul, x_shape) // hidden_size
out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(w_shape, w_type.element_type),
ir.RankedTensorType.get(b_shape, b_type.element_type),
]
operands = [grad_output, mu, rsigma, x, gamma]
operand_shapes = [go_shape, mu_shape, rsigma_shape, x_shape, w_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
epsilon,
)
out = custom_caller(LayerNormBwdPrimitive.name, args, opaque, False)
return out
_layernorm_bwd_p = register_primitive(LayerNormBwdPrimitive)
def layernorm_bwd(g: jnp.ndarray, mu: jnp.ndarray, rsigma: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, epsilon: float):
"""
Wrapper for TE layernorm bwd
"""
return _layernorm_bwd_p.bind(g, mu, rsigma, x, gamma, epsilon=epsilon)
class RmsNormFwdPrimitive(BasePrimitive):
"""
RMS Normalization Forward Primitive
"""
name = "te_rmsnorm_forward"
multiple_results = True
@staticmethod
def abstract(
x,
gamma,
*,
epsilon # pylint: disable=unused-argument
):
"""
RMSNorm fwd abstract
"""
x_dtype = dtypes.canonicalize_dtype(x.dtype)
rsigma_dtype = jnp.float32
hidden_size = gamma.size
# In Transformer, batch_size = batch x seqlen
batch_size = x.size // hidden_size
return (
ShapedArray(x.shape, x_dtype, named_shape=x.named_shape), # output
ShapedArray((batch_size,), rsigma_dtype, named_shape=x.named_shape), # rsigma
)
@staticmethod
def lowering(ctx, x, gamma, *, epsilon):
"""
RMSNorm fwd lowering rules
"""
x_aval, gamma_aval = ctx.avals_in
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
w_type = ir.RankedTensorType(gamma.type)
w_shape = w_type.shape
iv_element_type = ir.F32Type.get()
hidden_size = reduce(operator.mul, w_shape)
# In Transformer, batch_size = batch x seqlen
batch_size = reduce(operator.mul, x_shape) // hidden_size
out_types = [
ir.RankedTensorType.get(x_shape, w_type.element_type),
ir.RankedTensorType.get((batch_size,), iv_element_type),
]
operands = [x, gamma]
operand_shapes = [x_shape, w_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
epsilon,
)
out = custom_caller(RmsNormFwdPrimitive.name, args, opaque, False)
return out
_rmsnorm_fwd_p = register_primitive(RmsNormFwdPrimitive)
def rmsnorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, epsilon: float):
"""
Wrapper for TE rmsnorm fwd
"""
return _rmsnorm_fwd_p.bind(x, gamma, epsilon=epsilon)
class RmsNormFwdFp8Primitive(BasePrimitive):
"""
RMS Normalization Forward FP8 Primitive
"""
name = "te_rmsnorm_forward_fp8"
multiple_results = True
@staticmethod
def abstract(
x,
gamma,
amax,
scale,
scale_inv,
*,
epsilon # pylint: disable=unused-argument
):
"""
RMSNorm fwd (fp8 out) abstract
"""
x_dtype = dtypes.canonicalize_dtype(x.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax.dtype == jnp.float32
assert scale.dtype == jnp.float32
assert scale_inv.dtype == jnp.float32
out_dtype = jnp.int8
rsigma_dtype = jnp.float32
hidden_size = gamma.size
# In Transformer, batch_size = batch x seqlen
batch_size = x.size // hidden_size
return (
ShapedArray(x.shape, out_dtype, named_shape=x.named_shape), # output
ShapedArray((batch_size,), rsigma_dtype, named_shape=x.named_shape), # rsigma
ShapedArray((1,), amax.dtype, named_shape=amax.named_shape), # amax
)
@staticmethod
def lowering(ctx, x, gamma, amax, scale, scale_inv, *, epsilon):
"""
RMSNorm fwd (fp8 out) lowering rules
"""
x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
w_type = ir.RankedTensorType(gamma.type)
w_shape = w_type.shape
ir_out_dtype = dtype_to_ir_type(np.dtype(np.int8))
ir_rsigma_dtype = ir.F32Type.get()
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_dtype = ir_amax_type.element_type
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
hidden_size = reduce(operator.mul, w_shape)
# In Transformer, batch_size = batch x seqlen
batch_size = reduce(operator.mul, x_shape) // hidden_size
out_types = [
ir.RankedTensorType.get(x_shape, ir_out_dtype),
ir.RankedTensorType.get((batch_size,), ir_rsigma_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype),
]
operands = [x, gamma, amax, scale, scale_inv]
operand_shapes = [x_shape, w_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
epsilon,
)
out = custom_caller(RmsNormFwdFp8Primitive.name,
args,
opaque,
False,
operand_output_aliases={2: 2})
return out
_rmsnorm_fwd_fp8_p = register_primitive(RmsNormFwdFp8Primitive)
def rmsnorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray,
scale_inv: jnp.ndarray, epsilon: float):
"""
Wrapper for TE rmsnorm fwd (fp8 out)
"""
return _rmsnorm_fwd_fp8_p.bind(x, gamma, amax, scale, scale_inv, epsilon=epsilon)
class RmsNormBwdPrimitive(BasePrimitive):
"""
RMS Normalization Backward Primitive
"""
name = "te_rmsnorm_backward"
multiple_results = True
@staticmethod
def abstract(
grad_output,
rsigma,
x,
gamma,
*,
epsilon # pylint: disable=unused-argument
):
"""
RMSNorm bwd abstract
"""
w_dtype = dtypes.canonicalize_dtype(gamma.dtype)
x_dtype = dtypes.canonicalize_dtype(x.dtype)
rsigma_dtype = dtypes.canonicalize_dtype(rsigma.dtype)
hidden_size = gamma.size
# In Transformer, batch_size = batch x seqlen
batch_size = x.size // hidden_size
assert dtypes.canonicalize_dtype(grad_output.dtype) == w_dtype
assert grad_output.shape == x.shape
assert rsigma.shape == (batch_size,)
assert rsigma_dtype == jnp.float32
assert grad_output.named_shape == x.named_shape
return (
ShapedArray(x.shape, x_dtype, named_shape=grad_output.named_shape), # grad input
ShapedArray(gamma.shape, w_dtype, named_shape=gamma.named_shape), # grad gamma
)
@staticmethod
def lowering(ctx, grad_output, inv_var, x, gamma, *, epsilon):
"""
RMSNorm bwd lowering rules
"""
_, _, x_aval, gamma_aval = ctx.avals_in
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
w_type = ir.RankedTensorType(gamma.type)
w_shape = w_type.shape
go_shape = ir.RankedTensorType(grad_output.type).shape
inv_var_shape = ir.RankedTensorType(inv_var.type).shape
hidden_size = reduce(operator.mul, w_shape)
# In Transformer, batch_size = batch x seqlen
batch_size = reduce(operator.mul, x_shape) // hidden_size
out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(w_shape, w_type.element_type),
]
operands = [grad_output, inv_var, x, gamma]
operand_shapes = [go_shape, inv_var_shape, x_shape, w_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
epsilon,
)
out = custom_caller(RmsNormBwdPrimitive.name, args, opaque, False)
return out
_rmsnorm_bwd_p = register_primitive(RmsNormBwdPrimitive)
def rmsnorm_bwd(grad: jnp.ndarray, inv_var: jnp.ndarray, x: jnp.ndarray, gamma: jnp.ndarray,
epsilon: float):
"""
Wrapper for TE rmsnorm bwd
"""
return _rmsnorm_bwd_p.bind(grad, inv_var, x, gamma, epsilon=epsilon)
class QuantizePrimitive(BasePrimitive):
"""
Quantize Primitive
"""
name = "te_quantize"
multiple_results = True
@staticmethod
def abstract(inputs, amax, scale, scale_inv, *, out_dtype):
"""
te_quantize abstract
"""
in_dtype = dtypes.canonicalize_dtype(inputs.dtype)
assert in_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert isinstance(out_dtype, TEDType)
out_dtype = te_dtype_to_jax_dtype(out_dtype)
assert amax.dtype == jnp.float32
assert scale.dtype == jnp.float32
assert scale_inv.dtype == jnp.float32
return (ShapedArray(inputs.shape, out_dtype, named_shape=inputs.named_shape),
ShapedArray((1,), amax.dtype, named_shape=amax.named_shape))
@staticmethod
def lowering(ctx, inputs, amax, scale, scale_inv, *, out_dtype):
"""
te_quantize lowering rules
"""
in_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert in_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_in_type = ir.RankedTensorType(inputs.type)
ir_in_shape = ir_in_type.shape
ir_out_dtype = te_dtype_to_ir_dtype(out_dtype)
ir_out_shape = ir_in_shape
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_shape = ir_amax_type.shape
ir_amax_dtype = ir_amax_type.element_type
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
out_types = [
ir.RankedTensorType.get(ir_out_shape, ir_out_dtype),
ir.RankedTensorType.get((1,), ir_amax_dtype),
]
operands = [inputs, amax, scale, scale_inv]
operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor(in_aval.shape,
jax_dtype_to_te_dtype(in_aval.dtype),
out_dtype)
out = custom_caller(QuantizePrimitive.name,
args,
opaque,
False,
operand_output_aliases={1: 1})
return out
_quantize_p = register_primitive(QuantizePrimitive)
def quantize(inputs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
out_dtype: TEDType) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
quantize wrapper
Return FP8 tensor
"""
return _quantize_p.bind(inputs, amax, scale, scale_inv, out_dtype=out_dtype)
class DequantizePrimitive(BasePrimitive):
"""
Dequantize Primitive
"""
name = "te_dequantize"
multiple_results = False
@staticmethod
def abstract(inputs, amax, scale, scale_inv, *, fp8_dtype, out_dtype):
"""
te_dquantize abstract
"""
in_dtype = dtypes.canonicalize_dtype(inputs.dtype)
assert in_dtype == jnp.int8
assert isinstance(fp8_dtype, TEDType)
assert isinstance(out_dtype, TEDType)
out_dtype = te_dtype_to_jax_dtype(out_dtype)
assert out_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax.dtype == jnp.float32
assert scale.dtype == jnp.float32
assert scale_inv.dtype == jnp.float32
return ShapedArray(inputs.shape, out_dtype, named_shape=inputs.named_shape)
@staticmethod
def lowering(ctx, inputs, amax, scale, scale_inv, *, fp8_dtype, out_dtype):
"""
te_dquantize lowering rules
"""
in_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert in_aval.dtype == jnp.int8
assert amax_aval.dtype == jnp.float32
assert scale_aval.dtype == jnp.float32
assert scale_inv_aval.dtype == jnp.float32
ir_in_type = ir.RankedTensorType(inputs.type)
ir_in_shape = ir_in_type.shape
ir_out_dtype = te_dtype_to_ir_dtype(out_dtype)
ir_out_shape = ir_in_shape
ir_amax_type = ir.RankedTensorType(amax.type)
ir_amax_shape = ir_amax_type.shape
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape
out_types = [ir.RankedTensorType.get(ir_out_shape, ir_out_dtype)]
operands = [inputs, amax, scale, scale_inv]
operand_shapes = [ir_in_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_common_descriptor(in_aval.shape, fp8_dtype, out_dtype)
out = custom_caller(DequantizePrimitive.name, args, opaque, False)
return [out]
_dequantize_p = register_primitive(DequantizePrimitive)
def dequantize(inputs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray,
fp8_dtype: TEDType, out_dtype: TEDType) -> jnp.ndarray:
"""
dequantize wrapper
Return FP16/BF16/FP32 tensor
"""
return _dequantize_p.bind(inputs,
amax,
scale,
scale_inv,
fp8_dtype=fp8_dtype,
out_dtype=out_dtype)
class SoftmaxPrimitive(BasePrimitive):
"""
Softmax Primitive
"""
max_k_seqlen_supported = 4096
@staticmethod
def get_batch_per_block(k_seqlen: int) -> int:
"""Get batch per CTA in Softmax kernels"""
threads_per_warp = 32
threads_per_block = 128 # Depends on the kernel implmentation
pow2 = 1 << (k_seqlen - 1).bit_length()
warp_size = pow2 if pow2 < threads_per_warp else threads_per_warp
batches_per_warp = 2 if pow2 <= 128 else 1
warps_per_block = threads_per_block / warp_size
batches_per_block = warps_per_block * batches_per_warp
return batches_per_block
@staticmethod
def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
dtype: jnp.dtype) -> bool:
"""Check Softmax kernel availability based on size"""
raise NotImplementedError
@staticmethod
def softmax_backward_abstract(grad_outputs, softmax_outputs, scale_factor=None): # pylint: disable=unused-argument
"""
MLIR abstract
"""
grad_outputs_dtype = dtypes.canonicalize_dtype(grad_outputs.dtype)
softmax_outputs_dtype = dtypes.canonicalize_dtype(softmax_outputs.dtype)
assert grad_outputs_dtype == softmax_outputs_dtype
assert grad_outputs_dtype in [jnp.float16, jnp.bfloat16]
assert softmax_outputs_dtype in [jnp.float16, jnp.bfloat16]
assert grad_outputs.shape == softmax_outputs.shape
return ShapedArray(softmax_outputs.shape,
softmax_outputs_dtype,
named_shape=softmax_outputs.named_shape)
@staticmethod
def softmax_backward_lowering(name, ctx, grad_outputs, softmax_outputs, scale_factor):
"""
MLIR abstract
"""
grad_outputs_aval, _ = ctx.avals_in
grad_outputs_type = ir.RankedTensorType(grad_outputs.type)
grad_outputs_shape = grad_outputs_type.shape
batch = grad_outputs_shape[0]
pad_batch = batch # unused
heads = grad_outputs_shape[1]
q_seqlen = grad_outputs_shape[2]
k_seqlen = grad_outputs_shape[3]
softmax_outputs_type = ir.RankedTensorType(softmax_outputs.type)
softmax_outputs_shape = softmax_outputs_type.shape
out_types = [
ir.RankedTensorType.get(softmax_outputs_shape, softmax_outputs_type.element_type)
]
operands = [grad_outputs, softmax_outputs]
operand_shapes = [grad_outputs_shape, softmax_outputs_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_softmax_descriptor(
batch, pad_batch, heads, q_seqlen, k_seqlen,
jax_dtype_to_te_dtype(grad_outputs_aval.dtype), scale_factor)
out = custom_caller(name, args, opaque, False)
return [out]
class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
"""
Scaled Softmax Fwd Primitive
"""
name = "te_scaled_softmax_forward"
multiple_results = False
@staticmethod
def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
dtype: jnp.dtype) -> bool:
"""Check Softmax kernel availability based on size"""
attn_batches = batch * heads
if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096
and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4
and attn_batches % 4 == 0 # batch * heads must be divisor of 4
):
if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
return q_seqlen % batch_per_block == 0
return False
@staticmethod
def abstract(inputs, *, scale_factor): # pylint: disable=unused-argument
"""
te_scaled_softmax_forward abstract
"""
shape_rank = 4 # batch, heads, q_seqlen and k_seqlen
i_dtype = dtypes.canonicalize_dtype(inputs.dtype)
assert i_dtype in [jnp.float16, jnp.bfloat16]
i_shape = inputs.shape
assert len(i_shape) == shape_rank
q_seqlen = i_shape[2]
k_seqlen = i_shape[3]
assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
assert q_seqlen > 1
return ShapedArray(inputs.shape, i_dtype, named_shape=inputs.named_shape)
@staticmethod
def lowering(ctx, inputs, *, scale_factor):
"""
te_scaled_softmax_forward lowering rules
"""
shape_rank = 4 # batch, heads, q_seqlen and k_seqlen
i_aval, = ctx.avals_in
i_type = ir.RankedTensorType(inputs.type)
i_shape = i_type.shape
assert len(i_shape) == shape_rank
batch = i_shape[0]
pad_batch = batch
heads = i_shape[1]
q_seqlen = i_shape[2]
k_seqlen = i_shape[3]
out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
operands = [inputs]
operand_shapes = [i_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_softmax_descriptor(batch, pad_batch, heads, q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(i_aval.dtype),
scale_factor)
out = custom_caller(ScaledSoftmaxFwdPrimitive.name, args, opaque, False)
return [out]
_scaled_softmax_fwd_p = register_primitive(ScaledSoftmaxFwdPrimitive)
def scaled_softmax_fwd(inputs: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
"""
scaled_softmax_forward wrapper
Return FP16/BF16 tensor
"""
return _scaled_softmax_fwd_p.bind(inputs, scale_factor=scale_factor)
class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
"""
Scaled Softmax Bwd Primitive
"""
name = "te_scaled_softmax_backward"
multiple_results = False
@staticmethod
def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
dtype: jnp.dtype) -> bool:
"""Check Softmax kernel availability based on size"""
return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
dtype)
@staticmethod
def abstract(grad_outputs, softmax_outputs, *, scale_factor):
"""
te_scaled_softmax_backward abstract
"""
return SoftmaxPrimitive.softmax_backward_abstract(grad_outputs, softmax_outputs,
scale_factor)
@staticmethod
def lowering(ctx, grad_outputs, softmax_outputs, *, scale_factor):
"""
te_scaled_softmax_backward lowering rules
"""
out = SoftmaxPrimitive.softmax_backward_lowering(ScaledSoftmaxBwdPrimitive.name, ctx,
grad_outputs, softmax_outputs,
scale_factor)
return [out]
_scaled_softmax_bwd_p = register_primitive(ScaledSoftmaxBwdPrimitive)
def scaled_softmax_bwd(grad_outputs: jnp.ndarray, softmax_outputs: jnp.ndarray,
scale_factor: float) -> jnp.ndarray:
"""
scaled_softmax_backward wrapper
Return FP16/BF16 tensor
"""
return _scaled_softmax_bwd_p.bind(grad_outputs, softmax_outputs, scale_factor=scale_factor)
class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
"""
Scaled Masked Softmax Fwd Primitive
"""
name = "te_scaled_masked_softmax_forward"
multiple_results = False
@staticmethod
def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
dtype: jnp.dtype) -> bool:
"""Check Softmax kernel availability based on size"""
attn_batches = batch * heads
if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096
and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4
and attn_batches % 4 == 0 # batch * heads must be divisor of 4
):
if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
return q_seqlen % batch_per_block == 0
return False
@staticmethod
def abstract(inputs, mask, *, scale_factor): # pylint: disable=unused-argument
"""
te_scaled_masked_softmax_forward abstract
"""
shape_rank = 4 # batch, heads, q_seqlen and k_seqlen
i_dtype = dtypes.canonicalize_dtype(inputs.dtype)
assert i_dtype in [jnp.float16, jnp.bfloat16]
i_shape = inputs.shape
assert len(i_shape) == shape_rank
batch = i_shape[0]
q_seqlen = i_shape[2]
k_seqlen = i_shape[3]
assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
assert q_seqlen > 1
mask_dtype = dtypes.canonicalize_dtype(mask.dtype)
assert mask_dtype in [
jnp.uint8,
]
mask_shape = mask.shape
assert len(mask_shape) == shape_rank
pad_batch = mask_shape[0]
assert pad_batch in (1, batch) # 1 means broadcast
assert mask_shape[1] == 1 # 1 means broadcast
assert mask_shape[2] == q_seqlen
assert mask_shape[3] == k_seqlen
return ShapedArray(inputs.shape, i_dtype, named_shape=inputs.named_shape)
@staticmethod
def lowering(ctx, inputs, mask, *, scale_factor):
"""
te_scaled_masked_softmax_forward lowering rules
"""
shape_rank = 4 # batch, heads, q_seqlen and k_seqlen
i_aval, _ = ctx.avals_in
i_type = ir.RankedTensorType(inputs.type)
i_shape = i_type.shape
assert len(i_shape) == shape_rank
batch = i_shape[0]
heads = i_shape[1]
q_seqlen = i_shape[2]
k_seqlen = i_shape[3]
mask_type = ir.RankedTensorType(mask.type)
mask_shape = mask_type.shape
assert len(mask_shape) == shape_rank
pad_batch = mask_shape[0]
out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
operands = [inputs, mask]
operand_shapes = [i_shape, mask_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_softmax_descriptor(batch, pad_batch, heads, q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(i_aval.dtype),
scale_factor)
out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False)
return [out]
_scaled_masked_softmax_fwd_p = register_primitive(ScaledMaskedSoftmaxFwdPrimitive)
def scaled_masked_softmax_fwd(inputs: jnp.ndarray, mask: jnp.ndarray,
scale_factor: float) -> jnp.ndarray:
"""
scaled_masked_softmax_forward wrapper
Return FP16/BF16 tensor
"""
return _scaled_masked_softmax_fwd_p.bind(inputs, mask, scale_factor=scale_factor)
class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
"""
Scaled Masked Softmax Bwd Primitive
"""
name = "te_scaled_masked_softmax_backward"
multiple_results = False
@staticmethod
def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
dtype: jnp.dtype) -> bool:
"""Check Softmax kernel availability based on size"""
return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
dtype)
@staticmethod
def abstract(grad_outputs, softmax_outputs, *, scale_factor):
"""
te_scaled_masked_softmax_backward abstract
"""
return SoftmaxPrimitive.softmax_backward_abstract(grad_outputs, softmax_outputs,
scale_factor)
@staticmethod
def lowering(ctx, grad_outputs, softmax_outputs, *, scale_factor):
"""
te_scaled_masked_softmax_backward lowering rules
"""
out = SoftmaxPrimitive.softmax_backward_lowering(ScaledMaskedSoftmaxBwdPrimitive.name, ctx,
grad_outputs, softmax_outputs,
scale_factor)
return [out]
_scaled_masked_softmax_bwd_p = register_primitive(ScaledMaskedSoftmaxBwdPrimitive)
def scaled_masked_softmax_bwd(grad_outputs: jnp.ndarray, softmax_outputs: jnp.ndarray,
scale_factor: float) -> jnp.ndarray:
"""
scaled_masked_softmax_backward wrapper
Return FP16/BF16 tensor
"""
return _scaled_masked_softmax_bwd_p.bind(grad_outputs,
softmax_outputs,
scale_factor=scale_factor)
class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
"""
Scaled Upper Triang Masked Softmax Fwd Primitive
"""
name = "te_scaled_upper_triang_masked_softmax_forward"
multiple_results = False
@staticmethod
def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
dtype: jnp.dtype) -> bool:
"""Check Softmax kernel availability based on size"""
attn_batches = batch * heads
if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096
and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4
and attn_batches % 4 == 0 # batch * heads must be divisor of 4
):
if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
return attn_batches % batch_per_block == 0
return False
@staticmethod
def abstract(inputs, *, scale_factor): # pylint: disable=unused-argument
"""
te_scaled_upper_triang_masked_softmax_forward abstract
"""
shape_rank = 4 # batch, heads, q_seqlen and k_seqlen
i_dtype = dtypes.canonicalize_dtype(inputs.dtype)
assert i_dtype in [jnp.float16, jnp.bfloat16]
i_shape = inputs.shape
assert len(i_shape) == shape_rank
q_seqlen = i_shape[2]
k_seqlen = i_shape[3]
assert q_seqlen == k_seqlen
assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
assert q_seqlen > 1
return ShapedArray(inputs.shape, i_dtype, named_shape=inputs.named_shape)
@staticmethod
def lowering(ctx, inputs, *, scale_factor):
"""
te_scaled_upper_triang_masked_softmax_forward lowering rules
"""
shape_rank = 4 # batch, heads, q_seqlen and k_seqlen
i_aval, = ctx.avals_in
i_type = ir.RankedTensorType(inputs.type)
i_shape = i_type.shape
assert len(i_shape) == shape_rank
batch = i_shape[0]
pad_batch = batch
heads = i_shape[1]
q_seqlen = i_shape[2]
k_seqlen = i_shape[3]
out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)]
operands = [inputs]
operand_shapes = [i_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
opaque = transformer_engine_jax.pack_softmax_descriptor(batch, pad_batch, heads, q_seqlen,
k_seqlen,
jax_dtype_to_te_dtype(i_aval.dtype),
scale_factor)
out = custom_caller(ScaledUpperTriangMaskedSoftmaxFwdPrimitive.name, args, opaque, False)
return [out]
_scaled_upper_triang_masked_softmax_fwd_p = \
register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)
def scaled_upper_triang_masked_softmax_fwd(inputs: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
"""
scaled_upper_triang_masked_softmax_forward wrapper
Return FP16/BF16 tensor
"""
return _scaled_upper_triang_masked_softmax_fwd_p.bind(inputs, scale_factor=scale_factor)
class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
"""
Scaled Upper Triang Masked Softmax Bwd Primitive
"""
name = "te_scaled_upper_triang_masked_softmax_backward"
multiple_results = False
@staticmethod
def is_kernel_available(batch: int, heads: int, q_seqlen: int, k_seqlen: int,
dtype: jnp.dtype) -> bool:
"""Check Softmax kernel availability based on size"""
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
batch, heads, q_seqlen, k_seqlen, dtype)
@staticmethod
def abstract(grad_outputs, softmax_outputs, *, scale_factor):
"""
te_scaled_upper_triang_masked_softmax_backward abstract
"""
return SoftmaxPrimitive.softmax_backward_abstract(grad_outputs, softmax_outputs,
scale_factor)
@staticmethod
def lowering(ctx, grad_outputs, softmax_outputs, *, scale_factor):
"""
te_scaled_upper_triang_masked_softmax_backward lowering rules
"""
out = SoftmaxPrimitive.softmax_backward_lowering(
ScaledUpperTriangMaskedSoftmaxBwdPrimitive.name, ctx, grad_outputs, softmax_outputs,
scale_factor)
return [out]
_scaled_upper_triang_masked_softmax_bwd_p = \
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
def scaled_upper_triang_masked_softmax_bwd(grad_outputs: jnp.ndarray, softmax_outputs: jnp.ndarray,
scale_factor: float) -> jnp.ndarray:
"""
scaled_upper_triang_masked_softmax_backward wrapper
Return FP16/BF16 tensor
"""
return _scaled_upper_triang_masked_softmax_bwd_p.bind(grad_outputs,
softmax_outputs,
scale_factor=scale_factor)
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "common/include/transformer_engine/transformer_engine.h"
#include "jax/csrc/modules.h"
#include "jax/csrc/utils.h"
namespace transformer_engine {
namespace jax {
template <typename T>
pybind11::capsule EncapsulateFunction(T *fn) {
return pybind11::capsule(reinterpret_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
}
pybind11::dict Registrations() {
pybind11::dict dict;
dict["te_transpose"] = EncapsulateFunction(Transpose);
dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose);
dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu);
dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8);
dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu);
dict["te_dgated_gelu_cast_transpose"] = EncapsulateFunction(DGatedGeluCastTranspose);
dict["te_gemm"] = EncapsulateFunction(Gemm);
dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward);
dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8);
dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward);
dict["te_rmsnorm_forward"] = EncapsulateFunction(RMSNormForward);
dict["te_rmsnorm_forward_fp8"] = EncapsulateFunction(RMSNormForwardFP8);
dict["te_rmsnorm_backward"] = EncapsulateFunction(RMSNormBackward);
dict["te_quantize"] = EncapsulateFunction(Quantize);
dict["te_dequantize"] = EncapsulateFunction(Dequantize);
dict["te_scaled_softmax_forward"] = EncapsulateFunction(ScaledSoftmaxForward);
dict["te_scaled_softmax_backward"] = EncapsulateFunction(ScaledSoftmaxBackward);
dict["te_scaled_masked_softmax_forward"] = EncapsulateFunction(ScaledMaskedSoftmaxForward);
dict["te_scaled_masked_softmax_backward"] = EncapsulateFunction(ScaledMaskedSoftmaxBackward);
dict["te_scaled_upper_triang_masked_softmax_forward"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward);
dict["te_scaled_upper_triang_masked_softmax_backward"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward);
return dict;
}
PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("registrations", &Registrations);
m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor);
m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor);
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
pybind11::enum_<DType>(m, "DType")
.value("kByte", DType::kByte)
.value("kInt32", DType::kInt32)
.value("kFloat32", DType::kFloat32)
.value("kFloat16", DType::kFloat16)
.value("kBFloat16", DType::kBFloat16)
.value("kFloat8E4M3", DType::kFloat8E4M3)
.value("kFloat8E5M2", DType::kFloat8E5M2);
}
} // namespace jax
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/modules.h"
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <functional>
#include <numeric>
#include <stdexcept>
#include <string>
#include <vector>
#include "common/common.h"
#include "transformer_engine/activation.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/gemm.h"
#include "transformer_engine/layer_norm.h"
#include "transformer_engine/rmsnorm.h"
#include "transformer_engine/softmax.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h"
#include "utils.h"
namespace transformer_engine {
namespace jax {
constexpr size_t kCublasLtForwardWorkspaceSize = 32 * 1024 * 1024;
constexpr size_t kCublasLtBackwardWorkspaceSize = 32 * 1024 * 1024;
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
template <typename T>
pybind11::bytes PackOpaque(const T &descriptor) {
auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T));
return pybind11::bytes(str);
}
template <typename T>
const T *UnpackOpaque(const char *opaque, size_t opaque_len) {
if (opaque_len != sizeof(T)) {
throw std::runtime_error("Invalid opaque object size");
}
return reinterpret_cast<const T *>(opaque);
}
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype) {
CustomCallCommonDescriptor desc;
desc.shape.from_vector(shape);
desc.in_dtype = in_dtype;
desc.out_dtype = out_dtype;
return PackOpaque(desc);
}
pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, DType A_dtype,
DType B_dtype, DType D_dtype, bool transa, bool transb,
bool use_split_accumulator) {
return PackOpaque(CustomCallGemmDescriptor{m, n, k, A_dtype, B_dtype, D_dtype, transa, transb,
use_split_accumulator});
}
pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype,
float eps) {
return PackOpaque(CustomCallNormDescriptor{n, hidden, x_dtype, w_dtype, eps});
}
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, size_t heads,
size_t q_seqlen, size_t k_seqlen, DType dtype,
float scale_factor) {
return PackOpaque(
SoftmaxDescriptor{batch, pad_batch, heads, q_seqlen, k_seqlen, dtype, scale_factor});
}
void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
void *output) {
auto input_shape = std::vector<size_t>{rows, cols};
auto output_shape = std::vector<size_t>{cols, rows};
auto input_tensor = TensorWrapper(input, input_shape, dtype);
auto transposed_tensor = TensorWrapper(output, output_shape, dtype);
nvte_transpose(input_tensor.data(), transposed_tensor.data(), stream);
}
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
void *input = buffers[0];
void *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto rows = desc.shape.dims[0];
auto cols = desc.shape.dims[1];
assert(desc.in_dtype == desc.out_dtype);
auto dtype = desc.out_dtype;
TransposeImpl(input, rows, cols, dtype, stream, output);
}
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *input_cast = buffers[4];
auto *input_cast_trans = buffers[5];
float *amax_out = reinterpret_cast<float *>(buffers[6]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto input_trans_shape = std::vector<size_t>{n, m};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto input_cast_tensor =
TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto input_cast_trans_tensor = TensorWrapper(input_cast_trans, input_trans_shape,
desc.out_dtype, amax_out, scale, scale_inv);
nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(),
input_cast_trans_tensor.data(), stream);
}
void GatedGeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale,
cudaStream_t stream, float *scale_inverse, float *amax, void *output) {
auto input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n};
auto input_tensor = TensorWrapper(input, input_shape, static_cast<DType>(in_dtype));
auto output_tensor = TensorWrapper(output, output_shape, static_cast<DType>(out_dtype), amax,
scale, scale_inverse);
nvte_geglu(input_tensor.data(), output_tensor.data(), stream);
}
void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
GatedGeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, nullptr, stream, nullptr, nullptr,
output);
}
void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
float *amax = reinterpret_cast<float *>(buffers[1]);
float *scale = reinterpret_cast<float *>(buffers[2]);
float *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
float *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
GatedGeluImpl(input, m, n, desc.in_dtype, desc.out_dtype, scale, stream, scale_inv, amax_out,
output);
}
void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *gelu_input = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = std::vector<size_t>{m, n};
auto gelu_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype);
nvte_dgeglu(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(), stream);
}
void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *gelu_input = buffers[1];
float *amax = reinterpret_cast<float *>(buffers[2]);
float *scale = reinterpret_cast<float *>(buffers[3]);
float *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *output_trans = buffers[6];
float *amax_out = reinterpret_cast<float *>(buffers[7]);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
assert(amax == amax_out);
if (!use_fp8(desc.out_dtype)) {
scale = nullptr;
scale_inv = nullptr;
amax_out = nullptr;
}
auto m = desc.shape.dims[0];
auto n = desc.shape.dims[1];
auto input_shape = desc.shape.to_vector();
auto gelu_input_shape = std::vector<size_t>{m, n * 2};
auto output_shape = std::vector<size_t>{m, n * 2};
auto output_trans_shape = std::vector<size_t>{n * 2, m};
auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
auto gelu_input_tensor = TensorWrapper(gelu_input, gelu_input_shape, desc.in_dtype);
auto output_tensor =
TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
auto output_trans_tensor =
TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
nvte_dgeglu_cast_transpose(input_tensor.data(), gelu_input_tensor.data(), output_tensor.data(),
output_trans_tensor.data(), stream);
}
void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *A = buffers[0];
auto *B = buffers[1];
auto *A_scale_inverse = reinterpret_cast<float *>(buffers[2]);
auto *B_scale_inverse = reinterpret_cast<float *>(buffers[3]);
auto *D = buffers[4];
// We transposes shape of A, B and D here to correctly invoke
// cuBlasLt GEMM (col-major) for row-major data.
const auto &desc = *UnpackOpaque<CustomCallGemmDescriptor>(opaque, opaque_len);
auto m = desc.m;
auto n = desc.n;
auto k = desc.k;
auto A_shape = std::vector<size_t>{k, m};
auto A_tensor = TensorWrapper(A, A_shape, desc.A_dtype, nullptr, nullptr, A_scale_inverse);
auto B_shape = std::vector<size_t>{n, k};
auto B_tensor = TensorWrapper(B, B_shape, desc.B_dtype, nullptr, nullptr, B_scale_inverse);
auto D_shape = std::vector<size_t>{n, m};
auto D_tensor = TensorWrapper(D, D_shape, desc.D_dtype);
auto null_tensor = TensorWrapper(nullptr, std::vector<size_t>{0}, DType::kFloat32);
size_t workspace_size = kCublasLtForwardWorkspaceSize;
auto *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto wk_tensor = TensorWrapper(workspace, std::vector<size_t>{workspace_size}, DType::kByte);
nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), null_tensor.data(),
null_tensor.data(), (desc.transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(desc.transb) ? CUBLAS_OP_T : CUBLAS_OP_N, false, wk_tensor.data(), false,
desc.use_split_accumulator, stream);
}
void LayerNormForwardImpl(size_t n, size_t hidden, void *input, DType in_dtype, void *weight,
DType w_dtype, void *bias, float eps, void *output, DType out_dtype,
void *mu, void *rsigma, float *amax, float *scale, float *scale_inv,
cudaStream_t stream) {
auto input_shape = std::vector<size_t>{n, hidden};
auto weight_shape = std::vector<size_t>{hidden};
auto intermediates_shape = std::vector<size_t>{n};
auto is_layer_norm = (bias) ? true : false;
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto gamma_tensor = TensorWrapper(weight, weight_shape, in_dtype);
// assume output dtype = input dtype
// If we need mixed I/O precision in the future, we need an additional
// parameter for output type
auto output_tensor = TensorWrapper(output, input_shape, out_dtype, amax, scale, scale_inv);
auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32);
// Create uninitialized workspace, barrier and init them on the first
TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
// The first call is to query the required workspace
if (is_layer_norm) {
auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype);
auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);
nvte_layernorm_fwd(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream,
num_sm, dummy_workspace_tensor.data(), dummy_barrier_tensor.data());
} else {
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
rsigma_tensor.data(), stream, num_sm, dummy_workspace_tensor.data(),
dummy_barrier_tensor.data());
}
size_t workspace_size =
dummy_workspace_tensor.shape().data[0] * typeToSize(dummy_workspace_tensor.dtype()) +
dummy_barrier_tensor.shape().data[0] * typeToSize(dummy_barrier_tensor.dtype());
void *workspace = cublasLtMetaManager::Instance().GetWorkspace(workspace_size);
auto workspace_tensor =
TensorWrapper(workspace, dummy_workspace_tensor.shape(), dummy_workspace_tensor.dtype());
auto barrier_tensor =
TensorWrapper(reinterpret_cast<char *>(workspace) + dummy_workspace_tensor.shape().data[0],
dummy_barrier_tensor.shape(), dummy_barrier_tensor.dtype());
if (is_layer_norm) {
auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype);
auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);
nvte_layernorm_fwd(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream,
num_sm, workspace_tensor.data(), barrier_tensor.data());
} else {
nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(),
rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(),
barrier_tensor.data());
}
}
void LayerNormBackwardImpl(size_t n, size_t hidden, void *input, DType in_dtype, void *weight,
DType w_dtype, void *ograd, void *mu, void *rsigma, float eps,
void *xgrad, void *wgrad, void *dbeta, cudaStream_t stream) {
auto input_shape = std::vector<size_t>{n, hidden};
auto weight_shape = std::vector<size_t>{hidden};
auto intermediates_shape = std::vector<size_t>{n};
auto intermediates_dtype = DType::kFloat32;
auto is_layer_norm = (dbeta) ? true : false;
// assume input type = output type
auto *grad_output = ograd;
auto x_dtype = in_dtype;
auto dz_tensor = TensorWrapper(grad_output, input_shape, x_dtype);
auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, intermediates_dtype);
auto *x = input;
auto x_tensor = TensorWrapper(x, input_shape, x_dtype);
auto gamma_tensor = TensorWrapper(weight, weight_shape, w_dtype);
auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype);
auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype);
TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor;
TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
size_t dbeta_part_size{};
// The first call is to query the workspace
if (is_layer_norm) {
auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
nvte_layernorm_bwd(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
wgrad_tensor.data(), dbeta_tensor.data(),
dummy_dgamma_part_tensor.data(), dummy_dbeta_part_tensor.data(), stream,
num_sm, dummy_workspace_tensor.data(), dummy_barrier_tensor.data());
dbeta_part_size = dummy_dbeta_part_tensor.shape().data[0] *
dummy_dbeta_part_tensor.shape().data[1] *
typeToSize(dummy_dbeta_part_tensor.dtype());
} else {
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dummy_dgamma_part_tensor.data(), stream, num_sm,
dummy_workspace_tensor.data(), dummy_barrier_tensor.data());
}
size_t workspace_size =
dummy_workspace_tensor.shape().data[0] * typeToSize(dummy_workspace_tensor.dtype());
size_t barrier_size =
dummy_barrier_tensor.shape().data[0] * typeToSize(dummy_barrier_tensor.dtype());
size_t dgamma_part_size = dummy_dgamma_part_tensor.shape().data[0] *
dummy_dgamma_part_tensor.shape().data[1] *
typeToSize(dummy_dgamma_part_tensor.dtype());
size_t total_workspace_size =
(workspace_size + barrier_size + dgamma_part_size + dbeta_part_size);
void *workspace = cublasLtMetaManager::Instance().GetWorkspace(total_workspace_size);
void *barrier = static_cast<char *>(workspace) + workspace_size;
void *dgamma_part = static_cast<char *>(barrier) + barrier_size;
void *dbeta_part = static_cast<char *>(dgamma_part) + dgamma_part_size;
auto workspace_tensor =
TensorWrapper(workspace, dummy_workspace_tensor.shape(), dummy_workspace_tensor.dtype());
auto barrier_tensor =
TensorWrapper(barrier, dummy_barrier_tensor.shape(), dummy_barrier_tensor.dtype());
auto dgamma_part_tensor = TensorWrapper(dgamma_part, dummy_dgamma_part_tensor.shape(),
dummy_dgamma_part_tensor.dtype());
if (is_layer_norm) {
auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
auto dbeta_part_tensor = TensorWrapper(dbeta_part, dummy_dbeta_part_tensor.shape(),
dummy_dbeta_part_tensor.dtype());
nvte_layernorm_bwd(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
wgrad_tensor.data(), dbeta_tensor.data(), dgamma_part_tensor.data(),
dbeta_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
barrier_tensor.data());
} else {
nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(),
gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(),
dgamma_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
barrier_tensor.data());
}
}
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
auto *bias = buffers[2];
auto *amax = reinterpret_cast<float *>(buffers[3]);
auto *scale = reinterpret_cast<float *>(buffers[4]);
auto *scale_inv = reinterpret_cast<float *>(buffers[5]);
auto *output = buffers[6];
auto *mu = buffers[7];
auto *rsigma = buffers[8];
auto *amax_out = buffers[9];
assert(amax_out == amax);
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto n = desc.n;
auto hidden = desc.hidden;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto eps = desc.eps;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(n, hidden, input, in_dtype, weight, w_dtype, bias, eps, output, out_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
}
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
auto *bias = buffers[2];
auto *output = buffers[3];
auto *mu = buffers[4];
auto *rsigma = buffers[5];
float *amax = nullptr;
float *scale = nullptr;
float *scale_inv = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto n = desc.n;
auto hidden = desc.hidden;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto eps = desc.eps;
auto out_dtype = in_dtype;
LayerNormForwardImpl(n, hidden, input, in_dtype, weight, w_dtype, bias, eps, output, out_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
}
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto n = desc.n;
auto hidden = desc.hidden;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto eps = desc.eps;
auto *ograd = buffers[0];
auto *mu = buffers[1];
auto *rsigma = buffers[2];
auto *input = buffers[3];
auto *weight = buffers[4];
auto *xgrad = buffers[5];
auto *wgrad = buffers[6];
auto *dbeta = buffers[7];
LayerNormBackwardImpl(n, hidden, input, in_dtype, weight, w_dtype, ograd, mu, rsigma, eps,
xgrad, wgrad, dbeta, stream);
}
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
auto *amax = reinterpret_cast<float *>(buffers[2]);
auto *scale = reinterpret_cast<float *>(buffers[3]);
auto *scale_inv = reinterpret_cast<float *>(buffers[4]);
auto *output = buffers[5];
auto *rsigma = buffers[6];
auto *amax_out = buffers[7];
assert(amax_out == amax);
void *bias = nullptr;
void *mu = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto n = desc.n;
auto hidden = desc.hidden;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto eps = desc.eps;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(n, hidden, input, in_dtype, weight, w_dtype, bias, eps, output, out_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
}
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
auto *output = buffers[2];
auto *rsigma = buffers[3];
void *bias = nullptr;
void *mu = nullptr;
float *amax = nullptr;
float *scale = nullptr;
float *scale_inv = nullptr;
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto n = desc.n;
auto hidden = desc.hidden;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto eps = desc.eps;
auto out_dtype = in_dtype;
LayerNormForwardImpl(n, hidden, input, in_dtype, weight, w_dtype, bias, eps, output, out_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
}
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *ograd = buffers[0];
auto *rsigma = buffers[1];
auto *input = buffers[2];
auto *weight = buffers[3];
auto *xgrad = buffers[4];
auto *wgrad = buffers[5];
const auto &desc = *UnpackOpaque<CustomCallNormDescriptor>(opaque, opaque_len);
auto n = desc.n;
auto hidden = desc.hidden;
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto eps = desc.eps;
void *mu = nullptr;
void *dbeta = nullptr;
LayerNormBackwardImpl(n, hidden, input, in_dtype, weight, w_dtype, ograd, mu, rsigma, eps,
xgrad, wgrad, dbeta, stream);
}
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *amax = reinterpret_cast<float *>(buffers[1]);
auto *scale = reinterpret_cast<float *>(buffers[2]);
auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
auto *amax_out = reinterpret_cast<float *>(buffers[5]);
assert(amax == amax_out);
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto shape = desc.shape.to_vector();
auto input_tensor = TensorWrapper(input, shape, desc.in_dtype);
auto output_tensor = TensorWrapper(output, shape, desc.out_dtype, amax_out, scale, scale_inv);
nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream);
}
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *amax = reinterpret_cast<float *>(buffers[1]);
auto *scale = reinterpret_cast<float *>(buffers[2]);
auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
auto *output = buffers[4];
const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
auto shape = desc.shape.to_vector();
auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv);
auto output_tensor = TensorWrapper(output, shape, desc.out_dtype);
nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream);
}
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto shape = std::vector<size_t>{desc.batch, desc.heads, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, shape, dtype);
auto output_tensor = TensorWrapper(output, shape, dtype);
nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), desc.scale_factor,
stream);
}
void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *grad_output = buffers[0];
auto *softmax_output = buffers[1];
auto *dgrad = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto shape = std::vector<size_t>{desc.batch, desc.heads, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);
nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(),
dgrad_tensor.data(), desc.scale_factor, stream);
}
void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *mask = buffers[1];
auto *output = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto io_shape = std::vector<size_t>{desc.batch, desc.heads, desc.q_seqlen, desc.k_seqlen};
auto mask_shape = std::vector<size_t>{desc.pad_batch, 1, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, io_shape, dtype);
// Mask would be casted to uint8_t
auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte);
auto output_tensor = TensorWrapper(output, io_shape, dtype);
nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(),
output_tensor.data(), desc.scale_factor, stream);
}
void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
// The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax.
ScaledSoftmaxBackward(stream, buffers, opaque, opaque_len);
}
void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
auto *output = buffers[1];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto attn_batch = desc.batch * desc.heads;
auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto input_tensor = TensorWrapper(input, shape, dtype);
auto output_tensor = TensorWrapper(output, shape, dtype);
nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(),
desc.scale_factor, stream);
}
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *grad_output = buffers[0];
auto *softmax_output = buffers[1];
auto *dgrad = buffers[2];
const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
auto attn_batch = desc.batch * desc.heads;
auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
auto dtype = desc.dtype;
auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);
nvte_scaled_upper_triang_masked_softmax_backward(
grad_output_tensor.data(), softmax_output_tensor.data(), dgrad_tensor.data(),
desc.scale_factor, stream);
}
} // namespace jax
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
#include <cuda_runtime_api.h>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "transformer_engine/logging.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
namespace jax {
constexpr int kMaxNumDim = 8;
struct Shape {
int num_dim;
size_t dims[kMaxNumDim];
void from_vector(const std::vector<size_t> &shape) {
num_dim = shape.size();
assert(num_dim <= kMaxNumDim);
std::memcpy(dims, shape.data(), num_dim * sizeof(size_t));
}
std::vector<size_t> to_vector() const {
assert(num_dim <= kMaxNumDim);
std::vector<size_t> shape(num_dim);
std::memcpy(shape.data(), dims, num_dim * sizeof(size_t));
return shape;
}
};
struct CustomCallCommonDescriptor {
Shape shape;
DType in_dtype;
DType out_dtype;
};
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype);
struct CustomCallGemmDescriptor {
size_t m;
size_t n;
size_t k;
DType A_dtype;
DType B_dtype;
DType D_dtype;
bool transa;
bool transb;
bool use_split_accumulator;
};
pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, DType A_dtype,
DType B_dtype, DType D_dtype, bool transa, bool transb,
bool use_split_accumulator);
struct CustomCallNormDescriptor {
size_t n;
size_t hidden;
DType x_dtype;
DType w_dtype;
float eps;
};
pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype,
float eps);
struct SoftmaxDescriptor {
size_t batch;
size_t pad_batch;
size_t heads;
size_t q_seqlen;
size_t k_seqlen;
DType dtype;
float scale_factor;
};
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, size_t heads,
size_t q_seqlen, size_t k_seqlen, DType dtype,
float scale_factor);
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
} // namespace jax
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#include <pybind11/pybind11.h>
#include <cstdint>
#include <stdexcept>
#include <string>
#include <type_traits>
#include "transformer_engine/logging.h"
namespace transformer_engine {
namespace jax {
class cublasLtMetaManager {
public:
static cublasLtMetaManager &Instance() {
static thread_local cublasLtMetaManager instance;
return instance;
}
cublasLtMetaManager() {}
~cublasLtMetaManager() { Clear_(); }
void *GetWorkspace(size_t size = 4194304) {
ReallocateIfNeed_(size);
return workspace_;
}
private:
void *workspace_ = nullptr;
size_t size_ = 0;
void Clear_() {
if (workspace_ != nullptr) {
NVTE_CHECK_CUDA(cudaFree(workspace_));
}
workspace_ = nullptr;
size_ = 0;
}
void Allocate_(size_t new_size) {
NVTE_CHECK_CUDA(cudaMalloc(&workspace_, new_size));
size_ = new_size;
}
void ReallocateIfNeed_(size_t new_size) {
if (new_size > size_) {
Clear_();
Allocate_(new_size);
}
}
};
class cudaDevicePropertiesManager {
public:
static cudaDevicePropertiesManager &Instance() {
static thread_local cudaDevicePropertiesManager instance;
return instance;
}
int GetMultiProcessorCount() {
if (!prop_queried_) {
int device_id;
NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
cudaGetDeviceProperties(&prop_, device_id);
prop_queried_ = true;
}
return prop_.multiProcessorCount;
}
private:
bool prop_queried_ = false;
cudaDeviceProp prop_;
};
} // namespace jax
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX te modules"""
from typing import Tuple, Sequence
from functools import partial, reduce
import operator
import jax
import jax.numpy as jnp
from transformer_engine_jax import DType as TEDType
from .cpp_extensions import cast_transpose, gemm, jax_dtype_to_te_dtype
from .fp8 import FP8Helper, FP8GemmPackage
from .sharding import ShardingType, get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources
from .sharding import xmap_runner
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
amax_history_idx: int,
fwd_dtype: TEDType,
bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0) -> jnp.ndarray:
"""
FP8 dot wrapper
"""
assert fp8_gemm_pkg.num_of_gemm == 1
inputs = fp8_gemm_pkg.inputs
kernel = fp8_gemm_pkg.kernels[0]
fp8_max = fp8_gemm_pkg.fp8_max
amax = fp8_gemm_pkg.amax
scale = fp8_gemm_pkg.scale
scale_inv = fp8_gemm_pkg.scale_inv
if sharding_type is ShardingType.SINGLE:
res = _fp8_dot(inputs,
kernel,
fp8_max,
amax,
scale,
scale_inv,
amax_history_idx=amax_history_idx,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name="",
tp_axis_name="")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
kernel_tp_index = None
# TODO (Ming Huang): Should we add a new argument to support general sharding to kernel? # pylint: disable=fixme
if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
kernel_tp_index = len(kernel.shape) - 1
elif sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
kernel_tp_index = 0
input_tp_index = len(inputs.shape) - 1
sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape,
dp_dim_index, input_tp_index, kernel_tp_index,
contracting_dims, dp_axis_name, tp_axis_name)
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
kernel_ = jnp.reshape(kernel, sharding_meta.input_shapes[1]) # 1 for kernel
num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv
fp8_sharding_meta = get_fp8_meta_sharding_meta(sharding_type, num_of_fp8_meta_kind,
dp_axis_name, tp_axis_name)
axis_resources = merge_axis_resources(
[sharding_meta.axis_resources, fp8_sharding_meta.axis_resources])
partial_fp8_dot = partial(_fp8_dot,
amax_history_idx=amax_history_idx,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
res = xmap_runner(partial_fp8_dot, (*sharding_meta.in_axes, *fp8_sharding_meta.in_axes),
sharding_meta.out_axes, axis_resources,
(inputs_, kernel_, fp8_max, amax, scale, scale_inv))
res = jnp.reshape(res, sharding_meta.output_shapes[0])
return res
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12))
def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, amax_history_idx: int, fwd_dtype: TEDType,
bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int], Sequence[int]],
sharding_type: ShardingType, dp_axis_name: str, tp_axis_name: str):
res, _ = _fp8_dot_fwd(inputs,
kernel,
fp8_maxs,
amax,
scale,
scale_inv,
amax_history_idx,
fwd_dtype,
bwd_dtype,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
return res
def _fp8_dot_fwd(
inputs,
kernel,
fp8_maxs,
amax,
scale,
scale_inv,
amax_history_idx, # pylint: disable=unused-argument
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
contracting_dims,
sharding_type,
dp_axis_name, # pylint: disable=unused-argument
tp_axis_name):
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
input_shape_suf = inputs.shape[min(lhs_contracting_dims):]
kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
kernel_shape_suf = kernel.shape[max(rhs_contracting_dims) + 1:]
input_contracting_size = reduce(operator.mul, input_shape_suf)
kernel_contracting_size = reduce(operator.mul, kernel_shape_pre)
assert input_contracting_size == kernel_contracting_size
inputs_ = jnp.reshape(inputs, (-1, input_contracting_size))
kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1))
gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm_input_idx]
input_scale = scale[gemm_input_idx]
input_scale_inv = scale_inv[gemm_input_idx]
input_cast, input_cast_trans, input_amax = cast_transpose(inputs_, input_amax, input_scale,
input_scale_inv, fwd_dtype)
kernel_amax = amax[gemm_kernel_idx]
kernel_scale = scale[gemm_kernel_idx]
kernel_scale_inv = scale_inv[gemm_kernel_idx]
kernel_cast, kernel_cast_trans, kernel_amax = cast_transpose(kernel_, kernel_amax, kernel_scale,
kernel_scale_inv, fwd_dtype)
res = gemm(kernel_cast_trans, kernel_scale_inv, fwd_dtype, True, input_cast, input_scale_inv,
fwd_dtype, False, jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP)
if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
res = jax.lax.psum(res, tp_axis_name)
# (input_shape_pre, input_shape_suf)
# x (kernel_shape_pre, kernel_shape_suf)
# = (input_shape_pre, kernel_shape_suf)
output_shape = input_shape_pre + kernel_shape_suf
res = jnp.reshape(res, output_shape)
ctx = (input_cast_trans, kernel_cast, fp8_maxs, amax, scale, scale_inv, input_amax, kernel_amax,
inputs.shape, kernel.shape)
return res, ctx
def _fp8_dot_bwd(
amax_history_idx,
fwd_dtype,
bwd_dtype,
contracting_dims, # pylint: disable=unused-argument
sharding_type,
dp_axis_name,
tp_axis_name,
ctx,
g):
input_cast_trans, kernel_cast, \
fp8_maxs, amax, scale, scale_inv, \
input_amax, kernel_amax, \
inputs_shape, kernel_shape = ctx
gemm_input_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
grad_amax = amax[gemm_grad_idx]
grad_scale = scale[gemm_grad_idx]
grad_scale_inv = scale_inv[gemm_grad_idx]
g = jnp.reshape(g, (input_cast_trans.shape[1], -1))
grad_cast, grad_cast_trans, grad_amax = cast_transpose(g, grad_amax, grad_scale, grad_scale_inv,
bwd_dtype)
input_scale_inv = scale_inv[gemm_input_idx]
wgrad = gemm(grad_cast_trans, grad_scale_inv, bwd_dtype,
True, input_cast_trans, input_scale_inv, fwd_dtype, False,
jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD)
kernel_scale_inv = scale_inv[gemm_kernel_idx]
dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv,
bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD)
amax = amax.at[gemm_input_idx, amax_history_idx].set(input_amax[0])
amax = amax.at[gemm_kernel_idx, amax_history_idx].set(kernel_amax[0])
amax = amax.at[gemm_grad_idx, amax_history_idx].set(grad_amax[0])
if is_dp_enabled(sharding_type.value[0]):
wgrad = jax.lax.psum(wgrad, dp_axis_name)
amax = jax.lax.pmax(amax, dp_axis_name)
if is_tp_enabled(sharding_type.value[0]):
amax = jax.lax.pmax(amax, tp_axis_name)
if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
dgrad = jax.lax.psum(dgrad, tp_axis_name)
dgrad = jnp.reshape(dgrad, inputs_shape)
wgrad = jnp.reshape(wgrad, kernel_shape)
return dgrad, wgrad, fp8_maxs, amax, scale, scale_inv
_fp8_dot.defvjp(_fp8_dot_fwd, _fp8_dot_bwd)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Helper module for fp8 meta management
"""
import os
from contextlib import contextmanager
from typing import Optional, Union, Dict, List, Tuple
from flax.core.frozen_dict import FrozenDict
import jax
import jax.numpy as jnp
from transformer_engine_jax import DType
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import ShardingResource
Collection = Union[Dict, FrozenDict]
def _format2dtypes(format_: Format):
if format_ == Format.E4M3:
return DType.kFloat8E4M3, DType.kFloat8E4M3
if format_ == Format.E5M2:
return DType.kFloat8E5M2, DType.kFloat8E5M2
if format_ == Format.HYBRID:
return DType.kFloat8E4M3, DType.kFloat8E5M2
return DType.kBFloat16, DType.kBFloat16
class FP8GemmPackage:
"""
A container that contains all required data for
FP8 GEMM
"""
def __init__(
self,
num_of_gemm: int,
inputs: jnp.ndarray,
kernels: List[jnp.ndarray],
fp8_max: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
) -> None:
self._num_of_gemm = num_of_gemm
self._inputs = inputs
assert len(kernels) == self._num_of_gemm
self._kernels = kernels
total_num_of_meta = self._num_of_gemm * FP8Helper.NUM_META_PER_GEMM
assert fp8_max.shape[0] == total_num_of_meta
self._fp8_max = fp8_max
assert amax.shape[0] == total_num_of_meta
self._amax = amax
assert scale.shape[0] == total_num_of_meta
self._scale = scale
assert scale_inv.shape[0] == total_num_of_meta
self._scale_inv = scale_inv
@property
def num_of_gemm(self) -> int:
"""
num_of_gemm of this package
"""
return self._num_of_gemm
@property
def inputs(self) -> jnp.ndarray:
"""
inputs of this package
"""
return self._inputs
@property
def kernels(self) -> List[jnp.ndarray]:
"""
kernels of this package
"""
return self._kernels
@property
def fp8_max(self) -> jnp.ndarray:
"""
fp8_max of this package
"""
return self._fp8_max
@property
def amax(self) -> jnp.ndarray:
"""
amax of this package
"""
return self._amax
@property
def scale(self) -> jnp.ndarray:
"""
scale of this package
"""
return self._scale
@property
def scale_inv(self) -> jnp.ndarray:
"""
scale_inv of this package
"""
return self._scale_inv
class FP8Helper:
"""
FP8 helper to manage the FP8 meta
"""
INITIALIZED = False
MARGIN: float = 0.0
FP8_FORMAT: Format = Format.HYBRID
FWD_DTYPE: DType = DType.kFloat8E4M3
BWD_DTYPE: DType = DType.kFloat8E5M2
UPDATE_FP8META_INTERVAL: int = 1
AMAX_HISTORY_SIZE: int = 1
NUM_META_PER_GEMM: int = 3
INPUT_META_IDX_PER_GEMM: int = 0
KERNEL_META_IDX_PER_GEMM: int = 1
GRAD_META_IDX_PER_GEMM: int = 2
FP8_COLLECTION_NAME: str = "fp8_meta_collection"
FP8_AMAX_NAME: str = "fp8_meta_amax"
FP8_SCALE_NAME: str = "fp8_meta_scale"
FP8_SCALE_INV_NAME: str = "fp8_meta_scale_inv"
FP8_MAX_NAME: str = "fp8_max"
FP8_2X_ACC_FPROP_ENV_VAR_NAME = "NVTE_JAX_FP8_2X_ACC_FPROP"
FP8_2X_ACC_DGRAD_ENV_VAR_NAME = "NVTE_JAX_FP8_2X_ACC_DGRAD"
FP8_2X_ACC_WGRAD_ENV_VAR_NAME = "NVTE_JAX_FP8_2X_ACC_WGRAD"
FP8_2X_ACC_FPROP: bool = False
FP8_2X_ACC_DGRAD: bool = False
FP8_2X_ACC_WGRAD: bool = False
@staticmethod
def enable_fp8():
"""
Indicate if fp8 training is enable or not.
"""
return FP8Helper.INITIALIZED
@staticmethod
def initialize(margin: float = 0.0,
fp8_format: Format = Format.HYBRID,
update_fp8meta_interval: int = 1,
amax_history_size: int = 1) -> None:
"""
Initialize the FP8 meta
"""
FP8Helper.INITIALIZED = True
FP8Helper.MARGIN = margin
FP8Helper.FP8_FORMAT = fp8_format
FP8Helper.FWD_DTYPE, FP8Helper.BWD_DTYPE = \
_format2dtypes(FP8Helper.FP8_FORMAT)
FP8Helper.UPDATE_FP8META_INTERVAL = update_fp8meta_interval
FP8Helper.AMAX_HISTORY_SIZE = amax_history_size
FP8Helper.FP8_2X_ACC_FPROP = bool(
int(os.environ.get(FP8Helper.FP8_2X_ACC_FPROP_ENV_VAR_NAME, False)))
FP8Helper.FP8_2X_ACC_DGRAD = bool(
int(os.environ.get(FP8Helper.FP8_2X_ACC_DGRAD_ENV_VAR_NAME, False)))
FP8Helper.FP8_2X_ACC_WGRAD = bool(
int(os.environ.get(FP8Helper.FP8_2X_ACC_WGRAD_ENV_VAR_NAME, False)))
@staticmethod
def finalize() -> None:
"""
FP8 helper finalize
"""
FP8Helper.INITIALIZED = False
FP8Helper.MARGIN = 0.0
FP8Helper.FP8_FORMAT = Format.HYBRID
FP8Helper.FWD_DTYPE = DType.kFloat8E4M3
FP8Helper.BWD_DTYPE = DType.kFloat8E5M2
FP8Helper.UPDATE_FP8META_INTERVAL = 1
FP8Helper.AMAX_HISTORY_SIZE = 1
@staticmethod
def update_collections(new: Collection, original: Collection) -> None:
"""
Update the collections
"""
if not isinstance(original, FrozenDict):
original = FrozenDict(original)
for key in new:
if key in original:
original, _ = original.pop(key)
return FrozenDict({**new, **original})
@staticmethod
def update_fp8_metas(state: Collection) -> Collection:
"""
Update the FP8 metas
"""
if FP8Helper.FP8_COLLECTION_NAME in state:
if not isinstance(state, FrozenDict):
state = FrozenDict(state)
others, fp8_metas = state.pop(FP8Helper.FP8_COLLECTION_NAME)
fp8_metas = FP8Helper._update_fp8_metas_impl(fp8_metas)
return FrozenDict({**others, FP8Helper.FP8_COLLECTION_NAME: fp8_metas})
return state
@staticmethod
def generate_fp8_max_array(num_of_meta):
"""
Generate the FP8 max array
"""
num_of_gemm = num_of_meta // FP8Helper.NUM_META_PER_GEMM
fp8_max_fwd = FP8Helper.FP8_FORMAT.value.max_fwd
fp8_max_bwd = FP8Helper.FP8_FORMAT.value.max_bwd
fp8_max_per_gemm = []
for i in range(FP8Helper.NUM_META_PER_GEMM):
val = fp8_max_bwd if i == FP8Helper.GRAD_META_IDX_PER_GEMM \
else fp8_max_fwd
fp8_max_per_gemm.append([val])
fp8_max_per_gemm = jnp.asarray(fp8_max_per_gemm, dtype=jnp.float32)
return jnp.vstack([fp8_max_per_gemm] * num_of_gemm)
@staticmethod
def get_fp8_meta_indices(gemm_idx: int) -> Tuple[int]:
"""
Obtain the index about FP8 metas by the given GEMM index.
"""
input_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.INPUT_META_IDX_PER_GEMM
kernel_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.KERNEL_META_IDX_PER_GEMM
grad_idx = FP8Helper.NUM_META_PER_GEMM * gemm_idx + FP8Helper.GRAD_META_IDX_PER_GEMM
return input_idx, kernel_idx, grad_idx
@staticmethod
@jax.jit
def _update_fp8_metas_impl(fp8_metas: Collection) -> Collection:
fp8_meta_arrays, treedef = jax.tree_util.tree_flatten(fp8_metas)
num_of_meta_with_max = FP8Helper.NUM_META_PER_GEMM + 1
num_of_gemm = len(fp8_meta_arrays) // num_of_meta_with_max
for i in range(num_of_gemm):
# flattern array is ordered in alphabetical order of collection names
fp8_max_idx = i * num_of_meta_with_max
fp8_amax_idx = fp8_max_idx + 1
fp8_scale_idx = fp8_amax_idx + 1
fp8_scale_inv_idx = fp8_scale_idx + 1
fp8_max = fp8_meta_arrays[fp8_max_idx]
amax = fp8_meta_arrays[fp8_amax_idx]
scale = fp8_meta_arrays[fp8_scale_idx]
exp = jnp.floor(jnp.log2(fp8_max / amax)) - FP8Helper.MARGIN
sf = jnp.round(jnp.power(2, jnp.abs(exp)))
sf = jnp.where(amax > 0.0, sf, scale)
sf = jnp.where(jnp.isfinite(amax), sf, scale)
scale = jnp.where(exp < 0, 1 / sf, sf)
fp8_meta_arrays[fp8_scale_idx] = scale
fp8_meta_arrays[fp8_scale_inv_idx] = 1 / scale
return jax.tree_util.tree_unflatten(treedef, fp8_meta_arrays)
@contextmanager
def fp8_autocast(enabled: bool = False,
fp8_recipe: Optional[DelayedScaling] = None,
sharding_resource: Optional[ShardingResource] = None) -> None:
"""
Context manager for FP8 usage.
.. code-block:: python
mesh_shape = (4, 2)
dp_mesh_axis_name = 'data_parallel'
tp_mesh_axis_name = 'tensor_parallel'
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
sharding_resource=ShardingResource(dp_mesh_axis_name, tp_mesh_axis_name)
with fp8_autocast(enabled=True, sharding_resource=sharding_resource):
rules = extend_logical_axis_rules(tuple())
transformer = TransformerLayer()
with partitioning.axis_rules(rules):
pjit(transformer.init, ...)(...)
.. note::
We only support :attr:`margin`, :attr:`fp8_format`, :attr:`interval` and
:attr:`amax_history_len=1` in recipe.DelayedScaling currently. Other parameters
in recipe.DelayedScaling would be ignored, even is set.
Parameters
----------
enabled: bool, default = False
whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = None
recipe used for FP8 training.
sharding_resource: ShardingResource, defaule = None
specify the mesh axes for data and tensor parallelism to shard along.
If set to None, then ShardingResource() would be created.
"""
if fp8_recipe is None:
fp8_recipe = DelayedScaling()
assert fp8_recipe.amax_history_len == 1, \
"It only support amax_history_len == 1 for now."
if sharding_resource is None:
sharding_resource = ShardingResource()
try:
with global_shard_guard(sharding_resource):
if enabled:
FP8Helper.initialize(margin=fp8_recipe.margin,
fp8_format=fp8_recipe.fp8_format,
update_fp8meta_interval=fp8_recipe.interval,
amax_history_size=fp8_recipe.amax_history_len)
yield
finally:
FP8Helper.finalize()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment