Unverified Commit ec1030b5 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

Zero-centered gamma (Layernorm1p) support for JAX (#139)



* Add zero_center_gamma/functional pass
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add zero_centered_gamma for fp8_ln_mlp
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add zero_centered_gamma to modules
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add zero_centered_gamma to TransformerLayer
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refactored code style for improved readability and consistency
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Docs enhancement for zero_centered_gamma
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add escape for line break and remove some bad if conditions
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Revise scale_init docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 7f5e4cb9
...@@ -523,32 +523,44 @@ class TestLayerNorm: ...@@ -523,32 +523,44 @@ class TestLayerNorm:
@pytest.mark.parametrize('n, hidden', LN_CASES) @pytest.mark.parametrize('n, hidden', LN_CASES)
@pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('dtype', DTYPES)
def test_forward_backward(self, n, hidden, dtype): @pytest.mark.parametrize('zero_centered_gamma', [False, True])
def test_forward_backward(self, n, hidden, zero_centered_gamma, dtype):
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3) subkeys = jax.random.split(key, 3)
x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -2, 1) x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, -2, 1) scale_range = (-1, 1) if zero_centered_gamma else (0, 2)
scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *scale_range)
scale = jnp.asarray(scale, dtype) scale = jnp.asarray(scale, dtype)
bias = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -2, 1) bias = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
bias = jnp.asarray(bias, dtype) bias = jnp.asarray(bias, dtype)
epsilon = 1e-6 epsilon = 1e-6
def reference_layernorm(x, scale, bias): def reference_layernorm(x, scale, bias, zero_centered_gamma, eps):
x = jnp.asarray(x, jnp.float32) x_ = jnp.asarray(x, jnp.float32)
mean = jnp.mean(x, axis=-1, keepdims=True) mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True) var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
normed_input = (x - mean) * jax.lax.rsqrt(var + epsilon) normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
# Align TE implementation # Align TE implementation
return jnp.asarray(normed_input * scale + bias) if zero_centered_gamma:
return jnp.asarray(normed_input * (scale + 1) + bias).astype(x.dtype)
return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
def compute_loss(x):
# Higher precision to compute the loss
x_ = x.astype(jnp.float32)
return jnp.mean(jnp.square(x_)).astype(x.dtype)
jitted_primitive = jit( jitted_primitive = jit(
value_and_grad(lambda x, scale, bias: jnp.mean(layernorm(x, scale, bias, "layernorm")), value_and_grad(
(0, 1, 2))) lambda x, scale, bias: compute_loss(
layernorm(x, scale, bias, "layernorm", zero_centered_gamma, epsilon)),
(0, 1, 2)))
jitted_reference = jit( jitted_reference = jit(
value_and_grad(lambda x, scale, bias: jnp.mean(reference_layernorm(x, scale, bias)), value_and_grad(
(0, 1, 2))) lambda x, scale, bias: compute_loss(
reference_layernorm(x, scale, bias, zero_centered_gamma, epsilon)), (0, 1, 2)))
primitive_out, (primitive_dx, primitive_dgamma, primitive_out, (primitive_dx, primitive_dgamma,
primitive_dbeta) = jitted_primitive(x, scale, bias) primitive_dbeta) = jitted_primitive(x, scale, bias)
...@@ -561,7 +573,7 @@ class TestLayerNorm: ...@@ -561,7 +573,7 @@ class TestLayerNorm:
assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-7) assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-7)
assert_allclose(primitive_dbeta, reference_dbeta, rtol=1e-7) assert_allclose(primitive_dbeta, reference_dbeta, rtol=1e-7)
else: else:
assert_allclose(primitive_out, reference_out, rtol=1e-3) assert_allclose(primitive_out, reference_out, rtol=1e-7)
assert_allclose(primitive_dx, reference_dx, rtol=1e-4, atol=5e-8) assert_allclose(primitive_dx, reference_dx, rtol=1e-5, atol=1e-6)
assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-4, atol=5e-8) assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-5, atol=3e-5)
assert_allclose(primitive_dbeta, reference_dbeta, rtol=1e-4, atol=5e-8) assert_allclose(primitive_dbeta, reference_dbeta, rtol=1e-5, atol=3e-5)
...@@ -66,6 +66,7 @@ _KEY_OF_DROPOUT_RATE = "dropout_rate" ...@@ -66,6 +66,7 @@ _KEY_OF_DROPOUT_RATE = "dropout_rate"
_KEY_OF_MLP_ACTIVATIONS = "mlp_activations" _KEY_OF_MLP_ACTIVATIONS = "mlp_activations"
_KEY_OF_FUSE_MLP_WI = "fuse_mlp_wi" _KEY_OF_FUSE_MLP_WI = "fuse_mlp_wi"
_KEY_OF_LAYERNORM_TYPE = 'layernorm_type' _KEY_OF_LAYERNORM_TYPE = 'layernorm_type'
_KEY_OF_ZERO_CENTERED_GAMMA = 'zero_centered_gamma'
_KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence' _KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence'
BASE_ATTRS = {_KEY_OF_TRANSPOSE_BS: True} BASE_ATTRS = {_KEY_OF_TRANSPOSE_BS: True}
...@@ -74,6 +75,9 @@ ATTRS = [{ ...@@ -74,6 +75,9 @@ ATTRS = [{
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
}, { }, {
_KEY_OF_LAYERNORM_TYPE: 'layernorm', _KEY_OF_LAYERNORM_TYPE: 'layernorm',
}, {
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_ZERO_CENTERED_GAMMA: True
}, { }, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm', _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_RESIDUAL_POST_LAYERNORM: True _KEY_OF_RESIDUAL_POST_LAYERNORM: True
......
...@@ -604,9 +604,18 @@ class LayerNorm(nn.Module): ...@@ -604,9 +604,18 @@ class LayerNorm(nn.Module):
epsilon: float = 1e-6 epsilon: float = 1e-6
dtype: Any = jnp.float32 dtype: Any = jnp.float32
layernorm_type: str = 'layernorm' layernorm_type: str = 'layernorm'
scale_init: Initializer = nn.initializers.ones zero_centered_gamma: bool = False
scale_init: Initializer = None
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
def __post_init__(self):
if self.scale_init is None:
if not self.zero_centered_gamma:
self.scale_init = nn.initializers.ones
else:
self.scale_init = nn.initializers.zeros
super().__post_init__()
@nn.compact @nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""Applies layer normalization on the input.""" """Applies layer normalization on the input."""
...@@ -632,9 +641,13 @@ class LayerNorm(nn.Module): ...@@ -632,9 +641,13 @@ class LayerNorm(nn.Module):
bias = jnp.asarray(bias, self.dtype) bias = jnp.asarray(bias, self.dtype)
y = jnp.asarray(y, self.dtype) y = jnp.asarray(y, self.dtype)
z = y * scale + bias if not self.zero_centered_gamma:
z = y * scale + bias
else:
z = y * (scale + 1) + bias
else: else:
assert self.layernorm_type == 'rmsnorm' assert self.layernorm_type == 'rmsnorm'
assert not self.zero_centered_gamma
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype)
z = y * scale z = y * scale
...@@ -768,6 +781,7 @@ class EncoderLayer(nn.Module): ...@@ -768,6 +781,7 @@ class EncoderLayer(nn.Module):
dtype: Any = jnp.float32 dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
layernorm_type: str = 'layernorm' layernorm_type: str = 'layernorm'
zero_centered_gamma: bool = False
output_layernorm: bool = False output_layernorm: bool = False
drop_path: float = 0.0 drop_path: float = 0.0
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
...@@ -797,6 +811,7 @@ class EncoderLayer(nn.Module): ...@@ -797,6 +811,7 @@ class EncoderLayer(nn.Module):
if not self.output_layernorm: if not self.output_layernorm:
# Attention block. # Attention block.
x = LayerNorm(layernorm_type=self.layernorm_type, x = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype, dtype=self.dtype,
name="pre_attention_layer_norm")(inputs) name="pre_attention_layer_norm")(inputs)
...@@ -831,6 +846,7 @@ class EncoderLayer(nn.Module): ...@@ -831,6 +846,7 @@ class EncoderLayer(nn.Module):
# MLP block. # MLP block.
residual = x residual = x
y = LayerNorm(layernorm_type=self.layernorm_type, y = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype, dtype=self.dtype,
name='pre_mlp_layer_norm')(x) name='pre_mlp_layer_norm')(x)
...@@ -857,6 +873,7 @@ class EncoderLayer(nn.Module): ...@@ -857,6 +873,7 @@ class EncoderLayer(nn.Module):
if self.output_layernorm: if self.output_layernorm:
y = LayerNorm(layernorm_type=self.layernorm_type, y = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype, dtype=self.dtype,
name="output_layer_norm")(y) name="output_layer_norm")(y)
return y return y
...@@ -878,6 +895,7 @@ class DecoderLayer(nn.Module): ...@@ -878,6 +895,7 @@ class DecoderLayer(nn.Module):
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False output_layernorm: bool = False
layernorm_type: str = 'layernorm' layernorm_type: str = 'layernorm'
zero_centered_gamma: bool = False
drop_path: float = 0.0 drop_path: float = 0.0
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False fuse_mlp_wi: bool = False
...@@ -914,6 +932,7 @@ class DecoderLayer(nn.Module): ...@@ -914,6 +932,7 @@ class DecoderLayer(nn.Module):
if not self.output_layernorm: if not self.output_layernorm:
# Attention block. # Attention block.
x = LayerNorm(layernorm_type=self.layernorm_type, x = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype, dtype=self.dtype,
name="pre_self_attention_layer_norm")(inputs) name="pre_self_attention_layer_norm")(inputs)
...@@ -949,6 +968,7 @@ class DecoderLayer(nn.Module): ...@@ -949,6 +968,7 @@ class DecoderLayer(nn.Module):
# Encoder-Decoder block. # Encoder-Decoder block.
residual = x residual = x
y = LayerNorm(layernorm_type=self.layernorm_type, y = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype, dtype=self.dtype,
name='pre_cross_attention_layer_norm')(x) name='pre_cross_attention_layer_norm')(x)
...@@ -974,6 +994,7 @@ class DecoderLayer(nn.Module): ...@@ -974,6 +994,7 @@ class DecoderLayer(nn.Module):
# MLP block. # MLP block.
residual = y residual = y
z = LayerNorm(layernorm_type=self.layernorm_type, z = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype, dtype=self.dtype,
name='pre_mlp_layer_norm')(y) name='pre_mlp_layer_norm')(y)
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
...@@ -997,6 +1018,7 @@ class DecoderLayer(nn.Module): ...@@ -997,6 +1018,7 @@ class DecoderLayer(nn.Module):
if self.output_layernorm: if self.output_layernorm:
z = LayerNorm(layernorm_type=self.layernorm_type, z = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype, dtype=self.dtype,
name="output_layer_norm")(z) name="output_layer_norm")(z)
......
...@@ -745,13 +745,7 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -745,13 +745,7 @@ class LayerNormFwdPrimitive(BasePrimitive):
multiple_results = True multiple_results = True
@staticmethod @staticmethod
def abstract( def abstract(x, gamma, beta, **kwargs): # pylint: disable=unused-argument
x,
gamma,
beta,
*,
epsilon # pylint: disable=unused-argument
):
""" """
LayerNorm fwd abstract LayerNorm fwd abstract
""" """
...@@ -774,7 +768,7 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -774,7 +768,7 @@ class LayerNormFwdPrimitive(BasePrimitive):
) )
@staticmethod @staticmethod
def lowering(ctx, x, gamma, beta, *, epsilon): def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon):
""" """
LayerNorm fwd lowering rules LayerNorm fwd lowering rules
""" """
...@@ -815,6 +809,7 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -815,6 +809,7 @@ class LayerNormFwdPrimitive(BasePrimitive):
hidden_size, hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
zero_centered_gamma,
epsilon, epsilon,
) )
...@@ -826,11 +821,16 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -826,11 +821,16 @@ class LayerNormFwdPrimitive(BasePrimitive):
_layernorm_fwd_p = register_primitive(LayerNormFwdPrimitive) _layernorm_fwd_p = register_primitive(LayerNormFwdPrimitive)
def layernorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, epsilon: float): def layernorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool,
epsilon: float):
""" """
Wrapper for TE layernorm fwd Wrapper for TE layernorm fwd
""" """
return _layernorm_fwd_p.bind(x, gamma, beta, epsilon=epsilon) return _layernorm_fwd_p.bind(x,
gamma,
beta,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
class LayerNormFwdFp8Primitive(BasePrimitive): class LayerNormFwdFp8Primitive(BasePrimitive):
...@@ -848,8 +848,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -848,8 +848,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
amax, amax,
scale, scale,
scale_inv, scale_inv,
*, **kwargs # pylint: disable=unused-argument
epsilon # pylint: disable=unused-argument
): ):
""" """
LayerNorm fwd (fp8 out) abstract LayerNorm fwd (fp8 out) abstract
...@@ -879,7 +878,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -879,7 +878,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
) )
@staticmethod @staticmethod
def lowering(ctx, x, gamma, beta, amax, scale, scale_inv, *, epsilon): def lowering(ctx, x, gamma, beta, amax, scale, scale_inv, *, zero_centered_gamma, epsilon):
""" """
LayerNorm fwd (fp8 out) lowering rules LayerNorm fwd (fp8 out) lowering rules
""" """
...@@ -928,6 +927,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -928,6 +927,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
hidden_size, hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
zero_centered_gamma,
epsilon, epsilon,
) )
...@@ -944,11 +944,19 @@ _layernorm_fwd_fp8_p = register_primitive(LayerNormFwdFp8Primitive) ...@@ -944,11 +944,19 @@ _layernorm_fwd_fp8_p = register_primitive(LayerNormFwdFp8Primitive)
def layernorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, amax: jnp.ndarray, def layernorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, epsilon: float): scale: jnp.ndarray, scale_inv: jnp.ndarray, zero_centered_gamma: bool,
epsilon: float):
""" """
Wrapper for TE layernorm fwd (fp8 out) Wrapper for TE layernorm fwd (fp8 out)
""" """
return _layernorm_fwd_fp8_p.bind(x, gamma, beta, amax, scale, scale_inv, epsilon=epsilon) return _layernorm_fwd_fp8_p.bind(x,
gamma,
beta,
amax,
scale,
scale_inv,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
class LayerNormBwdPrimitive(BasePrimitive): class LayerNormBwdPrimitive(BasePrimitive):
...@@ -959,15 +967,7 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -959,15 +967,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
multiple_results = True multiple_results = True
@staticmethod @staticmethod
def abstract( def abstract(grad_output, mu, rsigma, x, gamma, **kwargs): # pylint: disable=unused-argument
grad_output,
mu,
rsigma,
x,
gamma,
*,
epsilon # pylint: disable=unused-argument
):
""" """
Layernorm bwd abstract Layernorm bwd abstract
""" """
...@@ -993,7 +993,7 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -993,7 +993,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
) )
@staticmethod @staticmethod
def lowering(ctx, grad_output, mu, rsigma, x, gamma, *, epsilon): def lowering(ctx, grad_output, mu, rsigma, x, gamma, *, zero_centered_gamma, epsilon):
""" """
Layernorm bwd lowering rules Layernorm bwd lowering rules
""" """
...@@ -1029,6 +1029,7 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -1029,6 +1029,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
hidden_size, hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
zero_centered_gamma,
epsilon, epsilon,
) )
...@@ -1041,11 +1042,17 @@ _layernorm_bwd_p = register_primitive(LayerNormBwdPrimitive) ...@@ -1041,11 +1042,17 @@ _layernorm_bwd_p = register_primitive(LayerNormBwdPrimitive)
def layernorm_bwd(g: jnp.ndarray, mu: jnp.ndarray, rsigma: jnp.ndarray, x: jnp.ndarray, def layernorm_bwd(g: jnp.ndarray, mu: jnp.ndarray, rsigma: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, epsilon: float): gamma: jnp.ndarray, zero_centered_gamma: bool, epsilon: float):
""" """
Wrapper for TE layernorm bwd Wrapper for TE layernorm bwd
""" """
return _layernorm_bwd_p.bind(g, mu, rsigma, x, gamma, epsilon=epsilon) return _layernorm_bwd_p.bind(g,
mu,
rsigma,
x,
gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
class RmsNormFwdPrimitive(BasePrimitive): class RmsNormFwdPrimitive(BasePrimitive):
...@@ -1056,12 +1063,7 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -1056,12 +1063,7 @@ class RmsNormFwdPrimitive(BasePrimitive):
multiple_results = True multiple_results = True
@staticmethod @staticmethod
def abstract( def abstract(x, gamma, **kwargs): # pylint: disable=unused-argument
x,
gamma,
*,
epsilon # pylint: disable=unused-argument
):
""" """
RMSNorm fwd abstract RMSNorm fwd abstract
""" """
...@@ -1106,6 +1108,7 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -1106,6 +1108,7 @@ class RmsNormFwdPrimitive(BasePrimitive):
hidden_size, hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
False, # RMSNorm doesn't support zero_centered_gamma
epsilon, epsilon,
) )
...@@ -1138,8 +1141,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -1138,8 +1141,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
amax, amax,
scale, scale,
scale_inv, scale_inv,
*, **kwargs # pylint: disable=unused-argument
epsilon # pylint: disable=unused-argument
): ):
""" """
RMSNorm fwd (fp8 out) abstract RMSNorm fwd (fp8 out) abstract
...@@ -1207,6 +1209,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -1207,6 +1209,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
hidden_size, hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
False, # RMSNorm doesn't support zero_centered_gamma
epsilon, epsilon,
) )
...@@ -1243,8 +1246,7 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -1243,8 +1246,7 @@ class RmsNormBwdPrimitive(BasePrimitive):
rsigma, rsigma,
x, x,
gamma, gamma,
*, **kwargs # pylint: disable=unused-argument
epsilon # pylint: disable=unused-argument
): ):
""" """
RMSNorm bwd abstract RMSNorm bwd abstract
...@@ -1298,6 +1300,7 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -1298,6 +1300,7 @@ class RmsNormBwdPrimitive(BasePrimitive):
hidden_size, hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
False, # RMSNorm doesn't support zero_centered_gamma
epsilon, epsilon,
) )
......
...@@ -66,8 +66,9 @@ pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, DType ...@@ -66,8 +66,9 @@ pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, DType
} }
pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype, pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype,
float eps) { bool zero_centered_gamma, float eps) {
return PackOpaque(CustomCallNormDescriptor{n, hidden, x_dtype, w_dtype, eps}); return PackOpaque(
CustomCallNormDescriptor{n, hidden, x_dtype, w_dtype, zero_centered_gamma, eps});
} }
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, size_t heads, pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, size_t heads,
...@@ -269,10 +270,10 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque ...@@ -269,10 +270,10 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque
desc.use_split_accumulator, stream); desc.use_split_accumulator, stream);
} }
void LayerNormForwardImpl(size_t n, size_t hidden, void *input, DType in_dtype, void *weight, void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps, void *input,
DType w_dtype, void *bias, float eps, void *output, DType out_dtype, DType in_dtype, void *weight, DType w_dtype, void *bias, void *output,
void *mu, void *rsigma, float *amax, float *scale, float *scale_inv, DType out_dtype, void *mu, void *rsigma, float *amax, float *scale,
cudaStream_t stream) { float *scale_inv, cudaStream_t stream) {
auto input_shape = std::vector<size_t>{n, hidden}; auto input_shape = std::vector<size_t>{n, hidden};
auto weight_shape = std::vector<size_t>{hidden}; auto weight_shape = std::vector<size_t>{hidden};
auto intermediates_shape = std::vector<size_t>{n}; auto intermediates_shape = std::vector<size_t>{n};
...@@ -291,12 +292,17 @@ void LayerNormForwardImpl(size_t n, size_t hidden, void *input, DType in_dtype, ...@@ -291,12 +292,17 @@ void LayerNormForwardImpl(size_t n, size_t hidden, void *input, DType in_dtype,
TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor; TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
if (!is_layer_norm) {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
}
// The first call is to query the required workspace // The first call is to query the required workspace
if (is_layer_norm) { if (is_layer_norm) {
auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype); auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype);
auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32); auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);
nvte_layernorm_fwd(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream, output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream,
num_sm, dummy_workspace_tensor.data(), dummy_barrier_tensor.data()); num_sm, dummy_workspace_tensor.data(), dummy_barrier_tensor.data());
} else { } else {
...@@ -322,7 +328,7 @@ void LayerNormForwardImpl(size_t n, size_t hidden, void *input, DType in_dtype, ...@@ -322,7 +328,7 @@ void LayerNormForwardImpl(size_t n, size_t hidden, void *input, DType in_dtype,
auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype); auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype);
auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32); auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);
nvte_layernorm_fwd(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream, output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream,
num_sm, workspace_tensor.data(), barrier_tensor.data()); num_sm, workspace_tensor.data(), barrier_tensor.data());
} else { } else {
...@@ -332,9 +338,10 @@ void LayerNormForwardImpl(size_t n, size_t hidden, void *input, DType in_dtype, ...@@ -332,9 +338,10 @@ void LayerNormForwardImpl(size_t n, size_t hidden, void *input, DType in_dtype,
} }
} }
void LayerNormBackwardImpl(size_t n, size_t hidden, void *input, DType in_dtype, void *weight, void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps,
DType w_dtype, void *ograd, void *mu, void *rsigma, float eps, void *input, DType in_dtype, void *weight, DType w_dtype, void *ograd,
void *xgrad, void *wgrad, void *dbeta, cudaStream_t stream) { void *mu, void *rsigma, void *xgrad, void *wgrad, void *dbeta,
cudaStream_t stream) {
auto input_shape = std::vector<size_t>{n, hidden}; auto input_shape = std::vector<size_t>{n, hidden};
auto weight_shape = std::vector<size_t>{hidden}; auto weight_shape = std::vector<size_t>{hidden};
auto intermediates_shape = std::vector<size_t>{n}; auto intermediates_shape = std::vector<size_t>{n};
...@@ -360,12 +367,17 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, void *input, DType in_dtype, ...@@ -360,12 +367,17 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, void *input, DType in_dtype,
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
size_t dbeta_part_size{}; size_t dbeta_part_size{};
auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
if (!is_layer_norm) {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
}
// The first call is to query the workspace // The first call is to query the workspace
if (is_layer_norm) { if (is_layer_norm) {
auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype); auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype); auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
nvte_layernorm_bwd(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
wgrad_tensor.data(), dbeta_tensor.data(), wgrad_tensor.data(), dbeta_tensor.data(),
dummy_dgamma_part_tensor.data(), dummy_dbeta_part_tensor.data(), stream, dummy_dgamma_part_tensor.data(), dummy_dbeta_part_tensor.data(), stream,
...@@ -411,7 +423,7 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, void *input, DType in_dtype, ...@@ -411,7 +423,7 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, void *input, DType in_dtype,
auto dbeta_part_tensor = TensorWrapper(dbeta_part, dummy_dbeta_part_tensor.shape(), auto dbeta_part_tensor = TensorWrapper(dbeta_part, dummy_dbeta_part_tensor.shape(),
dummy_dbeta_part_tensor.dtype()); dummy_dbeta_part_tensor.dtype());
nvte_layernorm_bwd(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
wgrad_tensor.data(), dbeta_tensor.data(), dgamma_part_tensor.data(), wgrad_tensor.data(), dbeta_tensor.data(), dgamma_part_tensor.data(),
dbeta_part_tensor.data(), stream, num_sm, workspace_tensor.data(), dbeta_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
...@@ -444,11 +456,12 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque ...@@ -444,11 +456,12 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto in_dtype = desc.x_dtype; auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto out_dtype = DType::kFloat8E4M3; auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(n, hidden, input, in_dtype, weight, w_dtype, bias, eps, output, out_dtype, LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype,
mu, rsigma, amax, scale, scale_inv, stream); bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, stream);
} }
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -470,9 +483,10 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -470,9 +483,10 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto out_dtype = in_dtype; auto out_dtype = in_dtype;
auto zero_centered_gamma = desc.zero_centered_gamma;
LayerNormForwardImpl(n, hidden, input, in_dtype, weight, w_dtype, bias, eps, output, out_dtype, LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype,
mu, rsigma, amax, scale, scale_inv, stream); bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, stream);
} }
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -483,6 +497,7 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -483,6 +497,7 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto in_dtype = desc.x_dtype; auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto *ograd = buffers[0]; auto *ograd = buffers[0];
auto *mu = buffers[1]; auto *mu = buffers[1];
...@@ -493,8 +508,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -493,8 +508,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto *wgrad = buffers[6]; auto *wgrad = buffers[6];
auto *dbeta = buffers[7]; auto *dbeta = buffers[7];
LayerNormBackwardImpl(n, hidden, input, in_dtype, weight, w_dtype, ograd, mu, rsigma, eps, LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype,
xgrad, wgrad, dbeta, stream); ograd, mu, rsigma, xgrad, wgrad, dbeta, stream);
} }
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -517,10 +532,11 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -517,10 +532,11 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto in_dtype = desc.x_dtype; auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto out_dtype = DType::kFloat8E4M3; auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(n, hidden, input, in_dtype, weight, w_dtype, bias, eps, output, out_dtype, LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype,
mu, rsigma, amax, scale, scale_inv, stream); bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, stream);
} }
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -541,10 +557,11 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz ...@@ -541,10 +557,11 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz
auto in_dtype = desc.x_dtype; auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto out_dtype = in_dtype; auto out_dtype = in_dtype;
LayerNormForwardImpl(n, hidden, input, in_dtype, weight, w_dtype, bias, eps, output, out_dtype, LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype,
mu, rsigma, amax, scale, scale_inv, stream); bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, stream);
} }
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -561,12 +578,13 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si ...@@ -561,12 +578,13 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
auto in_dtype = desc.x_dtype; auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
void *mu = nullptr; void *mu = nullptr;
void *dbeta = nullptr; void *dbeta = nullptr;
LayerNormBackwardImpl(n, hidden, input, in_dtype, weight, w_dtype, ograd, mu, rsigma, eps, LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype,
xgrad, wgrad, dbeta, stream); ograd, mu, rsigma, xgrad, wgrad, dbeta, stream);
} }
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......
...@@ -73,11 +73,12 @@ struct CustomCallNormDescriptor { ...@@ -73,11 +73,12 @@ struct CustomCallNormDescriptor {
size_t hidden; size_t hidden;
DType x_dtype; DType x_dtype;
DType w_dtype; DType w_dtype;
bool zero_centered_gamma;
float eps; float eps;
}; };
pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype, pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype,
float eps); bool zero_centered_gamma, float eps);
struct SoftmaxDescriptor { struct SoftmaxDescriptor {
size_t batch; size_t batch;
......
...@@ -37,6 +37,7 @@ def layernorm(inputs: jnp.ndarray, ...@@ -37,6 +37,7 @@ def layernorm(inputs: jnp.ndarray,
gamma: jnp.ndarray, gamma: jnp.ndarray,
beta: jnp.ndarray, beta: jnp.ndarray,
layernorm_type: str, layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6, epsilon: float = 1e-6,
sharding_type: ShardingType = ShardingType.SINGLE, sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0): dp_dim_index: int = 0):
...@@ -49,15 +50,18 @@ def layernorm(inputs: jnp.ndarray, ...@@ -49,15 +50,18 @@ def layernorm(inputs: jnp.ndarray,
layernorm_type = canonicalize_layernorm_type(layernorm_type) layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'rmsnorm': if layernorm_type == 'rmsnorm':
assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'" assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
if sharding_type is ShardingType.SINGLE: if sharding_type is ShardingType.SINGLE:
output = _layernorm(inputs, output = _layernorm(inputs,
gamma, gamma,
beta, beta,
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type, sharding_type=sharding_type,
dp_axis_name="", dp_axis_name="")
epsilon=epsilon)
else: else:
dp_axis_name = "batch" dp_axis_name = "batch"
tp_axis_name = "model" tp_axis_name = "model"
...@@ -75,9 +79,10 @@ def layernorm(inputs: jnp.ndarray, ...@@ -75,9 +79,10 @@ def layernorm(inputs: jnp.ndarray,
partial_ln = partial(_layernorm, partial_ln = partial(_layernorm,
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type, sharding_type=sharding_type,
dp_axis_name=dp_axis_name, dp_axis_name=dp_axis_name)
epsilon=epsilon)
output = xmap_runner(partial_ln, in_axes, sharding_meta.out_axes, output = xmap_runner(partial_ln, in_axes, sharding_meta.out_axes,
sharding_meta.axis_resources, (inputs_, gamma_, beta_)) sharding_meta.axis_resources, (inputs_, gamma_, beta_))
...@@ -87,9 +92,11 @@ def layernorm(inputs: jnp.ndarray, ...@@ -87,9 +92,11 @@ def layernorm(inputs: jnp.ndarray,
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7))
def _layernorm(x, gamma, beta, layernorm_type, sharding_type, dp_axis_name, epsilon=1e-6): def _layernorm(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon, sharding_type,
output, _ = _layernorm_fwd(x, gamma, beta, layernorm_type, sharding_type, dp_axis_name, epsilon) dp_axis_name):
output, _ = _layernorm_fwd(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon,
sharding_type, dp_axis_name)
return output return output
...@@ -98,23 +105,36 @@ def _layernorm_fwd( ...@@ -98,23 +105,36 @@ def _layernorm_fwd(
gamma, gamma,
beta, beta,
layernorm_type, layernorm_type,
zero_centered_gamma,
epsilon,
sharding_type, # pylint: disable=unused-argument sharding_type, # pylint: disable=unused-argument
dp_axis_name, # pylint: disable=unused-argument dp_axis_name # pylint: disable=unused-argument
epsilon): ):
if layernorm_type == 'layernorm': if layernorm_type == 'layernorm':
output, mu, rsigma = layernorm_fwd(x, gamma, beta, epsilon) output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon)
else: else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
output, rsigma = rmsnorm_fwd(x, gamma, epsilon) output, rsigma = rmsnorm_fwd(x, gamma, epsilon)
mu = None mu = None
return output, (mu, rsigma, x, gamma) return output, (mu, rsigma, x, gamma)
def _layernorm_bwd(layernorm_type, sharding_type, dp_axis_name, epsilon, ctx, g): def _layernorm_bwd(layernorm_type, zero_centered_gamma, epsilon, sharding_type, dp_axis_name, ctx,
g):
mu, rsigma, x, gamma = ctx mu, rsigma, x, gamma = ctx
if layernorm_type == 'layernorm': if layernorm_type == 'layernorm':
grad_input, grad_gamma, grad_beta = layernorm_bwd(g, mu, rsigma, x, gamma, epsilon=epsilon) grad_input, grad_gamma, grad_beta = layernorm_bwd(g,
mu,
rsigma,
x,
gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else: else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
grad_input, grad_gamma = rmsnorm_bwd(g, rsigma, x, gamma, epsilon=epsilon) grad_input, grad_gamma = rmsnorm_bwd(g, rsigma, x, gamma, epsilon=epsilon)
grad_beta = None grad_beta = None
...@@ -135,9 +155,10 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -135,9 +155,10 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
fwd_dtype: TEDType, fwd_dtype: TEDType,
bwd_dtype: TEDType, bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
sharding_type: ShardingType = ShardingType.SINGLE, sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0, dp_dim_index: int = 0) -> jnp.ndarray:
epsilon: float = 1e-6) -> jnp.ndarray:
""" """
LN + fp8 dot fusion wrapper LN + fp8 dot fusion wrapper
""" """
...@@ -147,6 +168,8 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -147,6 +168,8 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
layernorm_type = canonicalize_layernorm_type(layernorm_type) layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'rmsnorm': if layernorm_type == 'rmsnorm':
assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'" assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
assert fp8_gemm_pkg.num_of_gemm == 1 assert fp8_gemm_pkg.num_of_gemm == 1
inputs = fp8_gemm_pkg.inputs inputs = fp8_gemm_pkg.inputs
...@@ -169,10 +192,11 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -169,10 +192,11 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
fwd_dtype, fwd_dtype,
bwd_dtype, bwd_dtype,
contracting_dims, contracting_dims,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type, sharding_type=sharding_type,
dp_axis_name="", dp_axis_name="",
tp_axis_name="", tp_axis_name="")
epsilon=epsilon)
else: else:
dp_axis_name = "batch" dp_axis_name = "batch"
tp_axis_name = "model" tp_axis_name = "model"
...@@ -214,10 +238,11 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -214,10 +238,11 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
fwd_dtype=fwd_dtype, fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype, bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type, sharding_type=sharding_type,
dp_axis_name=dp_axis_name, dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name, tp_axis_name=tp_axis_name)
epsilon=epsilon)
# input, kernel, gamma, beta, fp8_metas # input, kernel, gamma, beta, fp8_metas
in_axes = (ln_sharding_meta.in_axes[0], dot_sharding_meta.in_axes[1], in_axes = (ln_sharding_meta.in_axes[0], dot_sharding_meta.in_axes[1],
...@@ -230,27 +255,18 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage, ...@@ -230,27 +255,18 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15)) @partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16))
def _layernorm_fp8_dot(inputs: jnp.ndarray, def _layernorm_fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray,
kernel: jnp.ndarray, beta: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
gamma: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str,
beta: jnp.ndarray, fwd_dtype: TEDType, bwd_dtype: TEDType,
fp8_maxs: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
layernorm_type: str,
fwd_dtype: TEDType,
bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]], contracting_dims: Tuple[Sequence[int], Sequence[int]],
sharding_type: ShardingType, zero_centered_gamma: bool, epsilon: float, sharding_type: ShardingType,
dp_axis_name: str, dp_axis_name: str, tp_axis_name: str) -> jnp.ndarray:
tp_axis_name: str,
epsilon: float = 1e-6) -> jnp.ndarray:
output, _ = _layernorm_fp8_dot_fwd(inputs, kernel, gamma, beta, fp8_maxs, amax, scale, output, _ = _layernorm_fp8_dot_fwd(inputs, kernel, gamma, beta, fp8_maxs, amax, scale,
scale_inv, layernorm_type, fwd_dtype, bwd_dtype, scale_inv, layernorm_type, fwd_dtype, bwd_dtype,
contracting_dims, sharding_type, dp_axis_name, tp_axis_name, contracting_dims, zero_centered_gamma, epsilon,
epsilon) sharding_type, dp_axis_name, tp_axis_name)
return output return output
...@@ -267,10 +283,11 @@ def _layernorm_fp8_dot_fwd( ...@@ -267,10 +283,11 @@ def _layernorm_fp8_dot_fwd(
fwd_dtype, fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument bwd_dtype, # pylint: disable=unused-argument
contracting_dims, contracting_dims,
zero_centered_gamma,
epsilon,
sharding_type, sharding_type,
dp_axis_name, # pylint: disable=unused-argument dp_axis_name, # pylint: disable=unused-argument
tp_axis_name, tp_axis_name):
epsilon):
lhs_contracting_dims, rhs_contracting_dims = contracting_dims lhs_contracting_dims, rhs_contracting_dims = contracting_dims
input_shape_pre = inputs.shape[:min(lhs_contracting_dims)] input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
...@@ -295,8 +312,11 @@ def _layernorm_fp8_dot_fwd( ...@@ -295,8 +312,11 @@ def _layernorm_fp8_dot_fwd(
input_amax, input_amax,
input_scale, input_scale,
input_scale_inv, input_scale_inv,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon) epsilon=epsilon)
else: else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
ln_out, rsigma, input_amax = rmsnorm_fwd_fp8(inputs, ln_out, rsigma, input_amax = rmsnorm_fwd_fp8(inputs,
gamma, gamma,
input_amax, input_amax,
...@@ -337,10 +357,11 @@ def _layernorm_fp8_dot_bwd( ...@@ -337,10 +357,11 @@ def _layernorm_fp8_dot_bwd(
fwd_dtype, fwd_dtype,
bwd_dtype, bwd_dtype,
contracting_dims, # pylint: disable=unused-argument contracting_dims, # pylint: disable=unused-argument
zero_centered_gamma,
epsilon,
sharding_type, sharding_type,
dp_axis_name, dp_axis_name,
tp_axis_name, tp_axis_name,
epsilon,
ctx, ctx,
g): g):
ln_out_, kernel_cast, \ ln_out_, kernel_cast, \
...@@ -377,8 +398,16 @@ def _layernorm_fp8_dot_bwd( ...@@ -377,8 +398,16 @@ def _layernorm_fp8_dot_bwd(
dgrad = jax.lax.psum(dgrad, tp_axis_name) dgrad = jax.lax.psum(dgrad, tp_axis_name)
if layernorm_type == 'layernorm': if layernorm_type == 'layernorm':
grad_input, grad_gamma, grad_beta = layernorm_bwd(dgrad, mu, rsigma, inputs, gamma, epsilon) grad_input, grad_gamma, grad_beta = layernorm_bwd(dgrad,
mu,
rsigma,
inputs,
gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else: else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
grad_input, grad_gamma = rmsnorm_bwd(dgrad, rsigma, inputs, gamma, epsilon=epsilon) grad_input, grad_gamma = rmsnorm_bwd(dgrad, rsigma, inputs, gamma, epsilon=epsilon)
grad_beta = None grad_beta = None
......
...@@ -103,6 +103,7 @@ def fp8_ln_mlp( ...@@ -103,6 +103,7 @@ def fp8_ln_mlp(
layernorm_type: str, layernorm_type: str,
fwd_dtype: TEDType, fwd_dtype: TEDType,
bwd_dtype: TEDType, bwd_dtype: TEDType,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6, epsilon: float = 1e-6,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
major_sharding_type: MajorShardingType = MajorShardingType.SINGLE, major_sharding_type: MajorShardingType = MajorShardingType.SINGLE,
...@@ -125,12 +126,14 @@ def fp8_ln_mlp( ...@@ -125,12 +126,14 @@ def fp8_ln_mlp(
layernorm_type = canonicalize_layernorm_type(layernorm_type) layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'rmsnorm': if layernorm_type == 'rmsnorm':
assert ln_bias is None, "ln_bias should be None if layernorm_type is 'rmsnorm'" assert ln_bias is None, "ln_bias should be None if layernorm_type is 'rmsnorm'"
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
assert activations == ('gelu', 'linear') assert activations == ('gelu', 'linear')
if major_sharding_type is MajorShardingType.SINGLE: if major_sharding_type is MajorShardingType.SINGLE:
res = _fp8_mlp(inputs, ln_scale, ln_bias, kernel_1, kernel_2, fp8_max, amax, scale, res = _fp8_mlp(inputs, ln_scale, ln_bias, kernel_1, kernel_2, fp8_max, amax, scale,
scale_inv, layernorm_type, activations, epsilon, fwd_dtype, bwd_dtype, scale_inv, layernorm_type, activations, zero_centered_gamma, epsilon,
contracting_dims, major_sharding_type, "", "") fwd_dtype, bwd_dtype, contracting_dims, major_sharding_type, "", "")
else: else:
dp_axis_name = "batch" dp_axis_name = "batch"
tp_axis_name = "model" tp_axis_name = "model"
...@@ -177,6 +180,7 @@ def fp8_ln_mlp( ...@@ -177,6 +180,7 @@ def fp8_ln_mlp(
partial_fp8_mlp = partial(_fp8_mlp, partial_fp8_mlp = partial(_fp8_mlp,
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
activations=activations, activations=activations,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon, epsilon=epsilon,
fwd_dtype=fwd_dtype, fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype, bwd_dtype=bwd_dtype,
...@@ -196,12 +200,13 @@ def fp8_ln_mlp( ...@@ -196,12 +200,13 @@ def fp8_ln_mlp(
return res return res
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17)) @partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray, def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray, kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str, scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str,
activations: Sequence[Union[str, Callable]], epsilon: float, fwd_dtype: TEDType, activations: Sequence[Union[str, Callable]], zero_centered_gamma: bool, epsilon: float,
bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int], Sequence[int]], fwd_dtype: TEDType, bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int],
Sequence[int]],
major_sharding_type: MajorShardingType, dp_axis_name: str, tp_axis_name: str): major_sharding_type: MajorShardingType, dp_axis_name: str, tp_axis_name: str):
res, _ = _fp8_mlp_fwd(inputs, res, _ = _fp8_mlp_fwd(inputs,
ln_scale, ln_scale,
...@@ -214,6 +219,7 @@ def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray, ...@@ -214,6 +219,7 @@ def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray,
scale_inv, scale_inv,
layernorm_type, layernorm_type,
activations, activations,
zero_centered_gamma,
epsilon, epsilon,
fwd_dtype, fwd_dtype,
bwd_dtype, bwd_dtype,
...@@ -236,6 +242,7 @@ def _fp8_mlp_fwd( ...@@ -236,6 +242,7 @@ def _fp8_mlp_fwd(
scale_inv, scale_inv,
layernorm_type, layernorm_type,
activations, activations,
zero_centered_gamma,
epsilon, epsilon,
fwd_dtype, fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument bwd_dtype, # pylint: disable=unused-argument
...@@ -276,8 +283,11 @@ def _fp8_mlp_fwd( ...@@ -276,8 +283,11 @@ def _fp8_mlp_fwd(
input_amax, input_amax,
input_scale, input_scale,
input_scale_inv, input_scale_inv,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon) epsilon=epsilon)
else: else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
ln_out, rsigma, ln_out_amax = rmsnorm_fwd_fp8(inputs_, ln_out, rsigma, ln_out_amax = rmsnorm_fwd_fp8(inputs_,
gamma, gamma,
input_amax, input_amax,
...@@ -334,6 +344,7 @@ def _fp8_mlp_fwd( ...@@ -334,6 +344,7 @@ def _fp8_mlp_fwd(
def _fp8_mlp_bwd( def _fp8_mlp_bwd(
layernorm_type, layernorm_type,
activations, # pylint: disable=unused-argument activations, # pylint: disable=unused-argument
zero_centered_gamma,
epsilon, epsilon,
fwd_dtype, fwd_dtype,
bwd_dtype, bwd_dtype,
...@@ -397,8 +408,11 @@ def _fp8_mlp_bwd( ...@@ -397,8 +408,11 @@ def _fp8_mlp_bwd(
rsigma, rsigma,
inputs_, inputs_,
gamma, gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon) epsilon=epsilon)
else: else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
grad_input, grad_gamma = rmsnorm_bwd(dgrad_1, rsigma, inputs_, gamma, epsilon=epsilon) grad_input, grad_gamma = rmsnorm_bwd(dgrad_1, rsigma, inputs_, gamma, epsilon=epsilon)
grad_beta = None grad_beta = None
......
...@@ -202,8 +202,20 @@ class LayerNorm(nn.Module): ...@@ -202,8 +202,20 @@ class LayerNorm(nn.Module):
A value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm' layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization. Indicate the type of layer normalization.
scale_init : Initializer, default = flax.linen.initializers.ones zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
(1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'.
The default of `scale_init` will also be changed. See `scale_init`.
scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`. Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
Otherwise, scale_init is `flax.linen.initializers.ones`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
scale_axes : Tuple[str, ...], default = ('embed', ) scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh. The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
...@@ -228,7 +240,8 @@ class LayerNorm(nn.Module): ...@@ -228,7 +240,8 @@ class LayerNorm(nn.Module):
""" """
epsilon: float = 1e-6 epsilon: float = 1e-6
layernorm_type: str = 'layernorm' layernorm_type: str = 'layernorm'
scale_init: Initializer = nn.initializers.ones zero_centered_gamma: bool = False
scale_init: Initializer = None
scale_axes: Tuple[str, ...] = ('embed',) scale_axes: Tuple[str, ...] = ('embed',)
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ('embed',) bias_axes: Tuple[str, ...] = ('embed',)
...@@ -236,6 +249,14 @@ class LayerNorm(nn.Module): ...@@ -236,6 +249,14 @@ class LayerNorm(nn.Module):
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
sharding_type: ShardingType = ShardingType.SINGLE sharding_type: ShardingType = ShardingType.SINGLE
def __post_init__(self):
if self.scale_init is None:
if not self.zero_centered_gamma:
self.scale_init = nn.initializers.ones
else:
self.scale_init = nn.initializers.zeros
super().__post_init__()
@nn.compact @nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
""" """
...@@ -256,14 +277,14 @@ class LayerNorm(nn.Module): ...@@ -256,14 +277,14 @@ class LayerNorm(nn.Module):
scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,), scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
self.scale_init, self.scale_axes, self.scale_init, self.scale_axes,
self.bias_init, self.bias_axes, self.dtype) self.bias_init, self.bias_axes, self.dtype)
return layernorm(x, return layernorm(x,
scale, scale,
ln_bias, ln_bias,
self.layernorm_type, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
sharding_type=self.sharding_type, sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0, dp_dim_index=1 if self.transpose_batch_sequence else 0)
epsilon=self.epsilon)
class TransformerEngineBase(nn.Module): class TransformerEngineBase(nn.Module):
...@@ -443,8 +464,20 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -443,8 +464,20 @@ class LayerNormDenseGeneral(TransformerEngineBase):
Indicate the type of layer normalization. Indicate the type of layer normalization.
epsilon : float, default = 1e-6 epsilon : float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
scale_init : Initializer, default = flax.linen.initializers.ones zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
(1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'.
The default of `scale_init` will also be changed. See `scale_init`
scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`. Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
Otherwise, scale_init is `flax.linen.initializers.ones`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
scale_axes : Tuple[str, ...], default = ('embed', ) scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh, The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
...@@ -496,7 +529,8 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -496,7 +529,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
enable_layernorm: bool = True enable_layernorm: bool = True
layernorm_type: str = 'layernorm' layernorm_type: str = 'layernorm'
epsilon: float = 1e-6 epsilon: float = 1e-6
scale_init: Initializer = nn.initializers.ones zero_centered_gamma: bool = False
scale_init: Initializer = None
scale_axes: Tuple[str, ...] = ('embed',) scale_axes: Tuple[str, ...] = ('embed',)
ln_bias_init: Initializer = nn.initializers.zeros ln_bias_init: Initializer = nn.initializers.zeros
ln_bias_axes: Tuple[str, ...] = ('embed',) ln_bias_axes: Tuple[str, ...] = ('embed',)
...@@ -515,6 +549,11 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -515,6 +549,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
if self.scale_init is None:
if not self.zero_centered_gamma:
self.scale_init = nn.initializers.ones
else:
self.scale_init = nn.initializers.zeros
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -553,9 +592,10 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -553,9 +592,10 @@ class LayerNormDenseGeneral(TransformerEngineBase):
scale, scale,
ln_bias, ln_bias,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
sharding_type=self.sharding_type, sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0, dp_dim_index=1 if self.transpose_batch_sequence else 0)
epsilon=self.epsilon)
else: else:
assert not self.return_layernorm_output assert not self.return_layernorm_output
y = inputs y = inputs
...@@ -600,9 +640,10 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -600,9 +640,10 @@ class LayerNormDenseGeneral(TransformerEngineBase):
self.layernorm_type, self.layernorm_type,
FP8Helper.FWD_DTYPE, FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind), FP8Helper.BWD_DTYPE, (axis, contract_ind),
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
sharding_type=self.sharding_type, sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0, dp_dim_index=1 if self.transpose_batch_sequence else 0)
epsilon=self.epsilon)
else: else:
kernel = jnp.asarray(kernel, self.dtype) kernel = jnp.asarray(kernel, self.dtype)
z = lax.dot_general(y, kernel, ((axis, contract_ind), ((), ()))) z = lax.dot_general(y, kernel, ((axis, contract_ind), ((), ())))
...@@ -638,8 +679,20 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -638,8 +679,20 @@ class LayerNormMLP(TransformerEngineBase):
Indicate the type of layer normalization. Indicate the type of layer normalization.
epsilon : float, default = 1e-6 epsilon : float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
scale_init : Initializer, default = flax.linen.initializers.ones zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
(1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'.
The default of `scale_init` will also be changed. See `scale_init`.
scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`. Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
Otherwise, scale_init is `flax.linen.initializers.ones`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
scale_axes : Tuple[str, ...], default = ('embed', ) scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh, The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
...@@ -704,7 +757,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -704,7 +757,8 @@ class LayerNormMLP(TransformerEngineBase):
enable_layernorm: bool = True enable_layernorm: bool = True
layernorm_type: str = 'layernorm' layernorm_type: str = 'layernorm'
epsilon: float = 1e-6 epsilon: float = 1e-6
scale_init: Initializer = nn.initializers.ones zero_centered_gamma: bool = False
scale_init: Initializer = None
scale_axes: Tuple[str, ...] = ('embed',) scale_axes: Tuple[str, ...] = ('embed',)
ln_bias_init: Initializer = nn.initializers.zeros ln_bias_init: Initializer = nn.initializers.zeros
ln_bias_axes: Tuple[str, ...] = ('embed',) ln_bias_axes: Tuple[str, ...] = ('embed',)
...@@ -727,6 +781,11 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -727,6 +781,11 @@ class LayerNormMLP(TransformerEngineBase):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
if self.scale_init is None:
if not self.zero_centered_gamma:
self.scale_init = nn.initializers.ones
else:
self.scale_init = nn.initializers.zeros
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -774,9 +833,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -774,9 +833,10 @@ class LayerNormMLP(TransformerEngineBase):
scale, scale,
ln_bias, ln_bias,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
sharding_type=first_sharding_type, sharding_type=first_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0, dp_dim_index=1 if self.transpose_batch_sequence else 0)
epsilon=self.epsilon)
else: else:
assert not self.return_layernorm_output assert not self.return_layernorm_output
y = inputs y = inputs
...@@ -830,6 +890,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -830,6 +890,7 @@ class LayerNormMLP(TransformerEngineBase):
self.layernorm_type, self.layernorm_type,
FP8Helper.FWD_DTYPE, FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, FP8Helper.BWD_DTYPE,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon, epsilon=self.epsilon,
contracting_dims=(axis, contract_ind), contracting_dims=(axis, contract_ind),
major_sharding_type=self.major_sharding_type, major_sharding_type=self.major_sharding_type,
...@@ -887,9 +948,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -887,9 +948,10 @@ class LayerNormMLP(TransformerEngineBase):
self.layernorm_type, self.layernorm_type,
FP8Helper.FWD_DTYPE, FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind), FP8Helper.BWD_DTYPE, (axis, contract_ind),
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
sharding_type=first_sharding_type, sharding_type=first_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0, dp_dim_index=1 if self.transpose_batch_sequence else 0)
epsilon=self.epsilon)
else: # not enable fp8 else: # not enable fp8
kernel = jnp.asarray(kernel, self.dtype) kernel = jnp.asarray(kernel, self.dtype)
x = lax.dot_general(y, kernel, ((axis, contract_ind), ((), ()))) x = lax.dot_general(y, kernel, ((axis, contract_ind), ((), ())))
......
...@@ -205,6 +205,14 @@ class MultiHeadAttention(nn.Module): ...@@ -205,6 +205,14 @@ class MultiHeadAttention(nn.Module):
Indicate the type of layer normalization. Indicate the type of layer normalization.
layernorm_epsilon: float, default = 1e-6 layernorm_epsilon: float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
(1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'.
kernel_init: Initializer, default = kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
Used for initializing the QKV and Output projection weights. Used for initializing the QKV and Output projection weights.
...@@ -250,6 +258,7 @@ class MultiHeadAttention(nn.Module): ...@@ -250,6 +258,7 @@ class MultiHeadAttention(nn.Module):
dropout_rng_name: str = 'dropout' dropout_rng_name: str = 'dropout'
layernorm_type: str = "layernorm" layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6 layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
kernel_init: Initializer = None kernel_init: Initializer = None
use_bias: bool = False use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
...@@ -346,6 +355,7 @@ class MultiHeadAttention(nn.Module): ...@@ -346,6 +355,7 @@ class MultiHeadAttention(nn.Module):
qkv_proj, ln_out = LayerNormDenseGeneral( qkv_proj, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm, enable_layernorm=not self.output_layernorm,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
axis=-1, axis=-1,
features=(3, self.num_heads * self.head_dim), features=(3, self.num_heads * self.head_dim),
...@@ -369,6 +379,7 @@ class MultiHeadAttention(nn.Module): ...@@ -369,6 +379,7 @@ class MultiHeadAttention(nn.Module):
query, ln_out = LayerNormDenseGeneral( query, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm, enable_layernorm=not self.output_layernorm,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
axis=-1, axis=-1,
features=self.num_heads * self.head_dim, features=self.num_heads * self.head_dim,
...@@ -412,6 +423,7 @@ class MultiHeadAttention(nn.Module): ...@@ -412,6 +423,7 @@ class MultiHeadAttention(nn.Module):
query, ln_out = LayerNormDenseGeneral( query, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm, enable_layernorm=not self.output_layernorm,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
axis=-1, axis=-1,
features=self.num_heads * self.head_dim, features=self.num_heads * self.head_dim,
...@@ -661,6 +673,14 @@ class TransformerLayer(nn.Module): ...@@ -661,6 +673,14 @@ class TransformerLayer(nn.Module):
Indicate the type of layer normalization. Indicate the type of layer normalization.
layernorm_epsilon: float, default = 1e-6 layernorm_epsilon: float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability. A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
(1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'.
hidden_dropout: float, default = 0.1 hidden_dropout: float, default = 0.1
Dropout probability for the dropout op after FC2 layer. Dropout probability for the dropout op after FC2 layer.
hidden_dropout_dims: Sequence[int], default = () hidden_dropout_dims: Sequence[int], default = ()
...@@ -740,6 +760,7 @@ class TransformerLayer(nn.Module): ...@@ -740,6 +760,7 @@ class TransformerLayer(nn.Module):
num_attention_heads: int = 8 num_attention_heads: int = 8
layernorm_type: str = 'layernorm' layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6 layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
hidden_dropout: float = 0.1 hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = () hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1 attention_dropout: float = 0.1
...@@ -874,6 +895,7 @@ class TransformerLayer(nn.Module): ...@@ -874,6 +895,7 @@ class TransformerLayer(nn.Module):
scaled_query_init=self.scaled_query_init, scaled_query_init=self.scaled_query_init,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon, layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm, apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm, output_layernorm=self.output_layernorm,
attn_type=self_attn_type, attn_type=self_attn_type,
...@@ -919,6 +941,7 @@ class TransformerLayer(nn.Module): ...@@ -919,6 +941,7 @@ class TransformerLayer(nn.Module):
dropout_rng_name=self.dropout_rng_name, dropout_rng_name=self.dropout_rng_name,
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon, layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
apply_residual_connection_post_layernorm=self. apply_residual_connection_post_layernorm=self.
apply_residual_connection_post_layernorm, apply_residual_connection_post_layernorm,
output_layernorm=False, # Must do LayerNorm before MHA. output_layernorm=False, # Must do LayerNorm before MHA.
...@@ -941,6 +964,7 @@ class TransformerLayer(nn.Module): ...@@ -941,6 +964,7 @@ class TransformerLayer(nn.Module):
residual = mlp_input residual = mlp_input
z, ln_out = LayerNormMLP( z, ln_out = LayerNormMLP(
layernorm_type=self.layernorm_type, layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon, epsilon=self.layernorm_epsilon,
major_sharding_type=infer_major_sharding_type(), major_sharding_type=infer_major_sharding_type(),
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
...@@ -973,11 +997,12 @@ class TransformerLayer(nn.Module): ...@@ -973,11 +997,12 @@ class TransformerLayer(nn.Module):
if self.output_layernorm: if self.output_layernorm:
ln_sharding_type, _ = infer_sharding_type() ln_sharding_type, _ = infer_sharding_type()
z = LayerNorm(layernorm_type=self.layernorm_type, z = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon,
scale_axes=('embed',), scale_axes=('embed',),
bias_axes=('embed',), bias_axes=('embed',),
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype, dtype=self.dtype,
epsilon=self.layernorm_epsilon,
sharding_type=ln_sharding_type, sharding_type=ln_sharding_type,
name="output_layer_norm")(z) name="output_layer_norm")(z)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment