Unverified Commit b6a61c3f authored by sdarkhovsky's avatar sdarkhovsky Committed by GitHub
Browse files

Fix interactive asr (#900)

* updated the build_generator call to include the models argument

* fixed RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same
parent c388ec2b
...@@ -55,8 +55,9 @@ def check_args(args): ...@@ -55,8 +55,9 @@ def check_args(args):
def process_predictions(args, hypos, sp, tgt_dict): def process_predictions(args, hypos, sp, tgt_dict):
res = [] res = []
device = torch.device("cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu")
for hypo in hypos[: min(len(hypos), args.nbest)]: for hypo in hypos[: min(len(hypos), args.nbest)]:
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu()) hyp_pieces = tgt_dict.string(hypo["tokens"].int().to(device))
hyp_words = sp.DecodePieces(hyp_pieces.split()) hyp_words = sp.DecodePieces(hyp_pieces.split())
res.append(hyp_words) res.append(hyp_words)
return res return res
...@@ -96,7 +97,8 @@ def calcMN(features): ...@@ -96,7 +97,8 @@ def calcMN(features):
def transcribe(waveform, args, task, generator, models, sp, tgt_dict): def transcribe(waveform, args, task, generator, models, sp, tgt_dict):
num_features = 80 num_features = 80
output = torchaudio.compliance.kaldi.fbank(waveform, num_mel_bins=num_features) output = torchaudio.compliance.kaldi.fbank(waveform, num_mel_bins=num_features)
output_cmvn = calcMN(output.cpu().detach()) device = torch.device("cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu")
output_cmvn = calcMN(output.to(device).detach())
# size (m, n) # size (m, n)
source = output_cmvn source = output_cmvn
...@@ -151,7 +153,7 @@ def setup_asr(args, logger): ...@@ -151,7 +153,7 @@ def setup_asr(args, logger):
optimize_models(args, use_cuda, models) optimize_models(args, use_cuda, models)
# Initialize generator # Initialize generator
generator = task.build_generator(args) generator = task.build_generator(models, args)
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.Load(os.path.join(args.data, "spm.model")) sp.Load(os.path.join(args.data, "spm.model"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment