Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b6b2f227
Unverified
Commit
b6b2f227
authored
Aug 03, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 03, 2020
Browse files
s2s: fix LR logging, remove some dead code. (#6205)
parent
06f1692b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
7 deletions
+5
-7
examples/lightning_base.py
examples/lightning_base.py
+1
-5
examples/seq2seq/callbacks.py
examples/seq2seq/callbacks.py
+4
-0
examples/seq2seq/train_mbart_cc25_enro.sh
examples/seq2seq/train_mbart_cc25_enro.sh
+0
-2
No files found.
examples/lightning_base.py
View file @
b6b2f227
...
@@ -58,7 +58,6 @@ class BaseTransformer(pl.LightningModule):
...
@@ -58,7 +58,6 @@ class BaseTransformer(pl.LightningModule):
self
.
hparams
=
hparams
self
.
hparams
=
hparams
self
.
step_count
=
0
self
.
step_count
=
0
self
.
tfmr_ckpts
=
{}
self
.
output_dir
=
Path
(
self
.
hparams
.
output_dir
)
self
.
output_dir
=
Path
(
self
.
hparams
.
output_dir
)
cache_dir
=
self
.
hparams
.
cache_dir
if
self
.
hparams
.
cache_dir
else
None
cache_dir
=
self
.
hparams
.
cache_dir
if
self
.
hparams
.
cache_dir
else
None
if
config
is
None
:
if
config
is
None
:
...
@@ -99,7 +98,7 @@ class BaseTransformer(pl.LightningModule):
...
@@ -99,7 +98,7 @@ class BaseTransformer(pl.LightningModule):
self
.
model
=
self
.
model_type
.
from_pretrained
(
*
args
,
**
kwargs
)
self
.
model
=
self
.
model_type
.
from_pretrained
(
*
args
,
**
kwargs
)
def
configure_optimizers
(
self
):
def
configure_optimizers
(
self
):
"Prepare optimizer and schedule (linear warmup and decay)"
""
"Prepare optimizer and schedule (linear warmup and decay)"
""
model
=
self
.
model
model
=
self
.
model
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
optimizer_grouped_parameters
=
[
optimizer_grouped_parameters
=
[
...
@@ -159,11 +158,9 @@ class BaseTransformer(pl.LightningModule):
...
@@ -159,11 +158,9 @@ class BaseTransformer(pl.LightningModule):
@
pl
.
utilities
.
rank_zero_only
@
pl
.
utilities
.
rank_zero_only
def
on_save_checkpoint
(
self
,
checkpoint
:
Dict
[
str
,
Any
])
->
None
:
def
on_save_checkpoint
(
self
,
checkpoint
:
Dict
[
str
,
Any
])
->
None
:
save_path
=
self
.
output_dir
.
joinpath
(
"best_tfmr"
)
save_path
=
self
.
output_dir
.
joinpath
(
"best_tfmr"
)
save_path
.
mkdir
(
exist_ok
=
True
)
self
.
model
.
config
.
save_step
=
self
.
step_count
self
.
model
.
config
.
save_step
=
self
.
step_count
self
.
model
.
save_pretrained
(
save_path
)
self
.
model
.
save_pretrained
(
save_path
)
self
.
tokenizer
.
save_pretrained
(
save_path
)
self
.
tokenizer
.
save_pretrained
(
save_path
)
self
.
tfmr_ckpts
[
self
.
step_count
]
=
save_path
@
staticmethod
@
staticmethod
def
add_model_specific_args
(
parser
,
root_dir
):
def
add_model_specific_args
(
parser
,
root_dir
):
...
@@ -274,7 +271,6 @@ def add_generic_args(parser, root_dir) -> None:
...
@@ -274,7 +271,6 @@ def add_generic_args(parser, root_dir) -> None:
default
=
1
,
default
=
1
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
)
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"random seed for initialization"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"random seed for initialization"
)
...
...
examples/seq2seq/callbacks.py
View file @
b6b2f227
...
@@ -19,6 +19,10 @@ logger = logging.getLogger(__name__)
...
@@ -19,6 +19,10 @@ logger = logging.getLogger(__name__)
class
Seq2SeqLoggingCallback
(
pl
.
Callback
):
class
Seq2SeqLoggingCallback
(
pl
.
Callback
):
def
on_batch_end
(
self
,
trainer
,
pl_module
):
lrs
=
{
f
"lr_group_
{
i
}
"
:
param
[
"lr"
]
for
i
,
param
in
enumerate
(
pl_module
.
trainer
.
optimizers
[
0
].
param_groups
)}
pl_module
.
logger
.
log_metrics
(
lrs
)
@
rank_zero_only
@
rank_zero_only
def
_write_logs
(
def
_write_logs
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
,
type_path
:
str
,
save_generations
=
True
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
,
type_path
:
str
,
save_generations
=
True
...
...
examples/seq2seq/train_mbart_cc25_enro.sh
View file @
b6b2f227
...
@@ -5,7 +5,6 @@ python finetune.py \
...
@@ -5,7 +5,6 @@ python finetune.py \
--learning_rate
=
3e-5
\
--learning_rate
=
3e-5
\
--fp16
\
--fp16
\
--do_train
\
--do_train
\
--do_predict
\
--val_check_interval
=
0.25
\
--val_check_interval
=
0.25
\
--adam_eps
1e-06
\
--adam_eps
1e-06
\
--num_train_epochs
6
--src_lang
en_XX
--tgt_lang
ro_RO
\
--num_train_epochs
6
--src_lang
en_XX
--tgt_lang
ro_RO
\
...
@@ -15,6 +14,5 @@ python finetune.py \
...
@@ -15,6 +14,5 @@ python finetune.py \
--task
translation
\
--task
translation
\
--warmup_steps
500
\
--warmup_steps
500
\
--freeze_embeds
\
--freeze_embeds
\
--early_stopping_patience
4
\
--model_name_or_path
=
facebook/mbart-large-cc25
\
--model_name_or_path
=
facebook/mbart-large-cc25
\
$@
$@
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment