Unverified Commit 33d479d2 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] distributed eval in one command (#7124)

parent 206b78d4
from pathlib import Path
import fire
try:
from .utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file
except ImportError:
from utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file
def combine_partial_results(
result_dir: str, save_dir: str = None, save_prefix=None, calc_bleu=False, just_metrics=False
):
"""Write first n lines of each file f in src_dir to dest_dir/f """
src_dir = Path(result_dir)
save_dir = Path(save_dir)
save_dir.mkdir(exist_ok=True)
paths_to_combine = list(src_dir.glob("rank*.json"))
records = []
for partial_result in paths_to_combine:
records.extend(load_json(partial_result))
preds = [x["pred"] for x in records]
labels = [x["label"] for x in records]
score_fn = calculate_bleu if calc_bleu else calculate_rouge
metrics = score_fn(preds, labels)
save_json(metrics, save_dir.joinpath("metrics.json")) # better would be be {prefix}_{rouge|bleu}.json
print(metrics)
if just_metrics:
return
if save_prefix is None:
save_prefix = "generated"
print("using generated as prefix")
tgt_path = save_dir.joinpath(f"{save_prefix}.target")
write_txt_file(labels, tgt_path)
pred_path = save_dir.joinpath(f"{save_prefix}.pred_target")
write_txt_file(preds, pred_path)
if "source" in records[0]:
src_path = save_dir.joinpath(f"{save_prefix}.source")
write_txt_file([x["source"] for x in records], src_path)
if __name__ == "__main__":
fire.Fire(combine_partial_results)
......@@ -12,7 +12,7 @@ Note: You need to have your test_generations.txt before you start this process.
cd $HOME
git clone git@github.com:moses-smt/mosesdecoder.git
cd mosesdecoder
git@github.com:rsennrich/wmt16-scripts.git
git clone git@github.com:rsennrich/wmt16-scripts.git
```
(2) define a function for post processing.
......
import argparse
import shutil
import time
from json import JSONDecodeError
from logging import getLogger
from pathlib import Path
from typing import Dict
from typing import Dict, List, Tuple
import torch
from torch.utils.data import DataLoader
......@@ -13,12 +16,29 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
logger = getLogger(__name__)
try:
from .utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
from .utils import (
Seq2SeqDataset,
calculate_bleu,
calculate_rouge,
lmap,
load_json,
parse_numeric_cl_kwargs,
save_json,
use_task_specific_params,
write_txt_file,
)
except ImportError:
from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
from utils import (
Seq2SeqDataset,
calculate_bleu,
calculate_rouge,
lmap,
load_json,
parse_numeric_cl_kwargs,
save_json,
use_task_specific_params,
write_txt_file,
)
def eval_data_dir(
......@@ -30,7 +50,6 @@ def eval_data_dir(
type_path="val",
n_obs=None,
fp16=False,
save_source=False,
num_beams: int = 4,
task="summarization",
local_rank=None,
......@@ -62,7 +81,7 @@ def eval_data_dir(
n_obs=n_obs,
prefix=model.config.prefix,
)
sampler = ds.make_sortish_sampler(bs, distributed=True)
sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False)
data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn)
dec_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=False) # tokenizer.decode
results = []
......@@ -75,23 +94,19 @@ def eval_data_dir(
)
preds = tokenizer.batch_decode(summaries, **dec_kwargs)
labels = tokenizer.batch_decode(batch["labels"], **dec_kwargs)
if save_source:
docs = tokenizer.batch_decode(batch["input_ids"], **dec_kwargs)
ids = batch["ids"]
for i in range(len(labels)):
label, pred = labels[i], preds[i]
if save_source:
results.append(dict(pred=pred, label=label, source=docs[i]))
else:
results.append(dict(pred=pred, label=label))
results.append(dict(pred=pred, label=label, id=ids[i].item()))
save_json(results, save_path)
return results
return results, sampler.num_replicas
def run_generate():
parser = argparse.ArgumentParser(
epilog="Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate"
)
parser.add_argument("--input_path", type=str, help="like cnn_dm/test.source")
parser.add_argument("--data_dir", type=str, help="like cnn_dm/test.source")
parser.add_argument(
"--model_name",
type=str,
......@@ -113,17 +128,31 @@ def run_generate():
parser.add_argument(
"--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all."
)
parser.add_argument(
"--sync_timeout",
type=int,
default=600,
required=False,
help="How long should master process wait for other processes to finish.",
)
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--save_source", action="store_true")
parser.add_argument("--debug", action="store_true")
start_time = time.time()
args, rest = parser.parse_known_args()
generate_kwargs = parse_numeric_cl_kwargs(rest)
if generate_kwargs:
print(f"parsed the following generate kwargs: {generate_kwargs}")
json_save_dir = Path(args.save_dir + "_tmp")
Path(json_save_dir).mkdir(exist_ok=True) # this handles locking.
intermediate_files = list(json_save_dir.glob("rank_*.json"))
if intermediate_files:
raise ValueError(f"Found files at {json_save_dir} please move or remove them.")
# In theory, a node could finish and save before another node hits this. If this happens, we can address later.
Path(args.save_dir).mkdir(exist_ok=True)
eval_data_dir(
args.input_path,
args.save_dir,
results, num_replicas = eval_data_dir(
args.data_dir,
json_save_dir,
args.model_name,
type_path=args.type_path,
batch_size=args.bs,
......@@ -131,11 +160,64 @@ def run_generate():
task=args.task,
local_rank=args.local_rank,
n_obs=args.n_obs,
save_source=args.save_source,
max_source_length=args.max_source_length,
**generate_kwargs,
)
if args.local_rank <= 0:
save_dir = Path(args.save_dir)
save_dir.mkdir(exist_ok=True)
partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout)
preds, labels = combine_partial_results(partial_results)
# Calculate metrics, save metrics, and save _generations.txt
calc_bleu = "translation" in args.task
score_fn = calculate_bleu if calc_bleu else calculate_rouge
metric_name = "bleu" if calc_bleu else "rouge"
metrics: Dict = score_fn(preds, labels)
metrics["n_obs"] = len(preds)
runtime = time.time() - start_time
metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 2)
# TODO(@stas00): add whatever metadata to metrics
metrics_save_path = save_dir.joinpath(f"{args.type_path}_{metric_name}.json")
save_json(metrics, metrics_save_path)
print(metrics)
write_txt_file(preds, save_dir.joinpath(f"{args.type_path}_generations.txt"))
if args.debug:
write_txt_file(labels, save_dir.joinpath(f"{args.type_path}.target"))
else:
shutil.rmtree(json_save_dir)
def combine_partial_results(partial_results) -> Tuple[List, List]:
"""Concatenate partial results into one file, then sort it by id."""
records = []
for partial_result in partial_results:
records.extend(partial_result)
records = list(sorted(records, key=lambda x: x["id"]))
preds = [x["pred"] for x in records]
labels = [x["label"] for x in records]
return preds, labels
def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]:
# WAIT FOR lots of .json files
start_wait = time.time()
logger.info("waiting for all nodes to finish")
json_data = None
while (time.time() - start_wait) < timeout:
json_files = list(save_dir.glob("rank_*.json"))
if len(json_files) < num_replicas:
continue
try:
# make sure all json files are fully saved
json_data = lmap(load_json, json_files)
return json_data
except JSONDecodeError:
continue
else:
raise TimeoutError("Rank 0 gave up on waiting for other processes")
# Unreachable
if __name__ == "__main__":
# Usage for MT:
......
......@@ -18,6 +18,7 @@ from torch import nn
from torch.utils.data import Dataset, Sampler
from transformers import BartTokenizer
from transformers.file_utils import cached_property
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
......@@ -114,9 +115,9 @@ class AbstractSeq2SeqDataset(Dataset):
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
def make_sortish_sampler(self, batch_size, distributed=False):
def make_sortish_sampler(self, batch_size, distributed=False, **kwargs):
if distributed:
return DistributedSortishSampler(self, batch_size)
return DistributedSortishSampler(self, batch_size, **kwargs)
else:
return SortishSampler(self.src_lens, batch_size)
......@@ -171,14 +172,11 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
assert source_line, f"empty source line for index {index}"
assert tgt_line, f"empty tgt line for index {index}"
return {
"tgt_texts": tgt_line,
"src_texts": source_line,
}
return {"tgt_texts": tgt_line, "src_texts": source_line, "id": index - 1}
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
"""Call prepare_seq2seq_batch."""
batch_encoding = self.tokenizer.prepare_seq2seq_batch(
batch_encoding: Dict[str, torch.Tensor] = self.tokenizer.prepare_seq2seq_batch(
[x["src_texts"] for x in batch],
src_lang=self.src_lang,
tgt_texts=[x["tgt_texts"] for x in batch],
......@@ -187,8 +185,9 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
max_target_length=self.max_target_length,
return_tensors="pt",
add_prefix_space=self.add_prefix_space,
)
return batch_encoding.data
).data
batch_encoding["ids"] = torch.tensor([x["id"] for x in batch])
return batch_encoding
class SortishSampler(Sampler):
......@@ -226,7 +225,7 @@ def sortish_sampler_indices(data: List, bs: int) -> np.array:
class DistributedSortishSampler(Sampler):
"""Copied from torch DistributedSampler"""
def __init__(self, dataset, batch_size, num_replicas=None, rank=None):
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
......@@ -239,22 +238,27 @@ class DistributedSortishSampler(Sampler):
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
if add_extra_examples:
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
else:
self.total_size = len(dataset)
self.num_samples = len(self.available_indices)
self.batch_size = batch_size
self.add_extra_examples = add_extra_examples
def __iter__(self) -> Iterable:
g = torch.Generator()
g.manual_seed(self.epoch)
available_indices = self.get_indices_for_rank() # indices[self.rank: self.total_size: self.num_replicas]
sortish_data = [self.dataset.src_lens[i] for i in available_indices]
sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size)
indices = [available_indices[i] for i in sortish_indices]
indices = [self.available_indices[i] for i in sortish_indices]
assert len(indices) == self.num_samples
return iter(indices)
def get_indices_for_rank(self) -> np.array:
@cached_property
def available_indices(self) -> np.array:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment