Unverified Commit 7a585983 authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

[JAX] Fused layers argument default values changed (#2347)



* Changing default activations in MLP, TransformerLayer, dropout rate after FC1 to 0, and return_layernorm_output to False
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Fixing the failing tests by hard coding  arguments to the previous values instead of relying on newer default values
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 5ea83432
...@@ -389,6 +389,7 @@ class TestDistributedLayernormMLP: ...@@ -389,6 +389,7 @@ class TestDistributedLayernormMLP:
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
activations=activation_type, activations=activation_type,
use_bias=use_bias, use_bias=use_bias,
return_layernorm_output=True,
) )
params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) params_single = ln_mlp_single.init(init_rngs, x, deterministic=True)
mlp_out_single, ln_out_single = ln_mlp_single.apply( mlp_out_single, ln_out_single = ln_mlp_single.apply(
...@@ -417,6 +418,7 @@ class TestDistributedLayernormMLP: ...@@ -417,6 +418,7 @@ class TestDistributedLayernormMLP:
dot_1_input_axes=DOT_1_INPUT_AXES, dot_1_input_axes=DOT_1_INPUT_AXES,
dot_2_input_axes=DOT_2_INPUT_AXES, dot_2_input_axes=DOT_2_INPUT_AXES,
name="mlp", name="mlp",
return_layernorm_output=True,
) )
params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True) params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True)
mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply( mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
......
...@@ -364,9 +364,9 @@ class MlpBlock(nn.Module): ...@@ -364,9 +364,9 @@ class MlpBlock(nn.Module):
transpose_batch_sequence: bool transpose_batch_sequence: bool
intermediate_dim: int = 2048 intermediate_dim: int = 2048
activations: Sequence[Union[str, Callable]] = ("relu",) activations: Sequence[Union[str, Callable]] = ("gelu",)
kernel_init: Initializer = None kernel_init: Initializer = None
intermediate_dropout_rate: float = 0.1 intermediate_dropout_rate: float = 0.0
intermediate_dropout_dims: Sequence[int] = () intermediate_dropout_dims: Sequence[int] = ()
use_bias: bool = False use_bias: bool = False
dtype: Any = jnp.float32 dtype: Any = jnp.float32
...@@ -1035,14 +1035,14 @@ class EncoderLayer(nn.Module): ...@@ -1035,14 +1035,14 @@ class EncoderLayer(nn.Module):
hidden_dropout: float = 0.1 hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = () hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1 attention_dropout: float = 0.1
intermediate_dropout: float = 0.1 intermediate_dropout: float = 0.0
intermediate_dropout_dims: Sequence[int] = () intermediate_dropout_dims: Sequence[int] = ()
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
float32_attention_logits: bool = False float32_attention_logits: bool = False
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
mlp_dim: int = 2048 mlp_dim: int = 2048
mlp_activations: Sequence[str] = ("relu",) mlp_activations: Sequence[str] = ("gelu",)
use_bias: bool = False use_bias: bool = False
dtype: Any = jnp.float32 dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
...@@ -1199,14 +1199,14 @@ class DecoderLayer(nn.Module): ...@@ -1199,14 +1199,14 @@ class DecoderLayer(nn.Module):
hidden_dropout: float = 0.1 hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = () hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1 attention_dropout: float = 0.1
intermediate_dropout: float = 0.1 intermediate_dropout: float = 0.0
intermediate_dropout_dims: Sequence[int] = () intermediate_dropout_dims: Sequence[int] = ()
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
float32_attention_logits: bool = False float32_attention_logits: bool = False
scale_attn_logits: bool = False scale_attn_logits: bool = False
scaled_query_init: bool = True scaled_query_init: bool = True
mlp_dim: int = 2048 mlp_dim: int = 2048
mlp_activations: Sequence[str] = ("relu",) mlp_activations: Sequence[str] = ("gelu",)
use_bias: bool = False use_bias: bool = False
dtype: Any = jnp.float32 dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
......
...@@ -597,7 +597,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -597,7 +597,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias_axes: Tuple[str, ...], default = () bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh, The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`. only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = True return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization. Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs. If set False, return None as the second tensor in outputs.
enable_low_rank_adaptation: bool, default = False enable_low_rank_adaptation: bool, default = False
...@@ -644,7 +644,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -644,7 +644,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
use_bias: bool = False use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = () bias_axes: Tuple[str, ...] = ()
return_layernorm_output: bool = True return_layernorm_output: bool = False
enable_low_rank_adaptation: bool = False enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32 low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None low_rank_adaptation_alpha: float = None
...@@ -891,10 +891,10 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -891,10 +891,10 @@ class LayerNormMLP(TransformerEngineBase):
The name of axes used to shard bias with a corresponding mesh for The name of axes used to shard bias with a corresponding mesh for
the weight of the second dense layer transformation. the weight of the second dense layer transformation.
Only used when :attr:`use_bias=True`. Only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = True return_layernorm_output: bool, default = False
Indicate whether to return the output of layer normalization. Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs. If set False, return None as the second tensor in outputs.
activations: Sequence[Union[str, Callable]], default = ('relu',) activations: Sequence[Union[str, Callable]], default = ('gelu',)
The sequence of activation functions to apply after the first dense layer transformation. The sequence of activation functions to apply after the first dense layer transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
activation_params: dict, default = None activation_params: dict, default = None
...@@ -903,7 +903,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -903,7 +903,7 @@ class LayerNormMLP(TransformerEngineBase):
need additional parameters. need additional parameters.
intermediate_dropout_rng_name: str, default = 'dropout' intermediate_dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks. The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
intermediate_dropout_rate: float, default = 0.1 intermediate_dropout_rate: float, default = 0.0
Dropout probability for the dropout op after the :attr:`activations`. Dropout probability for the dropout op after the :attr:`activations`.
intermediate_hidden_dropout_dims: Sequence[int], default = () intermediate_hidden_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden Dimensions that will share the same dropout mask for hidden
...@@ -959,11 +959,11 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -959,11 +959,11 @@ class LayerNormMLP(TransformerEngineBase):
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
bias_axes_1: Tuple[str, ...] = ("act", "mlp") bias_axes_1: Tuple[str, ...] = ("act", "mlp")
bias_axes_2: Tuple[str, ...] = ("embed",) bias_axes_2: Tuple[str, ...] = ("embed",)
return_layernorm_output: bool = True return_layernorm_output: bool = False
activations: Sequence[Union[str, Callable]] = ("relu",) activations: Sequence[Union[str, Callable]] = ("gelu",)
activation_params: dict = None activation_params: dict = None
intermediate_dropout_rng_name: str = "dropout" intermediate_dropout_rng_name: str = "dropout"
intermediate_dropout_rate: float = 0.1 intermediate_dropout_rate: float = 0.0
intermediate_hidden_dropout_dims: Sequence[int] = () intermediate_hidden_dropout_dims: Sequence[int] = ()
enable_low_rank_adaptation: bool = False enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32 low_rank_adaptation_dim: int = 32
......
...@@ -1620,7 +1620,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1620,7 +1620,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Dimensions that will share the same dropout mask for hidden Dimensions that will share the same dropout mask for hidden
attention_dropout: float, default = 0.1 attention_dropout: float, default = 0.1
Dropout probability for the dropout op during multi-head attention. Dropout probability for the dropout op during multi-head attention.
intermediate_dropout: float, default = 0.1 intermediate_dropout: float, default = 0.0
Dropout probability for the dropout op after FC1 layer. Dropout probability for the dropout op after FC1 layer.
intermediate_dropout_dims: Sequence[int], default = () intermediate_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden after FC1 layer. Dimensions that will share the same dropout mask for hidden after FC1 layer.
...@@ -1635,7 +1635,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1635,7 +1635,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing weights of FC1 and FC2 layers. Used for initializing weights of FC1 and FC2 layers.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
mlp_activations: Sequence[str], default = ('relu', ) mlp_activations: Sequence[str], default = ('gelu', )
The sequence of activation functions to apply after the first linear transformation. The sequence of activation functions to apply after the first linear transformation.
Each activation has its own transformation layer. Each activation has its own transformation layer.
mlp_activation_params: dict = None mlp_activation_params: dict = None
...@@ -1755,12 +1755,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1755,12 +1755,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
hidden_dropout: float = 0.1 hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = () hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1 attention_dropout: float = 0.1
intermediate_dropout: float = 0.1 intermediate_dropout: float = 0.0
intermediate_dropout_dims: Sequence[int] = () intermediate_dropout_dims: Sequence[int] = ()
dropout_rng_name: str = "dropout" dropout_rng_name: str = "dropout"
mha_kernel_init: Initializer = None mha_kernel_init: Initializer = None
mlp_kernel_init: Initializer = None mlp_kernel_init: Initializer = None
mlp_activations: Sequence[str] = ("relu",) mlp_activations: Sequence[str] = ("gelu",)
mlp_activation_params: dict = None mlp_activation_params: dict = None
use_bias: bool = False use_bias: bool = False
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment