Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f0c96faf
Unverified
Commit
f0c96faf
authored
Apr 16, 2020
by
Sam Shleifer
Committed by
GitHub
Apr 16, 2020
Browse files
[examples] summarization/bart/finetune.py supports t5 (#3824)
renames `run_bart_sum.py` to `finetune.py`
parent
0cec4fab
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
36 additions
and
14 deletions
+36
-14
examples/summarization/bart/finetune.py
examples/summarization/bart/finetune.py
+8
-8
examples/summarization/bart/run_train.sh
examples/summarization/bart/run_train.sh
+1
-1
examples/summarization/bart/run_train_tiny.sh
examples/summarization/bart/run_train_tiny.sh
+1
-1
examples/summarization/bart/test_bart_examples.py
examples/summarization/bart/test_bart_examples.py
+21
-3
examples/summarization/t5/README.md
examples/summarization/t5/README.md
+5
-1
No files found.
examples/summarization/bart/
run_bart_sum
.py
→
examples/summarization/bart/
finetune
.py
View file @
f0c96faf
...
@@ -19,7 +19,7 @@ except ImportError:
...
@@ -19,7 +19,7 @@ except ImportError:
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
BartSystem
(
BaseTransformer
):
class
SummarizationTrainer
(
BaseTransformer
):
mode
=
"language-modeling"
mode
=
"language-modeling"
...
@@ -64,18 +64,18 @@ class BartSystem(BaseTransformer):
...
@@ -64,18 +64,18 @@ class BartSystem(BaseTransformer):
return
{
"avg_val_loss"
:
avg_loss
,
"log"
:
tensorboard_logs
}
return
{
"avg_val_loss"
:
avg_loss
,
"log"
:
tensorboard_logs
}
def
test_step
(
self
,
batch
,
batch_idx
):
def
test_step
(
self
,
batch
,
batch_idx
):
# NOTE: this generation will not use the cache.
pad_token_id
=
self
.
tokenizer
.
pad_token_id
pad_token_id
=
self
.
tokenizer
.
pad_token_id
source_ids
,
source_mask
,
y
=
SummarizationDataset
.
trim_seq2seq_batch
(
batch
,
pad_token_id
)
source_ids
,
source_mask
,
y
=
SummarizationDataset
.
trim_seq2seq_batch
(
batch
,
pad_token_id
)
# NOTE: the
se
kwargs get more speed and lower quality summaries than those in evaluate_cnn.py
.
# NOTE: the
following
kwargs get more speed and lower quality summaries than those in evaluate_cnn.py
generated_ids
=
self
.
model
.
generate
(
generated_ids
=
self
.
model
.
generate
(
source_ids
,
input_ids
=
source_ids
,
source_mask
,
attention_mask
=
source_mask
,
num_beams
=
1
,
num_beams
=
1
,
max_length
=
80
,
max_length
=
80
,
repetition_penalty
=
2.5
,
repetition_penalty
=
2.5
,
length_penalty
=
1.0
,
length_penalty
=
1.0
,
early_stopping
=
True
,
early_stopping
=
True
,
use_cache
=
True
,
)
)
preds
=
[
preds
=
[
self
.
tokenizer
.
decode
(
g
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
True
)
self
.
tokenizer
.
decode
(
g
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
True
)
...
@@ -161,20 +161,20 @@ def main(args):
...
@@ -161,20 +161,20 @@ def main(args):
if
not
args
.
output_dir
:
if
not
args
.
output_dir
:
args
.
output_dir
=
os
.
path
.
join
(
"./results"
,
f
"
{
args
.
task
}
_
{
args
.
model_type
}
_
{
time
.
strftime
(
'%Y%m%d_%H%M%S'
)
}
"
,)
args
.
output_dir
=
os
.
path
.
join
(
"./results"
,
f
"
{
args
.
task
}
_
{
args
.
model_type
}
_
{
time
.
strftime
(
'%Y%m%d_%H%M%S'
)
}
"
,)
os
.
makedirs
(
args
.
output_dir
)
os
.
makedirs
(
args
.
output_dir
)
model
=
BartSystem
(
args
)
model
=
SummarizationTrainer
(
args
)
trainer
=
generic_train
(
model
,
args
)
trainer
=
generic_train
(
model
,
args
)
# Optionally, predict on dev set and write to output_dir
# Optionally, predict on dev set and write to output_dir
if
args
.
do_predict
:
if
args
.
do_predict
:
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpointepoch=*.ckpt"
),
recursive
=
True
)))
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpointepoch=*.ckpt"
),
recursive
=
True
)))
BartSystem
.
load_from_checkpoint
(
checkpoints
[
-
1
])
SummarizationTrainer
.
load_from_checkpoint
(
checkpoints
[
-
1
])
trainer
.
test
(
model
)
trainer
.
test
(
model
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
add_generic_args
(
parser
,
os
.
getcwd
())
add_generic_args
(
parser
,
os
.
getcwd
())
parser
=
BartSystem
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
parser
=
SummarizationTrainer
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
main
(
args
)
main
(
args
)
examples/summarization/bart/run_train.sh
View file @
f0c96faf
...
@@ -8,7 +8,7 @@ mkdir -p $OUTPUT_DIR
...
@@ -8,7 +8,7 @@ mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access transformer_base.py
# Add parent directory to python path to access transformer_base.py
export
PYTHONPATH
=
"../../"
:
"
${
PYTHONPATH
}
"
export
PYTHONPATH
=
"../../"
:
"
${
PYTHONPATH
}
"
python
run_bart_sum
.py
\
python
finetune
.py
\
--data_dir
=
./cnn-dailymail/cnn_dm
\
--data_dir
=
./cnn-dailymail/cnn_dm
\
--model_type
=
bart
\
--model_type
=
bart
\
--model_name_or_path
=
bart-large
\
--model_name_or_path
=
bart-large
\
...
...
examples/summarization/bart/run_train_tiny.sh
View file @
f0c96faf
...
@@ -14,7 +14,7 @@ mkdir -p $OUTPUT_DIR
...
@@ -14,7 +14,7 @@ mkdir -p $OUTPUT_DIR
# Add parent directory to python path to access transformer_base.py and utils.py
# Add parent directory to python path to access transformer_base.py and utils.py
export
PYTHONPATH
=
"../../"
:
"
${
PYTHONPATH
}
"
export
PYTHONPATH
=
"../../"
:
"
${
PYTHONPATH
}
"
python
run_bart_sum
.py
\
python
finetune
.py
\
--data_dir
=
cnn_tiny/
\
--data_dir
=
cnn_tiny/
\
--model_type
=
bart
\
--model_type
=
bart
\
--model_name_or_path
=
sshleifer/bart-tiny-random
\
--model_name_or_path
=
sshleifer/bart-tiny-random
\
...
...
examples/summarization/bart/test_bart_examples.py
View file @
f0c96faf
...
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
...
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
from
transformers
import
BartTokenizer
from
transformers
import
BartTokenizer
from
.evaluate_cnn
import
run_generate
from
.evaluate_cnn
import
run_generate
from
.
run_bart_sum
import
main
from
.
finetune
import
main
from
.utils
import
SummarizationDataset
from
.utils
import
SummarizationDataset
...
@@ -92,9 +92,27 @@ class TestBartExamples(unittest.TestCase):
...
@@ -92,9 +92,27 @@ class TestBartExamples(unittest.TestCase):
args_d
.
update
(
args_d
.
update
(
data_dir
=
tmp_dir
,
model_type
=
"bart"
,
train_batch_size
=
2
,
eval_batch_size
=
2
,
n_gpu
=
0
,
output_dir
=
output_dir
,
data_dir
=
tmp_dir
,
model_type
=
"bart"
,
train_batch_size
=
2
,
eval_batch_size
=
2
,
n_gpu
=
0
,
output_dir
=
output_dir
,
)
)
main
(
argparse
.
Namespace
(
**
args_d
))
args_d
.
update
({
"do_train"
:
False
,
"do_predict"
:
True
})
main
(
argparse
.
Namespace
(
**
args_d
))
args
=
argparse
.
Namespace
(
**
args_d
)
def
test_t5_run_sum_cli
(
self
):
main
(
args
)
args_d
:
dict
=
DEFAULT_ARGS
.
copy
()
tmp_dir
=
make_test_data_dir
()
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_"
)
args_d
.
update
(
data_dir
=
tmp_dir
,
model_type
=
"t5"
,
model_name_or_path
=
"patrickvonplaten/t5-tiny-random"
,
train_batch_size
=
2
,
eval_batch_size
=
2
,
n_gpu
=
0
,
output_dir
=
output_dir
,
do_predict
=
True
,
)
main
(
argparse
.
Namespace
(
**
args_d
))
# args_d.update({"do_train": False, "do_predict": True})
# main(argparse.Namespace(**args_d))
def
test_bart_summarization_dataset
(
self
):
def
test_bart_summarization_dataset
(
self
):
tmp_dir
=
Path
(
tempfile
.
gettempdir
())
tmp_dir
=
Path
(
tempfile
.
gettempdir
())
...
...
examples/summarization/t5/README.md
View file @
f0c96faf
...
@@ -15,7 +15,7 @@ wc -l cnn_articles_input_data.txt # should print 11490
...
@@ -15,7 +15,7 @@ wc -l cnn_articles_input_data.txt # should print 11490
wc
-l
cnn_articles_reference_summaries.txt
# should print 11490
wc
-l
cnn_articles_reference_summaries.txt
# should print 11490
```
```
###
Usage
###
Generating Summaries
To create summaries for each article in dataset, run:
To create summaries for each article in dataset, run:
```
bash
```
bash
...
@@ -23,3 +23,7 @@ python evaluate_cnn.py cnn_articles_input_data.txt cnn_generated_articles_summar
...
@@ -23,3 +23,7 @@ python evaluate_cnn.py cnn_articles_input_data.txt cnn_generated_articles_summar
```
```
The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
The default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
The rouge scores "rouge1, rouge2, rougeL" are automatically created and saved in
``rouge_score.txt``
.
The rouge scores "rouge1, rouge2, rougeL" are automatically created and saved in
``rouge_score.txt``
.
### Finetuning
Pass model_type=t5 and model
`examples/summarization/bart/finetune.py`
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment