Unverified Commit 614f7d28 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Fix whisper doc (#19608)



* update feature extractor params

* update attention mask handling

* fix doc and pipeline test

* add warning when skipping test

* add whisper translation and transcription test

* fix build doc test

* Correct whisper processor

* make fix copies

* remove sample docstring as it does not fit whisper model

* Update src/transformers/models/whisper/modeling_whisper.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix, doctests are passing

* Nit

* last nit
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 66dd8021
...@@ -32,13 +32,7 @@ from ...modeling_outputs import ( ...@@ -32,13 +32,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput, Seq2SeqModelOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_whisper import WhisperConfig from .configuration_whisper import WhisperConfig
...@@ -46,8 +40,6 @@ logger = logging.get_logger(__name__) ...@@ -46,8 +40,6 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "WhisperConfig" _CONFIG_FOR_DOC = "WhisperConfig"
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny" _CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
_PROCESSOR_FOR_DOC = "WhisperProcessor"
_EXPECTED_OUTPUT_SHAPE = [1, 2, 512]
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [ WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
...@@ -1005,14 +997,7 @@ class WhisperModel(WhisperPreTrainedModel): ...@@ -1005,14 +997,7 @@ class WhisperModel(WhisperPreTrainedModel):
self.encoder._freeze_parameters() self.encoder._freeze_parameters()
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_EXPECTED_OUTPUT_SHAPE,
modality="audio",
)
def forward( def forward(
self, self,
input_features=None, input_features=None,
...@@ -1029,7 +1014,25 @@ class WhisperModel(WhisperPreTrainedModel): ...@@ -1029,7 +1014,25 @@ class WhisperModel(WhisperPreTrainedModel):
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
): ):
r"""
Returns:
Example:
```python
>>> import torch
>>> from transformers import WhisperFeatureExtractor, WhisperModel
>>> from datasets import load_dataset
>>> model = WhisperModel.from_pretrained("openai/whisper-base")
>>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
>>> list(last_hidden_state.shape)
[1, 2, 512]
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment