"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "c808f156e9e38995faa74ec0219ce79d487fc585"
Unverified Commit b6b2f227 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

s2s: fix LR logging, remove some dead code. (#6205)

parent 06f1692b
...@@ -58,7 +58,6 @@ class BaseTransformer(pl.LightningModule): ...@@ -58,7 +58,6 @@ class BaseTransformer(pl.LightningModule):
self.hparams = hparams self.hparams = hparams
self.step_count = 0 self.step_count = 0
self.tfmr_ckpts = {}
self.output_dir = Path(self.hparams.output_dir) self.output_dir = Path(self.hparams.output_dir)
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
if config is None: if config is None:
...@@ -99,7 +98,7 @@ class BaseTransformer(pl.LightningModule): ...@@ -99,7 +98,7 @@ class BaseTransformer(pl.LightningModule):
self.model = self.model_type.from_pretrained(*args, **kwargs) self.model = self.model_type.from_pretrained(*args, **kwargs)
def configure_optimizers(self): def configure_optimizers(self):
"Prepare optimizer and schedule (linear warmup and decay)" """Prepare optimizer and schedule (linear warmup and decay)"""
model = self.model model = self.model
no_decay = ["bias", "LayerNorm.weight"] no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
...@@ -159,11 +158,9 @@ class BaseTransformer(pl.LightningModule): ...@@ -159,11 +158,9 @@ class BaseTransformer(pl.LightningModule):
@pl.utilities.rank_zero_only @pl.utilities.rank_zero_only
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
save_path = self.output_dir.joinpath("best_tfmr") save_path = self.output_dir.joinpath("best_tfmr")
save_path.mkdir(exist_ok=True)
self.model.config.save_step = self.step_count self.model.config.save_step = self.step_count
self.model.save_pretrained(save_path) self.model.save_pretrained(save_path)
self.tokenizer.save_pretrained(save_path) self.tokenizer.save_pretrained(save_path)
self.tfmr_ckpts[self.step_count] = save_path
@staticmethod @staticmethod
def add_model_specific_args(parser, root_dir): def add_model_specific_args(parser, root_dir):
...@@ -274,7 +271,6 @@ def add_generic_args(parser, root_dir) -> None: ...@@ -274,7 +271,6 @@ def add_generic_args(parser, root_dir) -> None:
default=1, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.", help="Number of updates steps to accumulate before performing a backward/update pass.",
) )
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
......
...@@ -19,6 +19,10 @@ logger = logging.getLogger(__name__) ...@@ -19,6 +19,10 @@ logger = logging.getLogger(__name__)
class Seq2SeqLoggingCallback(pl.Callback): class Seq2SeqLoggingCallback(pl.Callback):
def on_batch_end(self, trainer, pl_module):
lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)}
pl_module.logger.log_metrics(lrs)
@rank_zero_only @rank_zero_only
def _write_logs( def _write_logs(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True
......
...@@ -5,7 +5,6 @@ python finetune.py \ ...@@ -5,7 +5,6 @@ python finetune.py \
--learning_rate=3e-5 \ --learning_rate=3e-5 \
--fp16 \ --fp16 \
--do_train \ --do_train \
--do_predict \
--val_check_interval=0.25 \ --val_check_interval=0.25 \
--adam_eps 1e-06 \ --adam_eps 1e-06 \
--num_train_epochs 6 --src_lang en_XX --tgt_lang ro_RO \ --num_train_epochs 6 --src_lang en_XX --tgt_lang ro_RO \
...@@ -15,6 +14,5 @@ python finetune.py \ ...@@ -15,6 +14,5 @@ python finetune.py \
--task translation \ --task translation \
--warmup_steps 500 \ --warmup_steps 500 \
--freeze_embeds \ --freeze_embeds \
--early_stopping_patience 4 \
--model_name_or_path=facebook/mbart-large-cc25 \ --model_name_or_path=facebook/mbart-large-cc25 \
$@ $@
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment