Commit b75857fb authored by chenzk's avatar chenzk
Browse files

v1.0

parents
import gc
import queue
from typing import Generator
import numpy as np
import torch
from loguru import logger
from fish_speech.inference_engine.reference_loader import ReferenceLoader
from fish_speech.inference_engine.utils import InferenceResult, wav_chunk_header
from fish_speech.inference_engine.vq_manager import VQManager
from fish_speech.models.text2semantic.inference import (
GenerateRequest,
GenerateResponse,
WrappedGenerateResponse,
)
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
from fish_speech.utils import autocast_exclude_mps, set_seed
from fish_speech.utils.schema import ServeTTSRequest
class TTSInferenceEngine(ReferenceLoader, VQManager):
def __init__(
self,
llama_queue: queue.Queue,
decoder_model: FireflyArchitecture,
precision: torch.dtype,
compile: bool,
) -> None:
super().__init__()
self.llama_queue = llama_queue
self.decoder_model = decoder_model
self.precision = precision
self.compile = compile
@torch.inference_mode()
def inference(self, req: ServeTTSRequest) -> Generator[InferenceResult, None, None]:
"""
Main inference function:
- Loads the reference audio and text.
- Calls the LLAMA model for inference.
- Decodes the VQ tokens to audio.
"""
ref_id: str | None = req.reference_id
prompt_tokens, prompt_texts = [], []
# Load the reference audio and text based on id or hash
if ref_id is not None:
prompt_tokens, prompt_texts = self.load_by_id(ref_id, req.use_memory_cache)
elif req.references:
prompt_tokens, prompt_texts = self.load_by_hash(
req.references, req.use_memory_cache
)
# Set the random seed if provided
if req.seed is not None:
set_seed(req.seed)
logger.warning(f"set seed: {req.seed}")
# Get the symbolic tokens from the LLAMA model
response_queue = self.send_Llama_request(req, prompt_tokens, prompt_texts)
# Get the sample rate from the decoder model
sample_rate = self.decoder_model.spec_transform.sample_rate
# If streaming, send the header
if req.streaming:
yield InferenceResult(
code="header",
audio=(
sample_rate,
np.array(wav_chunk_header(sample_rate=sample_rate)),
),
error=None,
)
segments = []
while True:
# Get the response from the LLAMA model
wrapped_result: WrappedGenerateResponse = response_queue.get()
if wrapped_result.status == "error":
yield InferenceResult(
code="error",
audio=None,
error=(
wrapped_result.response
if isinstance(wrapped_result.response, Exception)
else Exception("Unknown error")
),
)
break
# Check the response type
if not isinstance(wrapped_result.response, GenerateResponse):
raise TypeError(
"Expected GenerateResponse, got {type(wrapped_result.response).__name__}"
)
result: GenerateResponse = wrapped_result.response
if result.action != "next":
segment = self.get_audio_segment(result)
if req.streaming: # Used only by the API server
yield InferenceResult(
code="segment",
audio=(sample_rate, segment),
error=None,
)
segments.append(segment)
else:
break
# Clean up the memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Edge case: no audio generated
if len(segments) == 0:
yield InferenceResult(
code="error",
audio=None,
error=RuntimeError("No audio generated, please check the input text."),
)
else:
# Streaming or not, return the final audio
audio = np.concatenate(segments, axis=0)
yield InferenceResult(
code="final",
audio=(sample_rate, audio),
error=None,
)
return None
def send_Llama_request(
self, req: ServeTTSRequest, prompt_tokens: list, prompt_texts: list
) -> queue.Queue:
"""
Send a request to the LLAMA model to generate the symbolic tokens.
"""
# Prepare the request
request = dict(
device=self.decoder_model.device,
max_new_tokens=req.max_new_tokens,
text=(
req.text
if not req.normalize
else ChnNormedText(raw_text=req.text).normalize()
),
top_p=req.top_p,
repetition_penalty=req.repetition_penalty,
temperature=req.temperature,
compile=self.compile,
iterative_prompt=req.chunk_length > 0,
chunk_length=req.chunk_length,
max_length=4096,
prompt_tokens=prompt_tokens,
prompt_text=prompt_texts,
)
# Create a queue to get the response
response_queue = queue.Queue()
# Send the request to the LLAMA model
self.llama_queue.put(
GenerateRequest(
request=request,
response_queue=response_queue,
)
)
return response_queue
def get_audio_segment(self, result: GenerateResponse) -> np.ndarray:
"""
Decode the VQ tokens to audio.
"""
# Don't use autocast on MPS devices
with autocast_exclude_mps(
device_type=self.decoder_model.device.type, dtype=self.precision
):
# Decode the symbolic tokens to audio
segment = self.decode_vq_tokens(codes=result.codes)
# Convert the audio to numpy
return segment.float().cpu().numpy()
import io
from hashlib import sha256
from pathlib import Path
from typing import Callable, Literal, Tuple
import torch
import torchaudio
from loguru import logger
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
from fish_speech.utils.file import (
AUDIO_EXTENSIONS,
audio_to_bytes,
list_files,
read_ref_text,
)
from fish_speech.utils.schema import ServeReferenceAudio
class ReferenceLoader:
def __init__(self) -> None:
"""
Component of the TTSInferenceEngine class.
Loads and manages the cache for the reference audio and text.
"""
self.ref_by_id: dict = {}
self.ref_by_hash: dict = {}
# Make Pylance happy (attribut/method not defined...)
self.decoder_model: FireflyArchitecture
self.encode_reference: Callable
# Define the torchaudio backend
backends = torchaudio.list_audio_backends()
if "ffmpeg" in backends:
self.backend = "ffmpeg"
else:
self.backend = "soundfile"
def load_by_id(
self,
id: str,
use_cache: Literal["on", "off"],
) -> Tuple:
# Load the references audio and text by id
ref_folder = Path("references") / id
ref_folder.mkdir(parents=True, exist_ok=True)
ref_audios = list_files(
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
)
if use_cache == "off" or id not in self.ref_by_id:
# If the references are not already loaded, encode them
prompt_tokens = [
self.encode_reference(
# decoder_model=self.decoder_model,
reference_audio=audio_to_bytes(str(ref_audio)),
enable_reference_audio=True,
)
for ref_audio in ref_audios
]
prompt_texts = [
read_ref_text(str(ref_audio.with_suffix(".lab")))
for ref_audio in ref_audios
]
self.ref_by_id[id] = (prompt_tokens, prompt_texts)
else:
# Reuse already encoded references
logger.info("Use same references")
prompt_tokens, prompt_texts = self.ref_by_id[id]
return prompt_tokens, prompt_texts
def load_by_hash(
self,
references: list[ServeReferenceAudio],
use_cache: Literal["on", "off"],
) -> Tuple:
# Load the references audio and text by hash
audio_hashes = [sha256(ref.audio).hexdigest() for ref in references]
cache_used = False
prompt_tokens, prompt_texts = [], []
for i, ref in enumerate(references):
if use_cache == "off" or audio_hashes[i] not in self.ref_by_hash:
# If the references are not already loaded, encode them
prompt_tokens.append(
self.encode_reference(
reference_audio=ref.audio,
enable_reference_audio=True,
)
)
prompt_texts.append(ref.text)
self.ref_by_hash[audio_hashes[i]] = (prompt_tokens, prompt_texts)
else:
# Reuse already encoded references
prompt_tokens, prompt_texts = self.ref_by_hash[audio_hashes[i]]
cache_used = True
if cache_used:
logger.info("Use same references")
return prompt_tokens, prompt_texts
def load_audio(self, reference_audio, sr):
"""
Load the audio data from a file or bytes.
"""
if len(reference_audio) > 255 or not Path(reference_audio).exists():
audio_data = reference_audio
reference_audio = io.BytesIO(audio_data)
waveform, original_sr = torchaudio.load(reference_audio, backend=self.backend)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if original_sr != sr:
resampler = torchaudio.transforms.Resample(
orig_freq=original_sr, new_freq=sr
)
waveform = resampler(waveform)
audio = waveform.squeeze().numpy()
return audio
import io
import wave
from dataclasses import dataclass
from typing import Literal, Optional, Tuple
import numpy as np
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
@dataclass
class InferenceResult:
code: Literal["header", "segment", "error", "final"]
audio: Optional[Tuple[int, np.ndarray]]
error: Optional[Exception]
def normalize_text(user_input: str, use_normalization: bool) -> str:
"""Normalize user input text if needed."""
if use_normalization:
return ChnNormedText(raw_text=user_input).normalize()
else:
return user_input
def wav_chunk_header(
sample_rate: int = 44100, bit_depth: int = 16, channels: int = 1
) -> bytes:
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(channels)
wav_file.setsampwidth(bit_depth // 8)
wav_file.setframerate(sample_rate)
wav_header_bytes = buffer.getvalue()
buffer.close()
return wav_header_bytes
from typing import Callable
import torch
from loguru import logger
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
class VQManager:
def __init__(self):
# Make Pylance happy (attribut/method not defined...)
self.decoder_model: FireflyArchitecture
self.load_audio: Callable
def decode_vq_tokens(self, codes):
feature_lengths = torch.tensor(
[codes.shape[1]], device=self.decoder_model.device
)
logger.info(f"VQ features: {codes.shape}")
if isinstance(self.decoder_model, FireflyArchitecture):
return self.decoder_model.decode(
indices=codes[None],
feature_lengths=feature_lengths,
)[0].squeeze()
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
def encode_reference(self, reference_audio, enable_reference_audio):
if enable_reference_audio and reference_audio is not None:
# Load audios, and prepare basic info here
reference_audio_content = self.load_audio(
reference_audio, self.decoder_model.spec_transform.sample_rate
)
audios = torch.from_numpy(reference_audio_content).to(
self.decoder_model.device
)[None, None, :]
audio_lengths = torch.tensor(
[audios.shape[2]], device=self.decoder_model.device, dtype=torch.long
)
logger.info(
f"Loaded audio with {audios.shape[2] / self.decoder_model.spec_transform.sample_rate:.2f} seconds"
)
# VQ Encoder
if isinstance(self.decoder_model, FireflyArchitecture):
prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0]
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
else:
raise ValueError(f"Unknown model type: {type(self.decoder_model)}")
else:
prompt_tokens = None
logger.info("No reference audio provided")
return prompt_tokens
import os
import queue
import threading
import time
from contextlib import nullcontext
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Optional, Tuple, Union
import click
import numpy as np
import torch
import torch._dynamo.config
import torch._inductor.config
from loguru import logger
from tqdm import tqdm
from transformers import AutoTokenizer
from fish_speech.conversation import (
CODEBOOK_PAD_TOKEN_ID,
Conversation,
Message,
TextPart,
VQPart,
)
from fish_speech.models.text2semantic.llama import BaseModelArgs
from fish_speech.text import clean_text, split_text
from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
if hasattr(torch._inductor.config, "fx_graph_cache"):
# Experimental feature to reduce compilation times, will be on by default in future
torch._inductor.config.fx_graph_cache = True
from torch.nn.attention import SDPBackend, sdpa_kernel
from fish_speech.models.text2semantic.llama import (
BaseTransformer,
DualARTransformer,
NaiveTransformer,
)
def multinomial_sample_one_no_sync(
probs_sort,
): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def logits_to_probs(
logits,
previous_tokens: Optional[torch.Tensor] = None,
temperature: torch.Tensor = 1.0,
top_p: torch.Tensor = 1.0,
repetition_penalty: torch.Tensor = 1.0,
) -> torch.Tensor:
# Apply repetition penalty
if previous_tokens is not None:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
)
logits.scatter_(dim=0, index=previous_tokens, src=score)
# Apply top-p sampling
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=0, index=sorted_indices, src=sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
logits = logits / max(temperature, 1e-5)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def multinomial_sample_one_no_sync_agent(
probs_sort,
): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def logits_to_probs_agent(
logits,
previous_tokens: Optional[torch.Tensor] = None,
temperature: torch.Tensor = 1.0,
top_p: torch.Tensor = 1.0,
repetition_penalty: torch.Tensor = 1.0,
) -> torch.Tensor:
# Apply repetition penalty
if previous_tokens is not None:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=-1, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
)
logits.scatter_(dim=-1, index=previous_tokens, src=score)
# Apply top-p sampling
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[..., 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
logits = logits / max(temperature, 1e-5)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def sample(
logits,
previous_tokens: Optional[torch.Tensor] = None,
**sampling_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
probs = logits_to_probs(
logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
def sample_agent(
logits,
previous_tokens: Optional[torch.Tensor] = None,
**sampling_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
probs = logits_to_probs_agent(
logits=logits[:, -1], previous_tokens=previous_tokens, **sampling_kwargs
)
idx_next = multinomial_sample_one_no_sync_agent(probs)
return idx_next, probs
def decode_one_token_ar_agent(
model: DualARTransformer,
x: torch.Tensor,
input_pos: torch.Tensor,
semantic_ids: list,
previous_tokens: torch.Tensor = None,
**sampling_kwargs,
) -> torch.Tensor:
# print(x, input_pos)
x = model.forward_generate(x, input_pos)
logits = x.logits # [:, -1:]
hidden_states = x.hidden_states # [:, -1:]
sampling_kwargs_main = sampling_kwargs.copy()
sampling_kwargs_main["temperature"] = 0.1
sampling_kwargs_main["top_p"] = 0.1
sampling_kwargs_main["repetition_penalty"] = 1.0
codebooks = [
sample_agent(
logits,
previous_tokens=None, # Disable repetition penalty for the token codebook
**sampling_kwargs_main,
)[0]
]
# Cleanup the cache
for layer in model.fast_layers:
layer.attention.kv_cache.k_cache.fill_(0)
layer.attention.kv_cache.v_cache.fill_(0)
for codebook_idx in range(model.config.num_codebooks):
input_pos = torch.tensor(
[codebook_idx], device=hidden_states.device, dtype=torch.long
)
logits = model.forward_generate_fast(hidden_states, input_pos)
a = sample_agent(
logits,
previous_tokens=(
previous_tokens[:, codebook_idx + 1]
if previous_tokens is not None
else None
),
**sampling_kwargs,
)[0]
hidden_states = model.fast_embeddings(a)
codebooks.append(a)
codebooks = torch.stack(codebooks, dim=1)
semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
codebooks[:, 1:, :] = torch.masked_fill(
codebooks[:, 1:, :],
~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
CODEBOOK_PAD_TOKEN_ID,
)
return codebooks
def decode_one_token_naive_agent(
model: NaiveTransformer,
x: torch.Tensor,
input_pos: torch.Tensor,
semantic_ids: list,
previous_tokens: torch.Tensor = None,
**sampling_kwargs,
) -> torch.Tensor:
x = model.forward_generate(x, input_pos)
codebooks = [
sample(
x.token_logits,
previous_tokens=None, # Disable repetition penalty for the token codebook
**sampling_kwargs,
)[0]
]
for i in range(model.config.num_codebooks):
codebooks.append(
sample_agent(
x.codebook_logits[:, :, i],
previous_tokens=(
previous_tokens[:, i + 1] if previous_tokens is not None else None
),
**sampling_kwargs,
)[0]
)
codebooks = torch.stack(codebooks, dim=1)
semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
codebooks[:, 1:, :] = torch.masked_fill(
codebooks[:, 1:, :],
~torch.isin(codebooks[:, :1, :], semantic_ids_tensor),
CODEBOOK_PAD_TOKEN_ID,
)
return codebooks
def decode_one_token_ar(
model: DualARTransformer,
x: torch.Tensor,
input_pos: torch.Tensor,
semantic_ids: list,
previous_tokens: torch.Tensor = None,
**sampling_kwargs,
) -> torch.Tensor:
x = model.forward_generate(x, input_pos)
sampling_kwargs_main = sampling_kwargs.copy()
# sampling_kwargs_main["temperature"] = 0.1
# sampling_kwargs_main["top_p"] = 0.1
# sampling_kwargs_main["repetition_penalty"] = 1.0
codebooks = [
sample(
x.logits,
previous_tokens=(
previous_tokens[0] if previous_tokens is not None else None
), # Disable repetition penalty for the token codebook
**sampling_kwargs_main,
)[0]
]
hidden_states = x.hidden_states
# Cleanup the cache
for layer in model.fast_layers:
layer.attention.kv_cache.k_cache.fill_(0)
layer.attention.kv_cache.v_cache.fill_(0)
input_pos = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
model.forward_generate_fast(hidden_states, input_pos)
a = codebooks[0] - model.tokenizer.semantic_begin_id
a[a < 0] = 0
hidden_states = model.fast_embeddings(a)
codebooks.append(a)
for codebook_idx in range(1, model.config.num_codebooks):
input_pos = torch.tensor(
[codebook_idx], device=hidden_states.device, dtype=torch.long
)
logits = model.forward_generate_fast(hidden_states, input_pos)
a = sample(
logits,
previous_tokens=(
previous_tokens[codebook_idx + 1]
if previous_tokens is not None
else None
),
**sampling_kwargs,
)[0]
hidden_states = model.fast_embeddings(a)
codebooks.append(a)
codebooks = torch.stack(codebooks, dim=0)
# semantic_ids_tensor = torch.tensor(semantic_ids, device=codebooks.device)
# codebooks[1:, :] = torch.masked_fill(
# codebooks[1:, :], ~torch.isin(codebooks[:1, :], semantic_ids_tensor), CODEBOOK_PAD_TOKEN_ID
# )
# print(codebooks)
return codebooks
def decode_one_token_naive(
model: NaiveTransformer,
x: torch.Tensor,
input_pos: torch.Tensor,
previous_tokens: torch.Tensor = None,
**sampling_kwargs,
) -> torch.Tensor:
x = model.forward_generate(x, input_pos)
sampling_kwargs_main = sampling_kwargs.copy()
sampling_kwargs_main["temperature"] = 0.1
sampling_kwargs_main["top_p"] = 0.1
sampling_kwargs_main["repetition_penalty"] = 1.0
codebooks = [
sample(
x.logits,
previous_tokens=None, # Disable repetition penalty for the token codebook
**sampling_kwargs_main,
)[0]
]
for i in range(model.config.num_codebooks):
codebooks.append(
sample(
x.codebook_logits[:, :, i],
previous_tokens=(
previous_tokens[i + 1] if previous_tokens is not None else None
),
**sampling_kwargs,
)[0]
)
return torch.stack(codebooks, dim=0)
def decode_n_tokens(
model: NaiveTransformer,
cur_token: torch.Tensor,
input_pos: torch.Tensor,
num_new_tokens: int,
semantic_ids: list,
decode_one_token=decode_one_token_naive,
**sampling_kwargs,
):
previous_tokens = torch.zeros(
(model.config.num_codebooks + 1, model.config.max_seq_len),
dtype=torch.int,
device=cur_token.device,
)
for i in tqdm(range(num_new_tokens)):
# We need to get windowed repeat penalty
win_size = 16
if i < win_size:
window = previous_tokens[:, :win_size]
else:
window = previous_tokens[:, i - win_size : i]
with (
torch.backends.cuda.sdp_kernel(
enable_flash=False, enable_mem_efficient=False, enable_math=True
)
if torch.cuda.is_available()
else nullcontext()
): # Actually better for Inductor to codegen attention here
next_token = decode_one_token(
model=model,
x=cur_token,
input_pos=input_pos,
previous_tokens=window,
semantic_ids=semantic_ids,
**sampling_kwargs,
)
input_pos += 1
cur_token = next_token.view(1, model.config.num_codebooks + 1, -1)
previous_tokens[:, i : i + 1] = next_token.view(
model.config.num_codebooks + 1, -1
)
if cur_token[0, 0, -1] == model.tokenizer.get_token_id(IM_END_TOKEN):
break
return previous_tokens[:, : i + 1]
@torch.no_grad()
@torch.inference_mode()
def generate(
*,
model: NaiveTransformer,
prompt: torch.Tensor,
max_new_tokens: int,
decode_one_token=decode_one_token_naive,
**sampling_kwargs,
) -> torch.Tensor:
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""
# create an empty tensor of the expected final shape and fill in the current tokens
T = prompt.size(1)
# semantic_id = model.tokenizer.convert_tokens_to_ids("<|semantic|>")
semantic_ids = [
model.tokenizer.get_token_id(f"<|semantic:{i}|>") for i in range(1024)
]
if max_new_tokens:
if T + max_new_tokens > model.config.max_seq_len:
max_new_tokens = model.config.max_seq_len - T
logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
T_new = T + max_new_tokens
else:
T_new = model.config.max_seq_len
max_new_tokens = T_new - T
device, dtype = prompt.device, prompt.dtype
codebook_dim = 1 + model.config.num_codebooks
# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(
(codebook_dim, model.config.max_seq_len), dtype=dtype, device=device
)
empty[:, :T] = prompt
seq = empty
input_pos = torch.arange(0, T, device=device)
# Use non-accelerated version for now, to avoid compilation overhead
prefill_decode = (
decode_one_token_naive
if isinstance(model, NaiveTransformer)
else decode_one_token_ar
)
next_token = prefill_decode(
model,
prompt.view(1, codebook_dim, -1),
input_pos,
semantic_ids=semantic_ids,
**sampling_kwargs,
)
seq[:, T : T + 1] = next_token
input_pos = torch.tensor([T], device=device, dtype=torch.int)
x = decode_n_tokens(
model,
next_token.view(1, codebook_dim, -1),
input_pos,
max_new_tokens - 1,
decode_one_token=decode_one_token,
semantic_ids=semantic_ids,
**sampling_kwargs,
)
# x = torch.cat(generated_tokens, dim=1)
seq = seq[:, : T + 1 + x.size(1)]
seq[:, T + 1 :] = x
return seq
def decode_n_tokens_agent(
model: NaiveTransformer,
cur_token: torch.Tensor,
input_pos: torch.Tensor,
num_new_tokens: int,
semantic_ids: list,
im_end_id: int = 4,
decode_one_token=decode_one_token_naive_agent,
early_stop_threshold: float = 0.6,
**sampling_kwargs,
):
batch_size = cur_token.size(0)
previous_tokens = torch.zeros(
(batch_size, model.config.num_codebooks + 1, model.config.max_seq_len),
dtype=torch.int,
device=cur_token.device,
)
finished = torch.zeros(batch_size, dtype=torch.bool, device=cur_token.device)
finished = finished | (cur_token[:, 0, -1] == im_end_id)
start_time = time.time()
for i in tqdm(range(num_new_tokens), desc="Decoding: ", total=num_new_tokens):
# We need to get windowed repeat penalty
win_size = 16
if i < win_size:
window = previous_tokens[:, :, :win_size]
else:
window = previous_tokens[:, :, i - win_size : i]
with sdpa_kernel(
SDPBackend.MATH
): # Actually better for Inductor to codegen attention here
next_token = decode_one_token(
model=model,
x=cur_token,
input_pos=input_pos,
previous_tokens=window,
semantic_ids=semantic_ids,
**sampling_kwargs,
)
input_pos += 1
cur_token = next_token.view(batch_size, model.config.num_codebooks + 1, -1)
previous_tokens[:, :, i : i + 1] = next_token.view(
batch_size, model.config.num_codebooks + 1, -1
)
yield cur_token.cpu()
finished = finished | (cur_token[:, 0, -1] == im_end_id)
if finished.all() or (
0 < early_stop_threshold < 1
and finished.sum() >= round(batch_size * early_stop_threshold)
):
break
total_time = time.time() - start_time
generated_tokens = i + 1
tokens_per_second = (generated_tokens / total_time) * batch_size
logger.info(
f"Decoded {generated_tokens} x {batch_size} tokens in {total_time:.2f}s ({tokens_per_second:.2f} tokens/s)"
)
@torch.no_grad()
@torch.inference_mode()
def generate_agent(
*,
model: BaseTransformer,
prompt: torch.Tensor,
max_new_tokens: int,
semantic_ids: list,
im_end_id: int = 4,
decode_one_token=decode_one_token_naive_agent,
num_samples: int = 1,
early_stop_threshold: float = 0.6,
**sampling_kwargs,
):
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""
# create an empty tensor of the expected final shape and fill in the current tokens
T = prompt.size(1)
prompt = prompt[None].repeat(num_samples, 1, 1)
if T >= model.config.max_seq_len:
raise ValueError(
f"Input sequence length {T} exceeds max_seq_len {model.config.max_seq_len}"
)
if max_new_tokens:
if T + max_new_tokens > model.config.max_seq_len:
max_new_tokens = model.config.max_seq_len - T
logger.info(f"Truncating max_new_tokens to {max_new_tokens}")
T_new = T + max_new_tokens
else:
T_new = model.config.max_seq_len
max_new_tokens = T_new - T
device, dtype = prompt.device, prompt.dtype
codebook_dim = 1 + model.config.num_codebooks
input_pos = torch.arange(0, T, device=device)
# Use non-accelerated version for now, to avoid compilation overhead
prefill_decode = (
decode_one_token_naive_agent
if isinstance(model, NaiveTransformer)
else decode_one_token_ar_agent
)
next_token = prefill_decode(
model,
prompt,
input_pos,
semantic_ids=semantic_ids,
**sampling_kwargs,
).view(num_samples, codebook_dim, -1)
yield next_token.cpu()
input_pos = torch.tensor([T], device=device, dtype=torch.int)
yield from decode_n_tokens_agent(
model,
next_token,
input_pos,
max_new_tokens - 1,
im_end_id=im_end_id,
semantic_ids=semantic_ids,
decode_one_token=decode_one_token,
early_stop_threshold=early_stop_threshold,
**sampling_kwargs,
)
def encode_tokens(
tokenizer,
string,
device="cuda",
prompt_tokens=None,
num_codebooks=4,
):
string = clean_text(string)
messages = []
messages.append(
Message(
role="user",
parts=[TextPart(text=string)],
cal_loss=False,
)
)
if prompt_tokens is not None:
if prompt_tokens.ndim == 3:
assert (
prompt_tokens.shape[0] == 1
), "3D prompt tokens should have shape (1, num_codebooks, seq_len)"
prompt_tokens = prompt_tokens[0]
assert prompt_tokens.ndim == 2, "Prompt tokens should be 2D tensor"
if prompt_tokens.shape[0] > num_codebooks:
logger.warning(
f"Prompt tokens shape {prompt_tokens.shape} is larger than num_codebooks {num_codebooks}, getting first {num_codebooks} codebooks"
)
prompt_tokens = prompt_tokens[:num_codebooks]
vq_part = VQPart(codes=prompt_tokens.to(device))
messages.append(
Message(
role="assistant",
parts=[TextPart(text="<|voice|>"), vq_part],
cal_loss=False,
)
)
else:
messages.append(
Message(
role="assistant",
parts=[TextPart(text="<|voice|>")],
cal_loss=False,
add_im_end=False,
)
)
conversation = Conversation(messages=messages)
# conversation.visualize(tokenizer)
encoded = conversation.encode_for_inference(
tokenizer=tokenizer,
num_codebooks=num_codebooks,
)
return encoded.to(device)
def load_model(checkpoint_path, device, precision, compile=False, is_agent=False):
model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
checkpoint_path, load_weights=True, is_agent=is_agent
)
model = model.to(device=device, dtype=precision)
logger.info(f"Restored model from checkpoint")
if isinstance(model, DualARTransformer):
decode_one_token = (
decode_one_token_ar_agent if is_agent else decode_one_token_ar
)
logger.info("Using DualARTransformer")
else:
decode_one_token = (
decode_one_token_naive_agent if is_agent else decode_one_token_naive
)
logger.info("Using NaiveTransformer")
if compile:
logger.info("Compiling function...")
decode_one_token = torch.compile(
decode_one_token,
fullgraph=True,
backend="inductor" if torch.cuda.is_available() else "aot_eager",
mode="reduce-overhead" if torch.cuda.is_available() else None,
)
return model.eval(), decode_one_token
@dataclass
class GenerateResponse:
action: Literal["sample", "next"]
codes: Optional[torch.Tensor] = None
text: Optional[str] = None
def generate_long(
*,
model,
device: str | torch.device,
decode_one_token: callable,
text: str,
num_samples: int = 1,
max_new_tokens: int = 0,
top_p: int = 0.7,
repetition_penalty: float = 1.5,
temperature: float = 0.7,
compile: bool = False,
iterative_prompt: bool = True,
max_length: int = 2048,
chunk_length: int = 150,
prompt_text: Optional[str | list[str]] = None,
prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
):
assert 0 < top_p <= 1, "top_p must be in (0, 1]"
assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
assert 0 < temperature < 2, "temperature must be in (0, 2)"
use_prompt = prompt_text is not None and prompt_tokens is not None
if use_prompt and isinstance(prompt_text, str):
prompt_text = [prompt_text]
prompt_tokens = [prompt_tokens]
assert use_prompt is False or len(prompt_text) == len(
prompt_tokens
), "Prompt text and tokens must have the same length"
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
tokenizer = model.tokenizer
im_end_id = tokenizer.get_token_id("<|im_end|>")
encoded = []
texts = split_text(text, chunk_length) if iterative_prompt else [text]
encoded_prompts = [
Conversation(
messages=[
Message(
role="system",
parts=[TextPart(text="Speak out the provided text.")],
cal_loss=False,
)
]
)
.encode_for_inference(
tokenizer=tokenizer,
num_codebooks=model.config.num_codebooks,
)
.to(device)
]
if use_prompt:
for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
encoded_prompts.append(
encode_tokens(
tokenizer,
string=t,
device=device,
prompt_tokens=c,
num_codebooks=model.config.num_codebooks,
)
)
for idx, text in enumerate(texts):
encoded.append(
encode_tokens(
tokenizer,
string=text,
device=device,
num_codebooks=model.config.num_codebooks,
)
)
logger.info(f"Encoded text: {text}")
# Move temperature, top_p, repetition_penalty to device
# This is important so that changing params doesn't trigger recompile
temperature = torch.tensor(temperature, device=device, dtype=torch.float)
top_p = torch.tensor(top_p, device=device, dtype=torch.float)
repetition_penalty = torch.tensor(
repetition_penalty, device=device, dtype=torch.float
)
for sample_idx in range(num_samples):
if torch.cuda.is_available():
torch.cuda.synchronize()
global_encoded = []
seg_idx = 0
while seg_idx < len(encoded):
logger.info(
f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
)
seg = encoded[seg_idx]
global_encoded.append(seg)
lengths = reversed([seg.size(1) for seg in global_encoded])
# Pick last 2000 tokens
count = 0
for i, length in enumerate(lengths):
count += length
if count + length > max_length - 1024 - sum(
t.shape[1] for t in encoded_prompts
):
break
if i != 0 and i % 2 == 0:
i -= 1
# Rotate the list, always make sure first segment is included to avoid drift
if i < len(global_encoded) - 2:
partial_encoded = global_encoded[:2] + global_encoded[-i:]
else:
partial_encoded = global_encoded
if use_prompt:
partial_encoded = encoded_prompts + partial_encoded
cat_encoded = torch.cat(partial_encoded, dim=1)
prompt_length = cat_encoded.size(1)
t0 = time.perf_counter()
y = generate(
model=model,
prompt=cat_encoded,
max_new_tokens=max_new_tokens,
decode_one_token=decode_one_token,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
if sample_idx == 0 and seg_idx == 0 and compile:
logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
if torch.cuda.is_available():
torch.cuda.synchronize()
t = time.perf_counter() - t0
tokens_generated = y.size(1) - prompt_length
tokens_sec = tokens_generated / t
logger.info(
f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
)
logger.info(
f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
)
if torch.cuda.is_available():
logger.info(
f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
)
# Put the generated tokens
# since there is <im_end>, we remove last token
codes = y[1:, prompt_length + 1 :].clone()
assert (codes >= 0).all(), f"Negative code found"
decoded = y[:, prompt_length:].clone()
# But for global encoding, we should keep the <im_end> token
global_encoded.append(decoded)
assert (codes >= 0).all(), f"Negative code found: {codes}"
yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
seg_idx += 1
# This indicates the end of the current sample
yield GenerateResponse(action="next")
@dataclass
class WrappedGenerateResponse:
status: Literal["success", "error"]
response: Optional[GenerateResponse | Exception] = None
@dataclass
class GenerateRequest:
request: dict
response_queue: queue.Queue
def launch_thread_safe_queue(
checkpoint_path,
device,
precision,
compile: bool = False,
):
input_queue = queue.Queue()
init_event = threading.Event()
def worker():
model, decode_one_token = load_model(
checkpoint_path, device, precision, compile=compile
)
with torch.device(device):
model.setup_caches(
max_batch_size=1,
max_seq_len=model.config.max_seq_len,
dtype=next(model.parameters()).dtype,
)
init_event.set()
while True:
item: GenerateRequest | None = input_queue.get()
if item is None:
break
kwargs = item.request
response_queue = item.response_queue
try:
for chunk in generate_long(
model=model, decode_one_token=decode_one_token, **kwargs
):
response_queue.put(
WrappedGenerateResponse(status="success", response=chunk)
)
except Exception as e:
response_queue.put(WrappedGenerateResponse(status="error", response=e))
threading.Thread(target=worker, daemon=True).start()
init_event.wait()
return input_queue
def launch_thread_safe_queue_agent(
checkpoint_path,
device,
precision,
compile: bool = False,
):
input_queue = queue.Queue()
init_event = threading.Event()
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
config = BaseModelArgs.from_pretrained(checkpoint_path)
def worker():
model, decode_one_token = load_model(
checkpoint_path, device, precision, compile=compile, is_agent=True
)
with torch.device(device):
model.setup_caches(
max_batch_size=1,
max_seq_len=model.config.max_seq_len,
dtype=next(model.parameters()).dtype,
)
init_event.set()
while True:
item: GenerateRequest | None = input_queue.get()
if item is None:
break
kwargs = item.request
response_queue = item.response_queue
try:
for token in generate_agent(
model=model,
decode_one_token=decode_one_token,
**kwargs,
):
response_queue.put(token)
response_queue.put("stop")
except Exception as e:
import traceback
logger.exception(f"Error in worker: {traceback.format_exc()}")
response_queue.put("error")
threading.Thread(target=worker, daemon=True).start()
init_event.wait()
return input_queue, tokenizer, config
@click.command()
@click.option(
"--text",
type=str,
default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
)
@click.option("--prompt-text", type=str, default=None, multiple=True)
@click.option(
"--prompt-tokens",
type=click.Path(path_type=Path, exists=True),
default=None,
multiple=True,
)
@click.option("--num-samples", type=int, default=1)
@click.option("--max-new-tokens", type=int, default=0)
@click.option("--top-p", type=float, default=0.7)
@click.option("--repetition-penalty", type=float, default=1.2)
@click.option("--temperature", type=float, default=0.7)
@click.option(
"--checkpoint-path",
type=click.Path(path_type=Path, exists=True),
default="checkpoints/fish-speech-1.5",
)
@click.option("--device", type=str, default="cuda")
@click.option("--compile/--no-compile", default=False)
@click.option("--seed", type=int, default=42)
@click.option("--half/--no-half", default=False)
@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
@click.option("--chunk-length", type=int, default=100)
@click.option("--output-dir", type=Path, default="temp")
def main(
text: str,
prompt_text: Optional[list[str]],
prompt_tokens: Optional[list[Path]],
num_samples: int,
max_new_tokens: int,
top_p: int,
repetition_penalty: float,
temperature: float,
checkpoint_path: Path,
device: str,
compile: bool,
seed: int,
half: bool,
iterative_prompt: bool,
chunk_length: int,
output_dir: Path,
) -> None:
os.makedirs(output_dir, exist_ok=True)
precision = torch.half if half else torch.bfloat16
if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
raise ValueError(
f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
)
logger.info("Loading model ...")
t0 = time.time()
model, decode_one_token = load_model(
checkpoint_path, device, precision, compile=compile
)
with torch.device(device):
model.setup_caches(
max_batch_size=1,
max_seq_len=model.config.max_seq_len,
dtype=next(model.parameters()).dtype,
)
if torch.cuda.is_available():
torch.cuda.synchronize()
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
if prompt_tokens is not None:
prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
generator = generate_long(
model=model,
device=device,
decode_one_token=decode_one_token,
text=text,
num_samples=num_samples,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
compile=compile,
iterative_prompt=iterative_prompt,
chunk_length=chunk_length,
prompt_text=prompt_text,
prompt_tokens=prompt_tokens,
)
idx = 0
codes = []
for response in generator:
if response.action == "sample":
codes.append(response.codes)
logger.info(f"Sampled text: {response.text}")
elif response.action == "next":
if codes:
codes_npy_path = os.path.join(output_dir, f"codes_{idx}.npy")
np.save(codes_npy_path, torch.cat(codes, dim=1).cpu().numpy())
logger.info(f"Saved codes to {codes_npy_path}")
logger.info(f"Next sample")
codes = []
idx += 1
else:
logger.error(f"Error: {response}")
if __name__ == "__main__":
main()
from typing import Any, Optional
import lightning as L
import torch
import torch.nn.functional as F
from lightning.pytorch.utilities.types import OptimizerLRScheduler
import fish_speech.utils as utils
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
from fish_speech.models.text2semantic.llama import NaiveTransformer
log = utils.RankedLogger(__name__, rank_zero_only=True)
class TextToSemantic(L.LightningModule):
def __init__(
self,
model: NaiveTransformer,
optimizer: Any,
lr_scheduler: Any,
):
super().__init__()
self.model = model
self.optimizer_builder = optimizer
self.lr_scheduler_builder = lr_scheduler
def forward(self, x):
return self.model(x)
def on_save_checkpoint(self, checkpoint):
# Save only LoRA parameters
state_dict = checkpoint["state_dict"]
use_lora = any("lora" in name for name in state_dict.keys())
if not use_lora:
return
for name in list(state_dict.keys()):
if "lora" not in name:
state_dict.pop(name)
def configure_optimizers(self) -> OptimizerLRScheduler:
# Get weight decay parameters
weight_decay_parameters, other_parameters = [], []
for name, param in self.named_parameters():
if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
other_parameters.append(param)
else:
weight_decay_parameters.append(param)
optimizer = self.optimizer_builder(
[
{"params": weight_decay_parameters},
{"params": other_parameters, "weight_decay": 0.0},
]
)
# Print the parameters and their weight decay
for i in optimizer.param_groups:
log.info(
f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
)
lr_scheduler = self.lr_scheduler_builder(optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "step",
},
}
# Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
def get_batch_logps(
self,
logits: torch.FloatTensor,
labels: torch.LongTensor,
average_log_prob: bool = False,
) -> torch.FloatTensor:
"""Compute the log probabilities of the given labels under the given logits.
Args:
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
Returns:
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
"""
assert logits.shape[:-1] == labels.shape
labels = labels.clone()
loss_mask = labels != -100
# dummy token; we'll ignore the losses on these tokens later
labels[labels == -100] = 0
per_token_logps = torch.gather(
logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
).squeeze(-1)
if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
return (per_token_logps * loss_mask).sum(-1)
def _step(self, batch, batch_idx, stage: str):
is_train = stage == "train"
if is_train:
# Key part to make lora work
# Otherwise the parameters are merged, which lead to incorrect gradients
self.model.train()
# Do positive and negative samples in the same batch to speed up training
labels = batch["labels"]
outputs = self.model(
inp=batch["inputs"],
key_padding_mask=batch["attention_masks"],
)
token_logits = outputs.token_logits
codebook_logits = outputs.codebook_logits
# Generate labels
base_loss = F.cross_entropy(
token_logits.view(-1, token_logits.size(-1)),
labels[:, 0].reshape(-1),
ignore_index=-100,
)
codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
semantic_loss = F.cross_entropy(
codebook_logits.view(-1, codebook_logits.size(-1)),
codebook_labels.reshape(-1),
ignore_index=-100,
)
loss = base_loss + semantic_loss
self.log(
f"{stage}/loss",
loss,
on_step=is_train,
on_epoch=not is_train,
prog_bar=True,
logger=True,
sync_dist=not is_train,
)
self.log(
f"{stage}/base_loss",
base_loss,
on_step=is_train,
on_epoch=not is_train,
prog_bar=False,
logger=True,
sync_dist=not is_train,
)
self.log(
f"{stage}/semantic_loss",
semantic_loss,
on_step=is_train,
on_epoch=not is_train,
prog_bar=False,
logger=True,
sync_dist=not is_train,
)
# Top-5 accuracy
accuracy = self.get_accuracy(codebook_logits, codebook_labels)
self.log(
f"{stage}/top_5_accuracy",
accuracy,
on_step=is_train,
on_epoch=not is_train,
prog_bar=True,
logger=True,
sync_dist=not is_train,
)
return loss
def get_accuracy(self, logits, labels):
mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
if mask.sum() == 0:
return torch.tensor(0.0, device=logits.device)
_, indices = logits.topk(5, dim=-1)
correct = indices.eq(labels.unsqueeze(-1))
correct[~mask] = 0
correct = correct.sum()
accuracy = correct / mask.sum()
return accuracy
def training_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "train")
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
import dataclasses
import json
import math
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange
from loguru import logger
from torch import Tensor
from torch.nn import functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.utils.checkpoint import checkpoint
from transformers import AutoTokenizer
from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer
from fish_speech.utils import RankedLogger
from .lora import LoraConfig, setup_lora
log = RankedLogger(__name__, rank_zero_only=True)
def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)
@dataclass
class BaseModelArgs:
model_type: str = "base"
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
dim: int = 4096
intermediate_size: int = None
n_local_heads: int = -1
head_dim: int = 64
rope_base: float = 10000
norm_eps: float = 1e-5
max_seq_len: int = 2048
dropout: float = 0.0
tie_word_embeddings: bool = True
attention_qkv_bias: bool = False
# Codebook configs
codebook_size: int = 160
num_codebooks: int = 4
# Gradient checkpointing
use_gradient_checkpointing: bool = True
# Initialize the model
initializer_range: float = 0.02
# Dummy vars
is_reward_model: bool = False
share_codebook_embeddings: bool = True
scale_codebook_embeddings: bool = False
def __post_init__(self):
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
if self.intermediate_size is None:
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = find_multiple(n_hidden, 256)
self.head_dim = self.dim // self.n_head
@staticmethod
def from_pretrained(path: str):
path = Path(path)
if path.is_dir():
path = path / "config.json"
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
match data["model_type"]:
case "naive":
cls = NaiveModelArgs
case "dual_ar":
cls = DualARModelArgs
case _:
raise ValueError(f"Unknown model type: {data['model_type']}")
return cls(**data)
def save(self, path: str):
with open(path, "w") as f:
json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
@dataclass
class NaiveModelArgs(BaseModelArgs):
model_type: str = "naive"
@dataclass
class DualARModelArgs(BaseModelArgs):
model_type: str = "dual_ar"
n_fast_layer: int = 4
fast_dim: int | None = None
fast_n_head: int | None = None
fast_n_local_heads: int | None = None
fast_head_dim: int | None = None
fast_intermediate_size: int | None = None
fast_attention_qkv_bias: bool | None = None
def __post_init__(self):
super().__post_init__()
self.fast_dim = self.fast_dim or self.dim
self.fast_n_head = self.fast_n_head or self.n_head
self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
self.fast_head_dim = self.fast_head_dim or self.head_dim
self.fast_intermediate_size = (
self.fast_intermediate_size or self.intermediate_size
)
self.fast_attention_qkv_bias = (
self.fast_attention_qkv_bias
if self.fast_attention_qkv_bias is not None
else self.attention_qkv_bias
)
class KVCache(nn.Module):
def __init__(
self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
assert input_pos.shape[0] == k_val.shape[2]
k_out = self.k_cache
v_out = self.v_cache
k_out[:, :, input_pos] = k_val
v_out[:, :, input_pos] = v_val
return k_out, v_out
@dataclass
class TransformerForwardResult:
token_logits: Tensor
codebook_logits: Tensor
@dataclass
class BaseTransformerForwardResult:
logits: Tensor
hidden_states: Tensor
class BaseTransformer(nn.Module):
def __init__(
self,
config: BaseModelArgs,
tokenizer: FishTokenizer,
init_weights: bool = True,
) -> None:
super().__init__()
self.config = config
self.tokenizer = tokenizer
self.semantic_token_ids = [
tokenizer.get_token_id(SEMANTIC_TOKEN) for SEMANTIC_TOKEN in SEMANTIC_TOKENS
]
# Slow transformer
self.embeddings = nn.Embedding(
config.vocab_size,
config.dim,
)
self.codebook_embeddings = nn.Embedding(
config.codebook_size * config.num_codebooks,
config.dim,
)
self.layers = nn.ModuleList(
TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
)
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
if self.config.tie_word_embeddings is False:
self.output = nn.Linear(
config.dim,
config.vocab_size,
bias=False,
)
self.register_buffer(
"freqs_cis",
precompute_freqs_cis(
config.max_seq_len,
config.dim // config.n_head,
config.rope_base,
),
persistent=False,
)
self.register_buffer(
"causal_mask",
torch.tril(
torch.ones(
config.max_seq_len,
config.max_seq_len,
dtype=torch.bool,
)
),
persistent=False,
)
# For kv cache
self.max_batch_size = -1
self.max_seq_len = -1
if init_weights:
self.apply(self._init_weights)
def setup_caches(
self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
):
if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
return
head_dim = self.config.dim // self.config.n_head
max_seq_len = find_multiple(max_seq_len, 8)
self.max_seq_len = max_seq_len
self.max_batch_size = max_batch_size
for b in self.layers:
b.attention.kv_cache = KVCache(
max_batch_size,
max_seq_len,
self.config.n_local_heads,
head_dim,
dtype=dtype,
)
def embed(self, inp: Tensor, share_codebook_embeddings=True) -> Tensor:
embeds = []
semantic_token_ids_tensor = torch.tensor(
self.semantic_token_ids, device=inp.device, dtype=inp.dtype
)
for i in range(self.config.num_codebooks):
if share_codebook_embeddings:
emb = self.codebook_embeddings(
inp[:, i + 1] + i * self.config.codebook_size
)
else:
emb = self.codebook_embeddings(inp[:, i + 1])
embeds.append(emb)
vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
vq_embeds_sum[~torch.isin(inp[:, 0], semantic_token_ids_tensor)] = 0
x = self.embeddings(inp[:, 0]) + vq_embeds_sum
return x
def forward(
self,
inp: Tensor,
key_padding_mask: Optional[Tensor] = None,
) -> BaseTransformerForwardResult:
seq_len = inp.size(2)
# Here we want to merge the embeddings of the codebooks
x = self.embed(inp)
freqs_cis = self.freqs_cis[:seq_len]
# Not that the causal mask here follows the definition of scaled_dot_product_attention
# That is, FALSE means masked out
# To maintain consistency, key_padding_mask use TRUE to mask out
mask = None
if key_padding_mask is not None:
causal = self.causal_mask[:seq_len, :seq_len]
causal = rearrange(causal, "q k -> 1 1 q k")
atten_mask = rearrange(key_padding_mask, "b s -> b 1 1 s")
atten_mask = atten_mask.logical_not()
mask = causal & atten_mask
# return freqs_cis, mask
for layer in self.layers:
if self.config.use_gradient_checkpointing and self.training:
x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
else:
x = layer(x, freqs_cis, mask)
# We got slow_out here
slow_out = self.norm(x)
if self.config.tie_word_embeddings:
token_logits = F.linear(slow_out, self.embeddings.weight)
else:
token_logits = self.output(slow_out)
return BaseTransformerForwardResult(
logits=token_logits,
hidden_states=x,
)
def forward_generate(
self,
inp: Tensor,
input_pos: Optional[Tensor] = None,
return_all: bool = False,
) -> BaseTransformerForwardResult:
x = self.embed(
inp, share_codebook_embeddings=self.config.share_codebook_embeddings
)
if input_pos is None:
input_pos = torch.arange(inp.shape[-1], device=x.device)
max_seq_len = inp.shape[-1]
else:
max_seq_len = self.max_seq_len
mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
freqs_cis = self.freqs_cis[input_pos]
for layer in self.layers:
x = layer(x, freqs_cis, mask, input_pos=input_pos)
# If prefill, we only calculate the logits of last token
if x.size(1) > 1 and not return_all:
x = x[:, -1:]
# We got slow_out here
slow_out = self.norm(x)
if self.config.is_reward_model:
token_logits = self.score_output(slow_out)
elif self.config.tie_word_embeddings:
token_logits = F.linear(slow_out, self.embeddings.weight)
else:
token_logits = self.output(slow_out)
return BaseTransformerForwardResult(
logits=token_logits,
hidden_states=x,
)
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@staticmethod
def from_pretrained(
path: str,
load_weights: bool = False,
max_length: int | None = None,
lora_config: LoraConfig | None = None,
rope_base: int | None = None,
is_agent: bool = False,
) -> "BaseTransformer":
config = BaseModelArgs.from_pretrained(str(path))
if max_length is not None:
config.max_seq_len = max_length
log.info(f"Override max_seq_len to {max_length}")
if rope_base is not None:
config.rope_base = rope_base
log.info(f"Override rope_base to {rope_base}")
match config.model_type:
case "naive":
model_cls = NaiveTransformer
case "dual_ar":
model_cls = DualARTransformer
case _:
raise ValueError(f"Unknown model type: {config.model_type}")
tokenizer_path = str(path) + "/tokenizer.tiktoken"
tokenizer = FishTokenizer(tokenizer_path)
log.info(f"Loading model from {path}, config: {config}")
model = model_cls(config, tokenizer=tokenizer)
if lora_config is not None:
setup_lora(model, lora_config)
log.info(f"LoRA setup: {lora_config}")
if load_weights is False:
log.info("Randomly initialized model")
else:
if "int8" in str(Path(path)):
logger.info("Using int8 weight-only quantization!")
from tools.llama.quantize import WeightOnlyInt8QuantHandler
simple_quantizer = WeightOnlyInt8QuantHandler(model)
model = simple_quantizer.convert_for_runtime()
if "int4" in str(Path(path)):
logger.info("Using int4 quantization!")
path_comps = path.name.split("-")
assert path_comps[-2].startswith("g")
groupsize = int(path_comps[-2][1:])
from tools.llama.quantize import WeightOnlyInt4QuantHandler
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
model = simple_quantizer.convert_for_runtime()
weights = torch.load(
Path(path) / "model.pth",
map_location="cpu",
mmap=True,
weights_only=True,
)
if "state_dict" in weights:
logger.warning(
"Using a TextToSemantic LightningModule checkpoint, "
"please make sure it is a full model, not a LoRA model."
)
weights = weights["state_dict"]
if next(iter(weights.keys())).startswith("model."):
logger.info(
f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
)
new_weights = OrderedDict()
for k, v in weights.items():
new_weights[k.replace("model.", "")] = v
weights = new_weights
# Verify the name and shape of parameters since strict=False in load_state_dict.
for k, v in model.named_parameters():
if k not in weights:
logger.warning(f"No weight for {k}")
elif v.shape != weights[k].shape:
logger.warning(
f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
)
err = model.load_state_dict(weights, strict=False, assign=True)
log.info(f"Loaded weights with error: {err}")
return model
def save_pretrained(self, path: str, drop_lora: bool = False):
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
self.config.save(path / "config.json")
state_dict = self.state_dict()
if drop_lora:
for key in list(state_dict.keys()):
if "lora" not in key:
continue
state_dict.pop(key)
log.info(f"Drop LoRA parameter: {key}")
torch.save(state_dict, path / "model.pth")
self.tokenizer.save_pretrained(path)
class NaiveTransformer(BaseTransformer):
def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
super().__init__(config, init_weights=False, tokenizer=tokenizer)
self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.codebook_output = nn.Linear(
config.dim,
config.codebook_size * config.num_codebooks,
bias=False,
)
self.apply(self._init_weights)
def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
token_logits = result.logits
x = result.hidden_states
# Codebook
codebook_logits = self.codebook_output(self.codebook_norm(x))
codebook_logits = rearrange(
codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
)
return TransformerForwardResult(
token_logits=token_logits,
codebook_logits=codebook_logits,
)
def forward(
self,
inp: Tensor,
key_padding_mask: Optional[Tensor] = None,
) -> TransformerForwardResult:
result = super().forward(
inp=inp,
key_padding_mask=key_padding_mask,
)
return self.decode(result)
def forward_generate(
self, x: Tensor, input_pos: Optional[Tensor] = None
) -> TransformerForwardResult:
result = super().forward_generate(x, input_pos)
return self.decode(result)
class DualARTransformer(BaseTransformer):
def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
super().__init__(config, init_weights=False, tokenizer=tokenizer)
# Project to fast dim if needed
if config.fast_dim is not None and config.fast_dim != config.dim:
self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
else:
self.fast_project_in = nn.Identity()
# Fast transformer
self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
# The equivalent bs is so large that sdpa doesn't work
override_config = dataclasses.replace(
config,
dim=config.fast_dim,
n_head=config.fast_n_head,
n_local_heads=config.fast_n_local_heads,
head_dim=config.fast_head_dim,
intermediate_size=config.fast_intermediate_size,
attention_qkv_bias=config.fast_attention_qkv_bias,
)
self.fast_layers = nn.ModuleList(
TransformerBlock(override_config, use_sdpa=False)
for _ in range(config.n_fast_layer)
)
self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
self.fast_output = nn.Linear(
config.fast_dim,
config.codebook_size,
bias=False,
)
self.register_buffer(
"fast_freqs_cis",
precompute_freqs_cis(
config.num_codebooks,
config.fast_dim // config.fast_n_head,
config.rope_base,
),
persistent=False,
)
self.apply(self._init_weights)
def setup_caches(
self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
):
super().setup_caches(max_batch_size, max_seq_len, dtype)
head_dim = self.config.fast_dim // self.config.fast_n_head
# Fast transformer
# The max seq len here is the number of codebooks
for b in self.fast_layers:
b.attention.kv_cache = KVCache(
max_batch_size,
self.config.num_codebooks,
self.config.fast_n_local_heads,
head_dim,
dtype=dtype,
)
def forward(
self,
inp: Tensor,
key_padding_mask: Optional[Tensor] = None,
) -> TransformerForwardResult:
parent_result = super().forward(inp, key_padding_mask)
token_logits = parent_result.logits
x = parent_result.hidden_states
x = self.fast_project_in(x)
# Fast transformer
fast_seq_len = self.config.num_codebooks
fast_mask = self.causal_mask[
None, None, :fast_seq_len, :fast_seq_len
] # (B, N, Q, K)
# Drop the last token and rotate left
codebooks = inp[:, 1:-1, 1:]
codebooks = F.pad(codebooks, (0, 1), value=0)
codebook_embeddings = self.fast_embeddings(codebooks)
x = torch.cat([x[:, None], codebook_embeddings], dim=1)
b, s = x.size(0), x.size(2)
x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
# Remove padded part
codebooks = rearrange(codebooks, "b n s -> (b s) n")
codebook_mask = (codebooks == 0).all(dim=-1)
if torch.all(codebook_mask):
# If all codebooks are padded, we keep first 8 to make sure the model runs
codebook_mask[:8] = False
x_bs, x_len = x.size(0), x.size(1)
x = x[~codebook_mask]
for layer in self.fast_layers:
if self.config.use_gradient_checkpointing and self.training:
x = checkpoint(
layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
)
else:
x = layer(x, self.fast_freqs_cis, fast_mask)
# unflatten the batch and num_codebooks
fast_out = self.fast_norm(x)
codebook_logits = self.fast_output(fast_out)
# Re-pad the codebook_logits
buffer = torch.zeros(
x_bs,
x_len,
codebook_logits.size(-1),
device=codebook_logits.device,
dtype=codebook_logits.dtype,
)
buffer[~codebook_mask] = codebook_logits
codebook_logits = buffer
assert codebook_logits.shape[1] == self.config.num_codebooks
codebook_logits = rearrange(
codebook_logits,
"(b s) n d -> b s n d",
b=b,
s=s,
n=self.config.num_codebooks,
)
return TransformerForwardResult(
token_logits=token_logits,
codebook_logits=codebook_logits,
)
def forward_generate_fast(
self, x: Tensor, input_pos: Optional[Tensor] = None
) -> Tensor:
# Fast transformer
x = x.view(1, 1, -1)
fast_mask = self.causal_mask[
None, None, input_pos, : self.config.num_codebooks
] # (B, N, Q, K)
fast_freqs_cis = self.fast_freqs_cis[input_pos]
for layer in self.fast_layers:
x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
# unflatten the batch and num_codebooks
fast_out = self.fast_norm(x) # only take the last token
codebook_logits = self.fast_output(fast_out)
return codebook_logits
def forward_generate(
self,
x: Tensor,
input_pos: Optional[Tensor] = None,
vq_masks: Optional[Tensor] = None,
) -> TransformerForwardResult:
x = super().forward_generate(x, input_pos, vq_masks)
x.hidden_states = self.fast_project_in(x.hidden_states)
return x
class TransformerBlock(nn.Module):
def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
super().__init__()
self.attention = Attention(config, use_sdpa=use_sdpa)
self.feed_forward = FeedForward(config)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
def forward(
self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Attention(nn.Module):
def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
super().__init__()
assert config.dim % config.n_head == 0
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
# key, query, value projections for all heads, but in a batch
self.wqkv = nn.Linear(
config.dim, total_head_dim, bias=config.attention_qkv_bias
)
self.wo = nn.Linear(config.dim, config.dim, bias=False)
self.kv_cache = None
self.dropout = config.dropout
self.n_head = config.n_head
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.dim = config.dim
self.use_sdpa = use_sdpa
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(self, state_dict, prefix, *args):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
def forward(
self,
x: Tensor,
freqs_cis: Tensor,
mask: Tensor,
input_pos: Optional[Tensor] = None,
) -> Tensor:
bsz, seqlen, _ = x.shape
kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
if self.kv_cache is not None:
k, v = self.kv_cache.update(input_pos, k, v)
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
if self.use_sdpa:
if mask is None:
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
y = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True,
# No third party attn_mask here to use flash_attention
)
else:
y = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=self.dropout if self.training else 0.0,
)
else:
y = self.eq_scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask,
dropout_p=self.dropout if self.training else 0.0,
)
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
return self.wo(y)
def eq_scaled_dot_product_attention(
self,
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
) -> torch.Tensor:
# This is a standard scaled dot product attention
# It's low efficient, but it doesn't raise cuda error
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1))
attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
class FeedForward(nn.Module):
def __init__(self, config: BaseModelArgs) -> None:
super().__init__()
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
def forward(self, x: Tensor) -> Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
"""
Precomputes frequency tensors for complex exponentials (cis)
Args:
seq_len: Length of the sequence for which positional embeddings are needed.
n_elem: Number of elements in the frequency tensor.
base: Base value for the frequency scaling (default: 10000).
Returns:
A tensor containing the precomputed frequencies in real and imaginary parts (bfloat16).
"""
freqs = 1.0 / (
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
)
t = torch.arange(seq_len, device=freqs.device)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
return cache.to(dtype=torch.bfloat16)
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
],
-1,
)
x_out2 = x_out2.flatten(3)
return x_out2.type_as(x)
from dataclasses import dataclass
import loralib as lora
@dataclass
class LoraConfig:
r: int
lora_alpha: float
lora_dropout: float = 0.0
def setup_lora(model, lora_config):
# Replace the embedding layer with a LoRA layer
model.embeddings = lora.Embedding(
num_embeddings=model.embeddings.num_embeddings,
embedding_dim=model.embeddings.embedding_dim,
padding_idx=model.embeddings.padding_idx,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
)
model.codebook_embeddings = lora.Embedding(
num_embeddings=model.codebook_embeddings.num_embeddings,
embedding_dim=model.codebook_embeddings.embedding_dim,
padding_idx=model.codebook_embeddings.padding_idx,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
)
# Replace output layer with a LoRA layer
linears = [(model, "output")]
# Replace all linear layers with LoRA layers
for layer in model.layers:
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
linears.extend(
[
(layer.feed_forward, "w1"),
(layer.feed_forward, "w2"),
(layer.feed_forward, "w3"),
]
)
if hasattr(model, "fast_layers"):
model.fast_embeddings = lora.Embedding(
num_embeddings=model.fast_embeddings.num_embeddings,
embedding_dim=model.fast_embeddings.embedding_dim,
padding_idx=model.fast_embeddings.padding_idx,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
)
# Dual-AR model
linears.append((model, "fast_output"))
for layer in model.fast_layers:
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
linears.extend(
[
(layer.feed_forward, "w1"),
(layer.feed_forward, "w2"),
(layer.feed_forward, "w3"),
]
)
for module, layer in linears:
updated_linear = lora.Linear(
in_features=getattr(module, layer).in_features,
out_features=getattr(module, layer).out_features,
bias=getattr(module, layer).bias,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
)
setattr(module, layer, updated_linear)
# Mark only the LoRA layers as trainable
lora.mark_only_lora_as_trainable(model, bias="none")
def get_merged_state_dict(model):
# This line will merge the state dict of the model and the LoRA parameters
model.eval()
# Then we need to remove the LoRA parameters from the state dict
state_dict = model.state_dict()
for name in list(state_dict.keys()):
if "lora" in name:
state_dict.pop(name)
return state_dict
from pathlib import Path
import click
import hydra
import numpy as np
import pyrootutils
import soundfile as sf
import torch
import torchaudio
from hydra import compose, initialize
from hydra.utils import instantiate
from loguru import logger
from omegaconf import OmegaConf
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from fish_speech.utils.file import AUDIO_EXTENSIONS
# register eval resolver
OmegaConf.register_new_resolver("eval", eval)
def load_model(config_name, checkpoint_path, device="cuda"):
hydra.core.global_hydra.GlobalHydra.instance().clear()
with initialize(version_base="1.3", config_path="../../configs"):
cfg = compose(config_name=config_name)
model = instantiate(cfg)
state_dict = torch.load(
checkpoint_path, map_location=device, mmap=True, weights_only=True
)
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
if any("generator" in k for k in state_dict):
state_dict = {
k.replace("generator.", ""): v
for k, v in state_dict.items()
if "generator." in k
}
result = model.load_state_dict(state_dict, strict=False, assign=True)
model.eval()
model.to(device)
logger.info(f"Loaded model: {result}")
return model
@torch.no_grad()
@click.command()
@click.option(
"--input-path",
"-i",
default="test.wav",
type=click.Path(exists=True, path_type=Path),
)
@click.option(
"--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
)
@click.option("--config-name", default="firefly_gan_vq")
@click.option(
"--checkpoint-path",
default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
)
@click.option(
"--device",
"-d",
default="cuda",
)
def main(input_path, output_path, config_name, checkpoint_path, device):
model = load_model(config_name, checkpoint_path, device=device)
if input_path.suffix in AUDIO_EXTENSIONS:
logger.info(f"Processing in-place reconstruction of {input_path}")
# Load audio
audio, sr = torchaudio.load(str(input_path))
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
audio = torchaudio.functional.resample(
audio, sr, model.spec_transform.sample_rate
)
audios = audio[None].to(device)
logger.info(
f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
)
# VQ Encoder
audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
indices = model.encode(audios, audio_lengths)[0][0]
# print("audios:", audios.tolist())
logger.info(f"Generated indices of shape {indices.shape}")
# Save indices
np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
elif input_path.suffix == ".npy":
logger.info(f"Processing precomputed indices from {input_path}")
indices = np.load(input_path)
indices = torch.from_numpy(indices).to(device).long()
assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
else:
raise ValueError(f"Unknown input type: {input_path}")
# Restore
feature_lengths = torch.tensor([indices.shape[1]], device=device)
fake_audios, _ = model.decode(
indices=indices[None], feature_lengths=feature_lengths
)
audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
logger.info(
f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
)
# Save audio
# torchaudio.save(output_path, fake_audios[0].cpu(), model.spec_transform.sample_rate)
fake_audio = fake_audios[0, 0].float().cpu().numpy()
sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
logger.info(f"Saved audio to {output_path}")
if __name__ == "__main__":
main()
import math
from functools import partial
from math import prod
from typing import Callable
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from torch.utils.checkpoint import checkpoint
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv1D") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return (kernel_size * dilation - dilation) // 2
def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
assert (padding_left + padding_right) <= x.shape[-1]
end = x.shape[-1] - padding_right
return x[..., padding_left:end]
def get_extra_padding_for_conv1d(
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
) -> int:
"""See `pad_for_conv1d`."""
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
# for tracer, math.ceil will make onnx graph become constant
if isinstance(n_frames, torch.Tensor):
ideal_length = (torch.ceil(n_frames).long() - 1) * stride + (
kernel_size - padding_total
)
else:
ideal_length = (math.ceil(n_frames) - 1) * stride + (
kernel_size - padding_total
)
return ideal_length - length
def pad1d(
x: torch.Tensor,
paddings: tuple[int, int],
mode: str = "zeros",
value: float = 0.0,
):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right
before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == "reflect":
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
class FishConvNet(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1
):
super(FishConvNet, self).__init__()
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
)
self.stride = stride
self.kernel_size = (kernel_size - 1) * dilation + 1
self.dilation = dilation
def forward(self, x):
pad = self.kernel_size - self.stride
extra_padding = get_extra_padding_for_conv1d(
x, self.kernel_size, self.stride, pad
)
x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
return self.conv(x).contiguous()
def weight_norm(self, name="weight", dim=0):
self.conv = weight_norm(self.conv, name=name, dim=dim)
return self
def remove_parametrizations(self, name="weight"):
self.conv = remove_parametrizations(self.conv, name)
return self
class FishTransConvNet(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
super(FishTransConvNet, self).__init__()
self.conv = nn.ConvTranspose1d(
in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
)
self.stride = stride
self.kernel_size = kernel_size
def forward(self, x):
x = self.conv(x)
pad = self.kernel_size - self.stride
padding_right = math.ceil(pad)
padding_left = pad - padding_right
x = unpad1d(x, (padding_left, padding_right))
return x.contiguous()
def weight_norm(self, name="weight", dim=0):
self.conv = weight_norm(self.conv, name=name, dim=dim)
return self
def remove_parametrizations(self, name="weight"):
self.conv = remove_parametrizations(self.conv, name)
return self
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super().__init__()
self.convs1 = nn.ModuleList(
[
FishConvNet(
channels, channels, kernel_size, stride=1, dilation=dilation[0]
).weight_norm(),
FishConvNet(
channels, channels, kernel_size, stride=1, dilation=dilation[1]
).weight_norm(),
FishConvNet(
channels, channels, kernel_size, stride=1, dilation=dilation[2]
).weight_norm(),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
FishConvNet(
channels, channels, kernel_size, stride=1, dilation=dilation[0]
).weight_norm(),
FishConvNet(
channels, channels, kernel_size, stride=1, dilation=dilation[1]
).weight_norm(),
FishConvNet(
channels, channels, kernel_size, stride=1, dilation=dilation[2]
).weight_norm(),
]
)
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.silu(x)
xt = c1(xt)
xt = F.silu(xt)
xt = c2(xt)
x = xt + x
return x
def remove_parametrizations(self):
for conv in self.convs1:
conv.remove_parametrizations()
for conv in self.convs2:
conv.remove_parametrizations()
class ParallelBlock(nn.Module):
def __init__(
self,
channels: int,
kernel_sizes: tuple[int] = (3, 7, 11),
dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
):
super().__init__()
assert len(kernel_sizes) == len(dilation_sizes)
self.blocks = nn.ModuleList()
for k, d in zip(kernel_sizes, dilation_sizes):
self.blocks.append(ResBlock1(channels, k, d))
def forward(self, x):
return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
def remove_parametrizations(self):
for block in self.blocks:
block.remove_parametrizations()
class HiFiGANGenerator(nn.Module):
def __init__(
self,
*,
hop_length: int = 512,
upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
resblock_kernel_sizes: tuple[int] = (3, 7, 11),
resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
num_mels: int = 128,
upsample_initial_channel: int = 512,
pre_conv_kernel_size: int = 7,
post_conv_kernel_size: int = 7,
post_activation: Callable = partial(nn.SiLU, inplace=True),
):
super().__init__()
assert (
prod(upsample_rates) == hop_length
), f"hop_length must be {prod(upsample_rates)}"
self.conv_pre = FishConvNet(
num_mels,
upsample_initial_channel,
pre_conv_kernel_size,
stride=1,
).weight_norm()
self.num_upsamples = len(upsample_rates)
self.num_kernels = len(resblock_kernel_sizes)
self.noise_convs = nn.ModuleList()
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
FishTransConvNet(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
stride=u,
).weight_norm()
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
self.resblocks.append(
ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
)
self.activation_post = post_activation()
self.conv_post = FishConvNet(
ch, 1, post_conv_kernel_size, stride=1
).weight_norm()
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
def forward(self, x):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.silu(x, inplace=True)
x = self.ups[i](x)
if self.training and self.checkpointing:
x = checkpoint(
self.resblocks[i],
x,
use_reentrant=False,
)
else:
x = self.resblocks[i](x)
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_parametrizations(self):
for up in self.ups:
up.remove_parametrizations()
for block in self.resblocks:
block.remove_parametrizations()
self.conv_pre.remove_parametrizations()
self.conv_post.remove_parametrizations()
# DropPath copied from timm library
def drop_path(
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
""" # noqa: E501
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f"drop_prob={round(self.drop_prob,3):0.3f}"
class LayerNorm(nn.Module):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
""" # noqa: E501
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps
)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None] * x + self.bias[:, None]
return x
# ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
class ConvNeXtBlock(nn.Module):
r"""ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
kernel_size (int): Kernel size for depthwise conv. Default: 7.
dilation (int): Dilation for depthwise conv. Default: 1.
""" # noqa: E501
def __init__(
self,
dim: int,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-6,
mlp_ratio: float = 4.0,
kernel_size: int = 7,
dilation: int = 1,
):
super().__init__()
self.dwconv = FishConvNet(
dim,
dim,
kernel_size=kernel_size,
# padding=int(dilation * (kernel_size - 1) / 2),
groups=dim,
) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim, int(mlp_ratio * dim)
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
self.gamma = (
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
if layer_scale_init_value > 0
else None
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x, apply_residual: bool = True):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
x = self.drop_path(x)
if apply_residual:
x = input + x
return x
class ConvNeXtEncoder(nn.Module):
def __init__(
self,
input_channels: int = 3,
depths: list[int] = [3, 3, 9, 3],
dims: list[int] = [96, 192, 384, 768],
drop_path_rate: float = 0.0,
layer_scale_init_value: float = 1e-6,
kernel_size: int = 7,
):
super().__init__()
assert len(depths) == len(dims)
self.downsample_layers = nn.ModuleList()
stem = nn.Sequential(
FishConvNet(
input_channels,
dims[0],
kernel_size=7,
# padding=3,
# padding_mode="replicate",
# padding_mode="zeros",
),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
)
self.downsample_layers.append(stem)
for i in range(len(depths) - 1):
mid_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
)
self.downsample_layers.append(mid_layer)
self.stages = nn.ModuleList()
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(len(depths)):
stage = nn.Sequential(
*[
ConvNeXtBlock(
dim=dims[i],
drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_value,
kernel_size=kernel_size,
)
for j in range(depths[i])
]
)
self.stages.append(stage)
cur += depths[i]
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
for i in range(len(self.downsample_layers)):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
return self.norm(x)
class FireflyArchitecture(nn.Module):
def __init__(
self,
backbone: nn.Module,
head: nn.Module,
quantizer: nn.Module,
spec_transform: nn.Module,
):
super().__init__()
self.backbone = backbone
self.head = head
self.quantizer = quantizer
self.spec_transform = spec_transform
self.downsample_factor = math.prod(self.quantizer.downsample_factor)
def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
if self.spec_transform is not None:
x = self.spec_transform(x)
x = self.backbone(x)
if mask is not None:
x = x * mask
if self.quantizer is not None:
vq_result = self.quantizer(x)
x = vq_result.z
if mask is not None:
x = x * mask
x = self.head(x, template=template)
if x.ndim == 2:
x = x[:, None, :]
if self.vq is not None:
return x, vq_result
return x
def encode(self, audios, audio_lengths):
audios = audios.float()
mels = self.spec_transform(audios)
mel_lengths = audio_lengths // self.spec_transform.hop_length
mel_masks = sequence_mask(mel_lengths, mels.shape[2])
mel_masks_float_conv = mel_masks[:, None, :].float()
mels = mels * mel_masks_float_conv
# Encode
encoded_features = self.backbone(mels) * mel_masks_float_conv
feature_lengths = mel_lengths // self.downsample_factor
return self.quantizer.encode(encoded_features), feature_lengths
def decode(self, indices, feature_lengths) -> torch.Tensor:
mel_masks = sequence_mask(
feature_lengths * self.downsample_factor,
indices.shape[2] * self.downsample_factor,
)
mel_masks_float_conv = mel_masks[:, None, :].float()
audio_lengths = (
feature_lengths * self.downsample_factor * self.spec_transform.hop_length
)
audio_masks = sequence_mask(
audio_lengths,
indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
)
audio_masks_float_conv = audio_masks[:, None, :].float()
z = self.quantizer.decode(indices) * mel_masks_float_conv
x = self.head(z) * audio_masks_float_conv
return x, audio_lengths
def remove_parametrizations(self):
if hasattr(self.backbone, "remove_parametrizations"):
self.backbone.remove_parametrizations()
if hasattr(self.head, "remove_parametrizations"):
self.head.remove_parametrizations()
@property
def device(self):
return next(self.parameters()).device
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from vector_quantize_pytorch import GroupedResidualFSQ
from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
@dataclass
class FSQResult:
z: torch.Tensor
codes: torch.Tensor
latents: torch.Tensor
class DownsampleFiniteScalarQuantize(nn.Module):
def __init__(
self,
input_dim: int = 512,
n_codebooks: int = 9,
n_groups: int = 1,
levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
downsample_factor: tuple[int] = (2, 2),
downsample_dims: tuple[int] | None = None,
):
super().__init__()
if downsample_dims is None:
downsample_dims = [input_dim for _ in range(len(downsample_factor))]
all_dims = (input_dim,) + tuple(downsample_dims)
self.residual_fsq = GroupedResidualFSQ(
dim=all_dims[-1],
levels=levels,
num_quantizers=n_codebooks,
groups=n_groups,
)
self.downsample_factor = downsample_factor
self.downsample_dims = downsample_dims
self.downsample = nn.Sequential(
*[
nn.Sequential(
FishConvNet(
all_dims[idx],
all_dims[idx + 1],
kernel_size=factor,
stride=factor,
),
ConvNeXtBlock(dim=all_dims[idx + 1]),
)
for idx, factor in enumerate(downsample_factor)
]
)
self.upsample = nn.Sequential(
*[
nn.Sequential(
FishTransConvNet(
all_dims[idx + 1],
all_dims[idx],
kernel_size=factor,
stride=factor,
),
ConvNeXtBlock(dim=all_dims[idx]),
)
for idx, factor in reversed(list(enumerate(downsample_factor)))
]
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(self, z) -> FSQResult:
original_shape = z.shape
z = self.downsample(z)
quantized, indices = self.residual_fsq(z.mT)
result = FSQResult(
z=quantized.mT,
codes=indices.mT,
latents=z,
)
result.z = self.upsample(result.z)
# Pad or crop z to match original shape
diff = original_shape[-1] - result.z.shape[-1]
left = diff // 2
right = diff - left
if diff > 0:
result.z = F.pad(result.z, (left, right))
elif diff < 0:
result.z = result.z[..., -left:right]
return result
def encode(self, z):
z = self.downsample(z)
_, indices = self.residual_fsq(z.mT)
indices = rearrange(indices, "g b l r -> b (g r) l")
return indices
def decode(self, indices: torch.Tensor):
indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
z_q = self.residual_fsq.get_output_from_indices(indices)
z_q = self.upsample(z_q.mT)
return z_q
import matplotlib
import torch
from matplotlib import pyplot as plt
matplotlib.use("Agg")
def convert_pad_shape(pad_shape):
l = pad_shape[::-1]
pad_shape = [item for sublist in l for item in sublist]
return pad_shape
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def plot_mel(data, titles=None):
fig, axes = plt.subplots(len(data), 1, squeeze=False)
if titles is None:
titles = [None for i in range(len(data))]
plt.tight_layout()
for i in range(len(data)):
mel = data[i]
if isinstance(mel, torch.Tensor):
mel = mel.float().detach().cpu().numpy()
axes[i][0].imshow(mel, origin="lower")
axes[i][0].set_aspect(2.5, adjustable="box")
axes[i][0].set_ylim(0, mel.shape[0])
axes[i][0].set_title(titles[i], fontsize="medium")
axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
axes[i][0].set_anchor("W")
return fig
def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret
def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
n_channels_int = n_channels[0]
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def avg_with_mask(x, mask):
assert mask.dtype == torch.float, "Mask should be float"
if mask.ndim == 2:
mask = mask.unsqueeze(1)
if mask.shape[1] == 1:
mask = mask.expand_as(x)
return (x * mask).sum() / mask.sum()
import math
def get_cosine_schedule_with_warmup_lr_lambda(
current_step: int,
*,
num_warmup_steps: int | float,
num_training_steps: int,
num_cycles: float = 0.5,
final_lr_ratio: float = 0.0,
):
if 0 < num_warmup_steps < 1: # float mode
num_warmup_steps = int(num_warmup_steps * num_training_steps)
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(
final_lr_ratio,
0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
)
def get_constant_schedule_with_warmup_lr_lambda(
current_step: int,
*,
num_warmup_steps: int | float,
num_training_steps: int | None = None,
):
if 0 < num_warmup_steps < 1: # float mode
num_warmup_steps = int(num_warmup_steps * num_training_steps)
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return 1.0
from .clean import clean_text
from .spliter import split_text
__all__ = ["clean_text", "split_text"]
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
# JetBrains PyCharm
.idea
# Customize
references
url.txt
# Git
.git
# This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works.
# Chn Text Norm
this is a repository for chinese text normalization (no longer maintained).
## Quick Start ##
### Git Clone Repo ###
git clone this repo to the root directory of your project which need to use it.
cd /path/to/proj
git clone https://github.com/Joee1995/chn-text-norm.git
after that, your doc tree should be:
```
proj # root of your project
|--- chn_text_norm # this chn-text-norm tool
|--- text.py
|--- ...
|--- text_normalize.py # your text normalization code
|--- ...
```
### How to Use ? ###
# text_normalize.py
from chn_text_norm.text import *
raw_text = 'your raw text'
text = Text(raw_text=raw_text).normalize()
### How to add quantums ###
打开test.py,然后你就知道怎么做了。
# -*- coding: utf-8 -*-
"""基本类
中文字符类
中文数字/数位类
中文数字类
中文数位类
中文数字系统类
中文数学符号类
*中文其他符号类
"""
__author__ = "Zhiyang Zhou <zyzhou@stu.xmu.edu.cn>"
__data__ = "2019-05-02"
from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES
class ChineseChar(object):
"""
中文字符
每个字符对应简体和繁体,
e.g. 简体 = '负', 繁体 = '負'
转换时可转换为简体或繁体
"""
def __init__(self, simplified, traditional):
self.simplified = simplified
self.traditional = traditional
self.__repr__ = self.__str__
def __str__(self):
return self.simplified or self.traditional or None
def __repr__(self):
return self.__str__()
class ChineseNumberUnit(ChineseChar):
"""
中文数字/数位字符
每个字符除繁简体外还有一个额外的大写字符
e.g. '陆' 和 '陸'
"""
def __init__(self, power, simplified, traditional, big_s, big_t):
super(ChineseNumberUnit, self).__init__(simplified, traditional)
self.power = power
self.big_s = big_s
self.big_t = big_t
def __str__(self):
return "10^{}".format(self.power)
@classmethod
def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
if small_unit:
return ChineseNumberUnit(
power=index + 1,
simplified=value[0],
traditional=value[1],
big_s=value[1],
big_t=value[1],
)
elif numbering_type == NUMBERING_TYPES[0]:
return ChineseNumberUnit(
power=index + 8,
simplified=value[0],
traditional=value[1],
big_s=value[0],
big_t=value[1],
)
elif numbering_type == NUMBERING_TYPES[1]:
return ChineseNumberUnit(
power=(index + 2) * 4,
simplified=value[0],
traditional=value[1],
big_s=value[0],
big_t=value[1],
)
elif numbering_type == NUMBERING_TYPES[2]:
return ChineseNumberUnit(
power=pow(2, index + 3),
simplified=value[0],
traditional=value[1],
big_s=value[0],
big_t=value[1],
)
else:
raise ValueError(
"Counting type should be in {0} ({1} provided).".format(
NUMBERING_TYPES, numbering_type
)
)
class ChineseNumberDigit(ChineseChar):
"""
中文数字字符
"""
def __init__(
self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None
):
super(ChineseNumberDigit, self).__init__(simplified, traditional)
self.value = value
self.big_s = big_s
self.big_t = big_t
self.alt_s = alt_s
self.alt_t = alt_t
def __str__(self):
return str(self.value)
@classmethod
def create(cls, i, v):
return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
class ChineseMath(ChineseChar):
"""
中文数位字符
"""
def __init__(self, simplified, traditional, symbol, expression=None):
super(ChineseMath, self).__init__(simplified, traditional)
self.symbol = symbol
self.expression = expression
self.big_s = simplified
self.big_t = traditional
CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
class NumberSystem(object):
"""
中文数字系统
"""
pass
class MathSymbol(object):
"""
用于中文数字系统的数学符号 (繁/简体), e.g.
positive = ['正', '正']
negative = ['负', '負']
point = ['点', '點']
"""
def __init__(self, positive, negative, point):
self.positive = positive
self.negative = negative
self.point = point
def __iter__(self):
for v in self.__dict__.values():
yield v
# class OtherSymbol(object):
# """
# 其他符号
# """
#
# def __init__(self, sil):
# self.sil = sil
#
# def __iter__(self):
# for v in self.__dict__.values():
# yield v
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment