Commit c37fc8fd authored by Myle Ott's avatar Myle Ott
Browse files

Output positional scores in interactive.py

parent bb5f15d1
......@@ -17,7 +17,7 @@ from fairseq.sequence_generator import SequenceGenerator
Batch = namedtuple('Batch', 'srcs tokens lengths')
Translation = namedtuple('Translation', 'src_str hypos alignments')
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')
def buffered_read(buffer_size):
......@@ -107,6 +107,7 @@ def main(args):
result = Translation(
src_str='O\t{}'.format(src_str),
hypos=[],
pos_scores=[],
alignments=[],
)
......@@ -121,6 +122,12 @@ def main(args):
remove_bpe=args.remove_bpe,
)
result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
result.pos_scores.append('P\t{}'.format(
' '.join(map(
lambda x: '{:.4f}'.format(x),
hypo['positional_scores'].tolist(),
))
))
result.alignments.append(
'A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment)))
if args.print_alignment else None
......@@ -156,8 +163,9 @@ def main(args):
for i in np.argsort(indices):
result = results[i]
print(result.src_str)
for hypo, align in zip(result.hypos, result.alignments):
for hypo, pos_scores, align in zip(result.hypos, result.pos_scores, result.alignments):
print(hypo)
print(pos_scores)
if align is not None:
print(align)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment