Commit c2eff804 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add new project

parents
Pipeline #1648 failed with stages
in 0 seconds
import torch
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501
class FeatureExtractor:
def __init__(
self,
device: str = "auto",
feature_size=80,
sampling_rate=16000,
hop_length=160,
chunk_length=30,
n_fft=400,
):
if device == "auto":
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.n_fft = n_fft
self.hop_length = hop_length
self.chunk_length = chunk_length
self.n_samples = chunk_length * sampling_rate
self.nb_max_frames = self.n_samples // hop_length
self.time_per_frame = hop_length / sampling_rate
self.sampling_rate = sampling_rate
self.mel_filters = self.get_mel_filters(
sampling_rate, n_fft, n_mels=feature_size
)
@staticmethod
def get_mel_filters(sr, n_fft, n_mels=128):
"""
Implementation of librosa.filters.mel in Pytorch
"""
# Initialize the weights
n_mels = int(n_mels)
# Center freqs of each FFT bin
fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr)
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 45.245640471924965
mels = torch.linspace(min_mel, max_mel, n_mels + 2)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = torch.log(torch.tensor(6.4)) / 27.0 # step size for log region
# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
mel_f = freqs
fdiff = torch.diff(mel_f)
ramps = mel_f.view(-1, 1) - fftfreqs.view(1, -1)
lower = -ramps[:-2] / fdiff[:-1].unsqueeze(1)
upper = ramps[2:] / fdiff[1:].unsqueeze(1)
# Intersect them with each other and zero, vectorized across all i
weights = torch.maximum(torch.zeros_like(lower), torch.minimum(lower, upper))
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm.unsqueeze(1)
return weights
def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False):
"""
Compute the log-Mel spectrogram of the provided audio.
"""
if chunk_length is not None:
self.n_samples = chunk_length * self.sampling_rate
self.nb_max_frames = self.n_samples // self.hop_length
if waveform.dtype is not torch.float32:
waveform = waveform.to(torch.float32)
waveform = (
waveform.to(self.device)
if self.device == "cuda" and not waveform.is_cuda
else waveform
)
if padding:
waveform = torch.nn.functional.pad(waveform, (0, self.n_samples))
window = torch.hann_window(self.n_fft).to(waveform.device)
stft = torch.stft(
waveform, self.n_fft, self.hop_length, window=window, return_complex=True
)
magnitudes = stft[..., :-1].abs() ** 2
mel_spec = self.mel_filters.to(waveform.device) @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
# When the model is running on multiple GPUs, the output should be moved
# to the CPU since we don't know which GPU will handle the next job.
return log_spec.cpu() if to_cpu else log_spec
import string
from functools import cached_property
from typing import List, Optional, Tuple
import tokenizers
class Tokenizer:
"""Simple wrapper around a tokenizers.Tokenizer."""
def __init__(
self,
tokenizer: tokenizers.Tokenizer,
multilingual: bool,
task: Optional[str] = None,
language: Optional[str] = None,
):
self.tokenizer = tokenizer
if multilingual:
if task not in _TASKS:
raise ValueError(
"'%s' is not a valid task (accepted tasks: %s)"
% (task, ", ".join(_TASKS))
)
if language not in _LANGUAGE_CODES:
raise ValueError(
"'%s' is not a valid language code (accepted language codes: %s)"
% (language, ", ".join(_LANGUAGE_CODES))
)
self.task = self.tokenizer.token_to_id("<|%s|>" % task)
self.language = self.tokenizer.token_to_id("<|%s|>" % language)
self.language_code = language
else:
self.task = None
self.language = None
self.language_code = "en"
@cached_property
def transcribe(self) -> int:
return self.tokenizer.token_to_id("<|transcribe|>")
@cached_property
def translate(self) -> int:
return self.tokenizer.token_to_id("<|translate|>")
@cached_property
def sot(self) -> int:
return self.tokenizer.token_to_id("<|startoftranscript|>")
@cached_property
def sot_lm(self) -> int:
return self.tokenizer.token_to_id("<|startoflm|>")
@cached_property
def sot_prev(self) -> int:
return self.tokenizer.token_to_id("<|startofprev|>")
@cached_property
def eot(self) -> int:
return self.tokenizer.token_to_id("<|endoftext|>")
@cached_property
def no_timestamps(self) -> int:
return self.tokenizer.token_to_id("<|notimestamps|>")
@property
def timestamp_begin(self) -> int:
return self.no_timestamps + 1
@property
def sot_sequence(self) -> List[int]:
sequence = [self.sot]
if self.language is not None:
sequence.append(self.language)
if self.task is not None:
sequence.append(self.task)
return sequence
def encode(self, text: str) -> List[int]:
return self.tokenizer.encode(text, add_special_tokens=False).ids
def decode(self, tokens: List[int]) -> str:
text_tokens = [token for token in tokens if token < self.eot]
return self.tokenizer.decode(text_tokens)
def decode_with_timestamps(self, tokens: List[int]) -> str:
outputs = [[]]
for token in tokens:
if token >= self.timestamp_begin:
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
outputs.append(timestamp)
outputs.append([])
else:
outputs[-1].append(token)
return "".join(
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
)
@cached_property
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
- ♪♪♪
- ( SPEAKING FOREIGN LANGUAGE )
- [DAVID] Hey there,
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
symbols += (
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
)
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
miscellaneous = set("♩♪♫♬♭♮♯")
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.encode(" -")[0], self.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [
self.encode(symbol),
self.encode(" " + symbol),
]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
return tuple(sorted(result))
def split_to_word_tokens(
self, tokens: List[int]
) -> Tuple[List[str], List[List[int]]]:
if self.language_code in {"zh", "ja", "th", "lo", "my", "yue"}:
# These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
return self.split_tokens_on_unicode(tokens)
return self.split_tokens_on_spaces(tokens)
def split_tokens_on_unicode(
self, tokens: List[int]
) -> Tuple[List[str], List[List[int]]]:
decoded_full = self.decode_with_timestamps(tokens)
replacement_char = "\ufffd"
words = []
word_tokens = []
current_tokens = []
unicode_offset = 0
for token in tokens:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
try:
replacement_char_index = decoded.index(replacement_char)
replacement_char_index += unicode_offset
except ValueError:
replacement_char_index = None
if replacement_char_index is None or (
replacement_char_index < len(decoded_full)
and decoded_full[replacement_char_index] == replacement_char
):
words.append(decoded)
word_tokens.append(current_tokens)
current_tokens = []
unicode_offset += len(decoded)
return words, word_tokens
def split_tokens_on_spaces(
self, tokens: List[int]
) -> Tuple[List[str], List[List[int]]]:
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
words = []
word_tokens = []
for subword, subword_tokens in zip(subwords, subword_tokens_list):
special = subword_tokens[0] >= self.eot
with_space = subword.startswith(" ")
punctuation = subword.strip() in string.punctuation
if special or with_space or punctuation or len(words) == 0:
words.append(subword)
word_tokens.append(subword_tokens)
else:
words[-1] = words[-1] + subword
word_tokens[-1].extend(subword_tokens)
return words, word_tokens
_TASKS = (
"transcribe",
"translate",
)
_LANGUAGE_CODES = (
"af",
"am",
"ar",
"as",
"az",
"ba",
"be",
"bg",
"bn",
"bo",
"br",
"bs",
"ca",
"cs",
"cy",
"da",
"de",
"el",
"en",
"es",
"et",
"eu",
"fa",
"fi",
"fo",
"fr",
"gl",
"gu",
"ha",
"haw",
"he",
"hi",
"hr",
"ht",
"hu",
"hy",
"id",
"is",
"it",
"ja",
"jw",
"ka",
"kk",
"km",
"kn",
"ko",
"la",
"lb",
"ln",
"lo",
"lt",
"lv",
"mg",
"mi",
"mk",
"ml",
"mn",
"mr",
"ms",
"mt",
"my",
"ne",
"nl",
"nn",
"no",
"oc",
"pa",
"pl",
"ps",
"pt",
"ro",
"ru",
"sa",
"sd",
"si",
"sk",
"sl",
"sn",
"so",
"sq",
"sr",
"su",
"sv",
"sw",
"ta",
"te",
"tg",
"th",
"tk",
"tl",
"tr",
"tt",
"uk",
"ur",
"uz",
"vi",
"yi",
"yo",
"zh",
"yue",
)
This diff is collapsed.
import logging
import os
import re
from typing import List, Optional
import huggingface_hub
import requests
from tqdm.auto import tqdm
_MODELS = {
"tiny.en": "Systran/faster-whisper-tiny.en",
"tiny": "Systran/faster-whisper-tiny",
"base.en": "Systran/faster-whisper-base.en",
"base": "Systran/faster-whisper-base",
"small.en": "Systran/faster-whisper-small.en",
"small": "Systran/faster-whisper-small",
"medium.en": "Systran/faster-whisper-medium.en",
"medium": "Systran/faster-whisper-medium",
"large-v1": "Systran/faster-whisper-large-v1",
"large-v2": "Systran/faster-whisper-large-v2",
"large-v3": "Systran/faster-whisper-large-v3",
"large": "Systran/faster-whisper-large-v3",
"distil-large-v2": "Systran/faster-distil-whisper-large-v2",
"distil-medium.en": "Systran/faster-distil-whisper-medium.en",
"distil-small.en": "Systran/faster-distil-whisper-small.en",
"distil-large-v3": "Systran/faster-distil-whisper-large-v3",
}
def available_models() -> List[str]:
"""Returns the names of available models."""
return list(_MODELS.keys())
def get_assets_path():
"""Returns the path to the assets directory."""
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
def get_logger():
"""Returns the module logger."""
return logging.getLogger("faster_whisper")
def download_model(
size_or_id: str,
output_dir: Optional[str] = None,
local_files_only: bool = False,
cache_dir: Optional[str] = None,
):
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
Args:
size_or_id: Size of the model to download from https://huggingface.co/Systran
(tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en,
distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2,
distil-large-v3), or a CTranslate2-converted model ID from the Hugging Face Hub
(e.g. Systran/faster-whisper-large-v3).
output_dir: Directory where the model should be saved. If not set, the model is saved in
the cache directory.
local_files_only: If True, avoid downloading the file and return the path to the local
cached file if it exists.
cache_dir: Path to the folder where cached files are stored.
Returns:
The path to the downloaded model.
Raises:
ValueError: if the model size is invalid.
"""
if re.match(r".*/.*", size_or_id):
repo_id = size_or_id
else:
repo_id = _MODELS.get(size_or_id)
if repo_id is None:
raise ValueError(
"Invalid model size '%s', expected one of: %s"
% (size_or_id, ", ".join(_MODELS.keys()))
)
allow_patterns = [
"config.json",
"preprocessor_config.json",
"model.bin",
"tokenizer.json",
"vocabulary.*",
]
kwargs = {
"local_files_only": local_files_only,
"allow_patterns": allow_patterns,
"tqdm_class": disabled_tqdm,
}
if output_dir is not None:
kwargs["local_dir"] = output_dir
kwargs["local_dir_use_symlinks"] = False
if cache_dir is not None:
kwargs["cache_dir"] = cache_dir
try:
return huggingface_hub.snapshot_download(repo_id, **kwargs)
except (
huggingface_hub.utils.HfHubHTTPError,
requests.exceptions.ConnectionError,
) as exception:
logger = get_logger()
logger.warning(
"An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
repo_id,
exception,
)
logger.warning(
"Trying to load the model directly from the local cache, if it exists."
)
kwargs["local_files_only"] = True
return huggingface_hub.snapshot_download(repo_id, **kwargs)
def format_timestamp(
seconds: float,
always_include_hours: bool = False,
decimal_marker: str = ".",
) -> str:
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)
class disabled_tqdm(tqdm):
def __init__(self, *args, **kwargs):
kwargs["disable"] = True
super().__init__(*args, **kwargs)
def get_end(segments: List[dict]) -> Optional[float]:
return next(
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
segments[-1]["end"] if segments else None,
)
This diff is collapsed.
"""Version information."""
__version__ = "1.0.3"
transformers[torch]>=4.23
ctranslate2>=4.0,<5
huggingface_hub>=0.13
tokenizers>=0.13,<1
onnxruntime>=1.14,<2
pyannote-audio>=3.1.1
torch>=2.1.1
torchaudio>=2.1.2
tqdm
\ No newline at end of file
[flake8]
max-line-length = 100
ignore =
E203,
W503,
[isort]
profile=black
lines_between_types=1
import os
from setuptools import find_packages, setup
base_dir = os.path.dirname(os.path.abspath(__file__))
def get_long_description():
readme_path = os.path.join(base_dir, "README.md")
with open(readme_path, encoding="utf-8") as readme_file:
return readme_file.read()
def get_project_version():
version_path = os.path.join(base_dir, "faster_whisper", "version.py")
version = {}
with open(version_path, encoding="utf-8") as fp:
exec(fp.read(), version)
return version["__version__"]
def get_requirements(path):
with open(path, encoding="utf-8") as requirements:
return [requirement.strip() for requirement in requirements]
install_requires = get_requirements(os.path.join(base_dir, "requirements.txt"))
conversion_requires = get_requirements(
os.path.join(base_dir, "requirements.conversion.txt")
)
setup(
name="faster-whisper",
version=get_project_version(),
license="MIT",
description="Faster Whisper transcription with CTranslate2",
long_description=get_long_description(),
long_description_content_type="text/markdown",
author="Guillaume Klein",
url="https://github.com/SYSTRAN/faster-whisper",
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
keywords="openai whisper speech ctranslate2 inference quantization transformer",
python_requires=">=3.8",
install_requires=install_requires,
extras_require={
"conversion": conversion_requires,
"dev": [
"black==23.*",
"flake8==6.*",
"isort==5.*",
"pytest==7.*",
],
},
packages=find_packages(),
include_package_data=True,
)
import os
import pytest
@pytest.fixture
def data_dir():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "data")
@pytest.fixture
def jfk_path(data_dir):
return os.path.join(data_dir, "jfk.flac")
@pytest.fixture
def physcisworks_path(data_dir):
return os.path.join(data_dir, "physicsworks.wav")
import os
from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio
from faster_whisper.tokenizer import Tokenizer
from faster_whisper.transcribe import get_suppressed_tokens
def test_supported_languages():
model = WhisperModel("tiny.en")
assert model.supported_languages == ["en"]
def test_transcribe(jfk_path):
model = WhisperModel("tiny")
segments, info = model.transcribe(jfk_path, word_timestamps=True)
assert info.all_language_probs is not None
assert info.language == "en"
assert info.language_probability > 0.9
assert info.duration == 11
# Get top language info from all results, which should match the
# already existing metadata
top_lang, top_lang_score = info.all_language_probs[0]
assert info.language == top_lang
assert abs(info.language_probability - top_lang_score) < 1e-16
segments = list(segments)
assert len(segments) == 1
segment = segments[0]
assert segment.text == (
" And so my fellow Americans ask not what your country can do for you, "
"ask what you can do for your country."
)
assert segment.text == "".join(word.word for word in segment.words)
assert segment.start == segment.words[0].start
assert segment.end == segment.words[-1].end
batched_model = BatchedInferencePipeline(model=model, use_vad_model=False)
result, info = batched_model.transcribe(jfk_path, word_timestamps=True)
assert info.language == "en"
assert info.language_probability > 0.7
segments = []
for segment in result:
segments.append(
{"start": segment.start, "end": segment.end, "text": segment.text}
)
assert len(segments) == 1
assert segment.text == (
" And so my fellow Americans ask not what your country can do for you, "
"ask what you can do for your country."
)
def test_batched_transcribe(physcisworks_path):
model = WhisperModel("tiny")
batched_model = BatchedInferencePipeline(model=model)
result, info = batched_model.transcribe(physcisworks_path, batch_size=16)
assert info.language == "en"
assert info.language_probability > 0.7
segments = []
for segment in result:
segments.append(
{"start": segment.start, "end": segment.end, "text": segment.text}
)
# number of near 30 sec segments
assert len(segments) == 8
result, info = batched_model.transcribe(
physcisworks_path,
batch_size=16,
without_timestamps=False,
word_timestamps=True,
)
segments = []
for segment in result:
assert segment.words is not None
segments.append(
{"start": segment.start, "end": segment.end, "text": segment.text}
)
assert len(segments) > 8
def test_prefix_with_timestamps(jfk_path):
model = WhisperModel("tiny")
segments, _ = model.transcribe(jfk_path, prefix="And so my fellow Americans")
segments = list(segments)
assert len(segments) == 1
segment = segments[0]
assert segment.text == (
" And so my fellow Americans ask not what your country can do for you, "
"ask what you can do for your country."
)
assert segment.start == 0
assert 10 < segment.end < 11
def test_vad(jfk_path):
model = WhisperModel("tiny")
segments, info = model.transcribe(
jfk_path,
vad_filter=True,
vad_parameters=dict(min_silence_duration_ms=500, speech_pad_ms=200),
)
segments = list(segments)
assert len(segments) == 1
segment = segments[0]
assert segment.text == (
" And so my fellow Americans ask not what your country can do for you, "
"ask what you can do for your country."
)
assert 0 < segment.start < 1
assert 10 < segment.end < 11
assert info.vad_options.min_silence_duration_ms == 500
assert info.vad_options.speech_pad_ms == 200
def test_stereo_diarization(data_dir):
model = WhisperModel("tiny")
audio_path = os.path.join(data_dir, "stereo_diarization.wav")
left, right = decode_audio(audio_path, split_stereo=True)
segments, _ = model.transcribe(left)
transcription = "".join(segment.text for segment in segments).strip()
assert transcription == (
"He began a confused complaint against the wizard, "
"who had vanished behind the curtain on the left."
)
segments, _ = model.transcribe(right)
transcription = "".join(segment.text for segment in segments).strip()
assert transcription == "The horizon seems extremely distant."
def test_multisegment_lang_id(physcisworks_path):
model = WhisperModel("tiny")
language_info = model.detect_language_multi_segment(physcisworks_path)
assert language_info["language_code"] == "en"
assert language_info["language_confidence"] > 0.8
def test_suppressed_tokens_minus_1():
model = WhisperModel("tiny.en")
tokenizer = Tokenizer(model.hf_tokenizer, False)
tokens = get_suppressed_tokens(tokenizer, [-1])
assert tokens == (
1,
2,
7,
8,
9,
10,
14,
25,
26,
27,
28,
29,
31,
58,
59,
60,
61,
62,
63,
90,
91,
92,
93,
357,
366,
438,
532,
685,
705,
796,
930,
1058,
1220,
1267,
1279,
1303,
1343,
1377,
1391,
1635,
1782,
1875,
2162,
2361,
2488,
3467,
4008,
4211,
4600,
4808,
5299,
5855,
6329,
7203,
9609,
9959,
10563,
10786,
11420,
11709,
11907,
13163,
13697,
13700,
14808,
15306,
16410,
16791,
17992,
19203,
19510,
20724,
22305,
22935,
27007,
30109,
30420,
33409,
34949,
40283,
40493,
40549,
47282,
49146,
50257,
50357,
50358,
50359,
50360,
)
def test_suppressed_tokens_minus_value():
model = WhisperModel("tiny.en")
tokenizer = Tokenizer(model.hf_tokenizer, False)
tokens = get_suppressed_tokens(tokenizer, [13])
assert tokens == (13, 50257, 50357, 50358, 50359, 50360)
import os
from faster_whisper import available_models, download_model
def test_available_models():
models = available_models()
assert isinstance(models, list)
assert "tiny" in models
def test_download_model(tmpdir):
output_dir = str(tmpdir.join("model"))
model_dir = download_model("tiny", output_dir=output_dir)
assert model_dir == output_dir
assert os.path.isdir(model_dir)
assert not os.path.islink(model_dir)
for filename in os.listdir(model_dir):
path = os.path.join(model_dir, filename)
assert not os.path.islink(path)
def test_download_model_in_cache(tmpdir):
cache_dir = str(tmpdir.join("model"))
download_model("tiny", cache_dir=cache_dir)
assert os.path.isdir(cache_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment