Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
515ed3ad
Unverified
Commit
515ed3ad
authored
Jan 20, 2022
by
NielsRogge
Committed by
GitHub
Jan 20, 2022
Browse files
Fix doc examples (#15257)
parent
ad739063
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
10 deletions
+39
-10
docs/source/model_doc/trocr.mdx
docs/source/model_doc/trocr.mdx
+2
-1
src/transformers/models/vilt/modeling_vilt.py
src/transformers/models/vilt/modeling_vilt.py
+37
-9
No files found.
docs/source/model_doc/trocr.mdx
View file @
515ed3ad
...
@@ -70,7 +70,8 @@ into a single instance to both extract the input features and decode the predict
...
@@ -70,7 +70,8 @@ into a single instance to both extract the input features and decode the predict
>>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
>>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
>>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
>>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
>>> # load image from the IAM dataset url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
>>> # load image from the IAM dataset
>>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
>>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
>>> pixel_values = processor(image, return_tensors="pt").pixel_values
>>> pixel_values = processor(image, return_tensors="pt").pixel_values
...
...
src/transformers/models/vilt/modeling_vilt.py
View file @
515ed3ad
...
@@ -42,10 +42,10 @@ from .configuration_vilt import ViltConfig
...
@@ -42,10 +42,10 @@ from .configuration_vilt import ViltConfig
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
_CONFIG_FOR_DOC
=
"ViltConfig"
_CONFIG_FOR_DOC
=
"ViltConfig"
_CHECKPOINT_FOR_DOC
=
"dandelin/vilt-b32-mlm
-itm
"
_CHECKPOINT_FOR_DOC
=
"dandelin/vilt-b32-mlm"
VILT_PRETRAINED_MODEL_ARCHIVE_LIST
=
[
VILT_PRETRAINED_MODEL_ARCHIVE_LIST
=
[
"dandelin/vilt-b32-mlm
-itm
"
,
"dandelin/vilt-b32-mlm"
,
# See all ViLT models at https://huggingface.co/models?filter=vilt
# See all ViLT models at https://huggingface.co/models?filter=vilt
]
]
...
@@ -775,17 +775,19 @@ class ViltModel(ViltPreTrainedModel):
...
@@ -775,17 +775,19 @@ class ViltModel(ViltPreTrainedModel):
Examples:
Examples:
```python
```python
>>> from transformers import Vilt
FeatureExtract
or, ViltModel
>>> from transformers import Vilt
Process
or, ViltModel
>>> from PIL import Image
>>> from PIL import Image
>>> import requests
>>> import requests
>>> # prepare image and text
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = "hello world"
>>>
feature_extractor = ViltFeatureExtract
or.from_pretrained("dandelin/vilt-b32-mlm
-itm
")
>>>
processor = ViltProcess
or.from_pretrained("dandelin/vilt-b32-mlm")
>>> model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm
-itm
")
>>> model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm")
>>> inputs =
feature_extract
or(image
s=image
, return_tensors="pt")
>>> inputs =
process
or(image
, text
, return_tensors="pt")
>>> outputs = model(**inputs)
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
>>> last_hidden_states = outputs.last_hidden_state
```"""
```"""
...
@@ -930,10 +932,11 @@ class ViltForMaskedLM(ViltPreTrainedModel):
...
@@ -930,10 +932,11 @@ class ViltForMaskedLM(ViltPreTrainedModel):
>>> from transformers import ViltProcessor, ViltForMaskedLM
>>> from transformers import ViltProcessor, ViltForMaskedLM
>>> import requests
>>> import requests
>>> from PIL import Image
>>> from PIL import Image
>>> import re
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = "
How many cats are there?
"
>>> text = "
a bunch of [MASK] laying on a [MASK].
"
>>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
>>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
>>> model = ViltForMaskedLM.from_pretrained("dandelin/vilt-b32-mlm")
>>> model = ViltForMaskedLM.from_pretrained("dandelin/vilt-b32-mlm")
...
@@ -943,7 +946,31 @@ class ViltForMaskedLM(ViltPreTrainedModel):
...
@@ -943,7 +946,31 @@ class ViltForMaskedLM(ViltPreTrainedModel):
>>> # forward pass
>>> # forward pass
>>> outputs = model(**encoding)
>>> outputs = model(**encoding)
>>> logits = outputs.logits
>>> tl = len(re.findall("\[MASK\]", text))
>>> inferred_token = [text]
>>> # gradually fill in the MASK tokens, one by one
>>> with torch.no_grad():
... for i in range(tl):
... encoded = processor.tokenizer(inferred_token)
... input_ids = torch.tensor(encoded.input_ids).to(device)
... encoded = encoded["input_ids"][0][1:-1]
... outputs = model(input_ids=input_ids, pixel_values=pixel_values)
... mlm_logits = outputs.logits[0] # shape (seq_len, vocab_size)
... # only take into account text features (minus CLS and SEP token)
... mlm_logits = mlm_logits[1 : input_ids.shape[1] - 1, :]
... mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1)
... # only take into account text
... mlm_values[torch.tensor(encoded) != 103] = 0
... select = mlm_values.argmax().item()
... encoded[select] = mlm_ids[select].item()
... inferred_token = [processor.decode(encoded)]
>>> selected_token = ""
>>> encoded = processor.tokenizer(inferred_token)
>>> processor.decode(encoded.input_ids[0], skip_special_tokens=True)
a bunch of cats laying on a couch.
```"""
```"""
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
...
@@ -1093,6 +1120,7 @@ class ViltForQuestionAnswering(ViltPreTrainedModel):
...
@@ -1093,6 +1120,7 @@ class ViltForQuestionAnswering(ViltPreTrainedModel):
>>> logits = outputs.logits
>>> logits = outputs.logits
>>> idx = logits.argmax(-1).item()
>>> idx = logits.argmax(-1).item()
>>> print("Predicted answer:", model.config.id2label[idx])
>>> print("Predicted answer:", model.config.id2label[idx])
Predicted answer: 2
```"""
```"""
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
...
@@ -1297,13 +1325,13 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel):
...
@@ -1297,13 +1325,13 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel):
>>> # prepare inputs
>>> # prepare inputs
>>> encoding = processor([image1, image2], text, return_tensors="pt")
>>> encoding = processor([image1, image2], text, return_tensors="pt")
>>> pixel_values = torch.stack([encoding_1.pixel_values, encoding_2.pixel_values], dim=1)
>>> # forward pass
>>> # forward pass
>>> outputs = model(input_ids=encoding.input_ids, pixel_values=encoding.pixel_values.unsqueeze(0))
>>> outputs = model(input_ids=encoding.input_ids, pixel_values=encoding.pixel_values.unsqueeze(0))
>>> logits = outputs.logits
>>> logits = outputs.logits
>>> idx = logits.argmax(-1).item()
>>> idx = logits.argmax(-1).item()
>>> print("Predicted answer:", model.config.id2label[idx])
>>> print("Predicted answer:", model.config.id2label[idx])
Predicted answer: True
```"""
```"""
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment