Unverified Commit 889d3bfd authored by srush's avatar srush Committed by GitHub
Browse files

default arg fix (#2937)

parent ea8eba35
......@@ -248,16 +248,22 @@ def generic_train(model, args):
filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=5
)
trainer = pl.Trainer(
train_params = dict(
accumulate_grad_batches=args.gradient_accumulation_steps,
gpus=args.n_gpu,
max_epochs=args.num_train_epochs,
use_amp=args.fp16,
amp_level=args.fp16_opt_level,
distributed_backend="ddp",
gradient_clip_val=args.max_grad_norm,
checkpoint_callback=checkpoint_callback,
)
if args.fp16:
train_params["use_amp"] = args.fp16
train_params["amp_level"] = args.fp16_opt_level
if args.n_gpu > 1:
train_params["distributed_backend"] = "ddp"
trainer = pl.Trainer(**train_params)
if args.do_train:
trainer.fit(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment