"...ui/git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "0d9009c96ea5a1922a1c187deef27c7f133ee946"
Unverified Commit fa876aee authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix TF Flaubert and XLM (#9661)

* Fix Flaubert and XLM

* Fix Flaubert and XLM

* Apply style
parent 11ec7490
......@@ -214,10 +214,13 @@ class TFFlaubertPreTrainedModel(TFPreTrainedModel):
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
if self.config.use_lang_emb and self.config.n_langs > 1:
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
return {
"input_ids": inputs_list,
"attention_mask": attns_list,
"langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]),
}
else:
langs_list = None
return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
return {"input_ids": inputs_list, "attention_mask": attns_list}
@add_start_docstrings(
......
......@@ -536,10 +536,13 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
inputs_list = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
attns_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
if self.config.use_lang_emb and self.config.n_langs > 1:
langs_list = tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
return {
"input_ids": inputs_list,
"attention_mask": attns_list,
"langs": tf.constant([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]]),
}
else:
langs_list = None
return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
return {"input_ids": inputs_list, "attention_mask": attns_list}
# Remove when XLMWithLMHead computes loss like other LM models
......@@ -1045,10 +1048,16 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
Returns:
tf.Tensor with dummy inputs
"""
return {
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
"langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
}
# Sometimes XLM has language embeddings so don't forget to build them as well if needed
if self.config.use_lang_emb and self.config.n_langs > 1:
return {
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
"langs": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
}
else:
return {
"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS),
}
@add_start_docstrings_to_model_forward(XLM_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment