Commit 0941998c authored by sunzhq2's avatar sunzhq2 Committed by xuxo
Browse files

conformer add post and ana

parent fde49a28
import torch
from torch.utils.data import DataLoader, Dataset
import soundfile
import time
import numpy as np
import os
import multiprocessing
import argparse
from typing import Dict, Optional, Tuple
from espnet2.bin.asr_inference import Speech2Text
from espnet2.torch_utils.device_funcs import to_device
torch.set_num_threads(1)
try:
from swig_decoders import map_batch, \
ctc_beam_search_decoder_batch, \
TrieVector, PathTrie
except ImportError:
print('Please install ctc decoders first by refering to\n' +
'https://github.com/Slyne/ctc_decoder.git')
sys.exit(1)
def lm_batchify_nll(lm_scorer, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood(nll) from transformer language model using lm_scorer
To avoid OOM, this function separates the input into batches.
Then call batch_score for each batch and combine and return results.
Args:
lm_scorer: Language model scorer object
text: (Batch, Length)
text_lengths: (Batch,)
batch_size: int, samples each batch contain when computing nll,
you may change this to avoid OOM or increase
"""
total_num = text.size(0)
if total_num <= batch_size:
nll, x_lengths = _compute_nll_with_lm_scorer(lm_scorer, text, text_lengths)
else:
nlls = []
x_lengths = []
max_length = text_lengths.max()
start_idx = 0
while True:
end_idx = min(start_idx + batch_size, total_num)
batch_text = text[start_idx:end_idx, :]
batch_text_lengths = text_lengths[start_idx:end_idx]
# batch_nll: [B * T]
batch_nll, batch_x_lengths = _compute_nll_with_lm_scorer(
lm_scorer, batch_text, batch_text_lengths, max_length=max_length
)
nlls.append(batch_nll)
x_lengths.append(batch_x_lengths)
start_idx = end_idx
if start_idx == total_num:
break
nll = torch.cat(nlls)
x_lengths = torch.cat(x_lengths)
assert nll.size(0) == total_num
assert x_lengths.size(0) == total_num
return nll, x_lengths
def _compute_nll_with_lm_scorer(lm_scorer, text: torch.Tensor, text_lengths: torch.Tensor, max_length: int = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood using lm_scorer's score method
This function simulates the nll method using the available score method
from the lm_scorer object.
"""
batch_size = text.size(0)
# For data parallel
if max_length is None:
text = text[:, : text_lengths.max()]
else:
text = text[:, :max_length]
# Initialize nll for each sequence
nll = torch.zeros(batch_size, device=text.device)
# Process each sequence individually
for batch_idx in range(batch_size):
seq_text = text[batch_idx]
seq_length = text_lengths[batch_idx]
# Truncate to actual sequence length
seq_text = seq_text[:seq_length]
# Initialize state for this sequence
state = None
# Process each token position sequentially
for pos in range(len(seq_text) - 1):
# Get current token
current_token = seq_text[pos].unsqueeze(0) # shape: (1,)
# Score the current token
logp, state = lm_scorer.score(current_token, state, None)
# Get the ground truth next token
next_token = seq_text[pos + 1]
# Get the negative log likelihood for the correct next token
token_nll = -logp[next_token]
nll[batch_idx] += token_nll
# x_lengths is text_lengths - 1 (since we score transitions between tokens)
x_lengths = text_lengths - 1
x_lengths = torch.clamp(x_lengths, min=0) # Ensure non-negative
return nll, x_lengths
class CustomAishellDataset(Dataset):
def __init__(self, wav_scp_file, text_file):
with open(wav_scp_file,'r') as wav_scp, open(text_file,'r') as text:
wavs = wav_scp.readlines()
texts = text.readlines()
self.wav_names = [item.split()[0] for item in wavs]
self.wav_paths = [item.split()[1] for item in wavs]
self.labels = ["".join(item.split()[1:]) for item in texts]
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
speech,sr = soundfile.read(self.wav_paths[idx])
assert sr==16000, sr
speech = np.array(speech, dtype=np.float32)
speech_len = speech.shape[0]
label = self.labels[idx]
name = self.wav_names[idx]
return speech, speech_len, label, name
def collate_wrapper(batch):
speeches = np.zeros((len(batch), 16000 * 30),dtype=np.float32)
lengths = np.zeros(len(batch),dtype=np.int64)
labels = []
names = []
for i, (speech, speech_len, label, name) in enumerate(batch):
speeches[i,:speech_len] = speech
lengths[i] = speech_len
labels.append(label)
names.append(name)
speeches = speeches[:,:max(lengths)]
return speeches, lengths, labels, names
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
seq_range = torch.arange(0,
max_len,
dtype=torch.int64,
device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
def get_args():
parser = argparse.ArgumentParser(description='recognize with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--lm_config', required=True, help='config file')
parser.add_argument('--gpu',
type=int,
default=0,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--wav_scp', required=True, help='wav scp file')
parser.add_argument('--text', required=True, help='ground truth text file')
parser.add_argument('--model_path', required=True, help='torch pt model file')
parser.add_argument('--lm_path', required=True, help='torch pt model file')
parser.add_argument('--result_file', default='./predictions.txt', help='asr result file')
parser.add_argument('--log_file', default='./rtf.txt', help='asr decoding log')
parser.add_argument('--batch_size',
type=int,
default=24,
help='batch_size')
parser.add_argument('--beam_size',
type=int,
default=10,
help='beam_size')
parser.add_argument('--mode',
choices=[
'ctc_greedy_search', 'ctc_prefix_beam_search',
'attention_rescoring', 'attention_lm_rescoring', 'lm_rescoring'],
default='attention_lm_rescoring',
help='decoding mode')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = get_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
dataset = CustomAishellDataset(args.wav_scp, args.text)
test_data_loader = DataLoader(dataset, batch_size=args.batch_size,
collate_fn=collate_wrapper)
speech2text = Speech2Text(
args.config,
args.model_path,
None,
args.lm_config,
args.lm_path,
device="cuda"
)
# 手动加载完整的ESPnetLanguageModel对象
# 因为Speech2Text中只存储了原始语言模型,我们需要完整的对象来使用batchify_nll方法
full_lm_model = None
if args.lm_config is not None and args.lm_path is not None:
from espnet2.tasks.lm import LMTask
full_lm_model, _ = LMTask.build_model_from_file(
args.lm_config, args.lm_path, "cuda"
)
full_lm_model.eval()
# 使用torch.compile优化模型性能
# 检查PyTorch版本是否支持torch.compile
if hasattr(torch, 'compile') and torch.cuda.is_available():
print("启用torch.compile优化...")
# 尝试不同的后端,从最兼容到最高性能
backends_to_try = [
("aot_eager", {}), # aot_eager不支持mode参数
("eager", {"mode": "reduce-overhead"}),
("inductor", {"mode": "reduce-overhead", "dynamic": False, "fullgraph": False})
]
for backend_name, backend_options in backends_to_try:
try:
print(f"尝试使用 {backend_name} 后端进行编译...")
# 编译ASR模型的关键组件
if hasattr(speech2text.asr_model, 'encode'):
speech2text.asr_model.encode = torch.compile(speech2text.asr_model.encode, backend=backend_name, **backend_options)
if hasattr(speech2text.asr_model.ctc, 'ctc_lo'):
speech2text.asr_model.ctc.ctc_lo = torch.compile(speech2text.asr_model.ctc.ctc_lo, backend=backend_name, **backend_options)
# 编译语言模型(如果存在)
if full_lm_model is not None and hasattr(full_lm_model, 'batchify_nll'):
full_lm_model.batchify_nll = torch.compile(full_lm_model.batchify_nll, backend=backend_name, **backend_options)
# 编译成功,设置TensorFloat-32加速
torch.set_float32_matmul_precision('high')
print(f"✓ 使用 {backend_name} 后端编译成功")
print("✓ TensorFloat-32加速已启用")
break
except Exception as e:
print(f"⚠ {backend_name} 后端编译失败: {e}")
# 恢复原始函数
if hasattr(speech2text.asr_model, 'encode'):
speech2text.asr_model.encode = speech2text.asr_model.encode._orig_mod if hasattr(speech2text.asr_model.encode, '_orig_mod') else speech2text.asr_model.encode
if hasattr(speech2text.asr_model.ctc, 'ctc_lo'):
speech2text.asr_model.ctc.ctc_lo = speech2text.asr_model.ctc.ctc_lo._orig_mod if hasattr(speech2text.asr_model.ctc.ctc_lo, '_orig_mod') else speech2text.asr_model.ctc.ctc_lo
if full_lm_model is not None and hasattr(full_lm_model, 'batchify_nll'):
full_lm_model.batchify_nll = full_lm_model.batchify_nll._orig_mod if hasattr(full_lm_model.batchify_nll, '_orig_mod') else full_lm_model.batchify_nll
if backend_name == backends_to_try[-1][0]: # 所有后端都失败
print("⚠ 所有编译后端都失败,将使用未编译模式运行")
torch.set_float32_matmul_precision('high') # 仍然启用TF32加速
print("✓ TensorFloat-32加速已启用(未编译模式)")
audio_sample_len = 0
total_inference_time = 0
with torch.no_grad(), open(args.result_file, 'w') as fout:
for _, batch in enumerate(test_data_loader):
# 开始计时推理时间(不包含torch.compile时间)
batch_start_time = time.perf_counter()
speech, speech_lens, labels, names = batch
audio_sample_len += np.sum(speech_lens) / 16000
batch = {"speech": speech, "speech_lengths": speech_lens}
if isinstance(batch["speech"], np.ndarray):
batch["speech"] = torch.tensor(batch["speech"])
if isinstance(batch["speech_lengths"], np.ndarray):
batch["speech_lengths"] = torch.tensor(batch["speech_lengths"])
# a. To device
batch = to_device(batch, device='cuda')
# b. Forward Encoder
# enc: [N, T, C]
ll = time.time()
encoder_out, encoder_out_lens = speech2text.asr_model.encode(**batch)
# ctc_log_probs: [N, T, C]
ctc_logits = speech2text.asr_model.ctc.ctc_lo(encoder_out)
ctc_log_probs = torch.nn.functional.log_softmax(ctc_logits, dim=2)
beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs,
args.beam_size, dim=2)
num_processes = min(multiprocessing.cpu_count(), args.batch_size)
if args.mode == 'ctc_greedy_search':
assert args.beam_size != 1
log_probs_idx = beam_log_probs_idx[:, :, 0]
batch_sents = []
for idx, seq in enumerate(log_probs_idx):
batch_sents.append(seq[0:encoder_out_lens[idx]].tolist())
hyps = map_batch(batch_sents, speech2text.asr_model.token_list,
num_processes, True, 0)
else:
batch_log_probs_seq_list = beam_log_probs.tolist()
batch_log_probs_idx_list = beam_log_probs_idx.tolist()
batch_len_list = encoder_out_lens.tolist()
batch_log_probs_seq = []
batch_log_probs_ids = []
batch_start = [] # only effective in streaming deployment
batch_root = TrieVector()
root_dict = {}
for i in range(len(batch_len_list)):
num_sent = batch_len_list[i]
batch_log_probs_seq.append(
batch_log_probs_seq_list[i][0:num_sent])
batch_log_probs_ids.append(
batch_log_probs_idx_list[i][0:num_sent])
root_dict[i] = PathTrie()
batch_root.append(root_dict[i])
batch_start.append(True)
score_hyps = ctc_beam_search_decoder_batch(batch_log_probs_seq,
batch_log_probs_ids,
batch_root,
batch_start,
args.beam_size,
num_processes,
0, -2, 0.99999)
if args.mode == 'ctc_prefix_beam_search':
hyps = []
for cand_hyps in score_hyps:
hyps.append(cand_hyps[0][1])
hyps = map_batch(hyps, speech2text.asr_model.token_list, num_processes, False, 0)
elif args.mode == 'attention_rescoring':
ctc_score, all_hyps = [], []
max_len = 0
for hyps in score_hyps:
cur_len = len(hyps)
if len(hyps) < args.beam_size:
hyps += (args.beam_size - cur_len) * [(-float("INF"), (0,))]
cur_ctc_score = []
for hyp in hyps:
cur_ctc_score.append(hyp[0])
all_hyps.append(list(hyp[1]))
if len(hyp[1]) > max_len:
max_len = len(hyp[1])
ctc_score.append(cur_ctc_score)
ctc_score = torch.tensor(ctc_score, dtype=torch.float32)
hyps_pad_sos_eos = torch.ones(
(args.batch_size, args.beam_size, max_len + 2), dtype=torch.int64) * speech2text.asr_model.ignore_id # FIXME: ignore id
hyps_pad_sos = torch.ones(
(args.batch_size, args.beam_size, max_len + 1), dtype=torch.int64) * speech2text.asr_model.eos # FIXME: eos
hyps_pad_eos = torch.ones(
(args.batch_size, args.beam_size, max_len + 1), dtype=torch.int64) * speech2text.asr_model.ignore_id # FIXME: ignore id
hyps_lens_sos = torch.ones((args.batch_size, args.beam_size), dtype=torch.int32)
k = 0
for i in range(args.batch_size):
for j in range(args.beam_size):
cand = all_hyps[k]
l = len(cand) + 2
hyps_pad_sos_eos[i][j][0:l] = torch.tensor([speech2text.asr_model.sos] + cand + [speech2text.asr_model.eos])
hyps_pad_sos[i][j][0:l-1] = torch.tensor([speech2text.asr_model.sos] + cand)
hyps_pad_eos[i][j][0:l-1] = torch.tensor(cand + [speech2text.asr_model.eos])
hyps_lens_sos[i][j] = len(cand) + 1
k += 1
bz = args.beam_size
B,T,F = encoder_out.shape
B2=B*bz
encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F)
encoder_out_lens = encoder_out_lens.repeat(bz)
hyps_pad = hyps_pad_sos_eos.view(B2, max_len + 2)
hyps_lens = hyps_lens_sos.view(B2,)
hyps_pad_sos = hyps_pad_sos.view(B2, max_len + 1)
hyps_pad_eos = hyps_pad_eos.view(B2, max_len + 1)
#hyps_pad_sos = hyps_pad[:, :-1]
#hyps_pad_eos = hyps_pad[:, 1:]
decoder_out, _ = speech2text.asr_model.decoder(encoder_out,encoder_out_lens,hyps_pad_sos.cuda(), hyps_lens.cuda())
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
mask = ~make_pad_mask(hyps_lens, max_len+1) # B2 x T2
# mask index, remove ignore id
index = torch.unsqueeze(hyps_pad_eos * mask, 2)
score = decoder_out.cpu().gather(2, index).squeeze(2) # B2 X T2
# mask padded part
score = score * mask
# decoder_out = decoder_out.view(B, bz, max_len+1, -1)
score = torch.sum(score, axis=1)
score = torch.reshape(score,(B,bz))
all_scores = ctc_score + 0.1 * score # FIX ME need tuned
best_index = torch.argmax(all_scores, dim=1)
best_sents = []
k = 0
for idx in best_index:
cur_best_sent = all_hyps[k: k + args.beam_size][idx]
best_sents.append(cur_best_sent)
k += args.beam_size
hyps = map_batch(best_sents, speech2text.asr_model.token_list, num_processes)
elif args.mode == 'attention_lm_rescoring':
ctc_score, all_hyps = [], []
max_len = 0
for hyps in score_hyps:
cur_len = len(hyps)
if len(hyps) < args.beam_size:
hyps += (args.beam_size - cur_len) * [(-float("INF"), (0,))]
cur_ctc_score = []
for hyp in hyps:
cur_ctc_score.append(hyp[0])
all_hyps.append(list(hyp[1]))
if len(hyp[1]) > max_len:
max_len = len(hyp[1])
ctc_score.append(cur_ctc_score)
ctc_score = torch.tensor(ctc_score, dtype=torch.float32)
# 优化:批量构建hyps_pad,避免嵌套循环
hyps_pad = torch.full((args.batch_size, args.beam_size, max_len),
speech2text.asr_model.ignore_id, dtype=torch.int64)
hyps_lens = torch.zeros((args.batch_size, args.beam_size), dtype=torch.int32)
# 批量填充数据
for k, cand in enumerate(all_hyps):
i = k // args.beam_size
j = k % args.beam_size
l = len(cand)
hyps_pad[i, j, :l] = torch.tensor(cand, dtype=torch.int64)
hyps_lens[i, j] = l
bz = args.beam_size
B,T,F = encoder_out.shape
B2=B*bz
encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F)
encoder_out_lens = encoder_out_lens.repeat(bz)
hyps_pad = hyps_pad.view(B2, max_len).cuda()
hyps_lens = hyps_lens.view(B2,).cuda()
decoder_scores = -speech2text.asr_model.batchify_nll(
encoder_out, encoder_out_lens, hyps_pad, hyps_lens, 320
)
decoder_scores = torch.reshape(decoder_scores,(B,bz)).cpu()
# 使用完整的ESPnetLanguageModel对象进行语言模型评分
if full_lm_model is not None:
try:
# 首先清理数据:将ignore_id替换为0(语言模型的padding值)
hyps_pad_clean = hyps_pad.clone()
hyps_pad_clean[hyps_pad_clean == speech2text.asr_model.ignore_id] = 0
# 使用更小的批量大小避免内存问题
nnlm_nll, x_lengths = full_lm_model.batchify_nll(hyps_pad_clean, hyps_lens, 64)
except Exception as e:
print(f"语言模型评分失败: {e}")
# 如果失败,使用零值作为fallback
nnlm_nll = torch.zeros_like(hyps_pad)
x_lengths = hyps_lens
else:
# 如果没有语言模型,使用默认值
nnlm_nll = torch.zeros_like(hyps_pad)
x_lengths = hyps_lens
nnlm_scores = -nnlm_nll.sum(dim=1)
nnlm_scores = torch.reshape(nnlm_scores,(B,bz)).cpu()
all_scores = ctc_score - 0.05 * decoder_scores + 1.0 * nnlm_scores # FIX ME need tuned
best_index = torch.argmax(all_scores, dim=1)
best_sents = []
k = 0
for idx in best_index:
cur_best_sent = all_hyps[k: k + args.beam_size][idx]
best_sents.append(cur_best_sent)
k += args.beam_size
hyps = map_batch(best_sents, speech2text.asr_model.token_list, num_processes)
elif args.mode == 'lm_rescoring':
# 优化:预分配内存,避免动态扩展
ctc_score = []
all_hyps = []
max_len = 0
# 预计算最大长度
for hyps in score_hyps:
for hyp in hyps:
if len(hyp[1]) > max_len:
max_len = len(hyp[1])
# 批量处理
for hyps in score_hyps:
cur_len = len(hyps)
if len(hyps) < args.beam_size:
hyps += (args.beam_size - cur_len) * [(-float("INF"), (0,))]
cur_ctc_score = []
for hyp in hyps:
cur_ctc_score.append(hyp[0])
all_hyps.append(list(hyp[1]))
ctc_score.append(cur_ctc_score)
ctc_score = torch.tensor(ctc_score, dtype=torch.float32)
hyps_pad = torch.ones(
(args.batch_size, args.beam_size, max_len), dtype=torch.int64) * speech2text.asr_model.ignore_id # FIXME: ignore id
hyps_lens = torch.ones((args.batch_size, args.beam_size), dtype=torch.int32)
k = 0
for i in range(args.batch_size):
for j in range(args.beam_size):
cand = all_hyps[k]
l = len(cand)
hyps_pad[i][j][0:l] = torch.tensor(cand)
hyps_lens[i][j] = len(cand)
k += 1
bz = args.beam_size
B,T,F = encoder_out.shape
B2=B*bz
hyps_pad = hyps_pad.view(B2, max_len).cuda()
hyps_lens = hyps_lens.view(B2,).cuda()
hyps_pad[hyps_pad == speech2text.asr_model.ignore_id] = 0
nnlm_nll, x_lengths = full_lm_model.batchify_nll(hyps_pad, hyps_lens, 320)
nnlm_scores = -nnlm_nll.sum(dim=1)
nnlm_scores = torch.reshape(nnlm_scores,(B,bz))
# 直接在GPU上计算,避免CPU-GPU传输
ctc_score_gpu = ctc_score.cuda()
all_scores = ctc_score_gpu + 0.9 * nnlm_scores # FIX ME need tuned
best_index = torch.argmax(all_scores, dim=1)
best_index = best_index.cpu() # 只在最后传输到CPU
best_sents = []
k = 0
for idx in best_index:
cur_best_sent = all_hyps[k: k + args.beam_size][idx]
best_sents.append(cur_best_sent)
k += args.beam_size
hyps = map_batch(best_sents, speech2text.asr_model.token_list, num_processes)
else:
raise NotImplementedError
print("耗时:",{time.time()-ll}, "fps:", {24/(time.time()-ll)})
for i, key in enumerate(names):
content = hyps[i]
# print('{} {}'.format(key, content))
fout.write('{} {}\n'.format(key, content))
# 记录batch推理时间(不包含torch.compile时间)
batch_end_time = time.perf_counter()
total_inference_time += batch_end_time - batch_start_time
# 计算总时间统计(不包含torch.compile时间)
if str(args.gpu) == '0':
with open(args.log_file, 'w') as log:
log.write(f"Decoding audio {audio_sample_len} secs, cost {total_inference_time} secs (不包含torch.compile时间), RTF: {total_inference_time/audio_sample_len}, process {audio_sample_len/total_inference_time} secs audio per second, decoding args: {args}")
......@@ -158,8 +158,14 @@ if __name__ == '__main__':
# b. Forward Encoder
# enc: [N, T, C]
feats, feats_lengths = speech2text.asr_model.pre_data(**batch)
feats_lengths_1 = torch.ceil(feats_lengths.float() / 4).long()
print("feats_lengths_1:",feats_lengths_1)
# print("feats_lengths:",feats_lengths)
ll_time = time.time()
encoder_out, encoder_out_lens = speech2text.asr_model.encode(**batch)
encoder_out, encoder_out_lens = speech2text.asr_model.encode(feats, feats_lengths)
print("encoder_out_lens:",encoder_out_lens)
# ctc_log_probs: [N, T, C]
ctc_log_probs = torch.nn.functional.log_softmax(
speech2text.asr_model.ctc.ctc_lo(encoder_out), dim=2
......
#!/usr/bin/env python3
import torch
from torch.utils.data import DataLoader, Dataset
import soundfile
......@@ -59,94 +60,6 @@ def collate_wrapper(batch):
return speeches, lengths, labels, names
# def collate_wrapper(batch):
# """
# 实现与ESPNet模型相同的特征处理流程:
# 1. 提取特征(相当于 self._extract_feats)
# 2. 跳过数据增强(仅在训练时使用)
# 3. 特征归一化(相当于 self.normalize)
# """
# speeches = np.zeros((len(batch), 16000 * 30), dtype=np.float32)
# lengths = np.zeros(len(batch), dtype=np.int64)
# labels = []
# names = []
# for i, (speech, speech_len, label, name) in enumerate(batch):
# speeches[i, :speech_len] = speech
# lengths[i] = speech_len
# labels.append(label)
# names.append(name)
# speeches = speeches[:, :max(lengths)]
# try:
# # === 1. 提取特征(相当于 self._extract_feats) ===
# import librosa
# batch_size = speeches.shape[0]
# features_list = []
# for i in range(batch_size):
# audio = speeches[i]
# # 提取梅尔特征(与ESPNet前端处理一致)
# audio = librosa.effects.trim(audio, top_db=20)[0] # 去除静音
# stft = librosa.stft(audio, n_fft=512, hop_length=128, win_length=512)
# spectrogram = np.abs(stft)
# mel_filter = librosa.filters.mel(sr=16000, n_fft=512, n_mels=80)
# mel_spectrogram = np.dot(mel_filter, spectrogram)
# log_mel_spectrogram = np.log(np.clip(mel_spectrogram, a_min=1e-10, a_max=None))
# log_mel_spectrogram = log_mel_spectrogram.T # [time, 80]
# features_list.append(log_mel_spectrogram)
# # 找到最大时间长度并填充
# max_time = max(feat.shape[0] for feat in features_list)
# features = np.zeros((batch_size, max_time, 80), dtype=np.float32)
# for i, feat in enumerate(features_list):
# features[i, :feat.shape[0], :] = feat
# feats_lengths = np.array([feat.shape[0] for feat in features_list], dtype=np.int64)
# # print(f"特征提取完成: 音频形状 {speeches.shape} -> 特征形状 {features.shape}")
# # === 2. 跳过数据增强(仅在训练时使用) ===
# # if self.specaug is not None and self.training: # 跳过
# # feats, feats_lengths = self.specaug(feats, feats_lengths)
# # === 3. 特征归一化(相当于 self.normalize) ===
# stats_file = "/home/sunzhq/workspace/yidong-infer/conformer/34e9cabc2c29fd0e3a2917ffa525d98b/exp/asr_stats_raw_sp/train/feats_stats.npz"
# # 导入GlobalMVN类
# from espnet2.layers.global_mvn import GlobalMVN
# # 创建GlobalMVN实例(与ESPNet配置相同)
# global_mvn = GlobalMVN(
# stats_file=stats_file,
# norm_means=True,
# norm_vars=True
# )
# # 转换为PyTorch张量并应用GlobalMVN
# features_tensor = torch.from_numpy(features).float()
# feats_lengths_tensor = torch.from_numpy(feats_lengths).long()
# # 应用GlobalMVN归一化
# normalized_features, normalized_lengths = global_mvn(features_tensor, feats_lengths_tensor)
# # 转换回numpy
# features = normalized_features.numpy()
# feats_lengths = normalized_lengths.numpy()
# # print(f"特征归一化完成: 使用GlobalMVN,统计文件 {stats_file}")
# # 返回处理后的特征
# return features, feats_lengths, labels, names
# except Exception as e:
# print(f"特征处理失败: {e}")
# print("将返回原始音频数据")
# return speeches, lengths, labels, names
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
......@@ -172,147 +85,33 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
mask = seq_range_expand >= seq_length_expand
return mask
def get_args():
parser = argparse.ArgumentParser(description='recognize with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--lm_config', required=True, help='config file')
parser.add_argument('--gpu',
type=int,
default=0,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--wav_scp', required=True, help='wav scp file')
parser.add_argument('--text', required=True, help='ground truth text file')
parser.add_argument('--model_path', required=True, help='torch pt model file')
parser.add_argument('--lm_path', required=True, help='torch pt model file')
parser.add_argument('--result_file', default='./predictions.txt', help='asr result file')
parser.add_argument('--log_file', default='./rtf.txt', help='asr decoding log')
parser.add_argument('--batch_size',
type=int,
default=24,
help='batch_size')
parser.add_argument('--beam_size',
type=int,
default=10,
help='beam_size')
parser.add_argument('--mode',
choices=[
'ctc_greedy_search', 'ctc_prefix_beam_search',
'attention_rescoring', 'attention_lm_rescoring', 'lm_rescoring'],
default='attention_lm_rescoring',
help='decoding mode')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = get_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
dataset = CustomAishellDataset(args.wav_scp, args.text)
# test_data_loader = DataLoader(dataset, batch_size=args.batch_size,
# collate_fn=collate_wrapper)
test_data_loader = DataLoader(dataset, batch_size=args.batch_size,
collate_fn=collate_wrapper)
speech2text = Speech2Text(
args.config,
args.model_path,
None,
args.lm_config,
args.lm_path,
device="cuda"
)
# 手动加载完整的ESPnetLanguageModel对象
# 因为Speech2Text中只存储了原始语言模型,我们需要完整的对象来使用batchify_nll方法
full_lm_model = None
if args.lm_config is not None and args.lm_path is not None:
from espnet2.tasks.lm import LMTask
full_lm_model, _ = LMTask.build_model_from_file(
args.lm_config, args.lm_path, "cuda"
)
full_lm_model.eval()
import onnxruntime as ort
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.enable_cpu_mem_arena = False
sess_options.enable_mem_pattern = False
providers = ['ROCMExecutionProvider']
encoder_path = "/home/sunzhq/workspace/yidong-infer/conformer/onnx_models_batch24/transformer_lm/full/default_encoder_fp16.onnx"
encoder_session = ort.InferenceSession(encoder_path,
providers=providers)
# encoder_session_io = encoder_session.io_binding()
output_names = ["encoder_out", "encode_out_lens"]
time_start = time.perf_counter()
audio_sample_len = 0
encoder_times = []
ctc_times = []
decoder_times = []
lm_times = []
beam_search_times = []
count_times = []
with torch.no_grad(), open(args.result_file, 'w') as fout:
for _, batch in enumerate(test_data_loader):
def process_batch_data(batch, speech2text):
"""Process batch data and prepare for inference"""
speech, speech_lens, labels, names = batch
audio_sample_len += np.sum(speech_lens) / 16000
batch = {"speech": speech, "speech_lengths": speech_lens}
if isinstance(batch["speech"], np.ndarray):
batch["speech"] = torch.tensor(batch["speech"])
if isinstance(batch["speech_lengths"], np.ndarray):
batch["speech_lengths"] = torch.tensor(batch["speech_lengths"])
audio_sample_len = np.sum(speech_lens) / 16000
batch_data = {"speech": speech, "speech_lengths": speech_lens}
# encoder_out_lens = np.array([np.sum(np.any(np.array(batch["speech"]) != 0, axis=1)) for i in range(np.array(batch["speech"]).shape[0])])
# encoder_inputs = {
# 'feats': np.array(batch["speech"]).astype(np.float32)}
if isinstance(batch_data["speech"], np.ndarray):
batch_data["speech"] = torch.tensor(batch_data["speech"])
if isinstance(batch_data["speech_lengths"], np.ndarray):
batch_data["speech_lengths"] = torch.tensor(batch_data["speech_lengths"])
batch = to_device(batch, device='cuda')
feats, encoder_out_lens = speech2text.asr_model.encode(**batch)
batch_data = to_device(batch_data, device='cuda')
feats, encoder_out_lens = speech2text.asr_model.pre_data(**batch_data)
encoder_out_lens = torch.ceil(encoder_out_lens.float() / 4).long()
encoder_inputs = {'feats': feats.cpu().numpy().astype(np.float32)}
return encoder_inputs, encoder_out_lens, labels, names, audio_sample_len
ll_time = time.time()
# encoder_time = time.time()
def inference_step(encoder_inputs, encoder_out_lens, speech2text, full_lm_model, args, encoder_session):
"""Perform inference on prepared data"""
# Run encoder inference
encoder_outputs = encoder_session.run(None, encoder_inputs)
# encoder_out_1, encoder_out_lens_1 = encoder_session_io.get_outputs()
encoder_out_numpy = encoder_outputs[0]
# encoder_out_lens = np.array(encoder_session_io.copy_outputs_to_cpu()[1])
encoder_out = torch.from_numpy(encoder_out_numpy).float().cuda()
# encoder_out_lens = torch.from_numpy(encoder_out_lens_numpy).float().cuda()
# encoder_count = time.time() - encoder_time
# print("encode 耗时:", encoder_count)
# encoder_times.append(encoder_count)
# # ctc_log_probs: [N, T, C]
# ctc_time = time.time()
# # print("encoder_out:",encoder_out.size())
# # a. To device
# batch = to_device(batch, device='cuda')
# # b. Forward Encoder
# # enc: [N, T, C]
# # print(batch)
# encoder_time = time.time()
# encoder_out, encoder_out_lens = speech2text.asr_model.encode(**batch)
# encoder_count = time.time() - encoder_time
# print("encoder_out_lens:", encoder_out_lens, encoder_out_lens.size())
# print("encoder_out:", encoder_out.size())
# print("encode 耗时:", encoder_count)
# # **************************************************
# # encoder_out_lens: tensor([129, 105, 180, 171, 153, 199, 299, 211, 247, 222, 141, 277, 83, 197,
# # 179, 154, 148, 165, 178, 165, 179, 241, 288, 137], device='cuda:0') torch.Size([24])
# # encoder_out: torch.Size([24, 299, 256])
# encoder_times.append(encoder_count)
# #ctc_log_probs: [N, T, C]
# ctc_time = time.time()
ctc_log_probs = torch.nn.functional.log_softmax(
speech2text.asr_model.ctc.ctc_lo(encoder_out), dim=2
)
......@@ -320,9 +119,6 @@ if __name__ == '__main__':
beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs,
args.beam_size, dim=2)
# ctc_count = time.time() - ctc_time
# print("ctc 耗时:", ctc_count)
# ctc_times.append(ctc_count)
num_processes = min(multiprocessing.cpu_count(), args.batch_size)
if args.mode == 'ctc_greedy_search':
......@@ -334,20 +130,15 @@ if __name__ == '__main__':
hyps = map_batch(batch_sents, speech2text.asr_model.token_list,
num_processes, True, 0)
else:
# beam_search_time = time.time()
batch_log_probs_seq_list = beam_log_probs.tolist()
batch_log_probs_idx_list = beam_log_probs_idx.tolist()
batch_len_list = encoder_out_lens.tolist()
# batch_len_list = encoder_out_lens
batch_log_probs_seq = []
batch_log_probs_ids = []
batch_start = [] # only effective in streaming deployment
batch_root = TrieVector()
root_dict = {}
for i in range(len(batch_len_list)):
# print(batch_len_list)
# num_sent = batch_len_list[i]
num_sent = encoder_out.size()[1]
batch_log_probs_seq.append(
batch_log_probs_seq_list[i][0:num_sent])
......@@ -364,12 +155,6 @@ if __name__ == '__main__':
num_processes,
0, -2, 0.99999)
# beam_search_count = time.time() - beam_search_time
# print("beam_search 耗时:", beam_search_count)
# beam_search_times.append(beam_search_count)
# beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs,
# args.beam_size, dim=2)
if args.mode == 'ctc_prefix_beam_search':
hyps = []
for cand_hyps in score_hyps:
......@@ -420,8 +205,6 @@ if __name__ == '__main__':
hyps_lens = hyps_lens_sos.view(B2,)
hyps_pad_sos = hyps_pad_sos.view(B2, max_len + 1)
hyps_pad_eos = hyps_pad_eos.view(B2, max_len + 1)
#hyps_pad_sos = hyps_pad[:, :-1]
#hyps_pad_eos = hyps_pad[:, 1:]
decoder_out, _ = speech2text.asr_model.decoder(encoder_out,encoder_out_lens,hyps_pad_sos.cuda(), hyps_lens.cuda())
......@@ -511,9 +294,7 @@ if __name__ == '__main__':
k += args.beam_size
hyps = map_batch(best_sents, speech2text.asr_model.token_list, num_processes)
elif args.mode == 'lm_rescoring':
# lm_time = time.time()
ctc_score, all_hyps = [], []
max_len = 0
......@@ -566,40 +347,158 @@ if __name__ == '__main__':
k += args.beam_size
hyps = map_batch(best_sents, speech2text.asr_model.token_list, num_processes)
count_time = time.time() - ll_time
count_times.append(count_time)
# lm_count = time.time() - lm_time
# print("lm 耗时:", lm_count)
# lm_times.append(lm_count)
# print("*"*50)
else:
raise NotImplementedError
return hyps
def get_args():
parser = argparse.ArgumentParser(description='recognize with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--lm_config', required=True, help='config file')
parser.add_argument('--gpu',
type=int,
default=0,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--wav_scp', required=True, help='wav scp file')
parser.add_argument('--text', required=True, help='ground truth text file')
parser.add_argument('--model_path', required=True, help='torch pt model file')
parser.add_argument('--lm_path', required=True, help='torch pt model file')
parser.add_argument('--result_file', default='./predictions.txt', help='asr result file')
parser.add_argument('--log_file', default='./rtf.txt', help='asr decoding log')
parser.add_argument('--batch_size',
type=int,
default=24,
help='batch_size')
parser.add_argument('--beam_size',
type=int,
default=10,
help='beam_size')
parser.add_argument('--mode',
choices=[
'ctc_greedy_search', 'ctc_prefix_beam_search',
'attention_rescoring', 'attention_lm_rescoring', 'lm_rescoring'],
default='attention_lm_rescoring',
help='decoding mode')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = get_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
dataset = CustomAishellDataset(args.wav_scp, args.text)
test_data_loader = DataLoader(dataset, batch_size=args.batch_size,
collate_fn=collate_wrapper)
speech2text = Speech2Text(
args.config,
args.model_path,
None,
args.lm_config,
args.lm_path,
device="cuda"
)
full_lm_model = None
if args.lm_config is not None and args.lm_path is not None:
from espnet2.tasks.lm import LMTask
full_lm_model, _ = LMTask.build_model_from_file(
args.lm_config, args.lm_path, "cuda"
)
full_lm_model.eval()
import onnxruntime as ort
providers = ['ROCMExecutionProvider']
encoder_path = "/home/sunzhq/workspace/yidong-infer/conformer/onnx_models_batch24_1/transformer_lm/full/default_encoder_fp16.onnx"
encoder_session = ort.InferenceSession(encoder_path,
providers=providers)
output_names = ["encoder_out", "encoder_out_lens"]
# Warmup: Run inference on first batch to initialize models and cache
print("Starting warmup...")
warmup_start = time.time()
with torch.no_grad():
for i, batch in enumerate(test_data_loader):
if i >= 1: # Warmup with first batch only
break
# Process batch data
encoder_inputs, encoder_out_lens, labels, names, audio_sample_len = process_batch_data(batch, speech2text)
# Run inference
hyps = inference_step(encoder_inputs, encoder_out_lens, speech2text, full_lm_model, args, encoder_session)
print(f"Warmup completed in {time.time() - warmup_start:.2f} seconds")
# Main inference loop
time_start = time.perf_counter()
audio_sample_len_total = 0
infer_times = []
total_infer_times = []
total_start = time.time()
# Open files for saving results in the required format
with torch.no_grad(), open(args.result_file, 'w') as fout, open('ref.trn', 'w') as ref_file, open('hyp.trn', 'w') as hyp_file:
for batch_idx, batch in enumerate(test_data_loader):
# Process batch data (separated from inference)
encoder_inputs, encoder_out_lens, labels, names, audio_sample_len = process_batch_data(batch, speech2text)
audio_sample_len_total += audio_sample_len
# Measure inference time
infer_start = time.time()
# Run inference
hyps = inference_step(encoder_inputs, encoder_out_lens, speech2text, full_lm_model, args, encoder_session)
infer_time = time.time() - infer_start
infer_times.append(infer_time)
# Save results
for i, key in enumerate(names):
content = hyps[i]
# print('{} {}'.format(key, content))
fout.write('{} {}\n'.format(key, content))
# Save to ref.trn and hyp.trn in the required format
# Convert continuous Chinese text to space-separated characters
ref_text = ' '.join(labels[i])
hyp_text = ' '.join(content)
ref_file.write('{} \t ({})\n'.format(ref_text, key))
hyp_file.write('{} \t ({})\n'.format(hyp_text, key))
# print(f"Batch {batch_idx + 1} processed in {infer_time:.3f} seconds")
total_infer_times.append(time.time() - total_start)
total_start = time.time()
# Calculate and print statistics
time_end = time.perf_counter() - time_start
# encoder_times = encoder_times[5:]
# ctc_times = ctc_times[5:]
# beam_search_times = beam_search_times[5:]
# lm_times = lm_times[5:]
# mean_encoder = np.mean(encoder_times)
# mean_ctc = np.mean(ctc_times)
# mean_beam_search = np.mean(beam_search_times)
# mean_lm = np.mean(lm_times)
# print("平均 encode time:", mean_encoder)
# print("平均 ctc time:", mean_ctc)
# print("平均 beam_search time:", mean_beam_search)
# print("平均 lm time:", mean_lm)
count_times = count_times[5:]
mean_count_time = np.mean(count_times)
print("平均 mean_count_time:", mean_count_time, " fps: ", 24/mean_count_time)
# if str(args.gpu) == '0':
# Exclude first few batches for warmup
# if len(infer_times) > 5:
# stable_infer_times = infer_times[5:]
# mean_infer_time = np.mean(stable_infer_times)
# print(f"Average inference time (excluding warmup): {mean_infer_time:.3f} seconds")
# print(f"FPS: {args.batch_size/mean_infer_time:.1f}")
print(f"Total audio processed: {audio_sample_len_total:.1f} seconds")
print(f"Total time: {time_end:.1f} seconds")
print(f"Real-time factor (RTF): {time_end/audio_sample_len_total:.3f}")
print("***************************")
infer_time = sum(infer_times)
avg_infer_fps = 24 * len(infer_times) / sum(infer_times)
print(f"total_infer_time: {infer_time}s")
print(f'avg_infer_fps: {avg_infer_fps}samples/s')
load_data_infer_time = sum(total_infer_times)
load_data_avg_infer_fps = len(total_infer_times) * 24 / sum(total_infer_times)
print(f'load_data_total_infer_time: {load_data_infer_time}s')
print(f'load_data_avg_total_Infer_fps: {load_data_avg_infer_fps} samples/s')
print("******************************")
with open(args.log_file, 'w') as log:
log.write(f"Decoding audio {audio_sample_len} secs, cost {time_end} secs, RTF: {time_end/audio_sample_len}, process {audio_sample_len/time_end} secs audio per second, decoding args: {args}")
log.write(f"Decoding audio {audio_sample_len_total} secs, cost {time_end} secs, RTF: {time_end/audio_sample_len_total}, process {audio_sample_len_total/time_end} secs audio per second, decoding args: {args}")
\ No newline at end of file
#!/usr/bin/bash
# if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# asr_train_config="/home/sunzhq/workspace/yidong-infer/conformer/34e9cabc2c29fd0e3a2917ffa525d98b/exp/asr_train_asr_conformer3_raw_char_batch_bins4000000_accum_grad4_sp/config.yaml"
# asr_model_file="/home/sunzhq/workspace/yidong-infer/conformer/34e9cabc2c29fd0e3a2917ffa525d98b/exp/asr_train_asr_conformer3_raw_char_batch_bins4000000_accum_grad4_sp/valid.acc.ave_10best.pth"
# lm_train_config=/home/sunzhq/workspace/yidong-infer/conformer/34e9cabc2c29fd0e3a2917ffa525d98b/exp/lm_train_lm_transformer_char_batch_bins2000000/config.yaml
# lm_path=/home/sunzhq/workspace/yidong-infer/conformer/34e9cabc2c29fd0e3a2917ffa525d98b/exp/lm_train_lm_transformer_char_batch_bins2000000/valid.loss.ave_10best.pth
# manifest="/home/sunzhq/workspace/yidong-infer/conformer/torch-infer/test"
asr_train_config="/home/sunzhq/workspace/yidong-infer/conformer/torch-infer/exp/asr_train_asr_conformer3_raw_char_batch_bins4000000_accum_grad4_sp/config.yaml"
asr_model_file="/home/sunzhq/workspace/yidong-infer/conformer/torch-infer/exp/asr_train_asr_conformer3_raw_char_batch_bins4000000_accum_grad4_sp/valid.acc.ave_10best.pth"
lm_train_config=/home/sunzhq/workspace/yidong-infer/conformer/torch-infer/exp/lm_train_lm_transformer_char_batch_bins2000000/config.yaml
lm_path=/home/sunzhq/workspace/yidong-infer/conformer/torch-infer/exp/lm_train_lm_transformer_char_batch_bins2000000/valid.loss.ave_10best.pth
manifest="/home/sunzhq/workspace/yidong-infer/conformer/torch-infer/test"
mkdir -p logs
# mode='attention_rescoring'
mode='lm_rescoring'
# num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
export HIP_VISIBLE_DEVICES=0
nohup numactl -N 0 -m 0 python3 infer.py \
--config $asr_train_config \
--model_path $asr_model_file \
--lm_config $lm_train_config \
--lm_path $lm_path \
--gpu 0 \
--wav_scp $manifest/wav.scp --text $manifest/text \
--result_file ./logs/predictions_${mode}_$gpu_id.txt \
--log_file ./logs/log_${mode}_$gpu_id.txt \
--batch_size 24 --beam_size 10 \
--mode $mode 2>&1 | tee result_0.log &
export HIP_VISIBLE_DEVICES=1
nohup numactl -N 1 -m 1 python3 infer.py \
--config $asr_train_config \
--model_path $asr_model_file \
--lm_config $lm_train_config \
--lm_path $lm_path \
--gpu 0 \
--wav_scp $manifest/wav.scp --text $manifest/text \
--result_file ./logs/predictions_${mode}_$gpu_id.txt \
--log_file ./logs/log_${mode}_$gpu_id.txt \
--batch_size 24 --beam_size 10 \
--mode $mode 2>&1 | tee result_1.log &
export HIP_VISIBLE_DEVICES=2
nohup numactl -N 2 -m 2 python3 infer.py \
--config $asr_train_config \
--model_path $asr_model_file \
--lm_config $lm_train_config \
--lm_path $lm_path \
--gpu 0 \
--wav_scp $manifest/wav.scp --text $manifest/text \
--result_file ./logs/predictions_${mode}_$gpu_id.txt \
--log_file ./logs/log_${mode}_$gpu_id.txt \
--batch_size 24 --beam_size 10 \
--mode $mode 2>&1 | tee result_2.log &
export HIP_VISIBLE_DEVICES=3
nohup numactl -N 3 -m 3 python3 infer.py \
--config $asr_train_config \
--model_path $asr_model_file \
--lm_config $lm_train_config \
--lm_path $lm_path \
--gpu 0 \
--wav_scp $manifest/wav.scp --text $manifest/text \
--result_file ./logs/predictions_${mode}_$gpu_id.txt \
--log_file ./logs/log_${mode}_$gpu_id.txt \
--batch_size 24 --beam_size 10 \
--mode $mode 2>&1 | tee result_3.log &
......@@ -20,97 +20,6 @@ except ImportError:
'https://github.com/Slyne/ctc_decoder.git')
sys.exit(1)
def lm_batchify_nll(lm_scorer, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood(nll) from transformer language model using lm_scorer
To avoid OOM, this function separates the input into batches.
Then call batch_score for each batch and combine and return results.
Args:
lm_scorer: Language model scorer object
text: (Batch, Length)
text_lengths: (Batch,)
batch_size: int, samples each batch contain when computing nll,
you may change this to avoid OOM or increase
"""
total_num = text.size(0)
if total_num <= batch_size:
nll, x_lengths = _compute_nll_with_lm_scorer(lm_scorer, text, text_lengths)
else:
nlls = []
x_lengths = []
max_length = text_lengths.max()
start_idx = 0
while True:
end_idx = min(start_idx + batch_size, total_num)
batch_text = text[start_idx:end_idx, :]
batch_text_lengths = text_lengths[start_idx:end_idx]
# batch_nll: [B * T]
batch_nll, batch_x_lengths = _compute_nll_with_lm_scorer(
lm_scorer, batch_text, batch_text_lengths, max_length=max_length
)
nlls.append(batch_nll)
x_lengths.append(batch_x_lengths)
start_idx = end_idx
if start_idx == total_num:
break
nll = torch.cat(nlls)
x_lengths = torch.cat(x_lengths)
assert nll.size(0) == total_num
assert x_lengths.size(0) == total_num
return nll, x_lengths
def _compute_nll_with_lm_scorer(lm_scorer, text: torch.Tensor, text_lengths: torch.Tensor, max_length: int = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood using lm_scorer's score method
This function simulates the nll method using the available score method
from the lm_scorer object.
"""
batch_size = text.size(0)
# For data parallel
if max_length is None:
text = text[:, : text_lengths.max()]
else:
text = text[:, :max_length]
# Initialize nll for each sequence
nll = torch.zeros(batch_size, device=text.device)
# Process each sequence individually
for batch_idx in range(batch_size):
seq_text = text[batch_idx]
seq_length = text_lengths[batch_idx]
# Truncate to actual sequence length
seq_text = seq_text[:seq_length]
# Initialize state for this sequence
state = None
# Process each token position sequentially
for pos in range(len(seq_text) - 1):
# Get current token
current_token = seq_text[pos].unsqueeze(0) # shape: (1,)
# Score the current token
logp, state = lm_scorer.score(current_token, state, None)
# Get the ground truth next token
next_token = seq_text[pos + 1]
# Get the negative log likelihood for the correct next token
token_nll = -logp[next_token]
nll[batch_idx] += token_nll
# x_lengths is text_lengths - 1 (since we score transitions between tokens)
x_lengths = text_lengths - 1
x_lengths = torch.clamp(x_lengths, min=0) # Ensure non-negative
return nll, x_lengths
class CustomAishellDataset(Dataset):
def __init__(self, wav_scp_file, text_file):
......@@ -210,9 +119,10 @@ if __name__ == '__main__':
args = get_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
dataset = CustomAishellDataset(args.wav_scp, args.text)
# test_data_loader = DataLoader(dataset, batch_size=args.batch_size,
# collate_fn=collate_wrapper)
test_data_loader = DataLoader(dataset, batch_size=args.batch_size,
collate_fn=collate_wrapper)
speech2text = Speech2Text(
args.config,
args.model_path,
......@@ -231,6 +141,20 @@ if __name__ == '__main__':
args.lm_config, args.lm_path, "cuda"
)
full_lm_model.eval()
import onnxruntime as ort
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.enable_cpu_mem_arena = False
sess_options.enable_mem_pattern = False
providers = ['ROCMExecutionProvider']
encoder_path = "/home/sunzhq/workspace/yidong-infer/conformer/onnx_models_batch24_1/transformer_lm/full/default_encoder_fp16.onnx"
encoder_session = ort.InferenceSession(encoder_path,
providers=providers)
encoder_session_io = encoder_session.io_binding()
output_names = ["encoder_out", "encoder_out_lens"]
time_start = time.perf_counter()
audio_sample_len = 0
......@@ -239,6 +163,7 @@ if __name__ == '__main__':
decoder_times = []
lm_times = []
beam_search_times = []
count_times = []
with torch.no_grad(), open(args.result_file, 'w') as fout:
for _, batch in enumerate(test_data_loader):
speech, speech_lens, labels, names = batch
......@@ -250,28 +175,67 @@ if __name__ == '__main__':
if isinstance(batch["speech_lengths"], np.ndarray):
batch["speech_lengths"] = torch.tensor(batch["speech_lengths"])
# a. To device
batch = to_device(batch, device='cuda')
feats, encoder_out_lens = speech2text.asr_model.pre_data(**batch)
encoder_out_lens = torch.ceil(encoder_out_lens.float() / 4).long()
encoder_inputs = {'feats': feats.cpu().numpy().astype(np.float32)}
inputData = {}
for key in encoder_inputs.keys():
inputData[key] = ort.OrtValue.ortvalue_from_numpy(encoder_inputs[key], device_type='cuda')
encoder_session_io.bind_input(
name = key,
device_type = inputData[key].device_name(),
device_id = 0,
element_type = np.float32,
shape = inputData[key].shape(),
buffer_ptr = inputData[key].data_ptr())
# for o_n in output_names:
# encoder_session_io.bind_output("encoder_out")
encoder_session_io.bind_output(name="encoder_out", device_type="cuda", device_id=0)
ll_time = time.time()
encoder_session.run_with_iobinding(encoder_session_io)
outputs = encoder_session_io.get_outputs()[0]
ptr = outputs.data_ptr() # GPU 内存地址
shape = outputs.shape()
dtype = torch.float32
total_elements = np.prod(shape)
element_size = 4
total_bytes = total_elements * element_size
methods = [m for m in dir(outputs) if not m.startswith('_')]
print(methods)
# encoder_out = torch.as_tensor(ptr, dtype=dtype, device='cuda').reshape(shape)
print(outputs)
print(outputs.device_name())
print("Has to_dlpack:", hasattr(outputs, 'to_dlpack'))
print("Shape:", outputs.shape())
# encoder_out = torch.from_dlpack(outputs[0].to_dlpack())
# result = encoder_session_io.copy_outputs_to_cpu()
# encoder_out = torch.tensor(result[0]).float().cuda()
print(encoder_out)
# # encoder_time = time.time()
# encoder_outputs = encoder_session.run(None, encoder_inputs)
# # encoder_out_1, encoder_out_lens_1 = encoder_session_io.get_outputs()
# encoder_out_numpy = encoder_outputs[0]
# # encoder_out_lens = np.array(encoder_session_io.copy_outputs_to_cpu()[1])
# encoder_out = torch.from_numpy(encoder_out_numpy).float().cuda()
# print(encoder_out.size())
# b. Forward Encoder
# enc: [N, T, C]
encoder_time = time.time()
encoder_out, encoder_out_lens = speech2text.asr_model.encode(**batch)
encoder_count = time.time() - encoder_time
print("encode 耗时:", encoder_count)
encoder_times.append(encoder_count)
# ctc_log_probs: [N, T, C]
ctc_time = time.time()
ctc_log_probs = torch.nn.functional.log_softmax(
speech2text.asr_model.ctc.ctc_lo(encoder_out), dim=2
)
ctc_count = time.time() - ctc_time
print("ctc 耗时:", ctc_count)
ctc_times.append(ctc_count)
beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs,
args.beam_size, dim=2)
# ctc_count = time.time() - ctc_time
# print("ctc 耗时:", ctc_count)
# ctc_times.append(ctc_count)
num_processes = min(multiprocessing.cpu_count(), args.batch_size)
if args.mode == 'ctc_greedy_search':
......@@ -283,18 +247,21 @@ if __name__ == '__main__':
hyps = map_batch(batch_sents, speech2text.asr_model.token_list,
num_processes, True, 0)
else:
beam_search_time = time.time()
# beam_search_time = time.time()
batch_log_probs_seq_list = beam_log_probs.tolist()
batch_log_probs_idx_list = beam_log_probs_idx.tolist()
batch_len_list = encoder_out_lens.tolist()
# batch_len_list = encoder_out_lens
batch_log_probs_seq = []
batch_log_probs_ids = []
batch_start = [] # only effective in streaming deployment
batch_root = TrieVector()
root_dict = {}
for i in range(len(batch_len_list)):
num_sent = batch_len_list[i]
# print(batch_len_list)
# num_sent = batch_len_list[i]
num_sent = encoder_out.size()[1]
batch_log_probs_seq.append(
batch_log_probs_seq_list[i][0:num_sent])
batch_log_probs_ids.append(
......@@ -310,9 +277,9 @@ if __name__ == '__main__':
num_processes,
0, -2, 0.99999)
beam_search_count = time.time() - beam_search_time
print("beam_search 耗时:", beam_search_count)
beam_search_times.append(beam_search_count)
# beam_search_count = time.time() - beam_search_time
# print("beam_search 耗时:", beam_search_count)
# beam_search_times.append(beam_search_count)
# beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs,
# args.beam_size, dim=2)
......@@ -459,7 +426,7 @@ if __name__ == '__main__':
elif args.mode == 'lm_rescoring':
lm_time = time.time()
# lm_time = time.time()
ctc_score, all_hyps = [], []
max_len = 0
......@@ -512,11 +479,13 @@ if __name__ == '__main__':
k += args.beam_size
hyps = map_batch(best_sents, speech2text.asr_model.token_list, num_processes)
count_time = time.time() - ll_time
count_times.append(count_time)
lm_count = time.time() - lm_time
print("lm 耗时:", lm_count)
lm_times.append(lm_count)
print("*"*50)
# lm_count = time.time() - lm_time
# print("lm 耗时:", lm_count)
# lm_times.append(lm_count)
# print("*"*50)
else:
raise NotImplementedError
......@@ -528,20 +497,22 @@ if __name__ == '__main__':
fout.write('{} {}\n'.format(key, content))
time_end = time.perf_counter() - time_start
encoder_times = encoder_times[5:]
ctc_times = ctc_times[5:]
beam_search_times = beam_search_times[5:]
lm_times = lm_times[5:]
mean_encoder = np.mean(encoder_times)
mean_ctc = np.mean(ctc_times)
mean_beam_search = np.mean(beam_search_times)
mean_lm = np.mean(lm_times)
print("平均 encode time:", mean_encoder)
print("平均 ctc time:", mean_ctc)
print("平均 beam_search time:", mean_beam_search)
print("平均 lm time:", mean_lm)
# encoder_times = encoder_times[5:]
# ctc_times = ctc_times[5:]
# beam_search_times = beam_search_times[5:]
# lm_times = lm_times[5:]
# mean_encoder = np.mean(encoder_times)
# mean_ctc = np.mean(ctc_times)
# mean_beam_search = np.mean(beam_search_times)
# mean_lm = np.mean(lm_times)
# print("平均 encode time:", mean_encoder)
# print("平均 ctc time:", mean_ctc)
# print("平均 beam_search time:", mean_beam_search)
# print("平均 lm time:", mean_lm)
count_times = count_times[5:]
mean_count_time = np.mean(count_times)
print("平均 mean_count_time:", mean_count_time, " fps: ", 24/mean_count_time)
# if str(args.gpu) == '0':
with open(args.log_file, 'w') as log:
log.write(f"Decoding audio {audio_sample_len} secs, cost {time_end} secs, RTF: {time_end/audio_sample_len}, process {audio_sample_len/time_end} secs audio per second, decoding args: {args}")
This source diff could not be displayed because it is too large. You can view the blob instead.
espnet: 0.9.0
files:
asr_model_file: exp/asr_train_asr_conformer3_raw_char_batch_bins4000000_accum_grad4_sp/valid.acc.ave_10best.pth
lm_file: exp/lm_train_lm_transformer_char_batch_bins2000000/valid.loss.ave_10best.pth
python: "3.7.3 (default, Mar 27 2019, 22:11:17) \n[GCC 7.3.0]"
timestamp: 1603088092.704853
torch: 1.6.0
yaml_files:
asr_train_config: exp/asr_train_asr_conformer3_raw_char_batch_bins4000000_accum_grad4_sp/config.yaml
lm_train_config: exp/lm_train_lm_transformer_char_batch_bins2000000/config.yaml
python3 conformer-compute-wer.py ./logs/ref.trn ./logs/hyp.trn
\ No newline at end of file
import logging
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Union
import torch
from packaging.version import parse as V
from typeguard import check_argument_types
from espnet2.asr.ctc import CTC
from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet2.asr.encoder.abs_encoder import AbsEncoder
from espnet2.asr.frontend.abs_frontend import AbsFrontend
from espnet2.asr.postencoder.abs_postencoder import AbsPostEncoder
from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder
from espnet2.asr.specaug.abs_specaug import AbsSpecAug
from espnet2.asr.transducer.error_calculator import ErrorCalculatorTransducer
from espnet2.asr_transducer.utils import get_transducer_task_io
from espnet2.layers.abs_normalize import AbsNormalize
from espnet2.torch_utils.device_funcs import force_gatherable
from espnet2.train.abs_espnet_model import AbsESPnetModel
from espnet.nets.e2e_asr_common import ErrorCalculator
from espnet.nets.pytorch_backend.nets_utils import th_accuracy
from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( # noqa: H301
LabelSmoothingLoss,
)
if V(torch.__version__) >= V("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
class ESPnetASRModel(AbsESPnetModel):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
postencoder: Optional[AbsPostEncoder],
decoder: Optional[AbsDecoder],
ctc: CTC,
joint_network: Optional[torch.nn.Module],
aux_ctc: dict = None,
ctc_weight: float = 0.5,
interctc_weight: float = 0.0,
ignore_id: int = -1,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
transducer_multi_blank_durations: List = [],
transducer_multi_blank_sigma: float = 0.05,
# In a regular ESPnet recipe, <sos> and <eos> are both "<sos/eos>"
# Pretrained HF Tokenizer needs custom sym_sos and sym_eos
sym_sos: str = "<sos/eos>",
sym_eos: str = "<sos/eos>",
extract_feats_in_collect_stats: bool = True,
lang_token_id: int = -1,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
super().__init__()
# NOTE (Shih-Lun): else case is for OpenAI Whisper ASR model,
# which doesn't use <blank> token
if sym_blank in token_list:
self.blank_id = token_list.index(sym_blank)
else:
self.blank_id = 0
if sym_sos in token_list:
self.sos = token_list.index(sym_sos)
else:
self.sos = vocab_size - 1
if sym_eos in token_list:
self.eos = token_list.index(sym_eos)
else:
self.eos = vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.interctc_weight = interctc_weight
self.aux_ctc = aux_ctc
self.token_list = token_list.copy()
#print("frontend:", frontend)
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
self.preencoder = preencoder
self.postencoder = postencoder
self.encoder = encoder
if not hasattr(self.encoder, "interctc_use_conditioning"):
self.encoder.interctc_use_conditioning = False
if self.encoder.interctc_use_conditioning:
self.encoder.conditioning_layer = torch.nn.Linear(
vocab_size, self.encoder.output_size()
)
self.use_transducer_decoder = joint_network is not None
self.error_calculator = None
if self.use_transducer_decoder:
self.decoder = decoder
self.joint_network = joint_network
if not transducer_multi_blank_durations:
from warprnnt_pytorch import RNNTLoss
self.criterion_transducer = RNNTLoss(
blank=self.blank_id,
fastemit_lambda=0.0,
)
else:
from espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank import (
MultiblankRNNTLossNumba,
)
self.criterion_transducer = MultiblankRNNTLossNumba(
blank=self.blank_id,
big_blank_durations=transducer_multi_blank_durations,
sigma=transducer_multi_blank_sigma,
reduction="mean",
fastemit_lambda=0.0,
)
self.transducer_multi_blank_durations = transducer_multi_blank_durations
if report_cer or report_wer:
self.error_calculator_trans = ErrorCalculatorTransducer(
decoder,
joint_network,
token_list,
sym_space,
sym_blank,
report_cer=report_cer,
report_wer=report_wer,
)
else:
self.error_calculator_trans = None
if self.ctc_weight != 0:
self.error_calculator = ErrorCalculator(
token_list, sym_space, sym_blank, report_cer, report_wer
)
else:
# we set self.decoder = None in the CTC mode since
# self.decoder parameters were never used and PyTorch complained
# and threw an Exception in the multi-GPU experiment.
# thanks Jeff Farris for pointing out the issue.
if ctc_weight < 1.0:
assert (
decoder is not None
), "decoder should not be None when attention is used"
else:
decoder = None
logging.warning("Set decoder to none as ctc_weight==1.0")
self.decoder = decoder
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
if report_cer or report_wer:
self.error_calculator = ErrorCalculator(
token_list, sym_space, sym_blank, report_cer, report_wer
)
if ctc_weight == 0.0:
self.ctc = None
else:
self.ctc = ctc
self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
self.is_encoder_whisper = "Whisper" in type(self.encoder).__name__
if self.is_encoder_whisper:
assert (
self.frontend is None
), "frontend should be None when using full Whisper model"
if lang_token_id != -1:
self.lang_token_id = torch.tensor([[lang_token_id]])
else:
self.lang_token_id = None
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
kwargs: "utt_id" is among the input.
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
batch_size = speech.shape[0]
text[text == -1] = self.ignore_id
# for data-parallel
text = text[:, : text_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
loss_att, acc_att, cer_att, wer_att = None, None, None, None
loss_ctc, cer_ctc = None, None
loss_transducer, cer_transducer, wer_transducer = None, None, None
stats = dict()
# 1. CTC branch
if self.ctc_weight != 0.0:
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# Collect CTC branch stats
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
# Intermediate CTC (optional)
loss_interctc = 0.0
if self.interctc_weight != 0.0 and intermediate_outs is not None:
for layer_idx, intermediate_out in intermediate_outs:
# we assume intermediate_out has the same length & padding
# as those of encoder_out
# use auxillary ctc data if specified
loss_ic = None
if self.aux_ctc is not None:
idx_key = str(layer_idx)
if idx_key in self.aux_ctc:
aux_data_key = self.aux_ctc[idx_key]
aux_data_tensor = kwargs.get(aux_data_key, None)
aux_data_lengths = kwargs.get(aux_data_key + "_lengths", None)
if aux_data_tensor is not None and aux_data_lengths is not None:
loss_ic, cer_ic = self._calc_ctc_loss(
intermediate_out,
encoder_out_lens,
aux_data_tensor,
aux_data_lengths,
)
else:
raise Exception(
"Aux. CTC tasks were specified but no data was found"
)
if loss_ic is None:
loss_ic, cer_ic = self._calc_ctc_loss(
intermediate_out, encoder_out_lens, text, text_lengths
)
loss_interctc = loss_interctc + loss_ic
# Collect Intermedaite CTC stats
stats["loss_interctc_layer{}".format(layer_idx)] = (
loss_ic.detach() if loss_ic is not None else None
)
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
loss_interctc = loss_interctc / len(intermediate_outs)
# calculate whole encoder loss
loss_ctc = (
1 - self.interctc_weight
) * loss_ctc + self.interctc_weight * loss_interctc
if self.use_transducer_decoder:
# 2a. Transducer decoder branch
(
loss_transducer,
cer_transducer,
wer_transducer,
) = self._calc_transducer_loss(
encoder_out,
encoder_out_lens,
text,
)
if loss_ctc is not None:
loss = loss_transducer + (self.ctc_weight * loss_ctc)
else:
loss = loss_transducer
# Collect Transducer branch stats
stats["loss_transducer"] = (
loss_transducer.detach() if loss_transducer is not None else None
)
stats["cer_transducer"] = cer_transducer
stats["wer_transducer"] = wer_transducer
else:
# 2b. Attention decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
stats["acc"] = acc_att
stats["cer"] = cer_att
stats["wer"] = wer_att
# Collect total loss stats
stats["loss"] = loss.detach()
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
**kwargs,
) -> Dict[str, torch.Tensor]:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
def pre_data(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
#print("self.normalize:",self.normalize)
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
#if self.preencoder is not None:
# feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
#if self.encoder.interctc_use_conditioning:
# encoder_out, encoder_out_lens, _ = self.encoder(
# feats, feats_lengths, ctc=self.ctc
# )
#else:
# encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
#intermediate_outs = None
#if isinstance(encoder_out, tuple):
# intermediate_outs = encoder_out[1]
# encoder_out = encoder_out[0]
# Post-encoder, e.g. NLU
#if self.postencoder is not None:
# encoder_out, encoder_out_lens = self.postencoder(
# encoder_out, encoder_out_lens
# )
#assert encoder_out.size(0) == speech.size(0), (
# encoder_out.size(),
# speech.size(0),
#)
#if (
# getattr(self.encoder, "selfattention_layer_type", None) != "lf_selfattn"
# and not self.is_encoder_whisper
#):
# assert encoder_out.size(-2) <= encoder_out_lens.max(), (
# encoder_out.size(),
# encoder_out_lens.max(),
# )
#if intermediate_outs is not None:
# return (encoder_out, intermediate_outs), encoder_out_lens
return feats, feats_lengths
def encode(
self, feats: torch.Tensor, feats_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
# with autocast(False):
# # 1. Extract feats
# feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# # 2. Data augmentation
# if self.specaug is not None and self.training:
# feats, feats_lengths = self.specaug(feats, feats_lengths)
# # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
# #print("self.normalize:",self.normalize)
# if self.normalize is not None:
# feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
# if self.preencoder is not None:
# feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder(
feats, feats_lengths, ctc=self.ctc
)
else:
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
encoder_out = encoder_out[0]
# Post-encoder, e.g. NLU
if self.postencoder is not None:
encoder_out, encoder_out_lens = self.postencoder(
encoder_out, encoder_out_lens
)
assert encoder_out.size(0) == feats.size(0), (
encoder_out.size(),
feats.size(0),
)
if (
getattr(self.encoder, "selfattention_layer_type", None) != "lf_selfattn"
and not self.is_encoder_whisper
):
assert encoder_out.size(-2) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
if intermediate_outs is not None:
return (encoder_out, intermediate_outs), encoder_out_lens
return encoder_out, encoder_out_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
return feats, feats_lengths
def nll(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
Normally, this function is called in batchify_nll.
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
ys_pad: (Batch, Length)
ys_pad_lens: (Batch,)
"""
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
) # [batch, seqlen, dim]
batch_size = decoder_out.size(0)
decoder_num_class = decoder_out.size(2)
# nll: negative log-likelihood
nll = torch.nn.functional.cross_entropy(
decoder_out.view(-1, decoder_num_class),
ys_out_pad.view(-1),
ignore_index=self.ignore_id,
reduction="none",
)
nll = nll.view(batch_size, -1)
nll = nll.sum(dim=1)
assert nll.size(0) == batch_size
return nll
def batchify_nll(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
ys_pad: (Batch, Length)
ys_pad_lens: (Batch,)
batch_size: int, samples each batch contain when computing nll,
you may change this to avoid OOM or increase
GPU memory usage
"""
total_num = encoder_out.size(0)
if total_num <= batch_size:
nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
else:
nll = []
start_idx = 0
while True:
end_idx = min(start_idx + batch_size, total_num)
batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
batch_ys_pad = ys_pad[start_idx:end_idx, :]
batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
batch_nll = self.nll(
batch_encoder_out,
batch_encoder_out_lens,
batch_ys_pad,
batch_ys_pad_lens,
)
nll.append(batch_nll)
start_idx = end_idx
if start_idx == total_num:
break
nll = torch.cat(nll)
assert nll.size(0) == total_num
return nll
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
if hasattr(self, "lang_token_id") and self.lang_token_id is not None:
ys_pad = torch.cat(
[
self.lang_token_id.repeat(ys_pad.size(0), 1).to(ys_pad.device),
ys_pad,
],
dim=1,
)
ys_pad_lens += 1
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
# Compute cer/wer using attention-decoder
if self.training or self.error_calculator is None:
cer_att, wer_att = None, None
else:
ys_hat = decoder_out.argmax(dim=-1)
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, acc_att, cer_att, wer_att
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
# Calc CER using CTC
cer_ctc = None
if not self.training and self.error_calculator is not None:
ys_hat = self.ctc.argmax(encoder_out).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc
def _calc_transducer_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
labels: torch.Tensor,
):
"""Compute Transducer loss.
Args:
encoder_out: Encoder output sequences. (B, T, D_enc)
encoder_out_lens: Encoder output sequences lengths. (B,)
labels: Label ID sequences. (B, L)
Return:
loss_transducer: Transducer loss value.
cer_transducer: Character error rate for Transducer.
wer_transducer: Word Error Rate for Transducer.
"""
decoder_in, target, t_len, u_len = get_transducer_task_io(
labels,
encoder_out_lens,
ignore_id=self.ignore_id,
blank_id=self.blank_id,
)
self.decoder.set_device(encoder_out.device)
decoder_out = self.decoder(decoder_in)
joint_out = self.joint_network(
encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
)
loss_transducer = self.criterion_transducer(
joint_out,
target,
t_len,
u_len,
)
cer_transducer, wer_transducer = None, None
if not self.training and self.error_calculator_trans is not None:
cer_transducer, wer_transducer = self.error_calculator_trans(
encoder_out, target
)
return loss_transducer, cer_transducer, wer_transducer
def _calc_batch_ctc_loss(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
):
if self.ctc is None:
return
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
# for data-parallel
text = text[:, : text_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
# Calc CTC loss
do_reduce = self.ctc.reduce
self.ctc.reduce = False
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, text_lengths)
self.ctc.reduce = do_reduce
return loss_ctc
espnet: 0.9.0
files:
asr_model_file: exp/asr_train_asr_conformer3_raw_char_batch_bins4000000_accum_grad4_sp/valid.acc.ave_10best.pth
lm_file: exp/lm_train_lm_transformer_char_batch_bins2000000/valid.loss.ave_10best.pth
python: "3.7.3 (default, Mar 27 2019, 22:11:17) \n[GCC 7.3.0]"
timestamp: 1603088092.704853
torch: 1.6.0
yaml_files:
asr_train_config: exp/asr_train_asr_conformer3_raw_char_batch_bins4000000_accum_grad4_sp/config.yaml
lm_train_config: exp/lm_train_lm_transformer_char_batch_bins2000000/config.yaml
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment