Unverified Commit de9e2979 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] distributed eval cleanup (#7110)

parent 54395d87
import argparse
import warnings
from logging import getLogger
from pathlib import Path
from typing import Dict
......@@ -18,6 +17,7 @@ try:
except ImportError:
from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
......@@ -51,6 +51,8 @@ def eval_data_dir(
tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type.
use_task_specific_params(model, task) # update config with task specific params
if max_source_length is None:
max_source_length = tokenizer.model_max_length
ds = Seq2SeqDataset(
tokenizer,
data_dir,
......@@ -97,9 +99,11 @@ def run_generate():
default="sshleifer/distilbart-xsum-12-3",
)
parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen")
parser.add_argument("--prefix", type=str, default="test", help="which subset to evaluate typically train/val/test")
parser.add_argument("--max_source_length", type=int, default=None)
parser.add_argument(
"--type_path", type=str, default="test", help="which subset to evaluate typically train/val/test"
)
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics")
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument(
......@@ -113,24 +117,23 @@ def run_generate():
parser.add_argument("--save_source", action="store_true")
args, rest = parser.parse_known_args()
parsed = parse_numeric_cl_kwargs(rest)
if parsed:
print(f"parsed the following generate kwargs: {parsed}")
generate_kwargs = parse_numeric_cl_kwargs(rest)
if generate_kwargs:
print(f"parsed the following generate kwargs: {generate_kwargs}")
Path(args.save_dir).mkdir(exist_ok=True)
if args.reference_path is None and Path(args.score_path).exists():
warnings.warn(f"score_path {args.score_path} will be overwritten unless you type ctrl-c.")
eval_data_dir(
args.input_path,
args.save_dir,
args.model_name,
prefix=args.prefix,
type_path=args.type_path,
batch_size=args.bs,
fp16=args.fp16,
task=args.task,
local_rank=args.local_rank,
n_obs=args.n_obs,
save_source=args.save_source,
**parsed,
max_source_length=args.max_source_length,
**generate_kwargs,
)
......
......@@ -98,7 +98,8 @@ class AbstractSeq2SeqDataset(Dataset):
self.max_target_length = max_target_length
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
self.tokenizer = tokenizer
self.prefix = prefix
self.prefix = prefix if prefix is not None else ""
if n_obs is not None:
self.src_lens = self.src_lens[:n_obs]
self.pad_token_id = self.tokenizer.pad_token_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment