import rerank_utils import rerank_generate import rerank_score_bw import rerank_score_lm from fairseq import bleu, options from fairseq.data import dictionary from examples.noisychannel import rerank_options from multiprocessing import Pool import math import numpy as np def score_target_hypo(args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize): print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c) gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args) dict = dictionary.Dictionary() scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) ordered_hypos = {} ordered_targets = {} for shard_id in range(len(bitext1_lst)): bitext1 = bitext1_lst[shard_id] bitext2 = bitext2_lst[shard_id] gen_output = gen_output_lst[shard_id] lm_res = lm_res_lst[shard_id] total = len(bitext1.rescore_source.keys()) source_lst = [] hypo_lst = [] score_lst = [] reference_lst = [] j = 1 best_score = -math.inf for i in range(total): # length is measured in terms of words, not bpe tokens, since models may not share the same bpe target_len = len(bitext1.rescore_hypo[i].split()) if lm_res is not None: lm_score = lm_res.score[i] else: lm_score = 0 if bitext2 is not None: bitext2_score = bitext2.rescore_score[i] bitext2_backwards = bitext2.backwards else: bitext2_score = None bitext2_backwards = None score = rerank_utils.get_score(a, b, c, target_len, bitext1.rescore_score[i], bitext2_score, lm_score=lm_score, lenpen=lenpen, src_len=bitext1.source_lengths[i], tgt_len=bitext1.target_lengths[i], bitext1_backwards=bitext1.backwards, bitext2_backwards=bitext2_backwards, normalize=normalize) if score > best_score: best_score = score best_hypo = bitext1.rescore_hypo[i] if j == gen_output.num_hypos[i] or j == args.num_rescore: j = 1 hypo_lst.append(best_hypo) score_lst.append(best_score) source_lst.append(bitext1.rescore_source[i]) reference_lst.append(bitext1.rescore_target[i]) best_score = -math.inf best_hypo = "" else: j += 1 gen_keys = list(sorted(gen_output.no_bpe_target.keys())) for key in range(len(gen_keys)): if args.prefix_len is None: assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], \ ("pred and rescore hypo mismatch: i: " + str(key) + ", " + str(hypo_lst[key]) + str(gen_keys[key]) + str(gen_output.no_bpe_hypo[key])) sys_tok = dict.encode_line(hypo_lst[key]) ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]]) scorer.add(ref_tok, sys_tok) else: full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]) sys_tok = dict.encode_line(full_hypo) ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]]) scorer.add(ref_tok, sys_tok) # if only one set of hyper parameters is provided, write the predictions to a file if write_hypos: # recover the orinal ids from n best list generation for key in range(len(gen_output.no_bpe_target)): if args.prefix_len is None: assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], \ "pred and rescore hypo mismatch:"+"i:"+str(key)+str(hypo_lst[key]) + str(gen_output.no_bpe_hypo[key]) ordered_hypos[gen_keys[key]] = hypo_lst[key] ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]] else: full_hypo = rerank_utils.get_full_from_prefix(hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]) ordered_hypos[gen_keys[key]] = full_hypo ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[gen_keys[key]] # write the hypos in the original order from nbest list generation if args.num_shards == (len(bitext1_lst)): with open(target_outfile, 'w') as t: with open(hypo_outfile, 'w') as h: for key in range(len(ordered_hypos)): t.write(ordered_targets[key]) h.write(ordered_hypos[key]) res = scorer.result_string(4) if write_hypos: print(res) score = rerank_utils.parse_bleu_scoring(res) return score def match_target_hypo(args, target_outfile, hypo_outfile): """combine scores from the LM and bitext models, and write the top scoring hypothesis to a file""" if len(args.weight1) == 1: res = score_target_hypo(args, args.weight1[0], args.weight2[0], args.weight3[0], args.lenpen[0], target_outfile, hypo_outfile, True, args.normalize) rerank_scores = [res] else: print("launching pool") with Pool(32) as p: rerank_scores = p.starmap(score_target_hypo, [(args, args.weight1[i], args.weight2[i], args.weight3[i], args.lenpen[i], target_outfile, hypo_outfile, False, args.normalize) for i in range(len(args.weight1))]) if len(rerank_scores) > 1: best_index = np.argmax(rerank_scores) best_score = rerank_scores[best_index] print("best score", best_score) print("best lenpen", args.lenpen[best_index]) print("best weight1", args.weight1[best_index]) print("best weight2", args.weight2[best_index]) print("best weight3", args.weight3[best_index]) return args.lenpen[best_index], args.weight1[best_index], \ args.weight2[best_index], args.weight3[best_index], best_score else: return args.lenpen[0], args.weight1[0], args.weight2[0], args.weight3[0], rerank_scores[0] def load_score_files(args): if args.all_shards: shard_ids = list(range(args.num_shards)) else: shard_ids = [args.shard_id] gen_output_lst = [] bitext1_lst = [] bitext2_lst = [] lm_res1_lst = [] for shard_id in shard_ids: using_nbest = args.nbest_list is not None pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ backwards_preprocessed_dir, lm_preprocessed_dir = \ rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, args.gen_model_name, shard_id, args.num_shards, args.sampling, args.prefix_len, args.target_prefix_frac, args.source_prefix_frac) rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name, target_prefix_frac=args.target_prefix_frac, source_prefix_frac=args.source_prefix_frac, backwards=args.backwards1) if args.score_model2 is not None: score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name, target_prefix_frac=args.target_prefix_frac, source_prefix_frac=args.source_prefix_frac, backwards=args.backwards2) if args.language_model is not None: lm_score_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.lm_name, lm_file=True) # get gen output predictions_bpe_file = pre_gen+"/generate_output_bpe.txt" if using_nbest: print("Using predefined n-best list from interactive.py") predictions_bpe_file = args.nbest_list gen_output = rerank_utils.BitextOutputFromGen(predictions_bpe_file, bpe_symbol=args.remove_bpe, nbest=using_nbest, prefix_len=args.prefix_len, target_prefix_frac=args.target_prefix_frac) if rerank1_is_gen: bitext1 = gen_output else: bitext1 = rerank_utils.BitextOutput(score1_file, args.backwards1, args.right_to_left1, args.remove_bpe, args.prefix_len, args.target_prefix_frac, args.source_prefix_frac) if args.score_model2 is not None or args.nbest_list is not None: if rerank2_is_gen: bitext2 = gen_output else: bitext2 = rerank_utils.BitextOutput(score2_file, args.backwards2, args.right_to_left2, args.remove_bpe, args.prefix_len, args.target_prefix_frac, args.source_prefix_frac) assert bitext2.source_lengths == bitext1.source_lengths, \ "source lengths for rescoring models do not match" assert bitext2.target_lengths == bitext1.target_lengths, \ "target lengths for rescoring models do not match" else: if args.diff_bpe: assert args.score_model2 is None bitext2 = gen_output else: bitext2 = None if args.language_model is not None: lm_res1 = rerank_utils.LMOutput(lm_score_file, args.lm_dict, args.prefix_len, args.remove_bpe, args.target_prefix_frac) else: lm_res1 = None gen_output_lst.append(gen_output) bitext1_lst.append(bitext1) bitext2_lst.append(bitext2) lm_res1_lst.append(lm_res1) return gen_output_lst, bitext1_lst, bitext2_lst, lm_res1_lst def rerank(args): if type(args.lenpen) is not list: args.lenpen = [args.lenpen] if type(args.weight1) is not list: args.weight1 = [args.weight1] if type(args.weight2) is not list: args.weight2 = [args.weight2] if type(args.weight3) is not list: args.weight3 = [args.weight3] if args.all_shards: shard_ids = list(range(args.num_shards)) else: shard_ids = [args.shard_id] for shard_id in shard_ids: pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \ backwards_preprocessed_dir, lm_preprocessed_dir = \ rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset, args.gen_model_name, shard_id, args.num_shards, args.sampling, args.prefix_len, args.target_prefix_frac, args.source_prefix_frac) rerank_generate.gen_and_reprocess_nbest(args) rerank_score_bw.score_bw(args) rerank_score_lm.score_lm(args) if args.write_hypos is None: write_targets = pre_gen+"/matched_targets" write_hypos = pre_gen+"/matched_hypos" else: write_targets = args.write_hypos+"_targets" + args.gen_subset write_hypos = args.write_hypos+"_hypos" + args.gen_subset if args.all_shards: write_targets += "_all_shards" write_hypos += "_all_shards" best_lenpen, best_weight1, best_weight2, best_weight3, best_score = \ match_target_hypo(args, write_targets, write_hypos) return best_lenpen, best_weight1, best_weight2, best_weight3, best_score def cli_main(): parser = rerank_options.get_reranking_parser() args = options.parse_args_and_arch(parser) rerank(args) if __name__ == '__main__': cli_main()