Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
dfce4096
Commit
dfce4096
authored
Oct 29, 2019
by
Rémi Louf
Browse files
resolve PR comments
parent
4c3ac4a7
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
647 additions
and
397 deletions
+647
-397
examples/run_summarization_finetuning.py
examples/run_summarization_finetuning.py
+81
-213
examples/utils_summarization.py
examples/utils_summarization.py
+184
-0
examples/utils_summarization_test.py
examples/utils_summarization_test.py
+133
-0
transformers/modeling_beam_search.py
transformers/modeling_beam_search.py
+177
-146
transformers/modeling_bert.py
transformers/modeling_bert.py
+19
-12
transformers/modeling_seq2seq.py
transformers/modeling_seq2seq.py
+53
-26
No files found.
examples/run_summarization_finetuning.py
View file @
dfce4096
...
...
@@ -16,10 +16,9 @@
""" Finetuning seq2seq models for sequence generation."""
import
argparse
from
collections
import
deque
import
functools
import
logging
import
os
import
pickle
import
random
import
sys
...
...
@@ -29,7 +28,22 @@ import torch
from
torch.optim
import
Adam
from
torch.utils.data
import
Dataset
,
DataLoader
,
RandomSampler
,
SequentialSampler
from
transformers
import
AutoTokenizer
,
PreTrainedSeq2seq
,
Model2Model
from
transformers
import
(
AutoTokenizer
,
BertForMaskedLM
,
BertConfig
,
PreTrainedSeq2seq
,
Model2Model
,
)
from
utils_summarization
import
(
CNNDailyMailDataset
,
encode_for_summarization
,
fit_to_block_size
,
build_lm_labels
,
build_mask
,
compute_token_type_ids
,
)
logger
=
logging
.
getLogger
(
__name__
)
logging
.
basicConfig
(
stream
=
sys
.
stdout
,
level
=
logging
.
INFO
)
...
...
@@ -46,195 +60,42 @@ def set_seed(args):
# ------------
class
TextDataset
(
Dataset
):
""" Abstracts the dataset used to train seq2seq models.
CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
stored in different files; the summary appears at the end of the story as
sentences that are prefixed by the special `@highlight` line. To process
the data, untar both datasets in the same folder, and pass the path to this
folder as the "data_dir argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/
[2] https://github.com/abisee/cnn-dailymail/
"""
def
load_and_cache_examples
(
args
,
tokenizer
):
dataset
=
CNNDailyMailDataset
(
tokenizer
,
data_dir
=
args
.
data_dir
)
return
dataset
def
__init__
(
self
,
tokenizer
,
prefix
=
"train"
,
data_dir
=
""
,
block_size
=
512
):
assert
os
.
path
.
isdir
(
data_dir
)
# Load the features that have already been computed, if any
cached_features_file
=
os
.
path
.
join
(
data_dir
,
"cached_lm_{}_{}"
.
format
(
block_size
,
prefix
)
)
if
os
.
path
.
exists
(
cached_features_file
):
logger
.
info
(
"Loading features from cached file %s"
,
cached_features_file
)
with
open
(
cached_features_file
,
"rb"
)
as
source
:
self
.
examples
=
pickle
.
load
(
source
)
return
logger
.
info
(
"Creating features from dataset at %s"
,
data_dir
)
datasets
=
[
"cnn"
,
"dailymail"
]
self
.
examples
=
{
"source"
:
[],
"target"
:
[]}
for
dataset
in
datasets
:
path_to_stories
=
os
.
path
.
join
(
data_dir
,
dataset
,
"stories"
)
story_filenames_list
=
os
.
listdir
(
path_to_stories
)
for
story_filename
in
story_filenames_list
:
path_to_story
=
os
.
path
.
join
(
path_to_stories
,
story_filename
)
if
not
os
.
path
.
isfile
(
path_to_story
):
continue
with
open
(
path_to_story
,
encoding
=
"utf-8"
)
as
source
:
raw_story
=
source
.
read
()
story_lines
,
summary_lines
=
process_story
(
raw_story
)
if
len
(
summary_lines
)
==
0
or
len
(
story_lines
)
==
0
:
continue
story_token_ids
,
summary_token_ids
=
_encode_for_summarization
(
story_lines
,
summary_lines
,
tokenizer
def
collate
(
data
,
tokenizer
,
block_size
):
""" List of tuple as an input. """
# remove the files with empty an story/summary, encode and fit to block
data
=
filter
(
lambda
x
:
not
(
len
(
x
[
0
])
==
0
or
len
(
x
[
1
])
==
0
),
data
)
data
=
[
encode_for_summarization
(
story
,
summary
,
tokenizer
)
for
story
,
summary
in
data
]
data
=
[
(
fit_to_block_size
(
story
,
block_size
,
tokenizer
.
pad_token_id
),
fit_to_block_size
(
summary
,
block_size
,
tokenizer
.
pad_token_id
),
)
story_seq
=
_fit_to_block_size
(
story_token_ids
,
block_size
)
self
.
examples
[
"source"
].
append
(
story_seq
)
summary_seq
=
_fit_to_block_size
(
summary_token_ids
,
block_size
)
self
.
examples
[
"summary"
].
append
(
summary_seq
)
logger
.
info
(
"Saving features into cache file %s"
,
cached_features_file
)
with
open
(
cached_features_file
,
"wb"
)
as
sink
:
pickle
.
dump
(
self
.
examples
,
sink
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
for
story
,
summary
in
data
]
def
__len__
(
self
):
return
len
(
self
.
examples
)
stories
=
torch
.
tensor
([
story
for
story
,
summary
in
data
])
summaries
=
torch
.
tensor
([
summary
for
story
,
summary
in
data
])
encoder_token_type_ids
=
compute_token_type_ids
(
stories
,
tokenizer
.
cls_token_id
)
encoder_mask
=
build_mask
(
stories
,
tokenizer
.
pad_token_id
)
decoder_mask
=
build_mask
(
summaries
,
tokenizer
.
pad_token_id
)
lm_labels
=
build_lm_labels
(
summaries
,
tokenizer
.
pad_token_id
)
def
__getitem__
(
self
,
items
):
return
(
torch
.
tensor
(
self
.
examples
[
"source"
][
items
]),
torch
.
tensor
(
self
.
examples
[
"target"
][
items
]),
)
def
process_story
(
raw_story
):
""" Extract the story and summary from a story file.
Attributes:
raw_story (str): content of the story file as an utf-8 encoded string.
Raises:
IndexError: If the stoy is empty or contains no highlights.
"""
nonempty_lines
=
list
(
filter
(
lambda
x
:
len
(
x
)
!=
0
,
[
line
.
strip
()
for
line
in
raw_story
.
split
(
"
\n
"
)])
stories
,
summaries
,
encoder_token_type_ids
,
encoder_mask
,
decoder_mask
,
lm_labels
,
)
# for some unknown reason some lines miss a period, add it
nonempty_lines
=
[
_add_missing_period
(
line
)
for
line
in
nonempty_lines
]
# gather article lines
story_lines
=
[]
lines
=
deque
(
nonempty_lines
)
while
True
:
try
:
element
=
lines
.
popleft
()
if
element
.
startswith
(
"@highlight"
):
break
story_lines
.
append
(
element
)
except
IndexError
:
# if "@highlight" is absent from the file we pop
# all elements until there is None.
return
story_lines
,
[]
# gather summary lines
summary_lines
=
list
(
filter
(
lambda
t
:
not
t
.
startswith
(
"@highlight"
),
lines
))
return
story_lines
,
summary_lines
def
_encode_for_summarization
(
story_lines
,
summary_lines
,
tokenizer
):
""" Encode the story and summary lines, and join them
as specified in [1] by using `[SEP] [CLS]` tokens to separate
sentences.
"""
story_lines_token_ids
=
[
tokenizer
.
add_special_tokens_single_sequence
(
tokenizer
.
encode
(
line
))
for
line
in
story_lines
]
summary_lines_token_ids
=
[
tokenizer
.
add_special_tokens_single_sequence
(
tokenizer
.
encode
(
line
))
for
line
in
summary_lines
]
story_token_ids
=
[
token
for
sentence
in
story_lines_token_ids
for
token
in
sentence
]
summary_token_ids
=
[
token
for
sentence
in
summary_lines_token_ids
for
token
in
sentence
]
return
story_token_ids
,
summary_token_ids
def
_add_missing_period
(
line
):
END_TOKENS
=
[
"."
,
"!"
,
"?"
,
"..."
,
"'"
,
"`"
,
'"'
,
u
"
\u2019
"
,
u
"
\u2019
"
,
")"
]
if
line
.
startswith
(
"@highlight"
):
return
line
if
line
[
-
1
]
in
END_TOKENS
:
return
line
return
line
+
"."
def
_fit_to_block_size
(
sequence
,
block_size
):
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter than the block size we pad it with -1 ids
which correspond to padding tokens.
"""
if
len
(
sequence
)
>
block_size
:
return
sequence
[:
block_size
]
else
:
sequence
.
extend
([
0
]
*
(
block_size
-
len
(
sequence
)))
return
sequence
def
mask_padding_tokens
(
sequence
):
""" Padding token, encoded as 0, are represented by the value -1 in the
masks """
padded
=
sequence
.
clone
()
padded
[
padded
==
0
]
=
-
1
return
padded
def
load_and_cache_examples
(
args
,
tokenizer
):
dataset
=
TextDataset
(
tokenizer
,
data_dir
=
args
.
data_dir
)
return
dataset
def
compute_token_type_ids
(
batch
,
separator_token_id
):
""" Segment embeddings as described in [1]
The values {0,1} were found in the repository [2].
Attributes:
batch: torch.Tensor, size [batch_size, block_size]
Batch of input.
separator_token_id: int
The value of the token that separates the segments.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
"""
batch_embeddings
=
[]
sentence_num
=
0
for
sequence
in
batch
:
embeddings
=
[]
for
s
in
sequence
:
if
s
==
separator_token_id
:
sentence_num
+=
1
embeddings
.
append
(
sentence_num
%
2
)
batch_embeddings
.
append
(
embeddings
)
return
torch
.
tensor
(
batch_embeddings
)
# ----------
# Optimizers
...
...
@@ -252,7 +113,7 @@ class BertSumOptimizer(object):
arXiv preprint arXiv:1908.08345 (2019).
"""
def
__init__
(
self
,
model
,
lr
,
warmup_steps
,
beta_1
=
0.99
,
beta_2
=
0.999
,
eps
=
1e-
9
):
def
__init__
(
self
,
model
,
lr
,
warmup_steps
,
beta_1
=
0.99
,
beta_2
=
0.999
,
eps
=
1e-
8
):
self
.
encoder
=
model
.
encoder
self
.
decoder
=
model
.
decoder
self
.
lr
=
lr
...
...
@@ -306,8 +167,12 @@ def train(args, model, tokenizer):
args
.
train_batch_size
=
args
.
per_gpu_train_batch_size
*
max
(
1
,
args
.
n_gpu
)
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
)
train_sampler
=
RandomSampler
(
train_dataset
)
model_collate_fn
=
functools
.
partial
(
collate
,
tokenizer
=
tokenizer
,
block_size
=
512
)
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
train_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
,
collate_fn
=
model_collate_fn
,
)
# Training schedule
...
...
@@ -351,26 +216,23 @@ def train(args, model, tokenizer):
for
_
in
train_iterator
:
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
True
)
for
step
,
batch
in
enumerate
(
epoch_iterator
):
source
,
target
=
batch
token_type_ids
=
compute_token_type_ids
(
source
,
tokenizer
.
cls_token_id
)
labels_src
=
mask_padding_tokens
(
source
)
labels_tgt
=
mask_padding_tokens
(
target
)
source
,
target
,
encoder_token_type_ids
,
encoder_mask
,
decoder_mask
,
lm_labels
=
batch
source
=
source
.
to
(
args
.
device
)
target
=
target
.
to
(
args
.
device
)
token_type_ids
=
token_type_ids
.
to
(
args
.
device
)
labels_src
=
labels_src
.
to
(
args
.
device
)
labels_tgt
=
labels_tgt
.
to
(
args
.
device
)
encoder_token_type_ids
=
encoder_token_type_ids
.
to
(
args
.
device
)
encoder_mask
=
encoder_mask
.
to
(
args
.
device
)
decoder_mask
=
decoder_mask
.
to
(
args
.
device
)
lm_labels
=
lm_labels
.
to
(
args
.
device
)
model
.
train
()
outputs
=
model
(
source
,
target
,
token_type_ids
=
token_type_ids
,
decoder_encoder_attention_mask
=
labels_src
,
decoder_attention_mask
=
labels_tgt
,
decoder_lm_labels
=
labels_tgt
,
decoder_initialize_randomly
=
True
,
encoder_token_type_ids
=
encoder_token_type_ids
,
encoder_attention_mask
=
encoder_mask
,
decoder_attention_mask
=
decoder_mask
,
decoder_lm_labels
=
lm_labels
,
)
loss
=
outputs
[
0
]
...
...
@@ -421,21 +283,23 @@ def evaluate(args, model, tokenizer, prefix=""):
model
.
eval
()
for
batch
in
tqdm
(
eval_dataloader
,
desc
=
"Evaluating"
):
source
,
target
=
batch
labels_src
=
mask_padding_tokens
(
source
)
labels_tgt
=
mask_padding_tokens
(
target
)
source
.
to
(
args
.
device
)
target
.
to
(
args
.
device
)
labels_src
.
to
(
args
.
device
)
labels_tgt
.
to
(
args
.
device
)
source
,
target
,
encoder_token_type_ids
,
encoder_mask
,
decoder_mask
,
lm_labels
=
batch
source
=
source
.
to
(
args
.
device
)
target
=
target
.
to
(
args
.
device
)
encoder_token_type_ids
=
encoder_token_type_ids
.
to
(
args
.
device
)
encoder_mask
=
encoder_mask
.
to
(
args
.
device
)
decoder_mask
=
decoder_mask
.
to
(
args
.
device
)
lm_labels
=
lm_labels
.
to
(
args
.
device
)
with
torch
.
no_grad
():
outputs
=
model
(
source
,
target
,
decoder_encoder_attention_mask
=
labels_src
,
decoder_attention_mask
=
labels_tgt
,
decoder_lm_labels
=
labels_tgt
,
encoder_token_type_ids
=
encoder_token_type_ids
,
encoder_attention_mask
=
encoder_mask
,
decoder_attention_mask
=
decoder_mask
,
decoder_lm_labels
=
lm_labels
,
)
lm_loss
=
outputs
[
0
]
eval_loss
+=
lm_loss
.
mean
().
item
()
...
...
@@ -525,7 +389,7 @@ def main():
)
parser
.
add_argument
(
"--num_train_epochs"
,
default
=
1
,
default
=
1
0
,
type
=
int
,
help
=
"Total number of training epochs to perform."
,
)
...
...
@@ -558,9 +422,13 @@ def main():
args
.
device
=
torch
.
device
(
"cuda"
)
args
.
n_gpu
=
torch
.
cuda
.
device_count
()
# Load pretrained model and tokenizer
# Load pretrained model and tokenizer
. The decoder's weights are randomly initialized.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
Model2Model
.
from_pretrained
(
args
.
model_name_or_path
)
config
=
BertConfig
.
from_pretrained
(
args
.
model_name_or_path
)
decoder_model
=
BertForMaskedLM
(
config
)
model
=
Model2Model
.
from_pretrained
(
args
.
model_name_or_path
,
decoder_model
=
decoder_model
)
# Setup logging
logging
.
basicConfig
(
...
...
examples/utils_summarization.py
0 → 100644
View file @
dfce4096
from
collections
import
deque
import
os
import
torch
from
torch.utils.data
import
Dataset
# ------------
# Data loading
# ------------
class
CNNDailyMailDataset
(
Dataset
):
""" Abstracts the dataset used to train seq2seq models.
CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
stored in different files; the summary appears at the end of the story as
sentences that are prefixed by the special `@highlight` line. To process
the data, untar both datasets in the same folder, and pass the path to this
folder as the "data_dir argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/
[2] https://github.com/abisee/cnn-dailymail/
"""
def
__init__
(
self
,
tokenizer
,
prefix
=
"train"
,
data_dir
=
""
):
assert
os
.
path
.
isdir
(
data_dir
)
self
.
tokenizer
=
tokenizer
# We initialize the class by listing all the files that contain
# stories and summaries. Files are not read in memory given
# the size of the corpus.
self
.
stories_path
=
[]
datasets
=
(
"cnn"
,
"dailymail"
)
for
dataset
in
datasets
:
path_to_stories
=
os
.
path
.
join
(
data_dir
,
dataset
,
"stories"
)
story_filenames_list
=
os
.
listdir
(
path_to_stories
)
for
story_filename
in
story_filenames_list
:
path_to_story
=
os
.
path
.
join
(
path_to_stories
,
story_filename
)
if
not
os
.
path
.
isfile
(
path_to_story
):
continue
self
.
stories_path
.
append
(
path_to_story
)
def
__len__
(
self
):
return
len
(
self
.
stories_path
)
def
__getitem__
(
self
,
idx
):
story_path
=
self
.
stories_path
[
idx
]
with
open
(
story_path
,
encoding
=
"utf-8"
)
as
source
:
raw_story
=
source
.
read
()
story_lines
,
summary_lines
=
process_story
(
raw_story
)
return
story_lines
,
summary_lines
def
process_story
(
raw_story
):
""" Extract the story and summary from a story file.
Attributes:
raw_story (str): content of the story file as an utf-8 encoded string.
Raises:
IndexError: If the stoy is empty or contains no highlights.
"""
nonempty_lines
=
list
(
filter
(
lambda
x
:
len
(
x
)
!=
0
,
[
line
.
strip
()
for
line
in
raw_story
.
split
(
"
\n
"
)])
)
# for some unknown reason some lines miss a period, add it
nonempty_lines
=
[
_add_missing_period
(
line
)
for
line
in
nonempty_lines
]
# gather article lines
story_lines
=
[]
lines
=
deque
(
nonempty_lines
)
while
True
:
try
:
element
=
lines
.
popleft
()
if
element
.
startswith
(
"@highlight"
):
break
story_lines
.
append
(
element
)
except
IndexError
:
# if "@highlight" is absent from the file we pop
# all elements until there is None.
return
story_lines
,
[]
# gather summary lines
summary_lines
=
list
(
filter
(
lambda
t
:
not
t
.
startswith
(
"@highlight"
),
lines
))
return
story_lines
,
summary_lines
def
_add_missing_period
(
line
):
END_TOKENS
=
[
"."
,
"!"
,
"?"
,
"..."
,
"'"
,
"`"
,
'"'
,
u
"
\u2019
"
,
u
"
\u2019
"
,
")"
]
if
line
.
startswith
(
"@highlight"
):
return
line
if
line
[
-
1
]
in
END_TOKENS
:
return
line
return
line
+
"."
# --------------------------
# Encoding and preprocessing
# --------------------------
def
fit_to_block_size
(
sequence
,
block_size
,
pad_token
):
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter than the block size we pad it with -1 ids
which correspond to padding tokens.
"""
if
len
(
sequence
)
>
block_size
:
return
sequence
[:
block_size
]
else
:
sequence
.
extend
([
pad_token
]
*
(
block_size
-
len
(
sequence
)))
return
sequence
def
build_lm_labels
(
sequence
,
pad_token
):
""" Padding token, encoded as 0, are represented by the value -1 so they
are not taken into account in the loss computation. """
padded
=
sequence
.
clone
()
padded
[
padded
==
pad_token
]
=
-
1
return
padded
def
build_mask
(
sequence
,
pad_token
):
""" Builds the mask. The attention mechanism will only attend to positions
with value 1. """
mask
=
sequence
.
clone
()
mask
[
mask
!=
pad_token
]
=
1
mask
[
mask
==
pad_token
]
=
0
return
mask
def
encode_for_summarization
(
story_lines
,
summary_lines
,
tokenizer
):
""" Encode the story and summary lines, and join them
as specified in [1] by using `[SEP] [CLS]` tokens to separate
sentences.
"""
story_lines_token_ids
=
[
tokenizer
.
add_special_tokens_single_sequence
(
tokenizer
.
encode
(
line
))
for
line
in
story_lines
]
summary_lines_token_ids
=
[
tokenizer
.
add_special_tokens_single_sequence
(
tokenizer
.
encode
(
line
))
for
line
in
summary_lines
]
story_token_ids
=
[
token
for
sentence
in
story_lines_token_ids
for
token
in
sentence
]
summary_token_ids
=
[
token
for
sentence
in
summary_lines_token_ids
for
token
in
sentence
]
return
story_token_ids
,
summary_token_ids
def
compute_token_type_ids
(
batch
,
separator_token_id
):
""" Segment embeddings as described in [1]
The values {0,1} were found in the repository [2].
Attributes:
batch: torch.Tensor, size [batch_size, block_size]
Batch of input.
separator_token_id: int
The value of the token that separates the segments.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
"""
batch_embeddings
=
[]
for
sequence
in
batch
:
sentence_num
=
0
embeddings
=
[]
for
s
in
sequence
:
if
s
==
separator_token_id
:
sentence_num
+=
1
embeddings
.
append
(
sentence_num
%
2
)
batch_embeddings
.
append
(
embeddings
)
return
torch
.
tensor
(
batch_embeddings
)
examples/
run
_summarization_
finetuning_
test.py
→
examples/
utils
_summarization_test.py
View file @
dfce4096
...
...
@@ -14,47 +14,64 @@
# limitations under the License.
import
unittest
from
run_summarization_finetuning
import
_fit_to_block_size
,
process_story
import
numpy
as
np
import
torch
from
utils_summarization
import
(
compute_token_type_ids
,
fit_to_block_size
,
build_mask
,
build_lm_labels
,
process_story
,
)
class
DataLoaderTest
(
unittest
.
TestCase
):
class
SummarizationDataProcessingTest
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
block_size
=
10
def
test_
truncate
_sequence_too_small
(
self
):
def
test_
fit_to_block
_sequence_too_small
(
self
):
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
sequence
=
[
1
,
2
,
3
,
4
]
expected_output
=
[
1
,
2
,
3
,
4
,
0
,
0
,
0
,
0
,
0
,
0
]
self
.
assertEqual
(
_fit_to_block_size
(
sequence
,
self
.
block_size
),
expected_output
)
self
.
assertEqual
(
fit_to_block_size
(
sequence
,
self
.
block_size
,
0
),
expected_output
)
def
test_truncate_sequence_fit_exactly
(
self
):
def
test_fit_to_block_sequence_fit_exactly
(
self
):
""" Do nothing if the sequence is the right size. """
sequence
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
]
expected_output
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
]
self
.
assertEqual
(
_fit_to_block_size
(
sequence
,
self
.
block_size
),
expected_output
)
self
.
assertEqual
(
fit_to_block_size
(
sequence
,
self
.
block_size
,
0
),
expected_output
)
def
test_truncate_sequence_too_big
(
self
):
def
test_fit_to_block_sequence_too_big
(
self
):
""" Truncate the sequence if it is too long. """
sequence
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
]
expected_output
=
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
]
self
.
assertEqual
(
_fit_to_block_size
(
sequence
,
self
.
block_size
),
expected_output
)
self
.
assertEqual
(
fit_to_block_size
(
sequence
,
self
.
block_size
,
0
),
expected_output
)
def
test_process_story_no_highlights
(
self
):
""" Processing a story with no highlights
should raise an exception
.
""" Processing a story with no highlights
returns an empty list for the summary
.
"""
raw_story
=
"""It was the year of Our Lord one thousand seven hundred and
seventy-five.
\n\n
Spiritual revelations were conceded to England at that
favoured period, as at this."""
_
,
summary
=
process_story
(
raw_story
)
self
.
assertEqual
(
summary
,
[])
_
,
summary
_lines
=
process_story
(
raw_story
)
self
.
assertEqual
(
summary
_lines
,
[])
def
test_process_empty_story
(
self
):
""" An empty story
should also raise
an
d
e
xception
.
""" An empty story
returns
an e
mpty collection of lines
.
"""
raw_story
=
""
story
,
summary
=
process_story
(
raw_story
)
self
.
assertEqual
(
story
,
[])
self
.
assertEqual
(
summary
,
[])
story
_lines
,
summary
_lines
=
process_story
(
raw_story
)
self
.
assertEqual
(
story
_lines
,
[])
self
.
assertEqual
(
summary
_lines
,
[])
def
test_story_with_missing_period
(
self
):
def
test_
process_
story_with_missing_period
(
self
):
raw_story
=
(
"It was the year of Our Lord one thousand seven hundred and "
"seventy-five
\n\n
Spiritual revelations were conceded to England "
...
...
@@ -71,6 +88,46 @@ class DataLoaderTest(unittest.TestCase):
expected_summary_lines
=
[
"It was the best of times."
]
self
.
assertEqual
(
expected_summary_lines
,
summary_lines
)
def
test_build_lm_labels_no_padding
(
self
):
sequence
=
torch
.
tensor
([
1
,
2
,
3
,
4
])
expected
=
sequence
np
.
testing
.
assert_array_equal
(
build_lm_labels
(
sequence
,
0
).
numpy
(),
expected
.
numpy
()
)
def
test_build_lm_labels
(
self
):
sequence
=
torch
.
tensor
([
1
,
2
,
3
,
4
,
0
,
0
,
0
])
expected
=
torch
.
tensor
([
1
,
2
,
3
,
4
,
-
1
,
-
1
,
-
1
])
np
.
testing
.
assert_array_equal
(
build_lm_labels
(
sequence
,
0
).
numpy
(),
expected
.
numpy
()
)
def
test_build_mask_no_padding
(
self
):
sequence
=
torch
.
tensor
([
1
,
2
,
3
,
4
])
expected
=
torch
.
tensor
([
1
,
1
,
1
,
1
])
np
.
testing
.
assert_array_equal
(
build_mask
(
sequence
,
0
).
numpy
(),
expected
.
numpy
()
)
def
test_build_mask
(
self
):
sequence
=
torch
.
tensor
([
1
,
2
,
3
,
4
,
23
,
23
,
23
])
expected
=
torch
.
tensor
([
1
,
1
,
1
,
1
,
0
,
0
,
0
])
np
.
testing
.
assert_array_equal
(
build_mask
(
sequence
,
23
).
numpy
(),
expected
.
numpy
()
)
def
test_compute_token_type_ids
(
self
):
separator
=
101
batch
=
torch
.
tensor
(
[[
1
,
2
,
3
,
4
,
5
,
6
],
[
1
,
2
,
3
,
101
,
5
,
6
],
[
1
,
101
,
3
,
4
,
101
,
6
]]
)
expected
=
torch
.
tensor
(
[[
0
,
0
,
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
1
,
1
,
1
],
[
0
,
1
,
1
,
1
,
0
,
0
]]
)
result
=
compute_token_type_ids
(
batch
,
separator
)
np
.
testing
.
assert_array_equal
(
result
,
expected
)
if
__name__
==
"__main__"
:
unittest
.
main
()
transformers/modeling_beam_search.py
View file @
dfce4096
...
...
@@ -26,119 +26,84 @@ import torch
from
torch
import
nn
class
ModelWith
BeamSearch
(
nn
.
Module
):
class
Transformer
BeamSearch
(
nn
.
Module
):
def
__init__
(
self
,
model
,
tokenizer
,
batch_size
,
beam_size
,
start_token_id
,
end_token_id
,
pad_token_id
,
min_length
,
max_length
,
alpha
,
block_trigram
=
True
,
alpha
=
0
,
block_
repeating_
trigram
=
True
,
):
"""
Attributes:
mask_word_id: token id that corresponds to the mask
"""
super
(
ModelWith
BeamSearch
,
self
).
__init__
()
super
(
Transformer
BeamSearch
,
self
).
__init__
()
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
start_token_id
=
tokenizer
.
start_token_id
self
.
end_token_id
=
tokenizer
.
end_token_id
self
.
pad_token_id
=
tokenizer
.
pad_token_id
self
.
beam_size
=
beam_size
self
.
start_token_id
=
start_token_id
self
.
end_token_id
=
end_token_id
self
.
pad_token_id
=
pad_token_id
self
.
min_length
=
min_length
self
.
max_length
=
max_length
self
.
alpha
=
alpha
self
.
block_trigram
=
block_trigram
def
forward
(
self
,
input_ids
,
**
kwargs
):
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
kwargs_encoder
=
{
argument
:
value
for
argument
,
value
in
kwargs
.
items
()
if
not
argument
.
startswith
(
"decoder_"
)
}
kwargs_decoder
=
{
argument
[
len
(
"decoder_"
):]:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"decoder_"
)
}
batch_size
,
_
=
input_ids
.
size
(
0
)
self
.
block_repeating_trigram
=
block_repeating_trigram
self
.
apply_length_penalty
=
False
if
alpha
==
0
else
True
self
.
alpha
=
alpha
# Variables that keep track of the status of the search
hypotheses
=
[[]
for
_
in
range
(
batch_size
)]
batch_offset
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
long
)
beam_offset
=
torch
.
arange
(
0
,
batch_size
*
self
.
beam_size
,
step
=
self
.
beam_size
,
dtype
=
torch
.
long
,
# State of the beam
self
.
hypotheses
=
[[]
for
_
in
range
(
batch_size
)]
self
.
batch_offset
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
long
)
self
.
beam_offset
=
torch
.
arange
(
0
,
batch_size
*
self
.
beam_size
,
step
=
self
.
beam_size
,
dtype
=
torch
.
long
)
growing_beam
=
torch
.
full
(
(
batch_size
*
self
.
beam_size
,
1
),
self
.
start_token_id
,
dtype
=
torch
.
long
,
self
.
growing_beam
=
torch
.
full
(
(
batch_size
*
self
.
beam_size
,
1
),
self
.
start_token_id
,
dtype
=
torch
.
long
)
topk_log_probabilities
=
torch
.
tensor
(
[
0.0
]
+
[
float
(
"-inf"
)]
*
(
self
.
beam_size
-
1
),
dtype
=
torch
.
float
,
self
.
topk_log_probabilities
=
torch
.
tensor
(
[
0.0
]
+
[
float
(
"-inf"
)]
*
(
self
.
beam_size
-
1
),
dtype
=
torch
.
float
).
repeat
(
batch_size
)
self
.
results
=
{
"prediction"
:
[[]
for
_
in
batch_size
],
"scores"
:
[[]
for
_
in
batch_size
],
}
self
.
_step
=
0
self
.
is_done
=
False
# Forward pass on the encoder
encoder_outputs
=
self
.
encoder
(
input_ids
,
kwargs_encoder
)
kwargs_decoder
[
"encoder_hidden_states"
]
=
tile
(
encoder_outputs
,
self
.
beam_size
,
dim
=
0
)
results
=
{}
results
[
"predictions"
]
=
[[]
for
_
in
batch_size
]
results
[
"scores"
]
=
[[]
for
_
in
batch_size
]
def
step
(
self
,
log_probabilities
):
""" Grows the beam by one step. """
self
.
_step
+=
1
for
step
in
range
(
self
.
max_length
):
decoder_input
=
growing_beam
[:,
-
1
]
outputs
=
self
.
decoder
(
decoder_input
,
kwargs_decoder
)
log_probabilities
=
torch
.
nn
.
functional
.
log_softmax
(
outputs
[
1
])
# The batch size changes as some beams finish so we define _B
vocab_size
=
log_probabilities
.
size
(
-
1
)
# The batch size changes as some beams finish so we define:
_B
=
log_probabilities
.
size
(
0
)
//
self
.
beam_size
# Multiply each beam probability with the probability of the
# next token (conditioned on the words in the beam).
log_probabilities
+=
topk_log_probabilities
.
view
(
-
1
,
1
)
log_probabilities
+=
self
.
topk_log_probabilities
.
view
(
-
1
,
1
)
# if the beam has not attained the minimum required length we
# make the end token arbitrarily unlikely.
if
step
<
self
.
min_length
:
log_probabilities
[
self
.
end_token_id
]
=
-
1e20
# Remove repeating tri-grams
if
(
self
.
args
.
block_trigram
):
if
(
step
+
1
>
3
):
for
i
in
range
(
_B
*
self
.
beam_size
):
tokens
=
[
t
for
t
in
growing_beam
[
i
]]
trigrams
=
[(
tokens
[
i
-
1
],
tokens
[
i
],
tokens
[
i
+
1
])
for
i
in
range
(
1
,
len
(
words
)
-
1
)]
last_trigram
=
tuple
(
trigrams
[
-
1
])
if
last_trigram
in
trigrams
[:
-
1
]:
log_probabilities
[
i
]
=
-
1e20
self
.
enforce_min_length
(
log_probabilities
)
if
self
.
block_repeating_trigram
:
self
.
remove_repeating_trigrams
(
log_probabilities
,
_B
)
# Find the `beam_size` (previous_beam + token) combinations with
# the highest score
topk_log_probabilities
,
topk_ids
=
log_probabilities
.
topk
(
log_probabilities
.
view
(
_B
,
self
.
beam_size
*
vocab_size
),
self
.
beam_size
,
dim
=
1
dim
=
1
,
)
# Apply the length penalty. The +1 accounts for the [EOS] token
# that will be added if the beam ends.
length_penalty
=
((
5.0
+
(
step
+
1
))
/
6.0
)
**
self
.
alpha
topk_scores
=
topk_log_probabilities
/
length_penalty
topk_scores
=
topk_log_probabilities
/
self
.
length_penalty
()
# Retrieve the corresponding respective beam and token id
# topk_token_ids[i] will be added to topk_beam_ids[i]
...
...
@@ -147,14 +112,14 @@ class ModelWithBeamSearch(nn.Module):
# Retrieve the row index of the surviving beams in the original
# view of the log_probabilities tensor
surviving_beams_rows
=
(
topk_beam_ids
+
beam_offset
[:
_B
].
view
(
-
1
,
1
)
).
view
(
-
1
)
surviving_beams_rows
=
(
topk_beam_ids
+
self
.
beam_offset
[:
_B
].
view
(
-
1
,
1
)).
view
(
-
1
)
# Append the last predictions
growing_beam
=
torch
.
cat
(
self
.
growing_beam
=
torch
.
cat
(
[
growing_beam
.
index_select
(
0
,
surviving_beams_rows
),
self
.
growing_beam
.
index_select
(
0
,
surviving_beams_rows
),
topk_token_ids
.
view
(
-
1
,
1
),
],
1
,
...
...
@@ -164,51 +129,117 @@ class ModelWithBeamSearch(nn.Module):
# growth step. Also if top beam (most probable) has ended
# for one element of the batch.
is_finished
=
topk_token_ids
.
eq
(
self
.
end_token_id
)
if
step
+
1
==
self
.
max_length
:
is_finished
.
fill_
(
1
)
self
.
enforce_max_length
()
is_top_beam_finished
=
is_finished
[:,
0
].
eq
(
1
)
# Save the finished searches
if
is_finished
.
any
():
predictions
=
growing_beam
.
view
(
-
1
,
self
.
beam_size
,
growing_beam
.
size
(
1
))
predictions
=
self
.
growing_beam
.
view
(
-
1
,
self
.
beam_size
,
self
.
growing_beam
.
size
(
1
)
)
for
i
in
range
(
is_finished
.
size
(
0
)):
if
is_top_beam_finished
[
i
]:
is_finished
[
i
].
fill_
(
1
)
finished_hyp
=
is_finished
[
i
].
nonzero
().
view
(
-
1
)
# Store finished hypotheses for this batch.
b
=
batch_offset
[
i
]
b
=
self
.
batch_offset
[
i
]
for
j
in
finished_hyp
:
hypotheses
[
b
].
append
((
topk_scores
[
i
,
j
],
predictions
[
i
,
j
,
:]))
self
.
hypotheses
[
b
].
append
((
topk_scores
[
i
,
j
],
predictions
[
i
,
j
,
:]))
# If the batch reached the end, save the best hypotheses
# in terms of length-penalized score.
if
is_top_beam_finished
[
i
]:
best_hyp
=
sorted
(
hypotheses
[
b
],
key
=
lambda
x
:
x
[
0
],
reverse
=
True
self
.
hypotheses
[
b
],
key
=
lambda
x
:
x
[
0
],
reverse
=
True
)
best_score
,
best_prediction
=
best_hyp
[
0
]
results
[
"scores"
][
b
].
append
(
best_score
)
results
[
"predictions"
][
b
].
append
(
best_prediction
)
self
.
results
[
"scores"
][
b
].
append
(
best_score
)
self
.
results
[
"predictions"
][
b
].
append
(
best_prediction
)
non_finished
=
is_top_beam_finished
.
eq
(
0
).
nonzero
().
view
(
-
1
)
if
len
(
non_finished
)
==
0
:
break
self
.
is_done
=
True
# Remove finished batches for the next step.
topk_log_probabilities
=
topk_log_probabilities
.
index_select
(
0
,
non_finished
)
batch_offset
=
batch_offset
.
index_select
(
0
,
non_finished
)
growing_beam
=
predictions
.
index_select
(
0
,
non_finished
).
view
(
-
1
,
growing_beam
.
size
(
-
1
)
topk_log_probabilities
=
topk_log_probabilities
.
index_select
(
0
,
non_finished
)
self
.
batch_offset
=
self
.
batch_offset
.
index_select
(
0
,
non_finished
)
self
.
growing_beam
=
predictions
.
index_select
(
0
,
non_finished
).
view
(
-
1
,
self
.
growing_beam
.
size
(
-
1
)
)
# Re-order the state for the next pass
surviving_beams_rows
=
surviving_beams_rows
.
index_select
(
0
,
non_finished
)
return
surviving_beams_rows
def
forward
(
self
,
encoder_input_ids
,
**
kwargs
):
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder
=
{
argument
[
len
(
"encoder_"
):]:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"encoder_"
)
}
kwargs_decoder
=
{
argument
[
len
(
"decoder_"
):]:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"decoder_"
)
}
kwargs_common
=
{
argument
:
value
for
argument
,
value
in
kwargs
.
items
()
if
not
(
argument
.
startswith
(
"encoder_"
)
or
argument
.
startswith
(
"decoder_"
))
}
kwargs_decoder
=
dict
(
kwargs_common
,
**
kwargs_decoder
)
kwargs_encoder
=
dict
(
kwargs_common
,
**
kwargs_encoder
)
# forward pass on the encoder
encoder_outputs
=
self
.
model
.
encoder
.
forward
(
encoder_input_ids
,
kwargs_encoder
)
kwargs_decoder
[
"encoder_hidden_states"
]
=
tile
(
encoder_outputs
,
self
.
beam_size
,
dim
=
0
)
# grow the beam by generating sequences in an autoregressive way
self
.
growing_beam
=
torch
.
full
(
(
self
.
batch_size
*
self
.
beam_size
,
1
),
self
.
start_token_id
,
dtype
=
torch
.
long
)
for
step
in
range
(
self
.
max_length
):
decoder_input
=
self
.
growing_beam
[:,
-
1
]
outputs
=
self
.
model
.
decoder
(
decoder_input
,
kwargs_decoder
)
log_probabilities
=
torch
.
nn
.
functional
.
log_softmax
(
outputs
[
1
])
surviving_beams_rows
=
self
.
step
(
log_probabilities
)
if
self
.
is_done
:
break
kwargs_decoder
[
"encoder_hidden_states"
]
=
kwargs_decoder
[
"encoder_hidden_states"
].
index_select
(
0
,
surviving_beams_rows
)
return
results
return
self
.
results
def
remove_repeating_trigrams
(
self
,
log_probabilities
,
_B
):
if
(
self
.
_step
+
1
>
3
):
for
i
in
range
(
_B
*
self
.
beam_size
):
tokens
=
[
t
for
t
in
self
.
growing_beam
[
i
]]
trigrams
=
[(
tokens
[
i
-
1
],
tokens
[
i
],
tokens
[
i
+
1
])
for
i
in
range
(
1
,
len
(
words
)
-
1
)]
last_trigram
=
tuple
(
trigrams
[
-
1
])
if
last_trigram
in
trigrams
[:
-
1
]:
log_probabilities
[
i
]
=
-
1e20
def
enforce_min_length
(
self
):
if
self
.
_step
<
self
.
min_length
:
self
.
log_probabilities
[
self
.
end_token_id
]
=
-
1e20
def
enforce_max_length
(
self
):
if
self
.
_step
+
1
==
self
.
max_length
:
self
.
is_finished
.
fill_
(
1
)
def
length_penalty
(
self
):
return
((
5.0
+
(
self
.
_step
+
1
))
/
6.0
)
**
self
.
alpha
def
tile
(
x
,
count
,
dim
=
0
):
...
...
transformers/modeling_bert.py
View file @
dfce4096
...
...
@@ -632,6 +632,8 @@ class BertModel(BertPreTrainedModel):
"""
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones_like
(
input_ids
)
if
encoder_attention_mask
is
None
:
encoder_attention_mask
=
torch
.
ones_like
(
input_ids
)
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros_like
(
input_ids
)
...
...
@@ -660,12 +662,15 @@ class BertModel(BertPreTrainedModel):
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
# If a 2D
encoder
attention mask is provided for the cross-attention
# If a 2D
ou 3D
attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if
encoder_attention_mask
is
not
None
:
encoder_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
encoder_attention_mask
=
encoder_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
encoder_attention_mask
=
(
1.0
-
encoder_attention_mask
)
*
-
10000.0
if
encoder_attention_mask
.
dim
()
==
3
:
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
:,
:]
if
encoder_attention_mask
.
dim
()
==
2
:
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
encoder_extended_attention_mask
=
encoder_extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
...
...
@@ -687,7 +692,7 @@ class BertModel(BertPreTrainedModel):
attention_mask
=
extended_attention_mask
,
head_mask
=
head_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
)
encoder_attention_mask
=
encoder_
extended_
attention_mask
)
sequence_output
=
encoder_outputs
[
0
]
pooled_output
=
self
.
pooler
(
sequence_output
)
...
...
@@ -788,8 +793,10 @@ class BertForMaskedLM(BertPreTrainedModel):
in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
**
masked_lm_
loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
**next_token_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Next token prediction loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
...
...
@@ -854,13 +861,13 @@ class BertForMaskedLM(BertPreTrainedModel):
if
lm_labels
is
not
None
:
# we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores
=
prediction_scores
[:,
:
-
1
,
:]
lm_labels
=
lm_labels
[:,
1
:]
prediction_scores
=
prediction_scores
[:,
:
-
1
,
:]
.
contiguous
()
lm_labels
=
lm_labels
[:,
1
:]
.
contiguous
()
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
seq2seq
_loss
=
loss_fct
(
prediction_scores
.
reshape
(
-
1
,
self
.
config
.
vocab_size
),
lm_labels
.
reshape
(
-
1
))
outputs
=
(
seq2seq
_loss
,)
+
outputs
next_token
_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
lm_labels
.
view
(
-
1
))
outputs
=
(
next_token
_loss
,)
+
outputs
return
outputs
# (m
lm_or_seq2seq
_loss), prediction_scores, (hidden_states), (attentions)
return
outputs
# (m
asked_lm_loss), (next_token
_loss), prediction_scores, (hidden_states), (attentions)
@
add_start_docstrings
(
"""Bert Model with a `next sentence prediction (classification)` head on top. """
,
...
...
transformers/modeling_seq2seq.py
View file @
dfce4096
...
...
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class
PreTrainedSeq2seq
(
nn
.
Module
):
r
"""
:class:`~transformers.Seq2seq` is a generic model class that will be
:class:`~transformers.
PreTrained
Seq2seq` is a generic model class that will be
instantiated as a Seq2seq model with one of the base model classes of
the library as encoder and (optionally) as decoder when created with
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
...
...
@@ -49,8 +49,7 @@ class PreTrainedSeq2seq(nn.Module):
*
model_args
,
**
kwargs
):
r
""" Instantiates an encoder and a decoder from one or two base classes
of the library from pre-trained model checkpoints.
r
""" Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
...
...
@@ -111,35 +110,44 @@ class PreTrainedSeq2seq(nn.Module):
model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
"""
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as a whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder
=
{
argument
:
value
argument
[
len
(
"encoder_"
):]
:
value
for
argument
,
value
in
kwargs
.
items
()
if
not
argument
.
startswith
(
"
d
ecoder_"
)
if
argument
.
startswith
(
"e
n
coder_"
)
}
kwargs_decoder
=
{
argument
[
len
(
"decoder_"
)
:]:
value
argument
[
len
(
"decoder_"
):]:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"decoder_"
)
}
kwargs_common
=
{
argument
:
value
for
argument
,
value
in
kwargs
.
items
()
if
not
(
argument
.
startswith
(
"encoder_"
)
or
argument
.
startswith
(
"decoder_"
))
}
kwargs_decoder
=
dict
(
kwargs_common
,
**
kwargs_decoder
)
kwargs_encoder
=
dict
(
kwargs_common
,
**
kwargs_encoder
)
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly.
encoder
=
kwargs_encoder
.
pop
(
"
encoder_
model"
,
None
)
encoder
=
kwargs_encoder
.
pop
(
"model"
,
None
)
if
encoder
is
None
:
kwargs_encoder
[
"is_decoder"
]
=
False
encoder
=
AutoModel
.
from_pretrained
(
encoder_pretrained_model_name_or_path
,
*
model_args
,
**
kwargs_encoder
)
encoder
.
config
.
is_decoder
=
False
decoder
=
kwargs_decoder
.
pop
(
"model"
,
None
)
if
decoder
is
None
:
kwargs_decoder
[
"is_decoder"
]
=
True
decoder
=
AutoModelWithLMHead
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs_decoder
)
decoder
.
config
.
is_decoder
=
True
model
=
cls
(
encoder
,
decoder
)
...
...
@@ -169,37 +177,60 @@ class PreTrainedSeq2seq(nn.Module):
decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
Indices of decoder input sequence tokens in the vocabulary.
"""
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder
=
{
argument
:
value
argument
[
len
(
"encoder_"
):]
:
value
for
argument
,
value
in
kwargs
.
items
()
if
not
argument
.
startswith
(
"
d
ecoder_"
)
if
argument
.
startswith
(
"e
n
coder_"
)
}
kwargs_decoder
=
{
argument
[
len
(
"decoder_"
)
:]:
value
argument
[
len
(
"decoder_"
):]:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"decoder_"
)
}
kwargs_common
=
{
argument
:
value
for
argument
,
value
in
kwargs
.
items
()
if
not
(
argument
.
startswith
(
"encoder_"
)
or
argument
.
startswith
(
"decoder_"
))
}
kwargs_decoder
=
dict
(
kwargs_common
,
**
kwargs_decoder
)
kwargs_encoder
=
dict
(
kwargs_common
,
**
kwargs_encoder
)
# Encode if needed (training, first prediction pass)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"
encoder_
hidden_states"
,
None
)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"hidden_states"
,
None
)
if
encoder_hidden_states
is
None
:
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
][
-
1
]
# output of the encoder *stack*
encoder_hidden_states
=
encoder_outputs
[
0
]
# output the last layer hidden state
else
:
encoder_outputs
=
()
# Decode
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
[
None
,
:,
:]
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
kwargs_decoder
[
"encoder_attention_mask"
]
=
kwargs_encoder
.
get
(
"attention_mask"
,
None
)
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
,
**
kwargs_decoder
)
return
decoder_outputs
+
encoder_outputs
class
Model2Model
(
PreTrainedSeq2seq
):
r
"""
:class:`~transformers.Model2Model` instantiates a Seq2Seq2 model
where both of the encoder and decoder are of the same family. If the
name of or that path to a pretrained model is specified the encoder and
the decoder will be initialized with the pretrained weight (the
cross-attention will be intialized randomly if its weights are not
present).
It is possible to override this behavior and initialize, say, the decoder randomly
by creating it beforehand as follows
config = BertConfig.from_pretrained()
decoder = BertForMaskedLM(config)
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder)
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
Model2Model
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
tie_weights
()
...
...
@@ -235,14 +266,10 @@ class Model2Model(PreTrainedSeq2seq):
model
=
super
(
Model2Model
,
cls
).
from_pretrained
(
encoder_pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
decoder_pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
*
args
,
**
kwargs
)
# Some architectures require for the decoder to be initialized randomly
# before fine-tuning.
if
kwargs
.
get
(
"decoder_initialize_randomly"
,
False
):
model
.
decoder
.
init_weights
()
return
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment