"docs/source/en/run_scripts.mdx" did not exist on "4a353cacb7e9c5a7fc895a77e98452eae525ba38"
Unverified Commit 0b1f552a authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

fix run_seq2seq.py; porting trainer tests to it (#10162)

* fix run_seq2seq.py; porting DeepSpeed tests to it

* unrefactor

* defensive programming

* defensive programming 2

* port the rest of the trainer tests

* style

* a cleaner scripts dir finder

* cleanup
parent 31b0560a
......@@ -18,6 +18,7 @@ Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
import json
import logging
import os
import re
......@@ -38,6 +39,7 @@ from transformers import (
DataCollatorForSeq2Seq,
HfArgumentParser,
MBartTokenizer,
MBartTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
default_data_collator,
......@@ -53,6 +55,11 @@ with FileLock(".lock") as lock:
logger = logging.getLogger(__name__)
def save_json(content, path, indent=4, **json_dump_kwargs):
with open(path, "w") as f:
json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs)
@dataclass
class ModelArguments:
"""
......@@ -351,8 +358,15 @@ def main():
)
# Set decoder_start_token_id
if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
assert (
data_args.target_lang is not None and data_args.source_lang is not None
), "mBart requires --target_lang and --source_lang"
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.target_lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.target_lang)
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
......@@ -448,6 +462,8 @@ def main():
if training_args.do_train:
train_dataset = datasets["train"]
if "train" not in datasets:
raise ValueError("--do_train requires a train dataset")
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
train_dataset = train_dataset.map(
......@@ -460,6 +476,8 @@ def main():
if training_args.do_eval:
max_target_length = data_args.val_max_target_length
if "validation" not in datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
......@@ -473,6 +491,8 @@ def main():
if training_args.do_predict:
max_target_length = data_args.val_max_target_length
if "test" not in datasets:
raise ValueError("--do_predict requires a test dataset")
test_dataset = datasets["test"]
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
......@@ -550,6 +570,7 @@ def main():
compute_metrics=compute_metrics if training_args.predict_with_generate else None,
)
all_metrics = {}
# Training
if training_args.do_train:
if last_checkpoint is not None:
......@@ -561,13 +582,17 @@ def main():
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
if trainer.is_world_process_zero():
with open(output_train_file, "w") as writer:
logger.info("***** Train results *****")
for key, value in sorted(train_result.metrics.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
logger.info("***** train metrics *****")
for key in sorted(metrics.keys()):
logger.info(f" {key} = {metrics[key]}")
save_json(metrics, os.path.join(training_args.output_dir, "train_results.json"))
all_metrics.update(metrics)
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
......@@ -577,16 +602,19 @@ def main():
if training_args.do_eval:
logger.info("*** Evaluate ***")
results = trainer.evaluate(max_length=data_args.val_max_target_length, num_beams=data_args.num_beams)
results = {k: round(v, 4) for k, v in results.items()}
metrics = trainer.evaluate(
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="val"
)
metrics = {k: round(v, 4) for k, v in metrics.items()}
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["val_samples"] = min(max_val_samples, len(eval_dataset))
output_eval_file = os.path.join(training_args.output_dir, "eval_results_seq2seq.txt")
if trainer.is_world_process_zero():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in sorted(results.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
logger.info("***** val metrics *****")
for key in sorted(metrics.keys()):
logger.info(f" {key} = {metrics[key]}")
save_json(metrics, os.path.join(training_args.output_dir, "val_results.json"))
all_metrics.update(metrics)
if training_args.do_predict:
logger.info("*** Test ***")
......@@ -597,16 +625,17 @@ def main():
max_length=data_args.val_max_target_length,
num_beams=data_args.num_beams,
)
test_metrics = test_results.metrics
test_metrics["test_loss"] = round(test_metrics["test_loss"], 4)
metrics = test_results.metrics
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
metrics = {k: round(v, 4) for k, v in metrics.items()}
output_test_result_file = os.path.join(training_args.output_dir, "test_results_seq2seq.txt")
if trainer.is_world_process_zero():
with open(output_test_result_file, "w") as writer:
logger.info("***** Test results *****")
for key, value in sorted(test_metrics.items()):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
logger.info("***** test metrics *****")
for key in sorted(metrics.keys()):
logger.info(f" {key} = {metrics[key]}")
save_json(metrics, os.path.join(training_args.output_dir, "test_results.json"))
all_metrics.update(metrics)
if training_args.predict_with_generate:
test_preds = tokenizer.batch_decode(
......@@ -617,6 +646,9 @@ def main():
with open(output_test_preds_file, "w") as writer:
writer.write("\n".join(test_preds))
if trainer.is_world_process_zero():
save_json(all_metrics, os.path.join(training_args.output_dir, "all_results.json"))
return results
......
This diff is collapsed.
{ "translation": { "en": "Corrections to votes and voting intentions: see Minutes Assignment conferred on a Member: see Minutes Membership of committees and delegations: see Minutes Decisions concerning certain documents: see Minutes Forwarding of texts adopted during the sitting: see Minutes Dates for next sittings: see Minutes", "ro": "Corectările voturilor şi intenţiile de vot: a se vedea procesul-verbal Misiune încredinţată unui deputat: consultaţi procesul-verbal Componenţa comisiilor şi a delegaţiilor: a se vedea procesul-verbal Decizii privind anumite documente: a se vedea procesul-verbal Transmiterea textelor adoptate în cursul prezentei şedinţe: a se vedea procesul-verbal Calendarul următoarelor şedinţe: a se vedea procesul-verbal" } }
{ "translation": { "en": "Membership of Parliament: see Minutes Approval of Minutes of previous sitting: see Minutes Membership of Parliament: see Minutes Verification of credentials: see Minutes Documents received: see Minutes Written statements and oral questions (tabling): see Minutes Petitions: see Minutes Texts of agreements forwarded by the Council: see Minutes Action taken on Parliament's resolutions: see Minutes Agenda for next sitting: see Minutes Closure of sitting (The sitting was closed at 7.45 p.m.)", "ro": "Componenţa Parlamentului: a se vedea procesul-verbal Aprobarea procesului-verbal al şedinţei precedente: a se vedea procesul-verbal Componenţa Parlamentului: a se vedea procesul-verbal Verificarea prerogativelor: a se vedea procesul-verbal Depunere de documente: a se vedea procesul-verbal Declaraţii scrise şi întrebări orale (depunere): consultaţi procesul-verbal Petiţii: a se vedea procesul-verbal Transmiterea de către Consiliu a textelor acordurilor: a se vedea procesul-verbal Cursul dat rezoluţiilor Parlamentului: a se vedea procesul-verbal Ordinea de zi a următoarei şedinţe: a se vedea procesul-verbal Ridicarea şedinţei (Se levanta la sesión a las 19.45 horas)" } }
{ "translation": { "en": "Election of Vice-Presidents of the European Parliament (deadline for submitting nominations): see Minutes (The sitting was suspended at 12.40 p.m. and resumed at 3.00 p.m.) Election of Quaestors of the European Parliament (deadline for submitting nominations): see Minutes (The sitting was suspended at 3.25 p.m. and resumed at 6.00 p.m.) Agenda for next sitting: see Minutes Closure of sitting (The sitting was closed at 6.15 p.m.) Opening of the sitting (The sitting was opened at 9.35 a.m.) Documents received: see Minutes Approval of Minutes of previous sitting: see Minutes Membership of Parliament: see Minutes", "ro": "Alegerea vicepreşedinţilor Parlamentului European (termenul de depunere a candidaturilor): consultaţi procesul-verbal (Die Sitzung wird um 12.40 Uhr unterbrochen und um 15.00 Uhr wiederaufgenommen). Alegerea chestorilor Parlamentului European (termenul de depunere a candidaturilor): consultaţi procesul-verbal (Die Sitzung wird um 15.25 Uhr unterbrochen und um 18.00 Uhr wiederaufgenommen). Ordinea de zi a următoarei şedinţe: a se vedea procesul-verbal Ridicarea şedinţei (Die Sitzung wird um 18.15 Uhr geschlossen.) Deschiderea şedinţei (Die Sitzung wird um 9.35 Uhr eröffnet.) Depunerea documentelor: a se vedea procesul-verbal Aprobarea procesului-verbal al şedinţei precedente: a se vedea procesul-verbal Componenţa Parlamentului: a se vedea procesul-verbal" } }
{ "translation": { "en": "Membership of committees (deadline for tabling amendments): see Minutes (The sitting was suspended at 7 p.m. and resumed at 9 p.m.) Agenda for next sitting: see Minutes Closure of sitting (The sitting was suspended at 23.25 p.m.) Documents received: see Minutes Communication of Council common positions: see Minutes (The sitting was suspended at 11.35 a.m. and resumed for voting time at noon) Approval of Minutes of previous sitting: see Minutes Committee of Inquiry into the crisis of the Equitable Life Assurance Society (extension of mandate): see Minutes", "ro": "Componenţa comisiilor (termenul de depunere a amendamentelor): consultaţi procesul-verbal (La seduta, sospesa alle 19.00, è ripresa alle 21.00) Ordinea de zi a următoarei şedinţe: a se vedea procesul-verbal Ridicarea şedinţei (Die Sitzung wird um 23.25 Uhr geschlossen.) Depunerea documentelor: a se vedea procesul-verbal Comunicarea poziţiilor comune ale Parlamentului: a se vedea procesul-verbal (La séance, suspendue à 11h35 dans l'attente de l'Heure des votes, est reprise à midi) Aprobarea procesului-verbal al şedinţei precedente: a se vedea procesul-verbal Comisia de anchetă privind criza societăţii de asigurări \"Equitable Life” (prelungirea mandatului): consultaţi procesul-verbal" } }
{ "translation": { "en": "Announcement by the President: see Minutes 1. Membership of committees (vote) 2. Amendment of the ACP-EC Partnership Agreement (vote) 4. Certification of train drivers operating locomotives and trains on the railway system in the Community (vote) 6. Law applicable to non-contractual obligations (\"ROME II\") (vote) 8. Seventh and eighth annual reports on arms exports (vote) Corrections to votes and voting intentions: see Minutes Membership of committees and delegations: see Minutes Request for waiver of parliamentary immunity: see Minutes Decisions concerning certain documents: see Minutes", "ro": "Comunicarea Preşedintelui: consultaţi procesul-verbal 1. Componenţa comisiilor (vot) 2. Modificarea Acordului de parteneriat ACP-CE (\"Acordul de la Cotonou”) (vot) 4. Certificarea mecanicilor de locomotivă care conduc locomotive şi trenuri în sistemul feroviar comunitar (vot) 6. Legea aplicabilă obligaţiilor necontractuale (\"Roma II”) (vot) 8. Al şaptelea şi al optulea raport anual privind exportul de armament (vot) Corectările voturilor şi intenţiile de vot: a se vedea procesul-verbal Componenţa comisiilor şi a delegaţiilor: a se vedea procesul-verbal Cerere de ridicare a imunităţii parlamentare: consultaţi procesul-verbal Decizii privind anumite documente: a se vedea procesul-verbal" } }
{ "translation": { "en": "Written statements for entry", "ro": "Declaraţii scrise înscrise" } }
{ "translation": { "en": "Written statements for entry in the register (Rule 116): see Minutes Forwarding of texts adopted during the sitting: see Minutes Dates for next sittings: see Minutes Adjournment of the session I declare the session of the European Parliament adjourned. (The sitting was closed at 1 p.m.) Approval of Minutes of previous sitting: see Minutes Membership of Parliament: see Minutes Request for the defence of parliamentary immunity: see Minutes Appointments to committees (proposal by the Conference of Presidents): see Minutes Documents received: see Minutes Texts of agreements forwarded by the Council: see Minutes", "ro": "Declaraţii scrise înscrise în registru (articolul 116 din Regulamentul de procedură): a se vedea procesul-verbal Transmiterea textelor adoptate în cursul prezentei şedinţe: a se vedea procesul-verbal Calendarul următoarelor şedinţe: a se vedea procesul-verbal Întreruperea sesiunii Dichiaro interrotta la sessione del Parlamento europeo. (La seduta è tolta alle 13.00) Aprobarea procesului-verbal al şedinţei precedente: a se vedea procesul-verbal Componenţa Parlamentului: a se vedea procesul-verbal Cerere de apărare a imunităţii parlamentare: consultaţi procesul-verbal Numiri în comisii (propunerea Conferinţei preşedinţilor): consultaţi procesul-verbal Depunerea documentelor: a se vedea procesul-verbal Transmiterea de către Consiliu a textelor acordurilor: a se vedea procesul-verbal" } }
{ "translation": { "en": "Action taken on Parliament's resolutions: see Minutes Oral questions and written statements (tabling): see Minutes Written statements (Rule 116): see Minutes Agenda: see Minutes 1. Appointments to parliamentary committees (vote): see Minutes Voting time Agenda for next sitting: see Minutes Closure of sitting (The sitting was closed at 12 midnight) Opening of the sitting (The sitting was opened at 09.05) Documents received: see Minutes Approval of Minutes of previous sitting: see Minutes 1. Protection of passengers against displaced luggage (vote) 2.", "ro": "Continuări ale rezoluţiilor Parlamentului: consultaţi procesul-verbal Declaraţii scrise şi întrebări orale (depunere): consultaţi procesul-verbal Declaraţii scrise (articolul 116 din Regulamentul de procedură) Ordinea de zi: a se vedea procesul-verbal 1. Numiri în comisiile parlamentare (vot): consultaţi procesul-verbal Timpul afectat votului Ordinea de zi a următoarei şedinţe: a se vedea procesul-verbal Ridicarea şedinţei (La seduta è tolta alle 24.00) Deschiderea şedinţei (The sitting was opened at 09.05) Depunerea documentelor: a se vedea procesul-verbal Aprobarea procesului-verbal al şedinţei precedente: a se vedea procesul-verbal 1. Protecţia pasagerilor împotriva deplasării bagajelor (vot) 2." } }
{ "translation": { "en": "Approval of motor vehicles with regard to the forward field of vision of the driver (vote) 3. EC-Korea Agreement on scientific and technological cooperation (vote) 4. Mainstreaming sustainability in development cooperation policies (vote) 5. Draft Amending Budget No 1/2007 (vote) 7. EC-Gabon Fisheries Partnership (vote) 10. Limitation periods in cross-border disputes involving personal injuries and fatal accidents (vote) 12. Strategy for a strengthened partnership with the Pacific Islands (vote) 13. The European private company statute (vote) That concludes the vote.", "ro": "Omologarea vehiculelor cu motor cu privire la câmpul de vizibilitate înainte al conducătorului auto (vot) 3. Acordul CE-Coreea de cooperare ştiinţifică şi tehnologică (vot) 4. Integrarea durabilităţii în politicile de cooperare pentru dezvoltare (vot) 5. Proiect de buget rectificativ nr.1/2007 (vot) 7. Acordul de parteneriat în domeniul pescuitului între Comunitatea Europeană şi Republica Gaboneză (vot) 10. Termenele de prescripţie aplicabile în cadrul litigiilor transfrontaliere cu privire la vătămările corporale şi accidentele mortale (vot) 12. Relaţiile UE cu insulele din Pacific: Strategie pentru un parteneriat consolidat (vot) 13. Statutul societăţii private europene (vot) Damit ist die Abstimmungsstunde beendet." } }
{ "translation": { "en": "Corrections to votes and voting intentions: see Minutes Assignment conferred on a Member: see Minutes Membership of committees and delegations: see Minutes Decisions concerning certain documents: see Minutes Forwarding of texts adopted during the sitting: see Minutes Dates for next sittings: see Minutes", "ro": "Corectările voturilor şi intenţiile de vot: a se vedea procesul-verbal Misiune încredinţată unui deputat: consultaţi procesul-verbal Componenţa comisiilor şi a delegaţiilor: a se vedea procesul-verbal Decizii privind anumite documente: a se vedea procesul-verbal Transmiterea textelor adoptate în cursul prezentei şedinţe: a se vedea procesul-verbal Calendarul următoarelor şedinţe: a se vedea procesul-verbal" } }
{ "translation": { "en": "Written statements for entry", "ro": "Declaraţii scrise înscrise" } }
This diff is collapsed.
......@@ -115,15 +115,16 @@ class TestDeepSpeed(TestCasePlus):
extra_args_str: str = None,
remove_args_str: str = None,
):
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
data_dir = self.examples_dir / "test_data/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir()
args = f"""
--model_name_or_path {model_name}
--data_dir {data_dir}
--train_file {data_dir}/train.json
--validation_file {data_dir}/val.json
--output_dir {output_dir}
--overwrite_output_dir
--n_train 8
--n_val 8
--max_train_samples 8
--max_val_samples 8
--max_source_length {max_len}
--max_target_length {max_len}
--val_max_target_length {max_len}
......@@ -139,8 +140,8 @@ class TestDeepSpeed(TestCasePlus):
--label_smoothing_factor 0.1
--adafactor
--task translation
--tgt_lang ro_RO
--src_lang en_XX
--target_lang ro_RO
--source_lang en_XX
""".split()
if extra_args_str is not None:
......@@ -151,7 +152,7 @@ class TestDeepSpeed(TestCasePlus):
args = [x for x in args if x not in remove_args]
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config.json".split()
script = [f"{self.examples_dir_str}/seq2seq/finetune_trainer.py"]
script = [f"{self.examples_dir_str}/seq2seq/run_seq2seq.py"]
num_gpus = get_gpu_count() if distributed else 1
launcher = f"deepspeed --num_gpus {num_gpus}".split()
......
......@@ -30,7 +30,10 @@ from transformers.testing_utils import (
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed
from .finetune_trainer import main
bindir = os.path.abspath(os.path.dirname(__file__))
sys.path.append(f"{bindir}/../../seq2seq")
from run_seq2seq import main # noqa
set_seed(42)
......@@ -60,8 +63,8 @@ def require_apex(test_case):
return test_case
class TestFinetuneTrainer(TestCasePlus):
def finetune_trainer_quick(self, distributed=None, extra_args_str=None):
class TestTrainerExt(TestCasePlus):
def run_seq2seq_quick(self, distributed=False, extra_args_str=None):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
......@@ -69,35 +72,37 @@ class TestFinetuneTrainer(TestCasePlus):
assert "eval_bleu" in first_step_stats
@require_torch_non_multi_gpu
def test_finetune_trainer_no_dist(self):
self.finetune_trainer_quick()
def test_run_seq2seq_no_dist(self):
self.run_seq2seq_quick()
# the following 2 tests verify that the trainer can handle distributed and non-distributed with n_gpu > 1
# verify that the trainer can handle non-distributed with n_gpu > 1
@require_torch_multi_gpu
def test_finetune_trainer_dp(self):
self.finetune_trainer_quick(distributed=False)
def test_run_seq2seq_dp(self):
self.run_seq2seq_quick(distributed=False)
# verify that the trainer can handle distributed with n_gpu > 1
@require_torch_multi_gpu
def test_finetune_trainer_ddp(self):
self.finetune_trainer_quick(distributed=True)
def test_run_seq2seq_ddp(self):
self.run_seq2seq_quick(distributed=True)
# it's crucial to test --sharded_ddp w/ and w/o --fp16
# test --sharded_ddp w/o --fp16
@require_torch_multi_gpu
@require_fairscale
def test_finetune_trainer_ddp_sharded_ddp(self):
self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp")
def test_run_seq2seq_ddp_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp")
# test --sharded_ddp w/ --fp16
@require_torch_multi_gpu
@require_fairscale
def test_finetune_trainer_ddp_sharded_ddp_fp16(self):
self.finetune_trainer_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
def test_run_seq2seq_ddp_sharded_ddp_fp16(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
@require_apex
def test_finetune_trainer_apex(self):
self.finetune_trainer_quick(extra_args_str="--fp16 --fp16_backend=apex")
def test_run_seq2seq_apex(self):
self.run_seq2seq_quick(extra_args_str="--fp16 --fp16_backend=apex")
@slow
def test_finetune_trainer_slow(self):
def test_run_seq2seq_slow(self):
# There is a missing call to __init__process_group somewhere
output_dir = self.run_trainer(
eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=10, distributed=False
......@@ -115,7 +120,7 @@ class TestFinetuneTrainer(TestCasePlus):
# test if do_predict saves generations and metrics
contents = os.listdir(output_dir)
contents = {os.path.basename(p) for p in contents}
assert "test_generations.txt" in contents
assert "test_preds_seq2seq.txt" in contents
assert "test_results.json" in contents
def run_trainer(
......@@ -127,15 +132,17 @@ class TestFinetuneTrainer(TestCasePlus):
distributed: bool = False,
extra_args_str: str = None,
):
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
data_dir = self.examples_dir / "test_data/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir()
args = f"""
--model_name_or_path {model_name}
--data_dir {data_dir}
--train_file {data_dir}/train.json
--validation_file {data_dir}/val.json
--test_file {data_dir}/test.json
--output_dir {output_dir}
--overwrite_output_dir
--n_train 8
--n_val 8
--max_train_samples 8
--max_val_samples 8
--max_source_length {max_len}
--max_target_length {max_len}
--val_max_target_length {max_len}
......@@ -156,10 +163,9 @@ class TestFinetuneTrainer(TestCasePlus):
--label_smoothing_factor 0.1
--adafactor
--task translation
--tgt_lang ro_RO
--src_lang en_XX
--target_lang ro_RO
--source_lang en_XX
""".split()
# --eval_beams 2
if extra_args_str is not None:
args.extend(extra_args_str.split())
......@@ -169,12 +175,12 @@ class TestFinetuneTrainer(TestCasePlus):
distributed_args = f"""
-m torch.distributed.launch
--nproc_per_node={n_gpu}
{self.test_file_dir}/finetune_trainer.py
{self.examples_dir_str}/seq2seq/run_seq2seq.py
""".split()
cmd = [sys.executable] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env())
else:
testargs = ["finetune_trainer.py"] + args
testargs = ["run_seq2seq.py"] + args
with patch.object(sys, "argv", testargs):
main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment