Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
515ed3ad
Unverified
Commit
515ed3ad
authored
Jan 20, 2022
by
NielsRogge
Committed by
GitHub
Jan 20, 2022
Browse files
Fix doc examples (#15257)
parent
ad739063
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
10 deletions
+39
-10
docs/source/model_doc/trocr.mdx
docs/source/model_doc/trocr.mdx
+2
-1
src/transformers/models/vilt/modeling_vilt.py
src/transformers/models/vilt/modeling_vilt.py
+37
-9
No files found.
docs/source/model_doc/trocr.mdx
View file @
515ed3ad
...
...
@@ -70,7 +70,8 @@ into a single instance to both extract the input features and decode the predict
>>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
>>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
>>> # load image from the IAM dataset url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
>>> # load image from the IAM dataset
>>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
>>> pixel_values = processor(image, return_tensors="pt").pixel_values
...
...
src/transformers/models/vilt/modeling_vilt.py
View file @
515ed3ad
...
...
@@ -42,10 +42,10 @@ from .configuration_vilt import ViltConfig
logger
=
logging
.
get_logger
(
__name__
)
_CONFIG_FOR_DOC
=
"ViltConfig"
_CHECKPOINT_FOR_DOC
=
"dandelin/vilt-b32-mlm
-itm
"
_CHECKPOINT_FOR_DOC
=
"dandelin/vilt-b32-mlm"
VILT_PRETRAINED_MODEL_ARCHIVE_LIST
=
[
"dandelin/vilt-b32-mlm
-itm
"
,
"dandelin/vilt-b32-mlm"
,
# See all ViLT models at https://huggingface.co/models?filter=vilt
]
...
...
@@ -775,17 +775,19 @@ class ViltModel(ViltPreTrainedModel):
Examples:
```python
>>> from transformers import Vilt
FeatureExtract
or, ViltModel
>>> from transformers import Vilt
Process
or, ViltModel
>>> from PIL import Image
>>> import requests
>>> # prepare image and text
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = "hello world"
>>>
feature_extractor = ViltFeatureExtract
or.from_pretrained("dandelin/vilt-b32-mlm
-itm
")
>>> model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm
-itm
")
>>>
processor = ViltProcess
or.from_pretrained("dandelin/vilt-b32-mlm")
>>> model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm")
>>> inputs =
feature_extract
or(image
s=image
, return_tensors="pt")
>>> inputs =
process
or(image
, text
, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
...
...
@@ -930,10 +932,11 @@ class ViltForMaskedLM(ViltPreTrainedModel):
>>> from transformers import ViltProcessor, ViltForMaskedLM
>>> import requests
>>> from PIL import Image
>>> import re
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = "
How many cats are there?
"
>>> text = "
a bunch of [MASK] laying on a [MASK].
"
>>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
>>> model = ViltForMaskedLM.from_pretrained("dandelin/vilt-b32-mlm")
...
...
@@ -943,7 +946,31 @@ class ViltForMaskedLM(ViltPreTrainedModel):
>>> # forward pass
>>> outputs = model(**encoding)
>>> logits = outputs.logits
>>> tl = len(re.findall("\[MASK\]", text))
>>> inferred_token = [text]
>>> # gradually fill in the MASK tokens, one by one
>>> with torch.no_grad():
... for i in range(tl):
... encoded = processor.tokenizer(inferred_token)
... input_ids = torch.tensor(encoded.input_ids).to(device)
... encoded = encoded["input_ids"][0][1:-1]
... outputs = model(input_ids=input_ids, pixel_values=pixel_values)
... mlm_logits = outputs.logits[0] # shape (seq_len, vocab_size)
... # only take into account text features (minus CLS and SEP token)
... mlm_logits = mlm_logits[1 : input_ids.shape[1] - 1, :]
... mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1)
... # only take into account text
... mlm_values[torch.tensor(encoded) != 103] = 0
... select = mlm_values.argmax().item()
... encoded[select] = mlm_ids[select].item()
... inferred_token = [processor.decode(encoded)]
>>> selected_token = ""
>>> encoded = processor.tokenizer(inferred_token)
>>> processor.decode(encoded.input_ids[0], skip_special_tokens=True)
a bunch of cats laying on a couch.
```"""
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
...
...
@@ -1093,6 +1120,7 @@ class ViltForQuestionAnswering(ViltPreTrainedModel):
>>> logits = outputs.logits
>>> idx = logits.argmax(-1).item()
>>> print("Predicted answer:", model.config.id2label[idx])
Predicted answer: 2
```"""
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
...
...
@@ -1297,13 +1325,13 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel):
>>> # prepare inputs
>>> encoding = processor([image1, image2], text, return_tensors="pt")
>>> pixel_values = torch.stack([encoding_1.pixel_values, encoding_2.pixel_values], dim=1)
>>> # forward pass
>>> outputs = model(input_ids=encoding.input_ids, pixel_values=encoding.pixel_values.unsqueeze(0))
>>> logits = outputs.logits
>>> idx = logits.argmax(-1).item()
>>> print("Predicted answer:", model.config.id2label[idx])
Predicted answer: True
```"""
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment