Unverified Commit 2f2b19ff authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Change doc example for `BigBirdForQuestionAnswering` (#21723)



Change doc example for BigBirdForQuestionAnswering
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 354b3383
...@@ -3042,8 +3042,8 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel): ...@@ -3042,8 +3042,8 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
>>> from transformers import AutoTokenizer, BigBirdForQuestionAnswering >>> from transformers import AutoTokenizer, BigBirdForQuestionAnswering
>>> from datasets import load_dataset >>> from datasets import load_dataset
>>> tokenizer = AutoTokenizer.from_pretrained("abhinavkulkarni/bigbird-roberta-base-finetuned-squad") >>> tokenizer = AutoTokenizer.from_pretrained("google/bigbird-roberta-base")
>>> model = BigBirdForQuestionAnswering.from_pretrained("abhinavkulkarni/bigbird-roberta-base-finetuned-squad") >>> model = BigBirdForQuestionAnswering.from_pretrained("google/bigbird-roberta-base")
>>> squad_ds = load_dataset("squad_v2", split="train") # doctest: +IGNORE_RESULT >>> squad_ds = load_dataset("squad_v2", split="train") # doctest: +IGNORE_RESULT
>>> # select random article and question >>> # select random article and question
...@@ -3062,17 +3062,14 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel): ...@@ -3062,17 +3062,14 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel):
>>> answer_start_index = outputs.start_logits.argmax() >>> answer_start_index = outputs.start_logits.argmax()
>>> answer_end_index = outputs.end_logits.argmax() >>> answer_end_index = outputs.end_logits.argmax()
>>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1] >>> predict_answer_token_ids = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
>>> tokenizer.decode(predict_answer_tokens) >>> predict_answer_token = tokenizer.decode(predict_answer_token_ids)
'80 °C (176 °F) or more'
``` ```
```python ```python
>>> target_start_index, target_end_index = torch.tensor([130]), torch.tensor([132]) >>> target_start_index, target_end_index = torch.tensor([130]), torch.tensor([132])
>>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index) >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
>>> loss = outputs.loss >>> loss = outputs.loss
>>> round(outputs.loss.item(), 2)
7.63
``` ```
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment