Unverified Commit 60ad7344 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[T5] Fix init in TF and Flax for pretraining (#17294)



* fix init

* Apply suggestions from code review

* fix

* finish

* Update src/transformers/modeling_tf_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 7ba1d4e5
......@@ -768,6 +768,8 @@ class T5PreTrainedModel(PreTrainedModel):
# Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
elif isinstance(module, T5DenseReluDense):
# Mesh TensorFlow FF initialization
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
......
......@@ -1112,7 +1112,9 @@ num_heads))`.
class TFT5Model(TFT5PreTrainedModel):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
self.shared = TFSharedEmbeddings(
config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
)
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
......@@ -1259,8 +1261,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model_dim = config.d_model
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
self.shared = TFSharedEmbeddings(
config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
)
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
......@@ -1600,7 +1603,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
class TFT5EncoderModel(TFT5PreTrainedModel):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
self.shared = TFSharedEmbeddings(
config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
)
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment