Unverified Commit ee2a3400 authored by Eran Hirsch's avatar Eran Hirsch Committed by GitHub
Browse files

Fix LongT5ForConditionalGeneration initialization of lm_head (#28873)

parent 1ea0bbd7
...@@ -1301,6 +1301,8 @@ class LongT5PreTrainedModel(PreTrainedModel): ...@@ -1301,6 +1301,8 @@ class LongT5PreTrainedModel(PreTrainedModel):
# Mesh TensorFlow embeddings initialization # Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
elif isinstance(module, LongT5DenseActDense): elif isinstance(module, LongT5DenseActDense):
# Mesh TensorFlow FF initialization # Mesh TensorFlow FF initialization
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment