Unverified Commit 0a144b8c authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[DOCTEST] Fix the documentation of RoCBert (#20142)

* update part of the doc

* add temp values, fix part of the doc

* add template outputs

* add correct models and outputss

* style

* fixup
parent 441811ec
......@@ -52,6 +52,29 @@ _CHECKPOINT_FOR_DOC = "weiweishi/roc-bert-base-zh"
_CONFIG_FOR_DOC = "RoCBertConfig"
_TOKENIZER_FOR_DOC = "RoCBertTokenizer"
# Base model docstring
_EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
# Token Classification output
_CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "ArthurZ/dummy-rocbert-ner"
# fmt: off
_TOKEN_CLASS_EXPECTED_OUTPUT = ["S-EVENT", "S-FAC", "I-ORDINAL", "I-ORDINAL", "E-ORG", "E-LANGUAGE", "E-ORG", "E-ORG", "E-ORG", "E-ORG", "I-EVENT", "S-TIME", "S-TIME", "E-LANGUAGE", "S-TIME", "E-DATE", "I-ORDINAL", "E-QUANTITY", "E-LANGUAGE", "S-TIME", "B-ORDINAL", "S-PRODUCT", "E-LANGUAGE", "E-LANGUAGE", "E-ORG", "E-LOC", "S-TIME", "I-ORDINAL", "S-FAC", "O", "S-GPE", "I-EVENT", "S-GPE", "E-LANGUAGE", "E-ORG", "S-EVENT", "S-FAC", "S-FAC", "S-FAC", "E-ORG", "S-FAC", "E-ORG", "S-GPE"]
# fmt: on
_TOKEN_CLASS_EXPECTED_LOSS = 3.62
# SequenceClassification docstring
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/dummy-rocbert-seq"
_SEQ_CLASS_EXPECTED_OUTPUT = "'financial news'"
_SEQ_CLASS_EXPECTED_LOSS = 2.31
# QuestionAsnwering docstring
_CHECKPOINT_FOR_QA = "ArthurZ/dummy-rocbert-qa"
_QA_EXPECTED_OUTPUT = "''"
_QA_EXPECTED_LOSS = 3.75
_QA_TARGET_START_INDEX = 14
_QA_TARGET_END_INDEX = 15
# Maske language modeling
ROC_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"weiweishi/roc-bert-base-zh",
# See all RoCBert models at https://huggingface.co/models?filter=roc_bert
......@@ -917,6 +940,7 @@ class RoCBertModel(RoCBertPreTrainedModel):
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
......@@ -1146,20 +1170,20 @@ class RoCBertForPreTraining(RoCBertPreTrainedModel):
>>> model = RoCBertForPreTraining.from_pretrained("weiweishi/roc-bert-base-zh")
>>> inputs = tokenizer("你好,很高兴认识你", return_tensors="pt")
>>> attack_inputs = tokenizer("你号,很高兴认识你", return_tensors="pt")
>>> attack_keys = list(attack_inputs.keys())
>>> for key in attack_keys:
... attack_inputs[f"attack_{key}"] = attack_inputs.pop(key)
>>> label_inputs = tokenizer("你好,很高兴认识你", return_tensors="pt")
>>> label_keys = list(attack_inputs.keys())
>>> for key in label_keys:
... label_inputs[f"labels_{key}"] = label_inputs.pop(key)
>>> attack_inputs = {}
>>> for key in list(inputs.keys()):
... attack_inputs[f"attack_{key}"] = inputs[key]
>>> label_inputs = {}
>>> for key in list(inputs.keys()):
... label_inputs[f"labels_{key}"] = inputs[key]
>>> inputs.update(label_inputs)
>>> inputs.update(attack_inputs)
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> logits.shape
torch.Size([1, 11, 21128])
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......@@ -1271,12 +1295,6 @@ class RoCBertForMaskedLM(RoCBertPreTrainedModel):
self.cls.predictions.decoder = new_embeddings
@add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
......@@ -1299,6 +1317,27 @@ class RoCBertForMaskedLM(RoCBertPreTrainedModel):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from transformers import RoCBertTokenizer, RoCBertForMaskedLM
>>> import torch
>>> tokenizer = RoCBertTokenizer.from_pretrained("weiweishi/roc-bert-base-zh")
>>> model = RoCBertForMaskedLM.from_pretrained("weiweishi/roc-bert-base-zh")
>>> inputs = tokenizer("法国是首都[MASK].", return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> # retrieve index of {mask}
>>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
>>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
>>> tokenizer.decode(predicted_token_id)
'.'
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......@@ -1461,7 +1500,8 @@ class RoCBertForCausalLM(RoCBertPreTrainedModel):
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
```"""
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roc_bert(
......@@ -1570,9 +1610,11 @@ class RoCBertForSequenceClassification(RoCBertPreTrainedModel):
@add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
)
def forward(
self,
......@@ -1782,9 +1824,11 @@ class RoCBertForTokenClassification(RoCBertPreTrainedModel):
@add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_TOKEN_CLASSIFICATION,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_TOKEN_CLASS_EXPECTED_OUTPUT,
expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
)
def forward(
self,
......@@ -1865,9 +1909,13 @@ class RoCBertForQuestionAnswering(RoCBertPreTrainedModel):
@add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_QA,
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
qa_target_start_index=_QA_TARGET_START_INDEX,
qa_target_end_index=_QA_TARGET_END_INDEX,
expected_output=_QA_EXPECTED_OUTPUT,
expected_loss=_QA_EXPECTED_LOSS,
)
def forward(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment