Unverified Commit 0203ad43 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] distributed eval cleanup (#7186)

parent 3babef81
...@@ -227,6 +227,20 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_ ...@@ -227,6 +227,20 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_
--fp16 \ --fp16 \
--bs 32 --bs 32
``` ```
### Multi-GPU Evalulation
here is a command to run xsum evaluation on 8 GPUS. It is more than linearly faster than run_eval.py in some cases
because it uses SortishSampler to minimize padding. You can also use it on 1 GPU. `data_dir` must have
`{type_path}.source` and `{type_path}.target`. Run `python run_distributed_eval.py --help` for all clargs.
```bash
python -m torch.distributed.launch --nproc_per_node=8 run_distributed_eval.py \
--model_name sshleifer/distilbart-large-xsum-12-3 \
--save_dir xsum_generations \
--data_dir xsum \
--fp16 # you can pass generate kwargs like num_beams here, just like run_eval.py
```
Contributions that implement this command for other distributed hardware setups are welcome!
#### run_eval tips and tricks #### run_eval tips and tricks
......
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
from json import JSONDecodeError from json import JSONDecodeError
from logging import getLogger from logging import getLogger
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -22,7 +22,7 @@ try: ...@@ -22,7 +22,7 @@ try:
calculate_rouge, calculate_rouge,
lmap, lmap,
load_json, load_json,
parse_numeric_cl_kwargs, parse_numeric_n_bool_cl_kwargs,
save_json, save_json,
use_task_specific_params, use_task_specific_params,
write_txt_file, write_txt_file,
...@@ -34,7 +34,7 @@ except ImportError: ...@@ -34,7 +34,7 @@ except ImportError:
calculate_rouge, calculate_rouge,
lmap, lmap,
load_json, load_json,
parse_numeric_cl_kwargs, parse_numeric_n_bool_cl_kwargs,
save_json, save_json,
use_task_specific_params, use_task_specific_params,
write_txt_file, write_txt_file,
...@@ -50,7 +50,6 @@ def eval_data_dir( ...@@ -50,7 +50,6 @@ def eval_data_dir(
type_path="val", type_path="val",
n_obs=None, n_obs=None,
fp16=False, fp16=False,
num_beams: int = 4,
task="summarization", task="summarization",
local_rank=None, local_rank=None,
**generate_kwargs, **generate_kwargs,
...@@ -81,23 +80,21 @@ def eval_data_dir( ...@@ -81,23 +80,21 @@ def eval_data_dir(
n_obs=n_obs, n_obs=n_obs,
prefix=model.config.prefix, prefix=model.config.prefix,
) )
sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False) # I set shuffle=True for a more accurate progress bar.
# If all the longest samples are first, the prog bar estimate is too high at the beginning.
sampler = ds.make_sortish_sampler(bs, distributed=True, add_extra_examples=False, shuffle=True)
data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn) data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn)
dec_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=False) # tokenizer.decode
results = [] results = []
for batch in tqdm(data_loader): for batch in tqdm(data_loader):
summaries = model.generate( summaries = model.generate(
input_ids=batch["input_ids"].to(model.device), input_ids=batch["input_ids"].to(model.device),
attention_mask=batch["attention_mask"].to(model.device), attention_mask=batch["attention_mask"].to(model.device),
num_beams=num_beams,
**generate_kwargs, **generate_kwargs,
) )
preds = tokenizer.batch_decode(summaries, **dec_kwargs) preds = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
labels = tokenizer.batch_decode(batch["labels"], **dec_kwargs)
ids = batch["ids"] ids = batch["ids"]
for i in range(len(labels)): for i, pred in enumerate(preds):
label, pred = labels[i], preds[i] results.append(dict(pred=pred, id=ids[i].item()))
results.append(dict(pred=pred, label=label, id=ids[i].item()))
save_json(results, save_path) save_json(results, save_path)
return results, sampler.num_replicas return results, sampler.num_replicas
...@@ -139,8 +136,8 @@ def run_generate(): ...@@ -139,8 +136,8 @@ def run_generate():
parser.add_argument("--debug", action="store_true") parser.add_argument("--debug", action="store_true")
start_time = time.time() start_time = time.time()
args, rest = parser.parse_known_args() args, rest = parser.parse_known_args()
generate_kwargs = parse_numeric_cl_kwargs(rest) generate_kwargs = parse_numeric_n_bool_cl_kwargs(rest)
if generate_kwargs: if generate_kwargs and args.local_rank <= 0:
print(f"parsed the following generate kwargs: {generate_kwargs}") print(f"parsed the following generate kwargs: {generate_kwargs}")
json_save_dir = Path(args.save_dir + "_tmp") json_save_dir = Path(args.save_dir + "_tmp")
Path(json_save_dir).mkdir(exist_ok=True) # this handles locking. Path(json_save_dir).mkdir(exist_ok=True) # this handles locking.
...@@ -168,7 +165,10 @@ def run_generate(): ...@@ -168,7 +165,10 @@ def run_generate():
save_dir = Path(args.save_dir) save_dir = Path(args.save_dir)
save_dir.mkdir(exist_ok=True) save_dir.mkdir(exist_ok=True)
partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout) partial_results = gather_results_from_each_node(num_replicas, json_save_dir, args.sync_timeout)
preds, labels = combine_partial_results(partial_results) preds = combine_partial_results(partial_results)
tgt_file = Path(args.data_dir).joinpath(args.type_path + ".target")
labels = [x.rstrip() for x in open(tgt_file).readlines()][: len(preds)]
# Calculate metrics, save metrics, and save _generations.txt # Calculate metrics, save metrics, and save _generations.txt
calc_bleu = "translation" in args.task calc_bleu = "translation" in args.task
score_fn = calculate_bleu if calc_bleu else calculate_rouge score_fn = calculate_bleu if calc_bleu else calculate_rouge
...@@ -179,7 +179,7 @@ def run_generate(): ...@@ -179,7 +179,7 @@ def run_generate():
metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 2) metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 2)
# TODO(@stas00): add whatever metadata to metrics # TODO(@stas00): add whatever metadata to metrics
metrics_save_path = save_dir.joinpath(f"{args.type_path}_{metric_name}.json") metrics_save_path = save_dir.joinpath(f"{args.type_path}_{metric_name}.json")
save_json(metrics, metrics_save_path) save_json(metrics, metrics_save_path, indent=None)
print(metrics) print(metrics)
write_txt_file(preds, save_dir.joinpath(f"{args.type_path}_generations.txt")) write_txt_file(preds, save_dir.joinpath(f"{args.type_path}_generations.txt"))
if args.debug: if args.debug:
...@@ -188,15 +188,14 @@ def run_generate(): ...@@ -188,15 +188,14 @@ def run_generate():
shutil.rmtree(json_save_dir) shutil.rmtree(json_save_dir)
def combine_partial_results(partial_results) -> Tuple[List, List]: def combine_partial_results(partial_results) -> List:
"""Concatenate partial results into one file, then sort it by id.""" """Concatenate partial results into one file, then sort it by id."""
records = [] records = []
for partial_result in partial_results: for partial_result in partial_results:
records.extend(partial_result) records.extend(partial_result)
records = list(sorted(records, key=lambda x: x["id"])) records = list(sorted(records, key=lambda x: x["id"]))
preds = [x["pred"] for x in records] preds = [x["pred"] for x in records]
labels = [x["label"] for x in records] return preds
return preds, labels
def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]: def gather_results_from_each_node(num_replicas, save_dir, timeout) -> List[Dict[str, List]]:
......
...@@ -156,7 +156,7 @@ def run_generate(verbose=True): ...@@ -156,7 +156,7 @@ def run_generate(verbose=True):
scores["info"] = args.info scores["info"] = args.info
if verbose: if verbose:
print(*scores) print(scores)
if args.score_path is not None: if args.score_path is not None:
path = args.score_path path = args.score_path
......
...@@ -115,11 +115,11 @@ class AbstractSeq2SeqDataset(Dataset): ...@@ -115,11 +115,11 @@ class AbstractSeq2SeqDataset(Dataset):
def get_char_lens(data_file): def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()] return [len(x) for x in Path(data_file).open().readlines()]
def make_sortish_sampler(self, batch_size, distributed=False, **kwargs): def make_sortish_sampler(self, batch_size, distributed=False, shuffle=True, **kwargs):
if distributed: if distributed:
return DistributedSortishSampler(self, batch_size, **kwargs) return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
else: else:
return SortishSampler(self.src_lens, batch_size) return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)
def __getitem__(self, item): def __getitem__(self, item):
raise NotImplementedError("You must implement this") raise NotImplementedError("You must implement this")
...@@ -193,18 +193,20 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset): ...@@ -193,18 +193,20 @@ class Seq2SeqDataset(AbstractSeq2SeqDataset):
class SortishSampler(Sampler): class SortishSampler(Sampler):
"Go through the text data by order of src length with a bit of randomness. From fastai repo." "Go through the text data by order of src length with a bit of randomness. From fastai repo."
def __init__(self, data, batch_size): def __init__(self, data, batch_size, shuffle=True):
self.data, self.bs = data, batch_size self.data, self.bs, self.shuffle = data, batch_size, shuffle
def __len__(self) -> int: def __len__(self) -> int:
return len(self.data) return len(self.data)
def __iter__(self): def __iter__(self):
return iter(sortish_sampler_indices(self.data, self.bs)) return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))
def sortish_sampler_indices(data: List, bs: int) -> np.array: def sortish_sampler_indices(data: List, bs: int, shuffle=True) -> np.array:
"Go through the text data by order of src length with a bit of randomness. From fastai repo." "Go through the text data by order of src length with a bit of randomness. From fastai repo."
if not shuffle:
return np.argsort(np.array(data) * -1)
def key_fn(i): def key_fn(i):
return data[i] return data[i]
...@@ -225,7 +227,7 @@ def sortish_sampler_indices(data: List, bs: int) -> np.array: ...@@ -225,7 +227,7 @@ def sortish_sampler_indices(data: List, bs: int) -> np.array:
class DistributedSortishSampler(Sampler): class DistributedSortishSampler(Sampler):
"""Copied from torch DistributedSampler""" """Copied from torch DistributedSampler"""
def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True): def __init__(self, dataset, batch_size, num_replicas=None, rank=None, add_extra_examples=True, shuffle=True):
if num_replicas is None: if num_replicas is None:
if not dist.is_available(): if not dist.is_available():
raise RuntimeError("Requires distributed package to be available") raise RuntimeError("Requires distributed package to be available")
...@@ -246,13 +248,14 @@ class DistributedSortishSampler(Sampler): ...@@ -246,13 +248,14 @@ class DistributedSortishSampler(Sampler):
self.num_samples = len(self.available_indices) self.num_samples = len(self.available_indices)
self.batch_size = batch_size self.batch_size = batch_size
self.add_extra_examples = add_extra_examples self.add_extra_examples = add_extra_examples
self.shuffle = shuffle
def __iter__(self) -> Iterable: def __iter__(self) -> Iterable:
g = torch.Generator() g = torch.Generator()
g.manual_seed(self.epoch) g.manual_seed(self.epoch)
sortish_data = [self.dataset.src_lens[i] for i in self.available_indices] sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size) sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle)
indices = [self.available_indices[i] for i in sortish_indices] indices = [self.available_indices[i] for i in sortish_indices]
assert len(indices) == self.num_samples assert len(indices) == self.num_samples
return iter(indices) return iter(indices)
...@@ -309,9 +312,9 @@ def save_git_info(folder_path: str) -> None: ...@@ -309,9 +312,9 @@ def save_git_info(folder_path: str) -> None:
save_json(repo_infos, os.path.join(folder_path, "git_log.json")) save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
def save_json(content, path): def save_json(content, path, indent=4, **json_dump_kwargs):
with open(path, "w") as f: with open(path, "w") as f:
json.dump(content, f, indent=4) json.dump(content, f, indent=indent, **json_dump_kwargs)
def load_json(path): def load_json(path):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment