Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5e24982e
Unverified
Commit
5e24982e
authored
Oct 28, 2020
by
Sean Naren
Committed by
GitHub
Oct 28, 2020
Browse files
Upgrade PyTorch Lightning to 1.0.2 (#7852)
Co-authored-by:
Sam Shleifer
<
sshleifer@gmail.com
>
parent
1b6c8d48
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
11 additions
and
13 deletions
+11
-13
examples/lightning_base.py
examples/lightning_base.py
+3
-2
examples/requirements.txt
examples/requirements.txt
+1
-1
examples/seq2seq/callbacks.py
examples/seq2seq/callbacks.py
+0
-1
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+2
-3
examples/seq2seq/test_bash_script.py
examples/seq2seq/test_bash_script.py
+1
-1
examples/seq2seq/test_seq2seq_examples_multi_gpu.py
examples/seq2seq/test_seq2seq_examples_multi_gpu.py
+0
-1
examples/text-classification/run_pl_glue.py
examples/text-classification/run_pl_glue.py
+1
-1
examples/token-classification/run_pl_ner.py
examples/token-classification/run_pl_ner.py
+3
-3
No files found.
examples/lightning_base.py
View file @
5e24982e
...
@@ -337,7 +337,7 @@ def add_generic_args(parser, root_dir) -> None:
...
@@ -337,7 +337,7 @@ def add_generic_args(parser, root_dir) -> None:
def
generic_train
(
def
generic_train
(
model
:
BaseTransformer
,
model
:
BaseTransformer
,
args
:
argparse
.
Namespace
,
args
:
argparse
.
Namespace
,
early_stopping_callback
=
Fals
e
,
early_stopping_callback
=
Non
e
,
logger
=
True
,
# can pass WandbLogger() here
logger
=
True
,
# can pass WandbLogger() here
extra_callbacks
=
[],
extra_callbacks
=
[],
checkpoint_callback
=
None
,
checkpoint_callback
=
None
,
...
@@ -355,6 +355,8 @@ def generic_train(
...
@@ -355,6 +355,8 @@ def generic_train(
checkpoint_callback
=
pl
.
callbacks
.
ModelCheckpoint
(
checkpoint_callback
=
pl
.
callbacks
.
ModelCheckpoint
(
filepath
=
args
.
output_dir
,
prefix
=
"checkpoint"
,
monitor
=
"val_loss"
,
mode
=
"min"
,
save_top_k
=
1
filepath
=
args
.
output_dir
,
prefix
=
"checkpoint"
,
monitor
=
"val_loss"
,
mode
=
"min"
,
save_top_k
=
1
)
)
if
early_stopping_callback
:
extra_callbacks
.
append
(
early_stopping_callback
)
if
logging_callback
is
None
:
if
logging_callback
is
None
:
logging_callback
=
LoggingCallback
()
logging_callback
=
LoggingCallback
()
...
@@ -376,7 +378,6 @@ def generic_train(
...
@@ -376,7 +378,6 @@ def generic_train(
callbacks
=
[
logging_callback
]
+
extra_callbacks
,
callbacks
=
[
logging_callback
]
+
extra_callbacks
,
logger
=
logger
,
logger
=
logger
,
checkpoint_callback
=
checkpoint_callback
,
checkpoint_callback
=
checkpoint_callback
,
early_stop_callback
=
early_stopping_callback
,
**
train_params
,
**
train_params
,
)
)
...
...
examples/requirements.txt
View file @
5e24982e
...
@@ -5,7 +5,7 @@ psutil
...
@@ -5,7 +5,7 @@ psutil
sacrebleu
sacrebleu
rouge-score
rouge-score
tensorflow_datasets
tensorflow_datasets
pytorch-lightning==
0.9.0
pytorch-lightning==
1.0.4
matplotlib
matplotlib
git-python==1.0.3
git-python==1.0.3
faiss-cpu
faiss-cpu
...
...
examples/seq2seq/callbacks.py
View file @
5e24982e
...
@@ -102,7 +102,6 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=Fa
...
@@ -102,7 +102,6 @@ def get_checkpoint_callback(output_dir, metric, save_top_k=1, lower_is_better=Fa
monitor
=
f
"val_
{
metric
}
"
,
monitor
=
f
"val_
{
metric
}
"
,
mode
=
"min"
if
"loss"
in
metric
else
"max"
,
mode
=
"min"
if
"loss"
in
metric
else
"max"
,
save_top_k
=
save_top_k
,
save_top_k
=
save_top_k
,
period
=
0
,
# maybe save a checkpoint every time val is run, not just end of epoch.
)
)
return
checkpoint_callback
return
checkpoint_callback
...
...
examples/seq2seq/finetune.py
View file @
5e24982e
...
@@ -182,7 +182,6 @@ class SummarizationModule(BaseTransformer):
...
@@ -182,7 +182,6 @@ class SummarizationModule(BaseTransformer):
return
self
.
_generative_step
(
batch
)
return
self
.
_generative_step
(
batch
)
def
validation_epoch_end
(
self
,
outputs
,
prefix
=
"val"
)
->
Dict
:
def
validation_epoch_end
(
self
,
outputs
,
prefix
=
"val"
)
->
Dict
:
self
.
step_count
+=
1
self
.
step_count
+=
1
losses
=
{
k
:
torch
.
stack
([
x
[
k
]
for
x
in
outputs
]).
mean
()
for
k
in
self
.
loss_names
}
losses
=
{
k
:
torch
.
stack
([
x
[
k
]
for
x
in
outputs
]).
mean
()
for
k
in
self
.
loss_names
}
loss
=
losses
[
"loss"
]
loss
=
losses
[
"loss"
]
...
@@ -252,7 +251,7 @@ class SummarizationModule(BaseTransformer):
...
@@ -252,7 +251,7 @@ class SummarizationModule(BaseTransformer):
def
get_dataloader
(
self
,
type_path
:
str
,
batch_size
:
int
,
shuffle
:
bool
=
False
)
->
DataLoader
:
def
get_dataloader
(
self
,
type_path
:
str
,
batch_size
:
int
,
shuffle
:
bool
=
False
)
->
DataLoader
:
dataset
=
self
.
get_dataset
(
type_path
)
dataset
=
self
.
get_dataset
(
type_path
)
if
self
.
hparams
.
sortish_sampler
and
type_path
!=
"test"
:
if
self
.
hparams
.
sortish_sampler
and
type_path
!=
"test"
and
type_path
!=
"val"
:
sampler
=
dataset
.
make_sortish_sampler
(
batch_size
,
distributed
=
self
.
hparams
.
gpus
>
1
)
sampler
=
dataset
.
make_sortish_sampler
(
batch_size
,
distributed
=
self
.
hparams
.
gpus
>
1
)
return
DataLoader
(
return
DataLoader
(
dataset
,
dataset
,
...
@@ -263,7 +262,7 @@ class SummarizationModule(BaseTransformer):
...
@@ -263,7 +262,7 @@ class SummarizationModule(BaseTransformer):
sampler
=
sampler
,
sampler
=
sampler
,
)
)
elif
self
.
hparams
.
max_tokens_per_batch
is
not
None
and
type_path
!=
"test"
:
elif
self
.
hparams
.
max_tokens_per_batch
is
not
None
and
type_path
!=
"test"
and
type_path
!=
"val"
:
batch_sampler
=
dataset
.
make_dynamic_sampler
(
batch_sampler
=
dataset
.
make_dynamic_sampler
(
self
.
hparams
.
max_tokens_per_batch
,
distributed
=
self
.
hparams
.
gpus
>
1
self
.
hparams
.
max_tokens_per_batch
,
distributed
=
self
.
hparams
.
gpus
>
1
)
)
...
...
examples/seq2seq/test_bash_script.py
View file @
5e24982e
...
@@ -144,6 +144,7 @@ class TestAll(TestCasePlus):
...
@@ -144,6 +144,7 @@ class TestAll(TestCasePlus):
f
"--num_train_epochs=
{
epochs
}
"
,
f
"--num_train_epochs=
{
epochs
}
"
,
"--warmup_steps=10"
,
"--warmup_steps=10"
,
"--val_check_interval=1.0"
,
"--val_check_interval=1.0"
,
"--do_predict"
,
]
]
)
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
@@ -151,7 +152,6 @@ class TestAll(TestCasePlus):
...
@@ -151,7 +152,6 @@ class TestAll(TestCasePlus):
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
BartSummarizationDistiller
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
parser
=
BartSummarizationDistiller
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
do_predict
=
False
# assert args.gpus == gpus THIS BREAKS for multigpu
# assert args.gpus == gpus THIS BREAKS for multigpu
model
=
distill_main
(
args
)
model
=
distill_main
(
args
)
...
...
examples/seq2seq/test_seq2seq_examples_multi_gpu.py
View file @
5e24982e
...
@@ -176,7 +176,6 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
...
@@ -176,7 +176,6 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
print
(
metrics
)
print
(
metrics
)
last_step_stats
=
metrics
[
"val"
][
-
1
]
last_step_stats
=
metrics
[
"val"
][
-
1
]
self
.
assertGreaterEqual
(
last_step_stats
[
"val_avg_gen_time"
],
0.01
)
self
.
assertGreaterEqual
(
last_step_stats
[
"val_avg_gen_time"
],
0.01
)
self
.
assertGreaterEqual
(
1.0
,
last_step_stats
[
"val_avg_gen_time"
])
self
.
assertIsInstance
(
last_step_stats
[
f
"val_avg_
{
val_metric
}
"
],
float
)
self
.
assertIsInstance
(
last_step_stats
[
f
"val_avg_
{
val_metric
}
"
],
float
)
self
.
assertEqual
(
len
(
metrics
[
"test"
]),
1
)
self
.
assertEqual
(
len
(
metrics
[
"test"
]),
1
)
desired_n_evals
=
int
(
args_d
[
"max_epochs"
]
*
(
1
/
args_d
[
"val_check_interval"
])
/
2
+
1
)
desired_n_evals
=
int
(
args_d
[
"max_epochs"
]
*
(
1
/
args_d
[
"val_check_interval"
])
/
2
+
1
)
...
...
examples/text-classification/run_pl_glue.py
View file @
5e24982e
...
@@ -192,7 +192,7 @@ def main():
...
@@ -192,7 +192,7 @@ def main():
# Optionally, predict on dev set and write to output_dir
# Optionally, predict on dev set and write to output_dir
if
args
.
do_predict
:
if
args
.
do_predict
:
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpointepoch=*.ckpt"
),
recursive
=
True
)))
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint
-
epoch=*.ckpt"
),
recursive
=
True
)))
model
=
model
.
load_from_checkpoint
(
checkpoints
[
-
1
])
model
=
model
.
load_from_checkpoint
(
checkpoints
[
-
1
])
return
trainer
.
test
(
model
)
return
trainer
.
test
(
model
)
...
...
examples/token-classification/run_pl_ner.py
View file @
5e24982e
...
@@ -207,9 +207,9 @@ if __name__ == "__main__":
...
@@ -207,9 +207,9 @@ if __name__ == "__main__":
if
args
.
do_predict
:
if
args
.
do_predict
:
# See https://github.com/huggingface/transformers/issues/3159
# See https://github.com/huggingface/transformers/issues/3159
# pl use this format to create a checkpoint:
# pl use this
default
format to create a checkpoint:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
# /pytorch_lightning/callbacks/model_checkpoint.py#L
169
# /pytorch_lightning/callbacks/model_checkpoint.py#L
322
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpointepoch=*.ckpt"
),
recursive
=
True
)))
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint
-
epoch=*.ckpt"
),
recursive
=
True
)))
model
=
model
.
load_from_checkpoint
(
checkpoints
[
-
1
])
model
=
model
.
load_from_checkpoint
(
checkpoints
[
-
1
])
trainer
.
test
(
model
)
trainer
.
test
(
model
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment