Unverified Commit 0203ad43 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] distributed eval cleanup (#7186)

parent 3babef81
......@@ -227,6 +227,20 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_
--fp16 \
--bs 32
```
### Multi-GPU Evalulation
here is a command to run xsum evaluation on 8 GPUS. It is more than linearly faster than run_eval.py in some cases
because it uses SortishSampler to minimize padding. You can also use it on 1 GPU. `data_dir` must have
`{type_path}.source` and `{type_path}.target`. Run `python run_distributed_eval.py --help` for all clargs.
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_distributed_eval.py \
--model_name sshleifer/distilbart-large-xsum-12-3 \
--save_dir xsum_generations \
--data_dir xsum \
--fp16 # you can pass generate kwargs like num_beams here, just like run_eval.py
```
Contributions that implement this command for other distributed hardware setups are welcome!
#### run_eval tips and tricks
......
......@@ -4,7 +4,7 @@ import time
from json import JSONDecodeError
from logging import getLogger
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List
import torch
from torch.utils.data import DataLoader
......@@ -22,7 +22,7 @@ try:
calculate_rouge,
lmap,
load_json,
parse_numeric_cl_kwargs,
parse_numeric_n_bool_cl_kwargs,
save_json,
use_task_specific_params,
write_txt_file,
......@@ -34,7 +34,7 @@ except ImportError:
calculate_rouge,
lmap,
load_json,
parse_numeric_cl_kwargs,
parse_numeric_n_bool_cl_kwargs,
save_json,
use_task_specific_params,
write_txt_file,
......@@ -50,7 +50,6 @@ def eval_data_dir(
type_path="val",
n_obs=None,
fp16=False,
num_beams: int = 4,
task="summarization",
local_rank=None,
**generate_kwargs,
......@@ -81,23 +80,21 @@ def eval_data_dir(
n_obs=n_obs,
prefix=model.config.prefix,
)
sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False)
# I set shuffle=True for a more accurate progress bar.
# If all the longest samples are first, the prog bar estimate is too high at the beginning.
sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False, shuffle=True)
data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn)
dec_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=False) # tokenizer.decode
results = []
for batch in tqdm(data_loader):
summaries = model.generate(
input_ids=batch["input_ids"].to(model.device),
attention_mask=batch["attention_mask"].to(model.device),
num_beams=num_beams,
**generate_kwargs,
)
preds = tokenizer.batch_decode(summaries, **dec_kwargs)
labels = tokenizer.batch_decode(batch["labels"], **dec_kwargs)
preds = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
ids = batch["ids"]
for i in range(len(labels)):
label, pred = labels[i], preds[i]
results.append(dict(pred=pred, label=label, id=ids[i].item()))
for i, pred in enumerate(preds):
results.append(dict(pred=pred, id=ids[i].item()))
save_json(results, save_path)
return results, sampler.num_replicas
......@@ -139,8 +136,8 @@ def run_generate():
parser.add_argument("--debug", action="store_true")
start_time = time.time()
args, rest = parser.parse_known_args()
generate_kwargs = parse_numeric_cl_kwargs(rest)
if generate_kwargs:
generate_kwargs = parse_numeric_n_bool_cl_kwargs(rest)
if generate_kwargs and args.local_rank <= 0:
print(f"parsed the following generate kwargs: {generate_kwargs}")
json_save_dir = Path(args.save_dir + "_tmp")
Path(json_save_dir).mkdir(exist_ok=True) # this handles locking.
......@@ -168,7 +165,10 @@ def run_generate():
save_dir = Path(args.save_dir)
save_dir.mkdir(exist_ok=True)
partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout)
preds, labels = combine_partial_results(partial_results)
preds = combine_partial_results(partial_results)
tgt_file = Path(args.data_dir).joinpath(args.type_path + ".target")
labels = [x.rstrip() for x in open(tgt_file).readlines()][: len(preds)]
# Calculate metrics, save metrics, and save _generations.txt
calc_bleu = "translation" in args.task
score_fn = calculate_bleu if calc_bleu else calculate_rouge
......@@ -179,7 +179,7 @@ def run_generate():
metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 2)
# TODO(@stas00): add whatever metadata to metrics
metrics_save_path = save_dir.joinpath(f"{args.type_path}_{metric_name}.json")
save_json(metrics, metrics_save_path)
save_json(metrics, metrics_save_path, indent=None)
print(metrics)
write_txt_file(preds, save_dir.joinpath(f"{args.type_path}_generations.txt"))
if args.debug:
......@@ -188,15 +188,14 @@ def run_generate():
shutil.rmtree(json_save_dir)
def combine_partial_results(partial_results) -> Tuple[List, List]:
def combine_partial_results(partial_results) -> List:
"""Concatenate partial results into one file, then sort it by id."""
records = []
for partial_result in partial_results:
records.extend(partial_result)
records = list(sorted(records, key=lambda x: x["id"]))
preds = [x["pred"] for x in records]
labels = [x["label"] for x in records]
return preds, labels
return preds
def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]:
......
......@@ -156,7 +156,7 @@ def run_generate(verbose=True):
scores["info"] = args.info
if verbose:
print(*scores)
print(scores)
if args.score_path is not None:
path = args.score_path
......
......@@ -115,11 +115,11 @@ class AbstractSeq2SeqDataset(Dataset):
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
def make_sortish_sampler(self, batch_size, distributed=False, **kwargs):
def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs):
if distributed:
return DistributedSortishSampler(self, batch_size, **kwargs)
return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
else:
return SortishSampler(self.src_lens, batch_size)
return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
def __getitem__(self, item):
raise NotImplementedError("You must implement this")
......@@ -193,18 +193,20 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
class SortishSampler(Sampler):
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
def __init__(self, data, batch_size):
self.data, self.bs = data, batch_size
def __init__(self, data, batch_size, shuffle=True):
self.data, self.bs, self.shuffle = data, batch_size, shuffle
def __len__(self) -> int:
return len(self.data)
def __iter__(self):
return iter(sortish_sampler_indices(self.data, self.bs))
return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))
def sortish_sampler_indices(data: List, bs: int) -> np.array:
def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array:
"Go through the text data by order of src length with a bit of randomness. From fastai repo."
if not shuffle:
return np.argsort(np.array(data) * -1)
def key_fn(i):
return data[i]
......@@ -225,7 +227,7 @@ def sortish_sampler_indices(data: List, bs: int) -> np.array:
class DistributedSortishSampler(Sampler):
"""Copied from torch DistributedSampler"""
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True):
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
......@@ -246,13 +248,14 @@ class DistributedSortishSampler(Sampler):
self.num_samples = len(self.available_indices)
self.batch_size = batch_size
self.add_extra_examples = add_extra_examples
self.shuffle = shuffle
def __iter__(self) -> Iterable:
g = torch.Generator()
g.manual_seed(self.epoch)
sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size)
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle)
indices = [self.available_indices[i] for i in sortish_indices]
assert len(indices) == self.num_samples
return iter(indices)
......@@ -309,9 +312,9 @@ def save_git_info(folder_path: str) -> None:
save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
def save_json(content, path):
def save_json(content, path, indent=4, **json_dump_kwargs):
with open(path, "w") as f:
json.dump(content, f, indent=4)
json.dump(content, f, indent=indent, **json_dump_kwargs)
def load_json(path):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment