Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4c3ac4a7
Commit
4c3ac4a7
authored
Oct 18, 2019
by
Rémi Louf
Browse files
here's one big commit
parent
932543f7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
951 additions
and
47 deletions
+951
-47
examples/README.md
examples/README.md
+3
-2
examples/run_summarization_finetuning.py
examples/run_summarization_finetuning.py
+620
-0
examples/run_summarization_finetuning_test.py
examples/run_summarization_finetuning_test.py
+14
-14
transformers/__init__.py
transformers/__init__.py
+1
-1
transformers/modeling_beam_search.py
transformers/modeling_beam_search.py
+240
-0
transformers/modeling_bert.py
transformers/modeling_bert.py
+12
-8
transformers/modeling_seq2seq.py
transformers/modeling_seq2seq.py
+61
-22
No files found.
examples/README.md
View file @
4c3ac4a7
...
@@ -393,7 +393,8 @@ This fine-tuned model is available as a checkpoint under the reference
...
@@ -393,7 +393,8 @@ This fine-tuned model is available as a checkpoint under the reference
## Seq2seq model fine-tuning
## Seq2seq model fine-tuning
Based on the script
[
`run_seq2seq_finetuning.py`
](
https://github.com/huggingface/transformers/blob/master/examples/run_seq2seq_finetuning.py
)
.
Based on the script
[
`run_summarization_finetuning.py`
](
https://github.com/huggingface/transformers/blob/master/examples/run_summarization_finetuning.py
)
.
Before running this script you should download
**both**
CNN and Daily Mail
Before running this script you should download
**both**
CNN and Daily Mail
datasets from
[
Kyunghyun Cho's website
](
https://cs.nyu.edu/~kcho/DMQA/
)
(
the
datasets from
[
Kyunghyun Cho's website
](
https://cs.nyu.edu/~kcho/DMQA/
)
(
the
...
@@ -412,7 +413,7 @@ archive.
...
@@ -412,7 +413,7 @@ archive.
```
bash
```
bash
export
DATA_PATH
=
/path/to/dataset/
export
DATA_PATH
=
/path/to/dataset/
python run_s
eq2seq
_finetuning.py
\
python run_s
ummarization
_finetuning.py
\
--output_dir
=
output
\
--output_dir
=
output
\
--model_type
=
bert2bert
\
--model_type
=
bert2bert
\
--model_name_or_path
=
bert2bert
\
--model_name_or_path
=
bert2bert
\
...
...
examples/run_s
eq2seq
_finetuning.py
→
examples/run_s
ummarization
_finetuning.py
View file @
4c3ac4a7
# coding=utf-8
# coding=utf-8
# Copyright 201
8 The Microsoft Reseach team and
The HuggingFace Inc. team.
# Copyright 201
9
The HuggingFace Inc. team.
# Copyright (c) 201
8 Microsoft and
The HuggingFace Inc. All rights reserved.
# Copyright (c) 201
9
The HuggingFace Inc. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -18,18 +18,21 @@
...
@@ -18,18 +18,21 @@
import
argparse
import
argparse
from
collections
import
deque
from
collections
import
deque
import
logging
import
logging
import
os
import
pickle
import
pickle
import
random
import
random
import
o
s
import
sy
s
import
numpy
as
np
import
numpy
as
np
from
tqdm
import
tqdm
,
trange
from
tqdm
import
tqdm
,
trange
import
torch
import
torch
from
torch.utils.data
import
Dataset
,
RandomSampler
from
torch.optim
import
Adam
from
torch.utils.data
import
Dataset
,
DataLoader
,
RandomSampler
,
SequentialSampler
from
transformers
import
AutoTokenizer
,
Model2Model
from
transformers
import
AutoTokenizer
,
PreTrainedSeq2seq
,
Model2Model
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
logging
.
basicConfig
(
stream
=
sys
.
stdout
,
level
=
logging
.
INFO
)
def
set_seed
(
args
):
def
set_seed
(
args
):
...
@@ -61,7 +64,7 @@ class TextDataset(Dataset):
...
@@ -61,7 +64,7 @@ class TextDataset(Dataset):
def
__init__
(
self
,
tokenizer
,
prefix
=
"train"
,
data_dir
=
""
,
block_size
=
512
):
def
__init__
(
self
,
tokenizer
,
prefix
=
"train"
,
data_dir
=
""
,
block_size
=
512
):
assert
os
.
path
.
isdir
(
data_dir
)
assert
os
.
path
.
isdir
(
data_dir
)
# Load features that have already been computed if
present
# Load
the
features that have already been computed
,
if
any
cached_features_file
=
os
.
path
.
join
(
cached_features_file
=
os
.
path
.
join
(
data_dir
,
"cached_lm_{}_{}"
.
format
(
block_size
,
prefix
)
data_dir
,
"cached_lm_{}_{}"
.
format
(
block_size
,
prefix
)
)
)
...
@@ -72,12 +75,11 @@ class TextDataset(Dataset):
...
@@ -72,12 +75,11 @@ class TextDataset(Dataset):
return
return
logger
.
info
(
"Creating features from dataset at %s"
,
data_dir
)
logger
.
info
(
"Creating features from dataset at %s"
,
data_dir
)
self
.
examples
=
[]
datasets
=
[
"cnn"
,
"dailymail"
]
datasets
=
[
"cnn"
,
"dailymail"
]
self
.
examples
=
{
"source"
:
[],
"target"
:
[]}
for
dataset
in
datasets
:
for
dataset
in
datasets
:
path_to_stories
=
os
.
path
.
join
(
data_dir
,
dataset
,
"stories"
)
path_to_stories
=
os
.
path
.
join
(
data_dir
,
dataset
,
"stories"
)
assert
os
.
path
.
isdir
(
path_to_stories
)
story_filenames_list
=
os
.
listdir
(
path_to_stories
)
story_filenames_list
=
os
.
listdir
(
path_to_stories
)
for
story_filename
in
story_filenames_list
:
for
story_filename
in
story_filenames_list
:
path_to_story
=
os
.
path
.
join
(
path_to_stories
,
story_filename
)
path_to_story
=
os
.
path
.
join
(
path_to_stories
,
story_filename
)
...
@@ -85,19 +87,19 @@ class TextDataset(Dataset):
...
@@ -85,19 +87,19 @@ class TextDataset(Dataset):
continue
continue
with
open
(
path_to_story
,
encoding
=
"utf-8"
)
as
source
:
with
open
(
path_to_story
,
encoding
=
"utf-8"
)
as
source
:
try
:
raw_story
=
source
.
read
()
raw_story
=
source
.
read
()
story_lines
,
summary_lines
=
process_story
(
raw_story
)
story
,
summary
=
process_story
(
raw_story
)
if
len
(
summary_lines
)
==
0
or
len
(
story_lines
)
==
0
:
except
IndexError
:
# skip ill-formed stories
continue
continue
story
=
token
izer
.
encode
(
story
)
story
_
token
_ids
,
summary_token_ids
=
_encode_for_summarization
(
story_
seq
=
_fit_to_block_size
(
story
,
block_s
ize
)
story_
lines
,
summary_lines
,
token
ize
r
)
s
ummary
=
tokenizer
.
encode
(
summary
)
s
tory_seq
=
_fit_to_block_size
(
story_token_ids
,
block_size
)
s
ummary_seq
=
_fit_to_block_size
(
summary
,
block_size
)
s
elf
.
examples
[
"source"
].
append
(
story_seq
)
self
.
examples
.
append
((
story_seq
,
summary_seq
))
summary_seq
=
_fit_to_block_size
(
summary_token_ids
,
block_size
)
self
.
examples
[
"summary"
].
append
(
summary_seq
)
logger
.
info
(
"Saving features into cache file %s"
,
cached_features_file
)
logger
.
info
(
"Saving features into cache file %s"
,
cached_features_file
)
with
open
(
cached_features_file
,
"wb"
)
as
sink
:
with
open
(
cached_features_file
,
"wb"
)
as
sink
:
...
@@ -107,7 +109,10 @@ class TextDataset(Dataset):
...
@@ -107,7 +109,10 @@ class TextDataset(Dataset):
return
len
(
self
.
examples
)
return
len
(
self
.
examples
)
def
__getitem__
(
self
,
items
):
def
__getitem__
(
self
,
items
):
return
torch
.
tensor
(
self
.
examples
[
items
])
return
(
torch
.
tensor
(
self
.
examples
[
"source"
][
items
]),
torch
.
tensor
(
self
.
examples
[
"target"
][
items
]),
)
def
process_story
(
raw_story
):
def
process_story
(
raw_story
):
...
@@ -119,33 +124,55 @@ def process_story(raw_story):
...
@@ -119,33 +124,55 @@ def process_story(raw_story):
Raises:
Raises:
IndexError: If the stoy is empty or contains no highlights.
IndexError: If the stoy is empty or contains no highlights.
"""
"""
file
_lines
=
list
(
nonempty
_lines
=
list
(
filter
(
lambda
x
:
len
(
x
)
!=
0
,
[
line
.
strip
()
for
line
in
raw_story
.
split
(
"
\n
"
)])
filter
(
lambda
x
:
len
(
x
)
!=
0
,
[
line
.
strip
()
for
line
in
raw_story
.
split
(
"
\n
"
)])
)
)
# for some unknown reason some lines miss a period, add it
# for some unknown reason some lines miss a period, add it
file
_lines
=
[
_add_missing_period
(
line
)
for
line
in
file
_lines
]
nonempty
_lines
=
[
_add_missing_period
(
line
)
for
line
in
nonempty
_lines
]
# gather article lines
# gather article lines
story_lines
=
[]
story_lines
=
[]
lines
=
deque
(
file
_lines
)
lines
=
deque
(
nonempty
_lines
)
while
True
:
while
True
:
try
:
try
:
element
=
lines
.
popleft
()
element
=
lines
.
popleft
()
if
element
.
startswith
(
"@highlight"
):
if
element
.
startswith
(
"@highlight"
):
break
break
story_lines
.
append
(
element
)
story_lines
.
append
(
element
)
except
IndexError
as
ie
:
# if "@highlight" absent from file
except
IndexError
:
raise
ie
# if "@highlight" is absent from the file we pop
# all elements until there is None.
return
story_lines
,
[]
# gather summary lines
# gather summary lines
highlights
_lines
=
list
(
filter
(
lambda
t
:
not
t
.
startswith
(
"@highlight"
),
lines
))
summary
_lines
=
list
(
filter
(
lambda
t
:
not
t
.
startswith
(
"@highlight"
),
lines
))
# join the lines
return
story_lines
,
summary_lines
story
=
" "
.
join
(
story_lines
)
summary
=
" "
.
join
(
highlights_lines
)
return
story
,
summary
def
_encode_for_summarization
(
story_lines
,
summary_lines
,
tokenizer
):
""" Encode the story and summary lines, and join them
as specified in [1] by using `[SEP] [CLS]` tokens to separate
sentences.
"""
story_lines_token_ids
=
[
tokenizer
.
add_special_tokens_single_sequence
(
tokenizer
.
encode
(
line
))
for
line
in
story_lines
]
summary_lines_token_ids
=
[
tokenizer
.
add_special_tokens_single_sequence
(
tokenizer
.
encode
(
line
))
for
line
in
summary_lines
]
story_token_ids
=
[
token
for
sentence
in
story_lines_token_ids
for
token
in
sentence
]
summary_token_ids
=
[
token
for
sentence
in
summary_lines_token_ids
for
token
in
sentence
]
return
story_token_ids
,
summary_token_ids
def
_add_missing_period
(
line
):
def
_add_missing_period
(
line
):
...
@@ -170,8 +197,11 @@ def _fit_to_block_size(sequence, block_size):
...
@@ -170,8 +197,11 @@ def _fit_to_block_size(sequence, block_size):
def
mask_padding_tokens
(
sequence
):
def
mask_padding_tokens
(
sequence
):
""" Replace the padding token with -1 values """
""" Padding token, encoded as 0, are represented by the value -1 in the
return
[
s
if
s
!=
0
else
-
1
for
s
in
sequence
]
masks """
padded
=
sequence
.
clone
()
padded
[
padded
==
0
]
=
-
1
return
padded
def
load_and_cache_examples
(
args
,
tokenizer
):
def
load_and_cache_examples
(
args
,
tokenizer
):
...
@@ -179,81 +209,181 @@ def load_and_cache_examples(args, tokenizer):
...
@@ -179,81 +209,181 @@ def load_and_cache_examples(args, tokenizer):
return
dataset
return
dataset
def
compute_token_type_ids
(
batch
,
separator_token_id
):
""" Segment embeddings as described in [1]
The values {0,1} were found in the repository [2].
Attributes:
batch: torch.Tensor, size [batch_size, block_size]
Batch of input.
separator_token_id: int
The value of the token that separates the segments.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
"""
batch_embeddings
=
[]
sentence_num
=
0
for
sequence
in
batch
:
embeddings
=
[]
for
s
in
sequence
:
if
s
==
separator_token_id
:
sentence_num
+=
1
embeddings
.
append
(
sentence_num
%
2
)
batch_embeddings
.
append
(
embeddings
)
return
torch
.
tensor
(
batch_embeddings
)
# ----------
# Optimizers
# ----------
class
BertSumOptimizer
(
object
):
""" Specific optimizer for BertSum.
As described in [1], the authors fine-tune BertSum for abstractive
summarization using two Adam Optimizers with different warm-up steps and
learning rate. They also use a custom learning rate scheduler.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
"""
def
__init__
(
self
,
model
,
lr
,
warmup_steps
,
beta_1
=
0.99
,
beta_2
=
0.999
,
eps
=
1e-9
):
self
.
encoder
=
model
.
encoder
self
.
decoder
=
model
.
decoder
self
.
lr
=
lr
self
.
warmup_steps
=
warmup_steps
self
.
optimizers
=
{
"encoder"
:
Adam
(
model
.
encoder
.
parameters
(),
lr
=
lr
[
"encoder"
],
betas
=
(
beta_1
,
beta_2
),
eps
=
eps
,
),
"decoder"
:
Adam
(
model
.
decoder
.
parameters
(),
lr
=
lr
[
"decoder"
],
betas
=
(
beta_1
,
beta_2
),
eps
=
eps
,
),
}
self
.
_step
=
0
def
_update_rate
(
self
,
stack
):
return
self
.
lr
[
stack
]
*
min
(
self
.
_step
**
(
-
0.5
),
self
.
_step
*
self
.
warmup_steps
[
stack
]
**
(
-
0.5
)
)
def
zero_grad
(
self
):
self
.
optimizer_decoder
.
zero_grad
()
self
.
optimizer_encoder
.
zero_grad
()
def
step
(
self
):
self
.
_step
+=
1
for
stack
,
optimizer
in
self
.
optimizers
.
items
():
new_rate
=
self
.
_update_rate
(
stack
)
for
param_group
in
optimizer
.
param_groups
:
param_group
[
"lr"
]
=
new_rate
optimizer
.
step
()
# ------------
# ------------
# Train
# Train
# ------------
# ------------
def
train
(
args
,
train_dataset
,
model
,
tokenizer
):
def
train
(
args
,
model
,
tokenizer
):
""" Fine-tune the pretrained model on the corpus. """
""" Fine-tune the pretrained model on the corpus. """
set_seed
(
args
)
# Prepare the data loading
# Load the data
args
.
train_bach_size
=
1
args
.
train_batch_size
=
args
.
per_gpu_train_batch_size
*
max
(
1
,
args
.
n_gpu
)
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
)
train_sampler
=
RandomSampler
(
train_dataset
)
train_sampler
=
RandomSampler
(
train_dataset
)
train_dataloader
=
DataLoader
(
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_bach_size
train_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_ba
t
ch_size
)
)
# Prepare the optimizer and schedule (linear warmup and decay)
# Training schedule
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
if
args
.
max_steps
>
0
:
optimizer_grouped_parameters
=
[
t_total
=
args
.
max_steps
{
args
.
num_train_epochs
=
t_total
//
(
"params"
:
[
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
+
1
p
)
for
n
,
p
in
model
.
named_parameters
()
else
:
if
not
any
(
nd
in
n
for
nd
in
no_decay
)
t_total
=
(
],
len
(
train_dataloader
)
"weight_decay"
:
args
.
weight_decay
,
//
args
.
gradient_accumulation_steps
},
*
args
.
num_train_epochs
{
)
"params"
:
[
p
# Prepare the optimizer
for
n
,
p
in
model
.
named_parameters
()
lr
=
{
"encoder"
:
0.002
,
"decoder"
:
0.2
}
if
any
(
nd
in
n
for
nd
in
no_decay
)
warmup_steps
=
{
"encoder"
:
20000
,
"decoder"
:
10000
}
],
optimizer
=
BertSumOptimizer
(
model
,
lr
,
warmup_steps
)
"weight_decay"
:
0.0
,
},
]
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
lr
=
args
.
learning_rate
,
eps
=
args
.
adam_epsilon
)
scheduler
=
WarmupLinearSchedule
(
optimizer
,
warmup_steps
=
args
.
warmup_steps
,
t_total
=
t_total
)
# Train
# Train
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
" Num examples = %d"
,
len
(
train_dataset
))
logger
.
info
(
" Num examples = %d"
,
len
(
train_dataset
))
logger
.
info
(
" Num Epochs = %d"
,
args
.
num_train_epochs
)
logger
.
info
(
" Num Epochs = %d"
,
args
.
num_train_epochs
)
logger
.
info
(
" Instantaneous batch size per GPU = %d"
,
args
.
per_gpu_train_batch_size
)
logger
.
info
(
logger
.
info
(
" Total train batch size (w. parallel, distributed & accumulation) = %d"
,
" Instantaneous batch size per GPU = %d"
,
args
.
per_gpu_train_batch_size
args
.
train_batch_size
)
*
args
.
gradient_accumulation_steps
logger
.
info
(
*
(
torch
.
distributed
.
get_world_size
()
if
args
.
local_rank
!=
-
1
else
1
),
" Total train batch size (w. parallel, distributed & accumulation) = %d"
,
args
.
train_batch_size
*
args
.
gradient_accumulation_steps
# * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
)
)
logger
.
info
(
" Gradient Accumulation steps = %d"
,
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Gradient Accumulation steps = %d"
,
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
global_step
=
0
tr_loss
,
logging_loss
=
0.0
,
0.0
model
.
zero_grad
()
model
.
zero_grad
()
train_iterator
=
trange
(
args
.
num_train_epochs
,
desc
=
"Epoch"
,
disable
=
True
)
train_iterator
=
trange
(
args
.
num_train_epochs
,
desc
=
"Epoch"
,
disable
=
True
)
set_seed
(
args
)
global_step
=
0
tr_loss
=
0.0
for
_
in
train_iterator
:
for
_
in
train_iterator
:
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
True
)
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
True
)
for
step
,
batch
in
enumerate
(
epoch_iterator
):
for
step
,
batch
in
enumerate
(
epoch_iterator
):
source
=
([
s
for
s
,
_
in
batch
]).
to
(
args
.
device
)
source
,
target
=
batch
target
=
([
t
for
_
,
t
in
batch
]).
to
(
args
.
device
)
token_type_ids
=
compute_token_type_ids
(
source
,
tokenizer
.
cls_token_id
)
labels_src
=
mask_padding_tokens
(
source
)
labels_tgt
=
mask_padding_tokens
(
target
)
source
=
source
.
to
(
args
.
device
)
target
=
target
.
to
(
args
.
device
)
token_type_ids
=
token_type_ids
.
to
(
args
.
device
)
labels_src
=
labels_src
.
to
(
args
.
device
)
labels_tgt
=
labels_tgt
.
to
(
args
.
device
)
model
.
train
()
model
.
train
()
outputs
=
model
(
source
,
target
,
decoder_lm_labels
=
mask_padding_tokens
(
target
))
outputs
=
model
(
source
,
target
,
token_type_ids
=
token_type_ids
,
decoder_encoder_attention_mask
=
labels_src
,
decoder_attention_mask
=
labels_tgt
,
decoder_lm_labels
=
labels_tgt
,
decoder_initialize_randomly
=
True
,
)
loss
=
outputs
[
0
]
loss
=
outputs
[
0
]
print
(
loss
)
if
args
.
gradient_accumulation_steps
>
1
:
loss
/=
args
.
gradient_accumulation_steps
loss
.
backward
()
loss
.
backward
()
tr_loss
+=
loss
.
item
()
tr_loss
+=
loss
.
item
()
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
max_grad_norm
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
max_grad_norm
)
optimizer
.
step
()
optimizer
.
step
()
scheduler
.
step
()
model
.
zero_grad
()
model
.
zero_grad
()
global_step
+=
1
global_step
+=
1
...
@@ -268,6 +398,68 @@ def train(args, train_dataset, model, tokenizer):
...
@@ -268,6 +398,68 @@ def train(args, train_dataset, model, tokenizer):
return
global_step
,
tr_loss
/
global_step
return
global_step
,
tr_loss
/
global_step
# ------------
# Train
# ------------
def
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
""
):
set_seed
(
args
)
args
.
eval_batch_size
=
args
.
per_gpu_eval_batch_size
*
max
(
1
,
args
.
n_gpu
)
eval_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
True
)
eval_sampler
=
SequentialSampler
(
eval_dataset
)
eval_dataloader
=
DataLoader
(
eval_dataset
,
sampler
=
eval_sampler
,
batch_size
=
args
.
eval_batch_size
)
logger
.
info
(
"***** Running evaluation {} *****"
.
format
(
prefix
))
logger
.
info
(
" Num examples = %d"
,
len
(
eval_dataset
))
logger
.
info
(
" Batch size = %d"
,
args
.
eval_batch_size
)
eval_loss
=
0.0
nb_eval_steps
=
0
model
.
eval
()
for
batch
in
tqdm
(
eval_dataloader
,
desc
=
"Evaluating"
):
source
,
target
=
batch
labels_src
=
mask_padding_tokens
(
source
)
labels_tgt
=
mask_padding_tokens
(
target
)
source
.
to
(
args
.
device
)
target
.
to
(
args
.
device
)
labels_src
.
to
(
args
.
device
)
labels_tgt
.
to
(
args
.
device
)
with
torch
.
no_grad
():
outputs
=
model
(
source
,
target
,
decoder_encoder_attention_mask
=
labels_src
,
decoder_attention_mask
=
labels_tgt
,
decoder_lm_labels
=
labels_tgt
,
)
lm_loss
=
outputs
[
0
]
eval_loss
+=
lm_loss
.
mean
().
item
()
nb_eval_steps
+=
1
eval_loss
=
eval_loss
/
nb_eval_steps
perplexity
=
torch
.
exp
(
torch
.
tensor
(
eval_loss
))
result
=
{
"perplexity"
:
perplexity
}
# Save the evaluation's results
output_eval_file
=
os
.
path
.
join
(
args
.
output_dir
,
"eval_results.txt"
)
if
not
os
.
path
.
exists
(
args
.
output_dir
):
os
.
makedirs
(
args
.
output_dir
)
with
open
(
output_eval_file
,
"w"
)
as
writer
:
logger
.
info
(
"***** Eval results {} *****"
.
format
(
prefix
))
for
key
in
sorted
(
result
.
keys
()):
logger
.
info
(
" %s = %s"
,
key
,
str
(
result
[
key
]))
writer
.
write
(
"%s = %s
\n
"
%
(
key
,
str
(
result
[
key
])))
return
result
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
@@ -289,7 +481,23 @@ def main():
...
@@ -289,7 +481,23 @@ def main():
# Optional parameters
# Optional parameters
parser
.
add_argument
(
parser
.
add_argument
(
"--adam_epsilon"
,
default
=
1e-8
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
)
parser
.
add_argument
(
"--do_evaluate"
,
type
=
bool
,
default
=
False
,
help
=
"Run model evaluation on out-of-sample data."
,
)
parser
.
add_argument
(
"--do_train"
,
type
=
bool
,
default
=
False
,
help
=
"Run training."
)
parser
.
add_argument
(
"--do_overwrite_output_dir"
,
type
=
bool
,
default
=
False
,
help
=
"Whether to overwrite the output dir."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--model_name_or_path"
,
"--model_name_or_path"
,
...
@@ -303,12 +511,6 @@ def main():
...
@@ -303,12 +511,6 @@ def main():
type
=
str
,
type
=
str
,
help
=
"The decoder architecture to be fine-tuned."
,
help
=
"The decoder architecture to be fine-tuned."
,
)
)
parser
.
add_argument
(
"--learning_rate"
,
default
=
5e-5
,
type
=
float
,
help
=
"The initial learning rate for Adam."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
)
...
@@ -318,43 +520,100 @@ def main():
...
@@ -318,43 +520,100 @@ def main():
type
=
int
,
type
=
int
,
help
=
"If > 0: set total number of training steps to perform. Override num_train_epochs."
,
help
=
"If > 0: set total number of training steps to perform. Override num_train_epochs."
,
)
)
parser
.
add_argument
(
"--to_cpu"
,
default
=
False
,
type
=
bool
,
help
=
"Whether to force training on CPU."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num_train_epochs"
,
"--num_train_epochs"
,
default
=
1
,
default
=
1
,
type
=
int
,
type
=
int
,
help
=
"Total number of training epochs to perform."
,
help
=
"Total number of training epochs to perform."
,
)
)
parser
.
add_argument
(
"--seed"
,
default
=
42
,
type
=
int
)
parser
.
add_argument
(
"--warmup_steps"
,
default
=
0
,
type
=
int
,
help
=
"Linear warmup over warmup_steps."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--weight_decay"
,
default
=
0.0
,
type
=
float
,
help
=
"Weight deay if we apply some."
"--per_gpu_train_batch_size"
,
default
=
4
,
type
=
int
,
help
=
"Batch size per GPU/CPU for training."
,
)
)
parser
.
add_argument
(
"--seed"
,
default
=
42
,
type
=
int
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
model_type
!=
"bert"
:
if
(
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
)
and
args
.
do_train
and
not
args
.
do_overwrite_output_dir
):
raise
ValueError
(
raise
ValueError
(
"Only the BERT architecture is currently supported for seq2seq."
"Output directory ({}) already exists and is not empty. Use --do_overwrite_output_dir to overwrite."
.
format
(
args
.
output_dir
)
)
)
# Set up training device
# Set up training device
# device = torch.device("cpu")
if
args
.
to_cpu
or
not
torch
.
cuda
.
is_available
():
args
.
device
=
torch
.
device
(
"cpu"
)
# Set seed
args
.
n_gpu
=
0
set_seed
(
args
)
else
:
args
.
device
=
torch
.
device
(
"cuda"
)
args
.
n_gpu
=
torch
.
cuda
.
device_count
()
# Load pretrained model and tokenizer
# Load pretrained model and tokenizer
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
Model2Model
.
from_pretrained
(
args
.
model_name_or_path
)
model
=
Model2Model
.
from_pretrained
(
args
.
model_name_or_path
)
# model.to(device)
# Setup logging
logging
.
basicConfig
(
format
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
level
=
logging
.
INFO
,
)
logger
.
warning
(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s"
,
0
,
args
.
device
,
args
.
n_gpu
,
False
,
False
,
)
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
# Training
# Train the model
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
)
model
.
to
(
args
.
device
)
global_step
,
tr_loss
=
train
(
args
,
train_dataset
,
model
,
tokenizer
)
if
args
.
do_train
:
# logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
global_step
,
tr_loss
=
train
(
args
,
model
,
tokenizer
)
logger
.
info
(
" global_step = %s, average loss = %s"
,
global_step
,
tr_loss
)
if
not
os
.
path
.
exists
(
args
.
output_dir
):
os
.
makedirs
(
args
.
output_dir
)
logger
.
info
(
"Saving model checkpoint to %s"
,
args
.
output_dir
)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save
=
(
model
.
module
if
hasattr
(
model
,
"module"
)
else
model
)
# Take care of distributed/parallel training
model_to_save
.
save_pretrained
(
args
.
output_dir
)
tokenizer
.
save_pretrained
(
args
.
output_dir
)
torch
.
save
(
args
,
os
.
path
.
join
(
args
.
output_dir
,
"training_arguments.bin"
))
# Evaluate the model
results
=
{}
if
args
.
do_evaluate
:
checkpoints
=
[]
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
for
checkpoint
in
checkpoints
:
encoder_checkpoint
=
os
.
path
.
join
(
checkpoint
,
"encoder"
)
decoder_checkpoint
=
os
.
path
.
join
(
checkpoint
,
"decoder"
)
model
=
PreTrainedSeq2seq
.
from_pretrained
(
encoder_checkpoint
,
decoder_checkpoint
)
model
.
to
(
args
.
device
)
results
=
"placeholder"
return
results
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/run_s
eq2seq
_finetuning_test.py
→
examples/run_s
ummarization
_finetuning_test.py
View file @
4c3ac4a7
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# limitations under the License.
# limitations under the License.
import
unittest
import
unittest
from
run_s
eq2seq
_finetuning
import
_fit_to_block_size
,
process_story
from
run_s
ummarization
_finetuning
import
_fit_to_block_size
,
process_story
class
DataLoaderTest
(
unittest
.
TestCase
):
class
DataLoaderTest
(
unittest
.
TestCase
):
...
@@ -43,15 +43,16 @@ class DataLoaderTest(unittest.TestCase):
...
@@ -43,15 +43,16 @@ class DataLoaderTest(unittest.TestCase):
raw_story
=
"""It was the year of Our Lord one thousand seven hundred and
raw_story
=
"""It was the year of Our Lord one thousand seven hundred and
seventy-five.
\n\n
Spiritual revelations were conceded to England at that
seventy-five.
\n\n
Spiritual revelations were conceded to England at that
favoured period, as at this."""
favoured period, as at this."""
with
self
.
assertRaises
(
IndexError
):
_
,
summary
=
process_story
(
raw_story
)
process_story
(
raw_story
)
self
.
assertEqual
(
summary
,
[]
)
def
test_process_empty_story
(
self
):
def
test_process_empty_story
(
self
):
""" An empty story should also raise and exception.
""" An empty story should also raise and exception.
"""
"""
raw_story
=
""
raw_story
=
""
with
self
.
assertRaises
(
IndexError
):
story
,
summary
=
process_story
(
raw_story
)
process_story
(
raw_story
)
self
.
assertEqual
(
story
,
[])
self
.
assertEqual
(
summary
,
[])
def
test_story_with_missing_period
(
self
):
def
test_story_with_missing_period
(
self
):
raw_story
=
(
raw_story
=
(
...
@@ -59,17 +60,16 @@ class DataLoaderTest(unittest.TestCase):
...
@@ -59,17 +60,16 @@ class DataLoaderTest(unittest.TestCase):
"seventy-five
\n\n
Spiritual revelations were conceded to England "
"seventy-five
\n\n
Spiritual revelations were conceded to England "
"at that favoured period, as at this.
\n
@highlight
\n\n
It was the best of times"
"at that favoured period, as at this.
\n
@highlight
\n\n
It was the best of times"
)
)
story
,
summary
=
process_story
(
raw_story
)
story
_lines
,
summary
_lines
=
process_story
(
raw_story
)
expected_story
=
(
expected_story_lines
=
[
"It was the year of Our Lord one thousand seven hundred and "
"It was the year of Our Lord one thousand seven hundred and seventy-five."
,
"seventy-five. Spiritual revelations were conceded to England at that "
"Spiritual revelations were conceded to England at that favoured period, as at this."
,
"favoured period, as at this."
]
)
self
.
assertEqual
(
expected_story_lines
,
story_lines
)
self
.
assertEqual
(
expected_story
,
story
)
expected_summary
=
"It was the best of times."
expected_summary
_lines
=
[
"It was the best of times."
]
self
.
assertEqual
(
expected_summary
,
summary
)
self
.
assertEqual
(
expected_summary
_lines
,
summary
_lines
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
transformers/__init__.py
View file @
4c3ac4a7
...
@@ -87,7 +87,7 @@ if is_torch_available():
...
@@ -87,7 +87,7 @@ if is_torch_available():
from
.modeling_distilbert
import
(
DistilBertForMaskedLM
,
DistilBertModel
,
from
.modeling_distilbert
import
(
DistilBertForMaskedLM
,
DistilBertModel
,
DistilBertForSequenceClassification
,
DistilBertForQuestionAnswering
,
DistilBertForSequenceClassification
,
DistilBertForQuestionAnswering
,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_seq2seq
import
Model2Model
from
.modeling_seq2seq
import
PreTrainedSeq2seq
,
Model2Model
# Optimization
# Optimization
from
.optimization
import
(
AdamW
,
ConstantLRSchedule
,
WarmupConstantSchedule
,
WarmupCosineSchedule
,
from
.optimization
import
(
AdamW
,
ConstantLRSchedule
,
WarmupConstantSchedule
,
WarmupCosineSchedule
,
...
...
transformers/modeling_beam_search.py
0 → 100644
View file @
4c3ac4a7
# coding=utf-8
# Copyright (c) 2019 Yang Liu
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
A general wrapper around models with LM heads to generate sequences
using beam search.
"""
import
torch
from
torch
import
nn
class
ModelWithBeamSearch
(
nn
.
Module
):
def
__init__
(
self
,
model
,
beam_size
,
start_token_id
,
end_token_id
,
pad_token_id
,
min_length
,
max_length
,
alpha
,
block_trigram
=
True
,
):
"""
Attributes:
mask_word_id: token id that corresponds to the mask
"""
super
(
ModelWithBeamSearch
,
self
).
__init__
()
self
.
model
=
model
self
.
beam_size
=
beam_size
self
.
start_token_id
=
start_token_id
self
.
end_token_id
=
end_token_id
self
.
pad_token_id
=
pad_token_id
self
.
min_length
=
min_length
self
.
max_length
=
max_length
self
.
alpha
=
alpha
self
.
block_trigram
=
block_trigram
def
forward
(
self
,
input_ids
,
**
kwargs
):
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
kwargs_encoder
=
{
argument
:
value
for
argument
,
value
in
kwargs
.
items
()
if
not
argument
.
startswith
(
"decoder_"
)
}
kwargs_decoder
=
{
argument
[
len
(
"decoder_"
):]:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"decoder_"
)
}
batch_size
,
_
=
input_ids
.
size
(
0
)
# Variables that keep track of the status of the search
hypotheses
=
[[]
for
_
in
range
(
batch_size
)]
batch_offset
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
long
)
beam_offset
=
torch
.
arange
(
0
,
batch_size
*
self
.
beam_size
,
step
=
self
.
beam_size
,
dtype
=
torch
.
long
,
)
growing_beam
=
torch
.
full
(
(
batch_size
*
self
.
beam_size
,
1
),
self
.
start_token_id
,
dtype
=
torch
.
long
,
)
topk_log_probabilities
=
torch
.
tensor
(
[
0.0
]
+
[
float
(
"-inf"
)]
*
(
self
.
beam_size
-
1
),
dtype
=
torch
.
float
,
).
repeat
(
batch_size
)
# Forward pass on the encoder
encoder_outputs
=
self
.
encoder
(
input_ids
,
kwargs_encoder
)
kwargs_decoder
[
"encoder_hidden_states"
]
=
tile
(
encoder_outputs
,
self
.
beam_size
,
dim
=
0
)
results
=
{}
results
[
"predictions"
]
=
[[]
for
_
in
batch_size
]
results
[
"scores"
]
=
[[]
for
_
in
batch_size
]
for
step
in
range
(
self
.
max_length
):
decoder_input
=
growing_beam
[:,
-
1
]
outputs
=
self
.
decoder
(
decoder_input
,
kwargs_decoder
)
log_probabilities
=
torch
.
nn
.
functional
.
log_softmax
(
outputs
[
1
])
vocab_size
=
log_probabilities
.
size
(
-
1
)
# The batch size changes as some beams finish so we define:
_B
=
log_probabilities
.
size
(
0
)
//
self
.
beam_size
# Multiply each beam probability with the probability of the
# next token (conditioned on the words in the beam).
log_probabilities
+=
topk_log_probabilities
.
view
(
-
1
,
1
)
# if the beam has not attained the minimum required length we
# make the end token arbitrarily unlikely.
if
step
<
self
.
min_length
:
log_probabilities
[
self
.
end_token_id
]
=
-
1e20
# Remove repeating tri-grams
if
(
self
.
args
.
block_trigram
):
if
(
step
+
1
>
3
):
for
i
in
range
(
_B
*
self
.
beam_size
):
tokens
=
[
t
for
t
in
growing_beam
[
i
]]
trigrams
=
[(
tokens
[
i
-
1
],
tokens
[
i
],
tokens
[
i
+
1
])
for
i
in
range
(
1
,
len
(
words
)
-
1
)]
last_trigram
=
tuple
(
trigrams
[
-
1
])
if
last_trigram
in
trigrams
[:
-
1
]:
log_probabilities
[
i
]
=
-
1e20
# Find the `beam_size` (previous_beam + token) combinations with
# the highest score
topk_log_probabilities
,
topk_ids
=
log_probabilities
.
topk
(
log_probabilities
.
view
(
_B
,
self
.
beam_size
*
vocab_size
),
self
.
beam_size
,
dim
=
1
)
# Apply the length penalty. The +1 accounts for the [EOS] token
# that will be added if the beam ends.
length_penalty
=
((
5.0
+
(
step
+
1
))
/
6.0
)
**
self
.
alpha
topk_scores
=
topk_log_probabilities
/
length_penalty
# Retrieve the corresponding respective beam and token id
# topk_token_ids[i] will be added to topk_beam_ids[i]
topk_beam_ids
=
topk_ids
.
div
(
vocab_size
)
topk_token_ids
=
topk_ids
.
fmod
(
vocab_size
)
# Retrieve the row index of the surviving beams in the original
# view of the log_probabilities tensor
surviving_beams_rows
=
(
topk_beam_ids
+
beam_offset
[:
_B
].
view
(
-
1
,
1
)
).
view
(
-
1
)
# Append the last predictions
growing_beam
=
torch
.
cat
(
[
growing_beam
.
index_select
(
0
,
surviving_beams_rows
),
topk_token_ids
.
view
(
-
1
,
1
),
],
1
,
)
# Check if any of the beam searches has ended during this
# growth step. Also if top beam (most probable) has ended
# for one element of the batch.
is_finished
=
topk_token_ids
.
eq
(
self
.
end_token_id
)
if
step
+
1
==
self
.
max_length
:
is_finished
.
fill_
(
1
)
is_top_beam_finished
=
is_finished
[:,
0
].
eq
(
1
)
# Save the finished searches
if
is_finished
.
any
():
predictions
=
growing_beam
.
view
(
-
1
,
self
.
beam_size
,
growing_beam
.
size
(
1
))
for
i
in
range
(
is_finished
.
size
(
0
)):
if
is_top_beam_finished
[
i
]:
is_finished
[
i
].
fill_
(
1
)
finished_hyp
=
is_finished
[
i
].
nonzero
().
view
(
-
1
)
# Store finished hypotheses for this batch.
b
=
batch_offset
[
i
]
for
j
in
finished_hyp
:
hypotheses
[
b
].
append
((
topk_scores
[
i
,
j
],
predictions
[
i
,
j
,
:]))
# If the batch reached the end, save the best hypotheses
# in terms of length-penalized score.
if
is_top_beam_finished
[
i
]:
best_hyp
=
sorted
(
hypotheses
[
b
],
key
=
lambda
x
:
x
[
0
],
reverse
=
True
)
best_score
,
best_prediction
=
best_hyp
[
0
]
results
[
"scores"
][
b
].
append
(
best_score
)
results
[
"predictions"
][
b
].
append
(
best_prediction
)
non_finished
=
is_top_beam_finished
.
eq
(
0
).
nonzero
().
view
(
-
1
)
if
len
(
non_finished
)
==
0
:
break
# Remove finished batches for the next step.
topk_log_probabilities
=
topk_log_probabilities
.
index_select
(
0
,
non_finished
)
batch_offset
=
batch_offset
.
index_select
(
0
,
non_finished
)
growing_beam
=
predictions
.
index_select
(
0
,
non_finished
).
view
(
-
1
,
growing_beam
.
size
(
-
1
)
)
# Re-order the state for the next pass
surviving_beams_rows
=
surviving_beams_rows
.
index_select
(
0
,
non_finished
)
kwargs_decoder
[
"encoder_hidden_states"
]
=
kwargs_decoder
[
"encoder_hidden_states"
].
index_select
(
0
,
surviving_beams_rows
)
return
results
def
tile
(
x
,
count
,
dim
=
0
):
"""
Tiles `x` along dimension `dim` `count` times.
Example:
>> ex = torch.tensor([1,2],[3,4])
>> tile(ex, 2, 0)
torch.Tensor([[1,2],[1,2],[3,4],[3,4]])
"""
perm
=
list
(
range
(
len
(
x
.
size
())))
if
dim
!=
0
:
perm
[
0
],
perm
[
dim
]
=
perm
[
dim
],
perm
[
0
]
x
=
x
.
permute
(
perm
).
contiguous
()
out_size
=
list
(
x
.
size
())
out_size
[
0
]
*=
count
batch
=
x
.
size
(
0
)
x
=
(
x
.
view
(
batch
,
-
1
)
.
transpose
(
0
,
1
)
.
repeat
(
count
,
1
)
.
transpose
(
0
,
1
)
.
contiguous
()
.
view
(
*
out_size
)
)
if
dim
!=
0
:
x
=
x
.
permute
(
perm
).
contiguous
()
return
x
transformers/modeling_bert.py
View file @
4c3ac4a7
...
@@ -646,7 +646,7 @@ class BertModel(BertPreTrainedModel):
...
@@ -646,7 +646,7 @@ class BertModel(BertPreTrainedModel):
if
attention_mask
.
dim
()
==
2
:
if
attention_mask
.
dim
()
==
2
:
if
self
.
config
.
is_decoder
:
if
self
.
config
.
is_decoder
:
batch_size
,
seq_length
=
input_ids
.
size
()
batch_size
,
seq_length
=
input_ids
.
size
()
seq_ids
=
torch
.
arange
(
seq_length
)
seq_ids
=
torch
.
arange
(
seq_length
,
device
=
input_ids
.
device
)
causal_mask
=
seq_ids
[
None
,
None
,
:].
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
causal_mask
=
seq_ids
[
None
,
None
,
:].
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[:,
None
,
None
,
:]
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[:,
None
,
None
,
:]
else
:
else
:
...
@@ -660,6 +660,13 @@ class BertModel(BertPreTrainedModel):
...
@@ -660,6 +660,13 @@ class BertModel(BertPreTrainedModel):
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
# If a 2D encoder attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if
encoder_attention_mask
is
not
None
:
encoder_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
encoder_attention_mask
=
encoder_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
encoder_attention_mask
=
(
1.0
-
encoder_attention_mask
)
*
-
10000.0
# Prepare head mask if needed
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# attention_probs has shape bsz x n_heads x N x N
...
@@ -819,7 +826,7 @@ class BertForMaskedLM(BertPreTrainedModel):
...
@@ -819,7 +826,7 @@ class BertForMaskedLM(BertPreTrainedModel):
self
.
bert
.
embeddings
.
word_embeddings
)
self
.
bert
.
embeddings
.
word_embeddings
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
masked_lm_labels
=
None
,
lm_labels
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
):
masked_lm_labels
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
,
lm_labels
=
None
,
):
outputs
=
self
.
bert
(
input_ids
,
outputs
=
self
.
bert
(
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
...
@@ -838,11 +845,8 @@ class BertForMaskedLM(BertPreTrainedModel):
...
@@ -838,11 +845,8 @@ class BertForMaskedLM(BertPreTrainedModel):
# 1. If a tensor that contains the indices of masked labels is provided,
# 1. If a tensor that contains the indices of masked labels is provided,
# the cross-entropy is the MLM cross-entropy that measures the likelihood
# the cross-entropy is the MLM cross-entropy that measures the likelihood
# of predictions for masked words.
# of predictions for masked words.
# 2. If
encoder hidden states are
provided we are in a causal s
ituat
io
n
where we
# 2. If
`lm_label` is
provided we are in a causal s
cenar
io where we
# try to predict the next word for each input in the encoder.
# try to predict the next word for each input in the encoder.
if
masked_lm_labels
is
not
None
and
lm_labels
is
not
None
:
raise
AttributeError
(
"Masked LM training with an encoder-decoder is not supported."
)
if
masked_lm_labels
is
not
None
:
if
masked_lm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
# -1 index = padding token
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
# -1 index = padding token
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
...
@@ -851,9 +855,9 @@ class BertForMaskedLM(BertPreTrainedModel):
...
@@ -851,9 +855,9 @@ class BertForMaskedLM(BertPreTrainedModel):
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
# we are doing next-token prediction; shift prediction scores and input ids by one
# we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores
=
prediction_scores
[:,
:
-
1
,
:]
prediction_scores
=
prediction_scores
[:,
:
-
1
,
:]
lm_labels
=
lm_labels
[:,
1
:,
:]
lm_labels
=
lm_labels
[:,
1
:]
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
seq2seq_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
lm_labels
.
view
(
-
1
))
seq2seq_loss
=
loss_fct
(
prediction_scores
.
reshape
(
-
1
,
self
.
config
.
vocab_size
),
lm_labels
.
reshape
(
-
1
))
outputs
=
(
seq2seq_loss
,)
+
outputs
outputs
=
(
seq2seq_loss
,)
+
outputs
return
outputs
# (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)
return
outputs
# (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)
...
...
transformers/modeling_seq2seq.py
View file @
4c3ac4a7
...
@@ -17,13 +17,12 @@
...
@@ -17,13 +17,12 @@
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
logging
import
os
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
.file_utils
import
add_start_docstrings
from
.modeling_auto
import
AutoModel
,
AutoModelWithLMHead
from
.modeling_auto
import
AutoModel
,
AutoModelWithLMHead
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -43,7 +42,13 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -43,7 +42,13 @@ class PreTrainedSeq2seq(nn.Module):
self
.
decoder
=
decoder
self
.
decoder
=
decoder
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
encoder_pretrained_model_name_or_path
=
None
,
decoder_pretrained_model_name_or_path
=
None
,
*
model_args
,
**
kwargs
):
def
from_pretrained
(
cls
,
encoder_pretrained_model_name_or_path
=
None
,
decoder_pretrained_model_name_or_path
=
None
,
*
model_args
,
**
kwargs
):
r
""" Instantiates an encoder and a decoder from one or two base classes
r
""" Instantiates an encoder and a decoder from one or two base classes
of the library from pre-trained model checkpoints.
of the library from pre-trained model checkpoints.
...
@@ -108,23 +113,28 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -108,23 +113,28 @@ class PreTrainedSeq2seq(nn.Module):
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
# decoder-specific it the key starts with `decoder_`
kwargs_decoder
=
{}
kwargs_encoder
=
{
kwargs_encoder
=
kwargs
argument
:
value
for
key
in
kwargs_encoder
.
keys
():
for
argument
,
value
in
kwargs
.
items
()
if
key
.
startswith
(
"decoder_"
):
if
not
argument
.
startswith
(
"decoder_"
)
kwargs_decoder
[
key
.
replace
(
"decoder_"
,
""
)]
=
kwargs_encoder
.
pop
(
key
)
}
kwargs_decoder
=
{
argument
[
len
(
"decoder_"
)
:]:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"decoder_"
)
}
# Load and initialize the encoder and decoder
# Load and initialize the encoder and decoder
#
The distinction between encoder and decoder at the model level is made
# The distinction between encoder and decoder at the model level is made
#
by the value of the flag `is_decoder` that we need to set correctly.
# by the value of the flag `is_decoder` that we need to set correctly.
encoder
=
kwargs
.
pop
(
"encoder_model"
,
None
)
encoder
=
kwargs
_encoder
.
pop
(
"encoder_model"
,
None
)
if
encoder
is
None
:
if
encoder
is
None
:
kwargs_encoder
[
"is_decoder"
]
=
False
kwargs_encoder
[
"is_decoder"
]
=
False
encoder
=
AutoModel
.
from_pretrained
(
encoder
=
AutoModel
.
from_pretrained
(
encoder_pretrained_model_name_or_path
,
*
model_args
,
**
kwargs_encoder
encoder_pretrained_model_name_or_path
,
*
model_args
,
**
kwargs_encoder
)
)
decoder
=
kwargs
.
pop
(
"
decoder
_
model"
,
None
)
decoder
=
kwargs
_
decoder
.
pop
(
"
model"
,
None
)
if
decoder
is
None
:
if
decoder
is
None
:
kwargs_decoder
[
"is_decoder"
]
=
True
kwargs_decoder
[
"is_decoder"
]
=
True
decoder
=
AutoModelWithLMHead
.
from_pretrained
(
decoder
=
AutoModelWithLMHead
.
from_pretrained
(
...
@@ -135,6 +145,12 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -135,6 +145,12 @@ class PreTrainedSeq2seq(nn.Module):
return
model
return
model
def
save_pretrained
(
self
,
save_directory
):
""" Save a Seq2Seq model and its configuration file in a format
such that it can be loaded using `:func:`~transformers.PreTrainedSeq2seq.from_pretrained` """
self
.
encoder
.
save_pretrained
(
os
.
path
.
join
(
save_directory
,
"encoder"
))
self
.
decoder
.
save_pretrained
(
os
.
path
.
join
(
save_directory
,
"decoder"
))
def
forward
(
self
,
encoder_input_ids
,
decoder_input_ids
,
**
kwargs
):
def
forward
(
self
,
encoder_input_ids
,
decoder_input_ids
,
**
kwargs
):
""" The forward pass on a seq2eq depends what we are performing:
""" The forward pass on a seq2eq depends what we are performing:
...
@@ -155,22 +171,29 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -155,22 +171,29 @@ class PreTrainedSeq2seq(nn.Module):
"""
"""
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# Separate the encoder- and decoder- specific kwargs. A kwarg is
# decoder-specific it the key starts with `decoder_`
# decoder-specific it the key starts with `decoder_`
kwargs_decoder
=
{}
kwargs_encoder
=
{
kwargs_encoder
=
kwargs
argument
:
value
for
key
in
kwargs_encoder
.
keys
():
for
argument
,
value
in
kwargs
.
items
()
if
key
.
startswith
(
"decoder_"
):
if
not
argument
.
startswith
(
"decoder_"
)
kwargs_decoder
[
key
.
replace
(
"decoder_"
,
""
)]
=
kwargs_encoder
.
pop
(
key
)
}
kwargs_decoder
=
{
argument
[
len
(
"decoder_"
)
:]:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"decoder_"
)
}
# Encode if needed (training, first prediction pass)
# Encode if needed (training, first prediction pass)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"encoder_hidden_states"
,
None
)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"encoder_hidden_states"
,
None
)
if
encoder_hidden_states
is
None
:
if
encoder_hidden_states
is
None
:
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs_encoder
)
encoder_outputs
=
self
.
encoder
(
encoder_input_ids
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
][
-
1
]
# output of the encoder *stack*
encoder_hidden_states
=
encoder_outputs
[
0
][
-
1
]
# output of the encoder *stack*
else
:
else
:
encoder_outputs
=
()
encoder_outputs
=
()
# Decode
# Decode
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
[
None
,
:,
:]
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
,
**
kwargs_decoder
)
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
,
**
kwargs_decoder
)
return
decoder_outputs
+
encoder_outputs
return
decoder_outputs
+
encoder_outputs
...
@@ -201,9 +224,25 @@ class Model2Model(PreTrainedSeq2seq):
...
@@ -201,9 +224,25 @@ class Model2Model(PreTrainedSeq2seq):
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
args
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
args
,
**
kwargs
):
model
=
super
(
Model2Model
,
cls
).
from_pretrained
(
encoder_pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
decoder_pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
if
(
**
kwargs
)
"bert"
not
in
pretrained_model_name_or_path
or
"roberta"
in
pretrained_model_name_or_path
or
"distilbert"
in
pretrained_model_name_or_path
):
raise
ValueError
(
"Only the Bert model is currently supported."
)
model
=
super
(
Model2Model
,
cls
).
from_pretrained
(
encoder_pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
decoder_pretrained_model_name_or_path
=
pretrained_model_name_or_path
,
**
kwargs
)
# Some architectures require for the decoder to be initialized randomly
# before fine-tuning.
if
kwargs
.
get
(
"decoder_initialize_randomly"
,
False
):
model
.
decoder
.
init_weights
()
return
model
return
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment