Unverified Commit b6a61c3f authored by sdarkhovsky's avatar sdarkhovsky Committed by GitHub
Browse files

Fix interactive asr (#900)

* updated the build_generator call to include the models argument

* fixed RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same
parent c388ec2b
......@@ -55,8 +55,9 @@ def check_args(args):
def process_predictions(args, hypos, sp, tgt_dict):
res = []
device = torch.device("cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu")
for hypo in hypos[: min(len(hypos), args.nbest)]:
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
hyp_pieces = tgt_dict.string(hypo["tokens"].int().to(device))
hyp_words = sp.DecodePieces(hyp_pieces.split())
res.append(hyp_words)
return res
......@@ -96,7 +97,8 @@ def calcMN(features):
def transcribe(waveform, args, task, generator, models, sp, tgt_dict):
num_features = 80
output = torchaudio.compliance.kaldi.fbank(waveform, num_mel_bins=num_features)
output_cmvn = calcMN(output.cpu().detach())
device = torch.device("cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu")
output_cmvn = calcMN(output.to(device).detach())
# size (m, n)
source = output_cmvn
......@@ -151,7 +153,7 @@ def setup_asr(args, logger):
optimize_models(args, use_cuda, models)
# Initialize generator
generator = task.build_generator(args)
generator = task.build_generator(models, args)
sp = spm.SentencePieceProcessor()
sp.Load(os.path.join(args.data, "spm.model"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment