Unverified Commit 6673f165 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Flax with compute dtype inferred from input dtype. (#1485)



flax module with compute dtype inferred from the inputs
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent eb9857d6
......@@ -56,7 +56,6 @@ class Net(nn.Module):
self_attn_mask_type="padding",
enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral,
dtype=jnp.bfloat16,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
......@@ -72,17 +71,15 @@ class Net(nn.Module):
features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
dtype=jnp.bfloat16,
)(x)
x = te_flax.DenseGeneral(
features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,),
dtype=jnp.bfloat16,
)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2)(x)
return x
......@@ -91,7 +88,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......@@ -136,7 +133,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels, 2)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......
......@@ -51,17 +51,16 @@ class Net(nn.Module):
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding",
enable_relative_embedding=False,
dtype=jnp.bfloat16,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te_flax.DenseGeneral(features=256)(x)
x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te_flax.DenseGeneral(features=256)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2)(x)
return x
......@@ -70,7 +69,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......@@ -115,7 +114,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels, 2)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......
......@@ -57,7 +57,6 @@ class Net(nn.Module):
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding",
enable_relative_embedding=False,
dtype=jnp.bfloat16,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
......@@ -67,17 +66,15 @@ class Net(nn.Module):
features=256,
kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
bias_axes=(NAMED_TP_AXIS,),
dtype=jnp.bfloat16,
)(x)
x = te_flax.DenseGeneral(
features=256,
kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
bias_axes=(NAMED_BROADCAST_AXIS,),
dtype=jnp.bfloat16,
)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2)(x)
return x
......
......@@ -46,17 +46,16 @@ class Net(nn.Module):
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type="padding",
enable_relative_embedding=False,
dtype=jnp.bfloat16,
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te_flax.DenseGeneral(features=256)(x)
x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x)
x = te_flax.DenseGeneral(features=256)(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=2)(x)
return x
......@@ -66,7 +65,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs)
one_hot = jax.nn.one_hot(labels, 2)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......@@ -112,7 +111,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
def loss_fn(var_collect, disable_dropout=False):
logits = state.apply_fn(var_collect, inputs, masks, disable_dropout)
one_hot = jax.nn.one_hot(labels, 2)
one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
......@@ -217,6 +216,7 @@ def train_and_evaluate(args):
with te.fp8_autocast(enabled=args.use_fp8):
encoder = Net(num_embed)
# We use nn.Embed, thus inputs need to be in int
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
var_collect = encoder.init(init_rngs, inputs, masks)
......
......@@ -36,6 +36,8 @@ class Net(nn.Module):
nn_Dense = te_flax.DenseGeneral
else:
nn_Dense = nn.Dense
# dtype is used for param init in TE but computation in Linen.nn
dtype = jnp.float32 if self.use_te else jnp.bfloat16
x = nn.Conv(features=32, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x)
x = nn.relu(x)
......@@ -44,11 +46,13 @@ class Net(nn.Module):
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Dropout(rate=0.25)(x, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
x = nn_Dense(features=128, dtype=jnp.bfloat16)(x)
assert x.dtype == jnp.bfloat16
x = nn_Dense(features=128, dtype=dtype)(x)
x = nn.relu(x)
x = nn.Dropout(rate=0.5)(x, deterministic=disable_dropout)
x = nn_Dense(features=16, dtype=jnp.bfloat16)(x)
x = nn.Dense(features=10, dtype=jnp.bfloat16)(x)
x = nn_Dense(features=16, dtype=dtype)(x)
x = nn_Dense(features=10, dtype=dtype)(x)
assert x.dtype == jnp.bfloat16
return x
......
......@@ -271,7 +271,6 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
intermediate_dim=INTERMEDIATE,
activations=activation_type,
dtype=dtype,
use_bias=use_bias,
)
params_single = ln_mlp_single.init(init_rngs, x)
......@@ -289,7 +288,6 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence=False,
intermediate_dim=INTERMEDIATE,
activations=activation_type,
dtype=dtype,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
......
......@@ -265,8 +265,8 @@ class BaseRunner:
"""Test only the forward"""
inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)
ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs)
ref_layer_cls = partial(self.reference_layer, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)
ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
......@@ -288,8 +288,8 @@ class BaseRunner:
"""Test forward and backward through value_and_grad()"""
inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype)
ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs)
ref_layer_cls = partial(self.reference_layer, **self.attrs)
layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs)
ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks)
test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks)
......
......@@ -110,7 +110,7 @@ class DotProductAttention(nn.Module):
Args:
dropout_rate: dropout rate
dtype: the dtype of the computation (default: float32)
dtype: the data type used to allocate the initial parameters (default: float32).
float32_logits: bool, if True then compute logits in float32 to avoid
numerical issues with bfloat16.
"""
......@@ -195,6 +195,7 @@ class DotProductAttention(nn.Module):
attn_weights = attn_weights * multiplier
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
attn_weights = attn_weights.astype(value.dtype)
# Take the linear combination of `value`.
if self.transpose_batch_sequence:
......@@ -209,7 +210,7 @@ class DenseGeneral(nn.Module):
Attributes:
features: tuple with numbers of output features.
axis: tuple with axes to apply the transformation on.
dtype: the dtype of the computation (default: float32).
dtype: the data type used to allocate the initial parameters (default: float32).
kernel_init: initializer function for the weight matrix.
use_bias: whether to add a bias to the output (default: False).
bias_init: initializer function for the bias vector.
......@@ -226,7 +227,9 @@ class DenseGeneral(nn.Module):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
super().__post_init__()
@nn.compact
......@@ -239,6 +242,7 @@ class DenseGeneral(nn.Module):
Returns:
The transformed input.
"""
input_dtype = inputs.dtype
features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis)
......@@ -248,23 +252,24 @@ class DenseGeneral(nn.Module):
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes
"kernel", self.kernel_init, kernel_param_shape, self.dtype, axes=self.kernel_axes
)
kernel = jnp.asarray(kernel, self.dtype)
kernel = jnp.asarray(kernel, input_dtype)
kernel = jnp.reshape(kernel, kernel_shape)
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, self.features, jnp.float32, axes=self.bias_axes
"bias", self.bias_init, self.features, self.dtype, axes=self.bias_axes
)
bias = bias.astype(self.dtype)
bias = bias.astype(input_dtype)
else:
bias = None
contract_ind = tuple(range(0, len(axis)))
y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
y = y.astype(input_dtype)
if bias is not None:
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
......@@ -281,7 +286,7 @@ class MlpBlock(nn.Module):
kernel_init: Kernel function, passed to the dense layers.
deterministic: Whether the dropout layers should be deterministic.
intermediate_dropout_rate: Dropout rate used after the intermediate layers.
dtype: Type for the dense layer.
dtype: the data type used to allocate the initial parameters (default: float32).
"""
transpose_batch_sequence: bool
......@@ -296,7 +301,9 @@ class MlpBlock(nn.Module):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
super().__post_init__()
@nn.compact
......@@ -358,6 +365,9 @@ class MlpBlock(nn.Module):
bias_axes="embed",
name="wo",
)(x)
assert (
output.dtype == inputs.dtype
), f"input.dtype={input.dtype}, output.dtype={output.dtype}"
return output
......@@ -429,7 +439,7 @@ class MultiHeadAttention(nn.Module):
should be divisible by the number of heads.
num_gqa_groups: number of kv attention heads
head_dim: dimension of each head.
dtype: the dtype of the computation.
dtype: the data type used to allocate the initial parameters (default: float32).
dropout_rate: dropout rate
kernel_init: initializer for the kernel of the Dense layers.
float32_logits: bool, if True then compute logits in float32 to avoid
......@@ -453,7 +463,9 @@ class MultiHeadAttention(nn.Module):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal")
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.dtype
)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
......@@ -738,6 +750,9 @@ class MultiHeadAttention(nn.Module):
dtype=self.dtype,
name="out",
)(x)
assert (
inputs_q.dtype == inputs_kv.dtype == out.dtype
), f"q.dtype={inputs_q.dtype}, kv.dtype={inputs_kv.dtype}, out.dtype={out.dtype}"
return out
......@@ -763,13 +778,13 @@ class LayerNorm(nn.Module):
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
"""Applies layer normalization on the input."""
x = jnp.asarray(x, jnp.float32)
input_dtype = x.dtype
features = x.shape[-1]
scale = nn_partitioning.param_with_axes(
"scale", self.scale_init, (features,), jnp.float32, axes=("embed",)
"scale", self.scale_init, (features,), self.dtype, axes=("embed",)
)
scale = jnp.asarray(scale, self.dtype)
scale = jnp.asarray(scale, input_dtype)
if self.layernorm_type == "layernorm":
mean = jnp.mean(x, axis=-1, keepdims=True)
......@@ -777,9 +792,9 @@ class LayerNorm(nn.Module):
y = (x - mean) * lax.rsqrt(var + self.epsilon)
bias = nn_partitioning.param_with_axes(
"ln_bias", self.bias_init, (features,), jnp.float32, axes=("embed",)
"ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",)
)
bias = jnp.asarray(bias, self.dtype)
bias = jnp.asarray(bias, input_dtype)
if not self.zero_centered_gamma:
z = y * scale + bias
......@@ -792,7 +807,8 @@ class LayerNorm(nn.Module):
y = x * lax.rsqrt(mean2 + self.epsilon)
z = y * scale
return jnp.asarray(z, self.dtype)
assert z.dtype == x.dtype, f"output_dtype={z.dtype}, input_dtype={x.dtype}"
return z
class RelativePositionBiases(nn.Module):
......@@ -805,7 +821,7 @@ class RelativePositionBiases(nn.Module):
distance bucket.
num_heads: Number of heads in the attention layer. Each head will get a
different relative position weighting.
dtype: Type of arrays through this module.
dtype: the data type used to allocate the initial parameters (default: float32).
embedding_init: initializer for relative embedding table.
"""
......@@ -1087,6 +1103,7 @@ class EncoderLayer(nn.Module):
dtype=self.dtype,
name="output_layernorm",
)(y)
assert y.dtype == inputs.dtype, f"output_dtype={y.dtype}, input_dtype={inputs.dtype}"
return y
......@@ -1293,6 +1310,7 @@ class DecoderLayer(nn.Module):
name="output_layernorm",
)(z)
assert z.dtype == inputs.dtype, f"output_dtype={z.dtype}, input_dtype={inputs.dtype}"
return z
......
This diff is collapsed.
......@@ -115,7 +115,6 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
float32_logits: bool = False
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
......@@ -143,6 +142,8 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match."
assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
input_dtype = query.dtype
if self.scale_factor is None:
scale_factor = 1.0 / sqrt(query.shape[-1])
else:
......@@ -150,8 +151,8 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
del self.scale_factor
if self.float32_logits:
query = query.astype(self.dtype)
key = key.astype(self.dtype)
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
h_q, h_kv = query.shape[-2], key.shape[-2]
# The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
# Therefore, we have to maintain two code paths.
......@@ -234,7 +235,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)(
attn_weights, mask, bias
).astype(self.dtype)
).astype(input_dtype)
if is_gqa:
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
......@@ -244,9 +245,12 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
dropout_shape = list(attn_weights.shape)
# TODO(rewang): add attention dropout broadcast dimension arguments for users
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype)
multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype)
attn_weights = attn_weights * multiplier
assert (
attn_weights.dtype == input_dtype
), f"output={attn_weights.dtype}, input={input_dtype}"
if self.transpose_batch_sequence:
if is_gqa:
return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
......@@ -254,6 +258,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
if is_gqa:
return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
......@@ -262,7 +267,6 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK
attn_bias_type: Optional[AttnBiasType] = None
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = False
......@@ -372,6 +376,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
if self.transpose_batch_sequence:
x = x.transpose([1, 0, 2, 3])
assert x.dtype == query.dtype
return x
......@@ -492,9 +497,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
The data type used to allocate the initial parameters.
"""
head_dim: int
......@@ -504,7 +507,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type: AttnMaskType = "causal"
attn_bias_type: AttnBiasType = None
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
dropout_rng_name: str = "dropout"
float32_logits: bool = False
qkv_layout: str = "bshd_bshd_bshd"
......@@ -552,6 +554,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
outputs: jax.numpy.ndarray
Output tensors.
"""
input_dtype = query.dtype
if mask is not None:
if sequence_descriptor is not None:
......@@ -642,7 +645,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
float32_logits=self.float32_logits,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
......@@ -662,7 +664,6 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_mask_type=attn_mask_type,
attn_bias_type=attn_bias_type,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
scale_factor=scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
qkv_layout=qkv_layout,
......@@ -679,7 +680,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
dropout_rng=dropout_rng,
deterministic=deterministic,
)
assert x.dtype == input_dtype, f"output_dtype={x.dtype}, input_dtype={input_dtype}"
return x
......@@ -720,10 +721,10 @@ def rotary_pos_emb(
sin, cos = generate_sin_cos(time_scales)
x1, x2 = jnp.split(x, 2, axis=-1)
part_1 = (x1 * cos - x2 * sin).astype(x.dtype)
part_2 = (x2 * cos + x1 * sin).astype(x.dtype)
part_1 = (x1 * cos - x2 * sin).astype(dtype=x.dtype)
part_2 = (x2 * cos + x1 * sin).astype(dtype=x.dtype)
output = jnp.concatenate([part_1, part_2], axis=-1)
output = jnp.concatenate([part_1, part_2], axis=-1, dtype=x.dtype)
return output
def consecutive_impl():
......@@ -928,8 +929,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
fuse_qkv_params: bool, default = True
If set to True, this module exposes a single fused
parameter for query-key-value for self-attention and key-value for
......@@ -975,7 +974,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False
......@@ -1026,7 +1024,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.weight_dtype
1.0, "fan_in", "normal", dtype=self.dtype
)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
......@@ -1071,6 +1069,11 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
Output tensors.
"""
assert (
inputs_q.dtype == inputs_kv.dtype
), f"q.dtype = {inputs_q.dtype}, kv.dtype = {inputs_kv.dtype}"
input_dtype = inputs_q.dtype
def query_init(*args):
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0)
......@@ -1154,7 +1157,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
dot_input_axes=inputs_logical_axes_no_sp,
name="qkv",
dtype=self.dtype,
weight_dtype=self.weight_dtype,
)(inputs_q)
qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj")
qkv_layout = QKVLayout.BS3HD
......@@ -1178,7 +1180,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
......@@ -1203,7 +1204,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
name="kv",
dtype=self.dtype,
weight_dtype=self.weight_dtype,
)(inputs_kv)
kv_proj = checkpoint_name(kv_proj, "combined_kv_proj")
qkv_layout = QKVLayout.BSHD_BS2HD
......@@ -1221,7 +1221,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
)
query, ln_out = LayerNormDenseGeneral(
enable_layernorm=self.input_layernorm,
......@@ -1242,7 +1241,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
......@@ -1253,9 +1251,11 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
assert ln_out is not None
inputs_kv = ln_out
query = query.astype(input_dtype)
key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
key = key.astype(self.dtype)
key = key.astype(input_dtype)
value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
value = value.astype(input_dtype)
query = checkpoint_name(query, "query_proj")
key = checkpoint_name(key, "key_proj")
value = checkpoint_name(value, "value_proj")
......@@ -1380,7 +1380,6 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_bias_type=self.attn_bias_type,
attention_dropout=self.attention_dropout,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_logits,
qkv_layout=qkv_layout.name,
......@@ -1406,11 +1405,13 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
name="out",
)(x)
out = checkpoint_name(out, "out_proj")
assert (
inputs_q.dtype == out.dtype
), f"output_dtype={out.dtype}, input_dtype={inputs_q.dtype}"
return out, ln_out
......@@ -1435,9 +1436,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
The data type used to allocate the initial parameters.
"""
num_buckets: int
......@@ -1446,7 +1445,6 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
embedding_init: Callable[..., Array] = nn.linear.default_embed_init
embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets")
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
@nn.compact
def __call__(self, q_seqlen, k_seqlen, bidirectional=True):
......@@ -1499,7 +1497,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
"rel_embedding",
self.embedding_init,
(self.num_attention_heads, self.num_buckets),
self.weight_dtype,
self.dtype,
axes=self.embedding_axes,
)
......@@ -1672,9 +1670,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used for computation.
weight_dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type of the module parameters.
The data type used to allocate the initial parameters.
drop_path: float, default = 0.0
When > 0.0, applies stochastic depth per sample in the main
path of the residual block.
......@@ -1727,7 +1723,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
drop_path: float = 0.0
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = False
......@@ -1739,11 +1734,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
def __post_init__(self):
if self.mha_kernel_init is None:
self.mha_kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "normal", dtype=self.weight_dtype
1.0, "fan_in", "normal", dtype=self.dtype
)
if self.mlp_kernel_init is None:
self.mlp_kernel_init = nn.initializers.variance_scaling(
1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype
1.0, "fan_in", "truncated_normal", dtype=self.dtype
)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
......@@ -1793,9 +1788,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
outputs: jax.numpy.ndarray
Output tensors.
"""
inputs = inputs.astype(self.dtype)
input_dtype = inputs.dtype
assert (
self.layer_type in TransformerLayerType
), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}."
......@@ -1833,8 +1826,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
max_distance=128,
num_attention_heads=self.num_attention_heads,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
embedding_init=nn.initializers.variance_scaling(
1.0, "fan_avg", "uniform", dtype=self.dtype
),
name="relpos_bias",
)
else:
......@@ -1867,7 +1861,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
x, ln_out = MultiHeadAttention(
num_attention_heads=self.num_attention_heads,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence,
......@@ -1946,7 +1939,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
y, ln_out = MultiHeadAttention(
num_attention_heads=self.num_attention_heads,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
head_dim=head_dim,
num_gqa_groups=self.num_gqa_groups,
transpose_batch_sequence=self.transpose_batch_sequence,
......@@ -2012,7 +2004,6 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
intermediate_dropout_rate=self.intermediate_dropout,
intermediate_hidden_dropout_dims=self.intermediate_dropout_dims,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
scale_axes=(W_NO_SHARD_AXES,),
ln_bias_axes=(W_NO_SHARD_AXES,),
kernel_init=self.mlp_kernel_init,
......@@ -2062,8 +2053,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_axes=(W_NO_SHARD_AXES,),
transpose_batch_sequence=self.transpose_batch_sequence,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
name="output_layernorm",
)(z)
assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}"
return z
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment