Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
50bc57ce
Unverified
Commit
50bc57ce
authored
Dec 15, 2021
by
NielsRogge
Committed by
GitHub
Dec 15, 2021
Browse files
Update Perceiver code examples (#14783)
* Fix code examples * Fix code example
parent
48d48276
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
122 additions
and
19 deletions
+122
-19
docs/source/model_doc/perceiver.mdx
docs/source/model_doc/perceiver.mdx
+5
-7
src/transformers/models/perceiver/modeling_perceiver.py
src/transformers/models/perceiver/modeling_perceiver.py
+117
-12
No files found.
docs/source/model_doc/perceiver.mdx
View file @
50bc57ce
...
...
@@ -81,9 +81,10 @@ Tips:
- The quickest way to get started with the Perceiver is by checking the [tutorial
notebooks](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/Perceiver).
- Note that the models available in the library only showcase some examples of what you can do with the Perceiver.
There are many more use cases, including question answering,
named-entity recognition, object detection, audio classification, video classification, etc.
- Refer to the [blog post](https://huggingface.co/blog/perceiver) if you want to fully understand how the model works and
is implemented in the library. Note that the models available in the library only showcase some examples of what you can do
with the Perceiver. There are many more use cases, including question answering, named-entity recognition, object detection,
audio classification, video classification, etc.
## Perceiver specific outputs
...
...
@@ -102,10 +103,7 @@ named-entity recognition, object detection, audio classification, video classifi
## PerceiverTokenizer
[[autodoc]] PerceiverTokenizer
- build_inputs_with_special_tokens
- get_special_tokens_mask
- create_token_type_ids_from_sequences
- save_vocabulary
- __call__
## PerceiverFeatureExtractor
...
...
src/transformers/models/perceiver/modeling_perceiver.py
View file @
50bc57ce
...
...
@@ -757,12 +757,7 @@ class PerceiverModel(PerceiverPreTrainedModel):
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
@
add_start_docstrings_to_model_forward
(
PERCEIVER_INPUTS_DOCSTRING
.
format
(
"(batch_size, sequence_length)"
))
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
checkpoint
=
_CHECKPOINT_FOR_DOC
,
output_type
=
PerceiverModelOutput
,
config_class
=
_CONFIG_FOR_DOC
,
)
@
replace_return_docstrings
(
output_type
=
PerceiverModelOutput
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
self
,
inputs
,
...
...
@@ -773,6 +768,85 @@ class PerceiverModel(PerceiverPreTrainedModel):
output_hidden_states
=
None
,
return_dict
=
None
,
):
r
"""
Returns:
Examples::
>>> from transformers import PerceiverConfig, PerceiverTokenizer, PerceiverFeatureExtractor, PerceiverModel
>>> from transformers.models.perceiver.modeling_perceiver import PerceiverTextPreprocessor, PerceiverImagePreprocessor, PerceiverClassificationDecoder
>>> import torch
>>> import requests
>>> from PIL import Image
>>> # EXAMPLE 1: using the Perceiver to classify texts
>>> # - we define a TextPreprocessor, which can be used to embed tokens
>>> # - we define a ClassificationDecoder, which can be used to decode the
>>> # final hidden states of the latents to classification logits
>>> # using trainable position embeddings
>>> config = PerceiverConfig()
>>> preprocessor = PerceiverTextPreprocessor(config)
>>> decoder = PerceiverClassificationDecoder(config,
... num_channels=config.d_latents,
... trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
... use_query_residual=True)
>>> model = PerceiverModel(config, input_preprocessor=preprocessor, decoder=decoder)
>>> # you can then do a forward pass as follows:
>>> tokenizer = PerceiverTokenizer()
>>> text = "hello world"
>>> inputs = tokenizer(text, return_tensors="pt").input_ids
>>> with torch.no_grad():
>>> outputs = model(inputs=inputs)
>>> logits = outputs.logits
>>> # to train, one can train the model using standard cross-entropy:
>>> criterion = torch.nn.CrossEntropyLoss()
>>> labels = torch.tensor([1])
>>> loss = criterion(logits, labels)
>>> # EXAMPLE 2: using the Perceiver to classify images
>>> # - we define an ImagePreprocessor, which can be used to embed images
>>> preprocessor=PerceiverImagePreprocessor(
config,
prep_type="conv1x1",
spatial_downsample=1,
out_channels=256,
position_encoding_type="trainable",
concat_or_add_pos="concat",
project_pos_dim=256,
trainable_position_encoding_kwargs=dict(num_channels=256, index_dims=config.image_size ** 2),
)
>>> model = PerceiverModel(
... config,
... input_preprocessor=preprocessor,
... decoder=PerceiverClassificationDecoder(
... config,
... num_channels=config.d_latents,
... trainable_position_encoding_kwargs=dict(num_channels=config.d_latents, index_dims=1),
... use_query_residual=True,
... ),
... )
>>> # you can then do a forward pass as follows:
>>> feature_extractor = PerceiverFeatureExtractor()
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = feature_extractor(image, return_tensors="pt").pixel_values
>>> with torch.no_grad():
>>> outputs = model(inputs=inputs)
>>> logits = outputs.logits
>>> # to train, one can train the model using standard cross-entropy:
>>> criterion = torch.nn.CrossEntropyLoss()
>>> labels = torch.tensor([1])
>>> loss = criterion(logits, labels)
"""
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
...
...
@@ -901,12 +975,7 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel):
self
.
post_init
()
@
add_start_docstrings_to_model_forward
(
PERCEIVER_INPUTS_DOCSTRING
.
format
(
"batch_size, sequence_length"
))
@
add_code_sample_docstrings
(
processor_class
=
_TOKENIZER_FOR_DOC
,
checkpoint
=
_CHECKPOINT_FOR_DOC
,
output_type
=
PerceiverMaskedLMOutput
,
config_class
=
_CONFIG_FOR_DOC
,
)
@
replace_return_docstrings
(
output_type
=
PerceiverMaskedLMOutput
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
self
,
inputs
=
None
,
...
...
@@ -923,6 +992,42 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel):
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
Returns:
Examples::
>>> from transformers import PerceiverTokenizer, PerceiverForMaskedLM
>>> import torch
>>> tokenizer = PerceiverTokenizer.from_pretrained('deepmind/language-perceiver')
>>> model = PerceiverForMaskedLM.from_pretrained('deepmind/language-perceiver')
>>> # training
>>> text = "This is an incomplete sentence where some words are missing."
>>> inputs = tokenizer(text, padding="max_length", return_tensors="pt")
>>> # mask " missing."
>>> inputs['input_ids'][0, 52:61] = tokenizer.mask_token_id
>>> labels = tokenizer(text, padding="max_length", return_tensors="pt").input_ids
>>> outputs = model(**inputs, labels=labels)
>>> loss = outputs.loss
>>> logits = outputs.logits
>>> # inference
>>> text = "This is an incomplete sentence where some words are missing."
>>> encoding = tokenizer(text, padding="max_length", return_tensors="pt")
>>> # mask bytes corresponding to " missing.". Note that the model performs much better if the masked span starts with a space.
>>> encoding['input_ids'][0, 52:61] = tokenizer.mask_token_id
>>> # forward pass
>>> with torch.no_grad():
>>> outputs = model(**encoding)
>>> logits = outputs.logits
>>> masked_tokens_predictions = logits[0, 52:61].argmax(dim=-1).tolist()
>>> tokenizer.decode(masked_tokens_predictions)
' missing.'
"""
if
inputs
is
not
None
and
input_ids
is
not
None
:
raise
ValueError
(
"You cannot use both `inputs` and `input_ids`"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment