Unverified Commit 9efad4ef authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[Swin2SR] Add doc tests (#20829)



* Fix doc tests

* Use Auto API

* Apply suggestion

* Revert "Apply suggestion"

This reverts commit cd9507a86644b4877c3e4a3d6c2d5919d9272dd7.
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MBP.localdomain>
parent 0d284bd5
......@@ -43,11 +43,11 @@ logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "Swin2SRConfig"
_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
_FEAT_EXTRACTOR_FOR_DOC = "AutoImageProcessor"
# Base docstring
_CHECKPOINT_FOR_DOC = "caidas/swin2sr-classicalsr-x2-64"
_EXPECTED_OUTPUT_SHAPE = [1, 64, 768]
_CHECKPOINT_FOR_DOC = "caidas/swin2SR-classical-sr-x2-64"
_EXPECTED_OUTPUT_SHAPE = [1, 180, 488, 648]
SWIN2SR_PRETRAINED_MODEL_ARCHIVE_LIST = [
......@@ -1141,19 +1141,28 @@ class Swin2SRForImageSuperResolution(Swin2SRPreTrainedModel):
Example:
```python
>>> import torch
>>> from transformers import Swin2SRFeatureExtractor, Swin2SRForImageSuperResolution
>>> from datasets import load_dataset
>>> feature_extractor = Swin2SRFeatureExtractor.from_pretrained("openai/whisper-base")
>>> model = Swin2SRForImageSuperResolution.from_pretrained("openai/whisper-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
>>> list(last_hidden_state.shape)
[1, 2, 512]
>>> import numpy as np
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
>>> processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
>>> model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
>>> url = "https://huggingface.co/spaces/jjourney1125/swin2sr/resolve/main/samples/butterfly.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> # prepare image for the model
>>> inputs = processor(image, return_tensors="pt")
>>> # forward pass
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
>>> output = np.moveaxis(output, source=0, destination=-1)
>>> output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
>>> # you can visualize `output` with `Image.fromarray`
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......
......@@ -166,6 +166,7 @@ src/transformers/models/segformer/modeling_tf_segformer.py
src/transformers/models/squeezebert/configuration_squeezebert.py
src/transformers/models/swin/configuration_swin.py
src/transformers/models/swin/modeling_swin.py
src/transformers/models/swin2sr/modeling_swin2sr.py
src/transformers/models/swinv2/configuration_swinv2.py
src/transformers/models/table_transformer/modeling_table_transformer.py
src/transformers/models/time_series_transformer/configuration_time_series_transformer.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment