Unverified Commit 6673f165 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Flax with compute dtype inferred from input dtype. (#1485)



flax module with compute dtype inferred from the inputs
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent eb9857d6
...@@ -56,7 +56,6 @@ class Net(nn.Module): ...@@ -56,7 +56,6 @@ class Net(nn.Module):
self_attn_mask_type="padding", self_attn_mask_type="padding",
enable_relative_embedding=False, enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral, enable_sequence_parallel=self.enable_seq_paral,
dtype=jnp.bfloat16,
) )
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
...@@ -72,17 +71,15 @@ class Net(nn.Module): ...@@ -72,17 +71,15 @@ class Net(nn.Module):
features=256, features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,), bias_axes=(NAMED_TP_AXIS,),
dtype=jnp.bfloat16,
)(x) )(x)
x = te_flax.DenseGeneral( x = te_flax.DenseGeneral(
features=256, features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,), bias_axes=(NAMED_BROADCAST_AXIS,),
dtype=jnp.bfloat16,
)(x) )(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2)(x)
return x return x
...@@ -91,7 +88,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): ...@@ -91,7 +88,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
def loss_fn(var_collect, disable_dropout=False): def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2) one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
...@@ -136,7 +133,7 @@ def eval_step(state, inputs, masks, labels, var_collect): ...@@ -136,7 +133,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
def loss_fn(var_collect, disable_dropout=False): def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels, 2) one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
......
...@@ -51,17 +51,16 @@ class Net(nn.Module): ...@@ -51,17 +51,16 @@ class Net(nn.Module):
layer_type=te_flax.TransformerLayerType.ENCODER, layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding", self_attn_mask_type="padding",
enable_relative_embedding=False, enable_relative_embedding=False,
dtype=jnp.bfloat16,
) )
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) x = te_flax.DenseGeneral(features=256)(x)
x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) x = te_flax.DenseGeneral(features=256)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2)(x)
return x return x
...@@ -70,7 +69,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): ...@@ -70,7 +69,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
def loss_fn(var_collect, disable_dropout=False): def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2) one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
...@@ -115,7 +114,7 @@ def eval_step(state, inputs, masks, labels, var_collect): ...@@ -115,7 +114,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
def loss_fn(var_collect, disable_dropout=False): def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels, 2) one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
......
...@@ -57,7 +57,6 @@ class Net(nn.Module): ...@@ -57,7 +57,6 @@ class Net(nn.Module):
layer_type=te_flax.TransformerLayerType.ENCODER, layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding", self_attn_mask_type="padding",
enable_relative_embedding=False, enable_relative_embedding=False,
dtype=jnp.bfloat16,
) )
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
...@@ -67,17 +66,15 @@ class Net(nn.Module): ...@@ -67,17 +66,15 @@ class Net(nn.Module):
features=256, features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,), bias_axes=(NAMED_TP_AXIS,),
dtype=jnp.bfloat16,
)(x) )(x)
x = te_flax.DenseGeneral( x = te_flax.DenseGeneral(
features=256, features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,), bias_axes=(NAMED_BROADCAST_AXIS,),
dtype=jnp.bfloat16,
)(x) )(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2)(x)
return x return x
......
...@@ -46,17 +46,16 @@ class Net(nn.Module): ...@@ -46,17 +46,16 @@ class Net(nn.Module):
layer_type=te_flax.TransformerLayerType.ENCODER, layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding", self_attn_mask_type="padding",
enable_relative_embedding=False, enable_relative_embedding=False,
dtype=jnp.bfloat16,
) )
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) x = te_flax.DenseGeneral(features=256)(x)
x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) x = te_flax.DenseGeneral(features=256)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) x = nn.Dense(features=2)(x)
return x return x
...@@ -66,7 +65,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): ...@@ -66,7 +65,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
def loss_fn(var_collect, disable_dropout=False): def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2) one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
...@@ -112,7 +111,7 @@ def eval_step(state, inputs, masks, labels, var_collect): ...@@ -112,7 +111,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
def loss_fn(var_collect, disable_dropout=False): def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels, 2) one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
...@@ -217,6 +216,7 @@ def train_and_evaluate(args): ...@@ -217,6 +216,7 @@ def train_and_evaluate(args):
with te.fp8_autocast(enabled=args.use_fp8): with te.fp8_autocast(enabled=args.use_fp8):
encoder = Net(num_embed) encoder = Net(num_embed)
# We use nn.Embed, thus inputs need to be in int
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8) masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
var_collect = encoder.init(init_rngs, inputs, masks) var_collect = encoder.init(init_rngs, inputs, masks)
......
...@@ -36,6 +36,8 @@ class Net(nn.Module): ...@@ -36,6 +36,8 @@ class Net(nn.Module):
nn_Dense = te_flax.DenseGeneral nn_Dense = te_flax.DenseGeneral
else: else:
nn_Dense = nn.Dense nn_Dense = nn.Dense
# dtype is used for param init in TE but computation in Linen.nn
dtype = jnp.float32 if self.use_te else jnp.bfloat16
x = nn.Conv(features=32, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x) x = nn.Conv(features=32, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x)
x = nn.relu(x) x = nn.relu(x)
...@@ -44,11 +46,13 @@ class Net(nn.Module): ...@@ -44,11 +46,13 @@ class Net(nn.Module):
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Dropout(rate=0.25)(x, deterministic=disable_dropout) x = nn.Dropout(rate=0.25)(x, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
x = nn_Dense(features=128, dtype=jnp.bfloat16)(x) assert x.dtype == jnp.bfloat16
x = nn_Dense(features=128, dtype=dtype)(x)
x = nn.relu(x) x = nn.relu(x)
x = nn.Dropout(rate=0.5)(x, deterministic=disable_dropout) x = nn.Dropout(rate=0.5)(x, deterministic=disable_dropout)
x = nn_Dense(features=16, dtype=jnp.bfloat16)(x) x = nn_Dense(features=16, dtype=dtype)(x)
x = nn.Dense(features=10, dtype=jnp.bfloat16)(x) x = nn_Dense(features=10, dtype=dtype)(x)
assert x.dtype == jnp.bfloat16
return x return x
......
...@@ -271,7 +271,6 @@ class TestDistributedLayernormMLP: ...@@ -271,7 +271,6 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence=False, # input: [batch, seqlen, hidden] transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
activations=activation_type, activations=activation_type,
dtype=dtype,
use_bias=use_bias, use_bias=use_bias,
) )
params_single = ln_mlp_single.init(init_rngs, x) params_single = ln_mlp_single.init(init_rngs, x)
...@@ -289,7 +288,6 @@ class TestDistributedLayernormMLP: ...@@ -289,7 +288,6 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence=False, transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
activations=activation_type, activations=activation_type,
dtype=dtype,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
......
...@@ -265,8 +265,8 @@ class BaseRunner: ...@@ -265,8 +265,8 @@ class BaseRunner:
"""Test only the forward""" """Test only the forward"""
inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype) inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)
ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs) ref_layer_cls = partial(self.reference_layer, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs) layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)
ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks) ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks) test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
...@@ -288,8 +288,8 @@ class BaseRunner: ...@@ -288,8 +288,8 @@ class BaseRunner:
"""Test forward and backward through value_and_grad()""" """Test forward and backward through value_and_grad()"""
inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype) inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)
ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs) ref_layer_cls = partial(self.reference_layer, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs) layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)
ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks) ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks) test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
......
...@@ -110,7 +110,7 @@ class DotProductAttention(nn.Module): ...@@ -110,7 +110,7 @@ class DotProductAttention(nn.Module):
Args: Args:
dropout_rate: dropout rate dropout_rate: dropout rate
dtype: the dtype of the computation (default: float32) dtype: the data type used to allocate the initial parameters (default: float32).
float32_logits: bool, if True then compute logits in float32 to avoid float32_logits: bool, if True then compute logits in float32 to avoid
numerical issues with bfloat16. numerical issues with bfloat16.
""" """
...@@ -195,6 +195,7 @@ class DotProductAttention(nn.Module): ...@@ -195,6 +195,7 @@ class DotProductAttention(nn.Module):
attn_weights = attn_weights * multiplier attn_weights = attn_weights * multiplier
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
attn_weights = attn_weights.astype(value.dtype)
# Take the linear combination of `value`. # Take the linear combination of `value`.
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
...@@ -209,7 +210,7 @@ class DenseGeneral(nn.Module): ...@@ -209,7 +210,7 @@ class DenseGeneral(nn.Module):
Attributes: Attributes:
features: tuple with numbers of output features. features: tuple with numbers of output features.
axis: tuple with axes to apply the transformation on. axis: tuple with axes to apply the transformation on.
dtype: the dtype of the computation (default: float32). dtype: the data type used to allocate the initial parameters (default: float32).
kernel_init: initializer function for the weight matrix. kernel_init: initializer function for the weight matrix.
use_bias: whether to add a bias to the output (default: False). use_bias: whether to add a bias to the output (default: False).
bias_init: initializer function for the bias vector. bias_init: initializer function for the bias vector.
...@@ -226,7 +227,9 @@ class DenseGeneral(nn.Module): ...@@ -226,7 +227,9 @@ class DenseGeneral(nn.Module):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -239,6 +242,7 @@ class DenseGeneral(nn.Module): ...@@ -239,6 +242,7 @@ class DenseGeneral(nn.Module):
Returns: Returns:
The transformed input. The transformed input.
""" """
input_dtype = inputs.dtype
features = _canonicalize_tuple(self.features) features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis) axis = _canonicalize_tuple(self.axis)
...@@ -248,23 +252,24 @@ class DenseGeneral(nn.Module): ...@@ -248,23 +252,24 @@ class DenseGeneral(nn.Module):
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features)) kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
kernel = nn_partitioning.param_with_axes( kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes "kernel", self.kernel_init, kernel_param_shape, self.dtype, axes=self.kernel_axes
) )
kernel = jnp.asarray(kernel, self.dtype) kernel = jnp.asarray(kernel, input_dtype)
kernel = jnp.reshape(kernel, kernel_shape) kernel = jnp.reshape(kernel, kernel_shape)
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, self.features, jnp.float32, axes=self.bias_axes "bias", self.bias_init, self.features, self.dtype, axes=self.bias_axes
) )
bias = bias.astype(self.dtype) bias = bias.astype(input_dtype)
else: else:
bias = None bias = None
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
y = y.astype(input_dtype)
if bias is not None: if bias is not None:
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
...@@ -281,7 +286,7 @@ class MlpBlock(nn.Module): ...@@ -281,7 +286,7 @@ class MlpBlock(nn.Module):
kernel_init: Kernel function, passed to the dense layers. kernel_init: Kernel function, passed to the dense layers.
deterministic: Whether the dropout layers should be deterministic. deterministic: Whether the dropout layers should be deterministic.
intermediate_dropout_rate: Dropout rate used after the intermediate layers. intermediate_dropout_rate: Dropout rate used after the intermediate layers.
dtype: Type for the dense layer. dtype: the data type used to allocate the initial parameters (default: float32).
""" """
transpose_batch_sequence: bool transpose_batch_sequence: bool
...@@ -296,7 +301,9 @@ class MlpBlock(nn.Module): ...@@ -296,7 +301,9 @@ class MlpBlock(nn.Module):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -358,6 +365,9 @@ class MlpBlock(nn.Module): ...@@ -358,6 +365,9 @@ class MlpBlock(nn.Module):
bias_axes="embed", bias_axes="embed",
name="wo", name="wo",
)(x) )(x)
assert (
output.dtype == inputs.dtype
), f"input.dtype={input.dtype}, output.dtype={output.dtype}"
return output return output
...@@ -429,7 +439,7 @@ class MultiHeadAttention(nn.Module): ...@@ -429,7 +439,7 @@ class MultiHeadAttention(nn.Module):
should be divisible by the number of heads. should be divisible by the number of heads.
num_gqa_groups: number of kv attention heads num_gqa_groups: number of kv attention heads
head_dim: dimension of each head. head_dim: dimension of each head.
dtype: the dtype of the computation. dtype: the data type used to allocate the initial parameters (default: float32).
dropout_rate: dropout rate dropout_rate: dropout rate
kernel_init: initializer for the kernel of the Dense layers. kernel_init: initializer for the kernel of the Dense layers.
float32_logits: bool, if True then compute logits in float32 to avoid float32_logits: bool, if True then compute logits in float32 to avoid
...@@ -453,7 +463,9 @@ class MultiHeadAttention(nn.Module): ...@@ -453,7 +463,9 @@ class MultiHeadAttention(nn.Module):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal") self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.dtype
)
if self.num_gqa_groups is None: if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads self.num_gqa_groups = self.num_attention_heads
super().__post_init__() super().__post_init__()
...@@ -738,6 +750,9 @@ class MultiHeadAttention(nn.Module): ...@@ -738,6 +750,9 @@ class MultiHeadAttention(nn.Module):
dtype=self.dtype, dtype=self.dtype,
name="out", name="out",
)(x) )(x)
assert (
inputs_q.dtype == inputs_kv.dtype == out.dtype
), f"q.dtype={inputs_q.dtype}, kv.dtype={inputs_kv.dtype}, out.dtype={out.dtype}"
return out return out
...@@ -763,13 +778,13 @@ class LayerNorm(nn.Module): ...@@ -763,13 +778,13 @@ class LayerNorm(nn.Module):
def __call__(self, x: jnp.ndarray) -> jnp.ndarray: def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""Applies layer normalization on the input.""" """Applies layer normalization on the input."""
x = jnp.asarray(x, jnp.float32) input_dtype = x.dtype
features = x.shape[-1] features = x.shape[-1]
scale = nn_partitioning.param_with_axes( scale = nn_partitioning.param_with_axes(
"scale", self.scale_init, (features,), jnp.float32, axes=("embed",) "scale", self.scale_init, (features,), self.dtype, axes=("embed",)
) )
scale = jnp.asarray(scale, self.dtype) scale = jnp.asarray(scale, input_dtype)
if self.layernorm_type == "layernorm": if self.layernorm_type == "layernorm":
mean = jnp.mean(x, axis=-1, keepdims=True) mean = jnp.mean(x, axis=-1, keepdims=True)
...@@ -777,9 +792,9 @@ class LayerNorm(nn.Module): ...@@ -777,9 +792,9 @@ class LayerNorm(nn.Module):
y = (x - mean) * lax.rsqrt(var + self.epsilon) y = (x - mean) * lax.rsqrt(var + self.epsilon)
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes(
"ln_bias", self.bias_init, (features,), jnp.float32, axes=("embed",) "ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",)
) )
bias = jnp.asarray(bias, self.dtype) bias = jnp.asarray(bias, input_dtype)
if not self.zero_centered_gamma: if not self.zero_centered_gamma:
z = y * scale + bias z = y * scale + bias
...@@ -792,7 +807,8 @@ class LayerNorm(nn.Module): ...@@ -792,7 +807,8 @@ class LayerNorm(nn.Module):
y = x * lax.rsqrt(mean2 + self.epsilon) y = x * lax.rsqrt(mean2 + self.epsilon)
z = y * scale z = y * scale
return jnp.asarray(z, self.dtype) assert z.dtype == x.dtype, f"output_dtype={z.dtype}, input_dtype={x.dtype}"
return z
class RelativePositionBiases(nn.Module): class RelativePositionBiases(nn.Module):
...@@ -805,7 +821,7 @@ class RelativePositionBiases(nn.Module): ...@@ -805,7 +821,7 @@ class RelativePositionBiases(nn.Module):
distance bucket. distance bucket.
num_heads: Number of heads in the attention layer. Each head will get a num_heads: Number of heads in the attention layer. Each head will get a
different relative position weighting. different relative position weighting.
dtype: Type of arrays through this module. dtype: the data type used to allocate the initial parameters (default: float32).
embedding_init: initializer for relative embedding table. embedding_init: initializer for relative embedding table.
""" """
...@@ -1087,6 +1103,7 @@ class EncoderLayer(nn.Module): ...@@ -1087,6 +1103,7 @@ class EncoderLayer(nn.Module):
dtype=self.dtype, dtype=self.dtype,
name="output_layernorm", name="output_layernorm",
)(y) )(y)
assert y.dtype == inputs.dtype, f"output_dtype={y.dtype}, input_dtype={inputs.dtype}"
return y return y
...@@ -1293,6 +1310,7 @@ class DecoderLayer(nn.Module): ...@@ -1293,6 +1310,7 @@ class DecoderLayer(nn.Module):
name="output_layernorm", name="output_layernorm",
)(z) )(z)
assert z.dtype == inputs.dtype, f"output_dtype={z.dtype}, input_dtype={inputs.dtype}"
return z return z
......
...@@ -57,19 +57,15 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga ...@@ -57,19 +57,15 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga
def _create_layernorm_parameters( def _create_layernorm_parameters(
layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype, weight_dtype layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, input_dtype, dtype
): ):
scale = nn_partitioning.param_with_axes( scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes)
"scale", scale_init, shape, weight_dtype, axes=scale_axes scale = scale.astype(input_dtype)
)
scale = scale.astype(dtype)
layernorm_type = canonicalize_layernorm_type(layernorm_type) layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == "layernorm": if layernorm_type == "layernorm":
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes)
"ln_bias", bias_init, shape, weight_dtype, axes=bias_axes bias = bias.astype(input_dtype)
)
bias = bias.astype(dtype)
else: else:
assert layernorm_type == "rmsnorm" assert layernorm_type == "rmsnorm"
bias = None bias = None
...@@ -158,15 +154,15 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods ...@@ -158,15 +154,15 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
heads = inputs.shape[1] heads = inputs.shape[1]
q_seqlen = inputs.shape[2] q_seqlen = inputs.shape[2]
k_seqlen = inputs.shape[3] k_seqlen = inputs.shape[3]
dtype = inputs.dtype input_dtype = inputs.dtype
logits = inputs logits = inputs
if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available( if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available(
self.softmax_type, batch, heads, q_seqlen, k_seqlen, inputs.dtype self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype
): ):
if bias is not None: if bias is not None:
logits = logits + bias.astype(dtype) logits = logits + bias.astype(input_dtype)
mask_ = mask mask_ = mask
if self.softmax_type is not SoftmaxType.SCALED_MASKED: if self.softmax_type is not SoftmaxType.SCALED_MASKED:
...@@ -178,25 +174,27 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods ...@@ -178,25 +174,27 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
if mask is not None: if mask is not None:
attention_bias = lax.select( attention_bias = lax.select(
mask > 0, mask > 0,
jnp.full(mask.shape, -1e10).astype(dtype), jnp.full(mask.shape, -1e10),
jnp.full(mask.shape, 0.0).astype(dtype), jnp.full(mask.shape, 0.0),
) )
attention_bias = attention_bias.astype(input_dtype)
if bias is not None: if bias is not None:
attention_bias = _combine_biases(attention_bias, bias) attention_bias = _combine_biases(attention_bias, bias)
if attention_bias is not None: if attention_bias is not None:
logits = logits + attention_bias.astype(dtype) logits = logits + attention_bias.astype(input_dtype)
# For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED # For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED
# and kernel is unavailable, then try on pure scaled softmax custom calls. # and kernel is unavailable, then try on pure scaled softmax custom calls.
if is_softmax_kernel_available( if is_softmax_kernel_available(
SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, dtype SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, input_dtype
): ):
outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED) outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED)
else: else:
outputs = jax_nn.softmax(logits * self.scale_factor) outputs = jax_nn.softmax(logits * self.scale_factor)
assert input_dtype == outputs.dtype
return outputs return outputs
...@@ -261,9 +259,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -261,9 +259,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation. The data type used to allocate the initial parameters.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
transpose_batch_sequence : bool, default = False transpose_batch_sequence : bool, default = False
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors and sequence length dimension. If set to True, the input tensors
...@@ -278,7 +274,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -278,7 +274,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ("embed",) bias_axes: Tuple[str, ...] = ("embed",)
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
def __post_init__(self): def __post_init__(self):
...@@ -303,7 +298,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -303,7 +298,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
outputs : jax.numpy.ndarray outputs : jax.numpy.ndarray
Output tensors. Output tensors.
""" """
x = x.astype(self.dtype) input_dtype = x.dtype
features = x.shape[-1] features = x.shape[-1]
scale, ln_bias = _create_layernorm_parameters( scale, ln_bias = _create_layernorm_parameters(
...@@ -313,10 +308,10 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -313,10 +308,10 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
self.scale_axes, self.scale_axes,
self.bias_init, self.bias_init,
self.bias_axes, self.bias_axes,
input_dtype,
self.dtype, self.dtype,
self.weight_dtype,
) )
return layernorm( out = layernorm(
x, x,
scale, scale,
ln_bias, ln_bias,
...@@ -324,6 +319,8 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -324,6 +319,8 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
epsilon=self.epsilon, epsilon=self.epsilon,
) )
assert out.dtype == input_dtype
return out
class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-methods class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-methods
...@@ -408,9 +405,7 @@ class DenseGeneral(TransformerEngineBase): ...@@ -408,9 +405,7 @@ class DenseGeneral(TransformerEngineBase):
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation. The data type used to allocate the initial parameters.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
transpose_batch_sequence : bool, default = True transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors and sequence length dimension. If set to True, the input tensors
...@@ -428,13 +423,12 @@ class DenseGeneral(TransformerEngineBase): ...@@ -428,13 +423,12 @@ class DenseGeneral(TransformerEngineBase):
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype 1.0, "fan_in", "truncated_normal", dtype=self.dtype
) )
super().__post_init__() super().__post_init__()
...@@ -454,24 +448,25 @@ class DenseGeneral(TransformerEngineBase): ...@@ -454,24 +448,25 @@ class DenseGeneral(TransformerEngineBase):
Output tensors. Output tensors.
""" """
input_dtype = inputs.dtype
features = _canonicalize_tuple(self.features) features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis) axis = _canonicalize_tuple(self.axis)
inputs = jnp.asarray(inputs, self.dtype)
axis = _normalize_axes(axis, inputs.ndim) axis = _normalize_axes(axis, inputs.ndim)
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes( kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
) )
kernel = kernel.astype(self.dtype) if not FP8Helper.is_fp8_enabled():
kernel = kernel.astype(input_dtype)
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
) )
bias = bias.astype(self.dtype) bias = bias.astype(input_dtype)
else: else:
bias = None bias = None
...@@ -500,11 +495,11 @@ class DenseGeneral(TransformerEngineBase): ...@@ -500,11 +495,11 @@ class DenseGeneral(TransformerEngineBase):
"lora_a_kernel", "lora_a_kernel",
self.kernel_init, self.kernel_init,
lora_a_kernel_init_shape, lora_a_kernel_init_shape,
self.weight_dtype, self.dtype,
axes=lora_a_kernel_axes, axes=lora_a_kernel_axes,
) )
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(self.dtype) lora_a_kernel = lora_a_kernel.astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape) lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
...@@ -512,10 +507,10 @@ class DenseGeneral(TransformerEngineBase): ...@@ -512,10 +507,10 @@ class DenseGeneral(TransformerEngineBase):
"lora_b_kernel", "lora_b_kernel",
nn.initializers.zeros, nn.initializers.zeros,
lora_b_kernel_shape, lora_b_kernel_shape,
self.weight_dtype, self.dtype,
axes=lora_b_kernel_axes, axes=lora_b_kernel_axes,
) )
lora_b_kernel = lora_b_kernel.astype(self.dtype) lora_b_kernel = lora_b_kernel.astype(input_dtype)
y += _apply_low_rank_adaptation( y += _apply_low_rank_adaptation(
inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
...@@ -524,6 +519,8 @@ class DenseGeneral(TransformerEngineBase): ...@@ -524,6 +519,8 @@ class DenseGeneral(TransformerEngineBase):
if bias is not None: if bias is not None:
bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
y += jnp.reshape(bias, bias_shape) y += jnp.reshape(bias, bias_shape)
assert y.dtype == input_dtype
return y return y
...@@ -606,9 +603,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -606,9 +603,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation. The data type used to allocate the initial parameters.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
transpose_batch_sequence : bool, default = True transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors and sequence length dimension. If set to True, the input tensors
...@@ -638,7 +633,6 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -638,7 +633,6 @@ class LayerNormDenseGeneral(TransformerEngineBase):
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None layernorm_input_axes: Tuple[str, ...] = None
dot_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None
...@@ -650,7 +644,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -650,7 +644,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
1.0, 1.0,
"fan_in", "fan_in",
"truncated_normal", "truncated_normal",
dtype=self.weight_dtype, dtype=self.dtype,
) )
self.scale_init = _obtain_default_layernorm_scale_init_if_need( self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init, self.scale_init,
...@@ -677,6 +671,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -677,6 +671,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
If :attr:`return_layernorm_output=False`, then this would be None. If :attr:`return_layernorm_output=False`, then this would be None.
""" """
input_dtype = inputs.dtype
ln_output = None ln_output = None
fuse_layernorm = ( fuse_layernorm = (
...@@ -684,7 +679,6 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -684,7 +679,6 @@ class LayerNormDenseGeneral(TransformerEngineBase):
and not self.return_layernorm_output and not self.return_layernorm_output
and self.enable_layernorm and self.enable_layernorm
) )
inputs = inputs.astype(self.dtype)
if self.enable_layernorm: if self.enable_layernorm:
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes) inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
...@@ -699,8 +693,8 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -699,8 +693,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
self.scale_axes, self.scale_axes,
self.ln_bias_init, self.ln_bias_init,
self.ln_bias_axes, self.ln_bias_axes,
input_dtype,
self.dtype, self.dtype,
self.weight_dtype,
) )
if not fuse_layernorm: if not fuse_layernorm:
...@@ -730,9 +724,10 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -730,9 +724,10 @@ class LayerNormDenseGeneral(TransformerEngineBase):
kernel_shape = tuple(y.shape[ax] for ax in axis) + features kernel_shape = tuple(y.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes( kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
) )
kernel = kernel.astype(self.dtype) if not FP8Helper.is_fp8_enabled():
kernel = kernel.astype(input_dtype)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
...@@ -775,11 +770,11 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -775,11 +770,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
"lora_a_kernel", "lora_a_kernel",
self.kernel_init, self.kernel_init,
lora_a_kernel_init_shape, lora_a_kernel_init_shape,
self.weight_dtype, self.dtype,
axes=lora_a_kernel_axes, axes=lora_a_kernel_axes,
) )
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(self.dtype) lora_a_kernel = lora_a_kernel.astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape) lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
...@@ -787,10 +782,10 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -787,10 +782,10 @@ class LayerNormDenseGeneral(TransformerEngineBase):
"lora_b_kernel", "lora_b_kernel",
nn.initializers.zeros, nn.initializers.zeros,
lora_b_kernel_shape, lora_b_kernel_shape,
self.weight_dtype, self.dtype,
axes=lora_b_kernel_axes, axes=lora_b_kernel_axes,
) )
lora_b_kernel = lora_b_kernel.astype(self.dtype) lora_b_kernel = lora_b_kernel.astype(input_dtype)
z += _apply_low_rank_adaptation( z += _apply_low_rank_adaptation(
y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
...@@ -799,9 +794,9 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -799,9 +794,9 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias = None bias = None
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes( bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes "bias", self.bias_init, features, self.dtype, axes=self.bias_axes
) )
bias = bias.astype(self.dtype) bias = bias.astype(input_dtype)
if bias is not None: if bias is not None:
bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
...@@ -810,6 +805,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -810,6 +805,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if self.depth_scaling is not None: if self.depth_scaling is not None:
z = z / self.depth_scaling z = z / self.depth_scaling
assert z.dtype == input_dtype
return z, ln_output # dense_output, layer_norm_output return z, ln_output # dense_output, layer_norm_output
...@@ -915,9 +911,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -915,9 +911,7 @@ class LayerNormMLP(TransformerEngineBase):
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation. The data type used to allocate the initial parameters.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
transpose_batch_sequence : bool, default = True transpose_batch_sequence : bool, default = True
Indicate whether the input tensors were switched axis of batch Indicate whether the input tensors were switched axis of batch
and sequence length dimension. If set to True, the input tensors and sequence length dimension. If set to True, the input tensors
...@@ -950,7 +944,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -950,7 +944,6 @@ class LayerNormMLP(TransformerEngineBase):
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
layernorm_input_axes: Tuple[str, ...] = None layernorm_input_axes: Tuple[str, ...] = None
dot_1_input_axes: Tuple[str, ...] = None dot_1_input_axes: Tuple[str, ...] = None
...@@ -959,7 +952,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -959,7 +952,7 @@ class LayerNormMLP(TransformerEngineBase):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype 1.0, "fan_in", "truncated_normal", dtype=self.dtype
) )
self.scale_init = _obtain_default_layernorm_scale_init_if_need( self.scale_init = _obtain_default_layernorm_scale_init_if_need(
self.scale_init, self.scale_init,
...@@ -988,6 +981,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -988,6 +981,7 @@ class LayerNormMLP(TransformerEngineBase):
If :attr:`return_layernorm_output=False`, then this would be None. If :attr:`return_layernorm_output=False`, then this would be None.
""" """
input_dtype = inputs.dtype
ln_output = None ln_output = None
fuse_layernorm = ( fuse_layernorm = (
...@@ -996,8 +990,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -996,8 +990,6 @@ class LayerNormMLP(TransformerEngineBase):
and self.enable_layernorm and self.enable_layernorm
) )
inputs = inputs.astype(self.dtype)
gated_act_pool = [ gated_act_pool = [
("gelu", "linear"), ("gelu", "linear"),
("silu", "linear"), ("silu", "linear"),
...@@ -1035,8 +1027,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1035,8 +1027,8 @@ class LayerNormMLP(TransformerEngineBase):
self.scale_axes, self.scale_axes,
self.ln_bias_init, self.ln_bias_init,
self.ln_bias_axes, self.ln_bias_axes,
input_dtype,
self.dtype, self.dtype,
self.weight_dtype,
) )
if not fuse_layernorm: if not fuse_layernorm:
...@@ -1083,11 +1075,12 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1083,11 +1075,12 @@ class LayerNormMLP(TransformerEngineBase):
num_activations, num_activations,
-2, -2,
kernel_1_each_shape, kernel_1_each_shape,
self.weight_dtype, self.dtype,
axes=self.kernel_axes_1, axes=self.kernel_axes_1,
) )
kernel_1 = jnp.reshape(kernel_1, kernel_1_shape) kernel_1 = jnp.reshape(kernel_1, kernel_1_shape)
kernel_1 = kernel_1.astype(self.dtype) if not FP8Helper.is_fp8_enabled():
kernel_1 = kernel_1.astype(input_dtype)
hidden_size = inputs.shape[-1] hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size) hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
...@@ -1096,11 +1089,12 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1096,11 +1089,12 @@ class LayerNormMLP(TransformerEngineBase):
"wo_kernel", "wo_kernel",
self.kernel_init, self.kernel_init,
kernel_2_param_shape, kernel_2_param_shape,
self.weight_dtype, self.dtype,
axes=self.kernel_axes_2, axes=self.kernel_axes_2,
) )
kernel_2 = jnp.reshape(kernel_2, kernel_2_shape) kernel_2 = jnp.reshape(kernel_2, kernel_2_shape)
kernel_2 = kernel_2.astype(self.dtype) if not FP8Helper.is_fp8_enabled():
kernel_2 = kernel_2.astype(input_dtype)
contract_ind = tuple(range(0, len(axis))) contract_ind = tuple(range(0, len(axis)))
ffn1_ckpt_name = "ffn1" ffn1_ckpt_name = "ffn1"
...@@ -1115,20 +1109,20 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1115,20 +1109,20 @@ class LayerNormMLP(TransformerEngineBase):
"wi_bias", "wi_bias",
self.bias_init, self.bias_init,
bias_1_shape, bias_1_shape,
self.weight_dtype, self.dtype,
axes=self.bias_axes_1, axes=self.bias_axes_1,
) )
bias_1 = bias_1.astype(self.dtype) bias_1 = bias_1.astype(input_dtype)
bias_2_shape = (hidden_size,) bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes( bias_2 = nn_partitioning.param_with_axes(
"wo_bias", "wo_bias",
self.bias_init, self.bias_init,
bias_2_shape, bias_2_shape,
self.weight_dtype, self.dtype,
axes=self.bias_axes_2, axes=self.bias_axes_2,
) )
bias_2 = bias_2.astype(self.dtype) bias_2 = bias_2.astype(input_dtype)
else: else:
bias_1 = None bias_1 = None
bias_2 = None bias_2 = None
...@@ -1195,11 +1189,11 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1195,11 +1189,11 @@ class LayerNormMLP(TransformerEngineBase):
num_activations, num_activations,
-2, -2,
wi_lora_a_kernel_init_each_shape, wi_lora_a_kernel_init_each_shape,
self.weight_dtype, self.dtype,
axes=wi_lora_a_kernel_axes, axes=wi_lora_a_kernel_axes,
) )
wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape) wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
wi_lora_a_kernel = wi_lora_a_kernel.astype(self.dtype) wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype)
wi_lora_b_kernel_shape = ( wi_lora_b_kernel_shape = (
num_activations, num_activations,
...@@ -1211,10 +1205,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1211,10 +1205,10 @@ class LayerNormMLP(TransformerEngineBase):
"wi_lora_b_kernel", "wi_lora_b_kernel",
nn.initializers.zeros, nn.initializers.zeros,
wi_lora_b_kernel_shape, wi_lora_b_kernel_shape,
self.weight_dtype, self.dtype,
axes=wi_lora_b_kernel_axes, axes=wi_lora_b_kernel_axes,
) )
wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype) wi_lora_b_kernel = wi_lora_b_kernel.astype(input_dtype)
x += _apply_low_rank_adaptation( x += _apply_low_rank_adaptation(
y, y,
...@@ -1231,11 +1225,11 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1231,11 +1225,11 @@ class LayerNormMLP(TransformerEngineBase):
"wi_bias", "wi_bias",
self.bias_init, self.bias_init,
intermediate_dim, intermediate_dim,
self.weight_dtype, self.dtype,
axes=self.bias_axes_1, axes=self.bias_axes_1,
) )
bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape
bias_1 = bias_1.astype(self.dtype) bias_1 = bias_1.astype(input_dtype)
x += jnp.reshape(bias_1, bias_1_shape) x += jnp.reshape(bias_1, bias_1_shape)
x = checkpoint_name(x, ffn1_ckpt_name) x = checkpoint_name(x, ffn1_ckpt_name)
...@@ -1250,7 +1244,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1250,7 +1244,7 @@ class LayerNormMLP(TransformerEngineBase):
z = functools.reduce(operator.mul, activations) z = functools.reduce(operator.mul, activations)
# Remove act axis # Remove act axis
z = jnp.reshape(z, (*z.shape[:-2], -1)) z = jnp.reshape(z, (*z.shape[:-2], -1))
z = z.astype(self.dtype) z = z.astype(input_dtype)
z = nn.Dropout( z = nn.Dropout(
rate=self.intermediate_dropout_rate, rate=self.intermediate_dropout_rate,
...@@ -1259,7 +1253,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1259,7 +1253,7 @@ class LayerNormMLP(TransformerEngineBase):
)(z, deterministic=deterministic) )(z, deterministic=deterministic)
z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes) z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes)
z = z.astype(self.dtype) z = z.astype(input_dtype)
# DenseGeneral 2 # DenseGeneral 2
out = type_safe_dot_general( out = type_safe_dot_general(
...@@ -1273,10 +1267,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1273,10 +1267,10 @@ class LayerNormMLP(TransformerEngineBase):
"wo_lora_a_kernel", "wo_lora_a_kernel",
self.kernel_init, self.kernel_init,
wo_lora_a_kernel_shape, wo_lora_a_kernel_shape,
self.weight_dtype, self.dtype,
axes=wo_lora_a_kernel_axes, axes=wo_lora_a_kernel_axes,
) )
wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype) wo_lora_a_kernel = wo_lora_a_kernel.astype(input_dtype)
wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size) wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape) wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
...@@ -1284,10 +1278,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1284,10 +1278,10 @@ class LayerNormMLP(TransformerEngineBase):
"wo_lora_b_kernel", "wo_lora_b_kernel",
nn.initializers.zeros, nn.initializers.zeros,
wo_lora_b_kernel_shape, wo_lora_b_kernel_shape,
self.weight_dtype, self.dtype,
axes=wo_lora_b_kernel_axes, axes=wo_lora_b_kernel_axes,
) )
wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype) wo_lora_b_kernel = wo_lora_b_kernel.astype(input_dtype)
out += _apply_low_rank_adaptation( out += _apply_low_rank_adaptation(
z, z,
...@@ -1304,12 +1298,13 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1304,12 +1298,13 @@ class LayerNormMLP(TransformerEngineBase):
"wo_bias", "wo_bias",
self.bias_init, self.bias_init,
(hidden_size,), (hidden_size,),
self.weight_dtype, self.dtype,
axes=self.bias_axes_2, axes=self.bias_axes_2,
) )
bias_2 = bias_2.astype(self.dtype) bias_2 = bias_2.astype(input_dtype)
out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,)) out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,))
out = checkpoint_name(out, ffn2_ckpt_name) out = checkpoint_name(out, ffn2_ckpt_name)
assert out.dtype == input_dtype
return out, ln_output # Output, layner_norm_output return out, ln_output # Output, layner_norm_output
...@@ -115,7 +115,6 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -115,7 +115,6 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
float32_logits: bool = False float32_logits: bool = False
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
...@@ -143,6 +142,8 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -143,6 +142,8 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match." assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match."
assert query.shape[-1] == key.shape[-1], "q, k head_dim must match." assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
input_dtype = query.dtype
if self.scale_factor is None: if self.scale_factor is None:
scale_factor = 1.0 / sqrt(query.shape[-1]) scale_factor = 1.0 / sqrt(query.shape[-1])
else: else:
...@@ -150,8 +151,8 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -150,8 +151,8 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
del self.scale_factor del self.scale_factor
if self.float32_logits: if self.float32_logits:
query = query.astype(self.dtype) query = query.astype(jnp.float32)
key = key.astype(self.dtype) key = key.astype(jnp.float32)
h_q, h_kv = query.shape[-2], key.shape[-2] h_q, h_kv = query.shape[-2], key.shape[-2]
# The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv. # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
# Therefore, we have to maintain two code paths. # Therefore, we have to maintain two code paths.
...@@ -234,7 +235,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -234,7 +235,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)( attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)(
attn_weights, mask, bias attn_weights, mask, bias
).astype(self.dtype) ).astype(input_dtype)
if is_gqa: if is_gqa:
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
...@@ -244,9 +245,12 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -244,9 +245,12 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
dropout_shape = list(attn_weights.shape) dropout_shape = list(attn_weights.shape)
# TODO(rewang): add attention dropout broadcast dimension arguments for users # TODO(rewang): add attention dropout broadcast dimension arguments for users
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype) multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype)
attn_weights = attn_weights * multiplier attn_weights = attn_weights * multiplier
assert (
attn_weights.dtype == input_dtype
), f"output={attn_weights.dtype}, input={input_dtype}"
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
if is_gqa: if is_gqa:
return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape) return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
...@@ -254,6 +258,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- ...@@ -254,6 +258,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
if is_gqa: if is_gqa:
return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape) return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value) return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
...@@ -262,7 +267,6 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -262,7 +267,6 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
scale_factor: Optional[float] = None scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
...@@ -372,6 +376,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -372,6 +376,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
x = x.transpose([1, 0, 2, 3]) x = x.transpose([1, 0, 2, 3])
assert x.dtype == query.dtype
return x return x
...@@ -492,9 +497,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -492,9 +497,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation. The data type used to allocate the initial parameters.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
""" """
head_dim: int head_dim: int
...@@ -504,7 +507,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -504,7 +507,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type: AttnMaskType = "causal" attn_mask_type: AttnMaskType = "causal"
attn_bias_type: AttnBiasType = None attn_bias_type: AttnBiasType = None
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
dropout_rng_name: str = "dropout" dropout_rng_name: str = "dropout"
float32_logits: bool = False float32_logits: bool = False
qkv_layout: str = "bshd_bshd_bshd" qkv_layout: str = "bshd_bshd_bshd"
...@@ -552,6 +554,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -552,6 +554,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
outputs: jax.numpy.ndarray outputs: jax.numpy.ndarray
Output tensors. Output tensors.
""" """
input_dtype = query.dtype
if mask is not None: if mask is not None:
if sequence_descriptor is not None: if sequence_descriptor is not None:
...@@ -642,7 +645,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -642,7 +645,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
float32_logits=self.float32_logits, float32_logits=self.float32_logits,
scale_factor=scale_factor, scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
...@@ -662,7 +664,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -662,7 +664,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
scale_factor=scale_factor, scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
...@@ -679,7 +680,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -679,7 +680,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
dropout_rng=dropout_rng, dropout_rng=dropout_rng,
deterministic=deterministic, deterministic=deterministic,
) )
assert x.dtype == input_dtype, f"output_dtype={x.dtype}, input_dtype={input_dtype}"
return x return x
...@@ -720,10 +721,10 @@ def rotary_pos_emb( ...@@ -720,10 +721,10 @@ def rotary_pos_emb(
sin, cos = generate_sin_cos(time_scales) sin, cos = generate_sin_cos(time_scales)
x1, x2 = jnp.split(x, 2, axis=-1) x1, x2 = jnp.split(x, 2, axis=-1)
part_1 = (x1 * cos - x2 * sin).astype(x.dtype) part_1 = (x1 * cos - x2 * sin).astype(dtype=x.dtype)
part_2 = (x2 * cos + x1 * sin).astype(x.dtype) part_2 = (x2 * cos + x1 * sin).astype(dtype=x.dtype)
output = jnp.concatenate([part_1, part_2], axis=-1) output = jnp.concatenate([part_1, part_2], axis=-1, dtype=x.dtype)
return output return output
def consecutive_impl(): def consecutive_impl():
...@@ -928,8 +929,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -928,8 +929,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation. The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
fuse_qkv_params: bool, default = True fuse_qkv_params: bool, default = True
If set to True, this module exposes a single fused If set to True, this module exposes a single fused
parameter for query-key-value for self-attention and key-value for parameter for query-key-value for self-attention and key-value for
...@@ -975,7 +974,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -975,7 +974,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim: int = 32 low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False enable_sequence_parallel: bool = False
...@@ -1026,7 +1024,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1026,7 +1024,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling( self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.weight_dtype 1.0, "fan_in", "normal", dtype=self.dtype
) )
if self.num_gqa_groups is None: if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads self.num_gqa_groups = self.num_attention_heads
...@@ -1071,6 +1069,11 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1071,6 +1069,11 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
Output tensors. Output tensors.
""" """
assert (
inputs_q.dtype == inputs_kv.dtype
), f"q.dtype = {inputs_q.dtype}, kv.dtype = {inputs_kv.dtype}"
input_dtype = inputs_q.dtype
def query_init(*args): def query_init(*args):
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0) return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)
...@@ -1154,7 +1157,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1154,7 +1157,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
dot_input_axes=inputs_logical_axes_no_sp, dot_input_axes=inputs_logical_axes_no_sp,
name="qkv", name="qkv",
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
)(inputs_q) )(inputs_q)
qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj") qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj")
qkv_layout = QKVLayout.BS3HD qkv_layout = QKVLayout.BS3HD
...@@ -1178,7 +1180,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1178,7 +1180,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
kernel_init=query_init, kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp, layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp, dot_input_axes=inputs_logical_axes_no_sp,
...@@ -1203,7 +1204,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1203,7 +1204,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
name="kv", name="kv",
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
)(inputs_kv) )(inputs_kv)
kv_proj = checkpoint_name(kv_proj, "combined_kv_proj") kv_proj = checkpoint_name(kv_proj, "combined_kv_proj")
qkv_layout = QKVLayout.BSHD_BS2HD qkv_layout = QKVLayout.BSHD_BS2HD
...@@ -1221,7 +1221,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1221,7 +1221,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
) )
query, ln_out = LayerNormDenseGeneral( query, ln_out = LayerNormDenseGeneral(
enable_layernorm=self.input_layernorm, enable_layernorm=self.input_layernorm,
...@@ -1242,7 +1241,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1242,7 +1241,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
kernel_init=query_init, kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp, layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp, dot_input_axes=inputs_logical_axes_no_sp,
...@@ -1253,9 +1251,11 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1253,9 +1251,11 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
assert ln_out is not None assert ln_out is not None
inputs_kv = ln_out inputs_kv = ln_out
query = query.astype(input_dtype)
key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv) key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
key = key.astype(self.dtype) key = key.astype(input_dtype)
value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv) value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
value = value.astype(input_dtype)
query = checkpoint_name(query, "query_proj") query = checkpoint_name(query, "query_proj")
key = checkpoint_name(key, "key_proj") key = checkpoint_name(key, "key_proj")
value = checkpoint_name(value, "value_proj") value = checkpoint_name(value, "value_proj")
...@@ -1380,7 +1380,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1380,7 +1380,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_bias_type=self.attn_bias_type, attn_bias_type=self.attn_bias_type,
attention_dropout=self.attention_dropout, attention_dropout=self.attention_dropout,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
dropout_rng_name=self.dropout_rng_name, dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_logits, float32_logits=self.float32_logits,
qkv_layout=qkv_layout.name, qkv_layout=qkv_layout.name,
...@@ -1406,11 +1405,13 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1406,11 +1405,13 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
name="out", name="out",
)(x) )(x)
out = checkpoint_name(out, "out_proj") out = checkpoint_name(out, "out_proj")
assert (
inputs_q.dtype == out.dtype
), f"output_dtype={out.dtype}, input_dtype={inputs_q.dtype}"
return out, ln_out return out, ln_out
...@@ -1435,9 +1436,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho ...@@ -1435,9 +1436,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation. The data type used to allocate the initial parameters.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
""" """
num_buckets: int num_buckets: int
...@@ -1446,7 +1445,6 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho ...@@ -1446,7 +1445,6 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
embedding_init: Callable[..., Array] = nn.linear.default_embed_init embedding_init: Callable[..., Array] = nn.linear.default_embed_init
embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets") embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets")
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
@nn.compact @nn.compact
def __call__(self, q_seqlen, k_seqlen, bidirectional=True): def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
...@@ -1499,7 +1497,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho ...@@ -1499,7 +1497,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
"rel_embedding", "rel_embedding",
self.embedding_init, self.embedding_init,
(self.num_attention_heads, self.num_buckets), (self.num_attention_heads, self.num_buckets),
self.weight_dtype, self.dtype,
axes=self.embedding_axes, axes=self.embedding_axes,
) )
...@@ -1672,9 +1670,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1672,9 +1670,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters Optimization parameters
----------------------- -----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32 dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation. The data type used to allocate the initial parameters.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
drop_path: float, default = 0.0 drop_path: float, default = 0.0
When > 0.0, applies stochastic depth per sample in the main When > 0.0, applies stochastic depth per sample in the main
path of the residual block. path of the residual block.
...@@ -1727,7 +1723,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1727,7 +1723,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim: int = 32 low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32 dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
drop_path: float = 0.0 drop_path: float = 0.0
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
...@@ -1739,11 +1734,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1739,11 +1734,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
def __post_init__(self): def __post_init__(self):
if self.mha_kernel_init is None: if self.mha_kernel_init is None:
self.mha_kernel_init = nn.initializers.variance_scaling( self.mha_kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.weight_dtype 1.0, "fan_in", "normal", dtype=self.dtype
) )
if self.mlp_kernel_init is None: if self.mlp_kernel_init is None:
self.mlp_kernel_init = nn.initializers.variance_scaling( self.mlp_kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype 1.0, "fan_in", "truncated_normal", dtype=self.dtype
) )
if self.num_gqa_groups is None: if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads self.num_gqa_groups = self.num_attention_heads
...@@ -1793,9 +1788,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1793,9 +1788,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
outputs: jax.numpy.ndarray outputs: jax.numpy.ndarray
Output tensors. Output tensors.
""" """
input_dtype = inputs.dtype
inputs = inputs.astype(self.dtype)
assert ( assert (
self.layer_type in TransformerLayerType self.layer_type in TransformerLayerType
), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}." ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}."
...@@ -1833,8 +1826,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1833,8 +1826,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
max_distance=128, max_distance=128,
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype, embedding_init=nn.initializers.variance_scaling(
embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"), 1.0, "fan_avg", "uniform", dtype=self.dtype
),
name="relpos_bias", name="relpos_bias",
) )
else: else:
...@@ -1867,7 +1861,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1867,7 +1861,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
x, ln_out = MultiHeadAttention( x, ln_out = MultiHeadAttention(
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
head_dim=head_dim, head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups, num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
...@@ -1946,7 +1939,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1946,7 +1939,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
y, ln_out = MultiHeadAttention( y, ln_out = MultiHeadAttention(
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
head_dim=head_dim, head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups, num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
...@@ -2012,7 +2004,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -2012,7 +2004,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
intermediate_dropout_rate=self.intermediate_dropout, intermediate_dropout_rate=self.intermediate_dropout,
intermediate_hidden_dropout_dims=self.intermediate_dropout_dims, intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
scale_axes=(W_NO_SHARD_AXES,), scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_init=self.mlp_kernel_init, kernel_init=self.mlp_kernel_init,
...@@ -2062,8 +2053,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -2062,8 +2053,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_NO_SHARD_AXES,), bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype, dtype=self.dtype,
weight_dtype=self.weight_dtype,
name="output_layernorm", name="output_layernorm",
)(z) )(z)
assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
return z return z
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment