Unverified Commit ec1030b5 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

Zero-centered gamma (Layernorm1p) support for JAX (#139)



* Add zero_center_gamma/functional pass
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add zero_centered_gamma for fp8_ln_mlp
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add zero_centered_gamma to modules
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add zero_centered_gamma to TransformerLayer
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refactored code style for improved readability and consistency
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Docs enhancement for zero_centered_gamma
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Add escape for line break and remove some bad if conditions
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Revise scale_init docs
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 7f5e4cb9
......@@ -523,32 +523,44 @@ class TestLayerNorm:
@pytest.mark.parametrize('n, hidden', LN_CASES)
@pytest.mark.parametrize('dtype', DTYPES)
def test_forward_backward(self, n, hidden, dtype):
@pytest.mark.parametrize('zero_centered_gamma', [False, True])
def test_forward_backward(self, n, hidden, zero_centered_gamma, dtype):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3)
x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -2, 1)
scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, -2, 1)
x = jax.random.uniform(subkeys[0], (n, hidden), dtype, -1, 1)
scale_range = (-1, 1) if zero_centered_gamma else (0, 2)
scale = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *scale_range)
scale = jnp.asarray(scale, dtype)
bias = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -2, 1)
bias = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
bias = jnp.asarray(bias, dtype)
epsilon = 1e-6
def reference_layernorm(x, scale, bias):
x = jnp.asarray(x, jnp.float32)
mean = jnp.mean(x, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True)
normed_input = (x - mean) * jax.lax.rsqrt(var + epsilon)
def reference_layernorm(x, scale, bias, zero_centered_gamma, eps):
x_ = jnp.asarray(x, jnp.float32)
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
# Align TE implementation
return jnp.asarray(normed_input * scale + bias)
if zero_centered_gamma:
return jnp.asarray(normed_input * (scale + 1) + bias).astype(x.dtype)
return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
def compute_loss(x):
# Higher precision to compute the loss
x_ = x.astype(jnp.float32)
return jnp.mean(jnp.square(x_)).astype(x.dtype)
jitted_primitive = jit(
value_and_grad(lambda x, scale, bias: jnp.mean(layernorm(x, scale, bias, "layernorm")),
value_and_grad(
lambda x, scale, bias: compute_loss(
layernorm(x, scale, bias, "layernorm", zero_centered_gamma, epsilon)),
(0, 1, 2)))
jitted_reference = jit(
value_and_grad(lambda x, scale, bias: jnp.mean(reference_layernorm(x, scale, bias)),
(0, 1, 2)))
value_and_grad(
lambda x, scale, bias: compute_loss(
reference_layernorm(x, scale, bias, zero_centered_gamma, epsilon)), (0, 1, 2)))
primitive_out, (primitive_dx, primitive_dgamma,
primitive_dbeta) = jitted_primitive(x, scale, bias)
......@@ -561,7 +573,7 @@ class TestLayerNorm:
assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-7)
assert_allclose(primitive_dbeta, reference_dbeta, rtol=1e-7)
else:
assert_allclose(primitive_out, reference_out, rtol=1e-3)
assert_allclose(primitive_dx, reference_dx, rtol=1e-4, atol=5e-8)
assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-4, atol=5e-8)
assert_allclose(primitive_dbeta, reference_dbeta, rtol=1e-4, atol=5e-8)
assert_allclose(primitive_out, reference_out, rtol=1e-7)
assert_allclose(primitive_dx, reference_dx, rtol=1e-5, atol=1e-6)
assert_allclose(primitive_dgamma, reference_dgamma, rtol=1e-5, atol=3e-5)
assert_allclose(primitive_dbeta, reference_dbeta, rtol=1e-5, atol=3e-5)
......@@ -66,6 +66,7 @@ _KEY_OF_DROPOUT_RATE = "dropout_rate"
_KEY_OF_MLP_ACTIVATIONS = "mlp_activations"
_KEY_OF_FUSE_MLP_WI = "fuse_mlp_wi"
_KEY_OF_LAYERNORM_TYPE = 'layernorm_type'
_KEY_OF_ZERO_CENTERED_GAMMA = 'zero_centered_gamma'
_KEY_OF_TRANSPOSE_BS = 'transpose_batch_sequence'
BASE_ATTRS = {_KEY_OF_TRANSPOSE_BS: True}
......@@ -74,6 +75,9 @@ ATTRS = [{
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
}, {
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
}, {
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_ZERO_CENTERED_GAMMA: True
}, {
_KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
_KEY_OF_RESIDUAL_POST_LAYERNORM: True
......
......@@ -604,9 +604,18 @@ class LayerNorm(nn.Module):
epsilon: float = 1e-6
dtype: Any = jnp.float32
layernorm_type: str = 'layernorm'
scale_init: Initializer = nn.initializers.ones
zero_centered_gamma: bool = False
scale_init: Initializer = None
bias_init: Initializer = nn.initializers.zeros
def __post_init__(self):
if self.scale_init is None:
if not self.zero_centered_gamma:
self.scale_init = nn.initializers.ones
else:
self.scale_init = nn.initializers.zeros
super().__post_init__()
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""Applies layer normalization on the input."""
......@@ -632,9 +641,13 @@ class LayerNorm(nn.Module):
bias = jnp.asarray(bias, self.dtype)
y = jnp.asarray(y, self.dtype)
if not self.zero_centered_gamma:
z = y * scale + bias
else:
z = y * (scale + 1) + bias
else:
assert self.layernorm_type == 'rmsnorm'
assert not self.zero_centered_gamma
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype)
z = y * scale
......@@ -768,6 +781,7 @@ class EncoderLayer(nn.Module):
dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False
layernorm_type: str = 'layernorm'
zero_centered_gamma: bool = False
output_layernorm: bool = False
drop_path: float = 0.0
fuse_qkv_params: bool = True
......@@ -797,6 +811,7 @@ class EncoderLayer(nn.Module):
if not self.output_layernorm:
# Attention block.
x = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="pre_attention_layer_norm")(inputs)
......@@ -831,6 +846,7 @@ class EncoderLayer(nn.Module):
# MLP block.
residual = x
y = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name='pre_mlp_layer_norm')(x)
......@@ -857,6 +873,7 @@ class EncoderLayer(nn.Module):
if self.output_layernorm:
y = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="output_layer_norm")(y)
return y
......@@ -878,6 +895,7 @@ class DecoderLayer(nn.Module):
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
layernorm_type: str = 'layernorm'
zero_centered_gamma: bool = False
drop_path: float = 0.0
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False
......@@ -914,6 +932,7 @@ class DecoderLayer(nn.Module):
if not self.output_layernorm:
# Attention block.
x = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="pre_self_attention_layer_norm")(inputs)
......@@ -949,6 +968,7 @@ class DecoderLayer(nn.Module):
# Encoder-Decoder block.
residual = x
y = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name='pre_cross_attention_layer_norm')(x)
......@@ -974,6 +994,7 @@ class DecoderLayer(nn.Module):
# MLP block.
residual = y
z = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name='pre_mlp_layer_norm')(y)
if self.apply_residual_connection_post_layernorm:
......@@ -997,6 +1018,7 @@ class DecoderLayer(nn.Module):
if self.output_layernorm:
z = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
dtype=self.dtype,
name="output_layer_norm")(z)
......
......@@ -745,13 +745,7 @@ class LayerNormFwdPrimitive(BasePrimitive):
multiple_results = True
@staticmethod
def abstract(
x,
gamma,
beta,
*,
epsilon # pylint: disable=unused-argument
):
def abstract(x, gamma, beta, **kwargs): # pylint: disable=unused-argument
"""
LayerNorm fwd abstract
"""
......@@ -774,7 +768,7 @@ class LayerNormFwdPrimitive(BasePrimitive):
)
@staticmethod
def lowering(ctx, x, gamma, beta, *, epsilon):
def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon):
"""
LayerNorm fwd lowering rules
"""
......@@ -815,6 +809,7 @@ class LayerNormFwdPrimitive(BasePrimitive):
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
zero_centered_gamma,
epsilon,
)
......@@ -826,11 +821,16 @@ class LayerNormFwdPrimitive(BasePrimitive):
_layernorm_fwd_p = register_primitive(LayerNormFwdPrimitive)
def layernorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, epsilon: float):
def layernorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool,
epsilon: float):
"""
Wrapper for TE layernorm fwd
"""
return _layernorm_fwd_p.bind(x, gamma, beta, epsilon=epsilon)
return _layernorm_fwd_p.bind(x,
gamma,
beta,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
class LayerNormFwdFp8Primitive(BasePrimitive):
......@@ -848,8 +848,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
amax,
scale,
scale_inv,
*,
epsilon # pylint: disable=unused-argument
**kwargs # pylint: disable=unused-argument
):
"""
LayerNorm fwd (fp8 out) abstract
......@@ -879,7 +878,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
)
@staticmethod
def lowering(ctx, x, gamma, beta, amax, scale, scale_inv, *, epsilon):
def lowering(ctx, x, gamma, beta, amax, scale, scale_inv, *, zero_centered_gamma, epsilon):
"""
LayerNorm fwd (fp8 out) lowering rules
"""
......@@ -928,6 +927,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
zero_centered_gamma,
epsilon,
)
......@@ -944,11 +944,19 @@ _layernorm_fwd_fp8_p = register_primitive(LayerNormFwdFp8Primitive)
def layernorm_fwd_fp8(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, epsilon: float):
scale: jnp.ndarray, scale_inv: jnp.ndarray, zero_centered_gamma: bool,
epsilon: float):
"""
Wrapper for TE layernorm fwd (fp8 out)
"""
return _layernorm_fwd_fp8_p.bind(x, gamma, beta, amax, scale, scale_inv, epsilon=epsilon)
return _layernorm_fwd_fp8_p.bind(x,
gamma,
beta,
amax,
scale,
scale_inv,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
class LayerNormBwdPrimitive(BasePrimitive):
......@@ -959,15 +967,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
multiple_results = True
@staticmethod
def abstract(
grad_output,
mu,
rsigma,
x,
gamma,
*,
epsilon # pylint: disable=unused-argument
):
def abstract(grad_output, mu, rsigma, x, gamma, **kwargs): # pylint: disable=unused-argument
"""
Layernorm bwd abstract
"""
......@@ -993,7 +993,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
)
@staticmethod
def lowering(ctx, grad_output, mu, rsigma, x, gamma, *, epsilon):
def lowering(ctx, grad_output, mu, rsigma, x, gamma, *, zero_centered_gamma, epsilon):
"""
Layernorm bwd lowering rules
"""
......@@ -1029,6 +1029,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
zero_centered_gamma,
epsilon,
)
......@@ -1041,11 +1042,17 @@ _layernorm_bwd_p = register_primitive(LayerNormBwdPrimitive)
def layernorm_bwd(g: jnp.ndarray, mu: jnp.ndarray, rsigma: jnp.ndarray, x: jnp.ndarray,
gamma: jnp.ndarray, epsilon: float):
gamma: jnp.ndarray, zero_centered_gamma: bool, epsilon: float):
"""
Wrapper for TE layernorm bwd
"""
return _layernorm_bwd_p.bind(g, mu, rsigma, x, gamma, epsilon=epsilon)
return _layernorm_bwd_p.bind(g,
mu,
rsigma,
x,
gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
class RmsNormFwdPrimitive(BasePrimitive):
......@@ -1056,12 +1063,7 @@ class RmsNormFwdPrimitive(BasePrimitive):
multiple_results = True
@staticmethod
def abstract(
x,
gamma,
*,
epsilon # pylint: disable=unused-argument
):
def abstract(x, gamma, **kwargs): # pylint: disable=unused-argument
"""
RMSNorm fwd abstract
"""
......@@ -1106,6 +1108,7 @@ class RmsNormFwdPrimitive(BasePrimitive):
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
)
......@@ -1138,8 +1141,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
amax,
scale,
scale_inv,
*,
epsilon # pylint: disable=unused-argument
**kwargs # pylint: disable=unused-argument
):
"""
RMSNorm fwd (fp8 out) abstract
......@@ -1207,6 +1209,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
)
......@@ -1243,8 +1246,7 @@ class RmsNormBwdPrimitive(BasePrimitive):
rsigma,
x,
gamma,
*,
epsilon # pylint: disable=unused-argument
**kwargs # pylint: disable=unused-argument
):
"""
RMSNorm bwd abstract
......@@ -1298,6 +1300,7 @@ class RmsNormBwdPrimitive(BasePrimitive):
hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
)
......
......@@ -66,8 +66,9 @@ pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, DType
}
pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype,
float eps) {
return PackOpaque(CustomCallNormDescriptor{n, hidden, x_dtype, w_dtype, eps});
bool zero_centered_gamma, float eps) {
return PackOpaque(
CustomCallNormDescriptor{n, hidden, x_dtype, w_dtype, zero_centered_gamma, eps});
}
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, size_t heads,
......@@ -269,10 +270,10 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque
desc.use_split_accumulator, stream);
}
void LayerNormForwardImpl(size_t n, size_t hidden, void *input, DType in_dtype, void *weight,
DType w_dtype, void *bias, float eps, void *output, DType out_dtype,
void *mu, void *rsigma, float *amax, float *scale, float *scale_inv,
cudaStream_t stream) {
void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps, void *input,
DType in_dtype, void *weight, DType w_dtype, void *bias, void *output,
DType out_dtype, void *mu, void *rsigma, float *amax, float *scale,
float *scale_inv, cudaStream_t stream) {
auto input_shape = std::vector<size_t>{n, hidden};
auto weight_shape = std::vector<size_t>{hidden};
auto intermediates_shape = std::vector<size_t>{n};
......@@ -291,12 +292,17 @@ void LayerNormForwardImpl(size_t n, size_t hidden, void *input, DType in_dtype,
TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
if (!is_layer_norm) {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
}
// The first call is to query the required workspace
if (is_layer_norm) {
auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype);
auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);
nvte_layernorm_fwd(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream,
num_sm, dummy_workspace_tensor.data(), dummy_barrier_tensor.data());
} else {
......@@ -322,7 +328,7 @@ void LayerNormForwardImpl(size_t n, size_t hidden, void *input, DType in_dtype,
auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype);
auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32);
nvte_layernorm_fwd(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps,
output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream,
num_sm, workspace_tensor.data(), barrier_tensor.data());
} else {
......@@ -332,9 +338,10 @@ void LayerNormForwardImpl(size_t n, size_t hidden, void *input, DType in_dtype,
}
}
void LayerNormBackwardImpl(size_t n, size_t hidden, void *input, DType in_dtype, void *weight,
DType w_dtype, void *ograd, void *mu, void *rsigma, float eps,
void *xgrad, void *wgrad, void *dbeta, cudaStream_t stream) {
void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps,
void *input, DType in_dtype, void *weight, DType w_dtype, void *ograd,
void *mu, void *rsigma, void *xgrad, void *wgrad, void *dbeta,
cudaStream_t stream) {
auto input_shape = std::vector<size_t>{n, hidden};
auto weight_shape = std::vector<size_t>{hidden};
auto intermediates_shape = std::vector<size_t>{n};
......@@ -360,12 +367,17 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, void *input, DType in_dtype,
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
size_t dbeta_part_size{};
auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
if (!is_layer_norm) {
NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma.");
}
// The first call is to query the workspace
if (is_layer_norm) {
auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype);
auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype);
nvte_layernorm_bwd(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
wgrad_tensor.data(), dbeta_tensor.data(),
dummy_dgamma_part_tensor.data(), dummy_dbeta_part_tensor.data(), stream,
......@@ -411,7 +423,7 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, void *input, DType in_dtype,
auto dbeta_part_tensor = TensorWrapper(dbeta_part, dummy_dbeta_part_tensor.shape(),
dummy_dbeta_part_tensor.dtype());
nvte_layernorm_bwd(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(),
rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(),
wgrad_tensor.data(), dbeta_tensor.data(), dgamma_part_tensor.data(),
dbeta_part_tensor.data(), stream, num_sm, workspace_tensor.data(),
......@@ -444,11 +456,12 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(n, hidden, input, in_dtype, weight, w_dtype, bias, eps, output, out_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype,
bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, stream);
}
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -470,9 +483,10 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto w_dtype = desc.w_dtype;
auto eps = desc.eps;
auto out_dtype = in_dtype;
auto zero_centered_gamma = desc.zero_centered_gamma;
LayerNormForwardImpl(n, hidden, input, in_dtype, weight, w_dtype, bias, eps, output, out_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype,
bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, stream);
}
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -483,6 +497,7 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto *ograd = buffers[0];
auto *mu = buffers[1];
......@@ -493,8 +508,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto *wgrad = buffers[6];
auto *dbeta = buffers[7];
LayerNormBackwardImpl(n, hidden, input, in_dtype, weight, w_dtype, ograd, mu, rsigma, eps,
xgrad, wgrad, dbeta, stream);
LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype,
ograd, mu, rsigma, xgrad, wgrad, dbeta, stream);
}
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -517,10 +532,11 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(n, hidden, input, in_dtype, weight, w_dtype, bias, eps, output, out_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype,
bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, stream);
}
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -541,10 +557,11 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto out_dtype = in_dtype;
LayerNormForwardImpl(n, hidden, input, in_dtype, weight, w_dtype, bias, eps, output, out_dtype,
mu, rsigma, amax, scale, scale_inv, stream);
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype,
bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, stream);
}
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -561,12 +578,13 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
auto in_dtype = desc.x_dtype;
auto w_dtype = desc.w_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
void *mu = nullptr;
void *dbeta = nullptr;
LayerNormBackwardImpl(n, hidden, input, in_dtype, weight, w_dtype, ograd, mu, rsigma, eps,
xgrad, wgrad, dbeta, stream);
LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype,
ograd, mu, rsigma, xgrad, wgrad, dbeta, stream);
}
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......
......@@ -73,11 +73,12 @@ struct CustomCallNormDescriptor {
size_t hidden;
DType x_dtype;
DType w_dtype;
bool zero_centered_gamma;
float eps;
};
pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype,
float eps);
bool zero_centered_gamma, float eps);
struct SoftmaxDescriptor {
size_t batch;
......
......@@ -37,6 +37,7 @@ def layernorm(inputs: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
layernorm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0):
......@@ -49,15 +50,18 @@ def layernorm(inputs: jnp.ndarray,
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'rmsnorm':
assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
if sharding_type is ShardingType.SINGLE:
output = _layernorm(inputs,
gamma,
beta,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type,
dp_axis_name="",
epsilon=epsilon)
dp_axis_name="")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
......@@ -75,9 +79,10 @@ def layernorm(inputs: jnp.ndarray,
partial_ln = partial(_layernorm,
layernorm_type=layernorm_type,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
epsilon=epsilon)
dp_axis_name=dp_axis_name)
output = xmap_runner(partial_ln, in_axes, sharding_meta.out_axes,
sharding_meta.axis_resources, (inputs_, gamma_, beta_))
......@@ -87,9 +92,11 @@ def layernorm(inputs: jnp.ndarray,
return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6))
def _layernorm(x, gamma, beta, layernorm_type, sharding_type, dp_axis_name, epsilon=1e-6):
output, _ = _layernorm_fwd(x, gamma, beta, layernorm_type, sharding_type, dp_axis_name, epsilon)
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7))
def _layernorm(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon, sharding_type,
dp_axis_name):
output, _ = _layernorm_fwd(x, gamma, beta, layernorm_type, zero_centered_gamma, epsilon,
sharding_type, dp_axis_name)
return output
......@@ -98,23 +105,36 @@ def _layernorm_fwd(
gamma,
beta,
layernorm_type,
zero_centered_gamma,
epsilon,
sharding_type, # pylint: disable=unused-argument
dp_axis_name, # pylint: disable=unused-argument
epsilon):
dp_axis_name # pylint: disable=unused-argument
):
if layernorm_type == 'layernorm':
output, mu, rsigma = layernorm_fwd(x, gamma, beta, epsilon)
output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
output, rsigma = rmsnorm_fwd(x, gamma, epsilon)
mu = None
return output, (mu, rsigma, x, gamma)
def _layernorm_bwd(layernorm_type, sharding_type, dp_axis_name, epsilon, ctx, g):
def _layernorm_bwd(layernorm_type, zero_centered_gamma, epsilon, sharding_type, dp_axis_name, ctx,
g):
mu, rsigma, x, gamma = ctx
if layernorm_type == 'layernorm':
grad_input, grad_gamma, grad_beta = layernorm_bwd(g, mu, rsigma, x, gamma, epsilon=epsilon)
grad_input, grad_gamma, grad_beta = layernorm_bwd(g,
mu,
rsigma,
x,
gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
grad_input, grad_gamma = rmsnorm_bwd(g, rsigma, x, gamma, epsilon=epsilon)
grad_beta = None
......@@ -135,9 +155,10 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
fwd_dtype: TEDType,
bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
sharding_type: ShardingType = ShardingType.SINGLE,
dp_dim_index: int = 0,
epsilon: float = 1e-6) -> jnp.ndarray:
dp_dim_index: int = 0) -> jnp.ndarray:
"""
LN + fp8 dot fusion wrapper
"""
......@@ -147,6 +168,8 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'rmsnorm':
assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'"
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
assert fp8_gemm_pkg.num_of_gemm == 1
inputs = fp8_gemm_pkg.inputs
......@@ -169,10 +192,11 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
fwd_dtype,
bwd_dtype,
contracting_dims,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type,
dp_axis_name="",
tp_axis_name="",
epsilon=epsilon)
tp_axis_name="")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
......@@ -214,10 +238,11 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
contracting_dims=contracting_dims,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
sharding_type=sharding_type,
dp_axis_name=dp_axis_name,
tp_axis_name=tp_axis_name,
epsilon=epsilon)
tp_axis_name=tp_axis_name)
# input, kernel, gamma, beta, fp8_metas
in_axes = (ln_sharding_meta.in_axes[0], dot_sharding_meta.in_axes[1],
......@@ -230,27 +255,18 @@ def layernorm_fp8_dot(fp8_gemm_pkg: FP8GemmPackage,
return output
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15))
def _layernorm_fp8_dot(inputs: jnp.ndarray,
kernel: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
fp8_maxs: jnp.ndarray,
amax: jnp.ndarray,
scale: jnp.ndarray,
scale_inv: jnp.ndarray,
layernorm_type: str,
fwd_dtype: TEDType,
bwd_dtype: TEDType,
@partial(jax.custom_vjp, nondiff_argnums=(8, 9, 10, 11, 12, 13, 14, 15, 16))
def _layernorm_fp8_dot(inputs: jnp.ndarray, kernel: jnp.ndarray, gamma: jnp.ndarray,
beta: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str,
fwd_dtype: TEDType, bwd_dtype: TEDType,
contracting_dims: Tuple[Sequence[int], Sequence[int]],
sharding_type: ShardingType,
dp_axis_name: str,
tp_axis_name: str,
epsilon: float = 1e-6) -> jnp.ndarray:
zero_centered_gamma: bool, epsilon: float, sharding_type: ShardingType,
dp_axis_name: str, tp_axis_name: str) -> jnp.ndarray:
output, _ = _layernorm_fp8_dot_fwd(inputs, kernel, gamma, beta, fp8_maxs, amax, scale,
scale_inv, layernorm_type, fwd_dtype, bwd_dtype,
contracting_dims, sharding_type, dp_axis_name, tp_axis_name,
epsilon)
contracting_dims, zero_centered_gamma, epsilon,
sharding_type, dp_axis_name, tp_axis_name)
return output
......@@ -267,10 +283,11 @@ def _layernorm_fp8_dot_fwd(
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
contracting_dims,
zero_centered_gamma,
epsilon,
sharding_type,
dp_axis_name, # pylint: disable=unused-argument
tp_axis_name,
epsilon):
tp_axis_name):
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
input_shape_pre = inputs.shape[:min(lhs_contracting_dims)]
......@@ -295,8 +312,11 @@ def _layernorm_fp8_dot_fwd(
input_amax,
input_scale,
input_scale_inv,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
ln_out, rsigma, input_amax = rmsnorm_fwd_fp8(inputs,
gamma,
input_amax,
......@@ -337,10 +357,11 @@ def _layernorm_fp8_dot_bwd(
fwd_dtype,
bwd_dtype,
contracting_dims, # pylint: disable=unused-argument
zero_centered_gamma,
epsilon,
sharding_type,
dp_axis_name,
tp_axis_name,
epsilon,
ctx,
g):
ln_out_, kernel_cast, \
......@@ -377,8 +398,16 @@ def _layernorm_fp8_dot_bwd(
dgrad = jax.lax.psum(dgrad, tp_axis_name)
if layernorm_type == 'layernorm':
grad_input, grad_gamma, grad_beta = layernorm_bwd(dgrad, mu, rsigma, inputs, gamma, epsilon)
grad_input, grad_gamma, grad_beta = layernorm_bwd(dgrad,
mu,
rsigma,
inputs,
gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
grad_input, grad_gamma = rmsnorm_bwd(dgrad, rsigma, inputs, gamma, epsilon=epsilon)
grad_beta = None
......
......@@ -103,6 +103,7 @@ def fp8_ln_mlp(
layernorm_type: str,
fwd_dtype: TEDType,
bwd_dtype: TEDType,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
major_sharding_type: MajorShardingType = MajorShardingType.SINGLE,
......@@ -125,12 +126,14 @@ def fp8_ln_mlp(
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'rmsnorm':
assert ln_bias is None, "ln_bias should be None if layernorm_type is 'rmsnorm'"
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
assert activations == ('gelu', 'linear')
if major_sharding_type is MajorShardingType.SINGLE:
res = _fp8_mlp(inputs, ln_scale, ln_bias, kernel_1, kernel_2, fp8_max, amax, scale,
scale_inv, layernorm_type, activations, epsilon, fwd_dtype, bwd_dtype,
contracting_dims, major_sharding_type, "", "")
scale_inv, layernorm_type, activations, zero_centered_gamma, epsilon,
fwd_dtype, bwd_dtype, contracting_dims, major_sharding_type, "", "")
else:
dp_axis_name = "batch"
tp_axis_name = "model"
......@@ -177,6 +180,7 @@ def fp8_ln_mlp(
partial_fp8_mlp = partial(_fp8_mlp,
layernorm_type=layernorm_type,
activations=activations,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
fwd_dtype=fwd_dtype,
bwd_dtype=bwd_dtype,
......@@ -196,12 +200,13 @@ def fp8_ln_mlp(
return res
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17))
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18))
def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray,
kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray,
scale: jnp.ndarray, scale_inv: jnp.ndarray, layernorm_type: str,
activations: Sequence[Union[str, Callable]], epsilon: float, fwd_dtype: TEDType,
bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int], Sequence[int]],
activations: Sequence[Union[str, Callable]], zero_centered_gamma: bool, epsilon: float,
fwd_dtype: TEDType, bwd_dtype: TEDType, contracting_dims: Tuple[Sequence[int],
Sequence[int]],
major_sharding_type: MajorShardingType, dp_axis_name: str, tp_axis_name: str):
res, _ = _fp8_mlp_fwd(inputs,
ln_scale,
......@@ -214,6 +219,7 @@ def _fp8_mlp(inputs: jnp.ndarray, ln_scale: jnp.ndarray, ln_bias: jnp.ndarray,
scale_inv,
layernorm_type,
activations,
zero_centered_gamma,
epsilon,
fwd_dtype,
bwd_dtype,
......@@ -236,6 +242,7 @@ def _fp8_mlp_fwd(
scale_inv,
layernorm_type,
activations,
zero_centered_gamma,
epsilon,
fwd_dtype,
bwd_dtype, # pylint: disable=unused-argument
......@@ -276,8 +283,11 @@ def _fp8_mlp_fwd(
input_amax,
input_scale,
input_scale_inv,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
ln_out, rsigma, ln_out_amax = rmsnorm_fwd_fp8(inputs_,
gamma,
input_amax,
......@@ -334,6 +344,7 @@ def _fp8_mlp_fwd(
def _fp8_mlp_bwd(
layernorm_type,
activations, # pylint: disable=unused-argument
zero_centered_gamma,
epsilon,
fwd_dtype,
bwd_dtype,
......@@ -397,8 +408,11 @@ def _fp8_mlp_bwd(
rsigma,
inputs_,
gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon)
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
grad_input, grad_gamma = rmsnorm_bwd(dgrad_1, rsigma, inputs_, gamma, epsilon=epsilon)
grad_beta = None
......
......@@ -202,8 +202,20 @@ class LayerNorm(nn.Module):
A value added to the denominator of layer normalization for numerical stability.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization.
scale_init : Initializer, default = flax.linen.initializers.ones
zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
(1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'.
The default of `scale_init` will also be changed. See `scale_init`.
scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
Otherwise, scale_init is `flax.linen.initializers.ones`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh.
......@@ -228,7 +240,8 @@ class LayerNorm(nn.Module):
"""
epsilon: float = 1e-6
layernorm_type: str = 'layernorm'
scale_init: Initializer = nn.initializers.ones
zero_centered_gamma: bool = False
scale_init: Initializer = None
scale_axes: Tuple[str, ...] = ('embed',)
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ('embed',)
......@@ -236,6 +249,14 @@ class LayerNorm(nn.Module):
transpose_batch_sequence: bool = False
sharding_type: ShardingType = ShardingType.SINGLE
def __post_init__(self):
if self.scale_init is None:
if not self.zero_centered_gamma:
self.scale_init = nn.initializers.ones
else:
self.scale_init = nn.initializers.zeros
super().__post_init__()
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""
......@@ -256,14 +277,14 @@ class LayerNorm(nn.Module):
scale, ln_bias = _create_layernorm_parameters(self.layernorm_type, (features,),
self.scale_init, self.scale_axes,
self.bias_init, self.bias_axes, self.dtype)
return layernorm(x,
scale,
ln_bias,
self.layernorm_type,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0,
epsilon=self.epsilon)
dp_dim_index=1 if self.transpose_batch_sequence else 0)
class TransformerEngineBase(nn.Module):
......@@ -443,8 +464,20 @@ class LayerNormDenseGeneral(TransformerEngineBase):
Indicate the type of layer normalization.
epsilon : float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability.
scale_init : Initializer, default = flax.linen.initializers.ones
zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
(1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'.
The default of `scale_init` will also be changed. See `scale_init`
scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
Otherwise, scale_init is `flax.linen.initializers.ones`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
......@@ -496,7 +529,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
enable_layernorm: bool = True
layernorm_type: str = 'layernorm'
epsilon: float = 1e-6
scale_init: Initializer = nn.initializers.ones
zero_centered_gamma: bool = False
scale_init: Initializer = None
scale_axes: Tuple[str, ...] = ('embed',)
ln_bias_init: Initializer = nn.initializers.zeros
ln_bias_axes: Tuple[str, ...] = ('embed',)
......@@ -515,6 +549,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
if self.scale_init is None:
if not self.zero_centered_gamma:
self.scale_init = nn.initializers.ones
else:
self.scale_init = nn.initializers.zeros
super().__post_init__()
@nn.compact
......@@ -553,9 +592,10 @@ class LayerNormDenseGeneral(TransformerEngineBase):
scale,
ln_bias,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0,
epsilon=self.epsilon)
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else:
assert not self.return_layernorm_output
y = inputs
......@@ -600,9 +640,10 @@ class LayerNormDenseGeneral(TransformerEngineBase):
self.layernorm_type,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind),
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
sharding_type=self.sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0,
epsilon=self.epsilon)
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else:
kernel = jnp.asarray(kernel, self.dtype)
z = lax.dot_general(y, kernel, ((axis, contract_ind), ((), ())))
......@@ -638,8 +679,20 @@ class LayerNormMLP(TransformerEngineBase):
Indicate the type of layer normalization.
epsilon : float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability.
scale_init : Initializer, default = flax.linen.initializers.ones
zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
(1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'.
The default of `scale_init` will also be changed. See `scale_init`.
scale_init : Initializer, default = None
Used for initializing scale factors :math:`\gamma`.
If `None` is provided, scale_init is set according to the value of zero_centered_gamma.
If zero_centered_gamma is set to `True`, then scale_init is `flax.linen.initializers.zeros`.
Otherwise, scale_init is `flax.linen.initializers.ones`.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
scale_axes : Tuple[str, ...], default = ('embed', )
The name of axes used to shard the scale factors :math:`\gamma` with a corresponding mesh,
......@@ -704,7 +757,8 @@ class LayerNormMLP(TransformerEngineBase):
enable_layernorm: bool = True
layernorm_type: str = 'layernorm'
epsilon: float = 1e-6
scale_init: Initializer = nn.initializers.ones
zero_centered_gamma: bool = False
scale_init: Initializer = None
scale_axes: Tuple[str, ...] = ('embed',)
ln_bias_init: Initializer = nn.initializers.zeros
ln_bias_axes: Tuple[str, ...] = ('embed',)
......@@ -727,6 +781,11 @@ class LayerNormMLP(TransformerEngineBase):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
if self.scale_init is None:
if not self.zero_centered_gamma:
self.scale_init = nn.initializers.ones
else:
self.scale_init = nn.initializers.zeros
super().__post_init__()
@nn.compact
......@@ -774,9 +833,10 @@ class LayerNormMLP(TransformerEngineBase):
scale,
ln_bias,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
sharding_type=first_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0,
epsilon=self.epsilon)
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else:
assert not self.return_layernorm_output
y = inputs
......@@ -830,6 +890,7 @@ class LayerNormMLP(TransformerEngineBase):
self.layernorm_type,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
contracting_dims=(axis, contract_ind),
major_sharding_type=self.major_sharding_type,
......@@ -887,9 +948,10 @@ class LayerNormMLP(TransformerEngineBase):
self.layernorm_type,
FP8Helper.FWD_DTYPE,
FP8Helper.BWD_DTYPE, (axis, contract_ind),
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
sharding_type=first_sharding_type,
dp_dim_index=1 if self.transpose_batch_sequence else 0,
epsilon=self.epsilon)
dp_dim_index=1 if self.transpose_batch_sequence else 0)
else: # not enable fp8
kernel = jnp.asarray(kernel, self.dtype)
x = lax.dot_general(y, kernel, ((axis, contract_ind), ((), ())))
......
......@@ -205,6 +205,14 @@ class MultiHeadAttention(nn.Module):
Indicate the type of layer normalization.
layernorm_epsilon: float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
(1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'.
kernel_init: Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal')
Used for initializing the QKV and Output projection weights.
......@@ -250,6 +258,7 @@ class MultiHeadAttention(nn.Module):
dropout_rng_name: str = 'dropout'
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
kernel_init: Initializer = None
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
......@@ -346,6 +355,7 @@ class MultiHeadAttention(nn.Module):
qkv_proj, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon,
axis=-1,
features=(3, self.num_heads * self.head_dim),
......@@ -369,6 +379,7 @@ class MultiHeadAttention(nn.Module):
query, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon,
axis=-1,
features=self.num_heads * self.head_dim,
......@@ -412,6 +423,7 @@ class MultiHeadAttention(nn.Module):
query, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon,
axis=-1,
features=self.num_heads * self.head_dim,
......@@ -661,6 +673,14 @@ class TransformerLayer(nn.Module):
Indicate the type of layer normalization.
layernorm_epsilon: float, default = 1e-6
A value added to the denominator of layer normalization for numerical stability.
zero_centered_gamma : bool, default = False
If set to `True`, the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} *
(1 + \gamma) + \beta
This parameter is only applicable for 'layernorm'.
hidden_dropout: float, default = 0.1
Dropout probability for the dropout op after FC2 layer.
hidden_dropout_dims: Sequence[int], default = ()
......@@ -740,6 +760,7 @@ class TransformerLayer(nn.Module):
num_attention_heads: int = 8
layernorm_type: str = 'layernorm'
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
......@@ -874,6 +895,7 @@ class TransformerLayer(nn.Module):
scaled_query_init=self.scaled_query_init,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm,
attn_type=self_attn_type,
......@@ -919,6 +941,7 @@ class TransformerLayer(nn.Module):
dropout_rng_name=self.dropout_rng_name,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
apply_residual_connection_post_layernorm=self.
apply_residual_connection_post_layernorm,
output_layernorm=False, # Must do LayerNorm before MHA.
......@@ -941,6 +964,7 @@ class TransformerLayer(nn.Module):
residual = mlp_input
z, ln_out = LayerNormMLP(
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon,
major_sharding_type=infer_major_sharding_type(),
transpose_batch_sequence=self.transpose_batch_sequence,
......@@ -973,11 +997,12 @@ class TransformerLayer(nn.Module):
if self.output_layernorm:
ln_sharding_type, _ = infer_sharding_type()
z = LayerNorm(layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.layernorm_epsilon,
scale_axes=('embed',),
bias_axes=('embed',),
transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype,
epsilon=self.layernorm_epsilon,
sharding_type=ln_sharding_type,
name="output_layer_norm")(z)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment