Unverified Commit 49bf5698 authored by Steven Anton's avatar Steven Anton Committed by GitHub
Browse files

Add doctests to Perceiver examples (#19129)



* Fix bug in example and add to tests

* Fix failing tests

* Check the size of logits

* Code style

* Try again...

* Add expected loss for PerceiverForMaskedLM doctest
Co-authored-by: default avatarSteven Anton <antonstv@amazon.com>
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent fe01ec34
......@@ -801,6 +801,8 @@ class PerceiverModel(PerceiverPreTrainedModel):
>>> with torch.no_grad():
... outputs = model(inputs=inputs)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 2]
>>> # to train, one can train the model using standard cross-entropy:
>>> criterion = torch.nn.CrossEntropyLoss()
......@@ -810,6 +812,7 @@ class PerceiverModel(PerceiverPreTrainedModel):
>>> # EXAMPLE 2: using the Perceiver to classify images
>>> # - we define an ImagePreprocessor, which can be used to embed images
>>> config = PerceiverConfig(image_size=224)
>>> preprocessor = PerceiverImagePreprocessor(
... config,
... prep_type="conv1x1",
......@@ -844,6 +847,8 @@ class PerceiverModel(PerceiverPreTrainedModel):
>>> with torch.no_grad():
... outputs = model(inputs=inputs)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 2]
>>> # to train, one can train the model using standard cross-entropy:
>>> criterion = torch.nn.CrossEntropyLoss()
......@@ -1017,7 +1022,12 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel):
>>> outputs = model(**inputs, labels=labels)
>>> loss = outputs.loss
>>> round(loss.item(), 2)
19.87
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 2048, 262]
>>> # inference
>>> text = "This is an incomplete sentence where some words are missing."
......@@ -1030,6 +1040,8 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel):
>>> with torch.no_grad():
... outputs = model(**encoding)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 2048, 262]
>>> masked_tokens_predictions = logits[0, 52:61].argmax(dim=-1).tolist()
>>> tokenizer.decode(masked_tokens_predictions)
......@@ -1128,6 +1140,8 @@ class PerceiverForSequenceClassification(PerceiverPreTrainedModel):
>>> inputs = tokenizer(text, return_tensors="pt").input_ids
>>> outputs = model(inputs=inputs)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 2]
```"""
if inputs is not None and input_ids is not None:
raise ValueError("You cannot use both `inputs` and `input_ids`")
......@@ -1265,9 +1279,13 @@ class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel):
>>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values
>>> outputs = model(inputs=inputs)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 1000]
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: tabby, tabby cat
```"""
if inputs is not None and pixel_values is not None:
raise ValueError("You cannot use both `inputs` and `pixel_values`")
......@@ -1402,9 +1420,13 @@ class PerceiverForImageClassificationFourier(PerceiverPreTrainedModel):
>>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values
>>> outputs = model(inputs=inputs)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 1000]
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: tabby, tabby cat
```"""
if inputs is not None and pixel_values is not None:
raise ValueError("You cannot use both `inputs` and `pixel_values`")
......@@ -1539,9 +1561,13 @@ class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel):
>>> inputs = feature_extractor(images=image, return_tensors="pt").pixel_values
>>> outputs = model(inputs=inputs)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 1000]
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = logits.argmax(-1).item()
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
Predicted class: tabby, tabby cat
```"""
if inputs is not None and pixel_values is not None:
raise ValueError("You cannot use both `inputs` and `pixel_values`")
......@@ -1689,6 +1715,8 @@ class PerceiverForOpticalFlow(PerceiverPreTrainedModel):
>>> patches = torch.randn(1, 2, 27, 368, 496)
>>> outputs = model(inputs=patches)
>>> logits = outputs.logits
>>> list(logits.shape)
[1, 368, 496, 2]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......@@ -1915,6 +1943,14 @@ class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel):
>>> outputs = model(inputs=inputs, subsampled_output_points=subsampling)
>>> logits = outputs.logits
>>> list(logits["audio"].shape)
[1, 240]
>>> list(logits["image"].shape)
[1, 6272, 3]
>>> list(logits["label"].shape)
[1, 700]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......@@ -2925,7 +2961,6 @@ class PerceiverAudioPostprocessor(nn.Module):
self.classifier = nn.Linear(in_channels, config.samples_per_patch)
def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, modality_sizes=None) -> torch.Tensor:
logits = self.classifier(inputs)
return torch.reshape(logits, [inputs.shape[0], -1])
......
......@@ -58,6 +58,7 @@ src/transformers/models/opt/modeling_opt.py
src/transformers/models/opt/modeling_tf_opt.py
src/transformers/models/owlvit/modeling_owlvit.py
src/transformers/models/pegasus/modeling_pegasus.py
src/transformers/models/perceiver/modeling_perceiver.py
src/transformers/models/plbart/modeling_plbart.py
src/transformers/models/poolformer/modeling_poolformer.py
src/transformers/models/reformer/modeling_reformer.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment