Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
...@@ -79,6 +79,7 @@ def calibrate(model, device, test_loader, fp8): ...@@ -79,6 +79,7 @@ def calibrate(model, device, test_loader, fp8):
with te.fp8_autocast(enabled=fp8, calibrating=True): with te.fp8_autocast(enabled=fp8, calibrating=True):
output = model(data) output = model(data)
def test(model, device, test_loader, use_fp8): def test(model, device, test_loader, use_fp8):
"""Testing function.""" """Testing function."""
model.eval() model.eval()
...@@ -89,12 +90,8 @@ def test(model, device, test_loader, use_fp8): ...@@ -89,12 +90,8 @@ def test(model, device, test_loader, use_fp8):
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=use_fp8): with te.fp8_autocast(enabled=use_fp8):
output = model(data) output = model(data)
test_loss += F.nll_loss( test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
output, target, reduction="sum" pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item() correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset) test_loss /= len(test_loader.dataset)
...@@ -150,9 +147,7 @@ def main(): ...@@ -150,9 +147,7 @@ def main():
default=False, default=False,
help="quickly check a single pass", help="quickly check a single pass",
) )
parser.add_argument( parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument( parser.add_argument(
"--log-interval", "--log-interval",
type=int, type=int,
...@@ -167,7 +162,10 @@ def main(): ...@@ -167,7 +162,10 @@ def main():
help="For Saving the current Model", help="For Saving the current Model",
) )
parser.add_argument( parser.add_argument(
"--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration" "--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration",
) )
parser.add_argument( parser.add_argument(
"--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only" "--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only"
...@@ -215,7 +213,7 @@ def main(): ...@@ -215,7 +213,7 @@ def main():
if args.save_model or args.use_fp8_infer: if args.save_model or args.use_fp8_infer:
torch.save(model.state_dict(), "mnist_cnn.pt") torch.save(model.state_dict(), "mnist_cnn.pt")
print('Eval with reloaded checkpoint : fp8='+str(args.use_fp8_infer)) print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8_infer))
weights = torch.load("mnist_cnn.pt") weights = torch.load("mnist_cnn.pt")
model.load_state_dict(weights) model.load_state_dict(weights)
test(model, device, test_loader, args.use_fp8_infer) test(model, device, test_loader, args.use_fp8_infer)
......
...@@ -18,21 +18,26 @@ path = sys.argv[1] ...@@ -18,21 +18,26 @@ path = sys.argv[1]
config_path = os.path.dirname(os.path.realpath(__file__)) + "/config.json" config_path = os.path.dirname(os.path.realpath(__file__)) + "/config.json"
class bcolors: class bcolors:
OKGREEN = '\033[92m' OKGREEN = "\033[92m"
WARNING = '\033[93m' WARNING = "\033[93m"
FAIL = '\033[91m' FAIL = "\033[91m"
ENDC = '\033[0m' ENDC = "\033[0m"
def print_ok(msg): def print_ok(msg):
print(f"{bcolors.OKGREEN}{msg}{bcolors.ENDC}") print(f"{bcolors.OKGREEN}{msg}{bcolors.ENDC}")
def print_fail(msg): def print_fail(msg):
print(f"{bcolors.FAIL}{msg}{bcolors.ENDC}") print(f"{bcolors.FAIL}{msg}{bcolors.ENDC}")
def print_warn(msg): def print_warn(msg):
print(f"{bcolors.WARNING}{msg}{bcolors.ENDC}") print(f"{bcolors.WARNING}{msg}{bcolors.ENDC}")
with open(config_path, "r") as f: with open(config_path, "r") as f:
c = json.load(f) c = json.load(f)
current_year = datetime.date.today().year current_year = datetime.date.today().year
...@@ -41,7 +46,7 @@ with open(config_path, "r") as f: ...@@ -41,7 +46,7 @@ with open(config_path, "r") as f:
else: else:
year_string = str(c["initial_year"]) + "-" + str(current_year) year_string = str(c["initial_year"]) + "-" + str(current_year)
copyright_string = c["copyright"].replace("<YEAR>", year_string) copyright_string = c["copyright"].replace("<YEAR>", year_string)
license = c["license"].split('\n') license = c["license"].split("\n")
excludes = c["exclude"] excludes = c["exclude"]
root_path = os.path.abspath(path) root_path = os.path.abspath(path)
copyright_only = c["copyright_only"] copyright_only = c["copyright_only"]
...@@ -49,36 +54,42 @@ with open(config_path, "r") as f: ...@@ -49,36 +54,42 @@ with open(config_path, "r") as f:
has_gitignore = os.path.exists(root_path + "/.gitignore") has_gitignore = os.path.exists(root_path + "/.gitignore")
def strip_star_slash(s): def strip_star_slash(s):
ret = s ret = s
if ret.startswith('*'): if ret.startswith("*"):
ret = ret[1:] ret = ret[1:]
if ret.endswith('/'): if ret.endswith("/"):
ret = ret[:-1] ret = ret[:-1]
return ret return ret
if has_gitignore: if has_gitignore:
with open(root_path + "/.gitignore", "r") as f: with open(root_path + "/.gitignore", "r") as f:
for line in f.readlines(): for line in f.readlines():
excludes.append(strip_star_slash(line.strip())) excludes.append(strip_star_slash(line.strip()))
def get_file_type(path): def get_file_type(path):
ext = {"c": ["c", "cpp", "cu", "h", "cuh"], ext = {
"py": ["py"], "c": ["c", "cpp", "cu", "h", "cuh"],
"rst": ["rst"], "py": ["py"],
"txt": ["txt"], "rst": ["rst"],
"cfg": ["cfg"], "txt": ["txt"],
"sh": ["sh"], "cfg": ["cfg"],
"md": ["md"], "sh": ["sh"],
} "md": ["md"],
}
tmp = path.split(".") tmp = path.split(".")
for filetype, ext_list in ext.items(): for filetype, ext_list in ext.items():
if tmp[-1] in ext_list: if tmp[-1] in ext_list:
return filetype return filetype
return "unknown" return "unknown"
success = True success = True
def check_file(path): def check_file(path):
global success global success
N = 10 N = 10
...@@ -127,9 +138,10 @@ def check_file(path): ...@@ -127,9 +138,10 @@ def check_file(path):
if copyright_found and license_found: if copyright_found and license_found:
print_ok("OK") print_ok("OK")
for root, dirs, files in os.walk(root_path): for root, dirs, files in os.walk(root_path):
print(f"Entering {root}") print(f"Entering {root}")
hidden = [d for d in dirs if d.startswith('.')] + [f for f in files if f.startswith('.')] hidden = [d for d in dirs if d.startswith(".")] + [f for f in files if f.startswith(".")]
all_excludes = excludes + hidden all_excludes = excludes + hidden
to_remove = [] to_remove = []
for d in dirs: for d in dirs:
......
...@@ -27,12 +27,13 @@ current_file_path = Path(__file__).parent.resolve() ...@@ -27,12 +27,13 @@ current_file_path = Path(__file__).parent.resolve()
from setuptools.command.build_ext import build_ext as BuildExtension from setuptools.command.build_ext import build_ext as BuildExtension
if "pytorch" in frameworks: if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks: elif "paddle" in frameworks:
from paddle.utils.cpp_extension import BuildExtension from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks: elif "jax" in frameworks:
install_and_import('pybind11') install_and_import("pybind11")
from pybind11.setup_helpers import build_ext as BuildExtension from pybind11.setup_helpers import build_ext as BuildExtension
...@@ -86,34 +87,45 @@ if __name__ == "__main__": ...@@ -86,34 +87,45 @@ if __name__ == "__main__":
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension from build_tools.pytorch import setup_pytorch_extension
ext_modules.append( ext_modules.append(
setup_pytorch_extension( setup_pytorch_extension(
"transformer_engine/pytorch/csrc", "transformer_engine/pytorch/csrc",
current_file_path / "transformer_engine" / "pytorch" / "csrc", current_file_path / "transformer_engine" / "pytorch" / "csrc",
current_file_path / "transformer_engine")) current_file_path / "transformer_engine",
)
)
if "jax" in frameworks: if "jax" in frameworks:
from build_tools.jax import setup_jax_extension from build_tools.jax import setup_jax_extension
ext_modules.append( ext_modules.append(
setup_jax_extension( setup_jax_extension(
"transformer_engine/jax/csrc", "transformer_engine/jax/csrc",
current_file_path / "transformer_engine" / "jax" / "csrc", current_file_path / "transformer_engine" / "jax" / "csrc",
current_file_path / "transformer_engine")) current_file_path / "transformer_engine",
)
)
if "paddle" in frameworks: if "paddle" in frameworks:
from build_tools.paddle import setup_paddle_extension from build_tools.paddle import setup_paddle_extension
ext_modules.append( ext_modules.append(
setup_paddle_extension( setup_paddle_extension(
"transformer_engine/paddle/csrc", "transformer_engine/paddle/csrc",
current_file_path / "transformer_engine" / "paddle" / "csrc", current_file_path / "transformer_engine" / "paddle" / "csrc",
current_file_path / "transformer_engine")) current_file_path / "transformer_engine",
)
)
# Configure package # Configure package
setuptools.setup( setuptools.setup(
name="transformer_engine", name="transformer_engine",
version=__version__, version=__version__,
packages=setuptools.find_packages( packages=setuptools.find_packages(
include=["transformer_engine", include=[
"transformer_engine.*", "transformer_engine",
"transformer_engine/build_tools"], "transformer_engine.*",
"transformer_engine/build_tools",
],
), ),
extras_require={ extras_require={
"test": test_requires, "test": test_requires,
...@@ -125,5 +137,5 @@ if __name__ == "__main__": ...@@ -125,5 +137,5 @@ if __name__ == "__main__":
install_requires=install_requires, install_requires=install_requires,
license_files=("LICENSE",), license_files=("LICENSE",),
include_package_data=True, include_package_data=True,
package_data={"": ["VERSION.txt"]} package_data={"": ["VERSION.txt"]},
) )
...@@ -132,7 +132,7 @@ void compute_bwd_ref( ...@@ -132,7 +132,7 @@ void compute_bwd_ref(
for (int b = 0; b < batches; ++b) { for (int b = 0; b < batches; ++b) {
for (int h = 0; h < heads; ++h) { for (int h = 0; h < heads; ++h) {
size_t offset = b * batch_size + h * head_size; size_t offset = b * batch_size + h * head_size;
compute_single_head_bwd(grad_out + offset, grad_in + offset, softmax_in + offset, compute_single_head_bwd(grad_out + offset, grad_in + offset, softmax_in + offset,
buff + offset, scaling_factor, batches, heads, rows, cols); buff + offset, scaling_factor, batches, heads, rows, cols);
} }
} }
......
...@@ -6,7 +6,7 @@ import jax ...@@ -6,7 +6,7 @@ import jax
import pytest import pytest
@pytest.fixture(autouse=True, scope='function') @pytest.fixture(autouse=True, scope="function")
def clear_live_arrays(): def clear_live_arrays():
""" """
Clear all live arrays to keep the resource clean Clear all live arrays to keep the resource clean
......
...@@ -16,15 +16,15 @@ from utils import assert_allclose, is_devices_enough ...@@ -16,15 +16,15 @@ from utils import assert_allclose, is_devices_enough
def generate_configs(): def generate_configs():
configs = [] configs = []
if is_devices_enough(2): if is_devices_enough(2):
configs.append([2, (2,), ('dp'), MeshResource(dp_resource='dp')]) configs.append([2, (2,), "dp", MeshResource(dp_resource="dp")])
configs.append([2, (2,), ('tp'), MeshResource(tp_resource='tp')]) configs.append([2, (2,), "tp", MeshResource(tp_resource="tp")])
if is_devices_enough(4): if is_devices_enough(4):
TP_size = 2 TP_size = 2
DP_size = 2 DP_size = 2
configs.append( configs.append(
[4, (DP_size, TP_size), ('dp', 'tp'), [4, (DP_size, TP_size), ("dp", "tp"), MeshResource(dp_resource="dp", tp_resource="tp")]
MeshResource(dp_resource='dp', tp_resource='tp')]) )
return configs return configs
...@@ -46,7 +46,7 @@ def assert_equal_collectives(target_hlo, coll_count_ref): ...@@ -46,7 +46,7 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
bytes_count = 0 bytes_count = 0
def get_bytes_per_txt(t): def get_bytes_per_txt(t):
''' """
The pattern of t would be like: The pattern of t would be like:
'f32[]', 'f32[]',
'(f32[1024]{0}', '(f32[1024]{0}',
...@@ -54,24 +54,24 @@ def assert_equal_collectives(target_hlo, coll_count_ref): ...@@ -54,24 +54,24 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
'f8E4M3FN[1024]{0}', 'f8E4M3FN[1024]{0}',
'i32[1024]{0}', 'i32[1024]{0}',
'bf16[1024,1024]{0}' 'bf16[1024,1024]{0}'
''' """
match = re.search(r'(i|f)(\d+).*\[([0-9,]*)\]', t) match = re.search(r"(i|f)(\d+).*\[([0-9,]*)\]", t)
_, bits_of_type, shape = match.groups() _, bits_of_type, shape = match.groups()
bytes_of_type = int(bits_of_type) // 8 bytes_of_type = int(bits_of_type) // 8
if shape == '': if shape == "":
num_of_elements = 1 num_of_elements = 1
else: else:
num_of_elements = reduce(operator.mul, map(int, shape.split(','))) num_of_elements = reduce(operator.mul, map(int, shape.split(",")))
return bytes_of_type * num_of_elements return bytes_of_type * num_of_elements
# ['xxx-start', '=', '(bf16[xxx]', 'bf16[xxx])', 'xxx-start(', ...] # ['xxx-start', '=', '(bf16[xxx]', 'bf16[xxx])', 'xxx-start(', ...]
if '(' in hlo_text[2]: if "(" in hlo_text[2]:
for txt in hlo_text[2:]: for txt in hlo_text[2:]:
bytes_count += get_bytes_per_txt(txt) bytes_count += get_bytes_per_txt(txt)
if ')' in txt: if ")" in txt:
break break
else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...] else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...]
bytes_count = get_bytes_per_txt(hlo_text[2]) bytes_count = get_bytes_per_txt(hlo_text[2])
return bytes_count return bytes_count
...@@ -91,21 +91,24 @@ def assert_equal_collectives(target_hlo, coll_count_ref): ...@@ -91,21 +91,24 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
return result return result
target_result = count_collectives(target_splitted_hlo) target_result = count_collectives(target_splitted_hlo)
assert target_result == coll_count_ref, \ assert (
f"Expected collective count is {coll_count_ref}, but got {target_result}." target_result == coll_count_ref
), f"Expected collective count is {coll_count_ref}, but got {target_result}."
def compare_ops(target_func,
ref_func, def compare_ops(
inputs, target_func,
coll_count_ref, ref_func,
*, inputs,
grad_args=None, coll_count_ref,
metric_fwd_dtype=None, *,
metric_bwd_dtype=None, grad_args=None,
in_shardings=_UNSPECIFIED, metric_fwd_dtype=None,
out_shardings=_UNSPECIFIED, metric_bwd_dtype=None,
**kwargs): in_shardings=_UNSPECIFIED,
out_shardings=_UNSPECIFIED,
**kwargs,
):
assert len(inputs) >= 1 assert len(inputs) >= 1
if metric_fwd_dtype is None: if metric_fwd_dtype is None:
......
...@@ -15,24 +15,10 @@ from jax import jit, value_and_grad ...@@ -15,24 +15,10 @@ from jax import jit, value_and_grad
from flax import linen as nn from flax import linen as nn
from utils import assert_allclose from utils import assert_allclose
from transformer_engine.jax.dot import ( from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
type_safe_dot_general, from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available
dequantize, from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
quantize from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp
)
from transformer_engine.jax.fp8 import (
FP8MetaPackage,
FP8Helper,
is_fp8_available
)
from transformer_engine.jax.layernorm import (
layernorm,
layernorm_fp8_dot
)
from transformer_engine.jax.layernorm_mlp import (
activation_lu,
fused_layernorm_fp8_mlp
)
from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax import cpp_extensions as tex
GEMM_CASES = [ GEMM_CASES = [
...@@ -50,11 +36,11 @@ is_fp8_supported, reason = is_fp8_available() ...@@ -50,11 +36,11 @@ is_fp8_supported, reason = is_fp8_available()
def _convert_to_activation_function(fn_or_string): def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function.""" """Convert a string to an activation function."""
if fn_or_string == 'linear': if fn_or_string == "linear":
return lambda x: x return lambda x: x
if fn_or_string == 'quick_gelu': if fn_or_string == "quick_gelu":
return lambda x: nn.gelu(x, approximate=True) return lambda x: nn.gelu(x, approximate=True)
if fn_or_string == 'squared_relu': if fn_or_string == "squared_relu":
return lambda x: functools.reduce(operator.mul, [nn.relu(x), nn.relu(x)]) return lambda x: functools.reduce(operator.mul, [nn.relu(x), nn.relu(x)])
if isinstance(fn_or_string, str): if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string) return getattr(nn, fn_or_string)
...@@ -93,7 +79,7 @@ class TestFP8Dot: ...@@ -93,7 +79,7 @@ class TestFP8Dot:
assert_allclose(z, x, dtype=jnp.float8_e4m3fn) assert_allclose(z, x, dtype=jnp.float8_e4m3fn)
@pytest.mark.parametrize('m,n,k', GEMM_CASES) @pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_forward_bf16(self, m, n, k): def test_forward_bf16(self, m, n, k):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
...@@ -106,7 +92,7 @@ class TestFP8Dot: ...@@ -106,7 +92,7 @@ class TestFP8Dot:
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES) @pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_forward_fp8_randint(self, m, n, k): def test_forward_fp8_randint(self, m, n, k):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
...@@ -135,7 +121,7 @@ class TestFP8Dot: ...@@ -135,7 +121,7 @@ class TestFP8Dot:
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE) assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
@pytest.mark.parametrize('m,n,k', GEMM_CASES) @pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_grad_bf16(self, m, n, k): def test_grad_bf16(self, m, n, k):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
...@@ -161,7 +147,7 @@ class TestFP8Dot: ...@@ -161,7 +147,7 @@ class TestFP8Dot:
assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16) assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES) @pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_grad_fp8_dot(self, m, n, k): def test_grad_fp8_dot(self, m, n, k):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2) subkeys = jax.random.split(key, 2)
...@@ -192,33 +178,38 @@ class TestFP8Dot: ...@@ -192,33 +178,38 @@ class TestFP8Dot:
ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b) ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)
for _ in range(3): for _ in range(3):
primitive_out, (primitive_a_grad, primitive_b_grad, amax_list, primitive_out, (primitive_a_grad, primitive_b_grad, amax_list, scale_list) = (
scale_list) = value_n_grad_primitive_func(a, b, amax_list, scale_list) value_n_grad_primitive_func(a, b, amax_list, scale_list)
)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE) assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE) assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE) assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 128, 512), @pytest.mark.parametrize(
(16384, 1024, 2816), "m,n,k", [(256, 128, 512), (16384, 1024, 2816), (16384, 2816, 1024), (16384, 1024, 1024)]
(16384, 2816, 1024), )
(16384, 1024, 1024)]) @pytest.mark.parametrize(
@pytest.mark.parametrize('activation_type', [('gelu', ), "activation_type",
('gelu', 'linear'), [
('silu', ), ("gelu",),
('silu', 'linear'), ("gelu", "linear"),
('relu',), ("silu",),
('relu', 'linear'), ("silu", "linear"),
('quick_gelu',), ("relu",),
('quick_gelu', 'linear'), ("relu", "linear"),
('squared_relu',), ("quick_gelu",),
('squared_relu', 'linear')]) ("quick_gelu", "linear"),
@pytest.mark.parametrize('use_bias', [True, False]) ("squared_relu",),
def test_grad_fused_layernorm_fp8_mlp(self, m, n, k, activation_type: Sequence[Union[str, ("squared_relu", "linear"),
Callable]], ],
use_bias: bool): )
""" N/a """ @pytest.mark.parametrize("use_bias", [True, False])
def test_grad_fused_layernorm_fp8_mlp(
self, m, n, k, activation_type: Sequence[Union[str, Callable]], use_bias: bool
):
"""N/a"""
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 6) subkeys = jax.random.split(key, 6)
...@@ -233,8 +224,9 @@ class TestFP8Dot: ...@@ -233,8 +224,9 @@ class TestFP8Dot:
b1 = None b1 = None
b2 = None b2 = None
def primitive_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, def primitive_func(
scale_list_2): x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2
):
# x is input tensor, matrix 2d # x is input tensor, matrix 2d
# y, z are weights, matrix 2d # y, z are weights, matrix 2d
# out = ((x * y) + w) * z + v # out = ((x * y) + w) * z + v
...@@ -255,18 +247,31 @@ class TestFP8Dot: ...@@ -255,18 +247,31 @@ class TestFP8Dot:
scale_list_2[2], scale_list_2[2],
) )
return jnp.mean( return jnp.mean(
fused_layernorm_fp8_mlp(x, fused_layernorm_fp8_mlp(
ln_s, x,
None, [y, z], [w, v], [fp8_meta_pkg_1, fp8_meta_pkg_2], ln_s,
"rmsnorm", None,
activation_type=activation_type, [y, z],
use_bias=use_bias)) [w, v],
[fp8_meta_pkg_1, fp8_meta_pkg_2],
def layernorm_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray, "rmsnorm",
kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray, activation_type=activation_type,
amax_list_1: List[jnp.ndarray], amax_list_2: List[jnp.ndarray], use_bias=use_bias,
scale_list_1: List[jnp.ndarray], )
scale_list_2: List[jnp.ndarray]) -> jnp.ndarray: )
def layernorm_fp8_mlp_ref(
x: jnp.ndarray,
ln_scale: jnp.ndarray,
kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray,
bias_1: jnp.ndarray,
bias_2: jnp.ndarray,
amax_list_1: List[jnp.ndarray],
amax_list_2: List[jnp.ndarray],
scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray],
) -> jnp.ndarray:
x = jnp.asarray(x, jnp.float32) x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
...@@ -315,11 +320,14 @@ class TestFP8Dot: ...@@ -315,11 +320,14 @@ class TestFP8Dot:
def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2): def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2):
return jnp.mean( return jnp.mean(
layernorm_fp8_mlp_ref(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, layernorm_fp8_mlp_ref(
scale_list_2)) x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2
)
)
value_n_grad_primitive_func = jit( value_n_grad_primitive_func = jit(
value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9))) value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
)
value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9))) value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
_, amax_list_1, scale_list_1 = TestFP8Dot._generate_fp8_meta() _, amax_list_1, scale_list_1 = TestFP8Dot._generate_fp8_meta()
...@@ -339,40 +347,87 @@ class TestFP8Dot: ...@@ -339,40 +347,87 @@ class TestFP8Dot:
# Convert str to index as str is not a valid type for JAX JIT # Convert str to index as str is not a valid type for JAX JIT
for _ in range(3): for _ in range(3):
ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad, ref_out, (
ref_amax_list_1, ref_amax_list_2, ref_scale_list_1, ref_a_grad,
ref_scale_list_2) = value_n_grad_ref_func(a, s, k1, k2, b1, b2, ref_s_grad,
ref_amax_list_1, ref_amax_list_2, ref_k1_grad,
ref_scale_list_1, ref_scale_list_2) ref_k2_grad,
ref_b1_grad,
ref_b2_grad,
ref_amax_list_1,
ref_amax_list_2,
ref_scale_list_1,
ref_scale_list_2,
) = value_n_grad_ref_func(
a,
s,
k1,
k2,
b1,
b2,
ref_amax_list_1,
ref_amax_list_2,
ref_scale_list_1,
ref_scale_list_2,
)
for _ in range(3): for _ in range(3):
primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad, primitive_out, (
primitive_k2_grad, primitive_b1_grad, primitive_b2_grad, primitive_a_grad,
primitive_amax_list_1, primitive_amax_list_2, primitive_scale_list_1, primitive_s_grad,
primitive_scale_list_2) = value_n_grad_primitive_func( primitive_k1_grad,
a, s, k1, k2, b1, b2, primitive_amax_list_1, primitive_amax_list_2, primitive_k2_grad,
primitive_scale_list_1, primitive_scale_list_2) primitive_b1_grad,
primitive_b2_grad,
primitive_amax_list_1,
primitive_amax_list_2,
primitive_scale_list_1,
primitive_scale_list_2,
) = value_n_grad_primitive_func(
a,
s,
k1,
k2,
b1,
b2,
primitive_amax_list_1,
primitive_amax_list_2,
primitive_scale_list_1,
primitive_scale_list_2,
)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE) assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(jnp.asarray(primitive_a_grad, np.float32), assert_allclose(
jnp.asarray(ref_a_grad, np.float32), jnp.asarray(primitive_a_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE) jnp.asarray(ref_a_grad, np.float32),
assert_allclose(jnp.asarray(primitive_k1_grad, np.float32), dtype=FP8Helper.BWD_DTYPE,
jnp.asarray(ref_k1_grad, np.float32), )
dtype=FP8Helper.BWD_DTYPE) assert_allclose(
assert_allclose(jnp.asarray(primitive_s_grad, np.float32), jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32), jnp.asarray(ref_k1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE) dtype=FP8Helper.BWD_DTYPE,
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32), )
jnp.asarray(ref_k2_grad, np.float32), assert_allclose(
dtype=FP8Helper.BWD_DTYPE) jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE,
)
assert_allclose(
jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE,
)
if use_bias: if use_bias:
assert_allclose(jnp.asarray(primitive_b2_grad, np.float32), assert_allclose(
jnp.asarray(ref_b2_grad, np.float32), jnp.asarray(primitive_b2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE) jnp.asarray(ref_b2_grad, np.float32),
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32), dtype=FP8Helper.BWD_DTYPE,
jnp.asarray(ref_b1_grad, np.float32), )
dtype=FP8Helper.BWD_DTYPE) assert_allclose(
jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE,
)
@pytest.fixture(name="random_inputs") @pytest.fixture(name="random_inputs")
...@@ -402,17 +457,22 @@ class TestActivationLu: ...@@ -402,17 +457,22 @@ class TestActivationLu:
def primitive_func(self, inputs): def primitive_func(self, inputs):
return jnp.mean(activation_lu(inputs, activation_type=self.activation_type)) return jnp.mean(activation_lu(inputs, activation_type=self.activation_type))
@pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)]) @pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)])
@pytest.mark.parametrize('activation_type', [('gelu',), @pytest.mark.parametrize(
('gelu', 'linear'), "activation_type",
('silu',), [
('silu', 'linear'), ("gelu",),
('relu',), ("gelu", "linear"),
('relu', 'linear'), ("silu",),
('quick_gelu',), ("silu", "linear"),
('quick_gelu', 'linear'), ("relu",),
('squared_relu',), ("relu", "linear"),
('squared_relu', 'linear') ]) ("quick_gelu",),
("quick_gelu", "linear"),
("squared_relu",),
("squared_relu", "linear"),
],
)
def test_activation_lu(self, random_inputs, activation_type): def test_activation_lu(self, random_inputs, activation_type):
x = random_inputs x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=1) x = jnp.repeat(x, len(activation_type), axis=1)
...@@ -441,23 +501,34 @@ class TestActivationLuFP8(TestActivationLu): ...@@ -441,23 +501,34 @@ class TestActivationLuFP8(TestActivationLu):
return output return output
def _prim_func_fwd(x, _x_t, _dbias, _amax): def _prim_func_fwd(x, _x_t, _dbias, _amax):
activation_lu_out, _ = tex.act_lu_fp8(x, amax, scale, scale_inv, activation_lu_out, _ = tex.act_lu_fp8(
FP8Helper.FWD_DTYPE, activation_type) x, amax, scale, scale_inv, FP8Helper.FWD_DTYPE, activation_type
)
activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv) activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
ctx = (x) ctx = x
return activation_lu_out, ctx return activation_lu_out, ctx
def _prim_func_bwd(ctx, g): def _prim_func_bwd(ctx, g):
x = ctx x = ctx
if len(self.activation_type) > 1: #gated, no bias if len(self.activation_type) > 1: # gated, no bias
dactivation_lu, dactivation_lu_trans, amax_out = \ dactivation_lu, dactivation_lu_trans, amax_out = tex.dgated_act_lu_cast_transpose(
tex.dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv, g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE, -1, activation_type
FP8Helper.BWD_DTYPE, -1, activation_type) )
dbias = jnp.empty(x.shape[-1], x.dtype) dbias = jnp.empty(x.shape[-1], x.dtype)
else: #not gated, with bias else: # not gated, with bias
dactivation_lu, dactivation_lu_trans, dbias, amax_out = \ dactivation_lu, dactivation_lu_trans, dbias, amax_out = (
tex.dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE, tex.dact_lu_dbias_cast_transpose(
-1, -2, self.activation_type) g,
x,
amax,
scale,
scale_inv,
FP8Helper.BWD_DTYPE,
-1,
-2,
self.activation_type,
)
)
dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv) dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv) dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out) ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out)
...@@ -468,23 +539,28 @@ class TestActivationLuFP8(TestActivationLu): ...@@ -468,23 +539,28 @@ class TestActivationLuFP8(TestActivationLu):
dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_indices], dtype=x.dtype) dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_indices], dtype=x.dtype)
dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype) dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
amax_no_use = jnp.zeros(1, jnp.float32) amax_no_use = jnp.zeros(1, jnp.float32)
value_n_grad_primitive_func = value_and_grad(lambda a, b, c, d: value_n_grad_primitive_func = value_and_grad(
jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3)) lambda a, b, c, d: jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3)
)
return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use) return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)]) @pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)])
@pytest.mark.parametrize('activation_type', [('gelu',), @pytest.mark.parametrize(
('gelu', 'linear'), "activation_type",
('silu',), [
('silu', 'linear'), ("gelu",),
('relu',), ("gelu", "linear"),
('relu', 'linear'), ("silu",),
('quick_gelu',), ("silu", "linear"),
('quick_gelu', 'linear'), ("relu",),
('squared_relu',), ("relu", "linear"),
('squared_relu', 'linear') ]) ("quick_gelu",),
("quick_gelu", "linear"),
("squared_relu",),
("squared_relu", "linear"),
],
)
def test_activation_lu(self, random_inputs, activation_type): def test_activation_lu(self, random_inputs, activation_type):
self.amax = jnp.zeros(1, jnp.float32) self.amax = jnp.zeros(1, jnp.float32)
self.scale = jnp.ones(1, jnp.float32) self.scale = jnp.ones(1, jnp.float32)
...@@ -500,12 +576,14 @@ class TestActivationLuFP8(TestActivationLu): ...@@ -500,12 +576,14 @@ class TestActivationLuFP8(TestActivationLu):
assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE) assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2) assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
if 'linear' not in activation_type: if "linear" not in activation_type:
assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1)))) assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE) assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(prim_grad_trans, assert_allclose(
jnp.transpose(ref_grad, self.transpose_indices), prim_grad_trans,
dtype=FP8Helper.BWD_DTYPE) jnp.transpose(ref_grad, self.transpose_indices),
dtype=FP8Helper.BWD_DTYPE,
)
class TestNorm: class TestNorm:
...@@ -536,34 +614,38 @@ class TestNorm: ...@@ -536,34 +614,38 @@ class TestNorm:
""" """
x_ = jnp.asarray(x, jnp.float32) x_ = jnp.asarray(x, jnp.float32)
if bias is None: if bias is None:
mean = 0. mean = 0.0
else: else:
mean = jnp.mean(x_, axis=-1, keepdims=True) mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps) normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
if zero_centered_gamma: if zero_centered_gamma:
scale += 1. scale += 1.0
if bias is None: if bias is None:
bias = 0. bias = 0.0
return jnp.asarray(normed_input * scale + bias).astype(x.dtype) return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
@pytest.mark.parametrize('n, hidden', LN_CASES) @pytest.mark.parametrize("n, hidden", LN_CASES)
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm']) @pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"])
@pytest.mark.parametrize('zero_centered_gamma', [False, True]) @pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize('epsilon', [1e-2, 1e-6]) @pytest.mark.parametrize("epsilon", [1e-2, 1e-6])
def test_layernorm_forward_backward(self, n, hidden, ln_type, zero_centered_gamma, epsilon, def test_layernorm_forward_backward(
dtype): self, n, hidden, ln_type, zero_centered_gamma, epsilon, dtype
):
""" """
Test transformer_engine.jax.layernorm.layernorm Test transformer_engine.jax.layernorm.layernorm
""" """
expect_assert = False expect_assert = False
if ln_type == 'rmsnorm' and zero_centered_gamma: if ln_type == "rmsnorm" and zero_centered_gamma:
# zero_centered_gamma is not supported for rmsnorm, expect an assertion. # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
expect_assert = True expect_assert = True
with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*" with (
) if expect_assert else nullcontext(): pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*")
if expect_assert
else nullcontext()
):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3) subkeys = jax.random.split(key, 3)
...@@ -571,7 +653,7 @@ class TestNorm: ...@@ -571,7 +653,7 @@ class TestNorm:
gamma_range = (-1, 1) if zero_centered_gamma else (0, 2) gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range) gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
gamma = jnp.asarray(gamma, dtype) gamma = jnp.asarray(gamma, dtype)
if ln_type == 'layernorm': if ln_type == "layernorm":
beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1) beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
beta = jnp.asarray(beta, dtype) beta = jnp.asarray(beta, dtype)
else: else:
...@@ -585,19 +667,27 @@ class TestNorm: ...@@ -585,19 +667,27 @@ class TestNorm:
jitted_primitive = jit( jitted_primitive = jit(
value_and_grad( value_and_grad(
lambda x, gamma, beta: compute_loss( lambda x, gamma, beta: compute_loss(
layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)), layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)
(0, 1, 2))) ),
(0, 1, 2),
)
)
jitted_reference = jit( jitted_reference = jit(
value_and_grad( value_and_grad(
lambda x, gamma, beta: compute_loss( lambda x, gamma, beta: compute_loss(
self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)), self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
(0, 1, 2))) ),
(0, 1, 2),
)
)
primitive_out, (primitive_dx, primitive_dgamma, primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive(
primitive_dbeta) = jitted_primitive(x, gamma, beta) x, gamma, beta
reference_out, (reference_dx, reference_dgamma, )
reference_dbeta) = jitted_reference(x, gamma, beta) reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(
x, gamma, beta
)
assert_allclose(primitive_out, reference_out, dtype=dtype) assert_allclose(primitive_out, reference_out, dtype=dtype)
assert_allclose(primitive_dx, reference_dx, dtype=dtype) assert_allclose(primitive_dx, reference_dx, dtype=dtype)
...@@ -606,21 +696,24 @@ class TestNorm: ...@@ -606,21 +696,24 @@ class TestNorm:
assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype) assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES) @pytest.mark.parametrize("m,n,k", GEMM_CASES)
@pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm']) @pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"])
@pytest.mark.parametrize('zero_centered_gamma', [True, False]) @pytest.mark.parametrize("zero_centered_gamma", [True, False])
@pytest.mark.parametrize('epsilon', [1e-2, 1e-6]) @pytest.mark.parametrize("epsilon", [1e-2, 1e-6])
def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon): def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon):
""" """
Test transformer_engine.jax.layernorm.layernorm_fp8_dot Test transformer_engine.jax.layernorm.layernorm_fp8_dot
""" """
expect_assert = False expect_assert = False
if ln_type == 'rmsnorm' and zero_centered_gamma: if ln_type == "rmsnorm" and zero_centered_gamma:
# zero_centered_gamma is not supported for rmsnorm, expect an assertion. # zero_centered_gamma is not supported for rmsnorm, expect an assertion.
expect_assert = True expect_assert = True
with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*" with (
) if expect_assert else nullcontext(): pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*")
if expect_assert
else nullcontext()
):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 4) subkeys = jax.random.split(key, 4)
...@@ -628,7 +721,7 @@ class TestNorm: ...@@ -628,7 +721,7 @@ class TestNorm:
b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16) b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16) gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
if ln_type == 'layernorm': if ln_type == "layernorm":
beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16) beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
else: else:
beta = None beta = None
...@@ -644,8 +737,9 @@ class TestNorm: ...@@ -644,8 +737,9 @@ class TestNorm:
amax_list_1[2], amax_list_1[2],
scale_list_1[2], scale_list_1[2],
) )
primitive_out = layernorm_fp8_dot(x, y, gamma, beta, fp8_meta_pkg, ln_type, primitive_out = layernorm_fp8_dot(
zero_centered_gamma) x, y, gamma, beta, fp8_meta_pkg, ln_type, zero_centered_gamma
)
return jnp.mean(primitive_out) return jnp.mean(primitive_out)
def ref_func(x, y, gamma, beta, zero_centered_gamma): def ref_func(x, y, gamma, beta, zero_centered_gamma):
...@@ -655,14 +749,19 @@ class TestNorm: ...@@ -655,14 +749,19 @@ class TestNorm:
value_n_grad_primitive_func = value_and_grad(primitive_func, range(6)) value_n_grad_primitive_func = value_and_grad(primitive_func, range(6))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))
ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad, ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad, ref_beta_grad) = (
ref_beta_grad) = value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma) value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma)
)
for _ in range(3): for _ in range(3):
primitive_out, (primitive_a_grad, primitive_b_grad, primitive_gamma_grad, primitive_out, (
primitive_beta_grad, amax_list_1, primitive_a_grad,
scale_list_1) = value_n_grad_primitive_func( primitive_b_grad,
a, b, gamma, beta, amax_list_1, scale_list_1) primitive_gamma_grad,
primitive_beta_grad,
amax_list_1,
scale_list_1,
) = value_n_grad_primitive_func(a, b, gamma, beta, amax_list_1, scale_list_1)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE) assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE) assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
......
...@@ -9,16 +9,8 @@ import jax.numpy as jnp ...@@ -9,16 +9,8 @@ import jax.numpy as jnp
import numpy as np import numpy as np
from flax.linen import dot_product_attention from flax.linen import dot_product_attention
from jax import random from jax import random
from jax.sharding import ( from jax.sharding import Mesh, NamedSharding, PartitionSpec
Mesh, from distributed_test_base import generate_configs, generate_collectives_count, compare_ops
NamedSharding,
PartitionSpec
)
from distributed_test_base import (
generate_configs,
generate_collectives_count,
compare_ops
)
from utils import make_causal_mask, make_self_mask from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import ( from transformer_engine.jax.attention import (
...@@ -27,7 +19,7 @@ from transformer_engine.jax.attention import ( ...@@ -27,7 +19,7 @@ from transformer_engine.jax.attention import (
fused_attn_kvpacked, fused_attn_kvpacked,
AttnBiasType, AttnBiasType,
AttnMaskType, AttnMaskType,
QKVLayout QKVLayout,
) )
...@@ -36,8 +28,9 @@ DTYPES = [jnp.float16, jnp.bfloat16] ...@@ -36,8 +28,9 @@ DTYPES = [jnp.float16, jnp.bfloat16]
class TestDistributedSelfAttn: class TestDistributedSelfAttn:
def generate_collectives_count_ref(self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, def generate_collectives_count_ref(
dtype): self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype) jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
_, seqlen, _, heads, _ = shape _, seqlen, _, heads, _ = shape
is_dp_enabled = mesh_resource.dp_resource is not None is_dp_enabled = mesh_resource.dp_resource is not None
...@@ -46,7 +39,7 @@ class TestDistributedSelfAttn: ...@@ -46,7 +39,7 @@ class TestDistributedSelfAttn:
idx = mesh_axes.index(mesh_resource.tp_resource) idx = mesh_axes.index(mesh_resource.tp_resource)
tp_size = mesh_shape[idx] tp_size = mesh_shape[idx]
all_reduce_loss_bytes = 4 # 1 * FP32 all_reduce_loss_bytes = 4 # 1 * FP32
bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize
allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled) allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled)
# for loss and dbias # for loss and dbias
...@@ -57,8 +50,11 @@ class TestDistributedSelfAttn: ...@@ -57,8 +50,11 @@ class TestDistributedSelfAttn:
qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype) qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
bias = random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype) \ bias = (
if with_bias else None random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype)
if with_bias
else None
)
mask = None mask = None
if attn_mask_type == AttnMaskType.PADDING_MASK: if attn_mask_type == AttnMaskType.PADDING_MASK:
...@@ -66,47 +62,76 @@ class TestDistributedSelfAttn: ...@@ -66,47 +62,76 @@ class TestDistributedSelfAttn:
elif attn_mask_type == AttnMaskType.CAUSAL_MASK: elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_self_mask(batch, seqlen) mask = make_self_mask(batch, seqlen)
qkv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, qkv_pspec = PartitionSpec(
None) mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
bias_pspec = PartitionSpec(None, mesh_resource.tp_resource, None, None) \ )
if with_bias else None bias_pspec = (
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \ PartitionSpec(None, mesh_resource.tp_resource, None, None) if with_bias else None
if attn_mask_type != AttnMaskType.NO_MASK else None )
mask_pspec = (
PartitionSpec(mesh_resource.dp_resource, None, None, None)
if attn_mask_type != AttnMaskType.NO_MASK
else None
)
return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]]) @pytest.mark.parametrize("data_shape", [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
@pytest.mark.parametrize( @pytest.mark.parametrize(
'attn_bias_type', "attn_bias_type",
[AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS]) [AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS],
@pytest.mark.parametrize('attn_mask_type', )
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) @pytest.mark.parametrize(
@pytest.mark.parametrize('dtype', DTYPES) "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
def test_self_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, )
attn_bias_type, attn_mask_type, dtype): @pytest.mark.parametrize("dtype", DTYPES)
def test_self_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
attn_mask_type,
dtype,
):
dropout_prob = 0.0 dropout_prob = 0.0
is_training = True is_training = True
scaling_factor = 1.0 scaling_factor = 1.0
_, seqlen, _, num_head, hidden = data_shape _, seqlen, _, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type, if not is_fused_attn_kernel_available(
attn_mask_type, dropout_prob, num_head, num_head, dtype,
seqlen, seqlen, hidden): dtype,
QKVLayout.BS3HD,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_head,
seqlen,
seqlen,
hidden,
):
pytest.skip(f"No FusedAttn backwend found") pytest.skip(f"No FusedAttn backwend found")
def target_func(qkv, bias, mask): def target_func(qkv, bias, mask):
return jnp.mean( return jnp.mean(
fused_attn_qkvpacked(qkv, fused_attn_qkvpacked(
bias, qkv,
mask, bias,
None, mask,
attn_bias_type=attn_bias_type, None,
attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type,
scaling_factor=scaling_factor, attn_mask_type=attn_mask_type,
dropout_probability=dropout_prob, scaling_factor=scaling_factor,
is_training=is_training)) dropout_probability=dropout_prob,
is_training=is_training,
)
)
def ref_func(qkv, bias, mask): def ref_func(qkv, bias, mask):
query, key, value = jnp.split(qkv, [1, 2], axis=-3) query, key, value = jnp.split(qkv, [1, 2], axis=-3)
...@@ -114,52 +139,59 @@ class TestDistributedSelfAttn: ...@@ -114,52 +139,59 @@ class TestDistributedSelfAttn:
key = jnp.squeeze(key) key = jnp.squeeze(key)
value = jnp.squeeze(value) value = jnp.squeeze(value)
output = dot_product_attention(query, output = dot_product_attention(
key, query,
value, key,
bias=bias, value,
mask=mask, bias=bias,
deterministic=is_training, mask=mask,
dropout_rate=dropout_prob, deterministic=is_training,
dropout_rng=None, dropout_rate=dropout_prob,
dtype=jnp.float32) dropout_rng=None,
dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype) return jnp.mean(output).astype(dtype)
with_bias = attn_bias_type != AttnBiasType.NO_BIAS with_bias = attn_bias_type != AttnBiasType.NO_BIAS
(qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = \ (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = self.generate_inputs(
self.generate_inputs(data_shape, mesh_resource, with_bias, data_shape, mesh_resource, with_bias, attn_mask_type, dtype
attn_mask_type, dtype) )
collective_count_ref = self.generate_collectives_count_ref(mesh_shape, mesh_axes, collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, with_bias, mesh_shape, mesh_axes, mesh_resource, with_bias, data_shape, dtype
data_shape, dtype) )
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource): with mesh, fp8_autocast(mesh_resource=mesh_resource):
qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec)) qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec))
bias_ = jax.device_put(bias, NamedSharding(mesh, bias_pspec)) \ bias_ = (
if bias is not None else bias jax.device_put(bias, NamedSharding(mesh, bias_pspec)) if bias is not None else bias
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \ )
if mask is not None else mask mask_ = (
jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
)
grad_args = (0, 1) if with_bias else (0,) grad_args = (0, 1) if with_bias else (0,)
out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,) out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,)
compare_ops(target_func, compare_ops(
ref_func, [qkv_, bias_, mask_], target_func,
collective_count_ref, ref_func,
grad_args=grad_args, [qkv_, bias_, mask_],
metric_fwd_dtype=dtype, collective_count_ref,
metric_bwd_dtype=dtype, grad_args=grad_args,
in_shardings=(qkv_pspec, bias_pspec, mask_pspec), metric_fwd_dtype=dtype,
out_shardings=(None, out_grad_shardings)) metric_bwd_dtype=dtype,
in_shardings=(qkv_pspec, bias_pspec, mask_pspec),
out_shardings=(None, out_grad_shardings),
)
class TestDistributedCrossAttn: class TestDistributedCrossAttn:
def generate_collectives_count_ref(self): def generate_collectives_count_ref(self):
# for loss # for loss
all_reduce_loss_bytes = 4 # 1 * FP32 all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype): def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype):
...@@ -176,20 +208,26 @@ class TestDistributedCrossAttn: ...@@ -176,20 +208,26 @@ class TestDistributedCrossAttn:
q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None) q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None)
kv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, kv_pspec = PartitionSpec(
None) mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \ )
if attn_mask_type != AttnMaskType.NO_MASK else None mask_pspec = (
PartitionSpec(mesh_resource.dp_resource, None, None, None)
if attn_mask_type != AttnMaskType.NO_MASK
else None
)
return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 12, 64], [32, 512, 16, 64]]) @pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]])
@pytest.mark.parametrize('attn_mask_type', @pytest.mark.parametrize(
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]) "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
@pytest.mark.parametrize('dtype', DTYPES) )
def test_cross_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, @pytest.mark.parametrize("dtype", DTYPES)
attn_mask_type, dtype): def test_cross_attn(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
):
attn_bias_type = AttnBiasType.NO_BIAS attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0 dropout_prob = 0.0
is_training = True is_training = True
...@@ -197,23 +235,36 @@ class TestDistributedCrossAttn: ...@@ -197,23 +235,36 @@ class TestDistributedCrossAttn:
_, seqlen, num_head, hidden = data_shape _, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BSHD_BS2HD, attn_bias_type, if not is_fused_attn_kernel_available(
attn_mask_type, dropout_prob, num_head, num_head, dtype,
seqlen, seqlen, hidden): dtype,
QKVLayout.BSHD_BS2HD,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_head,
seqlen,
seqlen,
hidden,
):
pytest.skip(f"No FusedAttn backwend found") pytest.skip(f"No FusedAttn backwend found")
def target_func(q, kv, mask): def target_func(q, kv, mask):
return jnp.mean( return jnp.mean(
fused_attn_kvpacked(q, fused_attn_kvpacked(
kv, q,
None, kv,
mask, None,
None, mask,
attn_bias_type=attn_bias_type, None,
attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type,
scaling_factor=scaling_factor, attn_mask_type=attn_mask_type,
dropout_probability=dropout_prob, scaling_factor=scaling_factor,
is_training=is_training)) dropout_probability=dropout_prob,
is_training=is_training,
)
)
def ref_func(query, kv, mask): def ref_func(query, kv, mask):
key, value = jnp.split(kv, [1], axis=-3) key, value = jnp.split(kv, [1], axis=-3)
...@@ -221,34 +272,41 @@ class TestDistributedCrossAttn: ...@@ -221,34 +272,41 @@ class TestDistributedCrossAttn:
key = jnp.squeeze(key) key = jnp.squeeze(key)
value = jnp.squeeze(value) value = jnp.squeeze(value)
output = dot_product_attention(query, output = dot_product_attention(
key, query,
value, key,
bias=None, value,
mask=mask, bias=None,
deterministic=is_training, mask=mask,
dropout_rate=dropout_prob, deterministic=is_training,
dropout_rng=None, dropout_rate=dropout_prob,
dtype=jnp.float32) dropout_rng=None,
dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype) return jnp.mean(output).astype(dtype)
(q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = \ (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs(
self.generate_inputs(data_shape, mesh_resource, attn_mask_type, dtype) data_shape, mesh_resource, attn_mask_type, dtype
)
collective_count_ref = self.generate_collectives_count_ref() collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource): with mesh, fp8_autocast(mesh_resource=mesh_resource):
q_ = jax.device_put(q, NamedSharding(mesh, q_pspec)) q_ = jax.device_put(q, NamedSharding(mesh, q_pspec))
kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec)) kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \ mask_ = (
if mask is not None else mask jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
)
compare_ops(target_func,
ref_func, [q_, kv_, mask_], compare_ops(
collective_count_ref, target_func,
grad_args=(0, 1), ref_func,
metric_fwd_dtype=dtype, [q_, kv_, mask_],
metric_bwd_dtype=dtype, collective_count_ref,
in_shardings=(q_pspec, kv_pspec, mask_pspec), grad_args=(0, 1),
out_shardings=(None, (q_pspec, kv_pspec))) metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(q_pspec, kv_pspec, mask_pspec),
out_shardings=(None, (q_pspec, kv_pspec)),
)
...@@ -35,31 +35,44 @@ class TestDistributedLayernorm: ...@@ -35,31 +35,44 @@ class TestDistributedLayernorm:
else: else:
raise NotImplementedError raise NotImplementedError
g_pspec = b_pspec = PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None) g_pspec = b_pspec = (
PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
)
return (x, gamma, beta), (x_pspec, g_pspec, b_pspec) return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)
def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype): def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype) jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
is_dp_enabled = mesh_resource.dp_resource is not None is_dp_enabled = mesh_resource.dp_resource is not None
assert ln_type in ['layernorm', 'rmsnorm'] assert ln_type in ["layernorm", "rmsnorm"]
all_reduce_loss_bytes = 4 # 1 * FP32 all_reduce_loss_bytes = 4 # 1 * FP32
# for loss, dgamma and dbeta # for loss, dgamma and dbeta
weight_count = 2 if ln_type == 'layernorm' else 1 weight_count = 2 if ln_type == "layernorm" else 1
allreduce_total_bytes = all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize allreduce_total_bytes = (
return generate_collectives_count(allreduce=allreduce_total_bytes * int(is_dp_enabled), all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
allgather=0, )
other=0) return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=0
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) )
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize('zero_centered_gamma', [False, True]) @pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('shard_weights', [False, True]) @pytest.mark.parametrize("dtype", DTYPES)
def test_layernorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, @pytest.mark.parametrize("zero_centered_gamma", [False, True])
zero_centered_gamma, shard_weights): @pytest.mark.parametrize("shard_weights", [False, True])
def test_layernorm(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
dtype,
zero_centered_gamma,
shard_weights,
):
epsilon = 1e-6 epsilon = 1e-6
ln_type = 'layernorm' ln_type = "layernorm"
def target_func(x, gamma, beta): def target_func(x, gamma, beta):
return jnp.mean(layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)) return jnp.mean(layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon))
...@@ -75,10 +88,12 @@ class TestDistributedLayernorm: ...@@ -75,10 +88,12 @@ class TestDistributedLayernorm:
output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype) output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype)
return jnp.mean(output) return jnp.mean(output)
(x, gamma, beta), (x_pspec, g_pspec, b_pspec) = \ (x, gamma, beta), (x_pspec, g_pspec, b_pspec) = self.generate_inputs(
self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights) data_shape, mesh_resource, dtype, shard_weights
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type, )
data_shape, dtype) collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource): with mesh, fp8_autocast(mesh_resource=mesh_resource):
...@@ -88,20 +103,25 @@ class TestDistributedLayernorm: ...@@ -88,20 +103,25 @@ class TestDistributedLayernorm:
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
compare_ops(target_func, compare_ops(
ref_func, [x_, gamma_, beta_], target_func,
collective_count_ref, ref_func,
grad_args=(0, 1, 2), [x_, gamma_, beta_],
metric_fwd_dtype=dtype, collective_count_ref,
metric_bwd_dtype=dtype, grad_args=(0, 1, 2),
in_shardings=(x_pspec, g_pspec, b_pspec), metric_fwd_dtype=dtype,
out_shardings=(None, (x_pspec, g_pspec, b_pspec))) metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)),
)
except AssertionError as err: except AssertionError as err:
# Layernorm should still produce the correct numerical result with # Layernorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same # gamma/beta sharded. However, the collective count may not be the same
# when XLA is forced to unshard gamma and/or beta. We can catch # when XLA is forced to unshard gamma and/or beta. We can catch
# and ignore that specific error here. # and ignore that specific error here.
if (g_pspec[-1] is None and b_pspec[-1] is None) or "Expected collective count" not in str(err): if (
g_pspec[-1] is None and b_pspec[-1] is None
) or "Expected collective count" not in str(err):
raise err raise err
finally: finally:
for w in warns: for w in warns:
...@@ -110,13 +130,15 @@ class TestDistributedLayernorm: ...@@ -110,13 +130,15 @@ class TestDistributedLayernorm:
"unsupported sharding of gamma and/or beta" "unsupported sharding of gamma and/or beta"
) )
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]]) @pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize('shard_weights', [False, True]) @pytest.mark.parametrize("shard_weights", [False, True])
def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights): def test_rmsnorm(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights
):
epsilon = 1e-6 epsilon = 1e-6
ln_type = 'rmsnorm' ln_type = "rmsnorm"
def target_func(x, gamma): def target_func(x, gamma):
return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon)) return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon))
...@@ -128,10 +150,12 @@ class TestDistributedLayernorm: ...@@ -128,10 +150,12 @@ class TestDistributedLayernorm:
output = y * gamma output = y * gamma
return jnp.mean(output) return jnp.mean(output)
(x, gamma, _), (x_pspec, g_pspec, _) = \ (x, gamma, _), (x_pspec, g_pspec, _) = self.generate_inputs(
self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights) data_shape, mesh_resource, dtype, shard_weights
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type, )
data_shape, dtype) collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource): with mesh, fp8_autocast(mesh_resource=mesh_resource):
...@@ -140,14 +164,17 @@ class TestDistributedLayernorm: ...@@ -140,14 +164,17 @@ class TestDistributedLayernorm:
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
compare_ops(target_func, compare_ops(
ref_func, [x_, gamma_], target_func,
collective_count_ref, ref_func,
grad_args=(0, 1), [x_, gamma_],
metric_fwd_dtype=dtype, collective_count_ref,
metric_bwd_dtype=dtype, grad_args=(0, 1),
in_shardings=(x_pspec, g_pspec), metric_fwd_dtype=dtype,
out_shardings=(None, (x_pspec, g_pspec))) metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec)),
)
except AssertionError as err: except AssertionError as err:
# RmsNorm should still produce the correct numerical result with # RmsNorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same # gamma/beta sharded. However, the collective count may not be the same
......
...@@ -15,22 +15,23 @@ from transformer_engine.jax import fp8_autocast ...@@ -15,22 +15,23 @@ from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.flax import LayerNormMLP from transformer_engine.jax.flax import LayerNormMLP
from transformer_engine.jax.layernorm_mlp import fused_layernorm_fp8_mlp from transformer_engine.jax.layernorm_mlp import fused_layernorm_fp8_mlp
from transformer_engine.jax.sharding import ( from transformer_engine.jax.sharding import (
HIDDEN_AXES, HIDDEN_TP_AXES, HIDDEN_AXES,
HIDDEN_TP_AXES,
BATCH_AXES, BATCH_AXES,
SEQLEN_TP_AXES, SEQLEN_AXES, SEQLEN_TP_AXES,
W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES SEQLEN_AXES,
W_NO_SHARD_AXES,
W_FSDP_AXES,
W_TP_AXES,
W_JOINED_AXES,
) )
from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.sharding import MeshResource
from utils import ( from utils import assert_allclose, assert_tree_like_allclose, is_devices_enough
assert_allclose,
assert_tree_like_allclose,
is_devices_enough
)
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
DTYPES = [jnp.bfloat16, jnp.float16] DTYPES = [jnp.bfloat16, jnp.float16]
INPUT_SHAPE = [[64, 128, 32]] # [batch, seqlen, hidden_in] INPUT_SHAPE = [[64, 128, 32]] # [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES) LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES) DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
...@@ -43,13 +44,13 @@ def generate_fsdp_and_tp_configs(): ...@@ -43,13 +44,13 @@ def generate_fsdp_and_tp_configs():
configs = [] configs = []
if is_devices_enough(2): if is_devices_enough(2):
configs.append( configs.append(
[2, (1, 2), ('fsdp', 'tp'), [2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
MeshResource(fsdp_resource='fsdp', tp_resource='tp')]) )
if is_devices_enough(4): if is_devices_enough(4):
configs.append( configs.append(
[4, (2, 2), ('fsdp', 'tp'), [4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
MeshResource(fsdp_resource='fsdp', tp_resource='tp')]) )
return configs return configs
...@@ -64,10 +65,12 @@ class TestDistributedLayernormMLP: ...@@ -64,10 +65,12 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype) gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
k1 = jax.random.normal(subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), k1 = jax.random.normal(
dtype) / jnp.sqrt(hidden_in) subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
k2 = jax.random.normal(subkeys[2], ) / jnp.sqrt(hidden_in)
(INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(INTERMEDIATE) k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
INTERMEDIATE
)
if use_bias: if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype) b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype) b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
...@@ -90,15 +93,27 @@ class TestDistributedLayernormMLP: ...@@ -90,15 +93,27 @@ class TestDistributedLayernormMLP:
scale_list_1: List[jnp.ndarray], scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray], scale_list_2: List[jnp.ndarray],
layernorm_type: str = "rmsnorm", layernorm_type: str = "rmsnorm",
activation_type: Sequence[Union[str, Callable]] = ('gelu',), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
use_bias: bool = True, use_bias: bool = True,
multi_gpus: bool = False, multi_gpus: bool = False,
) -> jnp.ndarray: ) -> jnp.ndarray:
fp8_meta_pkg1 = FP8MetaPackage(amax_list_1[0], scale_list_1[0], amax_list_1[1], fp8_meta_pkg1 = FP8MetaPackage(
scale_list_1[1], amax_list_1[2], scale_list_1[2]) amax_list_1[0],
fp8_meta_pkg2 = FP8MetaPackage(amax_list_2[0], scale_list_2[0], amax_list_2[1], scale_list_1[0],
scale_list_2[1], amax_list_2[2], scale_list_2[2]) amax_list_1[1],
scale_list_1[1],
amax_list_1[2],
scale_list_1[2],
)
fp8_meta_pkg2 = FP8MetaPackage(
amax_list_2[0],
scale_list_2[0],
amax_list_2[1],
scale_list_2[1],
amax_list_2[2],
scale_list_2[2],
)
if multi_gpus: if multi_gpus:
layernorm_input_axes = LAYERNORM_INPUT_AXES layernorm_input_axes = LAYERNORM_INPUT_AXES
...@@ -111,60 +126,68 @@ class TestDistributedLayernormMLP: ...@@ -111,60 +126,68 @@ class TestDistributedLayernormMLP:
# out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2 # out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
return jnp.mean( return jnp.mean(
fused_layernorm_fp8_mlp(x, fused_layernorm_fp8_mlp(
ln_scale, x,
None, [kernel_1, kernel_2], [bias_1, bias_2], ln_scale,
[fp8_meta_pkg1, fp8_meta_pkg2], None,
layernorm_type, [kernel_1, kernel_2],
layernorm_input_axes=layernorm_input_axes, [bias_1, bias_2],
dot_1_input_axes=dot_1_input_axes, [fp8_meta_pkg1, fp8_meta_pkg2],
dot_2_input_axes=dot_2_input_axes, layernorm_type,
activation_type=activation_type, layernorm_input_axes=layernorm_input_axes,
use_bias=use_bias)) dot_1_input_axes=dot_1_input_axes,
dot_2_input_axes=dot_2_input_axes,
activation_type=activation_type,
use_bias=use_bias,
)
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs()) @pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize('input_shape', INPUT_SHAPE) @pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize('activation_type', [("gelu",), ('gelu', 'linear')]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize('use_bias', [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
def test_layernorm_fp8_mlp_primitive(self, mesh_config, activation_type, use_bias, input_shape, def test_layernorm_fp8_mlp_primitive(
dtype): self, mesh_config, activation_type, use_bias, input_shape, dtype
):
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
layernorm_type = 'rmsnorm' layernorm_type = "rmsnorm"
fp8_amax_list_1 = [ fp8_amax_list_1 = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32), jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32), jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32) jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
] ]
fp8_amax_list_2 = [ fp8_amax_list_2 = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32), jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32), jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32) jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
] ]
fp8_scale_list_1 = [ fp8_scale_list_1 = [
jnp.ones((1,), jnp.float32), jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32), jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32) jnp.ones((1,), jnp.float32),
] ]
fp8_scale_list_2 = [ fp8_scale_list_2 = [
jnp.ones((1,), jnp.float32), jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32), jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32) jnp.ones((1,), jnp.float32),
] ]
inputs = [x, gamma, k1, k2, b1, b2] = \ inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs(
self.generate_inputs(input_shape, activation_type, use_bias, dtype) input_shape, activation_type, use_bias, dtype
)
inputs = [*inputs, fp8_amax_list_1, fp8_amax_list_2, fp8_scale_list_1, fp8_scale_list_2] inputs = [*inputs, fp8_amax_list_1, fp8_amax_list_2, fp8_scale_list_1, fp8_scale_list_2]
static_inputs = [layernorm_type, activation_type, use_bias] static_inputs = [layernorm_type, activation_type, use_bias]
value_and_grad_func = jax.value_and_grad(self.layernorm_fp8_mlp_prim_func, value_and_grad_func = jax.value_and_grad(
argnums=range(len(inputs))) self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
)
# Single GPU # Single GPU
single_jitter = jax.jit(value_and_grad_func, single_jitter = jax.jit(
static_argnums=range(len(inputs), value_and_grad_func, static_argnums=range(len(inputs), len(static_inputs) + len(inputs))
len(static_inputs) + len(inputs))) )
with fp8_autocast(enabled=True): with fp8_autocast(enabled=True):
single_fwd, single_grads = single_jitter(*inputs, *static_inputs) single_fwd, single_grads = single_jitter(*inputs, *static_inputs)
...@@ -172,12 +195,12 @@ class TestDistributedLayernormMLP: ...@@ -172,12 +195,12 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, mesh_resource=mesh_resource): with mesh, fp8_autocast(enabled=True, mesh_resource=mesh_resource):
k1_sharding = NamedSharding(mesh, PartitionSpec('fsdp', None, 'tp')) k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec('tp', 'fsdp')) k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding) k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding) k2_ = jax.device_put(k2, k2_sharding)
if use_bias: if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec(None, 'tp')) b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_ = jax.device_put(b1, b1_sharding) b1_ = jax.device_put(b1, b1_sharding)
else: else:
b1_sharding = b1_ = None b1_sharding = b1_ = None
...@@ -186,17 +209,29 @@ class TestDistributedLayernormMLP: ...@@ -186,17 +209,29 @@ class TestDistributedLayernormMLP:
# Position ref for sharding pspec lists # Position ref for sharding pspec lists
# x, gamma, k1, k2, b1, # x, gamma, k1, k2, b1,
# b2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv # b2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
in_shardings = (None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None, in_shardings = (
None, None) None,
out_shardings = (None, (None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None,
None, None, None)) k1_sharding,
k2_sharding,
multi_jitter = jax.jit(value_and_grad_func, b1_sharding,
in_shardings=in_shardings, None,
out_shardings=out_shardings, None,
static_argnums=range(len(multi_inputs), None,
len(static_inputs) + len(multi_inputs) + None,
1)) # +1 for multi_gpus None,
)
out_shardings = (
None,
(None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None, None, None),
)
multi_jitter = jax.jit(
value_and_grad_func,
in_shardings=in_shardings,
out_shardings=out_shardings,
static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1),
) # +1 for multi_gpus
multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)
...@@ -206,97 +241,96 @@ class TestDistributedLayernormMLP: ...@@ -206,97 +241,96 @@ class TestDistributedLayernormMLP:
if isinstance(multi_grads[i], list): if isinstance(multi_grads[i], list):
assert isinstance(single_grads[i], list) assert isinstance(single_grads[i], list)
for m_grad, s_grad in zip(multi_grads[i], single_grads[i]): for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
assert_allclose(m_grad, assert_allclose(
s_grad, m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close"
dtype=dtype, )
err_msg=f'multi_grads[{i}] is not close')
else: else:
assert_allclose(multi_grads[i], assert_allclose(
single_grads[i], multi_grads[i],
dtype=dtype, single_grads[i],
err_msg=f'multi_grads[{i}] is not close') dtype=dtype,
err_msg=f"multi_grads[{i}] is not close",
def _test_layernorm_mlp(self, mesh_config, activation_type, use_bias, input_shape, dtype, )
use_fp8):
def _test_layernorm_mlp(
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8
):
batch, seqlen, hidden_in = input_shape batch, seqlen, hidden_in = input_shape
layernorm_type = 'rmsnorm' layernorm_type = "rmsnorm"
rng = jax.random.PRNGKey(0) rng = jax.random.PRNGKey(0)
subkeys = jax.random.split(rng, 2) subkeys = jax.random.split(rng, 2)
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
init_rngs = {'params': subkeys[1]} init_rngs = {"params": subkeys[1]}
# Single GPUs # Single GPUs
with fp8_autocast(enabled=use_fp8): with fp8_autocast(enabled=use_fp8):
ln_mlp_single = LayerNormMLP( ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
transpose_batch_sequence=False, # input: [batch, seqlen, hidden] transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
activations=activation_type, activations=activation_type,
dtype=dtype, dtype=dtype,
use_bias=use_bias, use_bias=use_bias,
) )
params_single = ln_mlp_single.init(init_rngs, x) params_single = ln_mlp_single.init(init_rngs, x)
mlp_out_single, ln_out_single = ln_mlp_single.apply(params_single, mlp_out_single, ln_out_single = ln_mlp_single.apply(
x, params_single, x, deterministic=True
deterministic=True) )
# Multi GPUs # Multi GPUs
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=use_fp8, mesh_resource=mesh_resource): with mesh, fp8_autocast(enabled=use_fp8, mesh_resource=mesh_resource):
ln_mlp_sharded = LayerNormMLP(layernorm_type=layernorm_type, ln_mlp_sharded = LayerNormMLP(
transpose_batch_sequence=False, layernorm_type=layernorm_type,
intermediate_dim=INTERMEDIATE, transpose_batch_sequence=False,
activations=activation_type, intermediate_dim=INTERMEDIATE,
dtype=dtype, activations=activation_type,
scale_axes=(W_NO_SHARD_AXES,), dtype=dtype,
ln_bias_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes_2=(W_TP_AXES, W_FSDP_AXES), kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
use_bias=use_bias, kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
bias_axes_1=(W_JOINED_AXES, W_TP_AXES), use_bias=use_bias,
bias_axes_2=(W_NO_SHARD_AXES,), bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
layernorm_input_axes=LAYERNORM_INPUT_AXES, bias_axes_2=(W_NO_SHARD_AXES,),
dot_1_input_axes=DOT_1_INPUT_AXES, layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES, dot_1_input_axes=DOT_1_INPUT_AXES,
name='mlp') dot_2_input_axes=DOT_2_INPUT_AXES,
name="mlp",
)
params_sharded = ln_mlp_sharded.init(init_rngs, x) params_sharded = ln_mlp_sharded.init(init_rngs, x)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(params_sharded, mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
x, params_sharded, x, deterministic=True
deterministic=True) )
# Make sure params values are the same # Make sure params values are the same
assert_tree_like_allclose(params_sharded['params'], params_single['params']) assert_tree_like_allclose(params_sharded["params"], params_single["params"])
assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype) assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype) assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype)
@pytest.mark.parametrize('input_shape', INPUT_SHAPE) @pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs()) @pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize('activation_type', [("gelu",), ('silu', 'linear'), ('gelu', 'gelu')]) @pytest.mark.parametrize("activation_type", [("gelu",), ("silu", "linear"), ("gelu", "gelu")])
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize('use_bias', [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype): def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype):
self._test_layernorm_mlp(mesh_config, self._test_layernorm_mlp(
activation_type, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False
use_bias, )
input_shape,
dtype,
use_fp8=False)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs()) @pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize('activation_type', [("gelu",), ('gelu', 'linear'), ('gelu', 'gelu')]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear"), ("gelu", "gelu")])
@pytest.mark.parametrize('use_bias', [True, False]) @pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize('input_shape', INPUT_SHAPE) @pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
def test_layernorm_fp8_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, def test_layernorm_fp8_mlp_layer(
dtype): self, mesh_config, activation_type, use_bias, input_shape, dtype
self._test_layernorm_mlp(mesh_config, ):
activation_type, self._test_layernorm_mlp(
use_bias, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=True
input_shape, )
dtype,
use_fp8=True)
...@@ -25,7 +25,7 @@ class TestDistributedSoftmax: ...@@ -25,7 +25,7 @@ class TestDistributedSoftmax:
def generate_collectives_count_ref(self): def generate_collectives_count_ref(self):
# for loss # for loss
all_reduce_loss_bytes = 4 # 1 * FP32 all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, softmax_type, dtype, bad_sharding): def generate_inputs(self, shape, mesh_resource, softmax_type, dtype, bad_sharding):
...@@ -38,49 +38,65 @@ class TestDistributedSoftmax: ...@@ -38,49 +38,65 @@ class TestDistributedSoftmax:
mask = make_self_mask(batch, sqelen) mask = make_self_mask(batch, sqelen)
if not bad_sharding: if not bad_sharding:
x_pspec = PartitionSpec(mesh_resource.dp_resource, mesh_resource.tp_resource, x_pspec = PartitionSpec(
None, None) mesh_resource.dp_resource, mesh_resource.tp_resource, None, None
)
else: else:
x_pspec = PartitionSpec(mesh_resource.dp_resource, None, x_pspec = PartitionSpec(
None, mesh_resource.tp_resource) mesh_resource.dp_resource, None, None, mesh_resource.tp_resource
)
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None)
return (x, mask), (x_pspec, mask_pspec) return (x, mask), (x_pspec, mask_pspec)
@staticmethod @staticmethod
def target_func(x, mask, scale_factor=1.0, softmax_type=SoftmaxType.SCALED): def target_func(x, mask, scale_factor=1.0, softmax_type=SoftmaxType.SCALED):
return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type)) return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type))
@staticmethod @staticmethod
def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16): def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16):
bias = None bias = None
if mask is not None: if mask is not None:
bias = jax.lax.select(mask > 0, bias = jax.lax.select(
jnp.full(mask.shape, -1e10).astype(dtype), mask > 0,
jnp.full(mask.shape, 0.).astype(dtype)) jnp.full(mask.shape, -1e10).astype(dtype),
jnp.full(mask.shape, 0.0).astype(dtype),
)
if bias is not None: if bias is not None:
x = x + bias.astype(dtype) x = x + bias.astype(dtype)
output = jax.nn.softmax(x * scale_factor) output = jax.nn.softmax(x * scale_factor)
return jnp.mean(output) return jnp.mean(output)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 12, 128, 128], [64, 16, 1024, 1024]]) @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize( @pytest.mark.parametrize(
'softmax_type', "softmax_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED]) [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
@pytest.mark.parametrize('scale_factor', [1.0, 3.0]) )
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize("scale_factor", [1.0, 3.0])
@pytest.mark.parametrize('bad_sharding', [False, True]) @pytest.mark.parametrize("dtype", DTYPES)
def test_softmax(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, @pytest.mark.parametrize("bad_sharding", [False, True])
softmax_type, scale_factor, dtype, bad_sharding): def test_softmax(
self,
target_func = partial(self.target_func, device_count,
scale_factor=scale_factor, mesh_shape,
softmax_type=softmax_type) mesh_axes,
mesh_resource,
data_shape,
softmax_type,
scale_factor,
dtype,
bad_sharding,
):
target_func = partial(
self.target_func, scale_factor=scale_factor, softmax_type=softmax_type
)
ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype) ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)
(x, mask), (x_pspec, mask_pspec) = \ (x, mask), (x_pspec, mask_pspec) = self.generate_inputs(
self.generate_inputs(data_shape, mesh_resource, softmax_type, dtype, bad_sharding) data_shape, mesh_resource, softmax_type, dtype, bad_sharding
)
collective_count_ref = self.generate_collectives_count_ref() collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
...@@ -90,14 +106,17 @@ class TestDistributedSoftmax: ...@@ -90,14 +106,17 @@ class TestDistributedSoftmax:
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
compare_ops(target_func, compare_ops(
ref_func, [x_, mask_], target_func,
collective_count_ref, ref_func,
grad_args=(0,), [x_, mask_],
metric_fwd_dtype=dtype, collective_count_ref,
metric_bwd_dtype=dtype, grad_args=(0,),
in_shardings=(x_pspec, mask_pspec), metric_fwd_dtype=dtype,
out_shardings=(None, (x_pspec,))) metric_bwd_dtype=dtype,
in_shardings=(x_pspec, mask_pspec),
out_shardings=(None, (x_pspec,)),
)
except AssertionError as err: except AssertionError as err:
# Softmax should still produce the correct numerical result with # Softmax should still produce the correct numerical result with
# bad sharding. However, the collective count may not be the same # bad sharding. However, the collective count may not be the same
......
...@@ -20,12 +20,14 @@ class TestLoRA: ...@@ -20,12 +20,14 @@ class TestLoRA:
out = jnp.einsum(pattern, x, la, lb) out = jnp.einsum(pattern, x, la, lb)
return out * scale return out * scale
@pytest.mark.parametrize('shape', [(32, 1024), (32, 128, 1024)]) @pytest.mark.parametrize("shape", [(32, 1024), (32, 128, 1024)])
@pytest.mark.parametrize('dtype', [jnp.float32, jnp.bfloat16]) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16])
@pytest.mark.parametrize('axis_features_pattern', [((-1,), (1024,), '...h,hr,rk->...k'), @pytest.mark.parametrize(
((-1,), (3, 1024), '...h,hkr,krz->...kz')]) "axis_features_pattern",
@pytest.mark.parametrize('rank', [32, 16]) [((-1,), (1024,), "...h,hr,rk->...k"), ((-1,), (3, 1024), "...h,hkr,krz->...kz")],
@pytest.mark.parametrize('alpha', [None, 4, 8]) )
@pytest.mark.parametrize("rank", [32, 16])
@pytest.mark.parametrize("alpha", [None, 4, 8])
def test_lora(self, shape, dtype, axis_features_pattern, rank, alpha): def test_lora(self, shape, dtype, axis_features_pattern, rank, alpha):
axis, features, pattern = axis_features_pattern axis, features, pattern = axis_features_pattern
axis = _normalize_axes(axis, len(shape)) axis = _normalize_axes(axis, len(shape))
...@@ -49,16 +51,20 @@ class TestLoRA: ...@@ -49,16 +51,20 @@ class TestLoRA:
assert_allclose(out_target, out_ref, dtype=dtype) assert_allclose(out_target, out_ref, dtype=dtype)
@pytest.mark.parametrize('scope_ref_assert', @pytest.mark.parametrize(
[('none', LoRAScope(False, False, False), False), "scope_ref_assert",
('all', LoRAScope(True, True, True), False), [
('qkv_proj', LoRAScope(True, False, False), False), ("none", LoRAScope(False, False, False), False),
('output_proj', LoRAScope(False, True, False), False), ("all", LoRAScope(True, True, True), False),
('mlp', LoRAScope(False, False, True), False), ("qkv_proj", LoRAScope(True, False, False), False),
('exclude_qkv_proj', LoRAScope(False, True, True), False), ("output_proj", LoRAScope(False, True, False), False),
('exclude_output_proj', LoRAScope(True, False, True), False), ("mlp", LoRAScope(False, False, True), False),
('exclude_mlp', LoRAScope(True, True, False), False), ("exclude_qkv_proj", LoRAScope(False, True, True), False),
('messing_up', LoRAScope(), True)]) ("exclude_output_proj", LoRAScope(True, False, True), False),
("exclude_mlp", LoRAScope(True, True, False), False),
("messing_up", LoRAScope(), True),
],
)
def test_lora_scope_generator(self, scope_ref_assert): def test_lora_scope_generator(self, scope_ref_assert):
scope, reference, need_assert = scope_ref_assert scope, reference, need_assert = scope_ref_assert
try: try:
......
...@@ -24,7 +24,7 @@ from transformer_engine.jax.attention import ( ...@@ -24,7 +24,7 @@ from transformer_engine.jax.attention import (
QKVLayout, QKVLayout,
fused_attn_qkvpacked, fused_attn_qkvpacked,
fused_attn_kvpacked, fused_attn_kvpacked,
fused_attn fused_attn,
) )
from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine.jax.cpp_extensions import FusedAttnHelper
...@@ -33,7 +33,7 @@ from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend ...@@ -33,7 +33,7 @@ from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
from utils import assert_allclose from utils import assert_allclose
@pytest.fixture(autouse=True, scope='module') @pytest.fixture(autouse=True, scope="module")
def init(): def init():
""" """
WAR for CUDA uninitialize error WAR for CUDA uninitialize error
...@@ -43,10 +43,18 @@ def init(): ...@@ -43,10 +43,18 @@ def init():
yield yield
def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike, def general_dot_product_attention(
bias: ArrayLike, mask: ArrayLike, deterministic: bool, query: ArrayLike,
scale_factor: float, dropout_rate: float, dropout_rng: ArrayLike, key: ArrayLike,
dtype: DTypeLike) -> Array: value: ArrayLike,
bias: ArrayLike,
mask: ArrayLike,
deterministic: bool,
scale_factor: float,
dropout_rate: float,
dropout_rng: ArrayLike,
dtype: DTypeLike,
) -> Array:
""" """
Similar to flax.linen.dot_product_attention but with GQA support Similar to flax.linen.dot_product_attention but with GQA support
""" """
...@@ -59,7 +67,7 @@ def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: Array ...@@ -59,7 +67,7 @@ def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: Array
num_groups = h_q // h_kv num_groups = h_q // h_kv
grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d)) grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
# logits with shape (b, h_kv, num_groups, s_q, s_kv) # logits with shape (b, h_kv, num_groups, s_q, s_kv)
logits = scale_factor * jnp.einsum('...qhgd,...khd->...hgqk', grouped_query, key) logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
if bias is not None: if bias is not None:
# reshape logits without groups # reshape logits without groups
...@@ -76,13 +84,13 @@ def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: Array ...@@ -76,13 +84,13 @@ def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: Array
softmax_out = jax.nn.softmax(logits).astype(dtype) softmax_out = jax.nn.softmax(logits).astype(dtype)
if not deterministic and dropout_rate > 0.: if not deterministic and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate keep_prob = 1.0 - dropout_rate
keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape) keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype) multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
softmax_out = softmax_out * multiplier softmax_out = softmax_out * multiplier
context = jnp.einsum('...hgqk,...khd->...qhgd', softmax_out, value) context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
context = jnp.reshape(context, query.shape) context = jnp.reshape(context, query.shape)
return context return context
...@@ -105,6 +113,7 @@ def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array: ...@@ -105,6 +113,7 @@ def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
inv_padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0) inv_padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0)
return combine_masks(inv_causal_mask, inv_padding_mask) return combine_masks(inv_causal_mask, inv_padding_mask)
def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskType) -> Array: def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskType) -> Array:
""" """
Create attention mask based on mask type. A `True` value in the mask means Create attention mask based on mask type. A `True` value in the mask means
...@@ -118,23 +127,26 @@ def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskT ...@@ -118,23 +127,26 @@ def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskT
mask = jnp.logical_not(inv_mask) mask = jnp.logical_not(inv_mask)
return mask return mask
def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs): def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
""" """
JAX native dot product attention implementation JAX native dot product attention implementation
""" """
attn_mask_type = kwargs['attn_mask_type'] attn_mask_type = kwargs["attn_mask_type"]
mask = make_mask(q_token, kv_token, attn_mask_type) mask = make_mask(q_token, kv_token, attn_mask_type)
output = general_dot_product_attention(query, output = general_dot_product_attention(
key, query,
value, key,
bias=bias, value,
mask=mask, bias=bias,
deterministic=not kwargs['is_training'], mask=mask,
scale_factor=kwargs['scaling_factor'], deterministic=not kwargs["is_training"],
dropout_rate=kwargs['dropout_probability'], scale_factor=kwargs["scaling_factor"],
dropout_rng=dropout_rng, dropout_rate=kwargs["dropout_probability"],
dtype=jnp.float32) dropout_rng=dropout_rng,
dtype=jnp.float32,
)
return output.astype(query.dtype) return output.astype(query.dtype)
...@@ -142,10 +154,10 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng ...@@ -142,10 +154,10 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
""" """
TE customcall dot product attention implementation TE customcall dot product attention implementation
""" """
attn_mask_type = kwargs['attn_mask_type'] attn_mask_type = kwargs["attn_mask_type"]
mask = make_mask(q_token, kv_token, attn_mask_type) mask = make_mask(q_token, kv_token, attn_mask_type)
qkv_layout = kwargs.pop('qkv_layout') qkv_layout = kwargs.pop("qkv_layout")
match qkv_layout: match qkv_layout:
case QKVLayout.BS3HD: case QKVLayout.BS3HD:
query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value]) query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
...@@ -154,11 +166,13 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng ...@@ -154,11 +166,13 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
case QKVLayout.BSHD_BS2HD: case QKVLayout.BSHD_BS2HD:
key, value = map(partial(jnp.expand_dims, axis=-3), [key, value]) key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
kv = jnp.concatenate((key, value), axis=-3) kv = jnp.concatenate((key, value), axis=-3)
return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng, return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng, **kwargs).astype(
**kwargs).astype(query.dtype) query.dtype
)
case QKVLayout.BSHD_BSHD_BSHD: case QKVLayout.BSHD_BSHD_BSHD:
return fused_attn(query, key, value, bias, mask, dropout_rng, return fused_attn(query, key, value, bias, mask, dropout_rng, **kwargs).astype(
**kwargs).astype(query.dtype) query.dtype
)
class BiasShape(Enum): class BiasShape(Enum):
...@@ -166,10 +180,10 @@ class BiasShape(Enum): ...@@ -166,10 +180,10 @@ class BiasShape(Enum):
Enum class to represent the different bias shapes used in the fused attention. Enum class to represent the different bias shapes used in the fused attention.
""" """
BIAS_1HSS = '1HSS' BIAS_1HSS = "1HSS"
BIAS_B1SS = 'B1SS' BIAS_B1SS = "B1SS"
BIAS_BHSS = 'BHSS' BIAS_BHSS = "BHSS"
BIAS_11SS = '11SS' BIAS_11SS = "11SS"
@dataclass @dataclass
...@@ -177,6 +191,7 @@ class FusedAttnRunner: ...@@ -177,6 +191,7 @@ class FusedAttnRunner:
""" """
Fused attention runner Fused attention runner
""" """
batch_size: int batch_size: int
max_seqlen_q: int max_seqlen_q: int
max_seqlen_kv: int max_seqlen_kv: int
...@@ -198,21 +213,33 @@ class FusedAttnRunner: ...@@ -198,21 +213,33 @@ class FusedAttnRunner:
if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv: if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.") pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")
self.backend = FusedAttnHelper(self.dtype, self.dtype, self.qkv_layout.value, self.backend = FusedAttnHelper(
self.attn_bias_type.value, self.attn_mask_type.value, self.dtype,
self.dropout_prob, self.num_heads_q, self.num_heads_kv, self.dtype,
self.max_seqlen_q, self.max_seqlen_kv, self.qkv_layout.value,
self.head_dim).get_fused_attn_backend() self.attn_bias_type.value,
self.attn_mask_type.value,
self.dropout_prob,
self.num_heads_q,
self.num_heads_kv,
self.max_seqlen_q,
self.max_seqlen_kv,
self.head_dim,
).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("Unsupported inputs combination or device compute capability.") pytest.skip("Unsupported inputs combination or device compute capability.")
if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS: if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for " pytest.skip(
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.") "B1SS, BHSS and 11SS bias shapes are only supported for "
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK."
)
elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for " pytest.skip(
"the F16_arbitrary_seqlen backend.") "B1SS, BHSS and 11SS bias shapes are only supported for "
"the F16_arbitrary_seqlen backend."
)
def _setup_inputs(self): def _setup_inputs(self):
self._check_configs() self._check_configs()
...@@ -235,24 +262,25 @@ class FusedAttnRunner: ...@@ -235,24 +262,25 @@ class FusedAttnRunner:
else: else:
pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!") pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.) self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.) self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.) self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
if self.attn_bias_type != AttnBiasType.NO_BIAS: if self.attn_bias_type != AttnBiasType.NO_BIAS:
if self.bias_shape == BiasShape.BIAS_1HSS: if self.bias_shape == BiasShape.BIAS_1HSS:
self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.) self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
else: else:
# [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
# an arbitrary mask where (True/False -> 0/-Inf) # an arbitrary mask where (True/False -> 0/-Inf)
cudnn_neg_inf = -2.**27. if self.dtype == jnp.bfloat16 else -2.**15. cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype) self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
max_id = min(self.max_seqlen_q, self.max_seqlen_kv) max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences
seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist() seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
for i in range(1, len(seq_id)): for i in range(1, len(seq_id)):
self.bias = \ self.bias = self.bias.at[
self.bias.at[:, :, seq_id[i-1]:seq_id[i], seq_id[i-1]:seq_id[i]].set(0.) :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
].set(0.0)
else: else:
self.bias = None self.bias = None
...@@ -271,7 +299,7 @@ class FusedAttnRunner: ...@@ -271,7 +299,7 @@ class FusedAttnRunner:
self.valid_len_kv, self.token_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio) self.valid_len_kv, self.token_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1. / sqrt(self.head_dim) self.scaling_factor = 1.0 / sqrt(self.head_dim)
def test_forward(self): def test_forward(self):
""" """
...@@ -281,19 +309,19 @@ class FusedAttnRunner: ...@@ -281,19 +309,19 @@ class FusedAttnRunner:
args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng] args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
kwargs = { kwargs = {
'attn_bias_type': self.attn_bias_type, "attn_bias_type": self.attn_bias_type,
'attn_mask_type': self.attn_mask_type, "attn_mask_type": self.attn_mask_type,
'scaling_factor': self.scaling_factor, "scaling_factor": self.scaling_factor,
'dropout_probability': self.dropout_prob, "dropout_probability": self.dropout_prob,
'is_training': self.is_training, "is_training": self.is_training,
'qkv_layout': self.qkv_layout, "qkv_layout": self.qkv_layout,
} }
# Convert the outputs to float32 for the elementwise comparison # Convert the outputs to float32 for the elementwise comparison
primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32) primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32)
reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32) reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32)
if self.is_training and self.dropout_prob > 0.: if self.is_training and self.dropout_prob > 0.0:
return return
primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1) primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1)
...@@ -322,12 +350,12 @@ class FusedAttnRunner: ...@@ -322,12 +350,12 @@ class FusedAttnRunner:
args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng] args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
kwargs = { kwargs = {
'attn_bias_type': self.attn_bias_type, "attn_bias_type": self.attn_bias_type,
'attn_mask_type': self.attn_mask_type, "attn_mask_type": self.attn_mask_type,
'scaling_factor': self.scaling_factor, "scaling_factor": self.scaling_factor,
'dropout_probability': self.dropout_prob, "dropout_probability": self.dropout_prob,
'is_training': self.is_training, "is_training": self.is_training,
'qkv_layout': self.qkv_layout, "qkv_layout": self.qkv_layout,
} }
# We can compute dBias only for the [1, h, s, s] layout # We can compute dBias only for the [1, h, s, s] layout
...@@ -336,12 +364,18 @@ class FusedAttnRunner: ...@@ -336,12 +364,18 @@ class FusedAttnRunner:
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit( jitted_primitive = jit(
value_and_grad( value_and_grad(
lambda q, k, v, bias, *args: grad_func(customcall_fused_dpa, q, k, v, bias, *args, lambda q, k, v, bias, *args: grad_func(
**kwargs), arg_nums)) customcall_fused_dpa, q, k, v, bias, *args, **kwargs
),
arg_nums,
)
)
jitted_reference = jit( jitted_reference = jit(
value_and_grad( value_and_grad(
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs), lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
arg_nums)) arg_nums,
)
)
primitive_out, primitive_dgrad = jitted_primitive(*args) primitive_out, primitive_dgrad = jitted_primitive(*args)
reference_out, reference_dgrad = jitted_reference(*args) reference_out, reference_dgrad = jitted_reference(*args)
...@@ -350,9 +384,9 @@ class FusedAttnRunner: ...@@ -350,9 +384,9 @@ class FusedAttnRunner:
if self.dropout_prob > 0.0: if self.dropout_prob > 0.0:
return return
assert_allclose(primitive_out.astype(jnp.float32), assert_allclose(
reference_out.astype(jnp.float32), primitive_out.astype(jnp.float32), reference_out.astype(jnp.float32), dtype=self.dtype
dtype=self.dtype) )
def check_dqkv(primitive, reference, valid_len): def check_dqkv(primitive, reference, valid_len):
primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1) primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1)
...@@ -374,81 +408,158 @@ class FusedAttnRunner: ...@@ -374,81 +408,158 @@ class FusedAttnRunner:
primitive_dbias = jnp.float32(primitive_dgrad[3]) primitive_dbias = jnp.float32(primitive_dgrad[3])
reference_dbias = jnp.float32(reference_dgrad[3]) reference_dbias = jnp.float32(reference_dgrad[3])
assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:], assert_allclose(
jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :],
self.valid_len_kv:]), jnp.zeros_like(primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :]),
dtype=self.dtype) dtype=self.dtype,
)
# dbias padded part # dbias padded part
assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:], assert_allclose(
reference_dbias[..., self.valid_len_q:, self.valid_len_kv:], primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :],
dtype=self.dtype) reference_dbias[..., self.valid_len_q :, self.valid_len_kv :],
dtype=self.dtype,
)
# dbias valid part # dbias valid part
assert_allclose(primitive_dbias[..., :self.valid_len_q, :self.valid_len_kv], assert_allclose(
reference_dbias[..., :self.valid_len_q, :self.valid_len_kv], primitive_dbias[..., : self.valid_len_q, : self.valid_len_kv],
dtype=self.dtype) reference_dbias[..., : self.valid_len_q, : self.valid_len_kv],
dtype=self.dtype,
)
@pytest.mark.parametrize('attn_bias_type, bias_shape', [
pytest.param(AttnBiasType.NO_BIAS, None, id='NO_BIAS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id='POST_SCALE_BIAS-1HSS'), @pytest.mark.parametrize(
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id='POST_SCALE_BIAS-B1SS'), "attn_bias_type, bias_shape",
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id='POST_SCALE_BIAS-BHSS'), [
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id='POST_SCALE_BIAS-11SS'), pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
]) pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id="POST_SCALE_BIAS-1HSS"),
@pytest.mark.parametrize('attn_mask_type', [ pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id="POST_SCALE_BIAS-B1SS"),
pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'), pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id="POST_SCALE_BIAS-BHSS"),
pytest.param(AttnMaskType.PADDING_MASK, id='PADDING'), pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id="POST_SCALE_BIAS-11SS"),
pytest.param(AttnMaskType.CAUSAL_MASK, id='CAUSAL'), ],
pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id='PADDING_CAUSAL'), )
]) @pytest.mark.parametrize(
@pytest.mark.parametrize('qkv_layout', [ "attn_mask_type",
pytest.param(QKVLayout.BS3HD, id='QKV_PACKED'), [
pytest.param(QKVLayout.BSHD_BS2HD, id='KV_PACKED'), pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='SEPARATE'), pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
]) pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
@pytest.mark.parametrize('dtype', [ pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
pytest.param(jnp.bfloat16, id="BF16"), ],
pytest.param(jnp.float16, id="FP16"), )
]) @pytest.mark.parametrize(
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d', [ "qkv_layout",
pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'), [
pytest.param(4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'), pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'), pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
pytest.param(4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'), pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'), ],
pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA'), )
]) @pytest.mark.parametrize(
@pytest.mark.parametrize('dropout_prob', [ "dtype",
pytest.param(0.0, id="DROP_0.0"), [
pytest.param(0.1, id="DROP_0.1"), pytest.param(jnp.bfloat16, id="BF16"),
]) pytest.param(jnp.float16, id="FP16"),
],
)
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d",
[
pytest.param(32, 128, 128, 16, 16, 64, id="32-128-128-16-16-64-SELF"),
pytest.param(4, 2048, 2048, 12, 12, 64, id="4-2048-2048-12-12-64-SELF"),
pytest.param(32, 512, 128, 16, 16, 64, id="32-512-128-16-16-64-CROSS"),
pytest.param(4, 2048, 1024, 12, 12, 64, id="4-2048-1048-12-12-64-CROSS"),
pytest.param(32, 128, 128, 16, 8, 64, id="32-128-128-16-8-64-GQA"),
pytest.param(4, 2048, 2048, 12, 6, 64, id="4-2048-2048-12-6-64-GQA"),
],
)
@pytest.mark.parametrize(
"dropout_prob",
[
pytest.param(0.0, id="DROP_0.0"),
pytest.param(0.1, id="DROP_0.1"),
],
)
class TestFusedAttn: class TestFusedAttn:
""" """
Fused attention tester Fused attention tester
""" """
@staticmethod @staticmethod
@pytest.mark.parametrize('is_training', [ @pytest.mark.parametrize(
pytest.param(True, id='TRAINING'), "is_training",
pytest.param(False, id='INFERENCE'), [
]) pytest.param(True, id="TRAINING"),
def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob, pytest.param(False, id="INFERENCE"),
dtype, is_training, qkv_layout, bias_shape): ],
)
def test_forward(
b,
s_q,
s_kv,
h_q,
h_kv,
d,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
):
""" """
Test forward with parameterized configs Test forward with parameterized configs
""" """
runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, runner = FusedAttnRunner(
dropout_prob, dtype, is_training, qkv_layout, bias_shape) b,
s_q,
s_kv,
h_q,
h_kv,
d,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
)
runner.test_forward() runner.test_forward()
@staticmethod @staticmethod
def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob, def test_backward(
dtype, qkv_layout, bias_shape): b,
s_q,
s_kv,
h_q,
h_kv,
d,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
qkv_layout,
bias_shape,
):
""" """
Test backward with parameterized configs Test backward with parameterized configs
""" """
runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, runner = FusedAttnRunner(
dropout_prob, dtype, True, qkv_layout, bias_shape) b,
s_q,
s_kv,
h_q,
h_kv,
d,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
True,
qkv_layout,
bias_shape,
)
runner.test_backward() runner.test_backward()
...@@ -27,21 +27,28 @@ class TestFP8Helper(unittest.TestCase): ...@@ -27,21 +27,28 @@ class TestFP8Helper(unittest.TestCase):
fp8_format = FP8Format.E4M3 fp8_format = FP8Format.E4M3
amax_history_len = 10 amax_history_len = 10
FP8Helper.initialize(margin=margin, FP8Helper.initialize(
fp8_format=fp8_format, margin=margin, fp8_format=fp8_format, amax_history_len=amax_history_len
amax_history_len=amax_history_len) )
self.assertEqual( self.assertEqual(
FP8Helper.MARGIN, margin, f"FP8Helper.MARGIN initialization failed, should be {margin}" FP8Helper.MARGIN,
f" but got {FP8Helper.MARGIN}.") margin,
f"FP8Helper.MARGIN initialization failed, should be {margin}"
f" but got {FP8Helper.MARGIN}.",
)
self.assertEqual( self.assertEqual(
FP8Helper.FP8_FORMAT, fp8_format, FP8Helper.FP8_FORMAT,
fp8_format,
f"FP8Helper.FP8_FORMAT initialization failed, should be {fp8_format}" f"FP8Helper.FP8_FORMAT initialization failed, should be {fp8_format}"
f" but got {FP8Helper.FP8_FORMAT}.") f" but got {FP8Helper.FP8_FORMAT}.",
)
self.assertEqual( self.assertEqual(
FP8Helper.AMAX_HISTORY_LEN, amax_history_len, FP8Helper.AMAX_HISTORY_LEN,
amax_history_len,
f"FP8Helper.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}" f"FP8Helper.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
f" but got {FP8Helper.AMAX_HISTORY_LEN}.") f" but got {FP8Helper.AMAX_HISTORY_LEN}.",
)
FP8Helper.finalize() FP8Helper.finalize()
...@@ -77,7 +84,7 @@ class TestFP8Functions(unittest.TestCase): ...@@ -77,7 +84,7 @@ class TestFP8Functions(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason) @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast(self): def test_fp8_autocast(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests. FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state() self._check_defult_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()): with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
...@@ -102,21 +109,21 @@ class TestFP8Functions(unittest.TestCase): ...@@ -102,21 +109,21 @@ class TestFP8Functions(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason) @unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_with_sharding_resource(self): def test_fp8_autocast_with_sharding_resource(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests. FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state() self._check_defult_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1) ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
mesh_s = ( mesh_s = (
(MeshResource(None, None)), (MeshResource(None, None)),
(MeshResource('dp', None)), (MeshResource("dp", None)),
(MeshResource(None, 'tp')), (MeshResource(None, "tp")),
(MeshResource('dp', 'tp')), (MeshResource("dp", "tp")),
) )
# TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme # TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme
mesh_shape = (1, 1) mesh_shape = (1, 1)
devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
with jax.sharding.Mesh(devices, ('dp', 'tp')): with jax.sharding.Mesh(devices, ("dp", "tp")):
for sr in mesh_s: for sr in mesh_s:
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr): with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr):
self.assertTrue(FP8Helper.is_fp8_enabled()) self.assertTrue(FP8Helper.is_fp8_enabled())
......
...@@ -22,7 +22,7 @@ from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available ...@@ -22,7 +22,7 @@ from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
@pytest.fixture(autouse=True, scope='function') @pytest.fixture(autouse=True, scope="function")
def enable_fused_attn(): def enable_fused_attn():
"""Enable fused attention""" """Enable fused attention"""
os.environ["NVTE_FUSED_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1"
...@@ -30,9 +30,9 @@ def enable_fused_attn(): ...@@ -30,9 +30,9 @@ def enable_fused_attn():
del os.environ["NVTE_FUSED_ATTN"] del os.environ["NVTE_FUSED_ATTN"]
DATA_SHAPE = [ # (batch, seqlen, emb_dim) DATA_SHAPE = [ # (batch, seqlen, emb_dim)
pytest.param((32, 128, 1024), id='32-128-1024'), pytest.param((32, 128, 1024), id="32-128-1024"),
pytest.param((32, 512, 1024), id='32-512-1024'), pytest.param((32, 512, 1024), id="32-512-1024"),
] ]
DTYPE = [jnp.float32, jnp.bfloat16] DTYPE = [jnp.float32, jnp.bfloat16]
FP8_FORMATS = [Format.E4M3, Format.HYBRID] FP8_FORMATS = [Format.E4M3, Format.HYBRID]
...@@ -69,123 +69,138 @@ BASE_ATTRS = { ...@@ -69,123 +69,138 @@ BASE_ATTRS = {
_KEY_OF_ATTENTION_DROPOUT: 0, _KEY_OF_ATTENTION_DROPOUT: 0,
_KEY_OF_INTERMEDIATE_DROPOUT: 0, _KEY_OF_INTERMEDIATE_DROPOUT: 0,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal", _KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
_KEY_OF_LAYERNORM_TYPE: 'layernorm', _KEY_OF_LAYERNORM_TYPE: "layernorm",
} }
ATTRS = [{}, { ATTRS = [
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', {},
}, { {
_KEY_OF_ZERO_CENTERED_GAMMA: True, _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_LAYERNORM_EPS: 1e-2, },
}, { {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_ZERO_CENTERED_GAMMA: True,
_KEY_OF_RESIDUAL_POST_LAYERNORM: True _KEY_OF_LAYERNORM_EPS: 1e-2,
}, { },
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_RESIDUAL_POST_LAYERNORM: True},
_KEY_OF_OUTPUT_LAYERNORM: True {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_OUTPUT_LAYERNORM: True},
}, { {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_RESIDUAL_POST_LAYERNORM: True, _KEY_OF_RESIDUAL_POST_LAYERNORM: True,
_KEY_OF_OUTPUT_LAYERNORM: True _KEY_OF_OUTPUT_LAYERNORM: True,
}, { },
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_DROP_PATH: 0.1},
_KEY_OF_DROP_PATH: 0.1 {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_FUSE_QKV_PARAMS: False},
}, { {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_FUSE_QKV_PARAMS: False _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
}, { },
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', {
_KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'), _KEY_OF_SCALE_ATTN_LOGITS: True,
}, { _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_SCALE_ATTN_LOGITS: True, _KEY_OF_HIDDEN_DROPOUT: 0.8,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
_KEY_OF_HIDDEN_DROPOUT: 0.8, _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5, _KEY_OF_USE_BIAS: True,
_KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'), },
_KEY_OF_USE_BIAS: True, {
}, { _KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_TRANSPOSE_BS: False, _KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_SCALE_ATTN_LOGITS: True, _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
_KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'), },
}, { {
_KEY_OF_NUM_HEADS: 8, _KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4, _KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_TRANSPOSE_BS: False, _KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True, _KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_MLP_ACTIVATIONS: ('gelu',), _KEY_OF_MLP_ACTIVATIONS: ("gelu",),
_KEY_OF_USE_BIAS: True, _KEY_OF_USE_BIAS: True,
}, { },
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', {
_KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')), _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
}, { _KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")),
_KEY_OF_SCALE_ATTN_LOGITS: True, },
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', {
_KEY_OF_HIDDEN_DROPOUT: 0.8, _KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5, _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')), _KEY_OF_HIDDEN_DROPOUT: 0.8,
_KEY_OF_USE_BIAS: True, _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
}, { _KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")),
_KEY_OF_TRANSPOSE_BS: False, _KEY_OF_USE_BIAS: True,
_KEY_OF_SCALE_ATTN_LOGITS: True, },
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', {
_KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')), _KEY_OF_TRANSPOSE_BS: False,
}, { _KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_NUM_HEADS: 8, _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_NUM_GQA_GROUPS: 4, _KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")),
_KEY_OF_TRANSPOSE_BS: False, },
_KEY_OF_SCALE_ATTN_LOGITS: True, {
_KEY_OF_LAYERNORM_TYPE: 'layernorm', _KEY_OF_NUM_HEADS: 8,
_KEY_OF_MLP_ACTIVATIONS: (('silu',)), _KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_USE_BIAS: True, _KEY_OF_TRANSPOSE_BS: False,
}, { _KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_TRANSPOSE_BS: False, _KEY_OF_LAYERNORM_TYPE: "layernorm",
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_MLP_ACTIVATIONS: (("silu",)),
_KEY_OF_NUM_GQA_GROUPS: 1, _KEY_OF_USE_BIAS: True,
_KEY_OF_ENABLE_ROPE: True, },
_KEY_OF_ROPE_GROUP_METHOD: "consecutive", {
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True, _KEY_OF_TRANSPOSE_BS: False,
}, { _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_TRANSPOSE_BS: True, _KEY_OF_NUM_GQA_GROUPS: 1,
_KEY_OF_ENABLE_ROPE: True, _KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "consecutive", _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
_KEY_OF_USE_BIAS: True, _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
}, { },
_KEY_OF_TRANSPOSE_BS: False, {
_KEY_OF_LAYERNORM_TYPE: 'layernorm', _KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_NUM_GQA_GROUPS: 2, _KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ENABLE_ROPE: True, _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
_KEY_OF_ROPE_GROUP_METHOD: "alternate", _KEY_OF_USE_BIAS: True,
_KEY_OF_USE_BIAS: True, },
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True, {
}, { _KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_TRANSPOSE_BS: True, _KEY_OF_LAYERNORM_TYPE: "layernorm",
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_NUM_GQA_GROUPS: 2,
_KEY_OF_ENABLE_ROPE: True, _KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "alternate", _KEY_OF_ROPE_GROUP_METHOD: "alternate",
_KEY_OF_USE_BIAS: True, _KEY_OF_USE_BIAS: True,
}, { _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
_KEY_OF_HIDDEN_DROPOUT: 0.3, },
_KEY_OF_HIDDEN_DROPOUT_DIMS: (0,), {
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5, _KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,), _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
}, { _KEY_OF_ENABLE_ROPE: True,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding", _KEY_OF_ROPE_GROUP_METHOD: "alternate",
_KEY_OF_USE_BIAS: True, _KEY_OF_USE_BIAS: True,
}, { },
_KEY_OF_RELATIVE_EMBEDDING: False, {
_KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias", _KEY_OF_HIDDEN_DROPOUT: 0.3,
}, { _KEY_OF_HIDDEN_DROPOUT_DIMS: (0,),
_KEY_OF_ATTENTION_DROPOUT: 0.3, _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
}, { _KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,),
_KEY_OF_MLP_ACTIVATIONS: (('relu', 'relu')), },
}] {
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_USE_BIAS: True,
},
{
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias",
},
{
_KEY_OF_ATTENTION_DROPOUT: 0.3,
},
{
_KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")),
},
]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
class BaseRunner: class BaseRunner:
"""Base runner to define forward and backward tests""" """Base runner to define forward and backward tests"""
layer_type: TransformerLayerType = None layer_type: TransformerLayerType = None
reference_layer: flax.linen.Module = None reference_layer: flax.linen.Module = None
transformations: Dict[str, str] = None transformations: Dict[str, str] = None
...@@ -194,24 +209,24 @@ class BaseRunner: ...@@ -194,24 +209,24 @@ class BaseRunner:
self.attrs = attrs self.attrs = attrs
self._generate_test_rngs() self._generate_test_rngs()
# Disable fused attention for attention dropout because the different dropout impl # Disable fused attention for attention dropout because the different dropout impl
if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv('NVTE_FUSED_ATTN'): if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv("NVTE_FUSED_ATTN"):
os.environ['NVTE_FUSED_ATTN'] = "0" os.environ["NVTE_FUSED_ATTN"] = "0"
def _generate_test_rngs(self): def _generate_test_rngs(self):
root_rng = jax.random.PRNGKey(0) root_rng = jax.random.PRNGKey(0)
params_rng, init_dropout_rng, apply_dropout_rng = jax.random.split(root_rng, 3) params_rng, init_dropout_rng, apply_dropout_rng = jax.random.split(root_rng, 3)
self.init_rng = {'params': params_rng, 'dropout': init_dropout_rng} self.init_rng = {"params": params_rng, "dropout": init_dropout_rng}
self.apply_rng = {'dropout': apply_dropout_rng} self.apply_rng = {"dropout": apply_dropout_rng}
def _generate_layer(self, layer_cls, diff_inputs, no_diff_inputs): def _generate_layer(self, layer_cls, diff_inputs, no_diff_inputs):
layer = layer_cls() layer = layer_cls()
variables = layer.init(self.init_rng, *diff_inputs, *no_diff_inputs) variables = layer.init(self.init_rng, *diff_inputs, *no_diff_inputs)
others, params = flax.core.pop(variables, 'params') others, params = flax.core.pop(variables, "params")
del variables del variables
return layer, params, others return layer, params, others
def _loss_fn(self, diff_xs, no_diff_xs, params, others, model): def _loss_fn(self, diff_xs, no_diff_xs, params, others, model):
variables = {'params': params, **others} variables = {"params": params, **others}
output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng) output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng)
return jnp.mean(output, dtype=jnp.float32).astype(output.dtype) return jnp.mean(output, dtype=jnp.float32).astype(output.dtype)
...@@ -259,15 +274,18 @@ class BaseRunner: ...@@ -259,15 +274,18 @@ class BaseRunner:
) )
_, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME) _, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
test_others = FP8Helper.update_collections( test_others = FP8Helper.update_collections(
{FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others) {FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others
)
del tmp_grad, fp8_meta_grad del tmp_grad, fp8_meta_grad
grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False) grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False)
ref_out, (ref_dgrads, ref_wgrads) = grad_fn(inputs, ref_masks, ref_params, ref_others, ref_out, (ref_dgrads, ref_wgrads) = grad_fn(
ref_layer) inputs, ref_masks, ref_params, ref_others, ref_layer
test_out, (test_dgrads, test_wgrads) = grad_fn(inputs, test_masks, test_params, test_others, )
test_layer) test_out, (test_dgrads, test_wgrads) = grad_fn(
inputs, test_masks, test_params, test_others, test_layer
)
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_tree_like_allclose(ref_dgrads, test_dgrads, rtol=rtol, atol=atol) assert_tree_like_allclose(ref_dgrads, test_dgrads, rtol=rtol, atol=atol)
...@@ -278,19 +296,20 @@ class BaseRunner: ...@@ -278,19 +296,20 @@ class BaseRunner:
class EncoderRunner(BaseRunner): class EncoderRunner(BaseRunner):
"""Encoder runner implementations""" """Encoder runner implementations"""
layer_type = TransformerLayerType.ENCODER layer_type = TransformerLayerType.ENCODER
reference_layer = RefEncoderLayer reference_layer = RefEncoderLayer
transformations = { transformations = {
'attention/qkv/scale': 'pre_attention_layer_norm/scale', "attention/qkv/scale": "pre_attention_layer_norm/scale",
'attention/qkv/ln_bias': 'pre_attention_layer_norm/ln_bias', "attention/qkv/ln_bias": "pre_attention_layer_norm/ln_bias",
'attention/query/scale': 'pre_attention_layer_norm/scale', "attention/query/scale": "pre_attention_layer_norm/scale",
'attention/query/ln_bias': 'pre_attention_layer_norm/ln_bias', "attention/query/ln_bias": "pre_attention_layer_norm/ln_bias",
'mlp/wi_kernel': 'mlp/wi/kernel', "mlp/wi_kernel": "mlp/wi/kernel",
'mlp/wi_bias': 'mlp/wi/bias', "mlp/wi_bias": "mlp/wi/bias",
'mlp/wo_kernel': 'mlp/wo/kernel', "mlp/wo_kernel": "mlp/wo/kernel",
'mlp/wo_bias': 'mlp/wo/bias', "mlp/wo_bias": "mlp/wo/bias",
'mlp/scale': 'pre_mlp_layer_norm/scale', "mlp/scale": "pre_mlp_layer_norm/scale",
'mlp/ln_bias': 'pre_mlp_layer_norm/ln_bias', "mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
} }
def generate_inputs(self, data_shape, dtype): def generate_inputs(self, data_shape, dtype):
...@@ -307,13 +326,13 @@ class EncoderRunner(BaseRunner): ...@@ -307,13 +326,13 @@ class EncoderRunner(BaseRunner):
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1) causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ['casual', 'padding_causal']: if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["casual", "padding_causal"]:
mask = causal_mask mask = causal_mask
else: else:
mask = padded_mask mask = padded_mask
ref_masks = (1 - mask,) ref_masks = (1 - mask,)
test_masks = (None, mask) # The second arg of Transformer is encoded tokens. test_masks = (None, mask) # The second arg of Transformer is encoded tokens.
return inputs, (ref_masks, test_masks) return inputs, (ref_masks, test_masks)
...@@ -322,23 +341,24 @@ class DecoderRunner(BaseRunner): ...@@ -322,23 +341,24 @@ class DecoderRunner(BaseRunner):
""" """
Decoder runner implementations Decoder runner implementations
""" """
layer_type = TransformerLayerType.DECODER layer_type = TransformerLayerType.DECODER
reference_layer = RefDecoderLayer reference_layer = RefDecoderLayer
transformations = { transformations = {
'encoder_decoder_attention/qkv/scale': 'pre_cross_attention_layer_norm/scale', "encoder_decoder_attention/qkv/scale": "pre_cross_attention_layer_norm/scale",
'encoder_decoder_attention/qkv/ln_bias': 'pre_cross_attention_layer_norm/ln_bias', "encoder_decoder_attention/qkv/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
'encoder_decoder_attention/query/scale': 'pre_cross_attention_layer_norm/scale', "encoder_decoder_attention/query/scale": "pre_cross_attention_layer_norm/scale",
'encoder_decoder_attention/query/ln_bias': 'pre_cross_attention_layer_norm/ln_bias', "encoder_decoder_attention/query/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
'self_attention/qkv/scale': 'pre_self_attention_layer_norm/scale', "self_attention/qkv/scale": "pre_self_attention_layer_norm/scale",
'self_attention/qkv/ln_bias': 'pre_self_attention_layer_norm/ln_bias', "self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias",
'self_attention/query/scale': 'pre_self_attention_layer_norm/scale', "self_attention/query/scale": "pre_self_attention_layer_norm/scale",
'self_attention/query/ln_bias': 'pre_self_attention_layer_norm/ln_bias', "self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias",
'mlp/wi_kernel': 'mlp/wi/kernel', "mlp/wi_kernel": "mlp/wi/kernel",
'mlp/wi_bias': 'mlp/wi/bias', "mlp/wi_bias": "mlp/wi/bias",
'mlp/wo_kernel': 'mlp/wo/kernel', "mlp/wo_kernel": "mlp/wo/kernel",
'mlp/wo_bias': 'mlp/wo/bias', "mlp/wo_bias": "mlp/wo/bias",
'mlp/scale': 'pre_mlp_layer_norm/scale', "mlp/scale": "pre_mlp_layer_norm/scale",
'mlp/ln_bias': 'pre_mlp_layer_norm/ln_bias', "mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
} }
def generate_inputs(self, data_shape, dtype): def generate_inputs(self, data_shape, dtype):
...@@ -352,12 +372,14 @@ class DecoderRunner(BaseRunner): ...@@ -352,12 +372,14 @@ class DecoderRunner(BaseRunner):
data_rng = jax.random.PRNGKey(0) data_rng = jax.random.PRNGKey(0)
data_rng_0, data_rng_1 = jax.random.split(data_rng, 2) data_rng_0, data_rng_1 = jax.random.split(data_rng, 2)
inputs = (jax.random.normal(data_rng_0, data_shape, inputs = (
dtype), jax.random.normal(data_rng_1, data_shape, dtype)) jax.random.normal(data_rng_0, data_shape, dtype),
jax.random.normal(data_rng_1, data_shape, dtype),
)
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8) padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1) causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ['casual', 'padding_causal']: if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["casual", "padding_causal"]:
self_mask = causal_mask self_mask = causal_mask
else: else:
self_mask = padded_mask self_mask = padded_mask
...@@ -368,27 +390,28 @@ class DecoderRunner(BaseRunner): ...@@ -368,27 +390,28 @@ class DecoderRunner(BaseRunner):
return inputs, (ref_masks, test_masks) return inputs, (ref_masks, test_masks)
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', ATTRS) @pytest.mark.parametrize("attrs", ATTRS)
class BaseTester(): class BaseTester:
""" """
Pytest interface to invoke the runner Pytest interface to invoke the runner
""" """
runner = BaseRunner runner = BaseRunner
def test_forward(self, data_shape, dtype, attrs): def test_forward(self, data_shape, dtype, attrs):
"""Test normal datatype forward""" """Test normal datatype forward"""
FP8Helper.finalize() # Ensure FP8 disabled. FP8Helper.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-5, atol=7e-5) self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-5, atol=7e-5)
def test_backward(self, data_shape, dtype, attrs): def test_backward(self, data_shape, dtype, attrs):
"""Test normal datatype backward""" """Test normal datatype backward"""
FP8Helper.finalize() # Ensure FP8 disabled. FP8Helper.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-5, atol=7e-5) self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-5, atol=7e-5)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS) @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_format): def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_format):
"""Test forward with fp8 enabled""" """Test forward with fp8 enabled"""
FP8Helper.initialize(fp8_format=fp8_format) FP8Helper.initialize(fp8_format=fp8_format)
...@@ -396,7 +419,7 @@ class BaseTester(): ...@@ -396,7 +419,7 @@ class BaseTester():
FP8Helper.finalize() FP8Helper.finalize()
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS) @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_format): def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_format):
"""Test backward with fp8 enabled""" """Test backward with fp8 enabled"""
FP8Helper.initialize(fp8_format=fp8_format) FP8Helper.initialize(fp8_format=fp8_format)
...@@ -408,6 +431,7 @@ class TestEncoderLayer(BaseTester): ...@@ -408,6 +431,7 @@ class TestEncoderLayer(BaseTester):
""" """
Test transformer_engine.jax.flax.TransformerLayer(layer_type=Encoder) Test transformer_engine.jax.flax.TransformerLayer(layer_type=Encoder)
""" """
runner = EncoderRunner runner = EncoderRunner
...@@ -415,4 +439,5 @@ class TestDecoderLayer(BaseTester): ...@@ -415,4 +439,5 @@ class TestDecoderLayer(BaseTester):
""" """
Test transformer_engine.jax.flax.TransformerLayer(layer_type=Decoder) Test transformer_engine.jax.flax.TransformerLayer(layer_type=Decoder)
""" """
runner = DecoderRunner runner = DecoderRunner
...@@ -37,13 +37,13 @@ from transformer_engine.jax.softmax import SoftmaxType ...@@ -37,13 +37,13 @@ from transformer_engine.jax.softmax import SoftmaxType
is_fp8_supported, reason = is_fp8_available() is_fp8_supported, reason = is_fp8_available()
DATA_SHAPE = [(32, 128, 512), (32, 512, 512)] # (B, S, H) DATA_SHAPE = [(32, 128, 512), (32, 512, 512)] # (B, S, H)
DTYPE = [jnp.float32, jnp.bfloat16] DTYPE = [jnp.float32, jnp.bfloat16]
ENABLE_FP8 = [False, True] ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID] FP8_FORMATS = [Format.E4M3, Format.HYBRID]
@pytest.fixture(autouse=True, scope='module') @pytest.fixture(autouse=True, scope="module")
def enable_fused_attn(): def enable_fused_attn():
""" """
Enable fused attn for hopper+ arch. Enable fused attn for hopper+ arch.
...@@ -58,19 +58,16 @@ def enable_fused_attn(): ...@@ -58,19 +58,16 @@ def enable_fused_attn():
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd: for key in ref_fd:
assert key in test_fd, \ assert key in test_fd, f"{key} not found in test dict {test_fd}"
f"{key} not found in test dict {test_fd}" assert isinstance(
assert isinstance(test_fd[key], type(ref_fd[key])), \ test_fd[key], type(ref_fd[key])
f"The data type is not match between ref and test " \ ), f"The data type is not match between ref and test Dict on {key=}"
f" Dict on {key=}"
if isinstance(ref_fd[key], Dict): if isinstance(ref_fd[key], Dict):
compare_dict(ref_fd[key], test_fd[key], rtol, atol) compare_dict(ref_fd[key], test_fd[key], rtol, atol)
else: else:
assert_allclose(ref_fd[key], assert_allclose(
test_fd[key], ref_fd[key], test_fd[key], rtol=rtol, atol=atol, err_msg=f"{key=} is not close"
rtol=rtol, )
atol=atol,
err_msg=f"{key=} is not close")
class TestLayer: class TestLayer:
...@@ -105,9 +102,10 @@ class TestLayer: ...@@ -105,9 +102,10 @@ class TestLayer:
lyr_name = self.get_layer_name() lyr_name = self.get_layer_name()
if 'params' in flax_variables: if "params" in flax_variables:
synced_praxis_variables['params'][lyr_name]['cld'] = \ synced_praxis_variables["params"][lyr_name]["cld"] = flax.core.unfreeze(
flax.core.unfreeze(flax_variables['params']) flax_variables["params"]
)
return synced_praxis_variables, flax_variables return synced_praxis_variables, flax_variables
...@@ -116,23 +114,19 @@ class TestLayer: ...@@ -116,23 +114,19 @@ class TestLayer:
lyr_name = self.get_layer_name() lyr_name = self.get_layer_name()
if 'params' in synced_praxis_grads: if "params" in synced_praxis_grads:
synced_praxis_grads['params'] = \ synced_praxis_grads["params"] = synced_praxis_grads["params"][lyr_name]["cld"]
synced_praxis_grads['params'][lyr_name]['cld']
if FP8Helper.is_fp8_enabled(): if FP8Helper.is_fp8_enabled():
synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = \ synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = synced_praxis_grads[
synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME][lyr_name]['cld'] FP8Helper.FP8_COLLECTION_NAME
][lyr_name]["cld"]
return synced_praxis_grads, flax.core.unfreeze(flax_wgrads) return synced_praxis_grads, flax.core.unfreeze(flax_wgrads)
def forward_backward_runner(self, def forward_backward_runner(
data_shape, self, data_shape, dtype, praxis_p, flax_cls, rtol=1e-05, atol=1e-08
dtype, ):
praxis_p,
flax_cls,
rtol=1e-05,
atol=1e-08):
init_key = jax.random.PRNGKey(seed=1234) init_key = jax.random.PRNGKey(seed=1234)
test_inputs = self.input_getter(data_shape, dtype) test_inputs = self.input_getter(data_shape, dtype)
...@@ -148,28 +142,33 @@ class TestLayer: ...@@ -148,28 +142,33 @@ class TestLayer:
if "params_axes" in flax_variables: if "params_axes" in flax_variables:
flax_variables, _ = flax.core.pop(flax_variables, "params_axes") flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
if FP8Helper.is_fp8_enabled(): if FP8Helper.is_fp8_enabled():
flax_variables, _ = flax.core.pop(flax_variables, flax_variables, _ = flax.core.pop(
FP8Helper.FP8_COLLECTION_NAME + "_axes") flax_variables, FP8Helper.FP8_COLLECTION_NAME + "_axes"
)
praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables) praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)
iter_times = 5 if FP8Helper.is_fp8_enabled() else 1 iter_times = 5 if FP8Helper.is_fp8_enabled() else 1
for _ in range(iter_times): for _ in range(iter_times):
praxis_loss, praxis_wgrads, praxis_dgrad = \ praxis_loss, praxis_wgrads, praxis_dgrad = TestLayer.loss_and_grads(
TestLayer.loss_and_grads(praxis_layer, praxis_variables, *test_inputs) praxis_layer, praxis_variables, *test_inputs
flax_loss, flax_wgrads, flax_dgrad = \ )
TestLayer.loss_and_grads(flax_layer, flax_variables, *test_inputs) flax_loss, flax_wgrads, flax_dgrad = TestLayer.loss_and_grads(
flax_layer, flax_variables, *test_inputs
)
if FP8Helper.is_fp8_enabled(): if FP8Helper.is_fp8_enabled():
praxis_wgrads.pop('params') praxis_wgrads.pop("params")
praxis_variables = update_collections(praxis_wgrads, praxis_variables) praxis_variables = update_collections(praxis_wgrads, praxis_variables)
flax_wgrads, _ = flax.core.pop(flax_wgrads, 'params') flax_wgrads, _ = flax.core.pop(flax_wgrads, "params")
flax_variables = update_collections(flax_wgrads, flax_variables) flax_variables = update_collections(flax_wgrads, flax_variables)
praxis_loss, praxis_wgrads, praxis_dgrad = \ praxis_loss, praxis_wgrads, praxis_dgrad = TestLayer.loss_and_grads(
TestLayer.loss_and_grads(praxis_layer, praxis_variables, *test_inputs) praxis_layer, praxis_variables, *test_inputs
flax_loss, flax_wgrads, flax_dgrad = \ )
TestLayer.loss_and_grads(flax_layer, flax_variables, *test_inputs) flax_loss, flax_wgrads, flax_dgrad = TestLayer.loss_and_grads(
flax_layer, flax_variables, *test_inputs
)
assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol) assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)
assert_allclose(praxis_dgrad, flax_dgrad, rtol=rtol, atol=atol) assert_allclose(praxis_dgrad, flax_dgrad, rtol=rtol, atol=atol)
...@@ -179,18 +178,13 @@ class TestLayer: ...@@ -179,18 +178,13 @@ class TestLayer:
class LayerNormAttr: class LayerNormAttr:
LN_TYPE = 'layernorm_type' LN_TYPE = "layernorm_type"
ZERO_CEN = 'zero_centered_gamma' ZERO_CEN = "zero_centered_gamma"
ATTRS = [{ ATTRS = [
LN_TYPE: "layernorm", {LN_TYPE: "layernorm", ZERO_CEN: False},
ZERO_CEN: False {LN_TYPE: "layernorm", ZERO_CEN: True},
}, { {LN_TYPE: "rmsnorm", ZERO_CEN: False},
LN_TYPE: "layernorm", ]
ZERO_CEN: True
}, {
LN_TYPE: "rmsnorm",
ZERO_CEN: False
}]
class TestLayerNorm(TestLayer): class TestLayerNorm(TestLayer):
...@@ -200,7 +194,7 @@ class TestLayerNorm(TestLayer): ...@@ -200,7 +194,7 @@ class TestLayerNorm(TestLayer):
return (jax.random.normal(data_key, shape, dtype),) return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self): def get_layer_name(self):
return 'layer_norm' return "layer_norm"
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
layernorm_type = attrs[LayerNormAttr.LN_TYPE] layernorm_type = attrs[LayerNormAttr.LN_TYPE]
...@@ -209,63 +203,59 @@ class TestLayerNorm(TestLayer): ...@@ -209,63 +203,59 @@ class TestLayerNorm(TestLayer):
bias_init = WeightInit.Constant(0.0) bias_init = WeightInit.Constant(0.0)
transpose_batch_sequence = False transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(LayerNorm, praxis_p = pax_fiddle.Config(
name='layer_norm', LayerNorm,
dtype=dtype, name="layer_norm",
layernorm_type=layernorm_type, dtype=dtype,
zero_centered_gamma=zero_centered_gamma, layernorm_type=layernorm_type,
scale_init=scale_init, zero_centered_gamma=zero_centered_gamma,
bias_init=bias_init, scale_init=scale_init,
transpose_batch_sequence=transpose_batch_sequence) bias_init=bias_init,
flax_cls = partial(flax_LayerNorm, transpose_batch_sequence=transpose_batch_sequence,
layernorm_type=layernorm_type, )
zero_centered_gamma=zero_centered_gamma, flax_cls = partial(
scale_init=scale_init, flax_LayerNorm,
bias_init=TransformerEngineBaseLayer.generate_params_init( layernorm_type=layernorm_type,
"ln_bias", bias_init), zero_centered_gamma=zero_centered_gamma,
dtype=dtype, scale_init=scale_init,
transpose_batch_sequence=transpose_batch_sequence) bias_init=TransformerEngineBaseLayer.generate_params_init("ln_bias", bias_init),
dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', LayerNormAttr.ATTRS) @pytest.mark.parametrize("attrs", LayerNormAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class FusedSoftmaxAttr: class FusedSoftmaxAttr:
SCALE_FACTOR = 'scale_factor' SCALE_FACTOR = "scale_factor"
ST_TYPE = 'softmax_type' ST_TYPE = "softmax_type"
ATTRS = [{ ATTRS = [
SCALE_FACTOR: 0.0, {SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED},
ST_TYPE: SoftmaxType.SCALED {SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED_MASKED},
}, { {SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED_UPPER_TRIANG_MASKED},
SCALE_FACTOR: 0.0, ]
ST_TYPE: SoftmaxType.SCALED_MASKED
}, {
SCALE_FACTOR: 0.0,
ST_TYPE: SoftmaxType.SCALED_UPPER_TRIANG_MASKED
}]
class TestFusedSoftmax(TestLayer): class TestFusedSoftmax(TestLayer):
def input_getter(self, shape, dtype): def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234) data_key = jax.random.PRNGKey(seed=1234)
return jax.random.normal(data_key, shape, dtype), \ return jax.random.normal(data_key, shape, dtype), jnp.ones(shape, dtype=jnp.uint8) # Masks
jnp.ones(shape, dtype=jnp.uint8) # Masks
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
scale_factor = attrs[FusedSoftmaxAttr.SCALE_FACTOR] scale_factor = attrs[FusedSoftmaxAttr.SCALE_FACTOR]
softmax_type = attrs[FusedSoftmaxAttr.ST_TYPE] softmax_type = attrs[FusedSoftmaxAttr.ST_TYPE]
praxis_p = pax_fiddle.Config(FusedSoftmax, praxis_p = pax_fiddle.Config(
name='fused_softmax', FusedSoftmax, name="fused_softmax", scale_factor=scale_factor, softmax_type=softmax_type
scale_factor=scale_factor, )
softmax_type=softmax_type)
flax_cls = partial(Softmax, scale_factor=scale_factor, softmax_type=softmax_type) flax_cls = partial(Softmax, scale_factor=scale_factor, softmax_type=softmax_type)
return praxis_p, flax_cls return praxis_p, flax_cls
...@@ -276,34 +266,28 @@ class TestFusedSoftmax(TestLayer): ...@@ -276,34 +266,28 @@ class TestFusedSoftmax(TestLayer):
def sync_wgrads(self, praxis_wgrads, flax_wgrads): def sync_wgrads(self, praxis_wgrads, flax_wgrads):
return praxis_wgrads, flax_wgrads return praxis_wgrads, flax_wgrads
@pytest.mark.parametrize('data_shape', [(32, 1, 128, 128), (32, 1, 512, 128)]) @pytest.mark.parametrize("data_shape", [(32, 1, 128, 128), (32, 1, 512, 128)])
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', FusedSoftmaxAttr.ATTRS) @pytest.mark.parametrize("attrs", FusedSoftmaxAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
if (attrs[FusedSoftmaxAttr.ST_TYPE] == SoftmaxType.SCALED_UPPER_TRIANG_MASKED) and \ if (attrs[FusedSoftmaxAttr.ST_TYPE] == SoftmaxType.SCALED_UPPER_TRIANG_MASKED) and (
(data_shape[-2] != data_shape[-1]): data_shape[-2] != data_shape[-1]
pass # Skip, due to not support ):
pass # Skip, due to not support
else: else:
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class LinearAttr: class LinearAttr:
FEATURE = 'features' FEATURE = "features"
USE_BIAS = 'use_bias' USE_BIAS = "use_bias"
ATTRS = [{ ATTRS = [
FEATURE: 512, {FEATURE: 512, USE_BIAS: False},
USE_BIAS: False {FEATURE: 512, USE_BIAS: True},
}, { {FEATURE: 1024, USE_BIAS: False},
FEATURE: 512, {FEATURE: 1024, USE_BIAS: True},
USE_BIAS: True ]
}, {
FEATURE: 1024,
USE_BIAS: False
}, {
FEATURE: 1024,
USE_BIAS: True
}]
class TestLinear(TestLayer): class TestLinear(TestLayer):
...@@ -313,7 +297,7 @@ class TestLinear(TestLayer): ...@@ -313,7 +297,7 @@ class TestLinear(TestLayer):
return (jax.random.normal(data_key, shape, dtype),) return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self): def get_layer_name(self):
return 'linear' return "linear"
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
out_features = attrs[LinearAttr.FEATURE] out_features = attrs[LinearAttr.FEATURE]
...@@ -323,15 +307,17 @@ class TestLinear(TestLayer): ...@@ -323,15 +307,17 @@ class TestLinear(TestLayer):
axis = -1 axis = -1
transpose_batch_sequence = False transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(Linear, praxis_p = pax_fiddle.Config(
name='linear', Linear,
dtype=dtype, name="linear",
out_features=out_features, dtype=dtype,
params_init=kernel_init, out_features=out_features,
use_bias=use_bias, params_init=kernel_init,
bias_init=bias_init, use_bias=use_bias,
axis=axis, bias_init=bias_init,
transpose_batch_sequence=transpose_batch_sequence) axis=axis,
transpose_batch_sequence=transpose_batch_sequence,
)
flax_cls = partial( flax_cls = partial(
DenseGeneral, DenseGeneral,
features=out_features, features=out_features,
...@@ -340,29 +326,26 @@ class TestLinear(TestLayer): ...@@ -340,29 +326,26 @@ class TestLinear(TestLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init), bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
axis=axis, axis=axis,
dtype=dtype, dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', LinearAttr.ATTRS) @pytest.mark.parametrize("attrs", LinearAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', LinearAttr.ATTRS) @pytest.mark.parametrize("attrs", LinearAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS) @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(self, def test_forward_backward_fp8(
data_shape, self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
dtype, ):
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
ds = DelayedScaling(fp8_format=fp8_format) ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
...@@ -371,54 +354,20 @@ class TestLinear(TestLayer): ...@@ -371,54 +354,20 @@ class TestLinear(TestLayer):
class LayerNormLinearAttr: class LayerNormLinearAttr:
FEATURE = 'features' FEATURE = "features"
USE_BIAS = 'use_bias' USE_BIAS = "use_bias"
ENABLE_LN = 'enable_layernorm' ENABLE_LN = "enable_layernorm"
LN_TYPE = 'layernorm_type' LN_TYPE = "layernorm_type"
ZERO_CEN = 'zero_centered_gamma' ZERO_CEN = "zero_centered_gamma"
ATTRS = [{ ATTRS = [
FEATURE: 512, {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: False},
USE_BIAS: True, {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: False},
ENABLE_LN: True, {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: True},
LN_TYPE: 'layernorm', {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: True},
ZERO_CEN: False {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "rmsnorm", ZERO_CEN: False},
}, { {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "rmsnorm", ZERO_CEN: False},
FEATURE: 512, {FEATURE: 512, USE_BIAS: True, ENABLE_LN: False, LN_TYPE: "layernorm", ZERO_CEN: False},
USE_BIAS: True, ]
ENABLE_LN: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: False,
LN_TYPE: 'layernorm',
ZERO_CEN: False
}]
class TestLayerNormLinear(TestLayer): class TestLayerNormLinear(TestLayer):
...@@ -428,7 +377,7 @@ class TestLayerNormLinear(TestLayer): ...@@ -428,7 +377,7 @@ class TestLayerNormLinear(TestLayer):
return (jax.random.normal(data_key, shape, dtype),) return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self): def get_layer_name(self):
return 'ln_linear' return "ln_linear"
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
out_features = attrs[LayerNormLinearAttr.FEATURE] out_features = attrs[LayerNormLinearAttr.FEATURE]
...@@ -441,18 +390,20 @@ class TestLayerNormLinear(TestLayer): ...@@ -441,18 +390,20 @@ class TestLayerNormLinear(TestLayer):
axis = -1 axis = -1
transpose_batch_sequence = False transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(LayerNormLinear, praxis_p = pax_fiddle.Config(
name='ln_linear', LayerNormLinear,
dtype=dtype, name="ln_linear",
out_features=out_features, dtype=dtype,
enable_layernorm=enable_layernorm, out_features=out_features,
layernorm_type=layernorm_type, enable_layernorm=enable_layernorm,
zero_centered_gamma=zero_centered_gamma, layernorm_type=layernorm_type,
params_init=kernel_init, zero_centered_gamma=zero_centered_gamma,
use_bias=use_bias, params_init=kernel_init,
bias_init=bias_init, use_bias=use_bias,
axis=axis, bias_init=bias_init,
transpose_batch_sequence=transpose_batch_sequence) axis=axis,
transpose_batch_sequence=transpose_batch_sequence,
)
flax_cls = partial( flax_cls = partial(
LayerNormDenseGeneral, LayerNormDenseGeneral,
features=out_features, features=out_features,
...@@ -464,29 +415,26 @@ class TestLayerNormLinear(TestLayer): ...@@ -464,29 +415,26 @@ class TestLayerNormLinear(TestLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init), bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
axis=axis, axis=axis,
dtype=dtype, dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', LayerNormLinearAttr.ATTRS) @pytest.mark.parametrize("attrs", LayerNormLinearAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', LayerNormLinearAttr.ATTRS) @pytest.mark.parametrize("attrs", LayerNormLinearAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS) @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(self, def test_forward_backward_fp8(
data_shape, self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
dtype, ):
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
ds = DelayedScaling(fp8_format=fp8_format) ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
...@@ -495,62 +443,70 @@ class TestLayerNormLinear(TestLayer): ...@@ -495,62 +443,70 @@ class TestLayerNormLinear(TestLayer):
class LayerNormMLPAttr: class LayerNormMLPAttr:
INTERMEDIATE_DIM = 'intermediate_dim' INTERMEDIATE_DIM = "intermediate_dim"
USE_BIAS = 'use_bias' USE_BIAS = "use_bias"
ENABLE_LN = 'enable_layernorm' ENABLE_LN = "enable_layernorm"
LN_TYPE = 'layernorm_type' LN_TYPE = "layernorm_type"
ZERO_CEN = 'zero_centered_gamma' ZERO_CEN = "zero_centered_gamma"
ACTIVATION = 'activations' ACTIVATION = "activations"
ATTRS = [{ ATTRS = [
INTERMEDIATE_DIM: 2048, {
USE_BIAS: True, INTERMEDIATE_DIM: 2048,
ENABLE_LN: True, USE_BIAS: True,
LN_TYPE: 'layernorm', ENABLE_LN: True,
ZERO_CEN: False, LN_TYPE: "layernorm",
ACTIVATION: ('relu',) ZERO_CEN: False,
}, { ACTIVATION: ("relu",),
INTERMEDIATE_DIM: 2048, },
USE_BIAS: True, {
ENABLE_LN: True, INTERMEDIATE_DIM: 2048,
LN_TYPE: 'layernorm', USE_BIAS: True,
ZERO_CEN: True, ENABLE_LN: True,
ACTIVATION: ('relu',) LN_TYPE: "layernorm",
}, { ZERO_CEN: True,
INTERMEDIATE_DIM: 2048, ACTIVATION: ("relu",),
USE_BIAS: True, },
ENABLE_LN: True, {
LN_TYPE: 'rmsnorm', INTERMEDIATE_DIM: 2048,
ZERO_CEN: False, USE_BIAS: True,
ACTIVATION: ('relu',) ENABLE_LN: True,
}, { LN_TYPE: "rmsnorm",
INTERMEDIATE_DIM: 2048, ZERO_CEN: False,
USE_BIAS: True, ACTIVATION: ("relu",),
ENABLE_LN: True, },
LN_TYPE: 'rmsnorm', {
ZERO_CEN: False, INTERMEDIATE_DIM: 2048,
ACTIVATION: ('gelu', 'linear') USE_BIAS: True,
}, { ENABLE_LN: True,
INTERMEDIATE_DIM: 2048, LN_TYPE: "rmsnorm",
USE_BIAS: False, ZERO_CEN: False,
ENABLE_LN: True, ACTIVATION: ("gelu", "linear"),
LN_TYPE: 'rmsnorm', },
ZERO_CEN: False, {
ACTIVATION: ('gelu', 'linear') INTERMEDIATE_DIM: 2048,
}, { USE_BIAS: False,
INTERMEDIATE_DIM: 2048, ENABLE_LN: True,
USE_BIAS: True, LN_TYPE: "rmsnorm",
ENABLE_LN: True, ZERO_CEN: False,
LN_TYPE: 'rmsnorm', ACTIVATION: ("gelu", "linear"),
ZERO_CEN: False, },
ACTIVATION: ('silu', 'linear') {
}, { INTERMEDIATE_DIM: 2048,
INTERMEDIATE_DIM: 2048, USE_BIAS: True,
USE_BIAS: False, ENABLE_LN: True,
ENABLE_LN: True, LN_TYPE: "rmsnorm",
LN_TYPE: 'rmsnorm', ZERO_CEN: False,
ZERO_CEN: False, ACTIVATION: ("silu", "linear"),
ACTIVATION: ('silu', 'linear') },
}] {
INTERMEDIATE_DIM: 2048,
USE_BIAS: False,
ENABLE_LN: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("silu", "linear"),
},
]
class TestLayerNormMLP(TestLayer): class TestLayerNormMLP(TestLayer):
...@@ -560,7 +516,7 @@ class TestLayerNormMLP(TestLayer): ...@@ -560,7 +516,7 @@ class TestLayerNormMLP(TestLayer):
return (jax.random.normal(data_key, shape, dtype),) return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self): def get_layer_name(self):
return 'ln_mlp' return "ln_mlp"
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
intermediate_dim = attrs[LayerNormMLPAttr.INTERMEDIATE_DIM] intermediate_dim = attrs[LayerNormMLPAttr.INTERMEDIATE_DIM]
...@@ -574,20 +530,22 @@ class TestLayerNormMLP(TestLayer): ...@@ -574,20 +530,22 @@ class TestLayerNormMLP(TestLayer):
axis = -1 axis = -1
transpose_batch_sequence = False transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(LayerNormMLP, praxis_p = pax_fiddle.Config(
name='ln_mlp', LayerNormMLP,
dtype=dtype, name="ln_mlp",
intermediate_dim=intermediate_dim, dtype=dtype,
enable_layernorm=enable_layernorm, intermediate_dim=intermediate_dim,
layernorm_type=layernorm_type, enable_layernorm=enable_layernorm,
zero_centered_gamma=zero_centered_gamma, layernorm_type=layernorm_type,
params_init=kernel_init, zero_centered_gamma=zero_centered_gamma,
use_bias=use_bias, params_init=kernel_init,
bias_init=bias_init, use_bias=use_bias,
activations=activations, bias_init=bias_init,
intermediate_dropout_rate=0.0, activations=activations,
axis=axis, intermediate_dropout_rate=0.0,
transpose_batch_sequence=transpose_batch_sequence) axis=axis,
transpose_batch_sequence=transpose_batch_sequence,
)
flax_cls = partial( flax_cls = partial(
flax_LayerNormMLP, flax_LayerNormMLP,
intermediate_dim=intermediate_dim, intermediate_dim=intermediate_dim,
...@@ -601,29 +559,26 @@ class TestLayerNormMLP(TestLayer): ...@@ -601,29 +559,26 @@ class TestLayerNormMLP(TestLayer):
intermediate_dropout_rate=0.0, intermediate_dropout_rate=0.0,
axis=axis, axis=axis,
dtype=dtype, dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', LayerNormMLPAttr.ATTRS) @pytest.mark.parametrize("attrs", LayerNormMLPAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', LayerNormMLPAttr.ATTRS) @pytest.mark.parametrize("attrs", LayerNormMLPAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS) @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(self, def test_forward_backward_fp8(
data_shape, self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
dtype, ):
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
ds = DelayedScaling(fp8_format=fp8_format) ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
...@@ -634,35 +589,40 @@ class TestLayerNormMLP(TestLayer): ...@@ -634,35 +589,40 @@ class TestLayerNormMLP(TestLayer):
class TestRelativePositionBias(TestLayer): class TestRelativePositionBias(TestLayer):
def get_layer_name(self): def get_layer_name(self):
return 'relative_position_bias' return "relative_position_bias"
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
num_buckets = 32 num_buckets = 32
max_distance = 128 max_distance = 128
num_attention_heads = 64 num_attention_heads = 64
rb_stddev = (num_attention_heads * num_buckets)**-0.5 rb_stddev = (num_attention_heads * num_buckets) ** -0.5
embedding_init = WeightInit.Gaussian(rb_stddev) embedding_init = WeightInit.Gaussian(rb_stddev)
praxis_p = pax_fiddle.Config(RelativePositionBiases, praxis_p = pax_fiddle.Config(
name='relative_position_bias', RelativePositionBiases,
dtype=dtype, name="relative_position_bias",
num_buckets=num_buckets, dtype=dtype,
max_distance=max_distance, num_buckets=num_buckets,
num_attention_heads=num_attention_heads, max_distance=max_distance,
embedding_init=embedding_init) num_attention_heads=num_attention_heads,
flax_cls = partial(flax_RelativePositionBiases, embedding_init=embedding_init,
num_buckets=num_buckets, )
max_distance=max_distance, flax_cls = partial(
num_attention_heads=num_attention_heads, flax_RelativePositionBiases,
embedding_init=TransformerEngineBaseLayer.generate_params_init( num_buckets=num_buckets,
"rel_embedding", embedding_init), max_distance=max_distance,
dtype=dtype) num_attention_heads=num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", embedding_init
),
dtype=dtype,
)
return praxis_p, flax_cls return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', [{}]) @pytest.mark.parametrize("attrs", [{}])
def test_forward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
...@@ -678,53 +638,64 @@ class TestRelativePositionBias(TestLayer): ...@@ -678,53 +638,64 @@ class TestRelativePositionBias(TestLayer):
if "params_axes" in flax_variables: if "params_axes" in flax_variables:
flax_variables, _ = flax.core.pop(flax_variables, "params_axes") flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
if FP8Helper.is_fp8_enabled(): if FP8Helper.is_fp8_enabled():
flax_variables, _ = flax.core.pop(flax_variables, flax_variables, _ = flax.core.pop(
FP8Helper.FP8_COLLECTION_NAME + "_axes") flax_variables, FP8Helper.FP8_COLLECTION_NAME + "_axes"
)
praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables) praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)
praxis_loss= \ praxis_loss = TestLayer.loss(
TestLayer.loss(praxis_variables, *test_input, module=praxis_layer, mean_out=False) praxis_variables, *test_input, module=praxis_layer, mean_out=False
flax_loss = \ )
TestLayer.loss(flax_variables, *test_input, module=flax_layer, mean_out=False) flax_loss = TestLayer.loss(
flax_variables, *test_input, module=flax_layer, mean_out=False
)
assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol) assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)
class DotProductAttnAttr: class DotProductAttnAttr:
ATTN_MASK_TYPE = 'attn_mask_type' ATTN_MASK_TYPE = "attn_mask_type"
NUM_GQA_GROUPS = 'num_gqa_groups' NUM_GQA_GROUPS = "num_gqa_groups"
TRANSPOSE_BS = 'transpose_batch_sequence' TRANSPOSE_BS = "transpose_batch_sequence"
SCALE_FACTOR = 'scale_factor' SCALE_FACTOR = "scale_factor"
ATTRS = [{ ATTRS = [
ATTN_MASK_TYPE: 'padding', {
TRANSPOSE_BS: True, ATTN_MASK_TYPE: "padding",
SCALE_FACTOR: 0.125, TRANSPOSE_BS: True,
}, { SCALE_FACTOR: 0.125,
ATTN_MASK_TYPE: 'padding_causal', },
TRANSPOSE_BS: True, {
SCALE_FACTOR: 0.125, ATTN_MASK_TYPE: "padding_causal",
}, { TRANSPOSE_BS: True,
ATTN_MASK_TYPE: 'causal', SCALE_FACTOR: 0.125,
TRANSPOSE_BS: True, },
SCALE_FACTOR: 0.125, {
}, { ATTN_MASK_TYPE: "causal",
ATTN_MASK_TYPE: 'padding', TRANSPOSE_BS: True,
TRANSPOSE_BS: False, SCALE_FACTOR: 0.125,
SCALE_FACTOR: 0.125, },
}, { {
ATTN_MASK_TYPE: 'padding_causal', ATTN_MASK_TYPE: "padding",
TRANSPOSE_BS: False, TRANSPOSE_BS: False,
SCALE_FACTOR: 2., SCALE_FACTOR: 0.125,
}, { },
ATTN_MASK_TYPE: 'causal', {
TRANSPOSE_BS: False, ATTN_MASK_TYPE: "padding_causal",
SCALE_FACTOR: 1., TRANSPOSE_BS: False,
}, { SCALE_FACTOR: 2.0,
ATTN_MASK_TYPE: 'no_mask', },
TRANSPOSE_BS: False, {
SCALE_FACTOR: 1., ATTN_MASK_TYPE: "causal",
}] TRANSPOSE_BS: False,
SCALE_FACTOR: 1.0,
},
{
ATTN_MASK_TYPE: "no_mask",
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.0,
},
]
class TestDotProductAttn(TestLayer): class TestDotProductAttn(TestLayer):
...@@ -737,11 +708,12 @@ class TestDotProductAttn(TestLayer): ...@@ -737,11 +708,12 @@ class TestDotProductAttn(TestLayer):
shape = (shape[1], shape[0]) + shape[2:] shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8) mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [ return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]),
mask,
] ]
def get_layer_name(self): def get_layer_name(self):
return 'dot_product_attn' return "dot_product_attn"
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
head_dim = 64 head_dim = 64
...@@ -750,27 +722,31 @@ class TestDotProductAttn(TestLayer): ...@@ -750,27 +722,31 @@ class TestDotProductAttn(TestLayer):
attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE] attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE]
transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS] transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS]
praxis_p = pax_fiddle.Config(DotProductAttention, praxis_p = pax_fiddle.Config(
name='mha', DotProductAttention,
dtype=dtype, name="mha",
head_dim=head_dim, dtype=dtype,
num_attention_heads=num_attention_heads, head_dim=head_dim,
num_gqa_groups=num_gqa_groups, num_attention_heads=num_attention_heads,
attn_mask_type=attn_mask_type, num_gqa_groups=num_gqa_groups,
transpose_batch_sequence=transpose_batch_sequence) attn_mask_type=attn_mask_type,
flax_cls = partial(flax_DotProductAttention, transpose_batch_sequence=transpose_batch_sequence,
dtype=dtype, )
head_dim=head_dim, flax_cls = partial(
num_attention_heads=num_attention_heads, flax_DotProductAttention,
num_gqa_groups=num_gqa_groups, dtype=dtype,
attn_mask_type=attn_mask_type, head_dim=head_dim,
transpose_batch_sequence=transpose_batch_sequence) num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', [(32, 128, 16, 64)]) @pytest.mark.parametrize("data_shape", [(32, 128, 16, 64)])
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS) @pytest.mark.parametrize("attrs", DotProductAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
...@@ -778,113 +754,125 @@ class TestDotProductAttn(TestLayer): ...@@ -778,113 +754,125 @@ class TestDotProductAttn(TestLayer):
class MultiHeadAttnAttr: class MultiHeadAttnAttr:
USE_BIAS = 'use_bias' USE_BIAS = "use_bias"
LN_TYPE = 'layernorm_type' LN_TYPE = "layernorm_type"
ATTN_MASK_TYPE = 'attn_mask_type' ATTN_MASK_TYPE = "attn_mask_type"
ZERO_CEN = 'zero_centered_gamma' ZERO_CEN = "zero_centered_gamma"
NUM_ATTN_HEADS = 'num_attention_heads' NUM_ATTN_HEADS = "num_attention_heads"
NUM_GQA_GROUPS = 'num_gqa_groups' NUM_GQA_GROUPS = "num_gqa_groups"
TRANSPOSE_BS = 'transpose_batch_sequence' TRANSPOSE_BS = "transpose_batch_sequence"
ENABLE_ROPE = 'enable_rotary_pos_emb' ENABLE_ROPE = "enable_rotary_pos_emb"
ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method' ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
LORA_SCOPE = 'low_rank_adaptation_scope' LORA_SCOPE = "low_rank_adaptation_scope"
ATTRS = [{ ATTRS = [
USE_BIAS: True, {
LN_TYPE: 'layernorm', USE_BIAS: True,
ZERO_CEN: False, LN_TYPE: "layernorm",
ENABLE_ROPE: False, ZERO_CEN: False,
ROPE_GROUP_METHOD: 'consecutive', ENABLE_ROPE: False,
ATTN_MASK_TYPE: 'padding', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True, ATTN_MASK_TYPE: "padding",
}, { TRANSPOSE_BS: True,
USE_BIAS: True, },
LN_TYPE: 'layernorm', {
ZERO_CEN: True, USE_BIAS: True,
ENABLE_ROPE: False, LN_TYPE: "layernorm",
ROPE_GROUP_METHOD: 'consecutive', ZERO_CEN: True,
ATTN_MASK_TYPE: 'padding', ENABLE_ROPE: False,
TRANSPOSE_BS: False, ROPE_GROUP_METHOD: "consecutive",
}, { ATTN_MASK_TYPE: "padding",
USE_BIAS: True, TRANSPOSE_BS: False,
LN_TYPE: 'rmsnorm', },
ZERO_CEN: False, {
ENABLE_ROPE: False, USE_BIAS: True,
ROPE_GROUP_METHOD: 'consecutive', LN_TYPE: "rmsnorm",
ATTN_MASK_TYPE: 'padding', ZERO_CEN: False,
TRANSPOSE_BS: True, ENABLE_ROPE: False,
}, { ROPE_GROUP_METHOD: "consecutive",
USE_BIAS: True, ATTN_MASK_TYPE: "padding",
LN_TYPE: 'layernorm', TRANSPOSE_BS: True,
ZERO_CEN: False, },
ENABLE_ROPE: False, {
ROPE_GROUP_METHOD: 'consecutive', USE_BIAS: True,
ATTN_MASK_TYPE: 'causal', LN_TYPE: "layernorm",
TRANSPOSE_BS: False, ZERO_CEN: False,
}, { ENABLE_ROPE: False,
USE_BIAS: True, ROPE_GROUP_METHOD: "consecutive",
LN_TYPE: 'layernorm', ATTN_MASK_TYPE: "causal",
ZERO_CEN: True, TRANSPOSE_BS: False,
ENABLE_ROPE: False, },
ROPE_GROUP_METHOD: 'consecutive', {
ATTN_MASK_TYPE: 'causal', USE_BIAS: True,
TRANSPOSE_BS: True, LN_TYPE: "layernorm",
}, { ZERO_CEN: True,
USE_BIAS: True, ENABLE_ROPE: False,
LN_TYPE: 'rmsnorm', ROPE_GROUP_METHOD: "consecutive",
ZERO_CEN: False, ATTN_MASK_TYPE: "causal",
ENABLE_ROPE: False, TRANSPOSE_BS: True,
ROPE_GROUP_METHOD: 'consecutive', },
ATTN_MASK_TYPE: 'causal', {
TRANSPOSE_BS: False, USE_BIAS: True,
}, { LN_TYPE: "rmsnorm",
USE_BIAS: True, ZERO_CEN: False,
LN_TYPE: 'rmsnorm', ENABLE_ROPE: False,
ZERO_CEN: False, ROPE_GROUP_METHOD: "consecutive",
ENABLE_ROPE: False, ATTN_MASK_TYPE: "causal",
ROPE_GROUP_METHOD: 'consecutive', TRANSPOSE_BS: False,
NUM_ATTN_HEADS: 8, },
NUM_GQA_GROUPS: 4, {
ATTN_MASK_TYPE: 'causal', USE_BIAS: True,
TRANSPOSE_BS: True, LN_TYPE: "rmsnorm",
}, { ZERO_CEN: False,
USE_BIAS: True, ENABLE_ROPE: False,
LN_TYPE: 'rmsnorm', ROPE_GROUP_METHOD: "consecutive",
ZERO_CEN: False, NUM_ATTN_HEADS: 8,
ENABLE_ROPE: True, NUM_GQA_GROUPS: 4,
ROPE_GROUP_METHOD: 'consecutive', ATTN_MASK_TYPE: "causal",
NUM_ATTN_HEADS: 8, TRANSPOSE_BS: True,
NUM_GQA_GROUPS: 4, },
ATTN_MASK_TYPE: 'causal', {
TRANSPOSE_BS: False, USE_BIAS: True,
}, { LN_TYPE: "rmsnorm",
USE_BIAS: True, ZERO_CEN: False,
LN_TYPE: 'rmsnorm', ENABLE_ROPE: True,
ZERO_CEN: False, ROPE_GROUP_METHOD: "consecutive",
ENABLE_ROPE: True, NUM_ATTN_HEADS: 8,
ROPE_GROUP_METHOD: 'alternate', NUM_GQA_GROUPS: 4,
NUM_ATTN_HEADS: 8, ATTN_MASK_TYPE: "causal",
NUM_GQA_GROUPS: 4, TRANSPOSE_BS: False,
ATTN_MASK_TYPE: 'causal', },
TRANSPOSE_BS: True, {
}, { USE_BIAS: True,
USE_BIAS: True, LN_TYPE: "rmsnorm",
LN_TYPE: 'layernorm', ZERO_CEN: False,
ZERO_CEN: False, ENABLE_ROPE: True,
ENABLE_ROPE: False, ROPE_GROUP_METHOD: "alternate",
ROPE_GROUP_METHOD: 'consecutive', NUM_ATTN_HEADS: 8,
ATTN_MASK_TYPE: 'padding', NUM_GQA_GROUPS: 4,
LORA_SCOPE: 'all', ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: False, TRANSPOSE_BS: True,
}, { },
USE_BIAS: True, {
LN_TYPE: 'layernorm', USE_BIAS: True,
ZERO_CEN: False, LN_TYPE: "layernorm",
ENABLE_ROPE: False, ZERO_CEN: False,
ROPE_GROUP_METHOD: 'consecutive', ENABLE_ROPE: False,
ATTN_MASK_TYPE: 'causal', ROPE_GROUP_METHOD: "consecutive",
LORA_SCOPE: 'all', ATTN_MASK_TYPE: "padding",
TRANSPOSE_BS: True, LORA_SCOPE: "all",
}] TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "causal",
LORA_SCOPE: "all",
TRANSPOSE_BS: True,
},
]
class TestMultiHeadAttn(TestLayer): class TestMultiHeadAttn(TestLayer):
...@@ -899,13 +887,16 @@ class TestMultiHeadAttn(TestLayer): ...@@ -899,13 +887,16 @@ class TestMultiHeadAttn(TestLayer):
return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask] return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask]
def get_layer_name(self): def get_layer_name(self):
return 'multi_head_attn' return "multi_head_attn"
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
head_dim = 64 head_dim = 64
num_attention_heads = 16 num_attention_heads = 16
num_gqa_groups = attrs[MultiHeadAttnAttr.NUM_GQA_GROUPS] \ num_gqa_groups = (
if MultiHeadAttnAttr.NUM_GQA_GROUPS in attrs else None attrs[MultiHeadAttnAttr.NUM_GQA_GROUPS]
if MultiHeadAttnAttr.NUM_GQA_GROUPS in attrs
else None
)
layernorm_type = attrs[MultiHeadAttnAttr.LN_TYPE] layernorm_type = attrs[MultiHeadAttnAttr.LN_TYPE]
zero_centered_gamma = attrs[MultiHeadAttnAttr.ZERO_CEN] zero_centered_gamma = attrs[MultiHeadAttnAttr.ZERO_CEN]
kernel_init = WeightInit.Gaussian(1.0) kernel_init = WeightInit.Gaussian(1.0)
...@@ -916,35 +907,37 @@ class TestMultiHeadAttn(TestLayer): ...@@ -916,35 +907,37 @@ class TestMultiHeadAttn(TestLayer):
attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE] attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE]
enable_rotary_pos_emb = attrs[MultiHeadAttnAttr.ENABLE_ROPE] enable_rotary_pos_emb = attrs[MultiHeadAttnAttr.ENABLE_ROPE]
rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD] rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD]
low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, 'none') low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, "none")
fuse_qkv_params = True fuse_qkv_params = True
transpose_batch_sequence = attrs[MultiHeadAttnAttr.TRANSPOSE_BS] transpose_batch_sequence = attrs[MultiHeadAttnAttr.TRANSPOSE_BS]
scale_attn_logits = False scale_attn_logits = False
scaled_query_init = True scaled_query_init = True
float32_logits = False float32_logits = False
praxis_p = pax_fiddle.Config(MultiHeadAttention, praxis_p = pax_fiddle.Config(
name='mha', MultiHeadAttention,
dtype=dtype, name="mha",
head_dim=head_dim, dtype=dtype,
num_attention_heads=num_attention_heads, head_dim=head_dim,
num_gqa_groups=num_gqa_groups, num_attention_heads=num_attention_heads,
layernorm_type=layernorm_type, num_gqa_groups=num_gqa_groups,
zero_centered_gamma=zero_centered_gamma, layernorm_type=layernorm_type,
params_init=kernel_init, zero_centered_gamma=zero_centered_gamma,
use_bias=use_bias, params_init=kernel_init,
bias_init=bias_init, use_bias=use_bias,
return_layernorm_output=return_layernorm_output, bias_init=bias_init,
input_layernorm=input_layernorm, return_layernorm_output=return_layernorm_output,
attn_mask_type=attn_mask_type, input_layernorm=input_layernorm,
enable_rotary_pos_emb=enable_rotary_pos_emb, attn_mask_type=attn_mask_type,
rotary_pos_emb_group_method=rotary_pos_emb_group_method, enable_rotary_pos_emb=enable_rotary_pos_emb,
low_rank_adaptation_scope=low_rank_adaptation_scope, rotary_pos_emb_group_method=rotary_pos_emb_group_method,
fuse_qkv_params=fuse_qkv_params, low_rank_adaptation_scope=low_rank_adaptation_scope,
transpose_batch_sequence=transpose_batch_sequence, fuse_qkv_params=fuse_qkv_params,
scale_attn_logits=scale_attn_logits, transpose_batch_sequence=transpose_batch_sequence,
scaled_query_init=scaled_query_init, scale_attn_logits=scale_attn_logits,
float32_logits=float32_logits) scaled_query_init=scaled_query_init,
float32_logits=float32_logits,
)
flax_cls = partial( flax_cls = partial(
flax_MultiHeadAttention, flax_MultiHeadAttention,
dtype=dtype, dtype=dtype,
...@@ -966,30 +959,27 @@ class TestMultiHeadAttn(TestLayer): ...@@ -966,30 +959,27 @@ class TestMultiHeadAttn(TestLayer):
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits, scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init, scaled_query_init=scaled_query_init,
float32_logits=float32_logits) float32_logits=float32_logits,
)
return praxis_p, flax_cls return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS) @pytest.mark.parametrize("attrs", MultiHeadAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS) @pytest.mark.parametrize("attrs", MultiHeadAttnAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS) @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(self, def test_forward_backward_fp8(
data_shape, self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
dtype, ):
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
self.attrs = attrs self.attrs = attrs
ds = DelayedScaling(fp8_format=fp8_format) ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
...@@ -998,252 +988,279 @@ class TestMultiHeadAttn(TestLayer): ...@@ -998,252 +988,279 @@ class TestMultiHeadAttn(TestLayer):
class TransformerLayerAttr: class TransformerLayerAttr:
USE_BIAS = 'use_bias' USE_BIAS = "use_bias"
LN_TYPE = 'layernorm_type' LN_TYPE = "layernorm_type"
ACTIVATION = 'activations' ACTIVATION = "activations"
LYR_TYPE = 'layer_type' LYR_TYPE = "layer_type"
ZERO_CEN = 'zero_centered_gamma' ZERO_CEN = "zero_centered_gamma"
TRANSPOSE_BS = 'transpose_batch_sequence' TRANSPOSE_BS = "transpose_batch_sequence"
ENABLE_ROPE = 'enable_rotary_pos_emb' ENABLE_ROPE = "enable_rotary_pos_emb"
ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method' ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
LORA_SCOPE = 'low_rank_adaptation_scope' LORA_SCOPE = "low_rank_adaptation_scope"
ATTRS = [{ ATTRS = [
USE_BIAS: True, {
LN_TYPE: 'layernorm', USE_BIAS: True,
ZERO_CEN: False, LN_TYPE: "layernorm",
ACTIVATION: ('relu',), ZERO_CEN: False,
LYR_TYPE: TransformerLayerType.ENCODER, ACTIVATION: ("relu",),
ENABLE_ROPE: False, LYR_TYPE: TransformerLayerType.ENCODER,
ROPE_GROUP_METHOD: 'consecutive', ENABLE_ROPE: False,
TRANSPOSE_BS: True ROPE_GROUP_METHOD: "consecutive",
}, { TRANSPOSE_BS: True,
USE_BIAS: True, },
LN_TYPE: 'layernorm', {
ZERO_CEN: False, USE_BIAS: True,
ACTIVATION: ('relu',), LN_TYPE: "layernorm",
LYR_TYPE: TransformerLayerType.ENCODER, ZERO_CEN: False,
ENABLE_ROPE: False, ACTIVATION: ("relu",),
ROPE_GROUP_METHOD: 'consecutive', LYR_TYPE: TransformerLayerType.ENCODER,
TRANSPOSE_BS: False ENABLE_ROPE: False,
}, { ROPE_GROUP_METHOD: "consecutive",
USE_BIAS: True, TRANSPOSE_BS: False,
LN_TYPE: 'layernorm', },
ZERO_CEN: True, {
ACTIVATION: ('relu',), USE_BIAS: True,
LYR_TYPE: TransformerLayerType.ENCODER, LN_TYPE: "layernorm",
ENABLE_ROPE: False, ZERO_CEN: True,
ROPE_GROUP_METHOD: 'consecutive', ACTIVATION: ("relu",),
TRANSPOSE_BS: True LYR_TYPE: TransformerLayerType.ENCODER,
}, { ENABLE_ROPE: False,
USE_BIAS: True, ROPE_GROUP_METHOD: "consecutive",
LN_TYPE: 'layernorm', TRANSPOSE_BS: True,
ZERO_CEN: True, },
ACTIVATION: ('relu',), {
LYR_TYPE: TransformerLayerType.ENCODER, USE_BIAS: True,
ENABLE_ROPE: False, LN_TYPE: "layernorm",
ROPE_GROUP_METHOD: 'consecutive', ZERO_CEN: True,
TRANSPOSE_BS: False ACTIVATION: ("relu",),
}, { LYR_TYPE: TransformerLayerType.ENCODER,
USE_BIAS: True, ENABLE_ROPE: False,
LN_TYPE: 'rmsnorm', ROPE_GROUP_METHOD: "consecutive",
ZERO_CEN: False, TRANSPOSE_BS: False,
ACTIVATION: ('relu',), },
LYR_TYPE: TransformerLayerType.ENCODER, {
ENABLE_ROPE: False, USE_BIAS: True,
ROPE_GROUP_METHOD: 'consecutive', LN_TYPE: "rmsnorm",
TRANSPOSE_BS: True ZERO_CEN: False,
}, { ACTIVATION: ("relu",),
USE_BIAS: True, LYR_TYPE: TransformerLayerType.ENCODER,
LN_TYPE: 'rmsnorm', ENABLE_ROPE: False,
ZERO_CEN: False, ROPE_GROUP_METHOD: "consecutive",
ACTIVATION: ('relu',), TRANSPOSE_BS: True,
LYR_TYPE: TransformerLayerType.ENCODER, },
ENABLE_ROPE: False, {
ROPE_GROUP_METHOD: 'consecutive', USE_BIAS: True,
TRANSPOSE_BS: False LN_TYPE: "rmsnorm",
}, { ZERO_CEN: False,
USE_BIAS: True, ACTIVATION: ("relu",),
LN_TYPE: 'layernorm', LYR_TYPE: TransformerLayerType.ENCODER,
ZERO_CEN: True, ENABLE_ROPE: False,
ACTIVATION: ('relu',), ROPE_GROUP_METHOD: "consecutive",
LYR_TYPE: TransformerLayerType.DECODER, TRANSPOSE_BS: False,
ENABLE_ROPE: False, },
ROPE_GROUP_METHOD: 'consecutive', {
TRANSPOSE_BS: True USE_BIAS: True,
}, { LN_TYPE: "layernorm",
USE_BIAS: True, ZERO_CEN: True,
LN_TYPE: 'layernorm', ACTIVATION: ("relu",),
ZERO_CEN: True, LYR_TYPE: TransformerLayerType.DECODER,
ACTIVATION: ('relu',), ENABLE_ROPE: False,
LYR_TYPE: TransformerLayerType.DECODER, ROPE_GROUP_METHOD: "consecutive",
ENABLE_ROPE: False, TRANSPOSE_BS: True,
ROPE_GROUP_METHOD: 'consecutive', },
TRANSPOSE_BS: False {
}, { USE_BIAS: True,
USE_BIAS: True, LN_TYPE: "layernorm",
LN_TYPE: 'layernorm', ZERO_CEN: True,
ZERO_CEN: False, ACTIVATION: ("relu",),
ACTIVATION: ('relu',), LYR_TYPE: TransformerLayerType.DECODER,
LYR_TYPE: TransformerLayerType.DECODER, ENABLE_ROPE: False,
ENABLE_ROPE: False, ROPE_GROUP_METHOD: "consecutive",
ROPE_GROUP_METHOD: 'consecutive', TRANSPOSE_BS: False,
TRANSPOSE_BS: True },
}, { {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: "layernorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('relu',), ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False TRANSPOSE_BS: True,
}, { },
USE_BIAS: True, {
LN_TYPE: 'rmsnorm', USE_BIAS: True,
ZERO_CEN: False, LN_TYPE: "layernorm",
ACTIVATION: ('relu',), ZERO_CEN: False,
LYR_TYPE: TransformerLayerType.DECODER, ACTIVATION: ("relu",),
ENABLE_ROPE: False, LYR_TYPE: TransformerLayerType.DECODER,
ROPE_GROUP_METHOD: 'consecutive', ENABLE_ROPE: False,
TRANSPOSE_BS: True ROPE_GROUP_METHOD: "consecutive",
}, { TRANSPOSE_BS: False,
USE_BIAS: True, },
LN_TYPE: 'rmsnorm', {
ZERO_CEN: False, USE_BIAS: True,
ACTIVATION: ('relu',), LN_TYPE: "rmsnorm",
LYR_TYPE: TransformerLayerType.DECODER, ZERO_CEN: False,
ENABLE_ROPE: False, ACTIVATION: ("relu",),
ROPE_GROUP_METHOD: 'consecutive', LYR_TYPE: TransformerLayerType.DECODER,
TRANSPOSE_BS: False ENABLE_ROPE: False,
}, { ROPE_GROUP_METHOD: "consecutive",
USE_BIAS: True, TRANSPOSE_BS: True,
LN_TYPE: 'layernorm', },
ZERO_CEN: False, {
ACTIVATION: ('gelu', 'linear'), USE_BIAS: True,
LYR_TYPE: TransformerLayerType.ENCODER, LN_TYPE: "rmsnorm",
ENABLE_ROPE: False, ZERO_CEN: False,
ROPE_GROUP_METHOD: 'consecutive', ACTIVATION: ("relu",),
TRANSPOSE_BS: True LYR_TYPE: TransformerLayerType.DECODER,
}, { ENABLE_ROPE: False,
USE_BIAS: True, ROPE_GROUP_METHOD: "consecutive",
LN_TYPE: 'layernorm', TRANSPOSE_BS: False,
ZERO_CEN: False, },
ACTIVATION: ('gelu', 'linear'), {
LYR_TYPE: TransformerLayerType.ENCODER, USE_BIAS: True,
ENABLE_ROPE: False, LN_TYPE: "layernorm",
ROPE_GROUP_METHOD: 'consecutive', ZERO_CEN: False,
TRANSPOSE_BS: False ACTIVATION: ("gelu", "linear"),
}, { LYR_TYPE: TransformerLayerType.ENCODER,
USE_BIAS: True, ENABLE_ROPE: False,
LN_TYPE: 'rmsnorm', ROPE_GROUP_METHOD: "consecutive",
ZERO_CEN: False, TRANSPOSE_BS: True,
ACTIVATION: ('gelu', 'linear'), },
LYR_TYPE: TransformerLayerType.ENCODER, {
ENABLE_ROPE: False, USE_BIAS: True,
ROPE_GROUP_METHOD: 'consecutive', LN_TYPE: "layernorm",
TRANSPOSE_BS: True ZERO_CEN: False,
}, { ACTIVATION: ("gelu", "linear"),
USE_BIAS: True, LYR_TYPE: TransformerLayerType.ENCODER,
LN_TYPE: 'rmsnorm', ENABLE_ROPE: False,
ZERO_CEN: False, ROPE_GROUP_METHOD: "consecutive",
ACTIVATION: ('gelu', 'linear'), TRANSPOSE_BS: False,
LYR_TYPE: TransformerLayerType.ENCODER, },
ENABLE_ROPE: False, {
ROPE_GROUP_METHOD: 'consecutive', USE_BIAS: True,
TRANSPOSE_BS: False LN_TYPE: "rmsnorm",
}, { ZERO_CEN: False,
USE_BIAS: True, ACTIVATION: ("gelu", "linear"),
LN_TYPE: 'layernorm', LYR_TYPE: TransformerLayerType.ENCODER,
ZERO_CEN: False, ENABLE_ROPE: False,
ACTIVATION: ('gelu',), ROPE_GROUP_METHOD: "consecutive",
LYR_TYPE: TransformerLayerType.ENCODER, TRANSPOSE_BS: True,
ENABLE_ROPE: False, },
ROPE_GROUP_METHOD: 'consecutive', {
TRANSPOSE_BS: False, USE_BIAS: True,
LORA_SCOPE: 'all' LN_TYPE: "rmsnorm",
}, { ZERO_CEN: False,
USE_BIAS: True, ACTIVATION: ("gelu", "linear"),
LN_TYPE: 'layernorm', LYR_TYPE: TransformerLayerType.ENCODER,
ZERO_CEN: False, ENABLE_ROPE: False,
ACTIVATION: ('gelu', 'linear'), ROPE_GROUP_METHOD: "consecutive",
LYR_TYPE: TransformerLayerType.DECODER, TRANSPOSE_BS: False,
ENABLE_ROPE: False, },
ROPE_GROUP_METHOD: 'consecutive', {
TRANSPOSE_BS: True USE_BIAS: True,
}, { LN_TYPE: "layernorm",
USE_BIAS: True, ZERO_CEN: False,
LN_TYPE: 'layernorm', ACTIVATION: ("gelu",),
ZERO_CEN: False, LYR_TYPE: TransformerLayerType.ENCODER,
ACTIVATION: ('gelu', 'linear'), ENABLE_ROPE: False,
LYR_TYPE: TransformerLayerType.DECODER, ROPE_GROUP_METHOD: "consecutive",
ENABLE_ROPE: False, TRANSPOSE_BS: False,
ROPE_GROUP_METHOD: 'consecutive', LORA_SCOPE: "all",
TRANSPOSE_BS: False },
}, { {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: "layernorm",
ZERO_CEN: False, ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'), ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.DECODER, LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True TRANSPOSE_BS: True,
}, { },
USE_BIAS: True, {
LN_TYPE: 'rmsnorm', USE_BIAS: True,
ZERO_CEN: False, LN_TYPE: "layernorm",
ACTIVATION: ('gelu', 'linear'), ZERO_CEN: False,
LYR_TYPE: TransformerLayerType.DECODER, ACTIVATION: ("gelu", "linear"),
ENABLE_ROPE: False, LYR_TYPE: TransformerLayerType.DECODER,
ROPE_GROUP_METHOD: 'consecutive', ENABLE_ROPE: False,
TRANSPOSE_BS: False ROPE_GROUP_METHOD: "consecutive",
}, { TRANSPOSE_BS: False,
USE_BIAS: True, },
LN_TYPE: 'layernorm', {
ZERO_CEN: True, USE_BIAS: True,
ACTIVATION: ('gelu',), LN_TYPE: "rmsnorm",
LYR_TYPE: TransformerLayerType.ENCODER, ZERO_CEN: False,
ENABLE_ROPE: True, ACTIVATION: ("gelu", "linear"),
ROPE_GROUP_METHOD: 'alternate', LYR_TYPE: TransformerLayerType.DECODER,
TRANSPOSE_BS: False ENABLE_ROPE: False,
}, { ROPE_GROUP_METHOD: "consecutive",
USE_BIAS: True, TRANSPOSE_BS: True,
LN_TYPE: 'layernorm', },
ZERO_CEN: True, {
ACTIVATION: ('gelu',), USE_BIAS: True,
LYR_TYPE: TransformerLayerType.DECODER, LN_TYPE: "rmsnorm",
ENABLE_ROPE: True, ZERO_CEN: False,
ROPE_GROUP_METHOD: 'alternate', ACTIVATION: ("gelu", "linear"),
TRANSPOSE_BS: False LYR_TYPE: TransformerLayerType.DECODER,
}, { ENABLE_ROPE: False,
USE_BIAS: True, ROPE_GROUP_METHOD: "consecutive",
LN_TYPE: 'layernorm', TRANSPOSE_BS: False,
ZERO_CEN: True, },
ACTIVATION: ('gelu',), {
LYR_TYPE: TransformerLayerType.ENCODER, USE_BIAS: True,
ENABLE_ROPE: True, LN_TYPE: "layernorm",
ROPE_GROUP_METHOD: 'consecutive', ZERO_CEN: True,
TRANSPOSE_BS: False ACTIVATION: ("gelu",),
}, { LYR_TYPE: TransformerLayerType.ENCODER,
USE_BIAS: True, ENABLE_ROPE: True,
LN_TYPE: 'layernorm', ROPE_GROUP_METHOD: "alternate",
ZERO_CEN: True, TRANSPOSE_BS: False,
ACTIVATION: ('gelu',), },
LYR_TYPE: TransformerLayerType.DECODER, {
ENABLE_ROPE: True, USE_BIAS: True,
ROPE_GROUP_METHOD: 'consecutive', LN_TYPE: "layernorm",
TRANSPOSE_BS: False ZERO_CEN: True,
}, { ACTIVATION: ("gelu",),
USE_BIAS: True, LYR_TYPE: TransformerLayerType.DECODER,
LN_TYPE: 'layernorm', ENABLE_ROPE: True,
ZERO_CEN: False, ROPE_GROUP_METHOD: "alternate",
ACTIVATION: ('gelu',), TRANSPOSE_BS: False,
LYR_TYPE: TransformerLayerType.DECODER, },
ENABLE_ROPE: False, {
ROPE_GROUP_METHOD: 'consecutive', USE_BIAS: True,
TRANSPOSE_BS: False, LN_TYPE: "layernorm",
LORA_SCOPE: 'all' ZERO_CEN: True,
}] ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
LORA_SCOPE: "all",
},
]
class TestTransformer(TestLayer): class TestTransformer(TestLayer):
...@@ -1256,11 +1273,13 @@ class TestTransformer(TestLayer): ...@@ -1256,11 +1273,13 @@ class TestTransformer(TestLayer):
shape = (shape[1], shape[0]) + shape[2:] shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8) mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [ return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]),
mask,
mask,
] ]
def get_layer_name(self): def get_layer_name(self):
return 'transformerlayer' return "transformerlayer"
def generate_praxis_p_and_flax_cls(self, dtype, attrs): def generate_praxis_p_and_flax_cls(self, dtype, attrs):
hidden_size = 512 hidden_size = 512
...@@ -1277,97 +1296,102 @@ class TestTransformer(TestLayer): ...@@ -1277,97 +1296,102 @@ class TestTransformer(TestLayer):
layer_type = attrs[TransformerLayerAttr.LYR_TYPE] layer_type = attrs[TransformerLayerAttr.LYR_TYPE]
enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE] enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
rotary_pos_emb_group_method = attrs[TransformerLayerAttr.ROPE_GROUP_METHOD] rotary_pos_emb_group_method = attrs[TransformerLayerAttr.ROPE_GROUP_METHOD]
low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, 'none') low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, "none")
enable_relative_embedding = True enable_relative_embedding = True
relative_embedding = pax_fiddle.Config(RelativePositionBiases, relative_embedding = pax_fiddle.Config(
dtype=dtype, RelativePositionBiases, dtype=dtype, num_attention_heads=num_attention_heads
num_attention_heads=num_attention_heads) )
drop_path = 0.0 drop_path = 0.0
transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS] transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS]
rel_embedding_init = RelativePositionBiases.generate_embedding_init( rel_embedding_init = RelativePositionBiases.generate_embedding_init(
relative_embedding.embedding_init, relative_embedding.num_attention_heads, relative_embedding.embedding_init,
relative_embedding.num_buckets) relative_embedding.num_attention_heads,
relative_embedding.num_buckets,
)
relative_embedding_flax_module = flax_RelativePositionBiases( relative_embedding_flax_module = flax_RelativePositionBiases(
num_buckets=relative_embedding.num_buckets, num_buckets=relative_embedding.num_buckets,
max_distance=relative_embedding.max_distance, max_distance=relative_embedding.max_distance,
num_attention_heads=relative_embedding.num_attention_heads, num_attention_heads=relative_embedding.num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init( embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", rel_embedding_init), "rel_embedding", rel_embedding_init
),
embedding_axes=relative_embedding.embedding_axes, embedding_axes=relative_embedding.embedding_axes,
dtype=relative_embedding.dtype) dtype=relative_embedding.dtype,
)
praxis_p = pax_fiddle.Config(TransformerLayer,
name='transformer_layer', praxis_p = pax_fiddle.Config(
params_init=kernel_init, TransformerLayer,
dtype=dtype, name="transformer_layer",
hidden_size=hidden_size, params_init=kernel_init,
mlp_hidden_size=mlp_hidden_size, dtype=dtype,
num_attention_heads=num_attention_heads, hidden_size=hidden_size,
layernorm_type=layernorm_type, mlp_hidden_size=mlp_hidden_size,
hidden_dropout=hidden_dropout, num_attention_heads=num_attention_heads,
attention_dropout=attention_dropout, layernorm_type=layernorm_type,
intermediate_dropout=intermediate_dropout, hidden_dropout=hidden_dropout,
mlp_activations=mlp_activations, attention_dropout=attention_dropout,
use_bias=use_bias, intermediate_dropout=intermediate_dropout,
bias_init=bias_init, mlp_activations=mlp_activations,
layer_type=layer_type, use_bias=use_bias,
enable_relative_embedding=enable_relative_embedding, bias_init=bias_init,
enable_rotary_pos_emb=enable_rotary_pos_emb, layer_type=layer_type,
rotary_pos_emb_group_method=rotary_pos_emb_group_method, enable_relative_embedding=enable_relative_embedding,
low_rank_adaptation_scope=low_rank_adaptation_scope, enable_rotary_pos_emb=enable_rotary_pos_emb,
relative_embedding=relative_embedding, rotary_pos_emb_group_method=rotary_pos_emb_group_method,
drop_path=drop_path, low_rank_adaptation_scope=low_rank_adaptation_scope,
transpose_batch_sequence=transpose_batch_sequence) relative_embedding=relative_embedding,
flax_cls = partial(flax_TransformerLayer, drop_path=drop_path,
dtype=dtype, transpose_batch_sequence=transpose_batch_sequence,
hidden_size=hidden_size, )
mlp_hidden_size=mlp_hidden_size, flax_cls = partial(
num_attention_heads=num_attention_heads, flax_TransformerLayer,
layernorm_type=layernorm_type, dtype=dtype,
hidden_dropout=hidden_dropout, hidden_size=hidden_size,
attention_dropout=attention_dropout, mlp_hidden_size=mlp_hidden_size,
intermediate_dropout=intermediate_dropout, num_attention_heads=num_attention_heads,
mlp_activations=mlp_activations, layernorm_type=layernorm_type,
mha_kernel_init=TransformerEngineBaseLayer.generate_params_init( hidden_dropout=hidden_dropout,
"mha_kernel", kernel_init), attention_dropout=attention_dropout,
mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init( intermediate_dropout=intermediate_dropout,
"mlp_kernel", kernel_init), mlp_activations=mlp_activations,
use_bias=use_bias, mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
bias_init=TransformerEngineBaseLayer.generate_params_init( "mha_kernel", kernel_init
"bias", bias_init), ),
layer_type=layer_type, mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
enable_rotary_pos_emb=enable_rotary_pos_emb, "mlp_kernel", kernel_init
rotary_pos_emb_group_method=rotary_pos_emb_group_method, ),
enable_relative_embedding=enable_relative_embedding, use_bias=use_bias,
relative_embedding=relative_embedding_flax_module, bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
low_rank_adaptation_scope=low_rank_adaptation_scope, layer_type=layer_type,
drop_path=drop_path, enable_rotary_pos_emb=enable_rotary_pos_emb,
transpose_batch_sequence=transpose_batch_sequence) rotary_pos_emb_group_method=rotary_pos_emb_group_method,
enable_relative_embedding=enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
low_rank_adaptation_scope=low_rank_adaptation_scope,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS) @pytest.mark.parametrize("attrs", TransformerLayerAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE) @pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS) @pytest.mark.parametrize("attrs", TransformerLayerAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS) @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(self, def test_forward_backward_fp8(
data_shape, self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
dtype, ):
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
self.attrs = attrs self.attrs = attrs
ds = DelayedScaling(fp8_format=fp8_format) ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds): with fp8_autocast(enabled=True, fp8_recipe=ds):
......
...@@ -3,4 +3,5 @@ ...@@ -3,4 +3,5 @@
# See LICENSE for license information. # See LICENSE for license information.
import transformer_engine.jax import transformer_engine.jax
print("OK") print("OK")
...@@ -8,25 +8,25 @@ from transformer_engine.jax.flax import extend_logical_axis_rules ...@@ -8,25 +8,25 @@ from transformer_engine.jax.flax import extend_logical_axis_rules
from transformer_engine.jax.sharding import global_shard_guard, MeshResource from transformer_engine.jax.sharding import global_shard_guard, MeshResource
LOGICAL_RULES = [ LOGICAL_RULES = [
[(('a1', None), ('a2', 'ma2')), False], [(("a1", None), ("a2", "ma2")), False],
[(('a1', None), ('a2', 'ma2'), ('a3', ('ma31', 'ma32'))), True], [(("a1", None), ("a2", "ma2"), ("a3", ("ma31", "ma32"))), True],
[(('a1', None), ('a2', 'ma2'), ('a3', 'ma31'), ('a3', 'ma32')), False], [(("a1", None), ("a2", "ma2"), ("a3", "ma31"), ("a3", "ma32")), False],
[(('a1', None), ('a2', 'ma2'), ('batch', 'batch_1200234')), True], [(("a1", None), ("a2", "ma2"), ("batch", "batch_1200234")), True],
[(('a1', None), ('a2', 'ma2'), ('a2', 'ma1'), ('batch', 'model'), ('batch', 'data')), True], [(("a1", None), ("a2", "ma2"), ("a2", "ma1"), ("batch", "model"), ("batch", "data")), True],
] ]
MeshS = [ MeshS = [
MeshResource(), MeshResource(),
MeshResource('data', None), MeshResource("data", None),
MeshResource(None, 'model'), MeshResource(None, "model"),
MeshResource('data', 'model') MeshResource("data", "model"),
] ]
class TestShardingSideAPI: class TestShardingSideAPI:
@pytest.mark.parametrize('base_rules,need_assert', LOGICAL_RULES) @pytest.mark.parametrize("base_rules,need_assert", LOGICAL_RULES)
@pytest.mark.parametrize('sr', MeshS) @pytest.mark.parametrize("sr", MeshS)
def test_extend_logical_axis_rules(self, base_rules, need_assert, sr): def test_extend_logical_axis_rules(self, base_rules, need_assert, sr):
with global_shard_guard(sr): with global_shard_guard(sr):
try: try:
......
...@@ -43,6 +43,7 @@ class SoftmaxRunner: ...@@ -43,6 +43,7 @@ class SoftmaxRunner:
""" """
Softmax runner Softmax runner
""" """
batch_size: int batch_size: int
max_seqlen_q: int max_seqlen_q: int
max_seqlen_kv: int max_seqlen_kv: int
...@@ -57,14 +58,22 @@ class SoftmaxRunner: ...@@ -57,14 +58,22 @@ class SoftmaxRunner:
Jax softmax as the reference Jax softmax as the reference
""" """
if mask is not None: if mask is not None:
logits += lax.select(mask > 0, logits += lax.select(
jnp.full(mask.shape, -1e10).astype(logits.dtype), mask > 0,
jnp.full(mask.shape, 0.).astype(logits.dtype)) jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return nn.softmax(logits * scale_factor) return nn.softmax(logits * scale_factor)
def _is_support(self): def _is_support(self):
return is_softmax_kernel_available(self.softmax_type, self.batch_size, self.num_heads, return is_softmax_kernel_available(
self.max_seqlen_q, self.max_seqlen_kv, self.dtype) self.softmax_type,
self.batch_size,
self.num_heads,
self.max_seqlen_q,
self.max_seqlen_kv,
self.dtype,
)
def _setup_inputs(self): def _setup_inputs(self):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
...@@ -73,7 +82,7 @@ class SoftmaxRunner: ...@@ -73,7 +82,7 @@ class SoftmaxRunner:
logits_shape = (self.batch_size, self.num_heads, self.max_seqlen_q, self.max_seqlen_kv) logits_shape = (self.batch_size, self.num_heads, self.max_seqlen_q, self.max_seqlen_kv)
mask_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv) mask_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.) self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.0)
match self.softmax_type: match self.softmax_type:
case SoftmaxType.SCALED: case SoftmaxType.SCALED:
...@@ -81,7 +90,7 @@ class SoftmaxRunner: ...@@ -81,7 +90,7 @@ class SoftmaxRunner:
case SoftmaxType.SCALED_MASKED: case SoftmaxType.SCALED_MASKED:
self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8) self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8)
case SoftmaxType.SCALED_UPPER_TRIANG_MASKED: case SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
self.mask = (1. - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8) self.mask = (1.0 - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8)
case _: case _:
raise ValueError(f"Unknown {self.softmax_type=}") raise ValueError(f"Unknown {self.softmax_type=}")
...@@ -108,18 +117,24 @@ class SoftmaxRunner: ...@@ -108,18 +117,24 @@ class SoftmaxRunner:
args = [self.logits, self.mask] args = [self.logits, self.mask]
kwargs = { kwargs = {
'scale_factor': self.scale_factor, "scale_factor": self.scale_factor,
'softmax_type': self.softmax_type, "softmax_type": self.softmax_type,
} }
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit( jitted_primitive = jit(
value_and_grad(lambda logits, *args: grad_func(softmax, self.logits, *args, **kwargs), value_and_grad(
(0,))) lambda logits, *args: grad_func(softmax, self.logits, *args, **kwargs), (0,)
)
)
jitted_reference = jit( jitted_reference = jit(
value_and_grad( value_and_grad(
lambda logits, *args: grad_func(__class__.reference_softmax, self.logits, *args, ** lambda logits, *args: grad_func(
kwargs), (0,))) __class__.reference_softmax, self.logits, *args, **kwargs
),
(0,),
)
)
primitive_out, (primitive_grad_logits,) = jitted_primitive(*args) primitive_out, (primitive_grad_logits,) = jitted_primitive(*args)
reference_out, (reference_grad_logits,) = jitted_reference(*args) reference_out, (reference_grad_logits,) = jitted_reference(*args)
...@@ -128,21 +143,30 @@ class SoftmaxRunner: ...@@ -128,21 +143,30 @@ class SoftmaxRunner:
assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype) assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype)
@pytest.mark.parametrize('b, s_q, s_kv, h', [ @pytest.mark.parametrize(
pytest.param(8, 16, 16, 16, id='8-16-16-16'), "b, s_q, s_kv, h",
pytest.param(8, 512, 512, 16, id='8-512-512-16'), [
pytest.param(2, 8, 16384, 8, id='2-8-16384-8') pytest.param(8, 16, 16, 16, id="8-16-16-16"),
]) pytest.param(8, 512, 512, 16, id="8-512-512-16"),
@pytest.mark.parametrize('scale_factor', [0.125]) pytest.param(2, 8, 16384, 8, id="2-8-16384-8"),
@pytest.mark.parametrize('softmax_type', [ ],
pytest.param(SoftmaxType.SCALED, id='SCALED'), )
pytest.param(SoftmaxType.SCALED_MASKED, id='SCALED_MASKED'), @pytest.mark.parametrize("scale_factor", [0.125])
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id='SCALED_UPPER_TRIANG_MASKED') @pytest.mark.parametrize(
]) "softmax_type",
@pytest.mark.parametrize('dtype', [ [
pytest.param(jnp.bfloat16, id="BF16"), pytest.param(SoftmaxType.SCALED, id="SCALED"),
pytest.param(jnp.float16, id="FP16"), pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"),
]) pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
],
)
@pytest.mark.parametrize(
"dtype",
[
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16"),
],
)
class TestSoftmax: class TestSoftmax:
""" """
Test transformer_engine.jax.softmax.softmax Test transformer_engine.jax.softmax.softmax
......
...@@ -24,8 +24,9 @@ PRNGKey = Any ...@@ -24,8 +24,9 @@ PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
DType = jnp.dtype DType = jnp.dtype
Array = Any Array = Any
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, PrecisionLike = Union[
lax.Precision]] None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
Initializer = Callable[[PRNGKey, Shape, DType], Array] Initializer = Callable[[PRNGKey, Shape, DType], Array]
...@@ -56,7 +57,7 @@ def _canonicalize_tuple(x): ...@@ -56,7 +57,7 @@ def _canonicalize_tuple(x):
def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable: def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
"""Convert a string to an activation function.""" """Convert a string to an activation function."""
if fn_or_string == 'linear': if fn_or_string == "linear":
return lambda x: x return lambda x: x
if isinstance(fn_or_string, str): if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string) return getattr(nn, fn_or_string)
...@@ -68,17 +69,18 @@ def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Calla ...@@ -68,17 +69,18 @@ def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Calla
def combine_biases(*masks: Optional[Array]): def combine_biases(*masks: Optional[Array]):
"""Combine attention biases. """Combine attention biases.
Args: Args:
*masks: set of attention bias arguments to combine, some can be None. *masks: set of attention bias arguments to combine, some can be None.
Returns: Returns:
Combined mask, reduced by summation, returns None if no masks given. Combined mask, reduced by summation, returns None if no masks given.
""" """
masks = [m for m in masks if m is not None] masks = [m for m in masks if m is not None]
if not masks: if not masks:
return None return None
assert all(map(lambda x: x.ndim == masks[0].ndim, assert all(
masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}') map(lambda x: x.ndim == masks[0].ndim, masks)
), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
mask, *other_masks = masks mask, *other_masks = masks
for other_mask in other_masks: for other_mask in other_masks:
mask = mask + other_mask mask = mask + other_mask
...@@ -88,7 +90,7 @@ def combine_biases(*masks: Optional[Array]): ...@@ -88,7 +90,7 @@ def combine_biases(*masks: Optional[Array]):
class DotProductAttention(nn.Module): class DotProductAttention(nn.Module):
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
scale_attn_logits: bool = True scale_attn_logits: bool = True
dropout_rate: float = 0. dropout_rate: float = 0.0
dtype: DType = jnp.float32 dtype: DType = jnp.float32
float32_logits: bool = False float32_logits: bool = False
"""Computes dot-product attention given query, key, and value. """Computes dot-product attention given query, key, and value.
...@@ -105,12 +107,14 @@ class DotProductAttention(nn.Module): ...@@ -105,12 +107,14 @@ class DotProductAttention(nn.Module):
""" """
@nn.compact @nn.compact
def __call__(self, def __call__(
query: Array, self,
key: Array, query: Array,
value: Array, key: Array,
bias: Optional[Array] = None, value: Array,
deterministic: bool = False): bias: Optional[Array] = None,
deterministic: bool = False,
):
""" """
Args: Args:
query: queries for calculating attention with shape of `[batch, q_length, query: queries for calculating attention with shape of `[batch, q_length,
...@@ -127,14 +131,15 @@ class DotProductAttention(nn.Module): ...@@ -127,14 +131,15 @@ class DotProductAttention(nn.Module):
Returns: Returns:
Output of shape `[batch, length, num_heads, v_depth_per_head]`. Output of shape `[batch, length, num_heads, v_depth_per_head]`.
""" """
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
batch_dim = 1 if self.transpose_batch_sequence else 0 batch_dim = 1 if self.transpose_batch_sequence else 0
assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], ( assert (
'q, k, v batch dims must match.') query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
), "q, k, v batch dims must match."
sequence_dim = 0 if self.transpose_batch_sequence else 1 sequence_dim = 0 if self.transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.' assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.' assert key.shape[-2] == value.shape[-2], "k, v num_heads must match."
assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.' assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
if self.scale_attn_logits: if self.scale_attn_logits:
head_dim = query.shape[-1] head_dim = query.shape[-1]
...@@ -153,9 +158,9 @@ class DotProductAttention(nn.Module): ...@@ -153,9 +158,9 @@ class DotProductAttention(nn.Module):
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1])) grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key) attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
else: else:
attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key) attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
# reshape back to normal DPA shape for bias/softmax/dropout # reshape back to normal DPA shape for bias/softmax/dropout
b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
...@@ -170,37 +175,37 @@ class DotProductAttention(nn.Module): ...@@ -170,37 +175,37 @@ class DotProductAttention(nn.Module):
attn_weights = jax_nn.softmax(attn_weights).astype(self.dtype) attn_weights = jax_nn.softmax(attn_weights).astype(self.dtype)
# Apply attention dropout. # Apply attention dropout.
if not deterministic and self.dropout_rate > 0.: if not deterministic and self.dropout_rate > 0.0:
keep_prob = 1.0 - self.dropout_rate keep_prob = 1.0 - self.dropout_rate
# T5 broadcasts along the "length" dim, but unclear which one that # T5 broadcasts along the "length" dim, but unclear which one that
# corresponds to in positional dimensions here, assuming query dim. # corresponds to in positional dimensions here, assuming query dim.
dropout_shape = list(attn_weights.shape) dropout_shape = list(attn_weights.shape)
dropout_rng = self.make_rng('dropout') dropout_rng = self.make_rng("dropout")
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
multiplier = (keep.astype(attn_weights.dtype) / multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype)
jnp.asarray(keep_prob, dtype=self.dtype))
attn_weights = attn_weights * multiplier attn_weights = attn_weights * multiplier
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
# Take the linear combination of `value`. # Take the linear combination of `value`.
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape) return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape) return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
class DenseGeneral(nn.Module): class DenseGeneral(nn.Module):
"""A linear transformation with flexible axes and FP8 support. """A linear transformation with flexible axes and FP8 support.
Attributes: Attributes:
features: tuple with numbers of output features. features: tuple with numbers of output features.
axis: tuple with axes to apply the transformation on. axis: tuple with axes to apply the transformation on.
dtype: the dtype of the computation (default: float32). dtype: the dtype of the computation (default: float32).
kernel_init: initializer function for the weight matrix. kernel_init: initializer function for the weight matrix.
use_bias: whether to add a bias to the output (default: False). use_bias: whether to add a bias to the output (default: False).
bias_init: initializer function for the bias vector. bias_init: initializer function for the bias vector.
""" """
features: Union[Iterable[int], int] features: Union[Iterable[int], int]
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
...@@ -212,7 +217,7 @@ class DenseGeneral(nn.Module): ...@@ -212,7 +217,7 @@ class DenseGeneral(nn.Module):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -233,21 +238,17 @@ class DenseGeneral(nn.Module): ...@@ -233,21 +238,17 @@ class DenseGeneral(nn.Module):
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features)) kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
kernel = nn_partitioning.param_with_axes('kernel', kernel = nn_partitioning.param_with_axes(
self.kernel_init, "kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes
kernel_param_shape, )
jnp.float32,
axes=self.kernel_axes)
kernel = jnp.asarray(kernel, self.dtype) kernel = jnp.asarray(kernel, self.dtype)
kernel = jnp.reshape(kernel, kernel_shape) kernel = jnp.reshape(kernel, kernel_shape)
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes('bias', bias = nn_partitioning.param_with_axes(
self.bias_init, "bias", self.bias_init, self.features, jnp.float32, axes=self.bias_axes
self.features, )
jnp.float32,
axes=self.bias_axes)
bias = bias.astype(self.dtype) bias = bias.astype(self.dtype)
else: else:
bias = None bias = None
...@@ -264,18 +265,19 @@ class DenseGeneral(nn.Module): ...@@ -264,18 +265,19 @@ class DenseGeneral(nn.Module):
class MlpBlock(nn.Module): class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block. """Transformer MLP / feed-forward block.
Attributes: Attributes:
intermediate_dim: Shared dimension of hidden layers. intermediate_dim: Shared dimension of hidden layers.
activations: Type of activations for each layer. Each element is either activations: Type of activations for each layer. Each element is either
'linear', a string function name in flax.linen, or a function. 'linear', a string function name in flax.linen, or a function.
kernel_init: Kernel function, passed to the dense layers. kernel_init: Kernel function, passed to the dense layers.
deterministic: Whether the dropout layers should be deterministic. deterministic: Whether the dropout layers should be deterministic.
intermediate_dropout_rate: Dropout rate used after the intermediate layers. intermediate_dropout_rate: Dropout rate used after the intermediate layers.
dtype: Type for the dense layer. dtype: Type for the dense layer.
""" """
transpose_batch_sequence: bool transpose_batch_sequence: bool
intermediate_dim: int = 2048 intermediate_dim: int = 2048
activations: Sequence[Union[str, Callable]] = ('relu',) activations: Sequence[Union[str, Callable]] = ("relu",)
kernel_init: Initializer = None kernel_init: Initializer = None
intermediate_dropout_rate: float = 0.1 intermediate_dropout_rate: float = 0.1
intermediate_dropout_dims: Sequence[int] = () intermediate_dropout_dims: Sequence[int] = ()
...@@ -285,7 +287,7 @@ class MlpBlock(nn.Module): ...@@ -285,7 +287,7 @@ class MlpBlock(nn.Module):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -296,49 +298,57 @@ class MlpBlock(nn.Module): ...@@ -296,49 +298,57 @@ class MlpBlock(nn.Module):
activations = [] activations = []
if self.fuse_wi: if self.fuse_wi:
dense_name = 'wi' dense_name = "wi"
num_activations = len(self.activations) num_activations = len(self.activations)
x = DenseGeneral(self.intermediate_dim * num_activations, x = DenseGeneral(
dtype=self.dtype, self.intermediate_dim * num_activations,
kernel_init=self.kernel_init, dtype=self.dtype,
kernel_axes=('embed', 'mlp'), kernel_init=self.kernel_init,
use_bias=self.use_bias, kernel_axes=("embed", "mlp"),
bias_axes=('mlp'), use_bias=self.use_bias,
name=dense_name)(inputs) bias_axes="mlp",
name=dense_name,
)(inputs)
x = jnp.split(x, num_activations, axis=-1) x = jnp.split(x, num_activations, axis=-1)
for idx, act_fn in enumerate(self.activations): for idx, act_fn in enumerate(self.activations):
x_i = _convert_to_activation_function(act_fn)(x[idx]) x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i) activations.append(x_i)
else: else:
for idx, act_fn in enumerate(self.activations): for idx, act_fn in enumerate(self.activations):
dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}' dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}"
x = DenseGeneral(self.intermediate_dim, x = DenseGeneral(
dtype=self.dtype, self.intermediate_dim,
kernel_init=self.kernel_init, dtype=self.dtype,
kernel_axes=('embed', 'mlp'), kernel_init=self.kernel_init,
use_bias=self.use_bias, kernel_axes=("embed", "mlp"),
bias_axes=('mlp'), use_bias=self.use_bias,
name=dense_name)(inputs) bias_axes="mlp",
name=dense_name,
)(inputs)
x = _convert_to_activation_function(act_fn)(x) x = _convert_to_activation_function(act_fn)(x)
activations.append(x) activations.append(x)
# Take elementwise product of above intermediate activations. # Take elementwise product of above intermediate activations.
x = functools.reduce(operator.mul, activations) x = functools.reduce(operator.mul, activations)
# Apply dropout and final dense output projection. # Apply dropout and final dense output projection.
x = nn.Dropout(rate=self.intermediate_dropout_rate, x = nn.Dropout(
broadcast_dims=self.intermediate_dropout_dims)( rate=self.intermediate_dropout_rate, broadcast_dims=self.intermediate_dropout_dims
x, deterministic=deterministic) # Broadcast along length. )(
x, deterministic=deterministic
) # Broadcast along length.
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'mlp')) x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "mlp"))
else: else:
x = nn_partitioning.with_sharding_constraint(x, ('batch', 'length', 'mlp')) x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "mlp"))
output = DenseGeneral(inputs.shape[-1], output = DenseGeneral(
dtype=self.dtype, inputs.shape[-1],
kernel_init=self.kernel_init, dtype=self.dtype,
kernel_axes=('mlp', 'embed'), kernel_init=self.kernel_init,
use_bias=self.use_bias, kernel_axes=("mlp", "embed"),
bias_axes=('embed'), use_bias=self.use_bias,
name='wo')(x) bias_axes="embed",
name="wo",
)(x)
return output return output
...@@ -351,7 +361,7 @@ def apply_rotary_pos_emb_alternate( ...@@ -351,7 +361,7 @@ def apply_rotary_pos_emb_alternate(
embedding_dim = inputs.shape[-1] embedding_dim = inputs.shape[-1]
half_embedding_dim = embedding_dim // 2 half_embedding_dim = embedding_dim // 2
fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim
timescale = min_timescale * (max_timescale / min_timescale)**fraction timescale = min_timescale * (max_timescale / min_timescale) ** fraction
timescale = jnp.expand_dims(timescale, axis=tuple(range(inputs.ndim - 1))) timescale = jnp.expand_dims(timescale, axis=tuple(range(inputs.ndim - 1)))
position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim))) position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
sinusoid_inp = position / timescale sinusoid_inp = position / timescale
...@@ -386,7 +396,7 @@ def apply_rotary_pos_emb_consecutive( ...@@ -386,7 +396,7 @@ def apply_rotary_pos_emb_consecutive(
inputs_shifted_left, inputs_shifted_left,
) )
fraction = jnp.repeat(fraction, 2) fraction = jnp.repeat(fraction, 2)
timescale = min_timescale * (max_timescale / min_timescale)**fraction timescale = min_timescale * (max_timescale / min_timescale) ** fraction
position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim))) position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
...@@ -415,89 +425,96 @@ class MultiHeadAttention(nn.Module): ...@@ -415,89 +425,96 @@ class MultiHeadAttention(nn.Module):
kernel_init: initializer for the kernel of the Dense layers. kernel_init: initializer for the kernel of the Dense layers.
float32_logits: bool, if True then compute logits in float32 to avoid float32_logits: bool, if True then compute logits in float32 to avoid
numerical issues with bfloat16. numerical issues with bfloat16.
""" """
num_heads: int = 8 num_heads: int = 8
num_gqa_groups: int | None = None num_gqa_groups: int | None = None
head_dim: int = 64 head_dim: int = 64
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
dtype: DType = jnp.float32 dtype: DType = jnp.float32
dropout_rate: float = 0. dropout_rate: float = 0.0
kernel_init: Initializer = None kernel_init: Initializer = None
float32_logits: bool = False # computes logits in float32 for stability. float32_logits: bool = False # computes logits in float32 for stability.
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
enable_rotary_pos_emb: bool = False enable_rotary_pos_emb: bool = False
rotary_pos_emb_group_method: str = 'consecutive' rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv: bool = True fuse_qkv: bool = True
use_bias: bool = False use_bias: bool = False
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal') self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal")
if self.num_gqa_groups is None: if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads self.num_gqa_groups = self.num_attention_heads
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
def __call__(self, def __call__(
inputs_q: Array, self,
inputs_kv: Array, inputs_q: Array,
mask: Optional[Array] = None, inputs_kv: Array,
bias: Optional[Array] = None, mask: Optional[Array] = None,
*, bias: Optional[Array] = None,
decode: bool = False, *,
deterministic: bool = False) -> Array: decode: bool = False,
deterministic: bool = False,
) -> Array:
"""Applies multi-head dot product attention on the input data. """Applies multi-head dot product attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors, Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector. applies dot-product attention and project the results to an output vector.
There are two modes: decoding and non-decoding (e.g., training). The mode is There are two modes: decoding and non-decoding (e.g., training). The mode is
determined by `decode` argument. For decoding, this method is called twice, determined by `decode` argument. For decoding, this method is called twice,
first to initialize the cache and then for an actual decoding process. The first to initialize the cache and then for an actual decoding process. The
two calls are differentiated by the presence of 'cached_key' in the variable two calls are differentiated by the presence of 'cached_key' in the variable
dict. In the cache initialization stage, the cache variables are initialized dict. In the cache initialization stage, the cache variables are initialized
as zeros and will be filled in the subsequent decoding process. as zeros and will be filled in the subsequent decoding process.
In the cache initialization call, `inputs_q` has a shape [batch, length, In the cache initialization call, `inputs_q` has a shape [batch, length,
q_features] and `inputs_kv`: [batch, length, kv_features]. During the q_features] and `inputs_kv`: [batch, length, kv_features]. During the
incremental decoding stage, query, key and value all have the shape [batch, incremental decoding stage, query, key and value all have the shape [batch,
1, qkv_features] corresponding to a single step. 1, qkv_features] corresponding to a single step.
Args: Args:
inputs_q: input queries of shape `[batch, q_length, q_features]`. inputs_q: input queries of shape `[batch, q_length, q_features]`.
inputs_kv: key/values of shape `[batch, kv_length, kv_features]`. inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`. mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
decode: Whether to prepare and use an autoregressive cache. decode: Whether to prepare and use an autoregressive cache.
deterministic: Disables dropout if set to True. deterministic: Disables dropout if set to True.
Returns: Returns:
output of shape `[batch, length, q_features]`. output of shape `[batch, length, q_features]`.
""" """
q_projection = functools.partial(DenseGeneral, q_projection = functools.partial(
axis=-1, DenseGeneral,
features=self.num_heads * self.head_dim, axis=-1,
kernel_axes=('embed', 'joined_kv'), features=self.num_heads * self.head_dim,
use_bias=self.use_bias, kernel_axes=("embed", "joined_kv"),
bias_axes=('joined_kv'), use_bias=self.use_bias,
dtype=self.dtype) bias_axes="joined_kv",
dtype=self.dtype,
kv_projection = functools.partial(DenseGeneral, )
axis=-1,
features=self.num_gqa_groups * self.head_dim, kv_projection = functools.partial(
kernel_axes=('embed', 'joined_kv'), DenseGeneral,
use_bias=self.use_bias, axis=-1,
bias_axes=('joined_kv'), features=self.num_gqa_groups * self.head_dim,
dtype=self.dtype) kernel_axes=("embed", "joined_kv"),
use_bias=self.use_bias,
bias_axes="joined_kv",
dtype=self.dtype,
)
# NOTE: T5 does not explicitly rescale the attention logits by # NOTE: T5 does not explicitly rescale the attention logits by
# 1/sqrt(depth_kq)! This is folded into the initializers of the # 1/sqrt(depth_kq)! This is folded into the initializers of the
# linear transformations, which is equivalent under Adafactor # linear transformations, which is equivalent under Adafactor
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
query_init = lambda *args: self.kernel_init(*args) / (depth_scaling query_init = lambda *args: self.kernel_init(*args) / (
if self.scaled_query_init else 1.0) depth_scaling if self.scaled_query_init else 1.0
)
# Project inputs_q to multi-headed q/k/v # Project inputs_q to multi-headed q/k/v
# dimensions are then [batch, length, num_heads, head_dim] # dimensions are then [batch, length, num_heads, head_dim]
...@@ -515,39 +532,45 @@ class MultiHeadAttention(nn.Module): ...@@ -515,39 +532,45 @@ class MultiHeadAttention(nn.Module):
return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype) return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype)
is_self_attn = (inputs_q is inputs_kv) is_self_attn = inputs_q is inputs_kv
is_gqa = (self.num_heads != self.num_gqa_groups) is_gqa = self.num_heads != self.num_gqa_groups
is_qkvpack = (is_self_attn and not is_gqa) is_qkvpack = is_self_attn and not is_gqa
if self.fuse_qkv: if self.fuse_qkv:
if is_qkvpack: if is_qkvpack:
qkv_proj = DenseGeneral(axis=-1, qkv_proj = DenseGeneral(
features=self.num_heads * self.head_dim * 3, axis=-1,
kernel_axes=('embed', 'joined_kv'), features=self.num_heads * self.head_dim * 3,
kernel_init=qkv_init, kernel_axes=("embed", "joined_kv"),
use_bias=self.use_bias, kernel_init=qkv_init,
bias_axes=('joined_kv'), use_bias=self.use_bias,
name='qkv', bias_axes="joined_kv",
dtype=self.dtype)(inputs_kv) name="qkv",
dtype=self.dtype,
)(inputs_kv)
query, key, value = jnp.split( query, key, value = jnp.split(
qkv_proj, [self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2], qkv_proj,
axis=-1) [self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2],
axis=-1,
)
else: else:
query = q_projection(kernel_init=query_init, name='query')(inputs_q) query = q_projection(kernel_init=query_init, name="query")(inputs_q)
kv_proj = DenseGeneral(axis=-1, kv_proj = DenseGeneral(
features=self.num_gqa_groups * self.head_dim * 2, axis=-1,
kernel_axes=('embed', 'joined_kv'), features=self.num_gqa_groups * self.head_dim * 2,
kernel_init=self.kernel_init, kernel_axes=("embed", "joined_kv"),
use_bias=self.use_bias, kernel_init=self.kernel_init,
bias_axes=('joined_kv'), use_bias=self.use_bias,
name='kv', bias_axes="joined_kv",
dtype=self.dtype)(inputs_kv) name="kv",
dtype=self.dtype,
)(inputs_kv)
key, value = jnp.split(kv_proj, [self.num_gqa_groups * self.head_dim], axis=-1) key, value = jnp.split(kv_proj, [self.num_gqa_groups * self.head_dim], axis=-1)
else: else:
query = q_projection(kernel_init=query_init, name='query')(inputs_q) query = q_projection(kernel_init=query_init, name="query")(inputs_q)
key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv) key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv) value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
if self.enable_rotary_pos_emb: if self.enable_rotary_pos_emb:
batch_dim = 1 if self.transpose_batch_sequence else 0 batch_dim = 1 if self.transpose_batch_sequence else 0
...@@ -556,7 +579,7 @@ class MultiHeadAttention(nn.Module): ...@@ -556,7 +579,7 @@ class MultiHeadAttention(nn.Module):
q_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim) q_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
k_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim) k_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
if self.rotary_pos_emb_group_method == 'alternate': if self.rotary_pos_emb_group_method == "alternate":
apply_rotary_pos_emb = apply_rotary_pos_emb_alternate apply_rotary_pos_emb = apply_rotary_pos_emb_alternate
else: else:
apply_rotary_pos_emb = apply_rotary_pos_emb_consecutive apply_rotary_pos_emb = apply_rotary_pos_emb_consecutive
...@@ -571,33 +594,40 @@ class MultiHeadAttention(nn.Module): ...@@ -571,33 +594,40 @@ class MultiHeadAttention(nn.Module):
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
query = nn_partitioning.with_sharding_constraint(query, query = nn_partitioning.with_sharding_constraint(
('length', 'batch', 'heads', 'kv')) query, ("length", "batch", "heads", "kv")
key = nn_partitioning.with_sharding_constraint(key, ('length', 'batch', 'heads', 'kv')) )
value = nn_partitioning.with_sharding_constraint(value, key = nn_partitioning.with_sharding_constraint(key, ("length", "batch", "heads", "kv"))
('length', 'batch', 'heads', 'kv')) value = nn_partitioning.with_sharding_constraint(
value, ("length", "batch", "heads", "kv")
)
else: else:
query = nn_partitioning.with_sharding_constraint(query, query = nn_partitioning.with_sharding_constraint(
('batch', 'length', 'heads', 'kv')) query, ("batch", "length", "heads", "kv")
key = nn_partitioning.with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv')) )
value = nn_partitioning.with_sharding_constraint(value, key = nn_partitioning.with_sharding_constraint(key, ("batch", "length", "heads", "kv"))
('batch', 'length', 'heads', 'kv')) value = nn_partitioning.with_sharding_constraint(
value, ("batch", "length", "heads", "kv")
)
if decode: if decode:
# Detect if we're initializing by absence of existing cache data. # Detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable('cache', 'cached_key') is_initialized = self.has_variable("cache", "cached_key")
# The key and value have dimension [batch, length, num_heads, head_dim], # The key and value have dimension [batch, length, num_heads, head_dim],
# but we cache them as [batch, num_heads, head_dim, length] as a TPU # but we cache them as [batch, num_heads, head_dim, length] as a TPU
# fusion optimization. This also enables the "scatter via one-hot # fusion optimization. This also enables the "scatter via one-hot
# broadcast" trick, which means we do a one-hot broadcast instead of a # broadcast" trick, which means we do a one-hot broadcast instead of a
# scatter/gather operations, resulting in a 3-4x speedup in practice. # scatter/gather operations, resulting in a 3-4x speedup in practice.
swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3]) swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3])
cached_key = self.variable('cache', 'cached_key', jnp.zeros, swap_dims(key.shape), cached_key = self.variable(
key.dtype) "cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype
cached_value = self.variable('cache', 'cached_value', jnp.zeros, swap_dims(value.shape), )
value.dtype) cached_value = self.variable(
cache_index = self.variable('cache', 'cache_index', "cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype
lambda: jnp.array(0, dtype=jnp.int32)) )
cache_index = self.variable(
"cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
)
if is_initialized: if is_initialized:
batch, num_heads, head_dim, length = cached_key.value.shape batch, num_heads, head_dim, length = cached_key.value.shape
# During fast autoregressive decoding, we feed one position at a time, # During fast autoregressive decoding, we feed one position at a time,
...@@ -606,8 +636,9 @@ class MultiHeadAttention(nn.Module): ...@@ -606,8 +636,9 @@ class MultiHeadAttention(nn.Module):
expected_shape = (batch, 1, num_heads, head_dim) expected_shape = (batch, 1, num_heads, head_dim)
if expected_shape != query.shape: if expected_shape != query.shape:
raise ValueError( raise ValueError(
'Autoregressive cache shape error, ' "Autoregressive cache shape error, "
f"expected query shape {expected_shape} instead got {query.shape}.") f"expected query shape {expected_shape} instead got {query.shape}."
)
# Create a OHE of the current index. NOTE: the index is increased below. # Create a OHE of the current index. NOTE: the index is increased below.
cur_index = cache_index.value cur_index = cache_index.value
...@@ -638,11 +669,13 @@ class MultiHeadAttention(nn.Module): ...@@ -638,11 +669,13 @@ class MultiHeadAttention(nn.Module):
jnp.logical_not(mask), jnp.logical_not(mask),
jnp.broadcast_to( jnp.broadcast_to(
jnp.arange(length) <= cur_index, jnp.arange(length) <= cur_index,
# (1, 1, length) represent (head dim, query length, key length) # (1, 1, length) represent (head dim, query length, key length)
# query length is 1 because during decoding we deal with one # query length is 1 because during decoding we deal with one
# index. # index.
# The same mask is applied to all batch elements and heads. # The same mask is applied to all batch elements and heads.
(batch, 1, 1, length))) (batch, 1, 1, length),
),
)
# Grab the correct relative attention bias during decoding. This is # Grab the correct relative attention bias during decoding. This is
# only required during single step decoding. # only required during single step decoding.
...@@ -650,15 +683,18 @@ class MultiHeadAttention(nn.Module): ...@@ -650,15 +683,18 @@ class MultiHeadAttention(nn.Module):
# The bias is a full attention matrix, but during decoding we only # The bias is a full attention matrix, but during decoding we only
# have to take a slice of it. # have to take a slice of it.
# This is equivalent to bias[..., cur_index:cur_index+1, :]. # This is equivalent to bias[..., cur_index:cur_index+1, :].
bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0), bias = dynamic_vector_slice_in_dim(
jnp.reshape(cur_index, (-1)), 1, -2) jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
)
# Convert the boolean attention mask to an attention bias. # Convert the boolean attention mask to an attention bias.
if mask is not None: if mask is not None:
# attention mask in the form of attention bias # attention mask in the form of attention bias
attention_bias = lax.select(mask > 0, attention_bias = lax.select(
jnp.full(mask.shape, 0.).astype(self.dtype), mask > 0,
jnp.full(mask.shape, -1e10).astype(self.dtype)) jnp.full(mask.shape, 0.0).astype(self.dtype),
jnp.full(mask.shape, -1e10).astype(self.dtype),
)
else: else:
attention_bias = None attention_bias = None
...@@ -667,41 +703,41 @@ class MultiHeadAttention(nn.Module): ...@@ -667,41 +703,41 @@ class MultiHeadAttention(nn.Module):
attention_bias = combine_biases(attention_bias, bias) attention_bias = combine_biases(attention_bias, bias)
# Apply attention. # Apply attention.
x = DotProductAttention(transpose_batch_sequence=self.transpose_batch_sequence, x = DotProductAttention(
scale_attn_logits=self.scale_attn_logits, transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.dropout_rate, scale_attn_logits=self.scale_attn_logits,
dtype=self.dtype, dropout_rate=self.dropout_rate,
float32_logits=self.float32_logits)(query, dtype=self.dtype,
key, float32_logits=self.float32_logits,
value, )(query, key, value, bias=attention_bias, deterministic=deterministic)
bias=attention_bias,
deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'joined_kv')) x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "joined_kv"))
else: else:
x = nn_partitioning.with_sharding_constraint(x, ('batch', 'length', 'joined_kv')) x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "joined_kv"))
# Back to the original inputs dimensions. # Back to the original inputs dimensions.
out = DenseGeneral( out = DenseGeneral(
features=inputs_q.shape[-1], # output dim is set to the input dim. features=inputs_q.shape[-1], # output dim is set to the input dim.
axis=-1, axis=-1,
kernel_init=self.kernel_init, kernel_init=self.kernel_init,
kernel_axes=('joined_kv', 'embed'), kernel_axes=("joined_kv", "embed"),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_axes=('embed'), bias_axes="embed",
dtype=self.dtype, dtype=self.dtype,
name='out')(x) name="out",
)(x)
return out return out
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
"""T5 Layer normalization operating on the last axis of the input data.""" """T5 Layer normalization operating on the last axis of the input data."""
epsilon: float = 1e-6 epsilon: float = 1e-6
dtype: Any = jnp.float32 dtype: Any = jnp.float32
layernorm_type: str = 'layernorm' layernorm_type: str = "layernorm"
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
scale_init: Initializer = None scale_init: Initializer = None
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
...@@ -721,29 +757,27 @@ class LayerNorm(nn.Module): ...@@ -721,29 +757,27 @@ class LayerNorm(nn.Module):
x = jnp.asarray(x, jnp.float32) x = jnp.asarray(x, jnp.float32)
features = x.shape[-1] features = x.shape[-1]
scale = nn_partitioning.param_with_axes('scale', scale = nn_partitioning.param_with_axes(
self.scale_init, (features,), "scale", self.scale_init, (features,), jnp.float32, axes=("embed",)
jnp.float32, )
axes=('embed',))
scale = jnp.asarray(scale, self.dtype) scale = jnp.asarray(scale, self.dtype)
if self.layernorm_type == 'layernorm': if self.layernorm_type == "layernorm":
mean = jnp.mean(x, axis=-1, keepdims=True) mean = jnp.mean(x, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True) var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True)
y = (x - mean) * lax.rsqrt(var + self.epsilon) y = (x - mean) * lax.rsqrt(var + self.epsilon)
bias = nn_partitioning.param_with_axes('ln_bias', bias = nn_partitioning.param_with_axes(
self.bias_init, (features,), "ln_bias", self.bias_init, (features,), jnp.float32, axes=("embed",)
jnp.float32, )
axes=('embed',))
bias = jnp.asarray(bias, self.dtype) bias = jnp.asarray(bias, self.dtype)
if not self.zero_centered_gamma: if not self.zero_centered_gamma:
z = y * scale + bias z = y * scale + bias
else: else:
z = y * (scale + 1.) + bias z = y * (scale + 1.0) + bias
else: else:
assert self.layernorm_type == 'rmsnorm' assert self.layernorm_type == "rmsnorm"
assert not self.zero_centered_gamma assert not self.zero_centered_gamma
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
y = x * lax.rsqrt(mean2 + self.epsilon) y = x * lax.rsqrt(mean2 + self.epsilon)
...@@ -755,16 +789,17 @@ class LayerNorm(nn.Module): ...@@ -755,16 +789,17 @@ class LayerNorm(nn.Module):
class RelativePositionBiases(nn.Module): class RelativePositionBiases(nn.Module):
"""Adds T5-style relative positional embeddings to the attention logits. """Adds T5-style relative positional embeddings to the attention logits.
Attributes: Attributes:
num_buckets: Number of buckets to bucket distances between key and query num_buckets: Number of buckets to bucket distances between key and query
positions into. positions into.
max_distance: Maximum distance before everything is lumped into the last max_distance: Maximum distance before everything is lumped into the last
distance bucket. distance bucket.
num_heads: Number of heads in the attention layer. Each head will get a num_heads: Number of heads in the attention layer. Each head will get a
different relative position weighting. different relative position weighting.
dtype: Type of arrays through this module. dtype: Type of arrays through this module.
embedding_init: initializer for relative embedding table. embedding_init: initializer for relative embedding table.
""" """
num_buckets: int num_buckets: int
max_distance: int max_distance: int
num_heads: int num_heads: int
...@@ -772,33 +807,32 @@ class RelativePositionBiases(nn.Module): ...@@ -772,33 +807,32 @@ class RelativePositionBiases(nn.Module):
embedding_init: Callable[..., Array] = nn.linear.default_embed_init embedding_init: Callable[..., Array] = nn.linear.default_embed_init
@staticmethod @staticmethod
def _relative_position_bucket(relative_position, def _relative_position_bucket(
bidirectional=True, relative_position, bidirectional=True, num_buckets=32, max_distance=128
num_buckets=32, ):
max_distance=128):
"""Translate relative position to a bucket number for relative attention. """Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e. The relative position is defined as memory_position - query_position, i.e.
the distance in tokens from the attending position to the attended-to the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are position. If bidirectional=False, then positive relative positions are
invalid. invalid.
We use smaller buckets for small absolute relative_position and larger We use smaller buckets for small absolute relative_position and larger
buckets for larger absolute relative_positions. All relative buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions >=max_distance map to the same bucket. All relative
positions <=-max_distance map to the same bucket. This should allow for positions <=-max_distance map to the same bucket. This should allow for
more graceful generalization to longer sequences than the model has been more graceful generalization to longer sequences than the model has been
trained on. trained on.
Args: Args:
relative_position: an int32 array relative_position: an int32 array
bidirectional: a boolean - whether the attention is bidirectional bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer num_buckets: an integer
max_distance: an integer max_distance: an integer
Returns: Returns:
a Tensor with the same shape as relative_position, containing int32 a Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets) values in the range [0, num_buckets)
""" """
ret = 0 ret = 0
n = -relative_position n = -relative_position
if bidirectional: if bidirectional:
...@@ -811,8 +845,10 @@ class RelativePositionBiases(nn.Module): ...@@ -811,8 +845,10 @@ class RelativePositionBiases(nn.Module):
max_exact = num_buckets // 2 max_exact = num_buckets // 2
is_small = n < max_exact is_small = n < max_exact
val_if_large = max_exact + ( val_if_large = max_exact + (
np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) / np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps)
np.log(max_distance / max_exact) * (num_buckets - max_exact)).astype(np.int32) / np.log(max_distance / max_exact)
* (num_buckets - max_exact)
).astype(np.int32)
val_if_large = np.minimum(val_if_large, num_buckets - 1) val_if_large = np.minimum(val_if_large, num_buckets - 1)
ret += np.where(is_small, n, val_if_large) ret += np.where(is_small, n, val_if_large)
return ret return ret
...@@ -821,27 +857,31 @@ class RelativePositionBiases(nn.Module): ...@@ -821,27 +857,31 @@ class RelativePositionBiases(nn.Module):
def __call__(self, qlen, klen, bidirectional=True): def __call__(self, qlen, klen, bidirectional=True):
"""Produce relative position embedding attention biases. """Produce relative position embedding attention biases.
Args: Args:
qlen: attention query length. qlen: attention query length.
klen: attention key length. klen: attention key length.
bidirectional: whether to allow positive memory-query relative position bidirectional: whether to allow positive memory-query relative position
embeddings. embeddings.
Returns: Returns:
output: `(1, len, q_len, k_len)` attention bias output: `(1, len, q_len, k_len)` attention bias
""" """
context_position = np.arange(qlen, dtype=jnp.int32)[:, None] context_position = np.arange(qlen, dtype=jnp.int32)[:, None]
memory_position = np.arange(klen, dtype=jnp.int32)[None, :] memory_position = np.arange(klen, dtype=jnp.int32)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen) relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket(relative_position, rp_bucket = self._relative_position_bucket(
bidirectional=bidirectional, relative_position,
num_buckets=self.num_buckets, bidirectional=bidirectional,
max_distance=self.max_distance) num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
relative_attention_bias = nn_partitioning.param_with_axes( relative_attention_bias = nn_partitioning.param_with_axes(
'rel_embedding', "rel_embedding",
self.embedding_init, (self.num_heads, self.num_buckets), self.embedding_init,
(self.num_heads, self.num_buckets),
jnp.float32, jnp.float32,
axes=('heads', 'relpos_buckets')) axes=("heads", "relpos_buckets"),
)
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
# Instead of using a slow gather, we create a leading-dimension one-hot # Instead of using a slow gather, we create a leading-dimension one-hot
...@@ -855,9 +895,8 @@ class RelativePositionBiases(nn.Module): ...@@ -855,9 +895,8 @@ class RelativePositionBiases(nn.Module):
values = lax.dot_general( values = lax.dot_general(
relative_attention_bias, relative_attention_bias,
rp_bucket_one_hot, rp_bucket_one_hot,
( (((1,), (0,)), ((), ())), # rhs, lhs contracting dims
((1,), (0,)), # rhs, lhs contracting dims ) # no batched dims
((), ()))) # no batched dims
# Add a singleton batch dimension. # Add a singleton batch dimension.
# --> shape (1, num_heads, qlen, klen) # --> shape (1, num_heads, qlen, klen)
return values[jnp.newaxis, ...] return values[jnp.newaxis, ...]
...@@ -865,6 +904,7 @@ class RelativePositionBiases(nn.Module): ...@@ -865,6 +904,7 @@ class RelativePositionBiases(nn.Module):
class EncoderLayer(nn.Module): class EncoderLayer(nn.Module):
"""Transformer encoder layer.""" """Transformer encoder layer."""
enable_relative_embedding: bool = True enable_relative_embedding: bool = True
relative_embedding: nn.Module = None relative_embedding: nn.Module = None
num_attention_heads: int = 8 num_attention_heads: int = 8
...@@ -880,17 +920,17 @@ class EncoderLayer(nn.Module): ...@@ -880,17 +920,17 @@ class EncoderLayer(nn.Module):
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
mlp_dim: int = 2048 mlp_dim: int = 2048
mlp_activations: Sequence[str] = ('relu',) mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False use_bias: bool = False
dtype: Any = jnp.float32 dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
layernorm_type: str = 'layernorm' layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6 layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
output_layernorm: bool = False output_layernorm: bool = False
drop_path: float = 0.0 drop_path: float = 0.0
enable_rotary_pos_emb: bool = False enable_rotary_pos_emb: bool = False
rotary_pos_emb_group_method: str = 'consecutive' rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
fuse_mlp_wi: bool = True fuse_mlp_wi: bool = True
self_attn_bias_type: Any = None self_attn_bias_type: Any = None
...@@ -903,20 +943,21 @@ class EncoderLayer(nn.Module): ...@@ -903,20 +943,21 @@ class EncoderLayer(nn.Module):
@nn.compact @nn.compact
def __call__(self, inputs, encoder_mask=None, deterministic=False): def __call__(self, inputs, encoder_mask=None, deterministic=False):
del self.self_attn_mask_type # dummy, just align to TE's impl del self.self_attn_mask_type # dummy, just align to TE's impl
# Relative position embedding as attention biases. # Relative position embedding as attention biases.
sequence_dim = 0 if self.transpose_batch_sequence else 1 sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim batch_dim = 1 - sequence_dim
if self.enable_relative_embedding: if self.enable_relative_embedding:
if self.relative_embedding is None: if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32, rel_emb = RelativePositionBiases(
max_distance=128, num_buckets=32,
num_heads=self.num_attention_heads, max_distance=128,
dtype=self.dtype, num_heads=self.num_attention_heads,
embedding_init=nn.initializers.variance_scaling( dtype=self.dtype,
1.0, 'fan_avg', 'uniform'), embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
name='relpos_bias') name="relpos_bias",
)
else: else:
rel_emb = self.relative_embedding rel_emb = self.relative_embedding
encoder_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True) encoder_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
...@@ -928,11 +969,13 @@ class EncoderLayer(nn.Module): ...@@ -928,11 +969,13 @@ class EncoderLayer(nn.Module):
if not self.output_layernorm: if not self.output_layernorm:
# Attention block. # Attention block.
x = LayerNorm(layernorm_type=self.layernorm_type, x = LayerNorm(
epsilon=self.layernorm_epsilon, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, epsilon=self.layernorm_epsilon,
dtype=self.dtype, zero_centered_gamma=self.zero_centered_gamma,
name="pre_attention_layer_norm")(inputs) dtype=self.dtype,
name="pre_attention_layer_norm",
)(inputs)
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
residual = x residual = x
...@@ -940,39 +983,41 @@ class EncoderLayer(nn.Module): ...@@ -940,39 +983,41 @@ class EncoderLayer(nn.Module):
x = inputs x = inputs
# [batch, length, emb_dim] -> [batch, length, emb_dim] # [batch, length, emb_dim] -> [batch, length, emb_dim]
x = MultiHeadAttention(num_heads=self.num_attention_heads, x = MultiHeadAttention(
num_gqa_groups=self.num_gqa_groups, num_heads=self.num_attention_heads,
dtype=self.dtype, num_gqa_groups=self.num_gqa_groups,
head_dim=self.head_dim, dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence, head_dim=self.head_dim,
dropout_rate=self.attention_dropout, transpose_batch_sequence=self.transpose_batch_sequence,
float32_logits=self.float32_attention_logits, dropout_rate=self.attention_dropout,
scale_attn_logits=self.scale_attn_logits, float32_logits=self.float32_attention_logits,
scaled_query_init=self.scaled_query_init, scale_attn_logits=self.scale_attn_logits,
fuse_qkv=self.fuse_qkv_params, scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb, fuse_qkv=self.fuse_qkv_params,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, enable_rotary_pos_emb=self.enable_rotary_pos_emb,
use_bias=self.use_bias, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
name='attention')(x, use_bias=self.use_bias,
x, name="attention",
encoder_mask, )(x, x, encoder_mask, encoder_bias, deterministic=deterministic)
encoder_bias, x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
deterministic=deterministic) x, deterministic=deterministic
x = nn.Dropout(rate=self.hidden_dropout, )
broadcast_dims=self.hidden_dropout_dims)(x, deterministic=deterministic)
if self.drop_path > 0.0: if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim) drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
x = nn.Dropout(rate=self.drop_path, x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
broadcast_dims=drop_path_shape)(x, deterministic=deterministic) x, deterministic=deterministic
)
x = x + residual x = x + residual
# MLP block. # MLP block.
residual = x residual = x
y = LayerNorm(layernorm_type=self.layernorm_type, y = LayerNorm(
epsilon=self.layernorm_epsilon, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, epsilon=self.layernorm_epsilon,
dtype=self.dtype, zero_centered_gamma=self.zero_centered_gamma,
name='pre_mlp_layer_norm')(x) dtype=self.dtype,
name="pre_mlp_layer_norm",
)(x)
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
residual = y residual = y
...@@ -987,27 +1032,32 @@ class EncoderLayer(nn.Module): ...@@ -987,27 +1032,32 @@ class EncoderLayer(nn.Module):
use_bias=self.use_bias, use_bias=self.use_bias,
dtype=self.dtype, dtype=self.dtype,
fuse_wi=self.fuse_mlp_wi, fuse_wi=self.fuse_mlp_wi,
name='mlp', name="mlp",
)(y, deterministic=deterministic) )(y, deterministic=deterministic)
y = nn.Dropout(rate=self.hidden_dropout, y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
broadcast_dims=self.hidden_dropout_dims)(y, deterministic=deterministic) y, deterministic=deterministic
)
if self.drop_path > 0.0: if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim) drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim)
y = nn.Dropout(rate=self.drop_path, y = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
broadcast_dims=drop_path_shape)(y, deterministic=deterministic) y, deterministic=deterministic
)
y = y + residual y = y + residual
if self.output_layernorm: if self.output_layernorm:
y = LayerNorm(layernorm_type=self.layernorm_type, y = LayerNorm(
epsilon=self.layernorm_epsilon, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, epsilon=self.layernorm_epsilon,
dtype=self.dtype, zero_centered_gamma=self.zero_centered_gamma,
name="output_layernorm")(y) dtype=self.dtype,
name="output_layernorm",
)(y)
return y return y
class DecoderLayer(nn.Module): class DecoderLayer(nn.Module):
"""Transformer decoder layer that attends to the encoder.""" """Transformer decoder layer that attends to the encoder."""
enable_relative_embedding: bool = True enable_relative_embedding: bool = True
relative_embedding: nn.Module = None relative_embedding: nn.Module = None
num_attention_heads: int = 8 num_attention_heads: int = 8
...@@ -1023,17 +1073,17 @@ class DecoderLayer(nn.Module): ...@@ -1023,17 +1073,17 @@ class DecoderLayer(nn.Module):
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
mlp_dim: int = 2048 mlp_dim: int = 2048
mlp_activations: Sequence[str] = ('relu',) mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False use_bias: bool = False
dtype: Any = jnp.float32 dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False output_layernorm: bool = False
layernorm_type: str = 'layernorm' layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6 layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
drop_path: float = 0.0 drop_path: float = 0.0
enable_rotary_pos_emb: bool = False enable_rotary_pos_emb: bool = False
rotary_pos_emb_group_method: str = 'consecutive' rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
fuse_mlp_wi: bool = True fuse_mlp_wi: bool = True
self_attn_bias_type: Any = None self_attn_bias_type: Any = None
...@@ -1045,15 +1095,17 @@ class DecoderLayer(nn.Module): ...@@ -1045,15 +1095,17 @@ class DecoderLayer(nn.Module):
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
def __call__(self, def __call__(
inputs, self,
encoded, inputs,
decoder_mask=None, encoded,
encoder_decoder_mask=None, decoder_mask=None,
deterministic=False, encoder_decoder_mask=None,
decode=False, deterministic=False,
max_decode_length=None): decode=False,
del self.self_attn_mask_type # dummy, just align to TE's impl max_decode_length=None,
):
del self.self_attn_mask_type # dummy, just align to TE's impl
# Relative position embedding as attention biases. # Relative position embedding as attention biases.
sequence_dim = 0 if self.transpose_batch_sequence else 1 sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim batch_dim = 1 - sequence_dim
...@@ -1061,13 +1113,14 @@ class DecoderLayer(nn.Module): ...@@ -1061,13 +1113,14 @@ class DecoderLayer(nn.Module):
if self.enable_relative_embedding: if self.enable_relative_embedding:
l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim] l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim]
if self.relative_embedding is None: if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32, rel_emb = RelativePositionBiases(
max_distance=128, num_buckets=32,
num_heads=self.num_attention_heads, max_distance=128,
dtype=self.dtype, num_heads=self.num_attention_heads,
embedding_init=nn.initializers.variance_scaling( dtype=self.dtype,
1.0, 'fan_avg', 'uniform'), embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
name='relpos_bias') name="relpos_bias",
)
else: else:
rel_emb = self.relative_embedding rel_emb = self.relative_embedding
decoder_bias = rel_emb(l, l, False) decoder_bias = rel_emb(l, l, False)
...@@ -1079,11 +1132,13 @@ class DecoderLayer(nn.Module): ...@@ -1079,11 +1132,13 @@ class DecoderLayer(nn.Module):
if not self.output_layernorm: if not self.output_layernorm:
# Attention block. # Attention block.
x = LayerNorm(layernorm_type=self.layernorm_type, x = LayerNorm(
epsilon=self.layernorm_epsilon, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, epsilon=self.layernorm_epsilon,
dtype=self.dtype, zero_centered_gamma=self.zero_centered_gamma,
name="pre_self_attention_layer_norm")(inputs) dtype=self.dtype,
name="pre_self_attention_layer_norm",
)(inputs)
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
residual = x residual = x
...@@ -1091,71 +1146,74 @@ class DecoderLayer(nn.Module): ...@@ -1091,71 +1146,74 @@ class DecoderLayer(nn.Module):
x = inputs x = inputs
# Self-attention block # Self-attention block
x = MultiHeadAttention(num_heads=self.num_attention_heads, x = MultiHeadAttention(
num_gqa_groups=self.num_gqa_groups, num_heads=self.num_attention_heads,
dtype=self.dtype, num_gqa_groups=self.num_gqa_groups,
head_dim=self.head_dim, dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence, head_dim=self.head_dim,
dropout_rate=self.attention_dropout, transpose_batch_sequence=self.transpose_batch_sequence,
float32_logits=self.float32_attention_logits, dropout_rate=self.attention_dropout,
scale_attn_logits=self.scale_attn_logits, float32_logits=self.float32_attention_logits,
scaled_query_init=self.scaled_query_init, scale_attn_logits=self.scale_attn_logits,
enable_rotary_pos_emb=self.enable_rotary_pos_emb, scaled_query_init=self.scaled_query_init,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, enable_rotary_pos_emb=self.enable_rotary_pos_emb,
fuse_qkv=self.fuse_qkv_params, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
use_bias=self.use_bias, fuse_qkv=self.fuse_qkv_params,
name='self_attention')(x, use_bias=self.use_bias,
x, name="self_attention",
decoder_mask, )(x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode)
decoder_bias, x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
deterministic=deterministic, x, deterministic=deterministic
decode=decode) )
x = nn.Dropout(rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims)(x, deterministic=deterministic)
if self.drop_path > 0.0: if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim) drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
x = nn.Dropout(rate=self.drop_path, x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
broadcast_dims=drop_path_shape)(x, deterministic=deterministic) x, deterministic=deterministic
)
x = x + residual x = x + residual
# Encoder-Decoder block. # Encoder-Decoder block.
residual = x residual = x
y = LayerNorm(layernorm_type=self.layernorm_type, y = LayerNorm(
epsilon=self.layernorm_epsilon, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, epsilon=self.layernorm_epsilon,
dtype=self.dtype, zero_centered_gamma=self.zero_centered_gamma,
name='pre_cross_attention_layer_norm')(x) dtype=self.dtype,
name="pre_cross_attention_layer_norm",
)(x)
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
residual = y residual = y
y = MultiHeadAttention(num_heads=self.num_attention_heads, y = MultiHeadAttention(
num_gqa_groups=self.num_gqa_groups, num_heads=self.num_attention_heads,
dtype=self.dtype, num_gqa_groups=self.num_gqa_groups,
head_dim=self.head_dim, dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence, head_dim=self.head_dim,
dropout_rate=self.attention_dropout, transpose_batch_sequence=self.transpose_batch_sequence,
float32_logits=self.float32_attention_logits, dropout_rate=self.attention_dropout,
scale_attn_logits=self.scale_attn_logits, float32_logits=self.float32_attention_logits,
scaled_query_init=self.scaled_query_init, scale_attn_logits=self.scale_attn_logits,
enable_rotary_pos_emb=self.enable_rotary_pos_emb, scaled_query_init=self.scaled_query_init,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, enable_rotary_pos_emb=self.enable_rotary_pos_emb,
fuse_qkv=self.fuse_qkv_params, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
use_bias=self.use_bias, fuse_qkv=self.fuse_qkv_params,
name='encoder_decoder_attention')(y, use_bias=self.use_bias,
encoded, name="encoder_decoder_attention",
encoder_decoder_mask, )(y, encoded, encoder_decoder_mask, deterministic=deterministic)
deterministic=deterministic) y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
y = nn.Dropout(rate=self.hidden_dropout, y, deterministic=deterministic
broadcast_dims=self.hidden_dropout_dims)(y, deterministic=deterministic) )
y = y + residual y = y + residual
# MLP block. # MLP block.
residual = y residual = y
z = LayerNorm(layernorm_type=self.layernorm_type, z = LayerNorm(
epsilon=self.layernorm_epsilon, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, epsilon=self.layernorm_epsilon,
dtype=self.dtype, zero_centered_gamma=self.zero_centered_gamma,
name='pre_mlp_layer_norm')(y) dtype=self.dtype,
name="pre_mlp_layer_norm",
)(y)
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
residual = z residual = z
z = MlpBlock( z = MlpBlock(
...@@ -1167,22 +1225,26 @@ class DecoderLayer(nn.Module): ...@@ -1167,22 +1225,26 @@ class DecoderLayer(nn.Module):
use_bias=self.use_bias, use_bias=self.use_bias,
dtype=self.dtype, dtype=self.dtype,
fuse_wi=self.fuse_mlp_wi, fuse_wi=self.fuse_mlp_wi,
name='mlp', name="mlp",
)(z, deterministic=deterministic) )(z, deterministic=deterministic)
z = nn.Dropout(rate=self.hidden_dropout, z = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
broadcast_dims=self.hidden_dropout_dims)(z, deterministic=deterministic) z, deterministic=deterministic
)
if self.drop_path > 0.0: if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim) drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
z = nn.Dropout(rate=self.drop_path, z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
broadcast_dims=drop_path_shape)(z, deterministic=deterministic) z, deterministic=deterministic
)
z = z + residual z = z + residual
if self.output_layernorm: if self.output_layernorm:
z = LayerNorm(layernorm_type=self.layernorm_type, z = LayerNorm(
epsilon=self.layernorm_epsilon, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma, epsilon=self.layernorm_epsilon,
dtype=self.dtype, zero_centered_gamma=self.zero_centered_gamma,
name="output_layernorm")(z) dtype=self.dtype,
name="output_layernorm",
)(z)
return z return z
...@@ -1261,15 +1323,18 @@ def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08): ...@@ -1261,15 +1323,18 @@ def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08):
flatten_expected, _ = jax.tree_util.tree_flatten_with_path(expected) flatten_expected, _ = jax.tree_util.tree_flatten_with_path(expected)
flatten_actual, _ = jax.tree_util.tree_flatten_with_path(actual) flatten_actual, _ = jax.tree_util.tree_flatten_with_path(actual)
for (expected_path, expected_value), (actual_path, for (expected_path, expected_value), (actual_path, actual_value) in zip(
actual_value) in zip(flatten_expected, flatten_actual): flatten_expected, flatten_actual
):
assert expected_path == actual_path assert expected_path == actual_path
key_str = jax.tree_util.keystr(expected_path) key_str = jax.tree_util.keystr(expected_path)
assert_allclose(expected_value, assert_allclose(
actual_value, expected_value,
rtol=rtol, actual_value,
atol=atol, rtol=rtol,
err_msg=f'Value of expected{key_str} and actual{key_str} is not close') atol=atol,
err_msg=f"Value of expected{key_str} and actual{key_str} is not close",
)
def dtype_tols( def dtype_tols(
...@@ -1323,7 +1388,7 @@ def dtype_tols( ...@@ -1323,7 +1388,7 @@ def dtype_tols(
) )
def sync_params_values(dst, src, transformations, sep='/'): def sync_params_values(dst, src, transformations, sep="/"):
""" """
This function will reconstuct a tree with dst's tree_def/shape and src's value. This function will reconstuct a tree with dst's tree_def/shape and src's value.
transformations is a map that records the key mappings between dst and src. transformations is a map that records the key mappings between dst and src.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment