Unverified Commit 9416519d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Apply formatting (#929)



* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Apply formatting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d99142a0
......@@ -79,6 +79,7 @@ def calibrate(model, device, test_loader, fp8):
with te.fp8_autocast(enabled=fp8, calibrating=True):
output = model(data)
def test(model, device, test_loader, use_fp8):
"""Testing function."""
model.eval()
......@@ -89,12 +90,8 @@ def test(model, device, test_loader, use_fp8):
data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=use_fp8):
output = model(data)
test_loss += F.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
......@@ -150,9 +147,7 @@ def main():
default=False,
help="quickly check a single pass",
)
parser.add_argument(
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--log-interval",
type=int,
......@@ -167,7 +162,10 @@ def main():
help="For Saving the current Model",
)
parser.add_argument(
"--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration"
"--use-fp8",
action="store_true",
default=False,
help="Use FP8 for inference and training without recalibration",
)
parser.add_argument(
"--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only"
......@@ -215,7 +213,7 @@ def main():
if args.save_model or args.use_fp8_infer:
torch.save(model.state_dict(), "mnist_cnn.pt")
print('Eval with reloaded checkpoint : fp8='+str(args.use_fp8_infer))
print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8_infer))
weights = torch.load("mnist_cnn.pt")
model.load_state_dict(weights)
test(model, device, test_loader, args.use_fp8_infer)
......
......@@ -18,21 +18,26 @@ path = sys.argv[1]
config_path = os.path.dirname(os.path.realpath(__file__)) + "/config.json"
class bcolors:
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
def print_ok(msg):
print(f"{bcolors.OKGREEN}{msg}{bcolors.ENDC}")
def print_fail(msg):
print(f"{bcolors.FAIL}{msg}{bcolors.ENDC}")
def print_warn(msg):
print(f"{bcolors.WARNING}{msg}{bcolors.ENDC}")
with open(config_path, "r") as f:
c = json.load(f)
current_year = datetime.date.today().year
......@@ -41,7 +46,7 @@ with open(config_path, "r") as f:
else:
year_string = str(c["initial_year"]) + "-" + str(current_year)
copyright_string = c["copyright"].replace("<YEAR>", year_string)
license = c["license"].split('\n')
license = c["license"].split("\n")
excludes = c["exclude"]
root_path = os.path.abspath(path)
copyright_only = c["copyright_only"]
......@@ -49,36 +54,42 @@ with open(config_path, "r") as f:
has_gitignore = os.path.exists(root_path + "/.gitignore")
def strip_star_slash(s):
ret = s
if ret.startswith('*'):
if ret.startswith("*"):
ret = ret[1:]
if ret.endswith('/'):
if ret.endswith("/"):
ret = ret[:-1]
return ret
if has_gitignore:
with open(root_path + "/.gitignore", "r") as f:
for line in f.readlines():
excludes.append(strip_star_slash(line.strip()))
def get_file_type(path):
ext = {"c": ["c", "cpp", "cu", "h", "cuh"],
"py": ["py"],
"rst": ["rst"],
"txt": ["txt"],
"cfg": ["cfg"],
"sh": ["sh"],
"md": ["md"],
}
ext = {
"c": ["c", "cpp", "cu", "h", "cuh"],
"py": ["py"],
"rst": ["rst"],
"txt": ["txt"],
"cfg": ["cfg"],
"sh": ["sh"],
"md": ["md"],
}
tmp = path.split(".")
for filetype, ext_list in ext.items():
if tmp[-1] in ext_list:
return filetype
return "unknown"
success = True
def check_file(path):
global success
N = 10
......@@ -127,9 +138,10 @@ def check_file(path):
if copyright_found and license_found:
print_ok("OK")
for root, dirs, files in os.walk(root_path):
print(f"Entering {root}")
hidden = [d for d in dirs if d.startswith('.')] + [f for f in files if f.startswith('.')]
hidden = [d for d in dirs if d.startswith(".")] + [f for f in files if f.startswith(".")]
all_excludes = excludes + hidden
to_remove = []
for d in dirs:
......
......@@ -27,12 +27,13 @@ current_file_path = Path(__file__).parent.resolve()
from setuptools.command.build_ext import build_ext as BuildExtension
if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks:
from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
install_and_import('pybind11')
install_and_import("pybind11")
from pybind11.setup_helpers import build_ext as BuildExtension
......@@ -86,34 +87,45 @@ if __name__ == "__main__":
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension
ext_modules.append(
setup_pytorch_extension(
"transformer_engine/pytorch/csrc",
current_file_path / "transformer_engine" / "pytorch" / "csrc",
current_file_path / "transformer_engine"))
current_file_path / "transformer_engine",
)
)
if "jax" in frameworks:
from build_tools.jax import setup_jax_extension
ext_modules.append(
setup_jax_extension(
"transformer_engine/jax/csrc",
current_file_path / "transformer_engine" / "jax" / "csrc",
current_file_path / "transformer_engine"))
current_file_path / "transformer_engine",
)
)
if "paddle" in frameworks:
from build_tools.paddle import setup_paddle_extension
ext_modules.append(
setup_paddle_extension(
"transformer_engine/paddle/csrc",
current_file_path / "transformer_engine" / "paddle" / "csrc",
current_file_path / "transformer_engine"))
current_file_path / "transformer_engine",
)
)
# Configure package
setuptools.setup(
name="transformer_engine",
version=__version__,
packages=setuptools.find_packages(
include=["transformer_engine",
"transformer_engine.*",
"transformer_engine/build_tools"],
include=[
"transformer_engine",
"transformer_engine.*",
"transformer_engine/build_tools",
],
),
extras_require={
"test": test_requires,
......@@ -125,5 +137,5 @@ if __name__ == "__main__":
install_requires=install_requires,
license_files=("LICENSE",),
include_package_data=True,
package_data={"": ["VERSION.txt"]}
package_data={"": ["VERSION.txt"]},
)
......@@ -132,7 +132,7 @@ void compute_bwd_ref(
for (int b = 0; b < batches; ++b) {
for (int h = 0; h < heads; ++h) {
size_t offset = b * batch_size + h * head_size;
compute_single_head_bwd(grad_out + offset, grad_in + offset, softmax_in + offset,
compute_single_head_bwd(grad_out + offset, grad_in + offset, softmax_in + offset,
buff + offset, scaling_factor, batches, heads, rows, cols);
}
}
......
......@@ -6,7 +6,7 @@ import jax
import pytest
@pytest.fixture(autouse=True, scope='function')
@pytest.fixture(autouse=True, scope="function")
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
......
......@@ -16,15 +16,15 @@ from utils import assert_allclose, is_devices_enough
def generate_configs():
configs = []
if is_devices_enough(2):
configs.append([2, (2,), ('dp'), MeshResource(dp_resource='dp')])
configs.append([2, (2,), ('tp'), MeshResource(tp_resource='tp')])
configs.append([2, (2,), "dp", MeshResource(dp_resource="dp")])
configs.append([2, (2,), "tp", MeshResource(tp_resource="tp")])
if is_devices_enough(4):
TP_size = 2
DP_size = 2
configs.append(
[4, (DP_size, TP_size), ('dp', 'tp'),
MeshResource(dp_resource='dp', tp_resource='tp')])
[4, (DP_size, TP_size), ("dp", "tp"), MeshResource(dp_resource="dp", tp_resource="tp")]
)
return configs
......@@ -46,7 +46,7 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
bytes_count = 0
def get_bytes_per_txt(t):
'''
"""
The pattern of t would be like:
'f32[]',
'(f32[1024]{0}',
......@@ -54,24 +54,24 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
'f8E4M3FN[1024]{0}',
'i32[1024]{0}',
'bf16[1024,1024]{0}'
'''
match = re.search(r'(i|f)(\d+).*\[([0-9,]*)\]', t)
"""
match = re.search(r"(i|f)(\d+).*\[([0-9,]*)\]", t)
_, bits_of_type, shape = match.groups()
bytes_of_type = int(bits_of_type) // 8
if shape == '':
if shape == "":
num_of_elements = 1
else:
num_of_elements = reduce(operator.mul, map(int, shape.split(',')))
num_of_elements = reduce(operator.mul, map(int, shape.split(",")))
return bytes_of_type * num_of_elements
# ['xxx-start', '=', '(bf16[xxx]', 'bf16[xxx])', 'xxx-start(', ...]
if '(' in hlo_text[2]:
if "(" in hlo_text[2]:
for txt in hlo_text[2:]:
bytes_count += get_bytes_per_txt(txt)
if ')' in txt:
if ")" in txt:
break
else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...]
else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...]
bytes_count = get_bytes_per_txt(hlo_text[2])
return bytes_count
......@@ -91,21 +91,24 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
return result
target_result = count_collectives(target_splitted_hlo)
assert target_result == coll_count_ref, \
f"Expected collective count is {coll_count_ref}, but got {target_result}."
def compare_ops(target_func,
ref_func,
inputs,
coll_count_ref,
*,
grad_args=None,
metric_fwd_dtype=None,
metric_bwd_dtype=None,
in_shardings=_UNSPECIFIED,
out_shardings=_UNSPECIFIED,
**kwargs):
assert (
target_result == coll_count_ref
), f"Expected collective count is {coll_count_ref}, but got {target_result}."
def compare_ops(
target_func,
ref_func,
inputs,
coll_count_ref,
*,
grad_args=None,
metric_fwd_dtype=None,
metric_bwd_dtype=None,
in_shardings=_UNSPECIFIED,
out_shardings=_UNSPECIFIED,
**kwargs,
):
assert len(inputs) >= 1
if metric_fwd_dtype is None:
......
......@@ -15,24 +15,10 @@ from jax import jit, value_and_grad
from flax import linen as nn
from utils import assert_allclose
from transformer_engine.jax.dot import (
type_safe_dot_general,
dequantize,
quantize
)
from transformer_engine.jax.fp8 import (
FP8MetaPackage,
FP8Helper,
is_fp8_available
)
from transformer_engine.jax.layernorm import (
layernorm,
layernorm_fp8_dot
)
from transformer_engine.jax.layernorm_mlp import (
activation_lu,
fused_layernorm_fp8_mlp
)
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp
from transformer_engine.jax import cpp_extensions as tex
GEMM_CASES = [
......@@ -50,11 +36,11 @@ is_fp8_supported, reason = is_fp8_available()
def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function."""
if fn_or_string == 'linear':
if fn_or_string == "linear":
return lambda x: x
if fn_or_string == 'quick_gelu':
if fn_or_string == "quick_gelu":
return lambda x: nn.gelu(x, approximate=True)
if fn_or_string == 'squared_relu':
if fn_or_string == "squared_relu":
return lambda x: functools.reduce(operator.mul, [nn.relu(x), nn.relu(x)])
if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string)
......@@ -93,7 +79,7 @@ class TestFP8Dot:
assert_allclose(z, x, dtype=jnp.float8_e4m3fn)
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_forward_bf16(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
......@@ -106,7 +92,7 @@ class TestFP8Dot:
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_forward_fp8_randint(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
......@@ -135,7 +121,7 @@ class TestFP8Dot:
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_grad_bf16(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
......@@ -161,7 +147,7 @@ class TestFP8Dot:
assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_grad_fp8_dot(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
......@@ -192,33 +178,38 @@ class TestFP8Dot:
ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)
for _ in range(3):
primitive_out, (primitive_a_grad, primitive_b_grad, amax_list,
scale_list) = value_n_grad_primitive_func(a, b, amax_list, scale_list)
primitive_out, (primitive_a_grad, primitive_b_grad, amax_list, scale_list) = (
value_n_grad_primitive_func(a, b, amax_list, scale_list)
)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', [(256, 128, 512),
(16384, 1024, 2816),
(16384, 2816, 1024),
(16384, 1024, 1024)])
@pytest.mark.parametrize('activation_type', [('gelu', ),
('gelu', 'linear'),
('silu', ),
('silu', 'linear'),
('relu',),
('relu', 'linear'),
('quick_gelu',),
('quick_gelu', 'linear'),
('squared_relu',),
('squared_relu', 'linear')])
@pytest.mark.parametrize('use_bias', [True, False])
def test_grad_fused_layernorm_fp8_mlp(self, m, n, k, activation_type: Sequence[Union[str,
Callable]],
use_bias: bool):
""" N/a """
@pytest.mark.parametrize(
"m,n,k", [(256, 128, 512), (16384, 1024, 2816), (16384, 2816, 1024), (16384, 1024, 1024)]
)
@pytest.mark.parametrize(
"activation_type",
[
("gelu",),
("gelu", "linear"),
("silu",),
("silu", "linear"),
("relu",),
("relu", "linear"),
("quick_gelu",),
("quick_gelu", "linear"),
("squared_relu",),
("squared_relu", "linear"),
],
)
@pytest.mark.parametrize("use_bias", [True, False])
def test_grad_fused_layernorm_fp8_mlp(
self, m, n, k, activation_type: Sequence[Union[str, Callable]], use_bias: bool
):
"""N/a"""
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 6)
......@@ -233,8 +224,9 @@ class TestFP8Dot:
b1 = None
b2 = None
def primitive_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1,
scale_list_2):
def primitive_func(
x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2
):
# x is input tensor, matrix 2d
# y, z are weights, matrix 2d
# out = ((x * y) + w) * z + v
......@@ -255,18 +247,31 @@ class TestFP8Dot:
scale_list_2[2],
)
return jnp.mean(
fused_layernorm_fp8_mlp(x,
ln_s,
None, [y, z], [w, v], [fp8_meta_pkg_1, fp8_meta_pkg_2],
"rmsnorm",
activation_type=activation_type,
use_bias=use_bias))
def layernorm_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray,
amax_list_1: List[jnp.ndarray], amax_list_2: List[jnp.ndarray],
scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray]) -> jnp.ndarray:
fused_layernorm_fp8_mlp(
x,
ln_s,
None,
[y, z],
[w, v],
[fp8_meta_pkg_1, fp8_meta_pkg_2],
"rmsnorm",
activation_type=activation_type,
use_bias=use_bias,
)
)
def layernorm_fp8_mlp_ref(
x: jnp.ndarray,
ln_scale: jnp.ndarray,
kernel_1: jnp.ndarray,
kernel_2: jnp.ndarray,
bias_1: jnp.ndarray,
bias_2: jnp.ndarray,
amax_list_1: List[jnp.ndarray],
amax_list_2: List[jnp.ndarray],
scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray],
) -> jnp.ndarray:
x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
......@@ -315,11 +320,14 @@ class TestFP8Dot:
def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2):
return jnp.mean(
layernorm_fp8_mlp_ref(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1,
scale_list_2))
layernorm_fp8_mlp_ref(
x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2
)
)
value_n_grad_primitive_func = jit(
value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
)
value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
_, amax_list_1, scale_list_1 = TestFP8Dot._generate_fp8_meta()
......@@ -339,40 +347,87 @@ class TestFP8Dot:
# Convert str to index as str is not a valid type for JAX JIT
for _ in range(3):
ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad,
ref_amax_list_1, ref_amax_list_2, ref_scale_list_1,
ref_scale_list_2) = value_n_grad_ref_func(a, s, k1, k2, b1, b2,
ref_amax_list_1, ref_amax_list_2,
ref_scale_list_1, ref_scale_list_2)
ref_out, (
ref_a_grad,
ref_s_grad,
ref_k1_grad,
ref_k2_grad,
ref_b1_grad,
ref_b2_grad,
ref_amax_list_1,
ref_amax_list_2,
ref_scale_list_1,
ref_scale_list_2,
) = value_n_grad_ref_func(
a,
s,
k1,
k2,
b1,
b2,
ref_amax_list_1,
ref_amax_list_2,
ref_scale_list_1,
ref_scale_list_2,
)
for _ in range(3):
primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
primitive_k2_grad, primitive_b1_grad, primitive_b2_grad,
primitive_amax_list_1, primitive_amax_list_2, primitive_scale_list_1,
primitive_scale_list_2) = value_n_grad_primitive_func(
a, s, k1, k2, b1, b2, primitive_amax_list_1, primitive_amax_list_2,
primitive_scale_list_1, primitive_scale_list_2)
primitive_out, (
primitive_a_grad,
primitive_s_grad,
primitive_k1_grad,
primitive_k2_grad,
primitive_b1_grad,
primitive_b2_grad,
primitive_amax_list_1,
primitive_amax_list_2,
primitive_scale_list_1,
primitive_scale_list_2,
) = value_n_grad_primitive_func(
a,
s,
k1,
k2,
b1,
b2,
primitive_amax_list_1,
primitive_amax_list_2,
primitive_scale_list_1,
primitive_scale_list_2,
)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
jnp.asarray(ref_a_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_k1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(
jnp.asarray(primitive_a_grad, np.float32),
jnp.asarray(ref_a_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE,
)
assert_allclose(
jnp.asarray(primitive_k1_grad, np.float32),
jnp.asarray(ref_k1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE,
)
assert_allclose(
jnp.asarray(primitive_s_grad, np.float32),
jnp.asarray(ref_s_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE,
)
assert_allclose(
jnp.asarray(primitive_k2_grad, np.float32),
jnp.asarray(ref_k2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE,
)
if use_bias:
assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
jnp.asarray(ref_b2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(
jnp.asarray(primitive_b2_grad, np.float32),
jnp.asarray(ref_b2_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE,
)
assert_allclose(
jnp.asarray(primitive_b1_grad, np.float32),
jnp.asarray(ref_b1_grad, np.float32),
dtype=FP8Helper.BWD_DTYPE,
)
@pytest.fixture(name="random_inputs")
......@@ -402,17 +457,22 @@ class TestActivationLu:
def primitive_func(self, inputs):
return jnp.mean(activation_lu(inputs, activation_type=self.activation_type))
@pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
@pytest.mark.parametrize('activation_type', [('gelu',),
('gelu', 'linear'),
('silu',),
('silu', 'linear'),
('relu',),
('relu', 'linear'),
('quick_gelu',),
('quick_gelu', 'linear'),
('squared_relu',),
('squared_relu', 'linear') ])
@pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)])
@pytest.mark.parametrize(
"activation_type",
[
("gelu",),
("gelu", "linear"),
("silu",),
("silu", "linear"),
("relu",),
("relu", "linear"),
("quick_gelu",),
("quick_gelu", "linear"),
("squared_relu",),
("squared_relu", "linear"),
],
)
def test_activation_lu(self, random_inputs, activation_type):
x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=1)
......@@ -441,23 +501,34 @@ class TestActivationLuFP8(TestActivationLu):
return output
def _prim_func_fwd(x, _x_t, _dbias, _amax):
activation_lu_out, _ = tex.act_lu_fp8(x, amax, scale, scale_inv,
FP8Helper.FWD_DTYPE, activation_type)
activation_lu_out, _ = tex.act_lu_fp8(
x, amax, scale, scale_inv, FP8Helper.FWD_DTYPE, activation_type
)
activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
ctx = (x)
ctx = x
return activation_lu_out, ctx
def _prim_func_bwd(ctx, g):
x = ctx
if len(self.activation_type) > 1: #gated, no bias
dactivation_lu, dactivation_lu_trans, amax_out = \
tex.dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv,
FP8Helper.BWD_DTYPE, -1, activation_type)
if len(self.activation_type) > 1: # gated, no bias
dactivation_lu, dactivation_lu_trans, amax_out = tex.dgated_act_lu_cast_transpose(
g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE, -1, activation_type
)
dbias = jnp.empty(x.shape[-1], x.dtype)
else: #not gated, with bias
dactivation_lu, dactivation_lu_trans, dbias, amax_out = \
tex.dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE,
-1, -2, self.activation_type)
else: # not gated, with bias
dactivation_lu, dactivation_lu_trans, dbias, amax_out = (
tex.dact_lu_dbias_cast_transpose(
g,
x,
amax,
scale,
scale_inv,
FP8Helper.BWD_DTYPE,
-1,
-2,
self.activation_type,
)
)
dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out)
......@@ -468,23 +539,28 @@ class TestActivationLuFP8(TestActivationLu):
dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_indices], dtype=x.dtype)
dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
amax_no_use = jnp.zeros(1, jnp.float32)
value_n_grad_primitive_func = value_and_grad(lambda a, b, c, d:
jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3))
value_n_grad_primitive_func = value_and_grad(
lambda a, b, c, d: jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3)
)
return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
@pytest.mark.parametrize('activation_type', [('gelu',),
('gelu', 'linear'),
('silu',),
('silu', 'linear'),
('relu',),
('relu', 'linear'),
('quick_gelu',),
('quick_gelu', 'linear'),
('squared_relu',),
('squared_relu', 'linear') ])
@pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)])
@pytest.mark.parametrize(
"activation_type",
[
("gelu",),
("gelu", "linear"),
("silu",),
("silu", "linear"),
("relu",),
("relu", "linear"),
("quick_gelu",),
("quick_gelu", "linear"),
("squared_relu",),
("squared_relu", "linear"),
],
)
def test_activation_lu(self, random_inputs, activation_type):
self.amax = jnp.zeros(1, jnp.float32)
self.scale = jnp.ones(1, jnp.float32)
......@@ -500,12 +576,14 @@ class TestActivationLuFP8(TestActivationLu):
assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
if 'linear' not in activation_type:
if "linear" not in activation_type:
assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(prim_grad_trans,
jnp.transpose(ref_grad, self.transpose_indices),
dtype=FP8Helper.BWD_DTYPE)
assert_allclose(
prim_grad_trans,
jnp.transpose(ref_grad, self.transpose_indices),
dtype=FP8Helper.BWD_DTYPE,
)
class TestNorm:
......@@ -536,34 +614,38 @@ class TestNorm:
"""
x_ = jnp.asarray(x, jnp.float32)
if bias is None:
mean = 0.
mean = 0.0
else:
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
if zero_centered_gamma:
scale += 1.
scale += 1.0
if bias is None:
bias = 0.
bias = 0.0
return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
@pytest.mark.parametrize('n, hidden', LN_CASES)
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm'])
@pytest.mark.parametrize('zero_centered_gamma', [False, True])
@pytest.mark.parametrize('epsilon', [1e-2, 1e-6])
def test_layernorm_forward_backward(self, n, hidden, ln_type, zero_centered_gamma, epsilon,
dtype):
@pytest.mark.parametrize("n, hidden", LN_CASES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("epsilon", [1e-2, 1e-6])
def test_layernorm_forward_backward(
self, n, hidden, ln_type, zero_centered_gamma, epsilon, dtype
):
"""
Test transformer_engine.jax.layernorm.layernorm
"""
expect_assert = False
if ln_type == 'rmsnorm' and zero_centered_gamma:
if ln_type == "rmsnorm" and zero_centered_gamma:
# zero_centered_gamma is not supported for rmsnorm, expect an assertion.
expect_assert = True
with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*"
) if expect_assert else nullcontext():
with (
pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*")
if expect_assert
else nullcontext()
):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3)
......@@ -571,7 +653,7 @@ class TestNorm:
gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
gamma = jnp.asarray(gamma, dtype)
if ln_type == 'layernorm':
if ln_type == "layernorm":
beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
beta = jnp.asarray(beta, dtype)
else:
......@@ -585,19 +667,27 @@ class TestNorm:
jitted_primitive = jit(
value_and_grad(
lambda x, gamma, beta: compute_loss(
layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)),
(0, 1, 2)))
layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)
),
(0, 1, 2),
)
)
jitted_reference = jit(
value_and_grad(
lambda x, gamma, beta: compute_loss(
self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)),
(0, 1, 2)))
self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
),
(0, 1, 2),
)
)
primitive_out, (primitive_dx, primitive_dgamma,
primitive_dbeta) = jitted_primitive(x, gamma, beta)
reference_out, (reference_dx, reference_dgamma,
reference_dbeta) = jitted_reference(x, gamma, beta)
primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive(
x, gamma, beta
)
reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(
x, gamma, beta
)
assert_allclose(primitive_out, reference_out, dtype=dtype)
assert_allclose(primitive_dx, reference_dx, dtype=dtype)
......@@ -606,21 +696,24 @@ class TestNorm:
assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
@pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm'])
@pytest.mark.parametrize('zero_centered_gamma', [True, False])
@pytest.mark.parametrize('epsilon', [1e-2, 1e-6])
@pytest.mark.parametrize("m,n,k", GEMM_CASES)
@pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"])
@pytest.mark.parametrize("zero_centered_gamma", [True, False])
@pytest.mark.parametrize("epsilon", [1e-2, 1e-6])
def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon):
"""
Test transformer_engine.jax.layernorm.layernorm_fp8_dot
"""
expect_assert = False
if ln_type == 'rmsnorm' and zero_centered_gamma:
if ln_type == "rmsnorm" and zero_centered_gamma:
# zero_centered_gamma is not supported for rmsnorm, expect an assertion.
expect_assert = True
with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*"
) if expect_assert else nullcontext():
with (
pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*")
if expect_assert
else nullcontext()
):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 4)
......@@ -628,7 +721,7 @@ class TestNorm:
b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
if ln_type == 'layernorm':
if ln_type == "layernorm":
beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
else:
beta = None
......@@ -644,8 +737,9 @@ class TestNorm:
amax_list_1[2],
scale_list_1[2],
)
primitive_out = layernorm_fp8_dot(x, y, gamma, beta, fp8_meta_pkg, ln_type,
zero_centered_gamma)
primitive_out = layernorm_fp8_dot(
x, y, gamma, beta, fp8_meta_pkg, ln_type, zero_centered_gamma
)
return jnp.mean(primitive_out)
def ref_func(x, y, gamma, beta, zero_centered_gamma):
......@@ -655,14 +749,19 @@ class TestNorm:
value_n_grad_primitive_func = value_and_grad(primitive_func, range(6))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))
ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad,
ref_beta_grad) = value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma)
ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad, ref_beta_grad) = (
value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma)
)
for _ in range(3):
primitive_out, (primitive_a_grad, primitive_b_grad, primitive_gamma_grad,
primitive_beta_grad, amax_list_1,
scale_list_1) = value_n_grad_primitive_func(
a, b, gamma, beta, amax_list_1, scale_list_1)
primitive_out, (
primitive_a_grad,
primitive_b_grad,
primitive_gamma_grad,
primitive_beta_grad,
amax_list_1,
scale_list_1,
) = value_n_grad_primitive_func(a, b, gamma, beta, amax_list_1, scale_list_1)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
......
......@@ -9,16 +9,8 @@ import jax.numpy as jnp
import numpy as np
from flax.linen import dot_product_attention
from jax import random
from jax.sharding import (
Mesh,
NamedSharding,
PartitionSpec
)
from distributed_test_base import (
generate_configs,
generate_collectives_count,
compare_ops
)
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count, compare_ops
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import (
......@@ -27,7 +19,7 @@ from transformer_engine.jax.attention import (
fused_attn_kvpacked,
AttnBiasType,
AttnMaskType,
QKVLayout
QKVLayout,
)
......@@ -36,8 +28,9 @@ DTYPES = [jnp.float16, jnp.bfloat16]
class TestDistributedSelfAttn:
def generate_collectives_count_ref(self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape,
dtype):
def generate_collectives_count_ref(
self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
_, seqlen, _, heads, _ = shape
is_dp_enabled = mesh_resource.dp_resource is not None
......@@ -46,7 +39,7 @@ class TestDistributedSelfAttn:
idx = mesh_axes.index(mesh_resource.tp_resource)
tp_size = mesh_shape[idx]
all_reduce_loss_bytes = 4 # 1 * FP32
all_reduce_loss_bytes = 4 # 1 * FP32
bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize
allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled)
# for loss and dbias
......@@ -57,8 +50,11 @@ class TestDistributedSelfAttn:
qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
bias = random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype) \
if with_bias else None
bias = (
random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype)
if with_bias
else None
)
mask = None
if attn_mask_type == AttnMaskType.PADDING_MASK:
......@@ -66,47 +62,76 @@ class TestDistributedSelfAttn:
elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_self_mask(batch, seqlen)
qkv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource,
None)
bias_pspec = PartitionSpec(None, mesh_resource.tp_resource, None, None) \
if with_bias else None
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \
if attn_mask_type != AttnMaskType.NO_MASK else None
qkv_pspec = PartitionSpec(
mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
)
bias_pspec = (
PartitionSpec(None, mesh_resource.tp_resource, None, None) if with_bias else None
)
mask_pspec = (
PartitionSpec(mesh_resource.dp_resource, None, None, None)
if attn_mask_type != AttnMaskType.NO_MASK
else None
)
return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
@pytest.mark.parametrize(
'attn_bias_type',
[AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS])
@pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dtype', DTYPES)
def test_self_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
attn_bias_type, attn_mask_type, dtype):
"attn_bias_type",
[AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS],
)
@pytest.mark.parametrize(
"attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_self_attn(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
attn_bias_type,
attn_mask_type,
dtype,
):
dropout_prob = 0.0
is_training = True
scaling_factor = 1.0
_, seqlen, _, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
attn_mask_type, dropout_prob, num_head, num_head,
seqlen, seqlen, hidden):
if not is_fused_attn_kernel_available(
dtype,
dtype,
QKVLayout.BS3HD,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_head,
seqlen,
seqlen,
hidden,
):
pytest.skip(f"No FusedAttn backwend found")
def target_func(qkv, bias, mask):
return jnp.mean(
fused_attn_qkvpacked(qkv,
bias,
mask,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training))
fused_attn_qkvpacked(
qkv,
bias,
mask,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
)
)
def ref_func(qkv, bias, mask):
query, key, value = jnp.split(qkv, [1, 2], axis=-3)
......@@ -114,52 +139,59 @@ class TestDistributedSelfAttn:
key = jnp.squeeze(key)
value = jnp.squeeze(value)
output = dot_product_attention(query,
key,
value,
bias=bias,
mask=mask,
deterministic=is_training,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32)
output = dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
deterministic=is_training,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype)
with_bias = attn_bias_type != AttnBiasType.NO_BIAS
(qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = \
self.generate_inputs(data_shape, mesh_resource, with_bias,
attn_mask_type, dtype)
collective_count_ref = self.generate_collectives_count_ref(mesh_shape, mesh_axes,
mesh_resource, with_bias,
data_shape, dtype)
(qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, with_bias, attn_mask_type, dtype
)
collective_count_ref = self.generate_collectives_count_ref(
mesh_shape, mesh_axes, mesh_resource, with_bias, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec))
bias_ = jax.device_put(bias, NamedSharding(mesh, bias_pspec)) \
if bias is not None else bias
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \
if mask is not None else mask
bias_ = (
jax.device_put(bias, NamedSharding(mesh, bias_pspec)) if bias is not None else bias
)
mask_ = (
jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
)
grad_args = (0, 1) if with_bias else (0,)
out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,)
compare_ops(target_func,
ref_func, [qkv_, bias_, mask_],
collective_count_ref,
grad_args=grad_args,
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(qkv_pspec, bias_pspec, mask_pspec),
out_shardings=(None, out_grad_shardings))
compare_ops(
target_func,
ref_func,
[qkv_, bias_, mask_],
collective_count_ref,
grad_args=grad_args,
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(qkv_pspec, bias_pspec, mask_pspec),
out_shardings=(None, out_grad_shardings),
)
class TestDistributedCrossAttn:
def generate_collectives_count_ref(self):
# for loss
all_reduce_loss_bytes = 4 # 1 * FP32
all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype):
......@@ -176,20 +208,26 @@ class TestDistributedCrossAttn:
q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None)
kv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource,
None)
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \
if attn_mask_type != AttnMaskType.NO_MASK else None
kv_pspec = PartitionSpec(
mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
)
mask_pspec = (
PartitionSpec(mesh_resource.dp_resource, None, None, None)
if attn_mask_type != AttnMaskType.NO_MASK
else None
)
return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 12, 64], [32, 512, 16, 64]])
@pytest.mark.parametrize('attn_mask_type',
[AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
@pytest.mark.parametrize('dtype', DTYPES)
def test_cross_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
attn_mask_type, dtype):
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]])
@pytest.mark.parametrize(
"attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_cross_attn(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
):
attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0
is_training = True
......@@ -197,23 +235,36 @@ class TestDistributedCrossAttn:
_, seqlen, num_head, hidden = data_shape
if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BSHD_BS2HD, attn_bias_type,
attn_mask_type, dropout_prob, num_head, num_head,
seqlen, seqlen, hidden):
if not is_fused_attn_kernel_available(
dtype,
dtype,
QKVLayout.BSHD_BS2HD,
attn_bias_type,
attn_mask_type,
dropout_prob,
num_head,
num_head,
seqlen,
seqlen,
hidden,
):
pytest.skip(f"No FusedAttn backwend found")
def target_func(q, kv, mask):
return jnp.mean(
fused_attn_kvpacked(q,
kv,
None,
mask,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training))
fused_attn_kvpacked(
q,
kv,
None,
mask,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training,
)
)
def ref_func(query, kv, mask):
key, value = jnp.split(kv, [1], axis=-3)
......@@ -221,34 +272,41 @@ class TestDistributedCrossAttn:
key = jnp.squeeze(key)
value = jnp.squeeze(value)
output = dot_product_attention(query,
key,
value,
bias=None,
mask=mask,
deterministic=is_training,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32)
output = dot_product_attention(
query,
key,
value,
bias=None,
mask=mask,
deterministic=is_training,
dropout_rate=dropout_prob,
dropout_rng=None,
dtype=jnp.float32,
)
return jnp.mean(output).astype(dtype)
(q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = \
self.generate_inputs(data_shape, mesh_resource, attn_mask_type, dtype)
(q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, attn_mask_type, dtype
)
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
q_ = jax.device_put(q, NamedSharding(mesh, q_pspec))
kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \
if mask is not None else mask
compare_ops(target_func,
ref_func, [q_, kv_, mask_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(q_pspec, kv_pspec, mask_pspec),
out_shardings=(None, (q_pspec, kv_pspec)))
mask_ = (
jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
)
compare_ops(
target_func,
ref_func,
[q_, kv_, mask_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(q_pspec, kv_pspec, mask_pspec),
out_shardings=(None, (q_pspec, kv_pspec)),
)
......@@ -35,31 +35,44 @@ class TestDistributedLayernorm:
else:
raise NotImplementedError
g_pspec = b_pspec = PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
g_pspec = b_pspec = (
PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
)
return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)
def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
is_dp_enabled = mesh_resource.dp_resource is not None
assert ln_type in ['layernorm', 'rmsnorm']
all_reduce_loss_bytes = 4 # 1 * FP32
assert ln_type in ["layernorm", "rmsnorm"]
all_reduce_loss_bytes = 4 # 1 * FP32
# for loss, dgamma and dbeta
weight_count = 2 if ln_type == 'layernorm' else 1
allreduce_total_bytes = all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
return generate_collectives_count(allreduce=allreduce_total_bytes * int(is_dp_enabled),
allgather=0,
other=0)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('zero_centered_gamma', [False, True])
@pytest.mark.parametrize('shard_weights', [False, True])
def test_layernorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype,
zero_centered_gamma, shard_weights):
weight_count = 2 if ln_type == "layernorm" else 1
allreduce_total_bytes = (
all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
)
return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=0
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("shard_weights", [False, True])
def test_layernorm(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
dtype,
zero_centered_gamma,
shard_weights,
):
epsilon = 1e-6
ln_type = 'layernorm'
ln_type = "layernorm"
def target_func(x, gamma, beta):
return jnp.mean(layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon))
......@@ -75,10 +88,12 @@ class TestDistributedLayernorm:
output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype)
return jnp.mean(output)
(x, gamma, beta), (x_pspec, g_pspec, b_pspec) = \
self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights)
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type,
data_shape, dtype)
(x, gamma, beta), (x_pspec, g_pspec, b_pspec) = self.generate_inputs(
data_shape, mesh_resource, dtype, shard_weights
)
collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
......@@ -88,20 +103,25 @@ class TestDistributedLayernorm:
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(target_func,
ref_func, [x_, gamma_, beta_],
collective_count_ref,
grad_args=(0, 1, 2),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)))
compare_ops(
target_func,
ref_func,
[x_, gamma_, beta_],
collective_count_ref,
grad_args=(0, 1, 2),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)),
)
except AssertionError as err:
# Layernorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same
# when XLA is forced to unshard gamma and/or beta. We can catch
# and ignore that specific error here.
if (g_pspec[-1] is None and b_pspec[-1] is None) or "Expected collective count" not in str(err):
if (
g_pspec[-1] is None and b_pspec[-1] is None
) or "Expected collective count" not in str(err):
raise err
finally:
for w in warns:
......@@ -110,13 +130,15 @@ class TestDistributedLayernorm:
"unsupported sharding of gamma and/or beta"
)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('shard_weights', [False, True])
def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights):
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("shard_weights", [False, True])
def test_rmsnorm(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights
):
epsilon = 1e-6
ln_type = 'rmsnorm'
ln_type = "rmsnorm"
def target_func(x, gamma):
return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon))
......@@ -128,10 +150,12 @@ class TestDistributedLayernorm:
output = y * gamma
return jnp.mean(output)
(x, gamma, _), (x_pspec, g_pspec, _) = \
self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights)
collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type,
data_shape, dtype)
(x, gamma, _), (x_pspec, g_pspec, _) = self.generate_inputs(
data_shape, mesh_resource, dtype, shard_weights
)
collective_count_ref = self.generate_collectives_count_ref(
mesh_resource, ln_type, data_shape, dtype
)
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
......@@ -140,14 +164,17 @@ class TestDistributedLayernorm:
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(target_func,
ref_func, [x_, gamma_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec)))
compare_ops(
target_func,
ref_func,
[x_, gamma_],
collective_count_ref,
grad_args=(0, 1),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec)),
)
except AssertionError as err:
# RmsNorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same
......
......@@ -15,22 +15,23 @@ from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.flax import LayerNormMLP
from transformer_engine.jax.layernorm_mlp import fused_layernorm_fp8_mlp
from transformer_engine.jax.sharding import (
HIDDEN_AXES, HIDDEN_TP_AXES,
HIDDEN_AXES,
HIDDEN_TP_AXES,
BATCH_AXES,
SEQLEN_TP_AXES, SEQLEN_AXES,
W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
SEQLEN_TP_AXES,
SEQLEN_AXES,
W_NO_SHARD_AXES,
W_FSDP_AXES,
W_TP_AXES,
W_JOINED_AXES,
)
from transformer_engine.jax.sharding import MeshResource
from utils import (
assert_allclose,
assert_tree_like_allclose,
is_devices_enough
)
from utils import assert_allclose, assert_tree_like_allclose, is_devices_enough
is_fp8_supported, reason = is_fp8_available()
DTYPES = [jnp.bfloat16, jnp.float16]
INPUT_SHAPE = [[64, 128, 32]] # [batch, seqlen, hidden_in]
INPUT_SHAPE = [[64, 128, 32]] # [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
......@@ -43,13 +44,13 @@ def generate_fsdp_and_tp_configs():
configs = []
if is_devices_enough(2):
configs.append(
[2, (1, 2), ('fsdp', 'tp'),
MeshResource(fsdp_resource='fsdp', tp_resource='tp')])
[2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
)
if is_devices_enough(4):
configs.append(
[4, (2, 2), ('fsdp', 'tp'),
MeshResource(fsdp_resource='fsdp', tp_resource='tp')])
[4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
)
return configs
......@@ -64,10 +65,12 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
k1 = jax.random.normal(subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE),
dtype) / jnp.sqrt(hidden_in)
k2 = jax.random.normal(subkeys[2],
(INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(INTERMEDIATE)
k1 = jax.random.normal(
subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
) / jnp.sqrt(hidden_in)
k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
INTERMEDIATE
)
if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
......@@ -90,15 +93,27 @@ class TestDistributedLayernormMLP:
scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray],
layernorm_type: str = "rmsnorm",
activation_type: Sequence[Union[str, Callable]] = ('gelu',),
activation_type: Sequence[Union[str, Callable]] = ("gelu",),
use_bias: bool = True,
multi_gpus: bool = False,
) -> jnp.ndarray:
fp8_meta_pkg1 = FP8MetaPackage(amax_list_1[0], scale_list_1[0], amax_list_1[1],
scale_list_1[1], amax_list_1[2], scale_list_1[2])
fp8_meta_pkg2 = FP8MetaPackage(amax_list_2[0], scale_list_2[0], amax_list_2[1],
scale_list_2[1], amax_list_2[2], scale_list_2[2])
fp8_meta_pkg1 = FP8MetaPackage(
amax_list_1[0],
scale_list_1[0],
amax_list_1[1],
scale_list_1[1],
amax_list_1[2],
scale_list_1[2],
)
fp8_meta_pkg2 = FP8MetaPackage(
amax_list_2[0],
scale_list_2[0],
amax_list_2[1],
scale_list_2[1],
amax_list_2[2],
scale_list_2[2],
)
if multi_gpus:
layernorm_input_axes = LAYERNORM_INPUT_AXES
......@@ -111,60 +126,68 @@ class TestDistributedLayernormMLP:
# out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
return jnp.mean(
fused_layernorm_fp8_mlp(x,
ln_scale,
None, [kernel_1, kernel_2], [bias_1, bias_2],
[fp8_meta_pkg1, fp8_meta_pkg2],
layernorm_type,
layernorm_input_axes=layernorm_input_axes,
dot_1_input_axes=dot_1_input_axes,
dot_2_input_axes=dot_2_input_axes,
activation_type=activation_type,
use_bias=use_bias))
fused_layernorm_fp8_mlp(
x,
ln_scale,
None,
[kernel_1, kernel_2],
[bias_1, bias_2],
[fp8_meta_pkg1, fp8_meta_pkg2],
layernorm_type,
layernorm_input_axes=layernorm_input_axes,
dot_1_input_axes=dot_1_input_axes,
dot_2_input_axes=dot_2_input_axes,
activation_type=activation_type,
use_bias=use_bias,
)
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs())
@pytest.mark.parametrize('input_shape', INPUT_SHAPE)
@pytest.mark.parametrize('activation_type', [("gelu",), ('gelu', 'linear')])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('use_bias', [True, False])
def test_layernorm_fp8_mlp_primitive(self, mesh_config, activation_type, use_bias, input_shape,
dtype):
@pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("use_bias", [True, False])
def test_layernorm_fp8_mlp_primitive(
self, mesh_config, activation_type, use_bias, input_shape, dtype
):
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
layernorm_type = 'rmsnorm'
layernorm_type = "rmsnorm"
fp8_amax_list_1 = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32)
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
]
fp8_amax_list_2 = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32)
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
]
fp8_scale_list_1 = [
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32)
jnp.ones((1,), jnp.float32),
]
fp8_scale_list_2 = [
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32)
jnp.ones((1,), jnp.float32),
]
inputs = [x, gamma, k1, k2, b1, b2] = \
self.generate_inputs(input_shape, activation_type, use_bias, dtype)
inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs(
input_shape, activation_type, use_bias, dtype
)
inputs = [*inputs, fp8_amax_list_1, fp8_amax_list_2, fp8_scale_list_1, fp8_scale_list_2]
static_inputs = [layernorm_type, activation_type, use_bias]
value_and_grad_func = jax.value_and_grad(self.layernorm_fp8_mlp_prim_func,
argnums=range(len(inputs)))
value_and_grad_func = jax.value_and_grad(
self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
)
# Single GPU
single_jitter = jax.jit(value_and_grad_func,
static_argnums=range(len(inputs),
len(static_inputs) + len(inputs)))
single_jitter = jax.jit(
value_and_grad_func, static_argnums=range(len(inputs), len(static_inputs) + len(inputs))
)
with fp8_autocast(enabled=True):
single_fwd, single_grads = single_jitter(*inputs, *static_inputs)
......@@ -172,12 +195,12 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, mesh_resource=mesh_resource):
k1_sharding = NamedSharding(mesh, PartitionSpec('fsdp', None, 'tp'))
k2_sharding = NamedSharding(mesh, PartitionSpec('tp', 'fsdp'))
k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding)
if use_bias:
b1_sharding = NamedSharding(mesh, PartitionSpec(None, 'tp'))
b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_ = jax.device_put(b1, b1_sharding)
else:
b1_sharding = b1_ = None
......@@ -186,17 +209,29 @@ class TestDistributedLayernormMLP:
# Position ref for sharding pspec lists
# x, gamma, k1, k2, b1,
# b2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
in_shardings = (None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None,
None, None)
out_shardings = (None, (None, None, k1_sharding, k2_sharding, b1_sharding, None, None,
None, None, None))
multi_jitter = jax.jit(value_and_grad_func,
in_shardings=in_shardings,
out_shardings=out_shardings,
static_argnums=range(len(multi_inputs),
len(static_inputs) + len(multi_inputs) +
1)) # +1 for multi_gpus
in_shardings = (
None,
None,
k1_sharding,
k2_sharding,
b1_sharding,
None,
None,
None,
None,
None,
)
out_shardings = (
None,
(None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None, None, None),
)
multi_jitter = jax.jit(
value_and_grad_func,
in_shardings=in_shardings,
out_shardings=out_shardings,
static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1),
) # +1 for multi_gpus
multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)
......@@ -206,97 +241,96 @@ class TestDistributedLayernormMLP:
if isinstance(multi_grads[i], list):
assert isinstance(single_grads[i], list)
for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
assert_allclose(m_grad,
s_grad,
dtype=dtype,
err_msg=f'multi_grads[{i}] is not close')
assert_allclose(
m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close"
)
else:
assert_allclose(multi_grads[i],
single_grads[i],
dtype=dtype,
err_msg=f'multi_grads[{i}] is not close')
def _test_layernorm_mlp(self, mesh_config, activation_type, use_bias, input_shape, dtype,
use_fp8):
assert_allclose(
multi_grads[i],
single_grads[i],
dtype=dtype,
err_msg=f"multi_grads[{i}] is not close",
)
def _test_layernorm_mlp(
self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8
):
batch, seqlen, hidden_in = input_shape
layernorm_type = 'rmsnorm'
layernorm_type = "rmsnorm"
rng = jax.random.PRNGKey(0)
subkeys = jax.random.split(rng, 2)
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
init_rngs = {'params': subkeys[1]}
init_rngs = {"params": subkeys[1]}
# Single GPUs
with fp8_autocast(enabled=use_fp8):
ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type,
transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
intermediate_dim=INTERMEDIATE,
activations=activation_type,
dtype=dtype,
use_bias=use_bias,
)
params_single = ln_mlp_single.init(init_rngs, x)
mlp_out_single, ln_out_single = ln_mlp_single.apply(params_single,
x,
deterministic=True)
mlp_out_single, ln_out_single = ln_mlp_single.apply(
params_single, x, deterministic=True
)
# Multi GPUs
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=use_fp8, mesh_resource=mesh_resource):
ln_mlp_sharded = LayerNormMLP(layernorm_type=layernorm_type,
transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE,
activations=activation_type,
dtype=dtype,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
use_bias=use_bias,
bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
bias_axes_2=(W_NO_SHARD_AXES,),
layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES,
name='mlp')
ln_mlp_sharded = LayerNormMLP(
layernorm_type=layernorm_type,
transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE,
activations=activation_type,
dtype=dtype,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
use_bias=use_bias,
bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
bias_axes_2=(W_NO_SHARD_AXES,),
layernorm_input_axes=LAYERNORM_INPUT_AXES,
dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES,
name="mlp",
)
params_sharded = ln_mlp_sharded.init(init_rngs, x)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(params_sharded,
x,
deterministic=True)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
params_sharded, x, deterministic=True
)
# Make sure params values are the same
assert_tree_like_allclose(params_sharded['params'], params_single['params'])
assert_tree_like_allclose(params_sharded["params"], params_single["params"])
assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype)
@pytest.mark.parametrize('input_shape', INPUT_SHAPE)
@pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs())
@pytest.mark.parametrize('activation_type', [("gelu",), ('silu', 'linear'), ('gelu', 'gelu')])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('use_bias', [True, False])
@pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize("activation_type", [("gelu",), ("silu", "linear"), ("gelu", "gelu")])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("use_bias", [True, False])
def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype):
self._test_layernorm_mlp(mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=False)
self._test_layernorm_mlp(
mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False
)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs())
@pytest.mark.parametrize('activation_type', [("gelu",), ('gelu', 'linear'), ('gelu', 'gelu')])
@pytest.mark.parametrize('use_bias', [True, False])
@pytest.mark.parametrize('input_shape', INPUT_SHAPE)
@pytest.mark.parametrize('dtype', DTYPES)
def test_layernorm_fp8_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape,
dtype):
self._test_layernorm_mlp(mesh_config,
activation_type,
use_bias,
input_shape,
dtype,
use_fp8=True)
@pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear"), ("gelu", "gelu")])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("input_shape", INPUT_SHAPE)
@pytest.mark.parametrize("dtype", DTYPES)
def test_layernorm_fp8_mlp_layer(
self, mesh_config, activation_type, use_bias, input_shape, dtype
):
self._test_layernorm_mlp(
mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=True
)
......@@ -25,7 +25,7 @@ class TestDistributedSoftmax:
def generate_collectives_count_ref(self):
# for loss
all_reduce_loss_bytes = 4 # 1 * FP32
all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, softmax_type, dtype, bad_sharding):
......@@ -38,49 +38,65 @@ class TestDistributedSoftmax:
mask = make_self_mask(batch, sqelen)
if not bad_sharding:
x_pspec = PartitionSpec(mesh_resource.dp_resource, mesh_resource.tp_resource,
None, None)
x_pspec = PartitionSpec(
mesh_resource.dp_resource, mesh_resource.tp_resource, None, None
)
else:
x_pspec = PartitionSpec(mesh_resource.dp_resource, None,
None, mesh_resource.tp_resource)
x_pspec = PartitionSpec(
mesh_resource.dp_resource, None, None, mesh_resource.tp_resource
)
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None)
return (x, mask), (x_pspec, mask_pspec)
@staticmethod
def target_func(x, mask, scale_factor=1.0, softmax_type=SoftmaxType.SCALED):
return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type))
return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type))
@staticmethod
def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16):
bias = None
if mask is not None:
bias = jax.lax.select(mask > 0,
jnp.full(mask.shape, -1e10).astype(dtype),
jnp.full(mask.shape, 0.).astype(dtype))
bias = jax.lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(dtype),
jnp.full(mask.shape, 0.0).astype(dtype),
)
if bias is not None:
x = x + bias.astype(dtype)
output = jax.nn.softmax(x * scale_factor)
return jnp.mean(output)
@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize(
'softmax_type',
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED])
@pytest.mark.parametrize('scale_factor', [1.0, 3.0])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('bad_sharding', [False, True])
def test_softmax(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
softmax_type, scale_factor, dtype, bad_sharding):
target_func = partial(self.target_func,
scale_factor=scale_factor,
softmax_type=softmax_type)
"softmax_type",
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
)
@pytest.mark.parametrize("scale_factor", [1.0, 3.0])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("bad_sharding", [False, True])
def test_softmax(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
data_shape,
softmax_type,
scale_factor,
dtype,
bad_sharding,
):
target_func = partial(
self.target_func, scale_factor=scale_factor, softmax_type=softmax_type
)
ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)
(x, mask), (x_pspec, mask_pspec) = \
self.generate_inputs(data_shape, mesh_resource, softmax_type, dtype, bad_sharding)
(x, mask), (x_pspec, mask_pspec) = self.generate_inputs(
data_shape, mesh_resource, softmax_type, dtype, bad_sharding
)
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
......@@ -90,14 +106,17 @@ class TestDistributedSoftmax:
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(target_func,
ref_func, [x_, mask_],
collective_count_ref,
grad_args=(0,),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, mask_pspec),
out_shardings=(None, (x_pspec,)))
compare_ops(
target_func,
ref_func,
[x_, mask_],
collective_count_ref,
grad_args=(0,),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, mask_pspec),
out_shardings=(None, (x_pspec,)),
)
except AssertionError as err:
# Softmax should still produce the correct numerical result with
# bad sharding. However, the collective count may not be the same
......
......@@ -20,12 +20,14 @@ class TestLoRA:
out = jnp.einsum(pattern, x, la, lb)
return out * scale
@pytest.mark.parametrize('shape', [(32, 1024), (32, 128, 1024)])
@pytest.mark.parametrize('dtype', [jnp.float32, jnp.bfloat16])
@pytest.mark.parametrize('axis_features_pattern', [((-1,), (1024,), '...h,hr,rk->...k'),
((-1,), (3, 1024), '...h,hkr,krz->...kz')])
@pytest.mark.parametrize('rank', [32, 16])
@pytest.mark.parametrize('alpha', [None, 4, 8])
@pytest.mark.parametrize("shape", [(32, 1024), (32, 128, 1024)])
@pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16])
@pytest.mark.parametrize(
"axis_features_pattern",
[((-1,), (1024,), "...h,hr,rk->...k"), ((-1,), (3, 1024), "...h,hkr,krz->...kz")],
)
@pytest.mark.parametrize("rank", [32, 16])
@pytest.mark.parametrize("alpha", [None, 4, 8])
def test_lora(self, shape, dtype, axis_features_pattern, rank, alpha):
axis, features, pattern = axis_features_pattern
axis = _normalize_axes(axis, len(shape))
......@@ -49,16 +51,20 @@ class TestLoRA:
assert_allclose(out_target, out_ref, dtype=dtype)
@pytest.mark.parametrize('scope_ref_assert',
[('none', LoRAScope(False, False, False), False),
('all', LoRAScope(True, True, True), False),
('qkv_proj', LoRAScope(True, False, False), False),
('output_proj', LoRAScope(False, True, False), False),
('mlp', LoRAScope(False, False, True), False),
('exclude_qkv_proj', LoRAScope(False, True, True), False),
('exclude_output_proj', LoRAScope(True, False, True), False),
('exclude_mlp', LoRAScope(True, True, False), False),
('messing_up', LoRAScope(), True)])
@pytest.mark.parametrize(
"scope_ref_assert",
[
("none", LoRAScope(False, False, False), False),
("all", LoRAScope(True, True, True), False),
("qkv_proj", LoRAScope(True, False, False), False),
("output_proj", LoRAScope(False, True, False), False),
("mlp", LoRAScope(False, False, True), False),
("exclude_qkv_proj", LoRAScope(False, True, True), False),
("exclude_output_proj", LoRAScope(True, False, True), False),
("exclude_mlp", LoRAScope(True, True, False), False),
("messing_up", LoRAScope(), True),
],
)
def test_lora_scope_generator(self, scope_ref_assert):
scope, reference, need_assert = scope_ref_assert
try:
......
......@@ -24,7 +24,7 @@ from transformer_engine.jax.attention import (
QKVLayout,
fused_attn_qkvpacked,
fused_attn_kvpacked,
fused_attn
fused_attn,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
......@@ -33,7 +33,7 @@ from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
from utils import assert_allclose
@pytest.fixture(autouse=True, scope='module')
@pytest.fixture(autouse=True, scope="module")
def init():
"""
WAR for CUDA uninitialize error
......@@ -43,10 +43,18 @@ def init():
yield
def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike,
bias: ArrayLike, mask: ArrayLike, deterministic: bool,
scale_factor: float, dropout_rate: float, dropout_rng: ArrayLike,
dtype: DTypeLike) -> Array:
def general_dot_product_attention(
query: ArrayLike,
key: ArrayLike,
value: ArrayLike,
bias: ArrayLike,
mask: ArrayLike,
deterministic: bool,
scale_factor: float,
dropout_rate: float,
dropout_rng: ArrayLike,
dtype: DTypeLike,
) -> Array:
"""
Similar to flax.linen.dot_product_attention but with GQA support
"""
......@@ -59,7 +67,7 @@ def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: Array
num_groups = h_q // h_kv
grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
# logits with shape (b, h_kv, num_groups, s_q, s_kv)
logits = scale_factor * jnp.einsum('...qhgd,...khd->...hgqk', grouped_query, key)
logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
if bias is not None:
# reshape logits without groups
......@@ -76,13 +84,13 @@ def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: Array
softmax_out = jax.nn.softmax(logits).astype(dtype)
if not deterministic and dropout_rate > 0.:
if not deterministic and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate
keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
softmax_out = softmax_out * multiplier
context = jnp.einsum('...hgqk,...khd->...qhgd', softmax_out, value)
context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
context = jnp.reshape(context, query.shape)
return context
......@@ -105,6 +113,7 @@ def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
inv_padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0)
return combine_masks(inv_causal_mask, inv_padding_mask)
def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskType) -> Array:
"""
Create attention mask based on mask type. A `True` value in the mask means
......@@ -118,23 +127,26 @@ def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskT
mask = jnp.logical_not(inv_mask)
return mask
def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
"""
JAX native dot product attention implementation
"""
attn_mask_type = kwargs['attn_mask_type']
attn_mask_type = kwargs["attn_mask_type"]
mask = make_mask(q_token, kv_token, attn_mask_type)
output = general_dot_product_attention(query,
key,
value,
bias=bias,
mask=mask,
deterministic=not kwargs['is_training'],
scale_factor=kwargs['scaling_factor'],
dropout_rate=kwargs['dropout_probability'],
dropout_rng=dropout_rng,
dtype=jnp.float32)
output = general_dot_product_attention(
query,
key,
value,
bias=bias,
mask=mask,
deterministic=not kwargs["is_training"],
scale_factor=kwargs["scaling_factor"],
dropout_rate=kwargs["dropout_probability"],
dropout_rng=dropout_rng,
dtype=jnp.float32,
)
return output.astype(query.dtype)
......@@ -142,10 +154,10 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
"""
TE customcall dot product attention implementation
"""
attn_mask_type = kwargs['attn_mask_type']
attn_mask_type = kwargs["attn_mask_type"]
mask = make_mask(q_token, kv_token, attn_mask_type)
qkv_layout = kwargs.pop('qkv_layout')
qkv_layout = kwargs.pop("qkv_layout")
match qkv_layout:
case QKVLayout.BS3HD:
query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
......@@ -154,11 +166,13 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
case QKVLayout.BSHD_BS2HD:
key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
kv = jnp.concatenate((key, value), axis=-3)
return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng,
**kwargs).astype(query.dtype)
return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng, **kwargs).astype(
query.dtype
)
case QKVLayout.BSHD_BSHD_BSHD:
return fused_attn(query, key, value, bias, mask, dropout_rng,
**kwargs).astype(query.dtype)
return fused_attn(query, key, value, bias, mask, dropout_rng, **kwargs).astype(
query.dtype
)
class BiasShape(Enum):
......@@ -166,10 +180,10 @@ class BiasShape(Enum):
Enum class to represent the different bias shapes used in the fused attention.
"""
BIAS_1HSS = '1HSS'
BIAS_B1SS = 'B1SS'
BIAS_BHSS = 'BHSS'
BIAS_11SS = '11SS'
BIAS_1HSS = "1HSS"
BIAS_B1SS = "B1SS"
BIAS_BHSS = "BHSS"
BIAS_11SS = "11SS"
@dataclass
......@@ -177,6 +191,7 @@ class FusedAttnRunner:
"""
Fused attention runner
"""
batch_size: int
max_seqlen_q: int
max_seqlen_kv: int
......@@ -198,21 +213,33 @@ class FusedAttnRunner:
if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")
self.backend = FusedAttnHelper(self.dtype, self.dtype, self.qkv_layout.value,
self.attn_bias_type.value, self.attn_mask_type.value,
self.dropout_prob, self.num_heads_q, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_kv,
self.head_dim).get_fused_attn_backend()
self.backend = FusedAttnHelper(
self.dtype,
self.dtype,
self.qkv_layout.value,
self.attn_bias_type.value,
self.attn_mask_type.value,
self.dropout_prob,
self.num_heads_q,
self.num_heads_kv,
self.max_seqlen_q,
self.max_seqlen_kv,
self.head_dim,
).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("Unsupported inputs combination or device compute capability.")
if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.")
pytest.skip(
"B1SS, BHSS and 11SS bias shapes are only supported for "
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK."
)
elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
"the F16_arbitrary_seqlen backend.")
pytest.skip(
"B1SS, BHSS and 11SS bias shapes are only supported for "
"the F16_arbitrary_seqlen backend."
)
def _setup_inputs(self):
self._check_configs()
......@@ -235,24 +262,25 @@ class FusedAttnRunner:
else:
pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.)
self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.)
self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.)
self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
if self.attn_bias_type != AttnBiasType.NO_BIAS:
if self.bias_shape == BiasShape.BIAS_1HSS:
self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.)
self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
else:
# [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
# an arbitrary mask where (True/False -> 0/-Inf)
cudnn_neg_inf = -2.**27. if self.dtype == jnp.bfloat16 else -2.**15.
cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences
seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences
seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
for i in range(1, len(seq_id)):
self.bias = \
self.bias.at[:, :, seq_id[i-1]:seq_id[i], seq_id[i-1]:seq_id[i]].set(0.)
self.bias = self.bias.at[
:, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
].set(0.0)
else:
self.bias = None
......@@ -271,7 +299,7 @@ class FusedAttnRunner:
self.valid_len_kv, self.token_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1. / sqrt(self.head_dim)
self.scaling_factor = 1.0 / sqrt(self.head_dim)
def test_forward(self):
"""
......@@ -281,19 +309,19 @@ class FusedAttnRunner:
args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
kwargs = {
'attn_bias_type': self.attn_bias_type,
'attn_mask_type': self.attn_mask_type,
'scaling_factor': self.scaling_factor,
'dropout_probability': self.dropout_prob,
'is_training': self.is_training,
'qkv_layout': self.qkv_layout,
"attn_bias_type": self.attn_bias_type,
"attn_mask_type": self.attn_mask_type,
"scaling_factor": self.scaling_factor,
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
}
# Convert the outputs to float32 for the elementwise comparison
primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32)
reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32)
if self.is_training and self.dropout_prob > 0.:
if self.is_training and self.dropout_prob > 0.0:
return
primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1)
......@@ -322,12 +350,12 @@ class FusedAttnRunner:
args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
kwargs = {
'attn_bias_type': self.attn_bias_type,
'attn_mask_type': self.attn_mask_type,
'scaling_factor': self.scaling_factor,
'dropout_probability': self.dropout_prob,
'is_training': self.is_training,
'qkv_layout': self.qkv_layout,
"attn_bias_type": self.attn_bias_type,
"attn_mask_type": self.attn_mask_type,
"scaling_factor": self.scaling_factor,
"dropout_probability": self.dropout_prob,
"is_training": self.is_training,
"qkv_layout": self.qkv_layout,
}
# We can compute dBias only for the [1, h, s, s] layout
......@@ -336,12 +364,18 @@ class FusedAttnRunner:
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(customcall_fused_dpa, q, k, v, bias, *args,
**kwargs), arg_nums))
lambda q, k, v, bias, *args: grad_func(
customcall_fused_dpa, q, k, v, bias, *args, **kwargs
),
arg_nums,
)
)
jitted_reference = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
arg_nums))
arg_nums,
)
)
primitive_out, primitive_dgrad = jitted_primitive(*args)
reference_out, reference_dgrad = jitted_reference(*args)
......@@ -350,9 +384,9 @@ class FusedAttnRunner:
if self.dropout_prob > 0.0:
return
assert_allclose(primitive_out.astype(jnp.float32),
reference_out.astype(jnp.float32),
dtype=self.dtype)
assert_allclose(
primitive_out.astype(jnp.float32), reference_out.astype(jnp.float32), dtype=self.dtype
)
def check_dqkv(primitive, reference, valid_len):
primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1)
......@@ -374,81 +408,158 @@ class FusedAttnRunner:
primitive_dbias = jnp.float32(primitive_dgrad[3])
reference_dbias = jnp.float32(reference_dgrad[3])
assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
jnp.zeros_like(primitive_dbias[..., self.valid_len_q:,
self.valid_len_kv:]),
dtype=self.dtype)
assert_allclose(
primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :],
jnp.zeros_like(primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :]),
dtype=self.dtype,
)
# dbias padded part
assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
reference_dbias[..., self.valid_len_q:, self.valid_len_kv:],
dtype=self.dtype)
assert_allclose(
primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :],
reference_dbias[..., self.valid_len_q :, self.valid_len_kv :],
dtype=self.dtype,
)
# dbias valid part
assert_allclose(primitive_dbias[..., :self.valid_len_q, :self.valid_len_kv],
reference_dbias[..., :self.valid_len_q, :self.valid_len_kv],
dtype=self.dtype)
@pytest.mark.parametrize('attn_bias_type, bias_shape', [
pytest.param(AttnBiasType.NO_BIAS, None, id='NO_BIAS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id='POST_SCALE_BIAS-1HSS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id='POST_SCALE_BIAS-B1SS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id='POST_SCALE_BIAS-BHSS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id='POST_SCALE_BIAS-11SS'),
])
@pytest.mark.parametrize('attn_mask_type', [
pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'),
pytest.param(AttnMaskType.PADDING_MASK, id='PADDING'),
pytest.param(AttnMaskType.CAUSAL_MASK, id='CAUSAL'),
pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id='PADDING_CAUSAL'),
])
@pytest.mark.parametrize('qkv_layout', [
pytest.param(QKVLayout.BS3HD, id='QKV_PACKED'),
pytest.param(QKVLayout.BSHD_BS2HD, id='KV_PACKED'),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='SEPARATE'),
])
@pytest.mark.parametrize('dtype', [
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16"),
])
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d', [
pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'),
pytest.param(4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'),
pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'),
pytest.param(4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'),
pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'),
pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA'),
])
@pytest.mark.parametrize('dropout_prob', [
pytest.param(0.0, id="DROP_0.0"),
pytest.param(0.1, id="DROP_0.1"),
])
assert_allclose(
primitive_dbias[..., : self.valid_len_q, : self.valid_len_kv],
reference_dbias[..., : self.valid_len_q, : self.valid_len_kv],
dtype=self.dtype,
)
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id="POST_SCALE_BIAS-1HSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id="POST_SCALE_BIAS-B1SS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id="POST_SCALE_BIAS-BHSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id="POST_SCALE_BIAS-11SS"),
],
)
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
],
)
@pytest.mark.parametrize(
"qkv_layout",
[
pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
],
)
@pytest.mark.parametrize(
"dtype",
[
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16"),
],
)
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d",
[
pytest.param(32, 128, 128, 16, 16, 64, id="32-128-128-16-16-64-SELF"),
pytest.param(4, 2048, 2048, 12, 12, 64, id="4-2048-2048-12-12-64-SELF"),
pytest.param(32, 512, 128, 16, 16, 64, id="32-512-128-16-16-64-CROSS"),
pytest.param(4, 2048, 1024, 12, 12, 64, id="4-2048-1048-12-12-64-CROSS"),
pytest.param(32, 128, 128, 16, 8, 64, id="32-128-128-16-8-64-GQA"),
pytest.param(4, 2048, 2048, 12, 6, 64, id="4-2048-2048-12-6-64-GQA"),
],
)
@pytest.mark.parametrize(
"dropout_prob",
[
pytest.param(0.0, id="DROP_0.0"),
pytest.param(0.1, id="DROP_0.1"),
],
)
class TestFusedAttn:
"""
Fused attention tester
"""
@staticmethod
@pytest.mark.parametrize('is_training', [
pytest.param(True, id='TRAINING'),
pytest.param(False, id='INFERENCE'),
])
def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
dtype, is_training, qkv_layout, bias_shape):
@pytest.mark.parametrize(
"is_training",
[
pytest.param(True, id="TRAINING"),
pytest.param(False, id="INFERENCE"),
],
)
def test_forward(
b,
s_q,
s_kv,
h_q,
h_kv,
d,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
):
"""
Test forward with parameterized configs
"""
runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, is_training, qkv_layout, bias_shape)
runner = FusedAttnRunner(
b,
s_q,
s_kv,
h_q,
h_kv,
d,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
)
runner.test_forward()
@staticmethod
def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
dtype, qkv_layout, bias_shape):
def test_backward(
b,
s_q,
s_kv,
h_q,
h_kv,
d,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
qkv_layout,
bias_shape,
):
"""
Test backward with parameterized configs
"""
runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, True, qkv_layout, bias_shape)
runner = FusedAttnRunner(
b,
s_q,
s_kv,
h_q,
h_kv,
d,
attn_bias_type,
attn_mask_type,
dropout_prob,
dtype,
True,
qkv_layout,
bias_shape,
)
runner.test_backward()
......@@ -27,21 +27,28 @@ class TestFP8Helper(unittest.TestCase):
fp8_format = FP8Format.E4M3
amax_history_len = 10
FP8Helper.initialize(margin=margin,
fp8_format=fp8_format,
amax_history_len=amax_history_len)
FP8Helper.initialize(
margin=margin, fp8_format=fp8_format, amax_history_len=amax_history_len
)
self.assertEqual(
FP8Helper.MARGIN, margin, f"FP8Helper.MARGIN initialization failed, should be {margin}"
f" but got {FP8Helper.MARGIN}.")
FP8Helper.MARGIN,
margin,
f"FP8Helper.MARGIN initialization failed, should be {margin}"
f" but got {FP8Helper.MARGIN}.",
)
self.assertEqual(
FP8Helper.FP8_FORMAT, fp8_format,
FP8Helper.FP8_FORMAT,
fp8_format,
f"FP8Helper.FP8_FORMAT initialization failed, should be {fp8_format}"
f" but got {FP8Helper.FP8_FORMAT}.")
f" but got {FP8Helper.FP8_FORMAT}.",
)
self.assertEqual(
FP8Helper.AMAX_HISTORY_LEN, amax_history_len,
FP8Helper.AMAX_HISTORY_LEN,
amax_history_len,
f"FP8Helper.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
f" but got {FP8Helper.AMAX_HISTORY_LEN}.")
f" but got {FP8Helper.AMAX_HISTORY_LEN}.",
)
FP8Helper.finalize()
......@@ -77,7 +84,7 @@ class TestFP8Functions(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests.
FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
......@@ -102,21 +109,21 @@ class TestFP8Functions(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_with_sharding_resource(self):
FP8Helper.finalize() # Ensure the testing not affect by previous tests.
FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
mesh_s = (
(MeshResource(None, None)),
(MeshResource('dp', None)),
(MeshResource(None, 'tp')),
(MeshResource('dp', 'tp')),
(MeshResource("dp", None)),
(MeshResource(None, "tp")),
(MeshResource("dp", "tp")),
)
# TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme
mesh_shape = (1, 1)
devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
with jax.sharding.Mesh(devices, ('dp', 'tp')):
with jax.sharding.Mesh(devices, ("dp", "tp")):
for sr in mesh_s:
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr):
self.assertTrue(FP8Helper.is_fp8_enabled())
......
......@@ -22,7 +22,7 @@ from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
is_fp8_supported, reason = is_fp8_available()
@pytest.fixture(autouse=True, scope='function')
@pytest.fixture(autouse=True, scope="function")
def enable_fused_attn():
"""Enable fused attention"""
os.environ["NVTE_FUSED_ATTN"] = "1"
......@@ -30,9 +30,9 @@ def enable_fused_attn():
del os.environ["NVTE_FUSED_ATTN"]
DATA_SHAPE = [ # (batch, seqlen, emb_dim)
pytest.param((32, 128, 1024), id='32-128-1024'),
pytest.param((32, 512, 1024), id='32-512-1024'),
DATA_SHAPE = [ # (batch, seqlen, emb_dim)
pytest.param((32, 128, 1024), id="32-128-1024"),
pytest.param((32, 512, 1024), id="32-512-1024"),
]
DTYPE = [jnp.float32, jnp.bfloat16]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]
......@@ -69,123 +69,138 @@ BASE_ATTRS = {
_KEY_OF_ATTENTION_DROPOUT: 0,
_KEY_OF_INTERMEDIATE_DROPOUT: 0,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_LAYERNORM_TYPE: "layernorm",
}
ATTRS = [{}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
}, {
_KEY_OF_ZERO_CENTERED_GAMMA: True,
_KEY_OF_LAYERNORM_EPS: 1e-2,
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_RESIDUAL_POST_LAYERNORM: True
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_OUTPUT_LAYERNORM: True
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_RESIDUAL_POST_LAYERNORM: True,
_KEY_OF_OUTPUT_LAYERNORM: True
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_DROP_PATH: 0.1
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_FUSE_QKV_PARAMS: False
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'),
}, {
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_HIDDEN_DROPOUT: 0.8,
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
_KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'),
_KEY_OF_USE_BIAS: True,
}, {
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'),
}, {
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_MLP_ACTIVATIONS: ('gelu',),
_KEY_OF_USE_BIAS: True,
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')),
}, {
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_HIDDEN_DROPOUT: 0.8,
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
_KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')),
_KEY_OF_USE_BIAS: True,
}, {
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')),
}, {
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_MLP_ACTIVATIONS: (('silu',)),
_KEY_OF_USE_BIAS: True,
}, {
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_NUM_GQA_GROUPS: 1,
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "consecutive",
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
}, {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "consecutive",
_KEY_OF_USE_BIAS: True,
}, {
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_NUM_GQA_GROUPS: 2,
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "alternate",
_KEY_OF_USE_BIAS: True,
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
}, {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "alternate",
_KEY_OF_USE_BIAS: True,
}, {
_KEY_OF_HIDDEN_DROPOUT: 0.3,
_KEY_OF_HIDDEN_DROPOUT_DIMS: (0,),
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,),
}, {
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_USE_BIAS: True,
}, {
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias",
}, {
_KEY_OF_ATTENTION_DROPOUT: 0.3,
}, {
_KEY_OF_MLP_ACTIVATIONS: (('relu', 'relu')),
}]
ATTRS = [
{},
{
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
},
{
_KEY_OF_ZERO_CENTERED_GAMMA: True,
_KEY_OF_LAYERNORM_EPS: 1e-2,
},
{_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_RESIDUAL_POST_LAYERNORM: True},
{_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_OUTPUT_LAYERNORM: True},
{
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_RESIDUAL_POST_LAYERNORM: True,
_KEY_OF_OUTPUT_LAYERNORM: True,
},
{_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_DROP_PATH: 0.1},
{_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_FUSE_QKV_PARAMS: False},
{
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
},
{
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_HIDDEN_DROPOUT: 0.8,
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
_KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
_KEY_OF_USE_BIAS: True,
},
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
},
{
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_MLP_ACTIVATIONS: ("gelu",),
_KEY_OF_USE_BIAS: True,
},
{
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")),
},
{
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_HIDDEN_DROPOUT: 0.8,
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
_KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")),
_KEY_OF_USE_BIAS: True,
},
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")),
},
{
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: "layernorm",
_KEY_OF_MLP_ACTIVATIONS: (("silu",)),
_KEY_OF_USE_BIAS: True,
},
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_NUM_GQA_GROUPS: 1,
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "consecutive",
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
{
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "consecutive",
_KEY_OF_USE_BIAS: True,
},
{
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: "layernorm",
_KEY_OF_NUM_GQA_GROUPS: 2,
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "alternate",
_KEY_OF_USE_BIAS: True,
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
},
{
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_LAYERNORM_TYPE: "rmsnorm",
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "alternate",
_KEY_OF_USE_BIAS: True,
},
{
_KEY_OF_HIDDEN_DROPOUT: 0.3,
_KEY_OF_HIDDEN_DROPOUT_DIMS: (0,),
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,),
},
{
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_USE_BIAS: True,
},
{
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias",
},
{
_KEY_OF_ATTENTION_DROPOUT: 0.3,
},
{
_KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")),
},
]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
class BaseRunner:
"""Base runner to define forward and backward tests"""
layer_type: TransformerLayerType = None
reference_layer: flax.linen.Module = None
transformations: Dict[str, str] = None
......@@ -194,24 +209,24 @@ class BaseRunner:
self.attrs = attrs
self._generate_test_rngs()
# Disable fused attention for attention dropout because the different dropout impl
if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv('NVTE_FUSED_ATTN'):
os.environ['NVTE_FUSED_ATTN'] = "0"
if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv("NVTE_FUSED_ATTN"):
os.environ["NVTE_FUSED_ATTN"] = "0"
def _generate_test_rngs(self):
root_rng = jax.random.PRNGKey(0)
params_rng, init_dropout_rng, apply_dropout_rng = jax.random.split(root_rng, 3)
self.init_rng = {'params': params_rng, 'dropout': init_dropout_rng}
self.apply_rng = {'dropout': apply_dropout_rng}
self.init_rng = {"params": params_rng, "dropout": init_dropout_rng}
self.apply_rng = {"dropout": apply_dropout_rng}
def _generate_layer(self, layer_cls, diff_inputs, no_diff_inputs):
layer = layer_cls()
variables = layer.init(self.init_rng, *diff_inputs, *no_diff_inputs)
others, params = flax.core.pop(variables, 'params')
others, params = flax.core.pop(variables, "params")
del variables
return layer, params, others
def _loss_fn(self, diff_xs, no_diff_xs, params, others, model):
variables = {'params': params, **others}
variables = {"params": params, **others}
output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng)
return jnp.mean(output, dtype=jnp.float32).astype(output.dtype)
......@@ -259,15 +274,18 @@ class BaseRunner:
)
_, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
test_others = FP8Helper.update_collections(
{FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others)
{FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others
)
del tmp_grad, fp8_meta_grad
grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False)
ref_out, (ref_dgrads, ref_wgrads) = grad_fn(inputs, ref_masks, ref_params, ref_others,
ref_layer)
test_out, (test_dgrads, test_wgrads) = grad_fn(inputs, test_masks, test_params, test_others,
test_layer)
ref_out, (ref_dgrads, ref_wgrads) = grad_fn(
inputs, ref_masks, ref_params, ref_others, ref_layer
)
test_out, (test_dgrads, test_wgrads) = grad_fn(
inputs, test_masks, test_params, test_others, test_layer
)
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_tree_like_allclose(ref_dgrads, test_dgrads, rtol=rtol, atol=atol)
......@@ -278,19 +296,20 @@ class BaseRunner:
class EncoderRunner(BaseRunner):
"""Encoder runner implementations"""
layer_type = TransformerLayerType.ENCODER
reference_layer = RefEncoderLayer
transformations = {
'attention/qkv/scale': 'pre_attention_layer_norm/scale',
'attention/qkv/ln_bias': 'pre_attention_layer_norm/ln_bias',
'attention/query/scale': 'pre_attention_layer_norm/scale',
'attention/query/ln_bias': 'pre_attention_layer_norm/ln_bias',
'mlp/wi_kernel': 'mlp/wi/kernel',
'mlp/wi_bias': 'mlp/wi/bias',
'mlp/wo_kernel': 'mlp/wo/kernel',
'mlp/wo_bias': 'mlp/wo/bias',
'mlp/scale': 'pre_mlp_layer_norm/scale',
'mlp/ln_bias': 'pre_mlp_layer_norm/ln_bias',
"attention/qkv/scale": "pre_attention_layer_norm/scale",
"attention/qkv/ln_bias": "pre_attention_layer_norm/ln_bias",
"attention/query/scale": "pre_attention_layer_norm/scale",
"attention/query/ln_bias": "pre_attention_layer_norm/ln_bias",
"mlp/wi_kernel": "mlp/wi/kernel",
"mlp/wi_bias": "mlp/wi/bias",
"mlp/wo_kernel": "mlp/wo/kernel",
"mlp/wo_bias": "mlp/wo/bias",
"mlp/scale": "pre_mlp_layer_norm/scale",
"mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
}
def generate_inputs(self, data_shape, dtype):
......@@ -307,13 +326,13 @@ class EncoderRunner(BaseRunner):
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ['casual', 'padding_causal']:
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["casual", "padding_causal"]:
mask = causal_mask
else:
mask = padded_mask
ref_masks = (1 - mask,)
test_masks = (None, mask) # The second arg of Transformer is encoded tokens.
test_masks = (None, mask) # The second arg of Transformer is encoded tokens.
return inputs, (ref_masks, test_masks)
......@@ -322,23 +341,24 @@ class DecoderRunner(BaseRunner):
"""
Decoder runner implementations
"""
layer_type = TransformerLayerType.DECODER
reference_layer = RefDecoderLayer
transformations = {
'encoder_decoder_attention/qkv/scale': 'pre_cross_attention_layer_norm/scale',
'encoder_decoder_attention/qkv/ln_bias': 'pre_cross_attention_layer_norm/ln_bias',
'encoder_decoder_attention/query/scale': 'pre_cross_attention_layer_norm/scale',
'encoder_decoder_attention/query/ln_bias': 'pre_cross_attention_layer_norm/ln_bias',
'self_attention/qkv/scale': 'pre_self_attention_layer_norm/scale',
'self_attention/qkv/ln_bias': 'pre_self_attention_layer_norm/ln_bias',
'self_attention/query/scale': 'pre_self_attention_layer_norm/scale',
'self_attention/query/ln_bias': 'pre_self_attention_layer_norm/ln_bias',
'mlp/wi_kernel': 'mlp/wi/kernel',
'mlp/wi_bias': 'mlp/wi/bias',
'mlp/wo_kernel': 'mlp/wo/kernel',
'mlp/wo_bias': 'mlp/wo/bias',
'mlp/scale': 'pre_mlp_layer_norm/scale',
'mlp/ln_bias': 'pre_mlp_layer_norm/ln_bias',
"encoder_decoder_attention/qkv/scale": "pre_cross_attention_layer_norm/scale",
"encoder_decoder_attention/qkv/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
"encoder_decoder_attention/query/scale": "pre_cross_attention_layer_norm/scale",
"encoder_decoder_attention/query/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
"self_attention/qkv/scale": "pre_self_attention_layer_norm/scale",
"self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias",
"self_attention/query/scale": "pre_self_attention_layer_norm/scale",
"self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias",
"mlp/wi_kernel": "mlp/wi/kernel",
"mlp/wi_bias": "mlp/wi/bias",
"mlp/wo_kernel": "mlp/wo/kernel",
"mlp/wo_bias": "mlp/wo/bias",
"mlp/scale": "pre_mlp_layer_norm/scale",
"mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
}
def generate_inputs(self, data_shape, dtype):
......@@ -352,12 +372,14 @@ class DecoderRunner(BaseRunner):
data_rng = jax.random.PRNGKey(0)
data_rng_0, data_rng_1 = jax.random.split(data_rng, 2)
inputs = (jax.random.normal(data_rng_0, data_shape,
dtype), jax.random.normal(data_rng_1, data_shape, dtype))
inputs = (
jax.random.normal(data_rng_0, data_shape, dtype),
jax.random.normal(data_rng_1, data_shape, dtype),
)
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ['casual', 'padding_causal']:
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["casual", "padding_causal"]:
self_mask = causal_mask
else:
self_mask = padded_mask
......@@ -368,27 +390,28 @@ class DecoderRunner(BaseRunner):
return inputs, (ref_masks, test_masks)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', ATTRS)
class BaseTester():
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", ATTRS)
class BaseTester:
"""
Pytest interface to invoke the runner
"""
runner = BaseRunner
def test_forward(self, data_shape, dtype, attrs):
"""Test normal datatype forward"""
FP8Helper.finalize() # Ensure FP8 disabled.
FP8Helper.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-5, atol=7e-5)
def test_backward(self, data_shape, dtype, attrs):
"""Test normal datatype backward"""
FP8Helper.finalize() # Ensure FP8 disabled.
FP8Helper.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-5, atol=7e-5)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
@pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_format):
"""Test forward with fp8 enabled"""
FP8Helper.initialize(fp8_format=fp8_format)
......@@ -396,7 +419,7 @@ class BaseTester():
FP8Helper.finalize()
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
@pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_format):
"""Test backward with fp8 enabled"""
FP8Helper.initialize(fp8_format=fp8_format)
......@@ -408,6 +431,7 @@ class TestEncoderLayer(BaseTester):
"""
Test transformer_engine.jax.flax.TransformerLayer(layer_type=Encoder)
"""
runner = EncoderRunner
......@@ -415,4 +439,5 @@ class TestDecoderLayer(BaseTester):
"""
Test transformer_engine.jax.flax.TransformerLayer(layer_type=Decoder)
"""
runner = DecoderRunner
......@@ -37,13 +37,13 @@ from transformer_engine.jax.softmax import SoftmaxType
is_fp8_supported, reason = is_fp8_available()
DATA_SHAPE = [(32, 128, 512), (32, 512, 512)] # (B, S, H)
DATA_SHAPE = [(32, 128, 512), (32, 512, 512)] # (B, S, H)
DTYPE = [jnp.float32, jnp.bfloat16]
ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]
@pytest.fixture(autouse=True, scope='module')
@pytest.fixture(autouse=True, scope="module")
def enable_fused_attn():
"""
Enable fused attn for hopper+ arch.
......@@ -58,19 +58,16 @@ def enable_fused_attn():
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd:
assert key in test_fd, \
f"{key} not found in test dict {test_fd}"
assert isinstance(test_fd[key], type(ref_fd[key])), \
f"The data type is not match between ref and test " \
f" Dict on {key=}"
assert key in test_fd, f"{key} not found in test dict {test_fd}"
assert isinstance(
test_fd[key], type(ref_fd[key])
), f"The data type is not match between ref and test Dict on {key=}"
if isinstance(ref_fd[key], Dict):
compare_dict(ref_fd[key], test_fd[key], rtol, atol)
else:
assert_allclose(ref_fd[key],
test_fd[key],
rtol=rtol,
atol=atol,
err_msg=f"{key=} is not close")
assert_allclose(
ref_fd[key], test_fd[key], rtol=rtol, atol=atol, err_msg=f"{key=} is not close"
)
class TestLayer:
......@@ -105,9 +102,10 @@ class TestLayer:
lyr_name = self.get_layer_name()
if 'params' in flax_variables:
synced_praxis_variables['params'][lyr_name]['cld'] = \
flax.core.unfreeze(flax_variables['params'])
if "params" in flax_variables:
synced_praxis_variables["params"][lyr_name]["cld"] = flax.core.unfreeze(
flax_variables["params"]
)
return synced_praxis_variables, flax_variables
......@@ -116,23 +114,19 @@ class TestLayer:
lyr_name = self.get_layer_name()
if 'params' in synced_praxis_grads:
synced_praxis_grads['params'] = \
synced_praxis_grads['params'][lyr_name]['cld']
if "params" in synced_praxis_grads:
synced_praxis_grads["params"] = synced_praxis_grads["params"][lyr_name]["cld"]
if FP8Helper.is_fp8_enabled():
synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = \
synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME][lyr_name]['cld']
synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = synced_praxis_grads[
FP8Helper.FP8_COLLECTION_NAME
][lyr_name]["cld"]
return synced_praxis_grads, flax.core.unfreeze(flax_wgrads)
def forward_backward_runner(self,
data_shape,
dtype,
praxis_p,
flax_cls,
rtol=1e-05,
atol=1e-08):
def forward_backward_runner(
self, data_shape, dtype, praxis_p, flax_cls, rtol=1e-05, atol=1e-08
):
init_key = jax.random.PRNGKey(seed=1234)
test_inputs = self.input_getter(data_shape, dtype)
......@@ -148,28 +142,33 @@ class TestLayer:
if "params_axes" in flax_variables:
flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
if FP8Helper.is_fp8_enabled():
flax_variables, _ = flax.core.pop(flax_variables,
FP8Helper.FP8_COLLECTION_NAME + "_axes")
flax_variables, _ = flax.core.pop(
flax_variables, FP8Helper.FP8_COLLECTION_NAME + "_axes"
)
praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)
iter_times = 5 if FP8Helper.is_fp8_enabled() else 1
for _ in range(iter_times):
praxis_loss, praxis_wgrads, praxis_dgrad = \
TestLayer.loss_and_grads(praxis_layer, praxis_variables, *test_inputs)
flax_loss, flax_wgrads, flax_dgrad = \
TestLayer.loss_and_grads(flax_layer, flax_variables, *test_inputs)
praxis_loss, praxis_wgrads, praxis_dgrad = TestLayer.loss_and_grads(
praxis_layer, praxis_variables, *test_inputs
)
flax_loss, flax_wgrads, flax_dgrad = TestLayer.loss_and_grads(
flax_layer, flax_variables, *test_inputs
)
if FP8Helper.is_fp8_enabled():
praxis_wgrads.pop('params')
praxis_wgrads.pop("params")
praxis_variables = update_collections(praxis_wgrads, praxis_variables)
flax_wgrads, _ = flax.core.pop(flax_wgrads, 'params')
flax_wgrads, _ = flax.core.pop(flax_wgrads, "params")
flax_variables = update_collections(flax_wgrads, flax_variables)
praxis_loss, praxis_wgrads, praxis_dgrad = \
TestLayer.loss_and_grads(praxis_layer, praxis_variables, *test_inputs)
flax_loss, flax_wgrads, flax_dgrad = \
TestLayer.loss_and_grads(flax_layer, flax_variables, *test_inputs)
praxis_loss, praxis_wgrads, praxis_dgrad = TestLayer.loss_and_grads(
praxis_layer, praxis_variables, *test_inputs
)
flax_loss, flax_wgrads, flax_dgrad = TestLayer.loss_and_grads(
flax_layer, flax_variables, *test_inputs
)
assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)
assert_allclose(praxis_dgrad, flax_dgrad, rtol=rtol, atol=atol)
......@@ -179,18 +178,13 @@ class TestLayer:
class LayerNormAttr:
LN_TYPE = 'layernorm_type'
ZERO_CEN = 'zero_centered_gamma'
ATTRS = [{
LN_TYPE: "layernorm",
ZERO_CEN: False
}, {
LN_TYPE: "layernorm",
ZERO_CEN: True
}, {
LN_TYPE: "rmsnorm",
ZERO_CEN: False
}]
LN_TYPE = "layernorm_type"
ZERO_CEN = "zero_centered_gamma"
ATTRS = [
{LN_TYPE: "layernorm", ZERO_CEN: False},
{LN_TYPE: "layernorm", ZERO_CEN: True},
{LN_TYPE: "rmsnorm", ZERO_CEN: False},
]
class TestLayerNorm(TestLayer):
......@@ -200,7 +194,7 @@ class TestLayerNorm(TestLayer):
return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self):
return 'layer_norm'
return "layer_norm"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
layernorm_type = attrs[LayerNormAttr.LN_TYPE]
......@@ -209,63 +203,59 @@ class TestLayerNorm(TestLayer):
bias_init = WeightInit.Constant(0.0)
transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(LayerNorm,
name='layer_norm',
dtype=dtype,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
scale_init=scale_init,
bias_init=bias_init,
transpose_batch_sequence=transpose_batch_sequence)
flax_cls = partial(flax_LayerNorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
scale_init=scale_init,
bias_init=TransformerEngineBaseLayer.generate_params_init(
"ln_bias", bias_init),
dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence)
praxis_p = pax_fiddle.Config(
LayerNorm,
name="layer_norm",
dtype=dtype,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
scale_init=scale_init,
bias_init=bias_init,
transpose_batch_sequence=transpose_batch_sequence,
)
flax_cls = partial(
flax_LayerNorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
scale_init=scale_init,
bias_init=TransformerEngineBaseLayer.generate_params_init("ln_bias", bias_init),
dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', LayerNormAttr.ATTRS)
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", LayerNormAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class FusedSoftmaxAttr:
SCALE_FACTOR = 'scale_factor'
ST_TYPE = 'softmax_type'
ATTRS = [{
SCALE_FACTOR: 0.0,
ST_TYPE: SoftmaxType.SCALED
}, {
SCALE_FACTOR: 0.0,
ST_TYPE: SoftmaxType.SCALED_MASKED
}, {
SCALE_FACTOR: 0.0,
ST_TYPE: SoftmaxType.SCALED_UPPER_TRIANG_MASKED
}]
SCALE_FACTOR = "scale_factor"
ST_TYPE = "softmax_type"
ATTRS = [
{SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED},
{SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED_MASKED},
{SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED_UPPER_TRIANG_MASKED},
]
class TestFusedSoftmax(TestLayer):
def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return jax.random.normal(data_key, shape, dtype), \
jnp.ones(shape, dtype=jnp.uint8) # Masks
return jax.random.normal(data_key, shape, dtype), jnp.ones(shape, dtype=jnp.uint8) # Masks
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
scale_factor = attrs[FusedSoftmaxAttr.SCALE_FACTOR]
softmax_type = attrs[FusedSoftmaxAttr.ST_TYPE]
praxis_p = pax_fiddle.Config(FusedSoftmax,
name='fused_softmax',
scale_factor=scale_factor,
softmax_type=softmax_type)
praxis_p = pax_fiddle.Config(
FusedSoftmax, name="fused_softmax", scale_factor=scale_factor, softmax_type=softmax_type
)
flax_cls = partial(Softmax, scale_factor=scale_factor, softmax_type=softmax_type)
return praxis_p, flax_cls
......@@ -276,34 +266,28 @@ class TestFusedSoftmax(TestLayer):
def sync_wgrads(self, praxis_wgrads, flax_wgrads):
return praxis_wgrads, flax_wgrads
@pytest.mark.parametrize('data_shape', [(32, 1, 128, 128), (32, 1, 512, 128)])
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', FusedSoftmaxAttr.ATTRS)
@pytest.mark.parametrize("data_shape", [(32, 1, 128, 128), (32, 1, 512, 128)])
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", FusedSoftmaxAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
if (attrs[FusedSoftmaxAttr.ST_TYPE] == SoftmaxType.SCALED_UPPER_TRIANG_MASKED) and \
(data_shape[-2] != data_shape[-1]):
pass # Skip, due to not support
if (attrs[FusedSoftmaxAttr.ST_TYPE] == SoftmaxType.SCALED_UPPER_TRIANG_MASKED) and (
data_shape[-2] != data_shape[-1]
):
pass # Skip, due to not support
else:
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class LinearAttr:
FEATURE = 'features'
USE_BIAS = 'use_bias'
ATTRS = [{
FEATURE: 512,
USE_BIAS: False
}, {
FEATURE: 512,
USE_BIAS: True
}, {
FEATURE: 1024,
USE_BIAS: False
}, {
FEATURE: 1024,
USE_BIAS: True
}]
FEATURE = "features"
USE_BIAS = "use_bias"
ATTRS = [
{FEATURE: 512, USE_BIAS: False},
{FEATURE: 512, USE_BIAS: True},
{FEATURE: 1024, USE_BIAS: False},
{FEATURE: 1024, USE_BIAS: True},
]
class TestLinear(TestLayer):
......@@ -313,7 +297,7 @@ class TestLinear(TestLayer):
return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self):
return 'linear'
return "linear"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
out_features = attrs[LinearAttr.FEATURE]
......@@ -323,15 +307,17 @@ class TestLinear(TestLayer):
axis = -1
transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(Linear,
name='linear',
dtype=dtype,
out_features=out_features,
params_init=kernel_init,
use_bias=use_bias,
bias_init=bias_init,
axis=axis,
transpose_batch_sequence=transpose_batch_sequence)
praxis_p = pax_fiddle.Config(
Linear,
name="linear",
dtype=dtype,
out_features=out_features,
params_init=kernel_init,
use_bias=use_bias,
bias_init=bias_init,
axis=axis,
transpose_batch_sequence=transpose_batch_sequence,
)
flax_cls = partial(
DenseGeneral,
features=out_features,
......@@ -340,29 +326,26 @@ class TestLinear(TestLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
axis=axis,
dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence)
transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', LinearAttr.ATTRS)
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", LinearAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', LinearAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
def test_forward_backward_fp8(self,
data_shape,
dtype,
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", LinearAttr.ATTRS)
@pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(
self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
):
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
......@@ -371,54 +354,20 @@ class TestLinear(TestLayer):
class LayerNormLinearAttr:
FEATURE = 'features'
USE_BIAS = 'use_bias'
ENABLE_LN = 'enable_layernorm'
LN_TYPE = 'layernorm_type'
ZERO_CEN = 'zero_centered_gamma'
ATTRS = [{
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False
}, {
FEATURE: 512,
USE_BIAS: True,
ENABLE_LN: False,
LN_TYPE: 'layernorm',
ZERO_CEN: False
}]
FEATURE = "features"
USE_BIAS = "use_bias"
ENABLE_LN = "enable_layernorm"
LN_TYPE = "layernorm_type"
ZERO_CEN = "zero_centered_gamma"
ATTRS = [
{FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: False},
{FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: False},
{FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: True},
{FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: True},
{FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "rmsnorm", ZERO_CEN: False},
{FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "rmsnorm", ZERO_CEN: False},
{FEATURE: 512, USE_BIAS: True, ENABLE_LN: False, LN_TYPE: "layernorm", ZERO_CEN: False},
]
class TestLayerNormLinear(TestLayer):
......@@ -428,7 +377,7 @@ class TestLayerNormLinear(TestLayer):
return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self):
return 'ln_linear'
return "ln_linear"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
out_features = attrs[LayerNormLinearAttr.FEATURE]
......@@ -441,18 +390,20 @@ class TestLayerNormLinear(TestLayer):
axis = -1
transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(LayerNormLinear,
name='ln_linear',
dtype=dtype,
out_features=out_features,
enable_layernorm=enable_layernorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
params_init=kernel_init,
use_bias=use_bias,
bias_init=bias_init,
axis=axis,
transpose_batch_sequence=transpose_batch_sequence)
praxis_p = pax_fiddle.Config(
LayerNormLinear,
name="ln_linear",
dtype=dtype,
out_features=out_features,
enable_layernorm=enable_layernorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
params_init=kernel_init,
use_bias=use_bias,
bias_init=bias_init,
axis=axis,
transpose_batch_sequence=transpose_batch_sequence,
)
flax_cls = partial(
LayerNormDenseGeneral,
features=out_features,
......@@ -464,29 +415,26 @@ class TestLayerNormLinear(TestLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
axis=axis,
dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence)
transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', LayerNormLinearAttr.ATTRS)
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", LayerNormLinearAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', LayerNormLinearAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
def test_forward_backward_fp8(self,
data_shape,
dtype,
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", LayerNormLinearAttr.ATTRS)
@pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(
self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
):
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
......@@ -495,62 +443,70 @@ class TestLayerNormLinear(TestLayer):
class LayerNormMLPAttr:
INTERMEDIATE_DIM = 'intermediate_dim'
USE_BIAS = 'use_bias'
ENABLE_LN = 'enable_layernorm'
LN_TYPE = 'layernorm_type'
ZERO_CEN = 'zero_centered_gamma'
ACTIVATION = 'activations'
ATTRS = [{
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('relu',)
}, {
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('relu',)
}, {
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('relu',)
}, {
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear')
}, {
INTERMEDIATE_DIM: 2048,
USE_BIAS: False,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear')
}, {
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('silu', 'linear')
}, {
INTERMEDIATE_DIM: 2048,
USE_BIAS: False,
ENABLE_LN: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('silu', 'linear')
}]
INTERMEDIATE_DIM = "intermediate_dim"
USE_BIAS = "use_bias"
ENABLE_LN = "enable_layernorm"
LN_TYPE = "layernorm_type"
ZERO_CEN = "zero_centered_gamma"
ACTIVATION = "activations"
ATTRS = [
{
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
},
{
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("relu",),
},
{
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
},
{
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
},
{
INTERMEDIATE_DIM: 2048,
USE_BIAS: False,
ENABLE_LN: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
},
{
INTERMEDIATE_DIM: 2048,
USE_BIAS: True,
ENABLE_LN: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("silu", "linear"),
},
{
INTERMEDIATE_DIM: 2048,
USE_BIAS: False,
ENABLE_LN: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("silu", "linear"),
},
]
class TestLayerNormMLP(TestLayer):
......@@ -560,7 +516,7 @@ class TestLayerNormMLP(TestLayer):
return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self):
return 'ln_mlp'
return "ln_mlp"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
intermediate_dim = attrs[LayerNormMLPAttr.INTERMEDIATE_DIM]
......@@ -574,20 +530,22 @@ class TestLayerNormMLP(TestLayer):
axis = -1
transpose_batch_sequence = False
praxis_p = pax_fiddle.Config(LayerNormMLP,
name='ln_mlp',
dtype=dtype,
intermediate_dim=intermediate_dim,
enable_layernorm=enable_layernorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
params_init=kernel_init,
use_bias=use_bias,
bias_init=bias_init,
activations=activations,
intermediate_dropout_rate=0.0,
axis=axis,
transpose_batch_sequence=transpose_batch_sequence)
praxis_p = pax_fiddle.Config(
LayerNormMLP,
name="ln_mlp",
dtype=dtype,
intermediate_dim=intermediate_dim,
enable_layernorm=enable_layernorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
params_init=kernel_init,
use_bias=use_bias,
bias_init=bias_init,
activations=activations,
intermediate_dropout_rate=0.0,
axis=axis,
transpose_batch_sequence=transpose_batch_sequence,
)
flax_cls = partial(
flax_LayerNormMLP,
intermediate_dim=intermediate_dim,
......@@ -601,29 +559,26 @@ class TestLayerNormMLP(TestLayer):
intermediate_dropout_rate=0.0,
axis=axis,
dtype=dtype,
transpose_batch_sequence=transpose_batch_sequence)
transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', LayerNormMLPAttr.ATTRS)
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", LayerNormMLPAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', LayerNormMLPAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
def test_forward_backward_fp8(self,
data_shape,
dtype,
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", LayerNormMLPAttr.ATTRS)
@pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(
self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
):
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
......@@ -634,35 +589,40 @@ class TestLayerNormMLP(TestLayer):
class TestRelativePositionBias(TestLayer):
def get_layer_name(self):
return 'relative_position_bias'
return "relative_position_bias"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
num_buckets = 32
max_distance = 128
num_attention_heads = 64
rb_stddev = (num_attention_heads * num_buckets)**-0.5
rb_stddev = (num_attention_heads * num_buckets) ** -0.5
embedding_init = WeightInit.Gaussian(rb_stddev)
praxis_p = pax_fiddle.Config(RelativePositionBiases,
name='relative_position_bias',
dtype=dtype,
num_buckets=num_buckets,
max_distance=max_distance,
num_attention_heads=num_attention_heads,
embedding_init=embedding_init)
flax_cls = partial(flax_RelativePositionBiases,
num_buckets=num_buckets,
max_distance=max_distance,
num_attention_heads=num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", embedding_init),
dtype=dtype)
praxis_p = pax_fiddle.Config(
RelativePositionBiases,
name="relative_position_bias",
dtype=dtype,
num_buckets=num_buckets,
max_distance=max_distance,
num_attention_heads=num_attention_heads,
embedding_init=embedding_init,
)
flax_cls = partial(
flax_RelativePositionBiases,
num_buckets=num_buckets,
max_distance=max_distance,
num_attention_heads=num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", embedding_init
),
dtype=dtype,
)
return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', [{}])
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", [{}])
def test_forward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
......@@ -678,53 +638,64 @@ class TestRelativePositionBias(TestLayer):
if "params_axes" in flax_variables:
flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
if FP8Helper.is_fp8_enabled():
flax_variables, _ = flax.core.pop(flax_variables,
FP8Helper.FP8_COLLECTION_NAME + "_axes")
flax_variables, _ = flax.core.pop(
flax_variables, FP8Helper.FP8_COLLECTION_NAME + "_axes"
)
praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)
praxis_loss= \
TestLayer.loss(praxis_variables, *test_input, module=praxis_layer, mean_out=False)
flax_loss = \
TestLayer.loss(flax_variables, *test_input, module=flax_layer, mean_out=False)
praxis_loss = TestLayer.loss(
praxis_variables, *test_input, module=praxis_layer, mean_out=False
)
flax_loss = TestLayer.loss(
flax_variables, *test_input, module=flax_layer, mean_out=False
)
assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)
class DotProductAttnAttr:
ATTN_MASK_TYPE = 'attn_mask_type'
NUM_GQA_GROUPS = 'num_gqa_groups'
TRANSPOSE_BS = 'transpose_batch_sequence'
SCALE_FACTOR = 'scale_factor'
ATTRS = [{
ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125,
}, {
ATTN_MASK_TYPE: 'padding_causal',
TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125,
}, {
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125,
}, {
ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: False,
SCALE_FACTOR: 0.125,
}, {
ATTN_MASK_TYPE: 'padding_causal',
TRANSPOSE_BS: False,
SCALE_FACTOR: 2.,
}, {
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.,
}, {
ATTN_MASK_TYPE: 'no_mask',
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.,
}]
ATTN_MASK_TYPE = "attn_mask_type"
NUM_GQA_GROUPS = "num_gqa_groups"
TRANSPOSE_BS = "transpose_batch_sequence"
SCALE_FACTOR = "scale_factor"
ATTRS = [
{
ATTN_MASK_TYPE: "padding",
TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125,
},
{
ATTN_MASK_TYPE: "padding_causal",
TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125,
},
{
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: True,
SCALE_FACTOR: 0.125,
},
{
ATTN_MASK_TYPE: "padding",
TRANSPOSE_BS: False,
SCALE_FACTOR: 0.125,
},
{
ATTN_MASK_TYPE: "padding_causal",
TRANSPOSE_BS: False,
SCALE_FACTOR: 2.0,
},
{
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.0,
},
{
ATTN_MASK_TYPE: "no_mask",
TRANSPOSE_BS: False,
SCALE_FACTOR: 1.0,
},
]
class TestDotProductAttn(TestLayer):
......@@ -737,11 +708,12 @@ class TestDotProductAttn(TestLayer):
shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]),
mask,
]
def get_layer_name(self):
return 'dot_product_attn'
return "dot_product_attn"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
head_dim = 64
......@@ -750,27 +722,31 @@ class TestDotProductAttn(TestLayer):
attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE]
transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS]
praxis_p = pax_fiddle.Config(DotProductAttention,
name='mha',
dtype=dtype,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence)
flax_cls = partial(flax_DotProductAttention,
dtype=dtype,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence)
praxis_p = pax_fiddle.Config(
DotProductAttention,
name="mha",
dtype=dtype,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence,
)
flax_cls = partial(
flax_DotProductAttention,
dtype=dtype,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
attn_mask_type=attn_mask_type,
transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', [(32, 128, 16, 64)])
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS)
@pytest.mark.parametrize("data_shape", [(32, 128, 16, 64)])
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", DotProductAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
......@@ -778,113 +754,125 @@ class TestDotProductAttn(TestLayer):
class MultiHeadAttnAttr:
USE_BIAS = 'use_bias'
LN_TYPE = 'layernorm_type'
ATTN_MASK_TYPE = 'attn_mask_type'
ZERO_CEN = 'zero_centered_gamma'
NUM_ATTN_HEADS = 'num_attention_heads'
NUM_GQA_GROUPS = 'num_gqa_groups'
TRANSPOSE_BS = 'transpose_batch_sequence'
ENABLE_ROPE = 'enable_rotary_pos_emb'
ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
LORA_SCOPE = 'low_rank_adaptation_scope'
ATTRS = [{
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: True,
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: False,
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding',
TRANSPOSE_BS: True,
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: False,
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: True,
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: False,
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: True,
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'consecutive',
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: False,
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'alternate',
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal',
TRANSPOSE_BS: True,
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding',
LORA_SCOPE: 'all',
TRANSPOSE_BS: False,
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal',
LORA_SCOPE: 'all',
TRANSPOSE_BS: True,
}]
USE_BIAS = "use_bias"
LN_TYPE = "layernorm_type"
ATTN_MASK_TYPE = "attn_mask_type"
ZERO_CEN = "zero_centered_gamma"
NUM_ATTN_HEADS = "num_attention_heads"
NUM_GQA_GROUPS = "num_gqa_groups"
TRANSPOSE_BS = "transpose_batch_sequence"
ENABLE_ROPE = "enable_rotary_pos_emb"
ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
LORA_SCOPE = "low_rank_adaptation_scope"
ATTRS = [
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "padding",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "padding",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "padding",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: "consecutive",
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: "alternate",
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: "causal",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "padding",
LORA_SCOPE: "all",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
ATTN_MASK_TYPE: "causal",
LORA_SCOPE: "all",
TRANSPOSE_BS: True,
},
]
class TestMultiHeadAttn(TestLayer):
......@@ -899,13 +887,16 @@ class TestMultiHeadAttn(TestLayer):
return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask]
def get_layer_name(self):
return 'multi_head_attn'
return "multi_head_attn"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
head_dim = 64
num_attention_heads = 16
num_gqa_groups = attrs[MultiHeadAttnAttr.NUM_GQA_GROUPS] \
if MultiHeadAttnAttr.NUM_GQA_GROUPS in attrs else None
num_gqa_groups = (
attrs[MultiHeadAttnAttr.NUM_GQA_GROUPS]
if MultiHeadAttnAttr.NUM_GQA_GROUPS in attrs
else None
)
layernorm_type = attrs[MultiHeadAttnAttr.LN_TYPE]
zero_centered_gamma = attrs[MultiHeadAttnAttr.ZERO_CEN]
kernel_init = WeightInit.Gaussian(1.0)
......@@ -916,35 +907,37 @@ class TestMultiHeadAttn(TestLayer):
attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE]
enable_rotary_pos_emb = attrs[MultiHeadAttnAttr.ENABLE_ROPE]
rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD]
low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, 'none')
low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, "none")
fuse_qkv_params = True
transpose_batch_sequence = attrs[MultiHeadAttnAttr.TRANSPOSE_BS]
scale_attn_logits = False
scaled_query_init = True
float32_logits = False
praxis_p = pax_fiddle.Config(MultiHeadAttention,
name='mha',
dtype=dtype,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
params_init=kernel_init,
use_bias=use_bias,
bias_init=bias_init,
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
attn_mask_type=attn_mask_type,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
low_rank_adaptation_scope=low_rank_adaptation_scope,
fuse_qkv_params=fuse_qkv_params,
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init,
float32_logits=float32_logits)
praxis_p = pax_fiddle.Config(
MultiHeadAttention,
name="mha",
dtype=dtype,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_gqa_groups=num_gqa_groups,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
params_init=kernel_init,
use_bias=use_bias,
bias_init=bias_init,
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
attn_mask_type=attn_mask_type,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
low_rank_adaptation_scope=low_rank_adaptation_scope,
fuse_qkv_params=fuse_qkv_params,
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init,
float32_logits=float32_logits,
)
flax_cls = partial(
flax_MultiHeadAttention,
dtype=dtype,
......@@ -966,30 +959,27 @@ class TestMultiHeadAttn(TestLayer):
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init,
float32_logits=float32_logits)
float32_logits=float32_logits,
)
return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS)
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", MultiHeadAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
def test_forward_backward_fp8(self,
data_shape,
dtype,
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", MultiHeadAttnAttr.ATTRS)
@pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(
self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
):
self.attrs = attrs
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
......@@ -998,252 +988,279 @@ class TestMultiHeadAttn(TestLayer):
class TransformerLayerAttr:
USE_BIAS = 'use_bias'
LN_TYPE = 'layernorm_type'
ACTIVATION = 'activations'
LYR_TYPE = 'layer_type'
ZERO_CEN = 'zero_centered_gamma'
TRANSPOSE_BS = 'transpose_batch_sequence'
ENABLE_ROPE = 'enable_rotary_pos_emb'
ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
LORA_SCOPE = 'low_rank_adaptation_scope'
ATTRS = [{
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False,
LORA_SCOPE: 'all'
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'alternate',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'alternate',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False,
LORA_SCOPE: 'all'
}]
USE_BIAS = "use_bias"
LN_TYPE = "layernorm_type"
ACTIVATION = "activations"
LYR_TYPE = "layer_type"
ZERO_CEN = "zero_centered_gamma"
TRANSPOSE_BS = "transpose_batch_sequence"
ENABLE_ROPE = "enable_rotary_pos_emb"
ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
LORA_SCOPE = "low_rank_adaptation_scope"
ATTRS = [
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("relu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
LORA_SCOPE: "all",
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: True,
},
{
USE_BIAS: True,
LN_TYPE: "rmsnorm",
ZERO_CEN: False,
ACTIVATION: ("gelu", "linear"),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: "alternate",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: "alternate",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: True,
ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
},
{
USE_BIAS: True,
LN_TYPE: "layernorm",
ZERO_CEN: False,
ACTIVATION: ("gelu",),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: "consecutive",
TRANSPOSE_BS: False,
LORA_SCOPE: "all",
},
]
class TestTransformer(TestLayer):
......@@ -1256,11 +1273,13 @@ class TestTransformer(TestLayer):
shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]),
mask,
mask,
]
def get_layer_name(self):
return 'transformerlayer'
return "transformerlayer"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
hidden_size = 512
......@@ -1277,97 +1296,102 @@ class TestTransformer(TestLayer):
layer_type = attrs[TransformerLayerAttr.LYR_TYPE]
enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
rotary_pos_emb_group_method = attrs[TransformerLayerAttr.ROPE_GROUP_METHOD]
low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, 'none')
low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, "none")
enable_relative_embedding = True
relative_embedding = pax_fiddle.Config(RelativePositionBiases,
dtype=dtype,
num_attention_heads=num_attention_heads)
relative_embedding = pax_fiddle.Config(
RelativePositionBiases, dtype=dtype, num_attention_heads=num_attention_heads
)
drop_path = 0.0
transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS]
rel_embedding_init = RelativePositionBiases.generate_embedding_init(
relative_embedding.embedding_init, relative_embedding.num_attention_heads,
relative_embedding.num_buckets)
relative_embedding.embedding_init,
relative_embedding.num_attention_heads,
relative_embedding.num_buckets,
)
relative_embedding_flax_module = flax_RelativePositionBiases(
num_buckets=relative_embedding.num_buckets,
max_distance=relative_embedding.max_distance,
num_attention_heads=relative_embedding.num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", rel_embedding_init),
"rel_embedding", rel_embedding_init
),
embedding_axes=relative_embedding.embedding_axes,
dtype=relative_embedding.dtype)
praxis_p = pax_fiddle.Config(TransformerLayer,
name='transformer_layer',
params_init=kernel_init,
dtype=dtype,
hidden_size=hidden_size,
mlp_hidden_size=mlp_hidden_size,
num_attention_heads=num_attention_heads,
layernorm_type=layernorm_type,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
intermediate_dropout=intermediate_dropout,
mlp_activations=mlp_activations,
use_bias=use_bias,
bias_init=bias_init,
layer_type=layer_type,
enable_relative_embedding=enable_relative_embedding,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
low_rank_adaptation_scope=low_rank_adaptation_scope,
relative_embedding=relative_embedding,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence)
flax_cls = partial(flax_TransformerLayer,
dtype=dtype,
hidden_size=hidden_size,
mlp_hidden_size=mlp_hidden_size,
num_attention_heads=num_attention_heads,
layernorm_type=layernorm_type,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
intermediate_dropout=intermediate_dropout,
mlp_activations=mlp_activations,
mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mha_kernel", kernel_init),
mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mlp_kernel", kernel_init),
use_bias=use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init(
"bias", bias_init),
layer_type=layer_type,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
enable_relative_embedding=enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
low_rank_adaptation_scope=low_rank_adaptation_scope,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence)
dtype=relative_embedding.dtype,
)
praxis_p = pax_fiddle.Config(
TransformerLayer,
name="transformer_layer",
params_init=kernel_init,
dtype=dtype,
hidden_size=hidden_size,
mlp_hidden_size=mlp_hidden_size,
num_attention_heads=num_attention_heads,
layernorm_type=layernorm_type,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
intermediate_dropout=intermediate_dropout,
mlp_activations=mlp_activations,
use_bias=use_bias,
bias_init=bias_init,
layer_type=layer_type,
enable_relative_embedding=enable_relative_embedding,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
low_rank_adaptation_scope=low_rank_adaptation_scope,
relative_embedding=relative_embedding,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence,
)
flax_cls = partial(
flax_TransformerLayer,
dtype=dtype,
hidden_size=hidden_size,
mlp_hidden_size=mlp_hidden_size,
num_attention_heads=num_attention_heads,
layernorm_type=layernorm_type,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
intermediate_dropout=intermediate_dropout,
mlp_activations=mlp_activations,
mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mha_kernel", kernel_init
),
mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mlp_kernel", kernel_init
),
use_bias=use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
layer_type=layer_type,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
enable_relative_embedding=enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
low_rank_adaptation_scope=low_rank_adaptation_scope,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence,
)
return praxis_p, flax_cls
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS)
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", TransformerLayerAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
def test_forward_backward_fp8(self,
data_shape,
dtype,
attrs,
fp8_format,
rtol=1e-05,
atol=1e-08):
@pytest.mark.parametrize("data_shape", DATA_SHAPE)
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("attrs", TransformerLayerAttr.ATTRS)
@pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_backward_fp8(
self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
):
self.attrs = attrs
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
......
......@@ -3,4 +3,5 @@
# See LICENSE for license information.
import transformer_engine.jax
print("OK")
......@@ -8,25 +8,25 @@ from transformer_engine.jax.flax import extend_logical_axis_rules
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
LOGICAL_RULES = [
[(('a1', None), ('a2', 'ma2')), False],
[(('a1', None), ('a2', 'ma2'), ('a3', ('ma31', 'ma32'))), True],
[(('a1', None), ('a2', 'ma2'), ('a3', 'ma31'), ('a3', 'ma32')), False],
[(('a1', None), ('a2', 'ma2'), ('batch', 'batch_1200234')), True],
[(('a1', None), ('a2', 'ma2'), ('a2', 'ma1'), ('batch', 'model'), ('batch', 'data')), True],
[(("a1", None), ("a2", "ma2")), False],
[(("a1", None), ("a2", "ma2"), ("a3", ("ma31", "ma32"))), True],
[(("a1", None), ("a2", "ma2"), ("a3", "ma31"), ("a3", "ma32")), False],
[(("a1", None), ("a2", "ma2"), ("batch", "batch_1200234")), True],
[(("a1", None), ("a2", "ma2"), ("a2", "ma1"), ("batch", "model"), ("batch", "data")), True],
]
MeshS = [
MeshResource(),
MeshResource('data', None),
MeshResource(None, 'model'),
MeshResource('data', 'model')
MeshResource("data", None),
MeshResource(None, "model"),
MeshResource("data", "model"),
]
class TestShardingSideAPI:
@pytest.mark.parametrize('base_rules,need_assert', LOGICAL_RULES)
@pytest.mark.parametrize('sr', MeshS)
@pytest.mark.parametrize("base_rules,need_assert", LOGICAL_RULES)
@pytest.mark.parametrize("sr", MeshS)
def test_extend_logical_axis_rules(self, base_rules, need_assert, sr):
with global_shard_guard(sr):
try:
......
......@@ -43,6 +43,7 @@ class SoftmaxRunner:
"""
Softmax runner
"""
batch_size: int
max_seqlen_q: int
max_seqlen_kv: int
......@@ -57,14 +58,22 @@ class SoftmaxRunner:
Jax softmax as the reference
"""
if mask is not None:
logits += lax.select(mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.).astype(logits.dtype))
logits += lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.0).astype(logits.dtype),
)
return nn.softmax(logits * scale_factor)
def _is_support(self):
return is_softmax_kernel_available(self.softmax_type, self.batch_size, self.num_heads,
self.max_seqlen_q, self.max_seqlen_kv, self.dtype)
return is_softmax_kernel_available(
self.softmax_type,
self.batch_size,
self.num_heads,
self.max_seqlen_q,
self.max_seqlen_kv,
self.dtype,
)
def _setup_inputs(self):
key = jax.random.PRNGKey(0)
......@@ -73,7 +82,7 @@ class SoftmaxRunner:
logits_shape = (self.batch_size, self.num_heads, self.max_seqlen_q, self.max_seqlen_kv)
mask_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.)
self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.0)
match self.softmax_type:
case SoftmaxType.SCALED:
......@@ -81,7 +90,7 @@ class SoftmaxRunner:
case SoftmaxType.SCALED_MASKED:
self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8)
case SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
self.mask = (1. - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8)
self.mask = (1.0 - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8)
case _:
raise ValueError(f"Unknown {self.softmax_type=}")
......@@ -108,18 +117,24 @@ class SoftmaxRunner:
args = [self.logits, self.mask]
kwargs = {
'scale_factor': self.scale_factor,
'softmax_type': self.softmax_type,
"scale_factor": self.scale_factor,
"softmax_type": self.softmax_type,
}
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
value_and_grad(lambda logits, *args: grad_func(softmax, self.logits, *args, **kwargs),
(0,)))
value_and_grad(
lambda logits, *args: grad_func(softmax, self.logits, *args, **kwargs), (0,)
)
)
jitted_reference = jit(
value_and_grad(
lambda logits, *args: grad_func(__class__.reference_softmax, self.logits, *args, **
kwargs), (0,)))
lambda logits, *args: grad_func(
__class__.reference_softmax, self.logits, *args, **kwargs
),
(0,),
)
)
primitive_out, (primitive_grad_logits,) = jitted_primitive(*args)
reference_out, (reference_grad_logits,) = jitted_reference(*args)
......@@ -128,21 +143,30 @@ class SoftmaxRunner:
assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype)
@pytest.mark.parametrize('b, s_q, s_kv, h', [
pytest.param(8, 16, 16, 16, id='8-16-16-16'),
pytest.param(8, 512, 512, 16, id='8-512-512-16'),
pytest.param(2, 8, 16384, 8, id='2-8-16384-8')
])
@pytest.mark.parametrize('scale_factor', [0.125])
@pytest.mark.parametrize('softmax_type', [
pytest.param(SoftmaxType.SCALED, id='SCALED'),
pytest.param(SoftmaxType.SCALED_MASKED, id='SCALED_MASKED'),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id='SCALED_UPPER_TRIANG_MASKED')
])
@pytest.mark.parametrize('dtype', [
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16"),
])
@pytest.mark.parametrize(
"b, s_q, s_kv, h",
[
pytest.param(8, 16, 16, 16, id="8-16-16-16"),
pytest.param(8, 512, 512, 16, id="8-512-512-16"),
pytest.param(2, 8, 16384, 8, id="2-8-16384-8"),
],
)
@pytest.mark.parametrize("scale_factor", [0.125])
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(SoftmaxType.SCALED, id="SCALED"),
pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
],
)
@pytest.mark.parametrize(
"dtype",
[
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16"),
],
)
class TestSoftmax:
"""
Test transformer_engine.jax.softmax.softmax
......
......@@ -24,8 +24,9 @@ PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = Any
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
lax.Precision]]
PrecisionLike = Union[
None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
]
Initializer = Callable[[PRNGKey, Shape, DType], Array]
......@@ -56,7 +57,7 @@ def _canonicalize_tuple(x):
def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
"""Convert a string to an activation function."""
if fn_or_string == 'linear':
if fn_or_string == "linear":
return lambda x: x
if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string)
......@@ -68,17 +69,18 @@ def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Calla
def combine_biases(*masks: Optional[Array]):
"""Combine attention biases.
Args:
*masks: set of attention bias arguments to combine, some can be None.
Args:
*masks: set of attention bias arguments to combine, some can be None.
Returns:
Combined mask, reduced by summation, returns None if no masks given.
"""
Returns:
Combined mask, reduced by summation, returns None if no masks given.
"""
masks = [m for m in masks if m is not None]
if not masks:
return None
assert all(map(lambda x: x.ndim == masks[0].ndim,
masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
assert all(
map(lambda x: x.ndim == masks[0].ndim, masks)
), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
mask, *other_masks = masks
for other_mask in other_masks:
mask = mask + other_mask
......@@ -88,7 +90,7 @@ def combine_biases(*masks: Optional[Array]):
class DotProductAttention(nn.Module):
transpose_batch_sequence: bool = True
scale_attn_logits: bool = True
dropout_rate: float = 0.
dropout_rate: float = 0.0
dtype: DType = jnp.float32
float32_logits: bool = False
"""Computes dot-product attention given query, key, and value.
......@@ -105,12 +107,14 @@ class DotProductAttention(nn.Module):
"""
@nn.compact
def __call__(self,
query: Array,
key: Array,
value: Array,
bias: Optional[Array] = None,
deterministic: bool = False):
def __call__(
self,
query: Array,
key: Array,
value: Array,
bias: Optional[Array] = None,
deterministic: bool = False,
):
"""
Args:
query: queries for calculating attention with shape of `[batch, q_length,
......@@ -127,14 +131,15 @@ class DotProductAttention(nn.Module):
Returns:
Output of shape `[batch, length, num_heads, v_depth_per_head]`.
"""
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
batch_dim = 1 if self.transpose_batch_sequence else 0
assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
'q, k, v batch dims must match.')
assert (
query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
), "q, k, v batch dims must match."
sequence_dim = 0 if self.transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'
assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
assert key.shape[-2] == value.shape[-2], "k, v num_heads must match."
assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
if self.scale_attn_logits:
head_dim = query.shape[-1]
......@@ -153,9 +158,9 @@ class DotProductAttention(nn.Module):
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
if self.transpose_batch_sequence:
attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
else:
attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
# reshape back to normal DPA shape for bias/softmax/dropout
b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
......@@ -170,37 +175,37 @@ class DotProductAttention(nn.Module):
attn_weights = jax_nn.softmax(attn_weights).astype(self.dtype)
# Apply attention dropout.
if not deterministic and self.dropout_rate > 0.:
if not deterministic and self.dropout_rate > 0.0:
keep_prob = 1.0 - self.dropout_rate
# T5 broadcasts along the "length" dim, but unclear which one that
# corresponds to in positional dimensions here, assuming query dim.
dropout_shape = list(attn_weights.shape)
dropout_rng = self.make_rng('dropout')
dropout_rng = self.make_rng("dropout")
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
multiplier = (keep.astype(attn_weights.dtype) /
jnp.asarray(keep_prob, dtype=self.dtype))
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype)
attn_weights = attn_weights * multiplier
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
# Take the linear combination of `value`.
if self.transpose_batch_sequence:
return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
class DenseGeneral(nn.Module):
"""A linear transformation with flexible axes and FP8 support.
Attributes:
features: tuple with numbers of output features.
axis: tuple with axes to apply the transformation on.
dtype: the dtype of the computation (default: float32).
kernel_init: initializer function for the weight matrix.
use_bias: whether to add a bias to the output (default: False).
bias_init: initializer function for the bias vector.
Attributes:
features: tuple with numbers of output features.
axis: tuple with axes to apply the transformation on.
dtype: the dtype of the computation (default: float32).
kernel_init: initializer function for the weight matrix.
use_bias: whether to add a bias to the output (default: False).
bias_init: initializer function for the bias vector.
"""
features: Union[Iterable[int], int]
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
......@@ -212,7 +217,7 @@ class DenseGeneral(nn.Module):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
super().__post_init__()
@nn.compact
......@@ -233,21 +238,17 @@ class DenseGeneral(nn.Module):
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
kernel = nn_partitioning.param_with_axes('kernel',
self.kernel_init,
kernel_param_shape,
jnp.float32,
axes=self.kernel_axes)
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes
)
kernel = jnp.asarray(kernel, self.dtype)
kernel = jnp.reshape(kernel, kernel_shape)
if self.use_bias:
bias = nn_partitioning.param_with_axes('bias',
self.bias_init,
self.features,
jnp.float32,
axes=self.bias_axes)
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, self.features, jnp.float32, axes=self.bias_axes
)
bias = bias.astype(self.dtype)
else:
bias = None
......@@ -264,18 +265,19 @@ class DenseGeneral(nn.Module):
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block.
Attributes:
intermediate_dim: Shared dimension of hidden layers.
activations: Type of activations for each layer. Each element is either
'linear', a string function name in flax.linen, or a function.
kernel_init: Kernel function, passed to the dense layers.
deterministic: Whether the dropout layers should be deterministic.
intermediate_dropout_rate: Dropout rate used after the intermediate layers.
dtype: Type for the dense layer.
"""
Attributes:
intermediate_dim: Shared dimension of hidden layers.
activations: Type of activations for each layer. Each element is either
'linear', a string function name in flax.linen, or a function.
kernel_init: Kernel function, passed to the dense layers.
deterministic: Whether the dropout layers should be deterministic.
intermediate_dropout_rate: Dropout rate used after the intermediate layers.
dtype: Type for the dense layer.
"""
transpose_batch_sequence: bool
intermediate_dim: int = 2048
activations: Sequence[Union[str, Callable]] = ('relu',)
activations: Sequence[Union[str, Callable]] = ("relu",)
kernel_init: Initializer = None
intermediate_dropout_rate: float = 0.1
intermediate_dropout_dims: Sequence[int] = ()
......@@ -285,7 +287,7 @@ class MlpBlock(nn.Module):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
super().__post_init__()
@nn.compact
......@@ -296,49 +298,57 @@ class MlpBlock(nn.Module):
activations = []
if self.fuse_wi:
dense_name = 'wi'
dense_name = "wi"
num_activations = len(self.activations)
x = DenseGeneral(self.intermediate_dim * num_activations,
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=('embed', 'mlp'),
use_bias=self.use_bias,
bias_axes=('mlp'),
name=dense_name)(inputs)
x = DenseGeneral(
self.intermediate_dim * num_activations,
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=("embed", "mlp"),
use_bias=self.use_bias,
bias_axes="mlp",
name=dense_name,
)(inputs)
x = jnp.split(x, num_activations, axis=-1)
for idx, act_fn in enumerate(self.activations):
x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i)
else:
for idx, act_fn in enumerate(self.activations):
dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}'
x = DenseGeneral(self.intermediate_dim,
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=('embed', 'mlp'),
use_bias=self.use_bias,
bias_axes=('mlp'),
name=dense_name)(inputs)
dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}"
x = DenseGeneral(
self.intermediate_dim,
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=("embed", "mlp"),
use_bias=self.use_bias,
bias_axes="mlp",
name=dense_name,
)(inputs)
x = _convert_to_activation_function(act_fn)(x)
activations.append(x)
# Take elementwise product of above intermediate activations.
x = functools.reduce(operator.mul, activations)
# Apply dropout and final dense output projection.
x = nn.Dropout(rate=self.intermediate_dropout_rate,
broadcast_dims=self.intermediate_dropout_dims)(
x, deterministic=deterministic) # Broadcast along length.
x = nn.Dropout(
rate=self.intermediate_dropout_rate, broadcast_dims=self.intermediate_dropout_dims
)(
x, deterministic=deterministic
) # Broadcast along length.
if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'mlp'))
x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "mlp"))
else:
x = nn_partitioning.with_sharding_constraint(x, ('batch', 'length', 'mlp'))
output = DenseGeneral(inputs.shape[-1],
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=('mlp', 'embed'),
use_bias=self.use_bias,
bias_axes=('embed'),
name='wo')(x)
x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "mlp"))
output = DenseGeneral(
inputs.shape[-1],
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=("mlp", "embed"),
use_bias=self.use_bias,
bias_axes="embed",
name="wo",
)(x)
return output
......@@ -351,7 +361,7 @@ def apply_rotary_pos_emb_alternate(
embedding_dim = inputs.shape[-1]
half_embedding_dim = embedding_dim // 2
fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim
timescale = min_timescale * (max_timescale / min_timescale)**fraction
timescale = min_timescale * (max_timescale / min_timescale) ** fraction
timescale = jnp.expand_dims(timescale, axis=tuple(range(inputs.ndim - 1)))
position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
sinusoid_inp = position / timescale
......@@ -386,7 +396,7 @@ def apply_rotary_pos_emb_consecutive(
inputs_shifted_left,
)
fraction = jnp.repeat(fraction, 2)
timescale = min_timescale * (max_timescale / min_timescale)**fraction
timescale = min_timescale * (max_timescale / min_timescale) ** fraction
position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
......@@ -415,89 +425,96 @@ class MultiHeadAttention(nn.Module):
kernel_init: initializer for the kernel of the Dense layers.
float32_logits: bool, if True then compute logits in float32 to avoid
numerical issues with bfloat16.
"""
"""
num_heads: int = 8
num_gqa_groups: int | None = None
head_dim: int = 64
transpose_batch_sequence: bool = True
dtype: DType = jnp.float32
dropout_rate: float = 0.
dropout_rate: float = 0.0
kernel_init: Initializer = None
float32_logits: bool = False # computes logits in float32 for stability.
float32_logits: bool = False # computes logits in float32 for stability.
scale_attn_logits: bool = False
scaled_query_init: bool = True
enable_rotary_pos_emb: bool = False
rotary_pos_emb_group_method: str = 'consecutive'
rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv: bool = True
use_bias: bool = False
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal")
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact
def __call__(self,
inputs_q: Array,
inputs_kv: Array,
mask: Optional[Array] = None,
bias: Optional[Array] = None,
*,
decode: bool = False,
deterministic: bool = False) -> Array:
def __call__(
self,
inputs_q: Array,
inputs_kv: Array,
mask: Optional[Array] = None,
bias: Optional[Array] = None,
*,
decode: bool = False,
deterministic: bool = False,
) -> Array:
"""Applies multi-head dot product attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.
Projects the inputs into multi-headed query, key, and value vectors,
applies dot-product attention and project the results to an output vector.
There are two modes: decoding and non-decoding (e.g., training). The mode is
determined by `decode` argument. For decoding, this method is called twice,
first to initialize the cache and then for an actual decoding process. The
two calls are differentiated by the presence of 'cached_key' in the variable
dict. In the cache initialization stage, the cache variables are initialized
as zeros and will be filled in the subsequent decoding process.
There are two modes: decoding and non-decoding (e.g., training). The mode is
determined by `decode` argument. For decoding, this method is called twice,
first to initialize the cache and then for an actual decoding process. The
two calls are differentiated by the presence of 'cached_key' in the variable
dict. In the cache initialization stage, the cache variables are initialized
as zeros and will be filled in the subsequent decoding process.
In the cache initialization call, `inputs_q` has a shape [batch, length,
q_features] and `inputs_kv`: [batch, length, kv_features]. During the
incremental decoding stage, query, key and value all have the shape [batch,
1, qkv_features] corresponding to a single step.
In the cache initialization call, `inputs_q` has a shape [batch, length,
q_features] and `inputs_kv`: [batch, length, kv_features]. During the
incremental decoding stage, query, key and value all have the shape [batch,
1, qkv_features] corresponding to a single step.
Args:
inputs_q: input queries of shape `[batch, q_length, q_features]`.
inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
decode: Whether to prepare and use an autoregressive cache.
deterministic: Disables dropout if set to True.
Args:
inputs_q: input queries of shape `[batch, q_length, q_features]`.
inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
decode: Whether to prepare and use an autoregressive cache.
deterministic: Disables dropout if set to True.
Returns:
output of shape `[batch, length, q_features]`.
"""
q_projection = functools.partial(DenseGeneral,
axis=-1,
features=self.num_heads * self.head_dim,
kernel_axes=('embed', 'joined_kv'),
use_bias=self.use_bias,
bias_axes=('joined_kv'),
dtype=self.dtype)
kv_projection = functools.partial(DenseGeneral,
axis=-1,
features=self.num_gqa_groups * self.head_dim,
kernel_axes=('embed', 'joined_kv'),
use_bias=self.use_bias,
bias_axes=('joined_kv'),
dtype=self.dtype)
Returns:
output of shape `[batch, length, q_features]`.
"""
q_projection = functools.partial(
DenseGeneral,
axis=-1,
features=self.num_heads * self.head_dim,
kernel_axes=("embed", "joined_kv"),
use_bias=self.use_bias,
bias_axes="joined_kv",
dtype=self.dtype,
)
kv_projection = functools.partial(
DenseGeneral,
axis=-1,
features=self.num_gqa_groups * self.head_dim,
kernel_axes=("embed", "joined_kv"),
use_bias=self.use_bias,
bias_axes="joined_kv",
dtype=self.dtype,
)
# NOTE: T5 does not explicitly rescale the attention logits by
# 1/sqrt(depth_kq)! This is folded into the initializers of the
# linear transformations, which is equivalent under Adafactor
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
query_init = lambda *args: self.kernel_init(*args) / (depth_scaling
if self.scaled_query_init else 1.0)
query_init = lambda *args: self.kernel_init(*args) / (
depth_scaling if self.scaled_query_init else 1.0
)
# Project inputs_q to multi-headed q/k/v
# dimensions are then [batch, length, num_heads, head_dim]
......@@ -515,39 +532,45 @@ class MultiHeadAttention(nn.Module):
return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype)
is_self_attn = (inputs_q is inputs_kv)
is_gqa = (self.num_heads != self.num_gqa_groups)
is_qkvpack = (is_self_attn and not is_gqa)
is_self_attn = inputs_q is inputs_kv
is_gqa = self.num_heads != self.num_gqa_groups
is_qkvpack = is_self_attn and not is_gqa
if self.fuse_qkv:
if is_qkvpack:
qkv_proj = DenseGeneral(axis=-1,
features=self.num_heads * self.head_dim * 3,
kernel_axes=('embed', 'joined_kv'),
kernel_init=qkv_init,
use_bias=self.use_bias,
bias_axes=('joined_kv'),
name='qkv',
dtype=self.dtype)(inputs_kv)
qkv_proj = DenseGeneral(
axis=-1,
features=self.num_heads * self.head_dim * 3,
kernel_axes=("embed", "joined_kv"),
kernel_init=qkv_init,
use_bias=self.use_bias,
bias_axes="joined_kv",
name="qkv",
dtype=self.dtype,
)(inputs_kv)
query, key, value = jnp.split(
qkv_proj, [self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2],
axis=-1)
qkv_proj,
[self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2],
axis=-1,
)
else:
query = q_projection(kernel_init=query_init, name='query')(inputs_q)
kv_proj = DenseGeneral(axis=-1,
features=self.num_gqa_groups * self.head_dim * 2,
kernel_axes=('embed', 'joined_kv'),
kernel_init=self.kernel_init,
use_bias=self.use_bias,
bias_axes=('joined_kv'),
name='kv',
dtype=self.dtype)(inputs_kv)
query = q_projection(kernel_init=query_init, name="query")(inputs_q)
kv_proj = DenseGeneral(
axis=-1,
features=self.num_gqa_groups * self.head_dim * 2,
kernel_axes=("embed", "joined_kv"),
kernel_init=self.kernel_init,
use_bias=self.use_bias,
bias_axes="joined_kv",
name="kv",
dtype=self.dtype,
)(inputs_kv)
key, value = jnp.split(kv_proj, [self.num_gqa_groups * self.head_dim], axis=-1)
else:
query = q_projection(kernel_init=query_init, name='query')(inputs_q)
key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
query = q_projection(kernel_init=query_init, name="query")(inputs_q)
key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
if self.enable_rotary_pos_emb:
batch_dim = 1 if self.transpose_batch_sequence else 0
......@@ -556,7 +579,7 @@ class MultiHeadAttention(nn.Module):
q_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
k_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
if self.rotary_pos_emb_group_method == 'alternate':
if self.rotary_pos_emb_group_method == "alternate":
apply_rotary_pos_emb = apply_rotary_pos_emb_alternate
else:
apply_rotary_pos_emb = apply_rotary_pos_emb_consecutive
......@@ -571,33 +594,40 @@ class MultiHeadAttention(nn.Module):
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
if self.transpose_batch_sequence:
query = nn_partitioning.with_sharding_constraint(query,
('length', 'batch', 'heads', 'kv'))
key = nn_partitioning.with_sharding_constraint(key, ('length', 'batch', 'heads', 'kv'))
value = nn_partitioning.with_sharding_constraint(value,
('length', 'batch', 'heads', 'kv'))
query = nn_partitioning.with_sharding_constraint(
query, ("length", "batch", "heads", "kv")
)
key = nn_partitioning.with_sharding_constraint(key, ("length", "batch", "heads", "kv"))
value = nn_partitioning.with_sharding_constraint(
value, ("length", "batch", "heads", "kv")
)
else:
query = nn_partitioning.with_sharding_constraint(query,
('batch', 'length', 'heads', 'kv'))
key = nn_partitioning.with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv'))
value = nn_partitioning.with_sharding_constraint(value,
('batch', 'length', 'heads', 'kv'))
query = nn_partitioning.with_sharding_constraint(
query, ("batch", "length", "heads", "kv")
)
key = nn_partitioning.with_sharding_constraint(key, ("batch", "length", "heads", "kv"))
value = nn_partitioning.with_sharding_constraint(
value, ("batch", "length", "heads", "kv")
)
if decode:
# Detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable('cache', 'cached_key')
is_initialized = self.has_variable("cache", "cached_key")
# The key and value have dimension [batch, length, num_heads, head_dim],
# but we cache them as [batch, num_heads, head_dim, length] as a TPU
# fusion optimization. This also enables the "scatter via one-hot
# broadcast" trick, which means we do a one-hot broadcast instead of a
# scatter/gather operations, resulting in a 3-4x speedup in practice.
swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3])
cached_key = self.variable('cache', 'cached_key', jnp.zeros, swap_dims(key.shape),
key.dtype)
cached_value = self.variable('cache', 'cached_value', jnp.zeros, swap_dims(value.shape),
value.dtype)
cache_index = self.variable('cache', 'cache_index',
lambda: jnp.array(0, dtype=jnp.int32))
cached_key = self.variable(
"cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype
)
cached_value = self.variable(
"cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype
)
cache_index = self.variable(
"cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
)
if is_initialized:
batch, num_heads, head_dim, length = cached_key.value.shape
# During fast autoregressive decoding, we feed one position at a time,
......@@ -606,8 +636,9 @@ class MultiHeadAttention(nn.Module):
expected_shape = (batch, 1, num_heads, head_dim)
if expected_shape != query.shape:
raise ValueError(
'Autoregressive cache shape error, '
f"expected query shape {expected_shape} instead got {query.shape}.")
"Autoregressive cache shape error, "
f"expected query shape {expected_shape} instead got {query.shape}."
)
# Create a OHE of the current index. NOTE: the index is increased below.
cur_index = cache_index.value
......@@ -638,11 +669,13 @@ class MultiHeadAttention(nn.Module):
jnp.logical_not(mask),
jnp.broadcast_to(
jnp.arange(length) <= cur_index,
# (1, 1, length) represent (head dim, query length, key length)
# query length is 1 because during decoding we deal with one
# index.
# The same mask is applied to all batch elements and heads.
(batch, 1, 1, length)))
# (1, 1, length) represent (head dim, query length, key length)
# query length is 1 because during decoding we deal with one
# index.
# The same mask is applied to all batch elements and heads.
(batch, 1, 1, length),
),
)
# Grab the correct relative attention bias during decoding. This is
# only required during single step decoding.
......@@ -650,15 +683,18 @@ class MultiHeadAttention(nn.Module):
# The bias is a full attention matrix, but during decoding we only
# have to take a slice of it.
# This is equivalent to bias[..., cur_index:cur_index+1, :].
bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
jnp.reshape(cur_index, (-1)), 1, -2)
bias = dynamic_vector_slice_in_dim(
jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
)
# Convert the boolean attention mask to an attention bias.
if mask is not None:
# attention mask in the form of attention bias
attention_bias = lax.select(mask > 0,
jnp.full(mask.shape, 0.).astype(self.dtype),
jnp.full(mask.shape, -1e10).astype(self.dtype))
attention_bias = lax.select(
mask > 0,
jnp.full(mask.shape, 0.0).astype(self.dtype),
jnp.full(mask.shape, -1e10).astype(self.dtype),
)
else:
attention_bias = None
......@@ -667,41 +703,41 @@ class MultiHeadAttention(nn.Module):
attention_bias = combine_biases(attention_bias, bias)
# Apply attention.
x = DotProductAttention(transpose_batch_sequence=self.transpose_batch_sequence,
scale_attn_logits=self.scale_attn_logits,
dropout_rate=self.dropout_rate,
dtype=self.dtype,
float32_logits=self.float32_logits)(query,
key,
value,
bias=attention_bias,
deterministic=deterministic)
x = DotProductAttention(
transpose_batch_sequence=self.transpose_batch_sequence,
scale_attn_logits=self.scale_attn_logits,
dropout_rate=self.dropout_rate,
dtype=self.dtype,
float32_logits=self.float32_logits,
)(query, key, value, bias=attention_bias, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'joined_kv'))
x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "joined_kv"))
else:
x = nn_partitioning.with_sharding_constraint(x, ('batch', 'length', 'joined_kv'))
x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "joined_kv"))
# Back to the original inputs dimensions.
out = DenseGeneral(
features=inputs_q.shape[-1], # output dim is set to the input dim.
features=inputs_q.shape[-1], # output dim is set to the input dim.
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=('joined_kv', 'embed'),
kernel_axes=("joined_kv", "embed"),
use_bias=self.use_bias,
bias_axes=('embed'),
bias_axes="embed",
dtype=self.dtype,
name='out')(x)
name="out",
)(x)
return out
class LayerNorm(nn.Module):
"""T5 Layer normalization operating on the last axis of the input data."""
epsilon: float = 1e-6
dtype: Any = jnp.float32
layernorm_type: str = 'layernorm'
layernorm_type: str = "layernorm"
zero_centered_gamma: bool = False
scale_init: Initializer = None
bias_init: Initializer = nn.initializers.zeros
......@@ -721,29 +757,27 @@ class LayerNorm(nn.Module):
x = jnp.asarray(x, jnp.float32)
features = x.shape[-1]
scale = nn_partitioning.param_with_axes('scale',
self.scale_init, (features,),
jnp.float32,
axes=('embed',))
scale = nn_partitioning.param_with_axes(
"scale", self.scale_init, (features,), jnp.float32, axes=("embed",)
)
scale = jnp.asarray(scale, self.dtype)
if self.layernorm_type == 'layernorm':
if self.layernorm_type == "layernorm":
mean = jnp.mean(x, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True)
y = (x - mean) * lax.rsqrt(var + self.epsilon)
bias = nn_partitioning.param_with_axes('ln_bias',
self.bias_init, (features,),
jnp.float32,
axes=('embed',))
bias = nn_partitioning.param_with_axes(
"ln_bias", self.bias_init, (features,), jnp.float32, axes=("embed",)
)
bias = jnp.asarray(bias, self.dtype)
if not self.zero_centered_gamma:
z = y * scale + bias
else:
z = y * (scale + 1.) + bias
z = y * (scale + 1.0) + bias
else:
assert self.layernorm_type == 'rmsnorm'
assert self.layernorm_type == "rmsnorm"
assert not self.zero_centered_gamma
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
y = x * lax.rsqrt(mean2 + self.epsilon)
......@@ -755,16 +789,17 @@ class LayerNorm(nn.Module):
class RelativePositionBiases(nn.Module):
"""Adds T5-style relative positional embeddings to the attention logits.
Attributes:
num_buckets: Number of buckets to bucket distances between key and query
positions into.
max_distance: Maximum distance before everything is lumped into the last
distance bucket.
num_heads: Number of heads in the attention layer. Each head will get a
different relative position weighting.
dtype: Type of arrays through this module.
embedding_init: initializer for relative embedding table.
"""
Attributes:
num_buckets: Number of buckets to bucket distances between key and query
positions into.
max_distance: Maximum distance before everything is lumped into the last
distance bucket.
num_heads: Number of heads in the attention layer. Each head will get a
different relative position weighting.
dtype: Type of arrays through this module.
embedding_init: initializer for relative embedding table.
"""
num_buckets: int
max_distance: int
num_heads: int
......@@ -772,33 +807,32 @@ class RelativePositionBiases(nn.Module):
embedding_init: Callable[..., Array] = nn.linear.default_embed_init
@staticmethod
def _relative_position_bucket(relative_position,
bidirectional=True,
num_buckets=32,
max_distance=128):
def _relative_position_bucket(
relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
"""Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e.
the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are
invalid.
We use smaller buckets for small absolute relative_position and larger
buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative
positions <=-max_distance map to the same bucket. This should allow for
more graceful generalization to longer sequences than the model has been
trained on.
The relative position is defined as memory_position - query_position, i.e.
the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are
invalid.
We use smaller buckets for small absolute relative_position and larger
buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative
positions <=-max_distance map to the same bucket. This should allow for
more graceful generalization to longer sequences than the model has been
trained on.
Args:
relative_position: an int32 array
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Args:
relative_position: an int32 array
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets)
"""
Returns:
a Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets)
"""
ret = 0
n = -relative_position
if bidirectional:
......@@ -811,8 +845,10 @@ class RelativePositionBiases(nn.Module):
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) /
np.log(max_distance / max_exact) * (num_buckets - max_exact)).astype(np.int32)
np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps)
/ np.log(max_distance / max_exact)
* (num_buckets - max_exact)
).astype(np.int32)
val_if_large = np.minimum(val_if_large, num_buckets - 1)
ret += np.where(is_small, n, val_if_large)
return ret
......@@ -821,27 +857,31 @@ class RelativePositionBiases(nn.Module):
def __call__(self, qlen, klen, bidirectional=True):
"""Produce relative position embedding attention biases.
Args:
qlen: attention query length.
klen: attention key length.
bidirectional: whether to allow positive memory-query relative position
embeddings.
Args:
qlen: attention query length.
klen: attention key length.
bidirectional: whether to allow positive memory-query relative position
embeddings.
Returns:
output: `(1, len, q_len, k_len)` attention bias
"""
Returns:
output: `(1, len, q_len, k_len)` attention bias
"""
context_position = np.arange(qlen, dtype=jnp.int32)[:, None]
memory_position = np.arange(klen, dtype=jnp.int32)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket(relative_position,
bidirectional=bidirectional,
num_buckets=self.num_buckets,
max_distance=self.max_distance)
relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket(
relative_position,
bidirectional=bidirectional,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
relative_attention_bias = nn_partitioning.param_with_axes(
'rel_embedding',
self.embedding_init, (self.num_heads, self.num_buckets),
"rel_embedding",
self.embedding_init,
(self.num_heads, self.num_buckets),
jnp.float32,
axes=('heads', 'relpos_buckets'))
axes=("heads", "relpos_buckets"),
)
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
# Instead of using a slow gather, we create a leading-dimension one-hot
......@@ -855,9 +895,8 @@ class RelativePositionBiases(nn.Module):
values = lax.dot_general(
relative_attention_bias,
rp_bucket_one_hot,
(
((1,), (0,)), # rhs, lhs contracting dims
((), ()))) # no batched dims
(((1,), (0,)), ((), ())), # rhs, lhs contracting dims
) # no batched dims
# Add a singleton batch dimension.
# --> shape (1, num_heads, qlen, klen)
return values[jnp.newaxis, ...]
......@@ -865,6 +904,7 @@ class RelativePositionBiases(nn.Module):
class EncoderLayer(nn.Module):
"""Transformer encoder layer."""
enable_relative_embedding: bool = True
relative_embedding: nn.Module = None
num_attention_heads: int = 8
......@@ -880,17 +920,17 @@ class EncoderLayer(nn.Module):
scale_attn_logits: bool = False
scaled_query_init: bool = True
mlp_dim: int = 2048
mlp_activations: Sequence[str] = ('relu',)
mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False
dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False
layernorm_type: str = 'layernorm'
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
output_layernorm: bool = False
drop_path: float = 0.0
enable_rotary_pos_emb: bool = False
rotary_pos_emb_group_method: str = 'consecutive'
rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = True
self_attn_bias_type: Any = None
......@@ -903,20 +943,21 @@ class EncoderLayer(nn.Module):
@nn.compact
def __call__(self, inputs, encoder_mask=None, deterministic=False):
del self.self_attn_mask_type # dummy, just align to TE's impl
del self.self_attn_mask_type # dummy, just align to TE's impl
# Relative position embedding as attention biases.
sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim
if self.enable_relative_embedding:
if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32,
max_distance=128,
num_heads=self.num_attention_heads,
dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling(
1.0, 'fan_avg', 'uniform'),
name='relpos_bias')
rel_emb = RelativePositionBiases(
num_buckets=32,
max_distance=128,
num_heads=self.num_attention_heads,
dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
name="relpos_bias",
)
else:
rel_emb = self.relative_embedding
encoder_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
......@@ -928,11 +969,13 @@ class EncoderLayer(nn.Module):
if not self.output_layernorm:
# Attention block.
x = LayerNorm(layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="pre_attention_layer_norm")(inputs)
x = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="pre_attention_layer_norm",
)(inputs)
if self.apply_residual_connection_post_layernorm:
residual = x
......@@ -940,39 +983,41 @@ class EncoderLayer(nn.Module):
x = inputs
# [batch, length, emb_dim] -> [batch, length, emb_dim]
x = MultiHeadAttention(num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
fuse_qkv=self.fuse_qkv_params,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
use_bias=self.use_bias,
name='attention')(x,
x,
encoder_mask,
encoder_bias,
deterministic=deterministic)
x = nn.Dropout(rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims)(x, deterministic=deterministic)
x = MultiHeadAttention(
num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
fuse_qkv=self.fuse_qkv_params,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
use_bias=self.use_bias,
name="attention",
)(x, x, encoder_mask, encoder_bias, deterministic=deterministic)
x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
x, deterministic=deterministic
)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
x = nn.Dropout(rate=self.drop_path,
broadcast_dims=drop_path_shape)(x, deterministic=deterministic)
x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
x, deterministic=deterministic
)
x = x + residual
# MLP block.
residual = x
y = LayerNorm(layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name='pre_mlp_layer_norm')(x)
y = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="pre_mlp_layer_norm",
)(x)
if self.apply_residual_connection_post_layernorm:
residual = y
......@@ -987,27 +1032,32 @@ class EncoderLayer(nn.Module):
use_bias=self.use_bias,
dtype=self.dtype,
fuse_wi=self.fuse_mlp_wi,
name='mlp',
name="mlp",
)(y, deterministic=deterministic)
y = nn.Dropout(rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims)(y, deterministic=deterministic)
y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
y, deterministic=deterministic
)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim)
y = nn.Dropout(rate=self.drop_path,
broadcast_dims=drop_path_shape)(y, deterministic=deterministic)
y = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
y, deterministic=deterministic
)
y = y + residual
if self.output_layernorm:
y = LayerNorm(layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="output_layernorm")(y)
y = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="output_layernorm",
)(y)
return y
class DecoderLayer(nn.Module):
"""Transformer decoder layer that attends to the encoder."""
enable_relative_embedding: bool = True
relative_embedding: nn.Module = None
num_attention_heads: int = 8
......@@ -1023,17 +1073,17 @@ class DecoderLayer(nn.Module):
scale_attn_logits: bool = False
scaled_query_init: bool = True
mlp_dim: int = 2048
mlp_activations: Sequence[str] = ('relu',)
mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False
dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
layernorm_type: str = 'layernorm'
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
drop_path: float = 0.0
enable_rotary_pos_emb: bool = False
rotary_pos_emb_group_method: str = 'consecutive'
rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = True
self_attn_bias_type: Any = None
......@@ -1045,15 +1095,17 @@ class DecoderLayer(nn.Module):
super().__post_init__()
@nn.compact
def __call__(self,
inputs,
encoded,
decoder_mask=None,
encoder_decoder_mask=None,
deterministic=False,
decode=False,
max_decode_length=None):
del self.self_attn_mask_type # dummy, just align to TE's impl
def __call__(
self,
inputs,
encoded,
decoder_mask=None,
encoder_decoder_mask=None,
deterministic=False,
decode=False,
max_decode_length=None,
):
del self.self_attn_mask_type # dummy, just align to TE's impl
# Relative position embedding as attention biases.
sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim
......@@ -1061,13 +1113,14 @@ class DecoderLayer(nn.Module):
if self.enable_relative_embedding:
l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim]
if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32,
max_distance=128,
num_heads=self.num_attention_heads,
dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling(
1.0, 'fan_avg', 'uniform'),
name='relpos_bias')
rel_emb = RelativePositionBiases(
num_buckets=32,
max_distance=128,
num_heads=self.num_attention_heads,
dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
name="relpos_bias",
)
else:
rel_emb = self.relative_embedding
decoder_bias = rel_emb(l, l, False)
......@@ -1079,11 +1132,13 @@ class DecoderLayer(nn.Module):
if not self.output_layernorm:
# Attention block.
x = LayerNorm(layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="pre_self_attention_layer_norm")(inputs)
x = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="pre_self_attention_layer_norm",
)(inputs)
if self.apply_residual_connection_post_layernorm:
residual = x
......@@ -1091,71 +1146,74 @@ class DecoderLayer(nn.Module):
x = inputs
# Self-attention block
x = MultiHeadAttention(num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias,
name='self_attention')(x,
x,
decoder_mask,
decoder_bias,
deterministic=deterministic,
decode=decode)
x = nn.Dropout(rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims)(x, deterministic=deterministic)
x = MultiHeadAttention(
num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias,
name="self_attention",
)(x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode)
x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
x, deterministic=deterministic
)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
x = nn.Dropout(rate=self.drop_path,
broadcast_dims=drop_path_shape)(x, deterministic=deterministic)
x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
x, deterministic=deterministic
)
x = x + residual
# Encoder-Decoder block.
residual = x
y = LayerNorm(layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name='pre_cross_attention_layer_norm')(x)
y = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="pre_cross_attention_layer_norm",
)(x)
if self.apply_residual_connection_post_layernorm:
residual = y
y = MultiHeadAttention(num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias,
name='encoder_decoder_attention')(y,
encoded,
encoder_decoder_mask,
deterministic=deterministic)
y = nn.Dropout(rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims)(y, deterministic=deterministic)
y = MultiHeadAttention(
num_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.attention_dropout,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias,
name="encoder_decoder_attention",
)(y, encoded, encoder_decoder_mask, deterministic=deterministic)
y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
y, deterministic=deterministic
)
y = y + residual
# MLP block.
residual = y
z = LayerNorm(layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name='pre_mlp_layer_norm')(y)
z = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="pre_mlp_layer_norm",
)(y)
if self.apply_residual_connection_post_layernorm:
residual = z
z = MlpBlock(
......@@ -1167,22 +1225,26 @@ class DecoderLayer(nn.Module):
use_bias=self.use_bias,
dtype=self.dtype,
fuse_wi=self.fuse_mlp_wi,
name='mlp',
name="mlp",
)(z, deterministic=deterministic)
z = nn.Dropout(rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims)(z, deterministic=deterministic)
z = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
z, deterministic=deterministic
)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
z = nn.Dropout(rate=self.drop_path,
broadcast_dims=drop_path_shape)(z, deterministic=deterministic)
z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
z, deterministic=deterministic
)
z = z + residual
if self.output_layernorm:
z = LayerNorm(layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="output_layernorm")(z)
z = LayerNorm(
layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="output_layernorm",
)(z)
return z
......@@ -1261,15 +1323,18 @@ def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08):
flatten_expected, _ = jax.tree_util.tree_flatten_with_path(expected)
flatten_actual, _ = jax.tree_util.tree_flatten_with_path(actual)
for (expected_path, expected_value), (actual_path,
actual_value) in zip(flatten_expected, flatten_actual):
for (expected_path, expected_value), (actual_path, actual_value) in zip(
flatten_expected, flatten_actual
):
assert expected_path == actual_path
key_str = jax.tree_util.keystr(expected_path)
assert_allclose(expected_value,
actual_value,
rtol=rtol,
atol=atol,
err_msg=f'Value of expected{key_str} and actual{key_str} is not close')
assert_allclose(
expected_value,
actual_value,
rtol=rtol,
atol=atol,
err_msg=f"Value of expected{key_str} and actual{key_str} is not close",
)
def dtype_tols(
......@@ -1323,7 +1388,7 @@ def dtype_tols(
)
def sync_params_values(dst, src, transformations, sep='/'):
def sync_params_values(dst, src, transformations, sep="/"):
"""
This function will reconstuct a tree with dst's tree_def/shape and src's value.
transformations is a map that records the key mappings between dst and src.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment