Commit 60a2c57a authored by sunzhq2's avatar sunzhq2 Committed by xuxo
Browse files

update conformer

parent 4a699441
"""V2 backend for `asr_recog.py` using py:class:`espnet.nets.beam_search.BeamSearch`."""
import json
import logging
import torch
from packaging.version import parse as V
from espnet.asr.asr_utils import add_results_to_json, get_model_conf, torch_load
from espnet.asr.pytorch_backend.asr import load_trained_model
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.batch_beam_search import BatchBeamSearch
from espnet.nets.beam_search import BeamSearch
from espnet.nets.lm_interface import dynamic_import_lm
from espnet.nets.scorer_interface import BatchScorerInterface
from espnet.nets.scorers.length_bonus import LengthBonus
from espnet.utils.deterministic_utils import set_deterministic_pytorch
from espnet.utils.io_utils import LoadInputsAndTargets
def recog_v2(args):
"""Decode with custom models that implements ScorerInterface.
Notes:
The previous backend espnet.asr.pytorch_backend.asr.recog
only supports E2E and RNNLM
Args:
args (namespace): The program arguments.
See py:func:`espnet.bin.asr_recog.get_parser` for details
"""
logging.warning("experimental API for custom LMs is selected by --api v2")
if args.batchsize > 1:
raise NotImplementedError("multi-utt batch decoding is not implemented")
if args.streaming_mode is not None:
raise NotImplementedError("streaming mode is not implemented")
if args.word_rnnlm:
raise NotImplementedError("word LM is not implemented")
set_deterministic_pytorch(args)
model, train_args = load_trained_model(args.model)
assert isinstance(model, ASRInterface)
if args.quantize_config is not None:
q_config = set([getattr(torch.nn, q) for q in args.quantize_config])
else:
q_config = {torch.nn.Linear}
if args.quantize_asr_model:
logging.info("Use quantized asr model for decoding")
# See https://github.com/espnet/espnet/pull/3616 for more information.
if (
V(torch.__version__) < V("1.4.0")
and "lstm" in train_args.etype
and torch.nn.LSTM in q_config
):
raise ValueError(
"Quantized LSTM in ESPnet is only supported with torch 1.4+."
)
if args.quantize_dtype == "float16" and V(torch.__version__) < V("1.5.0"):
raise ValueError(
"float16 dtype for dynamic quantization is not supported with torch "
"version < 1.5.0. Switching to qint8 dtype instead."
)
dtype = getattr(torch, args.quantize_dtype)
model = torch.quantization.quantize_dynamic(model, q_config, dtype=dtype)
model.eval()
load_inputs_and_targets = LoadInputsAndTargets(
mode="asr",
load_output=False,
sort_in_input_length=False,
preprocess_conf=train_args.preprocess_conf
if args.preprocess_conf is None
else args.preprocess_conf,
preprocess_args={"train": False},
)
if args.rnnlm:
lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
# NOTE: for a compatibility with less than 0.5.0 version models
lm_model_module = getattr(lm_args, "model_module", "default")
lm_class = dynamic_import_lm(lm_model_module, lm_args.backend)
lm = lm_class(len(train_args.char_list), lm_args)
torch_load(args.rnnlm, lm)
if args.quantize_lm_model:
logging.info("Use quantized lm model")
dtype = getattr(torch, args.quantize_dtype)
lm = torch.quantization.quantize_dynamic(lm, q_config, dtype=dtype)
lm.eval()
else:
lm = None
if args.ngram_model:
from espnet.nets.scorers.ngram import NgramFullScorer, NgramPartScorer
if args.ngram_scorer == "full":
ngram = NgramFullScorer(args.ngram_model, train_args.char_list)
else:
ngram = NgramPartScorer(args.ngram_model, train_args.char_list)
else:
ngram = None
scorers = model.scorers()
scorers["lm"] = lm
scorers["ngram"] = ngram
scorers["length_bonus"] = LengthBonus(len(train_args.char_list))
weights = dict(
decoder=1.0 - args.ctc_weight,
ctc=args.ctc_weight,
lm=args.lm_weight,
ngram=args.ngram_weight,
length_bonus=args.penalty,
)
beam_search = BeamSearch(
beam_size=args.beam_size,
vocab_size=len(train_args.char_list),
weights=weights,
scorers=scorers,
sos=model.sos,
eos=model.eos,
token_list=train_args.char_list,
pre_beam_score_key=None if args.ctc_weight == 1.0 else "full",
)
# TODO(karita): make all scorers batchfied
if args.batchsize == 1:
non_batch = [
k
for k, v in beam_search.full_scorers.items()
if not isinstance(v, BatchScorerInterface)
]
if len(non_batch) == 0:
beam_search.__class__ = BatchBeamSearch
logging.info("BatchBeamSearch implementation is selected.")
else:
logging.warning(
f"As non-batch scorers {non_batch} are found, "
f"fall back to non-batch implementation."
)
if args.ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
if args.ngpu == 1:
device = "cuda"
else:
device = "cpu"
dtype = getattr(torch, args.dtype)
logging.info(f"Decoding device={device}, dtype={dtype}")
model.to(device=device, dtype=dtype).eval()
beam_search.to(device=device, dtype=dtype).eval()
# read json data
with open(args.recog_json, "rb") as f:
js = json.load(f)["utts"]
new_js = {}
with torch.no_grad():
for idx, name in enumerate(js.keys(), 1):
logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
batch = [(name, js[name])]
feat = load_inputs_and_targets(batch)[0][0]
enc = model.encode(torch.as_tensor(feat).to(device=device, dtype=dtype))
nbest_hyps = beam_search(
x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio
)
nbest_hyps = [
h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), args.nbest)]
]
new_js[name] = add_results_to_json(
js[name], nbest_hyps, train_args.char_list
)
with open(args.result_label, "wb") as f:
f.write(
json.dumps(
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
).encode("utf_8")
)
#!/usr/bin/env python3
import logging
import os
import random
import sys
from distutils.util import strtobool
import configargparse
import numpy as np
from espnet.asr.pytorch_backend.asr import enhance
# NOTE: you need this func to generate our sphinx doc
def get_parser():
parser = configargparse.ArgumentParser(
description="Enhance noisy speech for speech recognition",
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
)
# general configuration
parser.add("--config", is_config_file=True, help="config file path")
parser.add(
"--config2",
is_config_file=True,
help="second config file path that overwrites the settings in `--config`.",
)
parser.add(
"--config3",
is_config_file=True,
help="third config file path that overwrites the settings "
"in `--config` and `--config2`.",
)
parser.add_argument("--ngpu", default=0, type=int, help="Number of GPUs")
parser.add_argument(
"--backend",
default="chainer",
type=str,
choices=["chainer", "pytorch"],
help="Backend library",
)
parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
parser.add_argument("--seed", default=1, type=int, help="Random seed")
parser.add_argument("--verbose", "-V", default=1, type=int, help="Verbose option")
parser.add_argument(
"--batchsize",
default=1,
type=int,
help="Batch size for beam search (0: means no batch processing)",
)
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing",
)
# task related
parser.add_argument(
"--recog-json", type=str, help="Filename of recognition data (json)"
)
# model (parameter) related
parser.add_argument(
"--model", type=str, required=True, help="Model file parameters to read"
)
parser.add_argument(
"--model-conf", type=str, default=None, help="Model config file"
)
# Outputs configuration
parser.add_argument(
"--enh-wspecifier",
type=str,
default=None,
help="Specify the output way for enhanced speech."
"e.g. ark,scp:outdir,wav.scp",
)
parser.add_argument(
"--enh-filetype",
type=str,
default="sound",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for enhanced speech. "
'"mat" is the matrix format in kaldi',
)
parser.add_argument("--fs", type=int, default=16000, help="The sample frequency")
parser.add_argument(
"--keep-length",
type=strtobool,
default=True,
help="Adjust the output length to match " "with the input for enhanced speech",
)
parser.add_argument(
"--image-dir", type=str, default=None, help="The directory saving the images."
)
parser.add_argument(
"--num-images",
type=int,
default=20,
help="The number of images files to be saved. "
"If negative, all samples are to be saved.",
)
# IStft
parser.add_argument(
"--apply-istft",
type=strtobool,
default=True,
help="Apply istft to the output from the network",
)
parser.add_argument(
"--istft-win-length",
type=int,
default=512,
help="The window length for istft. "
"This option is ignored "
"if stft is found in the preprocess-conf",
)
parser.add_argument(
"--istft-n-shift",
type=str,
default=256,
help="The window type for istft. "
"This option is ignored "
"if stft is found in the preprocess-conf",
)
parser.add_argument(
"--istft-window",
type=str,
default="hann",
help="The window type for istft. "
"This option is ignored "
"if stft is found in the preprocess-conf",
)
return parser
def main(args):
parser = get_parser()
args = parser.parse_args(args)
# logging info
if args.verbose == 1:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
elif args.verbose == 2:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.warning("Skip DEBUG/INFO messages")
# check CUDA_VISIBLE_DEVICES
if args.ngpu > 0:
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if cvd is None:
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
elif args.ngpu != len(cvd.split(",")):
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
sys.exit(1)
# TODO(kamo): support of multiple GPUs
if args.ngpu > 1:
logging.error("The program only supports ngpu=1.")
sys.exit(1)
# display PYTHONPATH
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
# seed setting
random.seed(args.seed)
np.random.seed(args.seed)
logging.info("set random seed = %d" % args.seed)
# recog
logging.info("backend = " + args.backend)
if args.backend == "pytorch":
enhance(args)
else:
raise ValueError("Only pytorch is supported.")
if __name__ == "__main__":
main(sys.argv[1:])
#
# SPDX-FileCopyrightText:
# Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Initialize sub package."""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment