Unverified Commit 24e4f955 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Flax params initialization with weight_dtype (#1481)



* initialization with weight_dtype
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent f0d22ca1
...@@ -8,8 +8,8 @@ import functools ...@@ -8,8 +8,8 @@ import functools
import operator import operator
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union
import jax.numpy as jnp
import numpy as np import numpy as np
import jax.numpy as jnp
from flax import linen as nn from flax import linen as nn
from flax.linen import partitioning as nn_partitioning from flax.linen import partitioning as nn_partitioning
from jax import lax from jax import lax
...@@ -57,14 +57,18 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga ...@@ -57,14 +57,18 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga
def _create_layernorm_parameters( def _create_layernorm_parameters(
layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype, weight_dtype
): ):
scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes) scale = nn_partitioning.param_with_axes(
"scale", scale_init, shape, weight_dtype, axes=scale_axes
)
scale = scale.astype(dtype) scale = scale.astype(dtype)
layernorm_type = canonicalize_layernorm_type(layernorm_type) layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == "layernorm": if layernorm_type == "layernorm":
bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes) bias = nn_partitioning.param_with_axes(
"ln_bias", bias_init, shape, weight_dtype, axes=bias_axes
)
bias = bias.astype(dtype) bias = bias.astype(dtype)
else: else:
assert layernorm_type == "rmsnorm" assert layernorm_type == "rmsnorm"
...@@ -256,8 +260,10 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -256,8 +260,10 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
the data type used to allocate the initial parameters. The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
transpose_batch_sequence : bool, default = False transpose_batch_sequence : bool, default = False
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors and sequence length dimension. If set to True, the input tensors
...@@ -272,6 +278,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -272,6 +278,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ("embed",) bias_axes: Tuple[str, ...] = ("embed",)
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
def __post_init__(self): def __post_init__(self):
...@@ -307,6 +314,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -307,6 +314,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
self.bias_init, self.bias_init,
self.bias_axes, self.bias_axes,
self.dtype, self.dtype,
self.weight_dtype,
) )
return layernorm( return layernorm(
x, x,
...@@ -399,8 +407,10 @@ class DenseGeneral(TransformerEngineBase): ...@@ -399,8 +407,10 @@ class DenseGeneral(TransformerEngineBase):
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
transpose_batch_sequence : bool, default = True transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors and sequence length dimension. If set to True, the input tensors
...@@ -418,12 +428,13 @@ class DenseGeneral(TransformerEngineBase): ...@@ -418,12 +428,13 @@ class DenseGeneral(TransformerEngineBase):
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype 1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype
) )
super().__post_init__() super().__post_init__()
...@@ -452,13 +463,13 @@ class DenseGeneral(TransformerEngineBase): ...@@ -452,13 +463,13 @@ class DenseGeneral(TransformerEngineBase):
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes( kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes "kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes
) )
kernel = kernel.astype(self.dtype) kernel = kernel.astype(self.dtype)
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes "bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes
) )
bias = bias.astype(self.dtype) bias = bias.astype(self.dtype)
else: else:
...@@ -489,7 +500,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -489,7 +500,7 @@ class DenseGeneral(TransformerEngineBase):
"lora_a_kernel", "lora_a_kernel",
self.kernel_init, self.kernel_init,
lora_a_kernel_init_shape, lora_a_kernel_init_shape,
self.dtype, self.weight_dtype,
axes=lora_a_kernel_axes, axes=lora_a_kernel_axes,
) )
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
...@@ -501,7 +512,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -501,7 +512,7 @@ class DenseGeneral(TransformerEngineBase):
"lora_b_kernel", "lora_b_kernel",
nn.initializers.zeros, nn.initializers.zeros,
lora_b_kernel_shape, lora_b_kernel_shape,
self.dtype, self.weight_dtype,
axes=lora_b_kernel_axes, axes=lora_b_kernel_axes,
) )
lora_b_kernel = lora_b_kernel.astype(self.dtype) lora_b_kernel = lora_b_kernel.astype(self.dtype)
...@@ -594,8 +605,10 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -594,8 +605,10 @@ class LayerNormDenseGeneral(TransformerEngineBase):
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
transpose_batch_sequence : bool, default = True transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors and sequence length dimension. If set to True, the input tensors
...@@ -625,6 +638,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -625,6 +638,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None
...@@ -633,7 +647,10 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -633,7 +647,10 @@ class LayerNormDenseGeneral(TransformerEngineBase):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype 1.0,
"fan_in",
"truncated_normal",
dtype=self.weight_dtype,
) )
self.scale_init = _obtain_default_layernorm_scale_init_if_need( self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init, self.scale_init,
...@@ -683,6 +700,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -683,6 +700,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
self.ln_bias_init, self.ln_bias_init,
self.ln_bias_axes, self.ln_bias_axes,
self.dtype, self.dtype,
self.weight_dtype,
) )
if not fuse_layernorm: if not fuse_layernorm:
...@@ -712,7 +730,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -712,7 +730,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
kernel_shape = tuple(y.shape[ax] for ax in axis) + features kernel_shape = tuple(y.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes( kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes "kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes
) )
kernel = kernel.astype(self.dtype) kernel = kernel.astype(self.dtype)
...@@ -757,7 +775,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -757,7 +775,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
"lora_a_kernel", "lora_a_kernel",
self.kernel_init, self.kernel_init,
lora_a_kernel_init_shape, lora_a_kernel_init_shape,
self.dtype, self.weight_dtype,
axes=lora_a_kernel_axes, axes=lora_a_kernel_axes,
) )
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
...@@ -769,7 +787,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -769,7 +787,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
"lora_b_kernel", "lora_b_kernel",
nn.initializers.zeros, nn.initializers.zeros,
lora_b_kernel_shape, lora_b_kernel_shape,
self.dtype, self.weight_dtype,
axes=lora_b_kernel_axes, axes=lora_b_kernel_axes,
) )
lora_b_kernel = lora_b_kernel.astype(self.dtype) lora_b_kernel = lora_b_kernel.astype(self.dtype)
...@@ -781,7 +799,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -781,7 +799,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias = None bias = None
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes "bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes
) )
bias = bias.astype(self.dtype) bias = bias.astype(self.dtype)
...@@ -896,8 +914,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -896,8 +914,10 @@ class LayerNormMLP(TransformerEngineBase):
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype : jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
transpose_batch_sequence : bool, default = True transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors and sequence length dimension. If set to True, the input tensors
...@@ -930,6 +950,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -930,6 +950,7 @@ class LayerNormMLP(TransformerEngineBase):
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None layernorm_input_axes: Tuple[str, ...] = None
dot_1_input_axes: Tuple[str, ...] = None dot_1_input_axes: Tuple[str, ...] = None
...@@ -938,7 +959,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -938,7 +959,7 @@ class LayerNormMLP(TransformerEngineBase):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype 1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype
) )
self.scale_init = _obtain_default_layernorm_scale_init_if_need( self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init, self.scale_init,
...@@ -1015,6 +1036,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1015,6 +1036,7 @@ class LayerNormMLP(TransformerEngineBase):
self.ln_bias_init, self.ln_bias_init,
self.ln_bias_axes, self.ln_bias_axes,
self.dtype, self.dtype,
self.weight_dtype,
) )
if not fuse_layernorm: if not fuse_layernorm:
...@@ -1061,7 +1083,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1061,7 +1083,7 @@ class LayerNormMLP(TransformerEngineBase):
num_activations, num_activations,
-2, -2,
kernel_1_each_shape, kernel_1_each_shape,
self.dtype, self.weight_dtype,
axes=self.kernel_axes_1, axes=self.kernel_axes_1,
) )
kernel_1 = jnp.reshape(kernel_1, kernel_1_shape) kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
...@@ -1074,7 +1096,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1074,7 +1096,7 @@ class LayerNormMLP(TransformerEngineBase):
"wo_kernel", "wo_kernel",
self.kernel_init, self.kernel_init,
kernel_2_param_shape, kernel_2_param_shape,
self.dtype, self.weight_dtype,
axes=self.kernel_axes_2, axes=self.kernel_axes_2,
) )
kernel_2 = jnp.reshape(kernel_2, kernel_2_shape) kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
...@@ -1090,13 +1112,21 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1090,13 +1112,21 @@ class LayerNormMLP(TransformerEngineBase):
if self.use_bias: if self.use_bias:
bias_1_shape = intermediate_dim bias_1_shape = intermediate_dim
bias_1 = nn_partitioning.param_with_axes( bias_1 = nn_partitioning.param_with_axes(
"wi_bias", self.bias_init, bias_1_shape, self.dtype, axes=self.bias_axes_1 "wi_bias",
self.bias_init,
bias_1_shape,
self.weight_dtype,
axes=self.bias_axes_1,
) )
bias_1 = bias_1.astype(self.dtype) bias_1 = bias_1.astype(self.dtype)
bias_2_shape = (hidden_size,) bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes( bias_2 = nn_partitioning.param_with_axes(
"wo_bias", self.bias_init, bias_2_shape, self.dtype, axes=self.bias_axes_2 "wo_bias",
self.bias_init,
bias_2_shape,
self.weight_dtype,
axes=self.bias_axes_2,
) )
bias_2 = bias_2.astype(self.dtype) bias_2 = bias_2.astype(self.dtype)
else: else:
...@@ -1165,7 +1195,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1165,7 +1195,7 @@ class LayerNormMLP(TransformerEngineBase):
num_activations, num_activations,
-2, -2,
wi_lora_a_kernel_init_each_shape, wi_lora_a_kernel_init_each_shape,
self.dtype, self.weight_dtype,
axes=wi_lora_a_kernel_axes, axes=wi_lora_a_kernel_axes,
) )
wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape) wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
...@@ -1181,7 +1211,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1181,7 +1211,7 @@ class LayerNormMLP(TransformerEngineBase):
"wi_lora_b_kernel", "wi_lora_b_kernel",
nn.initializers.zeros, nn.initializers.zeros,
wi_lora_b_kernel_shape, wi_lora_b_kernel_shape,
self.dtype, self.weight_dtype,
axes=wi_lora_b_kernel_axes, axes=wi_lora_b_kernel_axes,
) )
wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype) wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype)
...@@ -1198,7 +1228,11 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1198,7 +1228,11 @@ class LayerNormMLP(TransformerEngineBase):
bias_1 = None bias_1 = None
if self.use_bias: if self.use_bias:
bias_1 = nn_partitioning.param_with_axes( bias_1 = nn_partitioning.param_with_axes(
"wi_bias", self.bias_init, intermediate_dim, self.dtype, axes=self.bias_axes_1 "wi_bias",
self.bias_init,
intermediate_dim,
self.weight_dtype,
axes=self.bias_axes_1,
) )
bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape
bias_1 = bias_1.astype(self.dtype) bias_1 = bias_1.astype(self.dtype)
...@@ -1240,7 +1274,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1240,7 +1274,7 @@ class LayerNormMLP(TransformerEngineBase):
"wo_lora_a_kernel", "wo_lora_a_kernel",
self.kernel_init, self.kernel_init,
wo_lora_a_kernel_shape, wo_lora_a_kernel_shape,
self.dtype, self.weight_dtype,
axes=wo_lora_a_kernel_axes, axes=wo_lora_a_kernel_axes,
) )
wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype) wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype)
...@@ -1251,7 +1285,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1251,7 +1285,7 @@ class LayerNormMLP(TransformerEngineBase):
"wo_lora_b_kernel", "wo_lora_b_kernel",
nn.initializers.zeros, nn.initializers.zeros,
wo_lora_b_kernel_shape, wo_lora_b_kernel_shape,
self.dtype, self.weight_dtype,
axes=wo_lora_b_kernel_axes, axes=wo_lora_b_kernel_axes,
) )
wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype) wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype)
...@@ -1268,7 +1302,11 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1268,7 +1302,11 @@ class LayerNormMLP(TransformerEngineBase):
bias_2 = None bias_2 = None
if self.use_bias: if self.use_bias:
bias_2 = nn_partitioning.param_with_axes( bias_2 = nn_partitioning.param_with_axes(
"wo_bias", self.bias_init, (hidden_size,), self.dtype, axes=self.bias_axes_2 "wo_bias",
self.bias_init,
(hidden_size,),
self.weight_dtype,
axes=self.bias_axes_2,
) )
bias_2 = bias_2.astype(self.dtype) bias_2 = bias_2.astype(self.dtype)
out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,)) out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
......
...@@ -115,6 +115,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -115,6 +115,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
float32_logits: bool = False float32_logits: bool = False
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
...@@ -261,6 +262,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -261,6 +262,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
...@@ -481,7 +483,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -481,7 +483,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
""" """
head_dim: int head_dim: int
...@@ -491,6 +495,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -491,6 +495,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type: AttnMaskType = "causal" attn_mask_type: AttnMaskType = "causal"
attn_bias_type: AttnBiasType = None attn_bias_type: AttnBiasType = None
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
dropout_rng_name: str = "dropout" dropout_rng_name: str = "dropout"
float32_logits: bool = False float32_logits: bool = False
qkv_layout: str = "bshd_bshd_bshd" qkv_layout: str = "bshd_bshd_bshd"
...@@ -615,6 +620,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -615,6 +620,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
float32_logits=self.float32_logits, float32_logits=self.float32_logits,
scale_factor=scale_factor, scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
...@@ -626,6 +632,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -626,6 +632,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
scale_factor=scale_factor, scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -881,7 +888,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -881,7 +888,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
fuse_qkv_params: bool, default = True fuse_qkv_params: bool, default = True
If set to True, this module exposes a single fused If set to True, this module exposes a single fused
parameter for query-key-value for self-attention and key-value for parameter for query-key-value for self-attention and key-value for
...@@ -927,6 +936,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -927,6 +936,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim: int = 32 low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False enable_sequence_parallel: bool = False
...@@ -977,7 +987,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -977,7 +987,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.dtype 1.0, "fan_in", "normal", self.weight_dtype
) )
if self.num_gqa_groups is None: if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads self.num_gqa_groups = self.num_attention_heads
...@@ -1105,6 +1115,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1105,6 +1115,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
dot_input_axes=inputs_logical_axes_no_sp, dot_input_axes=inputs_logical_axes_no_sp,
name="qkv", name="qkv",
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
)(inputs_q) )(inputs_q)
qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj") qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj")
qkv_layout = QKVLayout.BS3HD qkv_layout = QKVLayout.BS3HD
...@@ -1128,6 +1139,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1128,6 +1139,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
kernel_init=query_init, kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp, layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp, dot_input_axes=inputs_logical_axes_no_sp,
...@@ -1152,6 +1164,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1152,6 +1164,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
name="kv", name="kv",
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
)(inputs_kv) )(inputs_kv)
kv_proj = checkpoint_name(kv_proj, "combined_kv_proj") kv_proj = checkpoint_name(kv_proj, "combined_kv_proj")
qkv_layout = QKVLayout.BSHD_BS2HD qkv_layout = QKVLayout.BSHD_BS2HD
...@@ -1169,6 +1182,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1169,6 +1182,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
) )
query, ln_out = LayerNormDenseGeneral( query, ln_out = LayerNormDenseGeneral(
enable_layernorm=self.input_layernorm, enable_layernorm=self.input_layernorm,
...@@ -1189,6 +1203,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1189,6 +1203,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
kernel_init=query_init, kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp, layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp, dot_input_axes=inputs_logical_axes_no_sp,
...@@ -1326,6 +1341,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1326,6 +1341,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
attention_dropout=self.attention_dropout, attention_dropout=self.attention_dropout,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
dropout_rng_name=self.dropout_rng_name, dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_logits, float32_logits=self.float32_logits,
qkv_layout=qkv_layout.name, qkv_layout=qkv_layout.name,
...@@ -1351,6 +1367,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1351,6 +1367,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
name="out", name="out",
)(x) )(x)
out = checkpoint_name(out, "out_proj") out = checkpoint_name(out, "out_proj")
...@@ -1379,7 +1396,9 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho ...@@ -1379,7 +1396,9 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
""" """
num_buckets: int num_buckets: int
...@@ -1388,6 +1407,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho ...@@ -1388,6 +1407,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
embedding_init: Callable[..., Array] = nn.linear.default_embed_init embedding_init: Callable[..., Array] = nn.linear.default_embed_init
embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets") embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets")
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
@nn.compact @nn.compact
def __call__(self, q_seqlen, k_seqlen, bidirectional=True): def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
...@@ -1440,7 +1460,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho ...@@ -1440,7 +1460,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
"rel_embedding", "rel_embedding",
self.embedding_init, self.embedding_init,
(self.num_attention_heads, self.num_buckets), (self.num_attention_heads, self.num_buckets),
self.dtype, self.weight_dtype,
axes=self.embedding_axes, axes=self.embedding_axes,
) )
...@@ -1613,7 +1633,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1613,7 +1633,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used to allocate the initial parameters. The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
drop_path: float, default = 0.0 drop_path: float, default = 0.0
When > 0.0, applies stochastic depth per sample in the main When > 0.0, applies stochastic depth per sample in the main
path of the residual block. path of the residual block.
...@@ -1666,6 +1688,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1666,6 +1688,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim: int = 32 low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
drop_path: float = 0.0 drop_path: float = 0.0
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
...@@ -1677,11 +1700,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1677,11 +1700,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
def __post_init__(self): def __post_init__(self):
if self.mha_kernel_init is None: if self.mha_kernel_init is None:
self.mha_kernel_init = nn.initializers.variance_scaling( self.mha_kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.dtype 1.0, "fan_in", "normal", dtype=self.weight_dtype
) )
if self.mlp_kernel_init is None: if self.mlp_kernel_init is None:
self.mlp_kernel_init = nn.initializers.variance_scaling( self.mlp_kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype 1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype
) )
if self.num_gqa_groups is None: if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads self.num_gqa_groups = self.num_attention_heads
...@@ -1771,6 +1794,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1771,6 +1794,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
max_distance=128, max_distance=128,
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"), embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
name="relpos_bias", name="relpos_bias",
) )
...@@ -1804,6 +1828,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1804,6 +1828,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
x, ln_out = MultiHeadAttention( x, ln_out = MultiHeadAttention(
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
head_dim=head_dim, head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups, num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
...@@ -1882,6 +1907,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1882,6 +1907,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
y, ln_out = MultiHeadAttention( y, ln_out = MultiHeadAttention(
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
head_dim=head_dim, head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups, num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
...@@ -1947,6 +1973,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1947,6 +1973,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
intermediate_dropout_rate=self.intermediate_dropout, intermediate_dropout_rate=self.intermediate_dropout,
intermediate_hidden_dropout_dims=self.intermediate_dropout_dims, intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_init=self.mlp_kernel_init, kernel_init=self.mlp_kernel_init,
...@@ -1996,6 +2023,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1996,6 +2023,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_NO_SHARD_AXES,), bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
name="output_layernorm", name="output_layernorm",
)(z) )(z)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment