"...composable_kernel_rocm.git" did not exist on "a3c80265185ae1a8489d605675a71af206c5b5eb"
Unverified Commit a2b7d19b authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Fix seq2seq doc tests (#16606)

* fix bart and mbart

* add ckpt names as variables

* fix mbart

* fix plbart

* use varibale for ckot name
parent 0bf18643
......@@ -56,11 +56,14 @@ _TOKENIZER_FOR_DOC = "BartTokenizer"
_EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
# SequenceClassification docstring
_SEQ_CLASS_EXPECTED_OUTPUT_SHAPE = [1, 2]
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "valhalla/bart-large-sst2"
_SEQ_CLASS_EXPECTED_LOSS = 0.0
_SEQ_CLASS_EXPECTED_OUTPUT = "'POSITIVE'"
# QuestionAsnwering docstring
_QA_EXPECTED_LOSS = 2.98
_QA_EXPECTED_OUTPUT_SHAPE = [1, 17]
_CHECKPOINT_FOR_QA = "valhalla/bart-large-finetuned-squadv1"
_QA_EXPECTED_LOSS = 0.59
_QA_EXPECTED_OUTPUT = "' nice puppet'"
BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
......@@ -1447,10 +1450,11 @@ class BartForSequenceClassification(BartPretrainedModel):
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
output_type=Seq2SeqSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT_SHAPE,
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
)
def forward(
self,
......@@ -1572,11 +1576,11 @@ class BartForQuestionAnswering(BartPretrainedModel):
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_QA,
output_type=Seq2SeqQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
expected_loss=_QA_EXPECTED_LOSS,
expected_output=_QA_EXPECTED_OUTPUT_SHAPE,
expected_output=_QA_EXPECTED_OUTPUT,
)
def forward(
self,
......
......@@ -51,17 +51,20 @@ logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "google/bigbird-pegasus-large-arxiv"
_CONFIG_FOR_DOC = "BigBirdPegasusConfig"
_TOKENIZER_FOR_DOC = "PegasusTokenizer"
_TOKENIZER_FOR_DOC = "PegasusTokenizerFast"
# Base model docstring
_EXPECTED_OUTPUT_SHAPE = [1, 7, 1024]
# SequenceClassification docstring
_SEQ_CLASS_EXPECTED_OUTPUT_SHAPE = [1, 2]
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "hf-internal-testing/tiny-random-bigbird_pegasus"
_SEQ_CLASS_EXPECTED_LOSS = 0.69
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
# QuestionAsnwering docstring
_QA_EXPECTED_LOSS = 2.56
_QA_EXPECTED_OUTPUT_SHAPE = [1, 12]
_CHECKPOINT_FOR_QA = "hf-internal-testing/tiny-random-bigbird_pegasus"
_QA_EXPECTED_LOSS = 3.96
_QA_EXPECTED_OUTPUT = "''"
BIGBIRD_PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST = [
......@@ -2645,10 +2648,11 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
@add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
output_type=Seq2SeqSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT_SHAPE,
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
)
def forward(
self,
......@@ -2771,11 +2775,11 @@ class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
@add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_QA,
output_type=Seq2SeqQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
expected_loss=_QA_EXPECTED_LOSS,
expected_output=_QA_EXPECTED_OUTPUT_SHAPE,
expected_output=_QA_EXPECTED_OUTPUT,
)
def forward(
self,
......
......@@ -55,11 +55,14 @@ _TOKENIZER_FOR_DOC = "MBartTokenizer"
_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
# SequenceClassification docstring
_SEQ_CLASS_EXPECTED_OUTPUT_SHAPE = [1, 2]
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "hf-internal-testing/tiny-random-mbart"
_SEQ_CLASS_EXPECTED_LOSS = 0.69
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
# QuestionAsnwering docstring
_QA_EXPECTED_LOSS = 3.04
_QA_EXPECTED_OUTPUT_SHAPE = [1, 16]
_CHECKPOINT_FOR_QA = "hf-internal-testing/tiny-random-mbart"
_QA_EXPECTED_LOSS = 3.55
_QA_EXPECTED_OUTPUT = "'? Jim Henson was a'"
MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [
......@@ -1437,10 +1440,11 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
output_type=Seq2SeqSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT_SHAPE,
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
)
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
def forward(
......@@ -1563,11 +1567,11 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_QA,
output_type=Seq2SeqQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
expected_loss=_QA_EXPECTED_LOSS,
expected_output=_QA_EXPECTED_OUTPUT_SHAPE,
expected_output=_QA_EXPECTED_OUTPUT,
)
# Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward
def forward(
......
......@@ -54,7 +54,9 @@ _TOKENIZER_FOR_DOC = "PLBartTokenizer"
_EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
# SequenceClassification docstring
_SEQ_CLASS_EXPECTED_OUTPUT_SHAPE = [1, 2]
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "hf-internal-testing/tiny-plbart"
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
_SEQ_CLASS_EXPECTED_LOSS = 0.69
PLBART_PRETRAINED_MODEL_ARCHIVE_LIST = [
......@@ -1408,10 +1410,11 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel):
@add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
output_type=Seq2SeqSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT_SHAPE,
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
)
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment