Unverified Commit 45e14038 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Add WhisperModel to transformers (#19166)



* simplify loop

* add featur extractor

* add model

* start conversion

* add dropout

* initial commit of test files

* copnversion for all models

* update processor for correct padding

* update feature extraction

* update integration test logits match

* fmnt: off for the logits

* on the fly mel bank

* small nit

* update test

* update tokenizer

* nit feature extraction

* update

* update tokenizer test

* adds logit processor and update tokenizer to get supress tokens

* style

* clean convert

* revert to original modeling tf utils

* Update

* update

* nit

* clean convert file

* update tests and nits

* quality

* slow generation test

* ffn_dim to allow customization

* update readme

* add to toctreee

* start fixing integration tests

* update tests and code

* fix feature extractor

* fix config tests common

* update code to fix tests

* fix feature exctractor

* nit feature extraction

* update test for new feature extractor

* style

* add absrtact

* large logits wioth custom decoder input ids

* wraap around is otrch available

* fix feature extractor

* correct logits for whisper small.en

* nit

* fix encoder_attentino_mask

* some fixes

* remove unnecessary inputs

* nits

* add normalizer file

* update etst tokenization

* fix attention mask not defined

* Add model to README

* Fix doc tests

* fix generate

* remove uncoder attention mask useless

* update test modeling whisper

* update condfig to add second non supress tokens

* nits on feature exrtactor

* nit for test tokenizers

* update etsts

* update tests

* update tokenization test

* fixup

* invalidated hf token. Clean convert openai to whisper

* fix logit tests

* fixup

* clean merge

* revert toc_tree changes

* remove useless LogitProcessor

* Update whisper .mdx

* update config file doc

* update configuration docstring

* update test tokenization

* update test tokenization

* update tokenization whisper
Added copied from where needed

* update feature extraction

* nit test name

* style

* quality

* remove get suppress tokens and update non_speech tokens global variables

* Update src/transformers/models/whisper/feature_extraction_whisper.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* clean modeling whisper and test
Removed the attention mask arguments that are deprecated

* fix large test

* Add multilingual audio test, and translate test

* style

* fix larg multilingual test

* nits

* Update docs/source/en/model_doc/whisper.mdx
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* add copied from for attention layer

* remove attention masks in doc

* add english normalizer

* update tokenization test

* remove copied from in whisper attention : no bias in k_proj only

* wrap around dependencies in english normalizer

* style

* correct import generation logits

* for now, wrap feature extractor with torch

* Update src/transformers/models/whisper/convert_openai_whisper_to_tfms.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/whisper/configuration_whisper.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update docs/source/en/model_doc/whisper.mdx
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* remove torch depencies for feature extraction and style

* fixup

* nit

* update logitds

* style

* nit

* nits and fix final tests

* add `is_more_itertools_available` to utils

* quality

* add begin supress tokens, supress tokens to generate args and config

* clean supressTokensLogitProcessor in generation logits

* Nit naming

* add supressTokensAtBegin

* udpate tests, supress tokens to None or correct values

* nit and style

* update RAG to fit test and generate_logit

* add copy pasted statment on english normalizer

* add arguments to config_common_kwargs

* Update src/transformers/generation_utils.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/generation_logits_process.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/whisper/configuration_whisper.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* revert changes based on reviews

* update doc and nits

* more nits

* last nits

* update test configuration common

* add BART name in decoder attention mask documentation

* Update src/transformers/models/whisper/modeling_whisper.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* style

* nit

* nit

* add english.json file to git

* nits on documentation

* nit

* nits

* last styling

* add main toctree file

* remove sentence piece dependency

* clean init file

* fix tokenizer that has no dependencies on sentencepiece

* update whisper init file, nit

* remove english.json file

* add get decoder prompt id

* revert changes and add forced logit processor

* nit

* clean normalizer

* remove protected

* update

* Update src/transformers/models/whisper/configuration_whisper.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* update based on review

* Update src/transformers/models/whisper/configuration_whisper.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add batched tests
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarNielsRogge <niels.rogge1@gmail.com>
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 7598791c
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Feature extractor class for Whisper
"""
from typing import List, Optional, Union
import numpy as np
from numpy.fft import fft
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import TensorType, logging
logger = logging.get_logger(__name__)
class WhisperFeatureExtractor(SequenceFeatureExtractor):
r"""
Constructs a Whisper feature extractor.
This feature extractor inherits from [`WhisperFeatureExtractor`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the `Short Time
Fourier Transform` which should match pytorch's `torch.stft` equivalent.
Args:
feature_size (`int`, defaults to 80):
The feature dimension of the extracted features.
sampling_rate (`int`, defaults to 16000):
The sampling rate at which the audio files should be digitalized expressed in Hertz per second (Hz).
hop_length (`int`, defaults to 160):
Length of the overlaping windows for the STFT used to obtain the Mel Frequency coefficients.
chunk_length (`int`, defaults to 30):
The maximum number of chuncks of `sampling_rate` samples used to trim and pad longer or shorter audio
sequences.
n_fft (`int`, defaults to 400):
Size of the Fourier transform.
padding_value (`float`, *optional*, defaults to 0.0):
Padding value used to pad the audio. Should correspond to silences.
"""
model_input_names = ["input_features"]
def __init__(
self,
feature_size=80,
sampling_rate=16000,
hop_length=160,
chunk_length=30,
n_fft=400,
padding_value=0.0,
**kwargs
):
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
self.n_fft = n_fft
self.hop_length = hop_length
self.chunk_length = chunk_length
self.return_attention_mask = True
self.n_samples = chunk_length * sampling_rate
self.nb_max_frames = self.n_samples // hop_length
self.sampling_rate = sampling_rate
self.mel_filters = self.get_mel_filters(sampling_rate, n_fft, n_mels=feature_size)
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32):
# Initialize the weights
n_mels = int(n_mels)
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
# Center freqs of each FFT bin
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 45.245640471924965
mels = np.linspace(min_mel, max_mel, n_mels + 2)
mels = np.asanyarray(mels)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
mel_f = freqs
fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]
# .. then intersect them with each other and zero
weights[i] = np.maximum(0, np.minimum(lower, upper))
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm[:, np.newaxis]
return weights
def fram_wave(self, waveform, center=True):
"""
Transform a raw waveform into a list of smaller waveforms. The window length defines how much of the signal is
contain in each frame (smalle waveform), while the hope length defines the step between the beginning of each
new frame.
Centering is done by reflecting the waveform which is first centered around `frame_idx * hop_length`.
"""
frames = []
for i in range(0, waveform.shape[0] + 1, self.hop_length):
half_window = (self.n_fft - 1) // 2 + 1
if center:
start = i - half_window if i > half_window else 0
end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0]
frame = waveform[start:end]
if start == 0:
padd_width = (-i + half_window, 0)
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
elif end == waveform.shape[0]:
padd_width = (0, (i - waveform.shape[0] + half_window))
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
else:
frame = waveform[i : i + self.n_fft]
frame_width = frame.shape[0]
if frame_width < waveform.shape[0]:
frame = np.lib.pad(
frame, pad_width=(0, self.n_fft - frame_width), mode="constant", constant_values=0
)
frames.append(frame)
return np.stack(frames, 0)
def stft(self, frames, window):
"""
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same
results as `torch.stft`.
"""
frame_size = frames.shape[1]
fft_size = self.n_fft
if fft_size is None:
fft_size = frame_size
if fft_size < frame_size:
raise ValueError("FFT size must greater or equal the frame size")
# number of FFT bins to store
num_fft_bins = (fft_size >> 1) + 1
data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
fft_signal = np.zeros(fft_size)
for f, frame in enumerate(frames):
if window is not None:
np.multiply(frame, window, out=fft_signal[:frame_size])
else:
fft_signal[:frame_size] = frame
data[f] = fft(fft_signal, axis=0)[:num_fft_bins]
return data.T
def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
"""
Compute the log-Mel spectrogram of the provided audio, gives similar results whisper's original torch
implementation with 1e-5 tolerance.
"""
window = np.hanning(self.n_fft + 1)[:-1]
frames = self.fram_wave(waveform)
stft = self.stft(frames, window=window)
magnitudes = np.abs(stft[:, :-1]) ** 2
filters = self.mel_filters
mel_spec = filters @ magnitudes
log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None))
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
def __call__(
self,
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
truncation: bool = True,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: Optional[bool] = None,
padding: Optional[str] = "max_length",
max_length: Optional[int] = None,
**kwargs
) -> BatchFeature:
"""
Main method to featurize and prepare for the model one or several sequence(s). sequences.
Args:
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
values, a list of numpy arrays or a list of list of float values.
truncation (`bool`, *optional*, default to `True`):
Activates truncation to cut input sequences longer than *max_length* to *max_length*.
pad_to_multiple_of (`int`, *optional*, defaults to None):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
>= 7.5 (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
return_attention_mask (`bool`, *optional*):
Whether to return the attention mask. If left to the default, will return the attention mask according
to the specific feature_extractor's default.
[What are attention masks?](../glossary#attention-mask)
<Tip>
For WhisperTransoformer models, `attention_mask` should alwys be passed for batched inference, to avoid
subtle bugs.
</Tip>
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors instead of list of python integers. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return Numpy `np.ndarray` objects.
sampling_rate (`int`, *optional*):
The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
`sampling_rate` at the forward call to prevent silent errors.
padding_value (`float`, defaults to 0.0):
The value that is used to fill the padding values / vectors.
"""
is_batched = bool(
isinstance(raw_speech, (list, tuple))
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
)
if is_batched:
raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech, dtype=np.float32)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
raw_speech = raw_speech.astype(np.float32)
# always return batch
if not is_batched:
raw_speech = [np.asarray([raw_speech]).T]
batched_speech = BatchFeature({"input_features": raw_speech})
# convert into correct format for padding
padded_inputs = self.pad(
batched_speech,
padding=padding,
max_length=max_length if max_length else self.n_samples,
truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=False,
**kwargs,
)
# make sure list is in array format
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
input_features = [self._np_extract_fbank_features(waveform) for waveform in input_features[0]]
if isinstance(input_features[0], List):
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
else:
padded_inputs["input_features"] = input_features
if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
return padded_inputs
# coding=utf-8
# Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Whisper model."""
import math
import random
from typing import Optional, Tuple
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_whisper import WhisperConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "WhisperConfig"
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"openai/whisper-base",
# See all Whisper models at https://huggingface.co/models?filter=whisper
]
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class WhisperPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__(num_positions, embedding_dim)
def forward(self, input_ids, past_key_values_length=0):
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[-1]]
class WhisperAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
# Copied from transformers.models.bart.modeling_bart.BartAttention._shape with BART->whisper
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
# Copied from transformers.models.bart.modeling_bart.BartAttention.forward with BART->whisper
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextEncoderLayer with Speech2Text->Whisper
class WhisperEncoderLayer(nn.Module):
def __init__(self, config: WhisperConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = WhisperAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
layer_head_mask: torch.Tensor,
output_attentions: bool = False,
):
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextDecoderLayer with Speech2Text->Whisper
class WhisperDecoderLayer(nn.Module):
def __init__(self, config: WhisperConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = WhisperAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = WhisperAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
):
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size *(decoder_attention_heads,)*.
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# Cross-Attention Block
cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
# Fully Connected
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
if use_cache:
outputs += (present_key_value,)
return outputs
class WhisperPreTrainedModel(PreTrainedModel):
config_class = WhisperConfig
base_model_prefix = "model"
main_input_name = "input_features"
supports_gradient_checkpointing = True
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (WhisperDecoder, WhisperEncoder)):
module.gradient_checkpointing = value
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
"""
Computes the output length of the convolutional layers
"""
input_lengths = (input_lengths - 1) // 2 + 1
return input_lengths
WHISPER_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`WhisperConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
WHISPER_INPUTS_DOCSTRING = r"""
Args:
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are decoder input IDs?](../glossary#decoder-input-ids)
Whisper uses the `decoder_start_token_id` as the starting token for `decoder_input_ids` generation. If
`past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
If you want to change padding behavior, you should read
[`modeling_whisper._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the BART
paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`. decoder_inputs_embeds (`torch.FloatTensor` of
shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
`decoder_input_ids` you can choose to directly pass an embedded representation. If `past_key_values` is
used, optionally only the last `decoder_inputs_embeds` have to be input (see `past_key_values`). This is
useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
of `inputs_embeds`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class WhisperEncoder(WhisperPreTrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
[`WhisperEncoderLayer`].
Args:
config: WhisperConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: WhisperConfig):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop
embed_dim = config.d_model
self.num_mel_bins = config.num_mel_bins
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_features,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Args:
input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_features`, the [`WhisperFeatureExtractor`] should be used for extracting the mel features,
padding and conversion into a tensor of type `torch.FloatTensor`. See
[`~WhisperFeatureExtractor.__call__`]
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
inputs_embeds = inputs_embeds.permute(0, 2, 1)
embed_pos = self.embed_positions.weight
hidden_states = inputs_embeds + embed_pos
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states,
None,
(head_mask[idx] if head_mask is not None else None),
)
else:
layer_outputs = encoder_layer(
hidden_states,
None,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
class WhisperDecoder(WhisperPreTrainedModel):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`WhisperDecoderLayer`]
Args:
config: WhisperConfig
"""
def __init__(self, config: WhisperConfig):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_target_positions
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model)
self.layers = nn.ModuleList([WhisperDecoderLayer(config) for _ in range(config.decoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(inputs_embeds.device)
if attention_mask is not None:
if attention_mask.shape[-1] > input_shape[-1] > 0:
attention_mask = attention_mask[:, : input_shape[-1] + past_key_values_length]
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
control over how to convert `input_ids` indices into associated vectors than the model's internal
embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
# embed positions
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
assert attn_mask.size()[0] == (len(self.layers)), (
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache ="
" False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
encoder_hidden_states,
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
hidden_states = self.layer_norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
@add_start_docstrings(
"The bare Whisper Model outputting raw hidden-states without any specific head on top.",
WHISPER_START_DOCSTRING,
)
class WhisperModel(WhisperPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"proj_out.weight"]
def __init__(self, config: WhisperConfig):
super().__init__(config)
self.encoder = WhisperEncoder(config)
self.decoder = WhisperDecoder(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.decoder.embed_tokens
def set_input_embeddings(self, value):
self.decoder.embed_tokens = value
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_features=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Returns:
Example:
```python
>>> import torch
>>> from transformers import WhisperModel, WhisperFeatureExtractor
>>> from datasets import load_dataset
>>> model = WhisperModel.from_pretrained("openai/whisper-base")
>>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
>>> list(last_hidden_state.shape)
[1, 2, 512]
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_features,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs[0],
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
"The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
WHISPER_START_DOCSTRING,
)
class WhisperForConditionalGeneration(WhisperPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = [
r"encoder.version",
r"decoder.version",
r"proj_out.weight",
]
_keys_to_ignore_on_save = [
r"proj_out.weight",
]
def __init__(self, config: WhisperConfig):
super().__init__(config)
self.model = WhisperModel(config)
self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_encoder(self):
return self.model.get_encoder()
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
return new_embeddings
def get_output_embeddings(self):
return self.proj_out
def set_output_embeddings(self, new_embeddings):
self.proj_out = new_embeddings
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_features=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
past_key_values=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> import torch
>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> generated_ids = model.generate(inputs=input_features)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> transcription
' Mr. Quilter is the apostle of the middle classes, and we are glad to'
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_features,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
lm_logits = self.proj_out(outputs[0])
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
def prepare_inputs_for_generation(
self, decoder_input_ids, past=None, use_cache=None, encoder_outputs=None, attention_mask=None, **kwargs
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"use_cache": use_cache,
"decoder_attention_mask": None,
}
#
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Speech processor class for Whisper
"""
from ...processing_utils import ProcessorMixin
class WhisperProcessor(ProcessorMixin):
r"""
Constructs a Whisper processor which wraps a Whisper feature extractor and a Whisper tokenizer into a single
processor.
[`WhisperProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`] and [`WhisperTokenizer`]. See
the [`~WhisperProcessor.__call__`] and [`~WhisperProcessor.decode`] for more information.
Args:
feature_extractor (`WhisperFeatureExtractor`):
An instance of [`WhisperFeatureExtractor`]. The feature extractor is a required input.
tokenizer (`WhisperTokenizer`):
An instance of [`WhisperTokenizer`]. The tokenizer is a required input.
"""
feature_extractor_class = "WhisperFeatureExtractor"
tokenizer_class = "WhisperTokenizer"
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)
self.current_processor = self.feature_extractor
self._in_target_context_manager = False
def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
forced_decoder_tokens = ""
if language is not None:
if f"<|{language}|>" not in self.tokenizer.additional_special_tokens:
raise ValueError(
f"{language} is not supported. The language should be one of the following: '<|en|>',"
" '<|zh|>', '<|de|>', '<|es|>', '<|ru|>', '<|ko|>', '<|fr|>', '<|ja|>', '<|pt|>', '<|tr|>',"
" '<|pl|>', '<|ca|>', '<|nl|>', '<|ar|>', '<|sv|>', '<|it|>', '<|id|>', '<|hi|>', '<|fi|>',"
" '<|vi|>', '<|iw|>', '<|uk|>', '<|el|>', '<|ms|>', '<|cs|>', '<|ro|>', '<|da|>', '<|hu|>',"
" '<|ta|>', '<|no|>', '<|th|>', '<|ur|>', '<|hr|>', '<|bg|>', '<|lt|>', '<|la|>', '<|mi|>',"
" '<|ml|>', '<|cy|>', '<|sk|>', '<|te|>', '<|fa|>', '<|lv|>', '<|bn|>', '<|sr|>', '<|az|>',"
" '<|sl|>', '<|kn|>', '<|et|>', '<|mk|>', '<|br|>', '<|eu|>', '<|is|>', '<|hy|>', '<|ne|>',"
" '<|mn|>', '<|bs|>', '<|kk|>', '<|sq|>', '<|sw|>', '<|gl|>', '<|mr|>', '<|pa|>', '<|si|>',"
" '<|km|>', '<|sn|>', '<|yo|>', '<|so|>', '<|af|>', '<|oc|>', '<|ka|>', '<|be|>', '<|tg|>',"
" '<|sd|>', '<|gu|>', '<|am|>', '<|yi|>', '<|lo|>', '<|uz|>', '<|fo|>', '<|ht|>', '<|ps|>',"
" '<|tk|>', '<|nn|>', '<|mt|>', '<|sa|>', '<|lb|>', '<|my|>', '<|bo|>', '<|tl|>', '<|mg|>',"
" '<|as|>', '<|tt|>', '<|haw|>', '<|ln|>', '<|ha|>', '<|ba|>', '<|jw|>', '<|su|>'"
)
forced_decoder_tokens += f"<|{language}|>"
if task is not None:
if f"<|{task}|>" not in self.tokenizer.additional_special_tokens:
raise ValueError(
f"'{task}' is not supported. The language should be in : {{'transcribe', 'translate'}}"
)
forced_decoder_tokens += f"<|{task}|>"
forced_decoder_tokens += "<|notimestamps|>" if no_timestamps else ""
ids = self.tokenizer.encode(forced_decoder_tokens)
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(ids)]
return forced_decoder_ids
def __call__(self, *args, **kwargs):
"""
When used in normal mode, this method forwards all its arguments to WhisperFeatureExtractor's
[`~WhisperFeatureExtractor.__call__`] and returns its output. If used in the context
[`~WhisperProcessor.as_target_processor`] this method forwards all its arguments to WhisperTokenizer's
[`~WhisperTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information.
"""
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor(*args, **kwargs)
audio = kwargs.pop("audio", None)
text = kwargs.pop("text", None)
if len(args) > 0:
audio = args[0]
args = args[1:]
if audio is None and text is None:
raise ValueError("You need to specify either an `audio` or `text` input to process.")
if audio is not None:
inputs = self.feature_extractor(audio, *args, **kwargs)
if text is not None:
encodings = self.tokenizer(text, **kwargs)
if text is None:
return inputs
elif audio is None:
return encodings
else:
inputs["labels"] = encodings["input_ids"]
return inputs
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to WhisperTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to WhisperTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for Whisper."""
import json
import os
from typing import List, Optional, Tuple, Union
import regex as re
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import logging
from .english_normalizer import EnglishTextNormalizer
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"tokenizer_file": "tokenizer.json",
"merges_file": "merges.txt",
"normalizer_file": "normalizer.json",
}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"openai/whisper-base": "https://huggingface.co/openai/whisper-base/resolve/main/vocab.json",
},
"merges_file": {"openai/whisper-base": "https://huggingface.co/openai/whisper-base/resolve/main/merges_file.txt"},
"normalizer_file": {
"openai/whisper-base": "https://huggingface.co/openai/whisper-base/resolve/main/normalizer.json"
},
}
MAX_MODEL_INPUT_SIZES = {
"openai/whisper-base": 448,
}
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
characters the bpe code barfs on.
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
tables between utf-8 bytes and unicode strings.
"""
bs = (
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
logger = logging.get_logger(__name__)
# Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
def get_pairs(word):
"""
Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class WhisperTokenizer(PreTrainedTokenizer):
"""
Construct an Whisper tokenizer.
This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to
the superclass for more information regarding such methods.
Args:
vocab_file (`str`):
Path to the vocabulary file.
merges_file (`str`):
Path to the merges file.
normalizer_file (`str`, *optional*, defaults to `None`):
Path to the normalizer_file file.
errors (`str`, *optional*, defaults to `"replace"`):
Paradigm to follow when decoding bytes to UTF-8. See
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The beginning of sequence token.
eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
The end of sequence token.
add_prefix_space (`bool`, *optional*, defaults to `False`):
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
other word.
add_bos_token (`bool`, *optional*, defaults to `False`):
Whether or not to add an initial <|endoftext|> to the input. This allows to treat the leading word just as
any other word.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = MAX_MODEL_INPUT_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
merges_file,
normalizer_file=None,
errors="replace",
unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
pad_token=None,
add_prefix_space=False,
add_bos_token=False,
**kwargs
):
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
super().__init__(
errors=errors,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
add_prefix_space=add_prefix_space,
add_bos_token=add_bos_token,
**kwargs,
)
self.add_bos_token = add_bos_token
with open(vocab_file, encoding="utf-8") as vocab_handle:
self.encoder = json.load(vocab_handle)
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
with open(merges_file, encoding="utf-8") as merges_handle:
bpe_merges = merges_handle.read().split("\n")[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
self.add_prefix_space = add_prefix_space
if normalizer_file is not None:
with open(normalizer_file, encoding="utf-8") as vocab_handle:
self.english_spelling_normalizer = json.load(vocab_handle)
else:
self.english_spelling_normalizer = None
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
@property
def vocab_size(self) -> int:
return len(self.encoder)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe with GPT2 -> Whisper
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
except ValueError:
new_word.extend(word[i:])
break
else:
new_word.extend(word[i:j])
i = j
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.build_inputs_with_special_tokens with GPT2 -> Whisper
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
if self.add_bos_token:
bos_token_ids = [self.bos_token_id]
else:
bos_token_ids = []
output = bos_token_ids + token_ids_0
if token_ids_1 is None:
return output
return output + bos_token_ids + token_ids_1
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_special_tokens_mask with GPT2 -> Whisper
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
if not self.add_bos_token:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=False
)
if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0))
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize with GPT2 -> Whisper
def _tokenize(self, text):
"""Tokenize a string."""
bpe_tokens = []
for token in re.findall(self.pat, text):
token = "".join(
self.byte_encoder[b] for b in token.encode("utf-8")
) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id with GPT2 -> Whisper
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.decoder.get(index, self.decoder.get(self.unk_token_id))
def _normalize(self, text):
"""
Normalize a given string using the `EnglishTextNormalizer` class, which preforms commons transformation on
english text.
"""
normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
return normalizer(text)
def _decode(
self, token_ids: Union[int, List[int]], skip_special_tokens: bool = False, normalize: bool = False, **kwargs
) -> str:
self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
# To avoid mixing byte-level and unicode for byte-level BPT
# we need to build string separately for added tokens and byte-level tokens
# cf. https://github.com/huggingface/transformers/issues/1133
sub_texts = []
current_sub_text = []
for token in filtered_tokens:
if skip_special_tokens and token in self.all_special_ids:
continue
if token in self.added_tokens_encoder:
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
text = "".join(sub_texts)
if normalize:
clean_text = self._normalize(text)
return clean_text
else:
return text
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string with GPT2 -> Whisper
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
text = "".join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
return text
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
merge_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
)
normalizer_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["normalizer_file"]
)
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write("#version: 0.2\n")
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!"
)
index = token_index
writer.write(" ".join(bpe_tokens) + "\n")
index += 1
if self.english_spelling_normalizer is not None:
with open(normalizer_file, "w", encoding="utf-8") as f:
f.write(
json.dumps(self.english_spelling_normalizer, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
)
return vocab_file, merge_file, normalizer_file
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.prepare_for_tokenization with GPT2 -> Whisper
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
if is_split_into_words or add_prefix_space:
text = " " + text
return (text, kwargs)
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._build_conversation_input_ids with GPT2 -> Whisper
def _build_conversation_input_ids(self, conversation) -> List[int]:
input_ids = []
for is_user, text in conversation.iter_texts():
input_ids.extend(self.encode(text, add_special_tokens=False) + [self.eos_token_id])
if len(input_ids) > self.model_max_length:
input_ids = input_ids[-self.model_max_length :]
return input_ids
...@@ -100,6 +100,7 @@ from .import_utils import ( ...@@ -100,6 +100,7 @@ from .import_utils import (
is_ipex_available, is_ipex_available,
is_jumanpp_available, is_jumanpp_available,
is_librosa_available, is_librosa_available,
is_more_itertools_available,
is_ninja_available, is_ninja_available,
is_onnx_available, is_onnx_available,
is_pandas_available, is_pandas_available,
......
...@@ -5444,6 +5444,30 @@ class WavLMPreTrainedModel(metaclass=DummyObject): ...@@ -5444,6 +5444,30 @@ class WavLMPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class WhisperForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class WhisperModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class WhisperPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
XCLIP_PRETRAINED_MODEL_ARCHIVE_LIST = None XCLIP_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -456,6 +456,10 @@ def is_detectron2_available(): ...@@ -456,6 +456,10 @@ def is_detectron2_available():
return _detectron2_available return _detectron2_available
def is_more_itertools_available():
return importlib.util.find_spec("more_itertools") is not None
def is_rjieba_available(): def is_rjieba_available():
return importlib.util.find_spec("rjieba") is not None return importlib.util.find_spec("rjieba") is not None
......
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import os
import random
import tempfile
import unittest
import numpy as np
from transformers import is_speech_available
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torchaudio
from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
if is_speech_available():
from transformers import WhisperFeatureExtractor
if is_torch_available():
import torch
global_rng = random.Random()
def floats_list(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
values = []
for batch_idx in range(shape[0]):
values.append([])
for _ in range(shape[1]):
values[-1].append(rng.random() * scale)
return values
@require_torch
@require_torchaudio
class WhisperFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
min_seq_length=400,
max_seq_length=2000,
feature_size=10,
hop_length=160,
chunk_length=8,
padding_value=0.0,
sampling_rate=4_000,
return_attention_mask=True,
do_normalize=True,
):
self.parent = parent
self.batch_size = batch_size
self.min_seq_length = min_seq_length
self.max_seq_length = max_seq_length
self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
self.padding_value = padding_value
self.sampling_rate = sampling_rate
self.return_attention_mask = return_attention_mask
self.do_normalize = do_normalize
self.feature_size = feature_size
self.chunk_length = chunk_length
self.hop_length = hop_length
def prepare_feat_extract_dict(self):
return {
"feature_size": self.feature_size,
"hop_length": self.hop_length,
"chunk_length": self.chunk_length,
"padding_value": self.padding_value,
"sampling_rate": self.sampling_rate,
"return_attention_mask": self.return_attention_mask,
"do_normalize": self.do_normalize,
}
def prepare_inputs_for_common(self, equal_length=False, numpify=False):
def _flatten(list_of_lists):
return list(itertools.chain(*list_of_lists))
if equal_length:
speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)]
else:
# make sure that inputs increase in size
speech_inputs = [
floats_list((x, self.feature_size))
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
]
if numpify:
speech_inputs = [np.asarray(x) for x in speech_inputs]
return speech_inputs
@require_torch
@require_torchaudio
class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = WhisperFeatureExtractor if is_speech_available() else None
def setUp(self):
self.feat_extract_tester = WhisperFeatureExtractionTester(self)
def test_feat_extract_from_and_save_pretrained(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
mel_1 = dict_first.pop("mel_filters")
mel_2 = dict_second.pop("mel_filters")
self.assertTrue(np.allclose(mel_1, mel_2))
self.assertEqual(dict_first, dict_second)
def test_feat_extract_to_json_file(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
feat_extract_first.to_json_file(json_file_path)
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
mel_1 = dict_first.pop("mel_filters")
mel_2 = dict_second.pop("mel_filters")
self.assertTrue(np.allclose(mel_1, mel_2))
self.assertEqual(dict_first, dict_second)
def test_call(self):
# Tests that all call wrap to encode_plus and batch_encode_plus
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
# create three inputs of length 800, 1000, and 1200
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
# Test feature size
input_features = feature_extractor(np_speech_inputs, padding="max_length", return_tensors="np").input_features
self.assertTrue(input_features.ndim == 3)
self.assertTrue(input_features.shape[-1] == feature_extractor.nb_max_frames)
self.assertTrue(input_features.shape[-2] == feature_extractor.feature_size)
# Test not batched input
encoded_sequences_1 = feature_extractor(speech_inputs[0], return_tensors="np").input_features
encoded_sequences_2 = feature_extractor(np_speech_inputs[0], return_tensors="np").input_features
self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
# Test batched
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features
encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
# Test truncation required
speech_inputs = [floats_list((1, x))[0] for x in range(200, (feature_extractor.n_samples + 500), 200)]
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
speech_inputs_truncated = [x[: feature_extractor.n_samples] for x in speech_inputs]
np_speech_inputs_truncated = [np.asarray(speech_input) for speech_input in speech_inputs_truncated]
encoded_sequences_1 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
encoded_sequences_2 = feature_extractor(np_speech_inputs_truncated, return_tensors="np").input_features
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
def test_double_precision_pad(self):
import torch
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
np_speech_inputs = np.random.rand(100, 32).astype(np.float64)
py_speech_inputs = np_speech_inputs.tolist()
for inputs in [py_speech_inputs, np_speech_inputs]:
np_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="np")
self.assertTrue(np_processed.input_features.dtype == np.float32)
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
self.assertTrue(pt_processed.input_features.dtype == torch.float32)
def _load_datasamples(self, num_samples):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
def test_integration(self):
# fmt: off
EXPECTED_INPUT_FEATURES = torch.tensor(
[
0.1193, -0.0946, -0.1098, -0.0196, 0.0225, -0.0690, -0.1736, 0.0951,
0.0971, -0.0817, -0.0702, 0.0162, 0.0260, 0.0017, -0.0192, -0.1678,
0.0709, -0.1867, -0.0655, -0.0274, -0.0234, -0.1884, -0.0516, -0.0554,
-0.0274, -0.1425, -0.1423, 0.0837, 0.0377, -0.0854
]
)
# fmt: on
input_speech = self._load_datasamples(1)
feaure_extractor = WhisperFeatureExtractor()
input_features = feaure_extractor(input_speech, return_tensors="pt").input_features
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Whisper model. """
import copy
import inspect
import os
import tempfile
import unittest
from transformers import WhisperConfig
from transformers.testing_utils import is_torch_available, require_torch, require_torchaudio, slow, torch_device
from transformers.utils import cached_property
from transformers.utils.import_utils import is_datasets_available
from ...generation.test_generation_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
if is_datasets_available():
import datasets
from datasets import load_dataset
if is_torch_available():
import torch
from transformers import (
WhisperFeatureExtractor,
WhisperForConditionalGeneration,
WhisperModel,
WhisperProcessor,
set_seed,
)
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder
def prepare_whisper_inputs_dict(
config,
input_features,
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
if cross_attn_head_mask is None:
cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
# "input_ids": input_features,
"input_features": input_features,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}
@require_torch
class WhisperModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=60,
is_training=True,
use_labels=False,
vocab_size=99,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=4,
input_channels=1,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=20,
max_source_positions=30,
max_target_positions=40,
bos_token_id=98,
eos_token_id=98,
pad_token_id=0,
num_mel_bins=80,
decoder_start_token_id=85,
num_conv_layers=1,
suppress_tokens=None,
begin_suppress_tokens=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.input_channels = input_channels
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.num_mel_bins = num_mel_bins
self.max_position_embeddings = max_position_embeddings
self.max_source_positions = max_source_positions
self.max_target_positions = max_target_positions
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.decoder_start_token_id = decoder_start_token_id
self.num_conv_layers = num_conv_layers
self.suppress_tokens = suppress_tokens
self.begin_suppress_tokens = begin_suppress_tokens
def prepare_config_and_inputs(self):
input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size)
decoder_input_ids = torch.tensor(self.batch_size * [[self.decoder_start_token_id]], device=torch_device)
config = self.get_config()
inputs_dict = prepare_whisper_inputs_dict(
config,
attention_mask=None,
input_features=input_features,
decoder_input_ids=decoder_input_ids,
)
return config, inputs_dict
def get_config(self):
return WhisperConfig(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
decoder_layers=self.num_hidden_layers,
encoder_attention_heads=self.num_attention_heads,
decoder_attention_heads=self.num_attention_heads,
input_channels=self.input_channels,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
max_source_positions=self.max_source_positions,
max_target_positions=self.max_target_positions,
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
decoder_ffn_dim=self.hidden_size,
encoder_ffn_dim=self.hidden_size,
decoder_start_token_id=self.decoder_start_token_id,
suppress_tokens=self.suppress_tokens,
begin_suppress_tokens=self.begin_suppress_tokens,
)
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
def get_subsampled_output_lengths(self, input_lengths):
"""
Computes the output length of the convolutional layers
"""
for i in range(self.num_conv_layers):
input_lengths = (input_lengths - 1) // 2 + 1
return input_lengths
def create_and_check_model_forward(self, config, inputs_dict):
model = WhisperModel(config=config).to(torch_device).eval()
input_features = inputs_dict["input_features"]
decoder_input_ids = inputs_dict["decoder_input_ids"]
# first forward pass
last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
self.parent.assertTrue(last_hidden_state.shape, (13, 7, 16))
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
model = WhisperModel(config=config).get_decoder().to(torch_device).eval()
input_ids = inputs_dict["decoder_input_ids"]
attention_mask = inputs_dict["decoder_attention_mask"]
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size).clamp(2)
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1)
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
"last_hidden_state"
]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
model = WhisperModel(config=config).to(torch_device).eval()
outputs = model(**inputs_dict)
encoder_last_hidden_state = outputs.encoder_last_hidden_state
last_hidden_state = outputs.last_hidden_state
with tempfile.TemporaryDirectory() as tmpdirname:
encoder = model.get_encoder()
encoder.save_pretrained(tmpdirname)
encoder = WhisperEncoder.from_pretrained(tmpdirname).to(torch_device)
encoder_last_hidden_state_2 = encoder(inputs_dict["input_features"])[0]
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
with tempfile.TemporaryDirectory() as tmpdirname:
decoder = model.get_decoder()
decoder.save_pretrained(tmpdirname)
decoder = WhisperDecoder.from_pretrained(tmpdirname).to(torch_device)
last_hidden_state_2 = decoder(
input_ids=inputs_dict["decoder_input_ids"],
attention_mask=inputs_dict["decoder_attention_mask"],
encoder_hidden_states=encoder_last_hidden_state,
)[0]
self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3)
@require_torch
class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (WhisperModel, WhisperForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (WhisperForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True
fx_compatible = False
test_pruning = False
test_missing_keys = False
input_name = "input_features"
def setUp(self):
self.model_tester = WhisperModelTester(self)
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
self.maxDiff = 3000
def test_config(self):
self.config_tester.run_common_tests()
def test_save_load_strict(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
self.assertEqual(info["missing_keys"], [])
def test_model_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_forward(*config_and_inputs)
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
def test_encoder_decoder_model_standalone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
def _get_input_ids_and_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict[self.input_name]
# cut to half length & take max batch_size 3
max_batch_size = 3
input_ids = input_ids[:max_batch_size, :, :]
# generate max 3 tokens
max_length = input_ids.shape[-1] + 3
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id
return config, input_ids, None, max_length
# not implemented currently
def test_inputs_embeds(self):
pass
# training is not supported yet
def test_training(self):
pass
def test_training_gradient_checkpointing(self):
pass
def test_generate_with_head_masking(self):
pass
def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
config.max_target_positions = 400
input_features = input_dict["input_features"]
model = WhisperForConditionalGeneration(config).eval().to(torch_device)
if torch_device == "cuda":
input_features = input_features.half()
model.half()
model.generate(input_features)
model.generate(input_features, num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = [
"input_features",
"decoder_input_ids",
"decoder_attention_mask",
]
expected_arg_names.extend(
["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
else ["encoder_outputs"]
)
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
if hasattr(self.model_tester, "encoder_seq_length"):
seq_length = self.model_tester.encoder_seq_length
else:
seq_length = self.model_tester.seq_length
subsampled_seq_length = model._get_feat_extract_output_lengths(seq_length)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[subsampled_seq_length, self.model_tester.hidden_size],
)
if config.is_encoder_decoder:
hidden_states = outputs.decoder_hidden_states
self.assertIsInstance(hidden_states, (list, tuple))
self.assertEqual(len(hidden_states), expected_num_layers)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", 1)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[decoder_seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
seq_len = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", 1)
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
decoder_key_length = getattr(self.model_tester, "decoder_key_length", 1)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.eval()
subsampled_encoder_seq_length = model._get_feat_extract_output_lengths(encoder_seq_length)
subsampled_encoder_key_length = model._get_feat_extract_output_lengths(encoder_key_length)
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
)
out_len = len(outputs)
correct_outlen = 5
# loss is at first position
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned
self.assertEqual(out_len, correct_outlen)
# decoder attentions
decoder_attentions = outputs.decoder_attentions
self.assertIsInstance(decoder_attentions, (list, tuple))
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
# cross attentions
cross_attentions = outputs.cross_attentions
self.assertIsInstance(cross_attentions, (list, tuple))
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(cross_attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
decoder_seq_length,
subsampled_encoder_key_length,
],
)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
added_hidden_states = 2
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
)
def test_resize_tokens_embeddings(self):
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
return
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
if self.model_tester.is_training is False:
model.eval()
model_vocab_size = config.vocab_size
# Retrieve the embeddings and clone theme
model_embed = model.resize_token_embeddings(model_vocab_size)
cloned_embeddings = model_embed.weight.clone()
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
# make sure that decoder_input_ids are resized
if "decoder_input_ids" in inputs_dict:
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
models_equal = True
for p1, p2 in zip(cloned_embeddings, model_embed.weight):
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
def test_resize_embeddings_untied(self):
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
return
original_config.tie_word_embeddings = False
# if model cannot untied embeddings -> leave test
if original_config.tie_word_embeddings:
return
for model_class in self.all_model_classes:
config = copy.deepcopy(original_config)
model = model_class(config).to(torch_device)
# if no output embeddings -> leave test
if model.get_output_embeddings() is None:
continue
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
model_vocab_size = config.vocab_size
model.resize_token_embeddings(model_vocab_size + 10)
self.assertEqual(model.config.vocab_size, model_vocab_size + 10)
output_embeds = model.get_output_embeddings()
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
# Check bias if present
if output_embeds.bias is not None:
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
model.resize_token_embeddings(model_vocab_size - 15)
self.assertEqual(model.config.vocab_size, model_vocab_size - 15)
# Check that it actually resizes the embeddings matrix
output_embeds = model.get_output_embeddings()
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
# Check bias if present
if output_embeds.bias is not None:
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
if "decoder_input_ids" in inputs_dict:
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
def test_generate_without_input_ids(self):
pass
@staticmethod
def _get_encoder_outputs(
model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
):
encoder = model.get_encoder()
encoder_outputs = encoder(
input_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
num_interleave, dim=0
)
input_ids = input_ids[:, :, 0]
input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + torch.tensor(
[model._get_decoder_start_token_id()], device=input_ids.device
)
attention_mask = None
return encoder_outputs, input_ids, attention_mask
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, mel, seq_length = input_ids.shape
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
num_sequences_in_output = batch_size * num_return_sequences
gen_len = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
)
# scores
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
# Attentions
# encoder
self._check_encoder_attention_for_generate(
output.encoder_attentions, batch_size, config, subsampled_seq_length
)
# decoder
self._check_attentions_for_generate(
num_sequences_in_output,
output.decoder_attentions,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
# Hidden States
# encoder
self._check_encoder_hidden_states_for_generate(
output.encoder_hidden_states, batch_size, config, subsampled_seq_length
)
# decoder
self._check_hidden_states_for_generate(
num_sequences_in_output,
output.decoder_hidden_states,
min_length=1,
max_length=output.sequences.shape[-1],
config=config,
use_cache=use_cache,
)
def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript:
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
try:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
input_features = inputs["input_features"]
decoder_input_ids = inputs["decoder_input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"]
traced_model = torch.jit.trace(model, (input_features, decoder_input_ids, decoder_attention_mask))
except RuntimeError:
self.fail("Couldn't trace module.")
with tempfile.TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
try:
torch.jit.save(traced_model, pt_file_name)
except Exception:
self.fail("Couldn't save module.")
try:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")
model.to(torch_device)
model.eval()
loaded_model.to(torch_device)
loaded_model.eval()
model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict()
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
models_equal = True
for layer_name, p1 in model_state_dict.items():
p2 = loaded_model_state_dict[layer_name]
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
@require_torch
@require_torchaudio
class WhisperModelIntegrationTests(unittest.TestCase):
@cached_property
def default_processor(self):
return WhisperProcessor.from_pretrained("openai/whisper-base")
def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
@slow
def test_tiny_logits_librispeech(self):
torch_device = "cpu"
set_seed(0)
model = WhisperModel.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
input_speech = self._load_datasamples(1)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
with torch.no_grad():
logits = model(
input_features,
decoder_input_ids=torch.tensor([[50258, 50259, 50359]]),
output_hidden_states=False,
output_attentions=False,
return_dict=False,
use_cache=False,
)
# fmt: off
EXPECTED_LOGITS = torch.tensor(
[
2.9892, -6.7607, 5.7348, 3.6096, 0.2152, -5.7321, 4.8855, -1.6407,
0.2823, -1.5718, 10.4269, 3.4427, 0.0219, -8.0612, 3.4784, 8.4246,
4.0575, -2.2864, 11.1084, 0.9963, 0.9884, -8.5154, -3.5469, -9.3713,
0.9786, 3.5435, 7.4850, -5.2579, -1.4366, 10.4841
]
)
# fmt: on
self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
# fmt: off
EXPECTED_GENERATION = torch.tensor(
[
-1.4651, -2.6944, 2.7821, 2.3793, 4.0738, 0.0188, -3.3203, 1.9836,
0.0520, 0.7095, 1.1063, 0.2952, -3.6786, -0.5249, 0.3105, 4.7691,
1.1562, 1.3046, 0.5810, -0.3624, 1.7006, 1.3424, 0.9817, 2.1958,
1.8775, -5.7046, -0.7679, 4.0113, 2.6848, 2.8609
]
)
# fmt: on
head_logits = logits[0] @ model.decoder.embed_tokens.weight.T
self.assertTrue(torch.allclose(head_logits[0, 0, :30].cpu(), EXPECTED_GENERATION, atol=1e-4))
@slow
def test_small_en_logits_librispeech(self):
set_seed(0)
torch_device = "cpu"
model = WhisperModel.from_pretrained("openai/whisper-small.en")
model.to(torch_device)
input_speech = self._load_datasamples(1)
feaure_extractor = WhisperFeatureExtractor()
input_features = feaure_extractor(input_speech, return_tensors="pt").input_features.to(torch_device)
logits = model(
input_features,
decoder_input_ids=torch.tensor([[model.config.decoder_start_token_id]]),
output_hidden_states=False,
output_attentions=False,
use_cache=False,
)
logits = logits.last_hidden_state @ model.decoder.embed_tokens.weight.T
# fmt: off
EXPECTED_LOGITS = torch.tensor(
[
-3.6784, -7.7211, -9.5070, -11.9286, -7.6489, -9.7026, -5.6188,
-8.0104, -4.6238, -5.1833, -9.0485, -3.4079, -5.4874, -2.6935,
-6.3479, -7.3398, -6.9558, -7.6867, -7.4748, -8.3463, -9.9781,
-10.8389, -10.3105, -11.7201, -9.7261, -7.1590, -5.9272, -12.4509,
-11.1146, -8.1918
]
)
# fmt: on
self.assertTrue(torch.allclose(logits[0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
@slow
def test_large_logits_librispeech(self):
set_seed(0)
torch_device = "cpu"
model = WhisperModel.from_pretrained("openai/whisper-large")
model.to(torch_device)
input_speech = self._load_datasamples(1)
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
processed_inputs = processor(audio=input_speech, text="This part of the speech", return_tensors="pt")
input_features = processed_inputs.input_features.to(torch_device)
labels = processed_inputs.labels.to(torch_device)
logits = model(
input_features,
decoder_input_ids=labels,
output_hidden_states=False,
output_attentions=False,
use_cache=False,
)
logits = logits.last_hidden_state @ model.decoder.embed_tokens.weight.T
# fmt: off
EXPECTED_LOGITS = torch.tensor(
[
2.1382, 0.9381, 4.4671, 3.5589, 2.4022, 3.8576, -0.6521, 2.5472,
1.8301, 1.9957, 2.3432, 1.4678, 0.5459, 2.2597, 1.5179, 2.5357,
1.1624, 0.6194, 1.0757, 1.8259, 2.4076, 1.6601, 2.3503, 1.3376,
1.9891, 1.8635, 3.8931, 5.3699, 4.4772, 3.9184
]
)
# fmt: on
self.assertTrue(torch.allclose(logits[0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
@slow
def test_tiny_en_generation(self):
torch_device = "cpu"
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.to(torch_device)
model.config.decoder_start_token_id = 50257
input_speech = self._load_datasamples(1)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
torch_device
)
generated_ids = model.generate(input_features, num_beams=5)
transcript = processor.tokenizer.batch_decode(generated_ids)[0]
EXPECTED_TRANSCRIPT = (
"<|startoftranscript|><|notimestamps|> Mr. Quilter is the apostle of the middle"
" classes, and we are glad to"
)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
@slow
def test_tiny_generation(self):
torch_device = "cpu"
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
input_speech = self._load_datasamples(1)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
torch_device
)
generated_ids = model.generate(input_features, num_beams=5)
transcript = processor.tokenizer.decode(generated_ids[0])
EXPECTED_TRANSCRIPT = (
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle"
" classes and we are glad"
)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
@slow
def test_large_generation(self):
torch_device = "cpu"
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
model.to(torch_device)
input_speech = self._load_datasamples(1)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
torch_device
)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
generated_ids = model.generate(
input_features,
do_sample=False,
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
@slow
def test_large_generation_multilingual(self):
torch_device = "cpu"
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
model.to(torch_device)
ds = load_dataset("common_voice", "ja", split="test", streaming=True)
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
input_speech = next(iter(ds))["audio"]["array"]
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
torch_device
)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
generated_ids = model.generate(input_features, do_sample=False)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
generated_ids = model.generate(
input_features,
do_sample=False,
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " Kimura san ni denwa wo kaite moraimashita"
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
generated_ids = model.generate(input_features, do_sample=False)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
@slow
def test_large_batched_generation(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
input_speech = self._load_datasamples(4)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features
generated_ids = model.generate(input_features)
# fmt: off
EXPECTED_LOGITS = torch.tensor(
[
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
]
)
# fmt: on
self.assertTrue(torch.allclose(generated_ids, EXPECTED_LOGITS))
# fmt: off
EXPECTED_TRANSCRIPT = [
' Mr. Quilter is the apostle of the middle classes, and we are glad to',
" Nor is Mr. Quilter's manner less interesting than his matter.",
" He tells us that at this festive season of the year, with Christmas and roast beef",
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
]
# fmt: on
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
@slow
def test_tiny_en_batched_generation(self):
torch_device = "cuda"
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.to(torch_device)
input_speech = self._load_datasamples(4)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
torch_device
)
generated_ids = model.generate(input_features).to("cpu")
# fmt: off
EXPECTED_LOGITS = torch.tensor(
[
[50257, 50362, 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262, 3504, 6097, 11, 290, 356, 389, 9675, 284],
[50257, 50362, 5414, 318, 1770, 13, 2264, 346, 353, 338, 5642, 1342, 3499, 621, 465, 2300, 13, 50256, 50256, 50256],
[50257, 50362, 679, 4952, 514, 326, 379, 428, 43856, 1622, 286, 262, 614, 11, 351, 6786, 290, 32595, 12023, 28236],
[50257, 50362, 679, 468, 12296, 17188, 1771, 7361, 26113, 18881, 1122, 338, 670, 318, 1107, 8312, 706, 477, 290, 460]
]
)
# fmt: on
self.assertTrue(torch.allclose(generated_ids, EXPECTED_LOGITS))
# fmt: off
EXPECTED_TRANSCRIPT = [
" Mr. Quilter is the apostle of the middle classes, and we are glad to",
" Nor is Mr. Quilter's manner less interesting than his matter.",
" He tells us that at this festive season of the year, with Christmas and roast beef looming",
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can",
]
# fmt: on
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
from transformers import WhisperTokenizer, is_speech_available
from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio
from .test_feature_extraction_whisper import floats_list
if is_speech_available():
from transformers import WhisperFeatureExtractor, WhisperProcessor
@require_torch
@require_torchaudio
@require_sentencepiece
class WhisperProcessorTest(unittest.TestCase):
def setUp(self):
self.checkpoint = "openai/whisper-small.en"
self.tmpdirname = tempfile.mkdtemp()
def get_tokenizer(self, **kwargs):
return WhisperTokenizer.from_pretrained(self.checkpoint, **kwargs)
def get_feature_extractor(self, **kwargs):
return WhisperFeatureExtractor.from_pretrained(self.checkpoint, **kwargs)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_save_load_pretrained_default(self):
tokenizer = self.get_tokenizer()
feature_extractor = self.get_feature_extractor()
processor = WhisperProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
processor.save_pretrained(self.tmpdirname)
processor = WhisperProcessor.from_pretrained(self.tmpdirname)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
self.assertIsInstance(processor.tokenizer, WhisperTokenizer)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
self.assertIsInstance(processor.feature_extractor, WhisperFeatureExtractor)
def test_save_load_pretrained_additional_features(self):
processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
processor.save_pretrained(self.tmpdirname)
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0)
processor = WhisperProcessor.from_pretrained(
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
self.assertIsInstance(processor.tokenizer, WhisperTokenizer)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, WhisperFeatureExtractor)
def test_feature_extractor(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = WhisperProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
raw_speech = floats_list((3, 1000))
input_feat_extract = feature_extractor(raw_speech, return_tensors="np")
input_processor = processor(raw_speech, return_tensors="np")
for key in input_feat_extract.keys():
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
def test_tokenizer(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = WhisperProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
input_str = "This is a test string"
encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str)
for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key])
def test_tokenizer_decode(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = WhisperProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
decoded_processor = processor.batch_decode(predicted_ids)
decoded_tok = tokenizer.batch_decode(predicted_ids)
self.assertListEqual(decoded_tok, decoded_processor)
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.models.whisper import WhisperTokenizer
from transformers.testing_utils import slow
from ...test_tokenization_common import TokenizerTesterMixin
EN_CODE = 50258
ES_CODE = 50256
class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = WhisperTokenizer
test_rust_tokenizer = False
test_sentencepiece = False
def setUp(self):
super().setUp()
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
tokenizer.pad_token_id = 50256
tokenizer.pad_token = "<|endoftext|>"
tokenizer.save_pretrained(self.tmpdirname)
def test_convert_token_and_id(self):
"""Test ``_convert_token_to_id`` and ``_convert_id_to_token``."""
token = "Where"
token_id = 14436
self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
def test_get_vocab(self):
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
self.assertEqual(vocab_keys[0], "!")
self.assertEqual(vocab_keys[1], '"')
self.assertEqual(vocab_keys[-1], "<|notimestamps|>")
self.assertEqual(len(vocab_keys), 50364)
def test_vocab_size(self):
self.assertEqual(self.get_tokenizer().vocab_size, 50257)
def test_full_tokenizer(self):
tokenizer = WhisperTokenizer.from_pretrained(self.tmpdirname)
tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, ["This", "Ġis", "Ġa", "Ġ", "test"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens),
[5723, 307, 257, 220, 31636],
)
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(
tokens,
# fmt: off
['I', 'Ġwas', 'Ġborn', 'Ġin', 'Ġ9', '2000', ',', 'Ġand', 'Ġ', 'this', 'Ġis', 'Ġfals', 'é', '.'],
# fmt: on
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(ids, [40, 390, 4232, 294, 1722, 25743, 11, 293, 220, 11176, 307, 16720, 526, 13])
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(
back_tokens,
# fmt: off
['I', 'Ġwas', 'Ġborn', 'Ġin', 'Ġ9', '2000', ',', 'Ġand', 'Ġ', 'this', 'Ġis', 'Ġfals', 'é', '.'],
# fmt: on
)
def test_tokenizer_slow_store_full_signature(self):
pass
@slow
def test_tokenizer_integration(self):
# fmt: off
expected_encoding = {'input_ids': [[41762, 364, 357, 36234, 1900, 355, 12972, 13165, 354, 12, 35636, 364, 290, 12972, 13165, 354, 12, 5310, 13363, 12, 4835, 8, 3769, 2276, 12, 29983, 45619, 357, 13246, 51, 11, 402, 11571, 12, 17, 11, 5564, 13246, 38586, 11, 16276, 44, 11, 4307, 346, 33, 861, 11, 16276, 7934, 23029, 329, 12068, 15417, 28491, 357, 32572, 52, 8, 290, 12068, 15417, 16588, 357, 32572, 38, 8, 351, 625, 3933, 10, 2181, 13363, 4981, 287, 1802, 10, 8950, 290, 2769, 48817, 1799, 1022, 449, 897, 11, 9485, 15884, 354, 290, 309, 22854, 37535, 13], [13246, 51, 318, 3562, 284, 662, 12, 27432, 2769, 8406, 4154, 282, 24612, 422, 9642, 9608, 276, 2420, 416, 26913, 21143, 319, 1111, 1364, 290, 826, 4732, 287, 477, 11685, 13], [464, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # noqa: E501
# fmt: on
self.tokenizer_integration_test_util(
expected_encoding=expected_encoding, model_name="openai/whisper-tiny.en", padding=False
)
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
checkpoint_name = "openai/whisper-small.en"
transcript = (
"'<|startoftranscript|> <|en|> <|transcribe|> <|notimestamps|> Nor is Mr. Quilters manner less interesting"
" than his matter.<|endoftext|>'"
)
clean_transcript = " Nor is Mr. Quilters manner less interesting than his matter."
french_text = "Bonjour! Il me semble que Mrs Quilters n'était pas présente"
@classmethod
def setUpClass(cls):
cls.tokenizer: WhisperTokenizer = WhisperTokenizer.from_pretrained(cls.checkpoint_name)
return cls
def test_tokenizer_equivalence(self):
text = "다람쥐 헌 쳇바퀴에 타고파"
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="ko")
gpt2_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
gpt2_tokens = gpt2_tokenizer.encode(text)
multilingual_tokens = multilingual_tokenizer.encode(text)
assert gpt2_tokenizer.decode(gpt2_tokens) == text
assert multilingual_tokenizer.decode(multilingual_tokens) == text
assert len(gpt2_tokens) > len(multilingual_tokens)
# fmt: off
EXPECTED_ENG = [
46695, 97, 167, 252, 234, 168, 98, 238, 220, 169,
245, 234, 23821, 111, 229, 167, 108, 242, 169, 222,
112, 168, 245, 238, 220, 169, 225, 222, 166, 111,
254, 169, 234, 234
]
EXPECTED_MULTI = [
9835, 22855, 168, 98, 238, 13431, 234, 43517, 229, 47053,
169, 222, 19086, 19840, 1313, 17974
]
# fmt: on
self.assertListEqual(gpt2_tokens, EXPECTED_ENG)
self.assertListEqual(multilingual_tokens, EXPECTED_MULTI)
def test_tokenizer_special(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
text = "<|startoftranscript|>Hey! How are you feeling? J'ai l'impression que 郷さん est prêt<|endoftext|>"
multilingual_tokens = multilingual_tokenizer.encode(text)
# fmt: off
EXPECTED_MULTI = [
50257, 10814, 0, 1374, 389, 345, 4203, 30, 449, 6,
1872, 300, 6, 11011, 2234, 8358, 16268, 225, 115, 43357,
22174, 1556, 778, 25792, 83, 50256
]
# fmt: on
self.assertListEqual(multilingual_tokens, EXPECTED_MULTI)
self.assertEqual(text, multilingual_tokenizer.decode(multilingual_tokens))
transcript = multilingual_tokenizer.decode(multilingual_tokens, skip_special_tokens=True)
EXPECTED_JAP = "Hey! How are you feeling? J'ai l'impression que 郷さん est prêt"
self.assertEqual(transcript, EXPECTED_JAP)
def test_vocab_size(self):
self.assertEqual(self.tokenizer.vocab_size, 50257)
def test_tokenizer_decode_ignores_language_codes(self):
self.assertIn(ES_CODE, self.tokenizer.all_special_ids)
generated_ids = [ES_CODE, 4, 1601, 47, 7647, 2]
result = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
expected_spanish = self.tokenizer.decode(generated_ids[1:], skip_special_tokens=True)
self.assertEqual(result, expected_spanish)
self.assertNotIn(self.tokenizer.eos_token, result)
def test_batch_encoding(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
batch = ["<|en|><|notimestamps|>", "<|en|><|notimestamps|>I am sure that"]
batch_output = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
# fmt: off
EXPECTED_MULTI = [
[50258, 50362, 50256, 50256, 50256, 50256],
[50258, 50362, 40, 716, 1654, 326]
]
# fmt: on
self.assertListEqual(batch_output, EXPECTED_MULTI)
...@@ -84,6 +84,8 @@ config_common_kwargs = { ...@@ -84,6 +84,8 @@ config_common_kwargs = {
"sep_token_id": 9, "sep_token_id": 9,
"decoder_start_token_id": 10, "decoder_start_token_id": 10,
"exponential_decay_length_penalty": (5, 1.01), "exponential_decay_length_penalty": (5, 1.01),
"suppress_tokens": [0, 1],
"begin_suppress_tokens": 2,
"task_specific_params": {"translation": "some_params"}, "task_specific_params": {"translation": "some_params"},
"problem_type": "regression", "problem_type": "regression",
} }
......
...@@ -51,6 +51,8 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ ...@@ -51,6 +51,8 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"DeformableDetrEncoder", # Building part of bigger (tested) model. "DeformableDetrEncoder", # Building part of bigger (tested) model.
"DeformableDetrDecoder", # Building part of bigger (tested) model. "DeformableDetrDecoder", # Building part of bigger (tested) model.
"OPTDecoder", # Building part of bigger (tested) model. "OPTDecoder", # Building part of bigger (tested) model.
"WhisperDecoder", # Building part of bigger (tested) model.
"WhisperEncoder", # Building part of bigger (tested) model.
"DecisionTransformerGPT2Model", # Building part of bigger (tested) model. "DecisionTransformerGPT2Model", # Building part of bigger (tested) model.
"SegformerDecodeHead", # Building part of bigger (tested) model. "SegformerDecodeHead", # Building part of bigger (tested) model.
"PLBartEncoder", # Building part of bigger (tested) model. "PLBartEncoder", # Building part of bigger (tested) model.
......
...@@ -96,4 +96,5 @@ src/transformers/models/wav2vec2/tokenization_wav2vec2.py ...@@ -96,4 +96,5 @@ src/transformers/models/wav2vec2/tokenization_wav2vec2.py
src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
src/transformers/models/wavlm/modeling_wavlm.py src/transformers/models/wavlm/modeling_wavlm.py
src/transformers/models/whisper/modeling_whisper.py
src/transformers/models/yolos/modeling_yolos.py src/transformers/models/yolos/modeling_yolos.py
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment