Unverified Commit 9efad4ef authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[Swin2SR] Add doc tests (#20829)



* Fix doc tests

* Use Auto API

* Apply suggestion

* Revert "Apply suggestion"

This reverts commit cd9507a86644b4877c3e4a3d6c2d5919d9272dd7.
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MBP.localdomain>
parent 0d284bd5
...@@ -43,11 +43,11 @@ logger = logging.get_logger(__name__) ...@@ -43,11 +43,11 @@ logger = logging.get_logger(__name__)
# General docstring # General docstring
_CONFIG_FOR_DOC = "Swin2SRConfig" _CONFIG_FOR_DOC = "Swin2SRConfig"
_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor" _FEAT_EXTRACTOR_FOR_DOC = "AutoImageProcessor"
# Base docstring # Base docstring
_CHECKPOINT_FOR_DOC = "caidas/swin2sr-classicalsr-x2-64" _CHECKPOINT_FOR_DOC = "caidas/swin2SR-classical-sr-x2-64"
_EXPECTED_OUTPUT_SHAPE = [1, 64, 768] _EXPECTED_OUTPUT_SHAPE = [1, 180, 488, 648]
SWIN2SR_PRETRAINED_MODEL_ARCHIVE_LIST = [ SWIN2SR_PRETRAINED_MODEL_ARCHIVE_LIST = [
...@@ -1141,19 +1141,28 @@ class Swin2SRForImageSuperResolution(Swin2SRPreTrainedModel): ...@@ -1141,19 +1141,28 @@ class Swin2SRForImageSuperResolution(Swin2SRPreTrainedModel):
Example: Example:
```python ```python
>>> import torch >>> import torch
>>> from transformers import Swin2SRFeatureExtractor, Swin2SRForImageSuperResolution >>> import numpy as np
>>> from datasets import load_dataset >>> from PIL import Image
>>> import requests
>>> feature_extractor = Swin2SRFeatureExtractor.from_pretrained("openai/whisper-base")
>>> model = Swin2SRForImageSuperResolution.from_pretrained("openai/whisper-base") >>> from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") >>> model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
>>> input_features = inputs.input_features
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id >>> url = "https://huggingface.co/spaces/jjourney1125/swin2sr/resolve/main/samples/butterfly.jpg"
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state >>> image = Image.open(requests.get(url, stream=True).raw)
>>> list(last_hidden_state.shape) >>> # prepare image for the model
[1, 2, 512] >>> inputs = processor(image, return_tensors="pt")
>>> # forward pass
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
>>> output = np.moveaxis(output, source=0, destination=-1)
>>> output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
>>> # you can visualize `output` with `Image.fromarray`
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......
...@@ -166,6 +166,7 @@ src/transformers/models/segformer/modeling_tf_segformer.py ...@@ -166,6 +166,7 @@ src/transformers/models/segformer/modeling_tf_segformer.py
src/transformers/models/squeezebert/configuration_squeezebert.py src/transformers/models/squeezebert/configuration_squeezebert.py
src/transformers/models/swin/configuration_swin.py src/transformers/models/swin/configuration_swin.py
src/transformers/models/swin/modeling_swin.py src/transformers/models/swin/modeling_swin.py
src/transformers/models/swin2sr/modeling_swin2sr.py
src/transformers/models/swinv2/configuration_swinv2.py src/transformers/models/swinv2/configuration_swinv2.py
src/transformers/models/table_transformer/modeling_table_transformer.py src/transformers/models/table_transformer/modeling_table_transformer.py
src/transformers/models/time_series_transformer/configuration_time_series_transformer.py src/transformers/models/time_series_transformer/configuration_time_series_transformer.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment