Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4c3ac4a7
Commit
4c3ac4a7
authored
Oct 18, 2019
by
Rémi Louf
Browse files
here's one big commit
parent
932543f7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
951 additions
and
47 deletions
+951
-47
examples/README.md
examples/README.md
+3
-2
examples/run_summarization_finetuning.py
examples/run_summarization_finetuning.py
+620
-0
examples/run_summarization_finetuning_test.py
examples/run_summarization_finetuning_test.py
+14
-14
transformers/__init__.py
transformers/__init__.py
+1
-1
transformers/modeling_beam_search.py
transformers/modeling_beam_search.py
+240
-0
transformers/modeling_bert.py
transformers/modeling_bert.py
+12
-8
transformers/modeling_seq2seq.py
transformers/modeling_seq2seq.py
+61
-22
No files found.
examples/README.md
View file @
4c3ac4a7
...
@@ -393,7 +393,8 @@ This fine-tuned model is available as a checkpoint under the reference
...
@@ -393,7 +393,8 @@ This fine-tuned model is available as a checkpoint under the reference
## Seq2seq model fine-tuning
## Seq2seq model fine-tuning
Based on the script
[
`run_seq2seq_finetuning.py`
](
https://github.com/huggingface/transformers/blob/master/examples/run_seq2seq_finetuning.py
)
.
Based on the script
[
`run_summarization_finetuning.py`
](
https://github.com/huggingface/transformers/blob/master/examples/run_summarization_finetuning.py
)
.
Before running this script you should download
**both**
CNN and Daily Mail
Before running this script you should download
**both**
CNN and Daily Mail
datasets from
[
Kyunghyun Cho's website
](
https://cs.nyu.edu/~kcho/DMQA/
)
(
the
datasets from
[
Kyunghyun Cho's website
](
https://cs.nyu.edu/~kcho/DMQA/
)
(
the
...
@@ -412,7 +413,7 @@ archive.
...
@@ -412,7 +413,7 @@ archive.
```
bash
```
bash
export
DATA_PATH
=
/path/to/dataset/
export
DATA_PATH
=
/path/to/dataset/
python run_s
eq2seq
_finetuning.py
\
python run_s
ummarization
_finetuning.py
\
--output_dir
=
output
\
--output_dir
=
output
\
--model_type
=
bert2bert
\
--model_type
=
bert2bert
\
--model_name_or_path
=
bert2bert
\
--model_name_or_path
=
bert2bert
\
...
...
examples/run_s
eq2seq
_finetuning.py
→
examples/run_s
ummarization
_finetuning.py
View file @
4c3ac4a7
# coding=utf-8
# coding=utf-8
# Copyright 201
8 The Microsoft Reseach team and
The HuggingFace Inc. team.
# Copyright 201
9
The HuggingFace Inc. team.
# Copyright (c) 201
8 Microsoft and
The HuggingFace Inc. All rights reserved.
# Copyright (c) 201
9
The HuggingFace Inc. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -18,18 +18,21 @@
...
@@ -18,18 +18,21 @@
import
argparse
import
argparse
from
collections
import
deque
from
collections
import
deque
import
logging
import
logging
import
os
import
pickle
import
pickle
import
random
import
random
import
o
s
import
sy
s
import
numpy
as
np
import
numpy
as
np
from
tqdm
import
tqdm
,
trange
from
tqdm
import
tqdm
,
trange
import
torch
import
torch
from
torch.utils.data
import
Dataset
,
RandomSampler
from
torch.optim
import
Adam
from
torch.utils.data
import
Dataset
,
DataLoader
,
RandomSampler
,
SequentialSampler
from
transformers
import
AutoTokenizer
,
Model2Model
from
transformers
import
AutoTokenizer
,
PreTrainedSeq2seq
,
Model2Model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
logging
.
basicConfig
(
stream
=
sys
.
stdout
,
level
=
logging
.
INFO
)
def
set_seed
(
args
):
def
set_seed
(
args
):
...
@@ -61,7 +64,7 @@ class TextDataset(Dataset):
...
@@ -61,7 +64,7 @@ class TextDataset(Dataset):
def
__init__
(
self
,
tokenizer
,
prefix
=
"train"
,
data_dir
=
""
,
block_size
=
512
):
def
__init__
(
self
,
tokenizer
,
prefix
=
"train"
,
data_dir
=
""
,
block_size
=
512
):
assert
os
.
path
.
isdir
(
data_dir
)
assert
os
.
path
.
isdir
(
data_dir
)
# Load features that have already been computed if
present
# Load
the
features that have already been computed
,
if
any
cached_features_file
=
os
.
path
.
join
(
cached_features_file
=
os
.
path
.
join
(
data_dir
,
"cached_lm_{}_{}"
.
format
(
block_size
,
prefix
)
data_dir
,
"cached_lm_{}_{}"
.
format
(
block_size
,
prefix
)
)
)
...
@@ -72,12 +75,11 @@ class TextDataset(Dataset):
...
@@ -72,12 +75,11 @@ class TextDataset(Dataset):
return
return
logger
.
info
(
"Creating features from dataset at %s"
,
data_dir
)
logger
.
info
(
"Creating features from dataset at %s"
,
data_dir
)
self
.
examples
=
[]
datasets
=
[
"cnn"
,
"dailymail"
]
datasets
=
[
"cnn"
,
"dailymail"
]
self
.
examples
=
{
"source"
:
[],
"target"
:
[]}
for
dataset
in
datasets
:
for
dataset
in
datasets
:
path_to_stories
=
os
.
path
.
join
(
data_dir
,
dataset
,
"stories"
)
path_to_stories
=
os
.
path
.
join
(
data_dir
,
dataset
,
"stories"
)
assert
os
.
path
.
isdir
(
path_to_stories
)
story_filenames_list
=
os
.
listdir
(
path_to_stories
)
story_filenames_list
=
os
.
listdir
(
path_to_stories
)
for
story_filename
in
story_filenames_list
:
for
story_filename
in
story_filenames_list
:
path_to_story
=
os
.
path
.
join
(
path_to_stories
,
story_filename
)
path_to_story
=
os
.
path
.
join
(
path_to_stories
,
story_filename
)
...
@@ -85,19 +87,19 @@ class TextDataset(Dataset):
...
@@ -85,19 +87,19 @@ class TextDataset(Dataset):
continue
continue
with
open
(
path_to_story
,
encoding
=
"utf-8"
)
as
source
:
with
open
(
path_to_story
,
encoding
=
"utf-8"
)
as
source
:
try
:
raw_story
=
source
.
read
()
raw_story
=
source
.
read
()
story_lines
,
summary_lines
=
process_story
(
raw_story
)
story
,
summary
=
process_story
(
raw_story
)
if
len
(
summary_lines
)
==
0
or
len
(
story_lines
)
==
0
:
except
IndexError
:
# skip ill-formed stories
continue
continue
story
=
token
izer
.
encode
(
story
)
story
_
token
_ids
,
summary_token_ids
=
_encode_for_summarization
(
story_
seq
=
_fit_to_block_size
(
story
,
block_s
ize
)
story_
lines
,
summary_lines
,
token
ize
r
)
s
ummary
=
tokenizer
.
encode
(
summary
)
s
tory_seq
=
_fit_to_block_size
(
story_token_ids
,
block_size
)
s
ummary_seq
=
_fit_to_block_size
(
summary
,
block_size
)
s
elf
.
examples
[
"source"
].
append
(
story_seq
)
self
.
examples
.
append
((
story_seq
,
summary_seq
))
summary_seq
=
_fit_to_block_size
(
summary_token_ids
,
block_size
)
self
.
examples
[
"summary"
].
append
(
summary_seq
)
logger
.
info
(
"Saving features into cache file %s"
,
cached_features_file
)
logger
.
info
(
"Saving features into cache file %s"
,
cached_features_file
)
with
open
(
cached_features_file
,
"wb"
)
as
sink
:
with
open
(
cached_features_file
,
"wb"
)
as
sink
:
...
@@ -107,7 +109,10 @@ class TextDataset(Dataset):
...
@@ -107,7 +109,10 @@ class TextDataset(Dataset):
return
len
(
self
.
examples
)
return
len
(
self
.
examples
)
def
__getitem__
(
self
,
items
):
def
__getitem__
(
self
,
items
):
return
torch
.
tensor
(
self
.
examples
[
items
])
return
(
torch
.
tensor
(
self
.
examples
[
"source"
][
items
]),
torch
.
tensor
(
self
.
examples
[
"target"
][
items
]),
)
def
process_story
(
raw_story
):
def
process_story
(
raw_story
):
...
@@ -119,33 +124,55 @@ def process_story(raw_story):
...
@@ -119,33 +124,55 @@ def process_story(raw_story):
Raises:
Raises:
IndexError: If the stoy is empty or contains no highlights.
IndexError: If the stoy is empty or contains no highlights.
"""
"""
file
_lines
=
list
(
nonempty
_lines
=
list
(
filter
(
lambda
x
:
len
(
x
)
!=
0
,
[
line
.
strip
()
for
line
in
raw_story
.
split
(
"
\n
"
)])
filter
(
lambda
x
:
len
(
x
)
!=
0
,
[
line
.
strip
()
for
line
in
raw_story
.
split
(
"
\n
"
)])
)
)
# for some unknown reason some lines miss a period, add it
# for some unknown reason some lines miss a period, add it
file
_lines
=
[
_add_missing_period
(
line
)
for
line
in
file
_lines
]
nonempty
_lines
=
[
_add_missing_period
(
line
)
for
line
in
nonempty
_lines
]
# gather article lines
# gather article lines
story_lines
=
[]
story_lines
=
[]
lines
=
deque
(
file
_lines
)
lines
=
deque
(
nonempty
_lines
)
while
True
:
while
True
:
try
:
try
:
element
=
lines
.
popleft
()
element
=
lines
.
popleft
()
if
element
.
startswith
(
"@highlight"
):
if
element
.
startswith
(
"@highlight"
):
break
break
story_lines
.
append
(
element
)
story_lines
.
append
(
element
)
except
IndexError
as
ie
:
# if "@highlight" absent from file
except
IndexError
:
raise
ie
# if "@highlight" is absent from the file we pop
# all elements until there is None.
return
story_lines
,
[]
# gather summary lines
# gather summary lines
highlights
_lines
=
list
(
filter
(
lambda
t
:
not
t
.
startswith
(
"@highlight"
),
lines
))
summary
_lines
=
list
(
filter
(
lambda
t
:
not
t
.
startswith
(
"@highlight"
),
lines
))
# join the lines
return
story_lines
,
summary_lines
story
=
" "
.
join
(
story_lines
)
summary
=
" "
.
join
(
highlights_lines
)
return
story
,
summary
def
_encode_for_summarization
(
story_lines
,
summary_lines
,
tokenizer
):
""" Encode the story and summary lines, and join them
as specified in [1] by using `[SEP] [CLS]` tokens to separate
sentences.
"""
story_lines_token_ids
=
[
tokenizer
.
add_special_tokens_single_sequence
(
tokenizer
.
encode
(
line
))
for
line
in
story_lines
]
summary_lines_token_ids
=
[
tokenizer
.
add_special_tokens_single_sequence
(
tokenizer
.
encode
(
line
))
for
line
in
summary_lines
]
story_token_ids
=
[
token
for
sentence
in
story_lines_token_ids
for
token
in
sentence
]
summary_token_ids
=
[
token
for
sentence
in
summary_lines_token_ids
for
token
in
sentence
]
return
story_token_ids
,
summary_token_ids
def
_add_missing_period
(
line
):
def
_add_missing_period
(
line
):
...
@@ -170,8 +197,11 @@ def _fit_to_block_size(sequence, block_size):
...
@@ -170,8 +197,11 @@ def _fit_to_block_size(sequence, block_size):
def
mask_padding_tokens
(
sequence
):
def
mask_padding_tokens
(
sequence
):
""" Replace the padding token with -1 values """
""" Padding token, encoded as 0, are represented by the value -1 in the
return
[
s
if
s
!=
0
else
-
1
for
s
in
sequence
]
masks """
padded
=
sequence
.
clone
()
padded
[
padded
==
0
]
=
-
1
return
padded
def
load_and_cache_examples
(
args
,
tokenizer
):
def
load_and_cache_examples
(
args
,
tokenizer
):
...
@@ -179,81 +209,181 @@ def load_and_cache_examples(args, tokenizer):
...
@@ -179,81 +209,181 @@ def load_and_cache_examples(args, tokenizer):
return
dataset
return
dataset
def
compute_token_type_ids
(
batch
,
separator_token_id
):
""" Segment embeddings as described in [1]
The values {0,1} were found in the repository [2].
Attributes:
batch: torch.Tensor, size [batch_size, block_size]
Batch of input.
separator_token_id: int
The value of the token that separates the segments.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
"""
batch_embeddings
=
[]
sentence_num
=
0
for
sequence
in
batch
:
embeddings
=
[]
for
s
in
sequence
:
if
s
==
separator_token_id
:
sentence_num
+=
1
embeddings
.
append
(
sentence_num
%
2
)
batch_embeddings
.
append
(
embeddings
)
return
torch
.
tensor
(
batch_embeddings
)
# ----------
# Optimizers
# ----------
class
BertSumOptimizer
(
object
):
""" Specific optimizer for BertSum.
As described in [1], the authors fine-tune BertSum for abstractive
summarization using two Adam Optimizers with different warm-up steps and
learning rate. They also use a custom learning rate scheduler.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
"""
def
__init__
(
self
,
model
,
lr
,
warmup_steps
,
beta_1
=
0.99
,
beta_2
=
0.999
,
eps
=
1e-9
):
self
.
encoder
=
model
.
encoder
self
.
decoder
=
model
.
decoder
self
.
lr
=
lr
self
.
warmup_steps
=
warmup_steps
self
.
optimizers
=
{
"encoder"
:
Adam
(
model
.
encoder
.
parameters
(),
lr
=
lr
[
"encoder"
],
betas
=
(
beta_1
,
beta_2
),
eps
=
eps
,
),
"decoder"
:
Adam
(
model
.
decoder
.
parameters
(),
lr
=
lr
[
"decoder"
],
betas
=
(
beta_1
,
beta_2
),
eps
=
eps
,
),
}
self
.
_step
=
0
def
_update_rate
(
self
,
stack
):
return
self
.
lr
[
stack
]
*
min
(
self
.
_step
**
(
-
0.5
),
self
.
_step
*
self
.
warmup_steps
[
stack
]
**
(
-
0.5
)
)
def
zero_grad
(
self
):
self
.
optimizer_decoder
.
zero_grad
()
self
.
optimizer_encoder
.
zero_grad
()
def
step
(
self
):
self
.
_step
+=
1
for
stack
,
optimizer
in
self
.
optimizers
.
items
():
new_rate
=
self
.
_update_rate
(
stack
)
for
param_group
in
optimizer
.
param_groups
:
param_group
[
"lr"
]
=
new_rate
optimizer
.
step
()
# ------------
# ------------
# Train
# Train
# ------------
# ------------
def
train
(
args
,
train_dataset
,
model
,
tokenizer
):
def
train
(
args
,
model
,
tokenizer
):
""" Fine-tune the pretrained model on the corpus. """
""" Fine-tune the pretrained model on the corpus. """
set_seed
(
args
)
# Prepare the data loading
# Load the data
args
.
train_bach_size
=
1
args
.
train_batch_size
=
args
.
per_gpu_train_batch_size
*
max
(
1
,
args
.
n_gpu
)
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
)
train_sampler
=
RandomSampler
(
train_dataset
)
train_sampler
=
RandomSampler
(
train_dataset
)
train_dataloader
=
DataLoader
(
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_bach_size
train_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_ba
t
ch_size
)
)
# Prepare the optimizer and schedule (linear warmup and decay)
# Training schedule
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
if
args
.
max_steps
>
0
:
optimizer_grouped_parameters
=
[
t_total
=
args
.
max_steps
{
args
.
num_train_epochs
=
t_total
//
(
"params"
:
[
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
+
1
p
)
for
n
,
p
in
model
.
named_parameters
()
else
:
if
not
any
(
nd
in
n
for
nd
in
no_decay
)
t_total
=
(
],
len
(
train_dataloader
)
"weight_decay"
:
args
.
weight_decay
,
//
args
.
gradient_accumulation_steps
},
*
args
.
num_train_epochs
{
)
"params"
:
[
p
# Prepare the optimizer
for
n
,
p
in
model
.
named_parameters
()
lr
=
{
"encoder"
:
0.002
,
"decoder"
:
0.2
}
if
any
(
nd
in
n
for
nd
in
no_decay
)
warmup_steps
=
{
"encoder"
:
20000
,
"decoder"
:
10000
}
],
optimizer
=
BertSumOptimizer
(
model
,
lr
,
warmup_steps
)
"weight_decay"
:
0.0
,
},
]
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
lr
=
args
.
learning_rate
,
eps
=
args
.
adam_epsilon
)
scheduler
=
WarmupLinearSchedule
(
optimizer
,
warmup_steps
=
args
.
warmup_steps
,
t_total
=
t_total
)
# Train
# Train
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
" Num examples = %d"
,
len
(
train_dataset
))
logger
.
info
(
" Num examples = %d"
,
len
(
train_dataset
))
logger
.
info
(
" Num Epochs = %d"
,
args
.
num_train_epochs
)
logger
.
info
(
" Num Epochs = %d"
,
args
.
num_train_epochs
)
logger
.
info
(
" Instantaneous batch size per GPU = %d"
,
args
.
per_gpu_train_batch_size
)
logger
.
info
(
logger
.
info
(
" Total train batch size (w. parallel, distributed & accumulation) = %d"
,
" Instantaneous batch size per GPU = %d"
,
args
.
per_gpu_train_batch_size
args
.
train_batch_size
)
*
args
.
gradient_accumulation_steps
logger
.
info
(
*
(
torch
.
distributed
.
get_world_size
()
if
args
.
local_rank
!=
-
1
else
1
),
" Total train batch size (w. parallel, distributed & accumulation) = %d"
,
args
.
train_batch_size
*
args
.
gradient_accumulation_steps
# * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
)
)
logger
.
info
(
" Gradient Accumulation steps = %d"
,
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Gradient Accumulation steps = %d"
,
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
global_step
=
0
tr_loss
,
logging_loss
=
0.0
,
0.0
model
.
zero_grad
()
model
.
zero_grad
()
train_iterator
=
trange
(
args
.
num_train_epochs
,
desc
=
"Epoch"
,
disable
=
True
)
train_iterator
=
trange
(
args
.
num_train_epochs
,
desc
=
"Epoch"
,
disable
=
True
)
set_seed
(
args
)
global_step
=
0
tr_loss
=
0.0
for
_
in
train_iterator
:
for
_
in
train_iterator
:
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
True
)
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
True
)
for
step
,
batch
in
enumerate
(
epoch_iterator
):
for
step
,
batch
in
enumerate
(
epoch_iterator
):
source
=
([
s
for
s
,
_
in
batch
]).
to
(
args
.
device
)
source
,
target
=
batch
target
=
([
t
for
_
,
t
in
batch
]).
to
(
args
.
device
)
token_type_ids
=
compute_token_type_ids
(
source
,
tokenizer
.
cls_token_id
)
labels_src
=
mask_padding_tokens
(
source
)
labels_tgt
=
mask_padding_tokens
(
target
)
source
=
source
.
to
(
args
.
device
)
target
=
target
.
to
(
args
.
device
)
token_type_ids
=
token_type_ids
.
to
(
args
.
device
)
labels_src
=
labels_src
.
to
(
args
.
device
)
labels_tgt
=
labels_tgt
.
to
(
args
.
device
)
model
.
train
()
model
.
train
()
outputs
=
model
(
source
,
target
,
decoder_lm_labels
=
mask_padding_tokens
(
target
))
outputs
=
model
(
source
,
target
,
token_type_ids
=
token_type_ids
,
decoder_encoder_attention_mask
=
labels_src
,
decoder_attention_mask
=
labels_tgt
,
decoder_lm_labels
=
labels_tgt
,
decoder_initialize_randomly
=
True
,
)
loss
=
outputs
[
0
]
loss
=
outputs
[
0
]
print
(
loss
)
if
args
.
gradient_accumulation_steps
>
1
:
loss
/=
args
.
gradient_accumulation_steps
loss
.
backward
()
loss
.
backward
()
tr_loss
+=
loss
.
item
()
tr_loss
+=
loss
.
item
()
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
max_grad_norm
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
max_grad_norm
)
optimizer
.
step
()
optimizer
.
step
()
scheduler
.
step
()
model
.
zero_grad
()
model
.
zero_grad
()
global_step
+=
1
global_step
+=
1
...
@@ -268,6 +398,68 @@ def train(args, train_dataset, model, tokenizer):
...
@@ -268,6 +398,68 @@ def train(args, train_dataset, model, tokenizer):
return
global_step
,
tr_loss
/
global_step
return
global_step
,
tr_loss
/
global_step
# ------------
# Train
# ------------
def
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
""
):
set_seed
(
args
)
args
.
eval_batch_size
=
args
.
per_gpu_eval_batch_size
*
max
(
1
,
args
.
n_gpu
)
eval_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
True
)
eval_sampler
=
SequentialSampler
(
eval_dataset
)
eval_dataloader
=
DataLoader
(
eval_dataset
,
sampler
=
eval_sampler
,
batch_size
=
args
.
eval_batch_size
)
logger
.
info
(
"***** Running evaluation {} *****"
.
format
(
prefix
))
logger
.
info
(
" Num examples = %d"
,
len
(
eval_dataset
))
logger
.
info
(
" Batch size = %d"
,
args
.
eval_batch_size
)
eval_loss
=
0.0
nb_eval_steps
=
0
model
.
eval
()
for
batch
in
tqdm
(
eval_dataloader
,
desc
=
"Evaluating"
):
source
,
target
=
batch
labels_src
=
mask_padding_tokens
(
source
)
labels_tgt
=
mask_padding_tokens
(
target
)
source
.
to
(
args
.
device
)
target
.
to
(
args
.
device
)
labels_src
.
to
(
args
.
device
)
labels_tgt
.
to
(
args
.
device
)
with
torch
.
no_grad
():
outputs
=
model
(
source
,
target
,
decoder_encoder_attention_mask
=
labels_src
,
decoder_attention_mask
=
labels_tgt
,
decoder_lm_labels
=
labels_tgt
,
)
lm_loss
=
outputs
[
0
]
eval_loss
+=
lm_loss
.
mean
().
item
()
nb_eval_steps
+=
1
eval_loss
=
eval_loss
/
nb_eval_steps
perplexity
=
torch
.
exp
(
torch
.
tensor
(
eval_loss
))
result
=
{
"perplexity"
:
perplexity
}
# Save the evaluation's results
output_eval_file
=
os
.
path
.
join
(
args
.
output_dir
,
"eval_results.txt"
)
if
not
os
.
path
.
exists
(
args
.
output_dir
):
os
.
makedirs
(
args
.
output_dir
)
with
open
(
output_eval_file
,
"w"
)
as
writer
:
logger
.
info
(
"***** Eval results {} *****"
.
format
(
prefix
))
for
key
in
sorted
(
result
.
keys
()):
logger
.
info
(
" %s = %s"
,
key
,
str
(
result
[
key
]))
writer
.
write
(
"%s = %s
\n
"
%
(
key
,
str
(
result
[
key
])))
return
result
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
@@ -289,7 +481,23 @@ def main():
...
@@ -289,7 +481,23 @@ def main():
# Optional parameters
# Optional parameters
parser
.
add_argument
(
parser
.
add_argument
(
"--adam_epsilon"
,
default
=
1e-8
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
)
parser
.
add_argument
(
"--do_evaluate"
,
type
=
bool
,
default
=
False
,
help
=
"Run model evaluation on out-of-sample data."
,
)
parser
.
add_argument
(
"--do_train"
,
type
=
bool
,
default
=
False
,
help
=
"Run training."
)
parser
.
add_argument
(
"--do_overwrite_output_dir"
,
type
=
bool
,
default
=
False
,
help
=
"Whether to overwrite the output dir."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--model_name_or_path"
,
"--model_name_or_path"
,
...
@@ -303,12 +511,6 @@ def main():
...
@@ -303,12 +511,6 @@ def main():
type
=
str
,
type
=
str
,
help
=
"The decoder architecture to be fine-tuned."
,
help
=
"The decoder architecture to be fine-tuned."
,
)
)
parser
.
add_argument
(
"--learning_rate"
,
default
=
5e-5
,
type
=
float
,
help
=
"The initial learning rate for Adam."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
)
...
@@ -318,43 +520,100 @@ def main():
...
@@ -318,43 +520,100 @@ def main():
type
=
int
,
type
=
int
,
help
=
"If > 0: set total number of training steps to perform. Override num_train_epochs."
,
help
=
"If > 0: set total number of training steps to perform. Override num_train_epochs."
,
)
)
parser
.
add_argument
(
"--to_cpu"
,
default
=
False
,
type
=
bool
,
help
=
"Whether to force training on CPU."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num_train_epochs"
,
"--num_train_epochs"
,
default
=
1
,
default
=
1
,
type
=
int
,
type
=
int
,
help
=
"Total number of training epochs to perform."
,
help
=
"Total number of training epochs to perform."
,
)
)
parser
.
add_argument
(
"--seed"
,
default
=
42
,
type
=
int
)
parser
.
add_argument
(
"--warmup_steps"
,
default
=
0
,
type
=
int
,
help
=
"Linear warmup over warmup_steps."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--weight_decay"
,
default
=
0.0
,
type
=
float
,
help
=
"Weight deay if we apply some."
"--per_gpu_train_batch_size"
,
default
=
4
,
type
=
int
,
help
=
"Batch size per GPU/CPU for training."
,
)
)
parser
.
add_argument
(
"--seed"
,
default
=
42
,
type
=
int
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
model_type
!=
"bert"
:
if
(
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
)
and
args
.
do_train
and
not
args
.
do_overwrite_output_dir
):
raise
ValueError
(
raise
ValueError
(
"Only the BERT architecture is currently supported for seq2seq."
"Output directory ({}) already exists and is not empty. Use --do_overwrite_output_dir to overwrite."
.
format
(
args
.
output_dir
)
)
)
# Set up training device
# Set up training device
# device = torch.device("cpu")
if
args
.
to_cpu
or
not
torch
.
cuda
.
is_available
():
args
.
device
=
torch
.
device
(
"cpu"
)
# Set seed
args
.
n_gpu
=
0
set_seed
(
args
)
else
:
args
.
device
=
torch
.
device
(
"cuda"
)
args
.
n_gpu
=
torch
.
cuda
.
device_count
()
# Load pretrained model and tokenizer
# Load pretrained model and tokenizer
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
Model2Model
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
Model2Model
.
from_pretrained
(
args
.
model_name_or_path
)
# model.to(device)
# Setup logging
logging
.
basicConfig
(
format
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
level
=
logging
.
INFO
,
)
logger
.
warning
(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s"
,
0
,
args
.
device
,
args
.
n_gpu
,
False
,
False
,
)
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
# Training
# Train the model
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
)
model
.
to
(
args
.
device
)
global_step
,
tr_loss
=
train
(
args
,
train_dataset
,
model
,
tokenizer
)
if
args
.
do_train
:
# logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
global_step
,
tr_loss
=
train
(
args
,
model
,
tokenizer
)
logger
.
info
(
" global_step = %s, average loss = %s"
,
global_step
,
tr_loss
)
if
not
os
.
path
.
exists
(
args
.
output_dir
):
os
.
makedirs
(
args
.
output_dir
)
logger
.
info
(
"Saving model checkpoint to %s"
,
args
.
output_dir
)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save
=
(
model
.
module
if
hasattr
(
model
,
"module"
)
else
model
)
# Take care of distributed/parallel training
model_to_save
.
save_pretrained
(
args
.
output_dir
)
tokenizer
.
save_pretrained
(
args
.
output_dir
)
torch
.
save
(
args
,
os
.
path
.
join
(
args
.
output_dir
,
"training_arguments.bin"
))
# Evaluate the model
results
=
{}
if
args
.
do_evaluate
:
checkpoints
=
[]
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
for
checkpoint
in
checkpoints
:
encoder_checkpoint
=
os
.
path
.
join
(
checkpoint
,
"encoder"
)
decoder_checkpoint
=
os
.
path
.
join
(
checkpoint
,
"decoder"
)
model
=
PreTrainedSeq2seq
.
from_pretrained
(
encoder_checkpoint
,
decoder_checkpoint
)
model
.
to
(
args
.
device
)
results
=
"placeholder"
return
results
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/run_s
eq2seq
_finetuning_test.py
→
examples/run_s
ummarization
_finetuning_test.py
View file @
4c3ac4a7
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# limitations under the License.
# limitations under the License.
import
unittest
import
unittest
from
run_s
eq2seq
_finetuning
import
_fit_to_block_size
,
process_story
from
run_s
ummarization
_finetuning
import
_fit_to_block_size
,
process_story
class
DataLoaderTest
(
unittest
.
TestCase
):
class
DataLoaderTest
(
unittest
.
TestCase
):
...
@@ -43,15 +43,16 @@ class DataLoaderTest(unittest.TestCase):
...
@@ -43,15 +43,16 @@ class DataLoaderTest(unittest.TestCase):
raw_story
=
"""It was the year of Our Lord one thousand seven hundred and
raw_story
=
"""It was the year of Our Lord one thousand seven hundred and
seventy-five.
\n\n
Spiritual revelations were conceded to England at that
seventy-five.
\n\n
Spiritual revelations were conceded to England at that
favoured period, as at this."""
favoured period, as at this."""
with
self
.
assertRaises
(
IndexError
):
_
,
summary
=
process_story
(
raw_story
)
process_story
(
raw_story
)
self
.
assertEqual
(
summary
,
[]
)
def
test_process_empty_story
(
self
):
def
test_process_empty_story
(
self
):
""" An empty story should also raise and exception.
""" An empty story should also raise and exception.
"""
"""
raw_story
=
""
raw_story
=
""
with
self
.
assertRaises
(
IndexError
):
story
,
summary
=
process_story
(
raw_story
)
process_story
(
raw_story
)
self
.
assertEqual
(
story
,
[])
self
.
assertEqual
(
summary
,
[])
def
test_story_with_missing_period
(
self
):
def
test_story_with_missing_period
(
self
):
raw_story
=
(
raw_story
=
(
...
@@ -59,17 +60,16 @@ class DataLoaderTest(unittest.TestCase):
...
@@ -59,17 +60,16 @@ class DataLoaderTest(unittest.TestCase):
"seventy-five
\n\n
Spiritual revelations were conceded to England "
"seventy-five
\n\n
Spiritual revelations were conceded to England "
"at that favoured period, as at this.
\n
@highlight
\n\n
It was the best of times"
"at that favoured period, as at this.
\n
@highlight
\n\n
It was the best of times"
)
)
story
,
summary
=
process_story
(
raw_story
)
story
_lines
,
summary
_lines
=
process_story
(
raw_story
)
expected_story
=
(
expected_story_lines
=
[
"It was the year of Our Lord one thousand seven hundred and "
"It was the year of Our Lord one thousand seven hundred and seventy-five."
,
"seventy-five. Spiritual revelations were conceded to England at that "
"Spiritual revelations were conceded to England at that favoured period, as at this."
,
"favoured period, as at this."
]
)
self
.
assertEqual
(
expected_story_lines
,
story_lines
)
self
.
assertEqual
(
expected_story
,
story
)
expected_summary
=
"It was the best of times."
expected_summary
_lines
=
[
"It was the best of times."
]
self
.
assertEqual
(
expected_summary
,
summary
)
self
.
assertEqual
(
expected_summary
_lines
,
summary
_lines
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
transformers/__init__.py
View file @
4c3ac4a7
...
@@ -87,7 +87,7 @@ if is_torch_available():
...
@@ -87,7 +87,7 @@ if is_torch_available():
from
.modeling_distilbert
import
(
DistilBertForMaskedLM
,
DistilBertModel
,
from
.modeling_distilbert
import
(
DistilBertForMaskedLM
,
DistilBertModel
,
DistilBertForSequenceClassification
,
DistilBertForQuestionAnswering
,
DistilBertForSequenceClassification
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_seq2seq
import
Model2Model
from
.modeling_seq2seq
import
PreTrainedSeq2seq
,
Model2Model
# Optimization
# Optimization
from
.optimization
import
(
AdamW
,
ConstantLRSchedule
,
WarmupConstantSchedule
,
WarmupCosineSchedule
,
from
.optimization
import
(
AdamW
,
ConstantLRSchedule
,
WarmupConstantSchedule
,
WarmupCosineSchedule
,
...
...
transformers/modeling_beam_search.py
0 → 100644
View file @
4c3ac4a7
# coding=utf-8
# Copyright (c) 2019 Yang Liu
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
A general wrapper around models with LM heads to generate sequences
using beam search.
"""
import
torch
from
torch
import
nn
class
ModelWithBeamSearch
(
nn
.
Module
):
def
__init__
(
self
,
model
,
beam_size
,
start_token_id
,
end_token_id
,
pad_token_id
,
min_length
,
max_length
,
alpha
,
block_trigram
=
True
,
):
"""
Attributes:
mask_word_id: token id that corresponds to the mask
"""
super
(
ModelWithBeamSearch
,
self
).
__init__
()
self
.
model
=
model
self
.
beam_size
=
beam_size
self
.
start_token_id
=
start_token_id
self
.
end_token_id
=
end_token_id
self
.
pad_token_id
=
pad_token_id
self
.
min_length
=
min_length
self
.
max_length
=
max_length
self
.
alpha
=
alpha
self
.
block_trigram
=
block_trigram
def
forward
(
self
,
input_ids
,
**
kwargs
):
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
kwargs_encoder
=
{
argument
:
value
for
argument
,
value
in
kwargs
.
items
()
if
not
argument
.
startswith
(
"decoder_"
)
}
kwargs_decoder
=
{
argument
[
len
(
"decoder_"
):]:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"decoder_"
)
}
batch_size
,
_
=
input_ids
.
size
(
0
)
# Variables that keep track of the status of the search
hypotheses
=
[[]
for
_
in
range
(
batch_size
)]
batch_offset
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
long
)
beam_offset
=
torch
.
arange
(
0
,
batch_size
*
self
.
beam_size
,
step
=
self
.
beam_size
,
dtype
=
torch
.
long
,
)
growing_beam
=
torch
.
full
(
(
batch_size
*
self
.
beam_size
,
1
),
self
.
start_token_id
,
dtype
=
torch
.
long
,
)
topk_log_probabilities
=
torch
.
tensor
(
[
0.0
]
+
[
float
(
"-inf"
)]
*
(
self
.
beam_size
-
1
),
dtype
=
torch
.
float
,
).
repeat
(
batch_size
)
# Forward pass on the encoder
encoder_outputs
=
self
.
encoder
(
input_ids
,
kwargs_encoder
)
kwargs_decoder
[
"encoder_hidden_states"
]
=
tile
(
encoder_outputs
,
self
.
beam_size
,
dim
=
0
)
results
=
{}
results
[
"predictions"
]
=
[[]
for
_
in
batch_size
]
results
[
"scores"
]
=
[[]
for
_
in
batch_size
]
for
step
in
range
(
self
.
max_length
):
decoder_input
=
growing_beam
[:,
-
1
]
outputs
=
self
.
decoder
(
decoder_input
,
kwargs_decoder
)
log_probabilities
=
torch
.
nn
.
functional
.
log_softmax
(
outputs
[
1
])
vocab_size
=
log_probabilities
.
size
(
-
1
)
# The batch size changes as some beams finish so we define:
_B
=
log_probabilities
.
size
(
0
)
//
self
.
beam_size
# Multiply each beam probability with the probability of the
# next token (conditioned on the words in the beam).
log_probabilities
+=
topk_log_probabilities
.
view
(
-
1
,
1
)
# if the beam has not attained the minimum required length we
# make the end token arbitrarily unlikely.
if
step
<
self
.
min_length
:
log_probabilities
[
self
.
end_token_id
]
=
-
1e20
# Remove repeating tri-grams
if
(
self
.
args
.
block_trigram
):
if
(
step
+
1
>
3
):
for
i
in
range
(
_B
*
self
.
beam_size
):
tokens
=
[
t
for
t
in
growing_beam
[
i
]]
trigrams
=
[(
tokens
[
i
-
1
],
tokens
[
i
],
tokens
[
i
+
1
])
for
i
in
range
(
1
,
len
(
words
)
-
1
)]
last_trigram
=
tuple
(
trigrams
[
-
1
])
if
last_trigram
in
trigrams
[:
-
1
]:
log_probabilities
[
i
]
=
-
1e20
# Find the `beam_size` (previous_beam + token) combinations with
# the highest score
topk_log_probabilities
,
topk_ids
=
log_probabilities
.
topk
(
log_probabilities
.
view
(
_B
,
self
.
beam_size
*
vocab_size
),
self
.
beam_size
,
dim
=
1
)
# Apply the length penalty. The +1 accounts for the [EOS] token
# that will be added if the beam ends.
length_penalty
=
((
5.0
+
(
step
+
1
))
/
6.0
)
**
self
.
alpha
topk_scores
=
topk_log_probabilities
/
length_penalty
# Retrieve the corresponding respective beam and token id
# topk_token_ids[i] will be added to topk_beam_ids[i]
topk_beam_ids
=
topk_ids
.
div
(
vocab_size
)
topk_token_ids
=
topk_ids
.
fmod
(
vocab_size
)
# Retrieve the row index of the surviving beams in the original
# view of the log_probabilities tensor
surviving_beams_rows
=
(
topk_beam_ids
+
beam_offset
[:
_B
].
view
(
-
1
,
1
)
).
view
(
-
1
)
# Append the last predictions
growing_beam
=
torch
.
cat
(
[
growing_beam
.
index_select
(
0
,
surviving_beams_rows
),
topk_token_ids
.
view
(
-
1
,
1
),
],
1
,
)
# Check if any of the beam searches has ended during this
# growth step. Also if top beam (most probable) has ended
# for one element of the batch.
is_finished
=
topk_token_ids
.
eq
(
self
.
end_token_id
)
if
step
+
1
==
self
.
max_length
:
is_finished
.
fill_
(
1
)
is_top_beam_finished
=
is_finished
[:,
0
].
eq
(
1
)
# Save the finished searches
if
is_finished
.
any
():
predictions
=
growing_beam
.
view
(
-
1
,
self
.
beam_size
,
growing_beam
.
size
(
1
))
for
i
in
range
(
is_finished
.
size
(
0
)):
if
is_top_beam_finished
[
i
]:
is_finished
[
i
].
fill_
(
1
)
finished_hyp
=
is_finished
[
i
].
nonzero
().
view
(
-
1
)
# Store finished hypotheses for this batch.
b
=
batch_offset
[
i
]
for
j
in
finished_hyp
:
hypotheses
[
b
].
append
((
topk_scores
[
i
,
j
],
predictions
[
i
,
j
,
:]))
# If the batch reached the end, save the best hypotheses
# in terms of length-penalized score.
if
is_top_beam_finished
[
i
]:
best_hyp
=
sorted
(
hypotheses
[
b
],
key
=
lambda
x
:
x
[
0
],
reverse
=
True
)
best_score
,
best_prediction
=
best_hyp
[
0
]
results
[
"scores"
][
b
].
append
(
best_score
)
results
[
"predictions"
][
b
].
append
(
best_prediction
)
non_finished
=
is_top_beam_finished
.
eq
(
0
).
nonzero
().
view
(
-
1
)
if
len
(
non_finished
)
==
0
:
break
# Remove finished batches for the next step.
topk_log_probabilities
=
topk_log_probabilities
.
index_select
(
0
,
non_finished
)
batch_offset
=
batch_offset
.
index_select
(
0
,
non_finished
)
growing_beam
=
predictions
.
index_select
(
0
,
non_finished
).
view
(
-
1
,
growing_beam
.
size
(
-
1
)
)
# Re-order the state for the next pass
surviving_beams_rows
=
surviving_beams_rows
.
index_select
(
0
,
non_finished
)
kwargs_decoder
[
"encoder_hidden_states"
]
=
kwargs_decoder
[
"encoder_hidden_states"
].
index_select
(
0
,
surviving_beams_rows
)
return
results
def
tile
(
x
,
count
,
dim
=
0
):
"""
Tiles `x` along dimension `dim` `count` times.
Example:
>> ex = torch.tensor([1,2],[3,4])
>> tile(ex, 2, 0)
torch.Tensor([[1,2],[1,2],[3,4],[3,4]])
"""
perm
=
list
(
range
(
len
(
x
.
size
())))
if
dim
!=
0
:
perm
[
0
],
perm
[
dim
]
=
perm
[
dim
],
perm
[
0
]
x
=
x
.
permute
(
perm
).
contiguous
()
out_size
=
list
(
x
.
size
())
out_size
[
0
]
*=
count
batch
=
x
.
size
(
0
)
x
=
(
x
.
view
(
batch
,
-
1
)
.
transpose
(
0
,
1
)
.
repeat
(
count
,
1
)
.
transpose
(
0
,
1
)
.
contiguous
()
.
view
(
*
out_size
)
)
if
dim
!=
0
:
x
=
x
.
permute
(
perm
).
contiguous
()
return
x
transformers/modeling_bert.py
View file @
4c3ac4a7
...
@@ -646,7 +646,7 @@ class BertModel(BertPreTrainedModel):
...
@@ -646,7 +646,7 @@ class BertModel(BertPreTrainedModel):
if
attention_mask
.
dim
()
==
2
:
if
attention_mask
.
dim
()
==
2
:
if
self
.
config
.
is_decoder
:
if
self
.
config
.
is_decoder
:
batch_size
,
seq_length
=
input_ids
.
size
()
batch_size
,
seq_length
=
input_ids
.
size
()
seq_ids
=
torch
.
arange
(
seq_length
)
seq_ids
=
torch
.
arange
(
seq_length
,
device
=
input_ids
.
device
)
causal_mask
=
seq_ids
[
None
,
None
,
:].
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
causal_mask
=
seq_ids
[
None
,
None
,
:].
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[:,
None
,
None
,
:]
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[:,
None
,
None
,
:]
else
:
else
:
...
@@ -660,6 +660,13 @@ class BertModel(BertPreTrainedModel):
...
@@ -660,6 +660,13 @@ class BertModel(BertPreTrainedModel):
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
# If a 2D encoder attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if
encoder_attention_mask
is
not
None
:
encoder_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
encoder_attention_mask
=
encoder_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
encoder_attention_mask
=
(
1.0
-
encoder_attention_mask
)
*
-
10000.0
# Prepare head mask if needed
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# attention_probs has shape bsz x n_heads x N x N
...
@@ -819,7 +826,7 @@ class BertForMaskedLM(BertPreTrainedModel):
...
@@ -819,7 +826,7 @@ class BertForMaskedLM(BertPreTrainedModel):
self
.
bert
.
embeddings
.
word_embeddings
)
self
.
bert
.
embeddings
.
word_embeddings
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
masked_lm_labels
=
None
,
lm_labels
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
):
masked_lm_labels
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
,
lm_labels
=
None
,
):
outputs
=
self
.
bert
(
input_ids
,
outputs
=
self
.
bert
(
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
...
@@ -838,11 +845,8 @@ class BertForMaskedLM(BertPreTrainedModel):
...
@@ -838,11 +845,8 @@ class BertForMaskedLM(BertPreTrainedModel):
# 1. If a tensor that contains the indices of masked labels is provided,
# 1. If a tensor that contains the indices of masked labels is provided,
# the cross-entropy is the MLM cross-entropy that measures the likelihood
# the cross-entropy is the MLM cross-entropy that measures the likelihood
# of predictions for masked words.
# of predictions for masked words.
# 2. If
encoder hidden states are
provided we are in a causal s
ituat
io
n
where we
# 2. If
`lm_label` is
provided we are in a causal s
cenar
io where we
# try to predict the next word for each input in the encoder.
# try to predict the next word for each input in the encoder.
if
masked_lm_labels
is
not
None
and
lm_labels
is
not
None
:
raise
AttributeError
(
"Masked LM training with an encoder-decoder is not supported."
)
if
masked_lm_labels
is
not
None
:
if
masked_lm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
# -1 index = padding token
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
# -1 index = padding token
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
...
@@ -851,9 +855,9 @@ class BertForMaskedLM(BertPreTrainedModel):
...
@@ -851,9 +855,9 @@ class BertForMaskedLM(BertPreTrainedModel):
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
# we are doing next-token prediction; shift prediction scores and input ids by one
# we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores
=
prediction_scores
[:,
:
-
1
,
:]
prediction_scores
=
prediction_scores
[:,
:
-
1
,
:]
lm_labels
=
lm_labels
[:,
1
:,
:]
lm_labels
=
lm_labels
[:,
1
:]
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
seq2seq_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
lm_labels
.
view
(
-
1
))
seq2seq_loss
=
loss_fct
(
prediction_scores
.
reshape
(
-
1
,
self
.
config
.
vocab_size
),
lm_labels
.
reshape
(
-
1
))
outputs
=
(
seq2seq_loss
,)
+
outputs
outputs
=
(
seq2seq_loss
,)
+
outputs
return
outputs
# (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)
return
outputs
# (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)
...
...
transformers/modeling_seq2seq.py
View file @
4c3ac4a7
...
@@ -17,13 +17,12 @@
...
@@ -17,13 +17,12 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
import
os
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
.file_utils
import
add_start_docstrings
from
.modeling_auto
import
AutoModel
,
AutoModelWithLMHead
from
.modeling_auto
import
AutoModel
,
AutoModelWithLMHead
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -43,7 +42,13 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -43,7 +42,13 @@ class PreTrainedSeq2seq(nn.Module):
self
.
decoder
=
decoder
self
.
decoder
=
decoder
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
encoder_pretrained_model_name_or_path
=
None
,
decoder_pretrained_model_name_or_path
=
None
,
*
model_args
,
**
kwargs
):
def
from_pretrained
(
cls
,
encoder_pretrained_model_name_or_path
=
None
,
decoder_pretrained_model_name_or_path
=
None
,
*
model_args
,
**
kwargs
):
r
""" Instantiates an encoder and a decoder from one or two base classes
r
""" Instantiates an encoder and a decoder from one or two base classes
of the library from pre-trained model checkpoints.
of the library from pre-trained model checkpoints.
...
@@ -108,23 +113,28 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -108,23 +113,28 @@ class PreTrainedSeq2seq(nn.Module):
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
# decoder-specific it the key starts with `decoder_`
kwargs_decoder
=
{}
kwargs_encoder
=
{
kwargs_encoder
=
kwargs
argument
:
value
for
key
in
kwargs_encoder
.
keys
():
for
argument
,
value
in
kwargs
.
items
()
if
key
.
startswith
(
"decoder_"
):
if
not
argument
.
startswith
(
"decoder_"
)
kwargs_decoder
[
key
.
replace
(
"decoder_"
,
""
)]
=
kwargs_encoder
.
pop
(
key
)
}
kwargs_decoder
=
{
argument
[
len
(
"decoder_"
)
:]:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"decoder_"
)
}
# Load and initialize the encoder and decoder
# Load and initialize the encoder and decoder
#
The distinction between encoder and decoder at the model level is made
# The distinction between encoder and decoder at the model level is made
#
by the value of the flag `is_decoder` that we need to set correctly.
# by the value of the flag `is_decoder` that we need to set correctly.
encoder
=
kwargs
.
pop
(
"encoder_model"
,
None
)
encoder
=
kwargs
_encoder
.
pop
(
"encoder_model"
,
None
)
if
encoder
is
None
:
if
encoder
is
None
:
kwargs_encoder
[
"is_decoder"
]
=
False
kwargs_encoder
[
"is_decoder"
]
=
False
encoder
=
AutoModel
.
from_pretrained
(
encoder
=
AutoModel
.
from_pretrained
(
encoder_pretrained_model_name_or_path
,
*
model_args
,
**
kwargs_encoder
encoder_pretrained_model_name_or_path
,
*
model_args
,
**
kwargs_encoder
)
)
decoder
=
kwargs
.
pop
(
"
decoder
_
model"
,
None
)
decoder
=
kwargs
_
decoder
.
pop
(
"
model"
,
None
)
if
decoder
is
None
:
if
decoder
is
None
:
kwargs_decoder
[
"is_decoder"
]
=
True
kwargs_decoder
[
"is_decoder"
]
=
True
decoder
=
AutoModelWithLMHead
.
from_pretrained
(
decoder
=
AutoModelWithLMHead
.
from_pretrained
(
...
@@ -135,6 +145,12 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -135,6 +145,12 @@ class PreTrainedSeq2seq(nn.Module):
return
model
return
model
def
save_pretrained
(
self
,
save_directory
):
""" Save a Seq2Seq model and its configuration file in a format
such that it can be loaded using `:func:`~transformers.PreTrainedSeq2seq.from_pretrained` """
self
.
encoder
.
save_pretrained
(
os
.
path
.
join
(
save_directory
,
"encoder"
))
self
.
decoder
.
save_pretrained
(
os
.
path
.
join
(
save_directory
,
"decoder"
))
def
forward
(
self
,
encoder_input_ids
,
decoder_input_ids
,
**
kwargs
):
def
forward
(
self
,
encoder_input_ids
,
decoder_input_ids
,
**
kwargs
):
""" The forward pass on a seq2eq depends what we are performing:
""" The forward pass on a seq2eq depends what we are performing:
...
@@ -155,22 +171,29 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -155,22 +171,29 @@ class PreTrainedSeq2seq(nn.Module):
"""
"""
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
# decoder-specific it the key starts with `decoder_`
kwargs_decoder
=
{}
kwargs_encoder
=
{
kwargs_encoder
=
kwargs
argument
:
value
for
key
in
kwargs_encoder
.
keys
():
for
argument
,
value
in
kwargs
.
items
()
if
key
.
startswith
(
"decoder_"
):
if
not
argument
.
startswith
(
"decoder_"
)
kwargs_decoder
[
key
.
replace
(
"decoder_"
,
""
)]
=
kwargs_encoder
.
pop
(
key
)
}
kwargs_decoder
=
{
argument
[
len
(
"decoder_"
)
:]:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"decoder_"
)
}
# Encode if needed (training, first prediction pass)
# Encode if needed (training, first prediction pass)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"encoder_hidden_states"
,
None
)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"encoder_hidden_states"
,
None
)
if
encoder_hidden_states
is
None
:
if
encoder_hidden_states
is
None
:
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs_encoder
)
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
][
-
1
]
# output of the encoder *stack*
encoder_hidden_states
=
encoder_outputs
[
0
][
-
1
]
# output of the encoder *stack*
else
:
else
:
encoder_outputs
=
()
encoder_outputs
=
()
# Decode
# Decode
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
[
None
,
:,
:]
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
,
**
kwargs_decoder
)
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
,
**
kwargs_decoder
)
return
decoder_outputs
+
encoder_outputs
return
decoder_outputs
+
encoder_outputs
...
@@ -201,9 +224,25 @@ class Model2Model(PreTrainedSeq2seq):
...
@@ -201,9 +224,25 @@ class Model2Model(PreTrainedSeq2seq):
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
args
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
args
,
**
kwargs
):
model
=
super
(
Model2Model
,
cls
).
from_pretrained
(
encoder_pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
decoder_pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
if
(
**
kwargs
)
"bert"
not
in
pretrained_model_name_or_path
or
"roberta"
in
pretrained_model_name_or_path
or
"distilbert"
in
pretrained_model_name_or_path
):
raise
ValueError
(
"Only the Bert model is currently supported."
)
model
=
super
(
Model2Model
,
cls
).
from_pretrained
(
encoder_pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
decoder_pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
# Some architectures require for the decoder to be initialized randomly
# before fine-tuning.
if
kwargs
.
get
(
"decoder_initialize_randomly"
,
False
):
model
.
decoder
.
init_weights
()
return
model
return
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment