Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c59b1e68
Unverified
Commit
c59b1e68
authored
Apr 15, 2020
by
Sam Shleifer
Committed by
GitHub
Apr 15, 2020
Browse files
[examples] unit test for run_bart_sum (#3544)
- adds pytorch-lightning dependency
parent
301bf8d1
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
121 additions
and
25 deletions
+121
-25
examples/requirements.txt
examples/requirements.txt
+2
-1
examples/summarization/bart/run_bart_sum.py
examples/summarization/bart/run_bart_sum.py
+21
-15
examples/summarization/bart/run_train.sh
examples/summarization/bart/run_train.sh
+1
-5
examples/summarization/bart/run_train_tiny.sh
examples/summarization/bart/run_train_tiny.sh
+33
-0
examples/summarization/bart/test_bart_examples.py
examples/summarization/bart/test_bart_examples.py
+62
-2
examples/transformer_base.py
examples/transformer_base.py
+2
-2
No files found.
examples/requirements.txt
View file @
c59b1e68
...
@@ -5,4 +5,5 @@ seqeval
...
@@ -5,4 +5,5 @@ seqeval
psutil
psutil
sacrebleu
sacrebleu
rouge-score
rouge-score
tensorflow_datasets
tensorflow_datasets
\ No newline at end of file
pytorch-lightning==0.7.3 # April 10, 2020 release
examples/summarization/bart/run_bart_sum.py
View file @
c59b1e68
...
@@ -8,7 +8,12 @@ import torch
...
@@ -8,7 +8,12 @@ import torch
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
transformer_base
import
BaseTransformer
,
add_generic_args
,
generic_train
,
get_linear_schedule_with_warmup
from
transformer_base
import
BaseTransformer
,
add_generic_args
,
generic_train
,
get_linear_schedule_with_warmup
from
utils
import
SummarizationDataset
try
:
from
.utils
import
SummarizationDataset
except
ImportError
:
from
utils
import
SummarizationDataset
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -20,6 +25,11 @@ class BartSystem(BaseTransformer):
...
@@ -20,6 +25,11 @@ class BartSystem(BaseTransformer):
def
__init__
(
self
,
hparams
):
def
__init__
(
self
,
hparams
):
super
().
__init__
(
hparams
,
num_labels
=
None
,
mode
=
self
.
mode
)
super
().
__init__
(
hparams
,
num_labels
=
None
,
mode
=
self
.
mode
)
self
.
dataset_kwargs
:
dict
=
dict
(
data_dir
=
self
.
hparams
.
data_dir
,
max_source_length
=
self
.
hparams
.
max_source_length
,
max_target_length
=
self
.
hparams
.
max_target_length
,
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
decoder_input_ids
=
None
,
lm_labels
=
None
):
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
decoder_input_ids
=
None
,
lm_labels
=
None
):
return
self
.
model
(
return
self
.
model
(
...
@@ -92,14 +102,6 @@ class BartSystem(BaseTransformer):
...
@@ -92,14 +102,6 @@ class BartSystem(BaseTransformer):
return
self
.
test_end
(
outputs
)
return
self
.
test_end
(
outputs
)
@
property
def
dataset_kwargs
(
self
):
return
dict
(
data_dir
=
self
.
hparams
.
data_dir
,
max_source_length
=
self
.
hparams
.
max_source_length
,
max_target_length
=
self
.
hparams
.
max_target_length
,
)
def
get_dataloader
(
self
,
type_path
:
str
,
batch_size
:
int
)
->
DataLoader
:
def
get_dataloader
(
self
,
type_path
:
str
,
batch_size
:
int
)
->
DataLoader
:
dataset
=
SummarizationDataset
(
self
.
tokenizer
,
type_path
=
type_path
,
**
self
.
dataset_kwargs
)
dataset
=
SummarizationDataset
(
self
.
tokenizer
,
type_path
=
type_path
,
**
self
.
dataset_kwargs
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
collate_fn
=
dataset
.
collate_fn
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
collate_fn
=
dataset
.
collate_fn
)
...
@@ -153,17 +155,12 @@ class BartSystem(BaseTransformer):
...
@@ -153,17 +155,12 @@ class BartSystem(BaseTransformer):
return
parser
return
parser
if
__name__
==
"__main__"
:
def
main
(
args
):
parser
=
argparse
.
ArgumentParser
()
add_generic_args
(
parser
,
os
.
getcwd
())
parser
=
BartSystem
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
# If output_dir not provided, a folder will be generated in pwd
# If output_dir not provided, a folder will be generated in pwd
if
not
args
.
output_dir
:
if
not
args
.
output_dir
:
args
.
output_dir
=
os
.
path
.
join
(
"./results"
,
f
"
{
args
.
task
}
_
{
args
.
model_type
}
_
{
time
.
strftime
(
'%Y%m%d_%H%M%S'
)
}
"
,)
args
.
output_dir
=
os
.
path
.
join
(
"./results"
,
f
"
{
args
.
task
}
_
{
args
.
model_type
}
_
{
time
.
strftime
(
'%Y%m%d_%H%M%S'
)
}
"
,)
os
.
makedirs
(
args
.
output_dir
)
os
.
makedirs
(
args
.
output_dir
)
model
=
BartSystem
(
args
)
model
=
BartSystem
(
args
)
trainer
=
generic_train
(
model
,
args
)
trainer
=
generic_train
(
model
,
args
)
...
@@ -172,3 +169,12 @@ if __name__ == "__main__":
...
@@ -172,3 +169,12 @@ if __name__ == "__main__":
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpointepoch=*.ckpt"
),
recursive
=
True
)))
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpointepoch=*.ckpt"
),
recursive
=
True
)))
BartSystem
.
load_from_checkpoint
(
checkpoints
[
-
1
])
BartSystem
.
load_from_checkpoint
(
checkpoints
[
-
1
])
trainer
.
test
(
model
)
trainer
.
test
(
model
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
add_generic_args
(
parser
,
os
.
getcwd
())
parser
=
BartSystem
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
main
(
args
)
examples/summarization/bart/run_train.sh
View file @
c59b1e68
# Install newest ptl.
pip
install
-U
git+http://github.com/PyTorchLightning/pytorch-lightning/
export
OUTPUT_DIR_NAME
=
bart_sum
export
OUTPUT_DIR_NAME
=
bart_sum
export
CURRENT_DIR
=
${
PWD
}
export
CURRENT_DIR
=
${
PWD
}
export
OUTPUT_DIR
=
${
CURRENT_DIR
}
/
${
OUTPUT_DIR_NAME
}
export
OUTPUT_DIR
=
${
CURRENT_DIR
}
/
${
OUTPUT_DIR_NAME
}
...
@@ -20,4 +16,4 @@ python run_bart_sum.py \
...
@@ -20,4 +16,4 @@ python run_bart_sum.py \
--train_batch_size
=
4
\
--train_batch_size
=
4
\
--eval_batch_size
=
4
\
--eval_batch_size
=
4
\
--output_dir
=
$OUTPUT_DIR
\
--output_dir
=
$OUTPUT_DIR
\
--do_train
--do_train
$@
\ No newline at end of file
examples/summarization/bart/run_train_tiny.sh
0 → 100755
View file @
c59b1e68
# Script for verifying that run_bart_sum can be invoked from its directory
# Get tiny dataset with cnn_dm format (4 examples for train, val, test)
wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_tiny.tgz
tar
-xzvf
cnn_tiny.tgz
rm
cnn_tiny.tgz
export
OUTPUT_DIR_NAME
=
bart_utest_output
export
CURRENT_DIR
=
${
PWD
}
export
OUTPUT_DIR
=
${
CURRENT_DIR
}
/
${
OUTPUT_DIR_NAME
}
# Make output directory if it doesn't exist
mkdir
-p
$OUTPUT_DIR
# Add parent directory to python path to access transformer_base.py and utils.py
export
PYTHONPATH
=
"../../"
:
"
${
PYTHONPATH
}
"
python run_bart_sum.py
\
--data_dir
=
cnn_tiny/
\
--model_type
=
bart
\
--model_name_or_path
=
sshleifer/bart-tiny-random
\
--learning_rate
=
3e-5
\
--train_batch_size
=
2
\
--eval_batch_size
=
2
\
--output_dir
=
$OUTPUT_DIR
\
--num_train_epochs
=
1
\
--n_gpu
=
0
\
--do_train
$@
rm
-rf
cnn_tiny
rm
-rf
$OUTPUT_DIR
examples/summarization/bart/test_bart_examples.py
View file @
c59b1e68
import
argparse
import
logging
import
logging
import
os
import
sys
import
sys
import
tempfile
import
tempfile
import
unittest
import
unittest
...
@@ -10,6 +12,7 @@ from torch.utils.data import DataLoader
...
@@ -10,6 +12,7 @@ from torch.utils.data import DataLoader
from
transformers
import
BartTokenizer
from
transformers
import
BartTokenizer
from
.evaluate_cnn
import
run_generate
from
.evaluate_cnn
import
run_generate
from
.run_bart_sum
import
main
from
.utils
import
SummarizationDataset
from
.utils
import
SummarizationDataset
...
@@ -17,16 +20,61 @@ logging.basicConfig(level=logging.DEBUG)
...
@@ -17,16 +20,61 @@ logging.basicConfig(level=logging.DEBUG)
logger
=
logging
.
getLogger
()
logger
=
logging
.
getLogger
()
DEFAULT_ARGS
=
{
"output_dir"
:
""
,
"fp16"
:
False
,
"fp16_opt_level"
:
"O1"
,
"n_gpu"
:
1
,
"n_tpu_cores"
:
0
,
"max_grad_norm"
:
1.0
,
"do_train"
:
True
,
"do_predict"
:
False
,
"gradient_accumulation_steps"
:
1
,
"server_ip"
:
""
,
"server_port"
:
""
,
"seed"
:
42
,
"model_type"
:
"bart"
,
"model_name_or_path"
:
"sshleifer/bart-tiny-random"
,
"config_name"
:
""
,
"tokenizer_name"
:
""
,
"cache_dir"
:
""
,
"do_lower_case"
:
False
,
"learning_rate"
:
3e-05
,
"weight_decay"
:
0.0
,
"adam_epsilon"
:
1e-08
,
"warmup_steps"
:
0
,
"num_train_epochs"
:
1
,
"train_batch_size"
:
2
,
"eval_batch_size"
:
2
,
"max_source_length"
:
12
,
"max_target_length"
:
12
,
}
def
_dump_articles
(
path
:
Path
,
articles
:
list
):
def
_dump_articles
(
path
:
Path
,
articles
:
list
):
with
path
.
open
(
"w"
)
as
f
:
with
path
.
open
(
"w"
)
as
f
:
f
.
write
(
"
\n
"
.
join
(
articles
))
f
.
write
(
"
\n
"
.
join
(
articles
))
def
make_test_data_dir
():
tmp_dir
=
Path
(
tempfile
.
gettempdir
())
articles
=
[
" Sam ate lunch today"
,
"Sams lunch ingredients"
]
summaries
=
[
"A very interesting story about what I ate for lunch."
,
"Avocado, celery, turkey, coffee"
]
for
split
in
[
"train"
,
"val"
,
"test"
]:
_dump_articles
((
tmp_dir
/
f
"
{
split
}
.source"
),
articles
)
_dump_articles
((
tmp_dir
/
f
"
{
split
}
.target"
),
summaries
)
return
tmp_dir
class
TestBartExamples
(
unittest
.
TestCase
):
class
TestBartExamples
(
unittest
.
TestCase
):
def
test_bart_cnn_cli
(
self
):
@
classmethod
def
setUpClass
(
cls
):
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
logger
.
addHandler
(
stream_handler
)
logger
.
addHandler
(
stream_handler
)
logging
.
disable
(
logging
.
CRITICAL
)
# remove noisy download output from tracebacks
return
cls
def
test_bart_cnn_cli
(
self
):
tmp
=
Path
(
tempfile
.
gettempdir
())
/
"utest_generations_bart_sum.hypo"
tmp
=
Path
(
tempfile
.
gettempdir
())
/
"utest_generations_bart_sum.hypo"
output_file_name
=
Path
(
tempfile
.
gettempdir
())
/
"utest_output_bart_sum.hypo"
output_file_name
=
Path
(
tempfile
.
gettempdir
())
/
"utest_output_bart_sum.hypo"
articles
=
[
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."
]
articles
=
[
" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."
]
...
@@ -34,7 +82,19 @@ class TestBartExamples(unittest.TestCase):
...
@@ -34,7 +82,19 @@ class TestBartExamples(unittest.TestCase):
testargs
=
[
"evaluate_cnn.py"
,
str
(
tmp
),
str
(
output_file_name
),
"sshleifer/bart-tiny-random"
]
testargs
=
[
"evaluate_cnn.py"
,
str
(
tmp
),
str
(
output_file_name
),
"sshleifer/bart-tiny-random"
]
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
run_generate
()
run_generate
()
self
.
assertTrue
(
output_file_name
.
exists
())
self
.
assertTrue
(
Path
(
output_file_name
).
exists
())
os
.
remove
(
Path
(
output_file_name
))
def
test_bart_run_sum_cli
(
self
):
args_d
:
dict
=
DEFAULT_ARGS
.
copy
()
tmp_dir
=
make_test_data_dir
()
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_"
)
args_d
.
update
(
data_dir
=
tmp_dir
,
model_type
=
"bart"
,
train_batch_size
=
2
,
eval_batch_size
=
2
,
n_gpu
=
0
,
output_dir
=
output_dir
,
)
args
=
argparse
.
Namespace
(
**
args_d
)
main
(
args
)
def
test_bart_summarization_dataset
(
self
):
def
test_bart_summarization_dataset
(
self
):
tmp_dir
=
Path
(
tempfile
.
gettempdir
())
tmp_dir
=
Path
(
tempfile
.
gettempdir
())
...
...
examples/transformer_base.py
View file @
c59b1e68
...
@@ -104,8 +104,8 @@ class BaseTransformer(pl.LightningModule):
...
@@ -104,8 +104,8 @@ class BaseTransformer(pl.LightningModule):
self
.
lr_scheduler
.
step
()
self
.
lr_scheduler
.
step
()
def
get_tqdm_dict
(
self
):
def
get_tqdm_dict
(
self
):
tqdm_dict
=
{
"loss"
:
"{:.3f}"
.
format
(
self
.
trainer
.
avg_loss
)
,
"lr"
:
self
.
lr_scheduler
.
get_last_lr
()[
-
1
]}
avg_loss
=
getattr
(
self
.
trainer
,
"
avg_loss
"
,
0.0
)
tqdm_dict
=
{
"loss"
:
"{:.3f}"
.
format
(
avg_loss
),
"lr"
:
self
.
lr_scheduler
.
get_last_lr
()[
-
1
]}
return
tqdm_dict
return
tqdm_dict
def
test_step
(
self
,
batch
,
batch_nb
):
def
test_step
(
self
,
batch
,
batch_nb
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment