Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c59b1e68
Unverified
Commit
c59b1e68
authored
Apr 15, 2020
by
Sam Shleifer
Committed by
GitHub
Apr 15, 2020
Browse files
[examples] unit test for run_bart_sum (#3544)
- adds pytorch-lightning dependency
parent
301bf8d1
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
121 additions
and
25 deletions
+121
-25
examples/requirements.txt
examples/requirements.txt
+2
-1
examples/summarization/bart/run_bart_sum.py
examples/summarization/bart/run_bart_sum.py
+21
-15
examples/summarization/bart/run_train.sh
examples/summarization/bart/run_train.sh
+1
-5
examples/summarization/bart/run_train_tiny.sh
examples/summarization/bart/run_train_tiny.sh
+33
-0
examples/summarization/bart/test_bart_examples.py
examples/summarization/bart/test_bart_examples.py
+62
-2
examples/transformer_base.py
examples/transformer_base.py
+2
-2
No files found.
examples/requirements.txt
View file @
c59b1e68
...
@@ -5,4 +5,5 @@ seqeval
...
@@ -5,4 +5,5 @@ seqeval
psutil
psutil
sacrebleu
sacrebleu
rouge-score
rouge-score
tensorflow_datasets
tensorflow_datasets
\ No newline at end of file
pytorch-lightning==0.7.3 # April 10, 2020 release
examples/summarization/bart/run_bart_sum.py
View file @
c59b1e68
...
@@ -8,7 +8,12 @@ import torch
...
@@ -8,7 +8,12 @@ import torch
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
transformer_base
import
BaseTransformer
,
add_generic_args
,
generic_train
,
get_linear_schedule_with_warmup
from
transformer_base
import
BaseTransformer
,
add_generic_args
,
generic_train
,
get_linear_schedule_with_warmup
from
utils
import
SummarizationDataset
try
:
from
.utils
import
SummarizationDataset
except
ImportError
:
from
utils
import
SummarizationDataset
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -20,6 +25,11 @@ class BartSystem(BaseTransformer):
...
@@ -20,6 +25,11 @@ class BartSystem(BaseTransformer):
def
__init__
(
self
,
hparams
):
def
__init__
(
self
,
hparams
):
super
().
__init__
(
hparams
,
num_labels
=
None
,
mode
=
self
.
mode
)
super
().
__init__
(
hparams
,
num_labels
=
None
,
mode
=
self
.
mode
)
self
.
dataset_kwargs
:
dict
=
dict
(
data_dir
=
self
.
hparams
.
data_dir
,
max_source_length
=
self
.
hparams
.
max_source_length
,
max_target_length
=
self
.
hparams
.
max_target_length
,
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
decoder_input_ids
=
None
,
lm_labels
=
None
):
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
decoder_input_ids
=
None
,
lm_labels
=
None
):
return
self
.
model
(
return
self
.
model
(
...
@@ -92,14 +102,6 @@ class BartSystem(BaseTransformer):
...
@@ -92,14 +102,6 @@ class BartSystem(BaseTransformer):
return
self
.
test_end
(
outputs
)
return
self
.
test_end
(
outputs
)
@
property
def
dataset_kwargs
(
self
):
return
dict
(
data_dir
=
self
.
hparams
.
data_dir
,
max_source_length
=
self
.
hparams
.
max_source_length
,
max_target_length
=
self
.
hparams
.
max_target_length
,
)
def
get_dataloader
(
self
,
type_path
:
str
,
batch_size
:
int
)
->
DataLoader
:
def
get_dataloader
(
self
,
type_path
:
str
,
batch_size
:
int
)
->
DataLoader
:
dataset
=
SummarizationDataset
(
self
.
tokenizer
,
type_path
=
type_path
,
**
self
.
dataset_kwargs
)
dataset
=
SummarizationDataset
(
self
.
tokenizer
,
type_path
=
type_path
,
**
self
.
dataset_kwargs
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
collate_fn
=
dataset
.
collate_fn
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
collate_fn
=
dataset
.
collate_fn
)
...
@@ -153,17 +155,12 @@ class BartSystem(BaseTransformer):
...
@@ -153,17 +155,12 @@ class BartSystem(BaseTransformer):
return
parser
return
parser
if
__name__
==
"__main__"
:
def
main
(
args
):
parser
=
argparse
.
ArgumentParser
()
add_generic_args
(
parser
,
os
.
getcwd
())
parser
=
BartSystem
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
# If output_dir not provided, a folder will be generated in pwd
# If output_dir not provided, a folder will be generated in pwd
if
not
args
.
output_dir
:
if
not
args
.
output_dir
:
args
.
output_dir
=
os
.
path
.
join
(
"./results"
,
f
"
{
args
.
task
}
_
{
args
.
model_type
}
_
{
time
.
strftime
(
'%Y%m%d_%H%M%S'
)
}
"
,)
args
.
output_dir
=
os
.
path
.
join
(
"./results"
,
f
"
{
args
.
task
}
_
{
args
.
model_type
}
_
{
time
.
strftime
(
'%Y%m%d_%H%M%S'
)
}
"
,)
os
.
makedirs
(
args
.
output_dir
)
os
.
makedirs
(
args
.
output_dir
)
model
=
BartSystem
(
args
)
model
=
BartSystem
(
args
)
trainer
=
generic_train
(
model
,
args
)
trainer
=
generic_train
(
model
,
args
)
...
@@ -172,3 +169,12 @@ if __name__ == "__main__":
...
@@ -172,3 +169,12 @@ if __name__ == "__main__":
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpointepoch=*.ckpt"
),
recursive
=
True
)))
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpointepoch=*.ckpt"
),
recursive
=
True
)))
BartSystem
.
load_from_checkpoint
(
checkpoints
[
-
1
])
BartSystem
.
load_from_checkpoint
(
checkpoints
[
-
1
])
trainer
.
test
(
model
)
trainer
.
test
(
model
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
add_generic_args
(
parser
,
os
.
getcwd
())
parser
=
BartSystem
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
main
(
args
)
examples/summarization/bart/run_train.sh
View file @
c59b1e68
# Install newest ptl.
pip
install
-U
git+http://github.com/PyTorchLightning/pytorch-lightning/
export
OUTPUT_DIR_NAME
=
bart_sum
export
OUTPUT_DIR_NAME
=
bart_sum
export
CURRENT_DIR
=
${
PWD
}
export
CURRENT_DIR
=
${
PWD
}
export
OUTPUT_DIR
=
${
CURRENT_DIR
}
/
${
OUTPUT_DIR_NAME
}
export
OUTPUT_DIR
=
${
CURRENT_DIR
}
/
${
OUTPUT_DIR_NAME
}
...
@@ -20,4 +16,4 @@ python run_bart_sum.py \
...
@@ -20,4 +16,4 @@ python run_bart_sum.py \
--train_batch_size
=
4
\
--train_batch_size
=
4
\
--eval_batch_size
=
4
\
--eval_batch_size
=
4
\
--output_dir
=
$OUTPUT_DIR
\
--output_dir
=
$OUTPUT_DIR
\
--do_train
--do_train
$@
\ No newline at end of file
examples/summarization/bart/run_train_tiny.sh
0 → 100755
View file @
c59b1e68
# Script for verifying that run_bart_sum can be invoked from its directory
# Get tiny dataset with cnn_dm format (4 examples for train, val, test)
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_tiny.tgz
tar
-xzvf
cnn_tiny.tgz
rm
cnn_tiny.tgz
export
OUTPUT_DIR_NAME
=
bart_utest_output
export
CURRENT_DIR
=
${
PWD
}
export
OUTPUT_DIR
=
${
CURRENT_DIR
}
/
${
OUTPUT_DIR_NAME
}
# Make output directory if it doesn't exist
mkdir
-p
$OUTPUT_DIR
# Add parent directory to python path to access transformer_base.py and utils.py
export
PYTHONPATH
=
"../../"
:
"
${
PYTHONPATH
}
"
python run_bart_sum.py
\
--data_dir
=
cnn_tiny/
\
--model_type
=
bart
\
--model_name_or_path
=
sshleifer/bart-tiny-random
\
--learning_rate
=
3e-5
\
--train_batch_size
=
2
\
--eval_batch_size
=
2
\
--output_dir
=
$OUTPUT_DIR
\
--num_train_epochs
=
1
\
--n_gpu
=
0
\
--do_train
$@
rm
-rf
cnn_tiny
rm
-rf
$OUTPUT_DIR
examples/summarization/bart/test_bart_examples.py
View file @
c59b1e68
import
argparse
import
logging
import
logging
import
os
import
sys
import
sys
import
tempfile
import
tempfile
import
unittest
import
unittest
...
@@ -10,6 +12,7 @@ from torch.utils.data import DataLoader
...
@@ -10,6 +12,7 @@ from torch.utils.data import DataLoader
from
transformers
import
BartTokenizer
from
transformers
import
BartTokenizer
from
.evaluate_cnn
import
run_generate
from
.evaluate_cnn
import
run_generate
from
.run_bart_sum
import
main
from
.utils
import
SummarizationDataset
from
.utils
import
SummarizationDataset
...
@@ -17,16 +20,61 @@ logging.basicConfig(level=logging.DEBUG)
...
@@ -17,16 +20,61 @@ logging.basicConfig(level=logging.DEBUG)
logger
=
logging
.
getLogger
()
logger
=
logging
.
getLogger
()
DEFAULT_ARGS
=
{
"output_dir"
:
""
,
"fp16"
:
False
,
"fp16_opt_level"
:
"O1"
,
"n_gpu"
:
1
,
"n_tpu_cores"
:
0
,
"max_grad_norm"
:
1.0
,
"do_train"
:
True
,
"do_predict"
:
False
,
"gradient_accumulation_steps"
:
1
,
"server_ip"
:
""
,
"server_port"
:
""
,
"seed"
:
42
,
"model_type"
:
"bart"
,
"model_name_or_path"
:
"sshleifer/bart-tiny-random"
,
"config_name"
:
""
,
"tokenizer_name"
:
""
,
"cache_dir"
:
""
,
"do_lower_case"
:
False
,
"learning_rate"
:
3e-05
,
"weight_decay"
:
0.0
,
"adam_epsilon"
:
1e-08
,
"warmup_steps"
:
0
,
"num_train_epochs"
:
1
,
"train_batch_size"
:
2
,
"eval_batch_size"
:
2
,
"max_source_length"
:
12
,
"max_target_length"
:
12
,
}
def
_dump_articles
(
path
:
Path
,
articles
:
list
):
def
_dump_articles
(
path
:
Path
,
articles
:
list
):
with
path
.
open
(
"w"
)
as
f
:
with
path
.
open
(
"w"
)
as
f
:
f
.
write
(
"
\n
"
.
join
(
articles
))
f
.
write
(
"
\n
"
.
join
(
articles
))
def
make_test_data_dir
():
tmp_dir
=
Path
(
tempfile
.
gettempdir
())
articles
=
[
" Sam ate lunch today"
,
"Sams lunch ingredients"
]
summaries
=
[
"A very interesting story about what I ate for lunch."
,
"Avocado, celery, turkey, coffee"
]
for
split
in
[
"train"
,
"val"
,
"test"
]:
_dump_articles
((
tmp_dir
/
f
"
{
split
}
.source"
),
articles
)
_dump_articles
((
tmp_dir
/
f
"
{
split
}
.target"
),
summaries
)
return
tmp_dir
class
TestBartExamples
(
unittest
.
TestCase
):
class
TestBartExamples
(
unittest
.
TestCase
):
def
test_bart_cnn_cli
(
self
):
@
classmethod
def
setUpClass
(
cls
):
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
logger
.
addHandler
(
stream_handler
)
logger
.
addHandler
(
stream_handler
)
logging
.
disable
(
logging
.
CRITICAL
)
# remove noisy download output from tracebacks
return
cls
def
test_bart_cnn_cli
(
self
):
tmp
=
Path
(
tempfile
.
gettempdir
())
/
"utest_generations_bart_sum.hypo"
tmp
=
Path
(
tempfile
.
gettempdir
())
/
"utest_generations_bart_sum.hypo"
output_file_name
=
Path
(
tempfile
.
gettempdir
())
/
"utest_output_bart_sum.hypo"
output_file_name
=
Path
(
tempfile
.
gettempdir
())
/
"utest_output_bart_sum.hypo"
articles
=
[
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."
]
articles
=
[
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."
]
...
@@ -34,7 +82,19 @@ class TestBartExamples(unittest.TestCase):
...
@@ -34,7 +82,19 @@ class TestBartExamples(unittest.TestCase):
testargs
=
[
"evaluate_cnn.py"
,
str
(
tmp
),
str
(
output_file_name
),
"sshleifer/bart-tiny-random"
]
testargs
=
[
"evaluate_cnn.py"
,
str
(
tmp
),
str
(
output_file_name
),
"sshleifer/bart-tiny-random"
]
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
run_generate
()
run_generate
()
self
.
assertTrue
(
output_file_name
.
exists
())
self
.
assertTrue
(
Path
(
output_file_name
).
exists
())
os
.
remove
(
Path
(
output_file_name
))
def
test_bart_run_sum_cli
(
self
):
args_d
:
dict
=
DEFAULT_ARGS
.
copy
()
tmp_dir
=
make_test_data_dir
()
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_"
)
args_d
.
update
(
data_dir
=
tmp_dir
,
model_type
=
"bart"
,
train_batch_size
=
2
,
eval_batch_size
=
2
,
n_gpu
=
0
,
output_dir
=
output_dir
,
)
args
=
argparse
.
Namespace
(
**
args_d
)
main
(
args
)
def
test_bart_summarization_dataset
(
self
):
def
test_bart_summarization_dataset
(
self
):
tmp_dir
=
Path
(
tempfile
.
gettempdir
())
tmp_dir
=
Path
(
tempfile
.
gettempdir
())
...
...
examples/transformer_base.py
View file @
c59b1e68
...
@@ -104,8 +104,8 @@ class BaseTransformer(pl.LightningModule):
...
@@ -104,8 +104,8 @@ class BaseTransformer(pl.LightningModule):
self
.
lr_scheduler
.
step
()
self
.
lr_scheduler
.
step
()
def
get_tqdm_dict
(
self
):
def
get_tqdm_dict
(
self
):
tqdm_dict
=
{
"loss"
:
"{:.3f}"
.
format
(
self
.
trainer
.
avg_loss
)
,
"lr"
:
self
.
lr_scheduler
.
get_last_lr
()[
-
1
]}
avg_loss
=
getattr
(
self
.
trainer
,
"
avg_loss
"
,
0.0
)
tqdm_dict
=
{
"loss"
:
"{:.3f}"
.
format
(
avg_loss
),
"lr"
:
self
.
lr_scheduler
.
get_last_lr
()[
-
1
]}
return
tqdm_dict
return
tqdm_dict
def
test_step
(
self
,
batch
,
batch_nb
):
def
test_step
(
self
,
batch
,
batch_nb
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment