Commit 0112b0f0 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2394 canceled with stages
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import io
import logging
import re
import sys
import inspect
import random
import typing as tp
from functools import partial
import omegaconf
import torch
import torchaudio
import numpy as np
from typing_extensions import Literal
from typing import (
Any,
Union,
Iterable,
List,
Dict,
Optional,
Tuple,
)
from librosa.filters import mel as librosa_mel_fn
from scipy.io.wavfile import read
_BoolLike_co = Union[bool, np.bool_]
_IntLike_co = Union[_BoolLike_co, int, "np.integer[Any]"]
_FloatLike_co = Union[_IntLike_co, float, "np.floating[Any]"]
def process_audio(file_path, target_sample_rate=24000):
audio, sample_rate = torchaudio.load(file_path)
# Check if the audio needs to be resampled
if sample_rate != target_sample_rate:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)(audio)
# Convert stereo to mono (if necessary)
audio = audio.mean(dim=0, keepdim=True) if audio.size(0) == 2 else audio
return audio, target_sample_rate
def load_wav(full_path):
sampling_rate, data = read(full_path)
return data, sampling_rate
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.0:
print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
print("max value is ", torch.max(y))
# global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned
mel_basis = {}
hann_window = {}
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1)
spec = torch.view_as_real(
torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[str(y.device)],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
spec = spectral_normalize_torch(spec)
return spec
def fade_out(audio: torch.Tensor, sample_rate: int,
fade_duration: float) -> torch.Tensor:
"""
Apply a linear fade-out effect to the given audio waveform.
Parameters:
audio (torch.Tensor): The audio waveform tensor.
sample_rate (int): Sample rate of the audio.
fade_duration (float): Duration of the fade-out effect in seconds.
Returns:
torch.Tensor: The audio with the fade-out effect applied.
"""
fade_samples = int(fade_duration * sample_rate)
if fade_samples > audio.shape[1]:
fade_samples = audio.shape[
1] # use the whole length of audio if necessary
fade_out_envelope = torch.linspace(1.0, 0.0, fade_samples,
dtype=audio.dtype, device=audio.device)
fade_section = audio[:, -fade_samples:].clone()
fade_section *= fade_out_envelope
faded_audio = audio.clone()
faded_audio[:, -fade_samples:] = fade_section
return faded_audio
def split_wav_into_chunks(num_samples, wav, max_chunk_size, minimum_chunk_size=720):
num_chunks = (num_samples + max_chunk_size - 1) // max_chunk_size # Ceiling division
wav_chunks = []
for i in range(num_chunks):
start_idx = i * max_chunk_size
end_idx = min(start_idx + max_chunk_size, num_samples)
if (end_idx - start_idx) >= minimum_chunk_size:
if len(wav.shape) == 2:
chunk = wav[:,start_idx:end_idx]
else:
chunk = wav[start_idx:end_idx]
wav_chunks.append(chunk)
else:
print(f"{num_samples}:{num_chunks}, chunk size={(end_idx - start_idx)} is lower then minimum_chunk_size!")
return wav_chunks
def tiny(x: Union[float, np.ndarray]) -> _FloatLike_co:
"""Compute the tiny-value corresponding to an input's data type.
"""
# Make sure we have an array view
x = np.asarray(x)
# Only floating types generate a tiny
if np.issubdtype(x.dtype, np.floating) or np.issubdtype(
x.dtype, np.complexfloating
):
dtype = x.dtype
else:
dtype = np.dtype(np.float32)
return np.finfo(dtype).tiny
def detect_silence(audio, sample_rate, threshold=0.05, min_silence_duration=1):
"""
Detects the first occurrence of silence in the audio.
Parameters:
audio (Tensor): The audio waveform.
sample_rate (int): The sample rate of the audio.
threshold (float): The threshold below which the signal is considered silent.
min_silence_duration (float): The minimum duration of silence in seconds.
Returns:
int: The timestamp (in samples) where the silence starts.
"""
# Convert the audio to a numpy array for easier manipulation
audio_np = audio.numpy().flatten()
# Calculate the energy of the signal
energy = np.abs(audio_np)
# Find the indices where the energy is below the threshold
silent_indices = np.where(energy < threshold)[0]
# Find the start and end of contiguous silent regions
silent_regions = np.split(silent_indices, np.where(np.diff(silent_indices) != 1)[0] + 1)
# Filter out regions that are too short
min_silence_samples = int(min_silence_duration * sample_rate)
for region in silent_regions:
if len(region) >= min_silence_samples:
return region[0]
# If no silence is found, return the length of the audio
return len(audio_np)
def trim_audio(waveform, sample_rate=24000, threshold=0.05, min_silence_duration=1, minimum_silence_start_sample=24000):
"""
Trims the audio from the beginning to the first occurrence of silence.
Parameters:
waveform (Tensor): The waveform data to the input audio file.
sample_rate (int): Sample rate of the input audio file.
threshold (float): The threshold below which the signal is considered silent.
min_silence_duration (float): The minimum duration of silence in seconds.
"""
# Detect the first occurrence of silence
silence_start_sample = detect_silence(waveform, sample_rate, threshold, min_silence_duration)
if silence_start_sample > minimum_silence_start_sample :
trimmed_waveform = waveform[:silence_start_sample]
else:
trimmed_waveform = waveform[:minimum_silence_start_sample]
if isinstance(trimmed_waveform, torch.Tensor):
return trimmed_waveform
else:
return trimmed_waveform.unsqueeze()
def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
loudness_compressor: bool = False, energy_floor: float = 2e-3):
"""Normalize an input signal to a user loudness in dB LKFS.
Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
Args:
wav (torch.Tensor): Input multichannel audio data.
sample_rate (int): Sample rate.
loudness_headroom_db (float): Target loudness of the output in dB LUFS.
loudness_compressor (bool): Uses tanh for soft clipping.
energy_floor (float): anything below that RMS level will not be rescaled.
Returns:
torch.Tensor: Loudness normalized output data.
"""
energy = wav.pow(2).mean().sqrt().item()
if energy < energy_floor:
return wav
transform = torchaudio.transforms.Loudness(sample_rate)
input_loudness_db = transform(wav).item()
# calculate the gain needed to scale to the desired loudness level
delta_loudness = -loudness_headroom_db - input_loudness_db
gain = 10.0 ** (delta_loudness / 20.0)
output = gain * wav
if loudness_compressor:
output = torch.tanh(output)
assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
return output
def normalize(
S: np.ndarray,
*,
norm: Optional[float] = np.inf,
axis: Optional[int] = 0,
threshold: Optional[_FloatLike_co] = None,
fill: Optional[bool] = None,
) -> np.ndarray:
"""Normalize an array along a chosen axis.
"""
# Avoid div-by-zero
if threshold is None:
threshold = tiny(S)
elif threshold <= 0:
raise ParameterError(f"threshold={threshold} must be strictly positive")
if fill not in [None, False, True]:
raise ParameterError(f"fill={fill} must be None or boolean")
if not np.isfinite(S).all():
raise ParameterError("Input must be finite")
# All norms only depend on magnitude, let's do that first
S = S.numpy()
mag = np.abs(S).astype(float)
# For max/min norms, filling with 1 works
fill_norm = 1
if norm is None:
return S
elif norm == np.inf:
length = np.max(mag, axis=axis, keepdims=True)
elif norm == -np.inf:
length = np.min(mag, axis=axis, keepdims=True)
elif norm == 0:
if fill is True:
raise ParameterError("Cannot normalize with norm=0 and fill=True")
length = np.sum(mag > 0, axis=axis, keepdims=True, dtype=mag.dtype)
elif np.issubdtype(type(norm), np.number) and norm > 0:
length = np.sum(mag**norm, axis=axis, keepdims=True) ** (1.0 / norm)
if axis is None:
fill_norm = mag.size ** (-1.0 / norm)
else:
fill_norm = mag.shape[axis] ** (-1.0 / norm)
else:
raise ParameterError(f"Unsupported norm: {repr(norm)}")
# indices where norm is below the threshold
small_idx = length < threshold
Snorm = np.empty_like(S)
if fill is None:
# Leave small indices un-normalized
length[small_idx] = 1.0
Snorm[:] = S / length
elif fill:
# If we have a non-zero fill value, we locate those entries by
# doing a nan-divide.
# If S was finite, then length is finite (except for small positions)
length[small_idx] = np.nan
Snorm[:] = S / length
Snorm[np.isnan(Snorm)] = fill_norm
else:
# Set small values to zero by doing an inf-divide.
# This is safe (by IEEE-754) as long as S is finite.
length[small_idx] = np.inf
Snorm[:] = S / length
return Snorm
def normalize_audio(wav: torch.Tensor, normalize: bool = True,
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
loudness_compressor: bool = False, log_clipping: bool = False,
sample_rate: tp.Optional[int] = None,
stem_name: tp.Optional[str] = None) -> torch.Tensor:
"""Normalize the audio according to the prescribed strategy (see after).
Args:
wav (torch.Tensor): Audio data.
normalize (bool): if `True` (default), normalizes according to the prescribed
strategy (see after). If `False`, the strategy is only used in case clipping
would happen.
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
with extra headroom to avoid clipping. 'clip' just clips.
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
than the `peak_clip` one to avoid further clipping.
loudness_headroom_db (float): Target loudness for loudness normalization.
loudness_compressor (bool): If True, uses tanh based soft clipping.
log_clipping (bool): If True, basic logging on stderr when clipping still
occurs despite strategy (only for 'rms').
sample_rate (int): Sample rate for the audio data (required for loudness).
stem_name (str, optional): Stem name for clipping logging.
Returns:
torch.Tensor: Normalized audio.
"""
scale_peak = 10 ** (-peak_clip_headroom_db / 20)
scale_rms = 10 ** (-rms_headroom_db / 20)
if strategy == 'peak':
rescaling = (scale_peak / wav.abs().max())
if normalize or rescaling < 1:
wav = wav * rescaling
elif strategy == 'clip':
wav = wav.clamp(-scale_peak, scale_peak)
elif strategy == 'rms':
mono = wav.mean(dim=0)
rescaling = scale_rms / mono.pow(2).mean().sqrt()
if normalize or rescaling < 1:
wav = wav * rescaling
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
elif strategy == 'loudness':
assert sample_rate is not None, "Loudness normalization requires sample rate."
wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
else:
assert wav.abs().max() < 1
assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
return wav
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""
Convert audio to float 32 bits PCM format.
Args:
wav (torch.tensor): Input wav tensor
Returns:
same wav in float32 PCM format
"""
if wav.dtype.is_floating_point:
return wav
elif wav.dtype == torch.int16:
return wav.float() / 2**15
elif wav.dtype == torch.int32:
return wav.float() / 2**31
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to int 16 bits PCM format.
..Warning:: There exist many formula for doing this conversion. None are perfect
due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
it is possible that `i16_pcm(f32_pcm)) != Identity`.
Args:
wav (torch.tensor): Input wav tensor
Returns:
same wav in float16 PCM format
"""
if wav.dtype.is_floating_point:
assert wav.abs().max() <= 1
candidate = (wav * 2 ** 15).round()
if candidate.max() >= 2 ** 15: # clipping would occur
candidate = (wav * (2 ** 15 - 1)).round()
return candidate.short()
else:
assert wav.dtype == torch.int16
return wav
def compress(wav: torch.Tensor, sr: int,
target_format: tp.Literal["mp3", "ogg", "flac"] = "mp3",
bitrate: str = "128k") -> tp.Tuple[torch.Tensor, int]:
"""Convert audio wave form to a specified lossy format: mp3, ogg, flac
Args:
wav (torch.Tensor): Input wav tensor.
sr (int): Sampling rate.
target_format (str): Compression format (e.g., 'mp3').
bitrate (str): Bitrate for compression.
Returns:
Tuple of compressed WAV tensor and sampling rate.
"""
# Extract the bit rate from string (e.g., '128k')
match = re.search(r"\d+(\.\d+)?", str(bitrate))
parsed_bitrate = float(match.group()) if match else None
assert parsed_bitrate, f"Invalid bitrate specified (got {parsed_bitrate})"
try:
# Create a virtual file instead of saving to disk
buffer = io.BytesIO()
torchaudio.save(
buffer, wav, sr, format=target_format, bits_per_sample=parsed_bitrate,
)
# Move to the beginning of the file
buffer.seek(0)
compressed_wav, sr = torchaudio.load(buffer)
return compressed_wav, sr
except RuntimeError:
logger.warning(
f"compression failed skipping compression: {format} {parsed_bitrate}"
)
return wav, sr
def get_mp3(wav_tensor: torch.Tensor, sr: int, bitrate: str = "128k") -> torch.Tensor:
"""Convert a batch of audio files to MP3 format, maintaining the original shape.
This function takes a batch of audio files represented as a PyTorch tensor, converts
them to MP3 format using the specified bitrate, and returns the batch in the same
shape as the input.
Args:
wav_tensor (torch.Tensor): Batch of audio files represented as a tensor.
Shape should be (batch_size, channels, length).
sr (int): Sampling rate of the audio.
bitrate (str): Bitrate for MP3 conversion, default is '128k'.
Returns:
torch.Tensor: Batch of audio files converted to MP3 format, with the same
shape as the input tensor.
"""
device = wav_tensor.device
batch_size, channels, original_length = wav_tensor.shape
# Flatten tensor for conversion and move to CPU
wav_tensor_flat = wav_tensor.view(1, -1).cpu()
# Convert to MP3 format with specified bitrate
wav_tensor_flat, _ = compress(wav_tensor_flat, sr, bitrate=bitrate)
# Reshape back to original batch format and trim or pad if necessary
wav_tensor = wav_tensor_flat.view(batch_size, channels, -1)
compressed_length = wav_tensor.shape[-1]
if compressed_length > original_length:
wav_tensor = wav_tensor[:, :, :original_length] # Trim excess frames
elif compressed_length < original_length:
padding = torch.zeros(
batch_size, channels, original_length - compressed_length, device=device
)
wav_tensor = torch.cat((wav_tensor, padding), dim=-1) # Pad with zeros
# Move tensor back to the original device
return wav_tensor.to(device)
def get_aac(
wav_tensor: torch.Tensor,
sr: int,
bitrate: str = "128k",
lowpass_freq: tp.Optional[int] = None,
) -> torch.Tensor:
"""Converts a batch of audio tensors to AAC format and then back to tensors.
This function first saves the input tensor batch as WAV files, then uses FFmpeg to convert
these WAV files to AAC format. Finally, it loads the AAC files back into tensors.
Args:
wav_tensor (torch.Tensor): A batch of audio files represented as a tensor.
Shape should be (batch_size, channels, length).
sr (int): Sampling rate of the audio.
bitrate (str): Bitrate for AAC conversion, default is '128k'.
lowpass_freq (Optional[int]): Frequency for a low-pass filter. If None, no filter is applied.
Returns:
torch.Tensor: Batch of audio files converted to AAC and back, with the same
shape as the input tensor.
"""
import tempfile
import subprocess
device = wav_tensor.device
batch_size, channels, original_length = wav_tensor.shape
# Parse the bitrate value from the string
match = re.search(r"\d+(\.\d+)?", bitrate)
parsed_bitrate = (
match.group() if match else "128"
) # Default to 128 if parsing fails
# Flatten tensor for conversion and move to CPU
wav_tensor_flat = wav_tensor.view(1, -1).cpu()
with tempfile.NamedTemporaryFile(
suffix=".wav"
) as f_in, tempfile.NamedTemporaryFile(suffix=".aac") as f_out:
input_path, output_path = f_in.name, f_out.name
# Save the tensor as a WAV file
torchaudio.save(input_path, wav_tensor_flat, sr, backend="ffmpeg")
# Prepare FFmpeg command for AAC conversion
command = [
"ffmpeg",
"-y",
"-i",
input_path,
"-ar",
str(sr),
"-b:a",
f"{parsed_bitrate}k",
"-c:a",
"aac",
]
if lowpass_freq is not None:
command += ["-cutoff", str(lowpass_freq)]
command.append(output_path)
try:
# Run FFmpeg and suppress output
subprocess.run(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
# Load the AAC audio back into a tensor
aac_tensor, _ = torchaudio.load(output_path, backend="ffmpeg")
except Exception as exc:
raise RuntimeError(
"Failed to run command " ".join(command)} "
"(Often this means ffmpeg is not installed or the encoder is not supported, "
"make sure you installed an older version ffmpeg<5)"
) from exc
original_length_flat = batch_size * channels * original_length
compressed_length_flat = aac_tensor.shape[-1]
# Trim excess frames
if compressed_length_flat > original_length_flat:
aac_tensor = aac_tensor[:, :original_length_flat]
# Pad the shortedn frames
elif compressed_length_flat < original_length_flat:
padding = torch.zeros(
1, original_length_flat - compressed_length_flat, device=device
)
aac_tensor = torch.cat((aac_tensor, padding), dim=-1)
# Reshape and adjust length to match original tensor
wav_tensor = aac_tensor.view(batch_size, channels, -1)
compressed_length = wav_tensor.shape[-1]
assert compressed_length == original_length, (
"AAC-compressed audio does not have the same frames as original one. "
"One reason can be ffmpeg is not installed and used as proper backed "
"for torchaudio, or the AAC encoder is not correct. Run "
"`torchaudio.utils.ffmpeg_utils.get_audio_encoders()` and make sure we see entry for"
"AAC in the output."
)
return wav_tensor.to(device)
\ No newline at end of file
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`."""
import io
import json
import struct
import typing as tp
# format is `ECDC` magic code, followed by the header size as uint32.
# Then an uint8 indicates the protocol version (0.)
# The header is then provided as json and should contain all required
# informations for decoding. A raw stream of bytes is then provided
# and should be interpretable using the json header.
_encodec_header_struct = struct.Struct('!4sBI')
_ENCODEC_MAGIC = b'ECDC'
def write_ecdc_header(fo: tp.IO[bytes], metadata: tp.Any):
meta_dumped = json.dumps(metadata).encode('utf-8')
version = 0
header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version,
len(meta_dumped))
fo.write(header)
fo.write(meta_dumped)
fo.flush()
def _read_exactly(fo: tp.IO[bytes], size: int) -> bytes:
buf = b""
while len(buf) < size:
new_buf = fo.read(size)
if not new_buf:
raise EOFError("Impossible to read enough data from the stream, "
f"{size} bytes remaining.")
buf += new_buf
size -= len(new_buf)
return buf
def read_ecdc_header(fo: tp.IO[bytes]):
header_bytes = _read_exactly(fo, _encodec_header_struct.size)
magic, version, meta_size = _encodec_header_struct.unpack(header_bytes)
if magic != _ENCODEC_MAGIC:
raise ValueError("File is not in ECDC format.")
if version != 0:
raise ValueError("Version not supported.")
meta_bytes = _read_exactly(fo, meta_size)
return json.loads(meta_bytes.decode('utf-8'))
class BitPacker:
"""Simple bit packer to handle ints with a non standard width, e.g. 10 bits.
Note that for some bandwidth (1.5, 3), the codebook representation
will not cover an integer number of bytes.
Args:
bits (int): number of bits per value that will be pushed.
fo (IO[bytes]): file-object to push the bytes to.
"""
def __init__(self, bits: int, fo: tp.IO[bytes]):
self._current_value = 0
self._current_bits = 0
self.bits = bits
self.fo = fo
def push(self, value: int):
"""Push a new value to the stream. This will immediately
write as many uint8 as possible to the underlying file-object."""
self._current_value += (value << self._current_bits)
self._current_bits += self.bits
while self._current_bits >= 8:
lower_8bits = self._current_value & 0xff
self._current_bits -= 8
self._current_value >>= 8
self.fo.write(bytes([lower_8bits]))
def flush(self):
"""Flushes the remaining partial uint8, call this at the end
of the stream to encode."""
if self._current_bits:
self.fo.write(bytes([self._current_value]))
self._current_value = 0
self._current_bits = 0
self.fo.flush()
class BitUnpacker:
"""BitUnpacker does the opposite of `BitPacker`.
Args:
bits (int): number of bits of the values to decode.
fo (IO[bytes]): file-object to push the bytes to.
"""
def __init__(self, bits: int, fo: tp.IO[bytes]):
self.bits = bits
self.fo = fo
self._mask = (1 << bits) - 1
self._current_value = 0
self._current_bits = 0
def pull(self) -> tp.Optional[int]:
"""
Pull a single value from the stream, potentially reading some
extra bytes from the underlying file-object.
Returns `None` when reaching the end of the stream.
"""
while self._current_bits < self.bits:
buf = self.fo.read(1)
if not buf:
return None
character = buf[0]
self._current_value += character << self._current_bits
self._current_bits += 8
out = self._current_value & self._mask
self._current_value >>= self.bits
self._current_bits -= self.bits
return out
def test():
import torch
torch.manual_seed(1234)
for rep in range(4):
length: int = torch.randint(10, 2_000, (1, )).item()
bits: int = torch.randint(1, 16, (1, )).item()
tokens: tp.List[int] = torch.randint(2**bits, (length, )).tolist()
rebuilt: tp.List[int] = []
buf = io.BytesIO()
packer = BitPacker(bits, buf)
for token in tokens:
packer.push(token)
packer.flush()
buf.seek(0)
unpacker = BitUnpacker(bits, buf)
while True:
value = unpacker.pull()
if value is None:
break
rebuilt.append(value)
assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens))
# The flushing mechanism might lead to "ghost" values at the end of the stream.
assert len(rebuilt) <= len(tokens) + 8 // bits, (len(rebuilt),
len(tokens), bits)
for idx, (a, b) in enumerate(zip(tokens, rebuilt)):
assert a == b, (idx, a, b)
if __name__ == '__main__':
test()
# Copyright [2023-11-28] <sxc19@mails.tsinghua.edu.cn, Xingchen Song>
# 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from inspiremusic.transformer.activation import Swish
from inspiremusic.transformer.subsampling import (
LinearNoSubsampling,
EmbedinigNoSubsampling,
Conv1dSubsampling2,
Conv2dSubsampling4,
Conv2dSubsampling6,
Conv2dSubsampling8,
)
from inspiremusic.transformer.embedding import (PositionalEncoding,
RelPositionalEncoding,
WhisperPositionalEncoding,
LearnablePositionalEncoding,
NoPositionalEncoding)
from inspiremusic.transformer.attention import (MultiHeadedAttention,
RelPositionMultiHeadedAttention)
from inspiremusic.transformer.embedding import EspnetRelPositionalEncoding
from inspiremusic.transformer.subsampling import LegacyLinearNoSubsampling
INSPIREMUSIC_ACTIVATION_CLASSES = {
"hardtanh": torch.nn.Hardtanh,
"tanh": torch.nn.Tanh,
"relu": torch.nn.ReLU,
"selu": torch.nn.SELU,
"swish": getattr(torch.nn, "SiLU", Swish),
"gelu": torch.nn.GELU,
}
INSPIREMUSIC_SUBSAMPLE_CLASSES = {
"linear": LinearNoSubsampling,
"linear_legacy": LegacyLinearNoSubsampling,
"embed": EmbedinigNoSubsampling,
"conv1d2": Conv1dSubsampling2,
"conv2d": Conv2dSubsampling4,
"conv2d6": Conv2dSubsampling6,
"conv2d8": Conv2dSubsampling8,
'paraformer_dummy': torch.nn.Identity
}
INSPIREMUSIC_EMB_CLASSES = {
"embed": PositionalEncoding,
"abs_pos": PositionalEncoding,
"rel_pos": RelPositionalEncoding,
"rel_pos_espnet": EspnetRelPositionalEncoding,
"no_pos": NoPositionalEncoding,
"abs_pos_whisper": WhisperPositionalEncoding,
"embed_learnable_pe": LearnablePositionalEncoding,
}
INSPIREMUSIC_ATTENTION_CLASSES = {
"selfattn": MultiHeadedAttention,
"rel_selfattn": RelPositionMultiHeadedAttention,
}
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Unility functions for Transformer."""
from typing import List
import torch
IGNORE_ID = -1
MUSIC_STRUCTURE_LABELS = ["intro", "verse1", "chorus", "verse2", "outro"]
DTYPES = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
}
def pad_list(xs: List[torch.Tensor], pad_value: int):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
max_len = max([len(item) for item in xs])
batchs = len(xs)
ndim = xs[0].ndim
if ndim == 1:
pad_res = torch.zeros(batchs,
max_len,
dtype=xs[0].dtype,
device=xs[0].device)
elif ndim == 2:
pad_res = torch.zeros(batchs,
max_len,
xs[0].shape[1],
dtype=xs[0].dtype,
device=xs[0].device)
elif ndim == 3:
pad_res = torch.zeros(batchs,
max_len,
xs[0].shape[1],
xs[0].shape[2],
dtype=xs[0].dtype,
device=xs[0].device)
else:
raise ValueError(f"Unsupported ndim: {ndim}")
pad_res.fill_(pad_value)
for i in range(batchs):
pad_res[i, :len(xs[i])] = xs[i]
return pad_res
def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
ignore_label: int) -> torch.Tensor:
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax).
ignore_label (int): Ignore label id.
Returns:
torch.Tensor: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
pad_outputs.size(1)).argmax(2)
mask = pad_targets != ignore_label
numerator = torch.sum(
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
denominator = torch.sum(mask)
return (numerator / denominator).detach()
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def topk_sampling(weighted_scores, decoded_tokens, top_k=25):
zeros = weighted_scores.new_ones(weighted_scores.shape) * float('-inf')
values,indices = torch.topk(weighted_scores,top_k)
zeros.scatter_(-1, indices, values)
return random_sampling(zeros,decoded_tokens)
# Repetition Aware Sampling in VALL-E 2
def ras_sampling(weighted_scores, decoded_tokens, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
if rep_num >= win_size * tau_r:
top_ids = random_sampling(weighted_scores, decoded_tokens)
return top_ids
def caras_sampling(weighted_scores, decoded_tokens, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
weighted_scores, cfg_weighted_scores = weighted_scores
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
if rep_num >= win_size * tau_r:
top_ids = random_sampling(cfg_weighted_scores, decoded_tokens)
return top_ids
def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
prob, indices = [], []
cum_prob = 0.0
sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
for i in range(len(sorted_idx)):
# sampling both top-p and numbers.
if cum_prob < top_p and len(prob) < top_k:
cum_prob += sorted_value[i]
prob.append(sorted_value[i])
indices.append(sorted_idx[i])
else:
break
prob = torch.tensor(prob).to(weighted_scores)
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
top_ids = indices[prob.multinomial(1, replacement=True)]
return top_ids
def random_sampling(weighted_scores, decoded_tokens):
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
return top_ids
def fade_in_out(fade_in_mel, fade_out_mel, window):
device = fade_in_mel.device
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
mel_overlap_len = int(window.shape[0] / 2)
fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + \
fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
return fade_in_mel.to(device)
def set_all_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
assert mask.dtype == torch.bool
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
mask = mask.to(dtype)
# attention mask bias
# NOTE(Mddct): torch.finfo jit issues
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
mask = (1.0 - mask) * torch.finfo(dtype).min
return mask
\ No newline at end of file
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.utils.data import DataLoader
from inspiremusic.dataset.dataset import Dataset
import numpy as np
import librosa
def audio_process_dataset_and_dataloader(args, configs):
input_dataset = Dataset(args.input_data, data_pipeline=configs['data_pipeline'], mode='processing', shuffle=True, partition=True)
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
input_data_loader = DataLoader(input_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)
return input_dataset, input_data_loader
def is_silent(wav_path, threshold=0.01, frame_length=2048, hop_length=512):
y, sr = librosa.load(wav_path, sr=None)
rms = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0]
silent_frames = np.sum(rms < threshold) / len(rms)
silence_fraction_threshold = 0.95
return silent_frames >= silence_fraction_threshold
def rich_captions(text=None, tags=None, lyrics=None, chorus="verse", start_time=0.0, end_time=30.0):
if text is None and tags is None and lyrics is None:
return None
else:
if start_time is None:
start_time = 0.0
if end_time is None:
end_time = 30.0
if chorus is None:
chorus = "verse"
captions = f"<|{start_time:.1f}|><|{chorus}|>"
if tags is not None:
captions += f"<|{tags}|>"
if text is not None:
captions += f"<|{text}|>"
if lyrics is not None:
captions += f"<|lyrics|><|{lyrics}|>"
captions += f"<|{end_time:.1f}|>"
return captions
def process_tags(infile, outfile, timefile = None):
key_list = []
with open(infile, "r") as f:
for line in f:
sec = line.strip()
key_list.append(sec)
f.close()
if timefile is None:
with open(outfile, 'w') as f:
for k in key_list:
parts = k.rsplit('_', 1)
text = parts[0].replace('_', ' ') + ', ' + parts[1]
caption = rich_captions(text, None, None)
if caption is not None:
f.write("%s\t%s\n" %(k, caption))
f.close()
else:
times = {}
with open(timefile, "r") as f:
for line in f:
sec = line.strip().split("\t")
if len(sec) == 2 :
times[sec[0]] = sec[1]
f.close()
with open(outfile, 'w') as f:
for k in key_list:
parts = k.rsplit('_', 1)
text = parts[0].replace('_', ' ') + ', ' + parts[1]
if k in times.keys():
caption = rich_captions(text, None, None, "verse", 0.0, float(times[k]))
if caption is not None:
f.write("%s\t%s\n" %(k, caption))
f.close()
def process_trans(infile, outfile):
trans = {}
with open(infile, "r") as f:
for line in f:
sec = line.strip().split("\t")
if len(sec) == 2:
trans[sec[0]] = sec[1]
else:
print(line)
f.close()
with open(outfile, 'w') as f:
for k, v in trans.items():
f.write("%s\t%s\n" %(k, rich_captions(v)))
f.close()
\ No newline at end of file
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from contextlib import nullcontext
import os
import torch
import torch.distributed as dist
from inspiremusic.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, inspiremusic_join
from torch.cuda.amp import GradScaler, autocast
class Executor:
def __init__(self):
self.step = 0
self.epoch = 0
self.rank = int(os.environ.get('RANK', 0))
self.device = torch.device('cuda:{}'.format(self.rank))
def train_one_epoch(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join, scaler=None):
''' Train one epoch
'''
lr = optimizer.param_groups[0]['lr']
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
logging.info('using accumulate grad, new batch size is {} times'
' larger than before'.format(info_dict['accum_grad']))
# A context manager to be used in conjunction with an instance of
# torch.nn.parallel.DistributedDataParallel to be able to train
# with uneven inputs across participating processes.
model.train()
model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
with model_context():
for batch_idx, batch_dict in enumerate(train_data_loader):
info_dict["tag"] = "TRAIN"
info_dict["step"] = self.step
info_dict["epoch"] = self.epoch
info_dict["batch_idx"] = batch_idx
if inspiremusic_join(group_join, info_dict):
break
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
context = model.no_sync
# Used for single gpu training and DDP gradient synchronization
# processes.
else:
context = nullcontext
with context():
with autocast(enabled=scaler is not None):
info_dict = batch_forward(model, batch_dict, info_dict, scaler)
info_dict = batch_backward(model, info_dict, scaler)
info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict, scaler)
log_per_step(writer, info_dict)
# NOTE specify save_per_step in inspiremusic.yaml if you want to enable step save
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
(batch_idx + 1) % info_dict["accum_grad"] == 0:
dist.barrier()
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False, scaler=scaler)
model.train()
if (batch_idx + 1) % info_dict["accum_grad"] == 0:
self.step += 1
dist.barrier()
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True, scaler=scaler)
@torch.inference_mode()
def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True, capped_at=5, scaler=None):
''' Cross validation on
'''
logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
model.eval()
total_num_utts, total_loss_dict = 0, {} # avoid division by 0
stop = capped_at
for batch_idx, batch_dict in enumerate(cv_data_loader):
info_dict["tag"] = "CV"
info_dict["step"] = self.step
info_dict["epoch"] = self.epoch
info_dict["batch_idx"] = batch_idx
num_utts = len(batch_dict["utts"])
total_num_utts += num_utts
if capped_at>0:
if stop <= 0:
continue
else:
stop -= 1
with autocast(enabled=scaler is not None):
info_dict = batch_forward(model, batch_dict, info_dict, scaler)
for k, v in info_dict['loss_dict'].items():
if k not in total_loss_dict:
total_loss_dict[k] = []
total_loss_dict[k].append(v.item() * num_utts)
log_per_step(None, info_dict)
for k, v in total_loss_dict.items():
total_loss_dict[k] = sum(v) / total_num_utts
info_dict['loss_dict'] = total_loss_dict
log_per_save(writer, info_dict)
model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
save_model(model, model_name, info_dict)
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import torchaudio
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
def read_trans(list_file):
trans = {}
with open(list_file, 'r', encoding='utf8') as fin:
for line in fin:
sec = line.strip().split("\t")
if len(sec) > 1:
if sec[0] not in trans.keys():
trans[sec[0]] = sec[1]
return trans
def read_scp(list_file):
scp = {}
with open(list_file, 'r', encoding='utf8') as fin:
for line in fin:
sec = line.strip().split(" ")
if len(sec) > 1:
if sec[0] not in scp.keys():
scp[sec[0]] = sec[1]
return scp
def read_lists(list_file):
lists = []
with open(list_file, 'r', encoding='utf8') as fin:
for line in fin:
lists.append(line.strip())
return lists
def read_json_lists(list_file):
lists = read_lists(list_file)
results = {}
for fn in lists:
with open(fn, 'r', encoding='utf8') as fin:
results.update(json.load(fin))
return results
def load_wav(wav, target_sr):
audio, sample_rate = torchaudio.load(wav)
audio = audio.mean(dim=0, keepdim=True)
if sample_rate != target_sr:
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
return audio
def speed_change(waveform, sample_rate, speed_factor: str):
effects = [
["tempo", speed_factor], # speed_factor
["rate", f"{sample_rate}"]
]
augmented_waveform, new_sample_rate = torchaudio.sox_effects.apply_effects_tensor(
waveform,
sample_rate,
effects
)
return augmented_waveform, new_sample_rate
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
# whether contain chinese character
def contains_chinese(text):
return bool(chinese_char_pattern.search(text))
# replace special symbol
def replace_corner_mark(text):
text = text.replace('²', '平方')
text = text.replace('³', '立方')
return text
# remove meaningless symbol
def remove_bracket(text):
text = text.replace('(', '').replace(')', '')
text = text.replace('【', '').replace('】', '')
text = text.replace('`', '').replace('`', '')
text = text.replace("——", " ")
return text
# spell Arabic numerals
def spell_out_number(text: str, inflect_parser):
new_text = []
st = None
for i, c in enumerate(text):
if not c.isdigit():
if st is not None:
num_str = inflect_parser.number_to_words(text[st: i])
new_text.append(num_str)
st = None
new_text.append(c)
else:
if st is None:
st = i
if st is not None and st < len(text):
num_str = inflect_parser.number_to_words(text[st:])
new_text.append(num_str)
return ''.join(new_text)
# split paragrah logic:
# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
# 2. cal sentence len according to lang
# 3. split sentence according to puncatation
def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
def calc_utt_length(_text: str):
if lang == "zh":
return len(_text)
else:
return len(tokenize(_text))
def should_merge(_text: str):
if lang == "zh":
return len(_text) < merge_len
else:
return len(tokenize(_text)) < merge_len
if lang == "zh":
pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
else:
pounc = ['.', '?', '!', ';', ':']
if comma_split:
pounc.extend([',', ','])
st = 0
utts = []
for i, c in enumerate(text):
if c in pounc:
if len(text[st: i]) > 0:
utts.append(text[st: i] + c)
if i + 1 < len(text) and text[i + 1] in ['"', '”']:
tmp = utts.pop(-1)
utts.append(tmp + text[i + 1])
st = i + 2
else:
st = i + 1
if len(utts) == 0:
if lang == "zh":
utts.append(text + '。')
else:
utts.append(text + '.')
final_utts = []
cur_utt = ""
for utt in utts:
if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
final_utts.append(cur_utt)
cur_utt = ""
cur_utt = cur_utt + utt
if len(cur_utt) > 0:
if should_merge(cur_utt) and len(final_utts) != 0:
final_utts[-1] = final_utts[-1] + cur_utt
else:
final_utts.append(cur_utt)
return final_utts
# remove blank between chinese character
def replace_blank(text: str):
out_str = []
for i, c in enumerate(text):
if c == " ":
if ((text[i + 1].isascii() and text[i + 1] != " ") and
(text[i - 1].isascii() and text[i - 1] != " ")):
out_str.append(c)
else:
out_str.append(c)
return "".join(out_str)
import sys
import torch.distributed
import logging
HINTED = set()
def hint_once(content, uid, rank=None):
if (rank is None) or (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == rank:
if uid not in HINTED:
logging.info(content, stacklevel=3)
HINTED.add(uid)
\ No newline at end of file
import torch
import torch.nn.functional as F
def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
loss = 0
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
m_DG = torch.median((dr - dg))
L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
loss += tau - F.relu(tau - L_rel)
return loss
def mel_loss(real_speech, generated_speech, mel_transforms):
loss = 0
for transform in mel_transforms:
mel_r = transform(real_speech)
mel_g = transform(generated_speech)
loss += F.l1_loss(mel_g, mel_r)
return loss
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
'''
def subsequent_mask(
size: int,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size).
This mask is used only in decoder which works in an auto-regressive mode.
This means the current step could only do attention with its left steps.
In encoder, fully attention is used when streaming is not necessary and
the sequence is not long. In this case, no attention mask is needed.
When streaming is need, chunk-based attention is used in encoder. See
subsequent_chunk_mask for the chunk-based attention mask.
Args:
size (int): size of mask
str device (str): "cpu" or "cuda" or torch.Tensor.device
dtype (torch.device): result dtype
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
ret = torch.ones(size, size, device=device, dtype=torch.bool)
return torch.tril(ret)
'''
def subsequent_mask(
size: int,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size).
This mask is used only in decoder which works in an auto-regressive mode.
This means the current step could only do attention with its left steps.
In encoder, fully attention is used when streaming is not necessary and
the sequence is not long. In this case, no attention mask is needed.
When streaming is need, chunk-based attention is used in encoder. See
subsequent_chunk_mask for the chunk-based attention mask.
Args:
size (int): size of mask
str device (str): "cpu" or "cuda" or torch.Tensor.device
dtype (torch.device): result dtype
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
arange = torch.arange(size, device=device)
mask = arange.expand(size, size)
arange = arange.unsqueeze(-1)
mask = mask <= arange
return mask
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if num_left_chunks < 0:
start = 0
else:
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
ending = min((i // chunk_size + 1) * chunk_size, size)
ret[i, start:ending] = True
return ret
def add_optional_chunk_mask(xs: torch.Tensor,
masks: torch.Tensor,
use_dynamic_chunk: bool,
use_dynamic_left_chunk: bool,
decoding_chunk_size: int,
static_chunk_size: int,
num_decoding_left_chunks: int,
enable_full_context: bool = True):
""" Apply optional mask for encoder.
Args:
xs (torch.Tensor): padded input, (B, L, D), L for max length
mask (torch.Tensor): mask for xs, (B, 1, L)
use_dynamic_chunk (bool): whether to use dynamic chunk or not
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
training.
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
static_chunk_size (int): chunk size for static chunk training/decoding
if it's greater than 0, if use_dynamic_chunk is true,
this parameter will be ignored
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
enable_full_context (bool):
True: chunk size is either [1, 25] or full context(max_len)
False: chunk size ~ U[1, 25]
Returns:
torch.Tensor: chunk mask of the input xs.
"""
# Whether to use chunk mask or not
if use_dynamic_chunk:
max_len = xs.size(1)
if decoding_chunk_size < 0:
chunk_size = max_len
num_left_chunks = -1
elif decoding_chunk_size > 0:
chunk_size = decoding_chunk_size
num_left_chunks = num_decoding_left_chunks
else:
# chunk size is either [1, 25] or full context(max_len).
# Since we use 4 times subsampling and allow up to 1s(100 frames)
# delay, the maximum frame is 100 / 4 = 25.
chunk_size = torch.randint(1, max_len, (1, )).item()
num_left_chunks = -1
if chunk_size > max_len // 2 and enable_full_context:
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_left_chunks = torch.randint(0, max_left_chunks,
(1, )).item()
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
elif static_chunk_size > 0:
num_left_chunks = num_decoding_left_chunks
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
else:
chunk_masks = masks
return chunk_masks
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
seq_range = torch.arange(0,
max_len,
dtype=torch.int64,
device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment