Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3d76df3a
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "96592b544bb460085bb5e2522070254849e82350"
Unverified
Commit
3d76df3a
authored
Mar 24, 2020
by
Andre Carrera
Committed by
GitHub
Mar 24, 2020
Browse files
BART for summarization training with CNN/DM using pytorch-lightning
parent
eaabaaf7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
252 additions
and
2 deletions
+252
-2
examples/summarization/bart/README.md
examples/summarization/bart/README.md
+13
-0
examples/summarization/bart/run_bart_sum.py
examples/summarization/bart/run_bart_sum.py
+172
-0
examples/summarization/bart/run_train.sh
examples/summarization/bart/run_train.sh
+23
-0
examples/summarization/bart/utils.py
examples/summarization/bart/utils.py
+43
-0
examples/transformer_base.py
examples/transformer_base.py
+1
-2
No files found.
examples/summarization/bart/README.md
View file @
3d76df3a
...
...
@@ -14,6 +14,19 @@ python evaluate_cnn.py <path_to_test.source> cnn_test_summaries.txt
```
the default batch size, 8, fits in 16GB GPU memory, but may need to be adjusted to fit your system.
### Training
After downloading the CNN and Daily Mail datasets, preprocess the dataset:
```
commandline
git clone https://github.com/artmatsak/cnn-dailymail
cd cnn-dailymail && python make_datafiles.py ../cnn/stories/ ../dailymail/stories/
```
Run the training script:
`run_train.sh`
### Where is the code?
The core model is in
`src/transformers/modeling_bart.py`
. This directory only contains examples.
...
...
examples/summarization/bart/run_bart_sum.py
0 → 100644
View file @
3d76df3a
import
argparse
import
glob
import
logging
import
os
import
time
import
torch
from
torch.utils.data
import
DataLoader
from
transformer_base
import
BaseTransformer
,
add_generic_args
,
generic_train
,
get_linear_schedule_with_warmup
from
utils
import
SummarizationDataset
logger
=
logging
.
getLogger
(
__name__
)
class
BartSystem
(
BaseTransformer
):
mode
=
"language-modeling"
def
__init__
(
self
,
hparams
):
super
(
BartSystem
,
self
).
__init__
(
hparams
,
num_labels
=
None
,
mode
=
self
.
mode
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
decoder_input_ids
=
None
,
decoder_attention_mask
=
None
,
lm_labels
=
None
):
return
self
.
model
(
input_ids
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
decoder_attention_mask
=
decoder_attention_mask
,
lm_labels
=
lm_labels
,
)
def
_step
(
self
,
batch
):
y
=
batch
[
"target_ids"
]
y_ids
=
y
[:,
:
-
1
].
contiguous
()
lm_labels
=
y
[:,
1
:].
clone
()
lm_labels
[
y
[:,
1
:]
==
self
.
tokenizer
.
pad_token_id
]
=
-
100
outputs
=
self
(
input_ids
=
batch
[
"source_ids"
],
attention_mask
=
batch
[
"source_mask"
],
decoder_input_ids
=
y_ids
,
lm_labels
=
lm_labels
,
)
loss
=
outputs
[
0
]
return
loss
def
training_step
(
self
,
batch
,
batch_idx
):
loss
=
self
.
_step
(
batch
)
tensorboard_logs
=
{
"train_loss"
:
loss
}
return
{
"loss"
:
loss
,
"log"
:
tensorboard_logs
}
def
validation_step
(
self
,
batch
,
batch_idx
):
loss
=
self
.
_step
(
batch
)
return
{
"val_loss"
:
loss
}
def
validation_end
(
self
,
outputs
):
avg_loss
=
torch
.
stack
([
x
[
"val_loss"
]
for
x
in
outputs
]).
mean
()
tensorboard_logs
=
{
"val_loss"
:
avg_loss
}
return
{
"avg_val_loss"
:
avg_loss
,
"log"
:
tensorboard_logs
}
def
test_step
(
self
,
batch
,
batch_idx
):
generated_ids
=
self
.
model
.
generate
(
batch
[
"source_ids"
],
attention_mask
=
batch
[
"source_mask"
],
num_beams
=
1
,
max_length
=
80
,
repetition_penalty
=
2.5
,
length_penalty
=
1.0
,
early_stopping
=
True
,
)
preds
=
[
self
.
tokenizer
.
decode
(
g
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
True
)
for
g
in
generated_ids
]
target
=
[
self
.
tokenizer
.
decode
(
t
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
True
)
for
t
in
batch
[
"target_ids"
]
]
loss
=
self
.
_step
(
batch
)
return
{
"val_loss"
:
loss
,
"preds"
:
preds
,
"target"
:
target
}
def
test_end
(
self
,
outputs
):
return
self
.
validation_end
(
outputs
)
def
test_epoch_end
(
self
,
outputs
):
output_test_predictions_file
=
os
.
path
.
join
(
self
.
hparams
.
output_dir
,
"test_predictions.txt"
)
output_test_targets_file
=
os
.
path
.
join
(
self
.
hparams
.
output_dir
,
"test_targets.txt"
)
# write predictions and targets for later rouge evaluation.
with
open
(
output_test_predictions_file
,
"w+"
)
as
p_writer
,
open
(
output_test_targets_file
,
"w+"
)
as
t_writer
:
for
output_batch
in
outputs
:
p_writer
.
writelines
(
s
+
"
\n
"
for
s
in
output_batch
[
"preds"
])
t_writer
.
writelines
(
s
+
"
\n
"
for
s
in
output_batch
[
"target"
])
p_writer
.
close
()
t_writer
.
close
()
return
self
.
test_end
(
outputs
)
def
train_dataloader
(
self
):
train_dataset
=
SummarizationDataset
(
self
.
tokenizer
,
data_dir
=
self
.
hparams
.
data_dir
,
type_path
=
"train"
,
block_size
=
self
.
hparams
.
max_seq_length
)
dataloader
=
DataLoader
(
train_dataset
,
batch_size
=
self
.
hparams
.
train_batch_size
)
t_total
=
(
(
len
(
dataloader
.
dataset
)
//
(
self
.
hparams
.
train_batch_size
*
max
(
1
,
self
.
hparams
.
n_gpu
)))
//
self
.
hparams
.
gradient_accumulation_steps
*
float
(
self
.
hparams
.
num_train_epochs
)
)
scheduler
=
get_linear_schedule_with_warmup
(
self
.
opt
,
num_warmup_steps
=
self
.
hparams
.
warmup_steps
,
num_training_steps
=
t_total
)
self
.
lr_scheduler
=
scheduler
return
dataloader
def
val_dataloader
(
self
):
val_dataset
=
SummarizationDataset
(
self
.
tokenizer
,
data_dir
=
self
.
hparams
.
data_dir
,
type_path
=
"val"
,
block_size
=
self
.
hparams
.
max_seq_length
)
return
DataLoader
(
val_dataset
,
batch_size
=
self
.
hparams
.
eval_batch_size
)
def
test_dataloader
(
self
):
test_dataset
=
SummarizationDataset
(
self
.
tokenizer
,
data_dir
=
self
.
hparams
.
data_dir
,
type_path
=
"test"
,
block_size
=
self
.
hparams
.
max_seq_length
)
return
DataLoader
(
test_dataset
,
batch_size
=
self
.
hparams
.
eval_batch_size
)
@
staticmethod
def
add_model_specific_args
(
parser
,
root_dir
):
BaseTransformer
.
add_model_specific_args
(
parser
,
root_dir
)
# Add BART specific options
parser
.
add_argument
(
"--max_seq_length"
,
default
=
1024
,
type
=
int
,
help
=
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
,
)
parser
.
add_argument
(
"--data_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The input data dir. Should contain the dataset files for the CNN/DM summarization task."
,
)
return
parser
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
add_generic_args
(
parser
,
os
.
getcwd
())
parser
=
BartSystem
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
# If output_dir not provided, a folder will be generated in pwd
if
args
.
output_dir
is
None
:
args
.
output_dir
=
os
.
path
.
join
(
"./results"
,
f
"
{
args
.
task
}
_
{
args
.
model_type
}
_
{
time
.
strftime
(
'%Y%m%d_%H%M%S'
)
}
"
,)
os
.
makedirs
(
args
.
output_dir
)
model
=
BartSystem
(
args
)
trainer
=
generic_train
(
model
,
args
)
# Optionally, predict on dev set and write to output_dir
if
args
.
do_predict
:
checkpoints
=
list
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
args
.
output_dir
,
"checkpointepoch=*.ckpt"
),
recursive
=
True
)))
BartSystem
.
load_from_checkpoint
(
checkpoints
[
-
1
])
trainer
.
test
(
model
)
examples/summarization/bart/run_train.sh
0 → 100755
View file @
3d76df3a
# Install newest ptl.
pip
install
-U
git+http://github.com/PyTorchLightning/pytorch-lightning/
export
OUTPUT_DIR_NAME
=
bart_sum
export
CURRENT_DIR
=
${
PWD
}
export
OUTPUT_DIR
=
${
CURRENT_DIR
}
/
${
OUTPUT_DIR_NAME
}
# Make output directory if it doesn't exist
mkdir
-p
$OUTPUT_DIR
# Add parent directory to python path to access transformer_base.py
export
PYTHONPATH
=
"../../"
:
"
${
PYTHONPATH
}
"
python run_bart_sum.py
\
--data_dir
=
./cnn-dailymail/cnn_dm
\
--model_type
=
bart
\
--model_name_or_path
=
bart-large
\
--learning_rate
=
3e-5
\
--train_batch_size
=
4
\
--eval_batch_size
=
4
\
--output_dir
=
$OUTPUT_DIR
\
--do_train
\ No newline at end of file
examples/summarization/bart/utils.py
0 → 100644
View file @
3d76df3a
import
os
from
torch.utils.data
import
Dataset
class
SummarizationDataset
(
Dataset
):
def
__init__
(
self
,
tokenizer
,
data_dir
=
"./cnn-dailymail/cnn_dm/"
,
type_path
=
"train"
,
block_size
=
1024
):
super
(
SummarizationDataset
,).
__init__
()
self
.
tokenizer
=
tokenizer
self
.
source
=
[]
self
.
target
=
[]
print
(
"loading "
+
type_path
+
" source."
)
with
open
(
os
.
path
.
join
(
data_dir
,
type_path
+
".source"
),
"r"
)
as
f
:
for
text
in
f
.
readlines
():
# each text is a line and a full story
tokenized
=
tokenizer
.
batch_encode_plus
(
[
text
],
max_length
=
block_size
,
pad_to_max_length
=
True
,
return_tensors
=
"pt"
)
self
.
source
.
append
(
tokenized
)
f
.
close
()
print
(
"loading "
+
type_path
+
" target."
)
with
open
(
os
.
path
.
join
(
data_dir
,
type_path
+
".target"
),
"r"
)
as
f
:
for
text
in
f
.
readlines
():
# each text is a line and a summary
tokenized
=
tokenizer
.
batch_encode_plus
(
[
text
],
max_length
=
56
,
pad_to_max_length
=
True
,
return_tensors
=
"pt"
)
self
.
target
.
append
(
tokenized
)
f
.
close
()
def
__len__
(
self
):
return
len
(
self
.
source
)
def
__getitem__
(
self
,
index
):
source_ids
=
self
.
source
[
index
][
"input_ids"
].
squeeze
()
target_ids
=
self
.
target
[
index
][
"input_ids"
].
squeeze
()
src_mask
=
self
.
source
[
index
][
"attention_mask"
].
squeeze
()
# might need to squeeze
return
{
"source_ids"
:
source_ids
,
"source_mask"
:
src_mask
,
"target_ids"
:
target_ids
}
examples/transformer_base.py
View file @
3d76df3a
...
...
@@ -53,10 +53,9 @@ class BaseTransformer(pl.LightningModule):
super
(
BaseTransformer
,
self
).
__init__
()
self
.
hparams
=
hparams
self
.
hparams
.
model_type
=
self
.
hparams
.
model_type
.
lower
()
config
=
AutoConfig
.
from_pretrained
(
self
.
hparams
.
config_name
if
self
.
hparams
.
config_name
else
self
.
hparams
.
model_name_or_path
,
num_labels
=
num_labels
,
**
({
"
num_labels
"
:
num_labels
}
if
num_labels
is
not
None
else
{})
,
cache_dir
=
self
.
hparams
.
cache_dir
if
self
.
hparams
.
cache_dir
else
None
,
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment