"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "5cc557336ccb76541198f636d546f1338b46ab05"
Commit b97afd54 authored by wangwei990215's avatar wangwei990215
Browse files

Initial commit

parents
Pipeline #1825 failed with stages
in 0 seconds
import time
import torch
try:
import torch_musa
except ImportError as e:
print("You should install torch_musa if you want to run on Moore Threads GPU")
import os
import argparse
import torchaudio
from torchaudio.transforms import Resample
import logging
from mooer.datasets.speech_processor import *
from mooer.configs import asr_config
from mooer.models import mooer_model
from mooer.utils.utils import *
from mooer.models.hifigan import save_wav, get_hifigan_model, get_speaker_encoder, encode_prompt_wav
parser = argparse.ArgumentParser()
parser.add_argument("--wav_path", default='demo/resources/demo.wav', type=str, help="decode one wav file")
parser.add_argument("--wav_scp", default=None, type=str, help="decode scp if you want")
parser.add_argument("--task", default='s2s_chat', choices=['asr', 'ast', 's2s_trans', 's2s_chat'],
type=str, help="task: asr or ast or s2s_trans or s2s_chat. "
"Please set ast if you choose a asr/ast/s2s_trans/s2s_chat multitask model")
parser.add_argument("--batch_size", default=1, type=int, help="decode batch for scp")
parser.add_argument("--cmvn_path", default='', type=str, help="cmvn path.")
parser.add_argument("--encoder_path", default='', type=str, help="encoder path.")
parser.add_argument("--llm_path", default='', type=str, help="llm path.")
parser.add_argument("--adapter_path", default='', type=str, help="adapter path.")
parser.add_argument("--lora_dir", default='', type=str, help="lora path.")
parser.add_argument("--vocoder_path", default='', type=str, help="vocoder path")
parser.add_argument("--spk_encoder_path", default='', type=str, help="spk encoder path")
parser.add_argument("--prompt_wav_path", default='', type=str, help="prompt wav path")
parser.add_argument("--output_dir", default="response_wavs_dir", type=str, help="path to save wav generated")
args = parser.parse_args()
assert args.batch_size == 1, "Only support bsz=1 for S2ST task now. We will support batch inference soon."
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
filemode='w'
)
PROMPT_TEMPLATE_DICT = {
'qwen': "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
}
PROMPT_DICT = {
'asr': "Transcribe speech to text. ",
'ast': "Translate speech to english text. ",
's2s_trans': "Translate speech to english speech. ",
's2s_chat': "Answer my question with speech. "
}
model_config = asr_config.ModelConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# replace path
if args.llm_path and os.path.exists(args.llm_path):
model_config.llm_path = args.llm_path
if args.encoder_path and os.path.exists(args.encoder_path):
model_config.encoder_path = args.encoder_path
if args.adapter_path and os.path.exists(args.adapter_path):
model_config.adapter_path = args.adapter_path
if args.lora_dir and os.path.exists(args.lora_dir):
model_config.lora_dir = args.lora_dir
if args.cmvn_path and os.path.exists(args.cmvn_path):
model_config.cmvn_path = args.cmvn_path
if args.task:
model_config.prompt_key = args.task
device = str(get_device())
logger.info("This demo will run on {}".format(device.upper()))
logger.info(model_config)
os.makedirs(args.output_dir, exist_ok=True)
logger.info("Response wav will save in {}".format(args.output_dir))
model, tokenizer = mooer_model.init_model(
model_config=model_config)
AUDIO_START_TOKEN_INDEX = tokenizer.get_vocab()['<|audio_start|>']
model.to(device)
model.eval()
# data process
prompt_template_key = model_config.get('prompt_template_key', 'qwen')
prompt_template = PROMPT_TEMPLATE_DICT[prompt_template_key]
prompt_key = model_config.get('prompt_key', 'asr')
prompt_org = PROMPT_DICT[prompt_key]
logger.info(f"Use LLM Type {prompt_template_key}, "
f"Prompt template {prompt_template}, "
f"Use task type {prompt_key}, "
f"Prompt {prompt_org}")
cmvn = load_cmvn(model_config.get('cmvn_path'))
adapter_downsample_rate = model_config.get('adapter_downsample_rate')
hifigan_generator = get_hifigan_model(args.vocoder_path, device, decoder_dim=3584)
spk_encoder = get_speaker_encoder(args.spk_encoder_path, device)
spk_embedding = encode_prompt_wav(spk_encoder, args.prompt_wav_path, device)
def process_wav(wav_path):
audio_raw, sample_rate = torchaudio.load(wav_path)
if sample_rate != 16000:
# resample the data
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
audio_raw = resampler(audio_raw)
if audio_raw.shape[0] > 1:
# convert to mono
audio_raw = audio_raw.mean(dim=0, keepdim=True)
audio_raw = audio_raw[0]
prompt = prompt_template.format(prompt_org)
audio_mel = compute_fbank(waveform=audio_raw)
audio_mel = apply_lfr(inputs=audio_mel, lfr_m=7, lfr_n=6)
audio_mel = apply_cmvn(audio_mel, cmvn=cmvn)
audio_length = audio_mel.shape[0]
audio_length = audio_length // adapter_downsample_rate
audio_pseudo = torch.full((audio_length,), -1)
prompt_ids = tokenizer.encode(prompt)
prompt_length = len(prompt_ids)
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64)
example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio, prompt]
example_mask = example_ids.ge(-1)
items = {
"input_ids": example_ids,
"attention_mask": example_mask,
"audio_mel": audio_mel,
"audio_length": audio_length,
"prompt_length": prompt_length,
}
return items
load_dtype = model_config.get('load_dtype', 'bfloat16')
dtype = torch.float32
if load_dtype == 'float16':
dtype = torch.float16
elif load_dtype == 'bfloat16':
dtype = torch.bfloat16
logging.info(f"Input data type: {dtype}")
context_scope = torch.musa.amp.autocast if 'musa' in device else torch.cuda.amp.autocast
with torch.no_grad():
if args.wav_scp is not None and os.path.exists(args.wav_scp):
batch_size = args.batch_size
infer_time = []
items = parse_key_text(args.wav_scp)
uttids = list(items.keys())
num_batches = len(uttids) // batch_size + (0 if len(uttids) % batch_size == 0 else 1)
for i in range(num_batches):
try:
batch_uttids = uttids[i * batch_size:(i + 1) * batch_size]
batch_wav_paths = [items[uttid] for uttid in batch_uttids]
samples = []
for wav_path in batch_wav_paths:
samples.append(process_wav(wav_path))
batch = process_batch(samples, tokenizer=tokenizer)
for key in batch.keys():
batch[key] = batch[key].to(device) if isinstance(batch[key], torch.Tensor) else batch[key]
with context_scope(dtype=dtype):
ss = time.perf_counter()
inputs_embeds, attention_mask, kwargs = model.generate(**batch, compute_llm=False)
prompt_and_encoding_length = inputs_embeds.shape[1]
model_outputs = model.llm.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=kwargs.get("max_new_tokens", 1000),
num_beams=kwargs.get("num_beams", 4),
do_sample=True,
min_length=kwargs.get("min_length", 1),
top_p=0.85,
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
length_penalty=kwargs.get("length_penalty", 1.0),
temperature=kwargs.get("temperature", 1.0),
attention_mask=attention_mask,
bos_token_id=model.tokenizer.bos_token_id,
eos_token_id=model.tokenizer.eos_token_id,
pad_token_id=model.tokenizer.pad_token_id,
)
infer_time.append(time.perf_counter() - ss)
logging.info(f"Infer time: {time.perf_counter() - ss}")
output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False,
skip_special_tokens=True)
if hasattr(model.llm.model, "embed_tokens"):
teacher_forcing_input_embeds = model.llm.model.embed_tokens(model_outputs)
teacher_forcing_input_att_mask = torch.ones((1, teacher_forcing_input_embeds.shape[1]),
dtype=torch.bool).to(device)
else:
raise NotImplementedError
inputs_embeds = torch.concat([inputs_embeds, teacher_forcing_input_embeds], dim=-2)
attention_mask = torch.concat([attention_mask, teacher_forcing_input_att_mask], dim=-1)
llm_output = model.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask,
output_hidden_states=True)
audio_start_index = prompt_and_encoding_length + model_outputs[0].tolist().index(AUDIO_START_TOKEN_INDEX)
audio_latents = llm_output.hidden_states[-1][:, audio_start_index:-6, :]
for idx, text in enumerate(output_text):
logger.info(f"uttid: {batch_uttids[idx]}")
audio_file_out_tts = os.path.join(args.output_dir, f"{batch_uttids[idx]}.tts.wav")
text_ast = text.split("<|audio_start|>")[0]
text_ast = text_ast.replace('\\n', '\n')
logger.info(f"AST: {text_ast}")
save_wav(hifigan_generator, spk_embedding, audio_latents.float(), audio_file_out_tts)
logger.info(f"Finished writing: {audio_file_out_tts}")
except Exception as e:
logging.error(e)
logging.info("Total inference cost")
logging.info(sum(infer_time))
elif args.wav_path != '' and os.path.exists(args.wav_path):
try:
wav_path = args.wav_path
items = process_wav(wav_path)
batch = process_batch([items], tokenizer=tokenizer)
for key in batch.keys():
batch[key] = batch[key].to(device) if isinstance(batch[key], torch.Tensor) else batch[key]
with context_scope(dtype=dtype):
ss = time.perf_counter()
inputs_embeds, attention_mask, kwargs = model.generate(**batch, compute_llm=False)
prompt_and_encoding_length = inputs_embeds.shape[1]
model_outputs = model.llm.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=kwargs.get("max_new_tokens", 1000),
num_beams=kwargs.get("num_beams", 4),
do_sample=True,
min_length=kwargs.get("min_length", 1),
top_p=0.85,
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
length_penalty=kwargs.get("length_penalty", 1.0),
temperature=kwargs.get("temperature", 1.0),
attention_mask=attention_mask,
bos_token_id=model.tokenizer.bos_token_id,
eos_token_id=model.tokenizer.eos_token_id,
pad_token_id=model.tokenizer.pad_token_id,
)
logging.info(f"Infer time: {time.perf_counter() - ss}")
output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False,
skip_special_tokens=True)
if hasattr(model.llm.model, "embed_tokens"):
teacher_forcing_input_embeds = model.llm.model.embed_tokens(model_outputs)
teacher_forcing_input_att_mask = torch.ones((1, teacher_forcing_input_embeds.shape[1]),
dtype=torch.bool).to(device)
else:
raise NotImplementedError
inputs_embeds = torch.concat([inputs_embeds, teacher_forcing_input_embeds], dim=-2)
attention_mask = torch.concat([attention_mask, teacher_forcing_input_att_mask], dim=-1)
llm_output = model.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask,
output_hidden_states=True)
audio_start_index = prompt_and_encoding_length + model_outputs[0].tolist().index(AUDIO_START_TOKEN_INDEX)
audio_latents = llm_output.hidden_states[-1][:, audio_start_index:-6, :]
for text in output_text:
uttid = os.path.basename(wav_path).replace(".wav", "")
audio_file_out_tts = os.path.join(args.output_dir, f"{uttid}.tts.wav")
text_ast = text.split("<|audio_start|>")[0]
text_ast = text_ast.replace('\\n', '\n')
logger.info(f"Text: {text_ast}")
save_wav(hifigan_generator, spk_embedding, audio_latents.float(), audio_file_out_tts)
logger.info(f"Finished writing: {audio_file_out_tts}")
except Exception as e:
logging.error(e)
else:
raise IOError("You should specify --wav_scp or --wav_path as the input")
#模型编码
modelCode=1067
# 模型名称
modelName=mooer_pytorch
# 模型描述
modelDescription=一个由摩尔线程开发的、基于大语言模型(Large Language Model,LLM)的语音识别和语音翻译系统。
# 应用场景(多个标签以英文逗号分割)
appScenario=推理,语音识别,语音翻译,教育,医疗
# 框架类型(多个标签以英文逗号分割)
frameType=PyTorch
\ No newline at end of file
<Nnet>
<Splice> 560 560
[ 0 ]
<AddShift> 560 560
<LearnRateCoef> 0 [ -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 ]
<Rescale> 560 560
<LearnRateCoef> 0 [ 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 ]
</Nnet>
from dataclasses import dataclass
from typing import Optional
@dataclass
class ModelConfig:
def __init__(self):
self.llm_name: str = "qwen2_7b_chat"
# You should set your own path
self.llm_path: str = "pretrained_models/Qwen2-7B-Instruct"
self.encoder_path: str = "pretrained_models/paraformer_encoder/paraformer-encoder.pth"
self.adapter_path: str = "pretrained_models/asr_ast_mtl/adapter_project.pt"
self.lora_dir: str = "pretrained_models/asr_ast_mtl/lora_weights"
self.cmvn_path: str = "pretrained_models/paraformer_encoder/am.mvn"
self.prompt_key: str = 'ast' # or asr for ASR model
###############################
self.llm_type: str = "decoder_only"
self.llm_dim: int = 3584
self.load_dtype: str = "bfloat16"
self.encoder_name: str = 'paraformer'
self.encoder_dim: int = 512
self.adapter: str = "linear"
self.adapter_downsample_rate: int = 2
self.modal: str = "audio"
self.normalize: Optional[bool] = False
self.gradient_checkpoint: bool = False
self.is_inference: bool = True
self.prompt_template_key: str = 'qwen'
def __getitem__(self, key):
return getattr(self, key)
def get(self, attribute_name, default_value=None):
return getattr(self, attribute_name, default_value)
from dataclasses import dataclass
from typing import Optional, List
@dataclass
class ModelConfig:
def __init__(self):
self.llm_name: str = "qwen2_7b_chat"
# You should set your own path
self.llm_path: str = "pretrained_models/Qwen2-7B-Instruct"
self.encoder_path: str = "pretrained_models/paraformer_encoder/paraformer-encoder.pth"
self.adapter_path: str = "pretrained_models/asr_ast_mtl/adapter_project.pt"
self.lora_dir: str = "pretrained_models/asr_ast_mtl/lora_weights"
self.cmvn_path: str = "/root/MooER/src/mooer/configs/am.mvn"
self.prompt_key: str = 'asr' # asr, ast... you can add tasks in src/mooer/utils/data_utils.py
###############################
self.llm_type: str = "decoder_only"
self.llm_dim: int = 3584
self.load_dtype: str = "bfloat16"
self.encoder_name: str = 'paraformer'
self.encoder_dim: int = 512
self.adapter: str = "linear"
self.adapter_downsample_rate: int = 2
self.modal: str = "audio"
self.normalize: Optional[bool] = False
self.gradient_checkpoint: bool = False
self.is_inference: bool = True
self.prompt_template_key: str = 'qwen'
def __getitem__(self, key):
return getattr(self, key)
def get(self, attribute_name, default_value=None):
return getattr(self, attribute_name, default_value)
@dataclass
class PeftConfig:
def __init__(self):
self.peft_method: str = "lora" # None , llama_adapter, prefix
self.r: int = 64
self.lora_alpha: int = 16
self.target_modules: List = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
]
self.bias: str = "none"
self.task_type: str = "CAUSAL_LM"
self.lora_dropout: float = 0.05
self.inference_mode: bool = False
def __getitem__(self, key):
return getattr(self, key)
def get(self, attribute_name, default_value=None):
return getattr(self, attribute_name, default_value)
@dataclass
class TrainConfig:
def __init__(self):
self.model_name: str = "asr"
self.enable_deepspeed: bool = True
self.batch_size_training: int = 8 # you should set same as deepspeed config for throughput
self.batching_strategy: str = 'custom'
self.context_length: int = 4096
self.num_epochs: int = 10
self.num_workers_dataloader: int = 4
# please set it in deepspeed config
# self.warmup_steps: int = 1000
# self.total_steps: int = 1000000
# self.lr: float = 1e-4
# self.weight_decay: float = 0.0
self.save_interval: int = 20000
self.save_merge_rank: bool = True
# will merge deepspeed model from several rank
self.log_interval: int = 100
self.resume_step: int = 0
self.resume_epoch: int = 0
self.gamma: float = 0.85
self.seed: int = 42
self.use_fp16: bool = False
self.use_bf16: bool = True
self.mixed_precision: bool = True
self.val_batch_size: int = 10
self.use_peft: bool = True
self.output_dir: str = "output/save_models"
self.freeze_llm: bool = True
self.freeze_encoder: bool = True
self.freeze_projector: bool = True
self.find_unused_parameters: bool = False
self.gradient_checkpoint: bool = False
self.deepspeed_config: str = '/root/MooER/src/mooer/configs/deepspeed_config_zero2.json'
# if you want large bsz or to reduce memory, use zero3, but it will be slow
def __getitem__(self, key):
return getattr(self, key)
def get(self, attribute_name, default_value=None):
return getattr(self, attribute_name, default_value)
@dataclass
class DataConfig:
def __init__(self):
self.train_data_path: Optional[str] = ''
self.val_data_path: Optional[str] = ''
self.test_data_dir: str = '/Your/testsets/root'
self.test_sets: str = 'test-clean/test-other/aishell'
# you can put a series of test sets under test_data_dir for testing, use / for split
self.decode_path: Optional[str] = ''
self.fix_length_audio: int = -1
self.max_length: int = 2000
self.min_length: int = 20
self.mel_size: int = 80
self.train_data_type: str = 'shard'
self.test_data_type: str = 'shard'
self.prompt_template_key: str = 'qwen'
self.prompt_key: str = 'asr'
self.w2v_bert_path: str = ''
self.sort: bool = False
self.replace_text_path: str = ''
self.replace_type: str = 'replace'
# you can use replace_text_path & replace_type to train other task, e.g, AST, with same uttid but different label
def __getitem__(self, key):
return getattr(self, key)
def get(self, attribute_name, default_value=None):
return getattr(self, attribute_name, default_value)
def update(model_config, train_config, data_config):
train_config.is_inference = model_config.is_inference
data_config.is_inference = model_config.is_inference
data_config.num_epochs = train_config.num_epochs
data_config.adapter_downsample_rate = model_config.adapter_downsample_rate
data_config.cmvn_path = model_config.cmvn_path
data_config.encoder_name = model_config.encoder_name
data_config.normalize = model_config.normalize
from dataclasses import dataclass
from typing import Optional, List
@dataclass
class ModelConfig:
def __init__(self):
self.llm_name: str = "qwen2_7b_chat"
# You should set your own path
self.llm_path: str = "pretrained_models/Qwen2-7B-Instruct"
self.encoder_path: str = "pretrained_models/paraformer_encoder/paraformer-encoder.pth"
self.adapter_path: Optional[str] = ''
self.lora_dir: Optional[str] = ''
self.cmvn_path: str = "/root/MooER/src/mooer/configs/am.mvn"
self.prompt_key: str = 'asr' # asr, ast... you can add tasks in src/mooer/utils/data_utils.py
###############################
self.llm_type: str = "decoder_only"
self.llm_dim: int = 3584
self.load_dtype: str = "bfloat16"
self.encoder_name: str = 'paraformer'
self.encoder_dim: int = 512
self.adapter: str = "linear"
self.adapter_downsample_rate: int = 2
self.modal: str = "audio"
self.normalize: Optional[bool] = False
self.gradient_checkpoint: bool = False
self.is_inference: bool = False
self.prompt_template_key: str = 'qwen'
def __getitem__(self, key):
return getattr(self, key)
def get(self, attribute_name, default_value=None):
return getattr(self, attribute_name, default_value)
@dataclass
class PeftConfig:
def __init__(self):
self.peft_method: str = "lora" # None , llama_adapter, prefix
self.r: int = 64
self.lora_alpha: int = 16
self.target_modules: List = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
]
self.bias: str = "none"
self.task_type: str = "CAUSAL_LM"
self.lora_dropout: float = 0.05
self.inference_mode: bool = False
def __getitem__(self, key):
return getattr(self, key)
def get(self, attribute_name, default_value=None):
return getattr(self, attribute_name, default_value)
@dataclass
class TrainConfig:
def __init__(self):
self.model_name: str = "asr"
self.enable_deepspeed: bool = True
self.batch_size_training: int = 8 # you should set same as deepspeed config for throughput
self.batching_strategy: str = 'custom'
self.context_length: int = 4096
self.num_epochs: int = 10
self.num_workers_dataloader: int = 4
# please set it in deepspeed config
# self.warmup_steps: int = 1000
# self.total_steps: int = 1000000
# self.lr: float = 1e-4
# self.weight_decay: float = 0.0
self.save_interval: int = 20000
self.save_merge_rank: bool = True
# will merge deepspeed model from several rank
self.log_interval: int = 100
self.resume_step: int = 0
self.resume_epoch: int = 0
self.gamma: float = 0.85
self.seed: int = 42
self.use_fp16: bool = False
self.use_bf16: bool = True
self.mixed_precision: bool = True
self.val_batch_size: int = 1
self.use_peft: bool = True
self.output_dir: str = "output/save_models"
self.freeze_llm: bool = True
self.freeze_encoder: bool = True
self.freeze_projector: bool = False
self.find_unused_parameters: bool = False
self.gradient_checkpoint: bool = False
self.deepspeed_config: str = '/root/MooER/src/mooer/configs/deepspeed_config_zero2.json'
# if you want large bsz or to reduce memory, use zero3, but it will be slow
def __getitem__(self, key):
return getattr(self, key)
def get(self, attribute_name, default_value=None):
return getattr(self, attribute_name, default_value)
@dataclass
class DataConfig:
def __init__(self):
self.train_data_path: str = '/YOUR/training/data.0.list'
self.val_data_path: Optional[str] = ''
self.test_data_dir: Optional[str] = ''
self.test_sets: Optional[str] = ''
self.decode_path: Optional[str] = ''
self.fix_length_audio: int = -1
self.max_length: int = 2000
self.min_length: int = 20
self.mel_size: int = 80
self.train_data_type: str = 'shard'
self.test_data_type: str = 'shard'
self.prompt_template_key: str = 'qwen'
self.prompt_key: str = 'asr'
self.w2v_bert_path: str = ''
self.num_epochs: int = 10
self.sort: bool = False
self.replace_text_path: str = ''
self.replace_type: str = 'replace'
# you can use replace_text_path & replace_type to train other task, e.g, AST, with same uttid but different label
def __getitem__(self, key):
return getattr(self, key)
def get(self, attribute_name, default_value=None):
return getattr(self, attribute_name, default_value)
def update(model_config, train_config, data_config):
train_config.is_inference = model_config.is_inference
data_config.is_inference = model_config.is_inference
data_config.num_epochs = train_config.num_epochs
data_config.adapter_downsample_rate = model_config.adapter_downsample_rate
data_config.cmvn_path = model_config.cmvn_path
data_config.encoder_name = model_config.encoder_name
data_config.normalize = model_config.normalize
{
"train_micro_batch_size_per_gpu": 8,
"gradient_accumulation_steps": 2,
"steps_per_print": 100,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"total_num_steps": 1000000,
"warmup_max_lr": 0.0001,
"warmup_num_steps": 1000
}
},
"fp16": {
"enabled": false,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": true
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"activation_checkpointing": {
"partition_activations": true,
"cpu_checkpointing": true,
"contiguous_memory_optimization": true,
"synchronize_checkpoint_boundary": false,
"profile": false
}
}
\ No newline at end of file
{
"train_micro_batch_size_per_gpu": 16,
"gradient_accumulation_steps": 2,
"steps_per_print": 100,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"total_num_steps": 10000000,
"warmup_max_lr": 0.0001,
"warmup_num_steps": 1000
}
},
"fp16": {
"enabled": false,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"activation_checkpointing": {
"partition_activations": true,
"cpu_checkpointing": true,
"contiguous_memory_optimization": true,
"synchronize_checkpoint_boundary": false,
"profile": false
}
}
import logging
import os
import random
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset
from transformers import AutoFeatureExtractor
import mooer.datasets.speech_processor as processor
from mooer.utils.data_utils import PROMPT_DICT, PROMPT_TEMPLATE_DICT
def read_lists(list_file, num_epochs=1, shuffle=False):
lists = []
with open(list_file, 'r', encoding='utf8') as fin:
for line in fin:
lists.append(line.strip())
lists = lists * num_epochs
if shuffle:
random.shuffle(lists)
return lists
def load_cmvn(cmvn_file):
with open(cmvn_file, "r", encoding="utf-8") as f:
lines = f.readlines()
means_list = []
vars_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == "<AddShift>":
line_item = lines[i + 1].split()
if line_item[0] == "<LearnRateCoef>":
add_shift_line = line_item[3 : (len(line_item) - 1)]
means_list = list(add_shift_line)
continue
elif line_item[0] == "<Rescale>":
line_item = lines[i + 1].split()
if line_item[0] == "<LearnRateCoef>":
rescale_line = line_item[3 : (len(line_item) - 1)]
vars_list = list(rescale_line)
continue
means = np.array(means_list).astype(np.float32)
vars = np.array(vars_list).astype(np.float32)
cmvn = np.array([means, vars])
cmvn = torch.as_tensor(cmvn, dtype=torch.float32)
return cmvn
class Processor(IterableDataset):
def __init__(self, source, f, *args, **kw):
assert callable(f)
self.source = source
self.f = f
self.args = args
self.kw = kw
def set_epoch(self, epoch):
self.source.set_epoch(epoch)
def __iter__(self):
""" Return an iterator over the source dataset processed by the
given processor.
"""
assert self.source is not None
assert callable(self.f)
return self.f(iter(self.source), *self.args, **self.kw)
def apply(self, f):
assert callable(f)
return Processor(self, f, *self.args, **self.kw)
class DistributedSampler:
def __init__(self, shuffle=True, partition=True):
self.epoch = -1
self.update()
self.shuffle = shuffle
self.partition = partition
def update(self):
assert dist.is_available()
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
else:
self.rank = 0
self.world_size = 1
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.worker_id = 0
self.num_workers = 1
else:
self.worker_id = worker_info.id
self.num_workers = worker_info.num_workers
return dict(rank=self.rank,
world_size=self.world_size,
worker_id=self.worker_id,
num_workers=self.num_workers)
def set_epoch(self, epoch):
self.epoch = epoch
def sample(self, data):
""" Sample data according to rank/world_size/num_workers
Args:
data(List): input data list
Returns:
List: data list after sample
"""
data = list(range(len(data)))
if self.partition:
if self.shuffle:
random.Random(self.epoch).shuffle(data)
data = data[self.rank::self.world_size]
data = data[self.worker_id::self.num_workers]
return data
class DataList(IterableDataset):
def __init__(self, lists, shuffle=True, partition=True):
self.lists = lists
self.sampler = DistributedSampler(shuffle, partition)
def set_epoch(self, epoch):
self.sampler.set_epoch(epoch)
def __iter__(self):
sampler_info = self.sampler.update()
indexes = self.sampler.sample(self.lists)
for index in indexes:
# yield dict(src=src)
data = dict(src=self.lists[index])
data.update(sampler_info)
yield data
class SpeechDatasetShard(torch.utils.data.Dataset):
def __init__(self,
dataset_config,
normalize=True,
mel_size=128,
tokenizer=None):
super().__init__()
self.dataset_config = dataset_config
self.tokenizer = tokenizer
self.IGNORE_INDEX = -100
self.normalize = normalize
self.mel_size = mel_size
self.max_length = dataset_config.get('max_length', 2000)
self.min_length = dataset_config.get('min_length', 20)
self.prompt_template_key = dataset_config.get('prompt_template_key', 'qwen')
self.prompt_template = PROMPT_TEMPLATE_DICT[self.prompt_template_key]
self.prompt_key = dataset_config.get('prompt_key', 'asr')
if self.prompt_key == 'instruction':
self.prompt = PROMPT_DICT
else:
self.prompt = PROMPT_DICT[self.prompt_key]
logging.info(f"Use LLM Type {self.prompt_template_key}, "
f"Prompt template {self.prompt_template}, "
f"Use task type {self.prompt_key}, "
f"Prompt {self.prompt}")
self.is_inference = dataset_config.get('is_inference', False)
if (dataset_config.get('w2v_bert_path', None) is not None) and (os.path.exists(dataset_config.w2v_bert_path)):
self.auto_processer = AutoFeatureExtractor.from_pretrained(dataset_config.w2v_bert_path)
else:
self.auto_processer = None
if (dataset_config.get('cmvn_path', None) is not None) and (os.path.exists(dataset_config.cmvn_path)):
self.cmvn = load_cmvn(dataset_config.cmvn_path)
else:
assert self.dataset_config.encoder_name != 'paraformer', 'paraformer must use cmvn'
self.cmvn = None
self.num_epochs = dataset_config.get('num_epochs', 1)
self.adapter_downsample_rate = dataset_config.get('adapter_downsample_rate', 2)
self.sort = dataset_config.get('sort', True)
self.replace_text_table = None
self.replace_text_path = dataset_config.get('replace_text_path', '')
self.replace_type = dataset_config.get('replace_type', 'replace')
if self.prompt_key == 'instruction':
self.replace_type = 'instruction'
if os.path.exists(self.replace_text_path):
logging.info(f"Parsing replaced table {self.replace_text_path}..., Method {self.replace_type}")
self.replace_text_table = self.parse_txt2dict(self.replace_text_path)
if self.dataset_config.encoder_name in ['paraformer', 'whisper', 'w2v_bert2.0']:
self.input_type = 'mel'
else:
self.input_type = 'raw'
@classmethod
def parse_txt2dict(cls, txt_path):
result = {}
with open(txt_path, 'r') as r:
for line in r.readlines():
line = line.strip()
if line == '':
continue
line = line.split(maxsplit=1)
if len(line) != 2:
continue
key, text = line
result[key] = text
return result
def dataset(self,
data_type,
data_list_file,
shuffle=True,
partition=True):
assert data_type in ['raw', 'shard']
lists = read_lists(data_list_file, num_epochs=self.num_epochs, shuffle=shuffle)
dataset = DataList(lists, shuffle=shuffle, partition=partition)
if data_type == 'shard':
dataset = Processor(dataset, processor.url_opener)
dataset = Processor(dataset, processor.tar_file_and_group)
else:
dataset = Processor(dataset, processor.parse_raw)
if not self.is_inference:
dataset = Processor(dataset, processor.filter,
max_length=self.max_length,
min_length=self.min_length)
if self.replace_text_table is not None:
dataset = Processor(dataset, processor.refine_text,
replaced_table=self.replace_text_table,
replace_type=self.replace_type)
dataset = Processor(dataset, processor.gen_llm_inputs,
tokenizer=self.tokenizer,
ignore_index=self.IGNORE_INDEX,
normalize=self.normalize,
mel_size=self.mel_size,
input_type=self.input_type,
is_paraformer=self.dataset_config.encoder_name == 'paraformer',
prompt_template=self.prompt_template,
is_inference=self.is_inference,
autoprocesser=self.auto_processer,
cmvn=self.cmvn,
prompt_org=self.prompt,
adapter_downsample_rate=self.adapter_downsample_rate,
instruction=self.prompt_key == 'instruction')
if not self.is_inference:
# add shuffle
dataset = Processor(dataset, processor.shuffle)
if self.sort:
dataset = Processor(dataset, processor.sort, sort_size=2000, key='audio'
if self.dataset_config.encoder_name == 'hubert' else 'audio_mel')
return dataset
def pad(self, sequence, max_length, padding_idx=0):
if isinstance(sequence, (int, list, tuple)):
if len(sequence) < max_length:
sequence = sequence + [padding_idx] * (max_length - len(sequence))
else:
sequence = sequence[:max_length]
elif isinstance(sequence, torch.Tensor):
if len(sequence) < max_length:
sequence = torch.cat(
(sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx)))
else:
sequence = sequence[:max_length]
elif isinstance(sequence, np.ndarray):
if len(sequence) < max_length:
sequence = np.concatenate(
(sequence, np.full((max_length - len(sequence),) + sequence.shape[1:], padding_idx)))
else:
sequence = sequence[:max_length]
else:
raise Exception("Type mismatch during padding!")
return sequence
@classmethod
def padding(cls, sequence, padding_length, padding_idx=0, padding_side="right"):
if isinstance(sequence, (int, list, tuple)):
if padding_length >= 0:
sequence = sequence + [padding_idx] * padding_length
else:
sequence = sequence[:padding_length]
elif isinstance(sequence, torch.Tensor):
if sequence.ndimension() == 2:
if padding_length >= 0:
sequence = torch.nn.functional.pad(sequence, (0, padding_length))
else:
sequence = sequence[:, :padding_length]
else:
if padding_length >= 0:
if padding_side == "left":
sequence = torch.cat(
(torch.full(([padding_length] + list(sequence.size())[1:]), padding_idx), sequence))
else:
sequence = torch.cat(
(sequence, torch.full(([padding_length] + list(sequence.size())[1:]), padding_idx)))
else:
sequence = sequence[:padding_length]
elif isinstance(sequence, np.ndarray):
if padding_length >= 0:
sequence = np.concatenate(
(sequence, np.full((padding_length,) + sequence.shape[1:], padding_idx)))
else:
sequence = sequence[:padding_length]
else:
raise Exception("Type mismatch during padding!")
return sequence
def collator(self, samples):
assert samples is not None
input_prompt_lengths = [s["audio_length"] + s['prompt_length'] for s in samples] # [120, 48, 82, 42]
input_answer_lengths = [len(s["input_ids"]) - s["audio_length"] - s['prompt_length'] for s in
samples] # [0, 0, 0, 0]
input_prompt_max_length = max(input_prompt_lengths)
input_answer_max_length = max(input_answer_lengths)
input_ids = torch.stack([
self.padding(
self.padding(samples[index]["input_ids"], input_prompt_max_length - input_prompt_lengths[index],
self.tokenizer.pad_token_id, padding_side="left"),
input_answer_max_length - input_answer_lengths[index], self.tokenizer.pad_token_id
) for index in range(len(samples))
])
attention_mask = torch.stack([
self.padding(
self.padding(samples[index]["attention_mask"], input_prompt_max_length - input_prompt_lengths[index],
False, padding_side="left"),
input_answer_max_length - input_answer_lengths[index], False
) for index in range(len(samples))
])
if self.auto_processer is not None:
audio_mel_max_length = max([s['audio_mel'].shape[0] for s in samples])
audio_mel = torch.stack([self.pad(s['audio_mel'], audio_mel_max_length, 0)
for s in samples])
audio_mel_post_mask = torch.zeros(len(samples), (
audio_mel_max_length)) # w2v-bert
for line, sample in enumerate(samples):
audio_mel_post_mask[line, :(sample['audio_mel'].shape[0])] = 1
elif self.dataset_config.encoder_name == 'paraformer':
audio_mel_reallen = [s['audio_mel'].shape[0] for s in samples]
audio_mel_max_length = max(audio_mel_reallen)
audio_mel = torch.stack([self.pad(s['audio_mel'], audio_mel_max_length, 0)
for s in samples])
audio_mel_post_mask = torch.zeros(len(samples), (
audio_mel_max_length)) # paraformer
for line, sample in enumerate(samples):
audio_mel_post_mask[line, :(sample['audio_mel'].shape[0])] = 1
audio_mel_reallen = torch.tensor(audio_mel_reallen, dtype=torch.int32)
elif self.dataset_config.encoder_name == 'hubert':
audio_raw_max_length = max([s['audio'].shape[0] for s in samples])
audio_raw = torch.stack([self.pad(s['audio'], audio_raw_max_length, 0)
for s in samples])
audio_mask = torch.zeros(len(samples), audio_raw_max_length)
for line, sample in enumerate(samples):
audio_mask[line, :sample['audio'].shape[0]] = 1
elif self.dataset_config.encoder_name == 'whisper':
audio_mel_max_length = max([s['audio_mel'].shape[0] for s in samples])
audio_mel = torch.stack([self.pad(s['audio_mel'], audio_mel_max_length, 0)
for s in samples])
audio_mel_post_mask = torch.zeros(len(samples), (
audio_mel_max_length + 1) // 2) # ad-hoc for whisper for 2x downsample from mel to feats
for line, sample in enumerate(samples):
audio_mel_post_mask[line, :(sample['audio_mel'].shape[0] + 1) // 2] = 1
modality_mask = torch.zeros_like(attention_mask)
for index in range(len(samples)):
padding_left = input_prompt_max_length - input_prompt_lengths[index]
modality_mask[index, padding_left:padding_left + samples[index]["audio_length"]] = True
keys = [s['key'] for s in samples]
if self.is_inference:
targets = [s['target'] for s in samples]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"audio": audio_raw if self.input_type == "raw" else None,
"audio_mask": audio_mask if self.input_type == "raw" else None,
"audio_mel": audio_mel if self.input_type == "mel" else None,
"audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None,
"modality_mask": modality_mask,
"keys": keys,
"targets": targets,
"audio_mel_reallen": audio_mel_reallen if self.dataset_config.encoder_name == 'paraformer' else None
}
labels = torch.stack([
self.padding(
self.padding(samples[index]['labels'], input_prompt_max_length - input_prompt_lengths[index],
self.IGNORE_INDEX, padding_side="left"),
input_answer_max_length - input_answer_lengths[index], self.IGNORE_INDEX)
for index in range(len(samples))
])
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"audio": audio_raw if self.input_type == "raw" else None,
"audio_mask": audio_mask if self.input_type == "raw" else None,
"audio_mel": audio_mel if self.input_type == "mel" else None,
"audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None,
"modality_mask": modality_mask,
"audio_mel_reallen": audio_mel_reallen if self.dataset_config.encoder_name == 'paraformer' else None,
"keys": keys,
}
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class LinearAdapter(nn.Module):
def __init__(self, config):
super().__init__()
self.k = config.adapter_downsample_rate
self.encoder_dim = config.encoder_dim
self.llm_dim = config.llm_dim
self.linear1 = nn.Linear(self.encoder_dim * self.k, 2048)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(2048, config.llm_dim)
def forward(self, x, gradient_checkpoint=False):
batch_size, seq_len, dim = x.size()
num_frames_to_discard = seq_len % self.k
if num_frames_to_discard > 0:
x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1)
x = x.contiguous()
x = x.view(batch_size, seq_len // self.k, dim * self.k)
if gradient_checkpoint:
x = checkpoint(self.linear1, x)
else:
x = self.linear1(x)
x = self.relu(x)
if gradient_checkpoint:
x = checkpoint(self.linear2, x)
else:
x = self.linear2(x)
return x
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
class WhisperWrappedEncoder:
@classmethod
def load(cls, model_config):
def extract_variable_length_features(self, x: torch.Tensor, gradient_checkpoint=False):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
# assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
# x = (x + self.positional_embedding).to(x.dtype)
x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype)
for block in self.blocks:
if gradient_checkpoint:
x = checkpoint(block, x)
else:
x = block(x)
if gradient_checkpoint:
x = checkpoint(self.ln_post, x)
else:
x = self.ln_post(x)
return x
import whisper
encoder = whisper.load_model(name=model_config.encoder_path, device='cpu').encoder
encoder.extract_variable_length_features = types.MethodType(extract_variable_length_features, encoder)
return encoder
class HubertEncoder:
@classmethod
def load(cls, model_config):
import fairseq
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path])
model = models[0]
if model_config.encoder_type == "pretrain":
pass
elif model_config.encoder_type == "finetune":
model.w2v_encoder.proj = None
model.w2v_encoder.apply_mask = False
else:
assert model_config.encoder_type in ["pretrain", "finetune"], "input_type must be one of [pretrain, finetune]"
return model
class W2vBert2Encoder(nn.Module):
def __init__(self, config, model):
super().__init__()
self.config = config
self.model = model
@classmethod
def load(cls, model_config):
from transformers import Wav2Vec2BertModel
model = Wav2Vec2BertModel.from_pretrained(model_config.encoder_path)
return cls(model_config, model)
def extract_features(self, source, attention_mask):
output = self.model(source, attention_mask=attention_mask)
return output.last_hidden_state
class HfTextEncoder:
@classmethod
def load(cls, model_config):
from transformers import AutoModel
model = AutoModel.from_pretrained(model_config.encoder_path)
return model
class ParaformerEncoder(nn.Module):
def __init__(self, config, model):
super().__init__()
self.config = config
self.model = model
@classmethod
def load(cls, model_config):
from .Paraformer.encoder import SANMEncoder
model = SANMEncoder(gradient_checkpoint=model_config.get('gradient_checkpoint', False))
ckpt_dict = torch.load(model_config.encoder_path, map_location="cpu")
model.load_state_dict(ckpt_dict, strict=False)
return cls(model_config, model)
def extract_features(self, source, reallen):
output, _, _ = self.model(
xs_pad=source,
ilens=reallen
)
# TODO: support streaming @zhenlin.liang
return output
import json
import os
import torch
import torchaudio
from .hifigan import Generator
from .speaker_encoder import ResNetSpeakerEncoder
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def get_hifigan_model(vocoder_path, device,
vocoder_config=None, decoder_dim=1024):
if vocoder_config is None:
vocoder_config = os.path.join(os.path.dirname(__file__), 'hifigan_config.json')
print(f"Loading vocoder config from {vocoder_config}")
with open(vocoder_config, "r") as f:
config = json.load(f)
config = AttrDict(config)
config.input_num_mels = decoder_dim
vocoder = Generator(config)
print(f"Loading vocoder from {vocoder_path}")
ckpt = torch.load(vocoder_path, map_location=device)
vocoder.load_state_dict(ckpt["generator"], strict=True)
vocoder.eval()
vocoder.remove_weight_norm()
vocoder.to(device)
return vocoder
def get_speaker_encoder(model_checkpoint, device):
model = ResNetSpeakerEncoder(
input_dim=64,
proj_dim=512,
log_input=True,
use_torch_spec=True,
audio_config={
"fft_size": 512,
"win_length": 400,
"hop_length": 160,
"sample_rate": 16000,
"preemphasis": 0.97,
"num_mels": 64
},
)
print(f"Loading speaker encoder from {model_checkpoint}")
checkpoint = torch.load(model_checkpoint, map_location=device)["speaker_encoder"]
model.load_state_dict(checkpoint, strict=True)
model.eval()
model.to(device)
return model
def encode_prompt_wav(spk_encoder, prompt_wav, device):
wav, sr = torchaudio.load(prompt_wav)
wav_16k = torchaudio.functional.resample(wav, sr, 16000)
vocoder_latent = spk_encoder.forward(
wav_16k.to(device), l2_norm=True
).unsqueeze(-1)
return vocoder_latent
def save_wav(hifigan_generator, embedding, latent, save_path):
samples = hifigan_generator(latent.permute(0, 2, 1), g=embedding).squeeze(0)
samples = samples.detach().cpu()
torchaudio.save(save_path, samples, 22050)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm
LRELU_SLOPE = 0.1
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
class ResBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2])))
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.h = h
self.convs = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])))
])
self.convs.apply(init_weights)
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class Generator(torch.nn.Module):
def __init__(self, h):
super(Generator, self).__init__()
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
self.conv_pre = weight_norm(
Conv1d(h.input_num_mels if hasattr(h, "input_num_mels") and h.input_num_mels else h.num_mels,
h.upsample_initial_channel, 7, 1, padding=(7 - 1) // 2))
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(weight_norm(
ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)),
k, u, padding=(k - u) // 2)))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=(7 - 1) // 2))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
if h.cond_channels > 0:
self.cond_layer = nn.Conv1d(h.cond_channels, h.upsample_initial_channel, 1)
if h.cond_in_each_up_layer:
self.conds = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
self.conds.append(nn.Conv1d(h.cond_channels, ch, 1))
def forward(self, x, g=None):
x = torch.nn.functional.interpolate(
x,
scale_factor=[self.h.input_up_scale_factor],
mode="linear",
).squeeze(1)
x = self.conv_pre(x)
if self.h.cond_channels > 0:
x = x + self.cond_layer(g)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
if self.h.cond_in_each_up_layer:
x = x + self.conds[i](g)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 1024,
"learning_rate": 0.0002,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.99,
"seed": 1234,
"upsample_rates": [8,8,2,2],
"upsample_kernel_sizes": [16,16,4,4],
"upsample_initial_channel": 512,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"discriminator_periods": null,
"segment_size": 8192,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 22050,
"fmin": 0,
"fmax": 8000,
"fmax_for_loss": null,
"num_workers": 8,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
},
"input_num_mels": 3584,
"input_n_fft": null,
"input_hop_size": null,
"input_win_size": null,
"input_sampling_rate": null,
"cond_channels": 512,
"cond_in_each_up_layer": true,
"input_up_scale_factor": 4
}
\ No newline at end of file
import torch
import torch.nn as nn
import torchaudio
class SELayer(nn.Module):
def __init__(self, channel, reduction=8):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel),
nn.Sigmoid(),
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
class SEBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
super(SEBasicBlock, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.se = SELayer(planes, reduction)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.relu(out)
out = self.bn1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class PreEmphasis(nn.Module):
def __init__(self, coefficient=0.97):
super().__init__()
self.coefficient = coefficient
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
def forward(self, x):
assert len(x.size()) == 2
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
class ResNetSpeakerEncoder(nn.Module):
"""This is copied from 🐸TTS to remove it from the dependencies."""
# pylint: disable=W0102
def __init__(
self,
input_dim=64,
proj_dim=512,
layers=[3, 4, 6, 3],
num_filters=[32, 64, 128, 256],
encoder_type="ASP",
log_input=False,
use_torch_spec=False,
audio_config=None,
):
super(ResNetSpeakerEncoder, self).__init__()
self.encoder_type = encoder_type
self.input_dim = input_dim
self.log_input = log_input
self.use_torch_spec = use_torch_spec
self.audio_config = audio_config
self.proj_dim = proj_dim
self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.bn1 = nn.BatchNorm2d(num_filters[0])
self.inplanes = num_filters[0]
self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0])
self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2))
self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2))
self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2))
self.instancenorm = nn.InstanceNorm1d(input_dim)
if self.use_torch_spec:
self.torch_spec = torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]),
torchaudio.transforms.MelSpectrogram(
sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
),
)
else:
self.torch_spec = None
outmap_size = int(self.input_dim / 8)
self.attention = nn.Sequential(
nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
nn.ReLU(),
nn.BatchNorm1d(128),
nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
nn.Softmax(dim=2),
)
if self.encoder_type == "SAP":
out_dim = num_filters[3] * outmap_size
elif self.encoder_type == "ASP":
out_dim = num_filters[3] * outmap_size * 2
else:
raise ValueError("Undefined encoder")
self.fc = nn.Linear(out_dim, proj_dim)
self._init_layers()
def _init_layers(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def create_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
# pylint: disable=R0201
def new_parameter(self, *size):
out = nn.Parameter(torch.FloatTensor(*size))
nn.init.xavier_normal_(out)
return out
def forward(self, x, l2_norm=False):
"""Forward pass of the model.
Args:
x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
to compute the spectrogram on-the-fly.
l2_norm (bool): Whether to L2-normalize the outputs.
Shapes:
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
"""
x.squeeze_(1)
# if you torch spec compute it otherwise use the mel spec computed by the AP
if self.use_torch_spec:
x = self.torch_spec(x)
if self.log_input:
x = (x + 1e-6).log()
x = self.instancenorm(x).unsqueeze(1)
x = self.conv1(x)
x = self.relu(x)
x = self.bn1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.reshape(x.size()[0], -1, x.size()[-1])
w = self.attention(x)
if self.encoder_type == "SAP":
x = torch.sum(x * w, dim=2)
elif self.encoder_type == "ASP":
mu = torch.sum(x * w, dim=2)
sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5))
x = torch.cat((mu, sg), 1)
x = x.view(x.size()[0], -1)
x = self.fc(x)
if l2_norm:
x = torch.nn.functional.normalize(x, p=2, dim=1)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment