# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import time from functools import partial from multiprocessing.pool import ThreadPool as Pool from . import DEFAULT_EOS, GET, SEND class Agent(object): "an agent needs to follow this pattern" def __init__(self, *args, **kwargs): pass def init_states(self, *args, **kwargs): raise NotImplementedError def update_states(self, states, new_state): raise NotImplementedError def finish_eval(self, states, new_state): raise NotImplementedError def policy(self, state): raise NotImplementedError def reset(self): raise NotImplementedError def decode(self, session, low=0, high=100000, num_thread=10): corpus_info = session.corpus_info() high = min(corpus_info["num_sentences"] - 1, high) if low >= high: return t0 = time.time() if num_thread > 1: with Pool(10) as p: p.map( partial(self._decode_one, session), [sent_id for sent_id in range(low, high + 1)], ) else: for sent_id in range(low, high + 1): self._decode_one(session, sent_id) print(f"Finished {low} to {high} in {time.time() - t0}s") def _decode_one(self, session, sent_id): action = {} self.reset() states = self.init_states() while action.get("value", None) != DEFAULT_EOS: # take an action action = self.policy(states) if action["key"] == GET: new_states = session.get_src(sent_id, action["value"]) states = self.update_states(states, new_states) elif action["key"] == SEND: session.send_hypo(sent_id, action["value"]) print(" ".join(states["tokens"]["tgt"]))