Commit 72f5785f authored by huaerkl's avatar huaerkl
Browse files

v1.0

parents
Pipeline #505 canceled with stages
import json
import os
import tqdm
from fairseq.data import Dictionary, data_utils
def load_dictionary(dict_path):
return Dictionary.load(dict_path)
def load_dataset(split_path, src_dict):
dataset = data_utils.load_indexed_dataset(
split_path,
src_dict,
combine=False, # set to true for loading `train*`
)
if dataset is None:
raise FileNotFoundError(f"Dataset not found: {split_path}")
return dataset
def load_bpe(enc_path):
with open(enc_path) as f:
bpe2idx = json.load(f)
idx2bpe = {v: k for k, v in bpe2idx.items()}
return bpe2idx, idx2bpe
def detokenize(tokens, src_dict, idx2bpe):
raw_inds = map(int, src_dict.string(tokens).split())
raw_chrs = "".join([idx2bpe[raw_ind] for raw_ind in raw_inds])
raw_chrs = raw_chrs.replace("\u0120", " ")
return raw_chrs
def _main(src_root, src_dict_path, src_bpe_path, src_splits, tgt_root, tgt_splits):
src_dict = load_dictionary(src_dict_path)
bpe2idx, idx2bpe = load_bpe(src_bpe_path)
assert len(src_splits) == len(tgt_splits)
for src_split, tgt_split in zip(src_splits, tgt_splits):
src_dataset = load_dataset(f"{src_root}/{src_split}", src_dict)
tgt_path = f"{tgt_root}/{tgt_split}.txt"
print(f"processing {src_split} (dump to {tgt_path})...")
os.makedirs(os.path.dirname(tgt_path), exist_ok=True)
with open(tgt_path, "w") as f:
for tokens in tqdm.tqdm(src_dataset):
raw_str = detokenize(tokens, src_dict, idx2bpe)
f.write(raw_str + "\n")
def main_pt():
src_root = "/datasets01/bookwiki_CC-NEWS_openwebtext_stories-mmap2-bin/121219/bookwiki_CC-NEWS_openwebtext_stories-mmap2-bin"
src_dict_path = f"{src_root}/dict.txt"
src_bpe_path = f"{src_root}/encoder.json"
src_splits = [
"bookwiki_aml-mmap2-bin/shard0/train",
"bookwiki_aml-mmap2-bin/shard1/train",
"bookwiki_aml-mmap2-bin/shard2/train",
"bookwiki_aml-mmap2-bin/shard3/train",
"bookwiki_aml-mmap2-bin/shard4/train",
"bookwiki_aml-mmap2-bin/valid/valid",
]
tgt_root = "/checkpoint/wnhsu/data/data2vec2/data/text/bookwiki_aml-full-mmap2-txt"
tgt_splits = [
"train0",
"train1",
"train2",
"train3",
"train4",
"valid",
]
_main(src_root, src_dict_path, src_bpe_path, src_splits, tgt_root, tgt_splits)
def main_ft():
src_root = "/fsx-wav2vec/wnhsu/data/data2vec2/data/text/GLUE"
src_dict_path = f"{src_root}/dict.txt"
src_bpe_path = f"{src_root}/encoder.json"
src_splits = [
"CoLA-bin/input0/train",
"CoLA-bin/input0/valid",
"CoLA-bin/input0/test",
"MNLI-bin/input0/train",
"MNLI-bin/input0/valid",
"MNLI-bin/input0/test",
"MNLI-bin/input0/test1",
"MNLI-bin/input1/train",
"MNLI-bin/input1/valid",
"MNLI-bin/input1/test",
"MNLI-bin/input1/test1",
"MRPC-bin/input0/train",
"MRPC-bin/input0/valid",
"MRPC-bin/input0/test",
"MRPC-bin/input1/train",
"MRPC-bin/input1/valid",
"MRPC-bin/input1/test",
"QNLI-bin/input0/train",
"QNLI-bin/input0/valid",
"QNLI-bin/input0/test",
"QNLI-bin/input1/train",
"QNLI-bin/input1/valid",
"QNLI-bin/input1/test",
"QQP-bin/input0/train",
"QQP-bin/input0/valid",
"QQP-bin/input0/test",
"QQP-bin/input1/train",
"QQP-bin/input1/valid",
"QQP-bin/input1/test",
"RTE-bin/input0/train",
"RTE-bin/input0/valid",
"RTE-bin/input0/test",
"RTE-bin/input1/train",
"RTE-bin/input1/valid",
"RTE-bin/input1/test",
"SST-2-bin/input0/train",
"SST-2-bin/input0/valid",
"SST-2-bin/input0/test",
"STS-B-bin/input0/train",
"STS-B-bin/input0/valid",
"STS-B-bin/input0/test",
"STS-B-bin/input1/train",
"STS-B-bin/input1/valid",
"STS-B-bin/input1/test",
]
tgt_root = "/fsx-wav2vec/wnhsu/data/data2vec2/data/text/GLUE_chr"
tgt_splits = [
"CoLA-bin/input0/train",
"CoLA-bin/input0/valid",
"CoLA-bin/input0/test",
"MNLI-bin/input0/train",
"MNLI-bin/input0/valid",
"MNLI-bin/input0/test",
"MNLI-bin/input0/test1",
"MNLI-bin/input1/train",
"MNLI-bin/input1/valid",
"MNLI-bin/input1/test",
"MNLI-bin/input1/test1",
"MRPC-bin/input0/train",
"MRPC-bin/input0/valid",
"MRPC-bin/input0/test",
"MRPC-bin/input1/train",
"MRPC-bin/input1/valid",
"MRPC-bin/input1/test",
"QNLI-bin/input0/train",
"QNLI-bin/input0/valid",
"QNLI-bin/input0/test",
"QNLI-bin/input1/train",
"QNLI-bin/input1/valid",
"QNLI-bin/input1/test",
"QQP-bin/input0/train",
"QQP-bin/input0/valid",
"QQP-bin/input0/test",
"QQP-bin/input1/train",
"QQP-bin/input1/valid",
"QQP-bin/input1/test",
"RTE-bin/input0/train",
"RTE-bin/input0/valid",
"RTE-bin/input0/test",
"RTE-bin/input1/train",
"RTE-bin/input1/valid",
"RTE-bin/input1/test",
"SST-2-bin/input0/train",
"SST-2-bin/input0/valid",
"SST-2-bin/input0/test",
"STS-B-bin/input0/train",
"STS-B-bin/input0/valid",
"STS-B-bin/input0/test",
"STS-B-bin/input1/train",
"STS-B-bin/input1/valid",
"STS-B-bin/input1/test",
]
_main(src_root, src_dict_path, src_bpe_path, src_splits, tgt_root, tgt_splits)
if __name__ == "__main__":
main_pt()
main_ft()
import os, argparse, re, json, copy, math
from collections import OrderedDict
import numpy as np
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('base', help='base log path')
parser.add_argument('--file_name', default='train.log', help='the log file name')
parser.add_argument('--target', default='valid_loss', help='target metric')
parser.add_argument('--last', type=int, default=999999999, help='print last n matches')
parser.add_argument('--last_files', type=int, default=None, help='print last x files')
parser.add_argument('--everything', action='store_true', help='print everything instead of only last match')
parser.add_argument('--path_contains', help='only consider matching file pattern')
parser.add_argument('--group_on', help='if set, groups by this metric and shows table of differences')
parser.add_argument('--epoch', help='epoch for comparison', type=int)
parser.add_argument('--skip_empty', action='store_true', help='skip empty results')
parser.add_argument('--skip_containing', help='skips entries containing this attribute')
parser.add_argument('--unique_epochs', action='store_true', help='only consider the last line fore each epoch')
parser.add_argument('--best', action='store_true', help='print the last best result')
parser.add_argument('--avg_params', help='average these params through entire log')
parser.add_argument('--extract_prev', help='extracts this metric from previous line')
parser.add_argument('--remove_metric', help='extracts this metric from previous line')
parser.add_argument('--compact', action='store_true', help='if true, just prints checkpoint <tab> best val')
parser.add_argument('--hydra', action='store_true', help='if true, uses hydra param conventions')
parser.add_argument('--best_biggest', action='store_true', help='if true, best is the biggest number, not smallest')
parser.add_argument('--key_len', type=int, default=10, help='max length of key')
parser.add_argument('--best_only', action='store_true', help='if set, only prints the best value')
parser.add_argument('--flat', action='store_true', help='just print the best results')
def main(args, print_output):
ret = {}
entries = []
def extract_metric(s, metric):
try:
j = json.loads(s)
except:
return None
if args.epoch is not None and ('epoch' not in j or j['epoch'] != args.epoch):
return None
return j[metric] if metric in j else None
def extract_params(s):
s = s.replace(args.base, '', 1)
if args.path_contains is not None:
s = s.replace(args.path_contains, '', 1)
if args.hydra:
num_matches = re.findall(r'(?:/|__)([^/:]+):(\d+\.?\d*)', s)
# str_matches = re.findall(r'(?:/|__)([^/:]+):([^\.]*[^\d\.]+)(?:/|__)', s)
str_matches = re.findall(r'(?:/|__)?((?:(?!(?:\:|__)).)+):([^\.]*[^\d\.]+\d*)(?:/|__)', s)
lr_matches = re.findall(r'optimization.(lr):\[([\d\.,]+)\]', s)
task_matches = re.findall(r'.*/(\d+)$', s)
else:
num_matches = re.findall(r'\.?([^\.]+?)(\d+(e\-\d+)?(?:\.\d+)?)(\.|$)', s)
str_matches = re.findall(r'[/\.]([^\.]*[^\d\.]+\d*)(?=\.)', s)
lr_matches = []
task_matches = []
cp_matches = re.findall(r'checkpoint(?:_\d+)?_(\d+).pt', s)
items = OrderedDict()
for m in str_matches:
if isinstance(m, tuple):
if 'checkpoint' not in m[0]:
items[m[0]] = m[1]
else:
items[m] = ''
for m in num_matches:
items[m[0]] = m[1]
for m in lr_matches:
items[m[0]] = m[1]
for m in task_matches:
items["hydra_task"] = m
for m in cp_matches:
items['checkpoint'] = m
return items
abs_best = None
sources = []
for root, _, files in os.walk(args.base):
if args.path_contains is not None and not args.path_contains in root:
continue
for f in files:
if f.endswith(args.file_name):
sources.append((root, f))
if args.last_files is not None:
sources = sources[-args.last_files:]
for root, file in sources:
with open(os.path.join(root, file), 'r') as fin:
found = []
avg = {}
prev = None
for line in fin:
line = line.rstrip()
if line.find(args.target) != -1 and (
args.skip_containing is None or line.find(args.skip_containing) == -1):
try:
idx = line.index("{")
line = line[idx:]
line_json = json.loads(line)
except:
continue
if prev is not None:
try:
prev.update(line_json)
line_json = prev
except:
pass
if args.target in line_json:
found.append(line_json)
if args.avg_params:
avg_params = args.avg_params.split(',')
for p in avg_params:
m = extract_metric(line, p)
if m is not None:
prev_v, prev_c = avg.get(p, (0, 0))
avg[p] = prev_v + float(m), prev_c + 1
if args.extract_prev:
try:
prev = json.loads(line)
except:
pass
best = None
if args.best:
curr_best = None
for i in range(len(found)):
cand_best = found[i][args.target] if args.target in found[i] else None
def cmp(a, b):
a = float(a)
b = float(b)
if args.best_biggest:
return a > b
return a < b
if cand_best is not None and not math.isnan(float(cand_best)) and (
curr_best is None or cmp(cand_best, curr_best)):
curr_best = cand_best
if abs_best is None or cmp(curr_best, abs_best):
abs_best = curr_best
best = found[i]
if args.unique_epochs or args.epoch:
last_found = []
last_epoch = None
for i in reversed(range(len(found))):
epoch = found[i]['epoch']
if args.epoch and args.epoch != epoch:
continue
if epoch != last_epoch:
last_epoch = epoch
last_found.append(found[i])
found = list(reversed(last_found))
if len(found) == 0:
if print_output and (args.last_files is not None or not args.skip_empty):
# print(root.split('/')[-1])
print(root[len(args.base):])
print('Nothing')
else:
if not print_output:
ret[root[len(args.base):]] = best
continue
if args.compact:
# print('{}\t{}'.format(root.split('/')[-1], curr_best))
print('{}\t{}'.format(root[len(args.base)+1:], curr_best))
continue
if args.group_on is None and not args.best_only:
# print(root.split('/')[-1])
print(root[len(args.base):])
if not args.everything:
if best is not None and args.group_on is None and not args.best_only and not args.flat:
print(best, '(best)')
if args.group_on is None and args.last and not args.best_only and not args.flat:
for f in found[-args.last:]:
if args.extract_prev is not None:
try:
print('{}\t{}'.format(f[args.extract_prev], f[args.target]))
except Exception as e:
print('Exception!', e)
else:
print(f)
try:
metric = found[-1][args.target] if not args.best or best is None else best[args.target]
except:
print(found[-1])
raise
if metric is not None:
entries.append((extract_params(root), metric))
else:
for f in found:
print(f)
if not args.group_on and print_output:
print()
if len(avg) > 0:
for k, (v, c) in avg.items():
print(f'{k}: {v/c}')
if args.best_only:
print(abs_best)
if args.flat:
print("\t".join(m for _, m in entries))
if args.group_on is not None:
by_val = OrderedDict()
for e, m in entries:
k = args.group_on
if k not in e:
m_keys = [x for x in e.keys() if x.startswith(k)]
if len(m_keys) == 0:
val = "False"
else:
assert len(m_keys) == 1
k = m_keys[0]
val = m_keys[0]
else:
val = e[args.group_on]
if val == "":
val = "True"
scrubbed_entry = copy.deepcopy(e)
if k in scrubbed_entry:
del scrubbed_entry[k]
if args.remove_metric and args.remove_metric in scrubbed_entry:
val += '_' + scrubbed_entry[args.remove_metric]
del scrubbed_entry[args.remove_metric]
by_val.setdefault(tuple(scrubbed_entry.items()), dict())[val] = m
distinct_vals = set()
for v in by_val.values():
distinct_vals.update(v.keys())
try:
distinct_vals = {int(d) for d in distinct_vals}
except:
print(distinct_vals)
print()
print("by_val", len(by_val))
for k,v in by_val.items():
print(k, '=>', v)
print()
# , by_val, entries)
raise
from natsort import natsorted
svals = list(map(str, natsorted(distinct_vals)))
print('{}\t{}'.format(args.group_on, '\t'.join(svals)))
sums = OrderedDict({n:[] for n in svals})
for k, v in by_val.items():
kstr = '.'.join(':'.join(x) for x in k)
vstr = ''
for mv in svals:
x = v[mv] if mv in v else ''
vstr += '\t{}'.format(round(x, 5) if isinstance(x, float) else x)
try:
sums[mv].append(float(x))
except:
pass
print('{}{}'.format(kstr[:args.key_len], vstr))
if any(len(x) > 0 for x in sums.values()):
print('min:', end='')
for v in sums.values():
min = np.min(v)
print(f'\t{round(min, 5)}', end='')
print()
print('max:', end='')
for v in sums.values():
max = np.max(v)
print(f'\t{round(max, 5)}', end='')
print()
print('avg:', end='')
for v in sums.values():
mean = np.mean(v)
print(f'\t{round(mean, 5)}', end='')
print()
print('median:', end='')
for v in sums.values():
median = np.median(v)
print(f'\t{round(median, 5)}', end='')
print()
return ret
if __name__ == "__main__":
args = parser.parse_args()
main(args, print_output=True)
\ No newline at end of file
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .image_pretraining import ImagePretrainingTask, ImagePretrainingConfig
from .image_classification import ImageClassificationTask, ImageClassificationConfig
from .mae_image_pretraining import MaeImagePretrainingTask, MaeImagePretrainingConfig
__all__ = [
"ImageClassificationTask",
"ImageClassificationConfig",
"ImagePretrainingTask",
"ImagePretrainingConfig",
"MaeImagePretrainingTask",
"MaeImagePretrainingConfig",
]
\ No newline at end of file
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
import os
import numpy as np
import math
import torch
from sklearn import metrics as sklearn_metrics
from dataclasses import dataclass
from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig
from fairseq.tasks import register_task
from fairseq.logging import metrics
from ..data.add_class_target_dataset import AddClassTargetDataset
logger = logging.getLogger(__name__)
@dataclass
class AudioClassificationConfig(AudioPretrainingConfig):
label_descriptors: str = "label_descriptors.csv"
labels: str = "lbl"
@register_task("audio_classification", dataclass=AudioClassificationConfig)
class AudioClassificationTask(AudioPretrainingTask):
""" """
cfg: AudioClassificationConfig
def __init__(
self,
cfg: AudioClassificationConfig,
):
super().__init__(cfg)
self.state.add_factory("labels", self.load_labels)
def load_labels(self):
labels = {}
path = os.path.join(self.cfg.data, self.cfg.label_descriptors)
with open(path, "r") as ldf:
for line in ldf:
if line.strip() == "":
continue
items = line.split(",")
idx = items[0]
lbl = items[1]
assert lbl not in labels, lbl
labels[lbl] = idx
return labels
@property
def labels(self):
return self.state.labels
def load_dataset(
self, split: str, task_cfg: AudioClassificationConfig = None, **kwargs
):
super().load_dataset(split, task_cfg, **kwargs)
task_cfg = task_cfg or self.cfg
data_path = self.cfg.data
label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
skipped_indices = getattr(self.datasets[split], "skipped_indices", set())
labels = []
with open(label_path, "r") as f:
for i, line in enumerate(f):
if i not in skipped_indices:
lbl_items = line.rstrip().split("\t")
labels.append([int(x) for x in lbl_items[2].split(",")])
assert len(labels) == len(self.datasets[split]), (
f"labels length ({len(labels)}) and dataset length "
f"({len(self.datasets[split])}) do not match"
)
self.datasets[split] = AddClassTargetDataset(
self.datasets[split],
labels,
multi_class=True,
add_to_input=True,
num_classes=len(self.labels),
)
def calculate_stats(self, output, target):
classes_num = target.shape[-1]
stats = []
# Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet
# acc = sklearn_metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1))
# Class-wise statistics
for k in range(classes_num):
# Average precision
avg_precision = sklearn_metrics.average_precision_score(
target[:, k], output[:, k], average=None
)
dict = {
"AP": avg_precision,
}
# # AUC
# try:
# auc = sklearn_metrics.roc_auc_score(target[:, k], output[:, k], average=None)
# except:
# auc = 0
#
# # Precisions, recalls
# (precisions, recalls, thresholds) = sklearn_metrics.precision_recall_curve(
# target[:, k], output[:, k]
# )
#
# # FPR, TPR
# (fpr, tpr, thresholds) = sklearn_metrics.roc_curve(target[:, k], output[:, k])
#
# save_every_steps = 1000 # Sample statistics to reduce size
# dict = {
# "precisions": precisions[0::save_every_steps],
# "recalls": recalls[0::save_every_steps],
# "AP": avg_precision,
# "fpr": fpr[0::save_every_steps],
# "fnr": 1.0 - tpr[0::save_every_steps],
# "auc": auc,
# # note acc is not class-wise, this is just to keep consistent with other metrics
# "acc": acc,
# }
stats.append(dict)
return stats
def valid_step(self, sample, model, criterion):
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
return loss, sample_size, logging_output
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
if "_predictions" in logging_outputs[0]:
metrics.log_concat_tensor(
"_predictions",
torch.cat([l["_predictions"].cpu() for l in logging_outputs], dim=0),
)
metrics.log_concat_tensor(
"_targets",
torch.cat([l["_targets"].cpu() for l in logging_outputs], dim=0),
)
def compute_stats(meters):
if meters["_predictions"].tensor.shape[0] < 100:
return 0
stats = self.calculate_stats(
meters["_predictions"].tensor, meters["_targets"].tensor
)
return np.nanmean([stat["AP"] for stat in stats])
metrics.log_derived("mAP", compute_stats)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os.path as osp
import logging
from dataclasses import dataclass
import torch
from torchvision import transforms
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import register_task
from fairseq.logging import metrics
try:
from ..data import ImageDataset
except:
import sys
sys.path.append("..")
from data import ImageDataset
from .image_pretraining import (
ImagePretrainingConfig,
ImagePretrainingTask,
IMG_EXTENSIONS,
)
logger = logging.getLogger(__name__)
@dataclass
class ImageClassificationConfig(ImagePretrainingConfig):
pass
@register_task("image_classification", dataclass=ImageClassificationConfig)
class ImageClassificationTask(ImagePretrainingTask):
cfg: ImageClassificationConfig
@classmethod
def setup_task(cls, cfg: ImageClassificationConfig, **kwargs):
return cls(cfg)
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
data_path = self.cfg.data
cfg = task_cfg or self.cfg
path_with_split = osp.join(data_path, split)
if osp.exists(path_with_split):
data_path = path_with_split
from timm.data import create_transform
if split == "train":
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=cfg.input_size,
is_training=True,
auto_augment="rand-m9-mstd0.5-inc1",
interpolation="bicubic",
re_prob=0.25,
re_mode="pixel",
re_count=1,
mean=cfg.normalization_mean,
std=cfg.normalization_std,
)
if not cfg.input_size > 32:
transform.transforms[0] = transforms.RandomCrop(
cfg.input_size, padding=4
)
else:
t = []
if cfg.input_size > 32:
crop_pct = 1
if cfg.input_size < 384:
crop_pct = 224 / 256
size = int(cfg.input_size / crop_pct)
t.append(
transforms.Resize(
size, interpolation=3
), # to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(cfg.input_size))
t.append(transforms.ToTensor())
t.append(
transforms.Normalize(cfg.normalization_mean, cfg.normalization_std)
)
transform = transforms.Compose(t)
logger.info(transform)
self.datasets[split] = ImageDataset(
root=data_path,
extensions=IMG_EXTENSIONS,
load_classes=True,
transform=transform,
)
for k in self.datasets.keys():
if k != split:
assert self.datasets[k].classes == self.datasets[split].classes
def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False):
model = super().build_model(model_cfg, from_checkpoint)
actualized_cfg = getattr(model, "cfg", None)
if actualized_cfg is not None:
if hasattr(actualized_cfg, "pretrained_model_args"):
model_cfg.pretrained_model_args = actualized_cfg.pretrained_model_args
return model
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
if "correct" in logging_outputs[0]:
zero = torch.scalar_tensor(0.0)
correct = sum(log.get("correct", zero) for log in logging_outputs)
metrics.log_scalar_sum("_correct", correct)
metrics.log_derived(
"accuracy",
lambda meters: 100 * meters["_correct"].sum / meters["sample_size"].sum,
)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
import sys
import os.path as osp
from dataclasses import dataclass, field
from typing import List
from omegaconf import MISSING
import torch
from torchvision import transforms
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
try:
from ..data import ImageDataset
except:
sys.path.append("..")
from data import ImageDataset
logger = logging.getLogger(__name__)
IMG_EXTENSIONS = {
".jpg",
".jpeg",
".png",
".ppm",
".bmp",
".pgm",
".tif",
".tiff",
".webp",
}
@dataclass
class ImagePretrainingConfig(FairseqDataclass):
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
input_size: int = 224
normalization_mean: List[float] = (0.485, 0.456, 0.406)
normalization_std: List[float] = (0.229, 0.224, 0.225)
@register_task("image_pretraining", dataclass=ImagePretrainingConfig)
class ImagePretrainingTask(FairseqTask):
""" """
cfg: ImagePretrainingConfig
@classmethod
def setup_task(cls, cfg: ImagePretrainingConfig, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
cfg (AudioPretrainingConfig): configuration of this task
"""
return cls(cfg)
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
data_path = self.cfg.data
cfg = task_cfg or self.cfg
path_with_split = osp.join(data_path, split)
if osp.exists(path_with_split):
data_path = path_with_split
transform = transforms.Compose(
[
transforms.ColorJitter(0.4, 0.4, 0.4),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomResizedCrop(
size=cfg.input_size,
interpolation=transforms.InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(cfg.normalization_mean),
std=torch.tensor(cfg.normalization_std),
),
]
)
logger.info(transform)
self.datasets[split] = ImageDataset(
root=data_path,
extensions=IMG_EXTENSIONS,
load_classes=False,
transform=transform,
)
@property
def source_dictionary(self):
return None
@property
def target_dictionary(self):
return None
def max_positions(self):
"""Maximum input length supported by the encoder."""
return sys.maxsize, sys.maxsize
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
import sys
import torch
from typing import Optional
from dataclasses import dataclass, field
from omegaconf import MISSING
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from fairseq.logging import metrics
try:
from ..data import MaeFinetuningImageDataset
except:
sys.path.append("..")
from data import MaeFinetuningImageDataset
logger = logging.getLogger(__name__)
@dataclass
class MaeImageClassificationConfig(FairseqDataclass):
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
input_size: int = 224
local_cache_path: Optional[str] = None
rebuild_batches: bool = True
@register_task("mae_image_classification", dataclass=MaeImageClassificationConfig)
class MaeImageClassificationTask(FairseqTask):
""" """
cfg: MaeImageClassificationConfig
@classmethod
def setup_task(cls, cfg: MaeImageClassificationConfig, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
cfg (AudioPretrainingConfig): configuration of this task
"""
return cls(cfg)
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
data_path = self.cfg.data
cfg = task_cfg or self.cfg
self.datasets[split] = MaeFinetuningImageDataset(
root=data_path,
split=split,
is_train=split == "train",
input_size=cfg.input_size,
local_cache_path=cfg.local_cache_path,
shuffle=split == "train",
)
def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False):
model = super().build_model(model_cfg, from_checkpoint)
actualized_cfg = getattr(model, "cfg", None)
if actualized_cfg is not None:
if hasattr(actualized_cfg, "pretrained_model_args"):
model_cfg.pretrained_model_args = actualized_cfg.pretrained_model_args
return model
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
if "correct" in logging_outputs[0]:
zero = torch.scalar_tensor(0.0)
correct = sum(log.get("correct", zero) for log in logging_outputs)
metrics.log_scalar_sum("_correct", correct)
metrics.log_derived(
"accuracy",
lambda meters: 100 * meters["_correct"].sum / meters["sample_size"].sum,
)
@property
def source_dictionary(self):
return None
@property
def target_dictionary(self):
return None
def max_positions(self):
"""Maximum input length supported by the encoder."""
return sys.maxsize, sys.maxsize
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
import sys
from typing import Optional, List
from dataclasses import dataclass, field
from omegaconf import MISSING, II
from fairseq.data import SubsampleDataset
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
try:
from ..data import MaeImageDataset
except:
sys.path.append("..")
from data import MaeImageDataset
logger = logging.getLogger(__name__)
@dataclass
class ImageMaskingConfig:
patch_size: int = II("model.modalities.image.patch_size")
mask_prob: float = II("model.modalities.image.mask_prob")
mask_prob_adjust: float = II("model.modalities.image.mask_prob_adjust")
mask_length: int = II("model.modalities.image.mask_length")
inverse_mask: bool = II("model.modalities.image.inverse_mask")
mask_dropout: float = II("model.modalities.image.mask_dropout")
clone_batch: int = II("model.clone_batch")
expand_adjacent: bool = False
non_overlapping: bool = False
@dataclass
class MaeImagePretrainingConfig(FairseqDataclass):
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
multi_data: Optional[List[str]] = None
input_size: int = 224
local_cache_path: Optional[str] = None
key: str = "imgs"
beit_transforms: bool = False
target_transform: bool = False
no_transform: bool = False
rebuild_batches: bool = True
precompute_mask_config: Optional[ImageMaskingConfig] = None
subsample: float = 1
seed: int = II("common.seed")
dataset_type: str = "imagefolder"
@register_task("mae_image_pretraining", dataclass=MaeImagePretrainingConfig)
class MaeImagePretrainingTask(FairseqTask):
""" """
cfg: MaeImagePretrainingConfig
@classmethod
def setup_task(cls, cfg: MaeImagePretrainingConfig, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
cfg (AudioPretrainingConfig): configuration of this task
"""
return cls(cfg)
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
data_path = self.cfg.data
cfg = task_cfg or self.cfg
compute_mask = cfg.precompute_mask_config is not None
mask_args = {}
if compute_mask:
mask_args = cfg.precompute_mask_config
self.datasets[split] = MaeImageDataset(
root=data_path if cfg.multi_data is None else cfg.multi_data,
split=split,
input_size=cfg.input_size,
local_cache_path=cfg.local_cache_path,
key=cfg.key,
beit_transforms=cfg.beit_transforms,
target_transform=cfg.target_transform,
no_transform=cfg.no_transform,
compute_mask=compute_mask,
dataset_type=cfg.dataset_type,
**mask_args,
)
if cfg.subsample < 1:
self.datasets[split] = SubsampleDataset(
self.datasets[split],
cfg.subsample,
shuffle=True,
seed=cfg.seed,
)
@property
def source_dictionary(self):
return None
@property
def target_dictionary(self):
return None
def max_positions(self):
"""Maximum input length supported by the encoder."""
return sys.maxsize, sys.maxsize
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os, sys
sys.path.append("../../../../")
from dataclasses import dataclass
from typing import Optional, List
from omegaconf import II
from fairseq.data.iterators import GroupedEpochBatchIterator
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask
from fairseq.tasks.masked_lm import MaskedLMConfig, MaskedLMTask
from .mae_image_pretraining import MaeImagePretrainingConfig, MaeImagePretrainingTask
from examples.data2vec.data.modality import Modality
from fairseq.data.audio.multi_modality_dataset import (
MultiModalityDataset,
ModalityDatasetItem,
)
@dataclass
class MultimodalPretrainingConfig(FairseqDataclass):
audio: Optional[AudioPretrainingConfig] = None
image: Optional[MaeImagePretrainingConfig] = None
text: Optional[MaskedLMConfig] = None
audio_ratio: float = 1
image_ratio: float = 1
text_ratio: float = 1
max_tokens: Optional[int] = II("dataset.max_tokens")
batch_size: Optional[int] = II("dataset.batch_size")
update_freq: List[int] = II("optimization.update_freq")
rebuild_batches: bool = True
@register_task("multimodal_pretraining", dataclass=MultimodalPretrainingConfig)
class MultimodalPretrainingTask(FairseqTask):
""" """
cfg: MultimodalPretrainingConfig
def __init__(self, cfg: MultimodalPretrainingConfig):
super().__init__(cfg)
self.audio_task = (
AudioPretrainingTask(cfg.audio) if cfg.audio is not None else None
)
self.image_task = (
MaeImagePretrainingTask(cfg.image) if cfg.image is not None else None
)
self.text_task = MaskedLMTask(cfg.text) if cfg.text is not None else None
self.mult_ratios = []
@classmethod
def setup_task(cls, cfg: MultimodalPretrainingConfig, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
cfg (AudioPretrainingConfig): configuration of this task
"""
return cls(cfg)
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
datasets = []
self.mult_ratios = []
def load_ds(task, name, ratio):
if task is not None:
task.load_dataset(split)
ds = ModalityDatasetItem(
datasetname=name,
dataset=task.dataset(split),
max_positions=task.max_positions(),
max_tokens=self.cfg.max_tokens,
max_sentences=self.cfg.batch_size,
)
datasets.append(ds)
self.mult_ratios.append(ratio)
load_ds(self.audio_task, Modality.AUDIO, self.cfg.audio_ratio)
load_ds(self.image_task, Modality.IMAGE, self.cfg.image_ratio)
load_ds(self.text_task, Modality.TEXT, self.cfg.text_ratio)
assert len(datasets) > 0
self.datasets[split] = MultiModalityDataset(datasets)
@property
def supported_modalities(self):
modalities = []
if self.cfg.text is not None:
modalities.append(Modality.TEXT)
if self.cfg.audio is not None:
modalities.append(Modality.AUDIO)
if self.cfg.image is not None:
modalities.append(Modality.IMAGE)
return modalities
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=0,
data_buffer_size=0,
disable_iterator_cache=False,
skip_remainder_batch=False,
grouped_shuffling=False,
update_epoch_batch_itr=False,
):
# initialize the dataset with the correct starting epoch
dataset.set_epoch(epoch)
batch_samplers = dataset.get_batch_samplers(
self.mult_ratios, required_batch_size_multiple, seed
)
# return a reusable, sharded iterator
epoch_iter = GroupedEpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_samplers=batch_samplers,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
mult_rate=max(self.cfg.update_freq),
buffer_size=data_buffer_size,
skip_remainder_batch=skip_remainder_batch,
)
self.dataset_to_epoch_iter[dataset] = {} # refresh it every epoch
return epoch_iter
@property
def source_dictionary(self):
return None
@property
def target_dictionary(self):
return None
def max_positions(self):
"""Maximum input length supported by the encoder."""
return sys.maxsize, sys.maxsize
# Discriminative Reranking for Neural Machine Translation
https://aclanthology.org/2021.acl-long.563/
This folder contains source code for training DrNMT, a discriminatively trained reranker for neural machine translation.
## Data preparation
1. Follow the instructions under `examples/translation` to build a base MT model. Prepare three files, one with source sentences, one with ground truth target sentences, and one with hypotheses generated from the base MT model. Each line in the file contains one sentence in raw text (i.e. no sentencepiece, etc.). Below is an example of the files with _N_ hypotheses for each source sentence.
```
# Example of the source sentence file: (The file should contain L lines.)
source_sentence_1
source_sentence_2
source_sentence_3
...
source_sentence_L
# Example of the target sentence file: (The file should contain L lines.)
target_sentence_1
target_sentence_2
target_sentence_3
...
target_sentence_L
# Example of the hypotheses file: (The file should contain L*N lines.)
source_sentence_1_hypo_1
source_sentence_1_hypo_2
...
source_sentence_1_hypo_N
source_sentence_2_hypo_1
...
source_sentence_2_hypo_N
...
source_sentence_L_hypo_1
...
source_sentence_L_hypo_N
```
2. Download the [XLMR model](https://github.com/fairinternal/fairseq-py/tree/main/examples/xlmr#pre-trained-models).
```
wget https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.tar.gz
tar zxvf xlmr.base.tar.gz
# The folder should contain dict.txt, model.pt and sentencepiece.bpe.model.
```
3. Prepare scores and BPE data.
* `N`: Number of hypotheses per each source sentence. We use 50 in the paper.
* `SPLIT`: Name of the data split, i.e. train, valid, test. Use split_name, split_name1, split_name2, ..., if there are multiple datasets for a split, e.g. train, train1, valid, valid1.
* `NUM_SHARDS`: Number of shards. Set this to 1 for non-train splits.
* `METRIC`: The metric for DrNMT to optimize for. We support either `bleu` or `ter`.
```
# For each data split, e.g. train, valid, test, etc., run the following:
SOURCE_FILE=/path/to/source_sentence_file
TARGET_FILE=/path/to/target_sentence_file
HYPO_FILE=/path/to/hypo_file
XLMR_DIR=/path/to/xlmr
OUTPUT_DIR=/path/to/output
python scripts/prep_data.py \
--input-source ${SOURCE_FILE} \
--input-target ${TARGET_FILE} \
--input-hypo ${HYPO_FILE} \
--output-dir ${OUTPUT_DIR} \
--split $SPLIT
--beam $N \
--sentencepiece-model ${XLMR_DIR}/sentencepiece.bpe.model \
--metric $METRIC \
--num-shards ${NUM_SHARDS}
# The script will create ${OUTPUT_DIR}/$METRIC with ${NUM_SHARDS} splits.
# Under split*/input_src, split*/input_tgt and split*/$METRIC, there will be $SPLIT.bpe and $SPLIT.$METRIC files, respectively.
```
4. Pre-process the data into fairseq format.
```
# use comma to separate if there are more than one train or valid set
for suffix in src tgt ; do
fairseq-preprocess --only-source \
--trainpref ${OUTPUT_DIR}/$METRIC/split1/input_${suffix}/train.bpe \
--validpref ${OUTPUT_DIR}/$METRIC/split1/input_${suffix}/valid.bpe \
--destdir ${OUTPUT_DIR}/$METRIC/split1/input_${suffix} \
--workers 60 \
--srcdict ${XLMR_DIR}/dict.txt
done
for i in `seq 2 ${NUM_SHARDS}`; do
for suffix in src tgt ; do
fairseq-preprocess --only-source \
--trainpref ${OUTPUT_DIR}/$METRIC/split${i}/input_${suffix}/train.bpe \
--destdir ${OUTPUT_DIR}/$METRIC/split${i}/input_${suffix} \
--workers 60 \
--srcdict ${XLMR_DIR}/dict.txt
ln -s ${OUTPUT_DIR}/$METRIC/split1/input_${suffix}/valid* ${OUTPUT_DIR}/$METRIC/split${i}/input_${suffix}/.
done
ln -s ${OUTPUT_DIR}/$METRIC/split1/$METRIC/valid* ${OUTPUT_DIR}/$METRIC/split${i}/$METRIC/.
done
```
## Training
```
EXP_DIR=/path/to/exp
# An example of training the model with the config for De-En experiment in the paper.
# The config uses 16 GPUs and 50 hypotheses.
# For training with fewer number of GPUs, set
# distributed_training.distributed_world_size=k +optimization.update_freq='[x]' where x = 16/k
# For training with fewer number of hypotheses, set
# task.mt_beam=N dataset.batch_size=N dataset.required_batch_size_multiple=N
fairseq-hydra-train -m \
--config-dir config/ --config-name deen \
task.data=${OUTPUT_DIR}/$METRIC/split1/ \
task.num_data_splits=${NUM_SHARDS} \
model.pretrained_model=${XLMR_DIR}/model.pt \
common.user_dir=${FAIRSEQ_ROOT}/examples/discriminative_reranking_nmt \
checkpoint.save_dir=${EXP_DIR}
```
## Inference & scoring
Perform DrNMT reranking (fw + reranker score)
1. Tune weights on valid sets.
```
# genrate N hypotheses with the base MT model (fw score)
VALID_SOURCE_FILE=/path/to/source_sentences # one sentence per line, converted to the sentencepiece used by the base MT model
VALID_TARGET_FILE=/path/to/target_sentences # one sentence per line in raw text, i.e. no sentencepiece and tokenization
MT_MODEL=/path/to/mt_model
MT_DATA_PATH=/path/to/mt_data
cat ${VALID_SOURCE_FILE} | \
fairseq-interactive ${MT_DATA_PATH} \
--max-tokens 4000 --buffer-size 16 \
--num-workers 32 --path ${MT_MODEL} \
--beam $N --nbest $N \
--post-process sentencepiece &> valid-hypo.out
# replace "bleu" with "ter" to optimize for TER
python drnmt_rerank.py \
${OUTPUT_DIR}/$METRIC/split1/ \
--path ${EXP_DIR}/checkpoint_best.pt \
--in-text valid-hypo.out \
--results-path ${EXP_DIR} \
--gen-subset valid \
--target-text ${VALID_TARGET_FILE} \
--user-dir ${FAIRSEQ_ROOT}/examples/discriminative_reranking_nmt \
--bpe sentencepiece \
--sentencepiece-model ${XLMR_DIR}/sentencepiece.bpe.model \
--beam $N \
--batch-size $N \
--metric bleu \
--tune
```
2. Apply best weights on test sets
```
# genrate N hypotheses with the base MT model (fw score)
TEST_SOURCE_FILE=/path/to/source_sentences # one sentence per line, converted to the sentencepiece used by the base MT model
cat ${TEST_SOURCE_FILE} | \
fairseq-interactive ${MT_DATA_PATH} \
--max-tokens 4000 --buffer-size 16 \
--num-workers 32 --path ${MT_MODEL} \
--beam $N --nbest $N \
--post-process sentencepiece &> test-hypo.out
# replace "bleu" with "ter" to evaluate TER
# Add --target-text for evaluating BLEU/TER,
# otherwise the script will only generate the hypotheses with the highest scores only.
python drnmt_rerank.py \
${OUTPUT_DIR}/$METRIC/split1/ \
--path ${EXP_DIR}/checkpoint_best.pt \
--in-text test-hypo.out \
--results-path ${EXP_DIR} \
--gen-subset test \
--user-dir ${FAIRSEQ_ROOT}/examples/discriminative_reranking_nmt \
--bpe sentencepiece \
--sentencepiece-model ${XLMR_DIR}/sentencepiece.bpe.model \
--beam $N \
--batch-size $N \
--metric bleu \
--fw-weight ${BEST_FW_WEIGHT} \
--lenpen ${BEST_LENPEN}
```
## Citation
```bibtex
@inproceedings{lee2021discriminative,
title={Discriminative Reranking for Neural Machine Translation},
author={Lee, Ann and Auli, Michael and Ranzato, Marc'Aurelio},
booktitle={ACL},
year={2021}
}
```
from . import criterions, models, tasks # noqa
# @package _group_
common:
fp16: true
log_format: json
log_interval: 50
seed: 2
checkpoint:
no_epoch_checkpoints: true
best_checkpoint_metric: bleu
maximize_best_checkpoint_metric: true
task:
_name: discriminative_reranking_nmt
data: ???
num_data_splits: ???
include_src: true
mt_beam: 50
eval_target_metric: true
target_metric: bleu
dataset:
batch_size: 50
num_workers: 6
required_batch_size_multiple: 50
valid_subset: ???
criterion:
_name: kl_divergence_rereanking
target_dist_norm: minmax
temperature: 0.5
optimization:
max_epoch: 200
lr: [0.00005]
update_freq: [32]
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-06
lr_scheduler:
_name: polynomial_decay
warmup_updates: 8000
total_num_update: 320000
model:
_name: discriminative_nmt_reranker
pretrained_model: ???
classifier_dropout: 0.2
distributed_training:
ddp_backend: no_c10d
distributed_world_size: 16
from .discriminative_reranking_criterion import KLDivergenceRerankingCriterion
__all__ = [
"KLDivergenceRerankingCriterion",
]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.logging import metrics
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
_EPSILON = torch.finfo(torch.float32).eps
TARGET_DIST_NORM_CHOICES = ChoiceEnum(["none", "minmax"])
@dataclass
class KLDivergenceRerankingCriterionConfig(FairseqDataclass):
target_dist_norm: TARGET_DIST_NORM_CHOICES = field(
default="none",
metadata={"help": "method to normalize the range of target scores"},
)
temperature: float = field(
default=1.0,
metadata={"help": "temperature in softmax for target distributions"},
)
forward_batch_size: int = field(
default=32,
metadata={
"help": "number of hypotheses per batch for model forward (set a value smaller than --mt-beam to avoid OOM when training with a large beam size)"
},
)
@register_criterion(
"kl_divergence_rereanking", dataclass=KLDivergenceRerankingCriterionConfig
)
class KLDivergenceRerankingCriterion(FairseqCriterion):
def __init__(
self, task, target_dist_norm, temperature, forward_batch_size,
):
super().__init__(task)
self.target_dist_norm = target_dist_norm
self.temperature = temperature
self.forward_batch_size = forward_batch_size
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
sample_size = sample["id"].numel()
assert sample_size % self.task.cfg.mt_beam == 0, (
f"sample_size ({sample_size}) cannot be divided by beam size ({self.task.cfg.mt_beam})."
f"Please set --required-batch-size-multiple={self.task.cfg.mt_beam}."
)
# split into smaller batches for model forward
batch_out = []
for i in range(0, sample_size, self.forward_batch_size):
j = min(i + self.forward_batch_size, sample_size)
out = model(
src_tokens=sample["net_input"]["src_tokens"][i:j, :],
src_lengths=sample["net_input"]["src_lengths"][i:j],
)
batch_out.append(
model.sentence_forward(out, sample["net_input"]["src_tokens"][i:j, :])
)
batch_out = torch.cat(batch_out, dim=0).view(
self.task.cfg.mt_beam, sample_size // self.task.cfg.mt_beam, -1
) # T x B x C
if model.joint_classification == "sent":
batch_out = model.joint_forward(batch_out)
scores = model.classification_forward(batch_out.view(sample_size, 1, -1)).view(
-1, self.task.cfg.mt_beam
) # input: B x T x C
loss = self.compute_kl_loss(
scores, sample["target"][:, 0].view(-1, self.task.cfg.mt_beam)
)
sample_size = sample_size // self.task.cfg.mt_beam
logging_output = {
"loss": loss.detach(),
"ntokens": sample["ntokens"],
"nsentences": sample_size * self.task.cfg.mt_beam,
"sample_size": sample_size,
"scores": scores.detach(),
}
return loss, sample_size, logging_output
def compute_kl_loss(self, logits, target):
norm_target = target
if self.target_dist_norm == "minmax":
min_v = torch.min(target, 1, keepdim=True).values
max_v = torch.max(target, 1, keepdim=True).values
norm_target = (target - min_v) / (max_v - min_v + _EPSILON)
target_dist = F.softmax(
norm_target / self.temperature, dim=-1, dtype=torch.float32
)
model_dist = F.log_softmax(logits, dim=-1, dtype=torch.float32)
loss = -(target_dist * model_dist - target_dist * target_dist.log()).sum()
return loss
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
loss = loss_sum / sample_size / math.log(2)
metrics.log_scalar("loss", loss, sample_size, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Score raw text with a trained model.
"""
from collections import namedtuple
import logging
from multiprocessing import Pool
import sys
import os
import random
import numpy as np
import sacrebleu
import torch
from fairseq import checkpoint_utils, options, utils
logger = logging.getLogger("fairseq_cli.drnmt_rerank")
logger.setLevel(logging.INFO)
Batch = namedtuple("Batch", "ids src_tokens src_lengths")
pool_init_variables = {}
def init_loaded_scores(mt_scores, model_scores, hyp, ref):
global pool_init_variables
pool_init_variables["mt_scores"] = mt_scores
pool_init_variables["model_scores"] = model_scores
pool_init_variables["hyp"] = hyp
pool_init_variables["ref"] = ref
def parse_fairseq_gen(filename, task):
source = {}
hypos = {}
scores = {}
with open(filename, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("S-"): # source
uid, text = line.split("\t", 1)
uid = int(uid[2:])
source[uid] = text
elif line.startswith("D-"): # hypo
uid, score, text = line.split("\t", 2)
uid = int(uid[2:])
if uid not in hypos:
hypos[uid] = []
scores[uid] = []
hypos[uid].append(text)
scores[uid].append(float(score))
else:
continue
source_out = [source[i] for i in range(len(hypos))]
hypos_out = [h for i in range(len(hypos)) for h in hypos[i]]
scores_out = [s for i in range(len(scores)) for s in scores[i]]
return source_out, hypos_out, scores_out
def read_target(filename):
with open(filename, "r", encoding="utf-8") as f:
output = [line.strip() for line in f]
return output
def make_batches(args, src, hyp, task, max_positions, encode_fn):
assert len(src) * args.beam == len(
hyp
), f"Expect {len(src) * args.beam} hypotheses for {len(src)} source sentences with beam size {args.beam}. Got {len(hyp)} hypotheses intead."
hyp_encode = [
task.source_dictionary.encode_line(encode_fn(h), add_if_not_exist=False).long()
for h in hyp
]
if task.cfg.include_src:
src_encode = [
task.source_dictionary.encode_line(
encode_fn(s), add_if_not_exist=False
).long()
for s in src
]
tokens = [(src_encode[i // args.beam], h) for i, h in enumerate(hyp_encode)]
lengths = [(t1.numel(), t2.numel()) for t1, t2 in tokens]
else:
tokens = [(h,) for h in hyp_encode]
lengths = [(h.numel(),) for h in hyp_encode]
itr = task.get_batch_iterator(
dataset=task.build_dataset_for_inference(tokens, lengths),
max_tokens=args.max_tokens,
max_sentences=args.batch_size,
max_positions=max_positions,
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
).next_epoch_itr(shuffle=False)
for batch in itr:
yield Batch(
ids=batch["id"],
src_tokens=batch["net_input"]["src_tokens"],
src_lengths=batch["net_input"]["src_lengths"],
)
def decode_rerank_scores(args):
if args.max_tokens is None and args.batch_size is None:
args.batch_size = 1
logger.info(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Load ensemble
logger.info("loading model(s) from {}".format(args.path))
models, _model_args, task = checkpoint_utils.load_model_ensemble_and_task(
[args.path], arg_overrides=eval(args.model_overrides),
)
for model in models:
if args.fp16:
model.half()
if use_cuda:
model.cuda()
# Initialize generator
generator = task.build_generator(args)
# Handle tokenization and BPE
tokenizer = task.build_tokenizer(args)
bpe = task.build_bpe(args)
def encode_fn(x):
if tokenizer is not None:
x = tokenizer.encode(x)
if bpe is not None:
x = bpe.encode(x)
return x
max_positions = utils.resolve_max_positions(
task.max_positions(), *[model.max_positions() for model in models]
)
src, hyp, mt_scores = parse_fairseq_gen(args.in_text, task)
model_scores = {}
logger.info("decode reranker score")
for batch in make_batches(args, src, hyp, task, max_positions, encode_fn):
src_tokens = batch.src_tokens
src_lengths = batch.src_lengths
if use_cuda:
src_tokens = src_tokens.cuda()
src_lengths = src_lengths.cuda()
sample = {
"net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths},
}
scores = task.inference_step(generator, models, sample)
for id, sc in zip(batch.ids.tolist(), scores.tolist()):
model_scores[id] = sc[0]
model_scores = [model_scores[i] for i in range(len(model_scores))]
return src, hyp, mt_scores, model_scores
def get_score(mt_s, md_s, w1, lp, tgt_len):
return mt_s / (tgt_len ** lp) * w1 + md_s
def get_best_hyps(mt_scores, md_scores, hypos, fw_weight, lenpen, beam):
assert len(mt_scores) == len(md_scores) and len(mt_scores) == len(hypos)
hypo_scores = []
best_hypos = []
best_scores = []
offset = 0
for i in range(len(hypos)):
tgt_len = len(hypos[i].split())
hypo_scores.append(
get_score(mt_scores[i], md_scores[i], fw_weight, lenpen, tgt_len)
)
if (i + 1) % beam == 0:
max_i = np.argmax(hypo_scores)
best_hypos.append(hypos[offset + max_i])
best_scores.append(hypo_scores[max_i])
hypo_scores = []
offset += beam
return best_hypos, best_scores
def eval_metric(args, hypos, ref):
if args.metric == "bleu":
score = sacrebleu.corpus_bleu(hypos, [ref]).score
else:
score = sacrebleu.corpus_ter(hypos, [ref]).score
return score
def score_target_hypo(args, fw_weight, lp):
mt_scores = pool_init_variables["mt_scores"]
model_scores = pool_init_variables["model_scores"]
hyp = pool_init_variables["hyp"]
ref = pool_init_variables["ref"]
best_hypos, _ = get_best_hyps(
mt_scores, model_scores, hyp, fw_weight, lp, args.beam
)
rerank_eval = None
if ref:
rerank_eval = eval_metric(args, best_hypos, ref)
print(f"fw_weight {fw_weight}, lenpen {lp}, eval {rerank_eval}")
return rerank_eval
def print_result(best_scores, best_hypos, output_file):
for i, (s, h) in enumerate(zip(best_scores, best_hypos)):
print(f"{i}\t{s}\t{h}", file=output_file)
def main(args):
utils.import_user_module(args)
src, hyp, mt_scores, model_scores = decode_rerank_scores(args)
assert (
not args.tune or args.target_text is not None
), "--target-text has to be set when tuning weights"
if args.target_text:
ref = read_target(args.target_text)
assert len(src) == len(
ref
), f"different numbers of source and target sentences ({len(src)} vs. {len(ref)})"
orig_best_hypos = [hyp[i] for i in range(0, len(hyp), args.beam)]
orig_eval = eval_metric(args, orig_best_hypos, ref)
if args.tune:
logger.info("tune weights for reranking")
random_params = np.array(
[
[
random.uniform(
args.lower_bound_fw_weight, args.upper_bound_fw_weight
),
random.uniform(args.lower_bound_lenpen, args.upper_bound_lenpen),
]
for k in range(args.num_trials)
]
)
logger.info("launching pool")
with Pool(
32,
initializer=init_loaded_scores,
initargs=(mt_scores, model_scores, hyp, ref),
) as p:
rerank_scores = p.starmap(
score_target_hypo,
[
(args, random_params[i][0], random_params[i][1],)
for i in range(args.num_trials)
],
)
if args.metric == "bleu":
best_index = np.argmax(rerank_scores)
else:
best_index = np.argmin(rerank_scores)
best_fw_weight = random_params[best_index][0]
best_lenpen = random_params[best_index][1]
else:
assert (
args.lenpen is not None and args.fw_weight is not None
), "--lenpen and --fw-weight should be set"
best_fw_weight, best_lenpen = args.fw_weight, args.lenpen
best_hypos, best_scores = get_best_hyps(
mt_scores, model_scores, hyp, best_fw_weight, best_lenpen, args.beam
)
if args.results_path is not None:
os.makedirs(args.results_path, exist_ok=True)
output_path = os.path.join(
args.results_path, "generate-{}.txt".format(args.gen_subset),
)
with open(output_path, "w", buffering=1, encoding="utf-8") as o:
print_result(best_scores, best_hypos, o)
else:
print_result(best_scores, best_hypos, sys.stdout)
if args.target_text:
rerank_eval = eval_metric(args, best_hypos, ref)
print(f"before reranking, {args.metric.upper()}:", orig_eval)
print(
f"after reranking with fw_weight={best_fw_weight}, lenpen={best_lenpen}, {args.metric.upper()}:",
rerank_eval,
)
def cli_main():
parser = options.get_generation_parser(interactive=True)
parser.add_argument(
"--in-text",
default=None,
required=True,
help="text from fairseq-interactive output, containing source sentences and hypotheses",
)
parser.add_argument("--target-text", default=None, help="reference text")
parser.add_argument("--metric", type=str, choices=["bleu", "ter"], default="bleu")
parser.add_argument(
"--tune",
action="store_true",
help="if set, tune weights on fw scores and lenpen instead of applying fixed weights for reranking",
)
parser.add_argument(
"--lower-bound-fw-weight",
default=0.0,
type=float,
help="lower bound of search space",
)
parser.add_argument(
"--upper-bound-fw-weight",
default=3,
type=float,
help="upper bound of search space",
)
parser.add_argument(
"--lower-bound-lenpen",
default=0.0,
type=float,
help="lower bound of search space",
)
parser.add_argument(
"--upper-bound-lenpen",
default=3,
type=float,
help="upper bound of search space",
)
parser.add_argument(
"--fw-weight", type=float, default=None, help="weight on the fw model score"
)
parser.add_argument(
"--num-trials",
default=1000,
type=int,
help="number of trials to do for random search",
)
args = options.parse_args_and_arch(parser)
main(args)
if __name__ == "__main__":
cli_main()
from .discriminative_reranking_model import DiscriminativeNMTReranker
__all__ = [
"DiscriminativeNMTReranker",
]
from dataclasses import dataclass, field
import os
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import (
BaseFairseqModel,
register_model,
)
from fairseq.models.roberta.model import RobertaClassificationHead
from fairseq.modules import (
LayerNorm,
TransformerSentenceEncoder,
TransformerSentenceEncoderLayer,
)
ACTIVATION_FN_CHOICES = ChoiceEnum(utils.get_available_activation_fns())
JOINT_CLASSIFICATION_CHOICES = ChoiceEnum(["none", "sent"])
SENTENCE_REP_CHOICES = ChoiceEnum(["head", "meanpool", "maxpool"])
def update_init_roberta_model_state(state):
"""
update the state_dict of a Roberta model for initializing
weights of the BertRanker
"""
for k in list(state.keys()):
if ".lm_head." in k or "version" in k:
del state[k]
continue
# remove 'encoder/decoder.sentence_encoder.' from the key
assert k.startswith("encoder.sentence_encoder.") or k.startswith(
"decoder.sentence_encoder."
), f"Cannot recognize parameter name {k}"
if "layernorm_embedding" in k:
new_k = k.replace(".layernorm_embedding.", ".emb_layer_norm.")
state[new_k[25:]] = state[k]
else:
state[k[25:]] = state[k]
del state[k]
class BaseRanker(nn.Module):
def __init__(self, args, task):
super().__init__()
self.separator_token = task.dictionary.eos()
self.padding_idx = task.dictionary.pad()
def forward(self, src_tokens):
raise NotImplementedError
def get_segment_labels(self, src_tokens):
segment_boundary = (src_tokens == self.separator_token).long()
segment_labels = (
segment_boundary.cumsum(dim=1)
- segment_boundary
- (src_tokens == self.padding_idx).long()
)
return segment_labels
def get_positions(self, src_tokens, segment_labels):
segment_positions = (
torch.arange(src_tokens.shape[1])
.to(src_tokens.device)
.repeat(src_tokens.shape[0], 1)
)
segment_boundary = (src_tokens == self.separator_token).long()
_, col_idx = (segment_positions * segment_boundary).nonzero(as_tuple=True)
col_idx = torch.cat([torch.zeros(1).type_as(col_idx), col_idx])
offset = torch.cat(
[
torch.zeros(1).type_as(segment_boundary),
segment_boundary.sum(dim=1).cumsum(dim=0)[:-1],
]
)
segment_positions -= col_idx[segment_labels + offset.unsqueeze(1)] * (
segment_labels != 0
)
padding_mask = src_tokens.ne(self.padding_idx)
segment_positions = (segment_positions + 1) * padding_mask.type_as(
segment_positions
) + self.padding_idx
return segment_positions
class BertRanker(BaseRanker):
def __init__(self, args, task):
super(BertRanker, self).__init__(args, task)
init_model = getattr(args, "pretrained_model", "")
self.joint_layers = nn.ModuleList()
if os.path.isfile(init_model):
print(f"initialize weight from {init_model}")
from fairseq import hub_utils
x = hub_utils.from_pretrained(
os.path.dirname(init_model),
checkpoint_file=os.path.basename(init_model),
)
in_state_dict = x["models"][0].state_dict()
init_args = x["args"].model
num_positional_emb = init_args.max_positions + task.dictionary.pad() + 1
# follow the setup in roberta
self.model = TransformerSentenceEncoder(
padding_idx=task.dictionary.pad(),
vocab_size=len(task.dictionary),
num_encoder_layers=getattr(
args, "encoder_layers", init_args.encoder_layers
),
embedding_dim=init_args.encoder_embed_dim,
ffn_embedding_dim=init_args.encoder_ffn_embed_dim,
num_attention_heads=init_args.encoder_attention_heads,
dropout=init_args.dropout,
attention_dropout=init_args.attention_dropout,
activation_dropout=init_args.activation_dropout,
num_segments=2, # add language embeddings
max_seq_len=num_positional_emb,
offset_positions_by_padding=False,
encoder_normalize_before=True,
apply_bert_init=True,
activation_fn=init_args.activation_fn,
freeze_embeddings=args.freeze_embeddings,
n_trans_layers_to_freeze=args.n_trans_layers_to_freeze,
)
# still need to learn segment embeddings as we added a second language embedding
if args.freeze_embeddings:
for p in self.model.segment_embeddings.parameters():
p.requires_grad = False
update_init_roberta_model_state(in_state_dict)
print("loading weights from the pretrained model")
self.model.load_state_dict(
in_state_dict, strict=False
) # ignore mismatch in language embeddings
ffn_embedding_dim = init_args.encoder_ffn_embed_dim
num_attention_heads = init_args.encoder_attention_heads
dropout = init_args.dropout
attention_dropout = init_args.attention_dropout
activation_dropout = init_args.activation_dropout
activation_fn = init_args.activation_fn
classifier_embed_dim = getattr(
args, "embed_dim", init_args.encoder_embed_dim
)
if classifier_embed_dim != init_args.encoder_embed_dim:
self.transform_layer = nn.Linear(
init_args.encoder_embed_dim, classifier_embed_dim
)
else:
self.model = TransformerSentenceEncoder(
padding_idx=task.dictionary.pad(),
vocab_size=len(task.dictionary),
num_encoder_layers=args.encoder_layers,
embedding_dim=args.embed_dim,
ffn_embedding_dim=args.ffn_embed_dim,
num_attention_heads=args.attention_heads,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
max_seq_len=task.max_positions()
if task.max_positions()
else args.tokens_per_sample,
num_segments=2,
offset_positions_by_padding=False,
encoder_normalize_before=args.encoder_normalize_before,
apply_bert_init=args.apply_bert_init,
activation_fn=args.activation_fn,
)
classifier_embed_dim = args.embed_dim
ffn_embedding_dim = args.ffn_embed_dim
num_attention_heads = args.attention_heads
dropout = args.dropout
attention_dropout = args.attention_dropout
activation_dropout = args.activation_dropout
activation_fn = args.activation_fn
self.joint_classification = args.joint_classification
if args.joint_classification == "sent":
if args.joint_normalize_before:
self.joint_layer_norm = LayerNorm(classifier_embed_dim)
else:
self.joint_layer_norm = None
self.joint_layers = nn.ModuleList(
[
TransformerSentenceEncoderLayer(
embedding_dim=classifier_embed_dim,
ffn_embedding_dim=ffn_embedding_dim,
num_attention_heads=num_attention_heads,
dropout=dropout,
attention_dropout=attention_dropout,
activation_dropout=activation_dropout,
activation_fn=activation_fn,
)
for _ in range(args.num_joint_layers)
]
)
self.classifier = RobertaClassificationHead(
classifier_embed_dim,
classifier_embed_dim,
1, # num_classes
"tanh",
args.classifier_dropout,
)
def forward(self, src_tokens, src_lengths):
segment_labels = self.get_segment_labels(src_tokens)
positions = self.get_positions(src_tokens, segment_labels)
inner_states, _ = self.model(
tokens=src_tokens,
segment_labels=segment_labels,
last_state_only=True,
positions=positions,
)
return inner_states[-1].transpose(0, 1) # T x B x C -> B x T x C
def sentence_forward(self, encoder_out, src_tokens=None, sentence_rep="head"):
# encoder_out: B x T x C
if sentence_rep == "head":
x = encoder_out[:, :1, :]
else: # 'meanpool', 'maxpool'
assert src_tokens is not None, "meanpool requires src_tokens input"
segment_labels = self.get_segment_labels(src_tokens)
padding_mask = src_tokens.ne(self.padding_idx)
encoder_mask = segment_labels * padding_mask.type_as(segment_labels)
if sentence_rep == "meanpool":
ntokens = torch.sum(encoder_mask, dim=1, keepdim=True)
x = torch.sum(
encoder_out * encoder_mask.unsqueeze(2), dim=1, keepdim=True
) / ntokens.unsqueeze(2).type_as(encoder_out)
else: # 'maxpool'
encoder_out[
(encoder_mask == 0).unsqueeze(2).repeat(1, 1, encoder_out.shape[-1])
] = -float("inf")
x, _ = torch.max(encoder_out, dim=1, keepdim=True)
if hasattr(self, "transform_layer"):
x = self.transform_layer(x)
return x # B x 1 x C
def joint_forward(self, x):
# x: T x B x C
if self.joint_layer_norm:
x = self.joint_layer_norm(x.transpose(0, 1))
x = x.transpose(0, 1)
for layer in self.joint_layers:
x, _ = layer(x, self_attn_padding_mask=None)
return x
def classification_forward(self, x):
# x: B x T x C
return self.classifier(x)
@dataclass
class DiscriminativeNMTRerankerConfig(FairseqDataclass):
pretrained_model: str = field(
default="", metadata={"help": "pretrained model to load"}
)
sentence_rep: SENTENCE_REP_CHOICES = field(
default="head",
metadata={
"help": "method to transform the output of the transformer stack to a sentence-level representation"
},
)
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
attention_dropout: float = field(
default=0.0, metadata={"help": "dropout probability for attention weights"}
)
activation_dropout: float = field(
default=0.0, metadata={"help": "dropout probability after activation in FFN"}
)
classifier_dropout: float = field(
default=0.0, metadata={"help": "classifier dropout probability"}
)
embed_dim: int = field(default=768, metadata={"help": "embedding dimension"})
ffn_embed_dim: int = field(
default=2048, metadata={"help": "embedding dimension for FFN"}
)
encoder_layers: int = field(default=12, metadata={"help": "num encoder layers"})
attention_heads: int = field(default=8, metadata={"help": "num attention heads"})
encoder_normalize_before: bool = field(
default=False, metadata={"help": "apply layernorm before each encoder block"}
)
apply_bert_init: bool = field(
default=False, metadata={"help": "use custom param initialization for BERT"}
)
activation_fn: ACTIVATION_FN_CHOICES = field(
default="relu", metadata={"help": "activation function to use"}
)
freeze_embeddings: bool = field(
default=False, metadata={"help": "freeze embeddings in the pretrained model"}
)
n_trans_layers_to_freeze: int = field(
default=0,
metadata={
"help": "number of layers to freeze in the pretrained transformer model"
},
)
# joint classfication
joint_classification: JOINT_CLASSIFICATION_CHOICES = field(
default="none",
metadata={"help": "method to compute joint features for classification"},
)
num_joint_layers: int = field(
default=1, metadata={"help": "number of joint layers"}
)
joint_normalize_before: bool = field(
default=False,
metadata={"help": "apply layer norm on the input to the joint layer"},
)
@register_model(
"discriminative_nmt_reranker", dataclass=DiscriminativeNMTRerankerConfig
)
class DiscriminativeNMTReranker(BaseFairseqModel):
@classmethod
def build_model(cls, args, task):
model = BertRanker(args, task)
return DiscriminativeNMTReranker(args, model)
def __init__(self, args, model):
super().__init__()
self.model = model
self.sentence_rep = args.sentence_rep
self.joint_classification = args.joint_classification
def forward(self, src_tokens, src_lengths, **kwargs):
return self.model(src_tokens, src_lengths)
def sentence_forward(self, encoder_out, src_tokens):
return self.model.sentence_forward(encoder_out, src_tokens, self.sentence_rep)
def joint_forward(self, x):
return self.model.joint_forward(x)
def classification_forward(self, x):
return self.model.classification_forward(x)
#!/usr/bin/env python
import argparse
from multiprocessing import Pool
from pathlib import Path
import sacrebleu
import sentencepiece as spm
def read_text_file(filename):
with open(filename, "r") as f:
output = [line.strip() for line in f]
return output
def get_bleu(in_sent, target_sent):
bleu = sacrebleu.corpus_bleu([in_sent], [[target_sent]])
out = " ".join(
map(str, [bleu.score, bleu.sys_len, bleu.ref_len] + bleu.counts + bleu.totals)
)
return out
def get_ter(in_sent, target_sent):
ter = sacrebleu.corpus_ter([in_sent], [[target_sent]])
out = " ".join(map(str, [ter.score, ter.num_edits, ter.ref_length]))
return out
def init(sp_model):
global sp
sp = spm.SentencePieceProcessor()
sp.Load(sp_model)
def process(source_sent, target_sent, hypo_sent, metric):
source_bpe = " ".join(sp.EncodeAsPieces(source_sent))
hypo_bpe = [" ".join(sp.EncodeAsPieces(h)) for h in hypo_sent]
if metric == "bleu":
score_str = [get_bleu(h, target_sent) for h in hypo_sent]
else: # ter
score_str = [get_ter(h, target_sent) for h in hypo_sent]
return source_bpe, hypo_bpe, score_str
def main(args):
assert (
args.split.startswith("train") or args.num_shards == 1
), "--num-shards should be set to 1 for valid and test sets"
assert (
args.split.startswith("train")
or args.split.startswith("valid")
or args.split.startswith("test")
), "--split should be set to train[n]/valid[n]/test[n]"
source_sents = read_text_file(args.input_source)
target_sents = read_text_file(args.input_target)
num_sents = len(source_sents)
assert num_sents == len(
target_sents
), f"{args.input_source} and {args.input_target} should have the same number of sentences."
hypo_sents = read_text_file(args.input_hypo)
assert (
len(hypo_sents) % args.beam == 0
), f"Number of hypotheses ({len(hypo_sents)}) cannot be divided by beam size ({args.beam})."
hypo_sents = [
hypo_sents[i : i + args.beam] for i in range(0, len(hypo_sents), args.beam)
]
assert num_sents == len(
hypo_sents
), f"{args.input_hypo} should contain {num_sents * args.beam} hypotheses but only has {len(hypo_sents) * args.beam}. (--beam={args.beam})"
output_dir = args.output_dir / args.metric
for ns in range(args.num_shards):
print(f"processing shard {ns+1}/{args.num_shards}")
shard_output_dir = output_dir / f"split{ns+1}"
source_output_dir = shard_output_dir / "input_src"
hypo_output_dir = shard_output_dir / "input_tgt"
metric_output_dir = shard_output_dir / args.metric
source_output_dir.mkdir(parents=True, exist_ok=True)
hypo_output_dir.mkdir(parents=True, exist_ok=True)
metric_output_dir.mkdir(parents=True, exist_ok=True)
if args.n_proc > 1:
with Pool(
args.n_proc, initializer=init, initargs=(args.sentencepiece_model,)
) as p:
output = p.starmap(
process,
[
(source_sents[i], target_sents[i], hypo_sents[i], args.metric)
for i in range(ns, num_sents, args.num_shards)
],
)
else:
init(args.sentencepiece_model)
output = [
process(source_sents[i], target_sents[i], hypo_sents[i], args.metric)
for i in range(ns, num_sents, args.num_shards)
]
with open(source_output_dir / f"{args.split}.bpe", "w") as s_o, open(
hypo_output_dir / f"{args.split}.bpe", "w"
) as h_o, open(metric_output_dir / f"{args.split}.{args.metric}", "w") as m_o:
for source_bpe, hypo_bpe, score_str in output:
assert len(hypo_bpe) == len(score_str)
for h, m in zip(hypo_bpe, score_str):
s_o.write(f"{source_bpe}\n")
h_o.write(f"{h}\n")
m_o.write(f"{m}\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input-source", type=Path, required=True)
parser.add_argument("--input-target", type=Path, required=True)
parser.add_argument("--input-hypo", type=Path, required=True)
parser.add_argument("--output-dir", type=Path, required=True)
parser.add_argument("--split", type=str, required=True)
parser.add_argument("--beam", type=int, required=True)
parser.add_argument("--sentencepiece-model", type=str, required=True)
parser.add_argument("--metric", type=str, choices=["bleu", "ter"], default="bleu")
parser.add_argument("--num-shards", type=int, default=1)
parser.add_argument("--n-proc", type=int, default=8)
args = parser.parse_args()
main(args)
from .discriminative_reranking_task import DiscriminativeRerankingNMTTask
__all__ = [
"DiscriminativeRerankingNMTTask",
]
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
import itertools
import logging
import os
import numpy as np
import torch
from fairseq.logging import metrics
from fairseq.data import (
ConcatDataset,
ConcatSentencesDataset,
data_utils,
Dictionary,
IdDataset,
indexed_dataset,
NestedDictionaryDataset,
NumSamplesDataset,
NumelDataset,
PrependTokenDataset,
RawLabelDataset,
RightPadDataset,
SortDataset,
TruncateDataset,
TokenBlockDataset,
)
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from omegaconf import II, MISSING
EVAL_BLEU_ORDER = 4
TARGET_METRIC_CHOICES = ChoiceEnum(["bleu", "ter"])
logger = logging.getLogger(__name__)
@dataclass
class DiscriminativeRerankingNMTConfig(FairseqDataclass):
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
num_data_splits: int = field(
default=1, metadata={"help": "total number of data splits"}
)
no_shuffle: bool = field(
default=False, metadata={"help": "do not shuffle training data"}
)
max_positions: int = field(
default=512, metadata={"help": "number of positional embeddings to learn"}
)
include_src: bool = field(
default=False, metadata={"help": "include source sentence"}
)
mt_beam: int = field(default=50, metadata={"help": "beam size of input hypotheses"})
eval_target_metric: bool = field(
default=False,
metadata={"help": "evaluation with the target metric during validation"},
)
target_metric: TARGET_METRIC_CHOICES = field(
default="bleu", metadata={"help": "name of the target metric to optimize for"}
)
train_subset: str = field(
default=II("dataset.train_subset"),
metadata={"help": "data subset to use for training (e.g. train, valid, test)"},
)
seed: int = field(
default=II("common.seed"),
metadata={"help": "pseudo random number generator seed"},
)
class RerankerScorer(object):
"""Scores the target for a given (source (optional), target) input."""
def __init__(self, args, mt_beam):
self.mt_beam = mt_beam
@torch.no_grad()
def generate(self, models, sample, **kwargs):
"""Score a batch of translations."""
net_input = sample["net_input"]
assert len(models) == 1, "does not support model ensemble"
model = models[0]
bs = net_input["src_tokens"].shape[0]
assert (
model.joint_classification == "none" or bs % self.mt_beam == 0
), f"invalid batch size ({bs}) for joint classification with beam size ({self.mt_beam})"
model.eval()
logits = model(**net_input)
batch_out = model.sentence_forward(logits, net_input["src_tokens"])
if model.joint_classification == "sent":
batch_out = model.joint_forward(
batch_out.view(self.mt_beam, bs // self.mt_beam, -1)
)
scores = model.classification_forward(
batch_out.view(bs, 1, -1)
) # input: B x T x C
return scores
@register_task(
"discriminative_reranking_nmt", dataclass=DiscriminativeRerankingNMTConfig
)
class DiscriminativeRerankingNMTTask(FairseqTask):
"""
Translation rerank task.
The input can be either (src, tgt) sentence pairs or tgt sentence only.
"""
cfg: DiscriminativeRerankingNMTConfig
def __init__(self, cfg: DiscriminativeRerankingNMTConfig, data_dictionary=None):
super().__init__(cfg)
self.dictionary = data_dictionary
self._max_positions = cfg.max_positions
# args.tokens_per_sample = self._max_positions
# self.num_classes = 1 # for model
@classmethod
def load_dictionary(cls, cfg, filename):
"""Load the dictionary from the filename"""
dictionary = Dictionary.load(filename)
dictionary.add_symbol("<mask>") # for loading pretrained XLMR model
return dictionary
@classmethod
def setup_task(cls, cfg: DiscriminativeRerankingNMTConfig, **kwargs):
# load data dictionary (assume joint dictionary)
data_path = cfg.data
data_dict = cls.load_dictionary(
cfg, os.path.join(data_path, "input_src/dict.txt")
)
logger.info("[input] src dictionary: {} types".format(len(data_dict)))
return DiscriminativeRerankingNMTTask(cfg, data_dict)
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
"""Load a given dataset split (e.g., train, valid, test)."""
if self.cfg.data.endswith("1"):
data_shard = (epoch - 1) % self.cfg.num_data_splits + 1
data_path = self.cfg.data[:-1] + str(data_shard)
else:
data_path = self.cfg.data
def get_path(type, data_split):
return os.path.join(data_path, str(type), data_split)
def make_dataset(type, dictionary, data_split, combine):
split_path = get_path(type, data_split)
dataset = data_utils.load_indexed_dataset(
split_path,
dictionary,
combine=combine,
)
return dataset
def load_split(data_split, metric):
input_src = None
if self.cfg.include_src:
input_src = make_dataset(
"input_src", self.dictionary, data_split, combine=False
)
assert input_src is not None, "could not find dataset: {}".format(
get_path("input_src", data_split)
)
input_tgt = make_dataset(
"input_tgt", self.dictionary, data_split, combine=False
)
assert input_tgt is not None, "could not find dataset: {}".format(
get_path("input_tgt", data_split)
)
label_path = f"{get_path(metric, data_split)}.{metric}"
assert os.path.exists(label_path), f"could not find dataset: {label_path}"
np_labels = np.loadtxt(label_path)
if self.cfg.target_metric == "ter":
np_labels = -np_labels
label = RawLabelDataset(np_labels)
return input_src, input_tgt, label
src_datasets = []
tgt_datasets = []
label_datasets = []
if split == self.cfg.train_subset:
for k in itertools.count():
split_k = "train" + (str(k) if k > 0 else "")
prefix = os.path.join(data_path, "input_tgt", split_k)
if not indexed_dataset.dataset_exists(prefix, impl=None):
if k > 0:
break
else:
raise FileNotFoundError(f"Dataset not found: {prefix}")
input_src, input_tgt, label = load_split(
split_k, self.cfg.target_metric
)
src_datasets.append(input_src)
tgt_datasets.append(input_tgt)
label_datasets.append(label)
else:
input_src, input_tgt, label = load_split(split, self.cfg.target_metric)
src_datasets.append(input_src)
tgt_datasets.append(input_tgt)
label_datasets.append(label)
if len(tgt_datasets) == 1:
input_tgt, label = tgt_datasets[0], label_datasets[0]
if self.cfg.include_src:
input_src = src_datasets[0]
else:
input_tgt = ConcatDataset(tgt_datasets)
label = ConcatDataset(label_datasets)
if self.cfg.include_src:
input_src = ConcatDataset(src_datasets)
input_tgt = TruncateDataset(input_tgt, self.cfg.max_positions)
if self.cfg.include_src:
input_src = PrependTokenDataset(input_src, self.dictionary.bos())
input_src = TruncateDataset(input_src, self.cfg.max_positions)
src_lengths = NumelDataset(input_src, reduce=False)
src_tokens = ConcatSentencesDataset(input_src, input_tgt)
else:
src_tokens = PrependTokenDataset(input_tgt, self.dictionary.bos())
src_lengths = NumelDataset(src_tokens, reduce=False)
dataset = {
"id": IdDataset(),
"net_input": {
"src_tokens": RightPadDataset(
src_tokens,
pad_idx=self.source_dictionary.pad(),
),
"src_lengths": src_lengths,
},
"nsentences": NumSamplesDataset(),
"ntokens": NumelDataset(src_tokens, reduce=True),
"target": label,
}
dataset = NestedDictionaryDataset(
dataset,
sizes=[src_tokens.sizes],
)
assert (
len(dataset) % self.cfg.mt_beam == 0
), "dataset size (%d) is not a multiple of beam size (%d)" % (
len(dataset),
self.cfg.mt_beam,
)
# no need to shuffle valid/test sets
if not self.cfg.no_shuffle and split == self.cfg.train_subset:
# need to keep all hypothese together
start_idx = np.arange(0, len(dataset), self.cfg.mt_beam)
with data_utils.numpy_seed(self.cfg.seed + epoch):
np.random.shuffle(start_idx)
idx = np.arange(0, self.cfg.mt_beam)
shuffle = np.tile(idx, (len(start_idx), 1)).reshape(-1) + np.tile(
start_idx, (self.cfg.mt_beam, 1)
).transpose().reshape(-1)
dataset = SortDataset(
dataset,
sort_order=[shuffle],
)
logger.info(f"Loaded {split} with #samples: {len(dataset)}")
self.datasets[split] = dataset
return self.datasets[split]
def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs):
assert not self.cfg.include_src or len(src_tokens[0]) == 2
input_src = None
if self.cfg.include_src:
input_src = TokenBlockDataset(
[t[0] for t in src_tokens],
[l[0] for l in src_lengths],
block_size=None, # ignored for "eos" break mode
pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(),
break_mode="eos",
)
input_src = PrependTokenDataset(input_src, self.dictionary.bos())
input_src = TruncateDataset(input_src, self.cfg.max_positions)
input_tgt = TokenBlockDataset(
[t[-1] for t in src_tokens],
[l[-1] for l in src_lengths],
block_size=None, # ignored for "eos" break mode
pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(),
break_mode="eos",
)
input_tgt = TruncateDataset(input_tgt, self.cfg.max_positions)
if self.cfg.include_src:
src_tokens = ConcatSentencesDataset(input_src, input_tgt)
src_lengths = NumelDataset(input_src, reduce=False)
else:
input_tgt = PrependTokenDataset(input_tgt, self.dictionary.bos())
src_tokens = input_tgt
src_lengths = NumelDataset(src_tokens, reduce=False)
dataset = {
"id": IdDataset(),
"net_input": {
"src_tokens": RightPadDataset(
src_tokens,
pad_idx=self.source_dictionary.pad(),
),
"src_lengths": src_lengths,
},
"nsentences": NumSamplesDataset(),
"ntokens": NumelDataset(src_tokens, reduce=True),
}
return NestedDictionaryDataset(
dataset,
sizes=[src_tokens.sizes],
)
def build_model(self, cfg: FairseqDataclass, from_checkpoint: bool = False):
return super().build_model(cfg)
def build_generator(self, args):
return RerankerScorer(args, mt_beam=self.cfg.mt_beam)
def max_positions(self):
return self._max_positions
@property
def source_dictionary(self):
return self.dictionary
@property
def target_dictionary(self):
return self.dictionary
def create_dummy_batch(self, device):
dummy_target = (
torch.zeros(self.cfg.mt_beam, EVAL_BLEU_ORDER * 2 + 3).long().to(device)
if not self.cfg.eval_ter
else torch.zeros(self.cfg.mt_beam, 3).long().to(device)
)
return {
"id": torch.zeros(self.cfg.mt_beam, 1).long().to(device),
"net_input": {
"src_tokens": torch.zeros(self.cfg.mt_beam, 4).long().to(device),
"src_lengths": torch.ones(self.cfg.mt_beam, 1).long().to(device),
},
"nsentences": 0,
"ntokens": 0,
"target": dummy_target,
}
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
if ignore_grad and sample is None:
sample = self.create_dummy_batch(model.device)
return super().train_step(
sample, model, criterion, optimizer, update_num, ignore_grad
)
def valid_step(self, sample, model, criterion):
if sample is None:
sample = self.create_dummy_batch(model.device)
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
if not self.cfg.eval_target_metric:
return loss, sample_size, logging_output
scores = logging_output["scores"]
if self.cfg.target_metric == "bleu":
assert sample["target"].shape[1] == EVAL_BLEU_ORDER * 2 + 3, (
"target does not contain enough information ("
+ str(sample["target"].shape[1])
+ "for evaluating BLEU"
)
max_id = torch.argmax(scores, dim=1)
select_id = max_id + torch.arange(
0, sample_size * self.cfg.mt_beam, self.cfg.mt_beam
).to(max_id.device)
bleu_data = sample["target"][select_id, 1:].sum(0).data
logging_output["_bleu_sys_len"] = bleu_data[0]
logging_output["_bleu_ref_len"] = bleu_data[1]
for i in range(EVAL_BLEU_ORDER):
logging_output["_bleu_counts_" + str(i)] = bleu_data[2 + i]
logging_output["_bleu_totals_" + str(i)] = bleu_data[
2 + EVAL_BLEU_ORDER + i
]
elif self.cfg.target_metric == "ter":
assert sample["target"].shape[1] == 3, (
"target does not contain enough information ("
+ str(sample["target"].shape[1])
+ "for evaluating TER"
)
max_id = torch.argmax(scores, dim=1)
select_id = max_id + torch.arange(
0, sample_size * self.cfg.mt_beam, self.cfg.mt_beam
).to(max_id.device)
ter_data = sample["target"][select_id, 1:].sum(0).data
logging_output["_ter_num_edits"] = -ter_data[0]
logging_output["_ter_ref_len"] = -ter_data[1]
return loss, sample_size, logging_output
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
if not self.cfg.eval_target_metric:
return
def sum_logs(key):
return sum(log.get(key, 0) for log in logging_outputs)
if self.cfg.target_metric == "bleu":
counts, totals = [], []
for i in range(EVAL_BLEU_ORDER):
counts.append(sum_logs("_bleu_counts_" + str(i)))
totals.append(sum_logs("_bleu_totals_" + str(i)))
if max(totals) > 0:
# log counts as numpy arrays -- log_scalar will sum them correctly
metrics.log_scalar("_bleu_counts", np.array(counts))
metrics.log_scalar("_bleu_totals", np.array(totals))
metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len"))
metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len"))
def compute_bleu(meters):
import inspect
import sacrebleu
fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0]
if "smooth_method" in fn_sig:
smooth = {"smooth_method": "exp"}
else:
smooth = {"smooth": "exp"}
bleu = sacrebleu.compute_bleu(
correct=meters["_bleu_counts"].sum,
total=meters["_bleu_totals"].sum,
sys_len=meters["_bleu_sys_len"].sum,
ref_len=meters["_bleu_ref_len"].sum,
**smooth,
)
return round(bleu.score, 2)
metrics.log_derived("bleu", compute_bleu)
elif self.cfg.target_metric == "ter":
num_edits = sum_logs("_ter_num_edits")
ref_len = sum_logs("_ter_ref_len")
if ref_len > 0:
metrics.log_scalar("_ter_num_edits", num_edits)
metrics.log_scalar("_ter_ref_len", ref_len)
def compute_ter(meters):
score = meters["_ter_num_edits"].sum / meters["_ter_ref_len"].sum
return round(score.item(), 2)
metrics.log_derived("ter", compute_ter)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment