Commit 5c023842 authored by chenpangpang's avatar chenpangpang
Browse files

feat: 增加LatentSync

parent 822b66ca
Pipeline #2211 canceled with stages
# Adapted from https://github.com/Rudrabha/Wav2Lip/blob/master/audio.py
import librosa
import librosa.filters
import numpy as np
from scipy import signal
from scipy.io import wavfile
from omegaconf import OmegaConf
import torch
audio_config_path = "configs/audio.yaml"
config = OmegaConf.load(audio_config_path)
def load_wav(path, sr):
return librosa.core.load(path, sr=sr)[0]
def save_wav(wav, path, sr):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
# proposed by @dsmiller
wavfile.write(path, sr, wav.astype(np.int16))
def save_wavenet_wav(wav, path, sr):
librosa.output.write_wav(path, wav, sr=sr)
def preemphasis(wav, k, preemphasize=True):
if preemphasize:
return signal.lfilter([1, -k], [1], wav)
return wav
def inv_preemphasis(wav, k, inv_preemphasize=True):
if inv_preemphasize:
return signal.lfilter([1], [1, -k], wav)
return wav
def get_hop_size():
hop_size = config.audio.hop_size
if hop_size is None:
assert config.audio.frame_shift_ms is not None
hop_size = int(config.audio.frame_shift_ms / 1000 * config.audio.sample_rate)
return hop_size
def linearspectrogram(wav):
D = _stft(preemphasis(wav, config.audio.preemphasis, config.audio.preemphasize))
S = _amp_to_db(np.abs(D)) - config.audio.ref_level_db
if config.audio.signal_normalization:
return _normalize(S)
return S
def melspectrogram(wav):
D = _stft(preemphasis(wav, config.audio.preemphasis, config.audio.preemphasize))
S = _amp_to_db(_linear_to_mel(np.abs(D))) - config.audio.ref_level_db
if config.audio.signal_normalization:
return _normalize(S)
return S
def _lws_processor():
import lws
return lws.lws(config.audio.n_fft, get_hop_size(), fftsize=config.audio.win_size, mode="speech")
def _stft(y):
if config.audio.use_lws:
return _lws_processor(config.audio).stft(y).T
else:
return librosa.stft(y=y, n_fft=config.audio.n_fft, hop_length=get_hop_size(), win_length=config.audio.win_size)
##########################################################
# Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
def num_frames(length, fsize, fshift):
"""Compute number of time frames of spectrogram"""
pad = fsize - fshift
if length % fshift == 0:
M = (length + pad * 2 - fsize) // fshift + 1
else:
M = (length + pad * 2 - fsize) // fshift + 2
return M
def pad_lr(x, fsize, fshift):
"""Compute left and right padding"""
M = num_frames(len(x), fsize, fshift)
pad = fsize - fshift
T = len(x) + 2 * pad
r = (M - 1) * fshift + fsize - T
return pad, pad + r
##########################################################
# Librosa correct padding
def librosa_pad_lr(x, fsize, fshift):
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
# Conversions
_mel_basis = None
def _linear_to_mel(spectogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = _build_mel_basis()
return np.dot(_mel_basis, spectogram)
def _build_mel_basis():
assert config.audio.fmax <= config.audio.sample_rate // 2
return librosa.filters.mel(
sr=config.audio.sample_rate,
n_fft=config.audio.n_fft,
n_mels=config.audio.num_mels,
fmin=config.audio.fmin,
fmax=config.audio.fmax,
)
def _amp_to_db(x):
min_level = np.exp(config.audio.min_level_db / 20 * np.log(10))
return 20 * np.log10(np.maximum(min_level, x))
def _db_to_amp(x):
return np.power(10.0, (x) * 0.05)
def _normalize(S):
if config.audio.allow_clipping_in_normalization:
if config.audio.symmetric_mels:
return np.clip(
(2 * config.audio.max_abs_value) * ((S - config.audio.min_level_db) / (-config.audio.min_level_db))
- config.audio.max_abs_value,
-config.audio.max_abs_value,
config.audio.max_abs_value,
)
else:
return np.clip(
config.audio.max_abs_value * ((S - config.audio.min_level_db) / (-config.audio.min_level_db)),
0,
config.audio.max_abs_value,
)
assert S.max() <= 0 and S.min() - config.audio.min_level_db >= 0
if config.audio.symmetric_mels:
return (2 * config.audio.max_abs_value) * (
(S - config.audio.min_level_db) / (-config.audio.min_level_db)
) - config.audio.max_abs_value
else:
return config.audio.max_abs_value * ((S - config.audio.min_level_db) / (-config.audio.min_level_db))
def _denormalize(D):
if config.audio.allow_clipping_in_normalization:
if config.audio.symmetric_mels:
return (
(np.clip(D, -config.audio.max_abs_value, config.audio.max_abs_value) + config.audio.max_abs_value)
* -config.audio.min_level_db
/ (2 * config.audio.max_abs_value)
) + config.audio.min_level_db
else:
return (
np.clip(D, 0, config.audio.max_abs_value) * -config.audio.min_level_db / config.audio.max_abs_value
) + config.audio.min_level_db
if config.audio.symmetric_mels:
return (
(D + config.audio.max_abs_value) * -config.audio.min_level_db / (2 * config.audio.max_abs_value)
) + config.audio.min_level_db
else:
return (D * -config.audio.min_level_db / config.audio.max_abs_value) + config.audio.min_level_db
def get_melspec_overlap(audio_samples, melspec_length=52):
mel_spec_overlap = melspectrogram(audio_samples.numpy())
mel_spec_overlap = torch.from_numpy(mel_spec_overlap)
i = 0
mel_spec_overlap_list = []
while i + melspec_length < mel_spec_overlap.shape[1] - 3:
mel_spec_overlap_list.append(mel_spec_overlap[:, i : i + melspec_length].unsqueeze(0))
i += 3
mel_spec_overlap = torch.stack(mel_spec_overlap_list)
return mel_spec_overlap
# We modified the original AVReader class of decord to solve the problem of memory leak.
# For more details, refer to: https://github.com/dmlc/decord/issues/208
import numpy as np
from decord.video_reader import VideoReader
from decord.audio_reader import AudioReader
from decord.ndarray import cpu
from decord import ndarray as _nd
from decord.bridge import bridge_out
class AVReader(object):
"""Individual audio video reader with convenient indexing function.
Parameters
----------
uri: str
Path of file.
ctx: decord.Context
The context to decode the file, can be decord.cpu() or decord.gpu().
sample_rate: int, default is -1
Desired output sample rate of the audio, unchanged if `-1` is specified.
mono: bool, default is True
Desired output channel layout of the audio. `True` is mono layout. `False` is unchanged.
width : int, default is -1
Desired output width of the video, unchanged if `-1` is specified.
height : int, default is -1
Desired output height of the video, unchanged if `-1` is specified.
num_threads : int, default is 0
Number of decoding thread, auto if `0` is specified.
fault_tol : int, default is -1
The threshold of corupted and recovered frames. This is to prevent silent fault
tolerance when for example 50% frames of a video cannot be decoded and duplicate
frames are returned. You may find the fault tolerant feature sweet in many cases,
but not for training models. Say `N = # recovered frames`
If `fault_tol` < 0, nothing will happen.
If 0 < `fault_tol` < 1.0, if N > `fault_tol * len(video)`, raise `DECORDLimitReachedError`.
If 1 < `fault_tol`, if N > `fault_tol`, raise `DECORDLimitReachedError`.
"""
def __init__(
self, uri, ctx=cpu(0), sample_rate=44100, mono=True, width=-1, height=-1, num_threads=0, fault_tol=-1
):
self.__audio_reader = AudioReader(uri, ctx, sample_rate, mono)
self.__audio_reader.add_padding()
if hasattr(uri, "read"):
uri.seek(0)
self.__video_reader = VideoReader(uri, ctx, width, height, num_threads, fault_tol)
self.__video_reader.seek(0)
def __len__(self):
"""Get length of the video. Note that sometimes FFMPEG reports inaccurate number of frames,
we always follow what FFMPEG reports.
Returns
-------
int
The number of frames in the video file.
"""
return len(self.__video_reader)
def __getitem__(self, idx):
"""Get audio samples and video frame at `idx`.
Parameters
----------
idx : int or slice
The frame index, can be negative which means it will index backwards,
or slice of frame indices.
Returns
-------
(ndarray/list of ndarray, ndarray)
First element is samples of shape CxS or a list of length N containing samples of shape CxS,
where N is the number of frames, C is the number of channels,
S is the number of samples of the corresponding frame.
Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
where N is the length of the slice.
"""
assert self.__video_reader is not None and self.__audio_reader is not None
if isinstance(idx, slice):
return self.get_batch(range(*idx.indices(len(self.__video_reader))))
if idx < 0:
idx += len(self.__video_reader)
if idx >= len(self.__video_reader) or idx < 0:
raise IndexError("Index: {} out of bound: {}".format(idx, len(self.__video_reader)))
audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
results = (self.__audio_reader[audio_start_idx:audio_end_idx], self.__video_reader[idx])
self.__video_reader.seek(0)
return results
def get_batch(self, indices):
"""Get entire batch of audio samples and video frames.
Parameters
----------
indices : list of integers
A list of frame indices. If negative indices detected, the indices will be indexed from backward
Returns
-------
(list of ndarray, ndarray)
First element is a list of length N containing samples of shape CxS,
where N is the number of frames, C is the number of channels,
S is the number of samples of the corresponding frame.
Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
where N is the length of the slice.
"""
assert self.__video_reader is not None and self.__audio_reader is not None
indices = self._validate_indices(indices)
audio_arr = []
prev_video_idx = None
prev_audio_end_idx = None
for idx in list(indices):
frame_start_time, frame_end_time = self.__video_reader.get_frame_timestamp(idx)
# timestamp and sample conversion could have some error that could cause non-continuous audio
# we detect if retrieving continuous frame and make the audio continuous
if prev_video_idx and idx == prev_video_idx + 1:
audio_start_idx = prev_audio_end_idx
else:
audio_start_idx = self.__audio_reader._time_to_sample(frame_start_time)
audio_end_idx = self.__audio_reader._time_to_sample(frame_end_time)
audio_arr.append(self.__audio_reader[audio_start_idx:audio_end_idx])
prev_video_idx = idx
prev_audio_end_idx = audio_end_idx
results = (audio_arr, self.__video_reader.get_batch(indices))
self.__video_reader.seek(0)
return results
def _get_slice(self, sl):
audio_arr = np.empty(shape=(self.__audio_reader.shape()[0], 0), dtype="float32")
for idx in list(sl):
audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
audio_arr = np.concatenate(
(audio_arr, self.__audio_reader[audio_start_idx:audio_end_idx].asnumpy()), axis=1
)
results = (bridge_out(_nd.array(audio_arr)), self.__video_reader.get_batch(sl))
self.__video_reader.seek(0)
return results
def _validate_indices(self, indices):
"""Validate int64 integers and convert negative integers to positive by backward search"""
assert self.__video_reader is not None and self.__audio_reader is not None
indices = np.array(indices, dtype=np.int64)
# process negative indices
indices[indices < 0] += len(self.__video_reader)
if not (indices >= 0).all():
raise IndexError("Invalid negative indices: {}".format(indices[indices < 0] + len(self.__video_reader)))
if not (indices < len(self.__video_reader)).all():
raise IndexError("Out of bound indices: {}".format(indices[indices >= len(self.__video_reader)]))
return indices
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchvision import transforms
import cv2
from einops import rearrange
import mediapipe as mp
import torch
import numpy as np
from typing import Union
from .affine_transform import AlignRestore, laplacianSmooth
import face_alignment
"""
If you are enlarging the image, you should prefer to use INTER_LINEAR or INTER_CUBIC interpolation. If you are shrinking the image, you should prefer to use INTER_AREA interpolation.
https://stackoverflow.com/questions/23853632/which-kind-of-interpolation-best-for-resizing-image
"""
def load_fixed_mask(resolution: int) -> torch.Tensor:
mask_image = cv2.imread("latentsync/utils/mask.png")
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_AREA) / 255.0
mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w")
return mask_image
class ImageProcessor:
def __init__(self, resolution: int = 512, mask: str = "fix_mask", device: str = "cpu", mask_image=None):
self.resolution = resolution
self.resize = transforms.Resize(
(resolution, resolution), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True
)
self.normalize = transforms.Normalize([0.5], [0.5], inplace=True)
self.mask = mask
if mask in ["mouth", "face", "eye"]:
self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True) # Process single image
if mask == "fix_mask":
self.face_mesh = None
self.smoother = laplacianSmooth()
self.restorer = AlignRestore()
if mask_image is None:
self.mask_image = load_fixed_mask(resolution)
else:
self.mask_image = mask_image
if device != "cpu":
self.fa = face_alignment.FaceAlignment(
face_alignment.LandmarksType.TWO_D, flip_input=False, device=device
)
self.face_mesh = None
else:
# self.face_mesh = mp.solutions.face_mesh.FaceMesh(static_image_mode=True) # Process single image
self.face_mesh = None
self.fa = None
def detect_facial_landmarks(self, image: np.ndarray):
height, width, _ = image.shape
results = self.face_mesh.process(image)
if not results.multi_face_landmarks: # Face not detected
raise RuntimeError("Face not detected")
face_landmarks = results.multi_face_landmarks[0] # Only use the first face in the image
landmark_coordinates = [
(int(landmark.x * width), int(landmark.y * height)) for landmark in face_landmarks.landmark
] # x means width, y means height
return landmark_coordinates
def preprocess_one_masked_image(self, image: torch.Tensor) -> np.ndarray:
image = self.resize(image)
if self.mask == "mouth" or self.mask == "face":
landmark_coordinates = self.detect_facial_landmarks(image)
if self.mask == "mouth":
surround_landmarks = mouth_surround_landmarks
else:
surround_landmarks = face_surround_landmarks
points = [landmark_coordinates[landmark] for landmark in surround_landmarks]
points = np.array(points)
mask = np.ones((self.resolution, self.resolution))
mask = cv2.fillPoly(mask, pts=[points], color=(0, 0, 0))
mask = torch.from_numpy(mask)
mask = mask.unsqueeze(0)
elif self.mask == "half":
mask = torch.ones((self.resolution, self.resolution))
height = mask.shape[0]
mask[height // 2 :, :] = 0
mask = mask.unsqueeze(0)
elif self.mask == "eye":
mask = torch.ones((self.resolution, self.resolution))
landmark_coordinates = self.detect_facial_landmarks(image)
y = landmark_coordinates[195][1]
mask[y:, :] = 0
mask = mask.unsqueeze(0)
else:
raise ValueError("Invalid mask type")
image = image.to(dtype=torch.float32)
pixel_values = self.normalize(image / 255.0)
masked_pixel_values = pixel_values * mask
mask = 1 - mask
return pixel_values, masked_pixel_values, mask
def affine_transform(self, image: torch.Tensor) -> np.ndarray:
# image = rearrange(image, "c h w-> h w c").numpy()
if self.fa is None:
landmark_coordinates = np.array(self.detect_facial_landmarks(image))
lm68 = mediapipe_lm478_to_face_alignment_lm68(landmark_coordinates)
else:
detected_faces = self.fa.get_landmarks(image)
if detected_faces is None:
raise RuntimeError("Face not detected")
lm68 = detected_faces[0]
points = self.smoother.smooth(lm68)
lmk3_ = np.zeros((3, 2))
lmk3_[0] = points[17:22].mean(0)
lmk3_[1] = points[22:27].mean(0)
lmk3_[2] = points[27:36].mean(0)
# print(lmk3_)
face, affine_matrix = self.restorer.align_warp_face(
image.copy(), lmks3=lmk3_, smooth=True, border_mode="constant"
)
box = [0, 0, face.shape[1], face.shape[0]] # x1, y1, x2, y2
face = cv2.resize(face, (self.resolution, self.resolution), interpolation=cv2.INTER_CUBIC)
face = rearrange(torch.from_numpy(face), "h w c -> c h w")
return face, box, affine_matrix
def preprocess_fixed_mask_image(self, image: torch.Tensor, affine_transform=False):
if affine_transform:
image, _, _ = self.affine_transform(image)
else:
image = self.resize(image)
pixel_values = self.normalize(image / 255.0)
masked_pixel_values = pixel_values * self.mask_image
return pixel_values, masked_pixel_values, self.mask_image[0:1]
def prepare_masks_and_masked_images(self, images: Union[torch.Tensor, np.ndarray], affine_transform=False):
if isinstance(images, np.ndarray):
images = torch.from_numpy(images)
if images.shape[3] == 3:
images = rearrange(images, "b h w c -> b c h w")
if self.mask == "fix_mask":
results = [self.preprocess_fixed_mask_image(image, affine_transform=affine_transform) for image in images]
else:
results = [self.preprocess_one_masked_image(image) for image in images]
pixel_values_list, masked_pixel_values_list, masks_list = list(zip(*results))
return torch.stack(pixel_values_list), torch.stack(masked_pixel_values_list), torch.stack(masks_list)
def process_images(self, images: Union[torch.Tensor, np.ndarray]):
if isinstance(images, np.ndarray):
images = torch.from_numpy(images)
if images.shape[3] == 3:
images = rearrange(images, "b h w c -> b c h w")
images = self.resize(images)
pixel_values = self.normalize(images / 255.0)
return pixel_values
def close(self):
if self.face_mesh is not None:
self.face_mesh.close()
def mediapipe_lm478_to_face_alignment_lm68(lm478, return_2d=True):
"""
lm478: [B, 478, 3] or [478,3]
"""
# lm478[..., 0] *= W
# lm478[..., 1] *= H
landmarks_extracted = []
for index in landmark_points_68:
x = lm478[index][0]
y = lm478[index][1]
landmarks_extracted.append((x, y))
return np.array(landmarks_extracted)
landmark_points_68 = [
162,
234,
93,
58,
172,
136,
149,
148,
152,
377,
378,
365,
397,
288,
323,
454,
389,
71,
63,
105,
66,
107,
336,
296,
334,
293,
301,
168,
197,
5,
4,
75,
97,
2,
326,
305,
33,
160,
158,
133,
153,
144,
362,
385,
387,
263,
373,
380,
61,
39,
37,
0,
267,
269,
291,
405,
314,
17,
84,
181,
78,
82,
13,
312,
308,
317,
14,
87,
]
# Refer to https://storage.googleapis.com/mediapipe-assets/documentation/mediapipe_face_landmark_fullsize.png
mouth_surround_landmarks = [
164,
165,
167,
92,
186,
57,
43,
106,
182,
83,
18,
313,
406,
335,
273,
287,
410,
322,
391,
393,
]
face_surround_landmarks = [
152,
377,
400,
378,
379,
365,
397,
288,
435,
433,
411,
425,
423,
327,
326,
94,
97,
98,
203,
205,
187,
213,
215,
58,
172,
136,
150,
149,
176,
148,
]
if __name__ == "__main__":
image_processor = ImageProcessor(512, mask="fix_mask")
video = cv2.VideoCapture("/mnt/bn/maliva-gen-ai-v2/chunyu.li/HDTF/original/val/RD_Radio57_000.mp4")
while True:
ret, frame = video.read()
# if not ret:
# break
# cv2.imwrite("image.jpg", frame)
frame = rearrange(torch.Tensor(frame).type(torch.uint8), "h w c -> c h w")
# face, masked_face, _ = image_processor.preprocess_fixed_mask_image(frame, affine_transform=True)
face, _, _ = image_processor.affine_transform(frame)
break
face = (rearrange(face, "c h w -> h w c").detach().cpu().numpy()).astype(np.uint8)
cv2.imwrite("face.jpg", face)
# masked_face = (rearrange(masked_face, "c h w -> h w c").detach().cpu().numpy()).astype(np.uint8)
# cv2.imwrite("masked_face.jpg", masked_face)
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import imageio
import numpy as np
import json
from typing import Union
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.distributed as dist
from torchvision import transforms
from tqdm import tqdm
from einops import rearrange
import cv2
from decord import AudioReader, VideoReader
import shutil
import subprocess
# Machine epsilon for a float32 (single precision)
eps = np.finfo(np.float32).eps
def read_json(filepath: str):
with open(filepath) as f:
json_dict = json.load(f)
return json_dict
def read_video(video_path: str, change_fps=True, use_decord=True):
if change_fps:
temp_dir = "temp"
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
os.makedirs(temp_dir, exist_ok=True)
command = (
f"ffmpeg -loglevel error -y -nostdin -i {video_path} -r 25 -crf 18 {os.path.join(temp_dir, 'video.mp4')}"
)
subprocess.run(command, shell=True)
target_video_path = os.path.join(temp_dir, "video.mp4")
else:
target_video_path = video_path
if use_decord:
return read_video_decord(target_video_path)
else:
return read_video_cv2(target_video_path)
def read_video_decord(video_path: str):
vr = VideoReader(video_path)
video_frames = vr[:].asnumpy()
vr.seek(0)
return video_frames
def read_video_cv2(video_path: str):
# Open the video file
cap = cv2.VideoCapture(video_path)
# Check if the video was opened successfully
if not cap.isOpened():
print("Error: Could not open video.")
return np.array([])
frames = []
while True:
# Read a frame
ret, frame = cap.read()
# If frame is read correctly ret is True
if not ret:
break
# Convert BGR to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame_rgb)
# Release the video capture object
cap.release()
return np.array(frames)
def read_audio(audio_path: str, audio_sample_rate: int = 16000):
if audio_path is None:
raise ValueError("Audio path is required.")
ar = AudioReader(audio_path, sample_rate=audio_sample_rate, mono=True)
# To access the audio samples
audio_samples = torch.from_numpy(ar[:].asnumpy())
audio_samples = audio_samples.squeeze(0)
return audio_samples
def write_video(video_output_path: str, video_frames: np.ndarray, fps: int):
height, width = video_frames[0].shape[:2]
out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
# out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*"vp09"), fps, (width, height))
for frame in video_frames:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
out.write(frame)
out.release()
def init_dist(backend="nccl", **kwargs):
"""Initializes distributed environment."""
rank = int(os.environ["RANK"])
num_gpus = torch.cuda.device_count()
if num_gpus == 0:
raise RuntimeError("No GPUs available for training.")
local_rank = rank % num_gpus
torch.cuda.set_device(local_rank)
dist.init_process_group(backend=backend, **kwargs)
return local_rank
def zero_rank_print(s):
if dist.is_initialized() and dist.get_rank() == 0:
print("### " + s)
def zero_rank_log(logger, message: str):
if dist.is_initialized() and dist.get_rank() == 0:
logger.info(message)
def make_audio_window(audio_embeddings: torch.Tensor, window_size: int):
audio_window = []
end_idx = audio_embeddings.shape[1] - window_size + 1
for i in range(end_idx):
audio_window.append(audio_embeddings[:, i : i + window_size, :])
audio_window = torch.stack(audio_window)
audio_window = rearrange(audio_window, "f b w d -> b f w d")
return audio_window
def check_video_fps(video_path: str):
cam = cv2.VideoCapture(video_path)
fps = cam.get(cv2.CAP_PROP_FPS)
if fps != 25:
raise ValueError(f"Video FPS is not 25, it is {fps}. Please convert the video to 25 FPS.")
def tailor_tensor_to_length(tensor: torch.Tensor, length: int):
if len(tensor) == length:
return tensor
elif len(tensor) > length:
return tensor[:length]
else:
return torch.cat([tensor, tensor[-1].repeat(length - len(tensor))])
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
videos = rearrange(videos, "b c f h w -> f b c h w")
outputs = []
for x in videos:
x = torchvision.utils.make_grid(x, nrow=n_rows)
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
if rescale:
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
x = (x * 255).numpy().astype(np.uint8)
outputs.append(x)
os.makedirs(os.path.dirname(path), exist_ok=True)
imageio.mimsave(path, outputs, fps=fps)
def interpolate_features(features: torch.Tensor, output_len: int) -> torch.Tensor:
features = features.cpu().numpy()
input_len, num_features = features.shape
input_timesteps = np.linspace(0, 10, input_len)
output_timesteps = np.linspace(0, 10, output_len)
output_features = np.zeros((output_len, num_features))
for feat in range(num_features):
output_features[:, feat] = np.interp(output_timesteps, input_timesteps, features[:, feat])
return torch.from_numpy(output_features)
# DDIM Inversion
@torch.no_grad()
def init_prompt(prompt, pipeline):
uncond_input = pipeline.tokenizer(
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, return_tensors="pt"
)
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
text_input = pipeline.tokenizer(
[prompt],
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
context = torch.cat([uncond_embeddings, text_embeddings])
return context
def reversed_forward(ddim_scheduler, pred_noise, timesteps, x_t):
# Compute alphas, betas
alpha_prod_t = ddim_scheduler.alphas_cumprod[timesteps]
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if ddim_scheduler.config.prediction_type == "epsilon":
beta_prod_t = beta_prod_t[:, None, None, None, None]
alpha_prod_t = alpha_prod_t[:, None, None, None, None]
pred_original_sample = (x_t - beta_prod_t ** (0.5) * pred_noise) / alpha_prod_t ** (0.5)
else:
raise NotImplementedError("This prediction type is not implemented yet")
# Clip "predicted x_0"
if ddim_scheduler.config.clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
return pred_original_sample
def next_step(
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
ddim_scheduler,
):
timestep, next_timestep = (
min(timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999),
timestep,
)
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
beta_prod_t = 1 - alpha_prod_t
next_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
next_sample = alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction
return next_sample
def get_noise_pred_single(latents, t, context, unet):
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
return noise_pred
@torch.no_grad()
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
context = init_prompt(prompt, pipeline)
uncond_embeddings, cond_embeddings = context.chunk(2)
all_latent = [latent]
latent = latent.clone().detach()
for i in tqdm(range(num_inv_steps)):
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
latent = next_step(noise_pred, t, latent, ddim_scheduler)
all_latent.append(latent)
return all_latent
@torch.no_grad()
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
return ddim_latents
def plot_loss_chart(save_path: str, *args):
# Creating the plot
plt.figure()
for loss_line in args:
plt.plot(loss_line[1], loss_line[2], label=loss_line[0])
plt.xlabel("Step")
plt.ylabel("Loss")
plt.legend()
# Save the figure to a file
plt.savefig(save_path)
# Close the figure to free memory
plt.close()
CRED = "\033[91m"
CEND = "\033[0m"
def red_text(text: str):
return f"{CRED}{text}{CEND}"
log_loss = nn.BCELoss(reduction="none")
def cosine_loss(vision_embeds, audio_embeds, y):
sims = nn.functional.cosine_similarity(vision_embeds, audio_embeds)
# sims[sims!=sims] = 0 # remove nan
# sims = sims.clamp(0, 1)
loss = log_loss(sims.unsqueeze(1), y).squeeze()
return loss
def save_image(image, save_path):
# input size (C, H, W)
image = (image / 2 + 0.5).clamp(0, 1)
image = (image * 255).to(torch.uint8)
image = transforms.ToPILImage()(image)
# Save the image copy
image.save(save_path)
# Close the image file
image.close()
def gather_loss(loss, device):
# Sum the local loss across all processes
local_loss = loss.item()
global_loss = torch.tensor(local_loss, dtype=torch.float32).to(device)
dist.all_reduce(global_loss, op=dist.ReduceOp.SUM)
# Calculate the average loss across all processes
global_average_loss = global_loss.item() / dist.get_world_size()
return global_average_loss
def gather_video_paths_recursively(input_dir):
print(f"Recursively gathering video paths of {input_dir} ...")
paths = []
gather_video_paths(input_dir, paths)
return paths
def gather_video_paths(input_dir, paths):
for file in sorted(os.listdir(input_dir)):
if file.endswith(".mp4"):
filepath = os.path.join(input_dir, file)
paths.append(filepath)
elif os.path.isdir(os.path.join(input_dir, file)):
gather_video_paths(os.path.join(input_dir, file), paths)
def count_video_time(video_path):
video = cv2.VideoCapture(video_path)
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
fps = video.get(cv2.CAP_PROP_FPS)
return frame_count / fps
# Adapted from https://github.com/TMElyralab/MuseTalk/blob/main/musetalk/whisper/audio2feature.py
from .whisper import load_model
import numpy as np
import torch
import os
class Audio2Feature:
def __init__(
self,
model_path="checkpoints/whisper/tiny.pt",
device=None,
audio_embeds_cache_dir=None,
num_frames=16,
):
self.model = load_model(model_path, device)
self.audio_embeds_cache_dir = audio_embeds_cache_dir
self.num_frames = num_frames
self.embedding_dim = self.model.dims.n_audio_state
def get_sliced_feature(self, feature_array, vid_idx, audio_feat_length=[2, 2], fps=25):
"""
Get sliced features based on a given index
:param feature_array:
:param start_idx: the start index of the feature
:param audio_feat_length:
:return:
"""
length = len(feature_array)
selected_feature = []
selected_idx = []
center_idx = int(vid_idx * 50 / fps)
left_idx = center_idx - audio_feat_length[0] * 2
right_idx = center_idx + (audio_feat_length[1] + 1) * 2
for idx in range(left_idx, right_idx):
idx = max(0, idx)
idx = min(length - 1, idx)
x = feature_array[idx]
selected_feature.append(x)
selected_idx.append(idx)
selected_feature = torch.cat(selected_feature, dim=0)
selected_feature = selected_feature.reshape(-1, self.embedding_dim) # 50*384
return selected_feature, selected_idx
def get_sliced_feature_sparse(self, feature_array, vid_idx, audio_feat_length=[2, 2], fps=25):
"""
Get sliced features based on a given index
:param feature_array:
:param start_idx: the start index of the feature
:param audio_feat_length:
:return:
"""
length = len(feature_array)
selected_feature = []
selected_idx = []
for dt in range(-audio_feat_length[0], audio_feat_length[1] + 1):
left_idx = int((vid_idx + dt) * 50 / fps)
if left_idx < 1 or left_idx > length - 1:
left_idx = max(0, left_idx)
left_idx = min(length - 1, left_idx)
x = feature_array[left_idx]
x = x[np.newaxis, :, :]
x = np.repeat(x, 2, axis=0)
selected_feature.append(x)
selected_idx.append(left_idx)
selected_idx.append(left_idx)
else:
x = feature_array[left_idx - 1 : left_idx + 1]
selected_feature.append(x)
selected_idx.append(left_idx - 1)
selected_idx.append(left_idx)
selected_feature = np.concatenate(selected_feature, axis=0)
selected_feature = selected_feature.reshape(-1, self.embedding_dim) # 50*384
selected_feature = torch.from_numpy(selected_feature)
return selected_feature, selected_idx
def feature2chunks(self, feature_array, fps, audio_feat_length=[2, 2]):
whisper_chunks = []
whisper_idx_multiplier = 50.0 / fps
i = 0
print(f"video in {fps} FPS, audio idx in 50FPS")
while True:
start_idx = int(i * whisper_idx_multiplier)
selected_feature, selected_idx = self.get_sliced_feature(
feature_array=feature_array, vid_idx=i, audio_feat_length=audio_feat_length, fps=fps
)
# print(f"i:{i},selected_idx {selected_idx}")
whisper_chunks.append(selected_feature)
i += 1
if start_idx > len(feature_array):
break
return whisper_chunks
def _audio2feat(self, audio_path: str):
# get the sample rate of the audio
result = self.model.transcribe(audio_path)
embed_list = []
for emb in result["segments"]:
encoder_embeddings = emb["encoder_embeddings"]
encoder_embeddings = encoder_embeddings.transpose(0, 2, 1, 3)
encoder_embeddings = encoder_embeddings.squeeze(0)
start_idx = int(emb["start"])
end_idx = int(emb["end"])
emb_end_idx = int((end_idx - start_idx) / 2)
embed_list.append(encoder_embeddings[:emb_end_idx])
concatenated_array = torch.from_numpy(np.concatenate(embed_list, axis=0))
return concatenated_array
def audio2feat(self, audio_path):
if self.audio_embeds_cache_dir == "" or self.audio_embeds_cache_dir is None:
return self._audio2feat(audio_path)
audio_embeds_cache_path = os.path.join(self.audio_embeds_cache_dir, os.path.basename(audio_path) + ".pt")
if os.path.isfile(audio_embeds_cache_path):
try:
audio_feat = torch.load(audio_embeds_cache_path)
except Exception as e:
print(f"{type(e).__name__} - {e} - {audio_embeds_cache_path}")
os.remove(audio_embeds_cache_path)
audio_feat = self._audio2feat(audio_path)
torch.save(audio_feat, audio_embeds_cache_path)
else:
audio_feat = self._audio2feat(audio_path)
torch.save(audio_feat, audio_embeds_cache_path)
return audio_feat
def crop_overlap_audio_window(self, audio_feat, start_index):
selected_feature_list = []
for i in range(start_index, start_index + self.num_frames):
selected_feature, selected_idx = self.get_sliced_feature(
feature_array=audio_feat, vid_idx=i, audio_feat_length=[2, 2], fps=25
)
selected_feature_list.append(selected_feature)
mel_overlap = torch.stack(selected_feature_list)
return mel_overlap
if __name__ == "__main__":
audio_encoder = Audio2Feature(model_path="checkpoints/whisper/tiny.pt")
audio_path = "assets/demo1_audio.wav"
array = audio_encoder.audio2feat(audio_path)
print(array.shape)
fps = 25
whisper_idx_multiplier = 50.0 / fps
i = 0
print(f"video in {fps} FPS, audio idx in 50FPS")
while True:
start_idx = int(i * whisper_idx_multiplier)
selected_feature, selected_idx = audio_encoder.get_sliced_feature(
feature_array=array, vid_idx=i, audio_feat_length=[2, 2], fps=fps
)
print(f"video idx {i},\t audio idx {selected_idx},\t shape {selected_feature.shape}")
i += 1
if start_idx > len(array):
break
import hashlib
import io
import os
import urllib
import warnings
from typing import List, Optional, Union
import torch
from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import Whisper, ModelDimensions
from .transcribe import transcribe
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
}
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return model_bytes if in_memory else download_target
def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_MODELS.keys())
def load_model(
name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False
) -> Whisper:
"""
Load a Whisper ASR model
Parameters
----------
name : str
one of the official model names listed by `whisper.available_models()`, or
path to a model checkpoint containing the model dimensions and the model state_dict.
device : Union[str, torch.device]
the PyTorch device to put the model into
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool
whether to preload the model weights into host memory
Returns
-------
model : Whisper
The Whisper ASR model instance
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
download_root = os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
return model.to(device)
This source diff could not be displayed because it is too large. You can view the blob instead.
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
\ No newline at end of file
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
\ No newline at end of file
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
import os
from functools import lru_cache
from typing import Union
import ffmpeg
import numpy as np
import torch
import torch.nn.functional as F
from .utils import exact_div
# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
try:
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
out, _ = (
ffmpeg.input(file, threads=0)
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)
except ffmpeg.Error as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(dim=axis, index=torch.arange(length))
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
else:
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)
return array
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
)
"""
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 is supported
Returns
-------
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[:, :-1].abs() ** 2
filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical
from .audio import CHUNK_LENGTH
from .tokenizer import Tokenizer, get_tokenizer
from .utils import compression_ratio
if TYPE_CHECKING:
from .model import Whisper
@torch.no_grad()
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
"""
Detect the spoken language in the audio, and return them as list of strings, along with the ids
of the most probable language tokens and the probability distribution over all language tokens.
This is performed outside the main decode loop in order to not interfere with kv-caching.
Returns
-------
language_tokens : Tensor, shape = (n_audio,)
ids of the most probable language tokens, which appears after the startoftranscript token.
language_probs : List[Dict[str, float]], length = n_audio
list of dictionaries containing the probability distribution over all languages.
"""
if tokenizer is None:
tokenizer = get_tokenizer(model.is_multilingual)
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
single = mel.ndim == 2
if single:
mel = mel.unsqueeze(0)
# skip encoder forward pass if already-encoded audio features were given
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
mel = model.encoder(mel)
# forward pass using a single token, startoftranscript
n_audio = mel.shape[0]
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
logits = model.logits(x, mel)[:, 0]
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(tokenizer.all_language_tokens)] = False
logits[:, mask] = -np.inf
language_tokens = logits.argmax(dim=-1)
language_token_probs = logits.softmax(dim=-1).cpu()
language_probs = [
{
c: language_token_probs[i, j].item()
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
}
for i in range(n_audio)
]
if single:
language_tokens = language_tokens[0]
language_probs = language_probs[0]
return language_tokens, language_probs
@dataclass(frozen=True)
class DecodingOptions:
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
language: Optional[str] = None # language that the audio is in; uses detected language if None
# sampling-related options
temperature: float = 0.0
sample_len: Optional[int] = None # maximum number of tokens to sample
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
# options for ranking generations (either beams or best-of-N samples)
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
# prompt, prefix, and token suppression
prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
suppress_blank: bool = True # this will suppress blank outputs
# list of tokens ids (or comma-separated token ids) to suppress
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
# timestamp sampling options
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
# implementation details
fp16: bool = True # use fp16 for most of the calculation
@dataclass(frozen=True)
class DecodingResult:
audio_features: Tensor
language: str
encoder_embeddings: np.ndarray
decoder_embeddings: np.ndarray
language_probs: Optional[Dict[str, float]] = None
tokens: List[int] = field(default_factory=list)
text: str = ""
avg_logprob: float = np.nan
no_speech_prob: float = np.nan
temperature: float = np.nan
compression_ratio: float = np.nan
class Inference:
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
"""Perform a forward pass on the decoder and return per-token logits"""
raise NotImplementedError
def rearrange_kv_cache(self, source_indices) -> None:
"""Update the key-value cache according to the updated beams"""
raise NotImplementedError
def cleanup_caching(self) -> None:
"""Clean up any resources or hooks after decoding is finished"""
pass
class PyTorchInference(Inference):
def __init__(self, model: "Whisper", initial_token_length: int):
self.model: "Whisper" = model
self.initial_token_length = initial_token_length
self.kv_cache = {}
self.hooks = []
def logits(self, tokens: Tensor, audio_features: Tensor, include_embeddings=False) -> Tensor:
if not self.kv_cache:
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
if tokens.shape[-1] > self.initial_token_length:
# only need to use the last token except in the first forward pass
tokens = tokens[:, -1:]
return_val = self.model.decoder(tokens, audio_features,
kv_cache=self.kv_cache, include_embeddings=include_embeddings)
return return_val
def cleanup_caching(self):
for hook in self.hooks:
hook.remove()
self.kv_cache = {}
self.hooks = []
def rearrange_kv_cache(self, source_indices):
for module, tensor in self.kv_cache.items():
# update the key/value cache to contain the selected sequences
self.kv_cache[module] = tensor[source_indices].detach()
class SequenceRanker:
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
"""
Given a list of groups of samples and their cumulative log probabilities,
return the indices of the samples in each group to select as the final result
"""
raise NotImplementedError
class MaximumLikelihoodRanker(SequenceRanker):
"""
Select the sample with the highest log probabilities, penalized using either
a simple length normalization or Google NMT paper's length penalty
"""
def __init__(self, length_penalty: Optional[float]):
self.length_penalty = length_penalty
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
def scores(logprobs, lengths):
result = []
for logprob, length in zip(logprobs, lengths):
if self.length_penalty is None:
penalty = length
else:
# from the Google NMT paper
penalty = ((5 + length) / 6) ** self.length_penalty
result.append(logprob / penalty)
return result
# get the sequence with the highest score
lengths = [[len(t) for t in s] for s in tokens]
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
class TokenDecoder:
def reset(self):
"""Initialize any stateful variables for decoding a new sequence"""
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
"""Specify how to select the next token, based on the current trace and logits
Parameters
----------
tokens : Tensor, shape = (n_batch, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence tokens
logits : Tensor, shape = (n_batch, vocab_size)
per-token logits of the probability distribution at the current step
sum_logprobs : Tensor, shape = (n_batch)
cumulative log probabilities for each sequence
Returns
-------
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
the tokens, appended with the selected next token
completed : bool
True if all sequences has reached the end of text
"""
raise NotImplementedError
def finalize(
self, tokens: Tensor, sum_logprobs: Tensor
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
"""Finalize search and return the final candidate sequences
Parameters
----------
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence
sum_logprobs : Tensor, shape = (n_audio, n_group)
cumulative log probabilities for each sequence
Returns
-------
tokens : Sequence[Sequence[Tensor]], length = n_audio
sequence of Tensors containing candidate token sequences, for each audio input
sum_logprobs : List[List[float]], length = n_audio
sequence of cumulative log probabilities corresponding to the above
"""
raise NotImplementedError
class GreedyDecoder(TokenDecoder):
def __init__(self, temperature: float, eot: int):
self.temperature = temperature
self.eot = eot
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
temperature = self.temperature
if temperature == 0:
next_tokens = logits.argmax(dim=-1)
else:
next_tokens = Categorical(logits=logits / temperature).sample()
logprobs = F.log_softmax(logits.float(), dim=-1)
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
next_tokens[tokens[:, -1] == self.eot] = self.eot
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
completed = (tokens[:, -1] == self.eot).all()
return tokens, completed
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
# make sure each sequence has at least one EOT token at the end
tokens = F.pad(tokens, (0, 1), value=self.eot)
return tokens, sum_logprobs.tolist()
class BeamSearchDecoder(TokenDecoder):
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
self.patience = patience or 1.0
self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences = None
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
def reset(self):
self.finished_sequences = None
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
n_audio = tokens.shape[0] // self.beam_size
if self.finished_sequences is None: # for the first update
self.finished_sequences = [{} for _ in range(n_audio)]
logprobs = F.log_softmax(logits.float(), dim=-1)
next_tokens, source_indices, finished_sequences = [], [], []
for i in range(n_audio):
scores, sources, finished = {}, {}, {}
# STEP 1: calculate the cumulative log probabilities for possible candidates
for j in range(self.beam_size):
idx = i * self.beam_size + j
prefix = tokens[idx].tolist()
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
new_logprob = (sum_logprobs[idx] + logprob).item()
sequence = tuple(prefix + [token.item()])
scores[sequence] = new_logprob
sources[sequence] = idx
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
saved = 0
for sequence in sorted(scores, key=scores.get, reverse=True):
if sequence[-1] == self.eot:
finished[sequence] = scores[sequence]
else:
sum_logprobs[len(next_tokens)] = scores[sequence]
next_tokens.append(sequence)
source_indices.append(sources[sequence])
saved += 1
if saved == self.beam_size:
break
finished_sequences.append(finished)
tokens = torch.tensor(next_tokens, device=tokens.device)
self.inference.rearrange_kv_cache(source_indices)
# add newly finished sequences to self.finished_sequences
assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates:
break # the candidate list is full
previously_finished[seq] = newly_finished[seq]
# mark as completed if all audio has enough number of samples
completed = all(
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
)
return tokens, completed
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
# collect all finished sequences, including patience, and add unfinished ones if not enough
sum_logprobs = sum_logprobs.cpu()
for i, sequences in enumerate(self.finished_sequences):
if len(sequences) < self.beam_size: # when not enough sequences are finished
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
sequence = preceding_tokens[i, j].tolist() + [self.eot]
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
if len(sequences) >= self.beam_size:
break
tokens: List[List[Tensor]] = [
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
]
sum_logprobs: List[List[float]] = [
list(sequences.values()) for sequences in self.finished_sequences
]
return tokens, sum_logprobs
class LogitFilter:
def apply(self, logits: Tensor, tokens: Tensor) -> None:
"""Apply any filtering or masking to logits in-place
Parameters
----------
logits : Tensor, shape = (n_batch, vocab_size)
per-token logits of the probability distribution at the current step
tokens : Tensor, shape = (n_batch, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence tokens
"""
raise NotImplementedError
class SuppressBlank(LogitFilter):
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
def apply(self, logits: Tensor, tokens: Tensor):
if tokens.shape[1] == self.sample_begin:
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
class SuppressTokens(LogitFilter):
def __init__(self, suppress_tokens: Sequence[int]):
self.suppress_tokens = list(suppress_tokens)
def apply(self, logits: Tensor, tokens: Tensor):
logits[:, self.suppress_tokens] = -np.inf
class ApplyTimestampRules(LogitFilter):
def __init__(
self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
self.max_initial_timestamp_index = max_initial_timestamp_index
def apply(self, logits: Tensor, tokens: Tensor):
# suppress <|notimestamps|> which is handled by without_timestamps
if self.tokenizer.no_timestamps is not None:
logits[:, self.tokenizer.no_timestamps] = -np.inf
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]):
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
if last_was_timestamp:
if penultimate_was_timestamp: # has to be non-timestamp
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf
# apply the `max_initial_timestamp` option
if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None:
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
logits[:, last_allowed + 1 :] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = F.log_softmax(logits.float(), dim=-1)
for k in range(tokens.shape[0]):
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
class DecodingTask:
inference: Inference
sequence_ranker: SequenceRanker
decoder: TokenDecoder
logit_filters: List[LogitFilter]
def __init__(self, model: "Whisper", options: DecodingOptions):
self.model = model
language = options.language or "en"
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
self.n_group: int = options.beam_size or options.best_of or 1
self.n_ctx: int = model.dims.n_text_ctx
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
if self.options.without_timestamps:
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
self.sample_begin: int = len(self.initial_tokens)
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
# inference: implements the forward pass through the decoder, including kv caching
self.inference = PyTorchInference(model, len(self.initial_tokens))
# sequence ranker: implements how to rank a group of sampled sequences
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
# decoder: implements how to select the next tokens, given the autoregressive distribution
if options.beam_size is not None:
self.decoder = BeamSearchDecoder(
options.beam_size, tokenizer.eot, self.inference, options.patience
)
else:
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
# logit filters: applies various rules to suppress or penalize certain tokens
self.logit_filters = []
if self.options.suppress_blank:
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
if self.options.suppress_tokens:
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
if not options.without_timestamps:
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
max_initial_timestamp_index = None
if options.max_initial_timestamp:
max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
self.logit_filters.append(
ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
)
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
if options.beam_size is not None and options.best_of is not None:
raise ValueError("beam_size and best_of can't be given together")
if options.temperature == 0:
if options.best_of is not None:
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
if options.patience is not None and options.beam_size is None:
raise ValueError("patience requires beam_size to be given")
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
return options
def _get_initial_tokens(self) -> Tuple[int]:
tokens = list(self.sot_sequence)
prefix = self.options.prefix
prompt = self.options.prompt
if prefix:
prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
)
if self.sample_len is not None:
max_prefix_len = self.n_ctx // 2 - self.sample_len
prefix_tokens = prefix_tokens[-max_prefix_len:]
tokens = tokens + prefix_tokens
if prompt:
prompt_tokens = (
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
)
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
return tuple(tokens)
def _get_suppress_tokens(self) -> Tuple[int]:
suppress_tokens = self.options.suppress_tokens
if isinstance(suppress_tokens, str):
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
if -1 in suppress_tokens:
suppress_tokens = [t for t in suppress_tokens if t >= 0]
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
elif suppress_tokens is None or len(suppress_tokens) == 0:
suppress_tokens = [] # interpret empty string as an empty list
else:
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
suppress_tokens.extend(
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
)
if self.tokenizer.no_speech is not None:
# no-speech probability is collected separately
suppress_tokens.append(self.tokenizer.no_speech)
return tuple(sorted(set(suppress_tokens)))
def _get_audio_features(self, mel: Tensor, include_embeddings: bool = False):
if self.options.fp16:
mel = mel.half()
if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
# encoded audio features are given; skip audio encoding
audio_features = mel
else:
result = self.model.encoder(mel, include_embeddings)
if include_embeddings:
audio_features, embeddings = result
else:
audio_features = result
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
if include_embeddings:
return audio_features, embeddings
else:
return audio_features
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
languages = [self.options.language] * audio_features.shape[0]
lang_probs = None
if self.options.language is None or self.options.task == "lang_id":
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
languages = [max(probs, key=probs.get) for probs in lang_probs]
if self.options.language is None:
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
return languages, lang_probs
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
assert audio_features.shape[0] == tokens.shape[0]
n_batch = tokens.shape[0]
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
no_speech_probs = [np.nan] * n_batch
try:
embeddings = []
for i in range(self.sample_len):
logits, token_embeddings = self.inference.logits(tokens, audio_features, include_embeddings=True)
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
# now we need to consider the logits at the last token only
logits = logits[:, -1]
token_embeddings = token_embeddings[:, :, -1]
# Append embeddings together
embeddings.append(token_embeddings)
# apply the logit filters, e.g. for suppressing or applying penalty to
for logit_filter in self.logit_filters:
logit_filter.apply(logits, tokens)
# expand the tokens tensor with the selected next tokens
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
if completed or tokens.shape[-1] > self.n_ctx:
break
finally:
if completed:
embeddings = embeddings[:-1]
embeddings = np.stack(embeddings, 2)
self.inference.cleanup_caching()
return tokens, sum_logprobs, no_speech_probs, embeddings
@torch.no_grad()
def run(self, mel: Tensor) -> List[DecodingResult]:
self.decoder.reset()
tokenizer: Tokenizer = self.tokenizer
n_audio: int = mel.shape[0]
# encoder forward pass
forward_pass: Tuple[Tensor, np.ndarray] = self._get_audio_features(mel, include_embeddings=True)
audio_features, encoder_embeddings = forward_pass
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
# detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(audio_features, tokens)
if self.options.task == "lang_id":
return [
DecodingResult(audio_features=features, language=language, language_probs=probs)
for features, language, probs in zip(audio_features, languages, language_probs)
]
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
# call the main sampling loop
tokens, sum_logprobs, no_speech_probs, decoder_embeddings = self._main_loop(audio_features, tokens)
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
audio_features = audio_features[:: self.n_group]
no_speech_probs = no_speech_probs[:: self.n_group]
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
tokens = tokens.reshape(n_audio, self.n_group, -1)
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
# get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens: List[List[Tensor]] = [
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
]
# select the top-ranked sample in each group
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
if len(set(map(len, fields))) != 1:
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
return [
DecodingResult(
audio_features=features,
language=language,
tokens=tokens,
text=text,
avg_logprob=avg_logprob,
no_speech_prob=no_speech_prob,
temperature=self.options.temperature,
compression_ratio=compression_ratio(text),
encoder_embeddings=encoder_embeddings,
decoder_embeddings=decoder_embeddings
)
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
]
@torch.no_grad()
def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
"""
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
Parameters
----------
model: Whisper
the Whisper model instance
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
A tensor containing the Mel spectrogram(s)
options: DecodingOptions
A dataclass that contains all necessary options for decoding 30-second segments
Returns
-------
result: Union[DecodingResult, List[DecodingResult]]
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
"""
single = mel.ndim == 2
if single:
mel = mel.unsqueeze(0)
result = DecodingTask(model, options).run(mel)
if single:
result = result[0]
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment