Unverified Commit fa9e18c6 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `OPTForQuestionAnswering` doctest (#19479)



* Fix doc example for OPTForQuestionAnswering
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 957ce646
...@@ -53,12 +53,6 @@ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc" ...@@ -53,12 +53,6 @@ _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc"
_SEQ_CLASS_EXPECTED_LOSS = 1.71 _SEQ_CLASS_EXPECTED_LOSS = 1.71
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
# QuestionAnswering docstring
_QA_EXPECTED_OUTPUT = "'a nice puppet'"
_QA_EXPECTED_LOSS = 7.41
_QA_TARGET_START_INDEX = 14
_QA_TARGET_END_INDEX = 15
OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/opt-125m", "facebook/opt-125m",
"facebook/opt-350m", "facebook/opt-350m",
...@@ -1140,16 +1134,7 @@ class OPTForQuestionAnswering(OPTPreTrainedModel): ...@@ -1140,16 +1134,7 @@ class OPTForQuestionAnswering(OPTPreTrainedModel):
self.post_init() self.post_init()
@add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
qa_target_start_index=_QA_TARGET_START_INDEX,
qa_target_end_index=_QA_TARGET_END_INDEX,
expected_output=_QA_EXPECTED_OUTPUT,
expected_loss=_QA_EXPECTED_LOSS,
)
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
...@@ -1173,7 +1158,36 @@ class OPTForQuestionAnswering(OPTPreTrainedModel): ...@@ -1173,7 +1158,36 @@ class OPTForQuestionAnswering(OPTPreTrainedModel):
Labels for position (index) of the end of the labelled span for computing the token classification loss. Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss. are not taken into account for computing the loss.
"""
Returns:
Example:
```python
>>> from transformers import GPT2Tokenizer, OPTForQuestionAnswering
>>> import torch
>>> torch.manual_seed(4) # doctest: +IGNORE_RESULT
>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
>>> # note: we are loading a OPTForQuestionAnswering from the hub here,
>>> # so the head will be randomly initialized, hence the predictions will be random
>>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m")
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
>>> inputs = tokenizer(question, text, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> answer_start_index = outputs.start_logits.argmax()
>>> answer_end_index = outputs.end_logits.argmax()
>>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
>>> predicted = tokenizer.decode(predict_answer_tokens)
>>> predicted
' Henson?'
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.model( transformer_outputs = self.model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment