Commit cfb7d108 authored by Lysandre's avatar Lysandre Committed by Lysandre Debut
Browse files

FlauBERT lang embeddings only when n_langs > 1

parent b4691a43
...@@ -231,7 +231,7 @@ class FlaubertModel(XLMModel): ...@@ -231,7 +231,7 @@ class FlaubertModel(XLMModel):
inputs_embeds = self.embeddings(input_ids) inputs_embeds = self.embeddings(input_ids)
tensor = inputs_embeds + self.position_embeddings(position_ids).expand_as(inputs_embeds) tensor = inputs_embeds + self.position_embeddings(position_ids).expand_as(inputs_embeds)
if langs is not None and self.use_lang_emb: if langs is not None and self.use_lang_emb and self.config.n_langs > 1:
tensor = tensor + self.lang_embeddings(langs) tensor = tensor + self.lang_embeddings(langs)
if token_type_ids is not None: if token_type_ids is not None:
tensor = tensor + self.embeddings(token_type_ids) tensor = tensor + self.embeddings(token_type_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment