Unverified Commit a2b7d19b authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Fix seq2seq doc tests (#16606)

* fix bart and mbart

* add ckpt names as variables

* fix mbart

* fix plbart

* use varibale for ckot name
parent 0bf18643
...@@ -56,11 +56,14 @@ _TOKENIZER_FOR_DOC = "BartTokenizer" ...@@ -56,11 +56,14 @@ _TOKENIZER_FOR_DOC = "BartTokenizer"
_EXPECTED_OUTPUT_SHAPE = [1, 8, 768] _EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
# SequenceClassification docstring # SequenceClassification docstring
_SEQ_CLASS_EXPECTED_OUTPUT_SHAPE = [1, 2] _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "valhalla/bart-large-sst2"
_SEQ_CLASS_EXPECTED_LOSS = 0.0
_SEQ_CLASS_EXPECTED_OUTPUT = "'POSITIVE'"
# QuestionAsnwering docstring # QuestionAsnwering docstring
_QA_EXPECTED_LOSS = 2.98 _CHECKPOINT_FOR_QA = "valhalla/bart-large-finetuned-squadv1"
_QA_EXPECTED_OUTPUT_SHAPE = [1, 17] _QA_EXPECTED_LOSS = 0.59
_QA_EXPECTED_OUTPUT = "' nice puppet'"
BART_PRETRAINED_MODEL_ARCHIVE_LIST = [ BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
...@@ -1447,10 +1450,11 @@ class BartForSequenceClassification(BartPretrainedModel): ...@@ -1447,10 +1450,11 @@ class BartForSequenceClassification(BartPretrainedModel):
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
output_type=Seq2SeqSequenceClassifierOutput, output_type=Seq2SeqSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT_SHAPE, expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
) )
def forward( def forward(
self, self,
...@@ -1572,11 +1576,11 @@ class BartForQuestionAnswering(BartPretrainedModel): ...@@ -1572,11 +1576,11 @@ class BartForQuestionAnswering(BartPretrainedModel):
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_QA,
output_type=Seq2SeqQuestionAnsweringModelOutput, output_type=Seq2SeqQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_loss=_QA_EXPECTED_LOSS, expected_loss=_QA_EXPECTED_LOSS,
expected_output=_QA_EXPECTED_OUTPUT_SHAPE, expected_output=_QA_EXPECTED_OUTPUT,
) )
def forward( def forward(
self, self,
......
...@@ -51,17 +51,20 @@ logger = logging.get_logger(__name__) ...@@ -51,17 +51,20 @@ logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "google/bigbird-pegasus-large-arxiv" _CHECKPOINT_FOR_DOC = "google/bigbird-pegasus-large-arxiv"
_CONFIG_FOR_DOC = "BigBirdPegasusConfig" _CONFIG_FOR_DOC = "BigBirdPegasusConfig"
_TOKENIZER_FOR_DOC = "PegasusTokenizer" _TOKENIZER_FOR_DOC = "PegasusTokenizerFast"
# Base model docstring # Base model docstring
_EXPECTED_OUTPUT_SHAPE = [1, 7, 1024] _EXPECTED_OUTPUT_SHAPE = [1, 7, 1024]
# SequenceClassification docstring # SequenceClassification docstring
_SEQ_CLASS_EXPECTED_OUTPUT_SHAPE = [1, 2] _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "hf-internal-testing/tiny-random-bigbird_pegasus"
_SEQ_CLASS_EXPECTED_LOSS = 0.69
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
# QuestionAsnwering docstring # QuestionAsnwering docstring
_QA_EXPECTED_LOSS = 2.56 _CHECKPOINT_FOR_QA = "hf-internal-testing/tiny-random-bigbird_pegasus"
_QA_EXPECTED_OUTPUT_SHAPE = [1, 12] _QA_EXPECTED_LOSS = 3.96
_QA_EXPECTED_OUTPUT = "''"
BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST = [ BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST = [
...@@ -2645,10 +2648,11 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel): ...@@ -2645,10 +2648,11 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
@add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
output_type=Seq2SeqSequenceClassifierOutput, output_type=Seq2SeqSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT_SHAPE, expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
) )
def forward( def forward(
self, self,
...@@ -2771,11 +2775,11 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel): ...@@ -2771,11 +2775,11 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
@add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_QA,
output_type=Seq2SeqQuestionAnsweringModelOutput, output_type=Seq2SeqQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_loss=_QA_EXPECTED_LOSS, expected_loss=_QA_EXPECTED_LOSS,
expected_output=_QA_EXPECTED_OUTPUT_SHAPE, expected_output=_QA_EXPECTED_OUTPUT,
) )
def forward( def forward(
self, self,
......
...@@ -55,11 +55,14 @@ _TOKENIZER_FOR_DOC = "MBartTokenizer" ...@@ -55,11 +55,14 @@ _TOKENIZER_FOR_DOC = "MBartTokenizer"
_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
# SequenceClassification docstring # SequenceClassification docstring
_SEQ_CLASS_EXPECTED_OUTPUT_SHAPE = [1, 2] _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "hf-internal-testing/tiny-random-mbart"
_SEQ_CLASS_EXPECTED_LOSS = 0.69
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
# QuestionAsnwering docstring # QuestionAsnwering docstring
_QA_EXPECTED_LOSS = 3.04 _CHECKPOINT_FOR_QA = "hf-internal-testing/tiny-random-mbart"
_QA_EXPECTED_OUTPUT_SHAPE = [1, 16] _QA_EXPECTED_LOSS = 3.55
_QA_EXPECTED_OUTPUT = "'? Jim Henson was a'"
MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [ MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [
...@@ -1437,10 +1440,11 @@ class MBartForSequenceClassification(MBartPreTrainedModel): ...@@ -1437,10 +1440,11 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
output_type=Seq2SeqSequenceClassifierOutput, output_type=Seq2SeqSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT_SHAPE, expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
) )
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
def forward( def forward(
...@@ -1563,11 +1567,11 @@ class MBartForQuestionAnswering(MBartPreTrainedModel): ...@@ -1563,11 +1567,11 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_QA,
output_type=Seq2SeqQuestionAnsweringModelOutput, output_type=Seq2SeqQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_loss=_QA_EXPECTED_LOSS, expected_loss=_QA_EXPECTED_LOSS,
expected_output=_QA_EXPECTED_OUTPUT_SHAPE, expected_output=_QA_EXPECTED_OUTPUT,
) )
# Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward # Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward
def forward( def forward(
......
...@@ -54,7 +54,9 @@ _TOKENIZER_FOR_DOC = "PLBartTokenizer" ...@@ -54,7 +54,9 @@ _TOKENIZER_FOR_DOC = "PLBartTokenizer"
_EXPECTED_OUTPUT_SHAPE = [1, 8, 768] _EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
# SequenceClassification docstring # SequenceClassification docstring
_SEQ_CLASS_EXPECTED_OUTPUT_SHAPE = [1, 2] _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "hf-internal-testing/tiny-plbart"
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
_SEQ_CLASS_EXPECTED_LOSS = 0.69
PLBART_PRETRAINED_MODEL_ARCHIVE_LIST = [ PLBART_PRETRAINED_MODEL_ARCHIVE_LIST = [
...@@ -1408,10 +1410,11 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel): ...@@ -1408,10 +1410,11 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel):
@add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
output_type=Seq2SeqSequenceClassifierOutput, output_type=Seq2SeqSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT_SHAPE, expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
) )
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment