Commit b97afd54 authored by wangwei990215's avatar wangwei990215
Browse files

Initial commit

parents
Pipeline #1825 failed with stages
in 0 seconds
import time
import torch
try:
import torch_musa
except ImportError as e:
print("You should install torch_musa if you want to run on Moore Threads GPU")
import os
import argparse
import torchaudio
from torchaudio.transforms import Resample
import logging
from mooer.datasets.speech_processor import *
from mooer.configs import asr_config
from mooer.models import mooer_model
from mooer.utils.utils import *
from mooer.models.hifigan import save_wav, get_hifigan_model, get_speaker_encoder, encode_prompt_wav
parser = argparse.ArgumentParser()
parser.add_argument("--wav_path", default='demo/resources/demo.wav', type=str, help="decode one wav file")
parser.add_argument("--wav_scp", default=None, type=str, help="decode scp if you want")
parser.add_argument("--task", default='s2s_chat', choices=['asr', 'ast', 's2s_trans', 's2s_chat'],
type=str, help="task: asr or ast or s2s_trans or s2s_chat. "
"Please set ast if you choose a asr/ast/s2s_trans/s2s_chat multitask model")
parser.add_argument("--batch_size", default=1, type=int, help="decode batch for scp")
parser.add_argument("--cmvn_path", default='', type=str, help="cmvn path.")
parser.add_argument("--encoder_path", default='', type=str, help="encoder path.")
parser.add_argument("--llm_path", default='', type=str, help="llm path.")
parser.add_argument("--adapter_path", default='', type=str, help="adapter path.")
parser.add_argument("--lora_dir", default='', type=str, help="lora path.")
parser.add_argument("--vocoder_path", default='', type=str, help="vocoder path")
parser.add_argument("--spk_encoder_path", default='', type=str, help="spk encoder path")
parser.add_argument("--prompt_wav_path", default='', type=str, help="prompt wav path")
parser.add_argument("--output_dir", default="response_wavs_dir", type=str, help="path to save wav generated")
args = parser.parse_args()
assert args.batch_size == 1, "Only support bsz=1 for S2ST task now. We will support batch inference soon."
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
filemode='w'
)
PROMPT_TEMPLATE_DICT = {
'qwen': "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
}
PROMPT_DICT = {
'asr': "Transcribe speech to text. ",
'ast': "Translate speech to english text. ",
's2s_trans': "Translate speech to english speech. ",
's2s_chat': "Answer my question with speech. "
}
model_config = asr_config.ModelConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# replace path
if args.llm_path and os.path.exists(args.llm_path):
model_config.llm_path = args.llm_path
if args.encoder_path and os.path.exists(args.encoder_path):
model_config.encoder_path = args.encoder_path
if args.adapter_path and os.path.exists(args.adapter_path):
model_config.adapter_path = args.adapter_path
if args.lora_dir and os.path.exists(args.lora_dir):
model_config.lora_dir = args.lora_dir
if args.cmvn_path and os.path.exists(args.cmvn_path):
model_config.cmvn_path = args.cmvn_path
if args.task:
model_config.prompt_key = args.task
device = str(get_device())
logger.info("This demo will run on {}".format(device.upper()))
logger.info(model_config)
os.makedirs(args.output_dir, exist_ok=True)
logger.info("Response wav will save in {}".format(args.output_dir))
model, tokenizer = mooer_model.init_model(
model_config=model_config)
AUDIO_START_TOKEN_INDEX = tokenizer.get_vocab()['<|audio_start|>']
model.to(device)
model.eval()
# data process
prompt_template_key = model_config.get('prompt_template_key', 'qwen')
prompt_template = PROMPT_TEMPLATE_DICT[prompt_template_key]
prompt_key = model_config.get('prompt_key', 'asr')
prompt_org = PROMPT_DICT[prompt_key]
logger.info(f"Use LLM Type {prompt_template_key}, "
f"Prompt template {prompt_template}, "
f"Use task type {prompt_key}, "
f"Prompt {prompt_org}")
cmvn = load_cmvn(model_config.get('cmvn_path'))
adapter_downsample_rate = model_config.get('adapter_downsample_rate')
hifigan_generator = get_hifigan_model(args.vocoder_path, device, decoder_dim=3584)
spk_encoder = get_speaker_encoder(args.spk_encoder_path, device)
spk_embedding = encode_prompt_wav(spk_encoder, args.prompt_wav_path, device)
def process_wav(wav_path):
audio_raw, sample_rate = torchaudio.load(wav_path)
if sample_rate != 16000:
# resample the data
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
audio_raw = resampler(audio_raw)
if audio_raw.shape[0] > 1:
# convert to mono
audio_raw = audio_raw.mean(dim=0, keepdim=True)
audio_raw = audio_raw[0]
prompt = prompt_template.format(prompt_org)
audio_mel = compute_fbank(waveform=audio_raw)
audio_mel = apply_lfr(inputs=audio_mel, lfr_m=7, lfr_n=6)
audio_mel = apply_cmvn(audio_mel, cmvn=cmvn)
audio_length = audio_mel.shape[0]
audio_length = audio_length // adapter_downsample_rate
audio_pseudo = torch.full((audio_length,), -1)
prompt_ids = tokenizer.encode(prompt)
prompt_length = len(prompt_ids)
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64)
example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio, prompt]
example_mask = example_ids.ge(-1)
items = {
"input_ids": example_ids,
"attention_mask": example_mask,
"audio_mel": audio_mel,
"audio_length": audio_length,
"prompt_length": prompt_length,
}
return items
load_dtype = model_config.get('load_dtype', 'bfloat16')
dtype = torch.float32
if load_dtype == 'float16':
dtype = torch.float16
elif load_dtype == 'bfloat16':
dtype = torch.bfloat16
logging.info(f"Input data type: {dtype}")
context_scope = torch.musa.amp.autocast if 'musa' in device else torch.cuda.amp.autocast
with torch.no_grad():
if args.wav_scp is not None and os.path.exists(args.wav_scp):
batch_size = args.batch_size
infer_time = []
items = parse_key_text(args.wav_scp)
uttids = list(items.keys())
num_batches = len(uttids) // batch_size + (0 if len(uttids) % batch_size == 0 else 1)
for i in range(num_batches):
try:
batch_uttids = uttids[i * batch_size:(i + 1) * batch_size]
batch_wav_paths = [items[uttid] for uttid in batch_uttids]
samples = []
for wav_path in batch_wav_paths:
samples.append(process_wav(wav_path))
batch = process_batch(samples, tokenizer=tokenizer)
for key in batch.keys():
batch[key] = batch[key].to(device) if isinstance(batch[key], torch.Tensor) else batch[key]
with context_scope(dtype=dtype):
ss = time.perf_counter()
inputs_embeds, attention_mask, kwargs = model.generate(**batch, compute_llm=False)
prompt_and_encoding_length = inputs_embeds.shape[1]
model_outputs = model.llm.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=kwargs.get("max_new_tokens", 1000),
num_beams=kwargs.get("num_beams", 4),
do_sample=True,
min_length=kwargs.get("min_length", 1),
top_p=0.85,
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
length_penalty=kwargs.get("length_penalty", 1.0),
temperature=kwargs.get("temperature", 1.0),
attention_mask=attention_mask,
bos_token_id=model.tokenizer.bos_token_id,
eos_token_id=model.tokenizer.eos_token_id,
pad_token_id=model.tokenizer.pad_token_id,
)
infer_time.append(time.perf_counter() - ss)
logging.info(f"Infer time: {time.perf_counter() - ss}")
output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False,
skip_special_tokens=True)
if hasattr(model.llm.model, "embed_tokens"):
teacher_forcing_input_embeds = model.llm.model.embed_tokens(model_outputs)
teacher_forcing_input_att_mask = torch.ones((1, teacher_forcing_input_embeds.shape[1]),
dtype=torch.bool).to(device)
else:
raise NotImplementedError
inputs_embeds = torch.concat([inputs_embeds, teacher_forcing_input_embeds], dim=-2)
attention_mask = torch.concat([attention_mask, teacher_forcing_input_att_mask], dim=-1)
llm_output = model.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask,
output_hidden_states=True)
audio_start_index = prompt_and_encoding_length + model_outputs[0].tolist().index(AUDIO_START_TOKEN_INDEX)
audio_latents = llm_output.hidden_states[-1][:, audio_start_index:-6, :]
for idx, text in enumerate(output_text):
logger.info(f"uttid: {batch_uttids[idx]}")
audio_file_out_tts = os.path.join(args.output_dir, f"{batch_uttids[idx]}.tts.wav")
text_ast = text.split("<|audio_start|>")[0]
text_ast = text_ast.replace('\\n', '\n')
logger.info(f"AST: {text_ast}")
save_wav(hifigan_generator, spk_embedding, audio_latents.float(), audio_file_out_tts)
logger.info(f"Finished writing: {audio_file_out_tts}")
except Exception as e:
logging.error(e)
logging.info("Total inference cost")
logging.info(sum(infer_time))
elif args.wav_path != '' and os.path.exists(args.wav_path):
try:
wav_path = args.wav_path
items = process_wav(wav_path)
batch = process_batch([items], tokenizer=tokenizer)
for key in batch.keys():
batch[key] = batch[key].to(device) if isinstance(batch[key], torch.Tensor) else batch[key]
with context_scope(dtype=dtype):
ss = time.perf_counter()
inputs_embeds, attention_mask, kwargs = model.generate(**batch, compute_llm=False)
prompt_and_encoding_length = inputs_embeds.shape[1]
model_outputs = model.llm.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=kwargs.get("max_new_tokens", 1000),
num_beams=kwargs.get("num_beams", 4),
do_sample=True,
min_length=kwargs.get("min_length", 1),
top_p=0.85,
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
length_penalty=kwargs.get("length_penalty", 1.0),
temperature=kwargs.get("temperature", 1.0),
attention_mask=attention_mask,
bos_token_id=model.tokenizer.bos_token_id,
eos_token_id=model.tokenizer.eos_token_id,
pad_token_id=model.tokenizer.pad_token_id,
)
logging.info(f"Infer time: {time.perf_counter() - ss}")
output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False,
skip_special_tokens=True)
if hasattr(model.llm.model, "embed_tokens"):
teacher_forcing_input_embeds = model.llm.model.embed_tokens(model_outputs)
teacher_forcing_input_att_mask = torch.ones((1, teacher_forcing_input_embeds.shape[1]),
dtype=torch.bool).to(device)
else:
raise NotImplementedError
inputs_embeds = torch.concat([inputs_embeds, teacher_forcing_input_embeds], dim=-2)
attention_mask = torch.concat([attention_mask, teacher_forcing_input_att_mask], dim=-1)
llm_output = model.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask,
output_hidden_states=True)
audio_start_index = prompt_and_encoding_length + model_outputs[0].tolist().index(AUDIO_START_TOKEN_INDEX)
audio_latents = llm_output.hidden_states[-1][:, audio_start_index:-6, :]
for text in output_text:
uttid = os.path.basename(wav_path).replace(".wav", "")
audio_file_out_tts = os.path.join(args.output_dir, f"{uttid}.tts.wav")
text_ast = text.split("<|audio_start|>")[0]
text_ast = text_ast.replace('\\n', '\n')
logger.info(f"Text: {text_ast}")
save_wav(hifigan_generator, spk_embedding, audio_latents.float(), audio_file_out_tts)
logger.info(f"Finished writing: {audio_file_out_tts}")
except Exception as e:
logging.error(e)
else:
raise IOError("You should specify --wav_scp or --wav_path as the input")
#模型编码
modelCode=1067
# 模型名称
modelName=mooer_pytorch
# 模型描述
modelDescription=一个由摩尔线程开发的、基于大语言模型(Large Language Model,LLM)的语音识别和语音翻译系统。
# 应用场景(多个标签以英文逗号分割)
appScenario=推理,语音识别,语音翻译,教育,医疗
# 框架类型(多个标签以英文逗号分割)
frameType=PyTorch
\ No newline at end of file
<Nnet>
<Splice> 560 560
[ 0 ]
<AddShift> 560 560
<LearnRateCoef> 0 [ -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 ]
<Rescale> 560 560
<LearnRateCoef> 0 [ 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 ]
</Nnet>
from dataclasses import dataclass
from typing import Optional
@dataclass
class ModelConfig:
def __init__(self):
self.llm_name: str = "qwen2_7b_chat"
# You should set your own path
self.llm_path: str = "pretrained_models/Qwen2-7B-Instruct"
self.encoder_path: str = "pretrained_models/paraformer_encoder/paraformer-encoder.pth"
self.adapter_path: str = "pretrained_models/asr_ast_mtl/adapter_project.pt"
self.lora_dir: str = "pretrained_models/asr_ast_mtl/lora_weights"
self.cmvn_path: str = "pretrained_models/paraformer_encoder/am.mvn"
self.prompt_key: str = 'ast' # or asr for ASR model
###############################
self.llm_type: str = "decoder_only"
self.llm_dim: int = 3584
self.load_dtype: str = "bfloat16"
self.encoder_name: str = 'paraformer'
self.encoder_dim: int = 512
self.adapter: str = "linear"
self.adapter_downsample_rate: int = 2
self.modal: str = "audio"
self.normalize: Optional[bool] = False
self.gradient_checkpoint: bool = False
self.is_inference: bool = True
self.prompt_template_key: str = 'qwen'
def __getitem__(self, key):
return getattr(self, key)
def get(self, attribute_name, default_value=None):
return getattr(self, attribute_name, default_value)
from dataclasses import dataclass
from typing import Optional, List
@dataclass
class ModelConfig:
def __init__(self):
self.llm_name: str = "qwen2_7b_chat"
# You should set your own path
self.llm_path: str = "pretrained_models/Qwen2-7B-Instruct"
self.encoder_path: str = "pretrained_models/paraformer_encoder/paraformer-encoder.pth"
self.adapter_path: str = "pretrained_models/asr_ast_mtl/adapter_project.pt"
self.lora_dir: str = "pretrained_models/asr_ast_mtl/lora_weights"
self.cmvn_path: str = "/root/MooER/src/mooer/configs/am.mvn"
self.prompt_key: str = 'asr' # asr, ast... you can add tasks in src/mooer/utils/data_utils.py
###############################
self.llm_type: str = "decoder_only"
self.llm_dim: int = 3584
self.load_dtype: str = "bfloat16"
self.encoder_name: str = 'paraformer'
self.encoder_dim: int = 512
self.adapter: str = "linear"
self.adapter_downsample_rate: int = 2
self.modal: str = "audio"
self.normalize: Optional[bool] = False
self.gradient_checkpoint: bool = False
self.is_inference: bool = True
self.prompt_template_key: str = 'qwen'
def __getitem__(self, key):
return getattr(self, key)
def get(self, attribute_name, default_value=None):
return getattr(self, attribute_name, default_value)
@dataclass
class PeftConfig:
def __init__(self):
self.peft_method: str = "lora" # None , llama_adapter, prefix
self.r: int = 64
self.lora_alpha: int = 16
self.target_modules: List = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
]
self.bias: str = "none"
self.task_type: str = "CAUSAL_LM"
self.lora_dropout: float = 0.05
self.inference_mode: bool = False
def __getitem__(self, key):
return getattr(self, key)
def get(self, attribute_name, default_value=None):
return getattr(self, attribute_name, default_value)
@dataclass
class TrainConfig:
def __init__(self):
self.model_name: str = "asr"
self.enable_deepspeed: bool = True
self.batch_size_training: int = 8 # you should set same as deepspeed config for throughput
self.batching_strategy: str = 'custom'
self.context_length: int = 4096
self.num_epochs: int = 10
self.num_workers_dataloader: int = 4
# please set it in deepspeed config
# self.warmup_steps: int = 1000
# self.total_steps: int = 1000000
# self.lr: float = 1e-4
# self.weight_decay: float = 0.0
self.save_interval: int = 20000
self.save_merge_rank: bool = True
# will merge deepspeed model from several rank
self.log_interval: int = 100
self.resume_step: int = 0
self.resume_epoch: int = 0
self.gamma: float = 0.85
self.seed: int = 42
self.use_fp16: bool = False
self.use_bf16: bool = True
self.mixed_precision: bool = True
self.val_batch_size: int = 10
self.use_peft: bool = True
self.output_dir: str = "output/save_models"
self.freeze_llm: bool = True
self.freeze_encoder: bool = True
self.freeze_projector: bool = True
self.find_unused_parameters: bool = False
self.gradient_checkpoint: bool = False
self.deepspeed_config: str = '/root/MooER/src/mooer/configs/deepspeed_config_zero2.json'
# if you want large bsz or to reduce memory, use zero3, but it will be slow
def __getitem__(self, key):
return getattr(self, key)
def get(self, attribute_name, default_value=None):
return getattr(self, attribute_name, default_value)
@dataclass
class DataConfig:
def __init__(self):
self.train_data_path: Optional[str] = ''
self.val_data_path: Optional[str] = ''
self.test_data_dir: str = '/Your/testsets/root'
self.test_sets: str = 'test-clean/test-other/aishell'
# you can put a series of test sets under test_data_dir for testing, use / for split
self.decode_path: Optional[str] = ''
self.fix_length_audio: int = -1
self.max_length: int = 2000
self.min_length: int = 20
self.mel_size: int = 80
self.train_data_type: str = 'shard'
self.test_data_type: str = 'shard'
self.prompt_template_key: str = 'qwen'
self.prompt_key: str = 'asr'
self.w2v_bert_path: str = ''
self.sort: bool = False
self.replace_text_path: str = ''
self.replace_type: str = 'replace'
# you can use replace_text_path & replace_type to train other task, e.g, AST, with same uttid but different label
def __getitem__(self, key):
return getattr(self, key)
def get(self, attribute_name, default_value=None):
return getattr(self, attribute_name, default_value)
def update(model_config, train_config, data_config):
train_config.is_inference = model_config.is_inference
data_config.is_inference = model_config.is_inference
data_config.num_epochs = train_config.num_epochs
data_config.adapter_downsample_rate = model_config.adapter_downsample_rate
data_config.cmvn_path = model_config.cmvn_path
data_config.encoder_name = model_config.encoder_name
data_config.normalize = model_config.normalize
from dataclasses import dataclass
from typing import Optional, List
@dataclass
class ModelConfig:
def __init__(self):
self.llm_name: str = "qwen2_7b_chat"
# You should set your own path
self.llm_path: str = "pretrained_models/Qwen2-7B-Instruct"
self.encoder_path: str = "pretrained_models/paraformer_encoder/paraformer-encoder.pth"
self.adapter_path: Optional[str] = ''
self.lora_dir: Optional[str] = ''
self.cmvn_path: str = "/root/MooER/src/mooer/configs/am.mvn"
self.prompt_key: str = 'asr' # asr, ast... you can add tasks in src/mooer/utils/data_utils.py
###############################
self.llm_type: str = "decoder_only"
self.llm_dim: int = 3584
self.load_dtype: str = "bfloat16"
self.encoder_name: str = 'paraformer'
self.encoder_dim: int = 512
self.adapter: str = "linear"
self.adapter_downsample_rate: int = 2
self.modal: str = "audio"
self.normalize: Optional[bool] = False
self.gradient_checkpoint: bool = False
self.is_inference: bool = False
self.prompt_template_key: str = 'qwen'
def __getitem__(self, key):
return getattr(self, key)
def get(self, attribute_name, default_value=None):
return getattr(self, attribute_name, default_value)
@dataclass
class PeftConfig:
def __init__(self):
self.peft_method: str = "lora" # None , llama_adapter, prefix
self.r: int = 64
self.lora_alpha: int = 16
self.target_modules: List = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
]
self.bias: str = "none"
self.task_type: str = "CAUSAL_LM"
self.lora_dropout: float = 0.05
self.inference_mode: bool = False
def __getitem__(self, key):
return getattr(self, key)
def get(self, attribute_name, default_value=None):
return getattr(self, attribute_name, default_value)
@dataclass
class TrainConfig:
def __init__(self):
self.model_name: str = "asr"
self.enable_deepspeed: bool = True
self.batch_size_training: int = 8 # you should set same as deepspeed config for throughput
self.batching_strategy: str = 'custom'
self.context_length: int = 4096
self.num_epochs: int = 10
self.num_workers_dataloader: int = 4
# please set it in deepspeed config
# self.warmup_steps: int = 1000
# self.total_steps: int = 1000000
# self.lr: float = 1e-4
# self.weight_decay: float = 0.0
self.save_interval: int = 20000
self.save_merge_rank: bool = True
# will merge deepspeed model from several rank
self.log_interval: int = 100
self.resume_step: int = 0
self.resume_epoch: int = 0
self.gamma: float = 0.85
self.seed: int = 42
self.use_fp16: bool = False
self.use_bf16: bool = True
self.mixed_precision: bool = True
self.val_batch_size: int = 1
self.use_peft: bool = True
self.output_dir: str = "output/save_models"
self.freeze_llm: bool = True
self.freeze_encoder: bool = True
self.freeze_projector: bool = False
self.find_unused_parameters: bool = False
self.gradient_checkpoint: bool = False
self.deepspeed_config: str = '/root/MooER/src/mooer/configs/deepspeed_config_zero2.json'
# if you want large bsz or to reduce memory, use zero3, but it will be slow
def __getitem__(self, key):
return getattr(self, key)
def get(self, attribute_name, default_value=None):
return getattr(self, attribute_name, default_value)
@dataclass
class DataConfig:
def __init__(self):
self.train_data_path: str = '/YOUR/training/data.0.list'
self.val_data_path: Optional[str] = ''
self.test_data_dir: Optional[str] = ''
self.test_sets: Optional[str] = ''
self.decode_path: Optional[str] = ''
self.fix_length_audio: int = -1
self.max_length: int = 2000
self.min_length: int = 20
self.mel_size: int = 80
self.train_data_type: str = 'shard'
self.test_data_type: str = 'shard'
self.prompt_template_key: str = 'qwen'
self.prompt_key: str = 'asr'
self.w2v_bert_path: str = ''
self.num_epochs: int = 10
self.sort: bool = False
self.replace_text_path: str = ''
self.replace_type: str = 'replace'
# you can use replace_text_path & replace_type to train other task, e.g, AST, with same uttid but different label
def __getitem__(self, key):
return getattr(self, key)
def get(self, attribute_name, default_value=None):
return getattr(self, attribute_name, default_value)
def update(model_config, train_config, data_config):
train_config.is_inference = model_config.is_inference
data_config.is_inference = model_config.is_inference
data_config.num_epochs = train_config.num_epochs
data_config.adapter_downsample_rate = model_config.adapter_downsample_rate
data_config.cmvn_path = model_config.cmvn_path
data_config.encoder_name = model_config.encoder_name
data_config.normalize = model_config.normalize
{
"train_micro_batch_size_per_gpu": 8,
"gradient_accumulation_steps": 2,
"steps_per_print": 100,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"total_num_steps": 1000000,
"warmup_max_lr": 0.0001,
"warmup_num_steps": 1000
}
},
"fp16": {
"enabled": false,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": true
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"activation_checkpointing": {
"partition_activations": true,
"cpu_checkpointing": true,
"contiguous_memory_optimization": true,
"synchronize_checkpoint_boundary": false,
"profile": false
}
}
\ No newline at end of file
{
"train_micro_batch_size_per_gpu": 16,
"gradient_accumulation_steps": 2,
"steps_per_print": 100,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"total_num_steps": 10000000,
"warmup_max_lr": 0.0001,
"warmup_num_steps": 1000
}
},
"fp16": {
"enabled": false,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"activation_checkpointing": {
"partition_activations": true,
"cpu_checkpointing": true,
"contiguous_memory_optimization": true,
"synchronize_checkpoint_boundary": false,
"profile": false
}
}
import logging
import os
import random
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset
from transformers import AutoFeatureExtractor
import mooer.datasets.speech_processor as processor
from mooer.utils.data_utils import PROMPT_DICT, PROMPT_TEMPLATE_DICT
def read_lists(list_file, num_epochs=1, shuffle=False):
lists = []
with open(list_file, 'r', encoding='utf8') as fin:
for line in fin:
lists.append(line.strip())
lists = lists * num_epochs
if shuffle:
random.shuffle(lists)
return lists
def load_cmvn(cmvn_file):
with open(cmvn_file, "r", encoding="utf-8") as f:
lines = f.readlines()
means_list = []
vars_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == "<AddShift>":
line_item = lines[i + 1].split()
if line_item[0] == "<LearnRateCoef>":
add_shift_line = line_item[3 : (len(line_item) - 1)]
means_list = list(add_shift_line)
continue
elif line_item[0] == "<Rescale>":
line_item = lines[i + 1].split()
if line_item[0] == "<LearnRateCoef>":
rescale_line = line_item[3 : (len(line_item) - 1)]
vars_list = list(rescale_line)
continue
means = np.array(means_list).astype(np.float32)
vars = np.array(vars_list).astype(np.float32)
cmvn = np.array([means, vars])
cmvn = torch.as_tensor(cmvn, dtype=torch.float32)
return cmvn
class Processor(IterableDataset):
def __init__(self, source, f, *args, **kw):
assert callable(f)
self.source = source
self.f = f
self.args = args
self.kw = kw
def set_epoch(self, epoch):
self.source.set_epoch(epoch)
def __iter__(self):
""" Return an iterator over the source dataset processed by the
given processor.
"""
assert self.source is not None
assert callable(self.f)
return self.f(iter(self.source), *self.args, **self.kw)
def apply(self, f):
assert callable(f)
return Processor(self, f, *self.args, **self.kw)
class DistributedSampler:
def __init__(self, shuffle=True, partition=True):
self.epoch = -1
self.update()
self.shuffle = shuffle
self.partition = partition
def update(self):
assert dist.is_available()
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
else:
self.rank = 0
self.world_size = 1
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.worker_id = 0
self.num_workers = 1
else:
self.worker_id = worker_info.id
self.num_workers = worker_info.num_workers
return dict(rank=self.rank,
world_size=self.world_size,
worker_id=self.worker_id,
num_workers=self.num_workers)
def set_epoch(self, epoch):
self.epoch = epoch
def sample(self, data):
""" Sample data according to rank/world_size/num_workers
Args:
data(List): input data list
Returns:
List: data list after sample
"""
data = list(range(len(data)))
if self.partition:
if self.shuffle:
random.Random(self.epoch).shuffle(data)
data = data[self.rank::self.world_size]
data = data[self.worker_id::self.num_workers]
return data
class DataList(IterableDataset):
def __init__(self, lists, shuffle=True, partition=True):
self.lists = lists
self.sampler = DistributedSampler(shuffle, partition)
def set_epoch(self, epoch):
self.sampler.set_epoch(epoch)
def __iter__(self):
sampler_info = self.sampler.update()
indexes = self.sampler.sample(self.lists)
for index in indexes:
# yield dict(src=src)
data = dict(src=self.lists[index])
data.update(sampler_info)
yield data
class SpeechDatasetShard(torch.utils.data.Dataset):
def __init__(self,
dataset_config,
normalize=True,
mel_size=128,
tokenizer=None):
super().__init__()
self.dataset_config = dataset_config
self.tokenizer = tokenizer
self.IGNORE_INDEX = -100
self.normalize = normalize
self.mel_size = mel_size
self.max_length = dataset_config.get('max_length', 2000)
self.min_length = dataset_config.get('min_length', 20)
self.prompt_template_key = dataset_config.get('prompt_template_key', 'qwen')
self.prompt_template = PROMPT_TEMPLATE_DICT[self.prompt_template_key]
self.prompt_key = dataset_config.get('prompt_key', 'asr')
if self.prompt_key == 'instruction':
self.prompt = PROMPT_DICT
else:
self.prompt = PROMPT_DICT[self.prompt_key]
logging.info(f"Use LLM Type {self.prompt_template_key}, "
f"Prompt template {self.prompt_template}, "
f"Use task type {self.prompt_key}, "
f"Prompt {self.prompt}")
self.is_inference = dataset_config.get('is_inference', False)
if (dataset_config.get('w2v_bert_path', None) is not None) and (os.path.exists(dataset_config.w2v_bert_path)):
self.auto_processer = AutoFeatureExtractor.from_pretrained(dataset_config.w2v_bert_path)
else:
self.auto_processer = None
if (dataset_config.get('cmvn_path', None) is not None) and (os.path.exists(dataset_config.cmvn_path)):
self.cmvn = load_cmvn(dataset_config.cmvn_path)
else:
assert self.dataset_config.encoder_name != 'paraformer', 'paraformer must use cmvn'
self.cmvn = None
self.num_epochs = dataset_config.get('num_epochs', 1)
self.adapter_downsample_rate = dataset_config.get('adapter_downsample_rate', 2)
self.sort = dataset_config.get('sort', True)
self.replace_text_table = None
self.replace_text_path = dataset_config.get('replace_text_path', '')
self.replace_type = dataset_config.get('replace_type', 'replace')
if self.prompt_key == 'instruction':
self.replace_type = 'instruction'
if os.path.exists(self.replace_text_path):
logging.info(f"Parsing replaced table {self.replace_text_path}..., Method {self.replace_type}")
self.replace_text_table = self.parse_txt2dict(self.replace_text_path)
if self.dataset_config.encoder_name in ['paraformer', 'whisper', 'w2v_bert2.0']:
self.input_type = 'mel'
else:
self.input_type = 'raw'
@classmethod
def parse_txt2dict(cls, txt_path):
result = {}
with open(txt_path, 'r') as r:
for line in r.readlines():
line = line.strip()
if line == '':
continue
line = line.split(maxsplit=1)
if len(line) != 2:
continue
key, text = line
result[key] = text
return result
def dataset(self,
data_type,
data_list_file,
shuffle=True,
partition=True):
assert data_type in ['raw', 'shard']
lists = read_lists(data_list_file, num_epochs=self.num_epochs, shuffle=shuffle)
dataset = DataList(lists, shuffle=shuffle, partition=partition)
if data_type == 'shard':
dataset = Processor(dataset, processor.url_opener)
dataset = Processor(dataset, processor.tar_file_and_group)
else:
dataset = Processor(dataset, processor.parse_raw)
if not self.is_inference:
dataset = Processor(dataset, processor.filter,
max_length=self.max_length,
min_length=self.min_length)
if self.replace_text_table is not None:
dataset = Processor(dataset, processor.refine_text,
replaced_table=self.replace_text_table,
replace_type=self.replace_type)
dataset = Processor(dataset, processor.gen_llm_inputs,
tokenizer=self.tokenizer,
ignore_index=self.IGNORE_INDEX,
normalize=self.normalize,
mel_size=self.mel_size,
input_type=self.input_type,
is_paraformer=self.dataset_config.encoder_name == 'paraformer',
prompt_template=self.prompt_template,
is_inference=self.is_inference,
autoprocesser=self.auto_processer,
cmvn=self.cmvn,
prompt_org=self.prompt,
adapter_downsample_rate=self.adapter_downsample_rate,
instruction=self.prompt_key == 'instruction')
if not self.is_inference:
# add shuffle
dataset = Processor(dataset, processor.shuffle)
if self.sort:
dataset = Processor(dataset, processor.sort, sort_size=2000, key='audio'
if self.dataset_config.encoder_name == 'hubert' else 'audio_mel')
return dataset
def pad(self, sequence, max_length, padding_idx=0):
if isinstance(sequence, (int, list, tuple)):
if len(sequence) < max_length:
sequence = sequence + [padding_idx] * (max_length - len(sequence))
else:
sequence = sequence[:max_length]
elif isinstance(sequence, torch.Tensor):
if len(sequence) < max_length:
sequence = torch.cat(
(sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx)))
else:
sequence = sequence[:max_length]
elif isinstance(sequence, np.ndarray):
if len(sequence) < max_length:
sequence = np.concatenate(
(sequence, np.full((max_length - len(sequence),) + sequence.shape[1:], padding_idx)))
else:
sequence = sequence[:max_length]
else:
raise Exception("Type mismatch during padding!")
return sequence
@classmethod
def padding(cls, sequence, padding_length, padding_idx=0, padding_side="right"):
if isinstance(sequence, (int, list, tuple)):
if padding_length >= 0:
sequence = sequence + [padding_idx] * padding_length
else:
sequence = sequence[:padding_length]
elif isinstance(sequence, torch.Tensor):
if sequence.ndimension() == 2:
if padding_length >= 0:
sequence = torch.nn.functional.pad(sequence, (0, padding_length))
else:
sequence = sequence[:, :padding_length]
else:
if padding_length >= 0:
if padding_side == "left":
sequence = torch.cat(
(torch.full(([padding_length] + list(sequence.size())[1:]), padding_idx), sequence))
else:
sequence = torch.cat(
(sequence, torch.full(([padding_length] + list(sequence.size())[1:]), padding_idx)))
else:
sequence = sequence[:padding_length]
elif isinstance(sequence, np.ndarray):
if padding_length >= 0:
sequence = np.concatenate(
(sequence, np.full((padding_length,) + sequence.shape[1:], padding_idx)))
else:
sequence = sequence[:padding_length]
else:
raise Exception("Type mismatch during padding!")
return sequence
def collator(self, samples):
assert samples is not None
input_prompt_lengths = [s["audio_length"] + s['prompt_length'] for s in samples] # [120, 48, 82, 42]
input_answer_lengths = [len(s["input_ids"]) - s["audio_length"] - s['prompt_length'] for s in
samples] # [0, 0, 0, 0]
input_prompt_max_length = max(input_prompt_lengths)
input_answer_max_length = max(input_answer_lengths)
input_ids = torch.stack([
self.padding(
self.padding(samples[index]["input_ids"], input_prompt_max_length - input_prompt_lengths[index],
self.tokenizer.pad_token_id, padding_side="left"),
input_answer_max_length - input_answer_lengths[index], self.tokenizer.pad_token_id
) for index in range(len(samples))
])
attention_mask = torch.stack([
self.padding(
self.padding(samples[index]["attention_mask"], input_prompt_max_length - input_prompt_lengths[index],
False, padding_side="left"),
input_answer_max_length - input_answer_lengths[index], False
) for index in range(len(samples))
])
if self.auto_processer is not None:
audio_mel_max_length = max([s['audio_mel'].shape[0] for s in samples])
audio_mel = torch.stack([self.pad(s['audio_mel'], audio_mel_max_length, 0)
for s in samples])
audio_mel_post_mask = torch.zeros(len(samples), (
audio_mel_max_length)) # w2v-bert
for line, sample in enumerate(samples):
audio_mel_post_mask[line, :(sample['audio_mel'].shape[0])] = 1
elif self.dataset_config.encoder_name == 'paraformer':
audio_mel_reallen = [s['audio_mel'].shape[0] for s in samples]
audio_mel_max_length = max(audio_mel_reallen)
audio_mel = torch.stack([self.pad(s['audio_mel'], audio_mel_max_length, 0)
for s in samples])
audio_mel_post_mask = torch.zeros(len(samples), (
audio_mel_max_length)) # paraformer
for line, sample in enumerate(samples):
audio_mel_post_mask[line, :(sample['audio_mel'].shape[0])] = 1
audio_mel_reallen = torch.tensor(audio_mel_reallen, dtype=torch.int32)
elif self.dataset_config.encoder_name == 'hubert':
audio_raw_max_length = max([s['audio'].shape[0] for s in samples])
audio_raw = torch.stack([self.pad(s['audio'], audio_raw_max_length, 0)
for s in samples])
audio_mask = torch.zeros(len(samples), audio_raw_max_length)
for line, sample in enumerate(samples):
audio_mask[line, :sample['audio'].shape[0]] = 1
elif self.dataset_config.encoder_name == 'whisper':
audio_mel_max_length = max([s['audio_mel'].shape[0] for s in samples])
audio_mel = torch.stack([self.pad(s['audio_mel'], audio_mel_max_length, 0)
for s in samples])
audio_mel_post_mask = torch.zeros(len(samples), (
audio_mel_max_length + 1) // 2) # ad-hoc for whisper for 2x downsample from mel to feats
for line, sample in enumerate(samples):
audio_mel_post_mask[line, :(sample['audio_mel'].shape[0] + 1) // 2] = 1
modality_mask = torch.zeros_like(attention_mask)
for index in range(len(samples)):
padding_left = input_prompt_max_length - input_prompt_lengths[index]
modality_mask[index, padding_left:padding_left + samples[index]["audio_length"]] = True
keys = [s['key'] for s in samples]
if self.is_inference:
targets = [s['target'] for s in samples]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"audio": audio_raw if self.input_type == "raw" else None,
"audio_mask": audio_mask if self.input_type == "raw" else None,
"audio_mel": audio_mel if self.input_type == "mel" else None,
"audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None,
"modality_mask": modality_mask,
"keys": keys,
"targets": targets,
"audio_mel_reallen": audio_mel_reallen if self.dataset_config.encoder_name == 'paraformer' else None
}
labels = torch.stack([
self.padding(
self.padding(samples[index]['labels'], input_prompt_max_length - input_prompt_lengths[index],
self.IGNORE_INDEX, padding_side="left"),
input_answer_max_length - input_answer_lengths[index], self.IGNORE_INDEX)
for index in range(len(samples))
])
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"audio": audio_raw if self.input_type == "raw" else None,
"audio_mask": audio_mask if self.input_type == "raw" else None,
"audio_mel": audio_mel if self.input_type == "mel" else None,
"audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None,
"modality_mask": modality_mask,
"audio_mel_reallen": audio_mel_reallen if self.dataset_config.encoder_name == 'paraformer' else None,
"keys": keys,
}
\ No newline at end of file
import copy
import json
import logging
import random
import tarfile
from subprocess import PIPE, Popen
from urllib.parse import urlparse
import numpy as np
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
def compute_fbank(waveform,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
fs=16000,
snip_edges=True,
window_type="hamming"):
sample_rate = fs
waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
# Only keep key, feat, label
mat = kaldi.fbank(waveform,
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
sample_frequency=sample_rate,
window_type=window_type,
snip_edges=snip_edges)
return mat
def apply_lfr(inputs, lfr_m, lfr_n):
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / lfr_n))
left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
inputs = torch.vstack((left_padding, inputs))
T = T + (lfr_m - 1) // 2
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).view(1, -1))
else: # process last LFR frame
num_padding = lfr_m - (T - i * lfr_n)
frame = (inputs[i * lfr_n :]).view(-1)
for _ in range(num_padding):
frame = torch.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32)
def load_cmvn(cmvn_file):
with open(cmvn_file, "r", encoding="utf-8") as f:
lines = f.readlines()
means_list = []
vars_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == "<AddShift>":
line_item = lines[i + 1].split()
if line_item[0] == "<LearnRateCoef>":
add_shift_line = line_item[3 : (len(line_item) - 1)]
means_list = list(add_shift_line)
continue
elif line_item[0] == "<Rescale>":
line_item = lines[i + 1].split()
if line_item[0] == "<LearnRateCoef>":
rescale_line = line_item[3 : (len(line_item) - 1)]
vars_list = list(rescale_line)
continue
means = np.array(means_list).astype(np.float32)
vars = np.array(vars_list).astype(np.float32)
cmvn = np.array([means, vars])
cmvn = torch.as_tensor(cmvn, dtype=torch.float32)
return cmvn
def apply_cmvn(inputs, cmvn): # noqa
"""
Apply CMVN with mvn data
"""
device = inputs.device
dtype = inputs.dtype
frame, dim = inputs.shape
means = cmvn[0:1, :dim]
vars = cmvn[1:2, :dim]
inputs += means.to(device)
inputs *= vars.to(device)
return inputs.type(torch.float32)
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
def url_opener(data):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
url = sample['src']
try:
pr = urlparse(url)
# local file
if pr.scheme == '' or pr.scheme == 'file':
stream = open(url, 'rb')
# network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP
else:
cmd = f'curl -s -L {url}'
process = Popen(cmd, shell=True, stdout=PIPE)
sample.update(process=process)
stream = process.stdout
sample.update(stream=stream)
yield sample
except Exception as ex:
logging.warning('Failed to open {}'.format(url))
def tar_file_and_group(data):
""" Expand a stream of open tar files into a stream of tar file contents.
And groups the file with same prefix
Args:
data: Iterable[{src, stream}]
Returns:
Iterable[{key, wav, txt, sample_rate}]
"""
for sample in data:
assert 'stream' in sample
stream = tarfile.open(fileobj=sample['stream'], mode="r|*")
prev_prefix = None
example = {}
valid = True
try:
for tarinfo in stream:
name = tarinfo.name
pos = name.rfind('.')
assert pos > 0
prefix, postfix = name[:pos], name[pos + 1:]
if prev_prefix is not None and prefix != prev_prefix:
example['key'] = prev_prefix
if valid:
yield example
example = {}
valid = True
with stream.extractfile(tarinfo) as file_obj:
try:
if postfix == 'txt':
example['txt'] = file_obj.read().decode('utf8').strip()
elif postfix in AUDIO_FORMAT_SETS:
waveform, sample_rate = torchaudio.load(file_obj)
example['wav'] = waveform
example['sample_rate'] = sample_rate
else:
example[postfix] = file_obj.read()
except Exception as ex:
valid = False
logging.warning('error to parse {}'.format(name))
prev_prefix = prefix
if prev_prefix is not None:
example['key'] = prev_prefix
yield example
stream.close()
if 'process' in sample:
sample['process'].communicate()
sample['stream'].close()
except Exception as e:
logging.warning(e)
logging.warning('error to parse {}'.format(sample))
def parse_raw(data):
""" Parse key/wav/txt from json line
Args:
data: Iterable[str], str is a json line has key/wav/txt
Returns:
Iterable[{key, wav, txt, sample_rate}]
"""
for sample in data:
assert 'src' in sample
json_line = sample['src']
obj = json.loads(json_line)
assert 'key' in obj
assert 'wav' in obj
assert 'txt' in obj
key = obj['key']
wav_file = obj['wav']
txt = obj['txt']
try:
if 'start' in obj:
assert 'end' in obj
sample_rate = torchaudio.backend.sox_io_backend.info(
wav_file).sample_rate
start_frame = int(obj['start'] * sample_rate)
end_frame = int(obj['end'] * sample_rate)
waveform, _ = torchaudio.backend.sox_io_backend.load(
filepath=wav_file,
num_frames=end_frame - start_frame,
frame_offset=start_frame)
else:
waveform, sample_rate = torchaudio.load(wav_file)
example = dict(key=key,
txt=txt,
wav=waveform,
sample_rate=sample_rate)
yield example
except Exception as ex:
logging.warning('Failed to read {}'.format(wav_file))
def filter(data,
max_length=10240,
min_length=10):
""" Filter sample according to feature and label length
Inplace operation.
Args::
data: Iterable[{key, wav, label, sample_rate}]
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample
# sample['wav'] is torch.Tensor, we have 100 frames every second
num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100
if num_frames < min_length:
continue
if num_frames > max_length:
continue
yield sample
def refine_text(data, replaced_table, replace_type, concat_token='<|im_end|>'):
for sample in data:
assert 'txt' in sample
assert 'key' in sample
uttid = sample['key']
if replaced_table.get(uttid, None) is None:
continue
if replace_type == 'replace':
sample['txt'] = replaced_table[uttid]
elif replace_type == 'concat':
sample['txt'] = sample['txt'] + concat_token + replaced_table[uttid]
elif replace_type == 'concat_r':
sample['txt'] = replaced_table[uttid] + concat_token + sample['txt']
elif replace_type == 'instruction':
sample['txt'] = {
'asr': sample['txt'],
'ast': replaced_table[uttid],
'asr_ast': sample['txt'] + concat_token + replaced_table[uttid]
}
else:
raise KeyError
yield sample
def resample(data, resample_rate=16000):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
if sample_rate != resample_rate:
sample['sample_rate'] = resample_rate
sample['wav'] = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
yield sample
def speed_perturb(data, speeds=None):
""" Apply speed perturb to the data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
speeds(List[float]): optional speed
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
if speeds is None:
speeds = [0.9, 1.0, 1.1]
for sample in data:
assert 'sample_rate' in sample
assert 'wav' in sample
sample_rate = sample['sample_rate']
waveform = sample['wav']
speed = random.choice(speeds)
if speed != 1.0:
wav, _ = torchaudio.sox_effects.apply_effects_tensor(
waveform, sample_rate,
[['speed', str(speed)], ['rate', str(sample_rate)]])
sample['wav'] = wav
yield sample
def compute_w2vbert_fbank(sample,
num_mel_bins=23,
frame_length=25,
frame_shift=10,
dither=0.0):
""" Extract Pretrain w2vbert(4.5M hours) fbank
"""
sample = compute_fbank(sample, num_mel_bins, frame_length, frame_shift,
dither)
mat = sample['feat']
std, mean = torch.std_mean(mat, dim=0)
mat = mat.subtract(mean).divide(std)
sample['feat'] = mat
return sample
def gen_llm_inputs(data, tokenizer, ignore_index=-100, input_type='raw',
normalize=True, mel_size=128, prompt_template=None,
is_inference=False, autoprocesser=None, is_paraformer=False,
cmvn=None, prompt_org="Transcribe speech to text. ", adapter_downsample_rate=5,
instruction=False):
if prompt_template is None:
prompt_template = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
if instruction:
assert isinstance(prompt_org, dict)
else:
prompt = prompt_template.format(prompt_org)
answer_template = "{}"
fix_length_audio = -1
input_type = input_type
normalize = normalize
for sample in data:
audio_raw = sample['wav'][0]
if instruction:
target_dict = sample['txt']
task_list = list(target_dict.keys())
task_now = random.choice(task_list)
prompt = prompt_template.format(prompt_org[task_now])
target = target_dict[task_now].replace('▁', ' ')
else:
target = sample['txt'].replace('▁', ' ')
key = sample['key']
if autoprocesser is not None:
audio_mel = autoprocesser(audio_raw, sampling_rate=16000, return_tensors="pt")['input_features'].squeeze(0)
audio_length = audio_mel.shape[0] # w2v, downsample 4 has been processed in autoprocesser
audio_length = audio_length // adapter_downsample_rate
input_type = "mel"
elif is_paraformer:
audio_mel = compute_fbank(waveform=audio_raw)
audio_mel = apply_lfr(inputs=audio_mel, lfr_m=7, lfr_n=6)
audio_mel = apply_cmvn(audio_mel, cmvn=cmvn)
audio_length = audio_mel.shape[0]
audio_length = audio_length // adapter_downsample_rate
input_type = "mel"
elif input_type == "raw":
if normalize:
audio_raw = torch.nn.functional.layer_norm(audio_raw, audio_raw.shape)
audio_length = len(audio_raw) // 320 # ad-hoc for fairseq 320x downsample
audio_length = audio_length // adapter_downsample_rate
elif input_type == "mel":
# NOTE: this is for whisper, you can use compute_fbank to support your encoder
import whisper
audio_raw = whisper.pad_or_trim(audio_raw)
audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size).permute(1, 0)
audio_length = (audio_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats
audio_length = audio_length // adapter_downsample_rate
else:
raise KeyError
if fix_length_audio > 0:
audio_length = fix_length_audio
audio_pseudo = torch.full((audio_length,), -1)
prompt_ids = tokenizer.encode(prompt)
prompt_length = len(prompt_ids)
if is_inference:
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64)
example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio,prompt]
example_mask = example_ids.ge(-1)
yield {
"input_ids": example_ids,
"attention_mask": example_mask,
"audio": audio_raw if input_type == "raw" else None,
"audio_mel": audio_mel if input_type == "mel" else None,
"audio_length": audio_length,
"key": key,
"target": target,
"prompt_length": prompt_length,
}
else:
answer = answer_template.format(target)
example = prompt + answer
example_ids = tokenizer.encode(example)
example_ids.append(tokenizer.eos_token_id)
example_ids = torch.tensor(
example_ids, dtype=torch.int64
)
example_ids = torch.cat((audio_pseudo, example_ids))
labels_ids = copy.deepcopy(example_ids)
labels_ids[:audio_length + prompt_length] = -1
example_mask = example_ids.ge(-1)
label_mask = labels_ids.ge(0)
example_ids[~example_mask] = 0
labels_ids[~label_mask] = ignore_index
yield {
"input_ids": example_ids,
"labels": labels_ids,
"attention_mask": example_mask,
"audio": audio_raw if input_type == "raw" else None,
"audio_mel": audio_mel if input_type == "mel" else None,
"audio_length": audio_length,
"key": key,
"prompt_length": prompt_length,
}
def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80):
""" Do spec augmentation
Inplace operation
Args:
data: Iterable[{key, feat, label}]
num_t_mask: number of time mask to apply
num_f_mask: number of freq mask to apply
max_t: max width of time mask
max_f: max width of freq mask
max_w: max width of time warp
Returns
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'feat' in sample
x = sample['feat']
assert isinstance(x, torch.Tensor)
y = x.clone().detach()
max_frames = y.size(0)
max_freq = y.size(1)
# time mask
for i in range(num_t_mask):
start = random.randint(0, max_frames - 1)
length = random.randint(1, max_t)
end = min(max_frames, start + length)
y[start:end, :] = 0
# freq mask
for i in range(num_f_mask):
start = random.randint(0, max_freq - 1)
length = random.randint(1, max_f)
end = min(max_freq, start + length)
y[:, start:end] = 0
sample['feat'] = y
yield sample
def spec_sub(data, max_t=20, num_t_sub=3):
""" Do spec substitute
Inplace operation
Args:
data: Iterable[{key, feat, label}]
max_t: max width of time substitute
num_t_sub: number of time substitute to apply
Returns
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'feat' in sample
x = sample['feat']
assert isinstance(x, torch.Tensor)
y = x.clone().detach()
max_frames = y.size(0)
for i in range(num_t_sub):
start = random.randint(0, max_frames - 1)
length = random.randint(1, max_t)
end = min(max_frames, start + length)
# only substitute the earlier time chosen randomly for current time
pos = random.randint(0, start)
y[start:end, :] = x[start - pos:end - pos, :]
sample['feat'] = y
yield sample
def shuffle(data, shuffle_size=20000):
""" Local shuffle the data
Args:
data: Iterable[{key, feat, label}]
shuffle_size: buffer size for shuffle
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf:
yield x
buf = []
# The sample left over
random.shuffle(buf)
for x in buf:
yield x
def sort(data, sort_size=500, key='feat'):
""" Sort the data by feature length.
Sort is used after shuffle and before batch, so we can group
utts with similar lengths into a batch, and `sort_size` should
be less than `shuffle_size`
Args:
data: Iterable[{key, feat, label}]
sort_size: buffer size for sort
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= sort_size:
buf.sort(key=lambda x: x[key].size(0))
for x in buf:
yield x
buf = []
# The sample left over
buf.sort(key=lambda x: x[key].size(0))
for x in buf:
yield x
def pad(sequence, max_length, padding_idx=0):
if isinstance(sequence, (int, list, tuple)):
if len(sequence) < max_length:
sequence = sequence + [padding_idx] * (max_length - len(sequence))
else:
sequence = sequence[:max_length]
elif isinstance(sequence, torch.Tensor):
if len(sequence) < max_length:
sequence = torch.cat(
(sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx)))
else:
sequence = sequence[:max_length]
elif isinstance(sequence, np.ndarray):
if len(sequence) < max_length:
sequence = np.concatenate(
(sequence, np.full((max_length - len(sequence),) + sequence.shape[1:], padding_idx)))
else:
sequence = sequence[:max_length]
else:
raise Exception("Type mismatch during padding!")
return sequence
def padding(sequence, padding_length, padding_idx=0, padding_side="right"):
if isinstance(sequence, (int, list, tuple)):
if padding_length >= 0:
sequence = sequence + [padding_idx] * padding_length
else:
sequence = sequence[:padding_length]
elif isinstance(sequence, torch.Tensor):
if sequence.ndimension() == 2:
if padding_length >= 0:
sequence = torch.nn.functional.pad(sequence, (0, padding_length))
else:
sequence = sequence[:, :padding_length]
else:
if padding_length >= 0:
if padding_side == "left":
sequence = torch.cat(
(torch.full(([padding_length] + list(sequence.size())[1:]), padding_idx), sequence))
else:
sequence = torch.cat(
(sequence, torch.full(([padding_length] + list(sequence.size())[1:]), padding_idx)))
else:
sequence = sequence[:padding_length]
elif isinstance(sequence, np.ndarray):
if padding_length >= 0:
sequence = np.concatenate(
(sequence, np.full((padding_length,) + sequence.shape[1:], padding_idx)))
else:
sequence = sequence[:padding_length]
else:
raise Exception("Type mismatch during padding!")
return sequence
def process_batch(samples, tokenizer):
input_prompt_lengths = [s["audio_length"] + s['prompt_length'] for s in samples] # [120, 48, 82, 42]
input_answer_lengths = [len(s["input_ids"]) - s["audio_length"] - s['prompt_length'] for s in
samples] # [0, 0, 0, 0]
input_prompt_max_length = max(input_prompt_lengths)
input_answer_max_length = max(input_answer_lengths)
input_ids = torch.stack([
padding(
padding(samples[index]["input_ids"], input_prompt_max_length - input_prompt_lengths[index],
tokenizer.pad_token_id, padding_side="left"),
input_answer_max_length - input_answer_lengths[index], tokenizer.pad_token_id
) for index in range(len(samples))
])
attention_mask = torch.stack([
padding(
padding(samples[index]["attention_mask"], input_prompt_max_length - input_prompt_lengths[index],
False, padding_side="left"),
input_answer_max_length - input_answer_lengths[index], False
) for index in range(len(samples))
])
audio_mel_reallen = [s['audio_mel'].shape[0] for s in samples]
audio_mel_max_length = max(audio_mel_reallen)
audio_mel = torch.stack([pad(s['audio_mel'], audio_mel_max_length, 0)
for s in samples])
audio_mel_post_mask = torch.zeros(len(samples), (
audio_mel_max_length)) # paraformer
for line, sample in enumerate(samples):
audio_mel_post_mask[line, :(sample['audio_mel'].shape[0])] = 1
audio_mel_reallen = torch.tensor(audio_mel_reallen, dtype=torch.int32)
modality_mask = torch.zeros_like(attention_mask)
for index in range(len(samples)):
padding_left = input_prompt_max_length - input_prompt_lengths[index]
modality_mask[index, padding_left:padding_left + samples[index]["audio_length"]] = True
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"audio_mel": audio_mel,
"audio_mel_post_mask": audio_mel_post_mask,
"modality_mask": modality_mask,
"audio_mel_reallen": audio_mel_reallen
}
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
"""Make mask tensor containing indices of padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
Tensor: Mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if not isinstance(lengths, list):
lengths = lengths.tolist()
bs = int(len(lengths))
if maxlen is None:
if xs is None:
maxlen = int(max(lengths))
else:
maxlen = xs.size(length_dim)
else:
assert xs is None
assert maxlen >= int(max(lengths))
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
if xs is not None:
assert xs.size(0) == bs, (xs.size(0), bs)
if length_dim < 0:
length_dim = xs.dim() + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(slice(None) if i in (0, length_dim) else None for i in range(xs.dim()))
mask = mask[ind].expand_as(xs).to(xs.device)
return mask
class SinusoidalPositionEncoder(torch.nn.Module):
""" """
def __int__(self, d_model=80, dropout_rate=0.1):
pass
def encode(
self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32
):
batch_size = positions.size(0)
positions = positions.type(dtype)
device = positions.device
log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / (
depth / 2 - 1
)
inv_timescales = torch.exp(
torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment)
)
inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
inv_timescales, [1, 1, -1]
)
encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
return encoding.type(dtype)
def forward(self, x):
batch_size, timesteps, input_dim = x.size()
positions = torch.arange(1, timesteps + 1, device=x.device)[None, :]
position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
return x + position_encoding
class PositionwiseFeedForward(torch.nn.Module):
"""Positionwise feed forward layer.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
"""Construct an PositionwiseFeedForward object."""
super(PositionwiseFeedForward, self).__init__()
self.w_1 = torch.nn.Linear(idim, hidden_units)
self.w_2 = torch.nn.Linear(hidden_units, idim)
self.dropout = torch.nn.Dropout(dropout_rate)
self.activation = activation
def forward(self, x):
"""Forward function."""
return self.w_2(self.dropout(self.activation(self.w_1(x))))
class MultiHeadedAttentionSANM(nn.Module):
"""Multi-Head Attention layer.
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(
self,
n_head,
in_feat,
n_feat,
dropout_rate,
kernel_size,
sanm_shfit=0,
lora_list=None,
lora_rank=8,
lora_alpha=16,
lora_dropout=0.1,
):
"""Construct an MultiHeadedAttention object."""
super().__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_out = nn.Linear(n_feat, n_feat)
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
self.fsmn_block = nn.Conv1d(
n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
)
# padding
left_padding = (kernel_size - 1) // 2
if sanm_shfit > 0:
left_padding = left_padding + sanm_shfit
right_padding = kernel_size - 1 - left_padding
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
b, t, d = inputs.size()
if mask is not None:
mask = torch.reshape(mask, (b, -1, 1))
if mask_shfit_chunk is not None:
mask = mask * mask_shfit_chunk
inputs = inputs * mask
x = inputs.transpose(1, 2)
x = self.pad_fn(x)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x += inputs
x = self.dropout(x)
if mask is not None:
x = x * mask
return x
def forward_qkv(self, x):
"""Transform query, key and value.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
Returns:
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
"""
b, t, d = x.size()
q_k_v = self.linear_q_k_v(x)
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
1, 2
) # (batch, head, time1, d_k)
k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
1, 2
) # (batch, head, time2, d_k)
v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
1, 2
) # (batch, head, time2, d_k)
return q_h, k_h, v_h, v
def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch = value.size(0)
if mask is not None:
if mask_att_chunk_encoder is not None:
mask = mask * mask_att_chunk_encoder
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
min_value = -float(
"inf"
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
q_h, k_h, v_h, v = self.forward_qkv(x)
fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
q_h = q_h * self.d_k ** (-0.5)
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
return att_outs + fsmn_memory
def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
q_h, k_h, v_h, v = self.forward_qkv(x)
if chunk_size is not None and look_back > 0 or look_back == -1:
if cache is not None:
k_h_stride = k_h[:, :, : -(chunk_size[2]), :]
v_h_stride = v_h[:, :, : -(chunk_size[2]), :]
k_h = torch.cat((cache["k"], k_h), dim=2)
v_h = torch.cat((cache["v"], v_h), dim=2)
cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
if look_back != -1:
cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]) :, :]
cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]) :, :]
else:
cache_tmp = {
"k": k_h[:, :, : -(chunk_size[2]), :],
"v": v_h[:, :, : -(chunk_size[2]), :],
}
cache = cache_tmp
fsmn_memory = self.forward_fsmn(v, None)
q_h = q_h * self.d_k ** (-0.5)
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
att_outs = self.forward_attention(v_h, scores, None)
return att_outs + fsmn_memory, cache
class MultiSequential(torch.nn.Sequential):
"""Multi-input multi-output torch.nn.Sequential."""
def __init__(self, *args, layer_drop_rate=0.0):
"""Initialize MultiSequential with layer_drop.
Args:
layer_drop_rate (float): Probability of dropping out each fn (layer).
"""
super(MultiSequential, self).__init__(*args)
self.layer_drop_rate = layer_drop_rate
def forward(self, *args):
"""Repeat."""
_probs = torch.empty(len(self)).uniform_()
for idx, m in enumerate(self):
if not self.training or (_probs[idx] >= self.layer_drop_rate):
args = m(*args)
return args
class CheckpointFunction(torch.nn.Module):
"""Wrap a module function for gradient checkpointing."""
def __init__(self, module_fn):
super().__init__()
self.module_fn = module_fn
def forward(self, *args, **kwargs):
# Use checkpointing on the forward pass of the module function
return checkpoint(self.module_fn, *args, **kwargs)
def repeat(N, fn, layer_drop_rate=0.0, use_checkpoint=False):
"""Repeat module N times.
Args:
N (int): Number of repeat time.
fn (Callable): Function to generate module.
layer_drop_rate (float): Probability of dropping out each fn (layer).
Returns:
MultiSequential: Repeated model instance.
"""
modules = []
for n in range(N):
module_fn = fn(n)
if use_checkpoint:
# Wrap the module function with checkpointing
module_fn = CheckpointFunction(module_fn)
modules.append(module_fn)
return MultiSequential(*modules, layer_drop_rate=layer_drop_rate)
class LayerNorm(torch.nn.LayerNorm):
"""Layer normalization module.
Args:
nout (int): Output dim size.
dim (int): Dimension to be normalized.
"""
def __init__(self, nout, dim=-1):
"""Construct an LayerNorm object."""
super(LayerNorm, self).__init__(nout, eps=1e-12)
self.dim = dim
def forward(self, x):
"""Apply layer normalization.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Normalized tensor.
"""
if self.dim == -1:
return super(LayerNorm, self).forward(x)
return super(LayerNorm, self).forward(x.transpose(self.dim, -1)).transpose(self.dim, -1)
class EncoderLayerSANM(nn.Module):
def __init__(
self,
in_size,
size,
self_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayerSANM, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(in_size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.in_size = in_size
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
self.dropout_rate = dropout_rate
def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
residual = x
if self.normalize_before:
x = self.norm1(x)
if self.concat_after:
x_concat = torch.cat(
(
x,
self.self_attn(
x,
mask,
mask_shfit_chunk=mask_shfit_chunk,
mask_att_chunk_encoder=mask_att_chunk_encoder,
),
),
dim=-1,
)
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = stoch_layer_coeff * self.concat_linear(x_concat)
else:
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.dropout(
self.self_attn(
x,
mask,
mask_shfit_chunk=mask_shfit_chunk,
mask_att_chunk_encoder=mask_att_chunk_encoder,
)
)
else:
x = stoch_layer_coeff * self.dropout(
self.self_attn(
x,
mask,
mask_shfit_chunk=mask_shfit_chunk,
mask_att_chunk_encoder=mask_att_chunk_encoder,
)
)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
class SANMEncoder(nn.Module):
def __init__(
self,
input_size: int = 560,
output_size: int = 512,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 50,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.1,
input_layer: Optional[str] = "pe",
pos_enc_class=SinusoidalPositionEncoder,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
kernel_size: int = 11,
sanm_shfit: int = 0,
lora_list: List[str] = None,
lora_rank: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.1,
selfattention_layer_type: str = "sanm",
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
gradient_checkpoint=False
):
super().__init__()
self._output_size = output_size
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if selfattention_layer_type == "sanm":
encoder_selfattn_layer = MultiHeadedAttentionSANM
encoder_selfattn_layer_args0 = (
attention_heads,
input_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
lora_list,
lora_rank,
lora_alpha,
lora_dropout,
)
encoder_selfattn_layer_args = (
attention_heads,
output_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
lora_list,
lora_rank,
lora_alpha,
lora_dropout,
)
self.encoders0 = repeat(
1,
lambda lnum: EncoderLayerSANM(
input_size,
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args0),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
use_checkpoint=gradient_checkpoint
)
self.encoders = repeat(
num_blocks - 1,
lambda lnum: EncoderLayerSANM(
output_size,
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
use_checkpoint=gradient_checkpoint
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
self.dropout = nn.Dropout(dropout_rate)
self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
xs_pad = xs_pad * self.output_size() ** 0.5
xs_pad = self.embed(xs_pad)
# xs_pad = self.dropout(xs_pad)
encoder_outs = self.encoders0(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
encoder_outs = self.encoders(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class LinearAdapter(nn.Module):
def __init__(self, config):
super().__init__()
self.k = config.adapter_downsample_rate
self.encoder_dim = config.encoder_dim
self.llm_dim = config.llm_dim
self.linear1 = nn.Linear(self.encoder_dim * self.k, 2048)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(2048, config.llm_dim)
def forward(self, x, gradient_checkpoint=False):
batch_size, seq_len, dim = x.size()
num_frames_to_discard = seq_len % self.k
if num_frames_to_discard > 0:
x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1)
x = x.contiguous()
x = x.view(batch_size, seq_len // self.k, dim * self.k)
if gradient_checkpoint:
x = checkpoint(self.linear1, x)
else:
x = self.linear1(x)
x = self.relu(x)
if gradient_checkpoint:
x = checkpoint(self.linear2, x)
else:
x = self.linear2(x)
return x
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
class WhisperWrappedEncoder:
@classmethod
def load(cls, model_config):
def extract_variable_length_features(self, x: torch.Tensor, gradient_checkpoint=False):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
# assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
# x = (x + self.positional_embedding).to(x.dtype)
x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype)
for block in self.blocks:
if gradient_checkpoint:
x = checkpoint(block, x)
else:
x = block(x)
if gradient_checkpoint:
x = checkpoint(self.ln_post, x)
else:
x = self.ln_post(x)
return x
import whisper
encoder = whisper.load_model(name=model_config.encoder_path, device='cpu').encoder
encoder.extract_variable_length_features = types.MethodType(extract_variable_length_features, encoder)
return encoder
class HubertEncoder:
@classmethod
def load(cls, model_config):
import fairseq
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path])
model = models[0]
if model_config.encoder_type == "pretrain":
pass
elif model_config.encoder_type == "finetune":
model.w2v_encoder.proj = None
model.w2v_encoder.apply_mask = False
else:
assert model_config.encoder_type in ["pretrain", "finetune"], "input_type must be one of [pretrain, finetune]"
return model
class W2vBert2Encoder(nn.Module):
def __init__(self, config, model):
super().__init__()
self.config = config
self.model = model
@classmethod
def load(cls, model_config):
from transformers import Wav2Vec2BertModel
model = Wav2Vec2BertModel.from_pretrained(model_config.encoder_path)
return cls(model_config, model)
def extract_features(self, source, attention_mask):
output = self.model(source, attention_mask=attention_mask)
return output.last_hidden_state
class HfTextEncoder:
@classmethod
def load(cls, model_config):
from transformers import AutoModel
model = AutoModel.from_pretrained(model_config.encoder_path)
return model
class ParaformerEncoder(nn.Module):
def __init__(self, config, model):
super().__init__()
self.config = config
self.model = model
@classmethod
def load(cls, model_config):
from .Paraformer.encoder import SANMEncoder
model = SANMEncoder(gradient_checkpoint=model_config.get('gradient_checkpoint', False))
ckpt_dict = torch.load(model_config.encoder_path, map_location="cpu")
model.load_state_dict(ckpt_dict, strict=False)
return cls(model_config, model)
def extract_features(self, source, reallen):
output, _, _ = self.model(
xs_pad=source,
ilens=reallen
)
# TODO: support streaming @zhenlin.liang
return output
import json
import os
import torch
import torchaudio
from .hifigan import Generator
from .speaker_encoder import ResNetSpeakerEncoder
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def get_hifigan_model(vocoder_path, device,
vocoder_config=None, decoder_dim=1024):
if vocoder_config is None:
vocoder_config = os.path.join(os.path.dirname(__file__), 'hifigan_config.json')
print(f"Loading vocoder config from {vocoder_config}")
with open(vocoder_config, "r") as f:
config = json.load(f)
config = AttrDict(config)
config.input_num_mels = decoder_dim
vocoder = Generator(config)
print(f"Loading vocoder from {vocoder_path}")
ckpt = torch.load(vocoder_path, map_location=device)
vocoder.load_state_dict(ckpt["generator"], strict=True)
vocoder.eval()
vocoder.remove_weight_norm()
vocoder.to(device)
return vocoder
def get_speaker_encoder(model_checkpoint, device):
model = ResNetSpeakerEncoder(
input_dim=64,
proj_dim=512,
log_input=True,
use_torch_spec=True,
audio_config={
"fft_size": 512,
"win_length": 400,
"hop_length": 160,
"sample_rate": 16000,
"preemphasis": 0.97,
"num_mels": 64
},
)
print(f"Loading speaker encoder from {model_checkpoint}")
checkpoint = torch.load(model_checkpoint, map_location=device)["speaker_encoder"]
model.load_state_dict(checkpoint, strict=True)
model.eval()
model.to(device)
return model
def encode_prompt_wav(spk_encoder, prompt_wav, device):
wav, sr = torchaudio.load(prompt_wav)
wav_16k = torchaudio.functional.resample(wav, sr, 16000)
vocoder_latent = spk_encoder.forward(
wav_16k.to(device), l2_norm=True
).unsqueeze(-1)
return vocoder_latent
def save_wav(hifigan_generator, embedding, latent, save_path):
samples = hifigan_generator(latent.permute(0, 2, 1), g=embedding).squeeze(0)
samples = samples.detach().cpu()
torchaudio.save(save_path, samples, 22050)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm
LRELU_SLOPE = 0.1
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
class ResBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2])))
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.h = h
self.convs = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])))
])
self.convs.apply(init_weights)
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class Generator(torch.nn.Module):
def __init__(self, h):
super(Generator, self).__init__()
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
self.conv_pre = weight_norm(
Conv1d(h.input_num_mels if hasattr(h, "input_num_mels") and h.input_num_mels else h.num_mels,
h.upsample_initial_channel, 7, 1, padding=(7 - 1) // 2))
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(weight_norm(
ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)),
k, u, padding=(k - u) // 2)))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=(7 - 1) // 2))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
if h.cond_channels > 0:
self.cond_layer = nn.Conv1d(h.cond_channels, h.upsample_initial_channel, 1)
if h.cond_in_each_up_layer:
self.conds = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
self.conds.append(nn.Conv1d(h.cond_channels, ch, 1))
def forward(self, x, g=None):
x = torch.nn.functional.interpolate(
x,
scale_factor=[self.h.input_up_scale_factor],
mode="linear",
).squeeze(1)
x = self.conv_pre(x)
if self.h.cond_channels > 0:
x = x + self.cond_layer(g)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
if self.h.cond_in_each_up_layer:
x = x + self.conds[i](g)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 1024,
"learning_rate": 0.0002,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.99,
"seed": 1234,
"upsample_rates": [8,8,2,2],
"upsample_kernel_sizes": [16,16,4,4],
"upsample_initial_channel": 512,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"discriminator_periods": null,
"segment_size": 8192,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 22050,
"fmin": 0,
"fmax": 8000,
"fmax_for_loss": null,
"num_workers": 8,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
},
"input_num_mels": 3584,
"input_n_fft": null,
"input_hop_size": null,
"input_win_size": null,
"input_sampling_rate": null,
"cond_channels": 512,
"cond_in_each_up_layer": true,
"input_up_scale_factor": 4
}
\ No newline at end of file
import torch
import torch.nn as nn
import torchaudio
class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel),
nn.Sigmoid(),
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
class SEBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
super(SEBasicBlock, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.se = SELayer(planes, reduction)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.relu(out)
out = self.bn1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class PreEmphasis(nn.Module):
def __init__(self, coefficient=0.97):
super().__init__()
self.coefficient = coefficient
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
def forward(self, x):
assert len(x.size()) == 2
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
class ResNetSpeakerEncoder(nn.Module):
"""This is copied from 🐸TTS to remove it from the dependencies."""
# pylint: disable=W0102
def __init__(
self,
input_dim=64,
proj_dim=512,
layers=[3, 4, 6, 3],
num_filters=[32, 64, 128, 256],
encoder_type="ASP",
log_input=False,
use_torch_spec=False,
audio_config=None,
):
super(ResNetSpeakerEncoder, self).__init__()
self.encoder_type = encoder_type
self.input_dim = input_dim
self.log_input = log_input
self.use_torch_spec = use_torch_spec
self.audio_config = audio_config
self.proj_dim = proj_dim
self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.bn1 = nn.BatchNorm2d(num_filters[0])
self.inplanes = num_filters[0]
self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0])
self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2))
self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2))
self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2))
self.instancenorm = nn.InstanceNorm1d(input_dim)
if self.use_torch_spec:
self.torch_spec = torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
),
)
else:
self.torch_spec = None
outmap_size = int(self.input_dim / 8)
self.attention = nn.Sequential(
nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
nn.ReLU(),
nn.BatchNorm1d(128),
nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
nn.Softmax(dim=2),
)
if self.encoder_type == "SAP":
out_dim = num_filters[3] * outmap_size
elif self.encoder_type == "ASP":
out_dim = num_filters[3] * outmap_size * 2
else:
raise ValueError("Undefined encoder")
self.fc = nn.Linear(out_dim, proj_dim)
self._init_layers()
def _init_layers(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def create_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
# pylint: disable=R0201
def new_parameter(self, *size):
out = nn.Parameter(torch.FloatTensor(*size))
nn.init.xavier_normal_(out)
return out
def forward(self, x, l2_norm=False):
"""Forward pass of the model.
Args:
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
to compute the spectrogram on-the-fly.
l2_norm (bool): Whether to L2-normalize the outputs.
Shapes:
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
"""
x.squeeze_(1)
# if you torch spec compute it otherwise use the mel spec computed by the AP
if self.use_torch_spec:
x = self.torch_spec(x)
if self.log_input:
x = (x + 1e-6).log()
x = self.instancenorm(x).unsqueeze(1)
x = self.conv1(x)
x = self.relu(x)
x = self.bn1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.reshape(x.size()[0], -1, x.size()[-1])
w = self.attention(x)
if self.encoder_type == "SAP":
x = torch.sum(x * w, dim=2)
elif self.encoder_type == "ASP":
mu = torch.sum(x * w, dim=2)
sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5))
x = torch.cat((mu, sg), 1)
x = x.view(x.size()[0], -1)
x = self.fc(x)
if l2_norm:
x = torch.nn.functional.normalize(x, p=2, dim=1)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment