"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ecfddc6034a12c3ef6c4ef3e6f56f7d034ec3075"
Unverified Commit e7f8d2ab authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] two stage run_distributed_eval.py (#7105)

parent 0ec63afe
from pathlib import Path
import fire
try:
from .utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file
except ImportError:
from utils import calculate_bleu, calculate_rouge, load_json, save_json, write_txt_file
def combine_partial_results(
result_dir: str, save_dir: str = None, save_prefix=None, calc_bleu=False, just_metrics=False
):
"""Write first n lines of each file f in src_dir to dest_dir/f """
src_dir = Path(result_dir)
save_dir = Path(save_dir)
save_dir.mkdir(exist_ok=True)
paths_to_combine = list(src_dir.glob("rank*.json"))
records = []
for partial_result in paths_to_combine:
records.extend(load_json(partial_result))
preds = [x["pred"] for x in records]
labels = [x["label"] for x in records]
score_fn = calculate_bleu if calc_bleu else calculate_rouge
metrics = score_fn(preds, labels)
save_json(metrics, save_dir.joinpath("metrics.json")) # better would be be {prefix}_{rouge|bleu}.json
print(metrics)
if just_metrics:
return
if save_prefix is None:
save_prefix = "generated"
print("using generated as prefix")
tgt_path = save_dir.joinpath(f"{save_prefix}.target")
write_txt_file(labels, tgt_path)
pred_path = save_dir.joinpath(f"{save_prefix}.pred_target")
write_txt_file(preds, pred_path)
if "source" in records[0]:
src_path = save_dir.joinpath(f"{save_prefix}.source")
write_txt_file([x["source"] for x in records], src_path)
if __name__ == "__main__":
fire.Fire(combine_partial_results)
import argparse
import warnings
from logging import getLogger
from pathlib import Path
from typing import Dict
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
logger = getLogger(__name__)
try:
from .utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
except ImportError:
from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def eval_data_dir(
data_dir,
save_dir: str,
model_name: str,
bs: int = 8,
max_source_length: int = 1024,
type_path="val",
n_obs=None,
fp16=False,
save_source=False,
num_beams: int = 4,
task="summarization",
local_rank=None,
**generate_kwargs,
) -> Dict:
"""Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json"""
model_name = str(model_name)
assert local_rank is not None
torch.distributed.init_process_group(backend="nccl", rank=local_rank)
save_dir = Path(save_dir)
save_path = save_dir.joinpath(f"rank_{local_rank}_output.json")
torch.cuda.set_device(local_rank)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).cuda()
if fp16:
model = model.half()
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type.
use_task_specific_params(model, task) # update config with task specific params
ds = Seq2SeqDataset(
tokenizer,
data_dir,
max_source_length,
max_target_length=1024,
type_path=type_path,
n_obs=n_obs,
prefix=model.config.prefix,
)
sampler = ds.make_sortish_sampler(bs, distributed=True)
data_loader = DataLoader(ds, sampler=sampler, batch_size=bs, collate_fn=ds.collate_fn)
dec_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=False) # tokenizer.decode
results = []
for batch in tqdm(data_loader):
summaries = model.generate(
input_ids=batch["input_ids"].to(model.device),
attention_mask=batch["attention_mask"].to(model.device),
num_beams=num_beams,
**generate_kwargs,
)
preds = tokenizer.batch_decode(summaries, **dec_kwargs)
labels = tokenizer.batch_decode(batch["labels"], **dec_kwargs)
if save_source:
docs = tokenizer.batch_decode(batch["input_ids"], **dec_kwargs)
for i in range(len(labels)):
label, pred = labels[i], preds[i]
if save_source:
results.append(dict(pred=pred, label=label, source=docs[i]))
else:
results.append(dict(pred=pred, label=label))
save_json(results, save_path)
return results
def run_generate():
parser = argparse.ArgumentParser(
epilog="Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate"
)
parser.add_argument("--input_path", type=str, help="like cnn_dm/test.source")
parser.add_argument(
"--model_name",
type=str,
help="like facebook/bart-large-cnn,t5-base, etc.",
default="sshleifer/distilbart-xsum-12-3",
)
parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen")
parser.add_argument("--prefix", type=str, default="test", help="which subset to evaluate typically train/val/test")
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics")
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument(
"--local_rank", type=int, default=-1, required=False, help="should be passed by distributed.launch"
)
parser.add_argument(
"--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all."
)
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--save_source", action="store_true")
args, rest = parser.parse_known_args()
parsed = parse_numeric_cl_kwargs(rest)
if parsed:
print(f"parsed the following generate kwargs: {parsed}")
Path(args.save_dir).mkdir(exist_ok=True)
if args.reference_path is None and Path(args.score_path).exists():
warnings.warn(f"score_path {args.score_path} will be overwritten unless you type ctrl-c.")
eval_data_dir(
args.input_path,
args.save_dir,
args.model_name,
prefix=args.prefix,
batch_size=args.bs,
fp16=args.fp16,
task=args.task,
local_rank=args.local_rank,
n_obs=args.n_obs,
save_source=args.save_source,
**parsed,
)
if __name__ == "__main__":
# Usage for MT:
run_generate()
...@@ -387,3 +387,10 @@ def parse_numeric_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, fl ...@@ -387,3 +387,10 @@ def parse_numeric_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, fl
result[unparsed_args[i][2:]] = value result[unparsed_args[i][2:]] = value
return result return result
def write_txt_file(ordered_tgt, path):
f = Path(path).open("w")
for ln in ordered_tgt:
f.write(ln + "\n")
f.flush()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment