"...git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "1396186815aab1f2380c3627efd26cb12d760500"
Unverified Commit 6fe8a693 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] Fix t5 warning for distributed eval (#7487)

parent 4c672846
...@@ -42,8 +42,7 @@ def eval_data_dir( ...@@ -42,8 +42,7 @@ def eval_data_dir(
task="summarization", task="summarization",
local_rank=None, local_rank=None,
num_return_sequences=1, num_return_sequences=1,
src_lang=None, dataset_kwargs: Dict = None,
tgt_lang=None,
prefix="", prefix="",
**generate_kwargs, **generate_kwargs,
) -> Dict: ) -> Dict:
...@@ -78,9 +77,8 @@ def eval_data_dir( ...@@ -78,9 +77,8 @@ def eval_data_dir(
max_target_length=1024, max_target_length=1024,
type_path=type_path, type_path=type_path,
n_obs=n_obs, n_obs=n_obs,
src_lang=src_lang,
tgt_lang=tgt_lang,
prefix=prefix, prefix=prefix,
**dataset_kwargs,
) )
# I set shuffle=True for a more accurate progress bar. # I set shuffle=True for a more accurate progress bar.
# If all the longest samples are first, the prog bar estimate is too high at the beginning. # If all the longest samples are first, the prog bar estimate is too high at the beginning.
...@@ -158,6 +156,11 @@ def run_generate(): ...@@ -158,6 +156,11 @@ def run_generate():
if intermediate_files: if intermediate_files:
raise ValueError(f"Found files at {json_save_dir} please move or remove them.") raise ValueError(f"Found files at {json_save_dir} please move or remove them.")
# In theory, a node could finish and save before another node hits this. If this happens, we can address later. # In theory, a node could finish and save before another node hits this. If this happens, we can address later.
dataset_kwargs = {}
if args.src_lang is not None:
dataset_kwargs["src_lang"] = args.src_lang
if args.tgt_lang is not None:
dataset_kwargs["tgt_lang"] = args.tgt_lang
Path(args.save_dir).mkdir(exist_ok=True) Path(args.save_dir).mkdir(exist_ok=True)
results, num_replicas = eval_data_dir( results, num_replicas = eval_data_dir(
...@@ -173,9 +176,7 @@ def run_generate(): ...@@ -173,9 +176,7 @@ def run_generate():
max_source_length=args.max_source_length, max_source_length=args.max_source_length,
num_return_sequences=args.num_return_sequences, num_return_sequences=args.num_return_sequences,
prefix=args.prefix, prefix=args.prefix,
src_lang=args.src_lang, dataset_kwargs=dataset_kwargs ** generate_kwargs,
tgt_lang=args.tgt_lang,
**generate_kwargs,
) )
if args.local_rank <= 0: if args.local_rank <= 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment