Unverified Commit 606d9084 authored by Xabier de Zuazo's avatar Xabier de Zuazo Committed by GitHub
Browse files

Fix Whisper Conversion Script: Correct decoder_attention_heads and _download function (#26834)

* Fix error in convert_openai_to_hf.py: "_download() missing 1 required positional argument: root"

* Fix error in convert_openai_to_hf.py: "TypeError: byte indices must be integers or slices, not str"

* Fix decoder_attention_heads value in convert_openai_to_hf.py.

Correct the assignment for `decoder_attention_heads` in the conversion script for the Whisper model.

* Black reformat convert_openai_to_hf.py file.

* Fix Whisper model configuration defaults (for Tiny).

- Correct encoder/decoder layers and attention heads count.
- Update model width (`d_model`) to 384.

* Add docstring to the convert_openai_to_hf.py script with a doctest

* Add shebang and +x permission to the convert_openai_to_hf.py

* convert_openai_to_hf.py: reuse the read model_bytes in the _download() function

* Move convert_openai_to_hf.py doctest example to whisper.md

* whisper.md: Add an inference example to the Conversion section.

* whisper.md: remove `model.config.forced_decoder_ids` from examples (deprecated)

* whisper.md: Remove "## Format Conversion" section; not used by users

* whisper.md: Use librispeech_asr_dummy dataset and load_dataset()
parent 90b4adc1
...@@ -34,6 +34,42 @@ The original code can be found [here](https://github.com/openai/whisper). ...@@ -34,6 +34,42 @@ The original code can be found [here](https://github.com/openai/whisper).
- Inference is currently only implemented for short-form i.e. audio is pre-segmented into <=30s segments. Long-form (including timestamps) will be implemented in a future release. - Inference is currently only implemented for short-form i.e. audio is pre-segmented into <=30s segments. Long-form (including timestamps) will be implemented in a future release.
- One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text. - One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text.
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ). The Tensorflow version of this model was contributed by [amyeroberts](https://huggingface.co/amyeroberts).
The original code can be found [here](https://github.com/openai/whisper).
## Inference
Here is a step-by-step guide to transcribing an audio sample using a pre-trained Whisper model:
```python
>>> from datasets import load_dataset
>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration
>>> # Select an audio file and read it:
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> audio_sample = ds[0]["audio"]
>>> waveform = audio_sample["array"]
>>> sampling_rate = audio_sample["sampling_rate"]
>>> # Load the Whisper model in Hugging Face format:
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> # Use the model and processor to transcribe the audio:
>>> input_features = processor(
... waveform, sampling_rate=sampling_rate, return_tensors="pt"
... ).input_features
>>> # Generate token ids
>>> predicted_ids = model.generate(input_features)
>>> # Decode token ids to text
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
>>> transcription[0]
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```
## WhisperConfig ## WhisperConfig
[[autodoc]] WhisperConfig [[autodoc]] WhisperConfig
......
...@@ -77,13 +77,13 @@ class WhisperConfig(PretrainedConfig): ...@@ -77,13 +77,13 @@ class WhisperConfig(PretrainedConfig):
num_mel_bins (`int`, *optional*, defaults to 80): num_mel_bins (`int`, *optional*, defaults to 80):
Number of mel features used per input features. Should correspond to the value used in the Number of mel features used per input features. Should correspond to the value used in the
`WhisperProcessor` class. `WhisperProcessor` class.
encoder_layers (`int`, *optional*, defaults to 6): encoder_layers (`int`, *optional*, defaults to 4):
Number of encoder layers. Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 6): decoder_layers (`int`, *optional*, defaults to 4):
Number of decoder layers. Number of decoder layers.
encoder_attention_heads (`int`, *optional*, defaults to 4): encoder_attention_heads (`int`, *optional*, defaults to 6):
Number of attention heads for each attention layer in the Transformer encoder. Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (`int`, *optional*, defaults to 4): decoder_attention_heads (`int`, *optional*, defaults to 6):
Number of attention heads for each attention layer in the Transformer decoder. Number of attention heads for each attention layer in the Transformer decoder.
encoder_ffn_dim (`int`, *optional*, defaults to 1536): encoder_ffn_dim (`int`, *optional*, defaults to 1536):
Dimensionality of the "intermediate" (often named feed-forward) layer in encoder. Dimensionality of the "intermediate" (often named feed-forward) layer in encoder.
...@@ -106,7 +106,7 @@ class WhisperConfig(PretrainedConfig): ...@@ -106,7 +106,7 @@ class WhisperConfig(PretrainedConfig):
activation_function (`str`, *optional*, defaults to `"gelu"`): activation_function (`str`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported. `"relu"`, `"silu"` and `"gelu_new"` are supported.
d_model (`int`, *optional*, defaults to 256): d_model (`int`, *optional*, defaults to 384):
Dimensionality of the layers. Dimensionality of the layers.
dropout (`float`, *optional*, defaults to 0.1): dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
...@@ -197,10 +197,10 @@ class WhisperConfig(PretrainedConfig): ...@@ -197,10 +197,10 @@ class WhisperConfig(PretrainedConfig):
self, self,
vocab_size=51865, vocab_size=51865,
num_mel_bins=80, num_mel_bins=80,
encoder_layers=6, encoder_layers=4,
encoder_attention_heads=4, encoder_attention_heads=6,
decoder_layers=6, decoder_layers=4,
decoder_attention_heads=4, decoder_attention_heads=6,
decoder_ffn_dim=1536, decoder_ffn_dim=1536,
encoder_ffn_dim=1536, encoder_ffn_dim=1536,
encoder_layerdrop=0.0, encoder_layerdrop=0.0,
...@@ -209,7 +209,7 @@ class WhisperConfig(PretrainedConfig): ...@@ -209,7 +209,7 @@ class WhisperConfig(PretrainedConfig):
use_cache=True, use_cache=True,
is_encoder_decoder=True, is_encoder_decoder=True,
activation_function="gelu", activation_function="gelu",
d_model=256, d_model=384,
dropout=0.0, dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
activation_dropout=0.0, activation_dropout=0.0,
......
#!/usr/bin/env python
"""Converts a Whisper model in OpenAI format to Hugging Face format."""
# Copyright 2022 The HuggingFace Inc. team and the OpenAI team. All rights reserved. # Copyright 2022 The HuggingFace Inc. team and the OpenAI team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -14,6 +16,7 @@ ...@@ -14,6 +16,7 @@
import argparse import argparse
import hashlib import hashlib
import io
import os import os
import urllib import urllib
import warnings import warnings
...@@ -90,7 +93,7 @@ def make_linear_from_emb(emb): ...@@ -90,7 +93,7 @@ def make_linear_from_emb(emb):
return lin_layer return lin_layer
def _download(url: str, root: str) -> bytes: def _download(url: str, root: str) -> io.BytesIO:
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
filename = os.path.basename(url) filename = os.path.basename(url)
...@@ -103,7 +106,7 @@ def _download(url: str, root: str) -> bytes: ...@@ -103,7 +106,7 @@ def _download(url: str, root: str) -> bytes:
if os.path.isfile(download_target): if os.path.isfile(download_target):
model_bytes = open(download_target, "rb").read() model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes return torch.load(io.BytesIO(model_bytes))
else: else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
...@@ -125,12 +128,13 @@ def _download(url: str, root: str) -> bytes: ...@@ -125,12 +128,13 @@ def _download(url: str, root: str) -> bytes:
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
) )
return model_bytes return torch.load(io.BytesIO(model_bytes))
def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path): def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path):
if ".pt" not in checkpoint_path: if ".pt" not in checkpoint_path:
original_checkpoint = _download(_MODELS[checkpoint_path]) root = os.path.dirname(pytorch_dump_folder_path) or "."
original_checkpoint = _download(_MODELS[checkpoint_path], root)
else: else:
original_checkpoint = torch.load(checkpoint_path, map_location="cpu") original_checkpoint = torch.load(checkpoint_path, map_location="cpu")
dimensions = original_checkpoint["dims"] dimensions = original_checkpoint["dims"]
...@@ -151,7 +155,7 @@ def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path): ...@@ -151,7 +155,7 @@ def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path):
encoder_layers=dimensions["n_audio_layer"], encoder_layers=dimensions["n_audio_layer"],
encoder_attention_heads=dimensions["n_audio_head"], encoder_attention_heads=dimensions["n_audio_head"],
decoder_layers=dimensions["n_text_layer"], decoder_layers=dimensions["n_text_layer"],
decoder_attention_heads=dimensions["n_text_state"], decoder_attention_heads=dimensions["n_text_head"],
max_source_positions=dimensions["n_audio_ctx"], max_source_positions=dimensions["n_audio_ctx"],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment