Commit 5e8fb565 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #1799 canceled with stages
import os
import torch
from litgpt.generate.base import next_token_image_batch
import soundfile as sf
from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
from utils.snac_utils import get_snac, generate_audio_data
import clip
import inference
from tqdm import tqdm
from inference import OmniInference, load_model, load_audio, download_model
from inference import text_vocabsize, padded_text_vocabsize, get_text_stream
from PIL import Image
torch.set_printoptions(sci_mode=False)
_image = inference._image
_eoimage = inference._eoimage
_pad_t = inference._pad_t
_input_t = inference._input_t
_answer_t = inference._answer_t
_eot = inference._eot
_eoa = inference._eoa
_pad_a = inference._pad_a
_input_a = inference._input_a
_answer_a = inference._answer_a
def get_input_ids_ImageQA_ATBatch(mel, leng, whispermodel, device):
with torch.no_grad():
mel = mel.unsqueeze(0).to(device)
audio_feature = whispermodel.embed_audio(mel)[0][:leng]
audio_len = audio_feature.size(0)
input_ids = []
input_ids_item = [[] for i in range(8)]
for i in range(7):
input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)]
input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)]
input_ids_item[i] += [layershift(_answer_a,i)]
input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t]
input_ids_item = [torch.tensor(item) for item in input_ids_item]
input_ids.append(input_ids_item)
input_ids_item = [[] for i in range(8)]
for i in range(7):
input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)]
input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)] + [layershift(_pad_a,i)]
input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t]
input_ids_item = [torch.tensor(item) for item in input_ids_item]
input_ids.append(input_ids_item)
stacked_inputids = [[] for _ in range(8)]
for i in range(2):
for j in range(8):
stacked_inputids[j].append(input_ids[i][j])
stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
return torch.stack([audio_feature,audio_feature]), stacked_inputids
def load_clip_model(ckpt_dir, device):
clip_model_path = ckpt_dir + "/ViT-B-32.pt"
if not os.path.exists(clip_model_path):
clip_model_path = "ViT-B/32"
clipmodel, clippreprocess = clip.load(clip_model_path, device=device)
return clipmodel, clippreprocess
class OmniVisionInference(OmniInference):
def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
self.device = device
if not os.path.exists(ckpt_dir):
print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
download_model(ckpt_dir)
self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
self.clipmodel, self.clippreprocess = load_clip_model(ckpt_dir, device)
def warm_up(self,
audio_sample='./data/samples/vision_qa_audio.wav',
image_sample='./data/samples/vision_qa_image.jpg'
):
for _ in self.run_vision_AA_batch_stream(audio_sample, image_sample,
save_path="./data/samples/vision_qa_output.wav",
warm_up=True):
pass
@torch.inference_mode()
def run_vision_AA_batch_stream(self, audio_path, image_path,
stream_stride=4,
max_returned_tokens=2048,
temperature=0.9,
top_k=1,
top_p=1.0,
eos_id_a=_eoa,
eos_id_t=_eot,
pad_id=_pad_t,
save_path=None,
warm_up=False
):
with self.fabric.init_tensor():
self.model.set_kv_cache(batch_size=2)
model = self.model
mel, leng = load_audio(audio_path)
img = Image.open(image_path)
audio_feature, input_ids = get_input_ids_ImageQA_ATBatch(mel, leng, self.whispermodel, self.device)
ima = self.clippreprocess(img).unsqueeze(0).to(self.device)
ima_feature = self.clipmodel.encode_image(ima).squeeze(0).to(self.device)
ima_feature = torch.stack([ima_feature.clone(),ima_feature.clone()]).to(self.device)
leng = [leng,leng]
task = ['ImageQA_A','ImageQA_AT']
T = input_ids[0].size(1)
assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
if model.max_seq_length < max_returned_tokens - 1:
raise NotImplementedError(
f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
)
list_output = [[] for i in range(8)]
tokens_A , token_T = next_token_image_batch(
model,
audio_feature.to(torch.float32).to(self.device),
ima_feature.to(torch.float32).to(self.device) ,
input_ids ,
whisper_lens = leng ,
task = task,
input_pos = torch.arange(0, T, device=self.device),
temperature=temperature,
top_k=top_k,
top_p=top_p
)
for i in range(7): list_output[i].append(tokens_A[i].tolist()[0])
list_output[7].append(token_T.tolist()[0])
text_end = False
index = 1
nums_generate = stream_stride
begin_generate = False
current_index = 0
input_pos = torch.tensor([T], device=self.device)
model_input_ids = [[] for i in range(8)]
for i in range(7):
tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160
model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device))
model_input_ids[i] = torch.stack(model_input_ids[i])
model_input_ids[-1].append(token_T.clone().to(torch.int32))
model_input_ids[-1].append(token_T.clone().to(torch.int32))
model_input_ids[-1] = torch.stack(model_input_ids[-1])
text_index = 0
is_text_end = False
for _ in tqdm(range(2, max_returned_tokens - T + 1)):
tokens_A , token_T = next_token_image_batch(model, None , None ,
input_ids = model_input_ids,
whisper_lens= None,
task = None,
input_pos = input_pos,
temperature=temperature,
top_k=top_k,
top_p=top_p)
if text_end:
token_T = torch.tensor([_pad_t], device=self.device)
if tokens_A[-1] == eos_id_a:
break
if token_T == eos_id_t:
text_end = True
for i in range(7): list_output[i].append(tokens_A[i].tolist()[0])
list_output[7].append(token_T.tolist()[0])
if index == 7:
begin_generate = True
if begin_generate:
current_index += 1
if current_index == nums_generate:
current_index = 0
snac = get_snac(list_output,index,nums_generate)
audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
if is_text_end:
text_stream = ""
else:
text_stream, text_index, is_text_end = get_text_stream(list_output, text_index, self.text_tokenizer)
yield (audio_stream, text_stream)
if warm_up:
break
input_pos = input_pos.add_(1)
model_input_ids = [[] for i in range(8)]
for i in range(7):
tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160
model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device))
model_input_ids[i] = torch.stack(model_input_ids[i])
model_input_ids[-1].append(token_T.clone().to(torch.int32))
model_input_ids[-1].append(token_T.clone().to(torch.int32))
model_input_ids[-1] = torch.stack(model_input_ids[-1])
index += 1
text_tokens = list_output[-1]
if text_vocabsize in text_tokens:
text_tokens = text_tokens[:text_tokens.index(text_vocabsize)]
res_text = self.text_tokenizer.decode(torch.tensor(text_tokens))
print(f"text output: {res_text}")
if save_path is not None:
audiolist = reconscruct_snac(list_output)
audio = reconstruct_tensors(audiolist)
with torch.inference_mode():
audio_hat = self.snacmodel.decode(audio)
sf.write(save_path, audio_hat.squeeze().cpu().numpy(), 24000)
model.clear_kv_cache()
def test_vision_infer():
client = OmniVisionInference()
client.warm_up()
input_audio_path = './data/samples/vision_qa_audio.wav'
input_image_path = './data/samples/vision_qa_image.jpg'
res_text = ""
for audio_stream, text_stream in client.run_vision_AA_batch_stream(
input_audio_path,
input_image_path,
save_path="./vision_qa_output.wav"
):
res_text += text_stream
print(f"text_output: {res_text}")
if __name__ == "__main__":
test_vision_infer()
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import logging
import re
from litgpt.model import GPT # needs to be imported before config
from litgpt.config import Config
from litgpt.tokenizer import Tokenizer
# Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632
pattern = re.compile(".*Profiler function .* will be ignored")
logging.getLogger("torch._dynamo.variables.torch").addFilter(
lambda record: not pattern.search(record.getMessage())
)
# Avoid printing state-dict profiling output at the WARNING level when saving a checkpoint
logging.getLogger("torch.distributed.fsdp._optim_utils").disabled = True
logging.getLogger("torch.distributed.fsdp._debug_utils").disabled = True
__all__ = ["GPT", "Config", "Tokenizer"]
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, Optional, Type, Union
import torch
import yaml
from typing_extensions import Self
import litgpt.model
from litgpt.utils import find_multiple
@dataclass
class Config:
name: str = ""
hf_config: dict = field(default_factory=dict)
scale_embeddings: bool = False
block_size: int = 4096
vocab_size: int = 50254
padding_multiple: int = 512
padded_vocab_size: Optional[int] = None
n_layer: int = 16
n_head: int = 32
head_size: Optional[int] = None
n_embd: int = 4096
rotary_percentage: float = 0.25
parallel_residual: bool = True
bias: bool = True
lm_head_bias: bool = False
# to use multi-head attention (MHA), set this to `n_head` (default)
# to use multi-query attention (MQA), set this to 1
# to use grouped-query attention (GQA), set this to a value in between
# Example with `n_head=4`
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
# │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
# │ │ │ │ │ │ │
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
# │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
# └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
# │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
# ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
# │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
# └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
# ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
# MHA GQA MQA
# n_query_groups=4 n_query_groups=2 n_query_groups=1
#
# credit https://arxiv.org/pdf/2305.13245.pdf
n_query_groups: Optional[int] = None
shared_attention_norm: bool = False
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
norm_eps: float = 1e-5
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = (
"GptNeoxMLP"
)
gelu_approximate: str = "none"
intermediate_size: Optional[int] = None
rope_condense_ratio: int = 1
rope_base: int = 10000
n_expert: int = 0
n_expert_per_token: int = 0
add_qkv_bias: Optional[bool] = None
prompt_vocab_size: Optional[int] = None
attn_dropout: float = 0.0
pos_type: str = "rope"
force_align: bool = False
use_pretrain_phoneme_emb: bool = False
tie_word_embeddings: bool = False
# setting for mini-omni
text_vocab_size:int = 152000
cat_audio_vocab_size: int = 29120
audio_vocab_size: int = 4160
whisper_adapter_dim: int = 768
vision_adapter_dim: int = 512
post_adapter: bool = False
post_adapter_layers: int = 6
asr_adapter: str = "llamamlp"
def __post_init__(self):
if not self.name:
self.name = self.hf_config.get("name", self.name)
if self.head_size is None:
assert self.n_embd % self.n_head == 0
self.head_size = self.n_embd // self.n_head
# vocab size should be a power of 2 to be optimal on hardware. compute the closest value
if self.padded_vocab_size is None:
self.padded_vocab_size = find_multiple(
self.vocab_size, self.padding_multiple
)
else:
# vocab size shouldn't be larger than padded vocab size
self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
# compute the number of query groups
if self.n_query_groups is not None:
assert self.n_head % self.n_query_groups == 0
else:
self.n_query_groups = self.n_head
# compute the intermediate size for MLP if not set
if self.intermediate_size is None:
if self.mlp_class_name == "LLaMAMLP":
raise ValueError(
f"The config {self.name!r}, needs to set the `intermediate_size`"
)
self.intermediate_size = 4 * self.n_embd
self.rope_n_elem = int(self.rotary_percentage * self.head_size)
if self.add_qkv_bias is None:
self.add_qkv_bias = self.bias
@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
if name not in name_to_config:
# search through all `config['hf_config']['name']`
try:
conf_dict = next(
config
for config in configs
if name == config["hf_config"]["name"]
or config["hf_config"]["org"] + "/" + config["hf_config"]["name"]
== name
)
except StopIteration:
raise ValueError(f"{name!r} is not a supported config name")
else:
conf_dict = name_to_config[name]
conf_dict = conf_dict.copy()
conf_dict.update(kwargs)
return cls(**conf_dict)
@classmethod
def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
with open(path, encoding="utf-8") as fp:
file_kwargs = yaml.safe_load(fp)
if file_kwargs is None:
raise ValueError(f"{path} is empty which is likely unexpected.")
file_kwargs.update(kwargs)
return cls(**file_kwargs)
@classmethod
def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
"""Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`."""
if (config_path := path / "model_config.yaml").is_file():
return cls.from_file(config_path, **kwargs)
if (model_name := path.name) in name_to_config:
return cls.from_name(model_name, **kwargs)
raise FileNotFoundError(
f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists."
)
@property
def mlp_class(self) -> Type:
# `self.mlp_class_name` cannot be the type to keep the config serializable
return getattr(litgpt.model, self.mlp_class_name)
@property
def norm_class(self) -> Type:
# `self.norm_class_name` cannot be the type to keep the config serializable
if self.norm_class_name == "RMSNorm":
from functools import partial
from litgpt.model import RMSNorm
return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
return getattr(torch.nn, self.norm_class_name)
configs = []
name_to_config = {config["name"]: config for config in configs}
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from typing import Any, Literal, Optional
import torch
# import torch._dynamo.config
# import torch._inductor.config
from litgpt.model import GPT
from utils.snac_utils import layershift, snac_config
from tqdm import tqdm
def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
if torch._dynamo.is_compiling():
# Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
distribution = torch.empty_like(probs).exponential_(1)
return torch.argmax(probs / distribution, dim=-1, keepdim=True)
return torch.multinomial(probs, num_samples=1)
def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Example:
# sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0]
# sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# Keep at least 1 token always to prevent the case where no token is selected
# In this case the most probable one is always kept
sorted_indices_to_remove[-1:] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
0, sorted_indices, sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, float("-inf"))
return logits
def sample(
logits: torch.Tensor,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
) -> torch.Tensor:
if top_p < 0.0 or top_p > 1.0:
raise ValueError(f"top_p must be in [0, 1], got {top_p}")
logits = logits[0, -1]
# optionally crop the logits to only the top k options
if top_k is not None:
v, i = torch.topk(logits, min(top_k, logits.size(-1)))
# do not use `torch.where` as in nanogpt because it will repeat top-k collisions
logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
# optionally scale the logits and sample from a probability distribution
if temperature > 0.0 or top_p > 0.0:
if temperature > 0.0:
logits = logits / temperature
# optionally crop the logits to smallest set of logits with a cumulative probability above top_p
if top_p < 1.0:
logits = sample_top_p(logits, top_p)
probs = torch.nn.functional.softmax(logits, dim=-1)
return multinomial_num_samples_1(probs)
return torch.argmax(logits, dim=-1, keepdim=True)
def next_token(
model: GPT, input_pos: torch.Tensor, x: list, **kwargs: Any
) -> torch.Tensor:
input_pos = input_pos.to(model.device)
logits_a, logit_t = model(None, x, None, input_pos)
next_audio_tokens = []
for logit_a in logits_a:
next_a = sample(logit_a, **kwargs).to(dtype=x[0].dtype)
next_audio_tokens.append(next_a)
next_t = sample(logit_t, **kwargs).to(dtype=x[0].dtype)
return next_audio_tokens, next_t
def next_token_asr(
model: GPT,
input_pos: torch.Tensor,
audio_features: torch.tensor,
lens: int,
input_ids: list,
**kwargs: Any,
) -> torch.Tensor:
input_pos = input_pos.to(model.device)
input_ids = [input_id.to(model.device) for input_id in input_ids]
logits_a, logit_t = model(audio_features, input_ids, None, input_pos, whisper_lens=lens)
next_audio_tokens = []
for logit_a in logits_a:
next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
next_audio_tokens.append(next_a)
next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
return next_audio_tokens, next_t
def next_token_A1T2(
model: GPT,
audio_features: torch.tensor,
input_ids: list,
whisper_lens: int,
task: list,
input_pos: torch.Tensor,
**kwargs: Any,
) -> torch.Tensor:
input_pos = input_pos.to(model.device)
input_ids = [input_id.to(model.device) for input_id in input_ids]
logits_a, logit_t = model(
audio_features, input_ids, None, input_pos, whisper_lens=whisper_lens, task=task
)
next_audio_tokens = []
for logit_a in logits_a:
next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
next_audio_tokens.append(next_a)
next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
return next_audio_tokens, next_t
def next_token_A1T1(
model: GPT,
audio_features: torch.tensor,
input_ids: list,
whisper_lens: int,
task: list,
input_pos: torch.Tensor,
**kwargs: Any,
) -> torch.Tensor:
input_pos = input_pos.to(model.device)
input_ids = [input_id.to(model.device) for input_id in input_ids]
logits_a, logit_t = model(
audio_features, input_ids, None, input_pos, whisper_lens=whisper_lens, task=task
)
next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
return next_t
def next_token_image_batch(model: GPT,
audio_features: torch.tensor,
clip_features: torch.tensor,
input_ids: list,
whisper_lens: int,
task: list,
input_pos: torch.Tensor,
**kwargs: Any) -> torch.Tensor:
input_pos = input_pos.to(model.device)
input_ids = [input_id.to(model.device) for input_id in input_ids]
logits_a,logit_t = model(audio_features, input_ids, clip_features,
input_pos, whisper_lens=whisper_lens, task=task)
for i in range(7):
logits_a[i] = logits_a[i][0].unsqueeze(0)
logit_t = logit_t[1].unsqueeze(0)
next_audio_tokens = []
for logit_a in logits_a:
next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
next_audio_tokens.append(next_a)
next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
return next_audio_tokens, next_t
# torch._dynamo.config.automatic_dynamic_shapes = True
# torch._inductor.config.triton.unique_kernel_names = True
# torch._inductor.config.coordinate_descent_tuning = True
# next_token = torch.compile(next_token, mode="reduce-overhead")
@torch.inference_mode()
def generate(
model: GPT,
input_ids: list,
max_returned_tokens: int,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
eos_id_a: Optional[int] = None,
eos_id_t: Optional[int] = None,
pad_id: Optional[int] = None,
shift: Optional[int] = None,
include_prompt: bool = True,
generate_text=False,
) -> torch.Tensor:
# print("eos_id_a:", eos_id_a)
# print("eos_id_t:", eos_id_t)
# print("pad_id:", pad_id)
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
The implementation of this function is modified from A. Karpathy's nanoGPT.
Args:
model: The model to use.
prompt: Tensor of shape (T) with indices of the prompt sequence.
max_returned_tokens: The maximum number of tokens to return (given plus generated).
temperature: Scales the predicted logits by 1 / temperature.
top_k: If specified, only sample among the tokens with the k highest probabilities.
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
In top-p sampling, the next token is sampled from the highest probability tokens
whose cumulative probability exceeds the threshold `top_p`. When specified,
it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
to sampling the most probable token, while `top_p=1` samples from the whole distribution.
It can be used in conjunction with `top_k` and `temperature` with the following order
of application:
1. `top_k` sampling
2. `temperature` scaling
3. `top_p` sampling
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
eos_id: If specified, stop generating any more token once the <eos> token is triggered.
include_prompt: If true (default) prepends the prompt (after applying the prompt style) to the output.
"""
T = input_ids[0].size(0)
device = input_ids[0].device
assert max_returned_tokens > T
if model.max_seq_length < max_returned_tokens - 1:
# rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
# data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
# not support it to avoid negatively impacting the overall speed
raise NotImplementedError(
f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
)
for input_id in input_ids:
input_id = [input_id]
(
tokens_A1,
tokens_A2,
tokens_A3,
tokens_A4,
tokens_A5,
tokens_A6,
tokens_A7,
tokens_T,
) = input_ids
tokens_A1_output = [tokens_A1]
tokens_A2_output = [tokens_A2]
tokens_A3_output = [tokens_A3]
tokens_A4_output = [tokens_A4]
tokens_A5_output = [tokens_A5]
tokens_A6_output = [tokens_A6]
tokens_A7_output = [tokens_A7]
tokens_T_output = [tokens_T]
list_output = [
tokens_A1_output,
tokens_A2_output,
tokens_A3_output,
tokens_A4_output,
tokens_A5_output,
tokens_A6_output,
tokens_A7_output,
tokens_T_output,
]
input_pos = torch.tensor([T], device=device)
model_input_ids = [
tokens_A1.view(1, -1),
tokens_A2.view(1, -1),
tokens_A3.view(1, -1),
tokens_A4.view(1, -1),
tokens_A5.view(1, -1),
tokens_A6.view(1, -1),
tokens_A7.view(1, -1),
tokens_T.view(1, -1),
]
tokens_A, token_T = next_token(
model,
torch.arange(0, T, device=device),
model_input_ids,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
for i in range(7):
list_output[i].append(tokens_A[i].clone())
list_output[7].append(token_T.clone())
# prepare the input for the next iteration
for i in range(7):
tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
token_T = token_T.clone()
text_end = False
max_returned_tokens = 1000
for _ in tqdm(range(2, max_returned_tokens - T + 1)):
model_input_ids = [
token_a.view(1, -1).to(torch.int32) for token_a in tokens_A
] + [token_T.view(1, -1).to(torch.int32)]
tokens_A, token_T = next_token(
model,
input_pos,
model_input_ids,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
if text_end:
token_T = torch.tensor([pad_id], device=device)
for i in range(7):
list_output[i].append(tokens_A[i].clone())
list_output[7].append(token_T.clone())
if tokens_A[-1] == eos_id_a:
break
if token_T == eos_id_t:
if generate_text:
break
text_end = True
for i in range(7):
tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
token_T = token_T.clone()
input_pos = input_pos.add_(1)
for i in range(len(list_output)):
list_output[i] = torch.cat(list_output[i])
return list_output
@torch.inference_mode()
def generate_TA_BATCH(
model: GPT,
audio_features: torch.Tensor,
input_ids: list,
leng,
task,
max_returned_tokens: int = 1000,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
eos_id_a: Optional[int] = None,
eos_id_t: Optional[int] = None,
pad_id_t: Optional[int] = None,
shift: Optional[int] = None,
include_prompt: bool = True,
generate_text=False,
) -> torch.Tensor:
T = input_ids[0].size(1)
device = input_ids[0].device
assert max_returned_tokens > T
if model.max_seq_length < max_returned_tokens - 1:
raise NotImplementedError(
f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
)
input_pos = torch.tensor([T], device=device)
model_input_ids = input_ids
list_output = [[] for i in range(8)]
tokens_A, token_T = next_token_image_batch(
model,
audio_features.to(torch.float32).to(model.device),
None,
input_ids,
[T - 3, T - 3],
["A1T2", "A1T2"],
input_pos=torch.arange(0, T, device=device),
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
for i in range(7):
list_output[i].append(tokens_A[i].tolist()[0])
list_output[7].append(token_T.tolist()[0])
model_input_ids = [[] for i in range(8)]
for i in range(7):
tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
model_input_ids[i].append(torch.tensor([layershift(snac_config.end_of_audio, i)], device=device))
model_input_ids[i] = torch.stack(model_input_ids[i])
model_input_ids[-1].append(token_T.clone().to(torch.int32))
model_input_ids[-1].append(token_T.clone().to(torch.int32))
model_input_ids[-1] = torch.stack(model_input_ids[-1])
text_end = False
for _ in range(2, max_returned_tokens - T + 1):
tokens_A, token_T = next_token_image_batch(
model,
None,
None,
model_input_ids,
None,
None,
input_pos=input_pos,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
if text_end:
token_T = torch.tensor([pad_id_t], device=device)
if tokens_A[-1] == eos_id_a:
break
if token_T == eos_id_t:
text_end = True
for i in range(7):
list_output[i].append(tokens_A[i].tolist()[0])
list_output[7].append(token_T.tolist()[0])
model_input_ids = [[] for i in range(8)]
for i in range(7):
tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
model_input_ids[i].append(
torch.tensor([layershift(snac_config.end_of_audio, i)], device=device)
)
model_input_ids[i] = torch.stack(model_input_ids[i])
model_input_ids[-1].append(token_T.clone().to(torch.int32))
model_input_ids[-1].append(token_T.clone().to(torch.int32))
model_input_ids[-1] = torch.stack(model_input_ids[-1])
input_pos = input_pos.add_(1)
return list_output
@torch.inference_mode()
def generate_TT(
model: GPT,
audio_features: torch.Tensor,
input_ids: list,
leng,
task,
max_returned_tokens: int = 2048,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
eos_id_a: Optional[int] = None,
eos_id_t: Optional[int] = None,
pad_id_t: Optional[int] = None,
shift: Optional[int] = None,
include_prompt: bool = True,
generate_text=False,
) -> torch.Tensor:
T = input_ids[0].size(1)
device = input_ids[0].device
output = []
token_T = next_token_A1T1(
model,
None,
input_ids,
None,
None,
input_pos=torch.arange(0, T, device=device),
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
output.append(token_T.clone().tolist()[0])
input_pos = torch.tensor([T], device=device)
for _ in tqdm(range(2, max_returned_tokens - T + 1)):
model_input_ids = []
for i in range(7):
model_input_ids.append(
torch.tensor([layershift(snac_config.end_of_audio, i)])
.view(1, -1)
.to(torch.int32)
.to(device)
)
model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
token_T = next_token_A1T1(
model,
None,
model_input_ids,
None,
None,
input_pos=input_pos,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
if token_T == eos_id_t:
break
output.append(token_T.clone().tolist()[0])
input_pos = input_pos.add_(1)
return output
@torch.inference_mode()
def generate_AT(
model: GPT,
audio_features: torch.Tensor,
input_ids: list,
leng,
task,
max_returned_tokens: int = 2048,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
eos_id_a: Optional[int] = None,
eos_id_t: Optional[int] = None,
pad_id_t: Optional[int] = None,
shift: Optional[int] = None,
include_prompt: bool = True,
generate_text=False,
) -> torch.Tensor:
T = input_ids[0].size(1)
device = input_ids[0].device
output = []
token_T = next_token_A1T1(
model,
audio_features.to(torch.float32).to(model.device),
input_ids,
[T - 3],
["AT"],
input_pos=torch.arange(0, T, device=device),
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
output.append(token_T.clone().tolist()[0])
input_pos = torch.tensor([T], device=device)
text_end = False
for _ in tqdm(range(2, max_returned_tokens - T + 1)):
model_input_ids = []
for i in range(7):
model_input_ids.append(
torch.tensor([layershift(snac_config.end_of_audio, i)])
.view(1, -1)
.to(torch.int32)
.to(device)
)
model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
token_T = next_token_A1T1(
model,
None,
model_input_ids,
None,
None,
input_pos=input_pos,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
if token_T == eos_id_t:
break
output.append(token_T.clone().tolist()[0])
input_pos = input_pos.add_(1)
return output
@torch.inference_mode()
def generate_TA(
model: GPT,
audio_features: torch.Tensor,
input_ids: list,
leng,
task,
max_returned_tokens: int = 2048,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
eos_id_a: Optional[int] = None,
eos_id_t: Optional[int] = None,
pad_id_t: Optional[int] = None,
shift: Optional[int] = None,
include_prompt: bool = True,
generate_text=False,
) -> torch.Tensor:
T = input_ids[0].size(1)
device = input_ids[0].device
output = [[] for _ in range(8)]
tokens_A, token_T = next_token_A1T2(
model,
None,
input_ids,
None,
None,
input_pos=torch.arange(0, T, device=device),
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
for i in range(7):
output[i].append(tokens_A[i].clone().tolist()[0])
output[7].append(token_T.clone().tolist()[0])
input_pos = torch.tensor([T], device=device)
text_end = False
for _ in tqdm(range(2, max_returned_tokens - T + 1)):
model_input_ids = []
for i in range(7):
model_input_ids.append(
layershift(tokens_A[i].clone(), i)
.view(1, -1)
.to(torch.int32)
.to(device)
)
model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
tokens_A, token_T = next_token_A1T2(
model,
None,
model_input_ids,
None,
None,
input_pos=input_pos,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
if text_end:
token_T = torch.tensor([pad_id_t], device=device)
if tokens_A[-1] == eos_id_a:
break
if token_T == eos_id_t:
text_end = True
for i in range(7):
output[i].append(tokens_A[i].clone().tolist()[0])
output[7].append(token_T.clone().tolist()[0])
input_pos = input_pos.add_(1)
return output
@torch.inference_mode()
def generate_AA(
model: GPT,
audio_features: torch.Tensor,
input_ids: list,
leng,
task,
max_returned_tokens: int = 2048,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
eos_id_a: Optional[int] = None,
eos_id_t: Optional[int] = None,
pad_id_t: Optional[int] = None,
shift: Optional[int] = None,
include_prompt: bool = True,
generate_text=False,
) -> torch.Tensor:
T = input_ids[0].size(1)
device = input_ids[0].device
output = [[] for _ in range(8)]
tokens_A, token_T = next_token_A1T2(
model,
audio_features.to(torch.float32).to(model.device),
input_ids,
[T - 3],
["A1T2"],
input_pos=torch.arange(0, T, device=device),
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
for i in range(7):
output[i].append(tokens_A[i].clone().tolist()[0])
output[7].append(token_T.clone().tolist()[0])
input_pos = torch.tensor([T], device=device)
text_end = False
for _ in tqdm(range(2, max_returned_tokens - T + 1)):
model_input_ids = []
for i in range(7):
model_input_ids.append(
layershift(tokens_A[i].clone(), i)
.view(1, -1)
.to(torch.int32)
.to(device)
)
model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
tokens_A, token_T = next_token_A1T2(
model,
None,
model_input_ids,
None,
None,
input_pos=input_pos,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
if text_end:
token_T = torch.tensor([pad_id_t], device=device)
if tokens_A[-1] == eos_id_a:
break
if token_T == eos_id_t:
# print("text_end")
text_end = True
for i in range(7):
output[i].append(tokens_A[i].clone().tolist()[0])
output[7].append(token_T.clone().tolist()[0])
input_pos = input_pos.add_(1)
return output
@torch.inference_mode()
def generate_ASR(
model: GPT,
audio_features: torch.Tensor,
input_ids: list,
leng,
task,
max_returned_tokens: int = 1200,
*,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
eos_id_a: Optional[int] = None,
eos_id_t: Optional[int] = None,
pad_id_t: Optional[int] = None,
shift: Optional[int] = None,
include_prompt: bool = True,
generate_text=False,
) -> torch.Tensor:
T = input_ids[0].size(1)
device = input_ids[0].device
output = []
token_T = next_token_A1T1(
model,
audio_features.to(torch.float32).to(model.device),
input_ids,
[T - 3],
["asr"],
input_pos=torch.arange(0, T, device=device),
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
output.append(token_T.clone().tolist()[0])
input_pos = torch.tensor([T], device=device)
text_end = False
for _ in tqdm(range(2, max_returned_tokens - T + 1)):
model_input_ids = []
for i in range(7):
model_input_ids.append(
torch.tensor([layershift(snac_config.end_of_audio, i)])
.view(1, -1)
.to(torch.int32)
.to(device)
)
model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
token_T = next_token_A1T1(
model,
None,
model_input_ids,
None,
None,
input_pos=input_pos,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
if token_T == eos_id_t:
break
output.append(token_T.clone().tolist()[0])
input_pos = input_pos.add_(1)
return output
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
"""Full definition of a decoder-only transformer-based language model, all of it in this single file.
Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
"""
import math
from typing import Any, Optional, Tuple
import torch
import torch.nn as nn
from typing_extensions import Self
from litgpt.config import Config
class GPT(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
assert config.padded_vocab_size is not None
self.config = config
if self.config.asr_adapter == "mlp":
print("Using MLP adapter for ASR feature")
self.whisper_adapter = nn.Linear(config.whisper_adapter_dim, config.n_embd)
elif self.config.asr_adapter == "llamamlp":
print("using LLAMA MLP adapter for ASR feature")
self.whisper_adapter = whisperMLP(config=config)
else:
raise ValueError("asr_adapter should be mlp or llamamlp")
self.lm_head = nn.Linear(
config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
)
self.vision_adapter = visionMLP(config = config)
if config.post_adapter:
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
post_adapter=nn.ModuleList(
Block(config) for _ in range(config.post_adapter_layers)
),
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
post_adapter_audio_ln=config.norm_class(
config.n_embd, eps=config.norm_eps
),
post_adapter_audio_lm_head=nn.Linear(
config.n_embd, config.cat_audio_vocab_size, bias=config.lm_head_bias
),
)
)
else:
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
)
)
self.max_seq_length = self.config.block_size
self.mask_cache: Optional[torch.Tensor] = None
if config.tie_word_embeddings:
self.lm_head.weight = self.transformer.wte.weight
@property
def max_seq_length(self) -> int:
return self._max_seq_length
@max_seq_length.setter
def max_seq_length(self, value: int) -> None:
"""
When doing inference, the sequences used might be shorter than the model's context length.
This allows setting a smaller number to avoid allocating unused memory
"""
if value > self.config.block_size:
raise ValueError(
f"Cannot attend to {value}, block size is only {self.config.block_size}"
)
self._max_seq_length = value
if not hasattr(self, "cos"):
# first call
cos, sin = self.rope_cache()
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
# override
elif value != self.cos.size(0):
self.cos, self.sin = self.rope_cache(device=self.cos.device)
# the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
# if the kv cache is expected
def reset_parameters(self) -> None:
# Trigger resetting the rope-cache
self.cos, self.sin = self.rope_cache(device=self.cos.device)
def _init_weights(self, module: nn.Module) -> None:
"""Meant to be used with `gpt.apply(gpt._init_weights)`."""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def concat_feat(self, audio_feature, clip_feature, input_ids, T, task):
for j in range(len(T)):
if task[j] != 'T1T2' and task[j] != 'T1A2' and task[j]!='ImageQA_T' and not task[j] == 'ImageCAP' and not task[j] == 'ImageQA_A' and not task[j] == 'ImageQA_AT':
for i in range(7):
input_ids[i][j,1:T[j]+1,:] = audio_feature[j][:T[j]].clone()
assert task[j] != 'ImageQ', "ImageQ should be concat with audio feature"
elif task[j] == 'ImageQA_A' or task[j] == 'ImageQA_AT':
print("concat ImageQA_A feature")
for i in range(7):
input_ids[i][j,1:51,:] = clip_feature[j].clone()
input_ids[i][j,52 : 52 + T[j],:] = audio_feature[j][:T[j]].clone()
elif task[j] == 'ImageQA_T' or task[j] =='ImageCAP':
for i in range(7):
input_ids[i][j,1:51,:] = clip_feature[j].clone()
return input_ids
def forward(
self,
audio_features: torch.Tensor,
input_ids: torch.Tensor,
clip_features: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
whisper_lens: Optional[list] = None,
task: Optional[str] = None,
) -> torch.Tensor:
show = False
T = input_ids[0].size(1)
if self.max_seq_length < T:
raise ValueError(
f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}."
)
if input_pos is not None: # use the kv cache
cos = self.cos.index_select(0, input_pos)
sin = self.sin.index_select(0, input_pos)
if self.mask_cache is None:
raise TypeError("You need to call `gpt.set_kv_cache()`")
mask = self.mask_cache.index_select(2, input_pos)
else:
cos = self.cos[:T]
sin = self.sin[:T]
mask = None
if audio_features is not None:
# get whisper feature
x_a = self.whisper_adapter(audio_features)
if clip_features is not None:
x_v = self.vision_adapter(clip_features)
else:
x_v = None
# get input_ids embedding
x0, x1, x2, x3, x4, x5, x6, x7 = input_ids
x0 = self.transformer.wte(x0)
x1 = self.transformer.wte(x1)
x2 = self.transformer.wte(x2)
x3 = self.transformer.wte(x3)
x4 = self.transformer.wte(x4)
x5 = self.transformer.wte(x5)
x6 = self.transformer.wte(x6)
x7 = self.transformer.wte(x7)
# concat whisper feature
input_emb = self.concat_feat(
x_a, x_v, [x0, x1, x2, x3, x4, x5, x6, x7], whisper_lens, task
)
x0, x1, x2, x3, x4, x5, x6, x7 = input_emb
else:
x0, x1, x2, x3, x4, x5, x6, x7 = input_ids
x0 = self.transformer.wte(x0)
x1 = self.transformer.wte(x1)
x2 = self.transformer.wte(x2)
x3 = self.transformer.wte(x3)
x4 = self.transformer.wte(x4)
x5 = self.transformer.wte(x5)
x6 = self.transformer.wte(x6)
x7 = self.transformer.wte(x7)
x = (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8
if self.config.scale_embeddings:
x = x * (self.config.n_embd**0.5)
for block in self.transformer.h:
x = block(x, cos, sin, mask, input_pos)
text_vocab_size = self.config.text_vocab_size
audio_vocab_size = self.config.audio_vocab_size
x_ori = x
x_ori = self.transformer.ln_f(x_ori)
x_ori = self.lm_head(x_ori) # (b, t, vocab_size)
xt = x_ori[..., :text_vocab_size]
if self.config.post_adapter:
for block in self.transformer.post_adapter:
x = block(x, cos, sin, mask, input_pos)
x = self.transformer.post_adapter_audio_ln(x)
x = self.transformer.post_adapter_audio_lm_head(x) # (b, t, vocab_size)
xa = []
for i in range(7):
xa.append(x[..., audio_vocab_size * i : audio_vocab_size * (i + 1)])
else:
xa = []
for i in range(7):
xa.append(x_ori[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)])
return xa, xt
@classmethod
def from_name(cls, name: str, **kwargs: Any) -> Self:
return cls(Config.from_name(name, **kwargs))
def rope_cache(
self, device: Optional[torch.device] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return build_rope_cache(
seq_len=self.max_seq_length,
n_elem=self.config.rope_n_elem,
device=device,
condense_ratio=self.config.rope_condense_ratio,
base=self.config.rope_base,
)
def set_kv_cache(
self,
batch_size: int,
rope_cache_length: Optional[int] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
if rope_cache_length is None:
rope_cache_length = self.cos.size(-1)
max_seq_length = self.max_seq_length
# initialize the kv cache for all blocks
for block in self.transformer.h:
block.attn.kv_cache = block.attn.build_kv_cache(
batch_size, max_seq_length, rope_cache_length, device, dtype
)
if self.config.post_adapter:
for block in self.transformer.post_adapter:
block.attn.kv_cache = block.attn.build_kv_cache(
batch_size, max_seq_length, rope_cache_length, device, dtype
)
if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
# passing `attn_mask` to SDPA disables the flash implementation. since we only need the mask
# for the kv-cache support (only during inference), we only create it in that situation
self.mask_cache = build_mask_cache(max_seq_length, device)
def clear_kv_cache(self) -> None:
self.mask_cache = None
for block in self.transformer.h:
block.attn.kv_cache = None
class visionMLP(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
vision_adapter_dim = config.vision_adapter_dim
self.fc_1 = nn.Linear(vision_adapter_dim, config.intermediate_size, bias=config.bias)
self.fc_2 = nn.Linear(vision_adapter_dim, config.intermediate_size, bias=config.bias)
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
self.config = config
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = torch.nn.functional.silu(x_fc_1) * x_fc_2
return self.proj(x)
class Block(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
if not config.parallel_residual and config.shared_attention_norm:
raise NotImplementedError(
"No checkpoint amongst the ones we support uses this configuration"
" (non-parallel residual and shared attention norm)."
)
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
self.attn = CausalSelfAttention(config)
self.norm_2 = (
None
if config.shared_attention_norm
else config.norm_class(config.n_embd, eps=config.norm_eps)
)
self.mlp = config.mlp_class(config)
self.config = config
def forward(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Non-parallel residual Parallel residual
┌─ x ┌─ x ────────────┐ Note: if `shared_attention_norm` is True,
│ ↓ │ ↓ ↓ the output from `norm_1` is reused
│ norm_1 │ norm_1 ───► norm_2
│ ↓ │ ↓ ↓
│ attn │ attn mlp
│ ↓ │ ↓ │
┌─ └► + └► + ◄───────────┘
│ norm_2
│ ↓
│ mlp
│ ↓
└───► +
"""
x_normed = self.norm_1(x)
attention_output = self.attn(x_normed, cos, sin, mask, input_pos)
if self.config.parallel_residual:
x_normed = x_normed if self.config.shared_attention_norm else self.norm_2(x)
x = self.mlp(x_normed) + attention_output + x
else:
x = attention_output + x
x = self.mlp(self.norm_2(x)) + x
return x
class CausalSelfAttention(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
# key, query, value projections for all heads, but in a batch
self.attn = nn.Linear(config.n_embd, shape, bias=config.add_qkv_bias)
# output projection
# if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
self.proj = nn.Linear(
config.head_size * config.n_head, config.n_embd, bias=config.bias
)
# disabled by default
self.kv_cache: Optional[KVCache] = None
self.config = config
def forward(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, T, C = (
x.size()
) # batch size, sequence length, embedding dimensionality (n_embd)
qkv = self.attn(x)
# assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
q_per_kv = self.config.n_head // self.config.n_query_groups
total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
qkv = qkv.view(
B, T, self.config.n_query_groups, total_qkv, self.config.head_size
)
qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
# split batched computation into three
q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
# maybe repeat k and v if for the non multi-head attention cases
# training: flash attention requires it
# inference: multi-query would require a full kv cache so avoid it to limit its memory usage
if self.config.n_query_groups != self.config.n_head and (
input_pos is None or self.config.n_query_groups != 1
):
k = k.expand(
B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
)
v = v.expand(
B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
)
q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)
q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
if input_pos is not None:
if not isinstance(self.kv_cache, KVCache):
raise TypeError("You need to call `gpt.set_kv_cache()`")
k, v = self.kv_cache(input_pos, k, v)
y = self.scaled_dot_product_attention(q, k, v, mask)
y = y.reshape(
B, T, self.config.head_size * self.config.n_head
) # re-assemble all head outputs side by side
# output projection
return self.proj(y)
def scaled_dot_product_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
scale = 1.0 / math.sqrt(self.config.head_size)
y = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
)
return y.transpose(1, 2)
def build_kv_cache(
self,
batch_size: int,
max_seq_length: int,
rope_cache_length: Optional[int] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> "KVCache":
heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
v_shape = (batch_size, heads, max_seq_length, self.config.head_size)
if rope_cache_length is None:
if self.config.rotary_percentage != 1.0:
raise TypeError(
"Please pass the `rope_cache_length=gpt.cos.size(-1)` value"
)
k_shape = v_shape
else:
k_shape = (
batch_size,
heads,
max_seq_length,
rope_cache_length + self.config.head_size - self.config.rope_n_elem,
)
return KVCache(k_shape, v_shape, device=device, dtype=dtype)
class GptNeoxMLP(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
self.config = config
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc(x)
x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate)
return self.proj(x)
class LLaMAMLP(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
self.config = config
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = torch.nn.functional.silu(x_fc_1) * x_fc_2
return self.proj(x)
class whisperMLP(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.fc_1 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
self.fc_2 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
self.config = config
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = torch.nn.functional.silu(x_fc_1) * x_fc_2
return self.proj(x)
class GemmaMLP(LLaMAMLP):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fc_1 = self.fc_1(x)
x_fc_2 = self.fc_2(x)
x = (
torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate)
* x_fc_2
)
return self.proj(x)
class LLaMAMoE(nn.Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.gate = nn.Linear(config.n_embd, config.n_expert, bias=False)
self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert))
self.config = config
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Derived from: https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
See also figure 1 in https://arxiv.org/abs/2211.15841
"""
B, T, C = (
x.size()
) # batch size, sequence length, embedding dimensionality (n_embd)
x = x.view(-1, C) # (B*T, C)
router = self.gate(x) # (B*T, n_expert)
probs, indices = torch.topk(
router, self.config.n_expert_per_token
) # (B*T, n_expert_per_token)
probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)
masks = indices.unsqueeze(-1) == torch.arange(
self.config.n_expert, device=x.device
)
masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token)
y = torch.zeros_like(x) # (B*T, C)
for mask, expert in zip(masks, self.experts):
token_idx, expert_idx = torch.where(mask)
y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx])
return y.view(B, T, C)
def build_rope_cache(
seq_len: int,
n_elem: int,
device: Optional[torch.device] = None,
base: int = 10000,
condense_ratio: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Enhanced Transformer with Rotary Position Embedding.
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
"""
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=device) / condense_ratio
# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
return torch.cos(idx_theta), torch.sin(idx_theta)
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
head_size = x.size(-1)
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
roped = (x * cos) + (rotated * sin)
return roped.to(dtype=x.dtype)
class KVCache(nn.Module):
def __init__(
self,
k_shape: Tuple[int, int, int, int],
v_shape: Tuple[int, int, int, int],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()
self.register_buffer(
"k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False
)
self.register_buffer(
"v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False
)
def forward(
self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# move the buffer to the activation dtype for when AMP is used
self.k = self.k.to(k.dtype)
self.v = self.v.to(v.dtype)
# update the cache
k = self.k.index_copy_(2, input_pos, k)
v = self.v.index_copy_(2, input_pos, v)
return k, v
def reset_parameters(self) -> None:
torch.nn.init.zeros_(self.k)
torch.nn.init.zeros_(self.v)
def build_mask_cache(
max_seq_length: int, device: Optional[torch.device] = None
) -> torch.Tensor:
ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
return torch.tril(ones).unsqueeze(0).unsqueeze(0)
class RMSNorm(torch.nn.Module):
"""Root Mean Square Layer Normalization.
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
"""
def __init__(
self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False
) -> None:
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(size))
self.eps = eps
self.dim = dim
self.add_unit_offset = add_unit_offset
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
x = x.float()
# NOTE: the original RMSNorm paper implementation is not equivalent
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
x_normed = x * torch.rsqrt(norm_x + self.eps)
x_normed = x_normed.to(dtype=dtype)
if self.add_unit_offset:
# Gemma model requires a unit offset
# https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176
return x_normed * (1 + self.weight)
return x_normed * self.weight
def reset_parameters(self) -> None:
torch.nn.init.ones_(self.weight)
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import json
from pathlib import Path
from typing import Optional, Union
import torch
class Tokenizer:
def __init__(self, checkpoint_dir: Union[Path, str]) -> None:
checkpoint_dir = Path(checkpoint_dir)
if not checkpoint_dir.exists():
raise NotADirectoryError(
f"The checkpoint directory does not exist: {str(checkpoint_dir)}"
)
self.use_bos = self.check_if_bos_token_used(checkpoint_dir)
self.bos_id = None
self.eos_id = None
# some checkpoints have both files, `.json` takes precedence
if (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file():
from tokenizers import Tokenizer as HFTokenizer
self.processor = HFTokenizer.from_file(str(vocabulary_path))
self.backend = "huggingface"
if (
special_tokens_path := checkpoint_dir / "tokenizer_config.json"
).is_file():
with open(special_tokens_path, encoding="utf-8") as fp:
config = json.load(fp)
bos_token = config.get("bos_token")
eos_token = config.get("eos_token")
if bos_token is not None and isinstance(bos_token, dict):
bos_token = bos_token.get("content")
if eos_token is not None and isinstance(eos_token, dict):
eos_token = eos_token.get("content")
self.bos_id = (
self.token_to_id(bos_token) if bos_token is not None else None
)
self.eos_id = (
self.token_to_id(eos_token) if eos_token is not None else None
)
if (
special_tokens_path := checkpoint_dir / "generation_config.json"
).is_file():
with open(special_tokens_path, encoding="utf-8") as fp:
config = json.load(fp)
if self.bos_id is None:
self.bos_id = config.get("bos_token_id")
if self.eos_id is None:
self.eos_id = config.get("eos_token_id")
elif (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file():
from sentencepiece import SentencePieceProcessor
self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))
self.backend = "sentencepiece"
self.bos_id = self.processor.bos_id()
self.eos_id = self.processor.eos_id()
else:
raise NotImplementedError
@property
def vocab_size(self) -> int:
if self.backend == "huggingface":
return self.processor.get_vocab_size(with_added_tokens=False)
if self.backend == "sentencepiece":
return self.processor.vocab_size()
raise RuntimeError
def token_to_id(self, token: str) -> int:
if self.backend == "huggingface":
id_ = self.processor.token_to_id(token)
elif self.backend == "sentencepiece":
id_ = self.processor.piece_to_id(token)
else:
raise RuntimeError
if id_ is None:
raise ValueError(f"token {token!r} not found in the collection.")
return id_
def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
if not (
tokenizer_config_path := checkpoint_dir / "tokenizer_config.json"
).is_file():
return False
with open(tokenizer_config_path, encoding="utf-8") as fp:
config = json.load(fp)
if "add_bos_token" in config:
return config["add_bos_token"]
# if `add_bos_token` isn't in the config file, but LLaMA tokenizer is used - return True.
# ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2
return config.get("tokenizer_class") == "LlamaTokenizer"
def encode(
self,
string: str,
device: Optional[torch.device] = None,
bos: Optional[bool] = None,
eos: bool = False,
max_length: int = -1,
) -> torch.Tensor:
if self.backend == "huggingface":
tokens = self.processor.encode(string).ids
elif self.backend == "sentencepiece":
tokens = self.processor.encode(string)
else:
raise RuntimeError
if bos or (bos is None and self.use_bos):
bos_id = self.bos_id
if bos_id is None:
raise NotImplementedError(
"This tokenizer does not have a defined a bos token"
)
if tokens[0] != bos_id:
tokens = [bos_id] + tokens
if tokens is None:
raise ValueError("`tokens` is None")
if eos and (not tokens or tokens[-1] != self.eos_id):
tokens = tokens + [self.eos_id]
if max_length > 0:
tokens = tokens[:max_length]
return torch.tensor(tokens, dtype=torch.int, device=device)
def decode(self, tensor: torch.Tensor) -> str:
tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
return self.processor.decode(tokens)
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
"""Utility functions for training and inference."""
import inspect
import math
import os
import pickle
import shutil
import sys
from dataclasses import asdict, is_dataclass
from io import BytesIO
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
TypeVar,
Union,
)
import lightning as L
import torch
import torch.nn as nn
import torch.utils._device
import yaml
from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.load import _lazy_load as lazy_load
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.cli import instantiate_class
from torch.serialization import normalize_storage_type
from typing_extensions import Self
if TYPE_CHECKING:
from litgpt import GPT, Config
def init_out_dir(out_dir: Path) -> Path:
if not out_dir.is_absolute() and "LIGHTNING_ARTIFACTS_DIR" in os.environ:
return Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / out_dir
return out_dir
def find_resume_path(
resume: Union[bool, Literal["auto"], Path], out_dir: Path
) -> Optional[Path]:
if not resume or isinstance(resume, Path):
return resume
resume_path = max(
out_dir.rglob("step-*/*.pth"),
key=(lambda p: int(p.parent.name.split("-")[1])),
default=None,
)
if resume == "auto":
return resume_path
if resume is True and resume_path is None:
raise FileNotFoundError(
f"You passed `--resume=True`, but no checkpont file was found in `--out_dir={out_dir}`."
)
return resume_path
def find_multiple(n: int, k: int) -> int:
assert k > 0
if n % k == 0:
return n
return n + k - (n % k)
def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
total = 0
for p in module.parameters():
if requires_grad is None or p.requires_grad == requires_grad:
if hasattr(p, "quant_state"):
# bitsandbytes 4bit layer support
total += math.prod(p.quant_state.shape)
else:
total += p.numel()
return total
def reset_parameters(module: nn.Module) -> None:
"""Calls `reset_parameters` on the module and all its submodules."""
for mod in module.modules():
if callable(getattr(mod, "reset_parameters", None)):
mod.reset_parameters()
def check_valid_checkpoint_dir(
checkpoint_dir: Path,
model_filename: str = "lit_model.pth",
verbose: bool = True,
raise_error: bool = False,
) -> None:
files = {
model_filename: (checkpoint_dir / model_filename).is_file(),
"model_config.yaml": (checkpoint_dir / "model_config.yaml").is_file(),
"tokenizer.json OR tokenizer.model": (
checkpoint_dir / "tokenizer.json"
).is_file()
or (checkpoint_dir / "tokenizer.model").is_file(),
"tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
}
if checkpoint_dir.is_dir():
if all(files.values()):
# we're good
return
problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}"
else:
problem = " is not a checkpoint directory"
# list locally available checkpoints
available = list(Path("checkpoints").glob("*/*"))
if available:
options = "\n".join([""] + [repr(str(p.resolve())) for p in available])
extra = f"\nYou have downloaded locally:{options}\n"
else:
extra = ""
if verbose:
error_message = (
f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
"\nFind download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials\n"
f"{extra}\nSee all download options by running:\n litgpt download"
)
print(error_message, file=sys.stderr)
if raise_error:
raise FileNotFoundError(
f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
)
else:
raise SystemExit(1)
class SavingProxyForStorage:
def __init__(self, obj, saver, protocol_version=5):
self.protocol_version = protocol_version
self.saver = saver
if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
raise TypeError(f"expected storage, not {type(obj)}")
# this logic is taken from PyTorch 2.0+ torch/serialization.py
if isinstance(obj, torch.storage.TypedStorage):
# PT upstream wants to deprecate this eventually...
storage = obj._untyped_storage
storage_type_str = obj._pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
storage_numel = obj._size()
else:
storage = obj
storage_type = normalize_storage_type(type(obj))
storage_numel = storage.nbytes()
storage_key = saver._write_storage_and_return_key(storage)
location = torch.serialization.location_tag(storage)
self.storage_info = (
"storage",
storage_type,
storage_key,
location,
storage_numel,
)
def __reduce_ex__(self, protocol_version):
assert False, "this should be handled with out of band"
class SavingProxyForTensor:
def __init__(self, tensor, saver, protocol_version=5):
self.protocol_version = protocol_version
self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version)
if reduce_args[0] == torch._utils._rebuild_tensor_v2:
# for Tensors with Python attributes
(a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
assert isinstance(
storage, torch.storage.TypedStorage
), "Please check for updates"
storage_proxy = SavingProxyForStorage(
storage, saver, protocol_version=protocol_version
)
self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
else:
(storage, *other_reduce_args) = reduce_args
assert isinstance(
storage, torch.storage.TypedStorage
), "Please check for updates"
storage_proxy = SavingProxyForStorage(
storage, saver, protocol_version=protocol_version
)
self.reduce_args = (storage_proxy, *other_reduce_args)
def __reduce_ex__(self, protocol_version):
if protocol_version != self.protocol_version:
raise RuntimeError(
f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}"
)
return self.reduce_ret_fn, self.reduce_args
class IncrementalPyTorchPickler(pickle.Pickler):
def __init__(self, saver, *args, **kwargs):
super().__init__(*args, **kwargs)
self.storage_dtypes = {}
self.saver = saver
self.id_map = {}
# this logic is taken from PyTorch 2.0+ torch/serialization.py
def persistent_id(self, obj):
# FIXME: the docs say that persistent_id should only return a string
# but torch store returns tuples. This works only in the binary protocol
# see
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
if isinstance(obj, SavingProxyForStorage):
return obj.storage_info
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
storage = obj._untyped_storage
storage_dtype = obj.dtype
storage_type_str = obj._pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
storage_numel = obj._size()
else:
storage = obj
storage_dtype = torch.uint8
storage_type = normalize_storage_type(type(obj))
storage_numel = storage.nbytes()
# If storage is allocated, ensure that any other saved storages
# pointing to the same data all have the same dtype. If storage is
# not allocated, don't perform this check
if storage.data_ptr() != 0:
if storage.data_ptr() in self.storage_dtypes:
if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
raise RuntimeError(
"Cannot save multiple tensors or storages that view the same data as different types"
)
else:
self.storage_dtypes[storage.data_ptr()] = storage_dtype
storage_key = self.id_map.get(storage._cdata)
if storage_key is None:
storage_key = self.saver._write_storage_and_return_key(storage)
self.id_map[storage._cdata] = storage_key
location = torch.serialization.location_tag(storage)
return ("storage", storage_type, storage_key, location, storage_numel)
return None
class incremental_save:
def __init__(self, name):
self.name = name
self.zipfile = torch._C.PyTorchFileWriter(str(name))
self.has_saved = False
self.next_key = 0
def __enter__(self):
return self
def store_early(self, tensor):
if isinstance(tensor, torch.Tensor):
return SavingProxyForTensor(tensor, self)
raise TypeError(f"can only store tensors early, not {type(tensor)}")
def save(self, obj):
if self.has_saved:
raise RuntimeError("have already saved")
# Write the pickle data for `obj`
data_buf = BytesIO()
pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
pickler.dump(obj)
data_value = data_buf.getvalue()
self.zipfile.write_record("data.pkl", data_value, len(data_value))
self.has_saved = True
def _write_storage_and_return_key(self, storage):
if self.has_saved:
raise RuntimeError("have already saved")
key = self.next_key
self.next_key += 1
name = f"data/{key}"
if storage.device.type != "cpu":
storage = storage.cpu()
num_bytes = storage.nbytes()
self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
return key
def __exit__(self, type, value, traceback):
self.zipfile.write_end_of_file()
T = TypeVar("T")
def chunked_cross_entropy(
logits: Union[torch.Tensor, List[torch.Tensor]],
targets: torch.Tensor,
chunk_size: int = 128,
ignore_index: int = -100,
) -> torch.Tensor:
# with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
# the memory usage in fine-tuning settings with low number of parameters.
# as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing
# the memory spike's magnitude
# lm_head was chunked (we are fine-tuning)
if isinstance(logits, list):
# don't want to chunk cross entropy
if chunk_size == 0:
logits = torch.cat(logits, dim=1)
logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
return torch.nn.functional.cross_entropy(
logits, targets, ignore_index=ignore_index
)
# chunk cross entropy
logit_chunks = [
logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits
]
target_chunks = [
target_chunk.reshape(-1)
for target_chunk in targets.split(logits[0].size(1), dim=1)
]
loss_chunks = [
torch.nn.functional.cross_entropy(
logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none"
)
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
]
non_masked_elems = (targets != ignore_index).sum()
# See [non_masked_elems div note]
return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(
torch.ones_like(non_masked_elems)
)
# no chunking at all
logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
if chunk_size == 0:
return torch.nn.functional.cross_entropy(
logits, targets, ignore_index=ignore_index
)
# lm_head wasn't chunked, chunk cross entropy
logit_chunks = logits.split(chunk_size)
target_chunks = targets.split(chunk_size)
loss_chunks = [
torch.nn.functional.cross_entropy(
logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none"
)
for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
]
non_masked_elems = (targets != ignore_index).sum()
# [non_masked_elems div note]:
# max(1, non_masked_elems) would be more ergonomic to avoid a division by zero. However that
# results in a python int which is then passed back to torch division. By using the
# `x.maximum(torch.ones_like(x))` pattern we avoid a cudaStreamSynchronize.
return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(
torch.ones_like(non_masked_elems)
)
def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
for checkpoint_name, attribute_name in mapping.items():
full_checkpoint_name = prefix + checkpoint_name
if full_checkpoint_name in state_dict:
full_attribute_name = prefix + attribute_name
state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)
return state_dict
def get_default_supported_precision(training: bool) -> str:
"""Return default precision that is supported by the hardware: either `bf16` or `16`.
Args:
training: `-mixed` or `-true` version of the precision to use
Returns:
default precision that is suitable for the task and is supported by the hardware
"""
from lightning.fabric.accelerators import MPSAccelerator
if MPSAccelerator.is_available() or (
torch.cuda.is_available() and not torch.cuda.is_bf16_supported()
):
return "16-mixed" if training else "16-true"
return "bf16-mixed" if training else "bf16-true"
def load_checkpoint(
fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True
) -> None:
if isinstance(fabric.strategy, FSDPStrategy):
fabric.load_raw(checkpoint_path, model, strict=strict)
else:
state_dict = lazy_load(checkpoint_path)
state_dict = state_dict.get("model", state_dict)
model.load_state_dict(state_dict, strict=strict)
def flops_per_param(
max_seq_length: int, n_layer: int, n_embd: int, n_params: int
) -> int:
flops_per_token = (
2 * n_params
) # each parameter is used for a MAC (2 FLOPS) per network operation
# this assumes that all samples have a fixed length equal to the block size
# which is most likely false during finetuning
flops_per_seq = flops_per_token * max_seq_length
attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
return flops_per_seq + attn_flops_per_seq
def estimate_flops(model: "GPT", training: bool) -> int:
"""Measures estimated FLOPs for MFU.
Refs:
* https://ar5iv.labs.arxiv.org/html/2205.05198#A1
* https://ar5iv.labs.arxiv.org/html/2204.02311#A2
"""
# using all parameters for this is a naive over estimation because not all model parameters actually contribute to
# this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
# (~10%) compared to the measured FLOPs, making those lower but more realistic.
# For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
n_trainable_params = num_parameters(model, requires_grad=True)
trainable_flops = flops_per_param(
model.max_seq_length,
model.config.n_layer,
model.config.n_embd,
n_trainable_params,
)
# forward + backward + gradients (assumes no gradient accumulation)
ops_per_step = 3 if training else 1
n_frozen_params = num_parameters(model, requires_grad=False)
frozen_flops = flops_per_param(
model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params
)
# forward + backward
frozen_ops_per_step = 2 if training else 1
return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops
class CycleIterator:
"""An iterator that cycles through an iterable indefinitely.
Example:
>>> iterator = CycleIterator([1, 2, 3])
>>> [next(iterator) for _ in range(5)]
[1, 2, 3, 1, 2]
Note:
Unlike ``itertools.cycle``, this iterator does not cache the values of the iterable.
"""
def __init__(self, iterable: Iterable) -> None:
self.iterable = iterable
self.epoch = 0
self._iterator = None
def __next__(self) -> Any:
if self._iterator is None:
self._iterator = iter(self.iterable)
try:
return next(self._iterator)
except StopIteration:
self._iterator = iter(self.iterable)
self.epoch += 1
return next(self._iterator)
def __iter__(self) -> Self:
return self
def copy_config_files(source_dir: Path, out_dir: Path) -> None:
"""Copies the specified configuration and tokenizer files into the output directory."""
config_files = ["config.json", "generation_config.json", "model_config.yaml"]
tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"]
for file_name in config_files + tokenizer_files:
src_path = source_dir / file_name
if src_path.exists():
shutil.copy(src_path, out_dir)
def CLI(*args: Any, **kwargs: Any) -> Any:
from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options
set_docstring_parse_options(attribute_docstrings=True)
set_config_read_mode(urls_enabled=True)
return CLI(*args, **kwargs)
def capture_hparams() -> Dict[str, Any]:
"""Captures the local variables ('hyperparameters') from where this function gets called."""
caller_frame = inspect.currentframe().f_back
locals_of_caller = caller_frame.f_locals
hparams = {}
for name, value in locals_of_caller.items():
if value is None or isinstance(value, (int, float, str, bool, Path)):
hparams[name] = value
elif is_dataclass(value):
hparams[name] = asdict(value)
else:
hparams[name] = str(value)
return hparams
def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None:
"""Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint."""
from jsonargparse import capture_parser
# TODO: Make this more robust
# This hack strips away the subcommands from the top-level CLI
# to parse the file as if it was called as a script
known_commands = [
("finetune_full",), # For subcommands, use `("finetune", "full")` etc
("finetune_lora",),
("finetune_adapter",),
("finetune_adapter_v2",),
("finetune",),
("pretrain",),
]
for known_command in known_commands:
unwanted = slice(1, 1 + len(known_command))
if tuple(sys.argv[unwanted]) == known_command:
sys.argv[unwanted] = []
parser = capture_parser(lambda: CLI(function))
config = parser.parse_args()
parser.save(config, checkpoint_dir / "hyperparameters.yaml", overwrite=True)
def save_config(config: "Config", checkpoint_dir: Path) -> None:
config_dict = asdict(config)
with open(checkpoint_dir / "model_config.yaml", "w", encoding="utf-8") as fp:
yaml.dump(config_dict, fp)
def parse_devices(devices: Union[str, int]) -> int:
if devices in (-1, "auto"):
return torch.cuda.device_count() or 1
if isinstance(devices, int) and devices > 0:
return devices
raise ValueError(f"Devices must be 'auto' or a positive integer, got: {devices!r}")
def choose_logger(
logger_name: Literal["csv", "tensorboard", "wandb"],
out_dir: Path,
name: str,
log_interval: int = 1,
resume: Optional[bool] = None,
**kwargs: Any,
):
if logger_name == "csv":
return CSVLogger(
root_dir=(out_dir / "logs"),
name="csv",
flush_logs_every_n_steps=log_interval,
**kwargs,
)
if logger_name == "tensorboard":
return TensorBoardLogger(
root_dir=(out_dir / "logs"), name="tensorboard", **kwargs
)
if logger_name == "wandb":
return WandbLogger(project=name, resume=resume, **kwargs)
raise ValueError(
f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'."
)
def get_argument_names(cls):
sig = inspect.signature(cls.__init__)
return {
name
for name, param in sig.parameters.items()
if param.kind
in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY]
}
def instantiate_bnb_optimizer(optimizer, model_parameters):
if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (
isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")
):
raise ValueError(
"The chosen quantization format only supports the AdamW optimizer."
)
import bitsandbytes as bnb
if isinstance(optimizer, str):
optimizer = bnb.optim.PagedAdamW(model_parameters)
else:
optim_args = get_argument_names(bnb.optim.PagedAdamW)
allowed_kwargs = {
key: optimizer["init_args"][key]
for key in optim_args & optimizer["init_args"].keys()
}
optimizer = bnb.optim.PagedAdamW(model_parameters, **allowed_kwargs)
return optimizer
def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs):
if isinstance(optimizer, str):
optimizer_cls = getattr(torch.optim, optimizer)
optimizer = optimizer_cls(model_parameters, **kwargs)
else:
optimizer = dict(optimizer) # copy
optimizer["init_args"].update(kwargs)
optimizer = instantiate_class(model_parameters, optimizer)
return optimizer
def extend_checkpoint_dir(checkpoint_dir: Path) -> Path:
new_checkpoint_dir = "checkpoints" / checkpoint_dir
should_return_new_dir = (
not checkpoint_dir.is_dir()
and checkpoint_dir.parts[0] != "checkpoints"
and not checkpoint_dir.is_absolute()
and new_checkpoint_dir.exists()
)
return new_checkpoint_dir if should_return_new_dir else checkpoint_dir
# 模型编码
modelCode=1054
# 模型名称
modelName=mini-omni2_pytorch
# 模型描述
modelDescription=Mini-Omni2是个视觉-音频助理,能同时处理视觉、听觉和文本三种多模态,实时提供端到端的语音对话响应。
# 应用场景
appScenario=推理,对话问答,制造,广媒,金融,能源,医疗,家居,教育
# 框架类型
frameType=pytorch
#torch==2.3.1
#torchvision==0.18.1
#torchaudio==2.3.1
litgpt==0.4.3
snac==1.2.0
soundfile==0.12.1
openai-whisper
streamlit==1.37.1
# PyAudio==0.2.14
pydub==0.25.1
onnxruntime==1.19.0
# numpy==1.26.3
gradio==4.42.0
librosa==0.10.2.post1
#flask==3.0.3
fire
transformers==4.45.2
tokenizers==0.20.1
#git+https://github.com/mini-omni/CLIP.git
import flask
import base64
import tempfile
import traceback
from flask import Flask, Response, stream_with_context
from inference_vision import OmniVisionInference
class OmniChatServer(object):
def __init__(self, ip='0.0.0.0', port=60808, run_app=True,
ckpt_dir='./checkpoint', device='cuda:0') -> None:
server = Flask(__name__)
# CORS(server, resources=r"/*")
# server.config["JSON_AS_ASCII"] = False
self.client = OmniVisionInference(ckpt_dir, device)
self.client.warm_up()
server.route("/chat", methods=["POST"])(self.chat)
if run_app:
server.run(host=ip, port=port, threaded=False)
else:
self.server = server
def chat(self) -> Response:
req_data = flask.request.get_json()
try:
audio_data_buf = req_data["audio"].encode("utf-8")
audio_data_buf = base64.b64decode(audio_data_buf)
stream_stride = req_data.get("stream_stride", 4)
max_tokens = req_data.get("max_tokens", 2048)
image_data_buf = req_data.get("image", None)
if image_data_buf:
image_data_buf = image_data_buf.encode("utf-8")
image_data_buf = base64.b64decode(image_data_buf)
audio_path, img_path = None, None
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_f, \
tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as img_f:
audio_f.write(audio_data_buf)
audio_path = audio_f.name
if image_data_buf:
img_f.write(image_data_buf)
img_path = img_f.name
else:
img_path = None
if img_path is not None:
resp_generator = self.client.run_vision_AA_batch_stream(audio_f.name, img_f.name,
stream_stride, max_tokens,
save_path='./vision_qa_out_cache.wav')
else:
resp_generator = self.client.run_AT_batch_stream(audio_f.name, stream_stride,
max_tokens,
save_path='./audio_qa_out_cache.wav')
return Response(stream_with_context(self.generator(resp_generator)),
mimetype='multipart/x-mixed-replace; boundary=frame')
except Exception as e:
print(traceback.format_exc())
return Response("An error occurred", status=500)
def generator(self, resp_generator):
for audio_stream, text_stream in resp_generator:
yield b'\r\n--frame\r\n'
yield b'Content-Type: audio/wav\r\n\r\n'
yield audio_stream
yield b'\r\n--frame\r\n'
yield b'Content-Type: text/plain\r\n\r\n'
yield text_stream.encode()
# CUDA_VISIBLE_DEVICES=1 gunicorn -w 2 -b 0.0.0.0:60808 'server:create_app()'
def create_app():
server = OmniChatServer(run_app=False)
return server.server
def serve(ip='0.0.0.0', port=60808, device='cuda:0'):
OmniChatServer(ip, port=port,run_app=True, device=device)
if __name__ == "__main__":
import fire
fire.Fire(serve)
import torch
import time
import numpy as np
class SnacConfig:
audio_vocab_size = 4096
padded_vocab_size = 4160
end_of_audio = 4097
snac_config = SnacConfig()
def get_time_str():
time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
return time_str
def layershift(input_id, layer, stride=4160, shift=152000):
return input_id + shift + layer * stride
def generate_audio_data(snac_tokens, snacmodel, device=None):
audio = reconstruct_tensors(snac_tokens, device)
with torch.inference_mode():
audio_hat = snacmodel.decode(audio)
audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
audio_data = audio_data.astype(np.int16)
audio_data = audio_data.tobytes()
return audio_data
def get_snac(list_output, index, nums_generate):
snac = []
start = index
for i in range(nums_generate):
snac.append("#")
for j in range(7):
snac.append(list_output[j][start - nums_generate - 5 + j + i])
return snac
def reconscruct_snac(output_list):
if len(output_list) == 8:
output_list = output_list[:-1]
output = []
for i in range(7):
output_list[i] = output_list[i][i + 1 :]
for i in range(len(output_list[-1])):
output.append("#")
for j in range(7):
output.append(output_list[j][i])
return output
def reconstruct_tensors(flattened_output, device=None):
"""Reconstructs the list of tensors from the flattened output."""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def count_elements_between_hashes(lst):
try:
# Find the index of the first '#'
first_index = lst.index("#")
# Find the index of the second '#' after the first
second_index = lst.index("#", first_index + 1)
# Count the elements between the two indices
return second_index - first_index - 1
except ValueError:
# Handle the case where there aren't enough '#' symbols
return "List does not contain two '#' symbols"
def remove_elements_before_hash(flattened_list):
try:
# Find the index of the first '#'
first_hash_index = flattened_list.index("#")
# Return the list starting from the first '#'
return flattened_list[first_hash_index:]
except ValueError:
# Handle the case where there is no '#'
return "List does not contain the symbol '#'"
def list_to_torch_tensor(tensor1):
# Convert the list to a torch tensor
tensor = torch.tensor(tensor1)
# Reshape the tensor to have size (1, n)
tensor = tensor.unsqueeze(0)
return tensor
flattened_output = remove_elements_before_hash(flattened_output)
codes = []
tensor1 = []
tensor2 = []
tensor3 = []
tensor4 = []
n_tensors = count_elements_between_hashes(flattened_output)
if n_tensors == 7:
for i in range(0, len(flattened_output), 8):
tensor1.append(flattened_output[i + 1])
tensor2.append(flattened_output[i + 2])
tensor3.append(flattened_output[i + 3])
tensor3.append(flattened_output[i + 4])
tensor2.append(flattened_output[i + 5])
tensor3.append(flattened_output[i + 6])
tensor3.append(flattened_output[i + 7])
codes = [
list_to_torch_tensor(tensor1).to(device),
list_to_torch_tensor(tensor2).to(device),
list_to_torch_tensor(tensor3).to(device),
]
if n_tensors == 15:
for i in range(0, len(flattened_output), 16):
tensor1.append(flattened_output[i + 1])
tensor2.append(flattened_output[i + 2])
tensor3.append(flattened_output[i + 3])
tensor4.append(flattened_output[i + 4])
tensor4.append(flattened_output[i + 5])
tensor3.append(flattened_output[i + 6])
tensor4.append(flattened_output[i + 7])
tensor4.append(flattened_output[i + 8])
tensor2.append(flattened_output[i + 9])
tensor3.append(flattened_output[i + 10])
tensor4.append(flattened_output[i + 11])
tensor4.append(flattened_output[i + 12])
tensor3.append(flattened_output[i + 13])
tensor4.append(flattened_output[i + 14])
tensor4.append(flattened_output[i + 15])
codes = [
list_to_torch_tensor(tensor1).to(device),
list_to_torch_tensor(tensor2).to(device),
list_to_torch_tensor(tensor3).to(device),
list_to_torch_tensor(tensor4).to(device),
]
return codes
import bisect
import functools
import os
import warnings
from typing import List, NamedTuple, Optional
import numpy as np
# The code below is adapted from https://github.com/snakers4/silero-vad.
class VadOptions(NamedTuple):
"""VAD options.
Attributes:
threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
than max_speech_duration_s will be split at the timestamp of the last silence that
lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
split aggressively just before max_speech_duration_s.
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
before separating it
window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
Values other than these may affect model performance!!
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
"""
threshold: float = 0.5
min_speech_duration_ms: int = 250
max_speech_duration_s: float = float("inf")
min_silence_duration_ms: int = 2000
window_size_samples: int = 1024
speech_pad_ms: int = 400
def get_speech_timestamps(
audio: np.ndarray,
vad_options: Optional[VadOptions] = None,
**kwargs,
) -> List[dict]:
"""This method is used for splitting long audios into speech chunks using silero VAD.
Args:
audio: One dimensional float array.
vad_options: Options for VAD processing.
kwargs: VAD options passed as keyword arguments for backward compatibility.
Returns:
List of dicts containing begin and end samples of each speech chunk.
"""
if vad_options is None:
vad_options = VadOptions(**kwargs)
threshold = vad_options.threshold
min_speech_duration_ms = vad_options.min_speech_duration_ms
max_speech_duration_s = vad_options.max_speech_duration_s
min_silence_duration_ms = vad_options.min_silence_duration_ms
window_size_samples = vad_options.window_size_samples
speech_pad_ms = vad_options.speech_pad_ms
if window_size_samples not in [512, 1024, 1536]:
warnings.warn(
"Unusual window_size_samples! Supported window_size_samples:\n"
" - [512, 1024, 1536] for 16000 sampling_rate"
)
sampling_rate = 16000
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
max_speech_samples = (
sampling_rate * max_speech_duration_s
- window_size_samples
- 2 * speech_pad_samples
)
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
audio_length_samples = len(audio)
model = get_vad_model()
state = model.get_initial_state(batch_size=1)
speech_probs = []
for current_start_sample in range(0, audio_length_samples, window_size_samples):
chunk = audio[current_start_sample : current_start_sample + window_size_samples]
if len(chunk) < window_size_samples:
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
speech_prob, state = model(chunk, state, sampling_rate)
speech_probs.append(speech_prob)
triggered = False
speeches = []
current_speech = {}
neg_threshold = threshold - 0.15
# to save potential segment end (and tolerate some silence)
temp_end = 0
# to save potential segment limits in case of maximum segment size reached
prev_end = next_start = 0
for i, speech_prob in enumerate(speech_probs):
if (speech_prob >= threshold) and temp_end:
temp_end = 0
if next_start < prev_end:
next_start = window_size_samples * i
if (speech_prob >= threshold) and not triggered:
triggered = True
current_speech["start"] = window_size_samples * i
continue
if (
triggered
and (window_size_samples * i) - current_speech["start"] > max_speech_samples
):
if prev_end:
current_speech["end"] = prev_end
speeches.append(current_speech)
current_speech = {}
# previously reached silence (< neg_thres) and is still not speech (< thres)
if next_start < prev_end:
triggered = False
else:
current_speech["start"] = next_start
prev_end = next_start = temp_end = 0
else:
current_speech["end"] = window_size_samples * i
speeches.append(current_speech)
current_speech = {}
prev_end = next_start = temp_end = 0
triggered = False
continue
if (speech_prob < neg_threshold) and triggered:
if not temp_end:
temp_end = window_size_samples * i
# condition to avoid cutting in very short silence
if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
prev_end = temp_end
if (window_size_samples * i) - temp_end < min_silence_samples:
continue
else:
current_speech["end"] = temp_end
if (
current_speech["end"] - current_speech["start"]
) > min_speech_samples:
speeches.append(current_speech)
current_speech = {}
prev_end = next_start = temp_end = 0
triggered = False
continue
if (
current_speech
and (audio_length_samples - current_speech["start"]) > min_speech_samples
):
current_speech["end"] = audio_length_samples
speeches.append(current_speech)
for i, speech in enumerate(speeches):
if i == 0:
speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
if i != len(speeches) - 1:
silence_duration = speeches[i + 1]["start"] - speech["end"]
if silence_duration < 2 * speech_pad_samples:
speech["end"] += int(silence_duration // 2)
speeches[i + 1]["start"] = int(
max(0, speeches[i + 1]["start"] - silence_duration // 2)
)
else:
speech["end"] = int(
min(audio_length_samples, speech["end"] + speech_pad_samples)
)
speeches[i + 1]["start"] = int(
max(0, speeches[i + 1]["start"] - speech_pad_samples)
)
else:
speech["end"] = int(
min(audio_length_samples, speech["end"] + speech_pad_samples)
)
return speeches
def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
"""Collects and concatenates audio chunks."""
if not chunks:
return np.array([], dtype=np.float32)
return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
class SpeechTimestampsMap:
"""Helper class to restore original speech timestamps."""
def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2):
self.sampling_rate = sampling_rate
self.time_precision = time_precision
self.chunk_end_sample = []
self.total_silence_before = []
previous_end = 0
silent_samples = 0
for chunk in chunks:
silent_samples += chunk["start"] - previous_end
previous_end = chunk["end"]
self.chunk_end_sample.append(chunk["end"] - silent_samples)
self.total_silence_before.append(silent_samples / sampling_rate)
def get_original_time(
self,
time: float,
chunk_index: Optional[int] = None,
) -> float:
if chunk_index is None:
chunk_index = self.get_chunk_index(time)
total_silence_before = self.total_silence_before[chunk_index]
return round(total_silence_before + time, self.time_precision)
def get_chunk_index(self, time: float) -> int:
sample = int(time * self.sampling_rate)
return min(
bisect.bisect(self.chunk_end_sample, sample),
len(self.chunk_end_sample) - 1,
)
@functools.lru_cache
def get_vad_model():
"""Returns the VAD model instance."""
asset_dir = os.path.join(os.path.dirname(__file__), "assets")
path = os.path.join(asset_dir, "silero_vad.onnx")
return SileroVADModel(path)
class SileroVADModel:
def __init__(self, path):
try:
import onnxruntime
except ImportError as e:
raise RuntimeError(
"Applying the VAD filter requires the onnxruntime package"
) from e
opts = onnxruntime.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
opts.log_severity_level = 4
self.session = onnxruntime.InferenceSession(
path,
providers=["CPUExecutionProvider"],
sess_options=opts,
)
def get_initial_state(self, batch_size: int):
h = np.zeros((2, batch_size, 64), dtype=np.float32)
c = np.zeros((2, batch_size, 64), dtype=np.float32)
return h, c
def __call__(self, x, state, sr: int):
if len(x.shape) == 1:
x = np.expand_dims(x, 0)
if len(x.shape) > 2:
raise ValueError(
f"Too many dimensions for input audio chunk {len(x.shape)}"
)
if sr / x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")
h, c = state
ort_inputs = {
"input": x,
"h": h,
"c": c,
"sr": np.array(sr, dtype="int64"),
}
out, h, c = self.session.run(None, ort_inputs)
state = (h, c)
return out, state
"""A simple web interactive chat demo based on gradio."""
import os
import time
import gradio as gr
import base64
import numpy as np
import requests
API_URL = os.getenv("API_URL", None)
client = None
if API_URL is None:
from inference import OmniInference
omni_client = OmniInference('./checkpoint', 'cuda:0')
omni_client.warm_up()
OUT_CHUNK = 4096
OUT_RATE = 24000
OUT_CHANNELS = 1
def process_audio(audio):
filepath = audio
print(f"filepath: {filepath}")
if filepath is None:
return
cnt = 0
if API_URL is not None:
with open(filepath, "rb") as f:
data = f.read()
base64_encoded = str(base64.b64encode(data), encoding="utf-8")
files = {"audio": base64_encoded}
tik = time.time()
with requests.post(API_URL, json=files, stream=True) as response:
try:
for chunk in response.iter_content(chunk_size=OUT_CHUNK):
if chunk:
# Convert chunk to numpy array
if cnt == 0:
print(f"first chunk time cost: {time.time() - tik:.3f}")
cnt += 1
audio_data = np.frombuffer(chunk, dtype=np.int16)
audio_data = audio_data.reshape(-1, OUT_CHANNELS)
yield OUT_RATE, audio_data.astype(np.int16)
except Exception as e:
print(f"error: {e}")
else:
tik = time.time()
for chunk in omni_client.run_AT_batch_stream(filepath):
# Convert chunk to numpy array
if cnt == 0:
print(f"first chunk time cost: {time.time() - tik:.3f}")
cnt += 1
audio_data = np.frombuffer(chunk, dtype=np.int16)
audio_data = audio_data.reshape(-1, OUT_CHANNELS)
yield OUT_RATE, audio_data.astype(np.int16)
def main(port=None):
demo = gr.Interface(
process_audio,
inputs=gr.Audio(type="filepath", label="Microphone"),
outputs=[gr.Audio(label="Response", streaming=True, autoplay=True)],
title="Chat Mini-Omni Demo",
live=True,
)
if port is not None:
demo.queue().launch(share=False, server_name="0.0.0.0", server_port=port)
else:
demo.queue().launch()
if __name__ == "__main__":
import fire
fire.Fire(main)
import streamlit as st
import wave
# from ASR import recognize
import requests
import pyaudio
import numpy as np
import base64
import io
from typing import List
import av
import os
import time
import tempfile
import librosa
import traceback
from pydub import AudioSegment
from utils.vad import get_speech_timestamps, collect_chunks, VadOptions
from datetime import datetime
from PIL import Image
import streamlit_webrtc
from io import BytesIO
# set wide mode
# st.set_page_config(layout="wide")
last_video_frame = None
last_video_frame_ts = time.time()
API_URL = os.getenv("API_URL", "http://127.0.0.1:60808/chat")
API_URL = None if API_URL == "" else API_URL
# recording parameters
IN_FORMAT = pyaudio.paInt16
IN_CHANNELS = 1
IN_RATE = 24000
IN_CHUNK = 1024
IN_SAMPLE_WIDTH = 2
VAD_STRIDE = 0.5
# playing parameters
OUT_FORMAT = pyaudio.paInt16
OUT_CHANNELS = 1
OUT_RATE = 24000
OUT_SAMPLE_WIDTH = 2
OUT_CHUNK = 5760
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
def run_vad(ori_audio, sr):
_st = time.time()
try:
audio = np.frombuffer(ori_audio, dtype=np.int16)
audio = audio.astype(np.float32) / 32768.0
sampling_rate = 16000
if sr != sampling_rate:
audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
vad_parameters = {}
vad_parameters = VadOptions(**vad_parameters)
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio = collect_chunks(audio, speech_chunks)
duration_after_vad = audio.shape[0] / sampling_rate
if sr != sampling_rate:
# resample to original sampling rate
vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
else:
vad_audio = audio
vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
vad_audio_bytes = vad_audio.tobytes()
return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
except Exception as e:
msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}"
print(msg)
return -1, ori_audio, round(time.time() - _st, 4)
def warm_up():
frames = b"\x00\x00" * 1024 * 2 # 1024 frames of 2 bytes each
dur, frames, tcost = run_vad(frames, 16000)
print(f"warm up done, time_cost: {tcost:.3f} s")
def save_tmp_audio(audio_bytes):
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
file_name = tmpfile.name
audio = AudioSegment(
data=audio_bytes,
sample_width=OUT_SAMPLE_WIDTH,
frame_rate=OUT_RATE,
channels=OUT_CHANNELS,
)
audio.export(file_name, format="wav")
return file_name
def speaking(status, resp_text_holder=None, encoded_img=None):
# Initialize PyAudio
p = pyaudio.PyAudio()
# Open PyAudio stream
stream = p.open(
format=OUT_FORMAT, channels=OUT_CHANNELS, rate=OUT_RATE, output=True
)
audio_buffer = io.BytesIO()
wf = wave.open(audio_buffer, "wb")
wf.setnchannels(IN_CHANNELS)
wf.setsampwidth(IN_SAMPLE_WIDTH)
wf.setframerate(IN_RATE)
total_frames = b"".join(st.session_state.frames)
dur = len(total_frames) / (IN_RATE * IN_CHANNELS * IN_SAMPLE_WIDTH)
status.warning(f"Speaking... recorded audio duration: {dur:.3f} s")
wf.writeframes(total_frames)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
with open(tmpfile.name, "wb") as f:
f.write(audio_buffer.getvalue())
with open("input_audio.wav", "wb") as f:
f.write(audio_buffer.getvalue())
file_name = tmpfile.name
with st.chat_message("user"):
st.audio(file_name, format="audio/wav", loop=False, autoplay=False)
st.session_state.messages.append(
{"role": "assistant", "content": file_name, "type": "audio"}
)
st.session_state.frames = []
audio_bytes = audio_buffer.getvalue()
base64_encoded = str(base64.b64encode(audio_bytes), encoding="utf-8")
if API_URL is not None:
output_audio_bytes = b""
files = {"audio": base64_encoded}
if encoded_img is not None:
files["image"] = encoded_img
print("sending request to server")
resp_text_holder.empty()
resp_text = ""
with requests.post(API_URL, json=files, stream=True) as response:
try:
buffer = b''
for chunk in response.iter_content(chunk_size=2048):
buffer += chunk
while b'\r\n--frame\r\n' in buffer:
frame, buffer = buffer.split(b'\r\n--frame\r\n', 1)
if b'Content-Type: audio/wav' in frame:
audio_data = frame.split(b'\r\n\r\n', 1)[1]
# audio_data = base64.b64decode(audio_data)
output_audio_bytes += audio_data
audio_array = np.frombuffer(audio_data, dtype=np.int8)
stream.write(audio_array)
elif b'Content-Type: text/plain' in frame:
text_data = frame.split(b'\r\n\r\n', 1)[1].decode()
resp_text += text_data
if len(text_data) > 0:
print(resp_text, end='\r')
resp_text_holder.write(resp_text)
except Exception as e:
st.error(f"Error during audio streaming: {e}")
out_file = save_tmp_audio(output_audio_bytes)
with st.chat_message("assistant"):
st.write(resp_text)
with st.chat_message("assistant"):
st.audio(out_file, format="audio/wav", loop=False, autoplay=False)
st.session_state.messages.append(
{"role": "assistant", "content": resp_text, "type": "text"}
)
st.session_state.messages.append(
{"role": "assistant", "content": out_file, "type": "audio"}
)
else:
st.error("API_URL is not set. Please set the API_URL environment variable.")
time.sleep(1)
wf.close()
# Close PyAudio stream and terminate PyAudio
stream.stop_stream()
stream.close()
p.terminate()
st.session_state.speaking = False
st.session_state.recording = True
def recording(status):
audio = pyaudio.PyAudio()
stream = audio.open(
format=IN_FORMAT,
channels=IN_CHANNELS,
rate=IN_RATE,
input=True,
frames_per_buffer=IN_CHUNK,
)
temp_audio = b""
vad_audio = b""
start_talking = False
last_temp_audio = None
st.session_state.frames = []
while st.session_state.recording:
status.success("Listening...")
audio_bytes = stream.read(IN_CHUNK)
temp_audio += audio_bytes
if len(temp_audio) > IN_SAMPLE_WIDTH * IN_RATE * IN_CHANNELS * VAD_STRIDE:
dur_vad, vad_audio_bytes, time_vad = run_vad(temp_audio, IN_RATE)
print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s")
if dur_vad > 0.2 and not start_talking:
if last_temp_audio is not None:
st.session_state.frames.append(last_temp_audio)
start_talking = True
if start_talking:
st.session_state.frames.append(temp_audio)
if dur_vad < 0.1 and start_talking:
st.session_state.recording = False
print(f"speech end detected. excit")
last_temp_audio = temp_audio
temp_audio = b""
stream.stop_stream()
stream.close()
audio.terminate()
async def queued_video_frames_callback(frames: List[av.VideoFrame]) -> List[av.VideoFrame]:
# print(f"test-------queued_video_frames_callback")
global last_video_frame
global last_video_frame_ts
if len(frames) != 0:
if time.time() - last_video_frame_ts > 1:
last_frame = frames[-1]
# with video_frame_lock:
# last_video_frame[0] = last_frame.to_image()
# last_video_frame_ts[0] = time.time()
last_video_frame = last_frame.to_image()
last_video_frame_ts = time.time()
return frames
def main():
st.title("Chat Mini-Omni2 Demo")
status = st.empty()
# Mode selection
mode = st.radio(
"Select mode:",
("Audio-only", "Audio-vision"),
key="mode_selection",
horizontal=True
)
if mode == "Audio-only":
st.session_state.use_vision = False
st.info("Audio-only mode selected. The system will process only audio input.")
else: # Audio-vision
st.session_state.use_vision = True
st.info("Audio-vision mode selected. The system will process both audio and video input.")
if "warm_up" not in st.session_state:
warm_up()
st.session_state.warm_up = True
if "start" not in st.session_state:
st.session_state.start = False
if "recording" not in st.session_state:
st.session_state.recording = False
if "speaking" not in st.session_state:
st.session_state.speaking = False
if "frames" not in st.session_state:
st.session_state.frames = []
if not st.session_state.start:
status.warning("Click Start to chat")
start_col, stop_col, _ = st.columns([0.2, 0.2, 0.6])
start_button = start_col.button("Start", key="start_button")
# stop_button = stop_col.button("Stop", key="stop_button")
if start_button:
time.sleep(1)
st.session_state.recording = True
st.session_state.start = True
if st.session_state.use_vision:
with st.sidebar:
webrtc_ctx = streamlit_webrtc.webrtc_streamer(
key="speech-w-video",
mode=streamlit_webrtc.WebRtcMode.SENDRECV,
# rtc_configuration={"iceServers": get_ice_servers()},
media_stream_constraints={"video": True, "audio": False},
# video_receiver_size=10, # Increased from default 4 to 10
queued_video_frames_callback=queued_video_frames_callback,
)
if not webrtc_ctx.state.playing:
st.warning("Please allow camera access and try again.")
return
resp_text_holder = st.empty()
for message in st.session_state.messages:
with st.chat_message(message["role"]):
if message["type"] == "text":
st.markdown(message["content"])
elif message["type"] == "img":
st.image(message["content"], width=300)
elif message["type"] == "audio":
st.audio(
message["content"], format="audio/wav", loop=False, autoplay=False
)
while st.session_state.start:
if st.session_state.recording:
recording(status)
if not st.session_state.recording and st.session_state.start:
encoded_img = None
if st.session_state.use_vision:
# last_img = webrtc_ctx.video_receiver.get_frame(timeout=5).to_image()
last_img = last_video_frame
if last_img:
with st.chat_message("user"):
st.image(last_img, width=300)
st.session_state.messages.append({"role": "user", "content": last_img, "type": "img"})
input_img = last_img
buffer = BytesIO()
input_img.save(buffer, format="JPEG")
with open("input_image.jpg", "wb") as f:
f.write(buffer.getvalue())
encoded_img = str(base64.b64encode(buffer.getvalue()), encoding="utf-8")
else:
st.error("No image captured. Please allow camera access and try again.")
return
st.session_state.speaking = True
speaking(status, resp_text_holder, encoded_img)
# if stop_button:
# status.warning("Stopped, click Start to chat")
# st.session_state.start = False
# st.session_state.recording = False
# st.session_state.frames = []
# break
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment