Unverified Commit 24e4f955 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Flax params initialization with weight_dtype (#1481)



* initialization with weight_dtype
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent f0d22ca1
......@@ -8,8 +8,8 @@ import functools
import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union
import jax.numpy as jnp
import numpy as np
import jax.numpy as jnp
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax
......@@ -57,14 +57,18 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga
def _create_layernorm_parameters(
layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype
layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype, weight_dtype
):
scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes)
scale = nn_partitioning.param_with_axes(
"scale", scale_init, shape, weight_dtype, axes=scale_axes
)
scale = scale.astype(dtype)
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == "layernorm":
bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes)
bias = nn_partitioning.param_with_axes(
"ln_bias", bias_init, shape, weight_dtype, axes=bias_axes
)
bias = bias.astype(dtype)
else:
assert layernorm_type == "rmsnorm"
......@@ -256,8 +260,10 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters.
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
transpose_batch_sequence : bool, default = False
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
......@@ -272,6 +278,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ("embed",)
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
def __post_init__(self):
......@@ -307,6 +314,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
self.bias_init,
self.bias_axes,
self.dtype,
self.weight_dtype,
)
return layernorm(
x,
......@@ -399,8 +407,10 @@ class DenseGeneral(TransformerEngineBase):
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
......@@ -418,12 +428,13 @@ class DenseGeneral(TransformerEngineBase):
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype
)
super().__post_init__()
......@@ -452,13 +463,13 @@ class DenseGeneral(TransformerEngineBase):
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
"kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes
)
kernel = kernel.astype(self.dtype)
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes
"bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes
)
bias = bias.astype(self.dtype)
else:
......@@ -489,7 +500,7 @@ class DenseGeneral(TransformerEngineBase):
"lora_a_kernel",
self.kernel_init,
lora_a_kernel_init_shape,
self.dtype,
self.weight_dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
......@@ -501,7 +512,7 @@ class DenseGeneral(TransformerEngineBase):
"lora_b_kernel",
nn.initializers.zeros,
lora_b_kernel_shape,
self.dtype,
self.weight_dtype,
axes=lora_b_kernel_axes,
)
lora_b_kernel = lora_b_kernel.astype(self.dtype)
......@@ -594,8 +605,10 @@ class LayerNormDenseGeneral(TransformerEngineBase):
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
......@@ -625,6 +638,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None
......@@ -633,7 +647,10 @@ class LayerNormDenseGeneral(TransformerEngineBase):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
1.0,
"fan_in",
"truncated_normal",
dtype=self.weight_dtype,
)
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init,
......@@ -683,6 +700,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
self.ln_bias_init,
self.ln_bias_axes,
self.dtype,
self.weight_dtype,
)
if not fuse_layernorm:
......@@ -712,7 +730,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
kernel_shape = tuple(y.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
"kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes
)
kernel = kernel.astype(self.dtype)
......@@ -757,7 +775,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
"lora_a_kernel",
self.kernel_init,
lora_a_kernel_init_shape,
self.dtype,
self.weight_dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
......@@ -769,7 +787,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
"lora_b_kernel",
nn.initializers.zeros,
lora_b_kernel_shape,
self.dtype,
self.weight_dtype,
axes=lora_b_kernel_axes,
)
lora_b_kernel = lora_b_kernel.astype(self.dtype)
......@@ -781,7 +799,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias = None
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes
"bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes
)
bias = bias.astype(self.dtype)
......@@ -896,8 +914,10 @@ class LayerNormMLP(TransformerEngineBase):
Optimization parameters
-----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
......@@ -930,6 +950,7 @@ class LayerNormMLP(TransformerEngineBase):
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None
dot_1_input_axes: Tuple[str, ...] = None
......@@ -938,7 +959,7 @@ class LayerNormMLP(TransformerEngineBase):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype
)
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init,
......@@ -1015,6 +1036,7 @@ class LayerNormMLP(TransformerEngineBase):
self.ln_bias_init,
self.ln_bias_axes,
self.dtype,
self.weight_dtype,
)
if not fuse_layernorm:
......@@ -1061,7 +1083,7 @@ class LayerNormMLP(TransformerEngineBase):
num_activations,
-2,
kernel_1_each_shape,
self.dtype,
self.weight_dtype,
axes=self.kernel_axes_1,
)
kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
......@@ -1074,7 +1096,7 @@ class LayerNormMLP(TransformerEngineBase):
"wo_kernel",
self.kernel_init,
kernel_2_param_shape,
self.dtype,
self.weight_dtype,
axes=self.kernel_axes_2,
)
kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
......@@ -1090,13 +1112,21 @@ class LayerNormMLP(TransformerEngineBase):
if self.use_bias:
bias_1_shape = intermediate_dim
bias_1 = nn_partitioning.param_with_axes(
"wi_bias", self.bias_init, bias_1_shape, self.dtype, axes=self.bias_axes_1
"wi_bias",
self.bias_init,
bias_1_shape,
self.weight_dtype,
axes=self.bias_axes_1,
)
bias_1 = bias_1.astype(self.dtype)
bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes(
"wo_bias", self.bias_init, bias_2_shape, self.dtype, axes=self.bias_axes_2
"wo_bias",
self.bias_init,
bias_2_shape,
self.weight_dtype,
axes=self.bias_axes_2,
)
bias_2 = bias_2.astype(self.dtype)
else:
......@@ -1165,7 +1195,7 @@ class LayerNormMLP(TransformerEngineBase):
num_activations,
-2,
wi_lora_a_kernel_init_each_shape,
self.dtype,
self.weight_dtype,
axes=wi_lora_a_kernel_axes,
)
wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
......@@ -1181,7 +1211,7 @@ class LayerNormMLP(TransformerEngineBase):
"wi_lora_b_kernel",
nn.initializers.zeros,
wi_lora_b_kernel_shape,
self.dtype,
self.weight_dtype,
axes=wi_lora_b_kernel_axes,
)
wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype)
......@@ -1198,7 +1228,11 @@ class LayerNormMLP(TransformerEngineBase):
bias_1 = None
if self.use_bias:
bias_1 = nn_partitioning.param_with_axes(
"wi_bias", self.bias_init, intermediate_dim, self.dtype, axes=self.bias_axes_1
"wi_bias",
self.bias_init,
intermediate_dim,
self.weight_dtype,
axes=self.bias_axes_1,
)
bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape
bias_1 = bias_1.astype(self.dtype)
......@@ -1240,7 +1274,7 @@ class LayerNormMLP(TransformerEngineBase):
"wo_lora_a_kernel",
self.kernel_init,
wo_lora_a_kernel_shape,
self.dtype,
self.weight_dtype,
axes=wo_lora_a_kernel_axes,
)
wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype)
......@@ -1251,7 +1285,7 @@ class LayerNormMLP(TransformerEngineBase):
"wo_lora_b_kernel",
nn.initializers.zeros,
wo_lora_b_kernel_shape,
self.dtype,
self.weight_dtype,
axes=wo_lora_b_kernel_axes,
)
wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype)
......@@ -1268,7 +1302,11 @@ class LayerNormMLP(TransformerEngineBase):
bias_2 = None
if self.use_bias:
bias_2 = nn_partitioning.param_with_axes(
"wo_bias", self.bias_init, (hidden_size,), self.dtype, axes=self.bias_axes_2
"wo_bias",
self.bias_init,
(hidden_size,),
self.weight_dtype,
axes=self.bias_axes_2,
)
bias_2 = bias_2.astype(self.dtype)
out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
......
......@@ -115,6 +115,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
float32_logits: bool = False
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
......@@ -261,6 +262,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False
......@@ -481,7 +483,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
"""
head_dim: int
......@@ -491,6 +495,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type: AttnMaskType = "causal"
attn_bias_type: AttnBiasType = None
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
dropout_rng_name: str = "dropout"
float32_logits: bool = False
qkv_layout: str = "bshd_bshd_bshd"
......@@ -615,6 +620,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
float32_logits=self.float32_logits,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
......@@ -626,6 +632,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
qkv_layout=qkv_layout,
......@@ -881,7 +888,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
fuse_qkv_params: bool, default = True
If set to True, this module exposes a single fused
parameter for query-key-value for self-attention and key-value for
......@@ -927,6 +936,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False
......@@ -977,7 +987,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.dtype
1.0, "fan_in", "normal", self.weight_dtype
)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
......@@ -1105,6 +1115,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
dot_input_axes=inputs_logical_axes_no_sp,
name="qkv",
dtype=self.dtype,
weight_dtype=self.weight_dtype,
)(inputs_q)
qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj")
qkv_layout = QKVLayout.BS3HD
......@@ -1128,6 +1139,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
......@@ -1152,6 +1164,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
name="kv",
dtype=self.dtype,
weight_dtype=self.weight_dtype,
)(inputs_kv)
kv_proj = checkpoint_name(kv_proj, "combined_kv_proj")
qkv_layout = QKVLayout.BSHD_BS2HD
......@@ -1169,6 +1182,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
)
query, ln_out = LayerNormDenseGeneral(
enable_layernorm=self.input_layernorm,
......@@ -1189,6 +1203,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
......@@ -1326,6 +1341,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_bias_type=self.attn_bias_type,
attention_dropout=self.attention_dropout,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_logits,
qkv_layout=qkv_layout.name,
......@@ -1351,6 +1367,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
name="out",
)(x)
out = checkpoint_name(out, "out_proj")
......@@ -1379,7 +1396,9 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
"""
num_buckets: int
......@@ -1388,6 +1407,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
embedding_init: Callable[..., Array] = nn.linear.default_embed_init
embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets")
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
@nn.compact
def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
......@@ -1440,7 +1460,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
"rel_embedding",
self.embedding_init,
(self.num_attention_heads, self.num_buckets),
self.dtype,
self.weight_dtype,
axes=self.embedding_axes,
)
......@@ -1613,7 +1633,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters.
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
drop_path: float, default = 0.0
When > 0.0, applies stochastic depth per sample in the main
path of the residual block.
......@@ -1666,6 +1688,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
drop_path: float = 0.0
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = False
......@@ -1677,11 +1700,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
def __post_init__(self):
if self.mha_kernel_init is None:
self.mha_kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.dtype
1.0, "fan_in", "normal", dtype=self.weight_dtype
)
if self.mlp_kernel_init is None:
self.mlp_kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype
)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
......@@ -1771,6 +1794,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
max_distance=128,
num_attention_heads=self.num_attention_heads,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
name="relpos_bias",
)
......@@ -1804,6 +1828,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
x, ln_out = MultiHeadAttention(
num_attention_heads=self.num_attention_heads,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence,
......@@ -1882,6 +1907,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
y, ln_out = MultiHeadAttention(
num_attention_heads=self.num_attention_heads,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence,
......@@ -1947,6 +1973,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
intermediate_dropout_rate=self.intermediate_dropout,
intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_init=self.mlp_kernel_init,
......@@ -1996,6 +2023,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
name="output_layernorm",
)(z)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment