Commit 7df61696 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add fairseq0.10.2

parents
Pipeline #471 failed with stages
in 0 seconds
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import time
from functools import partial
from multiprocessing.pool import ThreadPool as Pool
from . import DEFAULT_EOS, GET, SEND
class Agent(object):
"an agent needs to follow this pattern"
def __init__(self, *args, **kwargs):
pass
def init_states(self, *args, **kwargs):
raise NotImplementedError
def update_states(self, states, new_state):
raise NotImplementedError
def finish_eval(self, states, new_state):
raise NotImplementedError
def policy(self, state):
raise NotImplementedError
def reset(self):
raise NotImplementedError
def decode(self, session, low=0, high=100000, num_thread=10):
corpus_info = session.corpus_info()
high = min(corpus_info["num_sentences"] - 1, high)
if low >= high:
return
t0 = time.time()
if num_thread > 1:
with Pool(10) as p:
p.map(
partial(self._decode_one, session),
[sent_id for sent_id in range(low, high + 1)],
)
else:
for sent_id in range(low, high + 1):
self._decode_one(session, sent_id)
print(f"Finished {low} to {high} in {time.time() - t0}s")
def _decode_one(self, session, sent_id):
action = {}
self.reset()
states = self.init_states()
while action.get("value", None) != DEFAULT_EOS:
# take an action
action = self.policy(states)
if action["key"] == GET:
new_states = session.get_src(sent_id, action["value"])
states = self.update_states(states, new_states)
elif action["key"] == SEND:
session.send_hypo(sent_id, action["value"])
print(" ".join(states["tokens"]["tgt"]))
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
from fairseq import checkpoint_utils, tasks, utils
from . import DEFAULT_EOS, GET, SEND
from .agent import Agent
class SimulTransAgent(Agent):
def __init__(self, args):
# Load Model
self.load_model(args)
# build word spliter
self.build_word_splitter(args)
self.max_len = args.max_len
self.eos = DEFAULT_EOS
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--model-path', type=str, required=True,
help='path to your pretrained model.')
parser.add_argument("--data-bin", type=str, required=True,
help="Path of data binary")
parser.add_argument("--user-dir", type=str, default="example/simultaneous_translation",
help="User directory for simultaneous translation")
parser.add_argument("--src-splitter-type", type=str, default=None,
help="Subword splitter type for source text")
parser.add_argument("--tgt-splitter-type", type=str, default=None,
help="Subword splitter type for target text")
parser.add_argument("--src-splitter-path", type=str, default=None,
help="Subword splitter model path for source text")
parser.add_argument("--tgt-splitter-path", type=str, default=None,
help="Subword splitter model path for target text")
parser.add_argument("--max-len", type=int, default=150,
help="Maximum length difference between source and target prediction")
parser.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='A dictionary used to override model args at generation '
'that were used during model training')
# fmt: on
return parser
def load_dictionary(self, task):
raise NotImplementedError
def load_model(self, args):
args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..")
utils.import_user_module(args)
filename = args.model_path
if not os.path.exists(filename):
raise IOError("Model file not found: {}".format(filename))
state = checkpoint_utils.load_checkpoint_to_cpu(
filename, json.loads(args.model_overrides)
)
saved_args = state["args"]
saved_args.data = args.data_bin
task = tasks.setup_task(saved_args)
# build model for ensemble
self.model = task.build_model(saved_args)
self.model.load_state_dict(state["model"], strict=True)
# Set dictionary
self.load_dictionary(task)
def init_states(self):
return {
"indices": {"src": [], "tgt": []},
"tokens": {"src": [], "tgt": []},
"segments": {"src": [], "tgt": []},
"steps": {"src": 0, "tgt": 0},
"finished": False,
"finish_read": False,
"model_states": {},
}
def update_states(self, states, new_state):
raise NotImplementedError
def policy(self, states):
# Read and Write policy
action = None
while action is None:
if states["finished"]:
# Finish the hypo by sending eos to server
return self.finish_action()
# Model make decision given current states
decision = self.model.decision_from_states(states)
if decision == 0 and not self.finish_read(states):
# READ
action = self.read_action(states)
else:
# WRITE
action = self.write_action(states)
# None means we make decision again but not sending server anything
# This happened when read a bufffered token
# Or predict a subword
return action
def finish_read(self, states):
raise NotImplementedError
def write_action(self, states):
token, index = self.model.predict_from_states(states)
if (
index == self.dict["tgt"].eos()
or len(states["tokens"]["tgt"]) > self.max_len
):
# Finish this sentence is predict EOS
states["finished"] = True
end_idx_last_full_word = self._target_length(states)
else:
states["tokens"]["tgt"] += [token]
end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word(
states["tokens"]["tgt"]
)
self._append_indices(states, [index], "tgt")
if end_idx_last_full_word > states["steps"]["tgt"]:
# Only sent detokenized full words to the server
word = self.word_splitter["tgt"].merge(
states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word]
)
states["steps"]["tgt"] = end_idx_last_full_word
states["segments"]["tgt"] += [word]
return {"key": SEND, "value": word}
else:
return None
def read_action(self, states):
return {"key": GET, "value": None}
def finish_action(self):
return {"key": SEND, "value": DEFAULT_EOS}
def reset(self):
pass
def finish_eval(self, states, new_state):
if len(new_state) == 0 and len(states["indices"]["src"]) == 0:
return True
return False
def _append_indices(self, states, new_indices, key):
states["indices"][key] += new_indices
def _target_length(self, states):
return len(states["tokens"]["tgt"])
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import DEFAULT_EOS, GET, register_agent
from .simul_trans_agent import SimulTransAgent
from .word_splitter import SPLITTER_DICT
@register_agent("simul_trans_text")
class SimulTransTextAgent(SimulTransAgent):
def build_word_splitter(self, args):
self.word_splitter = {}
self.word_splitter["src"] = SPLITTER_DICT[args.src_splitter_type](
getattr(args, f"src_splitter_path")
)
self.word_splitter["tgt"] = SPLITTER_DICT[args.tgt_splitter_type](
getattr(args, f"tgt_splitter_path")
)
def load_dictionary(self, task):
self.dict = {}
self.dict["tgt"] = task.target_dictionary
self.dict["src"] = task.source_dictionary
def update_states(self, states, new_state):
if states["finish_read"]:
return states
new_word = new_state["segment"]
# Split words and index the token
if new_word not in [DEFAULT_EOS]:
tokens = self.word_splitter["src"].split(new_word)
# Get indices from dictionary
# You can change to you own dictionary
indices = (
self.dict["src"]
.encode_line(
tokens,
line_tokenizer=lambda x: x,
add_if_not_exist=False,
append_eos=False,
)
.tolist()
)
else:
tokens = [new_word]
indices = [self.dict["src"].eos()]
states["finish_read"] = True
# Update states
states["segments"]["src"] += [new_word]
states["tokens"]["src"] += tokens
self._append_indices(states, indices, "src")
return states
def read_action(self, states):
# Increase source step by one
states["steps"]["src"] += 1
# At leat one word is read
if len(states["tokens"]["src"]) == 0:
return {"key": GET, "value": None}
# Only request new word if there is no buffered tokens
if len(states["tokens"]["src"]) <= states["steps"]["src"]:
return {"key": GET, "value": None}
return None
def finish_read(self, states):
# The first means all segments (full words) has been read from server
# The second means all tokens (subwords) has been read locally
return (
states["finish_read"]
and len(states["tokens"]["src"]) == states["steps"]["src"]
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
class SubwordSplitter(object):
def process_line(self, string):
raise NotImplementedError
def split(self, string):
raise NotImplementedError
class NoneWordSplitter(object):
def __init__(self, model):
pass
def split(self, string):
return [string]
def process_line(self, string):
return [string]
def finished_word(self, string):
return True
def merge(self, list_of_string):
return "".join(list_of_string)
def last_full_word_step(self, tokens, step):
return len(tokens)
def end_idx_last_full_word(self, tokens):
return len(tokens)
class BPEWordSplitter(object):
# TODO: lock back here
def __init__(self, model_path):
super().__init__()
from subword_nmt.apply_bpe import BPE
with open(model_path) as f:
self.model = BPE(f)
def split(self, string):
return self.model.process_line(string).split()
def end_idx_last_full_word(self, tokens):
# Begin of word indices
bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != "@@"]
if len(bow_indices) < 2:
return 0
else:
return bow_indices[-1]
def merge(self, list_of_string):
return " ".join([item.replace("@@", "") for item in list_of_string])
class SentencePieceModelWordSplitter(object):
def __init__(self, model_path):
super().__init__()
import sentencepiece as spm
self.model = spm.SentencePieceProcessor()
self.model.Load(model_path)
def split(self, string):
return self.model.EncodeAsPieces(string)
def end_idx_last_full_word(self, tokens):
# Begin of word indices
bow_indices = [i for i, t in enumerate(tokens) if t[0] == "\u2581"]
if len(bow_indices) < 2:
return 0
else:
return bow_indices[-1]
def merge(self, list_of_string):
return self.model.DecodePieces(list_of_string)
SPLITTER_DICT = {
None: NoneWordSplitter,
"BPE": BPEWordSplitter,
"SentencePieceModel": SentencePieceModelWordSplitter,
}
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import requests
from scorers import build_scorer
class SimulSTEvaluationService(object):
DEFAULT_HOSTNAME = "localhost"
DEFAULT_PORT = 12321
def __init__(self, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT):
self.hostname = hostname
self.port = port
self.base_url = f"http://{self.hostname}:{self.port}"
def __enter__(self):
self.new_session()
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def new_session(self):
# start eval session
url = f"{self.base_url}"
try:
_ = requests.post(url)
except Exception as e:
print(f"Failed to start an evaluation session: {e}")
print("Evaluation session started.")
return self
def get_scores(self):
# end eval session
url = f"{self.base_url}/result"
try:
r = requests.get(url)
print("Scores: {}".format(r.json()))
print("Evaluation session finished.")
except Exception as e:
print(f"Failed to end an evaluation session: {e}")
def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str:
url = f"{self.base_url}/src"
params = {"sent_id": sent_id}
if extra_params is not None:
for key in extra_params.keys():
params[key] = extra_params[key]
try:
r = requests.get(url, params=params)
except Exception as e:
print(f"Failed to request a source segment: {e}")
return r.json()
def send_hypo(self, sent_id: int, hypo: str) -> None:
url = f"{self.base_url}/hypo"
params = {"sent_id": sent_id}
try:
requests.put(url, params=params, data=hypo.encode("utf-8"))
except Exception as e:
print(f"Failed to send a translated segment: {e}")
def corpus_info(self):
url = f"{self.base_url}"
try:
r = requests.get(url)
except Exception as e:
print(f"Failed to request corpus information: {e}")
return r.json()
class SimulSTLocalEvaluationService(object):
def __init__(self, args):
self.scorer = build_scorer(args)
def get_scores(self):
return self.scorer.score()
def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str:
if extra_params is not None:
segment_size = extra_params.get("segment_size", None)
else:
segment_size = None
return self.scorer.send_src(int(sent_id), segment_size)
def send_hypo(self, sent_id: int, hypo: str) -> None:
list_of_tokens = hypo.strip().split()
self.scorer.recv_hyp(sent_id, list_of_tokens)
def corpus_info(self):
return self.scorer.get_info()
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import torch
from examples.simultaneous_translation.utils.latency import LatencyInference
LATENCY_METRICS = [
"differentiable_average_lagging",
"average_lagging",
"average_proportion",
]
class LatencyScorer:
def __init__(self, start_from_zero=True):
self.recorder = []
self.scores = {}
self.scorer = LatencyInference()
self.start_from_zero = start_from_zero
def update_reorder(self, list_of_dict):
self.recorder = []
for info in list_of_dict:
delays = [int(x) - int(not self.start_from_zero) for x in info["delays"]]
delays = torch.LongTensor(delays).unsqueeze(0)
src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0)
self.recorder.append(self.scorer(delays, src_len))
def cal_latency(self):
self.scores = {}
for metric in LATENCY_METRICS:
self.scores[metric] = sum(
[x[metric][0, 0].item() for x in self.recorder]
) / len(self.recorder)
return self.scores
@classmethod
def score(cls, list_of_dict, start_from_zero=True):
scorer_to_return = cls(start_from_zero)
scorer_to_return.update_reorder(list_of_dict)
scorer_to_return.cal_latency()
return scorer_to_return.scores
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", required=True)
parser.add_argument("--start-from-zero", action="store_true")
args = parser.parse_args()
scorer = LatencyInference()
recorder = []
with open(args.input, "r") as f:
for line in f:
info = json.loads(line)
delays = [int(x) - int(not args.start_from_zero) for x in info["delays"]]
delays = torch.LongTensor(delays).unsqueeze(0)
src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0)
recorder.append(scorer(delays, src_len))
average_results = {}
for metric in LATENCY_METRICS:
average_results[metric] = sum([x[metric][0, 0].item() for x in recorder]) / len(
recorder
)
print(f"{metric}: {average_results[metric]}")
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from agents import build_agent
from client import SimulSTEvaluationService, SimulSTLocalEvaluationService
from fairseq.registry import REGISTRIES
DEFAULT_HOSTNAME = "localhost"
DEFAULT_PORT = 12321
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--hostname", type=str, default=DEFAULT_HOSTNAME, help="server hostname"
)
parser.add_argument(
"--port", type=int, default=DEFAULT_PORT, help="server port number"
)
parser.add_argument("--agent-type", default="simul_trans_text", help="Agent type")
parser.add_argument("--scorer-type", default="text", help="Scorer type")
parser.add_argument(
"--start-idx",
type=int,
default=0,
help="Start index of the sentence to evaluate",
)
parser.add_argument(
"--end-idx",
type=int,
default=float("inf"),
help="End index of the sentence to evaluate",
)
parser.add_argument(
"--scores", action="store_true", help="Request scores from server"
)
parser.add_argument("--reset-server", action="store_true", help="Reset the server")
parser.add_argument(
"--num-threads", type=int, default=10, help="Number of threads used by agent"
)
parser.add_argument(
"--local", action="store_true", default=False, help="Local evaluation"
)
args, _ = parser.parse_known_args()
for registry_name, REGISTRY in REGISTRIES.items():
choice = getattr(args, registry_name, None)
if choice is not None:
cls = REGISTRY["registry"][choice]
if hasattr(cls, "add_args"):
cls.add_args(parser)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
if args.local:
session = SimulSTLocalEvaluationService(args)
else:
session = SimulSTEvaluationService(args.hostname, args.port)
if args.reset_server:
session.new_session()
if args.agent_type is not None:
agent = build_agent(args)
agent.decode(session, args.start_idx, args.end_idx, args.num_threads)
if args.scores:
session.get_scores()
print(session.get_scores())
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
from fairseq import registry
(build_scorer, register_scorer, SCORER_REGISTRIES, _) = registry.setup_registry(
"--scorer-type"
)
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
module = file[: file.find(".py")]
importlib.import_module("scorers." + module)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
from collections import defaultdict
from examples.simultaneous_translation.eval.eval_latency import LatencyScorer
from vizseq.scorers.bleu import BLEUScorer
from vizseq.scorers.meteor import METEORScorer
from vizseq.scorers.ter import TERScorer
DEFAULT_EOS = "</s>"
class SimulScorer(object):
def __init__(self, args):
self.tokenizer = args.tokenizer
self.output_dir = args.output
if args.output is not None:
self.output_files = {
"text": os.path.join(args.output, "text"),
"delay": os.path.join(args.output, "delay"),
"scores": os.path.join(args.output, "scores"),
}
else:
self.output_files = None
self.eos = DEFAULT_EOS
self.data = {"tgt": []}
self.reset()
def get_info(self):
return {"num_sentences": len(self)}
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--src-file', type=str, required=True,
help='Source input file')
parser.add_argument('--tgt-file', type=str, required=True,
help='Target reference file')
parser.add_argument('--tokenizer', default="13a", choices=["none", "13a"],
help='Tokenizer used for sacrebleu')
parser.add_argument('--output', type=str, default=None,
help='Path for output directory')
# fmt: on
def send_src(self, sent_id, *args):
raise NotImplementedError
def recv_hyp(self, sent_id, list_of_tokens):
for token in list_of_tokens:
self.translations[sent_id].append((token, self.steps[sent_id]))
def reset(self):
self.steps = defaultdict(int)
self.translations = defaultdict(list)
def src_lengths(self):
raise NotImplementedError
def score(self):
translations = []
delays = []
for i in range(1 + max(self.translations.keys())):
translations += [" ".join(t[0] for t in self.translations[i][:-1])]
delays += [[t[1] for t in self.translations[i]]]
bleu_score = BLEUScorer(
sent_level=False,
corpus_level=True,
extra_args={"bleu_tokenizer": self.tokenizer},
).score(translations, [self.data["tgt"]])
ter_score = TERScorer(sent_level=False, corpus_level=True).score(
translations, [self.data["tgt"]]
)
meteor_score = METEORScorer(sent_level=False, corpus_level=True).score(
translations, [self.data["tgt"]]
)
latency_score = LatencyScorer().score(
[
{"src_len": src_len, "delays": delay}
for src_len, delay in zip(self.src_lengths(), delays)
],
start_from_zero=False,
)
scores = {
"BLEU": bleu_score[0],
"TER": ter_score[0],
"METEOR": meteor_score[0],
"DAL": latency_score["differentiable_average_lagging"],
"AL": latency_score["average_lagging"],
"AP": latency_score["average_proportion"],
}
if self.output_files is not None:
try:
os.makedirs(self.output_dir, exist_ok=True)
self.write_results_to_file(translations, delays, scores)
except BaseException as be:
print(f"Failed to write results to {self.output_dir}.")
print(be)
print("Skip writing predictions")
return scores
def write_results_to_file(self, translations, delays, scores):
if self.output_files["text"] is not None:
with open(self.output_files["text"], "w") as f:
for line in translations:
f.write(line + "\n")
if self.output_files["delay"] is not None:
with open(self.output_files["delay"], "w") as f:
for i, delay in enumerate(delays):
f.write(
json.dumps({"src_len": self.src_lengths()[i], "delays": delay})
+ "\n"
)
with open(self.output_files["scores"], "w") as f:
for key, value in scores.items():
f.write(f"{key}, {value}\n")
@classmethod
def _load_text_file(cls, file, split=False):
with open(file) as f:
if split:
return [r.strip().split() for r in f]
else:
return [r.strip() for r in f]
@classmethod
def _load_text_from_json(cls, file):
list_to_return = []
with open(file) as f:
content = json.load(f)
for item in content["utts"].values():
list_to_return.append(item["output"]["text"].strip())
return list_to_return
@classmethod
def _load_wav_info_from_json(cls, file):
list_to_return = []
with open(file) as f:
content = json.load(f)
for item in content["utts"].values():
list_to_return.append(
{
"path": item["input"]["path"].strip(),
"length": item["input"]["length_ms"],
}
)
return list_to_return
@classmethod
def _load_wav_info_from_list(cls, file):
list_to_return = []
with open(file) as f:
for line in f:
list_to_return.append(
{
"path": line.strip(),
}
)
return list_to_return
def __len__(self):
return len(self.data["tgt"])
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import register_scorer
from .scorer import SimulScorer
@register_scorer("text")
class SimulTextScorer(SimulScorer):
def __init__(self, args):
super().__init__(args)
self.data = {
"src": self._load_text_file(args.src_file, split=True),
"tgt": self._load_text_file(args.tgt_file, split=False),
}
def send_src(self, sent_id, *args):
if self.steps[sent_id] >= len(self.data["src"][sent_id]):
dict_to_return = {
"sent_id": sent_id,
"segment_id": self.steps[sent_id],
"segment": self.eos,
}
# Consider EOS
self.steps[sent_id] = len(self.data["src"][sent_id]) + 1
else:
dict_to_return = {
"sent_id": sent_id,
"segment_id": self.steps[sent_id],
"segment": self.data["src"][sent_id][self.steps[sent_id]],
}
self.steps[sent_id] += 1
return dict_to_return
def src_lengths(self):
# +1 for eos
return [len(sent) + 1 for sent in self.data["src"]]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import sys
from scorers import build_scorer
from tornado import ioloop, web
DEFAULT_HOSTNAME = "localhost"
DEFAULT_PORT = 12321
class ScorerHandler(web.RequestHandler):
def initialize(self, scorer):
self.scorer = scorer
class EvalSessionHandler(ScorerHandler):
def post(self):
self.scorer.reset()
def get(self):
r = json.dumps(self.scorer.get_info())
self.write(r)
class ResultHandler(ScorerHandler):
def get(self):
r = json.dumps(self.scorer.score())
self.write(r)
class SourceHandler(ScorerHandler):
def get(self):
sent_id = int(self.get_argument("sent_id"))
segment_size = None
if "segment_size" in self.request.arguments:
string = self.get_argument("segment_size")
if len(string) > 0:
segment_size = int(string)
r = json.dumps(self.scorer.send_src(int(sent_id), segment_size))
self.write(r)
class HypothesisHandler(ScorerHandler):
def put(self):
sent_id = int(self.get_argument("sent_id"))
list_of_tokens = self.request.body.decode("utf-8").strip().split()
self.scorer.recv_hyp(sent_id, list_of_tokens)
def add_args():
parser = argparse.ArgumentParser()
# fmt: off
parser.add_argument('--hostname', type=str, default=DEFAULT_HOSTNAME,
help='Server hostname')
parser.add_argument('--port', type=int, default=DEFAULT_PORT,
help='Server port number')
args, _ = parser.parse_known_args()
# fmt: on
return args
def start_server(scorer, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT, debug=False):
app = web.Application(
[
(r"/result", ResultHandler, dict(scorer=scorer)),
(r"/src", SourceHandler, dict(scorer=scorer)),
(r"/hypo", HypothesisHandler, dict(scorer=scorer)),
(r"/", EvalSessionHandler, dict(scorer=scorer)),
],
debug=debug,
)
app.listen(port, max_buffer_size=1024 ** 3)
sys.stdout.write(f"Evaluation Server Started. Listening to port {port}\n")
ioloop.IOLoop.current().start()
if __name__ == "__main__":
args = add_args()
scorer = build_scorer(args)
start_server(scorer, args.hostname, args.port, args.debug)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
model_name = file[: file.find(".py")]
importlib.import_module(
"examples.simultaneous_translation.models." + model_name
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from examples.simultaneous_translation.modules.monotonic_transformer_layer import (
TransformerMonotonicDecoderLayer,
TransformerMonotonicEncoderLayer,
)
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer import (
TransformerDecoder,
TransformerEncoder,
TransformerModel,
base_architecture,
transformer_iwslt_de_en,
transformer_vaswani_wmt_en_de_big,
)
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model("transformer_unidirectional")
class TransformerUnidirectionalModel(TransformerModel):
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
@register_model("transformer_monotonic")
class TransformerMonotonicModel(TransformerModel):
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerMonotonicDecoder(args, tgt_dict, embed_tokens)
def _indices_from_states(self, states):
if type(states["indices"]["src"]) == list:
if next(self.parameters()).is_cuda:
tensor = torch.cuda.LongTensor
else:
tensor = torch.LongTensor
src_indices = tensor(
[states["indices"]["src"][: 1 + states["steps"]["src"]]]
)
tgt_indices = tensor(
[[self.decoder.dictionary.eos()] + states["indices"]["tgt"]]
)
else:
src_indices = states["indices"]["src"][: 1 + states["steps"]["src"]]
tgt_indices = states["indices"]["tgt"]
return src_indices, None, tgt_indices
def predict_from_states(self, states):
decoder_states = self.decoder.output_layer(states["decoder_features"])
lprobs = self.get_normalized_probs([decoder_states[:, -1:]], log_probs=True)
index = lprobs.argmax(dim=-1)
token = self.decoder.dictionary.string(index)
return token, index[0, 0].item()
def decision_from_states(self, states):
"""
This funcion take states dictionary as input, and gives the agent
a decision of whether read a token from server. Moreover, the decoder
states are also calculated here so we can directly generate a target
token without recompute every thing
"""
self.eval()
if len(states["tokens"]["src"]) == 0:
return 0
src_indices, src_lengths, tgt_indices = self._indices_from_states(states)
# Update encoder states if needed
if (
"encoder_states" not in states
or states["encoder_states"][0].size(1) <= states["steps"]["src"]
):
encoder_out_dict = self.encoder(src_indices, src_lengths)
states["encoder_states"] = encoder_out_dict
else:
encoder_out_dict = states["encoder_states"]
# online means we still need tokens to feed the model
states["model_states"]["online"] = not (
states["finish_read"]
and len(states["tokens"]["src"]) == states["steps"]["src"]
)
states["model_states"]["steps"] = states["steps"]
x, outputs = self.decoder.forward(
prev_output_tokens=tgt_indices,
encoder_out=encoder_out_dict,
incremental_state=states["model_states"],
features_only=True,
)
states["decoder_features"] = x
return outputs["action"]
class TransformerMonotonicEncoder(TransformerEncoder):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)
self.dictionary = dictionary
self.layers = nn.ModuleList([])
self.layers.extend(
[TransformerMonotonicEncoderLayer(args) for i in range(args.encoder_layers)]
)
class TransformerMonotonicDecoder(TransformerDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
super().__init__(args, dictionary, embed_tokens, no_encoder_attn=False)
self.dictionary = dictionary
self.layers = nn.ModuleList([])
self.layers.extend(
[
TransformerMonotonicDecoderLayer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
]
)
def pre_attention(
self, prev_output_tokens, encoder_out_dict, incremental_state=None
):
positions = (
self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
)
if self.embed_positions is not None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
encoder_out = encoder_out_dict.encoder_out
encoder_padding_mask = encoder_out_dict.encoder_padding_mask
return x, encoder_out, encoder_padding_mask
def post_attention(self, x):
if self.layer_norm:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x
def extract_features(
self, prev_output_tokens, encoder_out, incremental_state=None, **unused
):
"""
Similar to *forward* but only return features.
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
# incremental_state = None
(x, encoder_outs, encoder_padding_mask) = self.pre_attention(
prev_output_tokens, encoder_out, incremental_state
)
attn = None
inner_states = [x]
attn_list = []
step_list = []
for i, layer in enumerate(self.layers):
x, attn, _ = layer(
x=x,
encoder_out=encoder_outs,
encoder_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
self_attn_mask=self.buffered_future_mask(x)
if incremental_state is None
else None,
)
inner_states.append(x)
attn_list.append(attn)
if incremental_state is not None:
curr_steps = layer.get_steps(incremental_state)
step_list.append(curr_steps)
if incremental_state.get("online", False):
p_choose = (
attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t())
)
new_steps = curr_steps + (p_choose < 0.5).t().type_as(curr_steps)
if (new_steps >= incremental_state["steps"]["src"]).any():
# We need to prune the last self_attn saved_state
# if model decide not to read
# otherwise there will be duplicated saved_state
for j in range(i + 1):
self.layers[j].prune_incremental_state(incremental_state)
return x, {"action": 0}
if incremental_state is not None and not incremental_state.get("online", False):
# Here is for fast evaluation
fastest_step = (
torch.max(torch.cat(step_list, dim=1), dim=1, keepdim=True)[0] + 1
)
if "fastest_step" in incremental_state:
incremental_state["fastest_step"] = torch.cat(
[incremental_state["fastest_step"], fastest_step], dim=1
)
else:
incremental_state["fastest_step"] = fastest_step
x = self.post_attention(x)
return x, {
"action": 1,
"attn_list": attn_list,
"step_list": step_list,
"encoder_out": encoder_out,
"encoder_padding_mask": encoder_padding_mask,
}
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
if "fastest_step" in incremental_state:
incremental_state["fastest_step"] = incremental_state[
"fastest_step"
].index_select(0, new_order)
@register_model_architecture("transformer_monotonic", "transformer_monotonic")
def base_monotonic_rchitecture(args):
base_architecture(args)
args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False)
@register_model_architecture(
"transformer_monotonic", "transformer_monotonic_iwslt_de_en"
)
def transformer_monotonic_iwslt_de_en(args):
transformer_iwslt_de_en(args)
base_monotonic_rchitecture(args)
# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
@register_model_architecture(
"transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_de_big"
)
def transformer_monotonic_vaswani_wmt_en_de_big(args):
transformer_vaswani_wmt_en_de_big(args)
@register_model_architecture(
"transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_fr_big"
)
def transformer_monotonic_vaswani_wmt_en_fr_big(args):
transformer_monotonic_vaswani_wmt_en_fr_big(args)
@register_model_architecture(
"transformer_unidirectional", "transformer_unidirectional_iwslt_de_en"
)
def transformer_unidirectional_iwslt_de_en(args):
transformer_iwslt_de_en(args)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
from fairseq import registry
(
build_monotonic_attention,
register_monotonic_attention,
MONOTONIC_ATTENTION_REGISTRY,
_,
) = registry.setup_registry("--simul-type")
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
model_name = file[: file.find(".py")]
importlib.import_module(
"examples.simultaneous_translation.modules." + model_name
)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from examples.simultaneous_translation.utils.functions import (
exclusive_cumprod,
lengths_to_mask,
)
from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.modules import MultiheadAttention
from fairseq.utils import convert_padding_direction
from . import register_monotonic_attention
@with_incremental_state
class MonotonicAttention(nn.Module):
"""
Abstract class of monotonic attentions
"""
def __init__(self, args):
self.eps = args.attention_eps
self.mass_preservation = args.mass_preservation
self.noise_mean = args.noise_mean
self.noise_var = args.noise_var
self.energy_bias_init = args.energy_bias_init
self.energy_bias = (
nn.Parameter(self.energy_bias_init * torch.ones([1]))
if args.energy_bias is True
else 0
)
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--no-mass-preservation', action="store_false", dest="mass_preservation",
help='Do not stay on the last token when decoding')
parser.add_argument('--mass-preservation', action="store_true", dest="mass_preservation",
help='Stay on the last token when decoding')
parser.set_defaults(mass_preservation=True)
parser.add_argument('--noise-var', type=float, default=1.0,
help='Variance of discretness noise')
parser.add_argument('--noise-mean', type=float, default=0.0,
help='Mean of discretness noise')
parser.add_argument('--energy-bias', action="store_true", default=False,
help='Bias for energy')
parser.add_argument('--energy-bias-init', type=float, default=-2.0,
help='Initial value of the bias for energy')
parser.add_argument('--attention-eps', type=float, default=1e-6,
help='Epsilon when calculating expected attention')
# fmt: on
def p_choose(self, *args):
raise NotImplementedError
def input_projections(self, *args):
raise NotImplementedError
def attn_energy(self, q_proj, k_proj, key_padding_mask=None):
"""
Calculating monotonic energies
============================================================
Expected input size
q_proj: bsz * num_heads, tgt_len, self.head_dim
k_proj: bsz * num_heads, src_len, self.head_dim
key_padding_mask: bsz, src_len
attn_mask: tgt_len, src_len
"""
bsz, tgt_len, embed_dim = q_proj.size()
bsz = bsz // self.num_heads
src_len = k_proj.size(1)
attn_energy = torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias
attn_energy = attn_energy.view(bsz, self.num_heads, tgt_len, src_len)
if key_padding_mask is not None:
attn_energy = attn_energy.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).bool(),
float("-inf"),
)
return attn_energy
def expected_alignment_train(self, p_choose, key_padding_mask):
"""
Calculating expected alignment for MMA
Mask is not need because p_choose will be 0 if masked
q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j}
a_ij = p_ij q_ij
parellel solution:
ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))
============================================================
Expected input size
p_choose: bsz * num_heads, tgt_len, src_len
"""
# p_choose: bsz * num_heads, tgt_len, src_len
bsz_num_heads, tgt_len, src_len = p_choose.size()
# cumprod_1mp : bsz * num_heads, tgt_len, src_len
cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=self.eps)
cumprod_1mp_clamp = torch.clamp(cumprod_1mp, self.eps, 1.0)
init_attention = p_choose.new_zeros([bsz_num_heads, 1, src_len])
init_attention[:, :, 0] = 1.0
previous_attn = [init_attention]
for i in range(tgt_len):
# p_choose: bsz * num_heads, tgt_len, src_len
# cumprod_1mp_clamp : bsz * num_heads, tgt_len, src_len
# previous_attn[i]: bsz * num_heads, 1, src_len
# alpha_i: bsz * num_heads, src_len
alpha_i = (
p_choose[:, i]
* cumprod_1mp[:, i]
* torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1)
).clamp(0, 1.0)
previous_attn.append(alpha_i.unsqueeze(1))
# alpha: bsz * num_heads, tgt_len, src_len
alpha = torch.cat(previous_attn[1:], dim=1)
if self.mass_preservation:
# Last token has the residual probabilities
alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0)
assert not torch.isnan(alpha).any(), "NaN detected in alpha."
return alpha
def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state):
"""
Calculating mo alignment for MMA during inference time
============================================================
Expected input size
p_choose: bsz * num_heads, tgt_len, src_len
key_padding_mask: bsz * src_len
incremental_state: dict
"""
# p_choose: bsz * self.num_heads, src_len
bsz_num_heads, tgt_len, src_len = p_choose.size()
# One token at a time
assert tgt_len == 1
p_choose = p_choose[:, 0, :]
monotonic_cache = self._get_monotonic_buffer(incremental_state)
# prev_monotonic_step: bsz, num_heads
bsz = bsz_num_heads // self.num_heads
prev_monotonic_step = monotonic_cache.get(
"step", p_choose.new_zeros([bsz, self.num_heads]).long()
)
bsz, num_heads = prev_monotonic_step.size()
assert num_heads == self.num_heads
assert bsz * num_heads == bsz_num_heads
# p_choose: bsz, num_heads, src_len
p_choose = p_choose.view(bsz, num_heads, src_len)
if key_padding_mask is not None:
src_lengths = src_len - key_padding_mask.sum(dim=1, keepdim=True).long()
else:
src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len
# src_lengths: bsz, num_heads
src_lengths = src_lengths.expand_as(prev_monotonic_step)
# new_monotonic_step: bsz, num_heads
new_monotonic_step = prev_monotonic_step
step_offset = 0
if key_padding_mask is not None:
if key_padding_mask[:, 0].any():
# left_pad_source = True:
step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
max_steps = src_lengths - 1 if self.mass_preservation else src_lengths
# finish_read: bsz, num_heads
finish_read = new_monotonic_step.eq(max_steps)
while finish_read.sum().item() < bsz * self.num_heads:
# p_choose: bsz * self.num_heads, src_len
# only choose the p at monotonic steps
# p_choose_i: bsz , self.num_heads
p_choose_i = (
p_choose.gather(
2,
(step_offset + new_monotonic_step)
.unsqueeze(2)
.clamp(0, src_len - 1),
)
).squeeze(2)
action = (
(p_choose_i < 0.5)
.type_as(prev_monotonic_step)
.masked_fill(finish_read, 0)
)
# 1 x bsz
# sample actions on unfinished seq
# 1 means stay, finish reading
# 0 means leave, continue reading
# dist = torch.distributions.bernoulli.Bernoulli(p_choose)
# action = dist.sample().type_as(finish_read) * (1 - finish_read)
new_monotonic_step += action
finish_read = new_monotonic_step.eq(max_steps) | (action == 0)
# finish_read = (~ (finish_read.sum(dim=1, keepdim=True) < self.num_heads / 2)) | finish_read
monotonic_cache["step"] = new_monotonic_step
# alpha: bsz * num_heads, 1, src_len
# new_monotonic_step: bsz, num_heads
alpha = p_choose.new_zeros([bsz * self.num_heads, src_len]).scatter(
1,
(step_offset + new_monotonic_step)
.view(bsz * self.num_heads, 1)
.clamp(0, src_len - 1),
1,
)
if not self.mass_preservation:
alpha = alpha.masked_fill(
(new_monotonic_step == max_steps).view(bsz * self.num_heads, 1), 0
)
alpha = alpha.unsqueeze(1)
self._set_monotonic_buffer(incremental_state, monotonic_cache)
return alpha
def v_proj_output(self, value):
raise NotImplementedError
def forward(
self,
query,
key,
value,
key_padding_mask=None,
incremental_state=None,
*args,
**kwargs,
):
tgt_len, bsz, embed_dim = query.size()
src_len = value.size(0)
# stepwise prob
# p_choose: bsz * self.num_heads, tgt_len, src_len
p_choose = self.p_choose(query, key, key_padding_mask)
# expected alignment alpha
# bsz * self.num_heads, tgt_len, src_len
if incremental_state is not None:
alpha = self.expected_alignment_infer(
p_choose, key_padding_mask, incremental_state
)
else:
alpha = self.expected_alignment_train(p_choose, key_padding_mask)
# expected attention beta
# bsz * self.num_heads, tgt_len, src_len
beta = self.expected_attention(
alpha, query, key, value, key_padding_mask, incremental_state
)
attn_weights = beta
v_proj = self.v_proj_output(value)
attn = torch.bmm(attn_weights.type_as(v_proj), v_proj)
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
beta = beta.view(bsz, self.num_heads, tgt_len, src_len)
alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len)
p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len)
return attn, {"alpha": alpha, "beta": beta, "p_choose": p_choose}
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
super().reorder_incremental_state(incremental_state, new_order)
input_buffer = self._get_monotonic_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
input_buffer[k] = input_buffer[k].index_select(0, new_order)
self._set_monotonic_buffer(incremental_state, input_buffer)
def _get_monotonic_buffer(self, incremental_state):
return (
utils.get_incremental_state(
self,
incremental_state,
"monotonic",
)
or {}
)
def _set_monotonic_buffer(self, incremental_state, buffer):
utils.set_incremental_state(
self,
incremental_state,
"monotonic",
buffer,
)
def get_pointer(self, incremental_state):
return (
utils.get_incremental_state(
self,
incremental_state,
"monotonic",
)
or {}
)
def get_fastest_pointer(self, incremental_state):
return self.get_pointer(incremental_state)["step"].max(0)[0]
def set_pointer(self, incremental_state, p_choose):
curr_pointer = self.get_pointer(incremental_state)
if len(curr_pointer) == 0:
buffer = torch.zeros_like(p_choose)
else:
buffer = self.get_pointer(incremental_state)["step"]
buffer += (p_choose < 0.5).type_as(buffer)
utils.set_incremental_state(
self,
incremental_state,
"monotonic",
{"step": buffer},
)
@register_monotonic_attention("hard_aligned")
class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
def __init__(self, args):
MultiheadAttention.__init__(
self,
embed_dim=args.decoder_embed_dim,
num_heads=args.decoder_attention_heads,
kdim=getattr(args, "encoder_embed_dim", None),
vdim=getattr(args, "encoder_embed_dim", None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
MonotonicAttention.__init__(self, args)
self.k_in_proj = {"monotonic": self.k_proj}
self.q_in_proj = {"monotonic": self.q_proj}
self.v_in_proj = {"output": self.v_proj}
def input_projections(self, query, key, value, name):
"""
Prepare inputs for multihead attention
============================================================
Expected input size
query: tgt_len, bsz, embed_dim
key: src_len, bsz, embed_dim
value: src_len, bsz, embed_dim
name: monotonic or soft
"""
if query is not None:
bsz = query.size(1)
q = self.q_in_proj[name](query)
q *= self.scaling
q = (
q.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
else:
q = None
if key is not None:
bsz = key.size(1)
k = self.k_in_proj[name](key)
k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
else:
k = None
if value is not None:
bsz = value.size(1)
v = self.v_in_proj[name](value)
v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
else:
v = None
return q, k, v
def p_choose(self, query, key, key_padding_mask=None):
"""
Calculating step wise prob for reading and writing
1 to read, 0 to write
============================================================
Expected input size
query: bsz, tgt_len, embed_dim
key: bsz, src_len, embed_dim
value: bsz, src_len, embed_dim
key_padding_mask: bsz, src_len
attn_mask: bsz, src_len
query: bsz, tgt_len, embed_dim
"""
# prepare inputs
q_proj, k_proj, _ = self.input_projections(query, key, None, "monotonic")
# attention energy
attn_energy = self.attn_energy(q_proj, k_proj, key_padding_mask)
noise = 0
if self.training:
# add noise here to encourage discretness
noise = (
torch.normal(self.noise_mean, self.noise_var, attn_energy.size())
.type_as(attn_energy)
.to(attn_energy.device)
)
p_choose = torch.sigmoid(attn_energy + noise)
_, _, tgt_len, src_len = p_choose.size()
# p_choose: bsz * self.num_heads, tgt_len, src_len
return p_choose.view(-1, tgt_len, src_len)
def expected_attention(self, alpha, *args):
"""
For MMA-H, beta = alpha
"""
return alpha
def v_proj_output(self, value):
_, _, v_proj = self.input_projections(None, None, value, "output")
return v_proj
@register_monotonic_attention("infinite_lookback")
class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHard):
def __init__(self, args):
super().__init__(args)
self.init_soft_attention()
def init_soft_attention(self):
self.k_proj_soft = nn.Linear(self.kdim, self.embed_dim, bias=True)
self.q_proj_soft = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.k_in_proj["soft"] = self.k_proj_soft
self.q_in_proj["soft"] = self.q_proj_soft
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(
self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2)
)
nn.init.xavier_uniform_(
self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2)
)
else:
nn.init.xavier_uniform_(self.k_in_proj["soft"].weight)
nn.init.xavier_uniform_(self.q_in_proj["soft"].weight)
def expected_attention(
self, alpha, query, key, value, key_padding_mask, incremental_state
):
# monotonic attention, we will calculate milk here
bsz_x_num_heads, tgt_len, src_len = alpha.size()
bsz = int(bsz_x_num_heads / self.num_heads)
q, k, _ = self.input_projections(query, key, None, "soft")
soft_energy = self.attn_energy(q, k, key_padding_mask)
assert list(soft_energy.size()) == [bsz, self.num_heads, tgt_len, src_len]
soft_energy = soft_energy.view(bsz * self.num_heads, tgt_len, src_len)
if incremental_state is not None:
monotonic_cache = self._get_monotonic_buffer(incremental_state)
monotonic_step = monotonic_cache["step"] + 1
step_offset = 0
if key_padding_mask is not None:
if key_padding_mask[:, 0].any():
# left_pad_source = True:
step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
monotonic_step += step_offset
mask = lengths_to_mask(
monotonic_step.view(-1), soft_energy.size(2), 1
).unsqueeze(1)
soft_energy = soft_energy.masked_fill(~mask.bool(), float("-inf"))
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
exp_soft_energy = torch.exp(soft_energy)
exp_soft_energy_sum = exp_soft_energy.sum(dim=2)
beta = exp_soft_energy / exp_soft_energy_sum.unsqueeze(2)
else:
# bsz * num_heads, tgt_len, src_len
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
exp_soft_energy = torch.exp(soft_energy)
exp_soft_energy_cumsum = torch.cumsum(exp_soft_energy, dim=2)
if key_padding_mask is not None:
if key_padding_mask.any():
exp_soft_energy_cumsum = (
exp_soft_energy_cumsum.view(
-1, self.num_heads, tgt_len, src_len
)
.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps
)
.view(-1, tgt_len, src_len)
)
inner_items = alpha / exp_soft_energy_cumsum
beta = exp_soft_energy * torch.cumsum(
inner_items.flip(dims=[2]), dim=2
).flip(dims=[2])
beta = self.dropout_module(beta)
assert not torch.isnan(beta).any(), "NaN detected in beta."
return beta
@register_monotonic_attention("waitk")
class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookback):
def __init__(self, args):
super().__init__(args)
self.q_in_proj["soft"] = self.q_in_proj["monotonic"]
self.k_in_proj["soft"] = self.k_in_proj["monotonic"]
self.waitk_lagging = args.waitk_lagging
assert (
self.waitk_lagging > 0
), f"Lagging has to been larger than 0, get {self.waitk_lagging}."
@staticmethod
def add_args(parser):
super(
MonotonicMultiheadAttentionWaitk,
MonotonicMultiheadAttentionWaitk,
).add_args(parser)
parser.add_argument(
"--waitk-lagging", type=int, required=True, help="Wait k lagging"
)
def p_choose(
self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None
):
"""
query: bsz, tgt_len
key: bsz, src_len
key_padding_mask: bsz, src_len
"""
src_len, bsz, _ = key.size()
tgt_len, bsz, _ = query.size()
p_choose = query.new_ones(bsz, tgt_len, src_len)
p_choose = torch.tril(p_choose, diagonal=self.waitk_lagging - 1)
p_choose = torch.triu(p_choose, diagonal=self.waitk_lagging - 1)
if key_padding_mask is not None and key_padding_mask[:, 0].eq(1).any():
# Left pad source
# add -1 to the end
p_choose = p_choose.masked_fill(
key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1
)
p_choose = convert_padding_direction(
p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True
)
p_choose = p_choose.view(bsz, tgt_len, src_len).type_as(query)
# remove -1
p_choose[p_choose.eq(-1)] = 0
# Extend to each head
p_choose = (
p_choose.contiguous()
.unsqueeze(1)
.expand(-1, self.num_heads, -1, -1)
.contiguous()
.view(-1, tgt_len, src_len)
)
return p_choose
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.modules import LayerNorm, TransformerDecoderLayer, TransformerEncoderLayer
from . import build_monotonic_attention
class TransformerMonotonicEncoderLayer(TransformerEncoderLayer):
def forward(self, x, encoder_padding_mask):
seq_len, _, _ = x.size()
attn_mask = x.new_ones([seq_len, seq_len]).triu(1)
attn_mask = attn_mask.masked_fill(attn_mask.bool(), float("-inf"))
return super().forward(x, encoder_padding_mask, attn_mask)
class TransformerMonotonicDecoderLayer(TransformerDecoderLayer):
def __init__(
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
):
super().__init__(
args,
no_encoder_attn=True,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
self.encoder_attn = build_monotonic_attention(args)
self.encoder_attn_layer_norm = LayerNorm(
self.embed_dim, export=getattr(args, "char_inputs", False)
)
def prune_incremental_state(self, incremental_state):
def prune(module):
input_buffer = module._get_input_buffer(incremental_state)
for key in ["prev_key", "prev_value"]:
if input_buffer[key].size(2) > 1:
input_buffer[key] = input_buffer[key][:, :, :-1, :]
else:
input_buffer = {}
break
module._set_input_buffer(incremental_state, input_buffer)
prune(self.self_attn)
def get_steps(self, incremental_state):
return self.encoder_attn._get_monotonic_buffer(incremental_state).get("step", 0)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import os
# automatically import any Python files in the criterions/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
module = file[: file.find(".py")]
importlib.import_module("examples.simultaneous_translation.utils." + module)
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
def exclusive_cumprod(tensor, dim: int, eps: float = 1e-10):
"""
Implementing exclusive cumprod.
There is cumprod in pytorch, however there is no exclusive mode.
cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i]
exclusive means cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i]
"""
tensor_size = list(tensor.size())
tensor_size[dim] = 1
return_tensor = safe_cumprod(
torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim),
dim=dim,
eps=eps,
)
if dim == 0:
return return_tensor[:-1]
elif dim == 1:
return return_tensor[:, :-1]
elif dim == 2:
return return_tensor[:, :, :-1]
else:
raise RuntimeError("Cumprod on dimension 3 and more is not implemented")
def safe_cumprod(tensor, dim: int, eps: float = 1e-10):
"""
An implementation of cumprod to prevent precision issue.
cumprod(x)
= [x1, x1x2, x1x2x3, ....]
= [exp(log(x1)), exp(log(x1) + log(x2)), exp(log(x1) + log(x2) + log(x3)), ...]
= exp(cumsum(log(x)))
"""
if (tensor + eps < 0).any().item():
raise RuntimeError(
"Safe cumprod can only take non-negative tensors as input."
"Consider use torch.cumprod if you want to calculate negative values."
)
log_tensor = torch.log(tensor + eps)
cumsum_log_tensor = torch.cumsum(log_tensor, dim)
exp_cumsum_log_tensor = torch.exp(cumsum_log_tensor)
return exp_cumsum_log_tensor
def lengths_to_mask(lengths, max_len: int, dim: int = 0, negative_mask: bool = False):
"""
Convert a tensor of lengths to mask
For example, lengths = [[2, 3, 4]], max_len = 5
mask =
[[1, 1, 1],
[1, 1, 1],
[0, 1, 1],
[0, 0, 1],
[0, 0, 0]]
"""
assert len(lengths.size()) <= 2
if len(lengths) == 2:
if dim == 1:
lengths = lengths.t()
lengths = lengths
else:
lengths = lengths.unsqueeze(1)
# lengths : batch_size, 1
lengths = lengths.view(-1, 1)
batch_size = lengths.size(0)
# batch_size, max_len
mask = torch.arange(max_len).expand(batch_size, max_len).type_as(lengths) < lengths
if negative_mask:
mask = ~mask
if dim == 0:
# max_len, batch_size
mask = mask.t()
return mask
def moving_sum(x, start_idx: int, end_idx: int):
"""
From MONOTONIC CHUNKWISE ATTENTION
https://arxiv.org/pdf/1712.05382.pdf
Equation (18)
x = [x_1, x_2, ..., x_N]
MovingSum(x, start_idx, end_idx)_n = Sigma_{m=n−(start_idx−1)}^{n+end_idx-1} x_m
for n in {1, 2, 3, ..., N}
x : src_len, batch_size
start_idx : start idx
end_idx : end idx
Example
src_len = 5
batch_size = 3
x =
[[ 0, 5, 10],
[ 1, 6, 11],
[ 2, 7, 12],
[ 3, 8, 13],
[ 4, 9, 14]]
MovingSum(x, 3, 1) =
[[ 0, 5, 10],
[ 1, 11, 21],
[ 3, 18, 33],
[ 6, 21, 36],
[ 9, 24, 39]]
MovingSum(x, 1, 3) =
[[ 3, 18, 33],
[ 6, 21, 36],
[ 9, 24, 39],
[ 7, 17, 27],
[ 4, 9, 14]]
"""
assert start_idx > 0 and end_idx > 0
assert len(x.size()) == 2
src_len, batch_size = x.size()
# batch_size, 1, src_len
x = x.t().unsqueeze(1)
# batch_size, 1, src_len
moving_sum_weight = x.new_ones([1, 1, end_idx + start_idx - 1])
moving_sum = (
torch.nn.functional.conv1d(
x, moving_sum_weight, padding=start_idx + end_idx - 1
)
.squeeze(1)
.t()
)
moving_sum = moving_sum[end_idx:-start_idx]
assert src_len == moving_sum.size(0)
assert batch_size == moving_sum.size(1)
return moving_sum
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
class LatencyMetric(object):
@staticmethod
def length_from_padding_mask(padding_mask, batch_first: bool = False):
dim = 1 if batch_first else 0
return padding_mask.size(dim) - padding_mask.sum(dim=dim, keepdim=True)
def prepare_latency_metric(
self,
delays,
src_lens,
target_padding_mask=None,
batch_first: bool = False,
start_from_zero: bool = True,
):
assert len(delays.size()) == 2
assert len(src_lens.size()) == 2
if start_from_zero:
delays = delays + 1
if batch_first:
# convert to batch_last
delays = delays.t()
src_lens = src_lens.t()
tgt_len, bsz = delays.size()
_, bsz_1 = src_lens.size()
if target_padding_mask is not None:
target_padding_mask = target_padding_mask.t()
tgt_len_1, bsz_2 = target_padding_mask.size()
assert tgt_len == tgt_len_1
assert bsz == bsz_2
assert bsz == bsz_1
if target_padding_mask is None:
tgt_lens = tgt_len * delays.new_ones([1, bsz]).float()
else:
# 1, batch_size
tgt_lens = self.length_from_padding_mask(target_padding_mask, False).float()
delays = delays.masked_fill(target_padding_mask, 0)
return delays, src_lens, tgt_lens, target_padding_mask
def __call__(
self,
delays,
src_lens,
target_padding_mask=None,
batch_first: bool = False,
start_from_zero: bool = True,
):
delays, src_lens, tgt_lens, target_padding_mask = self.prepare_latency_metric(
delays, src_lens, target_padding_mask, batch_first, start_from_zero
)
return self.cal_metric(delays, src_lens, tgt_lens, target_padding_mask)
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
"""
Expected sizes:
delays: tgt_len, batch_size
src_lens: 1, batch_size
target_padding_mask: tgt_len, batch_size
"""
raise NotImplementedError
class AverageProportion(LatencyMetric):
"""
Function to calculate Average Proportion from
Can neural machine translation do simultaneous translation?
(https://arxiv.org/abs/1606.02012)
Delays are monotonic steps, range from 1 to src_len.
Give src x tgt y, AP is calculated as:
AP = 1 / (|x||y]) sum_i^|Y| deleys_i
"""
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
if target_padding_mask is not None:
AP = torch.sum(
delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True
)
else:
AP = torch.sum(delays, dim=0, keepdim=True)
AP = AP / (src_lens * tgt_lens)
return AP
class AverageLagging(LatencyMetric):
"""
Function to calculate Average Lagging from
STACL: Simultaneous Translation with Implicit Anticipation
and Controllable Latency using Prefix-to-Prefix Framework
(https://arxiv.org/abs/1810.08398)
Delays are monotonic steps, range from 1 to src_len.
Give src x tgt y, AP is calculated as:
AL = 1 / tau sum_i^tau delays_i - (i - 1) / gamma
Where
gamma = |y| / |x|
tau = argmin_i(delays_i = |x|)
"""
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
# tau = argmin_i(delays_i = |x|)
tgt_len, bsz = delays.size()
lagging_padding_mask = delays >= src_lens
lagging_padding_mask = torch.nn.functional.pad(
lagging_padding_mask.t(), (1, 0)
).t()[:-1, :]
gamma = tgt_lens / src_lens
lagging = (
delays
- torch.arange(delays.size(0))
.unsqueeze(1)
.type_as(delays)
.expand_as(delays)
/ gamma
)
lagging.masked_fill_(lagging_padding_mask, 0)
tau = (1 - lagging_padding_mask.type_as(lagging)).sum(dim=0, keepdim=True)
AL = lagging.sum(dim=0, keepdim=True) / tau
return AL
class DifferentiableAverageLagging(LatencyMetric):
"""
Function to calculate Differentiable Average Lagging from
Monotonic Infinite Lookback Attention for Simultaneous Machine Translation
(https://arxiv.org/abs/1906.05218)
Delays are monotonic steps, range from 0 to src_len-1.
(In the original paper thery are from 1 to src_len)
Give src x tgt y, AP is calculated as:
DAL = 1 / |Y| sum_i^|Y| delays'_i - (i - 1) / gamma
Where
delays'_i =
1. delays_i if i == 1
2. max(delays_i, delays'_{i-1} + 1 / gamma)
"""
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
tgt_len, bsz = delays.size()
gamma = tgt_lens / src_lens
new_delays = torch.zeros_like(delays)
for i in range(delays.size(0)):
if i == 0:
new_delays[i] = delays[i]
else:
new_delays[i] = torch.cat(
[
new_delays[i - 1].unsqueeze(0) + 1 / gamma,
delays[i].unsqueeze(0),
],
dim=0,
).max(dim=0)[0]
DAL = (
new_delays
- torch.arange(delays.size(0))
.unsqueeze(1)
.type_as(delays)
.expand_as(delays)
/ gamma
)
if target_padding_mask is not None:
DAL = DAL.masked_fill(target_padding_mask, 0)
DAL = DAL.sum(dim=0, keepdim=True) / tgt_lens
return DAL
class LatencyMetricVariance(LatencyMetric):
def prepare_latency_metric(
self,
delays,
src_lens,
target_padding_mask=None,
batch_first: bool = True,
start_from_zero: bool = True,
):
assert batch_first
assert len(delays.size()) == 3
assert len(src_lens.size()) == 2
if start_from_zero:
delays = delays + 1
# convert to batch_last
bsz, num_heads_x_layers, tgt_len = delays.size()
bsz_1, _ = src_lens.size()
assert bsz == bsz_1
if target_padding_mask is not None:
bsz_2, tgt_len_1 = target_padding_mask.size()
assert tgt_len == tgt_len_1
assert bsz == bsz_2
if target_padding_mask is None:
tgt_lens = tgt_len * delays.new_ones([bsz, tgt_len]).float()
else:
# batch_size, 1
tgt_lens = self.length_from_padding_mask(target_padding_mask, True).float()
delays = delays.masked_fill(target_padding_mask.unsqueeze(1), 0)
return delays, src_lens, tgt_lens, target_padding_mask
class VarianceDelay(LatencyMetricVariance):
@staticmethod
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask):
"""
delays : bsz, num_heads_x_layers, tgt_len
src_lens : bsz, 1
target_lens : bsz, 1
target_padding_mask: bsz, tgt_len or None
"""
if delays.size(1) == 1:
return delays.new_zeros([1])
variance_delays = delays.var(dim=1)
if target_padding_mask is not None:
variance_delays.masked_fill_(target_padding_mask, 0)
return variance_delays.sum(dim=1, keepdim=True) / tgt_lens
class LatencyInference(object):
def __init__(self, start_from_zero=True):
self.metric_calculator = {
"differentiable_average_lagging": DifferentiableAverageLagging(),
"average_lagging": AverageLagging(),
"average_proportion": AverageProportion(),
}
self.start_from_zero = start_from_zero
def __call__(self, monotonic_step, src_lens):
"""
monotonic_step range from 0 to src_len. src_len means eos
delays: bsz, tgt_len
src_lens: bsz, 1
"""
if not self.start_from_zero:
monotonic_step -= 1
src_lens = src_lens
delays = monotonic_step.view(
monotonic_step.size(0), -1, monotonic_step.size(-1)
).max(dim=1)[0]
delays = delays.masked_fill(delays >= src_lens, 0) + (src_lens - 1).expand_as(
delays
).masked_fill(delays < src_lens, 0)
return_dict = {}
for key, func in self.metric_calculator.items():
return_dict[key] = func(
delays.float(),
src_lens.float(),
target_padding_mask=None,
batch_first=True,
start_from_zero=True,
).t()
return return_dict
class LatencyTraining(object):
def __init__(
self,
avg_weight,
var_weight,
avg_type,
var_type,
stay_on_last_token,
average_method,
):
self.avg_weight = avg_weight
self.var_weight = var_weight
self.avg_type = avg_type
self.var_type = var_type
self.stay_on_last_token = stay_on_last_token
self.average_method = average_method
self.metric_calculator = {
"differentiable_average_lagging": DifferentiableAverageLagging(),
"average_lagging": AverageLagging(),
"average_proportion": AverageProportion(),
}
self.variance_calculator = {
"variance_delay": VarianceDelay(),
}
def expected_delays_from_attention(
self, attention, source_padding_mask=None, target_padding_mask=None
):
if type(attention) == list:
# bsz, num_heads, tgt_len, src_len
bsz, num_heads, tgt_len, src_len = attention[0].size()
attention = torch.cat(attention, dim=1)
bsz, num_heads_x_layers, tgt_len, src_len = attention.size()
# bsz * num_heads * num_layers, tgt_len, src_len
attention = attention.view(-1, tgt_len, src_len)
else:
# bsz * num_heads * num_layers, tgt_len, src_len
bsz, tgt_len, src_len = attention.size()
num_heads_x_layers = 1
attention = attention.view(-1, tgt_len, src_len)
if not self.stay_on_last_token:
residual_attention = 1 - attention[:, :, :-1].sum(dim=2, keepdim=True)
attention = torch.cat([attention[:, :, :-1], residual_attention], dim=2)
# bsz * num_heads_x_num_layers, tgt_len, src_len for MMA
steps = (
torch.arange(1, 1 + src_len)
.unsqueeze(0)
.unsqueeze(1)
.expand_as(attention)
.type_as(attention)
)
if source_padding_mask is not None:
src_offset = (
source_padding_mask.type_as(attention)
.sum(dim=1, keepdim=True)
.expand(bsz, num_heads_x_layers)
.contiguous()
.view(-1, 1)
)
src_lens = src_len - src_offset
if source_padding_mask[:, 0].any():
# Pad left
src_offset = src_offset.view(-1, 1, 1)
steps = steps - src_offset
steps = steps.masked_fill(steps <= 0, 0)
else:
src_lens = attention.new_ones([bsz, num_heads_x_layers]) * src_len
src_lens = src_lens.view(-1, 1)
# bsz * num_heads_num_layers, tgt_len, src_len
expected_delays = (
(steps * attention).sum(dim=2).view(bsz, num_heads_x_layers, tgt_len)
)
if target_padding_mask is not None:
expected_delays.masked_fill_(target_padding_mask.unsqueeze(1), 0)
return expected_delays, src_lens
def avg_loss(self, expected_delays, src_lens, target_padding_mask):
bsz, num_heads_x_layers, tgt_len = expected_delays.size()
target_padding_mask = (
target_padding_mask.unsqueeze(1)
.expand_as(expected_delays)
.contiguous()
.view(-1, tgt_len)
)
if self.average_method == "average":
# bsz * tgt_len
expected_delays = expected_delays.mean(dim=1)
elif self.average_method == "weighted_average":
weights = torch.nn.functional.softmax(expected_delays, dim=1)
expected_delays = torch.sum(expected_delays * weights, dim=1)
elif self.average_method == "max":
# bsz * num_heads_x_num_layers, tgt_len
expected_delays = expected_delays.max(dim=1)[0]
else:
raise RuntimeError(f"{self.average_method} is not supported")
src_lens = src_lens.view(bsz, -1)[:, :1]
target_padding_mask = target_padding_mask.view(bsz, -1, tgt_len)[:, 0]
if self.avg_weight > 0.0:
if self.avg_type in self.metric_calculator:
average_delays = self.metric_calculator[self.avg_type](
expected_delays,
src_lens,
target_padding_mask,
batch_first=True,
start_from_zero=False,
)
else:
raise RuntimeError(f"{self.avg_type} is not supported.")
# bsz * num_heads_x_num_layers, 1
return self.avg_weight * average_delays.sum()
else:
return 0.0
def var_loss(self, expected_delays, src_lens, target_padding_mask):
src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[
:, :1
]
if self.var_weight > 0.0:
if self.var_type in self.variance_calculator:
variance_delays = self.variance_calculator[self.var_type](
expected_delays,
src_lens,
target_padding_mask,
batch_first=True,
start_from_zero=False,
)
else:
raise RuntimeError(f"{self.var_type} is not supported.")
return self.var_weight * variance_delays.sum()
else:
return 0.0
def loss(self, attention, source_padding_mask=None, target_padding_mask=None):
expected_delays, src_lens = self.expected_delays_from_attention(
attention, source_padding_mask, target_padding_mask
)
latency_loss = 0
latency_loss += self.avg_loss(expected_delays, src_lens, target_padding_mask)
latency_loss += self.var_loss(expected_delays, src_lens, target_padding_mask)
return latency_loss
# Speech Recognition
`examples/speech_recognition` is implementing ASR task in Fairseq, along with needed features, datasets, models and loss functions to train and infer model described in [Transformers with convolutional context for ASR (Abdelrahman Mohamed et al., 2019)](https://arxiv.org/abs/1904.11660).
## Additional dependencies
On top of main fairseq dependencies there are couple more additional requirements.
1) Please follow the instructions to install [torchaudio](https://github.com/pytorch/audio). This is required to compute audio fbank features.
2) [Sclite](http://www1.icsi.berkeley.edu/Speech/docs/sctk-1.2/sclite.htm#sclite_name_0) is used to measure WER. Sclite can be downloaded and installed from source from sctk package [here](http://www.openslr.org/4/). Training and inference doesn't require Sclite dependency.
3) [sentencepiece](https://github.com/google/sentencepiece) is required in order to create dataset with word-piece targets.
## Preparing librispeech data
```
./examples/speech_recognition/datasets/prepare-librispeech.sh $DIR_TO_SAVE_RAW_DATA $DIR_FOR_PREPROCESSED_DATA
```
## Training librispeech data
```
python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 80 --task speech_recognition --arch vggtransformer_2 --optimizer adadelta --lr 1.0 --adadelta-eps 1e-8 --adadelta-rho 0.95 --clip-norm 10.0 --max-tokens 5000 --log-format json --log-interval 1 --criterion cross_entropy_acc --user-dir examples/speech_recognition/
```
## Inference for librispeech
`$SET` can be `test_clean` or `test_other`
Any checkpoint in `$MODEL_PATH` can be selected. In this example we are working with `checkpoint_last.pt`
```
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --max-tokens 25000 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --beam 20 --results-path $RES_DIR --batch-size 40 --gen-subset $SET --user-dir examples/speech_recognition/
```
## Inference for librispeech
```
sclite -r ${RES_DIR}/ref.word-checkpoint_last.pt-${SET}.txt -h ${RES_DIR}/hypo.word-checkpoint_last.pt-${SET}.txt -i rm -o all stdout > $RES_REPORT
```
`Sum/Avg` row from first table of the report has WER
## Using wav2letter components
[wav2letter](https://github.com/facebookresearch/wav2letter) now has integration with fairseq. Currently this includes:
* AutoSegmentationCriterion (ASG)
* wav2letter-style Conv/GLU model
* wav2letter's beam search decoder
To use these, follow the instructions on [this page](https://github.com/facebookresearch/wav2letter/tree/master/bindings/python) to install python bindings. Please note that python bindings are for a *subset* of wav2letter and don't require its full dependencies (notably, `flashlight` and `ArrayFire` are *not* required).
To quickly summarize the instructions: first, install [CUDA](https://developer.nvidia.com/cuda-downloads). Then follow these steps:
```
# additional prerequisites - use equivalents for your distro
sudo apt-get install build-essential cmake libatlas-base-dev libfftw3-dev liblzma-dev libbz2-dev libzstd-dev
# install KenLM from source
git clone https://github.com/kpu/kenlm.git
cd kenlm
mkdir -p build && cd build
cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_POSITION_INDEPENDENT_CODE=ON
make -j16
cd ..
export KENLM_ROOT_DIR=$(pwd)
cd ..
# install wav2letter python bindings
git clone https://github.com/facebookresearch/wav2letter.git
cd wav2letter/bindings/python
# make sure your python environment is active at this point
pip install torch packaging
pip install -e .
# try some examples to verify installation succeeded
python ./examples/criterion_example.py
python ./examples/decoder_example.py ../../src/decoder/test
python ./examples/feature_example.py ../../src/feature/test/data
```
## Training librispeech data (wav2letter style, Conv/GLU + ASG loss)
Training command:
```
python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 100 --task speech_recognition --arch w2l_conv_glu_enc --batch-size 4 --optimizer sgd --lr 0.3,0.8 --momentum 0.8 --clip-norm 0.2 --max-tokens 50000 --log-format json --log-interval 100 --num-workers 0 --sentence-avg --criterion asg_loss --asg-transitions-init 5 --max-replabel 2 --linseg-updates 8789 --user-dir examples/speech_recognition
```
Note that ASG loss currently doesn't do well with word-pieces. You should prepare a dataset with character targets by setting `nbpe=31` in `prepare-librispeech.sh`.
## Inference for librispeech (wav2letter decoder, n-gram LM)
Inference command:
```
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder kenlm --kenlm-model $KENLM_MODEL_PATH --lexicon $LEXICON_PATH --beam 200 --beam-threshold 15 --lm-weight 1.5 --word-score 1.5 --sil-weight -0.3 --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
```
`$KENLM_MODEL_PATH` should be a standard n-gram language model file. `$LEXICON_PATH` should be a wav2letter-style lexicon (list of known words and their spellings). For ASG inference, a lexicon line should look like this (note the repetition labels):
```
doorbell D O 1 R B E L 1 ▁
```
For CTC inference with word-pieces, repetition labels are not used and the lexicon should have most common spellings for each word (one can use sentencepiece's `NBestEncodeAsPieces` for this):
```
doorbell ▁DOOR BE LL
doorbell ▁DOOR B E LL
doorbell ▁DO OR BE LL
doorbell ▁DOOR B EL L
doorbell ▁DOOR BE L L
doorbell ▁DO OR B E LL
doorbell ▁DOOR B E L L
doorbell ▁DO OR B EL L
doorbell ▁DO O R BE LL
doorbell ▁DO OR BE L L
```
Lowercase vs. uppercase matters: the *word* should match the case of the n-gram language model (i.e. `$KENLM_MODEL_PATH`), while the *spelling* should match the case of the token dictionary (i.e. `$DIR_FOR_PREPROCESSED_DATA/dict.txt`).
## Inference for librispeech (wav2letter decoder, viterbi only)
Inference command:
```
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder viterbi --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment