Unverified Commit 2045a426 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Enhance JAX unit tests (#796)



* Add layernorm_fp8_dot unit test
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update the softmax primitives support conditions
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add tests for the softmax primitives
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Round1 refactor of test_layer
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Split dropout arguments of ref code and add hidden/intermediate dropout elementwise comparison
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add dropout_braodcast_dim, self_attn_mask tests and clean a few code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Abstract test layer and fix a rope reference code diff
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add bias tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add epsilon and float32 tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add relpos_bias and attention dropout tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Loose the atol
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Move common fixtures to conftest.py
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add doc string for test_layer
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add doc string for test_layer
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix conflicts of test_layer
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Avoid to left bias parameters in graph when use_bias=False
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 6459fd85
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""conftest for tests/jax"""
import jax
import pytest
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
......@@ -2,6 +2,7 @@
#
# See LICENSE for license information.
from contextlib import nullcontext
import functools
import operator
from typing import Callable, Sequence, Union
......@@ -10,7 +11,6 @@ import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import lax
from jax import jit, value_and_grad
from flax import linen as nn
......@@ -18,7 +18,7 @@ from utils import assert_allclose
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
from transformer_engine.jax.mlp import activation_lu, activation_lu_fp8, fused_layernorm_fp8_mlp
......@@ -45,16 +45,6 @@ def _convert_to_activation_function(fn_or_string):
raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
class TestFP8Dot:
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
......@@ -416,88 +406,150 @@ class TestActivationLuFP8(TestActivationLu):
dtype=FP8Helper.BWD_DTYPE)
class TestRMSNorm:
@pytest.mark.parametrize('n, hidden', LN_CASES)
@pytest.mark.parametrize('dtype', DTYPES)
def test_forward_backward(self, n, hidden, dtype):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -2, 1)
scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, -2, 1)
scale = jnp.asarray(scale, dtype)
epsilon = 1e-6
def reference_rmsnorm(x, scale):
x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * lax.rsqrt(mean2 + epsilon), dtype)
return y * scale
jitted_primitive = jit(
value_and_grad(lambda x, scale: jnp.mean(layernorm(x, scale, None, "rmsnorm")), (0, 1)))
jitted_reference = jit(
value_and_grad(lambda x, scale: jnp.mean(reference_rmsnorm(x, scale)), (0, 1)))
primitive_out, (primitive_dx, primitive_dgamma) = jitted_primitive(x, scale)
reference_out, (reference_dx, reference_dgamma) = jitted_reference(x, scale)
assert_allclose(primitive_out, reference_out, dtype=dtype)
assert_allclose(primitive_dx, reference_dx, dtype=dtype)
assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
class TestNorm:
"""
Test transformer_engine.jax.layernorm APIs
"""
class TestLayerNorm:
def reference_layernorm(self, x, scale, bias, zero_centered_gamma, eps):
"""
JAX native layernorm implementations
- bias is not None: layernorm
- bias is None: rmsnorm
"""
x_ = jnp.asarray(x, jnp.float32)
if bias is None:
mean = 0.
else:
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
if zero_centered_gamma:
scale += 1.
if bias is None:
bias = 0.
return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
@pytest.mark.parametrize('n, hidden', LN_CASES)
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm'])
@pytest.mark.parametrize('zero_centered_gamma', [False, True])
def test_forward_backward(self, n, hidden, zero_centered_gamma, dtype):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3)
x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
scale_range = (-1, 1) if zero_centered_gamma else (0, 2)
scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *scale_range)
scale = jnp.asarray(scale, dtype)
bias = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
bias = jnp.asarray(bias, dtype)
epsilon = 1e-6
def reference_layernorm(x, scale, bias, zero_centered_gamma, eps):
x_ = jnp.asarray(x, jnp.float32)
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
# Align TE implementation
if zero_centered_gamma:
return jnp.asarray(normed_input * (scale + 1) + bias).astype(x.dtype)
return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
def compute_loss(x):
# Higher precision to compute the loss
x_ = x.astype(jnp.float32)
return jnp.mean(jnp.square(x_)).astype(x.dtype)
jitted_primitive = jit(
value_and_grad(
lambda x, scale, bias: compute_loss(
layernorm(x, scale, bias, "layernorm", zero_centered_gamma, epsilon)),
(0, 1, 2)))
jitted_reference = jit(
value_and_grad(
lambda x, scale, bias: compute_loss(
reference_layernorm(x, scale, bias, zero_centered_gamma, epsilon)), (0, 1, 2)))
primitive_out, (primitive_dx, primitive_dgamma,
primitive_dbeta) = jitted_primitive(x, scale, bias)
reference_out, (reference_dx, reference_dgamma,
reference_dbeta) = jitted_reference(x, scale, bias)
assert_allclose(primitive_out, reference_out, dtype=dtype)
assert_allclose(primitive_dx, reference_dx, dtype=dtype)
assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)
@pytest.mark.parametrize('epsilon', [1e-2, 1e-6])
def test_layernorm_forward_backward(self, n, hidden, ln_type, zero_centered_gamma, epsilon,
dtype):
"""
Test transformer_engine.jax.layernorm.layernorm
"""
expect_assert = False
if ln_type == 'rmsnorm' and zero_centered_gamma:
# zero_centered_gamma is not supported for rmsnorm, expect an assertion.
expect_assert = True
with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*"
) if expect_assert else nullcontext():
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3)
x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
gamma = jnp.asarray(gamma, dtype)
if ln_type == 'layernorm':
beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
beta = jnp.asarray(beta, dtype)
else:
beta = None
def compute_loss(x):
# Higher precision to compute the loss
x_ = x.astype(jnp.float32)
return jnp.mean(jnp.square(x_)).astype(x.dtype)
jitted_primitive = jit(
value_and_grad(
lambda x, gamma, beta: compute_loss(
layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)),
(0, 1, 2)))
jitted_reference = jit(
value_and_grad(
lambda x, gamma, beta: compute_loss(
self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)),
(0, 1, 2)))
primitive_out, (primitive_dx, primitive_dgamma,
primitive_dbeta) = jitted_primitive(x, gamma, beta)
reference_out, (reference_dx, reference_dgamma,
reference_dbeta) = jitted_reference(x, gamma, beta)
assert_allclose(primitive_out, reference_out, dtype=dtype)
assert_allclose(primitive_dx, reference_dx, dtype=dtype)
assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
if beta is not None:
assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
@pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm'])
@pytest.mark.parametrize('zero_centered_gamma', [True, False])
@pytest.mark.parametrize('epsilon', [1e-2, 1e-6])
def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon):
"""
Test transformer_engine.jax.layernorm.layernorm_fp8_dot
"""
expect_assert = False
if ln_type == 'rmsnorm' and zero_centered_gamma:
# zero_centered_gamma is not supported for rmsnorm, expect an assertion.
expect_assert = True
with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*"
) if expect_assert else nullcontext():
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 4)
a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
if ln_type == 'layernorm':
beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
else:
beta = None
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
def primitive_func(x, y, gamma, beta, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
primitive_out = layernorm_fp8_dot(x, y, gamma, beta, fp8_meta_pkg, ln_type,
zero_centered_gamma)
return jnp.mean(primitive_out)
def ref_func(x, y, gamma, beta, zero_centered_gamma):
x = self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
return jnp.mean(jnp.dot(x, y))
value_n_grad_primitive_func = value_and_grad(primitive_func, range(8))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))
ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad,
ref_beta_grad) = value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma)
for _ in range(3):
primitive_out, (primitive_a_grad, primitive_b_grad, primitive_gamma_grad,
primitive_beta_grad, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv) = value_n_grad_primitive_func(
a, b, gamma, beta, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(primitive_gamma_grad, ref_gamma_grad, dtype=FP8Helper.BWD_DTYPE)
if beta is not None:
assert_allclose(primitive_beta_grad, ref_beta_grad, dtype=FP8Helper.BWD_DTYPE)
......@@ -27,16 +27,14 @@ from transformer_engine_jax import NVTE_Fused_Attn_Backend
from utils import assert_allclose
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
@pytest.fixture(autouse=True, scope='module')
def init():
"""
Clear all live arrays to keep the resource clean
WAR for CUDA uninitialize error
"""
# Calling customcalls before jax may cause CUDA uninitialize error
_ = jnp.zeros(0)
yield
for arr in jax.live_arrays():
arr.delete()
def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike,
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Test transformer_engine.jax.flax.TransformerLayer"""
import os
from functools import partial
from typing import Dict
import flax
import jax
import jax.numpy as jnp
import pytest
from utils import assert_allclose
from utils import assert_allclose, assert_tree_like_allclose, sync_params_values
from utils import DecoderLayer as RefDecoderLayer
from utils import EncoderLayer as RefEncoderLayer
......@@ -21,68 +22,18 @@ from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
is_fp8_supported, reason = is_fp8_available()
@pytest.fixture(autouse=True, scope='module')
@pytest.fixture(autouse=True, scope='function')
def enable_fused_attn():
"""
Enable fused attention
"""
"""Enable fused attention"""
os.environ["NVTE_FUSED_ATTN"] = "1"
yield
del os.environ["NVTE_FUSED_ATTN"]
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
def loss_fn(diff_xs, no_diff_xs, params, others, model, rngs):
output = model.apply({"params": params, **others}, *diff_xs, *no_diff_xs, rngs=rngs)
return jnp.mean(output)
def generate_test_rngs():
data_rng = jax.random.PRNGKey(0)
init_rng = {'params': jax.random.PRNGKey(1), 'dropout': jax.random.PRNGKey(2)}
apply_rng = {'dropout': jax.random.PRNGKey(3)}
return data_rng, init_rng, apply_rng
def generate_layer(layer_cls, init_rng, diff_inputs, no_diff_inputs):
layer = layer_cls()
variables = layer.init(init_rng, *diff_inputs, *no_diff_inputs)
others, params = flax.core.pop(variables, 'params')
del variables
return layer, params, others
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
# To be compatible with both Flax>=0.7.1 or <0.7.1
# since Flax 0.7.1 removed FrozenDict.
ref_fd = flax.core.unfreeze(ref_fd)
test_fd = flax.core.unfreeze(test_fd)
for key in ref_fd:
assert key in test_fd, \
f"{key} not found in test dict {test_fd}"
assert isinstance(test_fd[key], type(ref_fd[key])), \
f"The data type is not match between ref and test " \
f"dict on {key=}"
if isinstance(ref_fd[key], dict):
compare_dict(ref_fd[key], test_fd[key], rtol, atol)
else:
assert_allclose(ref_fd[key],
test_fd[key],
rtol=rtol,
atol=atol,
err_msg=f"{key=} is not close")
DATA_SHAPE = [(32, 128, 1024), (32, 512, 1024)] # (batch, seqlen, emb_dim)
DATA_SHAPE = [ # (batch, seqlen, emb_dim)
pytest.param((32, 128, 1024), id='32-128-1024'),
pytest.param((32, 512, 1024), id='32-512-1024'),
]
DTYPE = [jnp.float32, jnp.bfloat16]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]
......@@ -90,31 +41,42 @@ _KEY_OF_RESIDUAL_POST_LAYERNORM = "apply_residual_connection_post_layernorm"
_KEY_OF_OUTPUT_LAYERNORM = "output_layernorm"
_KEY_OF_DROP_PATH = "drop_path"
_KEY_OF_FUSE_QKV_PARAMS = "fuse_qkv_params"
_KEY_OF_DROPOUT_RATE = "dropout_rate"
_KEY_OF_HIDDEN_DROPOUT = "hidden_dropout"
_KEY_OF_ATTENTION_DROPOUT = "attention_dropout"
_KEY_OF_INTERMEDIATE_DROPOUT = "intermediate_dropout"
_KEY_OF_HIDDEN_DROPOUT_DIMS = "hidden_dropout_dims"
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS = "intermediate_dropout_dims"
_KEY_OF_MLP_ACTIVATIONS = "mlp_activations"
_KEY_OF_FUSE_MLP_WI = "fuse_mlp_wi"
_KEY_OF_LAYERNORM_TYPE = 'layernorm_type'
_KEY_OF_ZERO_CENTERED_GAMMA = 'zero_centered_gamma'
_KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence'
_KEY_OF_LAYERNORM_TYPE = "layernorm_type"
_KEY_OF_LAYERNORM_EPS = "layernorm_epsilon"
_KEY_OF_ZERO_CENTERED_GAMMA = "zero_centered_gamma"
_KEY_OF_TRANSPOSE_BS = "transpose_batch_sequence"
_KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
_KEY_OF_NUM_HEADS = 'num_attention_heads'
_KEY_OF_NUM_GQA_GROUPS = 'num_gqa_groups'
_KEY_OF_NUM_HEADS = "num_attention_heads"
_KEY_OF_NUM_GQA_GROUPS = "num_gqa_groups"
_KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"
_KEY_OF_ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
_KEY_OF_SELF_ATTN_BIAS_TYPE = "self_attn_bias_type"
_KEY_OF_SELF_ATTN_MASK_TYPE = "self_attn_mask_type"
_KEY_OF_FLOAT32_ATTENTION_LOGITS = "float32_attention_logits"
_KEY_OF_USE_BIAS = "use_bias"
_KEY_OF_RELATIVE_EMBEDDING = "enable_relative_embedding"
BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_DROPOUT_RATE: 0,
_KEY_OF_HIDDEN_DROPOUT: 0,
_KEY_OF_ATTENTION_DROPOUT: 0,
_KEY_OF_INTERMEDIATE_DROPOUT: 0,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
}
ATTRS = [{
ATTRS = [{}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
}, {
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
}, {
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_ZERO_CENTERED_GAMMA: True
_KEY_OF_ZERO_CENTERED_GAMMA: True,
_KEY_OF_LAYERNORM_EPS: 1e-2,
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_RESIDUAL_POST_LAYERNORM: True
......@@ -133,518 +95,323 @@ ATTRS = [{
_KEY_OF_FUSE_QKV_PARAMS: False
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True
_KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'),
}, {
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_DROPOUT_RATE: 0.8,
_KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True
_KEY_OF_HIDDEN_DROPOUT: 0.8,
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
_KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'),
_KEY_OF_USE_BIAS: True,
}, {
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True
_KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'),
}, {
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('gelu',)),
_KEY_OF_FUSE_MLP_WI: True
_KEY_OF_MLP_ACTIVATIONS: ('gelu',),
_KEY_OF_USE_BIAS: True,
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_DROPOUT_RATE: 0.8,
_KEY_OF_HIDDEN_DROPOUT: 0.8,
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
_KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True
_KEY_OF_USE_BIAS: True,
}, {
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')),
_KEY_OF_FUSE_MLP_WI: True
}, {
_KEY_OF_NUM_HEADS: 8,
_KEY_OF_NUM_GQA_GROUPS: 4,
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_SCALE_ATTN_LOGITS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_MLP_ACTIVATIONS: (('silu',)),
_KEY_OF_FUSE_MLP_WI: True
_KEY_OF_USE_BIAS: True,
}, {
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_FUSE_MLP_WI: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_NUM_GQA_GROUPS: 1,
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "consecutive"
_KEY_OF_ROPE_GROUP_METHOD: "consecutive",
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
}, {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_FUSE_MLP_WI: True,
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "consecutive"
_KEY_OF_ROPE_GROUP_METHOD: "consecutive",
_KEY_OF_USE_BIAS: True,
}, {
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_FUSE_MLP_WI: True,
_KEY_OF_NUM_GQA_GROUPS: 2,
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "alternate"
_KEY_OF_ROPE_GROUP_METHOD: "alternate",
_KEY_OF_USE_BIAS: True,
_KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
}, {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_FUSE_MLP_WI: True,
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "alternate"
_KEY_OF_ROPE_GROUP_METHOD: "alternate",
_KEY_OF_USE_BIAS: True,
}, {
_KEY_OF_HIDDEN_DROPOUT: 0.3,
_KEY_OF_HIDDEN_DROPOUT_DIMS: (0,),
_KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,),
}, {
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
_KEY_OF_USE_BIAS: True,
}, {
_KEY_OF_RELATIVE_EMBEDDING: False,
_KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias",
}, {
_KEY_OF_ATTENTION_DROPOUT: 0.3,
}]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
class TestEncoderLayer:
@staticmethod
def sync_params(ref, target):
unfreeze_target = flax.core.unfreeze(target)
unfreeze_attn_scope = unfreeze_target['attention']
ref_attn_scope = ref['attention']
for key in ref_attn_scope.keys():
unfreeze_attn_scope[key]['kernel'] = \
ref_attn_scope[key]['kernel'].reshape(unfreeze_attn_scope[key]['kernel'].shape)
unfreeze_target['mlp']['wi_kernel'] = \
jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
unfreeze_target['mlp']['wo_kernel'] = \
ref['mlp']['wo']['kernel']
return ref, unfreeze_target
def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
batch, seqlen = data_shape[:2]
if transpose_batch_sequence:
data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
sequence_dim = 0 if transpose_batch_sequence else 1
class BaseRunner:
"""Base runner to define forward and backward tests"""
layer_type: TransformerLayerType = None
reference_layer: flax.linen.Module = None
transformations: Dict[str, str] = None
data_rng, init_rng, apply_rng = generate_test_rngs()
inputs = (jax.random.normal(data_rng, data_shape, dtype),)
def __init__(self, attrs):
self.attrs = attrs
self._generate_test_rngs()
# Disable fused attention for attention dropout because the different dropout impl
if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv('NVTE_FUSED_ATTN'):
os.environ['NVTE_FUSED_ATTN'] = "0"
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
ref_masks = (1 - padded_mask,)
test_masks = (None, padded_mask) # The second arg of Transformer is encoded tokens.
te_layer_attrs = {}
for k, v in attrs.items():
if k == 'dropout_rate':
te_layer_attrs['attention_dropout'] = v
te_layer_attrs['hidden_dropout'] = v
te_layer_attrs['intermediate_dropout'] = v
elif k == 'fuse_mlp_wi':
continue
else:
te_layer_attrs[k] = v
ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs)
layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
dtype=dtype,
**te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
ref_masks)
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks)
ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params)
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
del data_rng, init_rng, apply_rng
def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
batch, seqlen = data_shape[:2]
if transpose_batch_sequence:
data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
sequence_dim = 0 if transpose_batch_sequence else 1
def _generate_test_rngs(self):
root_rng = jax.random.PRNGKey(0)
params_rng, init_dropout_rng, apply_dropout_rng = jax.random.split(root_rng, 3)
self.init_rng = {'params': params_rng, 'dropout': init_dropout_rng}
self.apply_rng = {'dropout': apply_dropout_rng}
data_rng, init_rng, apply_rng = generate_test_rngs()
inputs = (jax.random.normal(data_rng, data_shape, dtype),)
def _generate_layer(self, layer_cls, diff_inputs, no_diff_inputs):
layer = layer_cls()
variables = layer.init(self.init_rng, *diff_inputs, *no_diff_inputs)
others, params = flax.core.pop(variables, 'params')
del variables
return layer, params, others
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
ref_masks = (1 - padded_mask,)
test_masks = (None, padded_mask) # The second arg of Transformer is encoded tokens.
te_layer_attrs = {}
for k, v in attrs.items():
if k == 'dropout_rate':
te_layer_attrs['attention_dropout'] = v
te_layer_attrs['hidden_dropout'] = v
te_layer_attrs['intermediate_dropout'] = v
elif k == 'fuse_mlp_wi':
continue
else:
te_layer_attrs[k] = v
ref_layer_cls = partial(RefEncoderLayer, dtype=dtype, **attrs)
layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
dtype=dtype,
**te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
ref_masks)
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks)
ref_params, test_params = TestEncoderLayer.sync_params(ref_params, test_params)
def _loss_fn(self, diff_xs, no_diff_xs, params, others, model):
variables = {'params': params, **others}
output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng)
return jnp.mean(output, dtype=jnp.float32).astype(output.dtype)
def _sync_params(self, ref, target):
"""Copy the reference params to target"""
target = sync_params_values(target, ref, self.transformations)
return ref, target
def test_forward(self, data_shape, dtype, rtol=1e-05, atol=1e-08):
"""Test only the forward"""
inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)
ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs)
ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
ref_params, test_params = self._sync_params(ref_params, test_params)
ref_out = self._loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer)
test_out = self._loss_fn(inputs, test_masks, test_params, test_others, test_layer)
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
def test_backward(self, data_shape, dtype, rtol=1e-05, atol=1e-08):
"""Test forward and backward through value_and_grad()"""
inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)
ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs)
ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
ref_params, test_params = self._sync_params(ref_params, test_params)
if FP8Helper.is_fp8_enabled():
for _ in range(4):
_, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,),
has_aux=False)(inputs, test_masks, test_params,
test_others, test_layer, apply_rng)
_, tmp_grad = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
inputs,
test_masks,
test_params,
test_others,
test_layer,
)
_, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
test_others = FP8Helper.update_collections(
{FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others)
test_others = FP8Helper.update_fp8_metas(test_others)
del tmp_grad, fp8_meta_grad
grad_fn = jax.value_and_grad(loss_fn, argnums=(0, 2), has_aux=False)
ref_out, ref_grads = grad_fn(inputs, ref_masks, ref_params, ref_others, ref_layer,
apply_rng)
test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer,
apply_rng)
def reorganize_test_wgrad(test_wgrad, attrs):
num_heads = attrs.get(_KEY_OF_NUM_HEADS)
num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) and \
num_heads == num_gqa_groups
attn_name = 'attention'
unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad)
if "output_layernorm" not in attrs:
unfreeze_test_wgrad['pre_attention_layer_norm'] = {}
pre_attn_layer_key = 'qkv' if fuse_qkv else 'query'
unfreeze_test_wgrad['pre_attention_layer_norm']['scale'] = \
unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale']
del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale']
if 'ln_bias' in unfreeze_test_wgrad[attn_name][pre_attn_layer_key]:
unfreeze_test_wgrad['pre_attention_layer_norm']['ln_bias'] = \
unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
for key in unfreeze_test_wgrad[attn_name].keys():
unfreeze_test_wgrad[attn_name][key]['kernel'] = \
jnp.reshape(unfreeze_test_wgrad[attn_name][key]['kernel'],
(unfreeze_test_wgrad[attn_name][key]['kernel'].shape[0], -1))
unfreeze_test_wgrad['pre_mlp_layer_norm'] = {}
unfreeze_test_wgrad['pre_mlp_layer_norm']['scale'] = \
unfreeze_test_wgrad['mlp']['scale']
del unfreeze_test_wgrad['mlp']['scale']
if 'ln_bias' in unfreeze_test_wgrad['mlp']:
unfreeze_test_wgrad['pre_mlp_layer_norm']['ln_bias'] = \
unfreeze_test_wgrad['mlp']['ln_bias']
del unfreeze_test_wgrad['mlp']['ln_bias']
unfreeze_test_wgrad['mlp']['wi'] = {}
unfreeze_test_wgrad['mlp']['wi']['kernel'] = \
jnp.reshape(unfreeze_test_wgrad['mlp']['wi_kernel'],
(unfreeze_test_wgrad['mlp']['wi_kernel'].shape[0], -1))
del unfreeze_test_wgrad['mlp']['wi_kernel']
unfreeze_test_wgrad['mlp']['wo'] = {}
unfreeze_test_wgrad['mlp']['wo']['kernel'] = \
unfreeze_test_wgrad['mlp']['wo_kernel']
del unfreeze_test_wgrad['mlp']['wo_kernel']
return unfreeze_test_wgrad
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
compare_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol,
atol=atol) # wgrad
del data_rng, init_rng, apply_rng
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', ATTRS)
def test_forward(self, data_shape, dtype, attrs):
FP8Helper.finalize() # Ensure FP8 disabled.
self.forward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
@pytest.mark.parametrize('attrs', ATTRS)
def test_forward_with_fp8(self, data_shape, dtype, fp8_format, attrs):
FP8Helper.initialize(fp8_format=fp8_format)
self.forward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=1e-03)
FP8Helper.finalize()
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs):
FP8Helper.finalize() # Ensure FP8 disabled.
self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
@pytest.mark.parametrize('attrs', ATTRS)
def test_forward_backward_with_fp8(self, data_shape, dtype, fp8_format, attrs):
FP8Helper.initialize(fp8_format=fp8_format)
self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=1e-03)
FP8Helper.finalize()
class TestDecoderLayer:
@staticmethod
def sync_params(ref, target):
unfreeze_target = flax.core.unfreeze(target)
for scope in ['self_attention', 'encoder_decoder_attention']:
unfreeze_scope = unfreeze_target[scope]
ref_scope = ref[scope]
for key in unfreeze_scope.keys():
unfreeze_scope[key]['kernel'] = \
ref_scope[key]['kernel'].reshape(unfreeze_scope[key]['kernel'].shape)
unfreeze_target['mlp']['wi_kernel'] = \
jnp.reshape(ref['mlp']['wi']['kernel'], unfreeze_target['mlp']['wi_kernel'].shape)
unfreeze_target['mlp']['wo_kernel'] = \
ref['mlp']['wo']['kernel']
return ref, unfreeze_target
def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False)
ref_out, (ref_dgrads, ref_wgrads) = grad_fn(inputs, ref_masks, ref_params, ref_others,
ref_layer)
test_out, (test_dgrads, test_wgrads) = grad_fn(inputs, test_masks, test_params, test_others,
test_layer)
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_tree_like_allclose(ref_dgrads, test_dgrads, rtol=rtol, atol=atol)
_, restructed_ref_wgrads = self._sync_params(ref_wgrads, test_wgrads)
assert_tree_like_allclose(restructed_ref_wgrads, test_wgrads, rtol=rtol, atol=atol)
class EncoderRunner(BaseRunner):
"""Encoder runner implementations"""
layer_type = TransformerLayerType.ENCODER
reference_layer = RefEncoderLayer
transformations = {
'attention/qkv/scale': 'pre_attention_layer_norm/scale',
'attention/qkv/ln_bias': 'pre_attention_layer_norm/ln_bias',
'attention/query/scale': 'pre_attention_layer_norm/scale',
'attention/query/ln_bias': 'pre_attention_layer_norm/ln_bias',
'mlp/wi_kernel': 'mlp/wi/kernel',
'mlp/wi_bias': 'mlp/wi/bias',
'mlp/wo_kernel': 'mlp/wo/kernel',
'mlp/wo_bias': 'mlp/wo/bias',
'mlp/scale': 'pre_mlp_layer_norm/scale',
'mlp/ln_bias': 'pre_mlp_layer_norm/ln_bias',
}
def generate_inputs(self, data_shape, dtype):
"""
Return inputs, (ref_masks, test_masks)
"""
transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS]
batch, seqlen = data_shape[:2]
if transpose_batch_sequence:
data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
sequence_dim = 0 if transpose_batch_sequence else 1
data_rng, init_rng, apply_rng = generate_test_rngs()
inputs = (jax.random.normal(data_rng, data_shape,
dtype), jax.random.normal(data_rng, data_shape, dtype))
data_rng = jax.random.PRNGKey(2024)
inputs = (jax.random.normal(data_rng, data_shape, dtype),)
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
ref_masks = (1 - causal_mask, 1 - padded_mask)
test_masks = (causal_mask, padded_mask)
te_layer_attrs = {}
for k, v in attrs.items():
if k == 'dropout_rate':
te_layer_attrs['attention_dropout'] = v
te_layer_attrs['hidden_dropout'] = v
te_layer_attrs['intermediate_dropout'] = v
elif k == 'fuse_mlp_wi':
continue
else:
te_layer_attrs[k] = v
ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs)
layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER,
self_attn_mask_type='padding_causal',
dtype=dtype,
**te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
ref_masks)
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks)
ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params)
ref_out = loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer, apply_rng)
test_out = loss_fn(inputs, test_masks, test_params, test_others, test_layer, apply_rng)
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
del data_rng, init_rng, apply_rng
def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
transpose_batch_sequence = _KEY_OF_TRANSPOSE_BS in attrs and attrs[_KEY_OF_TRANSPOSE_BS]
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ['casual', 'padding_causal']:
mask = causal_mask
else:
mask = padded_mask
ref_masks = (1 - mask,)
test_masks = (None, mask) # The second arg of Transformer is encoded tokens.
return inputs, (ref_masks, test_masks)
class DecoderRunner(BaseRunner):
"""
Decoder runner implementations
"""
layer_type = TransformerLayerType.DECODER
reference_layer = RefDecoderLayer
transformations = {
'encoder_decoder_attention/qkv/scale': 'pre_cross_attention_layer_norm/scale',
'encoder_decoder_attention/qkv/ln_bias': 'pre_cross_attention_layer_norm/ln_bias',
'encoder_decoder_attention/query/scale': 'pre_cross_attention_layer_norm/scale',
'encoder_decoder_attention/query/ln_bias': 'pre_cross_attention_layer_norm/ln_bias',
'self_attention/qkv/scale': 'pre_self_attention_layer_norm/scale',
'self_attention/qkv/ln_bias': 'pre_self_attention_layer_norm/ln_bias',
'self_attention/query/scale': 'pre_self_attention_layer_norm/scale',
'self_attention/query/ln_bias': 'pre_self_attention_layer_norm/ln_bias',
'mlp/wi_kernel': 'mlp/wi/kernel',
'mlp/wi_bias': 'mlp/wi/bias',
'mlp/wo_kernel': 'mlp/wo/kernel',
'mlp/wo_bias': 'mlp/wo/bias',
'mlp/scale': 'pre_mlp_layer_norm/scale',
'mlp/ln_bias': 'pre_mlp_layer_norm/ln_bias',
}
def generate_inputs(self, data_shape, dtype):
"""
Return inputs, (ref_masks, test_masks)
"""
transpose_batch_sequence = self.attrs[_KEY_OF_TRANSPOSE_BS]
batch, seqlen = data_shape[:2]
if transpose_batch_sequence:
data_shape = (data_shape[1], data_shape[0], *data_shape[2:])
sequence_dim = 0 if transpose_batch_sequence else 1
data_rng, init_rng, apply_rng = generate_test_rngs()
inputs = (jax.random.normal(data_rng, data_shape,
dtype), jax.random.normal(data_rng, data_shape, dtype))
data_rng = jax.random.PRNGKey(0)
data_rng_0, data_rng_1 = jax.random.split(data_rng, 2)
inputs = (jax.random.normal(data_rng_0, data_shape,
dtype), jax.random.normal(data_rng_1, data_shape, dtype))
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
ref_masks = (1 - causal_mask, 1 - padded_mask)
test_masks = (causal_mask, padded_mask)
te_layer_attrs = {}
for k, v in attrs.items():
if k == 'dropout_rate':
te_layer_attrs['attention_dropout'] = v
te_layer_attrs['hidden_dropout'] = v
te_layer_attrs['intermediate_dropout'] = v
elif k == 'fuse_mlp_wi':
continue
else:
te_layer_attrs[k] = v
ref_layer_cls = partial(RefDecoderLayer, dtype=dtype, **attrs)
layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER,
self_attn_mask_type='padding_causal',
dtype=dtype,
**te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
ref_masks)
test_layer, test_params, test_others = generate_layer(layer_cls, init_rng, inputs,
test_masks)
ref_params, test_params = TestDecoderLayer.sync_params(ref_params, test_params)
if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ['casual', 'padding_causal']:
self_mask = causal_mask
else:
self_mask = padded_mask
if FP8Helper.is_fp8_enabled():
for _ in range(4):
_, tmp_grad = jax.value_and_grad(loss_fn, argnums=(3,),
has_aux=False)(inputs, test_masks, test_params,
test_others, test_layer, apply_rng)
_, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
test_others = FP8Helper.update_collections(
{FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others)
test_others = FP8Helper.update_fp8_metas(test_others)
del tmp_grad, fp8_meta_grad
ref_masks = (1 - self_mask, 1 - padded_mask)
test_masks = (self_mask, padded_mask)
return inputs, (ref_masks, test_masks)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', ATTRS)
class BaseTester():
"""
Pytest interface to invoke the runner
"""
runner = BaseRunner
grad_fn = jax.value_and_grad(loss_fn, argnums=(0, 2), has_aux=False)
ref_out, ref_grads = grad_fn(inputs, ref_masks, ref_params, ref_others, ref_layer,
apply_rng)
test_out, test_grads = grad_fn(inputs, test_masks, test_params, test_others, test_layer,
apply_rng)
def reorganize_test_wgrad(test_wgrad, attrs):
num_heads = attrs.get(_KEY_OF_NUM_HEADS)
num_gqa_groups = attrs.get(_KEY_OF_NUM_GQA_GROUPS, num_heads)
fuse_qkv = attrs.get(_KEY_OF_FUSE_QKV_PARAMS, True) and \
num_heads == num_gqa_groups
unfreeze_test_wgrad = flax.core.unfreeze(test_wgrad)
if "output_layernorm" not in attrs:
attn_name = 'self_attention'
unfreeze_test_wgrad['pre_self_attention_layer_norm'] = {}
pre_attn_layer_key = 'qkv' if fuse_qkv else 'query'
unfreeze_test_wgrad['pre_self_attention_layer_norm']['scale'] = \
unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale']
del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['scale']
if 'ln_bias' in unfreeze_test_wgrad[attn_name][pre_attn_layer_key]:
unfreeze_test_wgrad['pre_self_attention_layer_norm']['ln_bias'] = \
unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
del unfreeze_test_wgrad[attn_name][pre_attn_layer_key]['ln_bias']
for scope in ['self_attention', 'encoder_decoder_attention']:
for key in unfreeze_test_wgrad[scope].keys():
unfreeze_test_wgrad[scope][key]['kernel'] = \
jnp.reshape(unfreeze_test_wgrad[scope][key]['kernel'],
(unfreeze_test_wgrad[scope][key]['kernel'].shape[0], -1))
unfreeze_test_wgrad['pre_cross_attention_layer_norm'] = {}
unfreeze_test_wgrad['pre_cross_attention_layer_norm']['scale'] = \
unfreeze_test_wgrad['encoder_decoder_attention']['query']['scale']
del unfreeze_test_wgrad['encoder_decoder_attention']['query']['scale']
if 'ln_bias' in unfreeze_test_wgrad['encoder_decoder_attention']['query']:
unfreeze_test_wgrad['pre_cross_attention_layer_norm']['ln_bias'] = \
unfreeze_test_wgrad['encoder_decoder_attention']['query']['ln_bias']
del unfreeze_test_wgrad['encoder_decoder_attention']['query']['ln_bias']
unfreeze_test_wgrad['pre_mlp_layer_norm'] = {}
unfreeze_test_wgrad['pre_mlp_layer_norm']['scale'] = \
unfreeze_test_wgrad['mlp']['scale']
del unfreeze_test_wgrad['mlp']['scale']
if 'ln_bias' in unfreeze_test_wgrad['mlp']:
unfreeze_test_wgrad['pre_mlp_layer_norm']['ln_bias'] = \
unfreeze_test_wgrad['mlp']['ln_bias']
del unfreeze_test_wgrad['mlp']['ln_bias']
unfreeze_test_wgrad['mlp']['wi'] = {}
unfreeze_test_wgrad['mlp']['wi']['kernel'] = \
jnp.reshape(unfreeze_test_wgrad['mlp']['wi_kernel'],
(unfreeze_test_wgrad['mlp']['wi_kernel'].shape[0], -1))
del unfreeze_test_wgrad['mlp']['wi_kernel']
unfreeze_test_wgrad['mlp']['wo'] = {}
unfreeze_test_wgrad['mlp']['wo']['kernel'] = \
unfreeze_test_wgrad['mlp']['wo_kernel']
del unfreeze_test_wgrad['mlp']['wo_kernel']
return unfreeze_test_wgrad
if attrs[_KEY_OF_DROPOUT_RATE] == 0.: # Skip elementwise checking for dropout
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_allclose(ref_grads[0][0], test_grads[0][0], rtol=rtol, atol=atol) # dgrad
compare_dict(ref_grads[1],
reorganize_test_wgrad(test_grads[1], attrs),
rtol=rtol,
atol=atol) # wgrad
del data_rng, init_rng, apply_rng
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', ATTRS)
def test_forward(self, data_shape, dtype, attrs):
"""Test normal datatype forward"""
FP8Helper.finalize() # Ensure FP8 disabled.
self.forward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-5, atol=7e-5)
def test_backward(self, data_shape, dtype, attrs):
"""Test normal datatype backward"""
FP8Helper.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-5, atol=7e-5)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
@pytest.mark.parametrize('attrs', ATTRS)
def test_forward_with_fp8(self, data_shape, dtype, fp8_format, attrs):
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_format):
"""Test forward with fp8 enabled"""
FP8Helper.initialize(fp8_format=fp8_format)
self.forward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=3e-02)
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
FP8Helper.finalize()
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs):
FP8Helper.finalize() # Ensure FP8 disabled.
self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=3e-04)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('fp8_format', FP8_FORMATS)
@pytest.mark.parametrize('attrs', ATTRS)
def test_forward_backward_with_fp8(self, data_shape, dtype, fp8_format, attrs):
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_format):
"""Test backward with fp8 enabled"""
FP8Helper.initialize(fp8_format=fp8_format)
self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-04, atol=3e-02)
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
FP8Helper.finalize()
class TestEncoderLayer(BaseTester):
"""
Test transformer_engine.jax.flax.TransformerLayer(layer_type=Encoder)
"""
runner = EncoderRunner
class TestDecoderLayer(BaseTester):
"""
Test transformer_engine.jax.flax.TransformerLayer(layer_type=Decoder)
"""
runner = DecoderRunner
......@@ -56,16 +56,6 @@ def enable_fused_attn():
del os.environ["NVTE_FUSED_ATTN"]
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd:
assert key in test_fd, \
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for the softmax primitives"""
from contextlib import nullcontext
from dataclasses import dataclass
from functools import wraps
import jax
import jax.numpy as jnp
import pytest
from jax import lax
from jax import nn
from jax import value_and_grad, jit
from jax.typing import DTypeLike
from utils import assert_allclose
from transformer_engine.jax.softmax import is_softmax_kernel_available
from transformer_engine.jax.softmax import SoftmaxType, softmax
def catch_unsupported(method):
"""
The unsupported case should raise error instead of running it incorrectly.
This helper function is to check if the unsupported case raises the assertion error.
"""
@wraps(method)
def wrapper(self, *args, **kwargs):
if not self._is_support():
assertion_checker = pytest.raises(AssertionError)
else:
assertion_checker = nullcontext()
with assertion_checker:
return method(self, *args, **kwargs)
return wrapper
@dataclass
class SoftmaxRunner:
"""
Softmax runner
"""
batch_size: int
max_seqlen_q: int
max_seqlen_kv: int
num_heads: int
scale_factor: float
softmax_type: SoftmaxType
dtype: DTypeLike
@staticmethod
def reference_softmax(logits, mask, scale_factor, **_):
"""
Jax softmax as the reference
"""
if mask is not None:
logits += lax.select(mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.).astype(logits.dtype))
return nn.softmax(logits * scale_factor)
def _is_support(self):
return is_softmax_kernel_available(self.softmax_type, self.batch_size, self.num_heads,
self.max_seqlen_q, self.max_seqlen_kv, self.dtype)
def _setup_inputs(self):
key = jax.random.PRNGKey(0)
logits_key, mask_key = jax.random.split(key, 2)
logits_shape = (self.batch_size, self.num_heads, self.max_seqlen_q, self.max_seqlen_kv)
mask_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.)
match self.softmax_type:
case SoftmaxType.SCALED:
self.mask = None
case SoftmaxType.SCALED_MASKED:
self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8)
case SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
self.mask = (1. - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8)
case _:
raise ValueError(f"Unknown {self.softmax_type=}")
@catch_unsupported
def test_forward(self):
"""
Test transformer_engine.jax.softmax.softmax fwd rule
"""
self._setup_inputs()
primitive_out = softmax(self.logits, self.mask, self.scale_factor, self.softmax_type)
reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor)
assert_allclose(primitive_out, reference_out, dtype=self.dtype)
@catch_unsupported
def test_backward(self):
"""
Test transformer_engine.jax.softmax.softmax bwd rule
"""
self._setup_inputs()
def grad_func(func, *args, **kwargs):
fwd_out = func(*args, **kwargs)
return jnp.mean(fwd_out, dtype=jnp.float32).astype(self.dtype)
args = [self.logits, self.mask]
kwargs = {
'scale_factor': self.scale_factor,
'softmax_type': self.softmax_type,
}
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
value_and_grad(lambda logits, *args: grad_func(softmax, self.logits, *args, **kwargs),
(0,)))
jitted_reference = jit(
value_and_grad(
lambda logits, *args: grad_func(__class__.reference_softmax, self.logits, *args, **
kwargs), (0,)))
primitive_out, (primitive_grad_logits,) = jitted_primitive(*args)
reference_out, (reference_grad_logits,) = jitted_reference(*args)
assert_allclose(primitive_out, reference_out, dtype=self.dtype)
assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype)
@pytest.mark.parametrize('b, s_q, s_kv, h', [
pytest.param(8, 16, 16, 16, id='8-16-16-16'),
pytest.param(8, 512, 512, 16, id='8-512-512-16'),
pytest.param(2, 8, 16384, 8, id='2-8-16384-8')
])
@pytest.mark.parametrize('scale_factor', [0.125])
@pytest.mark.parametrize('softmax_type', [
pytest.param(SoftmaxType.SCALED, id='SCALED'),
pytest.param(SoftmaxType.SCALED_MASKED, id='SCALED_MASKED'),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id='SCALED_UPPER_TRIANG_MASKED')
])
@pytest.mark.parametrize('dtype', [
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16"),
])
class TestSoftmax:
"""
Test transformer_engine.jax.softmax.softmax
"""
@staticmethod
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
"""
Test forward with parameterized configs
"""
runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner.test_forward()
@staticmethod
def test_backward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
"""
Test forward with parameterized configs
"""
runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner.test_backward()
......@@ -13,6 +13,7 @@ import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import combine_masks
from jax import lax, vmap
from jax import nn as jax_nn
from jax import random as jax_random
......@@ -64,27 +65,6 @@ def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Calla
raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")
def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
"""Combine attention masks.
Args:
*masks: set of attention mask arguments to combine, some can be None.
dtype: final mask dtype
Returns:
Combined mask, reduced by logical and, returns None if no masks given.
"""
masks = [m for m in masks if m is not None]
if not masks:
return None
assert all(map(lambda x: x.ndim == masks[0].ndim,
masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
mask, *other_masks = masks
for other_mask in other_masks:
mask = jnp.logical_and(mask, other_mask)
return mask.astype(dtype)
def combine_biases(*masks: Optional[Array]):
"""Combine attention biases.
......@@ -105,96 +85,109 @@ def combine_biases(*masks: Optional[Array]):
return mask
def dot_product_attention(query: Array,
key: Array,
value: Array,
transpose_batch_sequence: bool,
bias: Optional[Array] = None,
dropout_rng: Optional[PRNGKey] = None,
dropout_rate: float = 0.,
deterministic: bool = False,
dtype: DType = jnp.float32,
float32_logits: bool = False):
class DotProductAttention(nn.Module):
transpose_batch_sequence: bool = True
scale_attn_logits: bool = True
dropout_rate: float = 0.
dtype: DType = jnp.float32
float32_logits: bool = False
"""Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
query and key and combines the values using the attention weights.
This is the core function for applying attention based on
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
query and key and combines the values using the attention weights.
Args:
query: queries for calculating attention with shape of `[batch, q_length,
num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of `[batch, kv_length,
num_gqa_groups, qk_depth_per_head]`.
value: values to be used in attention with shape of `[batch, kv_length,
num_gqa_groups, v_depth_per_head]`.
bias: bias for the attention weights. This should be broadcastable to the
shape `[batch, num_heads, q_length, kv_length]` This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
deterministic: bool, deterministic or not (to apply dropout)
dtype: the dtype of the computation (default: float32)
float32_logits: bool, if True then compute logits in float32 to avoid
numerical issues with bfloat16.
Args:
dropout_rate: dropout rate
dtype: the dtype of the computation (default: float32)
float32_logits: bool, if True then compute logits in float32 to avoid
numerical issues with bfloat16.
"""
Returns:
Output of shape `[batch, length, num_heads, v_depth_per_head]`.
"""
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
batch_dim = 1 if transpose_batch_sequence else 0
assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
'q, k, v batch dims must match.')
sequence_dim = 0 if transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'
# Casting logits and softmax computation for float32 for model stability.
if float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
# `attn_weights`: [batch, num_heads, groups, q_length, kv_length]
h_q, h_kv = query.shape[-2], key.shape[-2]
assert (h_q % h_kv == 0) and (h_q >= h_kv)
group_size = h_q // h_kv
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
if transpose_batch_sequence:
attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
else:
attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
# reshape back to normal DPA shape for bias/softmax/dropout
b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
# Apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None:
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
# Normalize the attention weights across `kv_length` dimension.
attn_weights = jax_nn.softmax(attn_weights).astype(dtype)
# Apply attention dropout.
if not deterministic and dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate
# T5 broadcasts along the "length" dim, but unclear which one that
# corresponds to in positional dimensions here, assuming query dim.
dropout_shape = list(attn_weights.shape)
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype))
attn_weights = attn_weights * multiplier
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
# Take the linear combination of `value`.
if transpose_batch_sequence:
return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
@nn.compact
def __call__(self,
query: Array,
key: Array,
value: Array,
bias: Optional[Array] = None,
deterministic: bool = False):
"""
Args:
query: queries for calculating attention with shape of `[batch, q_length,
num_heads, qk_depth_per_head]`.
key: keys for calculating attention with shape of `[batch, kv_length,
num_gqa_groups, qk_depth_per_head]`.
value: values to be used in attention with shape of `[batch, kv_length,
num_gqa_groups, v_depth_per_head]`.
bias: bias for the attention weights. This should be broadcastable to the
shape `[batch, num_heads, q_length, kv_length]` This can be used for
incorporating causal masks, padding masks, proximity bias, etc.
dropout_rng: JAX PRNGKey: to be used for dropout
deterministic: bool, deterministic or not (to apply dropout)
Returns:
Output of shape `[batch, length, num_heads, v_depth_per_head]`.
"""
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
batch_dim = 1 if self.transpose_batch_sequence else 0
assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
'q, k, v batch dims must match.')
sequence_dim = 0 if self.transpose_batch_sequence else 1
assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.'
assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'
if self.scale_attn_logits:
head_dim = query.shape[-1]
depth_scaling = jnp.sqrt(head_dim).astype(self.dtype)
query = query / depth_scaling
# Casting logits and softmax computation for float32 for model stability.
if self.float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
# `attn_weights`: [batch, num_heads, groups, q_length, kv_length]
h_q, h_kv = query.shape[-2], key.shape[-2]
assert (h_q % h_kv == 0) and (h_q >= h_kv)
group_size = h_q // h_kv
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
if self.transpose_batch_sequence:
attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
else:
attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
# reshape back to normal DPA shape for bias/softmax/dropout
b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
# Apply attention bias: masking, dropout, proximity bias, etc.
if bias is not None:
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
# Normalize the attention weights across `kv_length` dimension.
attn_weights = jax_nn.softmax(attn_weights).astype(self.dtype)
# Apply attention dropout.
if not deterministic and self.dropout_rate > 0.:
keep_prob = 1.0 - self.dropout_rate
# T5 broadcasts along the "length" dim, but unclear which one that
# corresponds to in positional dimensions here, assuming query dim.
dropout_shape = list(attn_weights.shape)
dropout_rng = self.make_rng('dropout')
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
multiplier = (keep.astype(attn_weights.dtype) /
jnp.asarray(keep_prob, dtype=self.dtype))
attn_weights = attn_weights * multiplier
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
# Take the linear combination of `value`.
if self.transpose_batch_sequence:
return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
class DenseGeneral(nn.Module):
......@@ -253,8 +246,9 @@ class DenseGeneral(nn.Module):
bias = nn_partitioning.param_with_axes('bias',
self.bias_init,
self.features,
self.dtype,
jnp.float32,
axes=self.bias_axes)
bias = bias.astype(self.dtype)
else:
bias = None
......@@ -284,8 +278,10 @@ class MlpBlock(nn.Module):
activations: Sequence[Union[str, Callable]] = ('relu',)
kernel_init: Initializer = None
intermediate_dropout_rate: float = 0.1
intermediate_dropout_dims: Sequence[int] = ()
use_bias: bool = False
dtype: Any = jnp.float32
fuse_wi: bool = False
fuse_wi: bool = True
def __post_init__(self):
if self.kernel_init is None:
......@@ -306,6 +302,8 @@ class MlpBlock(nn.Module):
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=('embed', 'mlp'),
use_bias=self.use_bias,
bias_axes=('mlp'),
name=dense_name)(inputs)
x = jnp.split(x, num_activations, axis=-1)
for idx, act_fn in enumerate(self.activations):
......@@ -318,16 +316,18 @@ class MlpBlock(nn.Module):
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=('embed', 'mlp'),
use_bias=self.use_bias,
bias_axes=('mlp'),
name=dense_name)(inputs)
x = _convert_to_activation_function(act_fn)(x)
activations.append(x)
# Take elementwise product of above intermediate activations.
x = functools.reduce(operator.mul, activations)
dropout_broadcast_dims = (0,) if self.transpose_batch_sequence else (1,)
# Apply dropout and final dense output projection.
x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=dropout_broadcast_dims)(
x, deterministic=deterministic) # Broadcast along length.
x = nn.Dropout(rate=self.intermediate_dropout_rate,
broadcast_dims=self.intermediate_dropout_dims)(
x, deterministic=deterministic) # Broadcast along length.
if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'mlp'))
else:
......@@ -336,6 +336,8 @@ class MlpBlock(nn.Module):
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=('mlp', 'embed'),
use_bias=self.use_bias,
bias_axes=('embed'),
name='wo')(x)
return output
......@@ -369,7 +371,6 @@ def apply_rotary_pos_emb_consecutive(
min_timescale: int = 1,
max_timescale: int = 10000,
):
embedding_dim = inputs.shape[-1]
half_embedding_dim = embedding_dim // 2
fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim
......@@ -429,6 +430,7 @@ class MultiHeadAttention(nn.Module):
enable_rotary_pos_emb: bool = False
rotary_pos_emb_group_method: str = 'consecutive'
fuse_qkv: bool = True
use_bias: bool = False
def __post_init__(self):
if self.kernel_init is None:
......@@ -478,12 +480,16 @@ class MultiHeadAttention(nn.Module):
axis=-1,
features=self.num_heads * self.head_dim,
kernel_axes=('embed', 'joined_kv'),
use_bias=self.use_bias,
bias_axes=('joined_kv'),
dtype=self.dtype)
kv_projection = functools.partial(DenseGeneral,
axis=-1,
features=self.num_gqa_groups * self.head_dim,
kernel_axes=('embed', 'joined_kv'),
use_bias=self.use_bias,
bias_axes=('joined_kv'),
dtype=self.dtype)
# NOTE: T5 does not explicitly rescale the attention logits by
......@@ -519,26 +525,27 @@ class MultiHeadAttention(nn.Module):
features=self.num_heads * self.head_dim * 3,
kernel_axes=('embed', 'joined_kv'),
kernel_init=qkv_init,
use_bias=self.use_bias,
bias_axes=('joined_kv'),
name='qkv',
dtype=self.dtype)(inputs_kv)
query, key, value = jnp.split(
qkv_proj, [self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2],
axis=-1)
if self.scale_attn_logits:
query = query / depth_scaling
else:
query = q_projection(kernel_init=query_init, name='query')( \
(inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q)
query = q_projection(kernel_init=query_init, name='query')(inputs_q)
kv_proj = DenseGeneral(axis=-1,
features=self.num_gqa_groups * self.head_dim * 2,
kernel_axes=('embed', 'joined_kv'),
kernel_init=self.kernel_init,
use_bias=self.use_bias,
bias_axes=('joined_kv'),
name='kv',
dtype=self.dtype)(inputs_kv)
key, value = jnp.split(kv_proj, [self.num_gqa_groups * self.head_dim], axis=-1)
else:
query = q_projection(kernel_init=query_init, name='query')( \
(inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q)
query = q_projection(kernel_init=query_init, name='query')(inputs_q)
key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
......@@ -546,15 +553,18 @@ class MultiHeadAttention(nn.Module):
batch_dim = 1 if self.transpose_batch_sequence else 0
seq_dim = 1 - batch_dim
position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
q_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
k_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
if self.rotary_pos_emb_group_method == 'alternate':
apply_rotary_pos_emb = apply_rotary_pos_emb_alternate
else:
apply_rotary_pos_emb = apply_rotary_pos_emb_consecutive
query = apply_rotary_pos_emb(query, position)
key = apply_rotary_pos_emb(key, position)
query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
query = apply_rotary_pos_emb(query, q_position)
key = apply_rotary_pos_emb(key, k_position)
query = query.reshape((*query.shape[:2], self.num_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
......@@ -656,21 +666,16 @@ class MultiHeadAttention(nn.Module):
if bias is not None:
attention_bias = combine_biases(attention_bias, bias)
dropout_rng = None
if not deterministic and self.dropout_rate > 0.:
dropout_rng = self.make_rng('dropout')
# Apply attention.
x = dot_product_attention(query,
key,
value,
transpose_batch_sequence=self.transpose_batch_sequence,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.dropout_rate,
deterministic=deterministic,
dtype=self.dtype,
float32_logits=self.float32_logits)
x = DotProductAttention(transpose_batch_sequence=self.transpose_batch_sequence,
scale_attn_logits=self.scale_attn_logits,
dropout_rate=self.dropout_rate,
dtype=self.dtype,
float32_logits=self.float32_logits)(query,
key,
value,
bias=attention_bias,
deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
......@@ -685,6 +690,8 @@ class MultiHeadAttention(nn.Module):
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=('joined_kv', 'embed'),
use_bias=self.use_bias,
bias_axes=('embed'),
dtype=self.dtype,
name='out')(x)
return out
......@@ -858,27 +865,36 @@ class RelativePositionBiases(nn.Module):
class EncoderLayer(nn.Module):
"""Transformer encoder layer."""
enable_relative_embedding: bool = True
relative_embedding: nn.Module = None
num_attention_heads: int = 8
num_gqa_groups: int | None = None
head_dim: int = 64
dropout_rate: float = 0.1
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout_dims: Sequence[int] = ()
transpose_batch_sequence: bool = True
float32_attention_logits: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
mlp_dim: int = 2048
mlp_activations: Sequence[str] = ('relu',)
use_bias: bool = False
dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False
layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
output_layernorm: bool = False
drop_path: float = 0.0
enable_rotary_pos_emb: bool = False
rotary_pos_emb_group_method: str = 'consecutive'
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False
fuse_mlp_wi: bool = True
self_attn_bias_type: Any = None
self_attn_mask_type: Any = None
def __post_init__(self):
if self.num_gqa_groups is None:
......@@ -887,21 +903,25 @@ class EncoderLayer(nn.Module):
@nn.compact
def __call__(self, inputs, encoder_mask=None, deterministic=False):
del self.self_attn_mask_type # dummy, just align to TE's impl
# Relative position embedding as attention biases.
sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim
if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32,
max_distance=128,
num_heads=self.num_attention_heads,
dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling(
1.0, 'fan_avg', 'uniform'),
name='relpos_bias')
if self.enable_relative_embedding:
if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32,
max_distance=128,
num_heads=self.num_attention_heads,
dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling(
1.0, 'fan_avg', 'uniform'),
name='relpos_bias')
else:
rel_emb = self.relative_embedding
encoder_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
else:
rel_emb = self.relative_embedding
encoder_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
encoder_bias = None
# Attention block.
residual = inputs
......@@ -909,6 +929,7 @@ class EncoderLayer(nn.Module):
if not self.output_layernorm:
# Attention block.
x = LayerNorm(layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="pre_attention_layer_norm")(inputs)
......@@ -924,20 +945,21 @@ class EncoderLayer(nn.Module):
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.dropout_rate,
dropout_rate=self.attention_dropout,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
fuse_qkv=self.fuse_qkv_params,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
use_bias=self.use_bias,
name='attention')(x,
x,
encoder_mask,
encoder_bias,
deterministic=deterministic)
x = nn.Dropout(rate=self.dropout_rate,
broadcast_dims=(sequence_dim,))(x, deterministic=deterministic)
x = nn.Dropout(rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims)(x, deterministic=deterministic)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
x = nn.Dropout(rate=self.drop_path,
......@@ -947,6 +969,7 @@ class EncoderLayer(nn.Module):
# MLP block.
residual = x
y = LayerNorm(layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name='pre_mlp_layer_norm')(x)
......@@ -959,13 +982,15 @@ class EncoderLayer(nn.Module):
transpose_batch_sequence=self.transpose_batch_sequence,
intermediate_dim=self.mlp_dim,
activations=self.mlp_activations,
intermediate_dropout_rate=self.dropout_rate,
intermediate_dropout_rate=self.intermediate_dropout,
intermediate_dropout_dims=self.intermediate_dropout_dims,
use_bias=self.use_bias,
dtype=self.dtype,
fuse_wi=self.fuse_mlp_wi,
name='mlp',
)(y, deterministic=deterministic)
y = nn.Dropout(rate=self.dropout_rate,
broadcast_dims=(sequence_dim,))(y, deterministic=deterministic)
y = nn.Dropout(rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims)(y, deterministic=deterministic)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim)
y = nn.Dropout(rate=self.drop_path,
......@@ -974,6 +999,7 @@ class EncoderLayer(nn.Module):
if self.output_layernorm:
y = LayerNorm(layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="output_layernorm")(y)
......@@ -982,27 +1008,36 @@ class EncoderLayer(nn.Module):
class DecoderLayer(nn.Module):
"""Transformer decoder layer that attends to the encoder."""
enable_relative_embedding: bool = True
relative_embedding: nn.Module = None
num_attention_heads: int = 8
num_gqa_groups: int | None = None
head_dim: int = 64
dropout_rate: float = 0.1
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout_dims: Sequence[int] = ()
transpose_batch_sequence: bool = True
float32_attention_logits: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
mlp_dim: int = 2048
mlp_activations: Sequence[str] = ('relu',)
use_bias: bool = False
dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
drop_path: float = 0.0
enable_rotary_pos_emb: bool = False
rotary_pos_emb_group_method: str = 'consecutive'
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False
fuse_mlp_wi: bool = True
self_attn_bias_type: Any = None
self_attn_mask_type: Any = None
def __post_init__(self):
if self.num_gqa_groups is None:
......@@ -1018,22 +1053,26 @@ class DecoderLayer(nn.Module):
deterministic=False,
decode=False,
max_decode_length=None):
del self.self_attn_mask_type # dummy, just align to TE's impl
# Relative position embedding as attention biases.
sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim
l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim]
if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32,
max_distance=128,
num_heads=self.num_attention_heads,
dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling(
1.0, 'fan_avg', 'uniform'),
name='relpos_bias')
if self.enable_relative_embedding:
l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim]
if self.relative_embedding is None:
rel_emb = RelativePositionBiases(num_buckets=32,
max_distance=128,
num_heads=self.num_attention_heads,
dtype=self.dtype,
embedding_init=nn.initializers.variance_scaling(
1.0, 'fan_avg', 'uniform'),
name='relpos_bias')
else:
rel_emb = self.relative_embedding
decoder_bias = rel_emb(l, l, False)
else:
rel_emb = self.relative_embedding
decoder_bias = rel_emb(l, l, False)
decoder_bias = None
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
residual = inputs
......@@ -1041,6 +1080,7 @@ class DecoderLayer(nn.Module):
if not self.output_layernorm:
# Attention block.
x = LayerNorm(layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="pre_self_attention_layer_norm")(inputs)
......@@ -1056,21 +1096,22 @@ class DecoderLayer(nn.Module):
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.dropout_rate,
dropout_rate=self.attention_dropout,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias,
name='self_attention')(x,
x,
decoder_mask,
decoder_bias,
deterministic=deterministic,
decode=decode)
x = nn.Dropout(rate=self.dropout_rate,
broadcast_dims=(sequence_dim,))(x, deterministic=deterministic)
x = nn.Dropout(rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims)(x, deterministic=deterministic)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
x = nn.Dropout(rate=self.drop_path,
......@@ -1080,6 +1121,7 @@ class DecoderLayer(nn.Module):
# Encoder-Decoder block.
residual = x
y = LayerNorm(layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name='pre_cross_attention_layer_norm')(x)
......@@ -1091,24 +1133,26 @@ class DecoderLayer(nn.Module):
dtype=self.dtype,
head_dim=self.head_dim,
transpose_batch_sequence=self.transpose_batch_sequence,
dropout_rate=self.dropout_rate,
dropout_rate=self.attention_dropout,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params,
use_bias=self.use_bias,
name='encoder_decoder_attention')(y,
encoded,
encoder_decoder_mask,
deterministic=deterministic)
y = nn.Dropout(rate=self.dropout_rate,
broadcast_dims=(sequence_dim,))(y, deterministic=deterministic)
y = nn.Dropout(rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims)(y, deterministic=deterministic)
y = y + residual
# MLP block.
residual = y
z = LayerNorm(layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name='pre_mlp_layer_norm')(y)
......@@ -1118,13 +1162,15 @@ class DecoderLayer(nn.Module):
transpose_batch_sequence=self.transpose_batch_sequence,
intermediate_dim=self.mlp_dim,
activations=self.mlp_activations,
intermediate_dropout_rate=self.dropout_rate,
intermediate_dropout_rate=self.intermediate_dropout,
intermediate_dropout_dims=self.intermediate_dropout_dims,
use_bias=self.use_bias,
dtype=self.dtype,
fuse_wi=self.fuse_mlp_wi,
name='mlp',
)(z, deterministic=deterministic)
z = nn.Dropout(rate=self.dropout_rate,
broadcast_dims=(sequence_dim,))(z, deterministic=deterministic)
z = nn.Dropout(rate=self.hidden_dropout,
broadcast_dims=self.hidden_dropout_dims)(z, deterministic=deterministic)
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
z = nn.Dropout(rate=self.drop_path,
......@@ -1133,6 +1179,7 @@ class DecoderLayer(nn.Module):
if self.output_layernorm:
z = LayerNorm(layernorm_type=self.layernorm_type,
epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="output_layernorm")(z)
......@@ -1210,6 +1257,21 @@ def assert_allclose(
np.testing.assert_allclose(actual, desired, **tols, **kwargs)
def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08):
flatten_expected, _ = jax.tree_util.tree_flatten_with_path(expected)
flatten_actual, _ = jax.tree_util.tree_flatten_with_path(actual)
for (expected_path, expected_value), (actual_path,
actual_value) in zip(flatten_expected, flatten_actual):
assert expected_path == actual_path
key_str = jax.tree_util.keystr(expected_path)
assert_allclose(expected_value,
actual_value,
rtol=rtol,
atol=atol,
err_msg=f'Value of expected{key_str} and actual{key_str} is not close')
def dtype_tols(
dtype: Union[DType, TEDType, np.dtype],
reference_value: float = 1.0,
......@@ -1259,3 +1321,36 @@ def dtype_tols(
rtol=eps_relaxed,
atol=max(ulp, eps_relaxed),
)
def sync_params_values(dst, src, transformations, sep='/'):
"""
This function will reconstuct a tree with dst's tree_def/shape and src's value.
transformations is a map that records the key mappings between dst and src.
If no dst key found in the transformerations, it will fall back to src key = dst key.
transformations = {
dst key map 0: src key map 0,
dst key map 1: src key map 1,
...
# if dst key = src key, we don't need to add it
}
"""
src_values = {}
for key, value in jax.tree_util.tree_leaves_with_path(src):
normalized_key = sep.join(x.key for x in key)
src_values[normalized_key] = value
flatten_dst, dst_tree_def = jax.tree_util.tree_flatten_with_path(dst)
synced_dst_values = []
for key, value in flatten_dst:
normalized_key = sep.join(x.key for x in key)
if normalized_key in transformations:
corresponding_src_key = transformations[normalized_key]
else:
corresponding_src_key = normalized_key
synced_dst_values.append(src_values[corresponding_src_key])
synced_dst = jax.tree_util.tree_unflatten(dst_tree_def, synced_dst_values)
return jax.tree_util.tree_map(lambda x, y: x.reshape(y.shape), synced_dst, dst)
......@@ -1069,7 +1069,7 @@ class SoftmaxPrimitive(BasePrimitive):
"""
Softmax Primitive
"""
max_k_seqlen_supported = 4096
max_k_seqlen_supported = 16384
name = "te_softmax_internal_placeholder"
@staticmethod
......@@ -1324,8 +1324,7 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
dtype = dtypes.canonicalize_dtype(dtype)
if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096
and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4
and attn_batches % 4 == 0 # batch * heads must be divisor of 4
):
......@@ -1483,8 +1482,7 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
dtype = dtypes.canonicalize_dtype(dtype)
if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096
and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4
and attn_batches % 4 == 0 # batch * heads must be divisor of 4
):
......@@ -1695,11 +1693,10 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
dtype = dtypes.canonicalize_dtype(dtype)
if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096
and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4
and attn_batches % 4 == 0 # batch * heads must be divisor of 4
):
and k_seqlen == q_seqlen):
if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
return attn_batches % batch_per_block == 0
......
......@@ -1035,21 +1035,25 @@ class LayerNormMLP(TransformerEngineBase):
if use_fused_layernorm_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
bias_1_shape = intermediate_dim if self.use_bias else 0
bias_1 = nn_partitioning.param_with_axes('wi_bias',
self.bias_init,
bias_1_shape,
jnp.float32,
axes=self.bias_axes_1)
bias_1 = bias_1.astype(self.dtype)
bias_2_shape = (hidden_size,) if self.use_bias else (0,)
bias_2 = nn_partitioning.param_with_axes('wo_bias',
self.bias_init,
bias_2_shape,
jnp.float32,
axes=self.bias_axes_2)
bias_2 = bias_2.astype(self.dtype)
if self.use_bias:
bias_1_shape = intermediate_dim
bias_1 = nn_partitioning.param_with_axes('wi_bias',
self.bias_init,
bias_1_shape,
jnp.float32,
axes=self.bias_axes_1)
bias_1 = bias_1.astype(self.dtype)
bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes('wo_bias',
self.bias_init,
bias_2_shape,
jnp.float32,
axes=self.bias_axes_2)
bias_2 = bias_2.astype(self.dtype)
else:
bias_1 = jnp.empty(0, self.dtype)
bias_2 = jnp.empty(0, self.dtype)
out = fused_layernorm_fp8_mlp(y,
scale,
......
......@@ -1103,7 +1103,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
else:
assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
# No changes to memory layout, should trigger bicast only (Ideally no Perf impact)
# No changes to memory layout, should trigger bitcast only (Ideally no Perf impact)
query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
......@@ -1161,8 +1161,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
jnp.reshape(cur_index, (-1)), 1, -2)
scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
LEADING_AXES = (BATCH_AXES, SEQLEN_AXES)
if self.transpose_batch_sequence:
LEADING_AXES = (SEQLEN_AXES, BATCH_AXES)
......@@ -1192,6 +1190,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
dpa_args = [query, key, value]
scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
x = DotProductAttention(head_dim=self.head_dim,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment