Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b6b2f227
"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "c808f156e9e38995faa74ec0219ce79d487fc585"
Unverified
Commit
b6b2f227
authored
Aug 03, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 03, 2020
Browse files
s2s: fix LR logging, remove some dead code. (#6205)
parent
06f1692b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
7 deletions
+5
-7
examples/lightning_base.py
examples/lightning_base.py
+1
-5
examples/seq2seq/callbacks.py
examples/seq2seq/callbacks.py
+4
-0
examples/seq2seq/train_mbart_cc25_enro.sh
examples/seq2seq/train_mbart_cc25_enro.sh
+0
-2
No files found.
examples/lightning_base.py
View file @
b6b2f227
...
@@ -58,7 +58,6 @@ class BaseTransformer(pl.LightningModule):
...
@@ -58,7 +58,6 @@ class BaseTransformer(pl.LightningModule):
self
.
hparams
=
hparams
self
.
hparams
=
hparams
self
.
step_count
=
0
self
.
step_count
=
0
self
.
tfmr_ckpts
=
{}
self
.
output_dir
=
Path
(
self
.
hparams
.
output_dir
)
self
.
output_dir
=
Path
(
self
.
hparams
.
output_dir
)
cache_dir
=
self
.
hparams
.
cache_dir
if
self
.
hparams
.
cache_dir
else
None
cache_dir
=
self
.
hparams
.
cache_dir
if
self
.
hparams
.
cache_dir
else
None
if
config
is
None
:
if
config
is
None
:
...
@@ -99,7 +98,7 @@ class BaseTransformer(pl.LightningModule):
...
@@ -99,7 +98,7 @@ class BaseTransformer(pl.LightningModule):
self
.
model
=
self
.
model_type
.
from_pretrained
(
*
args
,
**
kwargs
)
self
.
model
=
self
.
model_type
.
from_pretrained
(
*
args
,
**
kwargs
)
def
configure_optimizers
(
self
):
def
configure_optimizers
(
self
):
"Prepare optimizer and schedule (linear warmup and decay)"
""
"Prepare optimizer and schedule (linear warmup and decay)"
""
model
=
self
.
model
model
=
self
.
model
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
optimizer_grouped_parameters
=
[
optimizer_grouped_parameters
=
[
...
@@ -159,11 +158,9 @@ class BaseTransformer(pl.LightningModule):
...
@@ -159,11 +158,9 @@ class BaseTransformer(pl.LightningModule):
@
pl
.
utilities
.
rank_zero_only
@
pl
.
utilities
.
rank_zero_only
def
on_save_checkpoint
(
self
,
checkpoint
:
Dict
[
str
,
Any
])
->
None
:
def
on_save_checkpoint
(
self
,
checkpoint
:
Dict
[
str
,
Any
])
->
None
:
save_path
=
self
.
output_dir
.
joinpath
(
"best_tfmr"
)
save_path
=
self
.
output_dir
.
joinpath
(
"best_tfmr"
)
save_path
.
mkdir
(
exist_ok
=
True
)
self
.
model
.
config
.
save_step
=
self
.
step_count
self
.
model
.
config
.
save_step
=
self
.
step_count
self
.
model
.
save_pretrained
(
save_path
)
self
.
model
.
save_pretrained
(
save_path
)
self
.
tokenizer
.
save_pretrained
(
save_path
)
self
.
tokenizer
.
save_pretrained
(
save_path
)
self
.
tfmr_ckpts
[
self
.
step_count
]
=
save_path
@
staticmethod
@
staticmethod
def
add_model_specific_args
(
parser
,
root_dir
):
def
add_model_specific_args
(
parser
,
root_dir
):
...
@@ -274,7 +271,6 @@ def add_generic_args(parser, root_dir) -> None:
...
@@ -274,7 +271,6 @@ def add_generic_args(parser, root_dir) -> None:
default
=
1
,
default
=
1
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
)
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"random seed for initialization"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"random seed for initialization"
)
...
...
examples/seq2seq/callbacks.py
View file @
b6b2f227
...
@@ -19,6 +19,10 @@ logger = logging.getLogger(__name__)
...
@@ -19,6 +19,10 @@ logger = logging.getLogger(__name__)
class
Seq2SeqLoggingCallback
(
pl
.
Callback
):
class
Seq2SeqLoggingCallback
(
pl
.
Callback
):
def
on_batch_end
(
self
,
trainer
,
pl_module
):
lrs
=
{
f
"lr_group_
{
i
}
"
:
param
[
"lr"
]
for
i
,
param
in
enumerate
(
pl_module
.
trainer
.
optimizers
[
0
].
param_groups
)}
pl_module
.
logger
.
log_metrics
(
lrs
)
@
rank_zero_only
@
rank_zero_only
def
_write_logs
(
def
_write_logs
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
,
type_path
:
str
,
save_generations
=
True
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
,
type_path
:
str
,
save_generations
=
True
...
...
examples/seq2seq/train_mbart_cc25_enro.sh
View file @
b6b2f227
...
@@ -5,7 +5,6 @@ python finetune.py \
...
@@ -5,7 +5,6 @@ python finetune.py \
--learning_rate
=
3e-5
\
--learning_rate
=
3e-5
\
--fp16
\
--fp16
\
--do_train
\
--do_train
\
--do_predict
\
--val_check_interval
=
0.25
\
--val_check_interval
=
0.25
\
--adam_eps
1e-06
\
--adam_eps
1e-06
\
--num_train_epochs
6
--src_lang
en_XX
--tgt_lang
ro_RO
\
--num_train_epochs
6
--src_lang
en_XX
--tgt_lang
ro_RO
\
...
@@ -15,6 +14,5 @@ python finetune.py \
...
@@ -15,6 +14,5 @@ python finetune.py \
--task
translation
\
--task
translation
\
--warmup_steps
500
\
--warmup_steps
500
\
--freeze_embeds
\
--freeze_embeds
\
--early_stopping_patience
4
\
--model_name_or_path
=
facebook/mbart-large-cc25
\
--model_name_or_path
=
facebook/mbart-large-cc25
\
$@
$@
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment