Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
76e5af4c
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d714dfeaa8f019a634f2d565fc161f9b17fe85fb"
Unverified
Commit
76e5af4c
authored
Jun 23, 2020
by
Sam Shleifer
Committed by
GitHub
Jun 23, 2020
Browse files
[pl_examples] revert deletion of optimizer_step (#5227)
parent
c01480bb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
1 deletion
+15
-1
examples/lightning_base.py
examples/lightning_base.py
+13
-0
examples/summarization/finetune.py
examples/summarization/finetune.py
+1
-1
examples/summarization/run_distiller.sh
examples/summarization/run_distiller.sh
+1
-0
No files found.
examples/lightning_base.py
View file @
76e5af4c
...
@@ -116,6 +116,19 @@ class BaseTransformer(pl.LightningModule):
...
@@ -116,6 +116,19 @@ class BaseTransformer(pl.LightningModule):
self
.
opt
=
optimizer
self
.
opt
=
optimizer
return
[
optimizer
]
return
[
optimizer
]
def
optimizer_step
(
self
,
epoch
,
batch_idx
,
optimizer
,
optimizer_idx
,
second_order_closure
=
None
):
if
self
.
trainer
.
use_tpu
:
xm
.
optimizer_step
(
optimizer
)
else
:
optimizer
.
step
()
optimizer
.
zero_grad
()
self
.
lr_scheduler
.
step
()
def
get_tqdm_dict
(
self
):
avg_loss
=
getattr
(
self
.
trainer
,
"avg_loss"
,
0.0
)
tqdm_dict
=
{
"loss"
:
"{:.3f}"
.
format
(
avg_loss
),
"lr"
:
self
.
lr_scheduler
.
get_last_lr
()[
-
1
]}
return
tqdm_dict
def
test_step
(
self
,
batch
,
batch_nb
):
def
test_step
(
self
,
batch
,
batch_nb
):
return
self
.
validation_step
(
batch
,
batch_nb
)
return
self
.
validation_step
(
batch
,
batch_nb
)
...
...
examples/summarization/finetune.py
View file @
76e5af4c
...
@@ -149,7 +149,7 @@ class SummarizationModule(BaseTransformer):
...
@@ -149,7 +149,7 @@ class SummarizationModule(BaseTransformer):
source_ids
,
source_mask
,
y
=
SummarizationDataset
.
trim_seq2seq_batch
(
batch
,
pad_token_id
)
source_ids
,
source_mask
,
y
=
SummarizationDataset
.
trim_seq2seq_batch
(
batch
,
pad_token_id
)
t0
=
time
.
time
()
t0
=
time
.
time
()
generated_ids
=
self
.
model
.
generate
(
input_ids
=
source_ids
,
attention_mask
=
source_mask
,
use_cache
=
True
,)
generated_ids
=
self
.
model
.
generate
(
input_ids
=
source_ids
,
attention_mask
=
source_mask
,
use_cache
=
True
,)
gen_time
=
time
.
time
()
-
t0
/
source_ids
.
shape
[
0
]
gen_time
=
(
time
.
time
()
-
t0
)
/
source_ids
.
shape
[
0
]
preds
=
self
.
ids_to_clean_text
(
generated_ids
)
preds
=
self
.
ids_to_clean_text
(
generated_ids
)
target
=
self
.
ids_to_clean_text
(
y
)
target
=
self
.
ids_to_clean_text
(
y
)
loss_tensors
=
self
.
_step
(
batch
)
loss_tensors
=
self
.
_step
(
batch
)
...
...
examples/summarization/run_distiller.sh
View file @
76e5af4c
...
@@ -7,5 +7,6 @@ python distillation.py \
...
@@ -7,5 +7,6 @@ python distillation.py \
--learning_rate
=
3e-4
\
--learning_rate
=
3e-4
\
--do_train
\
--do_train
\
--do_predict
\
--do_predict
\
--fp16
\
--val_check_interval
0.1
\
--val_check_interval
0.1
\
$@
$@
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment