"...git@developer.sourcefind.cn:modelzoo/qwen_lmdeploy.git" did not exist on "d2c9caa4254872cf07f9021407da61dce2ca8918"
Unverified Commit e53138a1 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] add src_lang kwarg for distributed eval (#7300)

parent a9c7849c
...@@ -38,6 +38,9 @@ def eval_data_dir( ...@@ -38,6 +38,9 @@ def eval_data_dir(
fp16=False, fp16=False,
task="summarization", task="summarization",
local_rank=None, local_rank=None,
src_lang=None,
tgt_lang=None,
prefix="",
**generate_kwargs, **generate_kwargs,
) -> Dict: ) -> Dict:
"""Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json""" """Run evaluation on part of the data for one gpu and save to {save_dir}/rank_{rank}_output.json"""
...@@ -57,6 +60,8 @@ def eval_data_dir( ...@@ -57,6 +60,8 @@ def eval_data_dir(
use_task_specific_params(model, task) # update config with task specific params use_task_specific_params(model, task) # update config with task specific params
if max_source_length is None: if max_source_length is None:
max_source_length = tokenizer.model_max_length max_source_length = tokenizer.model_max_length
if prefix is None:
prefix = prefix or getattr(model.config, "prefix", "") or ""
ds = Seq2SeqDataset( ds = Seq2SeqDataset(
tokenizer, tokenizer,
data_dir, data_dir,
...@@ -64,7 +69,9 @@ def eval_data_dir( ...@@ -64,7 +69,9 @@ def eval_data_dir(
max_target_length=1024, max_target_length=1024,
type_path=type_path, type_path=type_path,
n_obs=n_obs, n_obs=n_obs,
prefix=model.config.prefix, src_lang=src_lang,
tgt_lang=tgt_lang,
prefix=prefix,
) )
# I set shuffle=True for a more accurate progress bar. # I set shuffle=True for a more accurate progress bar.
# If all the longest samples are first, the prog bar estimate is too high at the beginning. # If all the longest samples are first, the prog bar estimate is too high at the beginning.
...@@ -118,6 +125,11 @@ def run_generate(): ...@@ -118,6 +125,11 @@ def run_generate():
required=False, required=False,
help="How long should master process wait for other processes to finish.", help="How long should master process wait for other processes to finish.",
) )
parser.add_argument("--src_lang", type=str, default=None, required=False)
parser.add_argument("--tgt_lang", type=str, default=None, required=False)
parser.add_argument(
"--prefix", type=str, required=False, default=None, help="will be added to the begininng of src examples"
)
parser.add_argument("--fp16", action="store_true") parser.add_argument("--fp16", action="store_true")
parser.add_argument("--debug", action="store_true") parser.add_argument("--debug", action="store_true")
start_time = time.time() start_time = time.time()
...@@ -144,6 +156,9 @@ def run_generate(): ...@@ -144,6 +156,9 @@ def run_generate():
local_rank=args.local_rank, local_rank=args.local_rank,
n_obs=args.n_obs, n_obs=args.n_obs,
max_source_length=args.max_source_length, max_source_length=args.max_source_length,
prefix=args.prefix,
src_lang=args.src_lang,
tgt_lang=args.tgt_lang,
**generate_kwargs, **generate_kwargs,
) )
......
...@@ -168,6 +168,7 @@ class MBartTokenizer(XLMRobertaTokenizer): ...@@ -168,6 +168,7 @@ class MBartTokenizer(XLMRobertaTokenizer):
truncation: bool = True, truncation: bool = True,
padding: str = "longest", padding: str = "longest",
return_tensors: str = "pt", return_tensors: str = "pt",
add_prefix_space: bool = False, # ignored
**kwargs, **kwargs,
) -> BatchEncoding: ) -> BatchEncoding:
"""Prepare a batch that can be passed directly to an instance of MBartModel. """Prepare a batch that can be passed directly to an instance of MBartModel.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment