Unverified Commit a3ec6a54 authored by Jeng Bai-Cheng's avatar Jeng Bai-Cheng Committed by GitHub
Browse files

add building workflow for TE/Jax (#53)



* add building workflow for jax modules
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* replace bit_cast with reinterpret_cast
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add nvtx to cmake check list
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor layernorm fwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor rmsnorm fwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor layernorm_bwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* set pytorch as default in setup.py
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* rename extension from *.cc to *.cpp

cpplint cannot recognize *.cc file, so rename the extension
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor style, to align TE/PyTorch
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add pybinding, unittest and qa
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix license
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* disable c-extension-no-member and no-name-in-module
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add dataclass avoid pylint error
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update transformer_engine/__init__.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update tests/jax/test_custom_call_shape.py

fix typo
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update tests/jax/test_custom_call_shape.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* add building workflow for jax modules
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* replace bit_cast with reinterpret_cast
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add nvtx to cmake check list
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor layernorm fwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor rmsnorm fwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor layernorm_bwd
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* set pytorch as default in setup.py
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* rename extension from *.cc to *.cpp

cpplint cannot recognize *.cc file, so rename the extension
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* refactor style, to align TE/PyTorch
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add pybinding, unittest and qa
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix license
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* disable c-extension-no-member and no-name-in-module
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* add dataclass avoid pylint error
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update transformer_engine/__init__.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update tests/jax/test_custom_call_shape.py

fix typo
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* Update tests/jax/test_custom_call_shape.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* fix conflict due to PR62
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* fix c-extension-no-member and no-name-in-module

1. add transformer_engine_jax into extension-pkg-whitelist
2. convert pylintrc from CRLF to LF format
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

* Update setup.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>

* remove pylint:disable and refactor import order
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>

---------
Signed-off-by: default avatarRyan Jeng <rjeng@nvidia.com>
Signed-off-by: default avatarJeng Bai-Cheng <jeng1220@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d8a2f352
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -xe
: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax
[MASTER]
extension-pkg-whitelist=torch,
transformer_engine_extensions
transformer_engine_extensions,
transformer_engine_jax
disable=too-many-locals,
invalid-name,
......
......@@ -14,12 +14,12 @@ from setuptools import setup, find_packages, Extension
from setuptools.command.build_ext import build_ext
from distutils.version import LooseVersion
from distutils.file_util import copy_file
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.dirname(os.path.realpath(__file__))
with open(path + "/VERSION", "r") as f:
te_version = f.readline()
CUDA_HOME = os.environ.get("CUDA_HOME", "/usr/local/cuda")
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
......@@ -94,9 +94,10 @@ all_sources = pytorch_sources
supported_frameworks = {
"all": all_sources,
"pytorch": pytorch_sources,
"jax": None, # JAX use transformer_engine/CMakeLists.txt
}
framework = "all"
framework = os.environ.get("NVTE_FRAMEWORK", "pytorch")
args = sys.argv.copy()
for s in args:
......@@ -113,19 +114,72 @@ class CMakeExtension(Extension):
super(CMakeExtension, self).__init__(name, sources=sources, **kwargs)
self.cmake_path = cmake_path
class FrameworkBuilderBase:
def __init__(self, *args, **kwargs) -> None:
pass
def cmake_flags(self):
return []
def initialize_options(self):
pass
def finalize_options(self):
pass
def run(self, extensions):
pass
@staticmethod
def install_requires():
return []
class PyTorchBuilder(FrameworkBuilderBase):
def __init__(self, *args, **kwargs) -> None:
pytorch_args = copy.deepcopy(args)
pytorch_kwargs = copy.deepcopy(kwargs)
from torch.utils.cpp_extension import BuildExtension
self.pytorch_build_extensions = BuildExtension(*pytorch_args, **pytorch_kwargs)
def initialize_options(self):
self.pytorch_build_extensions.initialize_options()
def finalize_options(self):
self.pytorch_build_extensions.finalize_options()
def run(self, extensions):
other_ext = [
ext for ext in extensions if not isinstance(ext, CMakeExtension)
]
self.pytorch_build_extensions.extensions = other_ext
print("Building pyTorch extensions!")
self.pytorch_build_extensions.run()
@staticmethod
def install_requires():
return ["flash-attn @ git+https://github.com/ksivaman/flash-attention.git@hopper",]
class JaxBuilder(FrameworkBuilderBase):
def cmake_flags(self):
return ["-DENABLE_JAX=ON"]
def run(self, extensions):
print("Building jax extensions!")
ext_modules = []
dlfw_builder_funcs = []
ext_modules.append(
CMakeExtension(
name="transformer_engine",
cmake_path=os.path.join(path, "transformer_engine/common"),
cmake_path=os.path.join(path, "transformer_engine"),
sources=[],
include_dirs=include_dirs,
)
)
if framework in ("all", "pytorch"):
from torch.utils.cpp_extension import CUDAExtension
ext_modules.append(
CUDAExtension(
name="transformer_engine_extensions",
......@@ -137,6 +191,14 @@ if framework in ("all", "pytorch"):
include_dirs=include_dirs,
)
)
dlfw_builder_funcs.append(PyTorchBuilder)
if framework in ("all", "jax"):
dlfw_builder_funcs.append(JaxBuilder)
dlfw_install_requires = []
for builder in dlfw_builder_funcs:
dlfw_install_requires = dlfw_install_requires + builder.install_requires()
def get_cmake_bin():
......@@ -179,6 +241,7 @@ def get_cmake_bin():
class CMakeBuildExtension(build_ext, object):
def __init__(self, *args, **kwargs) -> None:
self.dlfw_flags = kwargs["dlfw_flags"]
super(CMakeBuildExtension, self).__init__(*args, **kwargs)
def build_extensions(self) -> None:
......@@ -198,6 +261,7 @@ class CMakeBuildExtension(build_ext, object):
"-DCMAKE_BUILD_TYPE=" + config,
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(config.upper(), build_dir),
]
cmake_args = cmake_args + self.dlfw_flags
cmake_build_args = ["--config", config]
......@@ -223,26 +287,35 @@ class CMakeBuildExtension(build_ext, object):
except OSError as e:
raise RuntimeError("CMake failed: {}".format(str(e)))
class TEBuildExtension(build_ext, object):
def __init__(self, *args, **kwargs) -> None:
self.dlfw_builder = []
for functor in dlfw_builder_funcs:
self.dlfw_builder.append(functor(*args, **kwargs))
flags = []
for builder in self.dlfw_builder:
flags = flags + builder.cmake_flags()
cmake_args = copy.deepcopy(args)
cmake_kwargs = copy.deepcopy(kwargs)
pytorch_args = copy.deepcopy(args)
pytorch_kwargs = copy.deepcopy(kwargs)
cmake_kwargs["dlfw_flags"] = flags
self.cmake_build_extensions = CMakeBuildExtension(*cmake_args, **cmake_kwargs)
self.pytorch_build_extensions = BuildExtension(*pytorch_args, **pytorch_kwargs)
self.all_outputs = None
super(TEBuildExtension, self).__init__(*args, **kwargs)
def initialize_options(self):
self.cmake_build_extensions.initialize_options()
self.pytorch_build_extensions.initialize_options()
for builder in self.dlfw_builder:
builder.initialize_options()
super(TEBuildExtension, self).initialize_options()
def finalize_options(self):
self.cmake_build_extensions.finalize_options()
self.pytorch_build_extensions.finalize_options()
for builder in self.dlfw_builder:
builder.finalize_options()
super(TEBuildExtension, self).finalize_options()
def run(self) -> None:
......@@ -250,12 +323,9 @@ class TEBuildExtension(build_ext, object):
cmake_ext = [ext for ext in self.extensions if isinstance(ext, CMakeExtension)]
self.cmake_build_extensions.extensions = cmake_ext
self.cmake_build_extensions.run()
other_ext = [
ext for ext in self.extensions if not isinstance(ext, CMakeExtension)
]
self.pytorch_build_extensions.extensions = other_ext
print("Building pyTorch extensions!")
self.pytorch_build_extensions.run()
for builder in self.dlfw_builder:
builder.run(self.extensions)
self.all_outputs = []
for f in os.scandir(self.build_lib):
......@@ -313,8 +383,6 @@ setup(
description="Transformer acceleration library",
ext_modules=ext_modules,
cmdclass={"build_ext": TEBuildExtension},
install_requires = [
"flash-attn @ git+https://github.com/ksivaman/flash-attention.git@hopper",
],
install_requires=dlfw_install_requires,
license_files=("LICENSE",),
)
This diff is collapsed.
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import jax.numpy as jnp
from jax.abstract_arrays import ShapedArray
from transformer_engine_jax import DType
from transformer_engine.jax.cpp_extensions import te_dtype_to_jax_dtype
from transformer_engine.jax.cpp_extensions import GemmPrimitive
SHAPES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024),
(16384, 1024, 1024)]
NAMED_SHAPES = [{}, {
"data": 4
}, {
"data": 2
}, {
"model": 4
}, {
"model": 2
}, {
"data": 4,
"model": 2
}, {
"model": 4,
"data": 2
}]
DTYPE = [DType.kFloat32, DType.kFloat16, DType.kBFloat16]
TRANSPOSE = [True, False]
class TestGEMMShapeInfer:
@staticmethod
def _joint_named_shape(ns1, ns2):
output_named_shape = {**ns1}
need_assert = False
for key in ns2:
if key in output_named_shape and output_named_shape[key] != ns2[key]:
need_assert = True
else:
output_named_shape[key] = ns2[key]
return output_named_shape, need_assert
@staticmethod
def _get_shapes(m, n, k, transa, transb):
# te_gemm only support TN and col-major, then we have to reorder a, b shape
# to compute row-major matrices calculate in col-major algos.
a = (m, k) if transa else (k, m)
b = (k, n) if transb else (n, k)
out = (n, m)
return a, b, out
@pytest.mark.parametrize('shapes', SHAPES)
@pytest.mark.parametrize('named_shape1', NAMED_SHAPES)
@pytest.mark.parametrize('named_shape2', NAMED_SHAPES)
@pytest.mark.parametrize('te_dtype', DTYPE)
@pytest.mark.parametrize('transa', TRANSPOSE)
@pytest.mark.parametrize('transb', TRANSPOSE)
def test_shape_infer(self, shapes, named_shape1, named_shape2, te_dtype, transa, transb):
a_shape, b_shape, out_shape = TestGEMMShapeInfer._get_shapes(*shapes, transa, transb)
dtype = te_dtype_to_jax_dtype(te_dtype)
mat_a = ShapedArray(a_shape, dtype, named_shape=named_shape1)
mat_b = ShapedArray(b_shape, dtype, named_shape=named_shape2)
scale_inv_a = ShapedArray((3, 1), jnp.float32)
scale_inv_b = ShapedArray((3, 1), jnp.float32)
ref_out_named_shape, need_assert = TestGEMMShapeInfer._joint_named_shape(
named_shape1, named_shape2)
ref_out = ShapedArray(out_shape, dtype, named_shape=ref_out_named_shape)
try:
test_out = GemmPrimitive.abstract(mat_a,
mat_b,
scale_inv_a,
scale_inv_b,
A_dtype=te_dtype,
B_dtype=te_dtype,
D_dtype=te_dtype,
transa=transa,
transb=transb,
use_split_accumulator=False)
assert not need_assert
assert ref_out == test_out
except AssertionError as ae:
assert need_assert, f"{ae.args}"
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import unittest
import flax
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import maps
from utils import assert_allclose, is_fp8_supported
from transformer_engine.common.recipe import DelayedScaling
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.fp8 import FP8Helper
from transformer_engine.jax.sharding import infer_major_sharding_type
from transformer_engine.jax.sharding import MajorShardingType
from transformer_engine.jax.sharding import ShardingResource
class TestFP8Helper(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_initialize(self):
margin = 5.0
fp8_format = FP8Format.E4M3
update_fp8meta_interval = 10
amax_history_size = 10
FP8Helper.initialize(margin=margin,
fp8_format=fp8_format,
update_fp8meta_interval=update_fp8meta_interval,
amax_history_size=amax_history_size)
self.assertEqual(
FP8Helper.MARGIN, margin, f"FP8Helper.MARGIN initialization failed, should be {margin}"
f" but got {FP8Helper.MARGIN}.")
self.assertEqual(
FP8Helper.FP8_FORMAT, fp8_format,
f"FP8Helper.FP8_FORMAT initialization failed, should be {fp8_format}"
f" but got {FP8Helper.FP8_FORMAT}.")
self.assertEqual(
FP8Helper.UPDATE_FP8META_INTERVAL, update_fp8meta_interval,
"FP8Helper.UPDATE_FP8META_INTERVAL initialization failed, should be"
f"{update_fp8meta_interval} but got {FP8Helper.UPDATE_FP8META_INTERVAL}.")
self.assertEqual(
FP8Helper.AMAX_HISTORY_SIZE, amax_history_size,
f"FP8Helper.AMAX_HISTORY_SIZE initialization failed, should be {amax_history_size}"
f" but got {FP8Helper.AMAX_HISTORY_SIZE}.")
FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_update_fp8_metas(self):
FP8Helper.initialize(margin=3.0, amax_history_size=5)
seed = 0
key1, key2 = jax.random.split(jax.random.PRNGKey(seed))
num_of_gemm = 10
num_of_meta = FP8Helper.NUM_META_PER_GEMM * num_of_gemm
def get_fp8_scale(fp8_max, amax, scale):
fp8_max = np.array(fp8_max)
amax = np.array(amax)
scale = np.array(scale)
exp = np.floor(np.log2(fp8_max / amax)) - FP8Helper.MARGIN
sf = np.round(np.power(2, np.abs(exp)))
sf = np.where(amax > 0.0, sf, scale)
sf = np.where(np.isfinite(amax), sf, scale)
return np.where(exp < 0, 1 / sf, sf)
meta_shape = (num_of_meta, FP8Helper.AMAX_HISTORY_SIZE)
fp8_max_array = FP8Helper.generate_fp8_max_array(num_of_meta)
fp8_amax_array1 = jax.random.uniform(key1, shape=meta_shape)
fp8_scale_array1 = get_fp8_scale(fp8_max_array, fp8_amax_array1, jnp.ones(meta_shape))
fp8_scale_inv_array1 = 1 / fp8_scale_array1
fp8_amax_array2 = jax.random.uniform(key2, shape=meta_shape)
fp8_scale_array2 = get_fp8_scale(fp8_max_array, fp8_amax_array2, jnp.ones(meta_shape))
fp8_scale_inv_array2 = 1 / fp8_scale_array2
state = flax.core.frozen_dict.FrozenDict({
FP8Helper.FP8_COLLECTION_NAME: {
"test_update_fp8_metas1": {
FP8Helper.FP8_MAX_NAME: fp8_max_array,
FP8Helper.FP8_AMAX_NAME: fp8_amax_array1,
FP8Helper.FP8_SCALE_NAME: jnp.ones(meta_shape),
FP8Helper.FP8_SCALE_INV_NAME: jnp.ones(meta_shape)
},
"test_update_fp8_metas2": {
FP8Helper.FP8_MAX_NAME: fp8_max_array,
FP8Helper.FP8_AMAX_NAME: fp8_amax_array2,
FP8Helper.FP8_SCALE_NAME: jnp.ones(meta_shape),
FP8Helper.FP8_SCALE_INV_NAME: jnp.ones(meta_shape)
}
}
})
updated_state = FP8Helper.update_fp8_metas(state)
state_array, _ = jax.tree_util.tree_flatten(updated_state)
meta_per_gemm = FP8Helper.NUM_META_PER_GEMM + 1
scale_shift = 2
scale_inv_shift = 3
assert_allclose(state_array[0 * meta_per_gemm + scale_shift], fp8_scale_array1)
assert_allclose(state_array[0 * meta_per_gemm + scale_inv_shift], fp8_scale_inv_array1)
assert_allclose(state_array[1 * meta_per_gemm + scale_shift], fp8_scale_array2)
assert_allclose(state_array[1 * meta_per_gemm + scale_inv_shift], fp8_scale_inv_array2)
FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_generate_fp8_max_array(self):
num_of_meta = FP8Helper.NUM_META_PER_GEMM * 2
def get_ref(format_for_test):
refer_list = []
for i in range(num_of_meta):
val = format_for_test.value.max_bwd \
if i % FP8Helper.NUM_META_PER_GEMM == FP8Helper.GRAD_META_IDX_PER_GEMM \
else format_for_test.value.max_fwd
refer_list.append([val])
return jnp.asarray(refer_list)
for fp8_format in FP8Format:
FP8Helper.initialize(fp8_format=fp8_format)
assert_allclose(get_ref(fp8_format), FP8Helper.generate_fp8_max_array(num_of_meta))
FP8Helper.finalize()
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_update_collections(self):
original_val = 0.0
updated_val = 10.0
original_state = {
"test1": original_val,
"test2": original_val,
}
updated_state = FP8Helper.update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val)
original_state = flax.core.frozen_dict.FrozenDict(original_state)
updated_state = FP8Helper.update_collections({"test1": updated_val}, original_state)
self.assertEqual(updated_state["test1"], updated_val)
self.assertEqual(updated_state["test2"], original_val)
class TestFP8Functions(unittest.TestCase):
def _check_defult_state(self):
self.assertFalse(FP8Helper.enable_fp8())
self.assertEqual(infer_major_sharding_type(), MajorShardingType.SINGLE)
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_fp8_autocast(self):
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
self.assertFalse(FP8Helper.enable_fp8())
self._check_defult_state()
ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(FP8Helper.enable_fp8())
self.assertEqual(FP8Helper.MARGIN, ds.margin)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_SIZE, ds.amax_history_len)
self._check_defult_state()
ds = DelayedScaling(margin=3.0, interval=1, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds):
self.assertTrue(FP8Helper.enable_fp8())
self.assertEqual(FP8Helper.MARGIN, ds.margin)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_SIZE, ds.amax_history_len)
self._check_defult_state()
ds = DelayedScaling(amax_history_len=2)
with self.assertRaises(AssertionError):
with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling(amax_history_len=2)):
pass
@unittest.skipIf(not is_fp8_supported(), reason='GPU capability is not enough to run FP8')
def test_fp8_autocast_with_sharding_resource(self):
self._check_defult_state()
ds = DelayedScaling(margin=5.0, interval=3, fp8_format=FP8Format.E4M3, amax_history_len=1)
# TODO (Ming Huang): Suport multi-GPUs testing. # pylint: disable=fixme
# srs = (
# (ShardingResource(None, None), MajorShardingType.SINGLE),
# (ShardingResource('dp', None), MajorShardingType.DP),
# (ShardingResource(None, 'tp'), MajorShardingType.TP),
# (ShardingResource('dp', 'tp'), MajorShardingType.DPTP),
# )
srs = (
(ShardingResource(None, None), MajorShardingType.SINGLE),
(ShardingResource('dp', None), MajorShardingType.SINGLE),
(ShardingResource(None, 'tp'), MajorShardingType.SINGLE),
(ShardingResource('dp', 'tp'), MajorShardingType.SINGLE),
)
# TODO (Ming Huang): Suport multi-GPUs testing. # pylint: disable=fixme
mesh_shape = (1, 1)
devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
with maps.Mesh(devices, ('dp', 'tp')):
for sr, mst in srs:
with fp8_autocast(enabled=True, fp8_recipe=ds, sharding_resource=sr):
self.assertTrue(FP8Helper.enable_fp8())
self.assertEqual(FP8Helper.MARGIN, ds.margin)
self.assertEqual(FP8Helper.UPDATE_FP8META_INTERVAL, ds.interval)
self.assertEqual(FP8Helper.FP8_FORMAT, ds.fp8_format)
self.assertEqual(FP8Helper.AMAX_HISTORY_SIZE, ds.amax_history_len)
self.assertEqual(infer_major_sharding_type(), mst)
self._check_defult_state()
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import jax
import numpy as np
import pytest
from jax.experimental import maps
from transformer_engine.jax.sharding import get_dot_sharding_meta
from transformer_engine.jax.sharding import get_elementwise_sharding_meta
from transformer_engine.jax.sharding import get_fp8_meta_sharding_meta
from transformer_engine.jax.sharding import global_shard_guard
from transformer_engine.jax.sharding import infer_major_sharding_type
from transformer_engine.jax.sharding import is_dp_enabled, is_tp_enabled
from transformer_engine.jax.sharding import ShardingMeta, ShardingResource, ShardingType
def _get_sharding_resource(mesh_names, sharding_type):
dp_r = None
tp_r = None
if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
dp_r = mesh_names[0]
if sharding_type in (ShardingType.TP_COL, ShardingType.TP_ROW):
tp_r = mesh_names[0]
if sharding_type in (ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
tp_r = mesh_names[1]
return ShardingResource(dp_r, tp_r)
DEVICE_COUNT = 4
MESH_CONFIG = [((4,), ("dp",), ShardingType.DP), ((4,), ("tp",), ShardingType.TP_COL),
((4,), ("tp",), ShardingType.TP_ROW), ((2, 2), ("dp", "tp"), ShardingType.DP_TP_COL),
((2, 2), ("dp", "tp"), ShardingType.DP_TP_ROW)]
LOGICAL_RULES = [[(('a1', None), ('a2', 'ma2')), False],
[(('a1', None), ('a2', 'ma2'), ('a3', ('ma31', 'ma32'))), True],
[(('a1', None), ('a2', 'ma2'), ('batch', 'batch_1200234')), True]]
SRS = [
ShardingResource(),
ShardingResource('data', None),
ShardingResource(None, 'model'),
ShardingResource('data', 'model')
]
def is_devices_enough():
return len(jax.devices()) >= DEVICE_COUNT
class TestGeneralFunc:
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough')
def test_infer_major_sharding_type(
self,
mesh_shape, # pylint: disable=unused-argument
mesh_names,
sharding_type):
devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
with maps.Mesh(devices, mesh_names):
assert infer_major_sharding_type() is sharding_type.value[0]
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
def test_is_dp_enabled(
self,
mesh_shape, # pylint: disable=unused-argument
mesh_names, # pylint: disable=unused-argument
sharding_type):
if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
assert is_dp_enabled(sharding_type.value[0])
else:
assert not is_dp_enabled(sharding_type.value[0])
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
def test_is_tp_enabled(
self,
mesh_shape, # pylint: disable=unused-argument
mesh_names, # pylint: disable=unused-argument
sharding_type):
if sharding_type is ShardingType.DP:
assert not is_tp_enabled(sharding_type.value[0])
else:
assert is_tp_enabled(sharding_type.value[0])
class TestShardingMetaGenerator:
BATCH_AXIS_NAME = 'batch'
MODEL_AXIS_NAME = 'model'
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough')
def test_fp8_meta(self, mesh_shape, mesh_names, sharding_type, num_of_fp8_meta=4):
def stack_axes_meta(mapping):
return tuple(mapping for _ in range(num_of_fp8_meta))
def get_ref_sm():
if sharding_type == ShardingType.DP:
return ShardingMeta(stack_axes_meta({}), stack_axes_meta({}),
{TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0]}, (),
())
if sharding_type == ShardingType.TP_COL:
return ShardingMeta(stack_axes_meta({}), stack_axes_meta({}),
{TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]}, (),
())
if sharding_type == ShardingType.TP_ROW:
return ShardingMeta(stack_axes_meta({}), stack_axes_meta({}),
{TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]}, (),
())
if sharding_type == ShardingType.DP_TP_COL:
return ShardingMeta(
stack_axes_meta({}), stack_axes_meta({}), {
TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
}, (), ())
if sharding_type == ShardingType.DP_TP_ROW:
return ShardingMeta(
stack_axes_meta({}), stack_axes_meta({}), {
TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
}, (), ())
return None
devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
with maps.Mesh(devices, mesh_names):
test_sm = get_fp8_meta_sharding_meta(
sharding_type,
num_of_fp8_meta,
dp_axis_name=TestShardingMetaGenerator.BATCH_AXIS_NAME,
tp_axis_name=TestShardingMetaGenerator.MODEL_AXIS_NAME)
assert test_sm == get_ref_sm()
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
@pytest.mark.parametrize('a_shape, b_shape', [((64, 128, 256), (256, 512)),
((128, 64, 512), (512, 256))])
@pytest.mark.parametrize('batch_dim_of_a', [0, 1])
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough')
def test_dot(self, mesh_shape, mesh_names, sharding_type, a_shape, b_shape, batch_dim_of_a):
model_dim_of_a = len(a_shape) - 1
model_dim_of_b = 0 if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW) else 1
contracting_dims = ((-1,), (0,))
def get_ref_sm():
out_shape = (*a_shape[:min(contracting_dims[0])],
*b_shape[max(contracting_dims[1]) + 1:])
if sharding_type == ShardingType.DP:
a_new_shape = (*a_shape[:batch_dim_of_a], mesh_shape[0], -1,
*a_shape[batch_dim_of_a + 1:])
return ShardingMeta(({
batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME
}, {}), ({
batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME
}), {TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0]},
[a_new_shape, b_shape], [out_shape])
if sharding_type == ShardingType.TP_COL:
b_new_shape = (b_shape[0], mesh_shape[0], b_shape[1] // mesh_shape[0])
return ShardingMeta(({}, {
1: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), ({
len(out_shape) - 1: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]},
[a_shape, b_new_shape], [out_shape])
if sharding_type == ShardingType.TP_ROW:
a_new_shape = (*a_shape[:-1], mesh_shape[0], a_shape[-1] // mesh_shape[0])
b_new_shape = (mesh_shape[0], b_shape[0] // mesh_shape[0], b_shape[1])
return ShardingMeta(({
len(a_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
}, {
0: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), ({}), {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]},
[a_new_shape, b_new_shape], [out_shape])
if sharding_type == ShardingType.DP_TP_COL:
a_new_shape = (*a_shape[:batch_dim_of_a], mesh_shape[0],
a_shape[batch_dim_of_a] // mesh_shape[0],
*a_shape[batch_dim_of_a + 1:])
b_new_shape = (b_shape[0], mesh_shape[1], b_shape[1] // mesh_shape[1])
return ShardingMeta(
({
batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME
}, {
1: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), ({
batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME,
len(out_shape): TestShardingMetaGenerator.MODEL_AXIS_NAME
}), {
TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
}, [a_new_shape, b_new_shape], [out_shape])
if sharding_type == ShardingType.DP_TP_ROW:
a_new_shape = (*a_shape[:batch_dim_of_a], mesh_shape[0],
a_shape[batch_dim_of_a] // mesh_shape[0],
*a_shape[batch_dim_of_a + 1:-1], mesh_shape[1],
a_shape[-1] // mesh_shape[1])
b_new_shape = (mesh_shape[1], b_shape[0] // mesh_shape[1], b_shape[1])
return ShardingMeta(
({
batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME,
len(a_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
}, {
0: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), ({
batch_dim_of_a: TestShardingMetaGenerator.BATCH_AXIS_NAME
}), {
TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
}, [a_new_shape, b_new_shape], [out_shape])
return None
devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
with maps.Mesh(devices, mesh_names):
test_sm = get_dot_sharding_meta(
sharding_type,
a_shape,
b_shape,
batch_dim_of_a,
model_dim_of_a,
model_dim_of_b,
contracting_dims,
dp_axis_name=TestShardingMetaGenerator.BATCH_AXIS_NAME,
tp_axis_name=TestShardingMetaGenerator.MODEL_AXIS_NAME)
assert test_sm == get_ref_sm()
@pytest.mark.parametrize('mesh_shape,mesh_names,sharding_type', MESH_CONFIG)
@pytest.mark.parametrize('input_shape', [(64, 128, 256), (128, 64, 512)])
@pytest.mark.parametrize('other_shape', [(256,), (512,)])
@pytest.mark.parametrize('batch_dim', [0, 1])
@pytest.mark.skipif(not is_devices_enough(), reason='Num of GPU is not enough')
def test_elementwise(self, mesh_shape, mesh_names, sharding_type, input_shape, other_shape,
batch_dim):
def get_ref_sm():
need_assert = True
ref_sharding_meta = None
if input_shape[-1] != other_shape[0]:
need_assert = True
ref_sharding_meta = None
elif sharding_type is (ShardingType.DP_TP_COL, ShardingType.DP):
need_assert = False
input_new_shape = (*input_shape[:batch_dim], mesh_shape[0], -1,
*input_shape[batch_dim + 1:])
ref_sharding_meta = ShardingMeta(({
batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME
}, {}), ({
batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME
}), {TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0]},
[input_new_shape, other_shape], [input_shape])
elif sharding_type is ShardingType.TP_COL:
need_assert = False
ref_sharding_meta = ShardingMeta(({}, {}), ({}), {}, [input_shape, other_shape],
[input_shape])
elif sharding_type is ShardingType.TP_ROW:
need_assert = False
input_new_shape = (*input_shape[:-1], mesh_shape[0], -1)
other_new_shape = (mesh_shape[0], -1)
ref_sharding_meta = ShardingMeta(({
len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
}, {
0: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), ({
len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), {TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[0]},
[input_new_shape, other_new_shape], [input_shape])
elif sharding_type is ShardingType.DP_TP_ROW:
need_assert = False
input_new_shape = (*input_shape[:batch_dim], mesh_shape[0], -1,
*input_shape[batch_dim + 1:-1], mesh_shape[1],
input_shape[-1] // mesh_shape[1])
other_new_shape = (mesh_shape[0], -1)
ref_sharding_meta = ShardingMeta(
({
batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME,
len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
}, {
0: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), ({
batch_dim: TestShardingMetaGenerator.BATCH_AXIS_NAME,
len(input_new_shape) - 2: TestShardingMetaGenerator.MODEL_AXIS_NAME
}), {
TestShardingMetaGenerator.BATCH_AXIS_NAME: mesh_names[0],
TestShardingMetaGenerator.MODEL_AXIS_NAME: mesh_names[1]
}, [input_new_shape, other_new_shape], [input_shape])
return ref_sharding_meta, need_assert
devices = np.asarray(jax.devices()[:DEVICE_COUNT]).reshape(*mesh_shape)
with global_shard_guard(_get_sharding_resource(mesh_names, sharding_type)):
with maps.Mesh(devices, mesh_names):
ref_sm, need_assert = get_ref_sm()
try:
test_sm = get_elementwise_sharding_meta(
sharding_type,
input_shape,
other_shape,
batch_dim,
dp_axis_name=TestShardingMetaGenerator.BATCH_AXIS_NAME,
tp_axis_name=TestShardingMetaGenerator.MODEL_AXIS_NAME)
assert not need_assert
assert test_sm == ref_sm
except (NotImplementedError, AssertionError) as e:
assert need_assert, f"{e.args}"
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Any, Callable, Tuple, Union
import jax.numpy as jnp
import numpy as np
from cuda import cudart
from jax import lax
PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = Any
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
lax.Precision]]
Initializer = Callable[[PRNGKey, Shape, DType], Array]
def is_fp8_supported():
"""
Thus JAX doesn't have API to query capability
Use cuda-python for get the compute capability
"""
cudaSuccess = cudart.cudaError_t.cudaSuccess
ret, gpu_id = cudart.cudaGetDevice()
assert ret == cudaSuccess
flag = cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor
ret, sm_major = cudart.cudaDeviceGetAttribute(flag, gpu_id)
assert ret == cudaSuccess
return sm_major >= 9
def assert_allclose(actual,
desired,
rtol=1e-05,
atol=1e-08,
equal_nan=True,
err_msg='',
verbose=True):
if not isinstance(actual, float):
actual = actual.astype(jnp.float32)
if not isinstance(desired, float):
desired = desired.astype(jnp.float32)
np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
cmake_minimum_required(VERSION 3.18)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 70 80 90)
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
project(transformer_engine LANGUAGES CUDA CXX)
list(APPEND CMAKE_CUDA_FLAGS "--threads 4")
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G")
endif()
find_package(CUDAToolkit REQUIRED cublas nvToolsExt)
find_package(Python COMPONENTS Interpreter Development REQUIRED)
include_directories(${PROJECT_SOURCE_DIR})
add_subdirectory(common)
option(ENABLE_JAX "Enable JAX in the building workflow." OFF)
if(ENABLE_JAX)
find_package(pybind11 CONFIG REQUIRED)
add_subdirectory(jax)
endif()
......@@ -4,4 +4,14 @@
"""Top level package"""
from . import common
from . import pytorch
try:
from . import pytorch
except ImportError as e:
pass
try:
from . import jax
except ImportError as e:
pass
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
cmake_minimum_required(VERSION 3.18)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 70 80 90)
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
project(transformer_engine LANGUAGES CUDA CXX)
list(APPEND CMAKE_CUDA_FLAGS "--threads 4")
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
list(APPEND CMAKE_CUDA_FLAGS "-G")
endif()
add_library(transformer_engine SHARED
transformer_engine.cpp
transpose/cast_transpose.cu
......@@ -38,9 +19,7 @@ add_library(transformer_engine SHARED
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu)
target_include_directories(transformer_engine PUBLIC "${PROJECT_SOURCE_DIR}/include")
find_package(CUDAToolkit REQUIRED cublas nvToolsExt)
target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include")
list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart CUDA::nvToolsExt)
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
pybind11_add_module(
transformer_engine_jax
${CMAKE_CURRENT_SOURCE_DIR}/csrc/extensions.cpp
${CMAKE_CURRENT_SOURCE_DIR}/csrc/modules.cpp
)
target_link_libraries(transformer_engine_jax PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt transformer_engine)
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
from .fp8 import fp8_autocast
from .sharding import ShardingResource
This diff is collapsed.
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "common/include/transformer_engine/transformer_engine.h"
#include "jax/csrc/modules.h"
#include "jax/csrc/utils.h"
namespace transformer_engine {
namespace jax {
template <typename T>
pybind11::capsule EncapsulateFunction(T *fn) {
return pybind11::capsule(reinterpret_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
}
pybind11::dict Registrations() {
pybind11::dict dict;
dict["te_transpose"] = EncapsulateFunction(Transpose);
dict["te_cast_transpose"] = EncapsulateFunction(CastTranspose);
dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu);
dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8);
dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu);
dict["te_dgated_gelu_cast_transpose"] = EncapsulateFunction(DGatedGeluCastTranspose);
dict["te_gemm"] = EncapsulateFunction(Gemm);
dict["te_layernorm_forward"] = EncapsulateFunction(LayerNormForward);
dict["te_layernorm_forward_fp8"] = EncapsulateFunction(LayerNormForwardFP8);
dict["te_layernorm_backward"] = EncapsulateFunction(LayerNormBackward);
dict["te_rmsnorm_forward"] = EncapsulateFunction(RMSNormForward);
dict["te_rmsnorm_forward_fp8"] = EncapsulateFunction(RMSNormForwardFP8);
dict["te_rmsnorm_backward"] = EncapsulateFunction(RMSNormBackward);
dict["te_quantize"] = EncapsulateFunction(Quantize);
dict["te_dequantize"] = EncapsulateFunction(Dequantize);
dict["te_scaled_softmax_forward"] = EncapsulateFunction(ScaledSoftmaxForward);
dict["te_scaled_softmax_backward"] = EncapsulateFunction(ScaledSoftmaxBackward);
dict["te_scaled_masked_softmax_forward"] = EncapsulateFunction(ScaledMaskedSoftmaxForward);
dict["te_scaled_masked_softmax_backward"] = EncapsulateFunction(ScaledMaskedSoftmaxBackward);
dict["te_scaled_upper_triang_masked_softmax_forward"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward);
dict["te_scaled_upper_triang_masked_softmax_backward"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward);
return dict;
}
PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("registrations", &Registrations);
m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor);
m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor);
m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor);
m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor);
pybind11::enum_<DType>(m, "DType")
.value("kByte", DType::kByte)
.value("kInt32", DType::kInt32)
.value("kFloat32", DType::kFloat32)
.value("kFloat16", DType::kFloat16)
.value("kBFloat16", DType::kBFloat16)
.value("kFloat8E4M3", DType::kFloat8E4M3)
.value("kFloat8E5M2", DType::kFloat8E5M2);
}
} // namespace jax
} // namespace transformer_engine
This diff is collapsed.
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
#include <cuda_runtime_api.h>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "transformer_engine/logging.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
namespace jax {
constexpr int kMaxNumDim = 8;
struct Shape {
int num_dim;
size_t dims[kMaxNumDim];
void from_vector(const std::vector<size_t> &shape) {
num_dim = shape.size();
assert(num_dim <= kMaxNumDim);
std::memcpy(dims, shape.data(), num_dim * sizeof(size_t));
}
std::vector<size_t> to_vector() const {
assert(num_dim <= kMaxNumDim);
std::vector<size_t> shape(num_dim);
std::memcpy(shape.data(), dims, num_dim * sizeof(size_t));
return shape;
}
};
struct CustomCallCommonDescriptor {
Shape shape;
DType in_dtype;
DType out_dtype;
};
pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
DType out_dtype);
struct CustomCallGemmDescriptor {
size_t m;
size_t n;
size_t k;
DType A_dtype;
DType B_dtype;
DType D_dtype;
bool transa;
bool transb;
bool use_split_accumulator;
};
pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, DType A_dtype,
DType B_dtype, DType D_dtype, bool transa, bool transb,
bool use_split_accumulator);
struct CustomCallNormDescriptor {
size_t n;
size_t hidden;
DType x_dtype;
DType w_dtype;
float eps;
};
pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype,
float eps);
struct SoftmaxDescriptor {
size_t batch;
size_t pad_batch;
size_t heads;
size_t q_seqlen;
size_t k_seqlen;
DType dtype;
float scale_factor;
};
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, size_t heads,
size_t q_seqlen, size_t k_seqlen, DType dtype,
float scale_factor);
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
std::size_t opaque_len);
} // namespace jax
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#include <pybind11/pybind11.h>
#include <cstdint>
#include <stdexcept>
#include <string>
#include <type_traits>
#include "transformer_engine/logging.h"
namespace transformer_engine {
namespace jax {
class cublasLtMetaManager {
public:
static cublasLtMetaManager &Instance() {
static thread_local cublasLtMetaManager instance;
return instance;
}
cublasLtMetaManager() {}
~cublasLtMetaManager() { Clear_(); }
void *GetWorkspace(size_t size = 4194304) {
ReallocateIfNeed_(size);
return workspace_;
}
private:
void *workspace_ = nullptr;
size_t size_ = 0;
void Clear_() {
if (workspace_ != nullptr) {
NVTE_CHECK_CUDA(cudaFree(workspace_));
}
workspace_ = nullptr;
size_ = 0;
}
void Allocate_(size_t new_size) {
NVTE_CHECK_CUDA(cudaMalloc(&workspace_, new_size));
size_ = new_size;
}
void ReallocateIfNeed_(size_t new_size) {
if (new_size > size_) {
Clear_();
Allocate_(new_size);
}
}
};
class cudaDevicePropertiesManager {
public:
static cudaDevicePropertiesManager &Instance() {
static thread_local cudaDevicePropertiesManager instance;
return instance;
}
int GetMultiProcessorCount() {
if (!prop_queried_) {
int device_id;
NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
cudaGetDeviceProperties(&prop_, device_id);
prop_queried_ = true;
}
return prop_.multiProcessorCount;
}
private:
bool prop_queried_ = false;
cudaDeviceProp prop_;
};
} // namespace jax
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX te modules"""
from typing import Tuple, Sequence
from functools import partial, reduce
import operator
import jax
import jax.numpy as jnp
from transformer_engine_jax import DType as TEDType
from .cpp_extensions import cast_transpose, gemm, jax_dtype_to_te_dtype
from .fp8 import FP8Helper, FP8GemmPackage
from .sharding import ShardingType, get_dot_sharding_meta, get_fp8_meta_sharding_meta
from .sharding import is_dp_enabled, is_tp_enabled, merge_axis_resources
from .sharding import xmap_runner
jax.config.update('experimental_xmap_spmd_lowering', True)
jax.config.update('experimental_xmap_spmd_lowering_manual', True)
def fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
amax_history_idx: int,
fwd_dtype: TEDType,
bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0) -> jnp.ndarray:
"""
FP8 dot wrapper
"""
assert fp8_gemm_pkg.num_of_gemm == 1
inputs = fp8_gemm_pkg.inputs
kernel = fp8_gemm_pkg.kernels[0]
fp8_max = fp8_gemm_pkg.fp8_max
amax = fp8_gemm_pkg.amax
scale = fp8_gemm_pkg.scale
scale_inv = fp8_gemm_pkg.scale_inv
if sharding_type is ShardingType.SINGLE:
res = _fp8_dot(inputs,
kernel,
fp8_max,
amax,
scale,
scale_inv,
amax_history_idx=amax_history_idx,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name="",
tp_axis_name="")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
kernel_tp_index = None
# TODO (Ming Huang): Should we add a new argument to support general sharding to kernel? # pylint: disable=fixme
if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
kernel_tp_index = len(kernel.shape) - 1
elif sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
kernel_tp_index = 0
input_tp_index = len(inputs.shape) - 1
sharding_meta = get_dot_sharding_meta(sharding_type, inputs.shape, kernel.shape,
dp_dim_index, input_tp_index, kernel_tp_index,
contracting_dims, dp_axis_name, tp_axis_name)
inputs_ = jnp.reshape(inputs, sharding_meta.input_shapes[0]) # 0 for input
kernel_ = jnp.reshape(kernel, sharding_meta.input_shapes[1]) # 1 for kernel
num_of_fp8_meta_kind = 4 # fp8_max, amax, scale, scale_inv
fp8_sharding_meta = get_fp8_meta_sharding_meta(sharding_type, num_of_fp8_meta_kind,
dp_axis_name, tp_axis_name)
axis_resources = merge_axis_resources(
[sharding_meta.axis_resources, fp8_sharding_meta.axis_resources])
partial_fp8_dot = partial(_fp8_dot,
amax_history_idx=amax_history_idx,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
res = xmap_runner(partial_fp8_dot, (*sharding_meta.in_axes, *fp8_sharding_meta.in_axes),
sharding_meta.out_axes, axis_resources,
(inputs_, kernel_, fp8_max, amax, scale, scale_inv))
res = jnp.reshape(res, sharding_meta.output_shapes[0])
return res
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12))
def _fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, amax_history_idx: int, fwd_dtype: TEDType,
bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int], Sequence[int]],
sharding_type: ShardingType, dp_axis_name: str, tp_axis_name: str):
res, _ = _fp8_dot_fwd(inputs,
kernel,
fp8_maxs,
amax,
scale,
scale_inv,
amax_history_idx,
fwd_dtype,
bwd_dtype,
contracting_dims=contracting_dims,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name)
return res
def _fp8_dot_fwd(
inputs,
kernel,
fp8_maxs,
amax,
scale,
scale_inv,
amax_history_idx, # pylint: disable=unused-argument
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
contracting_dims,
sharding_type,
dp_axis_name, # pylint: disable=unused-argument
tp_axis_name):
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
input_shape_suf = inputs.shape[min(lhs_contracting_dims):]
kernel_shape_pre = kernel.shape[:max(rhs_contracting_dims) + 1]
kernel_shape_suf = kernel.shape[max(rhs_contracting_dims) + 1:]
input_contracting_size = reduce(operator.mul, input_shape_suf)
kernel_contracting_size = reduce(operator.mul, kernel_shape_pre)
assert input_contracting_size == kernel_contracting_size
inputs_ = jnp.reshape(inputs, (-1, input_contracting_size))
kernel_ = jnp.reshape(kernel, (kernel_contracting_size, -1))
gemm_input_idx, gemm_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0)
input_amax = amax[gemm_input_idx]
input_scale = scale[gemm_input_idx]
input_scale_inv = scale_inv[gemm_input_idx]
input_cast, input_cast_trans, input_amax = cast_transpose(inputs_, input_amax, input_scale,
input_scale_inv, fwd_dtype)
kernel_amax = amax[gemm_kernel_idx]
kernel_scale = scale[gemm_kernel_idx]
kernel_scale_inv = scale_inv[gemm_kernel_idx]
kernel_cast, kernel_cast_trans, kernel_amax = cast_transpose(kernel_, kernel_amax, kernel_scale,
kernel_scale_inv, fwd_dtype)
res = gemm(kernel_cast_trans, kernel_scale_inv, fwd_dtype, True, input_cast, input_scale_inv,
fwd_dtype, False, jax_dtype_to_te_dtype(inputs.dtype), FP8Helper.FP8_2X_ACC_FPROP)
if sharding_type in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW):
res = jax.lax.psum(res, tp_axis_name)
# (input_shape_pre, input_shape_suf)
# x (kernel_shape_pre, kernel_shape_suf)
# = (input_shape_pre, kernel_shape_suf)
output_shape = input_shape_pre + kernel_shape_suf
res = jnp.reshape(res, output_shape)
ctx = (input_cast_trans, kernel_cast, fp8_maxs, amax, scale, scale_inv, input_amax, kernel_amax,
inputs.shape, kernel.shape)
return res, ctx
def _fp8_dot_bwd(
amax_history_idx,
fwd_dtype,
bwd_dtype,
contracting_dims, # pylint: disable=unused-argument
sharding_type,
dp_axis_name,
tp_axis_name,
ctx,
g):
input_cast_trans, kernel_cast, \
fp8_maxs, amax, scale, scale_inv, \
input_amax, kernel_amax, \
inputs_shape, kernel_shape = ctx
gemm_input_idx, gemm_kernel_idx, gemm_grad_idx = FP8Helper.get_fp8_meta_indices(0)
grad_amax = amax[gemm_grad_idx]
grad_scale = scale[gemm_grad_idx]
grad_scale_inv = scale_inv[gemm_grad_idx]
g = jnp.reshape(g, (input_cast_trans.shape[1], -1))
grad_cast, grad_cast_trans, grad_amax = cast_transpose(g, grad_amax, grad_scale, grad_scale_inv,
bwd_dtype)
input_scale_inv = scale_inv[gemm_input_idx]
wgrad = gemm(grad_cast_trans, grad_scale_inv, bwd_dtype,
True, input_cast_trans, input_scale_inv, fwd_dtype, False,
jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_WGRAD)
kernel_scale_inv = scale_inv[gemm_kernel_idx]
dgrad = gemm(kernel_cast, kernel_scale_inv, fwd_dtype, True, grad_cast, grad_scale_inv,
bwd_dtype, False, jax_dtype_to_te_dtype(g.dtype), FP8Helper.FP8_2X_ACC_DGRAD)
amax = amax.at[gemm_input_idx, amax_history_idx].set(input_amax[0])
amax = amax.at[gemm_kernel_idx, amax_history_idx].set(kernel_amax[0])
amax = amax.at[gemm_grad_idx, amax_history_idx].set(grad_amax[0])
if is_dp_enabled(sharding_type.value[0]):
wgrad = jax.lax.psum(wgrad, dp_axis_name)
amax = jax.lax.pmax(amax, dp_axis_name)
if is_tp_enabled(sharding_type.value[0]):
amax = jax.lax.pmax(amax, tp_axis_name)
if sharding_type in (ShardingType.TP_COL, ShardingType.DP_TP_COL):
dgrad = jax.lax.psum(dgrad, tp_axis_name)
dgrad = jnp.reshape(dgrad, inputs_shape)
wgrad = jnp.reshape(wgrad, kernel_shape)
return dgrad, wgrad, fp8_maxs, amax, scale, scale_inv
_fp8_dot.defvjp(_fp8_dot_fwd, _fp8_dot_bwd)
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment