#!/usr/bin/env python3 -u # Copyright (c) 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. # #------------------------------------------------------------------------- # # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import os import time from collections import namedtuple import numpy as np import torch from torch.serialization import default_restore_location from fairseq import data, options, tokenizer, utils, log_helper from fairseq.sequence_generator import SequenceGenerator from fairseq.meters import StopwatchMeter from fairseq.models.transformer import TransformerModel import dllogger from apply_bpe import BPE Batch = namedtuple('Batch', 'srcs tokens lengths') Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments') def load_ensemble_for_inference(filenames): """Load an ensemble of models for inference. model_arg_overrides allows you to pass a dictionary model_arg_overrides -- {'arg_name': arg} -- to override model args that were used during model training """ # load model architectures and weights states = [] for filename in filenames: if not os.path.exists(filename): raise IOError('Model file not found: {}'.format(filename)) state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu')) states.append(state) ensemble = [] for state in states: args = state['args'] # build model for ensemble model = TransformerModel.build_model(args) model.load_state_dict(state['model'], strict=True) ensemble.append(model) src_dict = states[0]['extra_state']['src_dict'] tgt_dict = states[0]['extra_state']['tgt_dict'] return ensemble, args, src_dict, tgt_dict def buffered_read(buffer_size, data_descriptor): buffer = [] for src_str in data_descriptor: buffer.append(src_str.strip()) if len(buffer) >= buffer_size: yield buffer buffer = [] if buffer: yield buffer def make_batches(lines, args, src_dict, max_positions, bpe=None): tokens = [ tokenizer.Tokenizer.tokenize( src_str, src_dict, tokenize=tokenizer.tokenize_en, add_if_not_exist=False, bpe=bpe ).long() for src_str in lines ] lengths = np.array([t.numel() for t in tokens]) itr = data.EpochBatchIterator( dataset=data.LanguagePairDataset(tokens, lengths, src_dict), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions, ).next_epoch_itr(shuffle=False) for batch in itr: yield Batch( srcs=[lines[i] for i in batch['id']], tokens=batch['net_input']['src_tokens'], lengths=batch['net_input']['src_lengths'], ), batch['id'] def setup_logger(args): if not args.no_dllogger: dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=1, filename=args.stat_file)]) for k, v in vars(args).items(): dllogger.log(step='PARAMETER', data={k:v}, verbosity=0) container_setup_info = log_helper.get_framework_env_vars() dllogger.log(step='PARAMETER', data=container_setup_info, verbosity=0) dllogger.metadata('throughput', {'unit':'tokens/s', 'format':':/3f', 'GOAL':'MAXIMIZE', 'STAGE':'INFER'}) else: dllogger.init(backends=[]) def main(args): setup_logger(args) args.interactive = sys.stdin.isatty() and not args.file # Just make the code more understendable if args.file: data_descriptor = open(args.file, 'r') else: data_descriptor = sys.stdin if args.interactive: args.buffer_size = 1 if args.max_tokens is None and args.max_sentences is None: args.max_sentences = 1 if args.buffer_size > 50000: print("WARNING: To prevent memory exhaustion buffer size is set to 50000", file=sys.stderr) args.buffer_size = 50000 assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' assert not args.max_sentences or args.max_sentences <= args.buffer_size, \ '--max-sentences/--batch-size cannot be larger than --buffer-size' print(args, file=sys.stderr) use_cuda = torch.cuda.is_available() and not args.cpu processing_start = time.time() # Load ensemble print('| loading model(s) from {}'.format(args.path), file=sys.stderr) model_paths = args.path.split(':') models, model_args, src_dict, tgt_dict = load_ensemble_for_inference(model_paths) if args.fp16: for model in models: model.half() # Optimize ensemble for generation for model in models: model.make_generation_fast_(need_attn=args.print_alignment) # Initialize generator translator = SequenceGenerator( models, tgt_dict.get_metadata(), maxlen=args.max_target_positions, beam_size=args.beam, stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized), len_penalty=args.lenpen, unk_penalty=args.unkpen, sampling=args.sampling, sampling_topk=args.sampling_topk, minlen=args.min_len, sampling_temperature=args.sampling_temperature ) if use_cuda: translator.cuda() # Load BPE codes file if args.bpe_codes: codes = open(args.bpe_codes, 'r') bpe = BPE(codes) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) def make_result(src_str, hypos): result = Translation( src_str=src_str, hypos=[], pos_scores=[], alignments=[], ) # Process top predictions for hypo in hypos[:min(len(hypos), args.nbest)]: hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, ) hypo_str = tokenizer.Tokenizer.detokenize(hypo_str, 'de').strip() result.hypos.append((hypo['score'], hypo_str)) result.pos_scores.append('P\t' + ' '.join(f'{x:.4f}' for x in hypo['positional_scores'].tolist())) result.alignments.append('A\t' + ' '.join(str(utils.item(x)) for x in alignment) if args.print_alignment else None ) return result gen_timer = StopwatchMeter() def process_batch(batch): tokens = batch.tokens lengths = batch.lengths if use_cuda: tokens = tokens.cuda() lengths = lengths.cuda() translation_start = time.time() gen_timer.start() translations = translator.generate( tokens, lengths, maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b), ) gen_timer.stop(sum(len(h[0]['tokens']) for h in translations)) dllogger.log(step='infer', data={'latency': time.time() - translation_start}) return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)] if args.interactive: print('| Type the input sentence and press return:') for inputs in buffered_read(args.buffer_size, data_descriptor): indices = [] results = [] for batch, batch_indices in make_batches(inputs, args, src_dict, args.max_positions, bpe): indices.extend(batch_indices) results += process_batch(batch) for i in np.argsort(indices): result = results[i] print(result.src_str, file=sys.stderr) for hypo, pos_scores, align in zip(result.hypos, result.pos_scores, result.alignments): print(f'Score {hypo[0]}', file=sys.stderr) print(hypo[1]) print(pos_scores, file=sys.stderr) if align is not None: print(align, file=sys.stderr) if args.file: data_descriptor.close() log_dict = { 'throughput': 1./gen_timer.avg, 'latency_avg': sum(gen_timer.intervals)/len(gen_timer.intervals), 'latency_p90': gen_timer.p(90), 'latency_p95': gen_timer.p(95), 'latency_p99': gen_timer.p(99), 'total_infernece_time': gen_timer.sum, 'total_run_time': time.time() - processing_start, } print('Translation time: {} s'.format(log_dict['total_infernece_time']), file=sys.stderr) print('Model throughput (beam {}): {} tokens/s'.format(args.beam, log_dict['throughput']), file=sys.stderr) print('Latency:\n\tAverage {:.3f}s\n\tp90 {:.3f}s\n\tp95 {:.3f}s\n\tp99 {:.3f}s'.format( log_dict['latency_avg'], log_dict['latency_p90'], log_dict['latency_p95'], log_dict['latency_p99']), file=sys.stderr) print('End to end time: {} s'.format(log_dict['total_run_time']), file=sys.stderr) dllogger.log(step=(), data=log_dict) if __name__ == '__main__': parser = options.get_inference_parser() parser.add_argument('--no-dllogger', action='store_true') ARGS = options.parse_args_and_arch(parser) main(ARGS)