Unverified Commit 607acd4f authored by DanielHesslow's avatar DanielHesslow Committed by GitHub
Browse files

Add Gated-SiLU to T5 (#17420)



* Add gated-silu to t5 architecture to support UL2

* Fix error message

* formatting

* formatting again

* refactor

* fix classnames in _init_weights

* remove is_gated

* add test

* fix test

* Try without the test?

* Add back the test.

* Improve error message.
Co-authored-by: default avatarDaniel Hesslow <daniel@lighton.ai>
parent 1c220ced
...@@ -116,6 +116,22 @@ class T5Config(PretrainedConfig): ...@@ -116,6 +116,22 @@ class T5Config(PretrainedConfig):
self.initializer_factor = initializer_factor self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj self.feed_forward_proj = feed_forward_proj
self.use_cache = use_cache self.use_cache = use_cache
act_info = self.feed_forward_proj.split("-")
self.dense_act_fn = act_info[-1]
self.is_gated_act = act_info[0] == "gated"
if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
raise ValueError(
f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer."
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
"'gated-gelu' or 'relu'"
)
# for backwards compatibility
if feed_forward_proj == "gated-gelu":
self.dense_act_fn = "gelu_new"
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
......
...@@ -87,7 +87,7 @@ class FlaxT5LayerNorm(nn.Module): ...@@ -87,7 +87,7 @@ class FlaxT5LayerNorm(nn.Module):
return self.weight * hidden_states return self.weight * hidden_states
class FlaxT5DenseReluDense(nn.Module): class FlaxT5DenseActDense(nn.Module):
config: T5Config config: T5Config
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
...@@ -108,16 +108,17 @@ class FlaxT5DenseReluDense(nn.Module): ...@@ -108,16 +108,17 @@ class FlaxT5DenseReluDense(nn.Module):
dtype=self.dtype, dtype=self.dtype,
) )
self.dropout = nn.Dropout(self.config.dropout_rate) self.dropout = nn.Dropout(self.config.dropout_rate)
self.act = ACT2FN[self.config.dense_act_fn]
def __call__(self, hidden_states, deterministic=True): def __call__(self, hidden_states, deterministic=True):
hidden_states = self.wi(hidden_states) hidden_states = self.wi(hidden_states)
hidden_states = jax.nn.relu(hidden_states) hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic) hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.wo(hidden_states) hidden_states = self.wo(hidden_states)
return hidden_states return hidden_states
class FlaxT5DenseGatedGeluDense(nn.Module): class FlaxT5DenseGatedActDense(nn.Module):
config: T5Config config: T5Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
...@@ -144,10 +145,10 @@ class FlaxT5DenseGatedGeluDense(nn.Module): ...@@ -144,10 +145,10 @@ class FlaxT5DenseGatedGeluDense(nn.Module):
dtype=self.dtype, dtype=self.dtype,
) )
self.dropout = nn.Dropout(self.config.dropout_rate) self.dropout = nn.Dropout(self.config.dropout_rate)
self.gelu_act = ACT2FN["gelu_new"] self.act = ACT2FN[self.config.dense_act_fn]
def __call__(self, hidden_states, deterministic): def __call__(self, hidden_states, deterministic):
hidden_gelu = self.gelu_act(self.wi_0(hidden_states)) hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states) hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states, deterministic=deterministic) hidden_states = self.dropout(hidden_states, deterministic=deterministic)
...@@ -160,14 +161,10 @@ class FlaxT5LayerFF(nn.Module): ...@@ -160,14 +161,10 @@ class FlaxT5LayerFF(nn.Module):
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self): def setup(self):
if self.config.feed_forward_proj == "relu": if self.config.is_gated_act:
self.DenseReluDense = FlaxT5DenseReluDense(self.config, dtype=self.dtype) self.DenseReluDense = FlaxT5DenseGatedActDense(self.config, dtype=self.dtype)
elif self.config.feed_forward_proj == "gated-gelu":
self.DenseReluDense = FlaxT5DenseGatedGeluDense(self.config, dtype=self.dtype)
else: else:
raise ValueError( self.DenseReluDense = FlaxT5DenseActDense(self.config, dtype=self.dtype)
f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
)
self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype) self.layer_norm = FlaxT5LayerNorm(self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype)
self.dropout = nn.Dropout(self.config.dropout_rate) self.dropout = nn.Dropout(self.config.dropout_rate)
......
...@@ -276,33 +276,33 @@ except Exception: ...@@ -276,33 +276,33 @@ except Exception:
pass pass
class T5DenseReluDense(nn.Module): class T5DenseActDense(nn.Module):
def __init__(self, config: T5Config): def __init__(self, config: T5Config):
super().__init__() super().__init__()
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
self.relu_act = ACT2FN["relu"] self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.wi(hidden_states) hidden_states = self.wi(hidden_states)
hidden_states = self.relu_act(hidden_states) hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states) hidden_states = self.wo(hidden_states)
return hidden_states return hidden_states
class T5DenseGatedGeluDense(nn.Module): class T5DenseGatedActDense(nn.Module):
def __init__(self, config: T5Config): def __init__(self, config: T5Config):
super().__init__() super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
self.gelu_act = ACT2FN["gelu_new"] self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_gelu = self.gelu_act(self.wi_0(hidden_states)) hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states) hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
...@@ -313,14 +313,10 @@ class T5DenseGatedGeluDense(nn.Module): ...@@ -313,14 +313,10 @@ class T5DenseGatedGeluDense(nn.Module):
class T5LayerFF(nn.Module): class T5LayerFF(nn.Module):
def __init__(self, config: T5Config): def __init__(self, config: T5Config):
super().__init__() super().__init__()
if config.feed_forward_proj == "relu": if config.is_gated_act:
self.DenseReluDense = T5DenseReluDense(config) self.DenseReluDense = T5DenseGatedActDense(config)
elif config.feed_forward_proj == "gated-gelu":
self.DenseReluDense = T5DenseGatedGeluDense(config)
else: else:
raise ValueError( self.DenseReluDense = T5DenseActDense(config)
f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
...@@ -769,7 +765,7 @@ class T5PreTrainedModel(PreTrainedModel): ...@@ -769,7 +765,7 @@ class T5PreTrainedModel(PreTrainedModel):
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
elif isinstance(module, T5DenseReluDense): elif isinstance(module, T5DenseActDense):
# Mesh TensorFlow FF initialization # Mesh TensorFlow FF initialization
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
# and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
...@@ -779,7 +775,7 @@ class T5PreTrainedModel(PreTrainedModel): ...@@ -779,7 +775,7 @@ class T5PreTrainedModel(PreTrainedModel):
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None: if hasattr(module.wo, "bias") and module.wo.bias is not None:
module.wo.bias.data.zero_() module.wo.bias.data.zero_()
elif isinstance(module, T5DenseGatedGeluDense): elif isinstance(module, T5DenseGatedActDense):
module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
module.wi_0.bias.data.zero_() module.wi_0.bias.data.zero_()
......
...@@ -93,7 +93,7 @@ class TFT5LayerNorm(tf.keras.layers.Layer): ...@@ -93,7 +93,7 @@ class TFT5LayerNorm(tf.keras.layers.Layer):
return self.weight * hidden_states return self.weight * hidden_states
class TFT5DenseReluDense(tf.keras.layers.Layer): class TFT5DenseActDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
wi_initializer = tf.keras.initializers.RandomNormal( wi_initializer = tf.keras.initializers.RandomNormal(
...@@ -109,7 +109,7 @@ class TFT5DenseReluDense(tf.keras.layers.Layer): ...@@ -109,7 +109,7 @@ class TFT5DenseReluDense(tf.keras.layers.Layer):
config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer
) # Update init weights as in flax ) # Update init weights as in flax
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
self.act = tf.keras.activations.relu self.act = get_tf_activation(config.dense_act_fn)
def call(self, hidden_states, training=False): def call(self, hidden_states, training=False):
hidden_states = self.wi(hidden_states) hidden_states = self.wi(hidden_states)
...@@ -119,7 +119,7 @@ class TFT5DenseReluDense(tf.keras.layers.Layer): ...@@ -119,7 +119,7 @@ class TFT5DenseReluDense(tf.keras.layers.Layer):
return hidden_states return hidden_states
class TFT5GatedGeluDense(tf.keras.layers.Layer): class TFT5DenseGatedActDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
wi_initializer = tf.keras.initializers.RandomNormal( wi_initializer = tf.keras.initializers.RandomNormal(
...@@ -138,7 +138,7 @@ class TFT5GatedGeluDense(tf.keras.layers.Layer): ...@@ -138,7 +138,7 @@ class TFT5GatedGeluDense(tf.keras.layers.Layer):
config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer config.d_model, use_bias=False, name="wo", kernel_initializer=wo_initializer
) # Update init weights as in flax ) # Update init weights as in flax
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
self.act = get_tf_activation("gelu_new") self.act = get_tf_activation(config.dense_act_fn)
def call(self, hidden_states, training=False): def call(self, hidden_states, training=False):
hidden_gelu = self.act(self.wi_0(hidden_states)) hidden_gelu = self.act(self.wi_0(hidden_states))
...@@ -152,14 +152,11 @@ class TFT5GatedGeluDense(tf.keras.layers.Layer): ...@@ -152,14 +152,11 @@ class TFT5GatedGeluDense(tf.keras.layers.Layer):
class TFT5LayerFF(tf.keras.layers.Layer): class TFT5LayerFF(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if config.feed_forward_proj == "relu": if config.is_gated_act:
self.DenseReluDense = TFT5DenseReluDense(config, name="DenseReluDense") self.DenseReluDense = TFT5DenseGatedActDense(config, name="DenseReluDense")
elif config.feed_forward_proj == "gated-gelu":
self.DenseReluDense = TFT5GatedGeluDense(config, name="DenseReluDense")
else: else:
raise ValueError( self.DenseReluDense = TFT5DenseActDense(config, name="DenseReluDense")
f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
)
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
......
...@@ -539,6 +539,12 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -539,6 +539,12 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config.feed_forward_proj = "gated-gelu" config.feed_forward_proj = "gated-gelu"
self.model_tester.create_and_check_model(config, *config_and_inputs[1:]) self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
def test_config_and_model_silu_gated(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config = config_and_inputs[0]
config.feed_forward_proj = "gated-silu"
self.model_tester.create_and_check_model(*config_and_inputs)
def test_with_lm_head(self): def test_with_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_lm_head(*config_and_inputs) self.model_tester.create_and_check_with_lm_head(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment