Commit 39ac40a9 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2747 failed with stages
in 0 seconds
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
'''
def subsequent_mask(
size: int,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size).
This mask is used only in decoder which works in an auto-regressive mode.
This means the current step could only do attention with its left steps.
In encoder, fully attention is used when streaming is not necessary and
the sequence is not long. In this case, no attention mask is needed.
When streaming is need, chunk-based attention is used in encoder. See
subsequent_chunk_mask for the chunk-based attention mask.
Args:
size (int): size of mask
str device (str): "cpu" or "cuda" or torch.Tensor.device
dtype (torch.device): result dtype
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
ret = torch.ones(size, size, device=device, dtype=torch.bool)
return torch.tril(ret)
'''
def subsequent_mask(
size: int,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size).
This mask is used only in decoder which works in an auto-regressive mode.
This means the current step could only do attention with its left steps.
In encoder, fully attention is used when streaming is not necessary and
the sequence is not long. In this case, no attention mask is needed.
When streaming is need, chunk-based attention is used in encoder. See
subsequent_chunk_mask for the chunk-based attention mask.
Args:
size (int): size of mask
str device (str): "cpu" or "cuda" or torch.Tensor.device
dtype (torch.device): result dtype
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
arange = torch.arange(size, device=device)
mask = arange.expand(size, size)
arange = arange.unsqueeze(-1)
mask = mask <= arange
return mask
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if num_left_chunks < 0:
start = 0
else:
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
ending = min((i // chunk_size + 1) * chunk_size, size)
ret[i, start:ending] = True
return ret
def add_optional_chunk_mask(xs: torch.Tensor,
masks: torch.Tensor,
use_dynamic_chunk: bool,
use_dynamic_left_chunk: bool,
decoding_chunk_size: int,
static_chunk_size: int,
num_decoding_left_chunks: int,
enable_full_context: bool = True):
""" Apply optional mask for encoder.
Args:
xs (torch.Tensor): padded input, (B, L, D), L for max length
mask (torch.Tensor): mask for xs, (B, 1, L)
use_dynamic_chunk (bool): whether to use dynamic chunk or not
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
training.
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
static_chunk_size (int): chunk size for static chunk training/decoding
if it's greater than 0, if use_dynamic_chunk is true,
this parameter will be ignored
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
enable_full_context (bool):
True: chunk size is either [1, 25] or full context(max_len)
False: chunk size ~ U[1, 25]
Returns:
torch.Tensor: chunk mask of the input xs.
"""
# Whether to use chunk mask or not
if use_dynamic_chunk:
max_len = xs.size(1)
if decoding_chunk_size < 0:
chunk_size = max_len
num_left_chunks = -1
elif decoding_chunk_size > 0:
chunk_size = decoding_chunk_size
num_left_chunks = num_decoding_left_chunks
else:
# chunk size is either [1, 25] or full context(max_len).
# Since we use 4 times subsampling and allow up to 1s(100 frames)
# delay, the maximum frame is 100 / 4 = 25.
chunk_size = torch.randint(1, max_len, (1, )).item()
num_left_chunks = -1
if chunk_size > max_len // 2 and enable_full_context:
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_left_chunks = torch.randint(0, max_left_chunks,
(1, )).item()
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
elif static_chunk_size > 0:
num_left_chunks = num_decoding_left_chunks
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
else:
chunk_masks = masks
return chunk_masks
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
seq_range = torch.arange(0,
max_len,
dtype=torch.int64,
device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
# 2022 Ximalaya Inc (Yuguang Yang)
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
# NeMo(https://github.com/NVIDIA/NeMo)
from typing import Union
import math
import warnings
import torch
from torch.optim.lr_scheduler import _LRScheduler
class WarmupLR(_LRScheduler):
"""The WarmupLR scheduler
This scheduler is almost same as NoamLR Scheduler except for following
difference:
NoamLR:
lr = optimizer.lr * model_size ** -0.5
* min(step ** -0.5, step * warmup_step ** -1.5)
WarmupLR:
lr = optimizer.lr * warmup_step ** 0.5
* min(step ** -0.5, step * warmup_step ** -1.5)
Note that the maximum lr equals to optimizer.lr in this scheduler.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: Union[int, float] = 25000,
last_epoch: int = -1,
):
self.warmup_steps = warmup_steps
# __init__() must be invoked before setting field
# because step() is also invoked in __init__()
super().__init__(optimizer, last_epoch)
def __repr__(self):
return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
def get_lr(self):
step_num = self.last_epoch + 1
if self.warmup_steps == 0:
return [lr * step_num**-0.5 for lr in self.base_lrs]
else:
return [
lr * self.warmup_steps**0.5 *
min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
for lr in self.base_lrs
]
def set_step(self, step: int):
self.last_epoch = step
class WarmupPolicy(_LRScheduler):
"""Adds warmup kwargs and warmup logic to lr policy.
All arguments should be passed as kwargs for clarity,
Args:
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
max_steps: Total number of steps while training or `None` for
infinite training
"""
def __init__(self,
optimizer,
*,
warmup_steps=None,
warmup_ratio=None,
max_steps=None,
min_lr=0.0,
last_epoch=-1):
assert not (warmup_steps is not None and warmup_ratio is not None),\
"Either use particular number of step or ratio"
assert warmup_ratio is None or max_steps is not None, \
"If there is a ratio, there should be a total steps"
# It is necessary to assign all attributes *before* __init__,
# as class is wrapped by an inner class.
self.max_steps = max_steps
if warmup_steps is not None:
self.warmup_steps = warmup_steps
elif warmup_ratio is not None:
self.warmup_steps = int(warmup_ratio * max_steps)
else:
self.warmup_steps = 0
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed "
"by the scheduler, please use `get_last_lr()`.",
UserWarning,
stacklevel=2)
step = self.last_epoch
if step <= self.warmup_steps and self.warmup_steps > 0:
return self._get_warmup_lr(step)
if step > self.max_steps:
return [self.min_lr for _ in self.base_lrs]
return self._get_lr(step)
def _get_warmup_lr(self, step):
lr_val = (step + 1) / (self.warmup_steps + 1)
return [initial_lr * lr_val for initial_lr in self.base_lrs]
def _get_lr(self, step):
"""Simple const lr policy"""
return self.base_lrs
class SquareRootConstantPolicy(_LRScheduler):
"""Adds warmup kwargs and warmup logic to lr policy.
All arguments should be passed as kwargs for clarity,
Args:
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
max_steps: Total number of steps while training or `None` for
infinite training
"""
def __init__(self,
optimizer,
*,
constant_steps=None,
constant_ratio=None,
max_steps=None,
min_lr=0.0,
last_epoch=-1):
assert not (constant_steps is not None
and constant_ratio is not None), \
"Either use particular number of step or ratio"
assert constant_ratio is None or max_steps is not None, \
"If there is a ratio, there should be a total steps"
# It is necessary to assign all attributes *before* __init__,
# as class is wrapped by an inner class.
self.max_steps = max_steps
if constant_steps is not None:
self.constant_steps = constant_steps
elif constant_ratio is not None:
self.constant_steps = int(constant_ratio * max_steps)
else:
self.constant_steps = 0
self.constant_lr = 1 / (constant_steps**0.5)
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed "
"by the scheduler, please use `get_last_lr()`.",
UserWarning,
stacklevel=2)
step = self.last_epoch
if step <= self.constant_steps:
return [self.constant_lr for _ in self.base_lrs]
if step > self.max_steps:
return [self.min_lr for _ in self.base_lrs]
return self._get_lr(step)
def _get_lr(self, step):
"""Simple const lr policy"""
return self.base_lrs
class WarmupHoldPolicy(WarmupPolicy):
"""Variant of WarmupPolicy which maintains high
learning rate for a defined number of steps.
All arguments should be passed as kwargs for clarity,
Args:
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
hold_steps: Number of training steps to
hold the learning rate after warm up
hold_ratio: Ratio of hold steps to total steps
max_steps: Total number of steps while training or `None` for
infinite training
"""
def __init__(
self,
optimizer,
*,
warmup_steps=None,
warmup_ratio=None,
hold_steps=None,
hold_ratio=None,
max_steps=None,
min_lr=0.0,
last_epoch=-1,
):
assert not (hold_steps is not None and hold_ratio is not None), \
"Either use particular number of step or ratio"
assert hold_ratio is None or max_steps is not None, \
"If there is a ratio, there should be a total steps"
self.min_lr = min_lr
self._last_warmup_lr = 0.0
# Necessary to duplicate as class attributes are hidden in inner class
self.max_steps = max_steps
if warmup_steps is not None:
self.warmup_steps = warmup_steps
elif warmup_ratio is not None:
self.warmup_steps = int(warmup_ratio * max_steps)
else:
self.warmup_steps = 0
if hold_steps is not None:
self.hold_steps = hold_steps + self.warmup_steps
elif hold_ratio is not None:
self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps
else:
self.hold_steps = 0
super().__init__(
optimizer,
warmup_steps=warmup_steps,
warmup_ratio=warmup_ratio,
max_steps=max_steps,
last_epoch=last_epoch,
min_lr=min_lr,
)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler,"
" "
"please use `get_last_lr()`.",
UserWarning,
stacklevel=2)
step = self.last_epoch
# Warmup phase
if step <= self.warmup_steps and self.warmup_steps > 0:
return self._get_warmup_lr(step)
# Hold phase
if (step >= self.warmup_steps) and (step < self.hold_steps):
return self.base_lrs
if step > self.max_steps:
return [self.min_lr for _ in self.base_lrs]
return self._get_lr(step)
class WarmupAnnealHoldPolicy(_LRScheduler):
"""Adds warmup kwargs and warmup logic to lr policy.
All arguments should be passed as kwargs for clarity,
Args:
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
max_steps: Total number of steps while training or `None` for
infinite training
min_lr: Minimum lr to hold the learning rate after decay at.
constant_steps: Number of steps to keep lr constant at.
constant_ratio: Ratio of steps to keep lr constant.
"""
def __init__(
self,
optimizer,
*,
warmup_steps=None,
warmup_ratio=None,
constant_steps=None,
constant_ratio=None,
max_steps=None,
min_lr=0.0,
last_epoch=-1,
):
assert not (warmup_steps is not None
and warmup_ratio is not None), \
"Either use particular number of step or ratio"
assert not (constant_steps is not None
and constant_ratio is not None), \
"Either use constant_steps or constant_ratio"
assert warmup_ratio is None or max_steps is not None, \
"If there is a ratio, there should be a total steps"
# It is necessary to assign all attributes *before* __init__,
# as class is wrapped by an inner class.
self.max_steps = max_steps
if warmup_steps is not None:
self.warmup_steps = warmup_steps
elif warmup_ratio is not None:
self.warmup_steps = int(warmup_ratio * max_steps)
else:
self.warmup_steps = 0
if constant_steps is not None:
self.constant_steps = constant_steps
elif constant_ratio is not None:
self.constant_steps = int(constant_ratio * max_steps)
else:
self.constant_steps = 0
self.decay_steps = max_steps - (self.constant_steps +
self.warmup_steps)
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed "
"by the scheduler, please use `get_last_lr()`.",
UserWarning,
stacklevel=2)
step = self.last_epoch
# Warmup steps
if self.warmup_steps > 0 and step <= self.warmup_steps:
return self._get_warmup_lr(step)
# Constant steps after warmup and decay
if self.constant_steps > 0 and (
self.warmup_steps + self.decay_steps) < step <= self.max_steps:
return self._get_constant_lr(step)
# Min lr after max steps of updates
if step > self.max_steps:
return [self.min_lr for _ in self.base_lrs]
return self._get_lr(step)
def _get_warmup_lr(self, step):
lr_val = (step + 1) / (self.warmup_steps + 1)
return [initial_lr * lr_val for initial_lr in self.base_lrs]
def _get_constant_lr(self, step):
return [self.min_lr for _ in self.base_lrs]
def _get_lr(self, step):
"""Simple const lr policy"""
return self.base_lrs
def _squareroot_annealing(initial_lr, step, max_steps, min_lr):
mult = ((max_steps - step) / max_steps)**0.5
out_lr = initial_lr * mult
out_lr = max(out_lr, min_lr)
return out_lr
def _square_annealing(initial_lr, step, max_steps, min_lr):
mult = ((max_steps - step) / max_steps)**2
out_lr = initial_lr * mult
out_lr = max(out_lr, min_lr)
return out_lr
def _cosine_annealing(initial_lr, step, max_steps, min_lr):
mult = 0.5 * (1 + math.cos(math.pi * step / max_steps))
out_lr = (initial_lr - min_lr) * mult + min_lr
return out_lr
def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step,
decay_steps, min_lr):
assert max_lr > min_lr
# Use linear warmup for the initial part.
if warmup_steps > 0 and step <= warmup_steps:
return max_lr * float(step) / float(warmup_steps)
# For any steps larger than `decay_steps`, use `min_lr`.
if step > warmup_steps + decay_steps:
return min_lr
# If we are done with the warmup period, use the decay style.
num_steps_ = step - warmup_steps
decay_steps_ = decay_steps
decay_ratio = float(num_steps_) / float(decay_steps_)
assert decay_ratio >= 0.0
assert decay_ratio <= 1.0
delta_lr = max_lr - min_lr
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
return min_lr + coeff * delta_lr
def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
if cycle:
multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps)
decay_steps *= multiplier
else:
step = min(step, decay_steps)
p = step / decay_steps
lr = (initial_lr - min_lr) * math.pow(1.0 - p, power)
lr += min_lr
return lr
def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps,
decay_rate, min_lr):
# hold_steps = total number of steps
# to hold the LR, not the warmup + hold steps.
T_warmup_decay = max(1, warmup_steps**decay_rate)
T_hold_decay = max(1, (step - hold_steps)**decay_rate)
lr = (initial_lr * T_warmup_decay) / T_hold_decay
lr = max(lr, min_lr)
return lr
class SquareAnnealing(WarmupPolicy):
def __init__(self,
optimizer,
*,
max_steps,
min_lr=1e-5,
last_epoch=-1,
**kwargs):
super().__init__(optimizer=optimizer,
max_steps=max_steps,
last_epoch=last_epoch,
min_lr=min_lr,
**kwargs)
def _get_lr(self, step):
new_lrs = [
_square_annealing(
initial_lr=initial_lr,
step=step - self.warmup_steps,
max_steps=self.max_steps - self.warmup_steps,
min_lr=self.min_lr,
) for initial_lr in self.base_lrs
]
return new_lrs
class SquareRootAnnealing(WarmupPolicy):
def __init__(self,
optimizer,
*,
max_steps,
min_lr=0,
last_epoch=-1,
**kwargs):
super().__init__(optimizer=optimizer,
max_steps=max_steps,
last_epoch=last_epoch,
min_lr=min_lr,
**kwargs)
def _get_lr(self, step):
new_lrs = [
_squareroot_annealing(initial_lr=initial_lr,
step=step,
max_steps=self.max_steps,
min_lr=self.min_lr)
for initial_lr in self.base_lrs
]
return new_lrs
class CosineAnnealing(WarmupAnnealHoldPolicy):
def __init__(self,
optimizer,
*,
max_steps,
min_lr=0,
last_epoch=-1,
**kwargs):
super().__init__(optimizer=optimizer,
max_steps=max_steps,
last_epoch=last_epoch,
min_lr=min_lr,
**kwargs)
def _get_lr(self, step):
for initial_lr in self.base_lrs:
if initial_lr < self.min_lr:
raise ValueError(
f"{self} received an initial learning rate "
f"that was lower than the minimum learning rate.")
if self.constant_steps is None or self.constant_steps == 0:
new_lrs = [
_cosine_annealing(
initial_lr=initial_lr,
step=step - self.warmup_steps,
max_steps=self.max_steps - self.warmup_steps,
min_lr=self.min_lr,
) for initial_lr in self.base_lrs
]
else:
new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step)
return new_lrs
def _get_warmup_lr(self, step):
if self.constant_steps is None or self.constant_steps == 0:
return super()._get_warmup_lr(step)
else:
# Use linear warmup for the initial part.
return self._get_linear_warmup_with_cosine_annealing_lr(step)
def _get_constant_lr(self, step):
# Only called when `constant_steps` > 0.
return self._get_linear_warmup_with_cosine_annealing_lr(step)
def _get_linear_warmup_with_cosine_annealing_lr(self, step):
# Cosine Schedule for Megatron LM,
# slightly different warmup schedule + constant LR at the end.
new_lrs = [
_linear_warmup_with_cosine_annealing(
max_lr=self.base_lrs[0],
warmup_steps=self.warmup_steps,
step=step,
decay_steps=self.decay_steps,
min_lr=self.min_lr,
) for _ in self.base_lrs
]
return new_lrs
class NoamAnnealing(_LRScheduler):
def __init__(self,
optimizer,
*,
d_model,
warmup_steps=None,
warmup_ratio=None,
max_steps=None,
min_lr=0.0,
last_epoch=-1):
self._normalize = d_model**(-0.5)
assert not (warmup_steps is not None
and warmup_ratio is not None), \
"Either use particular number of step or ratio"
assert warmup_ratio is None or max_steps is not None, \
"If there is a ratio, there should be a total steps"
# It is necessary to assign all attributes *before* __init__,
# as class is wrapped by an inner class.
self.max_steps = max_steps
if warmup_steps is not None:
self.warmup_steps = warmup_steps
elif warmup_ratio is not None:
self.warmup_steps = int(warmup_ratio * max_steps)
else:
self.warmup_steps = 0
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed "
"by the scheduler, please use `get_last_lr()`.",
UserWarning,
stacklevel=2)
step = max(1, self.last_epoch)
for initial_lr in self.base_lrs:
if initial_lr < self.min_lr:
raise ValueError(
f"{self} received an initial learning rate "
f"that was lower than the minimum learning rate.")
new_lrs = [
self._noam_annealing(initial_lr=initial_lr, step=step)
for initial_lr in self.base_lrs
]
return new_lrs
def _noam_annealing(self, initial_lr, step):
if self.warmup_steps > 0:
mult = self._normalize * min(step**(-0.5),
step * (self.warmup_steps**(-1.5)))
else:
mult = self._normalize * step**(-0.5)
out_lr = initial_lr * mult
if step > self.warmup_steps:
out_lr = max(out_lr, self.min_lr)
return out_lr
class NoamHoldAnnealing(WarmupHoldPolicy):
def __init__(self,
optimizer,
*,
max_steps,
decay_rate=0.5,
min_lr=0.0,
last_epoch=-1,
**kwargs):
"""
From Nemo:
Implementation of the Noam Hold Annealing policy
from the SqueezeFormer paper.
Unlike NoamAnnealing, the peak learning rate
can be explicitly set for this scheduler.
The schedule first performs linear warmup,
then holds the peak LR, then decays with some schedule for
the remainder of the steps.
Therefore the min-lr is still dependent
on the hyper parameters selected.
It's schedule is determined by three factors-
Warmup Steps: Initial stage, where linear warmup
occurs uptil the peak LR is reached. Unlike NoamAnnealing,
the peak LR is explicitly stated here instead of a scaling factor.
Hold Steps: Intermediate stage, where the peak LR
is maintained for some number of steps. In this region,
the high peak LR allows the model to converge faster
if training is stable. However the high LR
may also cause instability during training.
Should usually be a significant fraction of training
steps (around 30-40% of the entire training steps).
Decay Steps: Final stage, where the LR rapidly decays
with some scaling rate (set by decay rate).
To attain Noam decay, use 0.5,
for Squeezeformer recommended decay, use 1.0.
The fast decay after prolonged high LR during
hold phase allows for rapid convergence.
References:
- [Squeezeformer:
An Efficient Transformer for Automatic Speech Recognition]
(https://arxiv.org/abs/2206.00888)
Args:
optimizer: Pytorch compatible Optimizer object.
warmup_steps: Number of training steps in warmup stage
warmup_ratio: Ratio of warmup steps to total steps
hold_steps: Number of training steps to
hold the learning rate after warm up
hold_ratio: Ratio of hold steps to total steps
max_steps: Total number of steps while training or `None` for
infinite training
decay_rate: Float value describing the polynomial decay
after the hold period. Default value
of 0.5 corresponds to Noam decay.
min_lr: Minimum learning rate.
"""
self.decay_rate = decay_rate
super().__init__(optimizer=optimizer,
max_steps=max_steps,
last_epoch=last_epoch,
min_lr=min_lr,
**kwargs)
def _get_lr(self, step):
if self.warmup_steps is None or self.warmup_steps == 0:
raise ValueError(
"Noam scheduler cannot be used without warmup steps")
if self.hold_steps > 0:
hold_steps = self.hold_steps - self.warmup_steps
else:
hold_steps = 0
new_lrs = [
_noam_hold_annealing(
initial_lr,
step=step,
warmup_steps=self.warmup_steps,
hold_steps=hold_steps,
decay_rate=self.decay_rate,
min_lr=self.min_lr,
) for initial_lr in self.base_lrs
]
return new_lrs
def set_step(self, step: int):
self.last_epoch = step
class ConstantLR(_LRScheduler):
"""The ConstantLR scheduler
This scheduler keeps a constant lr
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
):
# __init__() must be invoked before setting field
# because step() is also invoked in __init__()
super().__init__(optimizer)
def get_lr(self):
return self.base_lrs
def set_step(self, step: int):
self.last_epoch = step
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2023 Horizon Inc. (authors: Xingchen Song)
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
import logging
import os
import torch
import json
import re
import datetime
import yaml
# import deepspeed
import torch.optim as optim
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
# from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
from cosyvoice.dataset.dataset import Dataset
from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR
def init_distributed(args):
world_size = int(os.environ.get('WORLD_SIZE', 1))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
rank = int(os.environ.get('RANK', 0))
logging.info('training on multiple gpus, this gpu {}'.format(local_rank) +
', rank {}, world_size {}'.format(rank, world_size))
if args.train_engine == 'torch_ddp':
torch.cuda.set_device(local_rank)
dist.init_process_group(args.dist_backend)
else:
deepspeed.init_distributed(dist_backend=args.dist_backend)
return world_size, local_rank, rank
def init_dataset_and_dataloader(args, configs):
train_dataset = Dataset(args.train_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=True, partition=True)
cv_dataset = Dataset(args.cv_data, data_pipeline=configs['data_pipeline'], mode='train', shuffle=False, partition=False)
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
train_data_loader = DataLoader(train_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
cv_data_loader = DataLoader(cv_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
return train_dataset, cv_dataset, train_data_loader, cv_data_loader
def check_modify_and_save_config(args, configs):
if args.train_engine == "torch_ddp":
configs['train_conf']["dtype"] = 'fp32'
else:
with open(args.deepspeed_config, 'r') as fin:
ds_configs = json.load(fin)
if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
configs['train_conf']["dtype"] = "fp16"
elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
configs['train_conf']["dtype"] = "bf16"
else:
configs['train_conf']["dtype"] = "fp32"
assert ds_configs["train_micro_batch_size_per_gpu"] == 1
# if use deepspeed, override ddp config
configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] * configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
return configs
def wrap_cuda_model(args, model):
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
world_size = int(os.environ.get('WORLD_SIZE', 1))
if args.train_engine == "torch_ddp": # native pytorch ddp
assert (torch.cuda.is_available())
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
else:
if int(os.environ.get('RANK', 0)) == 0:
logging.info("Estimating model states memory needs (zero2)...")
estimate_zero2_model_states_mem_needs_all_live(
model,
num_gpus_per_node=local_world_size,
num_nodes=world_size // local_world_size)
return model
def init_optimizer_and_scheduler(args, configs, model):
if configs['train_conf']['optim'] == 'adam':
optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
elif configs['train_conf']['optim'] == 'adamw':
optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
else:
raise ValueError("unknown optimizer: " + configs['train_conf'])
if configs['train_conf']['scheduler'] == 'warmuplr':
scheduler_type = WarmupLR
scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
scheduler_type = NoamHoldAnnealing
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
elif configs['train_conf']['scheduler'] == 'constantlr':
scheduler_type = ConstantLR
scheduler = ConstantLR(optimizer)
else:
raise ValueError("unknown scheduler: " + configs['train_conf'])
# use deepspeed optimizer for speedup
if args.train_engine == "deepspeed":
def scheduler(opt):
return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
model, optimizer, _, scheduler = deepspeed.initialize(
args=args,
model=model,
optimizer=None,
lr_scheduler=scheduler,
model_parameters=model.parameters())
return model, optimizer, scheduler
def init_summarywriter(args):
writer = None
if int(os.environ.get('RANK', 0)) == 0:
os.makedirs(args.model_dir, exist_ok=True)
writer = SummaryWriter(args.tensorboard_dir)
return writer
def save_model(model, model_name, info_dict):
rank = int(os.environ.get('RANK', 0))
model_dir = info_dict["model_dir"]
save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))
if info_dict["train_engine"] == "torch_ddp":
if rank == 0:
torch.save(model.module.state_dict(), save_model_path)
else:
with torch.no_grad():
model.save_checkpoint(save_dir=model_dir,
tag=model_name,
client_state=info_dict)
if rank == 0:
info_path = re.sub('.pt$', '.yaml', save_model_path)
info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
with open(info_path, 'w') as fout:
data = yaml.dump(info_dict)
fout.write(data)
logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path))
def cosyvoice_join(group_join, info_dict):
world_size = int(os.environ.get('WORLD_SIZE', 1))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
rank = int(os.environ.get('RANK', 0))
if info_dict["batch_idx"] != 0:
# we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr
try:
dist.monitored_barrier(group=group_join,
timeout=group_join.options._timeout)
return False
except RuntimeError as e:
logging.info("Detected uneven workload distribution: {}\n".format(e) +
"Break current worker to manually join all workers, " +
"world_size {}, current rank {}, current local_rank {}\n".
format(world_size, rank, local_rank))
return True
else:
return False
def batch_forward(model, batch, info_dict):
device = int(os.environ.get('LOCAL_RANK', 0))
dtype = info_dict["dtype"]
if dtype == "fp16":
dtype = torch.float16
elif dtype == "bf16":
dtype = torch.bfloat16
else: # fp32
dtype = torch.float32
if info_dict['train_engine'] == 'torch_ddp':
autocast = nullcontext()
else:
autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
with autocast:
info_dict['loss_dict'] = model(batch, device)
return info_dict
def batch_backward(model, info_dict):
if info_dict["train_engine"] == "deepspeed":
scaled_loss = model.backward(info_dict['loss_dict']['loss'])
else:
scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
scaled_loss.backward()
info_dict['loss_dict']['loss'] = scaled_loss
return info_dict
def update_parameter_and_lr(model, optimizer, scheduler, info_dict):
grad_norm = 0.0
if info_dict['train_engine'] == "deepspeed":
info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary()
model.step()
grad_norm = model.get_global_grad_norm()
elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
if torch.isfinite(grad_norm):
optimizer.step()
optimizer.zero_grad()
scheduler.step()
info_dict["lr"] = optimizer.param_groups[0]['lr']
info_dict["grad_norm"] = grad_norm
return info_dict
def log_per_step(writer, info_dict):
tag = info_dict["tag"]
epoch = info_dict.get('epoch', 0)
step = info_dict["step"]
batch_idx = info_dict["batch_idx"]
loss_dict = info_dict['loss_dict']
rank = int(os.environ.get('RANK', 0))
# only rank 0 write to tensorboard to avoid multi-process write
if writer is not None:
if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \
(info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0):
for k in ['epoch', 'lr', 'grad_norm']:
writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
for k, v in loss_dict.items():
writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
# TRAIN & CV, Shell log (stdout)
if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0:
log_str = '{} Batch {}/{} '.format(tag, epoch, batch_idx + 1)
for name, value in loss_dict.items():
log_str += '{} {:.6f} '.format(name, value)
if tag == "TRAIN":
log_str += 'lr {:.8f} grad_norm {:.6f}'.format(
info_dict["lr"], info_dict['grad_norm'])
log_str += ' rank {}'.format(rank)
logging.debug(log_str)
def log_per_save(writer, info_dict):
tag = info_dict["tag"]
epoch = info_dict["epoch"]
step = info_dict["step"]
loss_dict = info_dict["loss_dict"]
lr = info_dict['lr']
rank = int(os.environ.get('RANK', 0))
logging.info(
'Epoch {} Step {} CV info lr {} {} rank {}'.format(
epoch, step + 1, lr, rank, ' '.join(['{}_{}'.format(k, v) for k, v in loss_dict.items()])))
if writer is not None:
for k in ['epoch', 'lr']:
writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
for k, v in loss_dict.items():
writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
import torch
import torchaudio
import numpy as np
import re
from hyperpyyaml import load_hyperpyyaml
import uuid
from collections import defaultdict
def fade_in_out(fade_in_mel, fade_out_mel, window):
device = fade_in_mel.device
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
mel_overlap_len = int(window.shape[0] / 2)
fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
return fade_in_mel.to(device)
class AudioDecoder:
def __init__(self, config_path, flow_ckpt_path, hift_ckpt_path, device="cuda"):
self.device = device
with open(config_path, 'r') as f:
self.scratch_configs = load_hyperpyyaml(f)
# Load models
self.flow = self.scratch_configs['flow']
self.flow.load_state_dict(torch.load(flow_ckpt_path, map_location=self.device))
self.hift = self.scratch_configs['hift']
self.hift.load_state_dict(torch.load(hift_ckpt_path, map_location=self.device))
# Move models to the appropriate device
self.flow.to(self.device)
self.hift.to(self.device)
self.mel_overlap_dict = defaultdict(lambda: None)
self.hift_cache_dict = defaultdict(lambda: None)
self.token_min_hop_len = 2 * self.flow.input_frame_rate
self.token_max_hop_len = 4 * self.flow.input_frame_rate
self.token_overlap_len = 5
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
self.mel_window = np.hamming(2 * self.mel_overlap_len)
# hift cache
self.mel_cache_len = 1
self.source_cache_len = int(self.mel_cache_len * 256)
# speech fade in out
self.speech_window = np.hamming(2 * self.source_cache_len)
def token2wav(self, token, uuid, prompt_token=torch.zeros(1, 0, dtype=torch.int32),
prompt_feat=torch.zeros(1, 0, 80), embedding=torch.zeros(1, 192), finalize=False, option_steps=10):
tts_mel = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(
self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(
self.device),
embedding=embedding.to(self.device),
option_steps=option_steps)
# mel overlap fade in out
if self.mel_overlap_dict[uuid] is not None:
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
# append hift cache
if self.hift_cache_dict[uuid] is not None:
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
else:
hift_cache_source = torch.zeros(1, 1, 0)
# _tts_mel=tts_mel.contiguous()
# keep overlap mel and hift cache
if finalize is False:
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
'source': tts_source[:, :, -self.source_cache_len:],
'speech': tts_speech[:, -self.source_cache_len:]}
# if self.hift_cache_dict[uuid] is not None:
# tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
tts_speech = tts_speech[:, :-self.source_cache_len]
else:
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
del self.hift_cache_dict[uuid]
del self.mel_overlap_dict[uuid]
# if uuid in self.hift_cache_dict.keys() and self.hift_cache_dict[uuid] is not None:
# tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
return tts_speech, tts_mel
def offline_inference(self, token):
this_uuid = str(uuid.uuid1())
tts_speech, tts_mel = self.token2wav(token, uuid=this_uuid, finalize=True)
return tts_speech.cpu()
def stream_inference(self, token):
token.to(self.device)
this_uuid = str(uuid.uuid1())
# Prepare other necessary input tensors
llm_embedding = torch.zeros(1, 192).to(self.device)
prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device)
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device)
tts_speechs = []
tts_mels = []
block_size = self.flow.encoder.block_size
prev_mel = None
for idx in range(0, token.size(1), block_size):
# if idx>block_size: break
tts_token = token[:, idx:idx + block_size]
print(tts_token.size())
if prev_mel is not None:
prompt_speech_feat = torch.cat(tts_mels, dim=-1).transpose(1, 2)
flow_prompt_speech_token = token[:, :idx]
if idx + block_size >= token.size(-1):
is_finalize = True
else:
is_finalize = False
tts_speech, tts_mel = self.token2wav(tts_token, uuid=this_uuid,
prompt_token=flow_prompt_speech_token.to(self.device),
prompt_feat=prompt_speech_feat.to(self.device), finalize=is_finalize)
prev_mel = tts_mel
prev_speech = tts_speech
print(tts_mel.size())
tts_speechs.append(tts_speech)
tts_mels.append(tts_mel)
# Convert Mel spectrogram to audio using HiFi-GAN
tts_speech = torch.cat(tts_speechs, dim=-1).cpu()
return tts_speech.cpu()
"""
A model worker with transformers libs executes the model.
Run BF16 inference with:
python model_server.py --host localhost --model-path THUDM/glm-4-voice-9b --port 10000 --dtype bfloat16 --device cuda:0
Run Int4 inference with:
python model_server.py --host localhost --model-path THUDM/glm-4-voice-9b --port 10000 --dtype int4 --device cuda:0
"""
import argparse
import json
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
from transformers.generation.streamers import BaseStreamer
import torch
import uvicorn
from threading import Thread
from queue import Queue
class TokenStreamer(BaseStreamer):
def __init__(self, skip_prompt: bool = False, timeout=None):
self.skip_prompt = skip_prompt
# variables used in the streaming process
self.token_queue = Queue()
self.stop_signal = None
self.next_tokens_are_prompt = True
self.timeout = timeout
def put(self, value):
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
for token in value.tolist():
self.token_queue.put(token)
def end(self):
self.token_queue.put(self.stop_signal)
def __iter__(self):
return self
def __next__(self):
value = self.token_queue.get(timeout=self.timeout)
if value == self.stop_signal:
raise StopIteration()
else:
return value
class ModelWorker:
def __init__(self, model_path, dtype="bfloat16", device='cuda'):
self.device = device
self.bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
) if dtype == "int4" else None
self.glm_model = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
quantization_config=self.bnb_config if self.bnb_config else None,
device_map={"": 0}
).eval()
self.glm_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
@torch.inference_mode()
def generate_stream(self, params):
tokenizer, model = self.glm_tokenizer, self.glm_model
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 256))
inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.device)
streamer = TokenStreamer(skip_prompt=True)
thread = Thread(
target=model.generate,
kwargs=dict(
**inputs,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
streamer=streamer
)
)
thread.start()
for token_id in streamer:
yield (json.dumps({"token_id": token_id, "error_code": 0}) + "\n").encode()
def generate_stream_gate(self, params):
try:
for x in self.generate_stream(params):
yield x
except Exception as e:
print("Caught Unknown Error", e)
ret = {
"text": "Server Error",
"error_code": 1,
}
yield (json.dumps(ret) + "\n").encode()
app = FastAPI()
@app.post("/generate_stream")
async def generate_stream(request: Request):
params = await request.json()
generator = worker.generate_stream_gate(params)
return StreamingResponse(generator)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--port", type=int, default=10000)
parser.add_argument("--model-path", type=str, default="THUDM/glm-4-voice-9b")
args = parser.parse_args()
worker = ModelWorker(args.model_path, args.dtype, args.device)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
conformer==0.3.2
deepspeed==0.14.2; sys_platform == 'linux'
diffusers==0.27.2
fastapi==0.115.3
fastapi-cli==0.0.4
gdown==5.1.0
gradio==5.3.0
grpcio==1.57.0
grpcio-tools==1.57.0
huggingface_hub==0.25.2
hydra-core==1.3.2
HyperPyYAML==1.2.2
inflect==7.3.1
librosa==0.10.2
lightning==2.2.4
matplotlib==3.7.5
modelscope==1.15.0
networkx==3.1
numpy==1.24.4
omegaconf==2.3.0
onnxruntime-gpu==1.16.0; sys_platform == 'linux'
onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows'
openai-whisper==20231117
protobuf==4.25
pydantic==2.7.0
rich==13.7.1
Requests==2.32.3
safetensors==0.4.5
soundfile==0.12.1
tensorboard==2.14.0
transformers==4.44.1
uvicorn==0.32.0
wget==3.2
WeTextProcessing==1.0.3
torch==2.3.0
torchaudio==2.3.0
from transformers import WhisperConfig
class WhisperVQConfig(WhisperConfig):
def __init__(self,
pooling_kernel_size=None,
pooling_type="max",
pooling_position=0,
quantize_vocab_size=None,
quantize_position=16,
quantize_commit_coefficient=0.25,
quantize_loss_scale=1.0,
quantize_ema_decay=None,
quantize_restart_interval=None,
quantize_encoder_only=False,
quantize_causal_encoder=False,
quantize_causal_block_size=None,
skip_language_detection=False,
encoder_causal_attention=False,
encoder_causal_convolution=False,
**kwargs):
self.pooling_kernel_size = pooling_kernel_size
self.pooling_type = pooling_type
self.pooling_position = pooling_position
self.quantize_vocab_size = quantize_vocab_size
self.quantize_position = quantize_position
self.quantize_commit_coefficient = quantize_commit_coefficient
self.quantize_loss_scale = quantize_loss_scale
self.quantize_ema_decay = quantize_ema_decay
self.quantize_restart_interval = quantize_restart_interval
self.quantize_encoder_only = quantize_encoder_only
self.quantize_causal_encoder = quantize_causal_encoder
self.quantize_causal_block_size = quantize_causal_block_size
self.skip_language_detection = skip_language_detection
self.encoder_causal_attention = encoder_causal_attention
self.encoder_causal_convolution = encoder_causal_convolution
super().__init__(**kwargs)
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
import warnings
import zlib
from typing import Callable, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers.cache_utils import EncoderDecoderCache
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import (
LogitsProcessorList,
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor,
WhisperNoSpeechDetection,
WhisperTimeStampLogitsProcessor,
)
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.modeling_outputs import BaseModelOutput
from transformers.utils import logging
from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
logger = logging.get_logger(__name__)
def _median_filter(inputs: torch.Tensor, filter_width: int) -> torch.Tensor:
"""
Applies a median filter of width `filter_width` along the last dimension of the input.
The `inputs` tensor is assumed to be 3- or 4-dimensional.
"""
if filter_width <= 0 or filter_width % 2 != 1:
raise ValueError("`filter_width` should be an odd number")
pad_width = filter_width // 2
if inputs.shape[-1] <= pad_width:
return inputs
# Pad the left and right edges.
inputs = nn.functional.pad(inputs, (pad_width, pad_width, 0, 0), mode="reflect")
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
result = inputs.unfold(-1, filter_width, 1).sort()[0][..., pad_width]
return result
def _dynamic_time_warping(matrix: np.ndarray):
"""
Measures similarity between two temporal sequences: the input audio and the output tokens. Used to generate
token-level timestamps.
"""
output_length, input_length = matrix.shape
cost = np.ones((output_length + 1, input_length + 1), dtype=np.float32) * np.inf
trace = -np.ones((output_length + 1, input_length + 1), dtype=np.float32)
cost[0, 0] = 0
for j in range(1, input_length + 1):
for i in range(1, output_length + 1):
c0 = cost[i - 1, j - 1]
c1 = cost[i - 1, j]
c2 = cost[i, j - 1]
if c0 < c1 and c0 < c2:
c, t = c0, 0
elif c1 < c0 and c1 < c2:
c, t = c1, 1
else:
c, t = c2, 2
cost[i, j] = matrix[i - 1, j - 1] + c
trace[i, j] = t
# backtrace
i = trace.shape[0] - 1
j = trace.shape[1] - 1
trace[0, :] = 2
trace[:, 0] = 1
text_indices = []
time_indices = []
while i > 0 or j > 0:
text_indices.append(i - 1)
time_indices.append(j - 1)
if trace[i, j] == 0:
i -= 1
j -= 1
elif trace[i, j] == 1:
i -= 1
elif trace[i, j] == 2:
j -= 1
else:
raise RuntimeError(
f"Internal error in dynamic time warping. Unexpected trace[{i}, {j}]. Please file a bug report."
)
text_indices = np.array(text_indices)[::-1]
time_indices = np.array(time_indices)[::-1]
return text_indices, time_indices
def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name):
if logits_processor is not None:
logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None)
if logit_processor:
return getattr(logit_processor, attribute_name, None)
return None
def _pad_to_max_length(
current_segments,
pad_token_id,
device,
padding_side="right",
padding="longest",
bos_token_tensor=None,
cut_off_length=None,
):
max_total_length = 0
sequences = []
if padding_side not in ["right", "left"]:
raise ValueError(f"`padding_side` must be either 'right' or 'left', not {padding_side}")
if padding not in ["longest", "max_length"]:
raise ValueError(f"`padding` must be either 'longest' or 'max_length', not {padding}")
elif padding == "max_length" and cut_off_length is None:
raise ValueError("`cut_off_length` must be specified when `padding='max_length'`")
for current_segment_list in current_segments:
if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0:
sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1)
if cut_off_length is not None:
sequence = sequence[-cut_off_length:]
if bos_token_tensor is not None:
sequence = torch.cat([bos_token_tensor, sequence])
sequences.append(sequence)
max_total_length = max(max_total_length, len(sequences[-1]))
elif bos_token_tensor is not None:
sequences.append(bos_token_tensor)
else:
sequences.append(torch.tensor([], device=device))
max_total_length = cut_off_length + 1 if padding == "max_length" else max_total_length
for i in range(len(current_segments)):
pad_length = max_total_length - len(sequences[i])
pad = (0, pad_length) if padding_side == "right" else (pad_length, 0)
sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)
sequences = torch.stack(sequences, dim=0)
return sequences
class WhisperGenerationMixin:
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None):
"""
Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder
cross-attentions will be cropped before applying DTW.
Returns:
tensor containing the timestamps in seconds for each predicted token
"""
# Create a list with `decoder_layers` elements, each a tensor of shape
# (batch size, attention_heads, output length, input length).
cross_attentions = []
for i in range(self.config.decoder_layers):
cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))
# Select specific cross-attention layers and heads. This is a tensor
# of shape (batch size, num selected, output length, input length).
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
weights = weights.permute([1, 0, 2, 3])
weight_length = None
if "beam_indices" in generate_outputs:
# If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
# since the beam search strategy chooses the most probable sequences at the end of the search.
# In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
weights = weights[:, :, :weight_length]
# If beam index is still -1, it means that the associated token id is EOS
# We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
beam_indices = generate_outputs.beam_indices[:, :weight_length]
beam_indices = beam_indices.masked_fill(beam_indices == -1, 0)
# Select the cross attention from the right beam for each output sequences
weights = torch.stack(
[
torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i])
for i in range(beam_indices.shape[1])
],
dim=2,
)
# make sure timestamps are as long as weights
input_length = weight_length or cross_attentions[0].shape[2]
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)[:, : input_length + 1]
batch_size = timestamps.shape[0]
if num_frames is not None:
# two cases:
# 1. num_frames is the same for each sample -> compute the DTW matrix for each sample in parallel
# 2. num_frames is different, compute the DTW matrix for each sample sequentially
# we're using np.unique because num_frames can be int/list/tuple
if isinstance(num_frames, int):
weights = weights[..., : num_frames // 2]
elif isinstance(num_frames, (list, tuple, np.ndarray)) and len(np.unique(num_frames)) == 1:
weights = weights[..., : num_frames[0] // 2]
elif isinstance(num_frames, (torch.Tensor)) and len(torch.unique(num_frames)) == 1:
weights = weights[..., : num_frames[0] // 2]
else:
# num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences
repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
num_frames = np.repeat(num_frames, repeat_time)
if num_frames is None or isinstance(num_frames, int):
# Normalize and smoothen the weights.
std = torch.std(weights, dim=-2, keepdim=True, unbiased=False)
mean = torch.mean(weights, dim=-2, keepdim=True)
weights = (weights - mean) / std
weights = _median_filter(weights, self.config.median_filter_width)
# Average the different cross-attention heads.
weights = weights.mean(dim=1)
# Perform dynamic time warping on each element of the batch.
for batch_idx in range(batch_size):
if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray, torch.Tensor)):
matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2]
# Normalize and smoothen the weights.
std = torch.std(matrix, dim=-2, keepdim=True, unbiased=False)
mean = torch.mean(matrix, dim=-2, keepdim=True)
matrix = (matrix - mean) / std
matrix = _median_filter(matrix, self.config.median_filter_width)
# Average the different cross-attention heads.
matrix = matrix.mean(dim=0)
else:
matrix = weights[batch_idx]
text_indices, time_indices = _dynamic_time_warping(-matrix.cpu().double().numpy())
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps] * time_precision
timestamps[batch_idx, 1:] = torch.tensor(jump_times)
return timestamps
def generate(
self,
input_features: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: bool = False,
return_timestamps: Optional[bool] = None,
task: Optional[str] = None,
language: Optional[Union[str, List[str]]] = None,
is_multilingual: Optional[bool] = None,
prompt_ids: Optional[torch.Tensor] = None,
prompt_condition_type: Optional[str] = None, # first-segment, all-segments
condition_on_prev_tokens: Optional[bool] = None,
temperature: Optional[Union[float, Tuple[float, ...]]] = None,
compression_ratio_threshold: Optional[float] = None,
logprob_threshold: Optional[float] = None,
no_speech_threshold: Optional[float] = None,
num_segment_frames: Optional[int] = None,
attention_mask: Optional[torch.Tensor] = None,
time_precision: float = 0.02,
return_token_timestamps: Optional[bool] = None,
return_segments: bool = False,
return_dict_in_generate: Optional[bool] = None,
**kwargs,
):
"""
Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids.
<Tip warning={true}>
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
For an overview of generation strategies and code examples, check out the [following
guide](./generation_strategies).
</Tip>
Parameters:
input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904).
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
return_timestamps (`bool`, *optional*):
Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
task (`str`, *optional*):
Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
will be updated accordingly.
language (`str` or list of `str`, *optional*):
Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. For
batched generation, a list of language tokens can be passed. You can find all the possible language
tokens in the `model.generation_config.lang_to_id` dictionary.
is_multilingual (`bool`, *optional*):
Whether or not the model is multilingual.
prompt_ids (`torch.Tensor`, *optional*):
Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is
provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
prompt_condition_type (`str`, *optional*):
Only relevant for long-form transcription. Condition type of `prompt_ids`. 'first-segment' means only the first segment is conditioned on `prompt_ids`. 'all-segments' means each segment is conditioned on `prompt_ids`. Make sure to enable `condition_on_prev_tokens` for 'all-segments'.
Defaults to 'first-segment'. For short-term transcription only 'first-segment' is possible.
condition_on_prev_tokens (`bool`, *optional*):
Only relevant for long-form transcription. Whether to condition each segment on the previous segment.
As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
performance.
temperature (`float` or list of `float`, *optional*):
The temperature to be used for generation. Passing a single `float` value and `do_sample=True` activates
generation using sampling. For long-form transcription, temperature fallback can be activated by passing
a list of float values such as (0.0, 0.2, 0.4, 0.6, 0.8, 1.0). As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
performance.
compression_ratio_threshold (`float`, *optional*):
Only relevant for long-form transcription. If defined, the zlib compression rate of each segment will be computed. If the compression rate of
a segment is higher than `compression_ratio_threshold`, temperature fallback is activated: the generated segment is discarded and the generation is
repeated using a higher temperature. The intuition behind this feature is that segments with very high compression rates
suffer from a lot of repetition. The unwanted repetition can be reduced by injecting more randomness by increasing the temperature. If `compression_ratio_threshold` is defined
make sure that `temperature` is a list of values. A common value for `compression_ratio_threshold` is 1.35.
As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
performance.
logprob_threshold (`float`, *optional*):
Only relevant for long-form transcription. If defined, the average log-probability of each segment will be computed. If the log-probability of
a given segment is lower than `logprob_threshold`, temperature fallback is activated: the generated segment is discarded and the generation is
repeated using a higher temperature. The intuition behind this feature is that segments of low log-probability
can be improved by injecting more randomness by increasing the temperature. If `logprob_threshold` is defined
make sure that `temperature` is a list of values. A common value for `logprob_threshold` is -1.0.
As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
performance.
no_speech_threshold (`float`, *optional*):
Only relevant for long-form transcription. If defined, the "no-speech" token combined with the `logprob_threshold`
is used to determine whether a segment contains only silence. In this case, the transcription for this segment
is skipped.
As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve
performance.
num_segment_frames (`int`, *optional*):
The number of frames a single segment is made of. If not defined, `num_segment_frames` defaults to the model's stride
times the maximum input length.
attention_mask (`torch.Tensor`, *optional*):
`attention_mask` needs to be passed when doing long-form transcription using a batch size > 1.
time_precision (`int`, *optional*, defaults to 0.02):
The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts
for 20 ms.
return_token_timestamps (`bool`, *optional*):
Whether to return token-level timestamps with the text. This can be used with or without the
`return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
words.
return_segments (`bool`, *optional*, defaults to `False`):
Whether to additionally return a list of all segments. Note that this option can only be enabled
when doing long-form transcription.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of just returning the generated tokens.
Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when
`return_segments` is set True. In this case the generation outputs of each segment is added to each
segment.
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
Return:
[`~utils.ModelOutput`] or `torch.LongTensor` or `Dict[str, Any]`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor` or a dict of segments when `return_segments=True`.
If the passed input is > 30 seconds / > 3000 mel input features and `return_segments=True` then a dictionary of generated sequence ids, called `sequences` and a list of each generated segment is returned.
else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are:
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
else only the generated output sequence ids are returned.
Example:
- *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate.
```python
>>> import torch
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset, Audio
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> model.cuda() # doctest: +IGNORE_RESULT
>>> # load audios > 30 seconds
>>> ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
>>> # resample to 16kHz
>>> ds = ds.cast_column("audio", Audio(sampling_rate=16000))
>>> # take first 8 audios and retrieve array
>>> audio = ds[:8]["audio"]
>>> audio = [x["array"] for x in audio]
>>> # make sure to NOT truncate the input audio, to return the `attention_mask` and to pad to the longest audio
>>> inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
>>> inputs = inputs.to("cuda", torch.float32)
>>> # transcribe audio to ids
>>> generated_ids = model.generate(**inputs)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
>>> transcription[0]
" Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile."
```
- *Shortform transcription*: If passed mel input features are < 30 seconds, the whole audio will be transcribed with a single call to generate.
```python
>>> import torch
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> generated_ids = model.generate(inputs=input_features)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> transcription
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```
"""
# 0. deprecate old inputs
if "inputs" in kwargs:
input_features = kwargs.pop("inputs")
warnings.warn(
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
FutureWarning,
)
# 1. prepare generation config
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)
# 2. set global generate variables
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
num_segment_frames = input_stride * self.config.max_source_positions
batch_size, total_input_frames = self._retrieve_total_input_frames(
input_features=input_features, input_stride=input_stride, kwargs=kwargs
)
is_shortform = total_input_frames <= num_segment_frames
# 3. Make sure generation config is correctly set
# Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
return_dict_in_generate = self._set_return_outputs(
return_dict_in_generate=return_dict_in_generate,
return_token_timestamps=return_token_timestamps,
logprob_threshold=logprob_threshold,
generation_config=generation_config,
)
timestamp_begin = self._set_return_timestamps(
return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config
)
self._set_language_and_task(
language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
)
self._set_num_frames(
return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
)
self._set_thresholds_and_condition(
generation_config=generation_config,
logprob_threshold=logprob_threshold,
compression_ratio_threshold=compression_ratio_threshold,
no_speech_threshold=no_speech_threshold,
condition_on_prev_tokens=condition_on_prev_tokens,
)
self._set_prompt_condition_type(
generation_config=generation_config,
prompt_condition_type=prompt_condition_type,
)
kwargs["attention_mask"] = attention_mask
# pass self.config for backward compatibility
init_tokens = self._retrieve_init_tokens(
input_features,
batch_size=batch_size,
generation_config=generation_config,
config=self.config,
num_segment_frames=num_segment_frames,
kwargs=kwargs,
)
# passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
# where the input ids are handled explicitly by the generate method
self._check_decoder_input_ids(kwargs=kwargs)
# 3. Retrieve logits processors
device = kwargs["encoder_outputs"][0].device if "encoder_outputs" in kwargs else input_features.device
begin_index = init_tokens.shape[1]
logits_processor = self._retrieve_logit_processors(
generation_config=generation_config,
logits_processor=logits_processor,
begin_index=begin_index, # begin index is index of first generated decoder token
num_beams=kwargs.get("num_beams", 1),
device=device,
)
# 4 Set and retrieve global generation variables
self._set_condition_on_prev_tokens(
condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config
)
temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature
temperature = temperatures[0]
max_frames, seek = self._retrieve_max_frames_and_seek(
batch_size=batch_size,
attention_mask=attention_mask,
total_input_frames=total_input_frames,
is_shortform=is_shortform,
)
# 5 Prepare running variables, list for generation
num_return_sequences = generation_config.num_return_sequences
(
batch_idx_map,
cur_bsz,
input_features,
seek,
max_frames,
init_tokens,
do_condition_on_prev_tokens,
) = self._expand_variables_for_generation(
input_features=input_features,
seek=seek,
max_frames=max_frames,
init_tokens=init_tokens,
batch_size=batch_size,
condition_on_prev_tokens=condition_on_prev_tokens,
generation_config=generation_config,
)
current_segments = self._prepare_segments(
prompt_ids=prompt_ids,
batch_size=cur_bsz,
generation_config=generation_config,
)
# 6 Transcribe audio until we reach the end of all input audios
while (seek < max_frames).any():
# 6.1 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
# in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
# to know which original audio is being decoded
# Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk
input_features, cur_bsz, batch_idx_map = self._maybe_reduce_batch(
input_features=input_features,
seek=seek,
max_frames=max_frames,
cur_bsz=cur_bsz,
batch_idx_map=batch_idx_map,
)
time_offset = seek * time_precision / input_stride
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
# 6.2 cut out next 30s segment from input features
segment_input = self._get_input_segment(
input_features=input_features,
seek=seek,
seek_num_frames=seek_num_frames,
num_segment_frames=num_segment_frames,
cur_bsz=cur_bsz,
batch_idx_map=batch_idx_map,
)
# 6.3 prepare decoder input ids
suppress_tokens = _get_attr_from_logit_processors(
logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens"
)
decoder_input_ids, kwargs = self._prepare_decoder_input_ids(
cur_bsz=cur_bsz,
init_tokens=init_tokens,
current_segments=current_segments,
batch_idx_map=batch_idx_map,
do_condition_on_prev_tokens=do_condition_on_prev_tokens,
prompt_ids=prompt_ids,
generation_config=generation_config,
config=self.config,
device=init_tokens.device,
suppress_tokens=suppress_tokens,
kwargs=kwargs,
)
# 6.4 set max new tokens or max length
self._set_max_new_tokens_and_length(
config=self.config,
decoder_input_ids=decoder_input_ids,
generation_config=generation_config,
)
# 6.5 Set current `begin_index` for all logit processors
if logits_processor is not None:
for proc in logits_processor:
if hasattr(proc, "set_begin_index"):
proc.set_begin_index(decoder_input_ids.shape[-1])
# 6.6 Run generate with fallback
(
seek_sequences,
seek_outputs,
should_skip,
do_condition_on_prev_tokens,
model_output_type,
) = self.generate_with_fallback(
segment_input=segment_input,
decoder_input_ids=decoder_input_ids,
cur_bsz=cur_bsz,
batch_idx_map=batch_idx_map,
seek=seek,
num_segment_frames=num_segment_frames,
max_frames=max_frames,
temperatures=temperatures,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
return_token_timestamps=return_token_timestamps,
do_condition_on_prev_tokens=do_condition_on_prev_tokens,
is_shortform=is_shortform,
batch_size=batch_size,
kwargs=kwargs,
)
# 6.7 In every generated sequence, split by timestamp tokens and extract segments
for i, seek_sequence in enumerate(seek_sequences):
prev_i = batch_idx_map[i]
if should_skip[i]:
seek[prev_i] += seek_num_frames[prev_i]
continue
segments, segment_offset = self._retrieve_segment(
seek_sequence=seek_sequence,
seek_outputs=seek_outputs,
time_offset=time_offset,
timestamp_begin=timestamp_begin,
seek_num_frames=seek_num_frames,
time_precision=time_precision,
input_stride=input_stride,
prev_idx=prev_i,
idx=i,
return_token_timestamps=return_token_timestamps,
)
current_segments[prev_i] += segments
if is_shortform:
seek[prev_i] += max_frames[i]
else:
seek[prev_i] += segment_offset
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
final_segments = (
[x[1:] for x in current_segments]
if (prompt_ids is not None and generation_config.prompt_condition_type == "first-segment")
else current_segments
)
sequences = _pad_to_max_length(
final_segments, generation_config.pad_token_id, device=self.device, padding_side="right"
)
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
if return_segments:
return {"sequences": sequences, "segments": final_segments}
if is_shortform:
# add eos token:
if generation_config.max_new_tokens is None and generation_config.max_length is None:
eos_tokens = torch.full((sequences.shape[0], 1), generation_config.eos_token_id)
sequences = torch.cat([sequences, eos_tokens], dim=-1)
if return_token_timestamps:
outputs = {}
outputs["sequences"] = sequences
outputs["token_timestamps"] = torch.stack([d["token_timestamps"] for d in seek_outputs], dim=0)
else:
outputs = sequences
if return_dict_in_generate and generation_config.return_dict_in_generate:
dict_outputs = self._stack_split_outputs(seek_outputs, model_output_type, sequences.device, kwargs)
if num_return_sequences > 1:
if hasattr(dict_outputs, "encoder_attentions") and dict_outputs.encoder_attentions is not None:
dict_outputs.encoder_attentions = tuple(
dict_outputs.encoder_attentions[i][::num_return_sequences]
for i in range(len(dict_outputs.encoder_attentions))
)
if (
hasattr(dict_outputs, "encoder_hidden_states")
and dict_outputs.encoder_hidden_states is not None
):
dict_outputs.encoder_hidden_states = tuple(
dict_outputs.encoder_hidden_states[i][::num_return_sequences]
for i in range(len(dict_outputs.encoder_hidden_states))
)
if return_token_timestamps:
dict_outputs["token_timestamps"] = outputs["token_timestamps"]
return dict_outputs
return outputs
return sequences
def generate_with_fallback(
self,
segment_input,
decoder_input_ids,
cur_bsz,
batch_idx_map,
seek,
num_segment_frames,
max_frames,
temperatures,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
return_token_timestamps,
do_condition_on_prev_tokens,
is_shortform,
batch_size,
kwargs,
):
kwargs = copy.copy(kwargs)
# 6.6 Batch generate current chunk
seek_sequence_list = [None for _ in range(cur_bsz)]
seek_outputs_list = [None for _ in range(cur_bsz)]
needs_fallback = [False for _ in range(cur_bsz)]
should_skip = [False for _ in range(cur_bsz)]
fallback_index_map = list(range(cur_bsz))
if generation_config.no_speech_threshold is not None:
self._setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs)
for fallback_idx, temperature in enumerate(temperatures):
generation_config.do_sample = temperature is not None and temperature > 0.0
generation_config.temperature = temperature if generation_config.do_sample else 1.0
if generation_config.do_sample:
generation_config.num_beams = 1
generate_kwargs = copy.copy(kwargs)
for key in ["do_sample", "temperature", "num_beams"]:
if key in generate_kwargs:
del generate_kwargs[key]
cur_bsz = decoder_input_ids.shape[0]
if generation_config.cache_implementation == "static" and cur_bsz < batch_size:
segment_input = F.pad(segment_input, (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0)
decoder_input_ids = F.pad(
decoder_input_ids, (0, 0, 0, batch_size - cur_bsz), value=generation_config.pad_token_id
)
if generate_kwargs.get("decoder_attention_mask") is not None:
generate_kwargs["decoder_attention_mask"] = F.pad(
generate_kwargs["decoder_attention_mask"], (0, 0, 0, batch_size - cur_bsz), value=True
)
if generate_kwargs.get("encoder_outputs") is not None:
generate_kwargs["encoder_outputs"] = F.pad(
generate_kwargs["encoder_outputs"], (0, 0, 0, 0, 0, batch_size - cur_bsz), value=0
)
seek_outputs = super().generate(
segment_input,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
decoder_input_ids=decoder_input_ids,
**generate_kwargs,
)
model_output_type = type(seek_outputs)
# post-process sequence tokens and outputs to be in list form
seek_sequences, seek_outputs = self._postprocess_outputs(
seek_outputs=seek_outputs,
decoder_input_ids=decoder_input_ids,
return_token_timestamps=return_token_timestamps,
generation_config=generation_config,
is_shortform=is_shortform,
)
if cur_bsz < batch_size:
seek_sequences = seek_sequences[:cur_bsz]
seek_outputs = seek_outputs[:cur_bsz]
# 6.7 Extract cut sequences from every sequence and check if fallback should be applied
# Loop over each decoded audio individually as each decoding can be of a different length
new_fallback_index_map = []
new_segment_input = []
new_decoder_input_ids = []
new_decoder_attention_mask = []
for i, seek_sequence in enumerate(seek_sequences):
# make sure we cut a predicted EOS token if we are not finished with the generation yet
prev_i = batch_idx_map[fallback_index_map[i]]
is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i]
# remove eos token id
if is_not_final and seek_sequence[-1] == generation_config.eos_token_id:
seek_sequence = seek_sequence[:-1]
if return_token_timestamps and not is_shortform:
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-1]
# remove all padding tokens
if seek_sequence[-1] == generation_config.pad_token_id:
num_paddings = (seek_sequence == generation_config.pad_token_id).sum()
seek_sequence = seek_sequence[:-num_paddings]
if return_token_timestamps and not is_shortform:
seek_outputs[i]["token_timestamps"] = seek_outputs[i]["token_timestamps"][:-num_paddings]
# check which sequences in batch need fallback & which should be skipped
needs_fallback[i], should_skip[i] = self._need_fallback(
seek_sequence,
seek_outputs,
i,
logits_processor,
generation_config,
self.config.vocab_size,
temperature,
)
seek_sequence_list[fallback_index_map[i]] = seek_sequence
seek_outputs_list[fallback_index_map[i]] = seek_outputs[i]
is_low_temperature = temperature is None or temperature < 0.5
do_condition_on_prev_tokens[fallback_index_map[i]] = (
generation_config.condition_on_prev_tokens and is_low_temperature
)
if needs_fallback[i]:
new_fallback_index_map.append(fallback_index_map[i])
new_segment_input.append(segment_input[i])
new_decoder_input_ids.append(decoder_input_ids[i])
if "decoder_attention_mask" in kwargs:
new_decoder_attention_mask.append(kwargs["decoder_attention_mask"][i])
fallback_index_map = new_fallback_index_map
# if no sequence needs to be run with temperature fallback, we're finished
if len(fallback_index_map) == 0 or fallback_idx == len(temperatures) - 1:
seek_sequences = seek_sequence_list
seek_outputs = seek_outputs_list
break
# if we're still in the loop, make sure that decoder_input_ids and segment inputs are tensors
decoder_input_ids = torch.stack(new_decoder_input_ids)
segment_input = torch.stack(new_segment_input)
if "decoder_attention_mask" in kwargs:
kwargs["decoder_attention_mask"] = torch.stack(new_decoder_attention_mask)
return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens, model_output_type
@staticmethod
def _prepare_segments(prompt_ids, batch_size, generation_config):
if prompt_ids is not None and generation_config.prompt_condition_type == "first-segment":
prev_sot_token_id = getattr(generation_config, "prev_sot_token_id", None)
prompt_ids = prompt_ids[1:] if prompt_ids[0] == prev_sot_token_id else prompt_ids
current_segments = [[{"tokens": prompt_ids}] for _ in range(batch_size)]
else:
current_segments = [[] for _ in range(batch_size)]
return current_segments
def _postprocess_outputs(
self, seek_outputs, decoder_input_ids, return_token_timestamps, generation_config, is_shortform
):
# remove all previously passed decoder input ids
start_idx = decoder_input_ids.shape[-1] if not is_shortform else torch.tensor(0)
if isinstance(seek_outputs, torch.Tensor):
seek_outputs = seek_outputs[:, start_idx:]
return seek_outputs, seek_outputs
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
num_frames = getattr(generation_config, "num_frames", None)
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
)
seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]
seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:]
def split_by_batch_index(values, key, batch_idx, is_shortform):
if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
return [v[batch_idx].cpu() for v in values]
if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
elif key == "past_key_values":
if not is_shortform:
# we don't save `past_key_values` as this is too costly for longform
return None
elif isinstance(values, EncoderDecoderCache):
all_past_key_values = []
for layer_idx in range(self.config.decoder_layers):
layer_past_key_values = []
for cache_cls in [values.self_attention_cache, values.cross_attention_cache]:
for v in [cache_cls.key_cache, cache_cls.value_cache]:
layer_past_key_values.append(v[layer_idx][batch_idx][None].cpu())
all_past_key_values.append(tuple(layer_past_key_values))
return tuple(all_past_key_values)
else:
all_past_key_values = []
for v in range(len(values)):
layer_past_key_values = []
for w in values[v]:
layer_past_key_values.append(w[batch_idx][None].cpu())
all_past_key_values.append(tuple(layer_past_key_values))
return tuple(all_past_key_values)
return values[batch_idx].cpu()
sequence_tokens = seek_outputs["sequences"]
seek_outputs = [
{k: split_by_batch_index(v, k, i, is_shortform) for k, v in seek_outputs.items()}
for i in range(sequence_tokens.shape[0])
]
return sequence_tokens, seek_outputs
def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs):
# Stack back seek_outputs tensors after splitting them with the split_by_batch_index method
outputs = {}
for key in seek_outputs[0].keys():
if key == "sequences":
outputs[key] = torch.stack([v[key] for v in seek_outputs], dim=0).to(device)
if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
outputs[key] = tuple(
torch.stack([v[key][i] for v in seek_outputs]).to(device) for i in range(len(seek_outputs[0][key]))
)
if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
outputs[key] = tuple(
tuple(
torch.stack([v[key][i][j] for v in seek_outputs]).squeeze(1).to(device)
for j in range(len(seek_outputs[0][key][0]))
)
for i in range(len(seek_outputs[0][key]))
)
if key == "past_key_values":
past_key_value_type = kwargs.get("past_key_values")
if seek_outputs[0][key] is not None:
outputs[key] = tuple(
tuple(
torch.stack([v[key][i][j] for v in seek_outputs]).squeeze(1).to(device)
for j in range(len(seek_outputs[0][key][0]))
)
for i in range(len(seek_outputs[0][key]))
)
if past_key_value_type is not None and isinstance(past_key_value_type, EncoderDecoderCache):
outputs[key] = past_key_value_type.from_legacy_cache(outputs[key])
else:
outputs[key] = None
return model_output_type(**outputs)
def _need_fallback(
self,
seek_sequence,
seek_outputs,
index,
logits_processor,
generation_config,
vocab_size,
temperature,
):
needs_fallback = False
should_skip = False
if generation_config.compression_ratio_threshold is not None:
compression_ratio = self._retrieve_compression_ratio(seek_sequence, vocab_size)
if compression_ratio > generation_config.compression_ratio_threshold:
needs_fallback = True
if generation_config.logprob_threshold is not None:
if hasattr(seek_outputs[0], "sequences_scores"):
logprobs = [s["sequences_scores"] for s in seek_outputs][index]
else:
scores = seek_outputs[index]["scores"]
logprobs = self._retrieve_avg_logprobs(
scores, seek_sequence, generation_config.eos_token_id, temperature
)
if logprobs < generation_config.logprob_threshold:
needs_fallback = True
if generation_config.no_speech_threshold is not None:
no_speech_prob = _get_attr_from_logit_processors(
logits_processor, WhisperNoSpeechDetection, "no_speech_prob"
)
if (
logprobs < generation_config.logprob_threshold
and no_speech_prob[index] > generation_config.no_speech_threshold
):
needs_fallback = False
should_skip = True
return needs_fallback, should_skip
def _expand_variables_for_generation(
self, input_features, seek, max_frames, init_tokens, batch_size, condition_on_prev_tokens, generation_config
):
if generation_config.num_return_sequences is not None and generation_config.num_return_sequences > 1:
batch_idx_map = list(range(batch_size * generation_config.num_return_sequences))
cur_bsz = len(batch_idx_map)
do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(len(batch_idx_map))]
input_features = input_features.repeat_interleave(generation_config.num_return_sequences, dim=0)
seek = seek.repeat_interleave(generation_config.num_return_sequences, dim=0)
max_frames = max_frames.repeat_interleave(generation_config.num_return_sequences, dim=0)
init_tokens = init_tokens.repeat_interleave(generation_config.num_return_sequences, dim=0)
generation_config.num_return_sequences = 1
else:
cur_bsz = batch_size
batch_idx_map = list(range(cur_bsz))
do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(cur_bsz)]
return (
batch_idx_map,
cur_bsz,
input_features,
seek,
max_frames,
init_tokens,
do_condition_on_prev_tokens,
)
@staticmethod
def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs):
set_inputs = _get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs")
extra_kwargs = {k: v for k, v in kwargs.items() if torch.is_tensor(v)}
set_inputs({"inputs": segment_input, "decoder_input_ids": decoder_input_ids, **extra_kwargs})
@staticmethod
def _retrieve_total_input_frames(input_features, input_stride, kwargs):
if input_features is not None:
return input_features.shape[0], input_features.shape[-1]
if "encoder_outputs" in kwargs:
encoder_outputs_shape = (
kwargs["encoder_outputs"][0].shape
if isinstance(kwargs["encoder_outputs"], BaseModelOutput)
else kwargs["encoder_outputs"].shape
)
return encoder_outputs_shape[0], encoder_outputs_shape[1] * input_stride
raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.")
@staticmethod
def _maybe_warn_unused_inputs(
condition_on_prev_tokens,
temperature,
compression_ratio_threshold,
logprob_threshold,
no_speech_threshold,
total_input_frames,
):
warning_prefix = (
f"Audio input consists of only {total_input_frames}. "
"Short-form transcription is activated."
"{}, but will be ignored."
)
if condition_on_prev_tokens is not None:
logger.warning(warning_prefix.format(f"condition_on_prev_tokens is set to {condition_on_prev_tokens}"))
if compression_ratio_threshold is not None:
logger.warning(
warning_prefix.format(f"compression_ratio_threshold is set to {compression_ratio_threshold}")
)
if logprob_threshold is not None:
logger.warning(warning_prefix.format(f"logprob_threshold is set to {logprob_threshold}"))
if no_speech_threshold is not None:
logger.warning(warning_prefix.format(f"no_speech_threshold is set to {no_speech_threshold}"))
# when passing temperature as a list it cannot just be ignored => throw error in this case
if isinstance(temperature, (list, tuple)):
raise ValueError(
f"Audio input consists of only {total_input_frames}. Short-form transcription is activated."
f"temperature cannot be set to {temperature} which can only be used for temperature fallback for long-form generation. Make sure to set `temperature` to a float value or `None` for short-form generation."
)
@staticmethod
def _set_return_outputs(return_dict_in_generate, return_token_timestamps, logprob_threshold, generation_config):
if return_dict_in_generate is None:
return_dict_in_generate = generation_config.return_dict_in_generate
else:
generation_config.return_dict_in_generate = return_dict_in_generate
generation_config.return_token_timestamps = return_token_timestamps
if return_token_timestamps:
generation_config.return_dict_in_generate = True
generation_config.output_attentions = True
generation_config.output_scores = True
if logprob_threshold is not None:
generation_config.return_dict_in_generate = True
generation_config.output_scores = True
return return_dict_in_generate
def _set_return_timestamps(self, return_timestamps, is_shortform, generation_config):
if return_timestamps is None and hasattr(generation_config, "return_timestamps"):
return_timestamps = generation_config.return_timestamps
if not is_shortform:
if return_timestamps is False:
raise ValueError(
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
"requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
)
logger.info("Setting `return_timestamps=True` for long-form generation.")
return_timestamps = True
if return_timestamps and not hasattr(generation_config, "no_timestamps_token_id"):
raise ValueError(
"You are trying to return timestamps, but the generation config is not properly set. "
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
)
generation_config.return_timestamps = return_timestamps
if hasattr(generation_config, "no_timestamps_token_id"):
timestamp_begin = generation_config.no_timestamps_token_id + 1
else:
# BC for models missing the `no_timestamps_token_id` in the generation config when generating short-form with no timestamps
# We set the timestamp begin token larger than the vocab size, such that the timestamp condition is never met in the decoding loop
timestamp_begin = self.config.vocab_size + 1
return timestamp_begin
@staticmethod
def _set_language_and_task(language, task, is_multilingual, generation_config):
if is_multilingual is not None:
if not hasattr(generation_config, "is_multilingual"):
raise ValueError(
"The generation config is outdated and is thus not compatible with the `is_multilingual` argument "
"to `generate`. Please update the generation config as per the instructions "
"https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
)
generation_config.is_multilingual = is_multilingual
if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual:
if task is not None or language is not None:
raise ValueError(
"Cannot specify `task` or `language` for an English-only model. If the model is intended to be "
"multilingual, pass `is_multilingual=True` to generate, or update the generation config."
)
if language is not None:
if not hasattr(generation_config, "lang_to_id"):
raise ValueError(
"The generation config is outdated and is thus not compatible with the `language` argument "
"to `generate`. Either set the language using the `forced_decoder_ids` in the model config, "
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
)
generation_config.language = language
if task is not None:
if not hasattr(generation_config, "task_to_id"):
raise ValueError(
"The generation config is outdated and is thus not compatible with the `task` argument "
"to `generate`. Either set the task using the `forced_decoder_ids` in the model config, "
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
)
generation_config.task = task
def _retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs):
def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
"""short function to replace num with a itr in lst"""
found = any(i in lst for i in itr)
if found:
lst = [num if i in itr else i for i in lst]
else:
lst.append(num)
return lst
def language_to_id(language: str) -> int:
language = language.lower()
if language in generation_config.lang_to_id.keys():
language_token = language
elif language in TO_LANGUAGE_CODE.keys():
language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
elif language in TO_LANGUAGE_CODE.values():
language_token = f"<|{language}|>"
else:
is_language_code = len(language) == 2
raise ValueError(
f"Unsupported language: {language}. Language should be one of:"
f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
)
if language_token not in generation_config.lang_to_id:
raise ValueError(
f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
"(You should just add it to the generation config)"
)
return generation_config.lang_to_id[language_token]
task = getattr(generation_config, "task", None)
language = getattr(generation_config, "language", None)
forced_decoder_ids = generation_config.forced_decoder_ids
if forced_decoder_ids is not None:
if language is None and task is None and forced_decoder_ids[0][1] is None:
logger.warning_once(
"Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
"This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
)
elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
forced_decoder_ids = config.forced_decoder_ids
if forced_decoder_ids is not None and task is not None:
logger.warning_once(
f"You have passed task={task}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of task={task}."
)
forced_decoder_ids = None
elif forced_decoder_ids is not None and language is not None:
logger.warning_once(
f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}."
)
forced_decoder_ids = None
init_tokens = [generation_config.decoder_start_token_id]
if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1:
i = 1
while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i:
init_tokens += [forced_decoder_ids[0][1]]
forced_decoder_ids = forced_decoder_ids[1:]
i += 1
if len(forced_decoder_ids) > 0:
raise ValueError(
f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}.",
)
# from v4.39 the forced decoder ids are always None in favour of decoder input ids
generation_config.forced_decoder_ids = None
is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
# Make sure language is a list of strings of the correct length
if isinstance(language, (list, tuple)):
if any(l is None for l in language):
raise TypeError(
"Expected `language` to be `None`, a single string (e.g. `'en'`), or a list of strings with length equal to the batch size (e.g. `('en', 'fr')` for a batch size of 2). Got a list containing `None`."
)
if len(language) != batch_size:
raise ValueError(
"When passing a list of languages, the length of the list must match the batch size. "
f"Expected length of {batch_size}, but got {len(language)} languages."
)
languages = language
elif language is None:
# Language will be detected for each item in batch
languages = [None] * batch_size
else:
languages = [language] # Use a length-1 list now, broadcast later
# Separate init_tokens for each language
init_tokens = [copy.copy(init_tokens) for _ in languages]
# Update init_tokens with languages
lang_ids = None
if language is not None:
lang_ids = [language_to_id(l) for l in languages]
elif hasattr(generation_config, "lang_to_id") and is_lang_id_undefined:
# language is not defined or intentially set to `None` to trigger language detection
lang_ids = self.detect_language(
input_features=input_features,
encoder_outputs=kwargs.get("encoder_outputs", None),
attention_mask=kwargs.get("attention_mask", None),
generation_config=generation_config,
num_segment_frames=num_segment_frames,
).tolist()
if lang_ids is not None:
# append or replace lang_ids to init_tokens
for i in range(len(init_tokens)):
if len(init_tokens[i]) > 1:
init_tokens[i][1] = lang_ids[i]
else:
init_tokens[i].append(lang_ids[i])
del languages
# Update init_tokens with task
for i in range(len(init_tokens)):
if task is not None:
if task in TASK_IDS:
init_tokens[i].append(generation_config.task_to_id[generation_config.task])
task_id = generation_config.task_to_id[generation_config.task]
# if task is defined it'll overwrite task ids that might have already been defined via the generation_config
replace_or_add(init_tokens[i], task_id, generation_config.task_to_id.values())
else:
raise ValueError(f"The `{task}`task is not supported. The task should be one of `{TASK_IDS}`")
elif language is not None and hasattr(generation_config, "task_to_id"):
# if language is defined, but no task id is in `init_tokens`, default to transcribe
if not any(ti in init_tokens[i] for ti in generation_config.task_to_id.values()):
init_tokens[i].append(generation_config.task_to_id["transcribe"])
if (
not generation_config.return_timestamps
and hasattr(generation_config, "no_timestamps_token_id")
and init_tokens[i][-1] != generation_config.no_timestamps_token_id
):
init_tokens[i].append(generation_config.no_timestamps_token_id)
elif (
generation_config.return_timestamps and init_tokens[i][-1] == generation_config.no_timestamps_token_id
):
logger.info(
"<|notimestamps|> prompt token is removed from generation_config since `return_timestamps` is set to `'True'`."
)
init_tokens[i] = init_tokens[i][:-1]
# let's make sure we don't pass `None` tokens as prompt tokens
init_tokens[i] = [t for t in init_tokens[i] if t is not None]
return torch.as_tensor(init_tokens, dtype=torch.long, device=self.device).expand(batch_size, -1)
def detect_language(
self,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Union[torch.FloatTensor, BaseModelOutput]] = None,
generation_config: Optional[GenerationConfig] = None,
num_segment_frames: int = 3000,
) -> torch.Tensor:
"""
Detects language from log-mel input features or encoder_outputs
Parameters:
input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):
Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
num_segment_frames (`int`, *optional*, defaults to 3000):
The number of log-mel frames the model expects
Return:
A `torch.LongTensor` representing the detected language ids.
"""
if input_features is None and encoder_outputs is None:
raise ValueError("You have to specify either `input_features` or `encoder_outputs`")
elif input_features is not None and encoder_outputs is not None:
raise ValueError("Make sure to specificy only one of `input_features` or `encoder_outputs` - not both!")
elif input_features is not None:
inputs = {"input_features": input_features[:, :, :num_segment_frames]}
batch_size = input_features.shape[0]
elif encoder_outputs is not None:
inputs = {"encoder_outputs": encoder_outputs}
batch_size = (
encoder_outputs[0].shape[0] if isinstance(encoder_outputs, BaseModelOutput) else encoder_outputs[0]
)
if attention_mask is not None:
inputs["attention_mask"] = attention_mask
generation_config = generation_config or self.generation_config
decoder_input_ids = (
torch.ones((batch_size, 1), device=self.device, dtype=torch.long)
* generation_config.decoder_start_token_id
)
with torch.no_grad():
logits = self(**inputs, decoder_input_ids=decoder_input_ids).logits[:, -1]
non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool)
non_lang_mask[list(generation_config.lang_to_id.values())] = False
logits[:, non_lang_mask] = -np.inf
lang_ids = logits.argmax(-1)
return lang_ids
@staticmethod
def _check_decoder_input_ids(kwargs):
decoder_input_ids = kwargs.get("decoder_input_ids", None)
assistant_model = kwargs.get("assistant_model", None)
if decoder_input_ids is not None and assistant_model is not None:
raise ValueError(
"Passing `decoder_input_ids` is deprecated. Consider passing `prompt_ids` instead.",
)
@staticmethod
def _set_num_frames(return_token_timestamps, generation_config, kwargs):
if return_token_timestamps:
if getattr(generation_config, "task", None) == "translate":
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
if not hasattr(generation_config, "alignment_heads"):
raise ValueError(
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
)
generation_config.num_frames = kwargs.pop("num_frames", None)
@staticmethod
def _set_thresholds_and_condition(
generation_config,
logprob_threshold,
compression_ratio_threshold,
no_speech_threshold,
condition_on_prev_tokens,
):
generation_config.logprob_threshold = (
logprob_threshold
if logprob_threshold is not None
else getattr(generation_config, "logprob_threshold", None)
)
generation_config.compression_ratio_threshold = (
compression_ratio_threshold
if compression_ratio_threshold is not None
else getattr(generation_config, "compression_ratio_threshold", None)
)
generation_config.no_speech_threshold = (
no_speech_threshold
if no_speech_threshold is not None
else getattr(generation_config, "no_speech_threshold", None)
)
generation_config.condition_on_prev_tokens = (
condition_on_prev_tokens
if condition_on_prev_tokens is not None
else getattr(generation_config, "condition_on_prev_tokens", None)
)
@staticmethod
def _set_prompt_condition_type(generation_config, prompt_condition_type):
allowed_cond_types = ["first-segment", "all-segments"]
# default to "first-segment"
prompt_condition_type = prompt_condition_type or allowed_cond_types[0]
if prompt_condition_type not in allowed_cond_types:
raise ValueError(
f"`prompt_condition_type={prompt_condition_type} does not exist. Make sure to set `prompt_condition_type` to one of {', '.join(allowed_cond_types)}"
)
if generation_config.condition_on_prev_tokens is not True and prompt_condition_type == "all-segments":
raise ValueError(
"Make sure to set `condition_on_prev_tokens=True` when setting `prompt_condition_type='all-segments'`."
)
generation_config.prompt_condition_type = prompt_condition_type
@staticmethod
def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config):
condition_on_prev_tokens = (
condition_on_prev_tokens
if condition_on_prev_tokens is not None
else getattr(generation_config, "condition_on_prev_tokens", False)
)
generation_config.condition_on_prev_tokens = condition_on_prev_tokens
@staticmethod
def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames, is_shortform):
if batch_size > 1 and not is_shortform and attention_mask is None:
raise ValueError(
"When doing batched long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
)
elif batch_size > 1 and not is_shortform:
max_frames = attention_mask.sum(-1).cpu().to(torch.long)
seek = torch.zeros((batch_size,), dtype=torch.long)
else:
max_frames = torch.ones((batch_size,), dtype=torch.long) * total_input_frames
seek = torch.zeros((batch_size,), dtype=torch.long)
return max_frames, seek
def _retrieve_logit_processors(self, generation_config, logits_processor, begin_index, num_beams, device):
if generation_config.return_timestamps is True:
timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index)
logits_processor = (
[timestamp_processor] if logits_processor is None else [timestamp_processor] + logits_processor
)
if generation_config.suppress_tokens is not None:
suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens, device=device)
logits_processor = (
[suppress_tokens_processor]
if logits_processor is None
else [suppress_tokens_processor] + logits_processor
)
generation_config.suppress_tokens = None
if generation_config.begin_suppress_tokens is not None:
begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor(
generation_config.begin_suppress_tokens, begin_index=begin_index, device=device
)
logits_processor = (
[begin_suppress_processor]
if logits_processor is None
else [begin_suppress_processor] + logits_processor
)
generation_config.begin_suppress_tokens = None
if generation_config.no_speech_threshold is not None:
no_speech_detector = WhisperNoSpeechDetection(
no_speech_token=generation_config.no_timestamps_token_id - 1,
begin_index=begin_index,
scores_is_logprobs=num_beams > 1,
)
logits_processor = (
[no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor
)
no_speech_detector.set_model(self)
return logits_processor
@staticmethod
def _maybe_reduce_batch(input_features, seek, max_frames, cur_bsz, batch_idx_map):
prev_bsz = cur_bsz
new_batch_idx_map = []
for i in range(prev_bsz):
prev_i = batch_idx_map[i]
if seek[prev_i] >= max_frames[prev_i]:
cut_index = i + (cur_bsz - prev_bsz)
cur_bsz -= 1
input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0)
else:
# cut out index that goes away
new_batch_idx_map.append(prev_i)
return input_features, cur_bsz, new_batch_idx_map
@staticmethod
def _get_input_segment(input_features, seek, seek_num_frames, num_segment_frames, cur_bsz, batch_idx_map):
if input_features is None:
return None
segment_input = []
for i in range(cur_bsz):
prev_i = batch_idx_map[i]
segment_input_slice = input_features[i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i]]
if segment_input_slice.shape[-1] < num_segment_frames:
# pad to 3000 if necessary
segment_input_slice = F.pad(
segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1])
)
segment_input.append(segment_input_slice)
segment_input = torch.cat(segment_input, dim=0)
return segment_input
@staticmethod
def _prepare_decoder_input_ids(
cur_bsz,
init_tokens,
current_segments,
batch_idx_map,
do_condition_on_prev_tokens,
prompt_ids,
generation_config,
config,
device,
suppress_tokens,
kwargs,
):
if "decoder_input_ids" in kwargs:
decoder_input_ids = kwargs.pop("decoder_input_ids")
return decoder_input_ids, kwargs
cut_off_length = config.max_target_positions // 2 - 1
decoder_input_ids = init_tokens[batch_idx_map]
prev_start_of_text = getattr(generation_config, "prev_sot_token_id", None)
if prev_start_of_text is None:
prev_start_of_text = suppress_tokens[-2] if suppress_tokens is not None else None
if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0:
# according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map]
if prompt_ids is not None and generation_config.prompt_condition_type == "all-segments":
prev_ids = prompt_ids
else:
one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long)
prev_ids = prev_start_of_text * one_tensor[0] if prev_start_of_text is not None else None
padding = "max_length" if generation_config.cache_implementation == "static" else "longest"
prev_tokens = _pad_to_max_length(
active_segments,
generation_config.pad_token_id,
device=device,
padding_side="left",
padding=padding,
bos_token_tensor=prev_ids,
cut_off_length=cut_off_length,
)
decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
kwargs["decoder_attention_mask"] = decoder_input_ids != generation_config.pad_token_id
elif prompt_ids is not None:
prev_tokens = prompt_ids[None].repeat(decoder_input_ids.shape[0], 1)
decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1)
# make sure `"decoder_attention_mask"` is not passed to forward
kwargs.pop("decoder_attention_mask", None)
else:
# make sure `"decoder_attention_mask"` is not passed to forward
kwargs.pop("decoder_attention_mask", None)
return decoder_input_ids, kwargs
def _set_max_new_tokens_and_length(self, config, decoder_input_ids, generation_config):
max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0
if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions:
raise ValueError(
f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
f"is {max_new_tokens}. Thus, the combined length of "
f"`decoder_input_ids` and `max_new_tokens` is: {max_new_tokens + decoder_input_ids.shape[-1]}. This exceeds the "
f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. "
"You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
f"so that their combined length is less than {self.config.max_target_positions}."
)
num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1)
# Make sure we don't get larger than `max_length`
if generation_config.max_length is not None and generation_config.max_new_tokens is None:
max_length = min(generation_config.max_length + num_initial_tokens, config.max_target_positions)
logger.info(
f"Increase max_length from {generation_config.max_length} to {max_length} since input is conditioned on previous segment."
)
elif (
generation_config.max_new_tokens is not None
and generation_config.max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions
):
max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]
generation_config.max_new_tokens = max_new_tokens
@staticmethod
def _retrieve_compression_ratio(tokens, vocab_size):
"""Compute byte length of zlib compressed token bytes vs. byte length of raw token bytes"""
length = int(math.log2(vocab_size) / 8) + 1
token_bytes = b"".join([t.to_bytes(length, "little") for t in tokens.tolist()])
compression_ratio = len(token_bytes) / len(zlib.compress(token_bytes))
return compression_ratio
@staticmethod
def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature):
rescale_temperature = temperature if temperature > 0.0 else 1
scores = torch.stack(scores).to(tokens.device)
if scores.shape[0] > tokens.shape[0]:
scores = scores[: tokens.shape[0]]
else:
tokens = tokens[-scores.shape[0] :]
logprobs = F.log_softmax((scores * rescale_temperature).float(), dim=-1).to(scores.dtype)
# retrieve logprob of selected tokens and sum
sum_logprobs = sum((logprobs[i][tokens[i]] * (tokens[i] != eos_token_id)) for i in range(logprobs.shape[0]))
length = (tokens != eos_token_id).sum(-1) if eos_token_id is not None else tokens.shape[0]
avg_logprobs = sum_logprobs / (length + 1)
return avg_logprobs
@staticmethod
def _retrieve_segment(
seek_sequence,
seek_outputs,
time_offset,
timestamp_begin,
seek_num_frames,
time_precision,
input_stride,
prev_idx,
idx,
return_token_timestamps,
):
# find the predicted "end of segment" predictions of Whisper
# "end of segment" predictions occur whenever Whisper predicts a timestamp token
timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
timestamp_segment_indices.add_(1)
token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
# "end of segment" prediction and slice the decoding into segments accordingly
if len(timestamp_segment_indices) > 0:
# if the output contains two consecutive timestamp tokens
slices = timestamp_segment_indices.tolist()
segments = []
if single_timestamp_ending:
slices.append(len(seek_sequence))
last_slice = 0
# Add each segment to list of all segments
for current_slice in slices:
sliced_tokens = seek_sequence[last_slice:current_slice]
start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
segments.append(
{
"start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
"end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
"tokens": sliced_tokens,
"result": seek_outputs[idx],
}
)
if return_token_timestamps:
segments[-1]["token_timestamps"] = (
token_timestamps[last_slice:current_slice] + time_offset[prev_idx]
)
last_slice = current_slice
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
segment_offset = seek_num_frames[prev_idx]
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
# here we throw away all predictions after the last predicted "end of segment"
# since we are cutting right in the middle of an audio
last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin
segment_offset = last_timestamp_pos * input_stride
else:
# If whisper does not predict any "end of segment" token, then
# the whole decoding is considered a segment and we add it to the list of segments
timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
last_timestamp_pos = seek_num_frames[prev_idx]
if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = timestamps[-1].item() - timestamp_begin
segments = [
{
"start": time_offset[prev_idx],
"end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
"tokens": seek_sequence,
"result": seek_outputs[idx],
}
]
if return_token_timestamps:
segments[-1]["token_timestamps"] = token_timestamps + time_offset[prev_idx]
segment_offset = seek_num_frames[prev_idx]
return segments, segment_offset
# coding=utf-8
# Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Whisper model."""
import math
import os.path
import random
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from dataclasses import dataclass
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
SequenceClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_whisper import WhisperVQConfig
from .generation_whisper import WhisperGenerationMixin
if is_flash_attn_2_available():
from transformers.modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__)
_HIDDEN_STATES_START_POSITION = 1
_CONFIG_FOR_DOC = "WhisperConfig"
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
@dataclass
class QuantizedBaseModelOutput(BaseModelOutput):
quantized_token_ids: Optional[torch.LongTensor] = None
def vector_quantize(inputs, codebook):
embedding_size = codebook.size(1)
inputs_flatten = inputs.reshape(-1, embedding_size)
codebook_sqr = torch.sum(codebook ** 2, dim=1)
inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True)
# Compute the distances to the codebook
distances = torch.addmm(codebook_sqr + inputs_sqr,
inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0)
_, indices_flatten = torch.min(distances, dim=1)
codes_flatten = torch.index_select(codebook, dim=0,
index=indices_flatten)
codes = codes_flatten.view_as(inputs)
return codes, indices_flatten, distances
def mse_loss_with_mask(input, target, mask):
loss = torch.nn.functional.mse_loss(input, target, reduction='none')
loss = loss.mean(dim=-1)
loss = loss * mask
return loss.sum() / mask.sum()
class CausalConv1d(nn.Conv1d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
**kwargs
):
super(CausalConv1d, self).__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
**kwargs
)
self.left_padding = dilation * (kernel_size - 1)
def forward(self, inp):
x = torch.nn.functional.pad(inp.unsqueeze(2), (self.left_padding, 0, 0, 0)).squeeze(2)
return super(CausalConv1d, self).forward(x)
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
"""Returns sinusoids for positional embedding"""
if channels % 2 != 0:
raise ValueError(
f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
)
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.LongTensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
CPU as part of the preprocessing during training.
Args:
shape: The shape for which to compute masks. This should be of a tuple of size 2 where
the first element is the batch size and the second element is the length of the axis to span.
mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
independently generated mask spans of length `mask_length` is computed by
`mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
actual percentage will be smaller.
mask_length: size of the mask
min_masks: minimum number of masked spans
attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
each batch dimension.
"""
batch_size, sequence_length = shape
if mask_length < 1:
raise ValueError("`mask_length` has to be bigger than 0.")
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
def compute_num_masked_span(input_length):
"""Given input length, compute how many spans should be masked"""
num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
num_masked_span = max(num_masked_span, min_masks)
# make sure num masked span <= sequence_length
if num_masked_span * mask_length > sequence_length:
num_masked_span = sequence_length // mask_length
# make sure num_masked span is also <= input_length - (mask_length - 1)
if input_length - (mask_length - 1) < num_masked_span:
num_masked_span = max(input_length - (mask_length - 1), 0)
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.sum(-1).detach().tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
spec_aug_mask_idxs = []
max_num_masked_span = compute_num_masked_span(sequence_length)
if max_num_masked_span == 0:
return spec_aug_mask
for input_length in input_lengths:
# compute num of masked spans for this input
num_masked_span = compute_num_masked_span(input_length)
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
# Picking first sample just pads those vectors twice.
if len(spec_aug_mask_idx) == 0:
# this case can only happen if `input_length` is strictly smaller then
# `sequence_length` in which case the last token has to be a padding
# token which we can use as a dummy mask id
dummy_mask_idx = sequence_length - 1
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
batch_size, max_num_masked_span * mask_length
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
return spec_aug_mask
class WhisperPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__(num_positions, embedding_dim)
def forward(self, input_ids, past_key_values_length=0, position_ids=None):
if position_ids is None:
return self.weight[past_key_values_length: past_key_values_length + input_ids.shape[1]]
else:
return self.weight[position_ids]
class WhisperAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
is_causal: bool = False,
layer_idx: Optional[int] = None,
config: Optional[WhisperVQConfig] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim ** -0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
if layer_idx is None and is_decoder:
logger.warning_once(
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.layer_idx = layer_idx
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
# Copied from transformers.models.bart.modeling_bart.BartAttention._shape with BART->whisper
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz)
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_value.is_updated[self.layer_idx] = True
past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache
# use key_value_states if cross attention
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
else:
key_states = self._shape(self.k_proj(current_states), -1, bsz)
value_states = self._shape(self.v_proj(current_states), -1, bsz)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_probs, value_states)
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value
class WhisperFlashAttention2(WhisperAttention):
"""
Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
"The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. "
"Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers"
)
# WhisperFlashAttention2 attention does not support output_attentions
if output_attentions:
raise ValueError("WhisperFlashAttention2 attention does not support output_attentions")
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim))
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_value.is_updated[self.layer_idx] = True
past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache
# use key_value_states if cross attention
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
else:
key_states = self._shape(self.k_proj(current_states), -1, bsz)
value_states = self._shape(self.v_proj(current_states), -1, bsz)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
# We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
causal_mask = attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
causal_mask,
tgt_len,
dropout=self.dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
attn_output = attn_output.reshape(bsz, tgt_len, -1)
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class WhisperSdpaAttention(WhisperAttention):
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
if output_attentions or layer_head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
logger.warning_once(
"WhisperModel is using WhisperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
key_value_states=key_value_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
cache_position=cache_position,
)
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_value.is_updated[self.layer_idx] = True
past_key_value = past_key_value.cross_attention_cache
else:
past_key_value = past_key_value.self_attention_cache
# use key_value_states if cross attention
current_states = key_value_states if key_value_states is not None else hidden_states
if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
else:
key_states = self._shape(self.k_proj(current_states), -1, bsz)
value_states = self._shape(self.v_proj(current_states), -1, bsz)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
causal_mask = attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=is_causal,
)
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None, past_key_value
WHISPER_ATTENTION_CLASSES = {
"eager": WhisperAttention,
# "flash_attention_2": WhisperFlashAttention2,
"sdpa": WhisperSdpaAttention,
}
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER
class WhisperVQEncoderLayer(nn.Module):
def __init__(self, config: WhisperVQConfig, is_causal=False):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
config=config,
is_causal=is_causal
)
self.is_causal = is_causal
if self.is_causal:
assert isinstance(self.self_attn, WhisperSdpaAttention), "Causal attention is only supported for SDPA"
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
layer_head_mask: torch.Tensor,
output_attentions: bool = False,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask if not self.is_causal else None,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class WhisperDecoderLayer(nn.Module):
def __init__(self, config: WhisperVQConfig, layer_idx: int = None):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
is_causal=True,
layer_idx=layer_idx,
config=config,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
layer_idx=layer_idx,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
cache_position: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size `(decoder_attention_heads,)`.
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# Cross-Attention Block
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# add cross-attn to positions 1 of present_key_value tuple
present_key_value = (present_key_value, cross_attn_present_key_value)
# Fully Connected
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
if use_cache:
outputs += (present_key_value,)
return outputs
class WhisperPreTrainedModel(PreTrainedModel):
config_class = WhisperVQConfig
base_model_prefix = "model"
main_input_name = "input_features"
supports_gradient_checkpointing = True
_no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, WhisperVQEncoder):
with torch.no_grad():
embed_positions = module.embed_positions.weight
embed_positions.copy_(sinusoids(*embed_positions.shape))
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
"""
Computes the output length of the convolutional layers
"""
input_lengths = (input_lengths - 1) // 2 + 1
return input_lengths
WHISPER_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`WhisperConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
WHISPER_INPUTS_DOCSTRING = r"""
Args:
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing *SpecAugment* data augmentation on padding token indices. Mask values selected in
`[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are decoder input IDs?](../glossary#decoder-input-ids)
Whisper uses the `decoder_start_token_id` as the starting token for `decoder_input_ids` generation. If
`past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
If you want to change padding behavior, you should read
[`modeling_whisper._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the BART
paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are
four sets of pre-computed hidden-states: key and values states in the self-attention blocks (2) and
in the cross-attention blocks (2). The `past_key_values` are returned when `use_cache=True` is passed or
when `config.use_cache=True`
Two formats are allowed:
- An [`~cache_utils.EncoderDecoderCache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
input (see `past_key_values`). This is useful if you want more control over how to convert
`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the cache
in the correct position and to infer the complete sequence length.
"""
WHISPER_ENCODER_INPUTS_DOCSTRING = r"""
Args:
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class WhisperVQEncoder(WhisperPreTrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
[`WhisperEncoderLayer`].
Args:
config: WhisperConfig
"""
def __init__(self, config: WhisperVQConfig):
super().__init__(config)
self.config = config
self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop
embed_dim = config.d_model
self.num_mel_bins = config.num_mel_bins
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
if config.encoder_causal_convolution:
conv_class = CausalConv1d
else:
conv_class = nn.Conv1d
self.conv1 = conv_class(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
self.conv2 = conv_class(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
self.embed_positions.requires_grad_(False)
if config.quantize_encoder_only:
self.layers = nn.ModuleList([WhisperVQEncoderLayer(config,
is_causal=config.encoder_causal_attention or config.quantize_causal_encoder)
for _ in range(config.quantize_position)])
else:
self.layers = nn.ModuleList([WhisperVQEncoderLayer(config, is_causal=config.encoder_causal_attention or (
config.quantize_causal_encoder and layer_id < config.quantize_position)) for layer_id in
range(config.encoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
# Parameters related to pooling layer
self.pooling_layer = None
# Parameters related to quantization layer
self.codebook = None
self.embed_positions2 = None
self.quantize_loss = None
self.num_active_codes = None
self.quantize_ema_count = 0
# Save hiddens
self.save_hidden_dir = None
self.save_hidden_position = None
# Initialize weights and apply final processing
self.init_pooling_layer(config)
self.init_quantize_layer(config)
self.post_init()
def init_pooling_layer(self, config: WhisperVQConfig):
if config.pooling_kernel_size is not None:
if config.pooling_type == "max":
self.pooling_layer = nn.MaxPool1d(kernel_size=config.pooling_kernel_size)
elif config.pooling_type == "avg":
self.pooling_layer = nn.AvgPool1d(kernel_size=config.pooling_kernel_size)
else:
raise NotImplementedError(f"Pooling type {config.pooling_type} not implemented")
def init_quantize_layer(self, config: WhisperVQConfig, quantize_load_codebook=None):
if config.quantize_vocab_size is not None:
if config.pooling_position is not None:
assert config.quantize_position >= config.pooling_position
self.codebook = nn.Embedding(config.quantize_vocab_size, self.config.d_model)
if quantize_load_codebook is not None:
init_codes = np.load(quantize_load_codebook)
self.codebook.weight.data.copy_(torch.from_numpy(init_codes))
max_source_positions = self.max_source_positions
if config.pooling_kernel_size is not None:
max_source_positions = math.ceil(max_source_positions / self.config.pooling_kernel_size)
self.embed_positions2 = nn.Embedding(max_source_positions, self.config.d_model)
self.embed_positions2.weight.data.copy_(self.embed_positions.weight.data[:max_source_positions])
if config.quantize_ema_decay is not None:
self.codebook.weight.requires_grad = False
self.register_buffer("ema_count", torch.ones(config.quantize_vocab_size, dtype=torch.float))
self.register_buffer("ema_weight", self.codebook.weight.data.clone().float())
def _freeze_parameters(self):
for param in self.parameters():
param.requires_grad = False
self._requires_grad = False
def get_input_embeddings(self) -> nn.Module:
return self.conv1
def set_input_embeddings(self, value: nn.Module):
self.conv1 = value
def get_block_causal_attention_mask(self, attention_mask, block_size=50):
dtype = self.dtype
batch_size, seq_length = attention_mask.shape
causal_mask = torch.torch.tril(
torch.ones(1, seq_length, seq_length, dtype=torch.bool, device=attention_mask.device))
block_square_mask = []
for start in range(0, seq_length, block_size):
end = min(start + block_size, seq_length)
length = end - start
block_square_mask.append(causal_mask.new_ones((length, length)))
block_square_mask = torch.block_diag(*block_square_mask)
block_causal_mask = causal_mask | block_square_mask
block_causal_mask = block_causal_mask & attention_mask[:, None, :]
block_causal_mask = block_causal_mask.to(dtype=dtype) # fp16 compatibility
block_causal_mask = (1.0 - block_causal_mask) * torch.finfo(dtype).min
block_causal_mask = block_causal_mask.unsqueeze(1)
return block_causal_mask
def forward(
self,
input_features,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
quantized_token_ids=None
):
r"""
Args:
input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
attention_mask (`torch.Tensor`)`, *optional*):
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
but it is not used. By default the silence in the input log mel spectrogram are ignored.
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
# expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
# if input_features.shape[-1] != expected_seq_length:
# raise ValueError(
# f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
# )
batch_size, feature_size, seq_length = input_features.shape
seq_length = seq_length // (self.conv1.stride[0] * self.conv2.stride[0])
attention_mask = attention_mask[:, :: self.conv1.stride[0] * self.conv2.stride[0]]
if self.config.quantize_causal_block_size is not None:
extended_attention_mask = self.get_block_causal_attention_mask(attention_mask,
block_size=self.config.quantize_causal_block_size)
else:
extended_attention_mask = self.get_extended_attention_mask(attention_mask, (batch_size, seq_length))
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
inputs_embeds = inputs_embeds.permute(0, 2, 1)
embed_pos = self.embed_positions.weight
hidden_states = inputs_embeds + embed_pos[:seq_length]
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
assert attention_mask.shape[-1] == hidden_states.shape[1]
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
to_drop = False
if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
extended_attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
extended_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if idx + 1 == self.config.pooling_position and self.config.pooling_kernel_size is not None:
hidden_states = hidden_states.permute(0, 2, 1)
if hidden_states.shape[-1] % self.config.pooling_kernel_size != 0:
hidden_states = torch.nn.functional.pad(hidden_states, (
0, self.config.pooling_kernel_size - hidden_states.shape[-1] % self.config.pooling_kernel_size))
hidden_states = self.pooling_layer(hidden_states).permute(0, 2, 1)
attention_mask = attention_mask[:, ::self.config.pooling_kernel_size]
if self.config.quantize_causal_block_size is not None:
extended_attention_mask = self.get_block_causal_attention_mask(attention_mask, block_size=self.config.quantize_causal_block_size // self.config.pooling_kernel_size)
else:
extended_attention_mask = self.get_extended_attention_mask(attention_mask, (
batch_size, seq_length // self.config.pooling_kernel_size))
if idx + 1 == self.config.quantize_position and self.config.quantize_vocab_size is not None:
if quantized_token_ids is not None:
hidden_states = self.codebook(quantized_token_ids)
else:
hidden_quantized, indices_flat, distances = vector_quantize(hidden_states, self.codebook.weight)
quantized_token_ids = indices_flat.reshape(batch_size, hidden_quantized.shape[1])
if self.training:
encodings = torch.nn.functional.one_hot(indices_flat, self.config.quantize_vocab_size).float()
encodings = encodings * attention_mask.reshape(-1, 1)
n = torch.sum(encodings, dim=0)
torch.distributed.all_reduce(n, op=torch.distributed.ReduceOp.SUM)
self.num_active_codes = n.nonzero().shape[0]
if self.config.quantize_ema_decay:
hidden_flat = hidden_states.detach().float().reshape(-1, hidden_states.shape[-1])
with torch.autocast(device_type='cuda', dtype=torch.float32):
dw = torch.matmul(encodings.t(), hidden_flat)
torch.distributed.all_reduce(dw, op=torch.distributed.ReduceOp.SUM)
self.ema_count = self.ema_count * self.config.quantize_ema_decay + (
1 - self.config.quantize_ema_decay) * n
total_count = torch.sum(self.ema_count)
self.ema_count = (self.ema_count + 1e-5) / (
total_count + self.config.quantize_vocab_size * 1e-5) * total_count
self.ema_weight = self.ema_weight * self.config.quantize_ema_decay + (
1 - self.config.quantize_ema_decay) * dw
self.codebook.weight.data = self.ema_weight / self.ema_count.unsqueeze(1)
self.quantize_loss = self.config.quantize_loss_scale * self.config.quantize_commit_coefficient * mse_loss_with_mask(
hidden_states, hidden_quantized.detach(), attention_mask)
self.quantize_ema_count += 1
if self.config.quantize_restart_interval is not None and self.quantize_ema_count % self.config.quantize_restart_interval == 0:
rank, world_size = torch.distributed.get_rank(), torch.distributed.get_world_size()
segment_vocab_size = self.config.quantize_vocab_size // world_size
start_idx = segment_vocab_size * rank
ema_count_segment = self.ema_count[start_idx: start_idx + segment_vocab_size]
threshold = 1 * (
self.config.quantize_ema_decay ** self.config.quantize_restart_interval)
update_indices = (ema_count_segment < threshold).nonzero()[:, 0] + start_idx
num_update = update_indices.shape[0]
mask_flat = attention_mask.reshape(-1) > 0
hidden_selected = hidden_flat[mask_flat]
hidden_update = hidden_selected[random.sample(range(len(hidden_selected)), num_update)]
num_update = torch.as_tensor([num_update], dtype=torch.long,
device=hidden_states.device)
num_update_list = [torch.as_tensor([0], dtype=torch.long, device=hidden_states.device)
for _
in range(world_size)]
torch.distributed.all_gather(num_update_list, num_update)
update_indices_list = [
torch.zeros(num.item(), dtype=torch.long, device=hidden_states.device) for num in
num_update_list]
torch.distributed.all_gather(update_indices_list, update_indices)
update_indices = torch.cat(update_indices_list)
hidden_update_list = [
torch.zeros(num.item(), hidden_flat.shape[-1], dtype=hidden_update.dtype,
device=hidden_states.device) for num in num_update_list]
torch.distributed.all_gather(hidden_update_list, hidden_update)
hidden_update = torch.cat(hidden_update_list)
self.codebook.weight.data[update_indices] = hidden_update
self.ema_count[update_indices] = 1
self.ema_weight[update_indices] = hidden_update
if torch.distributed.get_rank() == 0:
print(f"restart {len(update_indices)} tokens")
else:
loss = self.config.quantize_loss_scale * (
self.config.quantize_commit_coefficient * mse_loss_with_mask(hidden_states,
hidden_quantized.detach(),
attention_mask) + mse_loss_with_mask(
hidden_quantized, hidden_states.detach(), attention_mask))
self.quantize_loss = loss
hidden_states = hidden_states + (hidden_quantized - hidden_states).detach()
else:
hidden_states = hidden_quantized
hidden_states = hidden_states + self.embed_positions2.weight[:hidden_states.shape[1]]
if idx + 1 == self.save_hidden_position:
import numpy as np
import uuid
to_save = []
for batch_idx, hidden_state in enumerate(hidden_states):
for seq_idx, hidden in enumerate(hidden_state):
if attention_mask[batch_idx, seq_idx]:
to_save.append(hidden.detach().cpu().numpy())
np.save(os.path.join(self.save_hidden_dir, f"{str(uuid.uuid4())}.npy"), to_save)
if not self.config.quantize_encoder_only:
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return QuantizedBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions,
quantized_token_ids=quantized_token_ids,
)
class WhisperVQDecoder(WhisperPreTrainedModel):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`WhisperDecoderLayer`]
Args:
config: WhisperConfig
"""
main_input_name = "input_ids"
def __init__(self, config: WhisperVQConfig):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_target_positions
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model)
self.layers = nn.ModuleList(
[WhisperDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)]
)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layer_norm = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
position_ids=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
cache_position=None,
):
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.]
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are
four sets of pre-computed hidden-states: key and values states in the self-attention blocks (2) and
in the cross-attention blocks (2). The `past_key_values` are returned when `use_cache=True` is passed or
when `config.use_cache=True`
Two formats are allowed:
- An [`~cache_utils.EncoderDecoderCache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
control over how to convert `input_ids` indices into associated vectors than the model's internal
embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
cache in the correct position and to infer the complete sequence length.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
assert encoder_attention_mask.shape[-1] == encoder_hidden_states.shape[1]
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
return_legacy_cache = False
return_self_attention_cache = False
if use_cache or past_key_values is not None:
if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
return_self_attention_cache = True
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
elif not isinstance(past_key_values, EncoderDecoderCache):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
)
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
past_key_values_length = 0
if cache_position is not None:
past_key_values_length = cache_position[0]
elif past_key_values is not None:
past_key_values_length = past_key_values.get_seq_length()
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# embed positions
if input_ids is not None:
positions = self.embed_positions(
input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
)
else:
positions = self.embed_positions(
inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
)
hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values.self_attention_cache if past_key_values is not None else None,
output_attentions,
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
assert attn_mask.size()[0] == (len(self.layers)), (
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop:
continue
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
encoder_hidden_states,
encoder_extended_attention_mask, # encoder attention mask
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, # past_key_value
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_values if use_cache else None,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
hidden_states = self.layer_norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = past_key_values if use_cache else None
if return_self_attention_cache:
next_cache = past_key_values.self_attention_cache
if return_legacy_cache:
next_cache = past_key_values.to_legacy_cache()
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@add_start_docstrings(
"The bare Whisper Model outputting raw hidden-states without any specific head on top.",
WHISPER_START_DOCSTRING,
)
class WhisperVQModel(WhisperPreTrainedModel):
def __init__(self, config: WhisperVQConfig):
super().__init__(config)
self.encoder = WhisperVQEncoder(config)
self.decoder = WhisperVQDecoder(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.decoder.embed_tokens
def set_input_embeddings(self, value):
self.decoder.embed_tokens = value
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def freeze_encoder(self):
"""
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
not be updated during training.
"""
self.encoder._freeze_parameters()
def _mask_input_features(
self,
input_features: torch.FloatTensor,
attention_mask: Optional[torch.LongTensor] = None,
):
"""
Masks extracted features along time axis and/or along feature axis according to
[SpecAugment](https://arxiv.org/abs/1904.08779).
"""
# `config.apply_spec_augment` can set masking to False
if not getattr(self.config, "apply_spec_augment", True):
return input_features
# generate indices & apply SpecAugment along time axis
batch_size, hidden_size, sequence_length = input_features.size()
if self.config.mask_time_prob > 0 and self.training:
# generate indices & apply SpecAugment along time axis
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
attention_mask=attention_mask,
min_masks=self.config.mask_time_min_masks,
)
mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool)
mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1)
input_features[mask_time_indices] = 0
if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
min_masks=self.config.mask_feature_min_masks,
)
mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool)
input_features[mask_feature_indices] = 0
return input_features
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
quantized_token_ids: Optional[torch.LongTensor] = None
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
r"""
Returns:
Example:
```python
>>> import torch
>>> from transformers import AutoFeatureExtractor, WhisperModel
>>> from datasets import load_dataset
>>> model = WhisperVQModel.from_pretrained("openai/whisper-base")
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
>>> list(last_hidden_state.shape)
[1, 2, 512]
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if encoder_outputs is None:
input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
encoder_outputs = self.encoder(
input_features,
attention_mask=attention_mask,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
quantized_token_ids=quantized_token_ids
)
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
attention_mask = attention_mask[:, ::self.encoder.conv1.stride[0] * self.encoder.conv2.stride[0]]
if self.encoder.config.pooling_kernel_size is not None:
attention_mask = attention_mask[:, ::self.encoder.config.pooling_kernel_size]
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_attention_mask=attention_mask,
encoder_hidden_states=encoder_outputs[0],
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
position_ids=decoder_position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
"The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
WHISPER_START_DOCSTRING,
)
class WhisperVQForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedModel):
base_model_prefix = "model"
_tied_weights_keys = ["proj_out.weight"]
def __init__(self, config: WhisperVQConfig):
super().__init__(config)
self.model = WhisperVQModel(config)
self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.quantize_loss = None
# Initialize weights and apply final processing
self.post_init()
def get_encoder(self):
return self.model.get_encoder()
def get_decoder(self):
return self.model.get_decoder()
def get_output_embeddings(self):
return self.proj_out
def set_output_embeddings(self, new_embeddings):
self.proj_out = new_embeddings
def get_input_embeddings(self) -> nn.Module:
return self.model.get_input_embeddings()
def freeze_encoder(self):
"""
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
not be updated during training.
"""
self.model.encoder._freeze_parameters()
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
quantized_token_ids: Optional[torch.LongTensor] = None
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> import torch
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperVQForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> generated_ids = model.generate(inputs=input_features)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> transcription
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_features,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds,
decoder_position_ids=decoder_position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
quantized_token_ids=quantized_token_ids
)
lm_logits = self.proj_out(outputs[0])
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
if self.training and self.model.encoder.quantize_loss is not None:
loss = loss + self.model.encoder.quantize_loss
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
use_cache=None,
encoder_outputs=None,
attention_mask=None,
decoder_attention_mask=None,
cache_position=None,
quantized_token_ids=None,
**kwargs,
):
decoder_position_ids = None
if decoder_attention_mask is not None:
decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0)
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, EncoderDecoderCache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
else:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
if decoder_position_ids is not None:
decoder_position_ids = decoder_position_ids[:, remove_prefix_length:]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
decoder_position_ids = decoder_position_ids.clone(memory_format=torch.contiguous_format)
if cache_position is None:
cache_position = torch.arange(
past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device
)
elif use_cache:
cache_position = cache_position[-decoder_input_ids.shape[1]:]
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
decoder_input_ids = decoder_input_ids.contiguous()
if (
isinstance(past_key_values, EncoderDecoderCache)
and (
isinstance(past_key_values.self_attention_cache, StaticCache)
or isinstance(past_key_values.cross_attention_cache, StaticCache)
)
and decoder_attention_mask is not None
and decoder_attention_mask.ndim == 2
):
batch_size, sequence_length = decoder_input_ids.shape
device = decoder_input_ids.device
dtype = self.proj_out.weight.dtype
min_dtype = torch.finfo(dtype).min
decoder_attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
decoder_attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.self_attention_cache.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
return {
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"use_cache": use_cache,
"decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
"cache_position": cache_position,
"quantized_token_ids": quantized_token_ids
}
def _retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs):
if self.config.skip_language_detection:
return torch.as_tensor([[generation_config.decoder_start_token_id] for _ in range(batch_size)],
dtype=torch.long, device=self.device).expand(batch_size, -1)
else:
return super()._retrieve_init_tokens(input_features, batch_size, generation_config, config,
num_segment_frames, kwargs)
class WhisperDecoderWrapper(WhisperPreTrainedModel):
"""
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
used in combination with the [`EncoderDecoderModel`] framework.
"""
def __init__(self, config):
super().__init__(config)
config.is_encoder_decoder = False
self.decoder = WhisperVQDecoder(config)
def get_input_embeddings(self):
return self.decoder.embed_tokens
def set_input_embeddings(self, value):
self.decoder.embed_tokens = value
def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
@add_start_docstrings(
"""
Whisper decoder with a language modeling head on top (linear layer with weights tied to the input embeddings).
""",
WHISPER_START_DOCSTRING,
)
class WhisperForCausalLM(WhisperPreTrainedModel):
_tied_weights_keys = ["proj_out.weight"]
main_input_name = "input_ids"
def __init__(self, config):
super().__init__(config)
config.is_encoder_decoder = False
self.model = WhisperDecoderWrapper(config)
self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.proj_out
def set_output_embeddings(self, new_embeddings):
self.proj_out = new_embeddings
def get_input_embeddings(self) -> nn.Module:
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.model.decoder = decoder
def get_decoder(self):
return self.model.decoder
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
encoder_outputs (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
if the model is configured as a decoder.
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
tensors are only required when the model is used as a decoder in a Sequence to Sequence model. Contains
pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. If
`past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the cache
in the correct position and to infer the complete sequence length.
Returns:
Example:
```python
>>> from transformers import WhisperForCausalLM, WhisperForConditionalGeneration, WhisperProcessor
>>> import torch
>>> from datasets import load_dataset
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
>>> model = WhisperVQForConditionalGeneration.from_pretrained("openai/whisper-large-v2")
>>> assistant_model = WhisperForCausalLM.from_pretrained("distil-whisper/distil-large-v2")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> sample = ds[0]["audio"]
>>> input_features = processor(
... sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt"
... ).input_features
>>> predicted_ids = model.generate(input_features, assistant_model=assistant_model)
>>> # decode token ids to text
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
>>> transcription
' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# If the user passed a tuple or `BaseModelOutput` for encoder_outputs, we extract only the hidden states
if isinstance(encoder_outputs, (BaseModelOutput, tuple, list)):
encoder_outputs = encoder_outputs[0]
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_outputs,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
logits = self.proj_out(outputs[0])
loss = None
if labels is not None:
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
use_cache=None,
encoder_outputs=None,
attention_mask=None,
cache_position=None,
**kwargs,
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, (Cache, EncoderDecoderCache)):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
else:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
if cache_position is None:
cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device)
elif use_cache:
cache_position = cache_position[-input_ids.shape[1]:]
return {
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"input_ids": input_ids,
"use_cache": use_cache,
"attention_mask": attention_mask,
"cache_position": cache_position,
}
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
@add_start_docstrings(
"""
Whisper Encoder Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
like SUPERB Keyword Spotting.
""",
WHISPER_ENCODER_INPUTS_DOCSTRING,
)
class WhisperForAudioClassification(WhisperPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.encoder = WhisperVQEncoder(config)
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
if config.use_weighted_layer_sum:
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def freeze_encoder(self):
"""
Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
not be updated during training. Only the projection layers and classification head will be updated.
"""
self.encoder._freeze_parameters()
def get_input_embeddings(self) -> nn.Module:
return self.encoder.get_input_embeddings()
def set_input_embeddings(self, value: nn.Module):
self.encoder.set_input_embeddings(value)
@add_start_docstrings_to_model_forward(WHISPER_ENCODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_features: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Example:
```python
>>> import torch
>>> from transformers import AutoFeatureExtractor, WhisperForAudioClassification
>>> from datasets import load_dataset
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
>>> model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
>>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True)
>>> sample = next(iter(ds))
>>> inputs = feature_extractor(
... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="pt"
... )
>>> input_features = inputs.input_features
>>> with torch.no_grad():
... logits = model(input_features).logits
>>> predicted_class_ids = torch.argmax(logits).item()
>>> predicted_label = model.config.id2label[predicted_class_ids]
>>> predicted_label
'Afrikaans'
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if self.config.use_weighted_layer_sum:
output_hidden_states = True
elif output_hidden_states is None:
output_hidden_states = self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_features,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.config.use_weighted_layer_sum:
hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = encoder_outputs[0]
hidden_states = self.projector(hidden_states)
pooled_output = hidden_states.mean(dim=1)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# move labels to correct device to enable PP
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + encoder_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
import os
import io
import glob
import math
import tarfile
import torch
import torchaudio
import safetensors
from .configuration_whisper import WhisperVQConfig
from .modeling_whisper import WhisperVQEncoder, WhisperVQForConditionalGeneration
from transformers import WhisperFeatureExtractor, WhisperTokenizerFast
def load_quantize_encoder(model_path):
config = WhisperVQConfig.from_pretrained(model_path)
config.quantize_encoder_only = True
model = WhisperVQEncoder(config)
state_dict = {}
for path in glob.glob(os.path.join(model_path, "model*.safetensors")):
with safetensors.safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("model.encoder."):
new_key = key[len("model.encoder."):]
if new_key.startswith("layer_norm"):
continue
if new_key.startswith("layers"):
layer_id = int(new_key.split(".")[1])
if layer_id >= config.quantize_position:
continue
state_dict[new_key] = f.get_tensor(key)
model.load_state_dict(state_dict)
model.eval()
model.cuda()
return model
_resample_buffer: dict[int, torchaudio.transforms.Resample] = {}
def extract_speech_token(model: WhisperVQEncoder, feature_extractor: WhisperFeatureExtractor, utts):
with torch.no_grad():
audios, indices = [], []
for idx, utt in enumerate(utts):
if isinstance(utt, tuple):
audio, sample_rate = utt
else:
audio, sample_rate = torchaudio.load(utt)
audio = audio.cuda()
if sample_rate != 16000:
if sample_rate not in _resample_buffer:
_resample_buffer[sample_rate] = torchaudio.transforms.Resample(
orig_freq=sample_rate,
new_freq=16000
).to('cuda')
audio = _resample_buffer[sample_rate](audio)
# if audio.shape[0] > 1:
# audio = audio[:1]
audio = audio[0]
audio = audio.cpu().numpy()
time_step = 0
while time_step * 16000 < audio.shape[0]:
audio_segment = audio[time_step * 16000: (time_step + 30) * 16000]
audios.append(audio_segment)
indices.append(idx)
time_step += 30
pooling_kernel_size = model.config.pooling_kernel_size or 1
stride = model.conv1.stride[0] * model.conv2.stride[0] * pooling_kernel_size * feature_extractor.hop_length
all_speech_tokens = [[] for _ in range(len(utts))]
batch_size = 128
for start in range(0, len(audios), batch_size):
features = feature_extractor(audios[start: start + batch_size], sampling_rate=16000,
return_attention_mask=True, return_tensors="pt", device='cuda',
padding="longest", pad_to_multiple_of=stride)
features = features.to(device="cuda")
outputs = model(**features)
speech_tokens = outputs.quantized_token_ids
attention_mask = features.attention_mask[:, ::model.conv1.stride[0] * model.conv2.stride[0]]
attention_mask = attention_mask[:, ::model.config.pooling_kernel_size]
assert attention_mask.shape == speech_tokens.shape
for i in range(len(speech_tokens)):
idx = indices[start + i]
speech_token = speech_tokens[i][attention_mask[i].bool()].tolist()
all_speech_tokens[idx].extend(speech_token)
return all_speech_tokens
# example of file for storing private and user specific environment variables, like keys or system paths
# rename it to ".env" (excluded from version control by default)
# .env is loaded by train.py automatically
# hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR}
MY_VAR="/home/user/my/system/path"
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
### VisualStudioCode
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
*.code-workspace
**/.vscode
# JetBrains
.idea/
# Data & Models
*.h5
*.tar
*.tar.gz
# Lightning-Hydra-Template
configs/local/default.yaml
/data/
/logs/
.env
# Aim logging
.aim
# Cython complied files
matcha/utils/monotonic_align/core.c
# Ignoring hifigan checkpoint
generator_v1
g_02500000
gradio_cached_examples/
synth_output/
/data
default_language_version:
python: python3.11
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
# list of supported hooks: https://pre-commit.com/hooks.html
- id: trailing-whitespace
- id: end-of-file-fixer
# - id: check-docstring-first
- id: check-yaml
- id: debug-statements
- id: detect-private-key
- id: check-toml
- id: check-case-conflict
- id: check-added-large-files
# python code formatting
- repo: https://github.com/psf/black
rev: 23.12.1
hooks:
- id: black
args: [--line-length, "120"]
# python import sorting
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]
# python upgrading syntax to newer version
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
hooks:
- id: pyupgrade
args: [--py38-plus]
# python check (PEP8), programming errors and code complexity
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: flake8
args:
[
"--max-line-length", "120",
"--extend-ignore",
"E203,E402,E501,F401,F841,RST2,RST301",
"--exclude",
"logs/*,data/*,matcha/hifigan/*",
]
additional_dependencies: [flake8-rst-docstrings==0.3.0]
# pylint
- repo: https://github.com/pycqa/pylint
rev: v3.0.3
hooks:
- id: pylint
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment