"docs/vscode:/vscode.git/clone" did not exist on "a39dfe4fb122c11be98a563fb8ca43b322e01036"
Unverified Commit 6fe8a693 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] Fix t5 warning for distributed eval (#7487)

parent 4c672846
......@@ -42,8 +42,7 @@ def eval_data_dir(
task="summarization",
local_rank=None,
num_return_sequences=1,
src_lang=None,
tgt_lang=None,
dataset_kwargs: Dict = None,
prefix="",
**generate_kwargs,
) -> Dict:
......@@ -78,9 +77,8 @@ def eval_data_dir(
max_target_length=1024,
type_path=type_path,
n_obs=n_obs,
src_lang=src_lang,
tgt_lang=tgt_lang,
prefix=prefix,
**dataset_kwargs,
)
# I set shuffle=True for a more accurate progress bar.
# If all the longest samples are first, the prog bar estimate is too high at the beginning.
......@@ -158,6 +156,11 @@ def run_generate():
if intermediate_files:
raise ValueError(f"Found files at {json_save_dir} please move or remove them.")
# In theory, a node could finish and save before another node hits this. If this happens, we can address later.
dataset_kwargs = {}
if args.src_lang is not None:
dataset_kwargs["src_lang"] = args.src_lang
if args.tgt_lang is not None:
dataset_kwargs["tgt_lang"] = args.tgt_lang
Path(args.save_dir).mkdir(exist_ok=True)
results, num_replicas = eval_data_dir(
......@@ -173,9 +176,7 @@ def run_generate():
max_source_length=args.max_source_length,
num_return_sequences=args.num_return_sequences,
prefix=args.prefix,
src_lang=args.src_lang,
tgt_lang=args.tgt_lang,
**generate_kwargs,
dataset_kwargs=dataset_kwargs ** generate_kwargs,
)
if args.local_rank <= 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment