Unverified Commit 82e5b4d2 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Fixed the shape miss-matching issue in MLP. (#859)



* Fixed the shape mismatching issue in MLP.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Add a corresponding test
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent 01801633
......@@ -177,6 +177,8 @@ ATTRS = [{}, {
_KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias",
}, {
_KEY_OF_ATTENTION_DROPOUT: 0.3,
}, {
_KEY_OF_MLP_ACTIVATIONS: (('relu', 'relu')),
}]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
......
......@@ -1148,7 +1148,7 @@ class LayerNormMLP(TransformerEngineBase):
x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i)
z = functools.reduce(operator.mul, activations)
if num_activations == 1:
# Remove act axis
z = jnp.reshape(z, (*z.shape[:-2], -1))
z = nn.Dropout(rate=self.intermediate_dropout_rate,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment