Commit 4620caa8 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

fix if use lang embeddings in tf xlm

parent fbd02d46
......@@ -408,7 +408,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
inputs_embeds = self.embeddings(input_ids)
tensor = inputs_embeds + self.position_embeddings(position_ids)
if langs is not None and self.use_lang_emb:
if langs is not None and self.use_lang_emb and self.n_langs > 1:
tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids)
......
......@@ -342,4 +342,4 @@ class TFXLMModelLanguageGenerationTest(unittest.TestCase):
] # the president the president the president the president the president the president the president the president the president the president
# TODO(PVP): this and other input_ids I tried for generation give pretty bad results. Not sure why. Model might just not be made for auto-regressive inference
output_ids = model.generate(input_ids)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids, do_sample=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment