"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "28e616c4e3ef24d3763de5c5210f2ee20be56f5e"
Unverified Commit b87e539d authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Flax module init with a given dtype (#1472)



* flax module to init params with given dtype
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* all tests passed
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* remove unneccessary reshape for kernel
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* remove casting output of dot
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* clean up
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 544dd14b
...@@ -252,8 +252,13 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -252,8 +252,13 @@ class FusedAttnFwdPrimitive(BasePrimitive):
k_dtype = dtypes.canonicalize_dtype(k_aval.dtype) k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
assert q_dtype == k_dtype == v_dtype == bias_dtype assert (
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype q_dtype == k_dtype == v_dtype == bias_dtype
), f"q_dtype={q_dtype}, k_dtype={k_dtype}, v_dtype={v_dtype}, bias_dtype={bias_dtype}"
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype, (
f"q_seqlen_or_cu_seqlen_aval={q_seqlen_or_cu_seqlen_aval},"
f" kv_seqlen_or_cu_seqlen_aval={kv_seqlen_or_cu_seqlen_aval}"
)
batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = (
FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
......
...@@ -25,7 +25,7 @@ def type_safe_dot_general( ...@@ -25,7 +25,7 @@ def type_safe_dot_general(
""" """
if fp8_meta_pkg is None: if fp8_meta_pkg is None:
kernel = jnp.asarray(kernel, x.dtype) assert x.dtype == kernel.dtype, f"lhs dtype = {x.dtype}, rhs dtype = {kernel.dtype}"
return jax.lax.dot_general(x, kernel, (contracting_dims, ((), ()))) return jax.lax.dot_general(x, kernel, (contracting_dims, ((), ())))
amax_list = fp8_meta_pkg.amax_list amax_list = fp8_meta_pkg.amax_list
......
...@@ -59,17 +59,13 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga ...@@ -59,17 +59,13 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga
def _create_layernorm_parameters( def _create_layernorm_parameters(
layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype
): ):
scale = nn_partitioning.param_with_axes( scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes)
"scale", scale_init, shape, jnp.float32, axes=scale_axes scale = scale.astype(dtype)
)
scale = jnp.asarray(scale, dtype)
layernorm_type = canonicalize_layernorm_type(layernorm_type) layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == "layernorm": if layernorm_type == "layernorm":
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes)
"ln_bias", bias_init, shape, jnp.float32, axes=bias_axes bias = bias.astype(dtype)
)
bias = jnp.asarray(bias, dtype)
else: else:
assert layernorm_type == "rmsnorm" assert layernorm_type == "rmsnorm"
bias = None bias = None
...@@ -280,7 +276,8 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -280,7 +276,8 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
def __post_init__(self): def __post_init__(self):
self.scale_init = _obtain_default_layernorm_scale_init_if_need( self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init, self.zero_centered_gamma self.scale_init,
self.zero_centered_gamma,
) )
super().__post_init__() super().__post_init__()
...@@ -299,6 +296,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -299,6 +296,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
outputs : jax.numpy.ndarray outputs : jax.numpy.ndarray
Output tensors. Output tensors.
""" """
x = x.astype(self.dtype)
features = x.shape[-1] features = x.shape[-1]
scale, ln_bias = _create_layernorm_parameters( scale, ln_bias = _create_layernorm_parameters(
...@@ -424,7 +422,9 @@ class DenseGeneral(TransformerEngineBase): ...@@ -424,7 +422,9 @@ class DenseGeneral(TransformerEngineBase):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -452,14 +452,13 @@ class DenseGeneral(TransformerEngineBase): ...@@ -452,14 +452,13 @@ class DenseGeneral(TransformerEngineBase):
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes( kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
) )
kernel = kernel.astype(self.dtype)
kernel = jnp.reshape(kernel, kernel_shape)
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, jnp.float32, axes=self.bias_axes "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
) )
bias = bias.astype(self.dtype) bias = bias.astype(self.dtype)
else: else:
...@@ -490,7 +489,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -490,7 +489,7 @@ class DenseGeneral(TransformerEngineBase):
"lora_a_kernel", "lora_a_kernel",
self.kernel_init, self.kernel_init,
lora_a_kernel_init_shape, lora_a_kernel_init_shape,
jnp.float32, self.dtype,
axes=lora_a_kernel_axes, axes=lora_a_kernel_axes,
) )
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
...@@ -502,7 +501,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -502,7 +501,7 @@ class DenseGeneral(TransformerEngineBase):
"lora_b_kernel", "lora_b_kernel",
nn.initializers.zeros, nn.initializers.zeros,
lora_b_kernel_shape, lora_b_kernel_shape,
jnp.float32, self.dtype,
axes=lora_b_kernel_axes, axes=lora_b_kernel_axes,
) )
lora_b_kernel = lora_b_kernel.astype(self.dtype) lora_b_kernel = lora_b_kernel.astype(self.dtype)
...@@ -633,9 +632,12 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -633,9 +632,12 @@ class LayerNormDenseGeneral(TransformerEngineBase):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
self.scale_init = _obtain_default_layernorm_scale_init_if_need( self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init, self.zero_centered_gamma self.scale_init,
self.zero_centered_gamma,
) )
super().__post_init__() super().__post_init__()
...@@ -665,6 +667,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -665,6 +667,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
and not self.return_layernorm_output and not self.return_layernorm_output
and self.enable_layernorm and self.enable_layernorm
) )
inputs = inputs.astype(self.dtype)
if self.enable_layernorm: if self.enable_layernorm:
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes) inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
...@@ -709,10 +712,9 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -709,10 +712,9 @@ class LayerNormDenseGeneral(TransformerEngineBase):
kernel_shape = tuple(y.shape[ax] for ax in axis) + features kernel_shape = tuple(y.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes( kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
) )
kernel = kernel.astype(self.dtype)
kernel = jnp.reshape(kernel, kernel_shape)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
...@@ -755,7 +757,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -755,7 +757,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
"lora_a_kernel", "lora_a_kernel",
self.kernel_init, self.kernel_init,
lora_a_kernel_init_shape, lora_a_kernel_init_shape,
jnp.float32, self.dtype,
axes=lora_a_kernel_axes, axes=lora_a_kernel_axes,
) )
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
...@@ -767,7 +769,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -767,7 +769,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
"lora_b_kernel", "lora_b_kernel",
nn.initializers.zeros, nn.initializers.zeros,
lora_b_kernel_shape, lora_b_kernel_shape,
jnp.float32, self.dtype,
axes=lora_b_kernel_axes, axes=lora_b_kernel_axes,
) )
lora_b_kernel = lora_b_kernel.astype(self.dtype) lora_b_kernel = lora_b_kernel.astype(self.dtype)
...@@ -779,7 +781,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -779,7 +781,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias = None bias = None
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, jnp.float32, axes=self.bias_axes "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
) )
bias = bias.astype(self.dtype) bias = bias.astype(self.dtype)
...@@ -935,9 +937,12 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -935,9 +937,12 @@ class LayerNormMLP(TransformerEngineBase):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
self.scale_init = _obtain_default_layernorm_scale_init_if_need( self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init, self.zero_centered_gamma self.scale_init,
self.zero_centered_gamma,
) )
super().__post_init__() super().__post_init__()
...@@ -970,6 +975,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -970,6 +975,8 @@ class LayerNormMLP(TransformerEngineBase):
and self.enable_layernorm and self.enable_layernorm
) )
inputs = inputs.astype(self.dtype)
gated_act_pool = [ gated_act_pool = [
("gelu", "linear"), ("gelu", "linear"),
("silu", "linear"), ("silu", "linear"),
...@@ -1033,7 +1040,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1033,7 +1040,7 @@ class LayerNormMLP(TransformerEngineBase):
for _ in range(num_kernels): for _ in range(num_kernels):
key, init_key = jax_random.split(key) key, init_key = jax_random.split(key)
kernels.append(self.kernel_init(init_key, *init_args)) kernels.append(self.kernel_init(init_key, *init_args))
return jnp.stack(kernels, axis=stack_axis, dtype=jnp.float32) return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype)
wi_fp8_meta_pkg = None wi_fp8_meta_pkg = None
wo_fp8_meta_pkg = None wo_fp8_meta_pkg = None
...@@ -1054,10 +1061,11 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1054,10 +1061,11 @@ class LayerNormMLP(TransformerEngineBase):
num_activations, num_activations,
-2, -2,
kernel_1_each_shape, kernel_1_each_shape,
jnp.float32, self.dtype,
axes=self.kernel_axes_1, axes=self.kernel_axes_1,
) )
kernel_1 = jnp.reshape(kernel_1, kernel_1_shape) kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
kernel_1 = kernel_1.astype(self.dtype)
hidden_size = inputs.shape[-1] hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size) hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
...@@ -1066,10 +1074,11 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1066,10 +1074,11 @@ class LayerNormMLP(TransformerEngineBase):
"wo_kernel", "wo_kernel",
self.kernel_init, self.kernel_init,
kernel_2_param_shape, kernel_2_param_shape,
jnp.float32, self.dtype,
axes=self.kernel_axes_2, axes=self.kernel_axes_2,
) )
kernel_2 = jnp.reshape(kernel_2, kernel_2_shape) kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
kernel_2 = kernel_2.astype(self.dtype)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
ffn1_ckpt_name = "ffn1" ffn1_ckpt_name = "ffn1"
...@@ -1081,13 +1090,13 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1081,13 +1090,13 @@ class LayerNormMLP(TransformerEngineBase):
if self.use_bias: if self.use_bias:
bias_1_shape = intermediate_dim bias_1_shape = intermediate_dim
bias_1 = nn_partitioning.param_with_axes( bias_1 = nn_partitioning.param_with_axes(
"wi_bias", self.bias_init, bias_1_shape, jnp.float32, axes=self.bias_axes_1 "wi_bias", self.bias_init, bias_1_shape, self.dtype, axes=self.bias_axes_1
) )
bias_1 = bias_1.astype(self.dtype) bias_1 = bias_1.astype(self.dtype)
bias_2_shape = (hidden_size,) bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes( bias_2 = nn_partitioning.param_with_axes(
"wo_bias", self.bias_init, bias_2_shape, jnp.float32, axes=self.bias_axes_2 "wo_bias", self.bias_init, bias_2_shape, self.dtype, axes=self.bias_axes_2
) )
bias_2 = bias_2.astype(self.dtype) bias_2 = bias_2.astype(self.dtype)
else: else:
...@@ -1156,7 +1165,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1156,7 +1165,7 @@ class LayerNormMLP(TransformerEngineBase):
num_activations, num_activations,
-2, -2,
wi_lora_a_kernel_init_each_shape, wi_lora_a_kernel_init_each_shape,
jnp.float32, self.dtype,
axes=wi_lora_a_kernel_axes, axes=wi_lora_a_kernel_axes,
) )
wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape) wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
...@@ -1172,7 +1181,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1172,7 +1181,7 @@ class LayerNormMLP(TransformerEngineBase):
"wi_lora_b_kernel", "wi_lora_b_kernel",
nn.initializers.zeros, nn.initializers.zeros,
wi_lora_b_kernel_shape, wi_lora_b_kernel_shape,
jnp.float32, self.dtype,
axes=wi_lora_b_kernel_axes, axes=wi_lora_b_kernel_axes,
) )
wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype) wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype)
...@@ -1189,10 +1198,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1189,10 +1198,10 @@ class LayerNormMLP(TransformerEngineBase):
bias_1 = None bias_1 = None
if self.use_bias: if self.use_bias:
bias_1 = nn_partitioning.param_with_axes( bias_1 = nn_partitioning.param_with_axes(
"wi_bias", self.bias_init, intermediate_dim, jnp.float32, axes=self.bias_axes_1 "wi_bias", self.bias_init, intermediate_dim, self.dtype, axes=self.bias_axes_1
) )
bias_1 = bias_1.astype(self.dtype)
bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape
bias_1 = bias_1.astype(self.dtype)
x += jnp.reshape(bias_1, bias_1_shape) x += jnp.reshape(bias_1, bias_1_shape)
x = checkpoint_name(x, ffn1_ckpt_name) x = checkpoint_name(x, ffn1_ckpt_name)
...@@ -1207,6 +1216,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1207,6 +1216,8 @@ class LayerNormMLP(TransformerEngineBase):
z = functools.reduce(operator.mul, activations) z = functools.reduce(operator.mul, activations)
# Remove act axis # Remove act axis
z = jnp.reshape(z, (*z.shape[:-2], -1)) z = jnp.reshape(z, (*z.shape[:-2], -1))
z = z.astype(self.dtype)
# import pdb; pdb.set_trace()
z = nn.Dropout( z = nn.Dropout(
rate=self.intermediate_dropout_rate, rate=self.intermediate_dropout_rate,
...@@ -1215,6 +1226,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1215,6 +1226,7 @@ class LayerNormMLP(TransformerEngineBase):
)(z, deterministic=deterministic) )(z, deterministic=deterministic)
z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes) z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
z = z.astype(self.dtype)
# DenseGeneral 2 # DenseGeneral 2
out = type_safe_dot_general( out = type_safe_dot_general(
...@@ -1228,7 +1240,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1228,7 +1240,7 @@ class LayerNormMLP(TransformerEngineBase):
"wo_lora_a_kernel", "wo_lora_a_kernel",
self.kernel_init, self.kernel_init,
wo_lora_a_kernel_shape, wo_lora_a_kernel_shape,
jnp.float32, self.dtype,
axes=wo_lora_a_kernel_axes, axes=wo_lora_a_kernel_axes,
) )
wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype) wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype)
...@@ -1239,7 +1251,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1239,7 +1251,7 @@ class LayerNormMLP(TransformerEngineBase):
"wo_lora_b_kernel", "wo_lora_b_kernel",
nn.initializers.zeros, nn.initializers.zeros,
wo_lora_b_kernel_shape, wo_lora_b_kernel_shape,
jnp.float32, self.dtype,
axes=wo_lora_b_kernel_axes, axes=wo_lora_b_kernel_axes,
) )
wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype) wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype)
...@@ -1256,7 +1268,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1256,7 +1268,7 @@ class LayerNormMLP(TransformerEngineBase):
bias_2 = None bias_2 = None
if self.use_bias: if self.use_bias:
bias_2 = nn_partitioning.param_with_axes( bias_2 = nn_partitioning.param_with_axes(
"wo_bias", self.bias_init, (hidden_size,), jnp.float32, axes=self.bias_axes_2 "wo_bias", self.bias_init, (hidden_size,), self.dtype, axes=self.bias_axes_2
) )
bias_2 = bias_2.astype(self.dtype) bias_2 = bias_2.astype(self.dtype)
out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,)) out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
......
...@@ -976,7 +976,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -976,7 +976,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
) )
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal") self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.dtype
)
if self.num_gqa_groups is None: if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads self.num_gqa_groups = self.num_attention_heads
super().__post_init__() super().__post_init__()
...@@ -1198,6 +1200,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1198,6 +1200,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
inputs_kv = ln_out inputs_kv = ln_out
key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv) key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
key = key.astype(self.dtype)
value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv) value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
query = checkpoint_name(query, "query_proj") query = checkpoint_name(query, "query_proj")
key = checkpoint_name(key, "key_proj") key = checkpoint_name(key, "key_proj")
...@@ -1437,7 +1440,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho ...@@ -1437,7 +1440,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
"rel_embedding", "rel_embedding",
self.embedding_init, self.embedding_init,
(self.num_attention_heads, self.num_buckets), (self.num_attention_heads, self.num_buckets),
jnp.float32, self.dtype,
axes=self.embedding_axes, axes=self.embedding_axes,
) )
...@@ -1673,10 +1676,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1673,10 +1676,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
def __post_init__(self): def __post_init__(self):
if self.mha_kernel_init is None: if self.mha_kernel_init is None:
self.mha_kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal") self.mha_kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.dtype
)
if self.mlp_kernel_init is None: if self.mlp_kernel_init is None:
self.mlp_kernel_init = nn.initializers.variance_scaling( self.mlp_kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal" 1.0, "fan_in", "truncated_normal", dtype=self.dtype
) )
if self.num_gqa_groups is None: if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads self.num_gqa_groups = self.num_attention_heads
...@@ -1726,6 +1731,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1726,6 +1731,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
outputs: jax.numpy.ndarray outputs: jax.numpy.ndarray
Output tensors. Output tensors.
""" """
inputs = inputs.astype(self.dtype)
assert ( assert (
self.layer_type in TransformerLayerType self.layer_type in TransformerLayerType
), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}." ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment