simul_trans_agent.py 5.76 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import json
import os

from fairseq import checkpoint_utils, tasks, utils

from . import DEFAULT_EOS, GET, SEND
from .agent import Agent


class SimulTransAgent(Agent):
    def __init__(self, args):
        # Load Model
        self.load_model(args)

        # build word spliter
        self.build_word_splitter(args)

        self.max_len = args.max_len

        self.eos = DEFAULT_EOS

    @staticmethod
    def add_args(parser):
        # fmt: off
        parser.add_argument('--model-path', type=str, required=True,
                            help='path to your pretrained model.')
        parser.add_argument("--data-bin", type=str, required=True,
                            help="Path of data binary")
        parser.add_argument("--user-dir", type=str, default="example/simultaneous_translation",
                            help="User directory for simultaneous translation")
        parser.add_argument("--src-splitter-type", type=str, default=None,
                            help="Subword splitter type for source text")
        parser.add_argument("--tgt-splitter-type", type=str, default=None,
                            help="Subword splitter type for target text")
        parser.add_argument("--src-splitter-path", type=str, default=None,
                            help="Subword splitter model path for source text")
        parser.add_argument("--tgt-splitter-path", type=str, default=None,
                            help="Subword splitter model path for target text")
        parser.add_argument("--max-len", type=int, default=150,
                            help="Maximum length difference between source and target prediction")
        parser.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
                            help='A dictionary used to override model args at generation '
                                 'that were used during model training')
        # fmt: on
        return parser

    def load_dictionary(self, task):
        raise NotImplementedError

    def load_model(self, args):
        args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..")
        utils.import_user_module(args)
        filename = args.model_path
        if not os.path.exists(filename):
            raise IOError("Model file not found: {}".format(filename))

        state = checkpoint_utils.load_checkpoint_to_cpu(
            filename, json.loads(args.model_overrides)
        )

        saved_args = state["args"]
        saved_args.data = args.data_bin

        task = tasks.setup_task(saved_args)

        # build model for ensemble
        self.model = task.build_model(saved_args)
        self.model.load_state_dict(state["model"], strict=True)

        # Set dictionary
        self.load_dictionary(task)

    def init_states(self):
        return {
            "indices": {"src": [], "tgt": []},
            "tokens": {"src": [], "tgt": []},
            "segments": {"src": [], "tgt": []},
            "steps": {"src": 0, "tgt": 0},
            "finished": False,
            "finish_read": False,
            "model_states": {},
        }

    def update_states(self, states, new_state):
        raise NotImplementedError

    def policy(self, states):
        # Read and Write policy
        action = None

        while action is None:
            if states["finished"]:
                # Finish the hypo by sending eos to server
                return self.finish_action()

            # Model make decision given current states
            decision = self.model.decision_from_states(states)

            if decision == 0 and not self.finish_read(states):
                # READ
                action = self.read_action(states)
            else:
                # WRITE
                action = self.write_action(states)

            # None means we make decision again but not sending server anything
            # This happened when read a bufffered token
            # Or predict a subword
        return action

    def finish_read(self, states):
        raise NotImplementedError

    def write_action(self, states):
        token, index = self.model.predict_from_states(states)

        if (
            index == self.dict["tgt"].eos()
            or len(states["tokens"]["tgt"]) > self.max_len
        ):
            # Finish this sentence is predict EOS
            states["finished"] = True
            end_idx_last_full_word = self._target_length(states)

        else:
            states["tokens"]["tgt"] += [token]
            end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word(
                states["tokens"]["tgt"]
            )
            self._append_indices(states, [index], "tgt")

        if end_idx_last_full_word > states["steps"]["tgt"]:
            # Only sent detokenized full words to the server
            word = self.word_splitter["tgt"].merge(
                states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word]
            )
            states["steps"]["tgt"] = end_idx_last_full_word
            states["segments"]["tgt"] += [word]

            return {"key": SEND, "value": word}
        else:
            return None

    def read_action(self, states):
        return {"key": GET, "value": None}

    def finish_action(self):
        return {"key": SEND, "value": DEFAULT_EOS}

    def reset(self):
        pass

    def finish_eval(self, states, new_state):
        if len(new_state) == 0 and len(states["indices"]["src"]) == 0:
            return True
        return False

    def _append_indices(self, states, new_indices, key):
        states["indices"][key] += new_indices

    def _target_length(self, states):
        return len(states["tokens"]["tgt"])