Unverified Commit 82e5b4d2 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Fixed the shape miss-matching issue in MLP. (#859)



* Fixed the shape mismatching issue in MLP.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Add a corresponding test
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent 01801633
...@@ -177,6 +177,8 @@ ATTRS = [{}, { ...@@ -177,6 +177,8 @@ ATTRS = [{}, {
_KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias", _KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias",
}, { }, {
_KEY_OF_ATTENTION_DROPOUT: 0.3, _KEY_OF_ATTENTION_DROPOUT: 0.3,
}, {
_KEY_OF_MLP_ACTIVATIONS: (('relu', 'relu')),
}] }]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
......
...@@ -1148,8 +1148,8 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1148,8 +1148,8 @@ class LayerNormMLP(TransformerEngineBase):
x_i = _convert_to_activation_function(act_fn)(x[idx]) x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i) activations.append(x_i)
z = functools.reduce(operator.mul, activations) z = functools.reduce(operator.mul, activations)
if num_activations == 1: # Remove act axis
z = jnp.reshape(z, (*z.shape[:-2], -1)) z = jnp.reshape(z, (*z.shape[:-2], -1))
z = nn.Dropout(rate=self.intermediate_dropout_rate, z = nn.Dropout(rate=self.intermediate_dropout_rate,
broadcast_dims=self.intermediate_hidden_dropout_dims, broadcast_dims=self.intermediate_hidden_dropout_dims,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment