Unverified Commit 6673f165 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Flax with compute dtype inferred from input dtype. (#1485)



flax module with compute dtype inferred from the inputs
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent eb9857d6
......@@ -56,7 +56,6 @@ class Net(nn.Module):
self_attn_mask_type="padding",
enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral,
dtype=jnp.bfloat16,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
......@@ -72,17 +71,15 @@ class Net(nn.Module):
features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
dtype=jnp.bfloat16,
)(x)
x = te_flax.DenseGeneral(
features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,),
dtype=jnp.bfloat16,
)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2)(x)
return x
......@@ -91,7 +88,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......@@ -136,7 +133,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels, 2)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......
......@@ -51,17 +51,16 @@ class Net(nn.Module):
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding",
enable_relative_embedding=False,
dtype=jnp.bfloat16,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te_flax.DenseGeneral(features=256)(x)
x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te_flax.DenseGeneral(features=256)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2)(x)
return x
......@@ -70,7 +69,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......@@ -115,7 +114,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels, 2)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......
......@@ -57,7 +57,6 @@ class Net(nn.Module):
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding",
enable_relative_embedding=False,
dtype=jnp.bfloat16,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
......@@ -67,17 +66,15 @@ class Net(nn.Module):
features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
dtype=jnp.bfloat16,
)(x)
x = te_flax.DenseGeneral(
features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,),
dtype=jnp.bfloat16,
)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2)(x)
return x
......
......@@ -46,17 +46,16 @@ class Net(nn.Module):
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding",
enable_relative_embedding=False,
dtype=jnp.bfloat16,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te_flax.DenseGeneral(features=256)(x)
x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te_flax.DenseGeneral(features=256)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2)(x)
return x
......@@ -66,7 +65,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......@@ -112,7 +111,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels, 2)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......@@ -217,6 +216,7 @@ def train_and_evaluate(args):
with te.fp8_autocast(enabled=args.use_fp8):
encoder = Net(num_embed)
# We use nn.Embed, thus inputs need to be in int
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
var_collect = encoder.init(init_rngs, inputs, masks)
......
......@@ -36,6 +36,8 @@ class Net(nn.Module):
nn_Dense = te_flax.DenseGeneral
else:
nn_Dense = nn.Dense
# dtype is used for param init in TE but computation in Linen.nn
dtype = jnp.float32 if self.use_te else jnp.bfloat16
x = nn.Conv(features=32, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x)
x = nn.relu(x)
......@@ -44,11 +46,13 @@ class Net(nn.Module):
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Dropout(rate=0.25)(x, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
x = nn_Dense(features=128, dtype=jnp.bfloat16)(x)
assert x.dtype == jnp.bfloat16
x = nn_Dense(features=128, dtype=dtype)(x)
x = nn.relu(x)
x = nn.Dropout(rate=0.5)(x, deterministic=disable_dropout)
x = nn_Dense(features=16, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=10, dtype=jnp.bfloat16)(x)
x = nn_Dense(features=16, dtype=dtype)(x)
x = nn_Dense(features=10, dtype=dtype)(x)
assert x.dtype == jnp.bfloat16
return x
......
......@@ -271,7 +271,6 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
intermediate_dim=INTERMEDIATE,
activations=activation_type,
dtype=dtype,
use_bias=use_bias,
)
params_single = ln_mlp_single.init(init_rngs, x)
......@@ -289,7 +288,6 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE,
activations=activation_type,
dtype=dtype,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
......
......@@ -265,8 +265,8 @@ class BaseRunner:
"""Test only the forward"""
inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)
ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs)
ref_layer_cls = partial(self.reference_layer, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)
ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
......@@ -288,8 +288,8 @@ class BaseRunner:
"""Test forward and backward through value_and_grad()"""
inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)
ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs)
ref_layer_cls = partial(self.reference_layer, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)
ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
......
......@@ -110,7 +110,7 @@ class DotProductAttention(nn.Module):
Args:
dropout_rate: dropout rate
dtype: the dtype of the computation (default: float32)
dtype: the data type used to allocate the initial parameters (default: float32).
float32_logits: bool, if True then compute logits in float32 to avoid
numerical issues with bfloat16.
"""
......@@ -195,6 +195,7 @@ class DotProductAttention(nn.Module):
attn_weights = attn_weights * multiplier
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
attn_weights = attn_weights.astype(value.dtype)
# Take the linear combination of `value`.
if self.transpose_batch_sequence:
......@@ -209,7 +210,7 @@ class DenseGeneral(nn.Module):
Attributes:
features: tuple with numbers of output features.
axis: tuple with axes to apply the transformation on.
dtype: the dtype of the computation (default: float32).
dtype: the data type used to allocate the initial parameters (default: float32).
kernel_init: initializer function for the weight matrix.
use_bias: whether to add a bias to the output (default: False).
bias_init: initializer function for the bias vector.
......@@ -226,7 +227,9 @@ class DenseGeneral(nn.Module):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
super().__post_init__()
@nn.compact
......@@ -239,6 +242,7 @@ class DenseGeneral(nn.Module):
Returns:
The transformed input.
"""
input_dtype = inputs.dtype
features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis)
......@@ -248,23 +252,24 @@ class DenseGeneral(nn.Module):
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes
"kernel", self.kernel_init, kernel_param_shape, self.dtype, axes=self.kernel_axes
)
kernel = jnp.asarray(kernel, self.dtype)
kernel = jnp.asarray(kernel, input_dtype)
kernel = jnp.reshape(kernel, kernel_shape)
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, self.features, jnp.float32, axes=self.bias_axes
"bias", self.bias_init, self.features, self.dtype, axes=self.bias_axes
)
bias = bias.astype(self.dtype)
bias = bias.astype(input_dtype)
else:
bias = None
contract_ind = tuple(range(0, len(axis)))
y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
y = y.astype(input_dtype)
if bias is not None:
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
......@@ -281,7 +286,7 @@ class MlpBlock(nn.Module):
kernel_init: Kernel function, passed to the dense layers.
deterministic: Whether the dropout layers should be deterministic.
intermediate_dropout_rate: Dropout rate used after the intermediate layers.
dtype: Type for the dense layer.
dtype: the data type used to allocate the initial parameters (default: float32).
"""
transpose_batch_sequence: bool
......@@ -296,7 +301,9 @@ class MlpBlock(nn.Module):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
super().__post_init__()
@nn.compact
......@@ -358,6 +365,9 @@ class MlpBlock(nn.Module):
bias_axes="embed",
name="wo",
)(x)
assert (
output.dtype == inputs.dtype
), f"input.dtype={input.dtype}, output.dtype={output.dtype}"
return output
......@@ -429,7 +439,7 @@ class MultiHeadAttention(nn.Module):
should be divisible by the number of heads.
num_gqa_groups: number of kv attention heads
head_dim: dimension of each head.
dtype: the dtype of the computation.
dtype: the data type used to allocate the initial parameters (default: float32).
dropout_rate: dropout rate
kernel_init: initializer for the kernel of the Dense layers.
float32_logits: bool, if True then compute logits in float32 to avoid
......@@ -453,7 +463,9 @@ class MultiHeadAttention(nn.Module):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal")
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.dtype
)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
......@@ -738,6 +750,9 @@ class MultiHeadAttention(nn.Module):
dtype=self.dtype,
name="out",
)(x)
assert (
inputs_q.dtype == inputs_kv.dtype == out.dtype
), f"q.dtype={inputs_q.dtype}, kv.dtype={inputs_kv.dtype}, out.dtype={out.dtype}"
return out
......@@ -763,13 +778,13 @@ class LayerNorm(nn.Module):
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""Applies layer normalization on the input."""
x = jnp.asarray(x, jnp.float32)
input_dtype = x.dtype
features = x.shape[-1]
scale = nn_partitioning.param_with_axes(
"scale", self.scale_init, (features,), jnp.float32, axes=("embed",)
"scale", self.scale_init, (features,), self.dtype, axes=("embed",)
)
scale = jnp.asarray(scale, self.dtype)
scale = jnp.asarray(scale, input_dtype)
if self.layernorm_type == "layernorm":
mean = jnp.mean(x, axis=-1, keepdims=True)
......@@ -777,9 +792,9 @@ class LayerNorm(nn.Module):
y = (x - mean) * lax.rsqrt(var + self.epsilon)
bias = nn_partitioning.param_with_axes(
"ln_bias", self.bias_init, (features,), jnp.float32, axes=("embed",)
"ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",)
)
bias = jnp.asarray(bias, self.dtype)
bias = jnp.asarray(bias, input_dtype)
if not self.zero_centered_gamma:
z = y * scale + bias
......@@ -792,7 +807,8 @@ class LayerNorm(nn.Module):
y = x * lax.rsqrt(mean2 + self.epsilon)
z = y * scale
return jnp.asarray(z, self.dtype)
assert z.dtype == x.dtype, f"output_dtype={z.dtype}, input_dtype={x.dtype}"
return z
class RelativePositionBiases(nn.Module):
......@@ -805,7 +821,7 @@ class RelativePositionBiases(nn.Module):
distance bucket.
num_heads: Number of heads in the attention layer. Each head will get a
different relative position weighting.
dtype: Type of arrays through this module.
dtype: the data type used to allocate the initial parameters (default: float32).
embedding_init: initializer for relative embedding table.
"""
......@@ -1087,6 +1103,7 @@ class EncoderLayer(nn.Module):
dtype=self.dtype,
name="output_layernorm",
)(y)
assert y.dtype == inputs.dtype, f"output_dtype={y.dtype}, input_dtype={inputs.dtype}"
return y
......@@ -1293,6 +1310,7 @@ class DecoderLayer(nn.Module):
name="output_layernorm",
)(z)
assert z.dtype == inputs.dtype, f"output_dtype={z.dtype}, input_dtype={inputs.dtype}"
return z
......
......@@ -57,19 +57,15 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga
def _create_layernorm_parameters(
layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype, weight_dtype
layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, input_dtype, dtype
):
scale = nn_partitioning.param_with_axes(
"scale", scale_init, shape, weight_dtype, axes=scale_axes
)
scale = scale.astype(dtype)
scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes)
scale = scale.astype(input_dtype)
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == "layernorm":
bias = nn_partitioning.param_with_axes(
"ln_bias", bias_init, shape, weight_dtype, axes=bias_axes
)
bias = bias.astype(dtype)
bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes)
bias = bias.astype(input_dtype)
else:
assert layernorm_type == "rmsnorm"
bias = None
......@@ -158,15 +154,15 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
heads = inputs.shape[1]
q_seqlen = inputs.shape[2]
k_seqlen = inputs.shape[3]
dtype = inputs.dtype
input_dtype = inputs.dtype
logits = inputs
if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available(
self.softmax_type, batch, heads, q_seqlen, k_seqlen, inputs.dtype
self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
):
if bias is not None:
logits = logits + bias.astype(dtype)
logits = logits + bias.astype(input_dtype)
mask_ = mask
if self.softmax_type is not SoftmaxType.SCALED_MASKED:
......@@ -178,25 +174,27 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
if mask is not None:
attention_bias = lax.select(
mask > 0,
jnp.full(mask.shape, -1e10).astype(dtype),
jnp.full(mask.shape, 0.0).astype(dtype),
jnp.full(mask.shape, -1e10),
jnp.full(mask.shape, 0.0),
)
attention_bias = attention_bias.astype(input_dtype)
if bias is not None:
attention_bias = _combine_biases(attention_bias, bias)
if attention_bias is not None:
logits = logits + attention_bias.astype(dtype)
logits = logits + attention_bias.astype(input_dtype)
# For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED
# and kernel is unavailable, then try on pure scaled softmax custom calls.
if is_softmax_kernel_available(
SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, dtype
SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, input_dtype
):
outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED)
else:
outputs = jax_nn.softmax(logits * self.scale_factor)
assert input_dtype == outputs.dtype
return outputs
......@@ -261,9 +259,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = False
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
......@@ -278,7 +274,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ("embed",)
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
def __post_init__(self):
......@@ -303,7 +298,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
outputs : jax.numpy.ndarray
Output tensors.
"""
x = x.astype(self.dtype)
input_dtype = x.dtype
features = x.shape[-1]
scale, ln_bias = _create_layernorm_parameters(
......@@ -313,10 +308,10 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
self.scale_axes,
self.bias_init,
self.bias_axes,
input_dtype,
self.dtype,
self.weight_dtype,
)
return layernorm(
out = layernorm(
x,
scale,
ln_bias,
......@@ -324,6 +319,8 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon,
)
assert out.dtype == input_dtype
return out
class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-methods
......@@ -408,9 +405,7 @@ class DenseGeneral(TransformerEngineBase):
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
......@@ -428,13 +423,12 @@ class DenseGeneral(TransformerEngineBase):
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
super().__post_init__()
......@@ -454,24 +448,25 @@ class DenseGeneral(TransformerEngineBase):
Output tensors.
"""
input_dtype = inputs.dtype
features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis)
inputs = jnp.asarray(inputs, self.dtype)
axis = _normalize_axes(axis, inputs.ndim)
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
)
kernel = kernel.astype(self.dtype)
if not FP8Helper.is_fp8_enabled():
kernel = kernel.astype(input_dtype)
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes
)
bias = bias.astype(self.dtype)
bias = bias.astype(input_dtype)
else:
bias = None
......@@ -500,11 +495,11 @@ class DenseGeneral(TransformerEngineBase):
"lora_a_kernel",
self.kernel_init,
lora_a_kernel_init_shape,
self.weight_dtype,
self.dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(self.dtype)
lora_a_kernel = lora_a_kernel.astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
......@@ -512,10 +507,10 @@ class DenseGeneral(TransformerEngineBase):
"lora_b_kernel",
nn.initializers.zeros,
lora_b_kernel_shape,
self.weight_dtype,
self.dtype,
axes=lora_b_kernel_axes,
)
lora_b_kernel = lora_b_kernel.astype(self.dtype)
lora_b_kernel = lora_b_kernel.astype(input_dtype)
y += _apply_low_rank_adaptation(
inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
......@@ -524,6 +519,8 @@ class DenseGeneral(TransformerEngineBase):
if bias is not None:
bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
y += jnp.reshape(bias, bias_shape)
assert y.dtype == input_dtype
return y
......@@ -606,9 +603,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
......@@ -638,7 +633,6 @@ class LayerNormDenseGeneral(TransformerEngineBase):
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None
......@@ -650,7 +644,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
1.0,
"fan_in",
"truncated_normal",
dtype=self.weight_dtype,
dtype=self.dtype,
)
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init,
......@@ -677,6 +671,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
If :attr:`return_layernorm_output=False`, then this would be None.
"""
input_dtype = inputs.dtype
ln_output = None
fuse_layernorm = (
......@@ -684,7 +679,6 @@ class LayerNormDenseGeneral(TransformerEngineBase):
and not self.return_layernorm_output
and self.enable_layernorm
)
inputs = inputs.astype(self.dtype)
if self.enable_layernorm:
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
......@@ -699,8 +693,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
self.scale_axes,
self.ln_bias_init,
self.ln_bias_axes,
input_dtype,
self.dtype,
self.weight_dtype,
)
if not fuse_layernorm:
......@@ -730,9 +724,10 @@ class LayerNormDenseGeneral(TransformerEngineBase):
kernel_shape = tuple(y.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
)
kernel = kernel.astype(self.dtype)
if not FP8Helper.is_fp8_enabled():
kernel = kernel.astype(input_dtype)
contract_ind = tuple(range(0, len(axis)))
......@@ -775,11 +770,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
"lora_a_kernel",
self.kernel_init,
lora_a_kernel_init_shape,
self.weight_dtype,
self.dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(self.dtype)
lora_a_kernel = lora_a_kernel.astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
......@@ -787,10 +782,10 @@ class LayerNormDenseGeneral(TransformerEngineBase):
"lora_b_kernel",
nn.initializers.zeros,
lora_b_kernel_shape,
self.weight_dtype,
self.dtype,
axes=lora_b_kernel_axes,
)
lora_b_kernel = lora_b_kernel.astype(self.dtype)
lora_b_kernel = lora_b_kernel.astype(input_dtype)
z += _apply_low_rank_adaptation(
y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
......@@ -799,9 +794,9 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias = None
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes
)
bias = bias.astype(self.dtype)
bias = bias.astype(input_dtype)
if bias is not None:
bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
......@@ -810,6 +805,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if self.depth_scaling is not None:
z = z / self.depth_scaling
assert z.dtype == input_dtype
return z, ln_output # dense_output, layer_norm_output
......@@ -915,9 +911,7 @@ class LayerNormMLP(TransformerEngineBase):
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
The data type used to allocate the initial parameters.
transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors
......@@ -950,7 +944,6 @@ class LayerNormMLP(TransformerEngineBase):
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None
dot_1_input_axes: Tuple[str, ...] = None
......@@ -959,7 +952,7 @@ class LayerNormMLP(TransformerEngineBase):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init,
......@@ -988,6 +981,7 @@ class LayerNormMLP(TransformerEngineBase):
If :attr:`return_layernorm_output=False`, then this would be None.
"""
input_dtype = inputs.dtype
ln_output = None
fuse_layernorm = (
......@@ -996,8 +990,6 @@ class LayerNormMLP(TransformerEngineBase):
and self.enable_layernorm
)
inputs = inputs.astype(self.dtype)
gated_act_pool = [
("gelu", "linear"),
("silu", "linear"),
......@@ -1035,8 +1027,8 @@ class LayerNormMLP(TransformerEngineBase):
self.scale_axes,
self.ln_bias_init,
self.ln_bias_axes,
input_dtype,
self.dtype,
self.weight_dtype,
)
if not fuse_layernorm:
......@@ -1083,11 +1075,12 @@ class LayerNormMLP(TransformerEngineBase):
num_activations,
-2,
kernel_1_each_shape,
self.weight_dtype,
self.dtype,
axes=self.kernel_axes_1,
)
kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
kernel_1 = kernel_1.astype(self.dtype)
if not FP8Helper.is_fp8_enabled():
kernel_1 = kernel_1.astype(input_dtype)
hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
......@@ -1096,11 +1089,12 @@ class LayerNormMLP(TransformerEngineBase):
"wo_kernel",
self.kernel_init,
kernel_2_param_shape,
self.weight_dtype,
self.dtype,
axes=self.kernel_axes_2,
)
kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
kernel_2 = kernel_2.astype(self.dtype)
if not FP8Helper.is_fp8_enabled():
kernel_2 = kernel_2.astype(input_dtype)
contract_ind = tuple(range(0, len(axis)))
ffn1_ckpt_name = "ffn1"
......@@ -1115,20 +1109,20 @@ class LayerNormMLP(TransformerEngineBase):
"wi_bias",
self.bias_init,
bias_1_shape,
self.weight_dtype,
self.dtype,
axes=self.bias_axes_1,
)
bias_1 = bias_1.astype(self.dtype)
bias_1 = bias_1.astype(input_dtype)
bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes(
"wo_bias",
self.bias_init,
bias_2_shape,
self.weight_dtype,
self.dtype,
axes=self.bias_axes_2,
)
bias_2 = bias_2.astype(self.dtype)
bias_2 = bias_2.astype(input_dtype)
else:
bias_1 = None
bias_2 = None
......@@ -1195,11 +1189,11 @@ class LayerNormMLP(TransformerEngineBase):
num_activations,
-2,
wi_lora_a_kernel_init_each_shape,
self.weight_dtype,
self.dtype,
axes=wi_lora_a_kernel_axes,
)
wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
wi_lora_a_kernel = wi_lora_a_kernel.astype(self.dtype)
wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype)
wi_lora_b_kernel_shape = (
num_activations,
......@@ -1211,10 +1205,10 @@ class LayerNormMLP(TransformerEngineBase):
"wi_lora_b_kernel",
nn.initializers.zeros,
wi_lora_b_kernel_shape,
self.weight_dtype,
self.dtype,
axes=wi_lora_b_kernel_axes,
)
wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype)
wi_lora_b_kernel = wi_lora_b_kernel.astype(input_dtype)
x += _apply_low_rank_adaptation(
y,
......@@ -1231,11 +1225,11 @@ class LayerNormMLP(TransformerEngineBase):
"wi_bias",
self.bias_init,
intermediate_dim,
self.weight_dtype,
self.dtype,
axes=self.bias_axes_1,
)
bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape
bias_1 = bias_1.astype(self.dtype)
bias_1 = bias_1.astype(input_dtype)
x += jnp.reshape(bias_1, bias_1_shape)
x = checkpoint_name(x, ffn1_ckpt_name)
......@@ -1250,7 +1244,7 @@ class LayerNormMLP(TransformerEngineBase):
z = functools.reduce(operator.mul, activations)
# Remove act axis
z = jnp.reshape(z, (*z.shape[:-2], -1))
z = z.astype(self.dtype)
z = z.astype(input_dtype)
z = nn.Dropout(
rate=self.intermediate_dropout_rate,
......@@ -1259,7 +1253,7 @@ class LayerNormMLP(TransformerEngineBase):
)(z, deterministic=deterministic)
z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
z = z.astype(self.dtype)
z = z.astype(input_dtype)
# DenseGeneral 2
out = type_safe_dot_general(
......@@ -1273,10 +1267,10 @@ class LayerNormMLP(TransformerEngineBase):
"wo_lora_a_kernel",
self.kernel_init,
wo_lora_a_kernel_shape,
self.weight_dtype,
self.dtype,
axes=wo_lora_a_kernel_axes,
)
wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype)
wo_lora_a_kernel = wo_lora_a_kernel.astype(input_dtype)
wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
......@@ -1284,10 +1278,10 @@ class LayerNormMLP(TransformerEngineBase):
"wo_lora_b_kernel",
nn.initializers.zeros,
wo_lora_b_kernel_shape,
self.weight_dtype,
self.dtype,
axes=wo_lora_b_kernel_axes,
)
wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype)
wo_lora_b_kernel = wo_lora_b_kernel.astype(input_dtype)
out += _apply_low_rank_adaptation(
z,
......@@ -1304,12 +1298,13 @@ class LayerNormMLP(TransformerEngineBase):
"wo_bias",
self.bias_init,
(hidden_size,),
self.weight_dtype,
self.dtype,
axes=self.bias_axes_2,
)
bias_2 = bias_2.astype(self.dtype)
bias_2 = bias_2.astype(input_dtype)
out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
out = checkpoint_name(out, ffn2_ckpt_name)
assert out.dtype == input_dtype
return out, ln_output # Output, layner_norm_output
......@@ -115,7 +115,6 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
float32_logits: bool = False
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
......@@ -143,6 +142,8 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match."
assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
input_dtype = query.dtype
if self.scale_factor is None:
scale_factor = 1.0 / sqrt(query.shape[-1])
else:
......@@ -150,8 +151,8 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
del self.scale_factor
if self.float32_logits:
query = query.astype(self.dtype)
key = key.astype(self.dtype)
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
h_q, h_kv = query.shape[-2], key.shape[-2]
# The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
# Therefore, we have to maintain two code paths.
......@@ -234,7 +235,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)(
attn_weights, mask, bias
).astype(self.dtype)
).astype(input_dtype)
if is_gqa:
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
......@@ -244,9 +245,12 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
dropout_shape = list(attn_weights.shape)
# TODO(rewang): add attention dropout broadcast dimension arguments for users
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype)
multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype)
attn_weights = attn_weights * multiplier
assert (
attn_weights.dtype == input_dtype
), f"output={attn_weights.dtype}, input={input_dtype}"
if self.transpose_batch_sequence:
if is_gqa:
return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
......@@ -254,6 +258,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
if is_gqa:
return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
......@@ -262,7 +267,6 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False
......@@ -372,6 +376,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
if self.transpose_batch_sequence:
x = x.transpose([1, 0, 2, 3])
assert x.dtype == query.dtype
return x
......@@ -492,9 +497,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
The data type used to allocate the initial parameters.
"""
head_dim: int
......@@ -504,7 +507,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type: AttnMaskType = "causal"
attn_bias_type: AttnBiasType = None
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
dropout_rng_name: str = "dropout"
float32_logits: bool = False
qkv_layout: str = "bshd_bshd_bshd"
......@@ -552,6 +554,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
outputs: jax.numpy.ndarray
Output tensors.
"""
input_dtype = query.dtype
if mask is not None:
if sequence_descriptor is not None:
......@@ -642,7 +645,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
float32_logits=self.float32_logits,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
......@@ -662,7 +664,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
qkv_layout=qkv_layout,
......@@ -679,7 +680,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
dropout_rng=dropout_rng,
deterministic=deterministic,
)
assert x.dtype == input_dtype, f"output_dtype={x.dtype}, input_dtype={input_dtype}"
return x
......@@ -720,10 +721,10 @@ def rotary_pos_emb(
sin, cos = generate_sin_cos(time_scales)
x1, x2 = jnp.split(x, 2, axis=-1)
part_1 = (x1 * cos - x2 * sin).astype(x.dtype)
part_2 = (x2 * cos + x1 * sin).astype(x.dtype)
part_1 = (x1 * cos - x2 * sin).astype(dtype=x.dtype)
part_2 = (x2 * cos + x1 * sin).astype(dtype=x.dtype)
output = jnp.concatenate([part_1, part_2], axis=-1)
output = jnp.concatenate([part_1, part_2], axis=-1, dtype=x.dtype)
return output
def consecutive_impl():
......@@ -928,8 +929,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
fuse_qkv_params: bool, default = True
If set to True, this module exposes a single fused
parameter for query-key-value for self-attention and key-value for
......@@ -975,7 +974,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False
......@@ -1026,7 +1024,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.weight_dtype
1.0, "fan_in", "normal", dtype=self.dtype
)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
......@@ -1071,6 +1069,11 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
Output tensors.
"""
assert (
inputs_q.dtype == inputs_kv.dtype
), f"q.dtype = {inputs_q.dtype}, kv.dtype = {inputs_kv.dtype}"
input_dtype = inputs_q.dtype
def query_init(*args):
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)
......@@ -1154,7 +1157,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
dot_input_axes=inputs_logical_axes_no_sp,
name="qkv",
dtype=self.dtype,
weight_dtype=self.weight_dtype,
)(inputs_q)
qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj")
qkv_layout = QKVLayout.BS3HD
......@@ -1178,7 +1180,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
......@@ -1203,7 +1204,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
name="kv",
dtype=self.dtype,
weight_dtype=self.weight_dtype,
)(inputs_kv)
kv_proj = checkpoint_name(kv_proj, "combined_kv_proj")
qkv_layout = QKVLayout.BSHD_BS2HD
......@@ -1221,7 +1221,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
)
query, ln_out = LayerNormDenseGeneral(
enable_layernorm=self.input_layernorm,
......@@ -1242,7 +1241,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
......@@ -1253,9 +1251,11 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
assert ln_out is not None
inputs_kv = ln_out
query = query.astype(input_dtype)
key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
key = key.astype(self.dtype)
key = key.astype(input_dtype)
value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
value = value.astype(input_dtype)
query = checkpoint_name(query, "query_proj")
key = checkpoint_name(key, "key_proj")
value = checkpoint_name(value, "value_proj")
......@@ -1380,7 +1380,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_bias_type=self.attn_bias_type,
attention_dropout=self.attention_dropout,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_logits,
qkv_layout=qkv_layout.name,
......@@ -1406,11 +1405,13 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
name="out",
)(x)
out = checkpoint_name(out, "out_proj")
assert (
inputs_q.dtype == out.dtype
), f"output_dtype={out.dtype}, input_dtype={inputs_q.dtype}"
return out, ln_out
......@@ -1435,9 +1436,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
The data type used to allocate the initial parameters.
"""
num_buckets: int
......@@ -1446,7 +1445,6 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
embedding_init: Callable[..., Array] = nn.linear.default_embed_init
embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets")
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
@nn.compact
def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
......@@ -1499,7 +1497,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
"rel_embedding",
self.embedding_init,
(self.num_attention_heads, self.num_buckets),
self.weight_dtype,
self.dtype,
axes=self.embedding_axes,
)
......@@ -1672,9 +1670,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
The data type used to allocate the initial parameters.
drop_path: float, default = 0.0
When > 0.0, applies stochastic depth per sample in the main
path of the residual block.
......@@ -1727,7 +1723,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
drop_path: float = 0.0
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = False
......@@ -1739,11 +1734,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
def __post_init__(self):
if self.mha_kernel_init is None:
self.mha_kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.weight_dtype
1.0, "fan_in", "normal", dtype=self.dtype
)
if self.mlp_kernel_init is None:
self.mlp_kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
......@@ -1793,9 +1788,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
outputs: jax.numpy.ndarray
Output tensors.
"""
inputs = inputs.astype(self.dtype)
input_dtype = inputs.dtype
assert (
self.layer_type in TransformerLayerType
), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}."
......@@ -1833,8 +1826,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
max_distance=128,
num_attention_heads=self.num_attention_heads,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
embedding_init=nn.initializers.variance_scaling(
1.0, "fan_avg", "uniform", dtype=self.dtype
),
name="relpos_bias",
)
else:
......@@ -1867,7 +1861,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
x, ln_out = MultiHeadAttention(
num_attention_heads=self.num_attention_heads,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence,
......@@ -1946,7 +1939,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
y, ln_out = MultiHeadAttention(
num_attention_heads=self.num_attention_heads,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence,
......@@ -2012,7 +2004,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
intermediate_dropout_rate=self.intermediate_dropout,
intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_init=self.mlp_kernel_init,
......@@ -2062,8 +2053,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
name="output_layernorm",
)(z)
assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
return z
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment