Unverified Commit 2045a426 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Enhance JAX unit tests (#796)



* Add layernorm_fp8_dot unit test
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update the softmax primitives support conditions
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add tests for the softmax primitives
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Round1 refactor of test_layer
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Split dropout arguments of ref code and add hidden/intermediate dropout elementwise comparison
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add dropout_braodcast_dim, self_attn_mask tests and clean a few code
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Abstract test layer and fix a rope reference code diff
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add bias tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add epsilon and float32 tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add relpos_bias and attention dropout tests
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Loose the atol
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Move common fixtures to conftest.py
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add doc string for test_layer
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add doc string for test_layer
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix conflicts of test_layer
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Avoid to left bias parameters in graph when use_bias=False
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 6459fd85
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""conftest for tests/jax"""
import jax
import pytest
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
......@@ -2,6 +2,7 @@
#
# See LICENSE for license information.
from contextlib import nullcontext
import functools
import operator
from typing import Callable, Sequence, Union
......@@ -10,7 +11,6 @@ import jax
import jax.numpy as jnp
import numpy as np
import pytest
from jax import lax
from jax import jit, value_and_grad
from flax import linen as nn
......@@ -18,7 +18,7 @@ from utils import assert_allclose
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
from transformer_engine.jax.mlp import activation_lu, activation_lu_fp8, fused_layernorm_fp8_mlp
......@@ -45,16 +45,6 @@ def _convert_to_activation_function(fn_or_string):
raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
class TestFP8Dot:
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
......@@ -416,88 +406,150 @@ class TestActivationLuFP8(TestActivationLu):
dtype=FP8Helper.BWD_DTYPE)
class TestRMSNorm:
@pytest.mark.parametrize('n, hidden', LN_CASES)
@pytest.mark.parametrize('dtype', DTYPES)
def test_forward_backward(self, n, hidden, dtype):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -2, 1)
scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, -2, 1)
scale = jnp.asarray(scale, dtype)
epsilon = 1e-6
def reference_rmsnorm(x, scale):
x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * lax.rsqrt(mean2 + epsilon), dtype)
return y * scale
jitted_primitive = jit(
value_and_grad(lambda x, scale: jnp.mean(layernorm(x, scale, None, "rmsnorm")), (0, 1)))
jitted_reference = jit(
value_and_grad(lambda x, scale: jnp.mean(reference_rmsnorm(x, scale)), (0, 1)))
primitive_out, (primitive_dx, primitive_dgamma) = jitted_primitive(x, scale)
reference_out, (reference_dx, reference_dgamma) = jitted_reference(x, scale)
assert_allclose(primitive_out, reference_out, dtype=dtype)
assert_allclose(primitive_dx, reference_dx, dtype=dtype)
assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
class TestNorm:
"""
Test transformer_engine.jax.layernorm APIs
"""
class TestLayerNorm:
def reference_layernorm(self, x, scale, bias, zero_centered_gamma, eps):
"""
JAX native layernorm implementations
- bias is not None: layernorm
- bias is None: rmsnorm
"""
x_ = jnp.asarray(x, jnp.float32)
if bias is None:
mean = 0.
else:
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
if zero_centered_gamma:
scale += 1.
if bias is None:
bias = 0.
return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
@pytest.mark.parametrize('n, hidden', LN_CASES)
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm'])
@pytest.mark.parametrize('zero_centered_gamma', [False, True])
def test_forward_backward(self, n, hidden, zero_centered_gamma, dtype):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3)
x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
scale_range = (-1, 1) if zero_centered_gamma else (0, 2)
scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *scale_range)
scale = jnp.asarray(scale, dtype)
bias = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
bias = jnp.asarray(bias, dtype)
epsilon = 1e-6
def reference_layernorm(x, scale, bias, zero_centered_gamma, eps):
x_ = jnp.asarray(x, jnp.float32)
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
# Align TE implementation
if zero_centered_gamma:
return jnp.asarray(normed_input * (scale + 1) + bias).astype(x.dtype)
return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
def compute_loss(x):
# Higher precision to compute the loss
x_ = x.astype(jnp.float32)
return jnp.mean(jnp.square(x_)).astype(x.dtype)
jitted_primitive = jit(
value_and_grad(
lambda x, scale, bias: compute_loss(
layernorm(x, scale, bias, "layernorm", zero_centered_gamma, epsilon)),
(0, 1, 2)))
jitted_reference = jit(
value_and_grad(
lambda x, scale, bias: compute_loss(
reference_layernorm(x, scale, bias, zero_centered_gamma, epsilon)), (0, 1, 2)))
primitive_out, (primitive_dx, primitive_dgamma,
primitive_dbeta) = jitted_primitive(x, scale, bias)
reference_out, (reference_dx, reference_dgamma,
reference_dbeta) = jitted_reference(x, scale, bias)
assert_allclose(primitive_out, reference_out, dtype=dtype)
assert_allclose(primitive_dx, reference_dx, dtype=dtype)
assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)
@pytest.mark.parametrize('epsilon', [1e-2, 1e-6])
def test_layernorm_forward_backward(self, n, hidden, ln_type, zero_centered_gamma, epsilon,
dtype):
"""
Test transformer_engine.jax.layernorm.layernorm
"""
expect_assert = False
if ln_type == 'rmsnorm' and zero_centered_gamma:
# zero_centered_gamma is not supported for rmsnorm, expect an assertion.
expect_assert = True
with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*"
) if expect_assert else nullcontext():
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3)
x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
gamma = jnp.asarray(gamma, dtype)
if ln_type == 'layernorm':
beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
beta = jnp.asarray(beta, dtype)
else:
beta = None
def compute_loss(x):
# Higher precision to compute the loss
x_ = x.astype(jnp.float32)
return jnp.mean(jnp.square(x_)).astype(x.dtype)
jitted_primitive = jit(
value_and_grad(
lambda x, gamma, beta: compute_loss(
layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)),
(0, 1, 2)))
jitted_reference = jit(
value_and_grad(
lambda x, gamma, beta: compute_loss(
self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)),
(0, 1, 2)))
primitive_out, (primitive_dx, primitive_dgamma,
primitive_dbeta) = jitted_primitive(x, gamma, beta)
reference_out, (reference_dx, reference_dgamma,
reference_dbeta) = jitted_reference(x, gamma, beta)
assert_allclose(primitive_out, reference_out, dtype=dtype)
assert_allclose(primitive_dx, reference_dx, dtype=dtype)
assert_allclose(primitive_dgamma, reference_dgamma, dtype=dtype)
if beta is not None:
assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('m,n,k', GEMM_CASES)
@pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm'])
@pytest.mark.parametrize('zero_centered_gamma', [True, False])
@pytest.mark.parametrize('epsilon', [1e-2, 1e-6])
def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon):
"""
Test transformer_engine.jax.layernorm.layernorm_fp8_dot
"""
expect_assert = False
if ln_type == 'rmsnorm' and zero_centered_gamma:
# zero_centered_gamma is not supported for rmsnorm, expect an assertion.
expect_assert = True
with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*"
) if expect_assert else nullcontext():
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 4)
a = jax.random.normal(subkeys[0], (m, k)).astype(jnp.bfloat16)
b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
if ln_type == 'layernorm':
beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
else:
beta = None
fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM)
fp8_metas_amax = jnp.zeros((FP8Helper.NUM_META_PER_GEMM, FP8Helper.AMAX_HISTORY_LEN),
jnp.float32)
fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM, 1), jnp.float32)
def primitive_func(x, y, gamma, beta, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv):
fp8_meta_pkg = FP8MetaPackage(1, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
primitive_out = layernorm_fp8_dot(x, y, gamma, beta, fp8_meta_pkg, ln_type,
zero_centered_gamma)
return jnp.mean(primitive_out)
def ref_func(x, y, gamma, beta, zero_centered_gamma):
x = self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
return jnp.mean(jnp.dot(x, y))
value_n_grad_primitive_func = value_and_grad(primitive_func, range(8))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))
ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad,
ref_beta_grad) = value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma)
for _ in range(3):
primitive_out, (primitive_a_grad, primitive_b_grad, primitive_gamma_grad,
primitive_beta_grad, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv) = value_n_grad_primitive_func(
a, b, gamma, beta, fp8_max, fp8_metas_amax, fp8_metas_scale,
fp8_metas_scale_inv)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(primitive_gamma_grad, ref_gamma_grad, dtype=FP8Helper.BWD_DTYPE)
if beta is not None:
assert_allclose(primitive_beta_grad, ref_beta_grad, dtype=FP8Helper.BWD_DTYPE)
......@@ -27,16 +27,14 @@ from transformer_engine_jax import NVTE_Fused_Attn_Backend
from utils import assert_allclose
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
@pytest.fixture(autouse=True, scope='module')
def init():
"""
Clear all live arrays to keep the resource clean
WAR for CUDA uninitialize error
"""
# Calling customcalls before jax may cause CUDA uninitialize error
_ = jnp.zeros(0)
yield
for arr in jax.live_arrays():
arr.delete()
def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike,
......
This diff is collapsed.
......@@ -56,16 +56,6 @@ def enable_fused_attn():
del os.environ["NVTE_FUSED_ATTN"]
@pytest.fixture(autouse=True, scope='function')
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
"""
yield
for arr in jax.live_arrays():
arr.delete()
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd:
assert key in test_fd, \
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tests for the softmax primitives"""
from contextlib import nullcontext
from dataclasses import dataclass
from functools import wraps
import jax
import jax.numpy as jnp
import pytest
from jax import lax
from jax import nn
from jax import value_and_grad, jit
from jax.typing import DTypeLike
from utils import assert_allclose
from transformer_engine.jax.softmax import is_softmax_kernel_available
from transformer_engine.jax.softmax import SoftmaxType, softmax
def catch_unsupported(method):
"""
The unsupported case should raise error instead of running it incorrectly.
This helper function is to check if the unsupported case raises the assertion error.
"""
@wraps(method)
def wrapper(self, *args, **kwargs):
if not self._is_support():
assertion_checker = pytest.raises(AssertionError)
else:
assertion_checker = nullcontext()
with assertion_checker:
return method(self, *args, **kwargs)
return wrapper
@dataclass
class SoftmaxRunner:
"""
Softmax runner
"""
batch_size: int
max_seqlen_q: int
max_seqlen_kv: int
num_heads: int
scale_factor: float
softmax_type: SoftmaxType
dtype: DTypeLike
@staticmethod
def reference_softmax(logits, mask, scale_factor, **_):
"""
Jax softmax as the reference
"""
if mask is not None:
logits += lax.select(mask > 0,
jnp.full(mask.shape, -1e10).astype(logits.dtype),
jnp.full(mask.shape, 0.).astype(logits.dtype))
return nn.softmax(logits * scale_factor)
def _is_support(self):
return is_softmax_kernel_available(self.softmax_type, self.batch_size, self.num_heads,
self.max_seqlen_q, self.max_seqlen_kv, self.dtype)
def _setup_inputs(self):
key = jax.random.PRNGKey(0)
logits_key, mask_key = jax.random.split(key, 2)
logits_shape = (self.batch_size, self.num_heads, self.max_seqlen_q, self.max_seqlen_kv)
mask_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.)
match self.softmax_type:
case SoftmaxType.SCALED:
self.mask = None
case SoftmaxType.SCALED_MASKED:
self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8)
case SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
self.mask = (1. - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8)
case _:
raise ValueError(f"Unknown {self.softmax_type=}")
@catch_unsupported
def test_forward(self):
"""
Test transformer_engine.jax.softmax.softmax fwd rule
"""
self._setup_inputs()
primitive_out = softmax(self.logits, self.mask, self.scale_factor, self.softmax_type)
reference_out = __class__.reference_softmax(self.logits, self.mask, self.scale_factor)
assert_allclose(primitive_out, reference_out, dtype=self.dtype)
@catch_unsupported
def test_backward(self):
"""
Test transformer_engine.jax.softmax.softmax bwd rule
"""
self._setup_inputs()
def grad_func(func, *args, **kwargs):
fwd_out = func(*args, **kwargs)
return jnp.mean(fwd_out, dtype=jnp.float32).astype(self.dtype)
args = [self.logits, self.mask]
kwargs = {
'scale_factor': self.scale_factor,
'softmax_type': self.softmax_type,
}
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
value_and_grad(lambda logits, *args: grad_func(softmax, self.logits, *args, **kwargs),
(0,)))
jitted_reference = jit(
value_and_grad(
lambda logits, *args: grad_func(__class__.reference_softmax, self.logits, *args, **
kwargs), (0,)))
primitive_out, (primitive_grad_logits,) = jitted_primitive(*args)
reference_out, (reference_grad_logits,) = jitted_reference(*args)
assert_allclose(primitive_out, reference_out, dtype=self.dtype)
assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype)
@pytest.mark.parametrize('b, s_q, s_kv, h', [
pytest.param(8, 16, 16, 16, id='8-16-16-16'),
pytest.param(8, 512, 512, 16, id='8-512-512-16'),
pytest.param(2, 8, 16384, 8, id='2-8-16384-8')
])
@pytest.mark.parametrize('scale_factor', [0.125])
@pytest.mark.parametrize('softmax_type', [
pytest.param(SoftmaxType.SCALED, id='SCALED'),
pytest.param(SoftmaxType.SCALED_MASKED, id='SCALED_MASKED'),
pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id='SCALED_UPPER_TRIANG_MASKED')
])
@pytest.mark.parametrize('dtype', [
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16"),
])
class TestSoftmax:
"""
Test transformer_engine.jax.softmax.softmax
"""
@staticmethod
def test_forward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
"""
Test forward with parameterized configs
"""
runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner.test_forward()
@staticmethod
def test_backward(b, s_q, s_kv, h, scale_factor, softmax_type, dtype):
"""
Test forward with parameterized configs
"""
runner = SoftmaxRunner(b, s_q, s_kv, h, scale_factor, softmax_type, dtype)
runner.test_backward()
This diff is collapsed.
......@@ -1069,7 +1069,7 @@ class SoftmaxPrimitive(BasePrimitive):
"""
Softmax Primitive
"""
max_k_seqlen_supported = 4096
max_k_seqlen_supported = 16384
name = "te_softmax_internal_placeholder"
@staticmethod
......@@ -1324,8 +1324,7 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
dtype = dtypes.canonicalize_dtype(dtype)
if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096
and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4
and attn_batches % 4 == 0 # batch * heads must be divisor of 4
):
......@@ -1483,8 +1482,7 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
dtype = dtypes.canonicalize_dtype(dtype)
if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096
and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4
and attn_batches % 4 == 0 # batch * heads must be divisor of 4
):
......@@ -1695,11 +1693,10 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
dtype = dtypes.canonicalize_dtype(dtype)
if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096
and 16 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
and q_seqlen % 4 == 0 # q_seqlen must be divisor of 4
and attn_batches % 4 == 0 # batch * heads must be divisor of 4
):
and k_seqlen == q_seqlen):
if 0 <= k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported:
batch_per_block = SoftmaxPrimitive.get_batch_per_block(k_seqlen)
return attn_batches % batch_per_block == 0
......
......@@ -1035,21 +1035,25 @@ class LayerNormMLP(TransformerEngineBase):
if use_fused_layernorm_mlp:
assert self.axis == -1 # Only support axis = =-1 at this moment
bias_1_shape = intermediate_dim if self.use_bias else 0
bias_1 = nn_partitioning.param_with_axes('wi_bias',
self.bias_init,
bias_1_shape,
jnp.float32,
axes=self.bias_axes_1)
bias_1 = bias_1.astype(self.dtype)
bias_2_shape = (hidden_size,) if self.use_bias else (0,)
bias_2 = nn_partitioning.param_with_axes('wo_bias',
self.bias_init,
bias_2_shape,
jnp.float32,
axes=self.bias_axes_2)
bias_2 = bias_2.astype(self.dtype)
if self.use_bias:
bias_1_shape = intermediate_dim
bias_1 = nn_partitioning.param_with_axes('wi_bias',
self.bias_init,
bias_1_shape,
jnp.float32,
axes=self.bias_axes_1)
bias_1 = bias_1.astype(self.dtype)
bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes('wo_bias',
self.bias_init,
bias_2_shape,
jnp.float32,
axes=self.bias_axes_2)
bias_2 = bias_2.astype(self.dtype)
else:
bias_1 = jnp.empty(0, self.dtype)
bias_2 = jnp.empty(0, self.dtype)
out = fused_layernorm_fp8_mlp(y,
scale,
......
......@@ -1103,7 +1103,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
else:
assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
# No changes to memory layout, should trigger bicast only (Ideally no Perf impact)
# No changes to memory layout, should trigger bitcast only (Ideally no Perf impact)
query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
......@@ -1161,8 +1161,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
jnp.reshape(cur_index, (-1)), 1, -2)
scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
LEADING_AXES = (BATCH_AXES, SEQLEN_AXES)
if self.transpose_batch_sequence:
LEADING_AXES = (SEQLEN_AXES, BATCH_AXES)
......@@ -1192,6 +1190,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint)
dpa_args = [query, key, value]
scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0
x = DotProductAttention(head_dim=self.head_dim,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment