Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a504cb49
Unverified
Commit
a504cb49
authored
Apr 20, 2020
by
Sam Shleifer
Committed by
GitHub
Apr 20, 2020
Browse files
[examples] fix summarization do_predict (#3866)
parent
52c85f84
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
6 deletions
+20
-6
examples/summarization/bart/finetune.py
examples/summarization/bart/finetune.py
+5
-1
examples/summarization/bart/test_bart_examples.py
examples/summarization/bart/test_bart_examples.py
+9
-0
examples/transformer_base.py
examples/transformer_base.py
+6
-5
No files found.
examples/summarization/bart/finetune.py
View file @
a504cb49
...
@@ -166,8 +166,12 @@ def main(args):
...
@@ -166,8 +166,12 @@ def main(args):
# Optionally, predict on dev set and write to output_dir
# Optionally, predict on dev set and write to output_dir
if
args
.
do_predict
:
if
args
.
do_predict
:
# See https://github.com/huggingface/transformers/issues/3159
# pl use this format to create a checkpoint:
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
# /pytorch_lightning/callbacks/model_checkpoint.py#L169
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpointepoch=*.ckpt"
),
recursive
=
True
)))
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpointepoch=*.ckpt"
),
recursive
=
True
)))
SummarizationTrainer
.
load_from_checkpoint
(
checkpoints
[
-
1
])
model
=
model
.
load_from_checkpoint
(
checkpoints
[
-
1
])
trainer
.
test
(
model
)
trainer
.
test
(
model
)
...
...
examples/summarization/bart/test_bart_examples.py
View file @
a504cb49
...
@@ -94,7 +94,15 @@ class TestBartExamples(unittest.TestCase):
...
@@ -94,7 +94,15 @@ class TestBartExamples(unittest.TestCase):
)
)
main
(
argparse
.
Namespace
(
**
args_d
))
main
(
argparse
.
Namespace
(
**
args_d
))
args_d
.
update
({
"do_train"
:
False
,
"do_predict"
:
True
})
args_d
.
update
({
"do_train"
:
False
,
"do_predict"
:
True
})
main
(
argparse
.
Namespace
(
**
args_d
))
main
(
argparse
.
Namespace
(
**
args_d
))
contents
=
os
.
listdir
(
output_dir
)
expected_contents
=
{
"checkpointepoch=0.ckpt"
,
"test_results.txt"
,
}
created_files
=
{
os
.
path
.
basename
(
p
)
for
p
in
contents
}
self
.
assertSetEqual
(
expected_contents
,
created_files
)
def
test_t5_run_sum_cli
(
self
):
def
test_t5_run_sum_cli
(
self
):
args_d
:
dict
=
DEFAULT_ARGS
.
copy
()
args_d
:
dict
=
DEFAULT_ARGS
.
copy
()
...
@@ -111,6 +119,7 @@ class TestBartExamples(unittest.TestCase):
...
@@ -111,6 +119,7 @@ class TestBartExamples(unittest.TestCase):
do_predict
=
True
,
do_predict
=
True
,
)
)
main
(
argparse
.
Namespace
(
**
args_d
))
main
(
argparse
.
Namespace
(
**
args_d
))
# args_d.update({"do_train": False, "do_predict": True})
# args_d.update({"do_train": False, "do_predict": True})
# main(argparse.Namespace(**args_d))
# main(argparse.Namespace(**args_d))
...
...
examples/transformer_base.py
View file @
a504cb49
import
argparse
import
logging
import
logging
import
os
import
os
import
random
import
random
...
@@ -38,7 +39,7 @@ MODEL_MODES = {
...
@@ -38,7 +39,7 @@ MODEL_MODES = {
}
}
def
set_seed
(
args
):
def
set_seed
(
args
:
argparse
.
Namespace
):
random
.
seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
...
@@ -47,7 +48,7 @@ def set_seed(args):
...
@@ -47,7 +48,7 @@ def set_seed(args):
class
BaseTransformer
(
pl
.
LightningModule
):
class
BaseTransformer
(
pl
.
LightningModule
):
def
__init__
(
self
,
hparams
,
num_labels
=
None
,
mode
=
"base"
,
**
config_kwargs
):
def
__init__
(
self
,
hparams
:
argparse
.
Namespace
,
num_labels
=
None
,
mode
=
"base"
,
**
config_kwargs
):
"Initialize a model."
"Initialize a model."
super
(
BaseTransformer
,
self
).
__init__
()
super
(
BaseTransformer
,
self
).
__init__
()
...
@@ -192,7 +193,7 @@ class BaseTransformer(pl.LightningModule):
...
@@ -192,7 +193,7 @@ class BaseTransformer(pl.LightningModule):
class
LoggingCallback
(
pl
.
Callback
):
class
LoggingCallback
(
pl
.
Callback
):
def
on_validation_end
(
self
,
trainer
,
pl_m
odule
):
def
on_validation_end
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningM
odule
):
logger
.
info
(
"***** Validation results *****"
)
logger
.
info
(
"***** Validation results *****"
)
if
pl_module
.
is_logger
():
if
pl_module
.
is_logger
():
metrics
=
trainer
.
callback_metrics
metrics
=
trainer
.
callback_metrics
...
@@ -201,7 +202,7 @@ class LoggingCallback(pl.Callback):
...
@@ -201,7 +202,7 @@ class LoggingCallback(pl.Callback):
if
key
not
in
[
"log"
,
"progress_bar"
]:
if
key
not
in
[
"log"
,
"progress_bar"
]:
logger
.
info
(
"{} = {}
\n
"
.
format
(
key
,
str
(
metrics
[
key
])))
logger
.
info
(
"{} = {}
\n
"
.
format
(
key
,
str
(
metrics
[
key
])))
def
on_test_end
(
self
,
trainer
,
pl_m
odule
):
def
on_test_end
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningM
odule
):
logger
.
info
(
"***** Test results *****"
)
logger
.
info
(
"***** Test results *****"
)
if
pl_module
.
is_logger
():
if
pl_module
.
is_logger
():
...
@@ -256,7 +257,7 @@ def add_generic_args(parser, root_dir):
...
@@ -256,7 +257,7 @@ def add_generic_args(parser, root_dir):
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"random seed for initialization"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"random seed for initialization"
)
def
generic_train
(
model
,
args
):
def
generic_train
(
model
:
BaseTransformer
,
args
:
argparse
.
Namespace
):
# init model
# init model
set_seed
(
args
)
set_seed
(
args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment