Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
...@@ -79,6 +79,7 @@ def calibrate(model, device, test_loader, fp8): ...@@ -79,6 +79,7 @@ def calibrate(model, device, test_loader, fp8):
with te.fp8_autocast(enabled=fp8, calibrating=True): with te.fp8_autocast(enabled=fp8, calibrating=True):
output = model(data) output = model(data)
def test(model, device, test_loader, use_fp8): def test(model, device, test_loader, use_fp8):
"""Testing function.""" """Testing function."""
model.eval() model.eval()
...@@ -89,12 +90,8 @@ def test(model, device, test_loader, use_fp8): ...@@ -89,12 +90,8 @@ def test(model, device, test_loader, use_fp8):
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=use_fp8): with te.fp8_autocast(enabled=use_fp8):
output = model(data) output = model(data)
test_loss += F.nll_loss( test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
output, target, reduction="sum" pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item() correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset) test_loss /= len(test_loader.dataset)
...@@ -150,9 +147,7 @@ def main(): ...@@ -150,9 +147,7 @@ def main():
default=False, default=False,
help="quickly check a single pass", help="quickly check a single pass",
) )
parser.add_argument( parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument( parser.add_argument(
"--log-interval", "--log-interval",
type=int, type=int,
...@@ -167,7 +162,10 @@ def main(): ...@@ -167,7 +162,10 @@ def main():
help="For Saving the current Model", help="For Saving the current Model",
) )
parser.add_argument( parser.add_argument(
"--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration" "--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration",
) )
parser.add_argument( parser.add_argument(
"--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only" "--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only"
...@@ -215,7 +213,7 @@ def main(): ...@@ -215,7 +213,7 @@ def main():
if args.save_model or args.use_fp8_infer: if args.save_model or args.use_fp8_infer:
torch.save(model.state_dict(), "mnist_cnn.pt") torch.save(model.state_dict(), "mnist_cnn.pt")
print('Eval with reloaded checkpoint : fp8='+str(args.use_fp8_infer)) print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8_infer))
weights = torch.load("mnist_cnn.pt") weights = torch.load("mnist_cnn.pt")
model.load_state_dict(weights) model.load_state_dict(weights)
test(model, device, test_loader, args.use_fp8_infer) test(model, device, test_loader, args.use_fp8_infer)
......
...@@ -18,21 +18,26 @@ path = sys.argv[1] ...@@ -18,21 +18,26 @@ path = sys.argv[1]
config_path = os.path.dirname(os.path.realpath(__file__)) + "/config.json" config_path = os.path.dirname(os.path.realpath(__file__)) + "/config.json"
class bcolors: class bcolors:
OKGREEN = '\033[92m' OKGREEN = "\033[92m"
WARNING = '\033[93m' WARNING = "\033[93m"
FAIL = '\033[91m' FAIL = "\033[91m"
ENDC = '\033[0m' ENDC = "\033[0m"
def print_ok(msg): def print_ok(msg):
print(f"{bcolors.OKGREEN}{msg}{bcolors.ENDC}") print(f"{bcolors.OKGREEN}{msg}{bcolors.ENDC}")
def print_fail(msg): def print_fail(msg):
print(f"{bcolors.FAIL}{msg}{bcolors.ENDC}") print(f"{bcolors.FAIL}{msg}{bcolors.ENDC}")
def print_warn(msg): def print_warn(msg):
print(f"{bcolors.WARNING}{msg}{bcolors.ENDC}") print(f"{bcolors.WARNING}{msg}{bcolors.ENDC}")
with open(config_path, "r") as f: with open(config_path, "r") as f:
c = json.load(f) c = json.load(f)
current_year = datetime.date.today().year current_year = datetime.date.today().year
...@@ -41,7 +46,7 @@ with open(config_path, "r") as f: ...@@ -41,7 +46,7 @@ with open(config_path, "r") as f:
else: else:
year_string = str(c["initial_year"]) + "-" + str(current_year) year_string = str(c["initial_year"]) + "-" + str(current_year)
copyright_string = c["copyright"].replace("<YEAR>", year_string) copyright_string = c["copyright"].replace("<YEAR>", year_string)
license = c["license"].split('\n') license = c["license"].split("\n")
excludes = c["exclude"] excludes = c["exclude"]
root_path = os.path.abspath(path) root_path = os.path.abspath(path)
copyright_only = c["copyright_only"] copyright_only = c["copyright_only"]
...@@ -49,21 +54,25 @@ with open(config_path, "r") as f: ...@@ -49,21 +54,25 @@ with open(config_path, "r") as f:
has_gitignore = os.path.exists(root_path + "/.gitignore") has_gitignore = os.path.exists(root_path + "/.gitignore")
def strip_star_slash(s): def strip_star_slash(s):
ret = s ret = s
if ret.startswith('*'): if ret.startswith("*"):
ret = ret[1:] ret = ret[1:]
if ret.endswith('/'): if ret.endswith("/"):
ret = ret[:-1] ret = ret[:-1]
return ret return ret
if has_gitignore: if has_gitignore:
with open(root_path + "/.gitignore", "r") as f: with open(root_path + "/.gitignore", "r") as f:
for line in f.readlines(): for line in f.readlines():
excludes.append(strip_star_slash(line.strip())) excludes.append(strip_star_slash(line.strip()))
def get_file_type(path): def get_file_type(path):
ext = {"c": ["c", "cpp", "cu", "h", "cuh"], ext = {
"c": ["c", "cpp", "cu", "h", "cuh"],
"py": ["py"], "py": ["py"],
"rst": ["rst"], "rst": ["rst"],
"txt": ["txt"], "txt": ["txt"],
...@@ -77,8 +86,10 @@ def get_file_type(path): ...@@ -77,8 +86,10 @@ def get_file_type(path):
return filetype return filetype
return "unknown" return "unknown"
success = True success = True
def check_file(path): def check_file(path):
global success global success
N = 10 N = 10
...@@ -127,9 +138,10 @@ def check_file(path): ...@@ -127,9 +138,10 @@ def check_file(path):
if copyright_found and license_found: if copyright_found and license_found:
print_ok("OK") print_ok("OK")
for root, dirs, files in os.walk(root_path): for root, dirs, files in os.walk(root_path):
print(f"Entering {root}") print(f"Entering {root}")
hidden = [d for d in dirs if d.startswith('.')] + [f for f in files if f.startswith('.')] hidden = [d for d in dirs if d.startswith(".")] + [f for f in files if f.startswith(".")]
all_excludes = excludes + hidden all_excludes = excludes + hidden
to_remove = [] to_remove = []
for d in dirs: for d in dirs:
......
...@@ -27,12 +27,13 @@ current_file_path = Path(__file__).parent.resolve() ...@@ -27,12 +27,13 @@ current_file_path = Path(__file__).parent.resolve()
from setuptools.command.build_ext import build_ext as BuildExtension from setuptools.command.build_ext import build_ext as BuildExtension
if "pytorch" in frameworks: if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks: elif "paddle" in frameworks:
from paddle.utils.cpp_extension import BuildExtension from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks: elif "jax" in frameworks:
install_and_import('pybind11') install_and_import("pybind11")
from pybind11.setup_helpers import build_ext as BuildExtension from pybind11.setup_helpers import build_ext as BuildExtension
...@@ -86,34 +87,45 @@ if __name__ == "__main__": ...@@ -86,34 +87,45 @@ if __name__ == "__main__":
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension from build_tools.pytorch import setup_pytorch_extension
ext_modules.append( ext_modules.append(
setup_pytorch_extension( setup_pytorch_extension(
"transformer_engine/pytorch/csrc", "transformer_engine/pytorch/csrc",
current_file_path / "transformer_engine" / "pytorch" / "csrc", current_file_path / "transformer_engine" / "pytorch" / "csrc",
current_file_path / "transformer_engine")) current_file_path / "transformer_engine",
)
)
if "jax" in frameworks: if "jax" in frameworks:
from build_tools.jax import setup_jax_extension from build_tools.jax import setup_jax_extension
ext_modules.append( ext_modules.append(
setup_jax_extension( setup_jax_extension(
"transformer_engine/jax/csrc", "transformer_engine/jax/csrc",
current_file_path / "transformer_engine" / "jax" / "csrc", current_file_path / "transformer_engine" / "jax" / "csrc",
current_file_path / "transformer_engine")) current_file_path / "transformer_engine",
)
)
if "paddle" in frameworks: if "paddle" in frameworks:
from build_tools.paddle import setup_paddle_extension from build_tools.paddle import setup_paddle_extension
ext_modules.append( ext_modules.append(
setup_paddle_extension( setup_paddle_extension(
"transformer_engine/paddle/csrc", "transformer_engine/paddle/csrc",
current_file_path / "transformer_engine" / "paddle" / "csrc", current_file_path / "transformer_engine" / "paddle" / "csrc",
current_file_path / "transformer_engine")) current_file_path / "transformer_engine",
)
)
# Configure package # Configure package
setuptools.setup( setuptools.setup(
name="transformer_engine", name="transformer_engine",
version=__version__, version=__version__,
packages=setuptools.find_packages( packages=setuptools.find_packages(
include=["transformer_engine", include=[
"transformer_engine",
"transformer_engine.*", "transformer_engine.*",
"transformer_engine/build_tools"], "transformer_engine/build_tools",
],
), ),
extras_require={ extras_require={
"test": test_requires, "test": test_requires,
...@@ -125,5 +137,5 @@ if __name__ == "__main__": ...@@ -125,5 +137,5 @@ if __name__ == "__main__":
install_requires=install_requires, install_requires=install_requires,
license_files=("LICENSE",), license_files=("LICENSE",),
include_package_data=True, include_package_data=True,
package_data={"": ["VERSION.txt"]} package_data={"": ["VERSION.txt"]},
) )
...@@ -6,7 +6,7 @@ import jax ...@@ -6,7 +6,7 @@ import jax
import pytest import pytest
@pytest.fixture(autouse=True, scope='function') @pytest.fixture(autouse=True, scope="function")
def clear_live_arrays(): def clear_live_arrays():
""" """
Clear all live arrays to keep the resource clean Clear all live arrays to keep the resource clean
......
...@@ -16,15 +16,15 @@ from utils import assert_allclose, is_devices_enough ...@@ -16,15 +16,15 @@ from utils import assert_allclose, is_devices_enough
def generate_configs(): def generate_configs():
configs = [] configs = []
if is_devices_enough(2): if is_devices_enough(2):
configs.append([2, (2,), ('dp'), MeshResource(dp_resource='dp')]) configs.append([2, (2,), "dp", MeshResource(dp_resource="dp")])
configs.append([2, (2,), ('tp'), MeshResource(tp_resource='tp')]) configs.append([2, (2,), "tp", MeshResource(tp_resource="tp")])
if is_devices_enough(4): if is_devices_enough(4):
TP_size = 2 TP_size = 2
DP_size = 2 DP_size = 2
configs.append( configs.append(
[4, (DP_size, TP_size), ('dp', 'tp'), [4, (DP_size, TP_size), ("dp", "tp"), MeshResource(dp_resource="dp", tp_resource="tp")]
MeshResource(dp_resource='dp', tp_resource='tp')]) )
return configs return configs
...@@ -46,7 +46,7 @@ def assert_equal_collectives(target_hlo, coll_count_ref): ...@@ -46,7 +46,7 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
bytes_count = 0 bytes_count = 0
def get_bytes_per_txt(t): def get_bytes_per_txt(t):
''' """
The pattern of t would be like: The pattern of t would be like:
'f32[]', 'f32[]',
'(f32[1024]{0}', '(f32[1024]{0}',
...@@ -54,22 +54,22 @@ def assert_equal_collectives(target_hlo, coll_count_ref): ...@@ -54,22 +54,22 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
'f8E4M3FN[1024]{0}', 'f8E4M3FN[1024]{0}',
'i32[1024]{0}', 'i32[1024]{0}',
'bf16[1024,1024]{0}' 'bf16[1024,1024]{0}'
''' """
match = re.search(r'(i|f)(\d+).*\[([0-9,]*)\]', t) match = re.search(r"(i|f)(\d+).*\[([0-9,]*)\]", t)
_, bits_of_type, shape = match.groups() _, bits_of_type, shape = match.groups()
bytes_of_type = int(bits_of_type) // 8 bytes_of_type = int(bits_of_type) // 8
if shape == '': if shape == "":
num_of_elements = 1 num_of_elements = 1
else: else:
num_of_elements = reduce(operator.mul, map(int, shape.split(','))) num_of_elements = reduce(operator.mul, map(int, shape.split(",")))
return bytes_of_type * num_of_elements return bytes_of_type * num_of_elements
# ['xxx-start', '=', '(bf16[xxx]', 'bf16[xxx])', 'xxx-start(', ...] # ['xxx-start', '=', '(bf16[xxx]', 'bf16[xxx])', 'xxx-start(', ...]
if '(' in hlo_text[2]: if "(" in hlo_text[2]:
for txt in hlo_text[2:]: for txt in hlo_text[2:]:
bytes_count += get_bytes_per_txt(txt) bytes_count += get_bytes_per_txt(txt)
if ')' in txt: if ")" in txt:
break break
else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...] else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...]
bytes_count = get_bytes_per_txt(hlo_text[2]) bytes_count = get_bytes_per_txt(hlo_text[2])
...@@ -91,11 +91,13 @@ def assert_equal_collectives(target_hlo, coll_count_ref): ...@@ -91,11 +91,13 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
return result return result
target_result = count_collectives(target_splitted_hlo) target_result = count_collectives(target_splitted_hlo)
assert target_result == coll_count_ref, \ assert (
f"Expected collective count is {coll_count_ref}, but got {target_result}." target_result == coll_count_ref
), f"Expected collective count is {coll_count_ref}, but got {target_result}."
def compare_ops(target_func, def compare_ops(
target_func,
ref_func, ref_func,
inputs, inputs,
coll_count_ref, coll_count_ref,
...@@ -105,7 +107,8 @@ def compare_ops(target_func, ...@@ -105,7 +107,8 @@ def compare_ops(target_func,
metric_bwd_dtype=None, metric_bwd_dtype=None,
in_shardings=_UNSPECIFIED, in_shardings=_UNSPECIFIED,
out_shardings=_UNSPECIFIED, out_shardings=_UNSPECIFIED,
**kwargs): **kwargs,
):
assert len(inputs) >= 1 assert len(inputs) >= 1
if metric_fwd_dtype is None: if metric_fwd_dtype is None:
......
...@@ -15,24 +15,10 @@ from jax import jit, value_and_grad ...@@ -15,24 +15,10 @@ from jax import jit, value_and_grad
from flax import linen as nn from flax import linen as nn
from utils import assert_allclose from utils import assert_allclose
from transformer_engine.jax.dot import ( from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
type_safe_dot_general, from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available
dequantize, from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
quantize from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp
)
from transformer_engine.jax.fp8 import (
FP8MetaPackage,
FP8Helper,
is_fp8_available
)
from transformer_engine.jax.layernorm import (
layernorm,
layernorm_fp8_dot
)
from transformer_engine.jax.layernorm_mlp import (
activation_lu,
fused_layernorm_fp8_mlp
)
from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax import cpp_extensions as tex
GEMM_CASES = [ GEMM_CASES = [
...@@ -50,11 +36,11 @@ is_fp8_supported, reason = is_fp8_available() ...@@ -50,11 +36,11 @@ is_fp8_supported, reason = is_fp8_available()
def _convert_to_activation_function(fn_or_string): def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function.""" """Convert a string to an activation function."""
if fn_or_string == 'linear': if fn_or_string == "linear":
return lambda x: x return lambda x: x
if fn_or_string == 'quick_gelu': if fn_or_string == "quick_gelu":
return lambda x: nn.gelu(x, approximate=True) return lambda x: nn.gelu(x, approximate=True)
if fn_or_string == 'squared_relu': if fn_or_string == "squared_relu":
return lambda x: functools.reduce(operator.mul, [nn.relu(x), nn.relu(x)]) return lambda x: functools.reduce(operator.mul, [nn.relu(x), nn.relu(x)])
if isinstance(fn_or_string, str): if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string) return getattr(nn, fn_or_string)
...@@ -93,7 +79,7 @@ class TestFP8Dot: ...@@ -93,7 +79,7 @@ class TestFP8Dot:
assert_allclose(z, x, dtype=jnp.float8_e4m3fn) assert_allclose(z, x, dtype=jnp.float8_e4m3fn)
@pytest.mark.parametrize('m,n,k', GEMM_CASES) @pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_forward_bf16(self, m, n, k): def test_forward_bf16(self, m, n, k):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
...@@ -106,7 +92,7 @@ class TestFP8Dot: ...@@ -106,7 +92,7 @@ class TestFP8Dot:
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES) @pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_forward_fp8_randint(self, m, n, k): def test_forward_fp8_randint(self, m, n, k):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
...@@ -135,7 +121,7 @@ class TestFP8Dot: ...@@ -135,7 +121,7 @@ class TestFP8Dot:
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE) assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
@pytest.mark.parametrize('m,n,k', GEMM_CASES) @pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_grad_bf16(self, m, n, k): def test_grad_bf16(self, m, n, k):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
...@@ -161,7 +147,7 @@ class TestFP8Dot: ...@@ -161,7 +147,7 @@ class TestFP8Dot:
assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16) assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES) @pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_grad_fp8_dot(self, m, n, k): def test_grad_fp8_dot(self, m, n, k):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
...@@ -192,33 +178,38 @@ class TestFP8Dot: ...@@ -192,33 +178,38 @@ class TestFP8Dot:
ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b) ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)
for _ in range(3): for _ in range(3):
primitive_out, (primitive_a_grad, primitive_b_grad, amax_list, primitive_out, (primitive_a_grad, primitive_b_grad, amax_list, scale_list) = (
scale_list) = value_n_grad_primitive_func(a, b, amax_list, scale_list) value_n_grad_primitive_func(a, b, amax_list, scale_list)
)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE) assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE) assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE) assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 128, 512), @pytest.mark.parametrize(
(16384, 1024, 2816), "m,n,k", [(256, 128, 512), (16384, 1024, 2816), (16384, 2816, 1024), (16384, 1024, 1024)]
(16384, 2816, 1024), )
(16384, 1024, 1024)]) @pytest.mark.parametrize(
@pytest.mark.parametrize('activation_type', [('gelu', ), "activation_type",
('gelu', 'linear'), [
('silu', ), ("gelu",),
('silu', 'linear'), ("gelu", "linear"),
('relu',), ("silu",),
('relu', 'linear'), ("silu", "linear"),
('quick_gelu',), ("relu",),
('quick_gelu', 'linear'), ("relu", "linear"),
('squared_relu',), ("quick_gelu",),
('squared_relu', 'linear')]) ("quick_gelu", "linear"),
@pytest.mark.parametrize('use_bias', [True, False]) ("squared_relu",),
def test_grad_fused_layernorm_fp8_mlp(self, m, n, k, activation_type: Sequence[Union[str, ("squared_relu", "linear"),
Callable]], ],
use_bias: bool): )
""" N/a """ @pytest.mark.parametrize("use_bias", [True, False])
def test_grad_fused_layernorm_fp8_mlp(
self, m, n, k, activation_type: Sequence[Union[str, Callable]], use_bias: bool
):
"""N/a"""
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 6) subkeys = jax.random.split(key, 6)
...@@ -233,8 +224,9 @@ class TestFP8Dot: ...@@ -233,8 +224,9 @@ class TestFP8Dot:
b1 = None b1 = None
b2 = None b2 = None
def primitive_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, def primitive_func(
scale_list_2): x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2
):
# x is input tensor, matrix 2d # x is input tensor, matrix 2d
# y, z are weights, matrix 2d # y, z are weights, matrix 2d
# out = ((x * y) + w) * z + v # out = ((x * y) + w) * z + v
...@@ -255,18 +247,31 @@ class TestFP8Dot: ...@@ -255,18 +247,31 @@ class TestFP8Dot:
scale_list_2[2], scale_list_2[2],
) )
return jnp.mean( return jnp.mean(
fused_layernorm_fp8_mlp(x, fused_layernorm_fp8_mlp(
x,
ln_s, ln_s,
None, [y, z], [w, v], [fp8_meta_pkg_1, fp8_meta_pkg_2], None,
[y, z],
[w, v],
[fp8_meta_pkg_1, fp8_meta_pkg_2],
"rmsnorm", "rmsnorm",
activation_type=activation_type, activation_type=activation_type,
use_bias=use_bias)) use_bias=use_bias,
)
)
def layernorm_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray, def layernorm_fp8_mlp_ref(
kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray, x: jnp.ndarray,
amax_list_1: List[jnp.ndarray], amax_list_2: List[jnp.ndarray], ln_scale: jnp.ndarray,
kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray,
bias_1: jnp.ndarray,
bias_2: jnp.ndarray,
amax_list_1: List[jnp.ndarray],
amax_list_2: List[jnp.ndarray],
scale_list_1: List[jnp.ndarray], scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray]) -> jnp.ndarray: scale_list_2: List[jnp.ndarray],
) -> jnp.ndarray:
x = jnp.asarray(x, jnp.float32) x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
...@@ -315,11 +320,14 @@ class TestFP8Dot: ...@@ -315,11 +320,14 @@ class TestFP8Dot:
def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2): def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2):
return jnp.mean( return jnp.mean(
layernorm_fp8_mlp_ref(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, layernorm_fp8_mlp_ref(
scale_list_2)) x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2
)
)
value_n_grad_primitive_func = jit( value_n_grad_primitive_func = jit(
value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9))) value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
)
value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9))) value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
_, amax_list_1, scale_list_1 = TestFP8Dot._generate_fp8_meta() _, amax_list_1, scale_list_1 = TestFP8Dot._generate_fp8_meta()
...@@ -339,40 +347,87 @@ class TestFP8Dot: ...@@ -339,40 +347,87 @@ class TestFP8Dot:
# Convert str to index as str is not a valid type for JAX JIT # Convert str to index as str is not a valid type for JAX JIT
for _ in range(3): for _ in range(3):
ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad, ref_out, (
ref_amax_list_1, ref_amax_list_2, ref_scale_list_1, ref_a_grad,
ref_scale_list_2) = value_n_grad_ref_func(a, s, k1, k2, b1, b2, ref_s_grad,
ref_amax_list_1, ref_amax_list_2, ref_k1_grad,
ref_scale_list_1, ref_scale_list_2) ref_k2_grad,
ref_b1_grad,
ref_b2_grad,
ref_amax_list_1,
ref_amax_list_2,
ref_scale_list_1,
ref_scale_list_2,
) = value_n_grad_ref_func(
a,
s,
k1,
k2,
b1,
b2,
ref_amax_list_1,
ref_amax_list_2,
ref_scale_list_1,
ref_scale_list_2,
)
for _ in range(3): for _ in range(3):
primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad, primitive_out, (
primitive_k2_grad, primitive_b1_grad, primitive_b2_grad, primitive_a_grad,
primitive_amax_list_1, primitive_amax_list_2, primitive_scale_list_1, primitive_s_grad,
primitive_scale_list_2) = value_n_grad_primitive_func( primitive_k1_grad,
a, s, k1, k2, b1, b2, primitive_amax_list_1, primitive_amax_list_2, primitive_k2_grad,
primitive_scale_list_1, primitive_scale_list_2) primitive_b1_grad,
primitive_b2_grad,
primitive_amax_list_1,
primitive_amax_list_2,
primitive_scale_list_1,
primitive_scale_list_2,
) = value_n_grad_primitive_func(
a,
s,
k1,
k2,
b1,
b2,
primitive_amax_list_1,
primitive_amax_list_2,
primitive_scale_list_1,
primitive_scale_list_2,
)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE) assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(jnp.asarray(primitive_a_grad, np.float32), assert_allclose(
jnp.asarray(primitive_a_grad, np.float32),
jnp.asarray(ref_a_grad, np.float32), jnp.asarray(ref_a_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE) dtype=FP8Helper.BWD_DTYPE,
assert_allclose(jnp.asarray(primitive_k1_grad, np.float32), )
assert_allclose(
jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_k1_grad, np.float32), jnp.asarray(ref_k1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE) dtype=FP8Helper.BWD_DTYPE,
assert_allclose(jnp.asarray(primitive_s_grad, np.float32), )
assert_allclose(
jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32), jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE) dtype=FP8Helper.BWD_DTYPE,
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32), )
assert_allclose(
jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32), jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE) dtype=FP8Helper.BWD_DTYPE,
)
if use_bias: if use_bias:
assert_allclose(jnp.asarray(primitive_b2_grad, np.float32), assert_allclose(
jnp.asarray(primitive_b2_grad, np.float32),
jnp.asarray(ref_b2_grad, np.float32), jnp.asarray(ref_b2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE) dtype=FP8Helper.BWD_DTYPE,
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32), )
assert_allclose(
jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32), jnp.asarray(ref_b1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE) dtype=FP8Helper.BWD_DTYPE,
)
@pytest.fixture(name="random_inputs") @pytest.fixture(name="random_inputs")
...@@ -402,17 +457,22 @@ class TestActivationLu: ...@@ -402,17 +457,22 @@ class TestActivationLu:
def primitive_func(self, inputs): def primitive_func(self, inputs):
return jnp.mean(activation_lu(inputs, activation_type=self.activation_type)) return jnp.mean(activation_lu(inputs, activation_type=self.activation_type))
@pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)]) @pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)])
@pytest.mark.parametrize('activation_type', [('gelu',), @pytest.mark.parametrize(
('gelu', 'linear'), "activation_type",
('silu',), [
('silu', 'linear'), ("gelu",),
('relu',), ("gelu", "linear"),
('relu', 'linear'), ("silu",),
('quick_gelu',), ("silu", "linear"),
('quick_gelu', 'linear'), ("relu",),
('squared_relu',), ("relu", "linear"),
('squared_relu', 'linear') ]) ("quick_gelu",),
("quick_gelu", "linear"),
("squared_relu",),
("squared_relu", "linear"),
],
)
def test_activation_lu(self, random_inputs, activation_type): def test_activation_lu(self, random_inputs, activation_type):
x = random_inputs x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=1) x = jnp.repeat(x, len(activation_type), axis=1)
...@@ -441,23 +501,34 @@ class TestActivationLuFP8(TestActivationLu): ...@@ -441,23 +501,34 @@ class TestActivationLuFP8(TestActivationLu):
return output return output
def _prim_func_fwd(x, _x_t, _dbias, _amax): def _prim_func_fwd(x, _x_t, _dbias, _amax):
activation_lu_out, _ = tex.act_lu_fp8(x, amax, scale, scale_inv, activation_lu_out, _ = tex.act_lu_fp8(
FP8Helper.FWD_DTYPE, activation_type) x, amax, scale, scale_inv, FP8Helper.FWD_DTYPE, activation_type
)
activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv) activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
ctx = (x) ctx = x
return activation_lu_out, ctx return activation_lu_out, ctx
def _prim_func_bwd(ctx, g): def _prim_func_bwd(ctx, g):
x = ctx x = ctx
if len(self.activation_type) > 1: #gated, no bias if len(self.activation_type) > 1: # gated, no bias
dactivation_lu, dactivation_lu_trans, amax_out = \ dactivation_lu, dactivation_lu_trans, amax_out = tex.dgated_act_lu_cast_transpose(
tex.dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv, g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE, -1, activation_type
FP8Helper.BWD_DTYPE, -1, activation_type) )
dbias = jnp.empty(x.shape[-1], x.dtype) dbias = jnp.empty(x.shape[-1], x.dtype)
else: #not gated, with bias else: # not gated, with bias
dactivation_lu, dactivation_lu_trans, dbias, amax_out = \ dactivation_lu, dactivation_lu_trans, dbias, amax_out = (
tex.dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE, tex.dact_lu_dbias_cast_transpose(
-1, -2, self.activation_type) g,
x,
amax,
scale,
scale_inv,
FP8Helper.BWD_DTYPE,
-1,
-2,
self.activation_type,
)
)
dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv) dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv) dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out) ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out)
...@@ -468,23 +539,28 @@ class TestActivationLuFP8(TestActivationLu): ...@@ -468,23 +539,28 @@ class TestActivationLuFP8(TestActivationLu):
dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_indices], dtype=x.dtype) dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_indices], dtype=x.dtype)
dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype) dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
amax_no_use = jnp.zeros(1, jnp.float32) amax_no_use = jnp.zeros(1, jnp.float32)
value_n_grad_primitive_func = value_and_grad(lambda a, b, c, d: value_n_grad_primitive_func = value_and_grad(
jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3)) lambda a, b, c, d: jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3)
)
return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use) return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)]) @pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)])
@pytest.mark.parametrize('activation_type', [('gelu',), @pytest.mark.parametrize(
('gelu', 'linear'), "activation_type",
('silu',), [
('silu', 'linear'), ("gelu",),
('relu',), ("gelu", "linear"),
('relu', 'linear'), ("silu",),
('quick_gelu',), ("silu", "linear"),
('quick_gelu', 'linear'), ("relu",),
('squared_relu',), ("relu", "linear"),
('squared_relu', 'linear') ]) ("quick_gelu",),
("quick_gelu", "linear"),
("squared_relu",),
("squared_relu", "linear"),
],
)
def test_activation_lu(self, random_inputs, activation_type): def test_activation_lu(self, random_inputs, activation_type):
self.amax = jnp.zeros(1, jnp.float32) self.amax = jnp.zeros(1, jnp.float32)
self.scale = jnp.ones(1, jnp.float32) self.scale = jnp.ones(1, jnp.float32)
...@@ -500,12 +576,14 @@ class TestActivationLuFP8(TestActivationLu): ...@@ -500,12 +576,14 @@ class TestActivationLuFP8(TestActivationLu):
assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE) assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2) assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
if 'linear' not in activation_type: if "linear" not in activation_type:
assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1)))) assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE) assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(prim_grad_trans, assert_allclose(
prim_grad_trans,
jnp.transpose(ref_grad, self.transpose_indices), jnp.transpose(ref_grad, self.transpose_indices),
dtype=FP8Helper.BWD_DTYPE) dtype=FP8Helper.BWD_DTYPE,
)
class TestNorm: class TestNorm:
...@@ -536,34 +614,38 @@ class TestNorm: ...@@ -536,34 +614,38 @@ class TestNorm:
""" """
x_ = jnp.asarray(x, jnp.float32) x_ = jnp.asarray(x, jnp.float32)
if bias is None: if bias is None:
mean = 0. mean = 0.0
else: else:
mean = jnp.mean(x_, axis=-1, keepdims=True) mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps) normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
if zero_centered_gamma: if zero_centered_gamma:
scale += 1. scale += 1.0
if bias is None: if bias is None:
bias = 0. bias = 0.0
return jnp.asarray(normed_input * scale + bias).astype(x.dtype) return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
@pytest.mark.parametrize('n, hidden', LN_CASES) @pytest.mark.parametrize("n, hidden", LN_CASES)
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm']) @pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"])
@pytest.mark.parametrize('zero_centered_gamma', [False, True]) @pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize('epsilon', [1e-2, 1e-6]) @pytest.mark.parametrize("epsilon", [1e-2, 1e-6])
def test_layernorm_forward_backward(self, n, hidden, ln_type, zero_centered_gamma, epsilon, def test_layernorm_forward_backward(
dtype): self, n, hidden, ln_type, zero_centered_gamma, epsilon, dtype
):
""" """
Test transformer_engine.jax.layernorm.layernorm Test transformer_engine.jax.layernorm.layernorm
""" """
expect_assert = False expect_assert = False
if ln_type == 'rmsnorm' and zero_centered_gamma: if ln_type == "rmsnorm" and zero_centered_gamma:
# zero_centered_gamma is not supported for rmsnorm, expect an assertion. # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
expect_assert = True expect_assert = True
with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*" with (
) if expect_assert else nullcontext(): pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*")
if expect_assert
else nullcontext()
):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3) subkeys = jax.random.split(key, 3)
...@@ -571,7 +653,7 @@ class TestNorm: ...@@ -571,7 +653,7 @@ class TestNorm:
gamma_range = (-1, 1) if zero_centered_gamma else (0, 2) gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range) gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
gamma = jnp.asarray(gamma, dtype) gamma = jnp.asarray(gamma, dtype)
if ln_type == 'layernorm': if ln_type == "layernorm":
beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1) beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
beta = jnp.asarray(beta, dtype) beta = jnp.asarray(beta, dtype)
else: else:
...@@ -585,19 +667,27 @@ class TestNorm: ...@@ -585,19 +667,27 @@ class TestNorm:
jitted_primitive = jit( jitted_primitive = jit(
value_and_grad( value_and_grad(
lambda x, gamma, beta: compute_loss( lambda x, gamma, beta: compute_loss(
layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)), layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)
(0, 1, 2))) ),
(0, 1, 2),
)
)
jitted_reference = jit( jitted_reference = jit(
value_and_grad( value_and_grad(
lambda x, gamma, beta: compute_loss( lambda x, gamma, beta: compute_loss(
self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)), self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
(0, 1, 2))) ),
(0, 1, 2),
)
)
primitive_out, (primitive_dx, primitive_dgamma, primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive(
primitive_dbeta) = jitted_primitive(x, gamma, beta) x, gamma, beta
reference_out, (reference_dx, reference_dgamma, )
reference_dbeta) = jitted_reference(x, gamma, beta) reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(
x, gamma, beta
)
assert_allclose(primitive_out, reference_out, dtype=dtype) assert_allclose(primitive_out, reference_out, dtype=dtype)
assert_allclose(primitive_dx, reference_dx, dtype=dtype) assert_allclose(primitive_dx, reference_dx, dtype=dtype)
...@@ -606,21 +696,24 @@ class TestNorm: ...@@ -606,21 +696,24 @@ class TestNorm:
assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype) assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES) @pytest.mark.parametrize("m,n,k", GEMM_CASES)
@pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm']) @pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"])
@pytest.mark.parametrize('zero_centered_gamma', [True, False]) @pytest.mark.parametrize("zero_centered_gamma", [True, False])
@pytest.mark.parametrize('epsilon', [1e-2, 1e-6]) @pytest.mark.parametrize("epsilon", [1e-2, 1e-6])
def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon): def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon):
""" """
Test transformer_engine.jax.layernorm.layernorm_fp8_dot Test transformer_engine.jax.layernorm.layernorm_fp8_dot
""" """
expect_assert = False expect_assert = False
if ln_type == 'rmsnorm' and zero_centered_gamma: if ln_type == "rmsnorm" and zero_centered_gamma:
# zero_centered_gamma is not supported for rmsnorm, expect an assertion. # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
expect_assert = True expect_assert = True
with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*" with (
) if expect_assert else nullcontext(): pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*")
if expect_assert
else nullcontext()
):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 4) subkeys = jax.random.split(key, 4)
...@@ -628,7 +721,7 @@ class TestNorm: ...@@ -628,7 +721,7 @@ class TestNorm:
b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16) b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16) gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
if ln_type == 'layernorm': if ln_type == "layernorm":
beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16) beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
else: else:
beta = None beta = None
...@@ -644,8 +737,9 @@ class TestNorm: ...@@ -644,8 +737,9 @@ class TestNorm:
amax_list_1[2], amax_list_1[2],
scale_list_1[2], scale_list_1[2],
) )
primitive_out = layernorm_fp8_dot(x, y, gamma, beta, fp8_meta_pkg, ln_type, primitive_out = layernorm_fp8_dot(
zero_centered_gamma) x, y, gamma, beta, fp8_meta_pkg, ln_type, zero_centered_gamma
)
return jnp.mean(primitive_out) return jnp.mean(primitive_out)
def ref_func(x, y, gamma, beta, zero_centered_gamma): def ref_func(x, y, gamma, beta, zero_centered_gamma):
...@@ -655,14 +749,19 @@ class TestNorm: ...@@ -655,14 +749,19 @@ class TestNorm:
value_n_grad_primitive_func = value_and_grad(primitive_func, range(6)) value_n_grad_primitive_func = value_and_grad(primitive_func, range(6))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))
ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad, ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad, ref_beta_grad) = (
ref_beta_grad) = value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma) value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma)
)
for _ in range(3): for _ in range(3):
primitive_out, (primitive_a_grad, primitive_b_grad, primitive_gamma_grad, primitive_out, (
primitive_beta_grad, amax_list_1, primitive_a_grad,
scale_list_1) = value_n_grad_primitive_func( primitive_b_grad,
a, b, gamma, beta, amax_list_1, scale_list_1) primitive_gamma_grad,
primitive_beta_grad,
amax_list_1,
scale_list_1,
) = value_n_grad_primitive_func(a, b, gamma, beta, amax_list_1, scale_list_1)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE) assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE) assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
......
...@@ -9,16 +9,8 @@ import jax.numpy as jnp ...@@ -9,16 +9,8 @@ import jax.numpy as jnp
import numpy as np import numpy as np
from flax.linen import dot_product_attention from flax.linen import dot_product_attention
from jax import random from jax import random
from jax.sharding import ( from jax.sharding import Mesh, NamedSharding, PartitionSpec
Mesh, from distributed_test_base import generate_configs, generate_collectives_count, compare_ops
NamedSharding,
PartitionSpec
)
from distributed_test_base import (
generate_configs,
generate_collectives_count,
compare_ops
)
from utils import make_causal_mask, make_self_mask from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
...@@ -27,7 +19,7 @@ from transformer_engine.jax.attention import ( ...@@ -27,7 +19,7 @@ from transformer_engine.jax.attention import (
fused_attn_kvpacked, fused_attn_kvpacked,
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
QKVLayout QKVLayout,
) )
...@@ -36,8 +28,9 @@ DTYPES = [jnp.float16, jnp.bfloat16] ...@@ -36,8 +28,9 @@ DTYPES = [jnp.float16, jnp.bfloat16]
class TestDistributedSelfAttn: class TestDistributedSelfAttn:
def generate_collectives_count_ref(self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, def generate_collectives_count_ref(
dtype): self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype) jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
_, seqlen, _, heads, _ = shape _, seqlen, _, heads, _ = shape
is_dp_enabled = mesh_resource.dp_resource is not None is_dp_enabled = mesh_resource.dp_resource is not None
...@@ -57,8 +50,11 @@ class TestDistributedSelfAttn: ...@@ -57,8 +50,11 @@ class TestDistributedSelfAttn:
qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype) qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
bias = random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype) \ bias = (
if with_bias else None random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype)
if with_bias
else None
)
mask = None mask = None
if attn_mask_type == AttnMaskType.PADDING_MASK: if attn_mask_type == AttnMaskType.PADDING_MASK:
...@@ -66,39 +62,66 @@ class TestDistributedSelfAttn: ...@@ -66,39 +62,66 @@ class TestDistributedSelfAttn:
elif attn_mask_type == AttnMaskType.CAUSAL_MASK: elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_self_mask(batch, seqlen) mask = make_self_mask(batch, seqlen)
qkv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, qkv_pspec = PartitionSpec(
None) mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
bias_pspec = PartitionSpec(None, mesh_resource.tp_resource, None, None) \ )
if with_bias else None bias_pspec = (
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \ PartitionSpec(None, mesh_resource.tp_resource, None, None) if with_bias else None
if attn_mask_type != AttnMaskType.NO_MASK else None )
mask_pspec = (
PartitionSpec(mesh_resource.dp_resource, None, None, None)
if attn_mask_type != AttnMaskType.NO_MASK
else None
)
return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]]) @pytest.mark.parametrize("data_shape", [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
@pytest.mark.parametrize(
"attn_bias_type",
[AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'attn_bias_type', "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
[AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS]) )
@pytest.mark.parametrize('attn_mask_type', @pytest.mark.parametrize("dtype", DTYPES)
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) def test_self_attn(
@pytest.mark.parametrize('dtype', DTYPES) self,
def test_self_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, device_count,
attn_bias_type, attn_mask_type, dtype): mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
attn_mask_type,
dtype,
):
dropout_prob = 0.0 dropout_prob = 0.0
is_training = True is_training = True
scaling_factor = 1.0 scaling_factor = 1.0
_, seqlen, _, num_head, hidden = data_shape _, seqlen, _, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type, if not is_fused_attn_kernel_available(
attn_mask_type, dropout_prob, num_head, num_head, dtype,
seqlen, seqlen, hidden): dtype,
QKVLayout.BS3HD,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_head,
seqlen,
seqlen,
hidden,
):
pytest.skip(f"No FusedAttn backwend found") pytest.skip(f"No FusedAttn backwend found")
def target_func(qkv, bias, mask): def target_func(qkv, bias, mask):
return jnp.mean( return jnp.mean(
fused_attn_qkvpacked(qkv, fused_attn_qkvpacked(
qkv,
bias, bias,
mask, mask,
None, None,
...@@ -106,7 +129,9 @@ class TestDistributedSelfAttn: ...@@ -106,7 +129,9 @@ class TestDistributedSelfAttn:
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_prob, dropout_probability=dropout_prob,
is_training=is_training)) is_training=is_training,
)
)
def ref_func(qkv, bias, mask): def ref_func(qkv, bias, mask):
query, key, value = jnp.split(qkv, [1, 2], axis=-3) query, key, value = jnp.split(qkv, [1, 2], axis=-3)
...@@ -114,7 +139,8 @@ class TestDistributedSelfAttn: ...@@ -114,7 +139,8 @@ class TestDistributedSelfAttn:
key = jnp.squeeze(key) key = jnp.squeeze(key)
value = jnp.squeeze(value) value = jnp.squeeze(value)
output = dot_product_attention(query, output = dot_product_attention(
query,
key, key,
value, value,
bias=bias, bias=bias,
...@@ -122,37 +148,43 @@ class TestDistributedSelfAttn: ...@@ -122,37 +148,43 @@ class TestDistributedSelfAttn:
deterministic=is_training, deterministic=is_training,
dropout_rate=dropout_prob, dropout_rate=dropout_prob,
dropout_rng=None, dropout_rng=None,
dtype=jnp.float32) dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype) return jnp.mean(output).astype(dtype)
with_bias = attn_bias_type != AttnBiasType.NO_BIAS with_bias = attn_bias_type != AttnBiasType.NO_BIAS
(qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = \ (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = self.generate_inputs(
self.generate_inputs(data_shape, mesh_resource, with_bias, data_shape, mesh_resource, with_bias, attn_mask_type, dtype
attn_mask_type, dtype) )
collective_count_ref = self.generate_collectives_count_ref(mesh_shape, mesh_axes, collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, with_bias, mesh_shape, mesh_axes, mesh_resource, with_bias, data_shape, dtype
data_shape, dtype) )
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource): with mesh, fp8_autocast(mesh_resource=mesh_resource):
qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec)) qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec))
bias_ = jax.device_put(bias, NamedSharding(mesh, bias_pspec)) \ bias_ = (
if bias is not None else bias jax.device_put(bias, NamedSharding(mesh, bias_pspec)) if bias is not None else bias
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \ )
if mask is not None else mask mask_ = (
jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
)
grad_args = (0, 1) if with_bias else (0,) grad_args = (0, 1) if with_bias else (0,)
out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,) out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,)
compare_ops(target_func, compare_ops(
ref_func, [qkv_, bias_, mask_], target_func,
ref_func,
[qkv_, bias_, mask_],
collective_count_ref, collective_count_ref,
grad_args=grad_args, grad_args=grad_args,
metric_fwd_dtype=dtype, metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype, metric_bwd_dtype=dtype,
in_shardings=(qkv_pspec, bias_pspec, mask_pspec), in_shardings=(qkv_pspec, bias_pspec, mask_pspec),
out_shardings=(None, out_grad_shardings)) out_shardings=(None, out_grad_shardings),
)
class TestDistributedCrossAttn: class TestDistributedCrossAttn:
...@@ -176,20 +208,26 @@ class TestDistributedCrossAttn: ...@@ -176,20 +208,26 @@ class TestDistributedCrossAttn:
q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None) q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None)
kv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, kv_pspec = PartitionSpec(
None) mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \ )
if attn_mask_type != AttnMaskType.NO_MASK else None mask_pspec = (
PartitionSpec(mesh_resource.dp_resource, None, None, None)
if attn_mask_type != AttnMaskType.NO_MASK
else None
)
return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 12, 64], [32, 512, 16, 64]]) @pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]])
@pytest.mark.parametrize('attn_mask_type', @pytest.mark.parametrize(
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
@pytest.mark.parametrize('dtype', DTYPES) )
def test_cross_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, @pytest.mark.parametrize("dtype", DTYPES)
attn_mask_type, dtype): def test_cross_attn(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
):
attn_bias_type = AttnBiasType.NO_BIAS attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0 dropout_prob = 0.0
is_training = True is_training = True
...@@ -197,14 +235,25 @@ class TestDistributedCrossAttn: ...@@ -197,14 +235,25 @@ class TestDistributedCrossAttn:
_, seqlen, num_head, hidden = data_shape _, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BSHD_BS2HD, attn_bias_type, if not is_fused_attn_kernel_available(
attn_mask_type, dropout_prob, num_head, num_head, dtype,
seqlen, seqlen, hidden): dtype,
QKVLayout.BSHD_BS2HD,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_head,
seqlen,
seqlen,
hidden,
):
pytest.skip(f"No FusedAttn backwend found") pytest.skip(f"No FusedAttn backwend found")
def target_func(q, kv, mask): def target_func(q, kv, mask):
return jnp.mean( return jnp.mean(
fused_attn_kvpacked(q, fused_attn_kvpacked(
q,
kv, kv,
None, None,
mask, mask,
...@@ -213,7 +262,9 @@ class TestDistributedCrossAttn: ...@@ -213,7 +262,9 @@ class TestDistributedCrossAttn:
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
dropout_probability=dropout_prob, dropout_probability=dropout_prob,
is_training=is_training)) is_training=is_training,
)
)
def ref_func(query, kv, mask): def ref_func(query, kv, mask):
key, value = jnp.split(kv, [1], axis=-3) key, value = jnp.split(kv, [1], axis=-3)
...@@ -221,7 +272,8 @@ class TestDistributedCrossAttn: ...@@ -221,7 +272,8 @@ class TestDistributedCrossAttn:
key = jnp.squeeze(key) key = jnp.squeeze(key)
value = jnp.squeeze(value) value = jnp.squeeze(value)
output = dot_product_attention(query, output = dot_product_attention(
query,
key, key,
value, value,
bias=None, bias=None,
...@@ -229,26 +281,32 @@ class TestDistributedCrossAttn: ...@@ -229,26 +281,32 @@ class TestDistributedCrossAttn:
deterministic=is_training, deterministic=is_training,
dropout_rate=dropout_prob, dropout_rate=dropout_prob,
dropout_rng=None, dropout_rng=None,
dtype=jnp.float32) dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype) return jnp.mean(output).astype(dtype)
(q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = \ (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs(
self.generate_inputs(data_shape, mesh_resource, attn_mask_type, dtype) data_shape, mesh_resource, attn_mask_type, dtype
)
collective_count_ref = self.generate_collectives_count_ref() collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource): with mesh, fp8_autocast(mesh_resource=mesh_resource):
q_ = jax.device_put(q, NamedSharding(mesh, q_pspec)) q_ = jax.device_put(q, NamedSharding(mesh, q_pspec))
kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec)) kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \ mask_ = (
if mask is not None else mask jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
)
compare_ops(target_func,
ref_func, [q_, kv_, mask_], compare_ops(
target_func,
ref_func,
[q_, kv_, mask_],
collective_count_ref, collective_count_ref,
grad_args=(0, 1), grad_args=(0, 1),
metric_fwd_dtype=dtype, metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype, metric_bwd_dtype=dtype,
in_shardings=(q_pspec, kv_pspec, mask_pspec), in_shardings=(q_pspec, kv_pspec, mask_pspec),
out_shardings=(None, (q_pspec, kv_pspec))) out_shardings=(None, (q_pspec, kv_pspec)),
)
...@@ -35,31 +35,44 @@ class TestDistributedLayernorm: ...@@ -35,31 +35,44 @@ class TestDistributedLayernorm:
else: else:
raise NotImplementedError raise NotImplementedError
g_pspec = b_pspec = PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None) g_pspec = b_pspec = (
PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
)
return (x, gamma, beta), (x_pspec, g_pspec, b_pspec) return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)
def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype): def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype) jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
is_dp_enabled = mesh_resource.dp_resource is not None is_dp_enabled = mesh_resource.dp_resource is not None
assert ln_type in ['layernorm', 'rmsnorm'] assert ln_type in ["layernorm", "rmsnorm"]
all_reduce_loss_bytes = 4 # 1 * FP32 all_reduce_loss_bytes = 4 # 1 * FP32
# for loss, dgamma and dbeta # for loss, dgamma and dbeta
weight_count = 2 if ln_type == 'layernorm' else 1 weight_count = 2 if ln_type == "layernorm" else 1
allreduce_total_bytes = all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize allreduce_total_bytes = (
return generate_collectives_count(allreduce=allreduce_total_bytes * int(is_dp_enabled), all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
allgather=0, )
other=0) return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=0
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) )
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize('zero_centered_gamma', [False, True]) @pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('shard_weights', [False, True]) @pytest.mark.parametrize("dtype", DTYPES)
def test_layernorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, @pytest.mark.parametrize("zero_centered_gamma", [False, True])
zero_centered_gamma, shard_weights): @pytest.mark.parametrize("shard_weights", [False, True])
def test_layernorm(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
dtype,
zero_centered_gamma,
shard_weights,
):
epsilon = 1e-6 epsilon = 1e-6
ln_type = 'layernorm' ln_type = "layernorm"
def target_func(x, gamma, beta): def target_func(x, gamma, beta):
return jnp.mean(layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)) return jnp.mean(layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon))
...@@ -75,10 +88,12 @@ class TestDistributedLayernorm: ...@@ -75,10 +88,12 @@ class TestDistributedLayernorm:
output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype) output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype)
return jnp.mean(output) return jnp.mean(output)
(x, gamma, beta), (x_pspec, g_pspec, b_pspec) = \ (x, gamma, beta), (x_pspec, g_pspec, b_pspec) = self.generate_inputs(
self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights) data_shape, mesh_resource, dtype, shard_weights
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type, )
data_shape, dtype) collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource): with mesh, fp8_autocast(mesh_resource=mesh_resource):
...@@ -88,20 +103,25 @@ class TestDistributedLayernorm: ...@@ -88,20 +103,25 @@ class TestDistributedLayernorm:
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
compare_ops(target_func, compare_ops(
ref_func, [x_, gamma_, beta_], target_func,
ref_func,
[x_, gamma_, beta_],
collective_count_ref, collective_count_ref,
grad_args=(0, 1, 2), grad_args=(0, 1, 2),
metric_fwd_dtype=dtype, metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype, metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec, b_pspec), in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec))) out_shardings=(None, (x_pspec, g_pspec, b_pspec)),
)
except AssertionError as err: except AssertionError as err:
# Layernorm should still produce the correct numerical result with # Layernorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same # gamma/beta sharded. However, the collective count may not be the same
# when XLA is forced to unshard gamma and/or beta. We can catch # when XLA is forced to unshard gamma and/or beta. We can catch
# and ignore that specific error here. # and ignore that specific error here.
if (g_pspec[-1] is None and b_pspec[-1] is None) or "Expected collective count" not in str(err): if (
g_pspec[-1] is None and b_pspec[-1] is None
) or "Expected collective count" not in str(err):
raise err raise err
finally: finally:
for w in warns: for w in warns:
...@@ -110,13 +130,15 @@ class TestDistributedLayernorm: ...@@ -110,13 +130,15 @@ class TestDistributedLayernorm:
"unsupported sharding of gamma and/or beta" "unsupported sharding of gamma and/or beta"
) )
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]]) @pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize('shard_weights', [False, True]) @pytest.mark.parametrize("shard_weights", [False, True])
def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights): def test_rmsnorm(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights
):
epsilon = 1e-6 epsilon = 1e-6
ln_type = 'rmsnorm' ln_type = "rmsnorm"
def target_func(x, gamma): def target_func(x, gamma):
return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon)) return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon))
...@@ -128,10 +150,12 @@ class TestDistributedLayernorm: ...@@ -128,10 +150,12 @@ class TestDistributedLayernorm:
output = y * gamma output = y * gamma
return jnp.mean(output) return jnp.mean(output)
(x, gamma, _), (x_pspec, g_pspec, _) = \ (x, gamma, _), (x_pspec, g_pspec, _) = self.generate_inputs(
self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights) data_shape, mesh_resource, dtype, shard_weights
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type, )
data_shape, dtype) collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource): with mesh, fp8_autocast(mesh_resource=mesh_resource):
...@@ -140,14 +164,17 @@ class TestDistributedLayernorm: ...@@ -140,14 +164,17 @@ class TestDistributedLayernorm:
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
compare_ops(target_func, compare_ops(
ref_func, [x_, gamma_], target_func,
ref_func,
[x_, gamma_],
collective_count_ref, collective_count_ref,
grad_args=(0, 1), grad_args=(0, 1),
metric_fwd_dtype=dtype, metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype, metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec), in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec))) out_shardings=(None, (x_pspec, g_pspec)),
)
except AssertionError as err: except AssertionError as err:
# RmsNorm should still produce the correct numerical result with # RmsNorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same # gamma/beta sharded. However, the collective count may not be the same
......
...@@ -15,18 +15,19 @@ from transformer_engine.jax import fp8_autocast ...@@ -15,18 +15,19 @@ from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.flax import LayerNormMLP from transformer_engine.jax.flax import LayerNormMLP
from transformer_engine.jax.layernorm_mlp import fused_layernorm_fp8_mlp from transformer_engine.jax.layernorm_mlp import fused_layernorm_fp8_mlp
from transformer_engine.jax.sharding import ( from transformer_engine.jax.sharding import (
HIDDEN_AXES, HIDDEN_TP_AXES, HIDDEN_AXES,
HIDDEN_TP_AXES,
BATCH_AXES, BATCH_AXES,
SEQLEN_TP_AXES, SEQLEN_AXES, SEQLEN_TP_AXES,
W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES SEQLEN_AXES,
W_NO_SHARD_AXES,
W_FSDP_AXES,
W_TP_AXES,
W_JOINED_AXES,
) )
from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.sharding import MeshResource
from utils import ( from utils import assert_allclose, assert_tree_like_allclose, is_devices_enough
assert_allclose,
assert_tree_like_allclose,
is_devices_enough
)
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
DTYPES = [jnp.bfloat16, jnp.float16] DTYPES = [jnp.bfloat16, jnp.float16]
...@@ -43,13 +44,13 @@ def generate_fsdp_and_tp_configs(): ...@@ -43,13 +44,13 @@ def generate_fsdp_and_tp_configs():
configs = [] configs = []
if is_devices_enough(2): if is_devices_enough(2):
configs.append( configs.append(
[2, (1, 2), ('fsdp', 'tp'), [2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
MeshResource(fsdp_resource='fsdp', tp_resource='tp')]) )
if is_devices_enough(4): if is_devices_enough(4):
configs.append( configs.append(
[4, (2, 2), ('fsdp', 'tp'), [4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
MeshResource(fsdp_resource='fsdp', tp_resource='tp')]) )
return configs return configs
...@@ -64,10 +65,12 @@ class TestDistributedLayernormMLP: ...@@ -64,10 +65,12 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype) gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
k1 = jax.random.normal(subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), k1 = jax.random.normal(
dtype) / jnp.sqrt(hidden_in) subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
k2 = jax.random.normal(subkeys[2], ) / jnp.sqrt(hidden_in)
(INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(INTERMEDIATE) k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
INTERMEDIATE
)
if use_bias: if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype) b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype) b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
...@@ -90,15 +93,27 @@ class TestDistributedLayernormMLP: ...@@ -90,15 +93,27 @@ class TestDistributedLayernormMLP:
scale_list_1: List[jnp.ndarray], scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray], scale_list_2: List[jnp.ndarray],
layernorm_type: str = "rmsnorm", layernorm_type: str = "rmsnorm",
activation_type: Sequence[Union[str, Callable]] = ('gelu',), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
use_bias: bool = True, use_bias: bool = True,
multi_gpus: bool = False, multi_gpus: bool = False,
) -> jnp.ndarray: ) -> jnp.ndarray:
fp8_meta_pkg1 = FP8MetaPackage(amax_list_1[0], scale_list_1[0], amax_list_1[1], fp8_meta_pkg1 = FP8MetaPackage(
scale_list_1[1], amax_list_1[2], scale_list_1[2]) amax_list_1[0],
fp8_meta_pkg2 = FP8MetaPackage(amax_list_2[0], scale_list_2[0], amax_list_2[1], scale_list_1[0],
scale_list_2[1], amax_list_2[2], scale_list_2[2]) amax_list_1[1],
scale_list_1[1],
amax_list_1[2],
scale_list_1[2],
)
fp8_meta_pkg2 = FP8MetaPackage(
amax_list_2[0],
scale_list_2[0],
amax_list_2[1],
scale_list_2[1],
amax_list_2[2],
scale_list_2[2],
)
if multi_gpus: if multi_gpus:
layernorm_input_axes = LAYERNORM_INPUT_AXES layernorm_input_axes = LAYERNORM_INPUT_AXES
...@@ -111,60 +126,68 @@ class TestDistributedLayernormMLP: ...@@ -111,60 +126,68 @@ class TestDistributedLayernormMLP:
# out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2 # out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
return jnp.mean( return jnp.mean(
fused_layernorm_fp8_mlp(x, fused_layernorm_fp8_mlp(
x,
ln_scale, ln_scale,
None, [kernel_1, kernel_2], [bias_1, bias_2], None,
[kernel_1, kernel_2],
[bias_1, bias_2],
[fp8_meta_pkg1, fp8_meta_pkg2], [fp8_meta_pkg1, fp8_meta_pkg2],
layernorm_type, layernorm_type,
layernorm_input_axes=layernorm_input_axes, layernorm_input_axes=layernorm_input_axes,
dot_1_input_axes=dot_1_input_axes, dot_1_input_axes=dot_1_input_axes,
dot_2_input_axes=dot_2_input_axes, dot_2_input_axes=dot_2_input_axes,
activation_type=activation_type, activation_type=activation_type,
use_bias=use_bias)) use_bias=use_bias,
)
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs()) @pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize('input_shape', INPUT_SHAPE) @pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize('activation_type', [("gelu",), ('gelu', 'linear')]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize('use_bias', [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
def test_layernorm_fp8_mlp_primitive(self, mesh_config, activation_type, use_bias, input_shape, def test_layernorm_fp8_mlp_primitive(
dtype): self, mesh_config, activation_type, use_bias, input_shape, dtype
):
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
layernorm_type = 'rmsnorm' layernorm_type = "rmsnorm"
fp8_amax_list_1 = [ fp8_amax_list_1 = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32), jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32), jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32) jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
] ]
fp8_amax_list_2 = [ fp8_amax_list_2 = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32), jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32), jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32) jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
] ]
fp8_scale_list_1 = [ fp8_scale_list_1 = [
jnp.ones((1,), jnp.float32), jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32), jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32) jnp.ones((1,), jnp.float32),
] ]
fp8_scale_list_2 = [ fp8_scale_list_2 = [
jnp.ones((1,), jnp.float32), jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32), jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32) jnp.ones((1,), jnp.float32),
] ]
inputs = [x, gamma, k1, k2, b1, b2] = \ inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs(
self.generate_inputs(input_shape, activation_type, use_bias, dtype) input_shape, activation_type, use_bias, dtype
)
inputs = [*inputs, fp8_amax_list_1, fp8_amax_list_2, fp8_scale_list_1, fp8_scale_list_2] inputs = [*inputs, fp8_amax_list_1, fp8_amax_list_2, fp8_scale_list_1, fp8_scale_list_2]
static_inputs = [layernorm_type, activation_type, use_bias] static_inputs = [layernorm_type, activation_type, use_bias]
value_and_grad_func = jax.value_and_grad(self.layernorm_fp8_mlp_prim_func, value_and_grad_func = jax.value_and_grad(
argnums=range(len(inputs))) self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
)
# Single GPU # Single GPU
single_jitter = jax.jit(value_and_grad_func, single_jitter = jax.jit(
static_argnums=range(len(inputs), value_and_grad_func, static_argnums=range(len(inputs), len(static_inputs) + len(inputs))
len(static_inputs) + len(inputs))) )
with fp8_autocast(enabled=True): with fp8_autocast(enabled=True):
single_fwd, single_grads = single_jitter(*inputs, *static_inputs) single_fwd, single_grads = single_jitter(*inputs, *static_inputs)
...@@ -172,12 +195,12 @@ class TestDistributedLayernormMLP: ...@@ -172,12 +195,12 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, mesh_resource=mesh_resource): with mesh, fp8_autocast(enabled=True, mesh_resource=mesh_resource):
k1_sharding = NamedSharding(mesh, PartitionSpec('fsdp', None, 'tp')) k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec('tp', 'fsdp')) k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding) k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding) k2_ = jax.device_put(k2, k2_sharding)
if use_bias: if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec(None, 'tp')) b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_ = jax.device_put(b1, b1_sharding) b1_ = jax.device_put(b1, b1_sharding)
else: else:
b1_sharding = b1_ = None b1_sharding = b1_ = None
...@@ -186,17 +209,29 @@ class TestDistributedLayernormMLP: ...@@ -186,17 +209,29 @@ class TestDistributedLayernormMLP:
# Position ref for sharding pspec lists # Position ref for sharding pspec lists
# x, gamma, k1, k2, b1, # x, gamma, k1, k2, b1,
# b2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv # b2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
in_shardings = (None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None, in_shardings = (
None, None) None,
out_shardings = (None, (None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None,
None, None, None)) k1_sharding,
k2_sharding,
b1_sharding,
None,
None,
None,
None,
None,
)
out_shardings = (
None,
(None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None, None, None),
)
multi_jitter = jax.jit(value_and_grad_func, multi_jitter = jax.jit(
value_and_grad_func,
in_shardings=in_shardings, in_shardings=in_shardings,
out_shardings=out_shardings, out_shardings=out_shardings,
static_argnums=range(len(multi_inputs), static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1),
len(static_inputs) + len(multi_inputs) + ) # +1 for multi_gpus
1)) # +1 for multi_gpus
multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)
...@@ -206,26 +241,28 @@ class TestDistributedLayernormMLP: ...@@ -206,26 +241,28 @@ class TestDistributedLayernormMLP:
if isinstance(multi_grads[i], list): if isinstance(multi_grads[i], list):
assert isinstance(single_grads[i], list) assert isinstance(single_grads[i], list)
for m_grad, s_grad in zip(multi_grads[i], single_grads[i]): for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
assert_allclose(m_grad, assert_allclose(
s_grad, m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close"
dtype=dtype, )
err_msg=f'multi_grads[{i}] is not close')
else: else:
assert_allclose(multi_grads[i], assert_allclose(
multi_grads[i],
single_grads[i], single_grads[i],
dtype=dtype, dtype=dtype,
err_msg=f'multi_grads[{i}] is not close') err_msg=f"multi_grads[{i}] is not close",
)
def _test_layernorm_mlp(self, mesh_config, activation_type, use_bias, input_shape, dtype, def _test_layernorm_mlp(
use_fp8): self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8
):
batch, seqlen, hidden_in = input_shape batch, seqlen, hidden_in = input_shape
layernorm_type = 'rmsnorm' layernorm_type = "rmsnorm"
rng = jax.random.PRNGKey(0) rng = jax.random.PRNGKey(0)
subkeys = jax.random.split(rng, 2) subkeys = jax.random.split(rng, 2)
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
init_rngs = {'params': subkeys[1]} init_rngs = {"params": subkeys[1]}
# Single GPUs # Single GPUs
with fp8_autocast(enabled=use_fp8): with fp8_autocast(enabled=use_fp8):
...@@ -238,16 +275,17 @@ class TestDistributedLayernormMLP: ...@@ -238,16 +275,17 @@ class TestDistributedLayernormMLP:
use_bias=use_bias, use_bias=use_bias,
) )
params_single = ln_mlp_single.init(init_rngs, x) params_single = ln_mlp_single.init(init_rngs, x)
mlp_out_single, ln_out_single = ln_mlp_single.apply(params_single, mlp_out_single, ln_out_single = ln_mlp_single.apply(
x, params_single, x, deterministic=True
deterministic=True) )
# Multi GPUs # Multi GPUs
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=use_fp8, mesh_resource=mesh_resource): with mesh, fp8_autocast(enabled=use_fp8, mesh_resource=mesh_resource):
ln_mlp_sharded = LayerNormMLP(layernorm_type=layernorm_type, ln_mlp_sharded = LayerNormMLP(
layernorm_type=layernorm_type,
transpose_batch_sequence=False, transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
activations=activation_type, activations=activation_type,
...@@ -262,41 +300,37 @@ class TestDistributedLayernormMLP: ...@@ -262,41 +300,37 @@ class TestDistributedLayernormMLP:
layernorm_input_axes=LAYERNORM_INPUT_AXES, layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES, dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES, dot_2_input_axes=DOT_2_INPUT_AXES,
name='mlp') name="mlp",
)
params_sharded = ln_mlp_sharded.init(init_rngs, x) params_sharded = ln_mlp_sharded.init(init_rngs, x)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(params_sharded, mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
x, params_sharded, x, deterministic=True
deterministic=True) )
# Make sure params values are the same # Make sure params values are the same
assert_tree_like_allclose(params_sharded['params'], params_single['params']) assert_tree_like_allclose(params_sharded["params"], params_single["params"])
assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype) assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype) assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype)
@pytest.mark.parametrize('input_shape', INPUT_SHAPE) @pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs()) @pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize('activation_type', [("gelu",), ('silu', 'linear'), ('gelu', 'gelu')]) @pytest.mark.parametrize("activation_type", [("gelu",), ("silu", "linear"), ("gelu", "gelu")])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize('use_bias', [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype): def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype):
self._test_layernorm_mlp(mesh_config, self._test_layernorm_mlp(
activation_type, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False
use_bias, )
input_shape,
dtype,
use_fp8=False)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs()) @pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize('activation_type', [("gelu",), ('gelu', 'linear'), ('gelu', 'gelu')]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear"), ("gelu", "gelu")])
@pytest.mark.parametrize('use_bias', [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize('input_shape', INPUT_SHAPE) @pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
def test_layernorm_fp8_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, def test_layernorm_fp8_mlp_layer(
dtype): self, mesh_config, activation_type, use_bias, input_shape, dtype
self._test_layernorm_mlp(mesh_config, ):
activation_type, self._test_layernorm_mlp(
use_bias, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=True
input_shape, )
dtype,
use_fp8=True)
...@@ -38,11 +38,13 @@ class TestDistributedSoftmax: ...@@ -38,11 +38,13 @@ class TestDistributedSoftmax:
mask = make_self_mask(batch, sqelen) mask = make_self_mask(batch, sqelen)
if not bad_sharding: if not bad_sharding:
x_pspec = PartitionSpec(mesh_resource.dp_resource, mesh_resource.tp_resource, x_pspec = PartitionSpec(
None, None) mesh_resource.dp_resource, mesh_resource.tp_resource, None, None
)
else: else:
x_pspec = PartitionSpec(mesh_resource.dp_resource, None, x_pspec = PartitionSpec(
None, mesh_resource.tp_resource) mesh_resource.dp_resource, None, None, mesh_resource.tp_resource
)
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None)
return (x, mask), (x_pspec, mask_pspec) return (x, mask), (x_pspec, mask_pspec)
...@@ -55,32 +57,46 @@ class TestDistributedSoftmax: ...@@ -55,32 +57,46 @@ class TestDistributedSoftmax:
def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16): def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16):
bias = None bias = None
if mask is not None: if mask is not None:
bias = jax.lax.select(mask > 0, bias = jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(dtype), jnp.full(mask.shape, -1e10).astype(dtype),
jnp.full(mask.shape, 0.).astype(dtype)) jnp.full(mask.shape, 0.0).astype(dtype),
)
if bias is not None: if bias is not None:
x = x + bias.astype(dtype) x = x + bias.astype(dtype)
output = jax.nn.softmax(x * scale_factor) output = jax.nn.softmax(x * scale_factor)
return jnp.mean(output) return jnp.mean(output)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 12, 128, 128], [64, 16, 1024, 1024]]) @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize( @pytest.mark.parametrize(
'softmax_type', "softmax_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED]) [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
@pytest.mark.parametrize('scale_factor', [1.0, 3.0]) )
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("scale_factor", [1.0, 3.0])
@pytest.mark.parametrize('bad_sharding', [False, True]) @pytest.mark.parametrize("dtype", DTYPES)
def test_softmax(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, @pytest.mark.parametrize("bad_sharding", [False, True])
softmax_type, scale_factor, dtype, bad_sharding): def test_softmax(
self,
target_func = partial(self.target_func, device_count,
scale_factor=scale_factor, mesh_shape,
softmax_type=softmax_type) mesh_axes,
mesh_resource,
data_shape,
softmax_type,
scale_factor,
dtype,
bad_sharding,
):
target_func = partial(
self.target_func, scale_factor=scale_factor, softmax_type=softmax_type
)
ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype) ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)
(x, mask), (x_pspec, mask_pspec) = \ (x, mask), (x_pspec, mask_pspec) = self.generate_inputs(
self.generate_inputs(data_shape, mesh_resource, softmax_type, dtype, bad_sharding) data_shape, mesh_resource, softmax_type, dtype, bad_sharding
)
collective_count_ref = self.generate_collectives_count_ref() collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
...@@ -90,14 +106,17 @@ class TestDistributedSoftmax: ...@@ -90,14 +106,17 @@ class TestDistributedSoftmax:
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
compare_ops(target_func, compare_ops(
ref_func, [x_, mask_], target_func,
ref_func,
[x_, mask_],
collective_count_ref, collective_count_ref,
grad_args=(0,), grad_args=(0,),
metric_fwd_dtype=dtype, metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype, metric_bwd_dtype=dtype,
in_shardings=(x_pspec, mask_pspec), in_shardings=(x_pspec, mask_pspec),
out_shardings=(None, (x_pspec,))) out_shardings=(None, (x_pspec,)),
)
except AssertionError as err: except AssertionError as err:
# Softmax should still produce the correct numerical result with # Softmax should still produce the correct numerical result with
# bad sharding. However, the collective count may not be the same # bad sharding. However, the collective count may not be the same
......
...@@ -20,12 +20,14 @@ class TestLoRA: ...@@ -20,12 +20,14 @@ class TestLoRA:
out = jnp.einsum(pattern, x, la, lb) out = jnp.einsum(pattern, x, la, lb)
return out * scale return out * scale
@pytest.mark.parametrize('shape', [(32, 1024), (32, 128, 1024)]) @pytest.mark.parametrize("shape", [(32, 1024), (32, 128, 1024)])
@pytest.mark.parametrize('dtype', [jnp.float32, jnp.bfloat16]) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16])
@pytest.mark.parametrize('axis_features_pattern', [((-1,), (1024,), '...h,hr,rk->...k'), @pytest.mark.parametrize(
((-1,), (3, 1024), '...h,hkr,krz->...kz')]) "axis_features_pattern",
@pytest.mark.parametrize('rank', [32, 16]) [((-1,), (1024,), "...h,hr,rk->...k"), ((-1,), (3, 1024), "...h,hkr,krz->...kz")],
@pytest.mark.parametrize('alpha', [None, 4, 8]) )
@pytest.mark.parametrize("rank", [32, 16])
@pytest.mark.parametrize("alpha", [None, 4, 8])
def test_lora(self, shape, dtype, axis_features_pattern, rank, alpha): def test_lora(self, shape, dtype, axis_features_pattern, rank, alpha):
axis, features, pattern = axis_features_pattern axis, features, pattern = axis_features_pattern
axis = _normalize_axes(axis, len(shape)) axis = _normalize_axes(axis, len(shape))
...@@ -49,16 +51,20 @@ class TestLoRA: ...@@ -49,16 +51,20 @@ class TestLoRA:
assert_allclose(out_target, out_ref, dtype=dtype) assert_allclose(out_target, out_ref, dtype=dtype)
@pytest.mark.parametrize('scope_ref_assert', @pytest.mark.parametrize(
[('none', LoRAScope(False, False, False), False), "scope_ref_assert",
('all', LoRAScope(True, True, True), False), [
('qkv_proj', LoRAScope(True, False, False), False), ("none", LoRAScope(False, False, False), False),
('output_proj', LoRAScope(False, True, False), False), ("all", LoRAScope(True, True, True), False),
('mlp', LoRAScope(False, False, True), False), ("qkv_proj", LoRAScope(True, False, False), False),
('exclude_qkv_proj', LoRAScope(False, True, True), False), ("output_proj", LoRAScope(False, True, False), False),
('exclude_output_proj', LoRAScope(True, False, True), False), ("mlp", LoRAScope(False, False, True), False),
('exclude_mlp', LoRAScope(True, True, False), False), ("exclude_qkv_proj", LoRAScope(False, True, True), False),
('messing_up', LoRAScope(), True)]) ("exclude_output_proj", LoRAScope(True, False, True), False),
("exclude_mlp", LoRAScope(True, True, False), False),
("messing_up", LoRAScope(), True),
],
)
def test_lora_scope_generator(self, scope_ref_assert): def test_lora_scope_generator(self, scope_ref_assert):
scope, reference, need_assert = scope_ref_assert scope, reference, need_assert = scope_ref_assert
try: try:
......
...@@ -24,7 +24,7 @@ from transformer_engine.jax.attention import ( ...@@ -24,7 +24,7 @@ from transformer_engine.jax.attention import (
QKVLayout, QKVLayout,
fused_attn_qkvpacked, fused_attn_qkvpacked,
fused_attn_kvpacked, fused_attn_kvpacked,
fused_attn fused_attn,
) )
from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine.jax.cpp_extensions import FusedAttnHelper
...@@ -33,7 +33,7 @@ from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend ...@@ -33,7 +33,7 @@ from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
from utils import assert_allclose from utils import assert_allclose
@pytest.fixture(autouse=True, scope='module') @pytest.fixture(autouse=True, scope="module")
def init(): def init():
""" """
WAR for CUDA uninitialize error WAR for CUDA uninitialize error
...@@ -43,10 +43,18 @@ def init(): ...@@ -43,10 +43,18 @@ def init():
yield yield
def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike, def general_dot_product_attention(
bias: ArrayLike, mask: ArrayLike, deterministic: bool, query: ArrayLike,
scale_factor: float, dropout_rate: float, dropout_rng: ArrayLike, key: ArrayLike,
dtype: DTypeLike) -> Array: value: ArrayLike,
bias: ArrayLike,
mask: ArrayLike,
deterministic: bool,
scale_factor: float,
dropout_rate: float,
dropout_rng: ArrayLike,
dtype: DTypeLike,
) -> Array:
""" """
Similar to flax.linen.dot_product_attention but with GQA support Similar to flax.linen.dot_product_attention but with GQA support
""" """
...@@ -59,7 +67,7 @@ def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: Array ...@@ -59,7 +67,7 @@ def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: Array
num_groups = h_q // h_kv num_groups = h_q // h_kv
grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d)) grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
# logits with shape (b, h_kv, num_groups, s_q, s_kv) # logits with shape (b, h_kv, num_groups, s_q, s_kv)
logits = scale_factor * jnp.einsum('...qhgd,...khd->...hgqk', grouped_query, key) logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
if bias is not None: if bias is not None:
# reshape logits without groups # reshape logits without groups
...@@ -76,13 +84,13 @@ def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: Array ...@@ -76,13 +84,13 @@ def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: Array
softmax_out = jax.nn.softmax(logits).astype(dtype) softmax_out = jax.nn.softmax(logits).astype(dtype)
if not deterministic and dropout_rate > 0.: if not deterministic and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate keep_prob = 1.0 - dropout_rate
keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape) keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype) multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
softmax_out = softmax_out * multiplier softmax_out = softmax_out * multiplier
context = jnp.einsum('...hgqk,...khd->...qhgd', softmax_out, value) context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
context = jnp.reshape(context, query.shape) context = jnp.reshape(context, query.shape)
return context return context
...@@ -105,6 +113,7 @@ def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array: ...@@ -105,6 +113,7 @@ def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
inv_padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0) inv_padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0)
return combine_masks(inv_causal_mask, inv_padding_mask) return combine_masks(inv_causal_mask, inv_padding_mask)
def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskType) -> Array: def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskType) -> Array:
""" """
Create attention mask based on mask type. A `True` value in the mask means Create attention mask based on mask type. A `True` value in the mask means
...@@ -118,23 +127,26 @@ def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskT ...@@ -118,23 +127,26 @@ def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskT
mask = jnp.logical_not(inv_mask) mask = jnp.logical_not(inv_mask)
return mask return mask
def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs): def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
""" """
JAX native dot product attention implementation JAX native dot product attention implementation
""" """
attn_mask_type = kwargs['attn_mask_type'] attn_mask_type = kwargs["attn_mask_type"]
mask = make_mask(q_token, kv_token, attn_mask_type) mask = make_mask(q_token, kv_token, attn_mask_type)
output = general_dot_product_attention(query, output = general_dot_product_attention(
query,
key, key,
value, value,
bias=bias, bias=bias,
mask=mask, mask=mask,
deterministic=not kwargs['is_training'], deterministic=not kwargs["is_training"],
scale_factor=kwargs['scaling_factor'], scale_factor=kwargs["scaling_factor"],
dropout_rate=kwargs['dropout_probability'], dropout_rate=kwargs["dropout_probability"],
dropout_rng=dropout_rng, dropout_rng=dropout_rng,
dtype=jnp.float32) dtype=jnp.float32,
)
return output.astype(query.dtype) return output.astype(query.dtype)
...@@ -142,10 +154,10 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng ...@@ -142,10 +154,10 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
""" """
TE customcall dot product attention implementation TE customcall dot product attention implementation
""" """
attn_mask_type = kwargs['attn_mask_type'] attn_mask_type = kwargs["attn_mask_type"]
mask = make_mask(q_token, kv_token, attn_mask_type) mask = make_mask(q_token, kv_token, attn_mask_type)
qkv_layout = kwargs.pop('qkv_layout') qkv_layout = kwargs.pop("qkv_layout")
match qkv_layout: match qkv_layout:
case QKVLayout.BS3HD: case QKVLayout.BS3HD:
query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value]) query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
...@@ -154,11 +166,13 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng ...@@ -154,11 +166,13 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
case QKVLayout.BSHD_BS2HD: case QKVLayout.BSHD_BS2HD:
key, value = map(partial(jnp.expand_dims, axis=-3), [key, value]) key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
kv = jnp.concatenate((key, value), axis=-3) kv = jnp.concatenate((key, value), axis=-3)
return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng, return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng, **kwargs).astype(
**kwargs).astype(query.dtype) query.dtype
)
case QKVLayout.BSHD_BSHD_BSHD: case QKVLayout.BSHD_BSHD_BSHD:
return fused_attn(query, key, value, bias, mask, dropout_rng, return fused_attn(query, key, value, bias, mask, dropout_rng, **kwargs).astype(
**kwargs).astype(query.dtype) query.dtype
)
class BiasShape(Enum): class BiasShape(Enum):
...@@ -166,10 +180,10 @@ class BiasShape(Enum): ...@@ -166,10 +180,10 @@ class BiasShape(Enum):
Enum class to represent the different bias shapes used in the fused attention. Enum class to represent the different bias shapes used in the fused attention.
""" """
BIAS_1HSS = '1HSS' BIAS_1HSS = "1HSS"
BIAS_B1SS = 'B1SS' BIAS_B1SS = "B1SS"
BIAS_BHSS = 'BHSS' BIAS_BHSS = "BHSS"
BIAS_11SS = '11SS' BIAS_11SS = "11SS"
@dataclass @dataclass
...@@ -177,6 +191,7 @@ class FusedAttnRunner: ...@@ -177,6 +191,7 @@ class FusedAttnRunner:
""" """
Fused attention runner Fused attention runner
""" """
batch_size: int batch_size: int
max_seqlen_q: int max_seqlen_q: int
max_seqlen_kv: int max_seqlen_kv: int
...@@ -198,21 +213,33 @@ class FusedAttnRunner: ...@@ -198,21 +213,33 @@ class FusedAttnRunner:
if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv: if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.") pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")
self.backend = FusedAttnHelper(self.dtype, self.dtype, self.qkv_layout.value, self.backend = FusedAttnHelper(
self.attn_bias_type.value, self.attn_mask_type.value, self.dtype,
self.dropout_prob, self.num_heads_q, self.num_heads_kv, self.dtype,
self.max_seqlen_q, self.max_seqlen_kv, self.qkv_layout.value,
self.head_dim).get_fused_attn_backend() self.attn_bias_type.value,
self.attn_mask_type.value,
self.dropout_prob,
self.num_heads_q,
self.num_heads_kv,
self.max_seqlen_q,
self.max_seqlen_kv,
self.head_dim,
).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("Unsupported inputs combination or device compute capability.") pytest.skip("Unsupported inputs combination or device compute capability.")
if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS: if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for " pytest.skip(
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.") "B1SS, BHSS and 11SS bias shapes are only supported for "
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK."
)
elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for " pytest.skip(
"the F16_arbitrary_seqlen backend.") "B1SS, BHSS and 11SS bias shapes are only supported for "
"the F16_arbitrary_seqlen backend."
)
def _setup_inputs(self): def _setup_inputs(self):
self._check_configs() self._check_configs()
...@@ -235,24 +262,25 @@ class FusedAttnRunner: ...@@ -235,24 +262,25 @@ class FusedAttnRunner:
else: else:
pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!") pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.) self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.) self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.) self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
if self.attn_bias_type != AttnBiasType.NO_BIAS: if self.attn_bias_type != AttnBiasType.NO_BIAS:
if self.bias_shape == BiasShape.BIAS_1HSS: if self.bias_shape == BiasShape.BIAS_1HSS:
self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.) self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
else: else:
# [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
# an arbitrary mask where (True/False -> 0/-Inf) # an arbitrary mask where (True/False -> 0/-Inf)
cudnn_neg_inf = -2.**27. if self.dtype == jnp.bfloat16 else -2.**15. cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype) self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
max_id = min(self.max_seqlen_q, self.max_seqlen_kv) max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences
seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist() seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
for i in range(1, len(seq_id)): for i in range(1, len(seq_id)):
self.bias = \ self.bias = self.bias.at[
self.bias.at[:, :, seq_id[i-1]:seq_id[i], seq_id[i-1]:seq_id[i]].set(0.) :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
].set(0.0)
else: else:
self.bias = None self.bias = None
...@@ -271,7 +299,7 @@ class FusedAttnRunner: ...@@ -271,7 +299,7 @@ class FusedAttnRunner:
self.valid_len_kv, self.token_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio) self.valid_len_kv, self.token_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1. / sqrt(self.head_dim) self.scaling_factor = 1.0 / sqrt(self.head_dim)
def test_forward(self): def test_forward(self):
""" """
...@@ -281,19 +309,19 @@ class FusedAttnRunner: ...@@ -281,19 +309,19 @@ class FusedAttnRunner:
args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng] args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
kwargs = { kwargs = {
'attn_bias_type': self.attn_bias_type, "attn_bias_type": self.attn_bias_type,
'attn_mask_type': self.attn_mask_type, "attn_mask_type": self.attn_mask_type,
'scaling_factor': self.scaling_factor, "scaling_factor": self.scaling_factor,
'dropout_probability': self.dropout_prob, "dropout_probability": self.dropout_prob,
'is_training': self.is_training, "is_training": self.is_training,
'qkv_layout': self.qkv_layout, "qkv_layout": self.qkv_layout,
} }
# Convert the outputs to float32 for the elementwise comparison # Convert the outputs to float32 for the elementwise comparison
primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32) primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32)
reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32) reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32)
if self.is_training and self.dropout_prob > 0.: if self.is_training and self.dropout_prob > 0.0:
return return
primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1) primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1)
...@@ -322,12 +350,12 @@ class FusedAttnRunner: ...@@ -322,12 +350,12 @@ class FusedAttnRunner:
args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng] args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
kwargs = { kwargs = {
'attn_bias_type': self.attn_bias_type, "attn_bias_type": self.attn_bias_type,
'attn_mask_type': self.attn_mask_type, "attn_mask_type": self.attn_mask_type,
'scaling_factor': self.scaling_factor, "scaling_factor": self.scaling_factor,
'dropout_probability': self.dropout_prob, "dropout_probability": self.dropout_prob,
'is_training': self.is_training, "is_training": self.is_training,
'qkv_layout': self.qkv_layout, "qkv_layout": self.qkv_layout,
} }
# We can compute dBias only for the [1, h, s, s] layout # We can compute dBias only for the [1, h, s, s] layout
...@@ -336,12 +364,18 @@ class FusedAttnRunner: ...@@ -336,12 +364,18 @@ class FusedAttnRunner:
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit( jitted_primitive = jit(
value_and_grad( value_and_grad(
lambda q, k, v, bias, *args: grad_func(customcall_fused_dpa, q, k, v, bias, *args, lambda q, k, v, bias, *args: grad_func(
**kwargs), arg_nums)) customcall_fused_dpa, q, k, v, bias, *args, **kwargs
),
arg_nums,
)
)
jitted_reference = jit( jitted_reference = jit(
value_and_grad( value_and_grad(
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs), lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
arg_nums)) arg_nums,
)
)
primitive_out, primitive_dgrad = jitted_primitive(*args) primitive_out, primitive_dgrad = jitted_primitive(*args)
reference_out, reference_dgrad = jitted_reference(*args) reference_out, reference_dgrad = jitted_reference(*args)
...@@ -350,9 +384,9 @@ class FusedAttnRunner: ...@@ -350,9 +384,9 @@ class FusedAttnRunner:
if self.dropout_prob > 0.0: if self.dropout_prob > 0.0:
return return
assert_allclose(primitive_out.astype(jnp.float32), assert_allclose(
reference_out.astype(jnp.float32), primitive_out.astype(jnp.float32), reference_out.astype(jnp.float32), dtype=self.dtype
dtype=self.dtype) )
def check_dqkv(primitive, reference, valid_len): def check_dqkv(primitive, reference, valid_len):
primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1) primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1)
...@@ -374,81 +408,158 @@ class FusedAttnRunner: ...@@ -374,81 +408,158 @@ class FusedAttnRunner:
primitive_dbias = jnp.float32(primitive_dgrad[3]) primitive_dbias = jnp.float32(primitive_dgrad[3])
reference_dbias = jnp.float32(reference_dgrad[3]) reference_dbias = jnp.float32(reference_dgrad[3])
assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:], assert_allclose(
jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :],
self.valid_len_kv:]), jnp.zeros_like(primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :]),
dtype=self.dtype) dtype=self.dtype,
)
# dbias padded part # dbias padded part
assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:], assert_allclose(
reference_dbias[..., self.valid_len_q:, self.valid_len_kv:], primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :],
dtype=self.dtype) reference_dbias[..., self.valid_len_q :, self.valid_len_kv :],
dtype=self.dtype,
)
# dbias valid part # dbias valid part
assert_allclose(primitive_dbias[..., :self.valid_len_q, :self.valid_len_kv], assert_allclose(
reference_dbias[..., :self.valid_len_q, :self.valid_len_kv], primitive_dbias[..., : self.valid_len_q, : self.valid_len_kv],
dtype=self.dtype) reference_dbias[..., : self.valid_len_q, : self.valid_len_kv],
dtype=self.dtype,
)
@pytest.mark.parametrize('attn_bias_type, bias_shape', [
pytest.param(AttnBiasType.NO_BIAS, None, id='NO_BIAS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id='POST_SCALE_BIAS-1HSS'), @pytest.mark.parametrize(
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id='POST_SCALE_BIAS-B1SS'), "attn_bias_type, bias_shape",
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id='POST_SCALE_BIAS-BHSS'), [
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id='POST_SCALE_BIAS-11SS'), pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
]) pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id="POST_SCALE_BIAS-1HSS"),
@pytest.mark.parametrize('attn_mask_type', [ pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id="POST_SCALE_BIAS-B1SS"),
pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'), pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id="POST_SCALE_BIAS-BHSS"),
pytest.param(AttnMaskType.PADDING_MASK, id='PADDING'), pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id="POST_SCALE_BIAS-11SS"),
pytest.param(AttnMaskType.CAUSAL_MASK, id='CAUSAL'), ],
pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id='PADDING_CAUSAL'), )
]) @pytest.mark.parametrize(
@pytest.mark.parametrize('qkv_layout', [ "attn_mask_type",
pytest.param(QKVLayout.BS3HD, id='QKV_PACKED'), [
pytest.param(QKVLayout.BSHD_BS2HD, id='KV_PACKED'), pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='SEPARATE'), pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
]) pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
@pytest.mark.parametrize('dtype', [ pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
],
)
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
],
)
@pytest.mark.parametrize(
"dtype",
[
pytest.param(jnp.bfloat16, id="BF16"), pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16"), pytest.param(jnp.float16, id="FP16"),
]) ],
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d', [ )
pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'), @pytest.mark.parametrize(
pytest.param(4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'), "b, s_q, s_kv, h_q, h_kv, d",
pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'), [
pytest.param(4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'), pytest.param(32, 128, 128, 16, 16, 64, id="32-128-128-16-16-64-SELF"),
pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'), pytest.param(4, 2048, 2048, 12, 12, 64, id="4-2048-2048-12-12-64-SELF"),
pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA'), pytest.param(32, 512, 128, 16, 16, 64, id="32-512-128-16-16-64-CROSS"),
]) pytest.param(4, 2048, 1024, 12, 12, 64, id="4-2048-1048-12-12-64-CROSS"),
@pytest.mark.parametrize('dropout_prob', [ pytest.param(32, 128, 128, 16, 8, 64, id="32-128-128-16-8-64-GQA"),
pytest.param(4, 2048, 2048, 12, 6, 64, id="4-2048-2048-12-6-64-GQA"),
],
)
@pytest.mark.parametrize(
"dropout_prob",
[
pytest.param(0.0, id="DROP_0.0"), pytest.param(0.0, id="DROP_0.0"),
pytest.param(0.1, id="DROP_0.1"), pytest.param(0.1, id="DROP_0.1"),
]) ],
)
class TestFusedAttn: class TestFusedAttn:
""" """
Fused attention tester Fused attention tester
""" """
@staticmethod @staticmethod
@pytest.mark.parametrize('is_training', [ @pytest.mark.parametrize(
pytest.param(True, id='TRAINING'), "is_training",
pytest.param(False, id='INFERENCE'), [
]) pytest.param(True, id="TRAINING"),
def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob, pytest.param(False, id="INFERENCE"),
dtype, is_training, qkv_layout, bias_shape): ],
)
def test_forward(
b,
s_q,
s_kv,
h_q,
h_kv,
d,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
):
""" """
Test forward with parameterized configs Test forward with parameterized configs
""" """
runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, runner = FusedAttnRunner(
dropout_prob, dtype, is_training, qkv_layout, bias_shape) b,
s_q,
s_kv,
h_q,
h_kv,
d,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
)
runner.test_forward() runner.test_forward()
@staticmethod @staticmethod
def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob, def test_backward(
dtype, qkv_layout, bias_shape): b,
s_q,
s_kv,
h_q,
h_kv,
d,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
qkv_layout,
bias_shape,
):
""" """
Test backward with parameterized configs Test backward with parameterized configs
""" """
runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, runner = FusedAttnRunner(
dropout_prob, dtype, True, qkv_layout, bias_shape) b,
s_q,
s_kv,
h_q,
h_kv,
d,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
True,
qkv_layout,
bias_shape,
)
runner.test_backward() runner.test_backward()
...@@ -27,21 +27,28 @@ class TestFP8Helper(unittest.TestCase): ...@@ -27,21 +27,28 @@ class TestFP8Helper(unittest.TestCase):
fp8_format = FP8Format.E4M3 fp8_format = FP8Format.E4M3
amax_history_len = 10 amax_history_len = 10
FP8Helper.initialize(margin=margin, FP8Helper.initialize(
fp8_format=fp8_format, margin=margin, fp8_format=fp8_format, amax_history_len=amax_history_len
amax_history_len=amax_history_len) )
self.assertEqual( self.assertEqual(
FP8Helper.MARGIN, margin, f"FP8Helper.MARGIN initialization failed, should be {margin}" FP8Helper.MARGIN,
f" but got {FP8Helper.MARGIN}.") margin,
f"FP8Helper.MARGIN initialization failed, should be {margin}"
f" but got {FP8Helper.MARGIN}.",
)
self.assertEqual( self.assertEqual(
FP8Helper.FP8_FORMAT, fp8_format, FP8Helper.FP8_FORMAT,
fp8_format,
f"FP8Helper.FP8_FORMAT initialization failed, should be {fp8_format}" f"FP8Helper.FP8_FORMAT initialization failed, should be {fp8_format}"
f" but got {FP8Helper.FP8_FORMAT}.") f" but got {FP8Helper.FP8_FORMAT}.",
)
self.assertEqual( self.assertEqual(
FP8Helper.AMAX_HISTORY_LEN, amax_history_len, FP8Helper.AMAX_HISTORY_LEN,
amax_history_len,
f"FP8Helper.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}" f"FP8Helper.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
f" but got {FP8Helper.AMAX_HISTORY_LEN}.") f" but got {FP8Helper.AMAX_HISTORY_LEN}.",
)
FP8Helper.finalize() FP8Helper.finalize()
...@@ -109,14 +116,14 @@ class TestFP8Functions(unittest.TestCase): ...@@ -109,14 +116,14 @@ class TestFP8Functions(unittest.TestCase):
mesh_s = ( mesh_s = (
(MeshResource(None, None)), (MeshResource(None, None)),
(MeshResource('dp', None)), (MeshResource("dp", None)),
(MeshResource(None, 'tp')), (MeshResource(None, "tp")),
(MeshResource('dp', 'tp')), (MeshResource("dp", "tp")),
) )
# TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme # TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme
mesh_shape = (1, 1) mesh_shape = (1, 1)
devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
with jax.sharding.Mesh(devices, ('dp', 'tp')): with jax.sharding.Mesh(devices, ("dp", "tp")):
for sr in mesh_s: for sr in mesh_s:
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr): with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr):
self.assertTrue(FP8Helper.is_fp8_enabled()) self.assertTrue(FP8Helper.is_fp8_enabled())
......
...@@ -22,7 +22,7 @@ from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available ...@@ -22,7 +22,7 @@ from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
@pytest.fixture(autouse=True, scope='function') @pytest.fixture(autouse=True, scope="function")
def enable_fused_attn(): def enable_fused_attn():
"""Enable fused attention""" """Enable fused attention"""
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
...@@ -31,8 +31,8 @@ def enable_fused_attn(): ...@@ -31,8 +31,8 @@ def enable_fused_attn():
DATA_SHAPE = [ # (batch, seqlen, emb_dim) DATA_SHAPE = [ # (batch, seqlen, emb_dim)
pytest.param((32, 128, 1024), id='32-128-1024'), pytest.param((32, 128, 1024), id="32-128-1024"),
pytest.param((32, 512, 1024), id='32-512-1024'), pytest.param((32, 512, 1024), id="32-512-1024"),
] ]
DTYPE = [jnp.float32, jnp.bfloat16] DTYPE = [jnp.float32, jnp.bfloat16]
FP8_FORMATS = [Format.E4M3, Format.HYBRID] FP8_FORMATS = [Format.E4M3, Format.HYBRID]
...@@ -69,123 +69,138 @@ BASE_ATTRS = { ...@@ -69,123 +69,138 @@ BASE_ATTRS = {
_KEY_OF_ATTENTION_DROPOUT: 0, _KEY_OF_ATTENTION_DROPOUT: 0,
_KEY_OF_INTERMEDIATE_DROPOUT: 0, _KEY_OF_INTERMEDIATE_DROPOUT: 0,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal", _KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
_KEY_OF_LAYERNORM_TYPE: 'layernorm', _KEY_OF_LAYERNORM_TYPE: "layernorm",
} }
ATTRS = [{}, { ATTRS = [
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', {},
}, { {
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
},
{
_KEY_OF_ZERO_CENTERED_GAMMA: True, _KEY_OF_ZERO_CENTERED_GAMMA: True,
_KEY_OF_LAYERNORM_EPS: 1e-2, _KEY_OF_LAYERNORM_EPS: 1e-2,
}, { },
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_RESIDUAL_POST_LAYERNORM: True},
_KEY_OF_RESIDUAL_POST_LAYERNORM: True {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_OUTPUT_LAYERNORM: True},
}, { {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_OUTPUT_LAYERNORM: True
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_RESIDUAL_POST_LAYERNORM: True, _KEY_OF_RESIDUAL_POST_LAYERNORM: True,
_KEY_OF_OUTPUT_LAYERNORM: True _KEY_OF_OUTPUT_LAYERNORM: True,
}, { },
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_DROP_PATH: 0.1},
_KEY_OF_DROP_PATH: 0.1 {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_FUSE_QKV_PARAMS: False},
}, { {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_FUSE_QKV_PARAMS: False _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
}, { },
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', {
_KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'),
}, {
_KEY_OF_SCALE_ATTN_LOGITS: True, _KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_HIDDEN_DROPOUT: 0.8, _KEY_OF_HIDDEN_DROPOUT: 0.8,
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5, _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
_KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'), _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
_KEY_OF_USE_BIAS: True, _KEY_OF_USE_BIAS: True,
}, { },
{
_KEY_OF_TRANSPOSE_BS: False, _KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True, _KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'), _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
}, { },
{
_KEY_OF_NUM_HEADS: 8, _KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4, _KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_TRANSPOSE_BS: False, _KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True, _KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_MLP_ACTIVATIONS: ('gelu',), _KEY_OF_MLP_ACTIVATIONS: ("gelu",),
_KEY_OF_USE_BIAS: True, _KEY_OF_USE_BIAS: True,
}, { },
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', {
_KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')), _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
}, { _KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")),
},
{
_KEY_OF_SCALE_ATTN_LOGITS: True, _KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_HIDDEN_DROPOUT: 0.8, _KEY_OF_HIDDEN_DROPOUT: 0.8,
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5, _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
_KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')), _KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")),
_KEY_OF_USE_BIAS: True, _KEY_OF_USE_BIAS: True,
}, { },
{
_KEY_OF_TRANSPOSE_BS: False, _KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True, _KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')), _KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")),
}, { },
{
_KEY_OF_NUM_HEADS: 8, _KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4, _KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_TRANSPOSE_BS: False, _KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True, _KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm', _KEY_OF_LAYERNORM_TYPE: "layernorm",
_KEY_OF_MLP_ACTIVATIONS: (('silu',)), _KEY_OF_MLP_ACTIVATIONS: (("silu",)),
_KEY_OF_USE_BIAS: True, _KEY_OF_USE_BIAS: True,
}, { },
{
_KEY_OF_TRANSPOSE_BS: False, _KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_NUM_GQA_GROUPS: 1, _KEY_OF_NUM_GQA_GROUPS: 1,
_KEY_OF_ENABLE_ROPE: True, _KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "consecutive", _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True, _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
}, { },
{
_KEY_OF_TRANSPOSE_BS: True, _KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_ENABLE_ROPE: True, _KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "consecutive", _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
_KEY_OF_USE_BIAS: True, _KEY_OF_USE_BIAS: True,
}, { },
{
_KEY_OF_TRANSPOSE_BS: False, _KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: 'layernorm', _KEY_OF_LAYERNORM_TYPE: "layernorm",
_KEY_OF_NUM_GQA_GROUPS: 2, _KEY_OF_NUM_GQA_GROUPS: 2,
_KEY_OF_ENABLE_ROPE: True, _KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "alternate", _KEY_OF_ROPE_GROUP_METHOD: "alternate",
_KEY_OF_USE_BIAS: True, _KEY_OF_USE_BIAS: True,
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True, _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
}, { },
{
_KEY_OF_TRANSPOSE_BS: True, _KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_ENABLE_ROPE: True, _KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "alternate", _KEY_OF_ROPE_GROUP_METHOD: "alternate",
_KEY_OF_USE_BIAS: True, _KEY_OF_USE_BIAS: True,
}, { },
{
_KEY_OF_HIDDEN_DROPOUT: 0.3, _KEY_OF_HIDDEN_DROPOUT: 0.3,
_KEY_OF_HIDDEN_DROPOUT_DIMS: (0,), _KEY_OF_HIDDEN_DROPOUT_DIMS: (0,),
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5, _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,), _KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,),
}, { },
{
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding", _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_USE_BIAS: True, _KEY_OF_USE_BIAS: True,
}, { },
{
_KEY_OF_RELATIVE_EMBEDDING: False, _KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias", _KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias",
}, { },
{
_KEY_OF_ATTENTION_DROPOUT: 0.3, _KEY_OF_ATTENTION_DROPOUT: 0.3,
}, { },
_KEY_OF_MLP_ACTIVATIONS: (('relu', 'relu')), {
}] _KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")),
},
]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
class BaseRunner: class BaseRunner:
"""Base runner to define forward and backward tests""" """Base runner to define forward and backward tests"""
layer_type: TransformerLayerType = None layer_type: TransformerLayerType = None
reference_layer: flax.linen.Module = None reference_layer: flax.linen.Module = None
transformations: Dict[str, str] = None transformations: Dict[str, str] = None
...@@ -194,24 +209,24 @@ class BaseRunner: ...@@ -194,24 +209,24 @@ class BaseRunner:
self.attrs = attrs self.attrs = attrs
self._generate_test_rngs() self._generate_test_rngs()
# Disable fused attention for attention dropout because the different dropout impl # Disable fused attention for attention dropout because the different dropout impl
if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv('NVTE_FUSED_ATTN'): if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv("NVTE_FUSED_ATTN"):
os.environ['NVTE_FUSED_ATTN'] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
def _generate_test_rngs(self): def _generate_test_rngs(self):
root_rng = jax.random.PRNGKey(0) root_rng = jax.random.PRNGKey(0)
params_rng, init_dropout_rng, apply_dropout_rng = jax.random.split(root_rng, 3) params_rng, init_dropout_rng, apply_dropout_rng = jax.random.split(root_rng, 3)
self.init_rng = {'params': params_rng, 'dropout': init_dropout_rng} self.init_rng = {"params": params_rng, "dropout": init_dropout_rng}
self.apply_rng = {'dropout': apply_dropout_rng} self.apply_rng = {"dropout": apply_dropout_rng}
def _generate_layer(self, layer_cls, diff_inputs, no_diff_inputs): def _generate_layer(self, layer_cls, diff_inputs, no_diff_inputs):
layer = layer_cls() layer = layer_cls()
variables = layer.init(self.init_rng, *diff_inputs, *no_diff_inputs) variables = layer.init(self.init_rng, *diff_inputs, *no_diff_inputs)
others, params = flax.core.pop(variables, 'params') others, params = flax.core.pop(variables, "params")
del variables del variables
return layer, params, others return layer, params, others
def _loss_fn(self, diff_xs, no_diff_xs, params, others, model): def _loss_fn(self, diff_xs, no_diff_xs, params, others, model):
variables = {'params': params, **others} variables = {"params": params, **others}
output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng) output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng)
return jnp.mean(output, dtype=jnp.float32).astype(output.dtype) return jnp.mean(output, dtype=jnp.float32).astype(output.dtype)
...@@ -259,15 +274,18 @@ class BaseRunner: ...@@ -259,15 +274,18 @@ class BaseRunner:
) )
_, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME) _, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
test_others = FP8Helper.update_collections( test_others = FP8Helper.update_collections(
{FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others) {FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others
)
del tmp_grad, fp8_meta_grad del tmp_grad, fp8_meta_grad
grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False) grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False)
ref_out, (ref_dgrads, ref_wgrads) = grad_fn(inputs, ref_masks, ref_params, ref_others, ref_out, (ref_dgrads, ref_wgrads) = grad_fn(
ref_layer) inputs, ref_masks, ref_params, ref_others, ref_layer
test_out, (test_dgrads, test_wgrads) = grad_fn(inputs, test_masks, test_params, test_others, )
test_layer) test_out, (test_dgrads, test_wgrads) = grad_fn(
inputs, test_masks, test_params, test_others, test_layer
)
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_tree_like_allclose(ref_dgrads, test_dgrads, rtol=rtol, atol=atol) assert_tree_like_allclose(ref_dgrads, test_dgrads, rtol=rtol, atol=atol)
...@@ -278,19 +296,20 @@ class BaseRunner: ...@@ -278,19 +296,20 @@ class BaseRunner:
class EncoderRunner(BaseRunner): class EncoderRunner(BaseRunner):
"""Encoder runner implementations""" """Encoder runner implementations"""
layer_type = TransformerLayerType.ENCODER layer_type = TransformerLayerType.ENCODER
reference_layer = RefEncoderLayer reference_layer = RefEncoderLayer
transformations = { transformations = {
'attention/qkv/scale': 'pre_attention_layer_norm/scale', "attention/qkv/scale": "pre_attention_layer_norm/scale",
'attention/qkv/ln_bias': 'pre_attention_layer_norm/ln_bias', "attention/qkv/ln_bias": "pre_attention_layer_norm/ln_bias",
'attention/query/scale': 'pre_attention_layer_norm/scale', "attention/query/scale": "pre_attention_layer_norm/scale",
'attention/query/ln_bias': 'pre_attention_layer_norm/ln_bias', "attention/query/ln_bias": "pre_attention_layer_norm/ln_bias",
'mlp/wi_kernel': 'mlp/wi/kernel', "mlp/wi_kernel": "mlp/wi/kernel",
'mlp/wi_bias': 'mlp/wi/bias', "mlp/wi_bias": "mlp/wi/bias",
'mlp/wo_kernel': 'mlp/wo/kernel', "mlp/wo_kernel": "mlp/wo/kernel",
'mlp/wo_bias': 'mlp/wo/bias', "mlp/wo_bias": "mlp/wo/bias",
'mlp/scale': 'pre_mlp_layer_norm/scale', "mlp/scale": "pre_mlp_layer_norm/scale",
'mlp/ln_bias': 'pre_mlp_layer_norm/ln_bias', "mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
} }
def generate_inputs(self, data_shape, dtype): def generate_inputs(self, data_shape, dtype):
...@@ -307,7 +326,7 @@ class EncoderRunner(BaseRunner): ...@@ -307,7 +326,7 @@ class EncoderRunner(BaseRunner):
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1) causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ['casual', 'padding_causal']: if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["casual", "padding_causal"]:
mask = causal_mask mask = causal_mask
else: else:
mask = padded_mask mask = padded_mask
...@@ -322,23 +341,24 @@ class DecoderRunner(BaseRunner): ...@@ -322,23 +341,24 @@ class DecoderRunner(BaseRunner):
""" """
Decoder runner implementations Decoder runner implementations
""" """
layer_type = TransformerLayerType.DECODER layer_type = TransformerLayerType.DECODER
reference_layer = RefDecoderLayer reference_layer = RefDecoderLayer
transformations = { transformations = {
'encoder_decoder_attention/qkv/scale': 'pre_cross_attention_layer_norm/scale', "encoder_decoder_attention/qkv/scale": "pre_cross_attention_layer_norm/scale",
'encoder_decoder_attention/qkv/ln_bias': 'pre_cross_attention_layer_norm/ln_bias', "encoder_decoder_attention/qkv/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
'encoder_decoder_attention/query/scale': 'pre_cross_attention_layer_norm/scale', "encoder_decoder_attention/query/scale": "pre_cross_attention_layer_norm/scale",
'encoder_decoder_attention/query/ln_bias': 'pre_cross_attention_layer_norm/ln_bias', "encoder_decoder_attention/query/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
'self_attention/qkv/scale': 'pre_self_attention_layer_norm/scale', "self_attention/qkv/scale": "pre_self_attention_layer_norm/scale",
'self_attention/qkv/ln_bias': 'pre_self_attention_layer_norm/ln_bias', "self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias",
'self_attention/query/scale': 'pre_self_attention_layer_norm/scale', "self_attention/query/scale": "pre_self_attention_layer_norm/scale",
'self_attention/query/ln_bias': 'pre_self_attention_layer_norm/ln_bias', "self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias",
'mlp/wi_kernel': 'mlp/wi/kernel', "mlp/wi_kernel": "mlp/wi/kernel",
'mlp/wi_bias': 'mlp/wi/bias', "mlp/wi_bias": "mlp/wi/bias",
'mlp/wo_kernel': 'mlp/wo/kernel', "mlp/wo_kernel": "mlp/wo/kernel",
'mlp/wo_bias': 'mlp/wo/bias', "mlp/wo_bias": "mlp/wo/bias",
'mlp/scale': 'pre_mlp_layer_norm/scale', "mlp/scale": "pre_mlp_layer_norm/scale",
'mlp/ln_bias': 'pre_mlp_layer_norm/ln_bias', "mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
} }
def generate_inputs(self, data_shape, dtype): def generate_inputs(self, data_shape, dtype):
...@@ -352,12 +372,14 @@ class DecoderRunner(BaseRunner): ...@@ -352,12 +372,14 @@ class DecoderRunner(BaseRunner):
data_rng = jax.random.PRNGKey(0) data_rng = jax.random.PRNGKey(0)
data_rng_0, data_rng_1 = jax.random.split(data_rng, 2) data_rng_0, data_rng_1 = jax.random.split(data_rng, 2)
inputs = (jax.random.normal(data_rng_0, data_shape, inputs = (
dtype), jax.random.normal(data_rng_1, data_shape, dtype)) jax.random.normal(data_rng_0, data_shape, dtype),
jax.random.normal(data_rng_1, data_shape, dtype),
)
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1) causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ['casual', 'padding_causal']: if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["casual", "padding_causal"]:
self_mask = causal_mask self_mask = causal_mask
else: else:
self_mask = padded_mask self_mask = padded_mask
...@@ -368,13 +390,14 @@ class DecoderRunner(BaseRunner): ...@@ -368,13 +390,14 @@ class DecoderRunner(BaseRunner):
return inputs, (ref_masks, test_masks) return inputs, (ref_masks, test_masks)
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', ATTRS) @pytest.mark.parametrize("attrs", ATTRS)
class BaseTester(): class BaseTester:
""" """
Pytest interface to invoke the runner Pytest interface to invoke the runner
""" """
runner = BaseRunner runner = BaseRunner
def test_forward(self, data_shape, dtype, attrs): def test_forward(self, data_shape, dtype, attrs):
...@@ -388,7 +411,7 @@ class BaseTester(): ...@@ -388,7 +411,7 @@ class BaseTester():
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-5, atol=7e-5) self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-5, atol=7e-5)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS) @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_format): def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_format):
"""Test forward with fp8 enabled""" """Test forward with fp8 enabled"""
FP8Helper.initialize(fp8_format=fp8_format) FP8Helper.initialize(fp8_format=fp8_format)
...@@ -396,7 +419,7 @@ class BaseTester(): ...@@ -396,7 +419,7 @@ class BaseTester():
FP8Helper.finalize() FP8Helper.finalize()
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS) @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_format): def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_format):
"""Test backward with fp8 enabled""" """Test backward with fp8 enabled"""
FP8Helper.initialize(fp8_format=fp8_format) FP8Helper.initialize(fp8_format=fp8_format)
...@@ -408,6 +431,7 @@ class TestEncoderLayer(BaseTester): ...@@ -408,6 +431,7 @@ class TestEncoderLayer(BaseTester):
""" """
Test transformer_engine.jax.flax.TransformerLayer(layer_type=Encoder) Test transformer_engine.jax.flax.TransformerLayer(layer_type=Encoder)
""" """
runner = EncoderRunner runner = EncoderRunner
...@@ -415,4 +439,5 @@ class TestDecoderLayer(BaseTester): ...@@ -415,4 +439,5 @@ class TestDecoderLayer(BaseTester):
""" """
Test transformer_engine.jax.flax.TransformerLayer(layer_type=Decoder) Test transformer_engine.jax.flax.TransformerLayer(layer_type=Decoder)
""" """
runner = DecoderRunner runner = DecoderRunner
...@@ -43,7 +43,7 @@ ENABLE_FP8 = [False, True] ...@@ -43,7 +43,7 @@ ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID] FP8_FORMATS = [Format.E4M3, Format.HYBRID]
@pytest.fixture(autouse=True, scope='module') @pytest.fixture(autouse=True, scope="module")
def enable_fused_attn(): def enable_fused_attn():
""" """
Enable fused attn for hopper+ arch. Enable fused attn for hopper+ arch.
...@@ -58,19 +58,16 @@ def enable_fused_attn(): ...@@ -58,19 +58,16 @@ def enable_fused_attn():
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd: for key in ref_fd:
assert key in test_fd, \ assert key in test_fd, f"{key} not found in test dict {test_fd}"
f"{key} not found in test dict {test_fd}" assert isinstance(
assert isinstance(test_fd[key], type(ref_fd[key])), \ test_fd[key], type(ref_fd[key])
f"The data type is not match between ref and test " \ ), f"The data type is not match between ref and test Dict on {key=}"
f" Dict on {key=}"
if isinstance(ref_fd[key], Dict): if isinstance(ref_fd[key], Dict):
compare_dict(ref_fd[key], test_fd[key], rtol, atol) compare_dict(ref_fd[key], test_fd[key], rtol, atol)
else: else:
assert_allclose(ref_fd[key], assert_allclose(
test_fd[key], ref_fd[key], test_fd[key], rtol=rtol, atol=atol, err_msg=f"{key=} is not close"
rtol=rtol, )
atol=atol,
err_msg=f"{key=} is not close")
class TestLayer: class TestLayer:
...@@ -105,9 +102,10 @@ class TestLayer: ...@@ -105,9 +102,10 @@ class TestLayer:
lyr_name = self.get_layer_name() lyr_name = self.get_layer_name()
if 'params' in flax_variables: if "params" in flax_variables:
synced_praxis_variables['params'][lyr_name]['cld'] = \ synced_praxis_variables["params"][lyr_name]["cld"] = flax.core.unfreeze(
flax.core.unfreeze(flax_variables['params']) flax_variables["params"]
)
return synced_praxis_variables, flax_variables return synced_praxis_variables, flax_variables
...@@ -116,23 +114,19 @@ class TestLayer: ...@@ -116,23 +114,19 @@ class TestLayer:
lyr_name = self.get_layer_name() lyr_name = self.get_layer_name()
if 'params' in synced_praxis_grads: if "params" in synced_praxis_grads:
synced_praxis_grads['params'] = \ synced_praxis_grads["params"] = synced_praxis_grads["params"][lyr_name]["cld"]
synced_praxis_grads['params'][lyr_name]['cld']
if FP8Helper.is_fp8_enabled(): if FP8Helper.is_fp8_enabled():
synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = \ synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = synced_praxis_grads[
synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME][lyr_name]['cld'] FP8Helper.FP8_COLLECTION_NAME
][lyr_name]["cld"]
return synced_praxis_grads, flax.core.unfreeze(flax_wgrads) return synced_praxis_grads, flax.core.unfreeze(flax_wgrads)
def forward_backward_runner(self, def forward_backward_runner(
data_shape, self, data_shape, dtype, praxis_p, flax_cls, rtol=1e-05, atol=1e-08
dtype, ):
praxis_p,
flax_cls,
rtol=1e-05,
atol=1e-08):
init_key = jax.random.PRNGKey(seed=1234) init_key = jax.random.PRNGKey(seed=1234)
test_inputs = self.input_getter(data_shape, dtype) test_inputs = self.input_getter(data_shape, dtype)
...@@ -148,28 +142,33 @@ class TestLayer: ...@@ -148,28 +142,33 @@ class TestLayer:
if "params_axes" in flax_variables: if "params_axes" in flax_variables:
flax_variables, _ = flax.core.pop(flax_variables, "params_axes") flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
if FP8Helper.is_fp8_enabled(): if FP8Helper.is_fp8_enabled():
flax_variables, _ = flax.core.pop(flax_variables, flax_variables, _ = flax.core.pop(
FP8Helper.FP8_COLLECTION_NAME + "_axes") flax_variables, FP8Helper.FP8_COLLECTION_NAME + "_axes"
)
praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables) praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)
iter_times = 5 if FP8Helper.is_fp8_enabled() else 1 iter_times = 5 if FP8Helper.is_fp8_enabled() else 1
for _ in range(iter_times): for _ in range(iter_times):
praxis_loss, praxis_wgrads, praxis_dgrad = \ praxis_loss, praxis_wgrads, praxis_dgrad = TestLayer.loss_and_grads(
TestLayer.loss_and_grads(praxis_layer, praxis_variables, *test_inputs) praxis_layer, praxis_variables, *test_inputs
flax_loss, flax_wgrads, flax_dgrad = \ )
TestLayer.loss_and_grads(flax_layer, flax_variables, *test_inputs) flax_loss, flax_wgrads, flax_dgrad = TestLayer.loss_and_grads(
flax_layer, flax_variables, *test_inputs
)
if FP8Helper.is_fp8_enabled(): if FP8Helper.is_fp8_enabled():
praxis_wgrads.pop('params') praxis_wgrads.pop("params")
praxis_variables = update_collections(praxis_wgrads, praxis_variables) praxis_variables = update_collections(praxis_wgrads, praxis_variables)
flax_wgrads, _ = flax.core.pop(flax_wgrads, 'params') flax_wgrads, _ = flax.core.pop(flax_wgrads, "params")
flax_variables = update_collections(flax_wgrads, flax_variables) flax_variables = update_collections(flax_wgrads, flax_variables)
praxis_loss, praxis_wgrads, praxis_dgrad = \ praxis_loss, praxis_wgrads, praxis_dgrad = TestLayer.loss_and_grads(
TestLayer.loss_and_grads(praxis_layer, praxis_variables, *test_inputs) praxis_layer, praxis_variables, *test_inputs
flax_loss, flax_wgrads, flax_dgrad = \ )
TestLayer.loss_and_grads(flax_layer, flax_variables, *test_inputs) flax_loss, flax_wgrads, flax_dgrad = TestLayer.loss_and_grads(
flax_layer, flax_variables, *test_inputs
)
assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol) assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)
assert_allclose(praxis_dgrad, flax_dgrad, rtol=rtol, atol=atol) assert_allclose(praxis_dgrad, flax_dgrad, rtol=rtol, atol=atol)
...@@ -179,18 +178,13 @@ class TestLayer: ...@@ -179,18 +178,13 @@ class TestLayer:
class LayerNormAttr: class LayerNormAttr:
LN_TYPE = 'layernorm_type' LN_TYPE = "layernorm_type"
ZERO_CEN = 'zero_centered_gamma' ZERO_CEN = "zero_centered_gamma"
ATTRS = [{ ATTRS = [
LN_TYPE: "layernorm", {LN_TYPE: "layernorm", ZERO_CEN: False},
ZERO_CEN: False {LN_TYPE: "layernorm", ZERO_CEN: True},
}, { {LN_TYPE: "rmsnorm", ZERO_CEN: False},
LN_TYPE: "layernorm", ]
ZERO_CEN: True
}, {
LN_TYPE: "rmsnorm",
ZERO_CEN: False
}]
class TestLayerNorm(TestLayer): class TestLayerNorm(TestLayer):
...@@ -200,7 +194,7 @@ class TestLayerNorm(TestLayer): ...@@ -200,7 +194,7 @@ class TestLayerNorm(TestLayer):
return (jax.random.normal(data_key, shape, dtype),) return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self): def get_layer_name(self):
return 'layer_norm' return "layer_norm"
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
layernorm_type = attrs[LayerNormAttr.LN_TYPE] layernorm_type = attrs[LayerNormAttr.LN_TYPE]
...@@ -209,63 +203,59 @@ class TestLayerNorm(TestLayer): ...@@ -209,63 +203,59 @@ class TestLayerNorm(TestLayer):
bias_init = WeightInit.Constant(0.0) bias_init = WeightInit.Constant(0.0)
transpose_batch_sequence = False transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(LayerNorm, praxis_p = pax_fiddle.Config(
name='layer_norm', LayerNorm,
name="layer_norm",
dtype=dtype, dtype=dtype,
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
scale_init=scale_init, scale_init=scale_init,
bias_init=bias_init, bias_init=bias_init,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence,
flax_cls = partial(flax_LayerNorm, )
flax_cls = partial(
flax_LayerNorm,
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
scale_init=scale_init, scale_init=scale_init,
bias_init=TransformerEngineBaseLayer.generate_params_init( bias_init=TransformerEngineBaseLayer.generate_params_init("ln_bias", bias_init),
"ln_bias", bias_init),
dtype=dtype, dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', LayerNormAttr.ATTRS) @pytest.mark.parametrize("attrs", LayerNormAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class FusedSoftmaxAttr: class FusedSoftmaxAttr:
SCALE_FACTOR = 'scale_factor' SCALE_FACTOR = "scale_factor"
ST_TYPE = 'softmax_type' ST_TYPE = "softmax_type"
ATTRS = [{ ATTRS = [
SCALE_FACTOR: 0.0, {SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED},
ST_TYPE: SoftmaxType.SCALED {SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED_MASKED},
}, { {SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED_UPPER_TRIANG_MASKED},
SCALE_FACTOR: 0.0, ]
ST_TYPE: SoftmaxType.SCALED_MASKED
}, {
SCALE_FACTOR: 0.0,
ST_TYPE: SoftmaxType.SCALED_UPPER_TRIANG_MASKED
}]
class TestFusedSoftmax(TestLayer): class TestFusedSoftmax(TestLayer):
def input_getter(self, shape, dtype): def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234) data_key = jax.random.PRNGKey(seed=1234)
return jax.random.normal(data_key, shape, dtype), \ return jax.random.normal(data_key, shape, dtype), jnp.ones(shape, dtype=jnp.uint8) # Masks
jnp.ones(shape, dtype=jnp.uint8) # Masks
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
scale_factor = attrs[FusedSoftmaxAttr.SCALE_FACTOR] scale_factor = attrs[FusedSoftmaxAttr.SCALE_FACTOR]
softmax_type = attrs[FusedSoftmaxAttr.ST_TYPE] softmax_type = attrs[FusedSoftmaxAttr.ST_TYPE]
praxis_p = pax_fiddle.Config(FusedSoftmax, praxis_p = pax_fiddle.Config(
name='fused_softmax', FusedSoftmax, name="fused_softmax", scale_factor=scale_factor, softmax_type=softmax_type
scale_factor=scale_factor, )
softmax_type=softmax_type)
flax_cls = partial(Softmax, scale_factor=scale_factor, softmax_type=softmax_type) flax_cls = partial(Softmax, scale_factor=scale_factor, softmax_type=softmax_type)
return praxis_p, flax_cls return praxis_p, flax_cls
...@@ -276,12 +266,13 @@ class TestFusedSoftmax(TestLayer): ...@@ -276,12 +266,13 @@ class TestFusedSoftmax(TestLayer):
def sync_wgrads(self, praxis_wgrads, flax_wgrads): def sync_wgrads(self, praxis_wgrads, flax_wgrads):
return praxis_wgrads, flax_wgrads return praxis_wgrads, flax_wgrads
@pytest.mark.parametrize('data_shape', [(32, 1, 128, 128), (32, 1, 512, 128)]) @pytest.mark.parametrize("data_shape", [(32, 1, 128, 128), (32, 1, 512, 128)])
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', FusedSoftmaxAttr.ATTRS) @pytest.mark.parametrize("attrs", FusedSoftmaxAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
if (attrs[FusedSoftmaxAttr.ST_TYPE] == SoftmaxType.SCALED_UPPER_TRIANG_MASKED) and \ if (attrs[FusedSoftmaxAttr.ST_TYPE] == SoftmaxType.SCALED_UPPER_TRIANG_MASKED) and (
(data_shape[-2] != data_shape[-1]): data_shape[-2] != data_shape[-1]
):
pass # Skip, due to not support pass # Skip, due to not support
else: else:
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
...@@ -289,21 +280,14 @@ class TestFusedSoftmax(TestLayer): ...@@ -289,21 +280,14 @@ class TestFusedSoftmax(TestLayer):
class LinearAttr: class LinearAttr:
FEATURE = 'features' FEATURE = "features"
USE_BIAS = 'use_bias' USE_BIAS = "use_bias"
ATTRS = [{ ATTRS = [
FEATURE: 512, {FEATURE: 512, USE_BIAS: False},
USE_BIAS: False {FEATURE: 512, USE_BIAS: True},
}, { {FEATURE: 1024, USE_BIAS: False},
FEATURE: 512, {FEATURE: 1024, USE_BIAS: True},
USE_BIAS: True ]
}, {
FEATURE: 1024,
USE_BIAS: False
}, {
FEATURE: 1024,
USE_BIAS: True
}]
class TestLinear(TestLayer): class TestLinear(TestLayer):
...@@ -313,7 +297,7 @@ class TestLinear(TestLayer): ...@@ -313,7 +297,7 @@ class TestLinear(TestLayer):
return (jax.random.normal(data_key, shape, dtype),) return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self): def get_layer_name(self):
return 'linear' return "linear"
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
out_features = attrs[LinearAttr.FEATURE] out_features = attrs[LinearAttr.FEATURE]
...@@ -323,15 +307,17 @@ class TestLinear(TestLayer): ...@@ -323,15 +307,17 @@ class TestLinear(TestLayer):
axis = -1 axis = -1
transpose_batch_sequence = False transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(Linear, praxis_p = pax_fiddle.Config(
name='linear', Linear,
name="linear",
dtype=dtype, dtype=dtype,
out_features=out_features, out_features=out_features,
params_init=kernel_init, params_init=kernel_init,
use_bias=use_bias, use_bias=use_bias,
bias_init=bias_init, bias_init=bias_init,
axis=axis, axis=axis,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence,
)
flax_cls = partial( flax_cls = partial(
DenseGeneral, DenseGeneral,
features=out_features, features=out_features,
...@@ -340,29 +326,26 @@ class TestLinear(TestLayer): ...@@ -340,29 +326,26 @@ class TestLinear(TestLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init), bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
axis=axis, axis=axis,
dtype=dtype, dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', LinearAttr.ATTRS) @pytest.mark.parametrize("attrs", LinearAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', LinearAttr.ATTRS) @pytest.mark.parametrize("attrs", LinearAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS) @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(self, def test_forward_backward_fp8(
data_shape, self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
dtype, ):
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
ds = DelayedScaling(fp8_format=fp8_format) ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
...@@ -371,54 +354,20 @@ class TestLinear(TestLayer): ...@@ -371,54 +354,20 @@ class TestLinear(TestLayer):
class LayerNormLinearAttr: class LayerNormLinearAttr:
FEATURE = 'features' FEATURE = "features"
USE_BIAS = 'use_bias' USE_BIAS = "use_bias"
ENABLE_LN = 'enable_layernorm' ENABLE_LN = "enable_layernorm"
LN_TYPE = 'layernorm_type' LN_TYPE = "layernorm_type"
ZERO_CEN = 'zero_centered_gamma' ZERO_CEN = "zero_centered_gamma"
ATTRS = [{ ATTRS = [
FEATURE: 512, {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: False},
USE_BIAS: True, {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: False},
ENABLE_LN: True, {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: True},
LN_TYPE: 'layernorm', {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: True},
ZERO_CEN: False {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "rmsnorm", ZERO_CEN: False},
}, { {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "rmsnorm", ZERO_CEN: False},
FEATURE: 512, {FEATURE: 512, USE_BIAS: True, ENABLE_LN: False, LN_TYPE: "layernorm", ZERO_CEN: False},
USE_BIAS: True, ]
ENABLE_LN: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: False,
LN_TYPE: 'layernorm',
ZERO_CEN: False
}]
class TestLayerNormLinear(TestLayer): class TestLayerNormLinear(TestLayer):
...@@ -428,7 +377,7 @@ class TestLayerNormLinear(TestLayer): ...@@ -428,7 +377,7 @@ class TestLayerNormLinear(TestLayer):
return (jax.random.normal(data_key, shape, dtype),) return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self): def get_layer_name(self):
return 'ln_linear' return "ln_linear"
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
out_features = attrs[LayerNormLinearAttr.FEATURE] out_features = attrs[LayerNormLinearAttr.FEATURE]
...@@ -441,8 +390,9 @@ class TestLayerNormLinear(TestLayer): ...@@ -441,8 +390,9 @@ class TestLayerNormLinear(TestLayer):
axis = -1 axis = -1
transpose_batch_sequence = False transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(LayerNormLinear, praxis_p = pax_fiddle.Config(
name='ln_linear', LayerNormLinear,
name="ln_linear",
dtype=dtype, dtype=dtype,
out_features=out_features, out_features=out_features,
enable_layernorm=enable_layernorm, enable_layernorm=enable_layernorm,
...@@ -452,7 +402,8 @@ class TestLayerNormLinear(TestLayer): ...@@ -452,7 +402,8 @@ class TestLayerNormLinear(TestLayer):
use_bias=use_bias, use_bias=use_bias,
bias_init=bias_init, bias_init=bias_init,
axis=axis, axis=axis,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence,
)
flax_cls = partial( flax_cls = partial(
LayerNormDenseGeneral, LayerNormDenseGeneral,
features=out_features, features=out_features,
...@@ -464,29 +415,26 @@ class TestLayerNormLinear(TestLayer): ...@@ -464,29 +415,26 @@ class TestLayerNormLinear(TestLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init), bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
axis=axis, axis=axis,
dtype=dtype, dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', LayerNormLinearAttr.ATTRS) @pytest.mark.parametrize("attrs", LayerNormLinearAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', LayerNormLinearAttr.ATTRS) @pytest.mark.parametrize("attrs", LayerNormLinearAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS) @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(self, def test_forward_backward_fp8(
data_shape, self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
dtype, ):
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
ds = DelayedScaling(fp8_format=fp8_format) ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
...@@ -495,62 +443,70 @@ class TestLayerNormLinear(TestLayer): ...@@ -495,62 +443,70 @@ class TestLayerNormLinear(TestLayer):
class LayerNormMLPAttr: class LayerNormMLPAttr:
INTERMEDIATE_DIM = 'intermediate_dim' INTERMEDIATE_DIM = "intermediate_dim"
USE_BIAS = 'use_bias' USE_BIAS = "use_bias"
ENABLE_LN = 'enable_layernorm' ENABLE_LN = "enable_layernorm"
LN_TYPE = 'layernorm_type' LN_TYPE = "layernorm_type"
ZERO_CEN = 'zero_centered_gamma' ZERO_CEN = "zero_centered_gamma"
ACTIVATION = 'activations' ACTIVATION = "activations"
ATTRS = [{ ATTRS = [
{
INTERMEDIATE_DIM: 2048, INTERMEDIATE_DIM: 2048,
USE_BIAS: True, USE_BIAS: True,
ENABLE_LN: True, ENABLE_LN: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',) ACTIVATION: ("relu",),
}, { },
{
INTERMEDIATE_DIM: 2048, INTERMEDIATE_DIM: 2048,
USE_BIAS: True, USE_BIAS: True,
ENABLE_LN: True, ENABLE_LN: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: True, ZERO_CEN: True,
ACTIVATION: ('relu',) ACTIVATION: ("relu",),
}, { },
{
INTERMEDIATE_DIM: 2048, INTERMEDIATE_DIM: 2048,
USE_BIAS: True, USE_BIAS: True,
ENABLE_LN: True, ENABLE_LN: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "rmsnorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',) ACTIVATION: ("relu",),
}, { },
{
INTERMEDIATE_DIM: 2048, INTERMEDIATE_DIM: 2048,
USE_BIAS: True, USE_BIAS: True,
ENABLE_LN: True, ENABLE_LN: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "rmsnorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear') ACTIVATION: ("gelu", "linear"),
}, { },
{
INTERMEDIATE_DIM: 2048, INTERMEDIATE_DIM: 2048,
USE_BIAS: False, USE_BIAS: False,
ENABLE_LN: True, ENABLE_LN: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "rmsnorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear') ACTIVATION: ("gelu", "linear"),
}, { },
{
INTERMEDIATE_DIM: 2048, INTERMEDIATE_DIM: 2048,
USE_BIAS: True, USE_BIAS: True,
ENABLE_LN: True, ENABLE_LN: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "rmsnorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('silu', 'linear') ACTIVATION: ("silu", "linear"),
}, { },
{
INTERMEDIATE_DIM: 2048, INTERMEDIATE_DIM: 2048,
USE_BIAS: False, USE_BIAS: False,
ENABLE_LN: True, ENABLE_LN: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "rmsnorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('silu', 'linear') ACTIVATION: ("silu", "linear"),
}] },
]
class TestLayerNormMLP(TestLayer): class TestLayerNormMLP(TestLayer):
...@@ -560,7 +516,7 @@ class TestLayerNormMLP(TestLayer): ...@@ -560,7 +516,7 @@ class TestLayerNormMLP(TestLayer):
return (jax.random.normal(data_key, shape, dtype),) return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self): def get_layer_name(self):
return 'ln_mlp' return "ln_mlp"
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
intermediate_dim = attrs[LayerNormMLPAttr.INTERMEDIATE_DIM] intermediate_dim = attrs[LayerNormMLPAttr.INTERMEDIATE_DIM]
...@@ -574,8 +530,9 @@ class TestLayerNormMLP(TestLayer): ...@@ -574,8 +530,9 @@ class TestLayerNormMLP(TestLayer):
axis = -1 axis = -1
transpose_batch_sequence = False transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(LayerNormMLP, praxis_p = pax_fiddle.Config(
name='ln_mlp', LayerNormMLP,
name="ln_mlp",
dtype=dtype, dtype=dtype,
intermediate_dim=intermediate_dim, intermediate_dim=intermediate_dim,
enable_layernorm=enable_layernorm, enable_layernorm=enable_layernorm,
...@@ -587,7 +544,8 @@ class TestLayerNormMLP(TestLayer): ...@@ -587,7 +544,8 @@ class TestLayerNormMLP(TestLayer):
activations=activations, activations=activations,
intermediate_dropout_rate=0.0, intermediate_dropout_rate=0.0,
axis=axis, axis=axis,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence,
)
flax_cls = partial( flax_cls = partial(
flax_LayerNormMLP, flax_LayerNormMLP,
intermediate_dim=intermediate_dim, intermediate_dim=intermediate_dim,
...@@ -601,29 +559,26 @@ class TestLayerNormMLP(TestLayer): ...@@ -601,29 +559,26 @@ class TestLayerNormMLP(TestLayer):
intermediate_dropout_rate=0.0, intermediate_dropout_rate=0.0,
axis=axis, axis=axis,
dtype=dtype, dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', LayerNormMLPAttr.ATTRS) @pytest.mark.parametrize("attrs", LayerNormMLPAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', LayerNormMLPAttr.ATTRS) @pytest.mark.parametrize("attrs", LayerNormMLPAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS) @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(self, def test_forward_backward_fp8(
data_shape, self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
dtype, ):
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
ds = DelayedScaling(fp8_format=fp8_format) ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
...@@ -634,35 +589,40 @@ class TestLayerNormMLP(TestLayer): ...@@ -634,35 +589,40 @@ class TestLayerNormMLP(TestLayer):
class TestRelativePositionBias(TestLayer): class TestRelativePositionBias(TestLayer):
def get_layer_name(self): def get_layer_name(self):
return 'relative_position_bias' return "relative_position_bias"
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
num_buckets = 32 num_buckets = 32
max_distance = 128 max_distance = 128
num_attention_heads = 64 num_attention_heads = 64
rb_stddev = (num_attention_heads * num_buckets)**-0.5 rb_stddev = (num_attention_heads * num_buckets) ** -0.5
embedding_init = WeightInit.Gaussian(rb_stddev) embedding_init = WeightInit.Gaussian(rb_stddev)
praxis_p = pax_fiddle.Config(RelativePositionBiases, praxis_p = pax_fiddle.Config(
name='relative_position_bias', RelativePositionBiases,
name="relative_position_bias",
dtype=dtype, dtype=dtype,
num_buckets=num_buckets, num_buckets=num_buckets,
max_distance=max_distance, max_distance=max_distance,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
embedding_init=embedding_init) embedding_init=embedding_init,
flax_cls = partial(flax_RelativePositionBiases, )
flax_cls = partial(
flax_RelativePositionBiases,
num_buckets=num_buckets, num_buckets=num_buckets,
max_distance=max_distance, max_distance=max_distance,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init( embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", embedding_init), "rel_embedding", embedding_init
dtype=dtype) ),
dtype=dtype,
)
return praxis_p, flax_cls return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', [{}]) @pytest.mark.parametrize("attrs", [{}])
def test_forward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
...@@ -678,53 +638,64 @@ class TestRelativePositionBias(TestLayer): ...@@ -678,53 +638,64 @@ class TestRelativePositionBias(TestLayer):
if "params_axes" in flax_variables: if "params_axes" in flax_variables:
flax_variables, _ = flax.core.pop(flax_variables, "params_axes") flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
if FP8Helper.is_fp8_enabled(): if FP8Helper.is_fp8_enabled():
flax_variables, _ = flax.core.pop(flax_variables, flax_variables, _ = flax.core.pop(
FP8Helper.FP8_COLLECTION_NAME + "_axes") flax_variables, FP8Helper.FP8_COLLECTION_NAME + "_axes"
)
praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables) praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)
praxis_loss= \ praxis_loss = TestLayer.loss(
TestLayer.loss(praxis_variables, *test_input, module=praxis_layer, mean_out=False) praxis_variables, *test_input, module=praxis_layer, mean_out=False
flax_loss = \ )
TestLayer.loss(flax_variables, *test_input, module=flax_layer, mean_out=False) flax_loss = TestLayer.loss(
flax_variables, *test_input, module=flax_layer, mean_out=False
)
assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol) assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)
class DotProductAttnAttr: class DotProductAttnAttr:
ATTN_MASK_TYPE = 'attn_mask_type' ATTN_MASK_TYPE = "attn_mask_type"
NUM_GQA_GROUPS = 'num_gqa_groups' NUM_GQA_GROUPS = "num_gqa_groups"
TRANSPOSE_BS = 'transpose_batch_sequence' TRANSPOSE_BS = "transpose_batch_sequence"
SCALE_FACTOR = 'scale_factor' SCALE_FACTOR = "scale_factor"
ATTRS = [{ ATTRS = [
ATTN_MASK_TYPE: 'padding', {
ATTN_MASK_TYPE: "padding",
TRANSPOSE_BS: True, TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125, SCALE_FACTOR: 0.125,
}, { },
ATTN_MASK_TYPE: 'padding_causal', {
ATTN_MASK_TYPE: "padding_causal",
TRANSPOSE_BS: True, TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125, SCALE_FACTOR: 0.125,
}, { },
ATTN_MASK_TYPE: 'causal', {
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: True, TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125, SCALE_FACTOR: 0.125,
}, { },
ATTN_MASK_TYPE: 'padding', {
ATTN_MASK_TYPE: "padding",
TRANSPOSE_BS: False, TRANSPOSE_BS: False,
SCALE_FACTOR: 0.125, SCALE_FACTOR: 0.125,
}, { },
ATTN_MASK_TYPE: 'padding_causal', {
ATTN_MASK_TYPE: "padding_causal",
TRANSPOSE_BS: False, TRANSPOSE_BS: False,
SCALE_FACTOR: 2., SCALE_FACTOR: 2.0,
}, { },
ATTN_MASK_TYPE: 'causal', {
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: False, TRANSPOSE_BS: False,
SCALE_FACTOR: 1., SCALE_FACTOR: 1.0,
}, { },
ATTN_MASK_TYPE: 'no_mask', {
ATTN_MASK_TYPE: "no_mask",
TRANSPOSE_BS: False, TRANSPOSE_BS: False,
SCALE_FACTOR: 1., SCALE_FACTOR: 1.0,
}] },
]
class TestDotProductAttn(TestLayer): class TestDotProductAttn(TestLayer):
...@@ -737,11 +708,12 @@ class TestDotProductAttn(TestLayer): ...@@ -737,11 +708,12 @@ class TestDotProductAttn(TestLayer):
shape = (shape[1], shape[0]) + shape[2:] shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8) mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [ return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]),
mask,
] ]
def get_layer_name(self): def get_layer_name(self):
return 'dot_product_attn' return "dot_product_attn"
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
head_dim = 64 head_dim = 64
...@@ -750,27 +722,31 @@ class TestDotProductAttn(TestLayer): ...@@ -750,27 +722,31 @@ class TestDotProductAttn(TestLayer):
attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE] attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE]
transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS] transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS]
praxis_p = pax_fiddle.Config(DotProductAttention, praxis_p = pax_fiddle.Config(
name='mha', DotProductAttention,
name="mha",
dtype=dtype, dtype=dtype,
head_dim=head_dim, head_dim=head_dim,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups, num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence,
flax_cls = partial(flax_DotProductAttention, )
flax_cls = partial(
flax_DotProductAttention,
dtype=dtype, dtype=dtype,
head_dim=head_dim, head_dim=head_dim,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups, num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', [(32, 128, 16, 64)]) @pytest.mark.parametrize("data_shape", [(32, 128, 16, 64)])
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS) @pytest.mark.parametrize("attrs", DotProductAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
...@@ -778,113 +754,125 @@ class TestDotProductAttn(TestLayer): ...@@ -778,113 +754,125 @@ class TestDotProductAttn(TestLayer):
class MultiHeadAttnAttr: class MultiHeadAttnAttr:
USE_BIAS = 'use_bias' USE_BIAS = "use_bias"
LN_TYPE = 'layernorm_type' LN_TYPE = "layernorm_type"
ATTN_MASK_TYPE = 'attn_mask_type' ATTN_MASK_TYPE = "attn_mask_type"
ZERO_CEN = 'zero_centered_gamma' ZERO_CEN = "zero_centered_gamma"
NUM_ATTN_HEADS = 'num_attention_heads' NUM_ATTN_HEADS = "num_attention_heads"
NUM_GQA_GROUPS = 'num_gqa_groups' NUM_GQA_GROUPS = "num_gqa_groups"
TRANSPOSE_BS = 'transpose_batch_sequence' TRANSPOSE_BS = "transpose_batch_sequence"
ENABLE_ROPE = 'enable_rotary_pos_emb' ENABLE_ROPE = "enable_rotary_pos_emb"
ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method' ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
LORA_SCOPE = 'low_rank_adaptation_scope' LORA_SCOPE = "low_rank_adaptation_scope"
ATTRS = [{ ATTRS = [
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: False, ZERO_CEN: False,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: 'padding', ATTN_MASK_TYPE: "padding",
TRANSPOSE_BS: True, TRANSPOSE_BS: True,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: True, ZERO_CEN: True,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: 'padding', ATTN_MASK_TYPE: "padding",
TRANSPOSE_BS: False, TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "rmsnorm",
ZERO_CEN: False, ZERO_CEN: False,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: 'padding', ATTN_MASK_TYPE: "padding",
TRANSPOSE_BS: True, TRANSPOSE_BS: True,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: False, ZERO_CEN: False,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: 'causal', ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: False, TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: True, ZERO_CEN: True,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: 'causal', ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: True, TRANSPOSE_BS: True,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "rmsnorm",
ZERO_CEN: False, ZERO_CEN: False,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: 'causal', ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: False, TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "rmsnorm",
ZERO_CEN: False, ZERO_CEN: False,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
NUM_ATTN_HEADS: 8, NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4, NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal', ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: True, TRANSPOSE_BS: True,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "rmsnorm",
ZERO_CEN: False, ZERO_CEN: False,
ENABLE_ROPE: True, ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
NUM_ATTN_HEADS: 8, NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4, NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal', ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: False, TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "rmsnorm",
ZERO_CEN: False, ZERO_CEN: False,
ENABLE_ROPE: True, ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'alternate', ROPE_GROUP_METHOD: "alternate",
NUM_ATTN_HEADS: 8, NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4, NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal', ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: True, TRANSPOSE_BS: True,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: False, ZERO_CEN: False,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: 'padding', ATTN_MASK_TYPE: "padding",
LORA_SCOPE: 'all', LORA_SCOPE: "all",
TRANSPOSE_BS: False, TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: False, ZERO_CEN: False,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: 'causal', ATTN_MASK_TYPE: "causal",
LORA_SCOPE: 'all', LORA_SCOPE: "all",
TRANSPOSE_BS: True, TRANSPOSE_BS: True,
}] },
]
class TestMultiHeadAttn(TestLayer): class TestMultiHeadAttn(TestLayer):
...@@ -899,13 +887,16 @@ class TestMultiHeadAttn(TestLayer): ...@@ -899,13 +887,16 @@ class TestMultiHeadAttn(TestLayer):
return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask] return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask]
def get_layer_name(self): def get_layer_name(self):
return 'multi_head_attn' return "multi_head_attn"
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
head_dim = 64 head_dim = 64
num_attention_heads = 16 num_attention_heads = 16
num_gqa_groups = attrs[MultiHeadAttnAttr.NUM_GQA_GROUPS] \ num_gqa_groups = (
if MultiHeadAttnAttr.NUM_GQA_GROUPS in attrs else None attrs[MultiHeadAttnAttr.NUM_GQA_GROUPS]
if MultiHeadAttnAttr.NUM_GQA_GROUPS in attrs
else None
)
layernorm_type = attrs[MultiHeadAttnAttr.LN_TYPE] layernorm_type = attrs[MultiHeadAttnAttr.LN_TYPE]
zero_centered_gamma = attrs[MultiHeadAttnAttr.ZERO_CEN] zero_centered_gamma = attrs[MultiHeadAttnAttr.ZERO_CEN]
kernel_init = WeightInit.Gaussian(1.0) kernel_init = WeightInit.Gaussian(1.0)
...@@ -916,15 +907,16 @@ class TestMultiHeadAttn(TestLayer): ...@@ -916,15 +907,16 @@ class TestMultiHeadAttn(TestLayer):
attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE] attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE]
enable_rotary_pos_emb = attrs[MultiHeadAttnAttr.ENABLE_ROPE] enable_rotary_pos_emb = attrs[MultiHeadAttnAttr.ENABLE_ROPE]
rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD] rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD]
low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, 'none') low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, "none")
fuse_qkv_params = True fuse_qkv_params = True
transpose_batch_sequence = attrs[MultiHeadAttnAttr.TRANSPOSE_BS] transpose_batch_sequence = attrs[MultiHeadAttnAttr.TRANSPOSE_BS]
scale_attn_logits = False scale_attn_logits = False
scaled_query_init = True scaled_query_init = True
float32_logits = False float32_logits = False
praxis_p = pax_fiddle.Config(MultiHeadAttention, praxis_p = pax_fiddle.Config(
name='mha', MultiHeadAttention,
name="mha",
dtype=dtype, dtype=dtype,
head_dim=head_dim, head_dim=head_dim,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
...@@ -944,7 +936,8 @@ class TestMultiHeadAttn(TestLayer): ...@@ -944,7 +936,8 @@ class TestMultiHeadAttn(TestLayer):
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits, scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init, scaled_query_init=scaled_query_init,
float32_logits=float32_logits) float32_logits=float32_logits,
)
flax_cls = partial( flax_cls = partial(
flax_MultiHeadAttention, flax_MultiHeadAttention,
dtype=dtype, dtype=dtype,
...@@ -966,30 +959,27 @@ class TestMultiHeadAttn(TestLayer): ...@@ -966,30 +959,27 @@ class TestMultiHeadAttn(TestLayer):
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits, scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init, scaled_query_init=scaled_query_init,
float32_logits=float32_logits) float32_logits=float32_logits,
)
return praxis_p, flax_cls return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS) @pytest.mark.parametrize("attrs", MultiHeadAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS) @pytest.mark.parametrize("attrs", MultiHeadAttnAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS) @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(self, def test_forward_backward_fp8(
data_shape, self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
dtype, ):
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
self.attrs = attrs self.attrs = attrs
ds = DelayedScaling(fp8_format=fp8_format) ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
...@@ -998,252 +988,279 @@ class TestMultiHeadAttn(TestLayer): ...@@ -998,252 +988,279 @@ class TestMultiHeadAttn(TestLayer):
class TransformerLayerAttr: class TransformerLayerAttr:
USE_BIAS = 'use_bias' USE_BIAS = "use_bias"
LN_TYPE = 'layernorm_type' LN_TYPE = "layernorm_type"
ACTIVATION = 'activations' ACTIVATION = "activations"
LYR_TYPE = 'layer_type' LYR_TYPE = "layer_type"
ZERO_CEN = 'zero_centered_gamma' ZERO_CEN = "zero_centered_gamma"
TRANSPOSE_BS = 'transpose_batch_sequence' TRANSPOSE_BS = "transpose_batch_sequence"
ENABLE_ROPE = 'enable_rotary_pos_emb' ENABLE_ROPE = "enable_rotary_pos_emb"
ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method' ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
LORA_SCOPE = 'low_rank_adaptation_scope' LORA_SCOPE = "low_rank_adaptation_scope"
ATTRS = [{ ATTRS = [
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True TRANSPOSE_BS: True,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: True, ZERO_CEN: True,
ACTIVATION: ('relu',), ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True TRANSPOSE_BS: True,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: True, ZERO_CEN: True,
ACTIVATION: ('relu',), ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "rmsnorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True TRANSPOSE_BS: True,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "rmsnorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: True, ZERO_CEN: True,
ACTIVATION: ('relu',), ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True TRANSPOSE_BS: True,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: True, ZERO_CEN: True,
ACTIVATION: ('relu',), ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True TRANSPOSE_BS: True,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "rmsnorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True TRANSPOSE_BS: True,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "rmsnorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True TRANSPOSE_BS: True,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "rmsnorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True TRANSPOSE_BS: True,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "rmsnorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu',), ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False, TRANSPOSE_BS: False,
LORA_SCOPE: 'all' LORA_SCOPE: "all",
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True TRANSPOSE_BS: True,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "rmsnorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True TRANSPOSE_BS: True,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "rmsnorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: True, ZERO_CEN: True,
ACTIVATION: ('gelu',), ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: True, ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'alternate', ROPE_GROUP_METHOD: "alternate",
TRANSPOSE_BS: False TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: True, ZERO_CEN: True,
ACTIVATION: ('gelu',), ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: True, ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'alternate', ROPE_GROUP_METHOD: "alternate",
TRANSPOSE_BS: False TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: True, ZERO_CEN: True,
ACTIVATION: ('gelu',), ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.ENCODER, LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: True, ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: True, ZERO_CEN: True,
ACTIVATION: ('gelu',), ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: True, ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False TRANSPOSE_BS: False,
}, { },
{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu',), ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False, TRANSPOSE_BS: False,
LORA_SCOPE: 'all' LORA_SCOPE: "all",
}] },
]
class TestTransformer(TestLayer): class TestTransformer(TestLayer):
...@@ -1256,11 +1273,13 @@ class TestTransformer(TestLayer): ...@@ -1256,11 +1273,13 @@ class TestTransformer(TestLayer):
shape = (shape[1], shape[0]) + shape[2:] shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8) mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [ return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]),
mask,
mask,
] ]
def get_layer_name(self): def get_layer_name(self):
return 'transformerlayer' return "transformerlayer"
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
hidden_size = 512 hidden_size = 512
...@@ -1277,29 +1296,34 @@ class TestTransformer(TestLayer): ...@@ -1277,29 +1296,34 @@ class TestTransformer(TestLayer):
layer_type = attrs[TransformerLayerAttr.LYR_TYPE] layer_type = attrs[TransformerLayerAttr.LYR_TYPE]
enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE] enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
rotary_pos_emb_group_method = attrs[TransformerLayerAttr.ROPE_GROUP_METHOD] rotary_pos_emb_group_method = attrs[TransformerLayerAttr.ROPE_GROUP_METHOD]
low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, 'none') low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, "none")
enable_relative_embedding = True enable_relative_embedding = True
relative_embedding = pax_fiddle.Config(RelativePositionBiases, relative_embedding = pax_fiddle.Config(
dtype=dtype, RelativePositionBiases, dtype=dtype, num_attention_heads=num_attention_heads
num_attention_heads=num_attention_heads) )
drop_path = 0.0 drop_path = 0.0
transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS] transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS]
rel_embedding_init = RelativePositionBiases.generate_embedding_init( rel_embedding_init = RelativePositionBiases.generate_embedding_init(
relative_embedding.embedding_init, relative_embedding.num_attention_heads, relative_embedding.embedding_init,
relative_embedding.num_buckets) relative_embedding.num_attention_heads,
relative_embedding.num_buckets,
)
relative_embedding_flax_module = flax_RelativePositionBiases( relative_embedding_flax_module = flax_RelativePositionBiases(
num_buckets=relative_embedding.num_buckets, num_buckets=relative_embedding.num_buckets,
max_distance=relative_embedding.max_distance, max_distance=relative_embedding.max_distance,
num_attention_heads=relative_embedding.num_attention_heads, num_attention_heads=relative_embedding.num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init( embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", rel_embedding_init), "rel_embedding", rel_embedding_init
),
embedding_axes=relative_embedding.embedding_axes, embedding_axes=relative_embedding.embedding_axes,
dtype=relative_embedding.dtype) dtype=relative_embedding.dtype,
)
praxis_p = pax_fiddle.Config(TransformerLayer, praxis_p = pax_fiddle.Config(
name='transformer_layer', TransformerLayer,
name="transformer_layer",
params_init=kernel_init, params_init=kernel_init,
dtype=dtype, dtype=dtype,
hidden_size=hidden_size, hidden_size=hidden_size,
...@@ -1319,8 +1343,10 @@ class TestTransformer(TestLayer): ...@@ -1319,8 +1343,10 @@ class TestTransformer(TestLayer):
low_rank_adaptation_scope=low_rank_adaptation_scope, low_rank_adaptation_scope=low_rank_adaptation_scope,
relative_embedding=relative_embedding, relative_embedding=relative_embedding,
drop_path=drop_path, drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence,
flax_cls = partial(flax_TransformerLayer, )
flax_cls = partial(
flax_TransformerLayer,
dtype=dtype, dtype=dtype,
hidden_size=hidden_size, hidden_size=hidden_size,
mlp_hidden_size=mlp_hidden_size, mlp_hidden_size=mlp_hidden_size,
...@@ -1331,12 +1357,13 @@ class TestTransformer(TestLayer): ...@@ -1331,12 +1357,13 @@ class TestTransformer(TestLayer):
intermediate_dropout=intermediate_dropout, intermediate_dropout=intermediate_dropout,
mlp_activations=mlp_activations, mlp_activations=mlp_activations,
mha_kernel_init=TransformerEngineBaseLayer.generate_params_init( mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mha_kernel", kernel_init), "mha_kernel", kernel_init
),
mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init( mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mlp_kernel", kernel_init), "mlp_kernel", kernel_init
),
use_bias=use_bias, use_bias=use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init( bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
"bias", bias_init),
layer_type=layer_type, layer_type=layer_type,
enable_rotary_pos_emb=enable_rotary_pos_emb, enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method, rotary_pos_emb_group_method=rotary_pos_emb_group_method,
...@@ -1344,30 +1371,27 @@ class TestTransformer(TestLayer): ...@@ -1344,30 +1371,27 @@ class TestTransformer(TestLayer):
relative_embedding=relative_embedding_flax_module, relative_embedding=relative_embedding_flax_module,
low_rank_adaptation_scope=low_rank_adaptation_scope, low_rank_adaptation_scope=low_rank_adaptation_scope,
drop_path=drop_path, drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS) @pytest.mark.parametrize("attrs", TransformerLayerAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS) @pytest.mark.parametrize("attrs", TransformerLayerAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS) @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(self, def test_forward_backward_fp8(
data_shape, self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
dtype, ):
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
self.attrs = attrs self.attrs = attrs
ds = DelayedScaling(fp8_format=fp8_format) ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
......
...@@ -3,4 +3,5 @@ ...@@ -3,4 +3,5 @@
# See LICENSE for license information. # See LICENSE for license information.
import transformer_engine.jax import transformer_engine.jax
print("OK") print("OK")
...@@ -8,25 +8,25 @@ from transformer_engine.jax.flax import extend_logical_axis_rules ...@@ -8,25 +8,25 @@ from transformer_engine.jax.flax import extend_logical_axis_rules
from transformer_engine.jax.sharding import global_shard_guard, MeshResource from transformer_engine.jax.sharding import global_shard_guard, MeshResource
LOGICAL_RULES = [ LOGICAL_RULES = [
[(('a1', None), ('a2', 'ma2')), False], [(("a1", None), ("a2", "ma2")), False],
[(('a1', None), ('a2', 'ma2'), ('a3', ('ma31', 'ma32'))), True], [(("a1", None), ("a2", "ma2"), ("a3", ("ma31", "ma32"))), True],
[(('a1', None), ('a2', 'ma2'), ('a3', 'ma31'), ('a3', 'ma32')), False], [(("a1", None), ("a2", "ma2"), ("a3", "ma31"), ("a3", "ma32")), False],
[(('a1', None), ('a2', 'ma2'), ('batch', 'batch_1200234')), True], [(("a1", None), ("a2", "ma2"), ("batch", "batch_1200234")), True],
[(('a1', None), ('a2', 'ma2'), ('a2', 'ma1'), ('batch', 'model'), ('batch', 'data')), True], [(("a1", None), ("a2", "ma2"), ("a2", "ma1"), ("batch", "model"), ("batch", "data")), True],
] ]
MeshS = [ MeshS = [
MeshResource(), MeshResource(),
MeshResource('data', None), MeshResource("data", None),
MeshResource(None, 'model'), MeshResource(None, "model"),
MeshResource('data', 'model') MeshResource("data", "model"),
] ]
class TestShardingSideAPI: class TestShardingSideAPI:
@pytest.mark.parametrize('base_rules,need_assert', LOGICAL_RULES) @pytest.mark.parametrize("base_rules,need_assert", LOGICAL_RULES)
@pytest.mark.parametrize('sr', MeshS) @pytest.mark.parametrize("sr", MeshS)
def test_extend_logical_axis_rules(self, base_rules, need_assert, sr): def test_extend_logical_axis_rules(self, base_rules, need_assert, sr):
with global_shard_guard(sr): with global_shard_guard(sr):
try: try:
......
...@@ -43,6 +43,7 @@ class SoftmaxRunner: ...@@ -43,6 +43,7 @@ class SoftmaxRunner:
""" """
Softmax runner Softmax runner
""" """
batch_size: int batch_size: int
max_seqlen_q: int max_seqlen_q: int
max_seqlen_kv: int max_seqlen_kv: int
...@@ -57,14 +58,22 @@ class SoftmaxRunner: ...@@ -57,14 +58,22 @@ class SoftmaxRunner:
Jax softmax as the reference Jax softmax as the reference
""" """
if mask is not None: if mask is not None:
logits += lax.select(mask > 0, logits += lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype), jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.).astype(logits.dtype)) jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return nn.softmax(logits * scale_factor) return nn.softmax(logits * scale_factor)
def _is_support(self): def _is_support(self):
return is_softmax_kernel_available(self.softmax_type, self.batch_size, self.num_heads, return is_softmax_kernel_available(
self.max_seqlen_q, self.max_seqlen_kv, self.dtype) self.softmax_type,
self.batch_size,
self.num_heads,
self.max_seqlen_q,
self.max_seqlen_kv,
self.dtype,
)
def _setup_inputs(self): def _setup_inputs(self):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
...@@ -73,7 +82,7 @@ class SoftmaxRunner: ...@@ -73,7 +82,7 @@ class SoftmaxRunner:
logits_shape = (self.batch_size, self.num_heads, self.max_seqlen_q, self.max_seqlen_kv) logits_shape = (self.batch_size, self.num_heads, self.max_seqlen_q, self.max_seqlen_kv)
mask_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv) mask_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.) self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.0)
match self.softmax_type: match self.softmax_type:
case SoftmaxType.SCALED: case SoftmaxType.SCALED:
...@@ -81,7 +90,7 @@ class SoftmaxRunner: ...@@ -81,7 +90,7 @@ class SoftmaxRunner:
case SoftmaxType.SCALED_MASKED: case SoftmaxType.SCALED_MASKED:
self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8) self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8)
case SoftmaxType.SCALED_UPPER_TRIANG_MASKED: case SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
self.mask = (1. - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8) self.mask = (1.0 - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8)
case _: case _:
raise ValueError(f"Unknown {self.softmax_type=}") raise ValueError(f"Unknown {self.softmax_type=}")
...@@ -108,18 +117,24 @@ class SoftmaxRunner: ...@@ -108,18 +117,24 @@ class SoftmaxRunner:
args = [self.logits, self.mask] args = [self.logits, self.mask]
kwargs = { kwargs = {
'scale_factor': self.scale_factor, "scale_factor": self.scale_factor,
'softmax_type': self.softmax_type, "softmax_type": self.softmax_type,
} }
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit( jitted_primitive = jit(
value_and_grad(lambda logits, *args: grad_func(softmax, self.logits, *args, **kwargs), value_and_grad(
(0,))) lambda logits, *args: grad_func(softmax, self.logits, *args, **kwargs), (0,)
)
)
jitted_reference = jit( jitted_reference = jit(
value_and_grad( value_and_grad(
lambda logits, *args: grad_func(__class__.reference_softmax, self.logits, *args, ** lambda logits, *args: grad_func(
kwargs), (0,))) __class__.reference_softmax, self.logits, *args, **kwargs
),
(0,),
)
)
primitive_out, (primitive_grad_logits,) = jitted_primitive(*args) primitive_out, (primitive_grad_logits,) = jitted_primitive(*args)
reference_out, (reference_grad_logits,) = jitted_reference(*args) reference_out, (reference_grad_logits,) = jitted_reference(*args)
...@@ -128,21 +143,30 @@ class SoftmaxRunner: ...@@ -128,21 +143,30 @@ class SoftmaxRunner:
assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype) assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype)
@pytest.mark.parametrize('b, s_q, s_kv, h', [ @pytest.mark.parametrize(
pytest.param(8, 16, 16, 16, id='8-16-16-16'), "b, s_q, s_kv, h",
pytest.param(8, 512, 512, 16, id='8-512-512-16'), [
pytest.param(2, 8, 16384, 8, id='2-8-16384-8') pytest.param(8, 16, 16, 16, id="8-16-16-16"),
]) pytest.param(8, 512, 512, 16, id="8-512-512-16"),
@pytest.mark.parametrize('scale_factor', [0.125]) pytest.param(2, 8, 16384, 8, id="2-8-16384-8"),
@pytest.mark.parametrize('softmax_type', [ ],
pytest.param(SoftmaxType.SCALED, id='SCALED'), )
pytest.param(SoftmaxType.SCALED_MASKED, id='SCALED_MASKED'), @pytest.mark.parametrize("scale_factor", [0.125])
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id='SCALED_UPPER_TRIANG_MASKED') @pytest.mark.parametrize(
]) "softmax_type",
@pytest.mark.parametrize('dtype', [ [
pytest.param(SoftmaxType.SCALED, id="SCALED"),
pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
],
)
@pytest.mark.parametrize(
"dtype",
[
pytest.param(jnp.bfloat16, id="BF16"), pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16"), pytest.param(jnp.float16, id="FP16"),
]) ],
)
class TestSoftmax: class TestSoftmax:
""" """
Test transformer_engine.jax.softmax.softmax Test transformer_engine.jax.softmax.softmax
......
...@@ -24,8 +24,9 @@ PRNGKey = Any ...@@ -24,8 +24,9 @@ PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
DType = jnp.dtype DType = jnp.dtype
Array = Any Array = Any
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, PrecisionLike = Union[
lax.Precision]] None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
Initializer = Callable[[PRNGKey, Shape, DType], Array] Initializer = Callable[[PRNGKey, Shape, DType], Array]
...@@ -56,7 +57,7 @@ def _canonicalize_tuple(x): ...@@ -56,7 +57,7 @@ def _canonicalize_tuple(x):
def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable: def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
"""Convert a string to an activation function.""" """Convert a string to an activation function."""
if fn_or_string == 'linear': if fn_or_string == "linear":
return lambda x: x return lambda x: x
if isinstance(fn_or_string, str): if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string) return getattr(nn, fn_or_string)
...@@ -77,8 +78,9 @@ def combine_biases(*masks: Optional[Array]): ...@@ -77,8 +78,9 @@ def combine_biases(*masks: Optional[Array]):
masks = [m for m in masks if m is not None] masks = [m for m in masks if m is not None]
if not masks: if not masks:
return None return None
assert all(map(lambda x: x.ndim == masks[0].ndim, assert all(
masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') map(lambda x: x.ndim == masks[0].ndim, masks)
), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
mask, *other_masks = masks mask, *other_masks = masks
for other_mask in other_masks: for other_mask in other_masks:
mask = mask + other_mask mask = mask + other_mask
...@@ -88,7 +90,7 @@ def combine_biases(*masks: Optional[Array]): ...@@ -88,7 +90,7 @@ def combine_biases(*masks: Optional[Array]):
class DotProductAttention(nn.Module): class DotProductAttention(nn.Module):
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
scale_attn_logits: bool = True scale_attn_logits: bool = True
dropout_rate: float = 0. dropout_rate: float = 0.0
dtype: DType = jnp.float32 dtype: DType = jnp.float32
float32_logits: bool = False float32_logits: bool = False
"""Computes dot-product attention given query, key, and value. """Computes dot-product attention given query, key, and value.
...@@ -105,12 +107,14 @@ class DotProductAttention(nn.Module): ...@@ -105,12 +107,14 @@ class DotProductAttention(nn.Module):
""" """
@nn.compact @nn.compact
def __call__(self, def __call__(
self,
query: Array, query: Array,
key: Array, key: Array,
value: Array, value: Array,
bias: Optional[Array] = None, bias: Optional[Array] = None,
deterministic: bool = False): deterministic: bool = False,
):
""" """
Args: Args:
query: queries for calculating attention with shape of `[batch, q_length, query: queries for calculating attention with shape of `[batch, q_length,
...@@ -127,14 +131,15 @@ class DotProductAttention(nn.Module): ...@@ -127,14 +131,15 @@ class DotProductAttention(nn.Module):
Returns: Returns:
Output of shape `[batch, length, num_heads, v_depth_per_head]`. Output of shape `[batch, length, num_heads, v_depth_per_head]`.
""" """
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
batch_dim = 1 if self.transpose_batch_sequence else 0 batch_dim = 1 if self.transpose_batch_sequence else 0
assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], ( assert (
'q, k, v batch dims must match.') query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
), "q, k, v batch dims must match."
sequence_dim = 0 if self.transpose_batch_sequence else 1 sequence_dim = 0 if self.transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.' assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.' assert key.shape[-2] == value.shape[-2], "k, v num_heads must match."
assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.' assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
if self.scale_attn_logits: if self.scale_attn_logits:
head_dim = query.shape[-1] head_dim = query.shape[-1]
...@@ -153,9 +158,9 @@ class DotProductAttention(nn.Module): ...@@ -153,9 +158,9 @@ class DotProductAttention(nn.Module):
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1])) grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key) attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
else: else:
attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key) attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
# reshape back to normal DPA shape for bias/softmax/dropout # reshape back to normal DPA shape for bias/softmax/dropout
b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
...@@ -170,24 +175,23 @@ class DotProductAttention(nn.Module): ...@@ -170,24 +175,23 @@ class DotProductAttention(nn.Module):
attn_weights = jax_nn.softmax(attn_weights).astype(self.dtype) attn_weights = jax_nn.softmax(attn_weights).astype(self.dtype)
# Apply attention dropout. # Apply attention dropout.
if not deterministic and self.dropout_rate > 0.: if not deterministic and self.dropout_rate > 0.0:
keep_prob = 1.0 - self.dropout_rate keep_prob = 1.0 - self.dropout_rate
# T5 broadcasts along the "length" dim, but unclear which one that # T5 broadcasts along the "length" dim, but unclear which one that
# corresponds to in positional dimensions here, assuming query dim. # corresponds to in positional dimensions here, assuming query dim.
dropout_shape = list(attn_weights.shape) dropout_shape = list(attn_weights.shape)
dropout_rng = self.make_rng('dropout') dropout_rng = self.make_rng("dropout")
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
multiplier = (keep.astype(attn_weights.dtype) / multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype)
jnp.asarray(keep_prob, dtype=self.dtype))
attn_weights = attn_weights * multiplier attn_weights = attn_weights * multiplier
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
# Take the linear combination of `value`. # Take the linear combination of `value`.
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape) return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape) return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
class DenseGeneral(nn.Module): class DenseGeneral(nn.Module):
...@@ -201,6 +205,7 @@ class DenseGeneral(nn.Module): ...@@ -201,6 +205,7 @@ class DenseGeneral(nn.Module):
use_bias: whether to add a bias to the output (default: False). use_bias: whether to add a bias to the output (default: False).
bias_init: initializer function for the bias vector. bias_init: initializer function for the bias vector.
""" """
features: Union[Iterable[int], int] features: Union[Iterable[int], int]
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
...@@ -212,7 +217,7 @@ class DenseGeneral(nn.Module): ...@@ -212,7 +217,7 @@ class DenseGeneral(nn.Module):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -233,21 +238,17 @@ class DenseGeneral(nn.Module): ...@@ -233,21 +238,17 @@ class DenseGeneral(nn.Module):
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features)) kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
kernel = nn_partitioning.param_with_axes('kernel', kernel = nn_partitioning.param_with_axes(
self.kernel_init, "kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes
kernel_param_shape, )
jnp.float32,
axes=self.kernel_axes)
kernel = jnp.asarray(kernel, self.dtype) kernel = jnp.asarray(kernel, self.dtype)
kernel = jnp.reshape(kernel, kernel_shape) kernel = jnp.reshape(kernel, kernel_shape)
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes('bias', bias = nn_partitioning.param_with_axes(
self.bias_init, "bias", self.bias_init, self.features, jnp.float32, axes=self.bias_axes
self.features, )
jnp.float32,
axes=self.bias_axes)
bias = bias.astype(self.dtype) bias = bias.astype(self.dtype)
else: else:
bias = None bias = None
...@@ -273,9 +274,10 @@ class MlpBlock(nn.Module): ...@@ -273,9 +274,10 @@ class MlpBlock(nn.Module):
intermediate_dropout_rate: Dropout rate used after the intermediate layers. intermediate_dropout_rate: Dropout rate used after the intermediate layers.
dtype: Type for the dense layer. dtype: Type for the dense layer.
""" """
transpose_batch_sequence: bool transpose_batch_sequence: bool
intermediate_dim: int = 2048 intermediate_dim: int = 2048
activations: Sequence[Union[str, Callable]] = ('relu',) activations: Sequence[Union[str, Callable]] = ("relu",)
kernel_init: Initializer = None kernel_init: Initializer = None
intermediate_dropout_rate: float = 0.1 intermediate_dropout_rate: float = 0.1
intermediate_dropout_dims: Sequence[int] = () intermediate_dropout_dims: Sequence[int] = ()
...@@ -285,7 +287,7 @@ class MlpBlock(nn.Module): ...@@ -285,7 +287,7 @@ class MlpBlock(nn.Module):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -296,49 +298,57 @@ class MlpBlock(nn.Module): ...@@ -296,49 +298,57 @@ class MlpBlock(nn.Module):
activations = [] activations = []
if self.fuse_wi: if self.fuse_wi:
dense_name = 'wi' dense_name = "wi"
num_activations = len(self.activations) num_activations = len(self.activations)
x = DenseGeneral(self.intermediate_dim * num_activations, x = DenseGeneral(
self.intermediate_dim * num_activations,
dtype=self.dtype, dtype=self.dtype,
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
kernel_axes=('embed', 'mlp'), kernel_axes=("embed", "mlp"),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_axes=('mlp'), bias_axes="mlp",
name=dense_name)(inputs) name=dense_name,
)(inputs)
x = jnp.split(x, num_activations, axis=-1) x = jnp.split(x, num_activations, axis=-1)
for idx, act_fn in enumerate(self.activations): for idx, act_fn in enumerate(self.activations):
x_i = _convert_to_activation_function(act_fn)(x[idx]) x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i) activations.append(x_i)
else: else:
for idx, act_fn in enumerate(self.activations): for idx, act_fn in enumerate(self.activations):
dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}' dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}"
x = DenseGeneral(self.intermediate_dim, x = DenseGeneral(
self.intermediate_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
kernel_axes=('embed', 'mlp'), kernel_axes=("embed", "mlp"),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_axes=('mlp'), bias_axes="mlp",
name=dense_name)(inputs) name=dense_name,
)(inputs)
x = _convert_to_activation_function(act_fn)(x) x = _convert_to_activation_function(act_fn)(x)
activations.append(x) activations.append(x)
# Take elementwise product of above intermediate activations. # Take elementwise product of above intermediate activations.
x = functools.reduce(operator.mul, activations) x = functools.reduce(operator.mul, activations)
# Apply dropout and final dense output projection. # Apply dropout and final dense output projection.
x = nn.Dropout(rate=self.intermediate_dropout_rate, x = nn.Dropout(
broadcast_dims=self.intermediate_dropout_dims)( rate=self.intermediate_dropout_rate, broadcast_dims=self.intermediate_dropout_dims
x, deterministic=deterministic) # Broadcast along length. )(
x, deterministic=deterministic
) # Broadcast along length.
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'mlp')) x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "mlp"))
else: else:
x = nn_partitioning.with_sharding_constraint(x, ('batch', 'length', 'mlp')) x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "mlp"))
output = DenseGeneral(inputs.shape[-1], output = DenseGeneral(
inputs.shape[-1],
dtype=self.dtype, dtype=self.dtype,
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
kernel_axes=('mlp', 'embed'), kernel_axes=("mlp", "embed"),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_axes=('embed'), bias_axes="embed",
name='wo')(x) name="wo",
)(x)
return output return output
...@@ -351,7 +361,7 @@ def apply_rotary_pos_emb_alternate( ...@@ -351,7 +361,7 @@ def apply_rotary_pos_emb_alternate(
embedding_dim = inputs.shape[-1] embedding_dim = inputs.shape[-1]
half_embedding_dim = embedding_dim // 2 half_embedding_dim = embedding_dim // 2
fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim
timescale = min_timescale * (max_timescale / min_timescale)**fraction timescale = min_timescale * (max_timescale / min_timescale) ** fraction
timescale = jnp.expand_dims(timescale, axis=tuple(range(inputs.ndim - 1))) timescale = jnp.expand_dims(timescale, axis=tuple(range(inputs.ndim - 1)))
position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim))) position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
sinusoid_inp = position / timescale sinusoid_inp = position / timescale
...@@ -386,7 +396,7 @@ def apply_rotary_pos_emb_consecutive( ...@@ -386,7 +396,7 @@ def apply_rotary_pos_emb_consecutive(
inputs_shifted_left, inputs_shifted_left,
) )
fraction = jnp.repeat(fraction, 2) fraction = jnp.repeat(fraction, 2)
timescale = min_timescale * (max_timescale / min_timescale)**fraction timescale = min_timescale * (max_timescale / min_timescale) ** fraction
position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim))) position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
...@@ -422,32 +432,34 @@ class MultiHeadAttention(nn.Module): ...@@ -422,32 +432,34 @@ class MultiHeadAttention(nn.Module):
head_dim: int = 64 head_dim: int = 64
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
dtype: DType = jnp.float32 dtype: DType = jnp.float32
dropout_rate: float = 0. dropout_rate: float = 0.0
kernel_init: Initializer = None kernel_init: Initializer = None
float32_logits: bool = False # computes logits in float32 for stability. float32_logits: bool = False # computes logits in float32 for stability.
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
enable_rotary_pos_emb: bool = False enable_rotary_pos_emb: bool = False
rotary_pos_emb_group_method: str = 'consecutive' rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv: bool = True fuse_qkv: bool = True
use_bias: bool = False use_bias: bool = False
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal') self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal")
if self.num_gqa_groups is None: if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads self.num_gqa_groups = self.num_attention_heads
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
def __call__(self, def __call__(
self,
inputs_q: Array, inputs_q: Array,
inputs_kv: Array, inputs_kv: Array,
mask: Optional[Array] = None, mask: Optional[Array] = None,
bias: Optional[Array] = None, bias: Optional[Array] = None,
*, *,
decode: bool = False, decode: bool = False,
deterministic: bool = False) -> Array: deterministic: bool = False,
) -> Array:
"""Applies multi-head dot product attention on the input data. """Applies multi-head dot product attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors, Projects the inputs into multi-headed query, key, and value vectors,
...@@ -476,28 +488,33 @@ class MultiHeadAttention(nn.Module): ...@@ -476,28 +488,33 @@ class MultiHeadAttention(nn.Module):
Returns: Returns:
output of shape `[batch, length, q_features]`. output of shape `[batch, length, q_features]`.
""" """
q_projection = functools.partial(DenseGeneral, q_projection = functools.partial(
DenseGeneral,
axis=-1, axis=-1,
features=self.num_heads * self.head_dim, features=self.num_heads * self.head_dim,
kernel_axes=('embed', 'joined_kv'), kernel_axes=("embed", "joined_kv"),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_axes=('joined_kv'), bias_axes="joined_kv",
dtype=self.dtype) dtype=self.dtype,
)
kv_projection = functools.partial(DenseGeneral, kv_projection = functools.partial(
DenseGeneral,
axis=-1, axis=-1,
features=self.num_gqa_groups * self.head_dim, features=self.num_gqa_groups * self.head_dim,
kernel_axes=('embed', 'joined_kv'), kernel_axes=("embed", "joined_kv"),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_axes=('joined_kv'), bias_axes="joined_kv",
dtype=self.dtype) dtype=self.dtype,
)
# NOTE: T5 does not explicitly rescale the attention logits by # NOTE: T5 does not explicitly rescale the attention logits by
# 1/sqrt(depth_kq)! This is folded into the initializers of the # 1/sqrt(depth_kq)! This is folded into the initializers of the
# linear transformations, which is equivalent under Adafactor # linear transformations, which is equivalent under Adafactor
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
query_init = lambda *args: self.kernel_init(*args) / (depth_scaling query_init = lambda *args: self.kernel_init(*args) / (
if self.scaled_query_init else 1.0) depth_scaling if self.scaled_query_init else 1.0
)
# Project inputs_q to multi-headed q/k/v # Project inputs_q to multi-headed q/k/v
# dimensions are then [batch, length, num_heads, head_dim] # dimensions are then [batch, length, num_heads, head_dim]
...@@ -515,39 +532,45 @@ class MultiHeadAttention(nn.Module): ...@@ -515,39 +532,45 @@ class MultiHeadAttention(nn.Module):
return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype) return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype)
is_self_attn = (inputs_q is inputs_kv) is_self_attn = inputs_q is inputs_kv
is_gqa = (self.num_heads != self.num_gqa_groups) is_gqa = self.num_heads != self.num_gqa_groups
is_qkvpack = (is_self_attn and not is_gqa) is_qkvpack = is_self_attn and not is_gqa
if self.fuse_qkv: if self.fuse_qkv:
if is_qkvpack: if is_qkvpack:
qkv_proj = DenseGeneral(axis=-1, qkv_proj = DenseGeneral(
axis=-1,
features=self.num_heads * self.head_dim * 3, features=self.num_heads * self.head_dim * 3,
kernel_axes=('embed', 'joined_kv'), kernel_axes=("embed", "joined_kv"),
kernel_init=qkv_init, kernel_init=qkv_init,
use_bias=self.use_bias, use_bias=self.use_bias,
bias_axes=('joined_kv'), bias_axes="joined_kv",
name='qkv', name="qkv",
dtype=self.dtype)(inputs_kv) dtype=self.dtype,
)(inputs_kv)
query, key, value = jnp.split( query, key, value = jnp.split(
qkv_proj, [self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2], qkv_proj,
axis=-1) [self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2],
axis=-1,
)
else: else:
query = q_projection(kernel_init=query_init, name='query')(inputs_q) query = q_projection(kernel_init=query_init, name="query")(inputs_q)
kv_proj = DenseGeneral(axis=-1, kv_proj = DenseGeneral(
axis=-1,
features=self.num_gqa_groups * self.head_dim * 2, features=self.num_gqa_groups * self.head_dim * 2,
kernel_axes=('embed', 'joined_kv'), kernel_axes=("embed", "joined_kv"),
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
use_bias=self.use_bias, use_bias=self.use_bias,
bias_axes=('joined_kv'), bias_axes="joined_kv",
name='kv', name="kv",
dtype=self.dtype)(inputs_kv) dtype=self.dtype,
)(inputs_kv)
key, value = jnp.split(kv_proj, [self.num_gqa_groups * self.head_dim], axis=-1) key, value = jnp.split(kv_proj, [self.num_gqa_groups * self.head_dim], axis=-1)
else: else:
query = q_projection(kernel_init=query_init, name='query')(inputs_q) query = q_projection(kernel_init=query_init, name="query")(inputs_q)
key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv) key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv) value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
if self.enable_rotary_pos_emb: if self.enable_rotary_pos_emb:
batch_dim = 1 if self.transpose_batch_sequence else 0 batch_dim = 1 if self.transpose_batch_sequence else 0
...@@ -556,7 +579,7 @@ class MultiHeadAttention(nn.Module): ...@@ -556,7 +579,7 @@ class MultiHeadAttention(nn.Module):
q_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim) q_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
k_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim) k_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
if self.rotary_pos_emb_group_method == 'alternate': if self.rotary_pos_emb_group_method == "alternate":
apply_rotary_pos_emb = apply_rotary_pos_emb_alternate apply_rotary_pos_emb = apply_rotary_pos_emb_alternate
else: else:
apply_rotary_pos_emb = apply_rotary_pos_emb_consecutive apply_rotary_pos_emb = apply_rotary_pos_emb_consecutive
...@@ -571,33 +594,40 @@ class MultiHeadAttention(nn.Module): ...@@ -571,33 +594,40 @@ class MultiHeadAttention(nn.Module):
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
query = nn_partitioning.with_sharding_constraint(query, query = nn_partitioning.with_sharding_constraint(
('length', 'batch', 'heads', 'kv')) query, ("length", "batch", "heads", "kv")
key = nn_partitioning.with_sharding_constraint(key, ('length', 'batch', 'heads', 'kv')) )
value = nn_partitioning.with_sharding_constraint(value, key = nn_partitioning.with_sharding_constraint(key, ("length", "batch", "heads", "kv"))
('length', 'batch', 'heads', 'kv')) value = nn_partitioning.with_sharding_constraint(
value, ("length", "batch", "heads", "kv")
)
else: else:
query = nn_partitioning.with_sharding_constraint(query, query = nn_partitioning.with_sharding_constraint(
('batch', 'length', 'heads', 'kv')) query, ("batch", "length", "heads", "kv")
key = nn_partitioning.with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv')) )
value = nn_partitioning.with_sharding_constraint(value, key = nn_partitioning.with_sharding_constraint(key, ("batch", "length", "heads", "kv"))
('batch', 'length', 'heads', 'kv')) value = nn_partitioning.with_sharding_constraint(
value, ("batch", "length", "heads", "kv")
)
if decode: if decode:
# Detect if we're initializing by absence of existing cache data. # Detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable('cache', 'cached_key') is_initialized = self.has_variable("cache", "cached_key")
# The key and value have dimension [batch, length, num_heads, head_dim], # The key and value have dimension [batch, length, num_heads, head_dim],
# but we cache them as [batch, num_heads, head_dim, length] as a TPU # but we cache them as [batch, num_heads, head_dim, length] as a TPU
# fusion optimization. This also enables the "scatter via one-hot # fusion optimization. This also enables the "scatter via one-hot
# broadcast" trick, which means we do a one-hot broadcast instead of a # broadcast" trick, which means we do a one-hot broadcast instead of a
# scatter/gather operations, resulting in a 3-4x speedup in practice. # scatter/gather operations, resulting in a 3-4x speedup in practice.
swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3])
cached_key = self.variable('cache', 'cached_key', jnp.zeros, swap_dims(key.shape), cached_key = self.variable(
key.dtype) "cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype
cached_value = self.variable('cache', 'cached_value', jnp.zeros, swap_dims(value.shape), )
value.dtype) cached_value = self.variable(
cache_index = self.variable('cache', 'cache_index', "cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype
lambda: jnp.array(0, dtype=jnp.int32)) )
cache_index = self.variable(
"cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
)
if is_initialized: if is_initialized:
batch, num_heads, head_dim, length = cached_key.value.shape batch, num_heads, head_dim, length = cached_key.value.shape
# During fast autoregressive decoding, we feed one position at a time, # During fast autoregressive decoding, we feed one position at a time,
...@@ -606,8 +636,9 @@ class MultiHeadAttention(nn.Module): ...@@ -606,8 +636,9 @@ class MultiHeadAttention(nn.Module):
expected_shape = (batch, 1, num_heads, head_dim) expected_shape = (batch, 1, num_heads, head_dim)
if expected_shape != query.shape: if expected_shape != query.shape:
raise ValueError( raise ValueError(
'Autoregressive cache shape error, ' "Autoregressive cache shape error, "
f"expected query shape {expected_shape} instead got {query.shape}.") f"expected query shape {expected_shape} instead got {query.shape}."
)
# Create a OHE of the current index. NOTE: the index is increased below. # Create a OHE of the current index. NOTE: the index is increased below.
cur_index = cache_index.value cur_index = cache_index.value
...@@ -642,7 +673,9 @@ class MultiHeadAttention(nn.Module): ...@@ -642,7 +673,9 @@ class MultiHeadAttention(nn.Module):
# query length is 1 because during decoding we deal with one # query length is 1 because during decoding we deal with one
# index. # index.
# The same mask is applied to all batch elements and heads. # The same mask is applied to all batch elements and heads.
(batch, 1, 1, length))) (batch, 1, 1, length),
),
)
# Grab the correct relative attention bias during decoding. This is # Grab the correct relative attention bias during decoding. This is
# only required during single step decoding. # only required during single step decoding.
...@@ -650,15 +683,18 @@ class MultiHeadAttention(nn.Module): ...@@ -650,15 +683,18 @@ class MultiHeadAttention(nn.Module):
# The bias is a full attention matrix, but during decoding we only # The bias is a full attention matrix, but during decoding we only
# have to take a slice of it. # have to take a slice of it.
# This is equivalent to bias[..., cur_index:cur_index+1, :]. # This is equivalent to bias[..., cur_index:cur_index+1, :].
bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0), bias = dynamic_vector_slice_in_dim(
jnp.reshape(cur_index, (-1)), 1, -2) jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
)
# Convert the boolean attention mask to an attention bias. # Convert the boolean attention mask to an attention bias.
if mask is not None: if mask is not None:
# attention mask in the form of attention bias # attention mask in the form of attention bias
attention_bias = lax.select(mask > 0, attention_bias = lax.select(
jnp.full(mask.shape, 0.).astype(self.dtype), mask > 0,
jnp.full(mask.shape, -1e10).astype(self.dtype)) jnp.full(mask.shape, 0.0).astype(self.dtype),
jnp.full(mask.shape, -1e10).astype(self.dtype),
)
else: else:
attention_bias = None attention_bias = None
...@@ -667,41 +703,41 @@ class MultiHeadAttention(nn.Module): ...@@ -667,41 +703,41 @@ class MultiHeadAttention(nn.Module):
attention_bias = combine_biases(attention_bias, bias) attention_bias = combine_biases(attention_bias, bias)
# Apply attention. # Apply attention.
x = DotProductAttention(transpose_batch_sequence=self.transpose_batch_sequence, x = DotProductAttention(
transpose_batch_sequence=self.transpose_batch_sequence,
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
dropout_rate=self.dropout_rate, dropout_rate=self.dropout_rate,
dtype=self.dtype, dtype=self.dtype,
float32_logits=self.float32_logits)(query, float32_logits=self.float32_logits,
key, )(query, key, value, bias=attention_bias, deterministic=deterministic)
value,
bias=attention_bias,
deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'joined_kv')) x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "joined_kv"))
else: else:
x = nn_partitioning.with_sharding_constraint(x, ('batch', 'length', 'joined_kv')) x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "joined_kv"))
# Back to the original inputs dimensions. # Back to the original inputs dimensions.
out = DenseGeneral( out = DenseGeneral(
features=inputs_q.shape[-1], # output dim is set to the input dim. features=inputs_q.shape[-1], # output dim is set to the input dim.
axis=-1, axis=-1,
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
kernel_axes=('joined_kv', 'embed'), kernel_axes=("joined_kv", "embed"),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_axes=('embed'), bias_axes="embed",
dtype=self.dtype, dtype=self.dtype,
name='out')(x) name="out",
)(x)
return out return out
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
"""T5 Layer normalization operating on the last axis of the input data.""" """T5 Layer normalization operating on the last axis of the input data."""
epsilon: float = 1e-6 epsilon: float = 1e-6
dtype: Any = jnp.float32 dtype: Any = jnp.float32
layernorm_type: str = 'layernorm' layernorm_type: str = "layernorm"
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
scale_init: Initializer = None scale_init: Initializer = None
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
...@@ -721,29 +757,27 @@ class LayerNorm(nn.Module): ...@@ -721,29 +757,27 @@ class LayerNorm(nn.Module):
x = jnp.asarray(x, jnp.float32) x = jnp.asarray(x, jnp.float32)
features = x.shape[-1] features = x.shape[-1]
scale = nn_partitioning.param_with_axes('scale', scale = nn_partitioning.param_with_axes(
self.scale_init, (features,), "scale", self.scale_init, (features,), jnp.float32, axes=("embed",)
jnp.float32, )
axes=('embed',))
scale = jnp.asarray(scale, self.dtype) scale = jnp.asarray(scale, self.dtype)
if self.layernorm_type == 'layernorm': if self.layernorm_type == "layernorm":
mean = jnp.mean(x, axis=-1, keepdims=True) mean = jnp.mean(x, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True) var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True)
y = (x - mean) * lax.rsqrt(var + self.epsilon) y = (x - mean) * lax.rsqrt(var + self.epsilon)
bias = nn_partitioning.param_with_axes('ln_bias', bias = nn_partitioning.param_with_axes(
self.bias_init, (features,), "ln_bias", self.bias_init, (features,), jnp.float32, axes=("embed",)
jnp.float32, )
axes=('embed',))
bias = jnp.asarray(bias, self.dtype) bias = jnp.asarray(bias, self.dtype)
if not self.zero_centered_gamma: if not self.zero_centered_gamma:
z = y * scale + bias z = y * scale + bias
else: else:
z = y * (scale + 1.) + bias z = y * (scale + 1.0) + bias
else: else:
assert self.layernorm_type == 'rmsnorm' assert self.layernorm_type == "rmsnorm"
assert not self.zero_centered_gamma assert not self.zero_centered_gamma
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
y = x * lax.rsqrt(mean2 + self.epsilon) y = x * lax.rsqrt(mean2 + self.epsilon)
...@@ -765,6 +799,7 @@ class RelativePositionBiases(nn.Module): ...@@ -765,6 +799,7 @@ class RelativePositionBiases(nn.Module):
dtype: Type of arrays through this module. dtype: Type of arrays through this module.
embedding_init: initializer for relative embedding table. embedding_init: initializer for relative embedding table.
""" """
num_buckets: int num_buckets: int
max_distance: int max_distance: int
num_heads: int num_heads: int
...@@ -772,10 +807,9 @@ class RelativePositionBiases(nn.Module): ...@@ -772,10 +807,9 @@ class RelativePositionBiases(nn.Module):
embedding_init: Callable[..., Array] = nn.linear.default_embed_init embedding_init: Callable[..., Array] = nn.linear.default_embed_init
@staticmethod @staticmethod
def _relative_position_bucket(relative_position, def _relative_position_bucket(
bidirectional=True, relative_position, bidirectional=True, num_buckets=32, max_distance=128
num_buckets=32, ):
max_distance=128):
"""Translate relative position to a bucket number for relative attention. """Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e. The relative position is defined as memory_position - query_position, i.e.
...@@ -811,8 +845,10 @@ class RelativePositionBiases(nn.Module): ...@@ -811,8 +845,10 @@ class RelativePositionBiases(nn.Module):
max_exact = num_buckets // 2 max_exact = num_buckets // 2
is_small = n < max_exact is_small = n < max_exact
val_if_large = max_exact + ( val_if_large = max_exact + (
np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) / np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps)
np.log(max_distance / max_exact) * (num_buckets - max_exact)).astype(np.int32) / np.log(max_distance / max_exact)
* (num_buckets - max_exact)
).astype(np.int32)
val_if_large = np.minimum(val_if_large, num_buckets - 1) val_if_large = np.minimum(val_if_large, num_buckets - 1)
ret += np.where(is_small, n, val_if_large) ret += np.where(is_small, n, val_if_large)
return ret return ret
...@@ -833,15 +869,19 @@ class RelativePositionBiases(nn.Module): ...@@ -833,15 +869,19 @@ class RelativePositionBiases(nn.Module):
context_position = np.arange(qlen, dtype=jnp.int32)[:, None] context_position = np.arange(qlen, dtype=jnp.int32)[:, None]
memory_position = np.arange(klen, dtype=jnp.int32)[None, :] memory_position = np.arange(klen, dtype=jnp.int32)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen) relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket(relative_position, rp_bucket = self._relative_position_bucket(
relative_position,
bidirectional=bidirectional, bidirectional=bidirectional,
num_buckets=self.num_buckets, num_buckets=self.num_buckets,
max_distance=self.max_distance) max_distance=self.max_distance,
)
relative_attention_bias = nn_partitioning.param_with_axes( relative_attention_bias = nn_partitioning.param_with_axes(
'rel_embedding', "rel_embedding",
self.embedding_init, (self.num_heads, self.num_buckets), self.embedding_init,
(self.num_heads, self.num_buckets),
jnp.float32, jnp.float32,
axes=('heads', 'relpos_buckets')) axes=("heads", "relpos_buckets"),
)
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
# Instead of using a slow gather, we create a leading-dimension one-hot # Instead of using a slow gather, we create a leading-dimension one-hot
...@@ -855,9 +895,8 @@ class RelativePositionBiases(nn.Module): ...@@ -855,9 +895,8 @@ class RelativePositionBiases(nn.Module):
values = lax.dot_general( values = lax.dot_general(
relative_attention_bias, relative_attention_bias,
rp_bucket_one_hot, rp_bucket_one_hot,
( (((1,), (0,)), ((), ())), # rhs, lhs contracting dims
((1,), (0,)), # rhs, lhs contracting dims ) # no batched dims
((), ()))) # no batched dims
# Add a singleton batch dimension. # Add a singleton batch dimension.
# --> shape (1, num_heads, qlen, klen) # --> shape (1, num_heads, qlen, klen)
return values[jnp.newaxis, ...] return values[jnp.newaxis, ...]
...@@ -865,6 +904,7 @@ class RelativePositionBiases(nn.Module): ...@@ -865,6 +904,7 @@ class RelativePositionBiases(nn.Module):
class EncoderLayer(nn.Module): class EncoderLayer(nn.Module):
"""Transformer encoder layer.""" """Transformer encoder layer."""
enable_relative_embedding: bool = True enable_relative_embedding: bool = True
relative_embedding: nn.Module = None relative_embedding: nn.Module = None
num_attention_heads: int = 8 num_attention_heads: int = 8
...@@ -880,17 +920,17 @@ class EncoderLayer(nn.Module): ...@@ -880,17 +920,17 @@ class EncoderLayer(nn.Module):
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
mlp_dim: int = 2048 mlp_dim: int = 2048
mlp_activations: Sequence[str] = ('relu',) mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False use_bias: bool = False
dtype: Any = jnp.float32 dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
layernorm_type: str = 'layernorm' layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6 layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
output_layernorm: bool = False output_layernorm: bool = False
drop_path: float = 0.0 drop_path: float = 0.0
enable_rotary_pos_emb: bool = False enable_rotary_pos_emb: bool = False
rotary_pos_emb_group_method: str = 'consecutive' rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
fuse_mlp_wi: bool = True fuse_mlp_wi: bool = True
self_attn_bias_type: Any = None self_attn_bias_type: Any = None
...@@ -910,13 +950,14 @@ class EncoderLayer(nn.Module): ...@@ -910,13 +950,14 @@ class EncoderLayer(nn.Module):
if self.enable_relative_embedding: if self.enable_relative_embedding:
if self.relative_embedding is None: if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32, rel_emb = RelativePositionBiases(
num_buckets=32,
max_distance=128, max_distance=128,
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling( embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
1.0, 'fan_avg', 'uniform'), name="relpos_bias",
name='relpos_bias') )
else: else:
rel_emb = self.relative_embedding rel_emb = self.relative_embedding
encoder_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True) encoder_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
...@@ -928,11 +969,13 @@ class EncoderLayer(nn.Module): ...@@ -928,11 +969,13 @@ class EncoderLayer(nn.Module):
if not self.output_layernorm: if not self.output_layernorm:
# Attention block. # Attention block.
x = LayerNorm(layernorm_type=self.layernorm_type, x = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype, dtype=self.dtype,
name="pre_attention_layer_norm")(inputs) name="pre_attention_layer_norm",
)(inputs)
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
residual = x residual = x
...@@ -940,7 +983,8 @@ class EncoderLayer(nn.Module): ...@@ -940,7 +983,8 @@ class EncoderLayer(nn.Module):
x = inputs x = inputs
# [batch, length, emb_dim] -> [batch, length, emb_dim] # [batch, length, emb_dim] -> [batch, length, emb_dim]
x = MultiHeadAttention(num_heads=self.num_attention_heads, x = MultiHeadAttention(
num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups, num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype, dtype=self.dtype,
head_dim=self.head_dim, head_dim=self.head_dim,
...@@ -953,26 +997,27 @@ class EncoderLayer(nn.Module): ...@@ -953,26 +997,27 @@ class EncoderLayer(nn.Module):
enable_rotary_pos_emb=self.enable_rotary_pos_emb, enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
use_bias=self.use_bias, use_bias=self.use_bias,
name='attention')(x, name="attention",
x, )(x, x, encoder_mask, encoder_bias, deterministic=deterministic)
encoder_mask, x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
encoder_bias, x, deterministic=deterministic
deterministic=deterministic) )
x = nn.Dropout(rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims)(x, deterministic=deterministic)
if self.drop_path > 0.0: if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim) drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
x = nn.Dropout(rate=self.drop_path, x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
broadcast_dims=drop_path_shape)(x, deterministic=deterministic) x, deterministic=deterministic
)
x = x + residual x = x + residual
# MLP block. # MLP block.
residual = x residual = x
y = LayerNorm(layernorm_type=self.layernorm_type, y = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype, dtype=self.dtype,
name='pre_mlp_layer_norm')(x) name="pre_mlp_layer_norm",
)(x)
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
residual = y residual = y
...@@ -987,27 +1032,32 @@ class EncoderLayer(nn.Module): ...@@ -987,27 +1032,32 @@ class EncoderLayer(nn.Module):
use_bias=self.use_bias, use_bias=self.use_bias,
dtype=self.dtype, dtype=self.dtype,
fuse_wi=self.fuse_mlp_wi, fuse_wi=self.fuse_mlp_wi,
name='mlp', name="mlp",
)(y, deterministic=deterministic) )(y, deterministic=deterministic)
y = nn.Dropout(rate=self.hidden_dropout, y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
broadcast_dims=self.hidden_dropout_dims)(y, deterministic=deterministic) y, deterministic=deterministic
)
if self.drop_path > 0.0: if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim) drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim)
y = nn.Dropout(rate=self.drop_path, y = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
broadcast_dims=drop_path_shape)(y, deterministic=deterministic) y, deterministic=deterministic
)
y = y + residual y = y + residual
if self.output_layernorm: if self.output_layernorm:
y = LayerNorm(layernorm_type=self.layernorm_type, y = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype, dtype=self.dtype,
name="output_layernorm")(y) name="output_layernorm",
)(y)
return y return y
class DecoderLayer(nn.Module): class DecoderLayer(nn.Module):
"""Transformer decoder layer that attends to the encoder.""" """Transformer decoder layer that attends to the encoder."""
enable_relative_embedding: bool = True enable_relative_embedding: bool = True
relative_embedding: nn.Module = None relative_embedding: nn.Module = None
num_attention_heads: int = 8 num_attention_heads: int = 8
...@@ -1023,17 +1073,17 @@ class DecoderLayer(nn.Module): ...@@ -1023,17 +1073,17 @@ class DecoderLayer(nn.Module):
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
mlp_dim: int = 2048 mlp_dim: int = 2048
mlp_activations: Sequence[str] = ('relu',) mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False use_bias: bool = False
dtype: Any = jnp.float32 dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False output_layernorm: bool = False
layernorm_type: str = 'layernorm' layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6 layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
drop_path: float = 0.0 drop_path: float = 0.0
enable_rotary_pos_emb: bool = False enable_rotary_pos_emb: bool = False
rotary_pos_emb_group_method: str = 'consecutive' rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
fuse_mlp_wi: bool = True fuse_mlp_wi: bool = True
self_attn_bias_type: Any = None self_attn_bias_type: Any = None
...@@ -1045,14 +1095,16 @@ class DecoderLayer(nn.Module): ...@@ -1045,14 +1095,16 @@ class DecoderLayer(nn.Module):
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
def __call__(self, def __call__(
self,
inputs, inputs,
encoded, encoded,
decoder_mask=None, decoder_mask=None,
encoder_decoder_mask=None, encoder_decoder_mask=None,
deterministic=False, deterministic=False,
decode=False, decode=False,
max_decode_length=None): max_decode_length=None,
):
del self.self_attn_mask_type # dummy, just align to TE's impl del self.self_attn_mask_type # dummy, just align to TE's impl
# Relative position embedding as attention biases. # Relative position embedding as attention biases.
sequence_dim = 0 if self.transpose_batch_sequence else 1 sequence_dim = 0 if self.transpose_batch_sequence else 1
...@@ -1061,13 +1113,14 @@ class DecoderLayer(nn.Module): ...@@ -1061,13 +1113,14 @@ class DecoderLayer(nn.Module):
if self.enable_relative_embedding: if self.enable_relative_embedding:
l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim] l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim]
if self.relative_embedding is None: if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32, rel_emb = RelativePositionBiases(
num_buckets=32,
max_distance=128, max_distance=128,
num_heads=self.num_attention_heads, num_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling( embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
1.0, 'fan_avg', 'uniform'), name="relpos_bias",
name='relpos_bias') )
else: else:
rel_emb = self.relative_embedding rel_emb = self.relative_embedding
decoder_bias = rel_emb(l, l, False) decoder_bias = rel_emb(l, l, False)
...@@ -1079,11 +1132,13 @@ class DecoderLayer(nn.Module): ...@@ -1079,11 +1132,13 @@ class DecoderLayer(nn.Module):
if not self.output_layernorm: if not self.output_layernorm:
# Attention block. # Attention block.
x = LayerNorm(layernorm_type=self.layernorm_type, x = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype, dtype=self.dtype,
name="pre_self_attention_layer_norm")(inputs) name="pre_self_attention_layer_norm",
)(inputs)
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
residual = x residual = x
...@@ -1091,7 +1146,8 @@ class DecoderLayer(nn.Module): ...@@ -1091,7 +1146,8 @@ class DecoderLayer(nn.Module):
x = inputs x = inputs
# Self-attention block # Self-attention block
x = MultiHeadAttention(num_heads=self.num_attention_heads, x = MultiHeadAttention(
num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups, num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype, dtype=self.dtype,
head_dim=self.head_dim, head_dim=self.head_dim,
...@@ -1104,31 +1160,32 @@ class DecoderLayer(nn.Module): ...@@ -1104,31 +1160,32 @@ class DecoderLayer(nn.Module):
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params, fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias, use_bias=self.use_bias,
name='self_attention')(x, name="self_attention",
x, )(x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode)
decoder_mask, x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
decoder_bias, x, deterministic=deterministic
deterministic=deterministic, )
decode=decode)
x = nn.Dropout(rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims)(x, deterministic=deterministic)
if self.drop_path > 0.0: if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim) drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
x = nn.Dropout(rate=self.drop_path, x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
broadcast_dims=drop_path_shape)(x, deterministic=deterministic) x, deterministic=deterministic
)
x = x + residual x = x + residual
# Encoder-Decoder block. # Encoder-Decoder block.
residual = x residual = x
y = LayerNorm(layernorm_type=self.layernorm_type, y = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype, dtype=self.dtype,
name='pre_cross_attention_layer_norm')(x) name="pre_cross_attention_layer_norm",
)(x)
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
residual = y residual = y
y = MultiHeadAttention(num_heads=self.num_attention_heads, y = MultiHeadAttention(
num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups, num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype, dtype=self.dtype,
head_dim=self.head_dim, head_dim=self.head_dim,
...@@ -1141,21 +1198,22 @@ class DecoderLayer(nn.Module): ...@@ -1141,21 +1198,22 @@ class DecoderLayer(nn.Module):
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params, fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias, use_bias=self.use_bias,
name='encoder_decoder_attention')(y, name="encoder_decoder_attention",
encoded, )(y, encoded, encoder_decoder_mask, deterministic=deterministic)
encoder_decoder_mask, y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
deterministic=deterministic) y, deterministic=deterministic
y = nn.Dropout(rate=self.hidden_dropout, )
broadcast_dims=self.hidden_dropout_dims)(y, deterministic=deterministic)
y = y + residual y = y + residual
# MLP block. # MLP block.
residual = y residual = y
z = LayerNorm(layernorm_type=self.layernorm_type, z = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype, dtype=self.dtype,
name='pre_mlp_layer_norm')(y) name="pre_mlp_layer_norm",
)(y)
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
residual = z residual = z
z = MlpBlock( z = MlpBlock(
...@@ -1167,22 +1225,26 @@ class DecoderLayer(nn.Module): ...@@ -1167,22 +1225,26 @@ class DecoderLayer(nn.Module):
use_bias=self.use_bias, use_bias=self.use_bias,
dtype=self.dtype, dtype=self.dtype,
fuse_wi=self.fuse_mlp_wi, fuse_wi=self.fuse_mlp_wi,
name='mlp', name="mlp",
)(z, deterministic=deterministic) )(z, deterministic=deterministic)
z = nn.Dropout(rate=self.hidden_dropout, z = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
broadcast_dims=self.hidden_dropout_dims)(z, deterministic=deterministic) z, deterministic=deterministic
)
if self.drop_path > 0.0: if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim) drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
z = nn.Dropout(rate=self.drop_path, z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
broadcast_dims=drop_path_shape)(z, deterministic=deterministic) z, deterministic=deterministic
)
z = z + residual z = z + residual
if self.output_layernorm: if self.output_layernorm:
z = LayerNorm(layernorm_type=self.layernorm_type, z = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype, dtype=self.dtype,
name="output_layernorm")(z) name="output_layernorm",
)(z)
return z return z
...@@ -1261,15 +1323,18 @@ def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08): ...@@ -1261,15 +1323,18 @@ def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08):
flatten_expected, _ = jax.tree_util.tree_flatten_with_path(expected) flatten_expected, _ = jax.tree_util.tree_flatten_with_path(expected)
flatten_actual, _ = jax.tree_util.tree_flatten_with_path(actual) flatten_actual, _ = jax.tree_util.tree_flatten_with_path(actual)
for (expected_path, expected_value), (actual_path, for (expected_path, expected_value), (actual_path, actual_value) in zip(
actual_value) in zip(flatten_expected, flatten_actual): flatten_expected, flatten_actual
):
assert expected_path == actual_path assert expected_path == actual_path
key_str = jax.tree_util.keystr(expected_path) key_str = jax.tree_util.keystr(expected_path)
assert_allclose(expected_value, assert_allclose(
expected_value,
actual_value, actual_value,
rtol=rtol, rtol=rtol,
atol=atol, atol=atol,
err_msg=f'Value of expected{key_str} and actual{key_str} is not close') err_msg=f"Value of expected{key_str} and actual{key_str} is not close",
)
def dtype_tols( def dtype_tols(
...@@ -1323,7 +1388,7 @@ def dtype_tols( ...@@ -1323,7 +1388,7 @@ def dtype_tols(
) )
def sync_params_values(dst, src, transformations, sep='/'): def sync_params_values(dst, src, transformations, sep="/"):
""" """
This function will reconstuct a tree with dst's tree_def/shape and src's value. This function will reconstuct a tree with dst's tree_def/shape and src's value.
transformations is a map that records the key mappings between dst and src. transformations is a map that records the key mappings between dst and src.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment