Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3d76df3a
Unverified
Commit
3d76df3a
authored
Mar 24, 2020
by
Andre Carrera
Committed by
GitHub
Mar 24, 2020
Browse files
BART for summarization training with CNN/DM using pytorch-lightning
parent
eaabaaf7
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
252 additions
and
2 deletions
+252
-2
examples/summarization/bart/README.md
examples/summarization/bart/README.md
+13
-0
examples/summarization/bart/run_bart_sum.py
examples/summarization/bart/run_bart_sum.py
+172
-0
examples/summarization/bart/run_train.sh
examples/summarization/bart/run_train.sh
+23
-0
examples/summarization/bart/utils.py
examples/summarization/bart/utils.py
+43
-0
examples/transformer_base.py
examples/transformer_base.py
+1
-2
No files found.
examples/summarization/bart/README.md
View file @
3d76df3a
...
...
@@ -14,6 +14,19 @@ python evaluate_cnn.py <path_to_test.source> cnn_test_summaries.txt
```
the default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
### Training
After downloading the CNN and Daily Mail datasets, preprocess the dataset:
```
commandline
git clone https://github.com/artmatsak/cnn-dailymail
cd cnn-dailymail && python make_datafiles.py ../cnn/stories/ ../dailymail/stories/
```
Run the training script:
`run_train.sh`
### Where is the code?
The core model is in
`src/transformers/modeling_bart.py`
. This directory only contains examples.
...
...
examples/summarization/bart/run_bart_sum.py
0 → 100644
View file @
3d76df3a
import
argparse
import
glob
import
logging
import
os
import
time
import
torch
from
torch.utils.data
import
DataLoader
from
transformer_base
import
BaseTransformer
,
add_generic_args
,
generic_train
,
get_linear_schedule_with_warmup
from
utils
import
SummarizationDataset
logger
=
logging
.
getLogger
(
__name__
)
class
BartSystem
(
BaseTransformer
):
mode
=
"language-modeling"
def
__init__
(
self
,
hparams
):
super
(
BartSystem
,
self
).
__init__
(
hparams
,
num_labels
=
None
,
mode
=
self
.
mode
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
decoder_input_ids
=
None
,
decoder_attention_mask
=
None
,
lm_labels
=
None
):
return
self
.
model
(
input_ids
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
decoder_attention_mask
=
decoder_attention_mask
,
lm_labels
=
lm_labels
,
)
def
_step
(
self
,
batch
):
y
=
batch
[
"target_ids"
]
y_ids
=
y
[:,
:
-
1
].
contiguous
()
lm_labels
=
y
[:,
1
:].
clone
()
lm_labels
[
y
[:,
1
:]
==
self
.
tokenizer
.
pad_token_id
]
=
-
100
outputs
=
self
(
input_ids
=
batch
[
"source_ids"
],
attention_mask
=
batch
[
"source_mask"
],
decoder_input_ids
=
y_ids
,
lm_labels
=
lm_labels
,
)
loss
=
outputs
[
0
]
return
loss
def
training_step
(
self
,
batch
,
batch_idx
):
loss
=
self
.
_step
(
batch
)
tensorboard_logs
=
{
"train_loss"
:
loss
}
return
{
"loss"
:
loss
,
"log"
:
tensorboard_logs
}
def
validation_step
(
self
,
batch
,
batch_idx
):
loss
=
self
.
_step
(
batch
)
return
{
"val_loss"
:
loss
}
def
validation_end
(
self
,
outputs
):
avg_loss
=
torch
.
stack
([
x
[
"val_loss"
]
for
x
in
outputs
]).
mean
()
tensorboard_logs
=
{
"val_loss"
:
avg_loss
}
return
{
"avg_val_loss"
:
avg_loss
,
"log"
:
tensorboard_logs
}
def
test_step
(
self
,
batch
,
batch_idx
):
generated_ids
=
self
.
model
.
generate
(
batch
[
"source_ids"
],
attention_mask
=
batch
[
"source_mask"
],
num_beams
=
1
,
max_length
=
80
,
repetition_penalty
=
2.5
,
length_penalty
=
1.0
,
early_stopping
=
True
,
)
preds
=
[
self
.
tokenizer
.
decode
(
g
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
True
)
for
g
in
generated_ids
]
target
=
[
self
.
tokenizer
.
decode
(
t
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
True
)
for
t
in
batch
[
"target_ids"
]
]
loss
=
self
.
_step
(
batch
)
return
{
"val_loss"
:
loss
,
"preds"
:
preds
,
"target"
:
target
}
def
test_end
(
self
,
outputs
):
return
self
.
validation_end
(
outputs
)
def
test_epoch_end
(
self
,
outputs
):
output_test_predictions_file
=
os
.
path
.
join
(
self
.
hparams
.
output_dir
,
"test_predictions.txt"
)
output_test_targets_file
=
os
.
path
.
join
(
self
.
hparams
.
output_dir
,
"test_targets.txt"
)
# write predictions and targets for later rouge evaluation.
with
open
(
output_test_predictions_file
,
"w+"
)
as
p_writer
,
open
(
output_test_targets_file
,
"w+"
)
as
t_writer
:
for
output_batch
in
outputs
:
p_writer
.
writelines
(
s
+
"
\n
"
for
s
in
output_batch
[
"preds"
])
t_writer
.
writelines
(
s
+
"
\n
"
for
s
in
output_batch
[
"target"
])
p_writer
.
close
()
t_writer
.
close
()
return
self
.
test_end
(
outputs
)
def
train_dataloader
(
self
):
train_dataset
=
SummarizationDataset
(
self
.
tokenizer
,
data_dir
=
self
.
hparams
.
data_dir
,
type_path
=
"train"
,
block_size
=
self
.
hparams
.
max_seq_length
)
dataloader
=
DataLoader
(
train_dataset
,
batch_size
=
self
.
hparams
.
train_batch_size
)
t_total
=
(
(
len
(
dataloader
.
dataset
)
//
(
self
.
hparams
.
train_batch_size
*
max
(
1
,
self
.
hparams
.
n_gpu
)))
//
self
.
hparams
.
gradient_accumulation_steps
*
float
(
self
.
hparams
.
num_train_epochs
)
)
scheduler
=
get_linear_schedule_with_warmup
(
self
.
opt
,
num_warmup_steps
=
self
.
hparams
.
warmup_steps
,
num_training_steps
=
t_total
)
self
.
lr_scheduler
=
scheduler
return
dataloader
def
val_dataloader
(
self
):
val_dataset
=
SummarizationDataset
(
self
.
tokenizer
,
data_dir
=
self
.
hparams
.
data_dir
,
type_path
=
"val"
,
block_size
=
self
.
hparams
.
max_seq_length
)
return
DataLoader
(
val_dataset
,
batch_size
=
self
.
hparams
.
eval_batch_size
)
def
test_dataloader
(
self
):
test_dataset
=
SummarizationDataset
(
self
.
tokenizer
,
data_dir
=
self
.
hparams
.
data_dir
,
type_path
=
"test"
,
block_size
=
self
.
hparams
.
max_seq_length
)
return
DataLoader
(
test_dataset
,
batch_size
=
self
.
hparams
.
eval_batch_size
)
@
staticmethod
def
add_model_specific_args
(
parser
,
root_dir
):
BaseTransformer
.
add_model_specific_args
(
parser
,
root_dir
)
# Add BART specific options
parser
.
add_argument
(
"--max_seq_length"
,
default
=
1024
,
type
=
int
,
help
=
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
,
)
parser
.
add_argument
(
"--data_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The input data dir. Should contain the dataset files for the CNN/DM summarization task."
,
)
return
parser
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
add_generic_args
(
parser
,
os
.
getcwd
())
parser
=
BartSystem
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
# If output_dir not provided, a folder will be generated in pwd
if
args
.
output_dir
is
None
:
args
.
output_dir
=
os
.
path
.
join
(
"./results"
,
f
"
{
args
.
task
}
_
{
args
.
model_type
}
_
{
time
.
strftime
(
'%Y%m%d_%H%M%S'
)
}
"
,)
os
.
makedirs
(
args
.
output_dir
)
model
=
BartSystem
(
args
)
trainer
=
generic_train
(
model
,
args
)
# Optionally, predict on dev set and write to output_dir
if
args
.
do_predict
:
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpointepoch=*.ckpt"
),
recursive
=
True
)))
BartSystem
.
load_from_checkpoint
(
checkpoints
[
-
1
])
trainer
.
test
(
model
)
examples/summarization/bart/run_train.sh
0 → 100755
View file @
3d76df3a
# Install newest ptl.
pip
install
-U
git+http://github.com/PyTorchLightning/pytorch-lightning/
export
OUTPUT_DIR_NAME
=
bart_sum
export
CURRENT_DIR
=
${
PWD
}
export
OUTPUT_DIR
=
${
CURRENT_DIR
}
/
${
OUTPUT_DIR_NAME
}
# Make output directory if it doesn't exist
mkdir
-p
$OUTPUT_DIR
# Add parent directory to python path to access transformer_base.py
export
PYTHONPATH
=
"../../"
:
"
${
PYTHONPATH
}
"
python run_bart_sum.py
\
--data_dir
=
./cnn-dailymail/cnn_dm
\
--model_type
=
bart
\
--model_name_or_path
=
bart-large
\
--learning_rate
=
3e-5
\
--train_batch_size
=
4
\
--eval_batch_size
=
4
\
--output_dir
=
$OUTPUT_DIR
\
--do_train
\ No newline at end of file
examples/summarization/bart/utils.py
0 → 100644
View file @
3d76df3a
import
os
from
torch.utils.data
import
Dataset
class
SummarizationDataset
(
Dataset
):
def
__init__
(
self
,
tokenizer
,
data_dir
=
"./cnn-dailymail/cnn_dm/"
,
type_path
=
"train"
,
block_size
=
1024
):
super
(
SummarizationDataset
,).
__init__
()
self
.
tokenizer
=
tokenizer
self
.
source
=
[]
self
.
target
=
[]
print
(
"loading "
+
type_path
+
" source."
)
with
open
(
os
.
path
.
join
(
data_dir
,
type_path
+
".source"
),
"r"
)
as
f
:
for
text
in
f
.
readlines
():
# each text is a line and a full story
tokenized
=
tokenizer
.
batch_encode_plus
(
[
text
],
max_length
=
block_size
,
pad_to_max_length
=
True
,
return_tensors
=
"pt"
)
self
.
source
.
append
(
tokenized
)
f
.
close
()
print
(
"loading "
+
type_path
+
" target."
)
with
open
(
os
.
path
.
join
(
data_dir
,
type_path
+
".target"
),
"r"
)
as
f
:
for
text
in
f
.
readlines
():
# each text is a line and a summary
tokenized
=
tokenizer
.
batch_encode_plus
(
[
text
],
max_length
=
56
,
pad_to_max_length
=
True
,
return_tensors
=
"pt"
)
self
.
target
.
append
(
tokenized
)
f
.
close
()
def
__len__
(
self
):
return
len
(
self
.
source
)
def
__getitem__
(
self
,
index
):
source_ids
=
self
.
source
[
index
][
"input_ids"
].
squeeze
()
target_ids
=
self
.
target
[
index
][
"input_ids"
].
squeeze
()
src_mask
=
self
.
source
[
index
][
"attention_mask"
].
squeeze
()
# might need to squeeze
return
{
"source_ids"
:
source_ids
,
"source_mask"
:
src_mask
,
"target_ids"
:
target_ids
}
examples/transformer_base.py
View file @
3d76df3a
...
...
@@ -53,10 +53,9 @@ class BaseTransformer(pl.LightningModule):
super
(
BaseTransformer
,
self
).
__init__
()
self
.
hparams
=
hparams
self
.
hparams
.
model_type
=
self
.
hparams
.
model_type
.
lower
()
config
=
AutoConfig
.
from_pretrained
(
self
.
hparams
.
config_name
if
self
.
hparams
.
config_name
else
self
.
hparams
.
model_name_or_path
,
num_labels
=
num_labels
,
**
({
"
num_labels
"
:
num_labels
}
if
num_labels
is
not
None
else
{})
,
cache_dir
=
self
.
hparams
.
cache_dir
if
self
.
hparams
.
cache_dir
else
None
,
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment