Unverified Commit 22ccf9b1 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Fix missing axes parameters in TransformerLayer and the wrong shape of...


[JAX] Fix missing axes parameters in TransformerLayer and the wrong shape of bias in LayerNormMLP (#196)

Fixed missing axes and wrong shape of bias in LayerNormMLP
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent 496b8fdd
......@@ -451,8 +451,8 @@ class DenseGeneral(TransformerEngineBase):
y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
if bias is not None:
bais_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
y += jnp.reshape(bias, bais_shape)
bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
y += jnp.reshape(bias, bias_shape)
return y
......@@ -660,8 +660,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axes=self.bias_axes)
if bias is not None:
bais_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
z += jnp.reshape(bias, bais_shape)
bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
z += jnp.reshape(bias, bias_shape)
if self.depth_scaling is not None:
z = z / self.depth_scaling
......@@ -772,7 +772,10 @@ class LayerNormMLP(TransformerEngineBase):
kernel_axes_2: Tuple[str, ...] = ('mlp', 'embed')
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
bias_axes_1: Tuple[str, ...] = ('mlp',)
bias_axes_1: Tuple[str, ...] = (
'act',
'mlp',
)
bias_axes_2: Tuple[str, ...] = ('embed',)
return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ('relu',)
......@@ -961,10 +964,12 @@ class LayerNormMLP(TransformerEngineBase):
bias = None
if self.use_bias:
bias = nn_partitioning.param_with_axes('wi_bias',
self.bias_init, (self.intermediate_dim,),
self.bias_init,
intermediate_dim,
self.dtype,
axes=self.bias_axes_1)
x += jnp.reshape(bias, (1,) * (x.ndim - 1) + (-1,))
bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape
x += jnp.reshape(bias, bias_shape)
if self.activations == ('gelu', 'linear'):
z = geglu(x,
......
......@@ -411,6 +411,10 @@ class MultiHeadAttention(nn.Module):
kernel_init=qkv_init,
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=(
'qkv_dim',
'joined_kv',
),
name='qkv',
dtype=self.dtype)(inputs_q)
if not use_fused_attn:
......@@ -430,6 +434,7 @@ class MultiHeadAttention(nn.Module):
kernel_axes=('embed', 'joined_kv'),
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=('joined_kv',),
dtype=self.dtype,
kernel_init=query_init,
name='query')(inputs_q)
......@@ -441,6 +446,10 @@ class MultiHeadAttention(nn.Module):
kernel_init=kv_init,
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=(
'kv_dim',
'joined_kv',
),
name='kv',
dtype=self.dtype)(inputs_kv)
if not use_fused_attn:
......@@ -455,6 +464,7 @@ class MultiHeadAttention(nn.Module):
kernel_axes=('embed', 'joined_kv'),
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=('joined_kv',),
dtype=self.dtype)
query, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm,
......@@ -470,6 +480,7 @@ class MultiHeadAttention(nn.Module):
kernel_axes=('embed', 'joined_kv'),
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=('joined_kv',),
dtype=self.dtype,
kernel_init=query_init,
name='query')(inputs_q)
......@@ -622,6 +633,7 @@ class MultiHeadAttention(nn.Module):
kernel_axes=('joined_kv', 'embed'),
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=('embed',),
dtype=self.dtype,
name='out')(x)
return out, residual
......@@ -1059,6 +1071,11 @@ class TransformerLayer(nn.Module):
kernel_axes_2=('mlp', 'embed'),
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes_1=(
'act',
'mlp',
),
bias_axes_2=('embed',),
name='mlp',
)(mlp_input, deterministic=deterministic)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment