Unverified Commit e53138a1 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] add src_lang kwarg for distributed eval (#7300)

parent a9c7849c
......@@ -38,6 +38,9 @@ def eval_data_dir(
fp16=False,
task="summarization",
local_rank=None,
src_lang=None,
tgt_lang=None,
prefix="",
**generate_kwargs,
) -> Dict:
"""Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json"""
......@@ -57,6 +60,8 @@ def eval_data_dir(
use_task_specific_params(model, task) # update config with task specific params
if max_source_length is None:
max_source_length = tokenizer.model_max_length
if prefix is None:
prefix = prefix or getattr(model.config, "prefix", "") or ""
ds = Seq2SeqDataset(
tokenizer,
data_dir,
......@@ -64,7 +69,9 @@ def eval_data_dir(
max_target_length=1024,
type_path=type_path,
n_obs=n_obs,
prefix=model.config.prefix,
src_lang=src_lang,
tgt_lang=tgt_lang,
prefix=prefix,
)
# I set shuffle=True for a more accurate progress bar.
# If all the longest samples are first, the prog bar estimate is too high at the beginning.
......@@ -118,6 +125,11 @@ def run_generate():
required=False,
help="How long should master process wait for other processes to finish.",
)
parser.add_argument("--src_lang", type=str, default=None, required=False)
parser.add_argument("--tgt_lang", type=str, default=None, required=False)
parser.add_argument(
"--prefix", type=str, required=False, default=None, help="will be added to the begininng of src examples"
)
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--debug", action="store_true")
start_time = time.time()
......@@ -144,6 +156,9 @@ def run_generate():
local_rank=args.local_rank,
n_obs=args.n_obs,
max_source_length=args.max_source_length,
prefix=args.prefix,
src_lang=args.src_lang,
tgt_lang=args.tgt_lang,
**generate_kwargs,
)
......
......@@ -168,6 +168,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
truncation: bool = True,
padding: str = "longest",
return_tensors: str = "pt",
add_prefix_space: bool = False, # ignored
**kwargs,
) -> BatchEncoding:
"""Prepare a batch that can be passed directly to an instance of MBartModel.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment