infer.py 8.71 KB
Newer Older
Dmytro Okhonko's avatar
Dmytro Okhonko committed
1
2
3
4
5
6
7
8
9
10
11
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Run inference for pre-processed data with a trained model.
"""

import logging
Jeff Cai's avatar
Jeff Cai committed
12
import math
Dmytro Okhonko's avatar
Dmytro Okhonko committed
13
14
15
16
import os

import sentencepiece as spm
import torch
Jeff Cai's avatar
Jeff Cai committed
17
from fairseq import checkpoint_utils, options, progress_bar, utils, tasks
Dmytro Okhonko's avatar
Dmytro Okhonko committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.utils import import_user_module


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def add_asr_eval_argument(parser):
    parser.add_argument("--kspmodel", default=None, help="sentence piece model")
    parser.add_argument(
        "--wfstlm", default=None, help="wfstlm on dictonary output units"
    )
    parser.add_argument(
        "--rnnt_decoding_type",
        default="greedy",
        help="wfstlm on dictonary\
output units",
    )
    parser.add_argument(
Jeff Cai's avatar
Jeff Cai committed
38
        "--lm-weight",
Dmytro Okhonko's avatar
Dmytro Okhonko committed
39
        "--lm_weight",
Jeff Cai's avatar
Jeff Cai committed
40
        type=float,
Dmytro Okhonko's avatar
Dmytro Okhonko committed
41
        default=0.2,
Jeff Cai's avatar
Jeff Cai committed
42
        help="weight for lm while interpolating with neural score",
Dmytro Okhonko's avatar
Dmytro Okhonko committed
43
44
45
46
    )
    parser.add_argument(
        "--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
    )
Jeff Cai's avatar
Jeff Cai committed
47
48
49
50
51
52
53
54
55
    parser.add_argument(
        "--w2l-decoder", choices=["viterbi", "kenlm"], help="use a w2l decoder"
    )
    parser.add_argument("--lexicon", help="lexicon for w2l decoder")
    parser.add_argument("--kenlm-model", help="kenlm model for w2l decoder")
    parser.add_argument("--beam-threshold", type=float, default=25.0)
    parser.add_argument("--word-score", type=float, default=1.0)
    parser.add_argument("--unk-weight", type=float, default=-math.inf)
    parser.add_argument("--sil-weight", type=float, default=0.0)
Dmytro Okhonko's avatar
Dmytro Okhonko committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    return parser


def check_args(args):
    assert args.path is not None, "--path required for generation!"
    assert args.results_path is not None, "--results_path required for generation!"
    assert (
        not args.sampling or args.nbest == args.beam
    ), "--sampling requires --nbest to be equal to --beam"
    assert (
        args.replace_unk is None or args.raw_text
    ), "--replace-unk requires a raw text dataset (--raw-text)"


def get_dataset_itr(args, task):
    return task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=(1000000.0, 1000000.0),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
    ).next_epoch_itr(shuffle=False)


Jeff Cai's avatar
Jeff Cai committed
84
85
86
def process_predictions(
    args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
):
Dmytro Okhonko's avatar
Dmytro Okhonko committed
87
88
89
90
    for hypo in hypos[: min(len(hypos), args.nbest)]:
        hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
        hyp_words = sp.DecodePieces(hyp_pieces.split())
        print(
Jeff Cai's avatar
Jeff Cai committed
91
            "{} ({}-{})".format(hyp_pieces, speaker, id), file=res_files["hypo.units"]
Dmytro Okhonko's avatar
Dmytro Okhonko committed
92
        )
Jeff Cai's avatar
Jeff Cai committed
93
        print("{} ({}-{})".format(hyp_words, speaker, id), file=res_files["hypo.words"])
Dmytro Okhonko's avatar
Dmytro Okhonko committed
94
95
96

        tgt_pieces = tgt_dict.string(target_tokens)
        tgt_words = sp.DecodePieces(tgt_pieces.split())
Jeff Cai's avatar
Jeff Cai committed
97
98
        print("{} ({}-{})".format(tgt_pieces, speaker, id), file=res_files["ref.units"])
        print("{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"])
Dmytro Okhonko's avatar
Dmytro Okhonko committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        # only score top hypothesis
        if not args.quiet:
            logger.debug("HYPO:" + hyp_words)
            logger.debug("TARGET:" + tgt_words)
            logger.debug("___________________")


def prepare_result_files(args):
    def get_res_file(file_prefix):
        path = os.path.join(
            args.results_path,
            "{}-{}-{}.txt".format(
                file_prefix, os.path.basename(args.path), args.gen_subset
            ),
        )
        return open(path, "w", buffering=1)

    return {
        "hypo.words": get_res_file("hypo.word"),
        "hypo.units": get_res_file("hypo.units"),
        "ref.words": get_res_file("ref.word"),
        "ref.units": get_res_file("ref.units"),
    }


Jeff Cai's avatar
Jeff Cai committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def load_models_and_criterions(filenames, arg_overrides=None, task=None):
    models = []
    criterions = []
    for filename in filenames:
        if not os.path.exists(filename):
            raise IOError("Model file not found: {}".format(filename))
        state = checkpoint_utils.load_checkpoint_to_cpu(filename, arg_overrides)

        args = state["args"]
        if task is None:
            task = tasks.setup_task(args)

        # build model for ensemble
        model = task.build_model(args)
        model.load_state_dict(state["model"], strict=True)
        models.append(model)

        criterion = task.build_criterion(args)
        if "criterion" in state:
            criterion.load_state_dict(state["criterion"], strict=True)
        criterions.append(criterion)
    return models, criterions, args


Dmytro Okhonko's avatar
Dmytro Okhonko committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def optimize_models(args, use_cuda, models):
    """Optimize ensemble for generation
    """
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()


def main(args):
    check_args(args)
    import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 30000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    logger.info(
        "| {} {} {} examples".format(
            args.data, args.gen_subset, len(task.dataset(args.gen_subset))
        )
    )

    # Set dictionary
    tgt_dict = task.target_dictionary

Jeff Cai's avatar
Jeff Cai committed
184
    logger.info("| decoding with criterion {}".format(args.criterion))
Dmytro Okhonko's avatar
Dmytro Okhonko committed
185
186
187

    # Load ensemble
    logger.info("| loading model(s) from {}".format(args.path))
Jeff Cai's avatar
Jeff Cai committed
188
    models, criterions, _model_args = load_models_and_criterions(
Dmytro Okhonko's avatar
Dmytro Okhonko committed
189
        args.path.split(":"),
Jeff Cai's avatar
Jeff Cai committed
190
191
        arg_overrides=eval(args.model_overrides),  # noqa
        task=task,
Dmytro Okhonko's avatar
Dmytro Okhonko committed
192
193
194
    )
    optimize_models(args, use_cuda, models)

Jeff Cai's avatar
Jeff Cai committed
195
196
197
198
199
    # hack to pass transitions to W2lDecoder
    if args.criterion == "asg_loss":
        trans = criterions[0].asg.trans.data
        args.asg_transitions = torch.flatten(trans).tolist()

Dmytro Okhonko's avatar
Dmytro Okhonko committed
200
201
202
203
204
205
206
207
208
209
210
211
212
    # Load dataset (possibly sharded)
    itr = get_dataset_itr(args, task)

    # Initialize generator
    gen_timer = StopwatchMeter()
    generator = task.build_generator(args)

    num_sentences = 0

    if not os.path.exists(args.results_path):
        os.makedirs(args.results_path)

    sp = spm.SentencePieceProcessor()
Jeff Cai's avatar
Jeff Cai committed
213
    sp.Load(os.path.join(args.data, "spm.model"))
Dmytro Okhonko's avatar
Dmytro Okhonko committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231

    res_files = prepare_result_files(args)
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if "net_input" not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample["target"][:, : args.prefix_size]

            gen_timer.start()
            hypos = task.inference_step(generator, models, sample, prefix_tokens)
            num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
            gen_timer.stop(num_generated_tokens)

Jeff Cai's avatar
Jeff Cai committed
232
            for i, sample_id in enumerate(sample["id"].tolist()):
Dmytro Okhonko's avatar
Dmytro Okhonko committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
                speaker = task.dataset(args.gen_subset).speakers[int(sample_id)]
                id = task.dataset(args.gen_subset).ids[int(sample_id)]
                target_tokens = (
                    utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu()
                )
                # Process top predictions
                process_predictions(
                    args, hypos[i], sp, tgt_dict, target_tokens, res_files, speaker, id
                )

            wps_meter.update(num_generated_tokens)
            t.log({"wps": round(wps_meter.avg)})
            num_sentences += sample["nsentences"]

    logger.info(
        "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
        "sentences/s, {:.2f} tokens/s)".format(
            num_sentences,
            gen_timer.n,
            gen_timer.sum,
            num_sentences / gen_timer.sum,
            1.0 / gen_timer.avg,
        )
    )
    logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))


def cli_main():
    parser = options.get_generation_parser()
    parser = add_asr_eval_argument(parser)
    args = options.parse_args_and_arch(parser)
    main(args)


if __name__ == "__main__":
    cli_main()