Unverified Commit de9e2979 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] distributed eval cleanup (#7110)

parent 54395d87
import argparse import argparse
import warnings
from logging import getLogger from logging import getLogger
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
...@@ -18,6 +17,7 @@ try: ...@@ -18,6 +17,7 @@ try:
except ImportError: except ImportError:
from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params from utils import Seq2SeqDataset, parse_numeric_cl_kwargs, save_json, use_task_specific_params
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
...@@ -51,6 +51,8 @@ def eval_data_dir( ...@@ -51,6 +51,8 @@ def eval_data_dir(
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type. logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type.
use_task_specific_params(model, task) # update config with task specific params use_task_specific_params(model, task) # update config with task specific params
if max_source_length is None:
max_source_length = tokenizer.model_max_length
ds = Seq2SeqDataset( ds = Seq2SeqDataset(
tokenizer, tokenizer,
data_dir, data_dir,
...@@ -97,9 +99,11 @@ def run_generate(): ...@@ -97,9 +99,11 @@ def run_generate():
default="sshleifer/distilbart-xsum-12-3", default="sshleifer/distilbart-xsum-12-3",
) )
parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen") parser.add_argument("--save_dir", type=str, help="where to save", default="tmp_gen")
parser.add_argument("--prefix", type=str, default="test", help="which subset to evaluate typically train/val/test") parser.add_argument("--max_source_length", type=int, default=None)
parser.add_argument(
"--type_path", type=str, default="test", help="which subset to evaluate typically train/val/test"
)
parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target") parser.add_argument("--reference_path", type=str, required=False, help="like cnn_dm/test.target")
parser.add_argument("--score_path", type=str, required=False, default="metrics.json", help="where to save metrics")
parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics") parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size") parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument( parser.add_argument(
...@@ -113,24 +117,23 @@ def run_generate(): ...@@ -113,24 +117,23 @@ def run_generate():
parser.add_argument("--save_source", action="store_true") parser.add_argument("--save_source", action="store_true")
args, rest = parser.parse_known_args() args, rest = parser.parse_known_args()
parsed = parse_numeric_cl_kwargs(rest) generate_kwargs = parse_numeric_cl_kwargs(rest)
if parsed: if generate_kwargs:
print(f"parsed the following generate kwargs: {parsed}") print(f"parsed the following generate kwargs: {generate_kwargs}")
Path(args.save_dir).mkdir(exist_ok=True) Path(args.save_dir).mkdir(exist_ok=True)
if args.reference_path is None and Path(args.score_path).exists():
warnings.warn(f"score_path {args.score_path} will be overwritten unless you type ctrl-c.")
eval_data_dir( eval_data_dir(
args.input_path, args.input_path,
args.save_dir, args.save_dir,
args.model_name, args.model_name,
prefix=args.prefix, type_path=args.type_path,
batch_size=args.bs, batch_size=args.bs,
fp16=args.fp16, fp16=args.fp16,
task=args.task, task=args.task,
local_rank=args.local_rank, local_rank=args.local_rank,
n_obs=args.n_obs, n_obs=args.n_obs,
save_source=args.save_source, save_source=args.save_source,
**parsed, max_source_length=args.max_source_length,
**generate_kwargs,
) )
......
...@@ -98,7 +98,8 @@ class AbstractSeq2SeqDataset(Dataset): ...@@ -98,7 +98,8 @@ class AbstractSeq2SeqDataset(Dataset):
self.max_target_length = max_target_length self.max_target_length = max_target_length
assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.prefix = prefix self.prefix = prefix if prefix is not None else ""
if n_obs is not None: if n_obs is not None:
self.src_lens = self.src_lens[:n_obs] self.src_lens = self.src_lens[:n_obs]
self.pad_token_id = self.tokenizer.pad_token_id self.pad_token_id = self.tokenizer.pad_token_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment