Unverified Commit 7a585983 authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

[JAX] Fused layers argument default values changed (#2347)



* Changing default activations in MLP, TransformerLayer, dropout rate after FC1 to 0, and return_layernorm_output to False
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Fixing the failing tests by hard coding  arguments to the previous values instead of relying on newer default values
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 5ea83432
......@@ -389,6 +389,7 @@ class TestDistributedLayernormMLP:
intermediate_dim=INTERMEDIATE,
activations=activation_type,
use_bias=use_bias,
return_layernorm_output=True,
)
params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
mlp_out_single, ln_out_single = ln_mlp_single.apply(
......@@ -417,6 +418,7 @@ class TestDistributedLayernormMLP:
dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES,
name="mlp",
return_layernorm_output=True,
)
params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
......
......@@ -364,9 +364,9 @@ class MlpBlock(nn.Module):
transpose_batch_sequence: bool
intermediate_dim: int = 2048
activations: Sequence[Union[str, Callable]] = ("relu",)
activations: Sequence[Union[str, Callable]] = ("gelu",)
kernel_init: Initializer = None
intermediate_dropout_rate: float = 0.1
intermediate_dropout_rate: float = 0.0
intermediate_dropout_dims: Sequence[int] = ()
use_bias: bool = False
dtype: Any = jnp.float32
......@@ -1035,14 +1035,14 @@ class EncoderLayer(nn.Module):
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout: float = 0.0
intermediate_dropout_dims: Sequence[int] = ()
transpose_batch_sequence: bool = True
float32_attention_logits: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
mlp_dim: int = 2048
mlp_activations: Sequence[str] = ("relu",)
mlp_activations: Sequence[str] = ("gelu",)
use_bias: bool = False
dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False
......@@ -1199,14 +1199,14 @@ class DecoderLayer(nn.Module):
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout: float = 0.0
intermediate_dropout_dims: Sequence[int] = ()
transpose_batch_sequence: bool = True
float32_attention_logits: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
mlp_dim: int = 2048
mlp_activations: Sequence[str] = ("relu",)
mlp_activations: Sequence[str] = ("gelu",)
use_bias: bool = False
dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False
......
......@@ -597,7 +597,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = True
return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
enable_low_rank_adaptation: bool, default = False
......@@ -644,7 +644,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ()
return_layernorm_output: bool = True
return_layernorm_output: bool = False
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
......@@ -891,10 +891,10 @@ class LayerNormMLP(TransformerEngineBase):
The name of axes used to shard bias with a corresponding mesh for
the weight of the second dense layer transformation.
Only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = True
return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
activations: Sequence[Union[str, Callable]], default = ('relu',)
activations: Sequence[Union[str, Callable]], default = ('gelu',)
The sequence of activation functions to apply after the first dense layer transformation.
Each activation has its own transformation layer.
activation_params: dict, default = None
......@@ -903,7 +903,7 @@ class LayerNormMLP(TransformerEngineBase):
need additional parameters.
intermediate_dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
intermediate_dropout_rate: float, default = 0.1
intermediate_dropout_rate: float, default = 0.0
Dropout probability for the dropout op after the :attr:`activations`.
intermediate_hidden_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden
......@@ -959,11 +959,11 @@ class LayerNormMLP(TransformerEngineBase):
bias_init: Initializer = nn.initializers.zeros
bias_axes_1: Tuple[str, ...] = ("act", "mlp")
bias_axes_2: Tuple[str, ...] = ("embed",)
return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ("relu",)
return_layernorm_output: bool = False
activations: Sequence[Union[str, Callable]] = ("gelu",)
activation_params: dict = None
intermediate_dropout_rng_name: str = "dropout"
intermediate_dropout_rate: float = 0.1
intermediate_dropout_rate: float = 0.0
intermediate_hidden_dropout_dims: Sequence[int] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
......
......@@ -1620,7 +1620,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Dimensions that will share the same dropout mask for hidden
attention_dropout: float, default = 0.1
Dropout probability for the dropout op during multi-head attention.
intermediate_dropout: float, default = 0.1
intermediate_dropout: float, default = 0.0
Dropout probability for the dropout op after FC1 layer.
intermediate_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden after FC1 layer.
......@@ -1635,7 +1635,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing weights of FC1 and FC2 layers.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
mlp_activations: Sequence[str], default = ('relu', )
mlp_activations: Sequence[str], default = ('gelu', )
The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer.
mlp_activation_params: dict = None
......@@ -1755,12 +1755,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout: float = 0.0
intermediate_dropout_dims: Sequence[int] = ()
dropout_rng_name: str = "dropout"
mha_kernel_init: Initializer = None
mlp_kernel_init: Initializer = None
mlp_activations: Sequence[str] = ("relu",)
mlp_activations: Sequence[str] = ("gelu",)
mlp_activation_params: dict = None
use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment