Commit 18d27e00 authored by wangwei990215's avatar wangwei990215
Browse files

initial commit

parent 541f4c7a
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
criterion_name = file[: file.find(".py")]
importlib.import_module(
"examples.simultaneous_translation.criterions." + criterion_name
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from examples.simultaneous_translation.utils.latency import LatencyTraining
from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterion,
)
@register_criterion("latency_augmented_label_smoothed_cross_entropy")
class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
LabelSmoothedCrossEntropyCriterion
):
def __init__(self, args, task):
super().__init__(args, task)
self.eps = args.label_smoothing
self.latency_weight_avg = args.latency_weight_avg
self.latency_weight_avg_type = args.latency_weight_avg_type
self.latency_weight_var = args.latency_weight_var
self.latency_weight_var_type = args.latency_weight_var_type
self.mass_preservation = args.mass_preservation
self.average_method = args.average_method
self.latency_train = LatencyTraining(
self.latency_weight_avg,
self.latency_weight_var,
self.latency_weight_avg_type,
self.latency_weight_var_type,
self.mass_preservation,
self.average_method,
)
@staticmethod
def add_args(parser):
super(
LatencyAugmentedLabelSmoothedCrossEntropyCriterion,
LatencyAugmentedLabelSmoothedCrossEntropyCriterion,
).add_args(parser)
"""Add criterion-specific arguments to the parser."""
# fmt: off
parser.add_argument("--latency-weight-avg", default=0., type=float, metavar='D',
help="Average loss weight")
parser.add_argument("--latency-weight-var", default=0., type=float, metavar='D',
help="Variance loss weight")
parser.add_argument("--latency-weight-avg-type", default="differentiable_average_lagging",
help="Statistics for Average loss type")
parser.add_argument("--latency-weight-var-type", default="variance_delay",
help="Statistics for variance loss type")
parser.add_argument("--average-method", default="weighted_average",
help="Average loss type")
# fmt: on
def compute_loss(self, model, net_output, sample, reduce=True):
# Compute cross entropy loss first
loss, nll_loss = super().compute_loss(model, net_output, sample, reduce)
# Obtain the expected alignment
attn_list = [item["alpha"] for item in net_output[-1]["attn_list"]]
target_padding_mask = model.get_targets(sample, net_output).eq(self.padding_idx)
source_padding_mask = net_output[-1].get("encoder_padding_mask", None)
# Get latency loss
latency_loss = self.latency_train.loss(
attn_list, source_padding_mask, target_padding_mask
)
loss += latency_loss
return loss, nll_loss
# **Baseline Simultaneous Translation**
---
This is an instruction of training and evaluating a *wait-k* simultanoes LSTM model on MUST-C English-Gernam Dataset.
[STACL: Simultaneous Translation with Implicit Anticipation and Controllable Latency using Prefix-to-Prefix Framework](https://https://www.aclweb.org/anthology/P19-1289/)
## **Requirements**
Install fairseq (make sure to use the correct branch):
```
git clone --branch simulastsharedtask git@github.com:pytorch/fairseq.git
cd fairseq
pip install -e .
```
Assuming that fairseq is installed in a directory called `FAIRSEQ`.
Install SentencePiece. One easy way is to use anaconda:
```
conda install -c powerai sentencepiece
```
Download the MuST-C data for English-German available at https://ict.fbk.eu/must-c/.
We will assume that the data is downloaded in a directory called `DATA_ROOT`.
## **Text-to-text Model**
---
### Data Preparation
Train a SentencePiece model:
```shell
for lang in en de; do
python $FAIRSEQ/examples/simultaneous_translation/data/train_spm.py \
--data-path $DATA_ROOT/data \
--vocab-size 10000 \
--max-frame 3000 \
--model-type unigram \
--lang $lang \
--out-path .
```
Process the data with the SentencePiece model:
```shell
proc_dir=proc
mkdir -p $proc_dir
for split in train dev tst-COMMON tst-HE; do
for lang in en de; do
spm_encode \
--model unigram-$lang-10000-3000/spm.model \
< $DATA_ROOT/data/$split/txt/$split.$lang \
> $proc_dir/$split.spm.$lang
done
done
```
Binarize the data:
```shell
proc_dir=proc
fairseq-preprocess \
--source-lang en --target-lang de \
--trainpref $proc_dir/train.spm \
--validpref $proc_dir/dev.spm \
--testpref $proc_dir/tst-COMMON.spm \
--thresholdtgt 0 \
--thresholdsrc 0 \
--workers 20 \
--destdir ./data-bin/mustc_en_de \
```
### Training
```shell
mkdir -p checkpoints
CUDA_VISIBLE_DEVICES=1 python $FAIRSEQ/train.py data-bin/mustc_en_de \
--save-dir checkpoints \
--arch berard_simul_text_iwslt \
--simul-type waitk \
--waitk-lagging 2 \
--optimizer adam \
--max-epoch 100 \
--lr 0.001 \
--clip-norm 5.0 \
--batch-size 128 \
--log-format json \
--log-interval 10 \
--criterion cross_entropy_acc \
--user-dir $FAIRSEQ/examples/simultaneous_translation
```
## **Speech-to-text Model**
---
### Data Preparation
First, segment wav files.
```shell
python $FAIRSEQ/examples/simultaneous_translation/data/segment_wav.py \
--datapath $DATA_ROOT
```
Similar to text-to-text model, train a Sentencepiecemodel, but only train on German
```Shell
python $FAIRSEQ/examples/simultaneous_translation/data/train_spm.py \
--data-path $DATA_ROOT/data \
--vocab-size 10000 \
--max-frame 3000 \
--model-type unigram \
--lang $lang \
--out-path .
```
## Training
```shell
mkdir -p checkpoints
CUDA_VISIBLE_DEVICES=1 python $FAIRSEQ/train.py data-bin/mustc_en_de \
--save-dir checkpoints \
--arch berard_simul_text_iwslt \
--waitk-lagging 2 \
--waitk-stride 10 \
--input-feat-per-channel 40 \
--encoder-hidden-size 512 \
--output-layer-dim 128 \
--decoder-num-layers 3 \
--task speech_translation \
--user-dir $FAIRSEQ/examples/simultaneous_translation
--optimizer adam \
--max-epoch 100 \
--lr 0.001 \
--clip-norm 5.0 \
--batch-size 128 \
--log-format json \
--log-interval 10 \
--criterion cross_entropy_acc \
--user-dir $FAIRSEQ/examples/simultaneous_translation
```
## Evaluation
---
### Evaluation Server
For text translation models, the server is set up as follow give input file and reference file.
``` shell
python ./eval/server.py \
--hostname localhost \
--port 12321 \
--src-file $DATA_ROOT/data/dev/txt/dev.en \
--ref-file $DATA_ROOT/data/dev/txt/dev.de
```
For speech translation models, the input is the data direcrory.
``` shell
python ./eval/server.py \
--hostname localhost \
--port 12321 \
--ref-file $DATA_ROOT \
--data-type speech
```
### Decode and Evaluate with Client
Once the server is set up, run client to evaluate translation quality and latency.
```shell
# TEXT
python $fairseq_dir/examples/simultaneous_translation/evaluate.py \
data-bin/mustc_en_de \
--user-dir $FAIRSEQ/examples/simultaneous_translation \
--src-spm unigram-en-10000-3000/spm.model\
--tgt-spm unigram-de-10000-3000/spm.model\
-s en -t de \
--path checkpoints/checkpoint_best.pt
# SPEECH
python $fairseq_dir/examples/simultaneous_translation/evaluate.py \
data-bin/mustc_en_de \
--user-dir $FAIRSEQ/examples/simultaneous_translation \
--data-type speech \
--tgt-spm unigram-de-10000-3000/spm.model\
-s en -t de \
--path checkpoints/checkpoint_best.pt
```
# Introduction to evaluation interface
The simultaneous translation models from sharedtask participents are evaluated under a server-client protocol. The participents are requisted to plug in their own model API in the protocol, and submit a docker file.
## Server-Client Protocol
An server-client protocol that will be used in evaluation. For example, when a *wait-k* model (k=3) translate the English sentence "Alice and Bob are good friends" to Genman sentence "Alice und Bob sind gute Freunde." , the evaluation process is shown as following figure.
While every time client needs to read a new state (word or speech utterence), a "GET" request is supposed to sent over to server. Whenever a new token is generated, a "SEND" request with the word predicted (untokenized word) will be sent to server immediately. The server can hence calculate both latency and BLEU score of the sentence.
### Server
The server code is provided and can be set up directly locally for development purpose. For example, to evaluate a text simultaneous test set,
```shell
python fairseq/examples/simultaneous_translation/eval/server.py \
--hostname local_host \
--port 1234 \
--src-file SRC_FILE \
--ref-file REF_FILE \
--data-type text \
```
The state that server sent to client is has the following format
```json
{
'sent_id': Int,
'segment_id': Int,
'segment': String
}
```
### Client
The client will handle the evaluation process mentioned above. It should be out-of-box as well. The client's protocol is as following table
|Action|Content|
|:---:|:---:|
|Request new word / utterence| ```{key: "Get", value: None}```|
|Predict word "W"| ```{key: "SEND", value: "W"}```|
The core of the client module is the agent, which needs to be modified to different models accordingly. The abstract class of agent is as follow, the evaluation process happens in the `decode()` function.
```python
class Agent(object):
"an agent needs to follow this pattern"
def __init__(self, *args, **kwargs):
...
def init_states(self):
# Initializing states
...
def update_states(self, states, new_state):
# Update states with given new state from server
# TODO (describe the states)
...
def finish_eval(self, states, new_state):
# Check if evaluation is finished
...
def policy(self, state: list) -> dict:
# Provide a action given current states
# The action can only be either
# {key: "GET", value: NONE}
# or
# {key: "SEND", value: W}
...
def reset(self):
# Reset agent
...
def decode(self, session):
states = self.init_states()
self.reset()
# Evaluataion protocol happens here
while True:
# Get action from the current states according to self.policy()
action = self.policy(states)
if action['key'] == GET:
# Read a new state from server
new_state = session.get_src()
states = self.update_states(states, new_state)
if self.finish_eval(states, new_state):
# End of document
break
elif action['key'] == SEND:
# Send a new prediction to server
session.send_hypo(action['value'])
# Clean the history, wait for next sentence
if action['value'] == DEFAULT_EOS:
states = self.init_states()
self.reset()
else:
raise NotImplementedError
```
Here an implementation of agent of text [*wait-k* model](somelink). Notice that the tokenization is not considered.
## Quality
The quality is measured by detokenized BLEU. So make sure that the predicted words sent to server are detokenized. An implementation is can be find [here](some link)
## Latency
The latency metrics are
* Average Proportion
* Average Lagging
* Differentiable Average Lagging
Again Thery will also be evaluated on detokenized text.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
from fairseq import registry
build_agent, register_agent, MONOTONIC_AGENT, _ = registry.setup_registry(
"--agent-type"
)
DEFAULT_EOS = "</s>"
GET = 0
SEND = 1
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
module = file[: file.find(".py")]
importlib.import_module("agents." + module)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import time
from functools import partial
from multiprocessing.pool import ThreadPool as Pool
from . import DEFAULT_EOS, GET, SEND
class Agent(object):
"an agent needs to follow this pattern"
def __init__(self, *args, **kwargs):
pass
def init_states(self, *args, **kwargs):
raise NotImplementedError
def update_states(self, states, new_state):
raise NotImplementedError
def finish_eval(self, states, new_state):
raise NotImplementedError
def policy(self, state):
raise NotImplementedError
def reset(self):
raise NotImplementedError
def decode(self, session, low=0, high=100000, num_thread=10):
corpus_info = session.corpus_info()
high = min(corpus_info["num_sentences"] - 1, high)
if low >= high:
return
t0 = time.time()
if num_thread > 1:
with Pool(10) as p:
p.map(
partial(self._decode_one, session),
[sent_id for sent_id in range(low, high + 1)],
)
else:
for sent_id in range(low, high + 1):
self._decode_one(session, sent_id)
print(f"Finished {low} to {high} in {time.time() - t0}s")
def _decode_one(self, session, sent_id):
action = {}
self.reset()
states = self.init_states()
while action.get("value", None) != DEFAULT_EOS:
# take an action
action = self.policy(states)
if action["key"] == GET:
new_states = session.get_src(sent_id, action["value"])
states = self.update_states(states, new_states)
elif action["key"] == SEND:
session.send_hypo(sent_id, action["value"])
print(" ".join(states["tokens"]["tgt"]))
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
from fairseq import checkpoint_utils, tasks, utils
from . import DEFAULT_EOS, GET, SEND
from .agent import Agent
class SimulTransAgent(Agent):
def __init__(self, args):
# Load Model
self.load_model(args)
# build word spliter
self.build_word_splitter(args)
self.max_len = args.max_len
self.eos = DEFAULT_EOS
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--model-path', type=str, required=True,
help='path to your pretrained model.')
parser.add_argument("--data-bin", type=str, required=True,
help="Path of data binary")
parser.add_argument("--user-dir", type=str, default="example/simultaneous_translation",
help="User directory for simultaneous translation")
parser.add_argument("--src-splitter-type", type=str, default=None,
help="Subword splitter type for source text")
parser.add_argument("--tgt-splitter-type", type=str, default=None,
help="Subword splitter type for target text")
parser.add_argument("--src-splitter-path", type=str, default=None,
help="Subword splitter model path for source text")
parser.add_argument("--tgt-splitter-path", type=str, default=None,
help="Subword splitter model path for target text")
parser.add_argument("--max-len", type=int, default=150,
help="Maximum length difference between source and target prediction")
parser.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='A dictionary used to override model args at generation '
'that were used during model training')
# fmt: on
return parser
def load_dictionary(self, task):
raise NotImplementedError
def load_model(self, args):
args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..")
utils.import_user_module(args)
filename = args.model_path
if not os.path.exists(filename):
raise IOError("Model file not found: {}".format(filename))
state = checkpoint_utils.load_checkpoint_to_cpu(
filename, json.loads(args.model_overrides)
)
saved_args = state["args"]
saved_args.data = args.data_bin
task = tasks.setup_task(saved_args)
# build model for ensemble
self.model = task.build_model(saved_args)
self.model.load_state_dict(state["model"], strict=True)
# Set dictionary
self.load_dictionary(task)
def init_states(self):
return {
"indices": {"src": [], "tgt": []},
"tokens": {"src": [], "tgt": []},
"segments": {"src": [], "tgt": []},
"steps": {"src": 0, "tgt": 0},
"finished": False,
"finish_read": False,
"model_states": {},
}
def update_states(self, states, new_state):
raise NotImplementedError
def policy(self, states):
# Read and Write policy
action = None
while action is None:
if states["finished"]:
# Finish the hypo by sending eos to server
return self.finish_action()
# Model make decision given current states
decision = self.model.decision_from_states(states)
if decision == 0 and not self.finish_read(states):
# READ
action = self.read_action(states)
else:
# WRITE
action = self.write_action(states)
# None means we make decision again but not sending server anything
# This happened when read a bufffered token
# Or predict a subword
return action
def finish_read(self, states):
raise NotImplementedError
def write_action(self, states):
token, index = self.model.predict_from_states(states)
if (
index == self.dict["tgt"].eos()
or len(states["tokens"]["tgt"]) > self.max_len
):
# Finish this sentence is predict EOS
states["finished"] = True
end_idx_last_full_word = self._target_length(states)
else:
states["tokens"]["tgt"] += [token]
end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word(
states["tokens"]["tgt"]
)
self._append_indices(states, [index], "tgt")
if end_idx_last_full_word > states["steps"]["tgt"]:
# Only sent detokenized full words to the server
word = self.word_splitter["tgt"].merge(
states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word]
)
states["steps"]["tgt"] = end_idx_last_full_word
states["segments"]["tgt"] += [word]
return {"key": SEND, "value": word}
else:
return None
def read_action(self, states):
return {"key": GET, "value": None}
def finish_action(self):
return {"key": SEND, "value": DEFAULT_EOS}
def reset(self):
pass
def finish_eval(self, states, new_state):
if len(new_state) == 0 and len(states["indices"]["src"]) == 0:
return True
return False
def _append_indices(self, states, new_indices, key):
states["indices"][key] += new_indices
def _target_length(self, states):
return len(states["tokens"]["tgt"])
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import DEFAULT_EOS, GET, register_agent
from .simul_trans_agent import SimulTransAgent
from .word_splitter import SPLITTER_DICT
@register_agent("simul_trans_text")
class SimulTransTextAgent(SimulTransAgent):
def build_word_splitter(self, args):
self.word_splitter = {}
self.word_splitter["src"] = SPLITTER_DICT[args.src_splitter_type](
getattr(args, f"src_splitter_path")
)
self.word_splitter["tgt"] = SPLITTER_DICT[args.tgt_splitter_type](
getattr(args, f"tgt_splitter_path")
)
def load_dictionary(self, task):
self.dict = {}
self.dict["tgt"] = task.target_dictionary
self.dict["src"] = task.source_dictionary
def update_states(self, states, new_state):
if states["finish_read"]:
return states
new_word = new_state["segment"]
# Split words and index the token
if new_word not in [DEFAULT_EOS]:
tokens = self.word_splitter["src"].split(new_word)
# Get indices from dictionary
# You can change to you own dictionary
indices = (
self.dict["src"]
.encode_line(
tokens,
line_tokenizer=lambda x: x,
add_if_not_exist=False,
append_eos=False,
)
.tolist()
)
else:
tokens = [new_word]
indices = [self.dict["src"].eos()]
states["finish_read"] = True
# Update states
states["segments"]["src"] += [new_word]
states["tokens"]["src"] += tokens
self._append_indices(states, indices, "src")
return states
def read_action(self, states):
# Increase source step by one
states["steps"]["src"] += 1
# At leat one word is read
if len(states["tokens"]["src"]) == 0:
return {"key": GET, "value": None}
# Only request new word if there is no buffered tokens
if len(states["tokens"]["src"]) <= states["steps"]["src"]:
return {"key": GET, "value": None}
return None
def finish_read(self, states):
# The first means all segments (full words) has been read from server
# The second means all tokens (subwords) has been read locally
return (
states["finish_read"]
and len(states["tokens"]["src"]) == states["steps"]["src"]
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
class SubwordSplitter(object):
def process_line(self, string):
raise NotImplementedError
def split(self, string):
raise NotImplementedError
class NoneWordSplitter(object):
def __init__(self, model):
pass
def split(self, string):
return [string]
def process_line(self, string):
return [string]
def finished_word(self, string):
return True
def merge(self, list_of_string):
return "".join(list_of_string)
def last_full_word_step(self, tokens, step):
return len(tokens)
def end_idx_last_full_word(self, tokens):
return len(tokens)
class BPEWordSplitter(object):
# TODO: lock back here
def __init__(self, model_path):
super().__init__()
from subword_nmt.apply_bpe import BPE
with open(model_path) as f:
self.model = BPE(f)
def split(self, string):
return self.model.process_line(string).split()
def end_idx_last_full_word(self, tokens):
# Begin of word indices
bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != "@@"]
if len(bow_indices) < 2:
return 0
else:
return bow_indices[-1]
def merge(self, list_of_string):
return " ".join([item.replace("@@", "") for item in list_of_string])
class SentencePieceModelWordSplitter(object):
def __init__(self, model_path):
super().__init__()
import sentencepiece as spm
self.model = spm.SentencePieceProcessor()
self.model.Load(model_path)
def split(self, string):
return self.model.EncodeAsPieces(string)
def end_idx_last_full_word(self, tokens):
# Begin of word indices
bow_indices = [i for i, t in enumerate(tokens) if t[0] == "\u2581"]
if len(bow_indices) < 2:
return 0
else:
return bow_indices[-1]
def merge(self, list_of_string):
return self.model.DecodePieces(list_of_string)
SPLITTER_DICT = {
None: NoneWordSplitter,
"BPE": BPEWordSplitter,
"SentencePieceModel": SentencePieceModelWordSplitter,
}
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import requests
from scorers import build_scorer
class SimulSTEvaluationService(object):
DEFAULT_HOSTNAME = "localhost"
DEFAULT_PORT = 12321
def __init__(self, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT):
self.hostname = hostname
self.port = port
self.base_url = f"http://{self.hostname}:{self.port}"
def __enter__(self):
self.new_session()
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def new_session(self):
# start eval session
url = f"{self.base_url}"
try:
_ = requests.post(url)
except Exception as e:
print(f"Failed to start an evaluation session: {e}")
print("Evaluation session started.")
return self
def get_scores(self):
# end eval session
url = f"{self.base_url}/result"
try:
r = requests.get(url)
print("Scores: {}".format(r.json()))
print("Evaluation session finished.")
except Exception as e:
print(f"Failed to end an evaluation session: {e}")
def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str:
url = f"{self.base_url}/src"
params = {"sent_id": sent_id}
if extra_params is not None:
for key in extra_params.keys():
params[key] = extra_params[key]
try:
r = requests.get(url, params=params)
except Exception as e:
print(f"Failed to request a source segment: {e}")
return r.json()
def send_hypo(self, sent_id: int, hypo: str) -> None:
url = f"{self.base_url}/hypo"
params = {"sent_id": sent_id}
try:
requests.put(url, params=params, data=hypo.encode("utf-8"))
except Exception as e:
print(f"Failed to send a translated segment: {e}")
def corpus_info(self):
url = f"{self.base_url}"
try:
r = requests.get(url)
except Exception as e:
print(f"Failed to request corpus information: {e}")
return r.json()
class SimulSTLocalEvaluationService(object):
def __init__(self, args):
self.scorer = build_scorer(args)
def get_scores(self):
return self.scorer.score()
def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str:
if extra_params is not None:
segment_size = extra_params.get("segment_size", None)
else:
segment_size = None
return self.scorer.send_src(int(sent_id), segment_size)
def send_hypo(self, sent_id: int, hypo: str) -> None:
list_of_tokens = hypo.strip().split()
self.scorer.recv_hyp(sent_id, list_of_tokens)
def corpus_info(self):
return self.scorer.get_info()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import torch
from examples.simultaneous_translation.utils.latency import LatencyInference
LATENCY_METRICS = [
"differentiable_average_lagging",
"average_lagging",
"average_proportion",
]
class LatencyScorer:
def __init__(self, start_from_zero=True):
self.recorder = []
self.scores = {}
self.scorer = LatencyInference()
self.start_from_zero = start_from_zero
def update_reorder(self, list_of_dict):
self.recorder = []
for info in list_of_dict:
delays = [int(x) - int(not self.start_from_zero) for x in info["delays"]]
delays = torch.LongTensor(delays).unsqueeze(0)
src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0)
self.recorder.append(self.scorer(delays, src_len))
def cal_latency(self):
self.scores = {}
for metric in LATENCY_METRICS:
self.scores[metric] = sum(
[x[metric][0, 0].item() for x in self.recorder]
) / len(self.recorder)
return self.scores
@classmethod
def score(cls, list_of_dict, start_from_zero=True):
scorer_to_return = cls(start_from_zero)
scorer_to_return.update_reorder(list_of_dict)
scorer_to_return.cal_latency()
return scorer_to_return.scores
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", required=True)
parser.add_argument("--start-from-zero", action="store_true")
args = parser.parse_args()
scorer = LatencyInference()
recorder = []
with open(args.input, "r") as f:
for line in f:
info = json.loads(line)
delays = [int(x) - int(not args.start_from_zero) for x in info["delays"]]
delays = torch.LongTensor(delays).unsqueeze(0)
src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0)
recorder.append(scorer(delays, src_len))
average_results = {}
for metric in LATENCY_METRICS:
average_results[metric] = sum([x[metric][0, 0].item() for x in recorder]) / len(
recorder
)
print(f"{metric}: {average_results[metric]}")
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from agents import build_agent
from client import SimulSTEvaluationService, SimulSTLocalEvaluationService
from fairseq.registry import REGISTRIES
DEFAULT_HOSTNAME = "localhost"
DEFAULT_PORT = 12321
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--hostname", type=str, default=DEFAULT_HOSTNAME, help="server hostname"
)
parser.add_argument(
"--port", type=int, default=DEFAULT_PORT, help="server port number"
)
parser.add_argument("--agent-type", default="simul_trans_text", help="Agent type")
parser.add_argument("--scorer-type", default="text", help="Scorer type")
parser.add_argument(
"--start-idx",
type=int,
default=0,
help="Start index of the sentence to evaluate",
)
parser.add_argument(
"--end-idx",
type=int,
default=float("inf"),
help="End index of the sentence to evaluate",
)
parser.add_argument(
"--scores", action="store_true", help="Request scores from server"
)
parser.add_argument("--reset-server", action="store_true", help="Reset the server")
parser.add_argument(
"--num-threads", type=int, default=10, help="Number of threads used by agent"
)
parser.add_argument(
"--local", action="store_true", default=False, help="Local evaluation"
)
args, _ = parser.parse_known_args()
for registry_name, REGISTRY in REGISTRIES.items():
choice = getattr(args, registry_name, None)
if choice is not None:
cls = REGISTRY["registry"][choice]
if hasattr(cls, "add_args"):
cls.add_args(parser)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
if args.local:
session = SimulSTLocalEvaluationService(args)
else:
session = SimulSTEvaluationService(args.hostname, args.port)
if args.reset_server:
session.new_session()
if args.agent_type is not None:
agent = build_agent(args)
agent.decode(session, args.start_idx, args.end_idx, args.num_threads)
if args.scores:
session.get_scores()
print(session.get_scores())
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
from fairseq import registry
(build_scorer, register_scorer, SCORER_REGISTRIES, _) = registry.setup_registry(
"--scorer-type"
)
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
module = file[: file.find(".py")]
importlib.import_module("scorers." + module)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
from collections import defaultdict
from examples.simultaneous_translation.eval.eval_latency import LatencyScorer
from vizseq.scorers.bleu import BLEUScorer
from vizseq.scorers.meteor import METEORScorer
from vizseq.scorers.ter import TERScorer
DEFAULT_EOS = "</s>"
class SimulScorer(object):
def __init__(self, args):
self.tokenizer = args.tokenizer
self.output_dir = args.output
if args.output is not None:
self.output_files = {
"text": os.path.join(args.output, "text"),
"delay": os.path.join(args.output, "delay"),
"scores": os.path.join(args.output, "scores"),
}
else:
self.output_files = None
self.eos = DEFAULT_EOS
self.data = {"tgt": []}
self.reset()
def get_info(self):
return {"num_sentences": len(self)}
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--src-file', type=str, required=True,
help='Source input file')
parser.add_argument('--tgt-file', type=str, required=True,
help='Target reference file')
parser.add_argument('--tokenizer', default="13a", choices=["none", "13a"],
help='Tokenizer used for sacrebleu')
parser.add_argument('--output', type=str, default=None,
help='Path for output directory')
# fmt: on
def send_src(self, sent_id, *args):
raise NotImplementedError
def recv_hyp(self, sent_id, list_of_tokens):
for token in list_of_tokens:
self.translations[sent_id].append((token, self.steps[sent_id]))
def reset(self):
self.steps = defaultdict(int)
self.translations = defaultdict(list)
def src_lengths(self):
raise NotImplementedError
def score(self):
translations = []
delays = []
for i in range(1 + max(self.translations.keys())):
translations += [" ".join(t[0] for t in self.translations[i][:-1])]
delays += [[t[1] for t in self.translations[i]]]
bleu_score = BLEUScorer(
sent_level=False,
corpus_level=True,
extra_args={"bleu_tokenizer": self.tokenizer},
).score(translations, [self.data["tgt"]])
ter_score = TERScorer(sent_level=False, corpus_level=True).score(
translations, [self.data["tgt"]]
)
meteor_score = METEORScorer(sent_level=False, corpus_level=True).score(
translations, [self.data["tgt"]]
)
latency_score = LatencyScorer().score(
[
{"src_len": src_len, "delays": delay}
for src_len, delay in zip(self.src_lengths(), delays)
],
start_from_zero=False,
)
scores = {
"BLEU": bleu_score[0],
"TER": ter_score[0],
"METEOR": meteor_score[0],
"DAL": latency_score["differentiable_average_lagging"],
"AL": latency_score["average_lagging"],
"AP": latency_score["average_proportion"],
}
if self.output_files is not None:
try:
os.makedirs(self.output_dir, exist_ok=True)
self.write_results_to_file(translations, delays, scores)
except BaseException as be:
print(f"Failed to write results to {self.output_dir}.")
print(be)
print("Skip writing predictions")
return scores
def write_results_to_file(self, translations, delays, scores):
if self.output_files["text"] is not None:
with open(self.output_files["text"], "w") as f:
for line in translations:
f.write(line + "\n")
if self.output_files["delay"] is not None:
with open(self.output_files["delay"], "w") as f:
for i, delay in enumerate(delays):
f.write(
json.dumps({"src_len": self.src_lengths()[i], "delays": delay})
+ "\n"
)
with open(self.output_files["scores"], "w") as f:
for key, value in scores.items():
f.write(f"{key}, {value}\n")
@classmethod
def _load_text_file(cls, file, split=False):
with open(file) as f:
if split:
return [r.strip().split() for r in f]
else:
return [r.strip() for r in f]
@classmethod
def _load_text_from_json(cls, file):
list_to_return = []
with open(file) as f:
content = json.load(f)
for item in content["utts"].values():
list_to_return.append(item["output"]["text"].strip())
return list_to_return
@classmethod
def _load_wav_info_from_json(cls, file):
list_to_return = []
with open(file) as f:
content = json.load(f)
for item in content["utts"].values():
list_to_return.append(
{
"path": item["input"]["path"].strip(),
"length": item["input"]["length_ms"],
}
)
return list_to_return
@classmethod
def _load_wav_info_from_list(cls, file):
list_to_return = []
with open(file) as f:
for line in f:
list_to_return.append(
{
"path": line.strip(),
}
)
return list_to_return
def __len__(self):
return len(self.data["tgt"])
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import register_scorer
from .scorer import SimulScorer
@register_scorer("text")
class SimulTextScorer(SimulScorer):
def __init__(self, args):
super().__init__(args)
self.data = {
"src": self._load_text_file(args.src_file, split=True),
"tgt": self._load_text_file(args.tgt_file, split=False),
}
def send_src(self, sent_id, *args):
if self.steps[sent_id] >= len(self.data["src"][sent_id]):
dict_to_return = {
"sent_id": sent_id,
"segment_id": self.steps[sent_id],
"segment": self.eos,
}
# Consider EOS
self.steps[sent_id] = len(self.data["src"][sent_id]) + 1
else:
dict_to_return = {
"sent_id": sent_id,
"segment_id": self.steps[sent_id],
"segment": self.data["src"][sent_id][self.steps[sent_id]],
}
self.steps[sent_id] += 1
return dict_to_return
def src_lengths(self):
# +1 for eos
return [len(sent) + 1 for sent in self.data["src"]]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import sys
from scorers import build_scorer
from tornado import ioloop, web
DEFAULT_HOSTNAME = "localhost"
DEFAULT_PORT = 12321
class ScorerHandler(web.RequestHandler):
def initialize(self, scorer):
self.scorer = scorer
class EvalSessionHandler(ScorerHandler):
def post(self):
self.scorer.reset()
def get(self):
r = json.dumps(self.scorer.get_info())
self.write(r)
class ResultHandler(ScorerHandler):
def get(self):
r = json.dumps(self.scorer.score())
self.write(r)
class SourceHandler(ScorerHandler):
def get(self):
sent_id = int(self.get_argument("sent_id"))
segment_size = None
if "segment_size" in self.request.arguments:
string = self.get_argument("segment_size")
if len(string) > 0:
segment_size = int(string)
r = json.dumps(self.scorer.send_src(int(sent_id), segment_size))
self.write(r)
class HypothesisHandler(ScorerHandler):
def put(self):
sent_id = int(self.get_argument("sent_id"))
list_of_tokens = self.request.body.decode("utf-8").strip().split()
self.scorer.recv_hyp(sent_id, list_of_tokens)
def add_args():
parser = argparse.ArgumentParser()
# fmt: off
parser.add_argument('--hostname', type=str, default=DEFAULT_HOSTNAME,
help='Server hostname')
parser.add_argument('--port', type=int, default=DEFAULT_PORT,
help='Server port number')
args, _ = parser.parse_known_args()
# fmt: on
return args
def start_server(scorer, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT, debug=False):
app = web.Application(
[
(r"/result", ResultHandler, dict(scorer=scorer)),
(r"/src", SourceHandler, dict(scorer=scorer)),
(r"/hypo", HypothesisHandler, dict(scorer=scorer)),
(r"/", EvalSessionHandler, dict(scorer=scorer)),
],
debug=debug,
)
app.listen(port, max_buffer_size=1024 ** 3)
sys.stdout.write(f"Evaluation Server Started. Listening to port {port}\n")
ioloop.IOLoop.current().start()
if __name__ == "__main__":
args = add_args()
scorer = build_scorer(args)
start_server(scorer, args.hostname, args.port, args.debug)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
model_name = file[: file.find(".py")]
importlib.import_module(
"examples.simultaneous_translation.models." + model_name
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from examples.simultaneous_translation.modules.monotonic_transformer_layer import (
TransformerMonotonicDecoderLayer,
TransformerMonotonicEncoderLayer,
)
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import (
TransformerDecoder,
TransformerEncoder,
TransformerModel,
base_architecture,
transformer_iwslt_de_en,
transformer_vaswani_wmt_en_de_big,
)
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model("transformer_unidirectional")
class TransformerUnidirectionalModel(TransformerModel):
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
@register_model("transformer_monotonic")
class TransformerMonotonicModel(TransformerModel):
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerMonotonicDecoder(args, tgt_dict, embed_tokens)
def _indices_from_states(self, states):
if type(states["indices"]["src"]) == list:
if next(self.parameters()).is_cuda:
tensor = torch.cuda.LongTensor
else:
tensor = torch.LongTensor
src_indices = tensor(
[states["indices"]["src"][: 1 + states["steps"]["src"]]]
)
tgt_indices = tensor(
[[self.decoder.dictionary.eos()] + states["indices"]["tgt"]]
)
else:
src_indices = states["indices"]["src"][: 1 + states["steps"]["src"]]
tgt_indices = states["indices"]["tgt"]
return src_indices, None, tgt_indices
def predict_from_states(self, states):
decoder_states = self.decoder.output_layer(states["decoder_features"])
lprobs = self.get_normalized_probs([decoder_states[:, -1:]], log_probs=True)
index = lprobs.argmax(dim=-1)
token = self.decoder.dictionary.string(index)
return token, index[0, 0].item()
def decision_from_states(self, states):
"""
This funcion take states dictionary as input, and gives the agent
a decision of whether read a token from server. Moreover, the decoder
states are also calculated here so we can directly generate a target
token without recompute every thing
"""
self.eval()
if len(states["tokens"]["src"]) == 0:
return 0
src_indices, src_lengths, tgt_indices = self._indices_from_states(states)
# Update encoder states if needed
if (
"encoder_states" not in states
or states["encoder_states"][0].size(1) <= states["steps"]["src"]
):
encoder_out_dict = self.encoder(src_indices, src_lengths)
states["encoder_states"] = encoder_out_dict
else:
encoder_out_dict = states["encoder_states"]
# online means we still need tokens to feed the model
states["model_states"]["online"] = not (
states["finish_read"]
and len(states["tokens"]["src"]) == states["steps"]["src"]
)
states["model_states"]["steps"] = states["steps"]
x, outputs = self.decoder.forward(
prev_output_tokens=tgt_indices,
encoder_out=encoder_out_dict,
incremental_state=states["model_states"],
features_only=True,
)
states["decoder_features"] = x
return outputs["action"]
class TransformerMonotonicEncoder(TransformerEncoder):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)
self.dictionary = dictionary
self.layers = nn.ModuleList([])
self.layers.extend(
[TransformerMonotonicEncoderLayer(args) for i in range(args.encoder_layers)]
)
class TransformerMonotonicDecoder(TransformerDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
super().__init__(args, dictionary, embed_tokens, no_encoder_attn=False)
self.dictionary = dictionary
self.layers = nn.ModuleList([])
self.layers.extend(
[
TransformerMonotonicDecoderLayer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
]
)
def pre_attention(
self, prev_output_tokens, encoder_out_dict, incremental_state=None
):
positions = (
self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
)
if self.embed_positions is not None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
encoder_out = encoder_out_dict.encoder_out
encoder_padding_mask = encoder_out_dict.encoder_padding_mask
return x, encoder_out, encoder_padding_mask
def post_attention(self, x):
if self.layer_norm:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x
def extract_features(
self, prev_output_tokens, encoder_out, incremental_state=None, **unused
):
"""
Similar to *forward* but only return features.
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
# incremental_state = None
(x, encoder_outs, encoder_padding_mask) = self.pre_attention(
prev_output_tokens, encoder_out, incremental_state
)
attn = None
inner_states = [x]
attn_list = []
step_list = []
for i, layer in enumerate(self.layers):
x, attn, _ = layer(
x=x,
encoder_out=encoder_outs,
encoder_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
self_attn_mask=self.buffered_future_mask(x)
if incremental_state is None
else None,
)
inner_states.append(x)
attn_list.append(attn)
if incremental_state is not None:
curr_steps = layer.get_steps(incremental_state)
step_list.append(curr_steps)
if incremental_state.get("online", False):
p_choose = (
attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t())
)
new_steps = curr_steps + (p_choose < 0.5).t().type_as(curr_steps)
if (new_steps >= incremental_state["steps"]["src"]).any():
# We need to prune the last self_attn saved_state
# if model decide not to read
# otherwise there will be duplicated saved_state
for j in range(i + 1):
self.layers[j].prune_incremental_state(incremental_state)
return x, {"action": 0}
if incremental_state is not None and not incremental_state.get("online", False):
# Here is for fast evaluation
fastest_step = (
torch.max(torch.cat(step_list, dim=1), dim=1, keepdim=True)[0] + 1
)
if "fastest_step" in incremental_state:
incremental_state["fastest_step"] = torch.cat(
[incremental_state["fastest_step"], fastest_step], dim=1
)
else:
incremental_state["fastest_step"] = fastest_step
x = self.post_attention(x)
return x, {
"action": 1,
"attn_list": attn_list,
"step_list": step_list,
"encoder_out": encoder_out,
"encoder_padding_mask": encoder_padding_mask,
}
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
if "fastest_step" in incremental_state:
incremental_state["fastest_step"] = incremental_state[
"fastest_step"
].index_select(0, new_order)
@register_model_architecture("transformer_monotonic", "transformer_monotonic")
def base_monotonic_rchitecture(args):
base_architecture(args)
args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False)
@register_model_architecture(
"transformer_monotonic", "transformer_monotonic_iwslt_de_en"
)
def transformer_monotonic_iwslt_de_en(args):
transformer_iwslt_de_en(args)
base_monotonic_rchitecture(args)
# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
@register_model_architecture(
"transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_de_big"
)
def transformer_monotonic_vaswani_wmt_en_de_big(args):
transformer_vaswani_wmt_en_de_big(args)
@register_model_architecture(
"transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_fr_big"
)
def transformer_monotonic_vaswani_wmt_en_fr_big(args):
transformer_monotonic_vaswani_wmt_en_fr_big(args)
@register_model_architecture(
"transformer_unidirectional", "transformer_unidirectional_iwslt_de_en"
)
def transformer_unidirectional_iwslt_de_en(args):
transformer_iwslt_de_en(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
from fairseq import registry
(
build_monotonic_attention,
register_monotonic_attention,
MONOTONIC_ATTENTION_REGISTRY,
_,
) = registry.setup_registry("--simul-type")
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
model_name = file[: file.find(".py")]
importlib.import_module(
"examples.simultaneous_translation.modules." + model_name
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment