Unverified Commit 49bf5698 authored by Steven Anton's avatar Steven Anton Committed by GitHub
Browse files

Add doctests to Perceiver examples (#19129)



* Fix bug in example and add to tests

* Fix failing tests

* Check the size of logits

* Code style

* Try again...

* Add expected loss for PerceiverForMaskedLM doctest
Co-authored-by: default avatarSteven Anton <antonstv@amazon.com>
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent fe01ec34
...@@ -801,6 +801,8 @@ class PerceiverModel(PerceiverPreTrainedModel): ...@@ -801,6 +801,8 @@ class PerceiverModel(PerceiverPreTrainedModel):
>>> with torch.no_grad(): >>> with torch.no_grad():
... outputs = model(inputs=inputs) ... outputs = model(inputs=inputs)
>>> logits = outputs.logits >>> logits = outputs.logits
>>> list(logits.shape)
[1, 2]
>>> # to train, one can train the model using standard cross-entropy: >>> # to train, one can train the model using standard cross-entropy:
>>> criterion = torch.nn.CrossEntropyLoss() >>> criterion = torch.nn.CrossEntropyLoss()
...@@ -810,6 +812,7 @@ class PerceiverModel(PerceiverPreTrainedModel): ...@@ -810,6 +812,7 @@ class PerceiverModel(PerceiverPreTrainedModel):
>>> # EXAMPLE 2: using the Perceiver to classify images >>> # EXAMPLE 2: using the Perceiver to classify images
>>> # - we define an ImagePreprocessor, which can be used to embed images >>> # - we define an ImagePreprocessor, which can be used to embed images
>>> config = PerceiverConfig(image_size=224)
>>> preprocessor = PerceiverImagePreprocessor( >>> preprocessor = PerceiverImagePreprocessor(
... config, ... config,
... prep_type="conv1x1", ... prep_type="conv1x1",
...@@ -844,6 +847,8 @@ class PerceiverModel(PerceiverPreTrainedModel): ...@@ -844,6 +847,8 @@ class PerceiverModel(PerceiverPreTrainedModel):
>>> with torch.no_grad(): >>> with torch.no_grad():
... outputs = model(inputs=inputs) ... outputs = model(inputs=inputs)
>>> logits = outputs.logits >>> logits = outputs.logits
>>> list(logits.shape)
[1, 2]
>>> # to train, one can train the model using standard cross-entropy: >>> # to train, one can train the model using standard cross-entropy:
>>> criterion = torch.nn.CrossEntropyLoss() >>> criterion = torch.nn.CrossEntropyLoss()
...@@ -1017,7 +1022,12 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel): ...@@ -1017,7 +1022,12 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel):
>>> outputs = model(**inputs, labels=labels) >>> outputs = model(**inputs, labels=labels)
>>> loss = outputs.loss >>> loss = outputs.loss
>>> round(loss.item(), 2)
19.87
>>> logits = outputs.logits >>> logits = outputs.logits
>>> list(logits.shape)
[1, 2048, 262]
>>> # inference >>> # inference
>>> text = "This is an incomplete sentence where some words are missing." >>> text = "This is an incomplete sentence where some words are missing."
...@@ -1030,6 +1040,8 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel): ...@@ -1030,6 +1040,8 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel):
>>> with torch.no_grad(): >>> with torch.no_grad():
... outputs = model(**encoding) ... outputs = model(**encoding)
>>> logits = outputs.logits >>> logits = outputs.logits
>>> list(logits.shape)
[1, 2048, 262]
>>> masked_tokens_predictions = logits[0, 52:61].argmax(dim=-1).tolist() >>> masked_tokens_predictions = logits[0, 52:61].argmax(dim=-1).tolist()
>>> tokenizer.decode(masked_tokens_predictions) >>> tokenizer.decode(masked_tokens_predictions)
...@@ -1128,6 +1140,8 @@ class PerceiverForSequenceClassification(PerceiverPreTrainedModel): ...@@ -1128,6 +1140,8 @@ class PerceiverForSequenceClassification(PerceiverPreTrainedModel):
>>> inputs = tokenizer(text, return_tensors="pt").input_ids >>> inputs = tokenizer(text, return_tensors="pt").input_ids
>>> outputs = model(inputs=inputs) >>> outputs = model(inputs=inputs)
>>> logits = outputs.logits >>> logits = outputs.logits
>>> list(logits.shape)
[1, 2]
```""" ```"""
if inputs is not None and input_ids is not None: if inputs is not None and input_ids is not None:
raise ValueError("You cannot use both `inputs` and `input_ids`") raise ValueError("You cannot use both `inputs` and `input_ids`")
...@@ -1265,9 +1279,13 @@ class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel): ...@@ -1265,9 +1279,13 @@ class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel):
>>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values >>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values
>>> outputs = model(inputs=inputs) >>> outputs = model(inputs=inputs)
>>> logits = outputs.logits >>> logits = outputs.logits
>>> list(logits.shape)
[1, 1000]
>>> # model predicts one of the 1000 ImageNet classes >>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item() >>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx]) >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: tabby, tabby cat
```""" ```"""
if inputs is not None and pixel_values is not None: if inputs is not None and pixel_values is not None:
raise ValueError("You cannot use both `inputs` and `pixel_values`") raise ValueError("You cannot use both `inputs` and `pixel_values`")
...@@ -1402,9 +1420,13 @@ class PerceiverForImageClassificationFourier(PerceiverPreTrainedModel): ...@@ -1402,9 +1420,13 @@ class PerceiverForImageClassificationFourier(PerceiverPreTrainedModel):
>>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values >>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values
>>> outputs = model(inputs=inputs) >>> outputs = model(inputs=inputs)
>>> logits = outputs.logits >>> logits = outputs.logits
>>> list(logits.shape)
[1, 1000]
>>> # model predicts one of the 1000 ImageNet classes >>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item() >>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx]) >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: tabby, tabby cat
```""" ```"""
if inputs is not None and pixel_values is not None: if inputs is not None and pixel_values is not None:
raise ValueError("You cannot use both `inputs` and `pixel_values`") raise ValueError("You cannot use both `inputs` and `pixel_values`")
...@@ -1539,9 +1561,13 @@ class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel): ...@@ -1539,9 +1561,13 @@ class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel):
>>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values >>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values
>>> outputs = model(inputs=inputs) >>> outputs = model(inputs=inputs)
>>> logits = outputs.logits >>> logits = outputs.logits
>>> list(logits.shape)
[1, 1000]
>>> # model predicts one of the 1000 ImageNet classes >>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item() >>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx]) >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: tabby, tabby cat
```""" ```"""
if inputs is not None and pixel_values is not None: if inputs is not None and pixel_values is not None:
raise ValueError("You cannot use both `inputs` and `pixel_values`") raise ValueError("You cannot use both `inputs` and `pixel_values`")
...@@ -1689,6 +1715,8 @@ class PerceiverForOpticalFlow(PerceiverPreTrainedModel): ...@@ -1689,6 +1715,8 @@ class PerceiverForOpticalFlow(PerceiverPreTrainedModel):
>>> patches = torch.randn(1, 2, 27, 368, 496) >>> patches = torch.randn(1, 2, 27, 368, 496)
>>> outputs = model(inputs=patches) >>> outputs = model(inputs=patches)
>>> logits = outputs.logits >>> logits = outputs.logits
>>> list(logits.shape)
[1, 368, 496, 2]
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -1915,6 +1943,14 @@ class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel): ...@@ -1915,6 +1943,14 @@ class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel):
>>> outputs = model(inputs=inputs, subsampled_output_points=subsampling) >>> outputs = model(inputs=inputs, subsampled_output_points=subsampling)
>>> logits = outputs.logits >>> logits = outputs.logits
>>> list(logits["audio"].shape)
[1, 240]
>>> list(logits["image"].shape)
[1, 6272, 3]
>>> list(logits["label"].shape)
[1, 700]
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -2925,7 +2961,6 @@ class PerceiverAudioPostprocessor(nn.Module): ...@@ -2925,7 +2961,6 @@ class PerceiverAudioPostprocessor(nn.Module):
self.classifier = nn.Linear(in_channels, config.samples_per_patch) self.classifier = nn.Linear(in_channels, config.samples_per_patch)
def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor: def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor:
logits = self.classifier(inputs) logits = self.classifier(inputs)
return torch.reshape(logits, [inputs.shape[0], -1]) return torch.reshape(logits, [inputs.shape[0], -1])
......
...@@ -58,6 +58,7 @@ src/transformers/models/opt/modeling_opt.py ...@@ -58,6 +58,7 @@ src/transformers/models/opt/modeling_opt.py
src/transformers/models/opt/modeling_tf_opt.py src/transformers/models/opt/modeling_tf_opt.py
src/transformers/models/owlvit/modeling_owlvit.py src/transformers/models/owlvit/modeling_owlvit.py
src/transformers/models/pegasus/modeling_pegasus.py src/transformers/models/pegasus/modeling_pegasus.py
src/transformers/models/perceiver/modeling_perceiver.py
src/transformers/models/plbart/modeling_plbart.py src/transformers/models/plbart/modeling_plbart.py
src/transformers/models/poolformer/modeling_poolformer.py src/transformers/models/poolformer/modeling_poolformer.py
src/transformers/models/reformer/modeling_reformer.py src/transformers/models/reformer/modeling_reformer.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment