Unverified Commit 1e3c3622 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix template (#9512)

parent d415882b
...@@ -462,7 +462,7 @@ class TF{{cookiecutter.camelcase_modelname}}LMPredictionHead(tf.keras.layers.Lay ...@@ -462,7 +462,7 @@ class TF{{cookiecutter.camelcase_modelname}}LMPredictionHead(tf.keras.layers.Lay
super().build(input_shape) super().build(input_shape)
def get_output_embeddings(self): def get_output_embeddings(self):
return self.input_embeddings.word_embeddings return self.input_embeddings
def set_output_embeddings(self, value): def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value self.input_embeddings.word_embeddings = value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment