Unverified Commit 22ccf9b1 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Fix missing axes parameters in TransformerLayer and the wrong shape of...


[JAX] Fix missing axes parameters in TransformerLayer and the wrong shape of bias in LayerNormMLP (#196)

Fixed missing axes and wrong shape of bias in LayerNormMLP
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent 496b8fdd
...@@ -451,8 +451,8 @@ class DenseGeneral(TransformerEngineBase): ...@@ -451,8 +451,8 @@ class DenseGeneral(TransformerEngineBase):
y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
if bias is not None: if bias is not None:
bais_shape = (1,) * (y.ndim - bias.ndim) + bias.shape bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
y += jnp.reshape(bias, bais_shape) y += jnp.reshape(bias, bias_shape)
return y return y
...@@ -660,8 +660,8 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -660,8 +660,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axes=self.bias_axes) axes=self.bias_axes)
if bias is not None: if bias is not None:
bais_shape = (1,) * (z.ndim - bias.ndim) + bias.shape bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape
z += jnp.reshape(bias, bais_shape) z += jnp.reshape(bias, bias_shape)
if self.depth_scaling is not None: if self.depth_scaling is not None:
z = z / self.depth_scaling z = z / self.depth_scaling
...@@ -772,7 +772,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -772,7 +772,10 @@ class LayerNormMLP(TransformerEngineBase):
kernel_axes_2: Tuple[str, ...] = ('mlp', 'embed') kernel_axes_2: Tuple[str, ...] = ('mlp', 'embed')
use_bias: bool = False use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
bias_axes_1: Tuple[str, ...] = ('mlp',) bias_axes_1: Tuple[str, ...] = (
'act',
'mlp',
)
bias_axes_2: Tuple[str, ...] = ('embed',) bias_axes_2: Tuple[str, ...] = ('embed',)
return_layernorm_output: bool = True return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ('relu',) activations: Sequence[Union[str, Callable]] = ('relu',)
...@@ -961,10 +964,12 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -961,10 +964,12 @@ class LayerNormMLP(TransformerEngineBase):
bias = None bias = None
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes('wi_bias', bias = nn_partitioning.param_with_axes('wi_bias',
self.bias_init, (self.intermediate_dim,), self.bias_init,
intermediate_dim,
self.dtype, self.dtype,
axes=self.bias_axes_1) axes=self.bias_axes_1)
x += jnp.reshape(bias, (1,) * (x.ndim - 1) + (-1,)) bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape
x += jnp.reshape(bias, bias_shape)
if self.activations == ('gelu', 'linear'): if self.activations == ('gelu', 'linear'):
z = geglu(x, z = geglu(x,
......
...@@ -411,6 +411,10 @@ class MultiHeadAttention(nn.Module): ...@@ -411,6 +411,10 @@ class MultiHeadAttention(nn.Module):
kernel_init=qkv_init, kernel_init=qkv_init,
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=(
'qkv_dim',
'joined_kv',
),
name='qkv', name='qkv',
dtype=self.dtype)(inputs_q) dtype=self.dtype)(inputs_q)
if not use_fused_attn: if not use_fused_attn:
...@@ -430,6 +434,7 @@ class MultiHeadAttention(nn.Module): ...@@ -430,6 +434,7 @@ class MultiHeadAttention(nn.Module):
kernel_axes=('embed', 'joined_kv'), kernel_axes=('embed', 'joined_kv'),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=('joined_kv',),
dtype=self.dtype, dtype=self.dtype,
kernel_init=query_init, kernel_init=query_init,
name='query')(inputs_q) name='query')(inputs_q)
...@@ -441,6 +446,10 @@ class MultiHeadAttention(nn.Module): ...@@ -441,6 +446,10 @@ class MultiHeadAttention(nn.Module):
kernel_init=kv_init, kernel_init=kv_init,
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=(
'kv_dim',
'joined_kv',
),
name='kv', name='kv',
dtype=self.dtype)(inputs_kv) dtype=self.dtype)(inputs_kv)
if not use_fused_attn: if not use_fused_attn:
...@@ -455,6 +464,7 @@ class MultiHeadAttention(nn.Module): ...@@ -455,6 +464,7 @@ class MultiHeadAttention(nn.Module):
kernel_axes=('embed', 'joined_kv'), kernel_axes=('embed', 'joined_kv'),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=('joined_kv',),
dtype=self.dtype) dtype=self.dtype)
query, ln_out = LayerNormDenseGeneral( query, ln_out = LayerNormDenseGeneral(
enable_layernorm=not self.output_layernorm, enable_layernorm=not self.output_layernorm,
...@@ -470,6 +480,7 @@ class MultiHeadAttention(nn.Module): ...@@ -470,6 +480,7 @@ class MultiHeadAttention(nn.Module):
kernel_axes=('embed', 'joined_kv'), kernel_axes=('embed', 'joined_kv'),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=('joined_kv',),
dtype=self.dtype, dtype=self.dtype,
kernel_init=query_init, kernel_init=query_init,
name='query')(inputs_q) name='query')(inputs_q)
...@@ -622,6 +633,7 @@ class MultiHeadAttention(nn.Module): ...@@ -622,6 +633,7 @@ class MultiHeadAttention(nn.Module):
kernel_axes=('joined_kv', 'embed'), kernel_axes=('joined_kv', 'embed'),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=('embed',),
dtype=self.dtype, dtype=self.dtype,
name='out')(x) name='out')(x)
return out, residual return out, residual
...@@ -1059,6 +1071,11 @@ class TransformerLayer(nn.Module): ...@@ -1059,6 +1071,11 @@ class TransformerLayer(nn.Module):
kernel_axes_2=('mlp', 'embed'), kernel_axes_2=('mlp', 'embed'),
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes_1=(
'act',
'mlp',
),
bias_axes_2=('embed',),
name='mlp', name='mlp',
)(mlp_input, deterministic=deterministic) )(mlp_input, deterministic=deterministic)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment