Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9660ba1c
"vscode:/vscode.git/clone" did not exist on "aaaed56ffcee3433fa57345b70ff68db8e8bde07"
Commit
9660ba1c
authored
Oct 31, 2019
by
Rémi Louf
Committed by
Julien Chaumond
Dec 09, 2019
Browse files
Add beam search
parent
1c71ecc8
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
594 additions
and
513 deletions
+594
-513
examples/run_summarization_finetuning.py
examples/run_summarization_finetuning.py
+0
-502
examples/utils_summarization.py
examples/utils_summarization.py
+9
-11
transformers/generate/__init__.py
transformers/generate/__init__.py
+1
-0
transformers/generate/beam_search.py
transformers/generate/beam_search.py
+358
-0
transformers/tests/beam_search_tests.py
transformers/tests/beam_search_tests.py
+226
-0
No files found.
examples/run_summarization_finetuning.py
deleted
100644 → 0
View file @
1c71ecc8
# coding=utf-8
# Copyright 2019 The HuggingFace Inc. team.
# Copyright (c) 2019 The HuggingFace Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning seq2seq models for sequence generation."""
import
argparse
import
functools
import
logging
import
os
import
random
import
sys
import
numpy
as
np
from
tqdm
import
tqdm
,
trange
import
torch
from
torch.optim
import
Adam
from
torch.utils.data
import
DataLoader
,
RandomSampler
,
SequentialSampler
from
transformers
import
(
AutoTokenizer
,
BertForMaskedLM
,
BertConfig
,
PreTrainedEncoderDecoder
,
Model2Model
,
)
from
utils_summarization
import
(
CNNDailyMailDataset
,
encode_for_summarization
,
fit_to_block_size
,
build_lm_labels
,
build_mask
,
compute_token_type_ids
,
)
logger
=
logging
.
getLogger
(
__name__
)
logging
.
basicConfig
(
stream
=
sys
.
stdout
,
level
=
logging
.
INFO
)
def
set_seed
(
args
):
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
torch
.
manual_seed
(
args
.
seed
)
# ------------
# Load dataset
# ------------
def
load_and_cache_examples
(
args
,
tokenizer
):
dataset
=
CNNDailyMailDataset
(
tokenizer
,
data_dir
=
args
.
data_dir
)
return
dataset
def
collate
(
data
,
tokenizer
,
block_size
):
""" List of tuple as an input. """
# remove the files with empty an story/summary, encode and fit to block
data
=
filter
(
lambda
x
:
not
(
len
(
x
[
0
])
==
0
or
len
(
x
[
1
])
==
0
),
data
)
data
=
[
encode_for_summarization
(
story
,
summary
,
tokenizer
)
for
story
,
summary
in
data
]
data
=
[
(
fit_to_block_size
(
story
,
block_size
,
tokenizer
.
pad_token_id
),
fit_to_block_size
(
summary
,
block_size
,
tokenizer
.
pad_token_id
),
)
for
story
,
summary
in
data
]
stories
=
torch
.
tensor
([
story
for
story
,
summary
in
data
])
summaries
=
torch
.
tensor
([
summary
for
story
,
summary
in
data
])
encoder_token_type_ids
=
compute_token_type_ids
(
stories
,
tokenizer
.
cls_token_id
)
encoder_mask
=
build_mask
(
stories
,
tokenizer
.
pad_token_id
)
decoder_mask
=
build_mask
(
summaries
,
tokenizer
.
pad_token_id
)
lm_labels
=
build_lm_labels
(
summaries
,
tokenizer
.
pad_token_id
)
return
(
stories
,
summaries
,
encoder_token_type_ids
,
encoder_mask
,
decoder_mask
,
lm_labels
,
)
# ----------
# Optimizers
# ----------
class
BertSumOptimizer
(
object
):
""" Specific optimizer for BertSum.
As described in [1], the authors fine-tune BertSum for abstractive
summarization using two Adam Optimizers with different warm-up steps and
learning rate. They also use a custom learning rate scheduler.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
"""
def
__init__
(
self
,
model
,
lr
,
warmup_steps
,
beta_1
=
0.99
,
beta_2
=
0.999
,
eps
=
1e-8
):
self
.
encoder
=
model
.
encoder
self
.
decoder
=
model
.
decoder
self
.
lr
=
lr
self
.
warmup_steps
=
warmup_steps
self
.
optimizers
=
{
"encoder"
:
Adam
(
model
.
encoder
.
parameters
(),
lr
=
lr
[
"encoder"
],
betas
=
(
beta_1
,
beta_2
),
eps
=
eps
,
),
"decoder"
:
Adam
(
model
.
decoder
.
parameters
(),
lr
=
lr
[
"decoder"
],
betas
=
(
beta_1
,
beta_2
),
eps
=
eps
,
),
}
self
.
_step
=
0
def
_update_rate
(
self
,
stack
):
return
self
.
lr
[
stack
]
*
min
(
self
.
_step
**
(
-
0.5
),
self
.
_step
*
self
.
warmup_steps
[
stack
]
**
(
-
0.5
)
)
def
zero_grad
(
self
):
self
.
optimizer_decoder
.
zero_grad
()
self
.
optimizer_encoder
.
zero_grad
()
def
step
(
self
):
self
.
_step
+=
1
for
stack
,
optimizer
in
self
.
optimizers
.
items
():
new_rate
=
self
.
_update_rate
(
stack
)
for
param_group
in
optimizer
.
param_groups
:
param_group
[
"lr"
]
=
new_rate
optimizer
.
step
()
# ------------
# Train
# ------------
def
train
(
args
,
model
,
tokenizer
):
""" Fine-tune the pretrained model on the corpus. """
set_seed
(
args
)
# Load the data
args
.
train_batch_size
=
args
.
per_gpu_train_batch_size
*
max
(
1
,
args
.
n_gpu
)
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
)
train_sampler
=
RandomSampler
(
train_dataset
)
model_collate_fn
=
functools
.
partial
(
collate
,
tokenizer
=
tokenizer
,
block_size
=
512
)
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
,
collate_fn
=
model_collate_fn
,
)
# Training schedule
if
args
.
max_steps
>
0
:
t_total
=
args
.
max_steps
args
.
num_train_epochs
=
t_total
//
(
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
+
1
)
else
:
t_total
=
(
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
*
args
.
num_train_epochs
)
# Prepare the optimizer
lr
=
{
"encoder"
:
0.002
,
"decoder"
:
0.2
}
warmup_steps
=
{
"encoder"
:
20000
,
"decoder"
:
10000
}
optimizer
=
BertSumOptimizer
(
model
,
lr
,
warmup_steps
)
# Train
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
" Num examples = %d"
,
len
(
train_dataset
))
logger
.
info
(
" Num Epochs = %d"
,
args
.
num_train_epochs
)
logger
.
info
(
" Instantaneous batch size per GPU = %d"
,
args
.
per_gpu_train_batch_size
)
logger
.
info
(
" Total train batch size (w. parallel, distributed & accumulation) = %d"
,
args
.
train_batch_size
*
args
.
gradient_accumulation_steps
# * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
)
logger
.
info
(
" Gradient Accumulation steps = %d"
,
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
model
.
zero_grad
()
train_iterator
=
trange
(
args
.
num_train_epochs
,
desc
=
"Epoch"
,
disable
=
True
)
global_step
=
0
tr_loss
=
0.0
for
_
in
train_iterator
:
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
True
)
for
step
,
batch
in
enumerate
(
epoch_iterator
):
source
,
target
,
encoder_token_type_ids
,
encoder_mask
,
decoder_mask
,
lm_labels
=
batch
source
=
source
.
to
(
args
.
device
)
target
=
target
.
to
(
args
.
device
)
encoder_token_type_ids
=
encoder_token_type_ids
.
to
(
args
.
device
)
encoder_mask
=
encoder_mask
.
to
(
args
.
device
)
decoder_mask
=
decoder_mask
.
to
(
args
.
device
)
lm_labels
=
lm_labels
.
to
(
args
.
device
)
model
.
train
()
outputs
=
model
(
source
,
target
,
encoder_token_type_ids
=
encoder_token_type_ids
,
encoder_attention_mask
=
encoder_mask
,
decoder_attention_mask
=
decoder_mask
,
decoder_lm_labels
=
lm_labels
,
)
loss
=
outputs
[
0
]
print
(
loss
)
if
args
.
gradient_accumulation_steps
>
1
:
loss
/=
args
.
gradient_accumulation_steps
loss
.
backward
()
tr_loss
+=
loss
.
item
()
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
max_grad_norm
)
optimizer
.
step
()
model
.
zero_grad
()
global_step
+=
1
if
args
.
max_steps
>
0
and
global_step
>
args
.
max_steps
:
epoch_iterator
.
close
()
break
if
args
.
max_steps
>
0
and
global_step
>
args
.
max_steps
:
train_iterator
.
close
()
break
return
global_step
,
tr_loss
/
global_step
# ------------
# Train
# ------------
def
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
""
):
set_seed
(
args
)
args
.
eval_batch_size
=
args
.
per_gpu_eval_batch_size
*
max
(
1
,
args
.
n_gpu
)
eval_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
True
)
eval_sampler
=
SequentialSampler
(
eval_dataset
)
eval_dataloader
=
DataLoader
(
eval_dataset
,
sampler
=
eval_sampler
,
batch_size
=
args
.
eval_batch_size
)
# multi-gpu evaluate
if
args
.
n_gpu
>
1
:
model
=
torch
.
nn
.
DataParallel
(
model
)
logger
.
info
(
"***** Running evaluation {} *****"
.
format
(
prefix
))
logger
.
info
(
" Num examples = %d"
,
len
(
eval_dataset
))
logger
.
info
(
" Batch size = %d"
,
args
.
eval_batch_size
)
eval_loss
=
0.0
nb_eval_steps
=
0
model
.
eval
()
for
batch
in
tqdm
(
eval_dataloader
,
desc
=
"Evaluating"
):
source
,
target
,
encoder_token_type_ids
,
encoder_mask
,
decoder_mask
,
lm_labels
=
batch
source
=
source
.
to
(
args
.
device
)
target
=
target
.
to
(
args
.
device
)
encoder_token_type_ids
=
encoder_token_type_ids
.
to
(
args
.
device
)
encoder_mask
=
encoder_mask
.
to
(
args
.
device
)
decoder_mask
=
decoder_mask
.
to
(
args
.
device
)
lm_labels
=
lm_labels
.
to
(
args
.
device
)
with
torch
.
no_grad
():
outputs
=
model
(
source
,
target
,
encoder_token_type_ids
=
encoder_token_type_ids
,
encoder_attention_mask
=
encoder_mask
,
decoder_attention_mask
=
decoder_mask
,
decoder_lm_labels
=
lm_labels
,
)
lm_loss
=
outputs
[
0
]
eval_loss
+=
lm_loss
.
mean
().
item
()
nb_eval_steps
+=
1
eval_loss
=
eval_loss
/
nb_eval_steps
perplexity
=
torch
.
exp
(
torch
.
tensor
(
eval_loss
))
result
=
{
"perplexity"
:
perplexity
}
# Save the evaluation's results
output_eval_file
=
os
.
path
.
join
(
args
.
output_dir
,
"eval_results.txt"
)
if
not
os
.
path
.
exists
(
args
.
output_dir
):
os
.
makedirs
(
args
.
output_dir
)
with
open
(
output_eval_file
,
"w"
)
as
writer
:
logger
.
info
(
"***** Eval results {} *****"
.
format
(
prefix
))
for
key
in
sorted
(
result
.
keys
()):
logger
.
info
(
" %s = %s"
,
key
,
str
(
result
[
key
]))
writer
.
write
(
"%s = %s
\n
"
%
(
key
,
str
(
result
[
key
])))
return
result
def
save_model_checkpoints
(
args
,
model
,
tokenizer
):
if
not
os
.
path
.
exists
(
args
.
output_dir
):
os
.
makedirs
(
args
.
output_dir
)
logger
.
info
(
"Saving model checkpoint to %s"
,
args
.
output_dir
)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save
=
(
model
.
module
if
hasattr
(
model
,
"module"
)
else
model
)
# Take care of distributed/parallel training
model_to_save
.
save_pretrained
(
args
.
output_dir
,
model_type
=
'bert'
)
tokenizer
.
save_pretrained
(
args
.
output_dir
)
torch
.
save
(
args
,
os
.
path
.
join
(
args
.
output_dir
,
"training_arguments.bin"
))
def
main
():
parser
=
argparse
.
ArgumentParser
()
# Required parameters
parser
.
add_argument
(
"--data_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The input training data file (a text file)."
,
)
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The output directory where the model predictions and checkpoints will be written."
,
)
# Optional parameters
parser
.
add_argument
(
"--gradient_accumulation_steps"
,
type
=
int
,
default
=
1
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
,
)
parser
.
add_argument
(
"--do_evaluate"
,
type
=
bool
,
default
=
False
,
help
=
"Run model evaluation on out-of-sample data."
,
)
parser
.
add_argument
(
"--do_train"
,
type
=
bool
,
default
=
False
,
help
=
"Run training."
)
parser
.
add_argument
(
"--do_overwrite_output_dir"
,
type
=
bool
,
default
=
False
,
help
=
"Whether to overwrite the output dir."
,
)
parser
.
add_argument
(
"--model_name_or_path"
,
default
=
"bert-base-cased"
,
type
=
str
,
help
=
"The model checkpoint to initialize the encoder and decoder's weights with."
,
)
parser
.
add_argument
(
"--model_type"
,
default
=
"bert"
,
type
=
str
,
help
=
"The decoder architecture to be fine-tuned."
,
)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
parser
.
add_argument
(
"--max_steps"
,
default
=-
1
,
type
=
int
,
help
=
"If > 0: set total number of training steps to perform. Override num_train_epochs."
,
)
parser
.
add_argument
(
"--to_cpu"
,
default
=
False
,
type
=
bool
,
help
=
"Whether to force training on CPU."
)
parser
.
add_argument
(
"--num_train_epochs"
,
default
=
10
,
type
=
int
,
help
=
"Total number of training epochs to perform."
,
)
parser
.
add_argument
(
"--per_gpu_train_batch_size"
,
default
=
4
,
type
=
int
,
help
=
"Batch size per GPU/CPU for training."
,
)
parser
.
add_argument
(
"--seed"
,
default
=
42
,
type
=
int
)
args
=
parser
.
parse_args
()
if
(
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
)
and
args
.
do_train
and
not
args
.
do_overwrite_output_dir
):
raise
ValueError
(
"Output directory ({}) already exists and is not empty. Use --do_overwrite_output_dir to overwrite."
.
format
(
args
.
output_dir
)
)
# Set up training device
if
args
.
to_cpu
or
not
torch
.
cuda
.
is_available
():
args
.
device
=
torch
.
device
(
"cpu"
)
args
.
n_gpu
=
0
else
:
args
.
device
=
torch
.
device
(
"cuda"
)
args
.
n_gpu
=
torch
.
cuda
.
device_count
()
# Load pretrained model and tokenizer. The decoder's weights are randomly initialized.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_name_or_path
)
config
=
BertConfig
.
from_pretrained
(
args
.
model_name_or_path
)
decoder_model
=
BertForMaskedLM
(
config
)
model
=
Model2Model
.
from_pretrained
(
args
.
model_name_or_path
,
decoder_model
=
decoder_model
)
# Setup logging
logging
.
basicConfig
(
format
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
level
=
logging
.
INFO
,
)
logger
.
warning
(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s"
,
0
,
args
.
device
,
args
.
n_gpu
,
False
,
False
,
)
logger
.
info
(
"Training/evaluation parameters %s"
,
args
)
# Train the model
model
.
to
(
args
.
device
)
if
args
.
do_train
:
try
:
global_step
,
tr_loss
=
train
(
args
,
model
,
tokenizer
)
except
KeyboardInterrupt
:
response
=
input
(
"You interrupted the training. Do you want to save the model checkpoints? [Y/n]"
)
if
response
.
lower
()
in
[
""
,
"y"
,
"yes"
]:
save_model_checkpoints
(
args
,
model
,
tokenizer
)
sys
.
exit
(
0
)
logger
.
info
(
" global_step = %s, average loss = %s"
,
global_step
,
tr_loss
)
save_model_checkpoints
(
args
,
model
,
tokenizer
)
# Evaluate the model
results
=
{}
if
args
.
do_evaluate
:
checkpoints
=
[
args
.
output_dir
]
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
for
checkpoint
in
checkpoints
:
encoder_checkpoint
=
os
.
path
.
join
(
checkpoint
,
"bert_encoder"
)
decoder_checkpoint
=
os
.
path
.
join
(
checkpoint
,
"bert_decoder"
)
model
=
PreTrainedEncoderDecoder
.
from_pretrained
(
encoder_checkpoint
,
decoder_checkpoint
)
model
.
to
(
args
.
device
)
print
(
"model loaded"
)
return
results
if
__name__
==
"__main__"
:
main
()
examples/utils_summarization.py
View file @
9660ba1c
...
@@ -25,9 +25,8 @@ class CNNDailyMailDataset(Dataset):
...
@@ -25,9 +25,8 @@ class CNNDailyMailDataset(Dataset):
[2] https://github.com/abisee/cnn-dailymail/
[2] https://github.com/abisee/cnn-dailymail/
"""
"""
def
__init__
(
self
,
tokenizer
,
prefix
=
"train"
,
data_dir
=
""
):
def
__init__
(
self
,
data_dir
=
""
,
prefix
=
"train"
):
assert
os
.
path
.
isdir
(
data_dir
)
assert
os
.
path
.
isdir
(
data_dir
)
self
.
tokenizer
=
tokenizer
# We initialize the class by listing all the files that contain
# We initialize the class by listing all the files that contain
# stories and summaries. Files are not read in memory given
# stories and summaries. Files are not read in memory given
...
@@ -104,31 +103,30 @@ def _add_missing_period(line):
...
@@ -104,31 +103,30 @@ def _add_missing_period(line):
# --------------------------
# --------------------------
def
fit_to_block_size
(
sequence
,
block_size
,
pad_token
):
def
fit_to_block_size
(
sequence
,
block_size
,
pad_token
_id
):
""" Adapt the source and target sequences' lengths to the block size.
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter than the block size we pad it with -1 ids
If the sequence is shorter we append padding token to the right of the sequence.
which correspond to padding tokens.
"""
"""
if
len
(
sequence
)
>
block_size
:
if
len
(
sequence
)
>
block_size
:
return
sequence
[:
block_size
]
return
sequence
[:
block_size
]
else
:
else
:
sequence
.
extend
([
pad_token
]
*
(
block_size
-
len
(
sequence
)))
sequence
.
extend
([
pad_token
_id
]
*
(
block_size
-
len
(
sequence
)))
return
sequence
return
sequence
def
build_lm_labels
(
sequence
,
pad_token
):
def
build_lm_labels
(
sequence
,
pad_token
_id
):
""" Padding token
, encoded as 0,
are rep
resent
ed by the value -1 so they
""" Padding token are rep
lac
ed by the value -1 so they
are not taken into account in the loss computation. """
are not taken into account in the loss computation. """
padded
=
sequence
.
clone
()
padded
=
sequence
.
clone
()
padded
[
padded
==
pad_token
]
=
-
1
padded
[
padded
==
pad_token
_id
]
=
-
1
return
padded
return
padded
def
build_mask
(
sequence
,
pad_token
):
def
build_mask
(
sequence
,
pad_token
_id
):
""" Builds the mask. The attention mechanism will only attend to positions
""" Builds the mask. The attention mechanism will only attend to positions
with value 1. """
with value 1. """
mask
=
torch
.
ones_like
(
sequence
)
mask
=
torch
.
ones_like
(
sequence
)
idx_pad_tokens
=
sequence
==
pad_token
idx_pad_tokens
=
sequence
==
pad_token
_id
mask
[
idx_pad_tokens
]
=
0
mask
[
idx_pad_tokens
]
=
0
return
mask
return
mask
...
...
transformers/generate/__init__.py
0 → 100644
View file @
9660ba1c
from
.beam_search
import
BeamSearch
transformers/
modeling_
beam_search.py
→
transformers/
generate/
beam_search.py
View file @
9660ba1c
# coding=utf-8
# coding=utf-8
#
Copyright (c) 2019 Yang Liu
#
MIT License
# Permission is hereby granted, free of charge, to any person obtaining a copy
# Copyright (c) 2017-Present OpenNMT
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# this software and associated documentation files (the "Software"), to deal in
# copies of the Software, and to permit persons to whom the Software is
# the Software without restriction, including without limitation the rights to
# furnished to do so, subject to the following conditions:
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# copies or substantial portions of the Software.
...
@@ -19,69 +21,161 @@
...
@@ -19,69 +21,161 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# SOFTWARE.
"""
"""
A general wrapper around models with LM heads to generate sequences
Use Beam Search to generate sequences using encoder-decoder models.
using beam search.
"""
"""
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
class
Transformer
BeamSearch
(
nn
.
Module
):
class
BeamSearch
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
model
,
model
,
tokenizer
,
tokenizer
,
batch_size
,
beam_size
,
beam_size
,
min_length
,
min_length
,
max_length
,
max_length
,
batch_size
=
1
,
alpha
=
0
,
alpha
=
0
,
block_repeating_trigram
=
True
,
block_repeating_trigram
s
=
True
,
):
):
r
"""
Inputs:
**model**: instance of ``transformers.PreTrainedEncoderDecoder``
The pretrained encoder-decoder model that will be used to generate the sequences.
**tokenizer**: instance of ``transformers.PreTrainedTokenizer``
The pretrained tokenizer associated to the model used in the encoder-decoder. We only
support encoder-decoder that use the same tokenizer for encoder and decoder. The tokenizer
needs to be initialized or this function will raise and exception.
**batch_size**: (`optional`) int
Batch size of the inputs. The value is set automatically when calling `forward`.
**beam_size**: int
Number of beams that are used for each element on the batch.
**min_length**: int
Minimum number of steps performed by the beam search before terminating.
**max_length**: int
Maximum number of steps performed by the beam search. Any beam that has not finished
will return its current solution with the highest probability. The sequence that is
returned has a length of max_length-1 to account for the end token that is subsequently added.
**alpha**: float
Parameter of the length penalty. Read the documentation of the `_length_penalty` method for mode details.
**block_repeating_trigrams**: bool
Whether to block sequences that have repeating 3-grams.
"""
"""
Attributes:
super
(
BeamSearch
,
self
).
__init__
()
mask_word_id: token id that corresponds to the mask
"""
super
(
TransformerBeamSearch
,
self
).
__init__
()
self
.
model
=
model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
start
_token_id
=
tokenizer
.
start
_token_id
self
.
bos
_token_id
=
tokenizer
.
bos
_token_id
self
.
e
nd
_token_id
=
tokenizer
.
e
nd
_token_id
self
.
e
os
_token_id
=
tokenizer
.
e
os
_token_id
self
.
pad_token_id
=
tokenizer
.
pad_token_id
self
.
pad_token_id
=
tokenizer
.
pad_token_id
self
.
batch_size
=
batch_size
self
.
beam_size
=
beam_size
self
.
beam_size
=
beam_size
self
.
min_length
=
min_length
self
.
min_length
=
min_length
self
.
max_length
=
max_length
self
.
max_length
=
max_length
self
.
block_repeating_trigram
=
block_repeating_trigram
self
.
block_repeating_trigram
=
block_repeating_trigram
s
self
.
apply_length_penalty
=
False
if
alpha
==
0
else
True
self
.
apply_length_penalty
=
False
if
alpha
==
0
else
True
self
.
alpha
=
alpha
self
.
alpha
=
alpha
# State of the beam
self
.
_init_beam_state
(
batch_size
)
def
__len__
(
self
):
try
:
return
self
.
growing_beams
.
size
(
1
)
except
NameError
:
return
0
def
_init_beam_state
(
self
,
batch_size
):
""" (re-)Initialize the state of the beams. """
self
.
hypotheses
=
[[]
for
_
in
range
(
batch_size
)]
self
.
hypotheses
=
[[]
for
_
in
range
(
batch_size
)]
self
.
batch_offset
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
long
)
self
.
batch_offset
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
long
)
self
.
beam_offset
=
torch
.
arange
(
self
.
beam_offset
=
torch
.
arange
(
0
,
batch_size
*
self
.
beam_size
,
step
=
self
.
beam_size
,
dtype
=
torch
.
long
0
,
batch_size
*
self
.
beam_size
,
step
=
self
.
beam_size
,
dtype
=
torch
.
long
)
)
self
.
growing_beam
=
torch
.
full
(
self
.
growing_beam
s
=
torch
.
full
(
(
batch_size
*
self
.
beam_size
,
1
),
self
.
start
_token_id
,
dtype
=
torch
.
long
(
batch_size
*
self
.
beam_size
,
1
),
self
.
bos
_token_id
,
dtype
=
torch
.
long
)
)
self
.
topk_log_probabilities
=
torch
.
tensor
(
self
.
topk_log_probabilities
=
torch
.
tensor
(
[
0.0
]
+
[
float
(
"-inf"
)]
*
(
self
.
beam_size
-
1
),
dtype
=
torch
.
float
[
0.0
]
+
[
float
(
"-inf"
)]
*
(
self
.
beam_size
-
1
),
dtype
=
torch
.
float
).
repeat
(
batch_size
)
).
repeat
(
batch_size
)
self
.
results
=
{
self
.
results
=
{
"prediction"
:
[[]
for
_
in
batch_size
],
"prediction
s
"
:
[[]
for
_
in
range
(
batch_size
)
],
"scores"
:
[[]
for
_
in
batch_size
],
"scores"
:
[[]
for
_
in
range
(
batch_size
)
],
}
}
self
.
_step
=
0
self
.
_step
=
0
self
.
is_done
=
False
self
.
is_done
=
False
def
step
(
self
,
log_probabilities
):
def
forward
(
self
,
encoder_input_ids
,
**
model_kwargs
):
""" Grows the beam by one step. """
""" Generate a sequence using Beam Search. """
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_common
=
{
argument
:
value
for
argument
,
value
in
model_kwargs
.
items
()
if
not
argument
.
startswith
(
"encoder_"
)
and
not
argument
.
startswith
(
"decoder_"
)
}
kwargs_decoder
=
kwargs_common
.
copy
()
kwargs_encoder
=
kwargs_common
.
copy
()
kwargs_encoder
.
update
(
{
argument
[
len
(
"encoder_"
)
:]:
value
for
argument
,
value
in
model_kwargs
.
items
()
if
argument
.
startswith
(
"encoder_"
)
}
)
kwargs_decoder
.
update
(
{
argument
[
len
(
"decoder_"
)
:]:
value
for
argument
,
value
in
model_kwargs
.
items
()
if
argument
.
startswith
(
"decoder_"
)
}
)
# forward pass on the encoder
encoder_outputs
=
self
.
model
.
encoder
.
forward
(
encoder_input_ids
,
kwargs_encoder
)
kwargs_decoder
[
"encoder_hidden_states"
]
=
tile
(
encoder_outputs
,
self
.
beam_size
,
dim
=
0
)
# grow the beam by generating sequences in an autoregressive way
batch_size
=
encoder_input_ids
.
size
(
0
)
self
.
_init_beam_state
(
batch_size
)
for
step
in
range
(
self
.
max_length
):
# prepare the decoder input
decoder_input
=
fit_to_block_size
(
self
.
growing_beams
,
self
.
tokenizer
.
pad_token_id
)
kwargs_decoder
[
"decoder_lm_labels"
]
=
build_lm_labels
(
decoder_input
,
self
.
tokenizer
.
pad_token_id
)
kwargs_decoder
[
"decoder_attention_mask"
]
=
build_mask
(
decoder_input
,
self
.
tokenizer
.
pad_token_id
)
outputs
=
self
.
model
.
decoder
(
decoder_input
,
kwargs_decoder
)
log_probabilities
=
torch
.
nn
.
functional
.
log_softmax
(
outputs
[
1
])
surviving_beams_rows
=
self
.
grow
(
log_probabilities
)
if
self
.
is_done
:
break
kwargs_decoder
[
"encoder_hidden_states"
]
=
kwargs_decoder
[
"encoder_hidden_states"
].
index_select
(
0
,
surviving_beams_rows
)
kwargs_decoder
[
"encoder_attention_mask"
]
=
kwargs_decoder
[
"encoder_attention_mask"
].
index_select
(
0
,
surviving_beams_rows
)
return
self
.
results
def
grow
(
self
,
log_probabilities
):
""" Grow the beams by one step. """
self
.
_step
+=
1
self
.
_step
+=
1
# The
batch size
changes as some beams finish so we define _B
# The
number of beams
changes as some beams finish so we define _B
vocab_size
=
log_probabilities
.
size
(
-
1
)
vocab_size
=
log_probabilities
.
size
(
-
1
)
_B
=
log_probabilities
.
size
(
0
)
//
self
.
beam_size
_B
=
log_probabilities
.
size
(
0
)
//
self
.
beam_size
...
@@ -89,21 +183,21 @@ class TransformerBeamSearch(nn.Module):
...
@@ -89,21 +183,21 @@ class TransformerBeamSearch(nn.Module):
# next token (conditioned on the words in the beam).
# next token (conditioned on the words in the beam).
log_probabilities
+=
self
.
topk_log_probabilities
.
view
(
-
1
,
1
)
log_probabilities
+=
self
.
topk_log_probabilities
.
view
(
-
1
,
1
)
self
.
enforce_min_length
(
log_probabilities
)
self
.
_
enforce_min_length
(
log_probabilities
)
if
self
.
block_repeating_trigram
:
if
self
.
block_repeating_trigram
:
self
.
remove_repeating_trigrams
(
log_probabilities
,
_B
)
self
.
_
remove_
beams_with_
repeating_trigrams
(
log_probabilities
,
_B
)
# Find the `beam_size` (previous_beam + token) combinations with
# Find the `beam_size` (previous_beam + token) combinations with
# the highest score
# the highest score
topk_log_probabilities
,
topk_ids
=
log_probabilities
.
topk
(
topk_log_probabilities
,
topk_ids
=
torch
.
topk
(
log_probabilities
.
view
(
_B
,
self
.
beam_size
*
vocab_size
),
log_probabilities
.
view
(
_B
,
self
.
beam_size
*
vocab_size
),
self
.
beam_size
,
dim
=
1
self
.
beam_size
,
dim
=
1
,
)
)
# Apply the length penalty. The +1 accounts for the [EOS] token
# Apply the length penalty. The +1 accounts for the [EOS] token
# that will be added if the beam ends.
# that will be added if the beam ends.
topk_scores
=
topk_log_probabilities
/
self
.
length_penalty
()
topk_scores
=
topk_log_probabilities
if
self
.
apply_length_penalty
:
topk_scores
/=
self
.
_length_penalty
()
# Retrieve the corresponding respective beam and token id
# Retrieve the corresponding respective beam and token id
# topk_token_ids[i] will be added to topk_beam_ids[i]
# topk_token_ids[i] will be added to topk_beam_ids[i]
...
@@ -112,14 +206,13 @@ class TransformerBeamSearch(nn.Module):
...
@@ -112,14 +206,13 @@ class TransformerBeamSearch(nn.Module):
# Retrieve the row index of the surviving beams in the original
# Retrieve the row index of the surviving beams in the original
# view of the log_probabilities tensor
# view of the log_probabilities tensor
surviving_beams_rows
=
(
topk_beam_ids
+
self
.
beam_offset
[:
_B
].
view
(
-
1
,
1
)).
view
(
surviving_beams_per_batch
=
topk_beam_ids
+
self
.
beam_offset
[:
_B
].
view
(
-
1
,
1
)
-
1
surviving_beams_rows
=
surviving_beams_per_batch
.
view
(
-
1
)
)
# Append the last predictions
# Append the last predictions
self
.
growing_beam
=
torch
.
cat
(
self
.
growing_beam
s
=
torch
.
cat
(
[
[
self
.
growing_beam
.
index_select
(
0
,
surviving_beams_rows
),
self
.
growing_beam
s
.
index_select
(
0
,
surviving_beams_rows
),
topk_token_ids
.
view
(
-
1
,
1
),
topk_token_ids
.
view
(
-
1
,
1
),
],
],
1
,
1
,
...
@@ -128,21 +221,38 @@ class TransformerBeamSearch(nn.Module):
...
@@ -128,21 +221,38 @@ class TransformerBeamSearch(nn.Module):
# Check if any of the beam searches has ended during this
# Check if any of the beam searches has ended during this
# growth step. Also if top beam (most probable) has ended
# growth step. Also if top beam (most probable) has ended
# for one element of the batch.
# for one element of the batch.
is_finished
=
topk_token_ids
.
eq
(
self
.
end_token_id
)
is_finished
=
topk_token_ids
.
eq
(
self
.
eos_token_id
)
self
.
enforce_max_length
()
self
.
_enforce_max_length
(
is_finished
)
is_top_beam_finished
=
is_finished
[:,
0
].
eq
(
1
)
if
is_finished
.
any
():
non_finished
=
self
.
_cut_finished
(
is_finished
,
topk_scores
)
self
.
batch_offset
=
self
.
batch_offset
.
index_select
(
0
,
non_finished
)
surviving_beams_per_batch
=
surviving_beams_per_batch
.
index_select
(
0
,
non_finished
)
self
.
topk_log_probabilities
=
self
.
topk_log_probabilities
.
index_select
(
0
,
non_finished
)
surviving_beams_rows
=
surviving_beams_per_batch
.
view
(
-
1
)
self
.
growing_beams
=
self
.
growing_beams
.
index_select
(
0
,
surviving_beams_rows
)
return
surviving_beams_rows
def
_cut_finished
(
self
,
is_finished
,
topk_scores
):
""" Save the finished searches and cut the correponding sequences off
the beams. """
is_top_beam_finished
=
is_finished
[:,
0
].
eq
(
True
)
# Save the finished searches
# Save the finished searches
if
is_finished
.
any
():
predictions
=
self
.
growing_beams
.
view
(
predictions
=
self
.
growing_beam
.
view
(
-
1
,
self
.
beam_size
,
self
.
growing_beams
.
size
(
1
)
-
1
,
self
.
beam_size
,
self
.
growing_beam
.
size
(
1
)
)
)
for
i
in
range
(
is_finished
.
size
(
0
)):
for
i
in
range
(
is_finished
.
size
(
0
)):
if
is_top_beam_finished
[
i
]:
if
is_top_beam_finished
[
i
]:
is_finished
[
i
].
fill_
(
1
)
is_finished
[
i
].
fill_
(
1
)
finished_hyp
=
is_finished
[
i
].
nonzero
().
view
(
-
1
)
finished_hyp
=
is_finished
[
i
].
nonzero
().
view
(
-
1
)
# Store finished
hypotheses for this batch
.
# Store
the
finished
beams as a (score, prediction) hypothesis
.
b
=
self
.
batch_offset
[
i
]
b
=
self
.
batch_offset
[
i
]
for
j
in
finished_hyp
:
for
j
in
finished_hyp
:
self
.
hypotheses
[
b
].
append
((
topk_scores
[
i
,
j
],
predictions
[
i
,
j
,
:]))
self
.
hypotheses
[
b
].
append
((
topk_scores
[
i
,
j
],
predictions
[
i
,
j
,
:]))
...
@@ -150,95 +260,44 @@ class TransformerBeamSearch(nn.Module):
...
@@ -150,95 +260,44 @@ class TransformerBeamSearch(nn.Module):
# If the batch reached the end, save the best hypotheses
# If the batch reached the end, save the best hypotheses
# in terms of length-penalized score.
# in terms of length-penalized score.
if
is_top_beam_finished
[
i
]:
if
is_top_beam_finished
[
i
]:
best_hyp
=
sorted
(
best_score
,
best_prediction
=
max
(
self
.
hypotheses
[
b
],
key
=
lambda
x
:
x
[
0
])
self
.
hypotheses
[
b
],
key
=
lambda
x
:
x
[
0
],
reverse
=
True
)
best_score
,
best_prediction
=
best_hyp
[
0
]
self
.
results
[
"scores"
][
b
].
append
(
best_score
)
self
.
results
[
"scores"
][
b
].
append
(
best_score
)
self
.
results
[
"predictions"
][
b
].
append
(
best_prediction
)
self
.
results
[
"predictions"
][
b
].
append
(
best_prediction
)
non_finished
=
is_top_beam_finished
.
eq
(
0
).
nonzero
().
view
(
-
1
)
non_finished
=
is_top_beam_finished
.
eq
(
False
).
nonzero
().
view
(
-
1
)
if
len
(
non_finished
)
==
0
:
if
len
(
non_finished
)
==
0
:
self
.
is_done
=
True
self
.
is_done
=
True
# Remove finished batches for the next step.
return
non_finished
topk_log_probabilities
=
topk_log_probabilities
.
index_select
(
0
,
non_finished
)
self
.
batch_offset
=
self
.
batch_offset
.
index_select
(
0
,
non_finished
)
self
.
growing_beam
=
predictions
.
index_select
(
0
,
non_finished
).
view
(
-
1
,
self
.
growing_beam
.
size
(
-
1
)
)
surviving_beams_rows
=
surviving_beams_rows
.
index_select
(
0
,
non_finished
)
return
surviving_beams_rows
def
_remove_beams_with_repeating_trigrams
(
self
,
log_probabilities
,
_B
):
if
self
.
_step
+
1
>
3
:
# [BOS] does not count
def
forward
(
self
,
encoder_input_ids
,
**
kwargs
):
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder
=
{
argument
[
len
(
"encoder_"
):]:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"encoder_"
)
}
kwargs_decoder
=
{
argument
[
len
(
"decoder_"
):]:
value
for
argument
,
value
in
kwargs
.
items
()
if
argument
.
startswith
(
"decoder_"
)
}
kwargs_common
=
{
argument
:
value
for
argument
,
value
in
kwargs
.
items
()
if
not
(
argument
.
startswith
(
"encoder_"
)
or
argument
.
startswith
(
"decoder_"
))
}
kwargs_decoder
=
dict
(
kwargs_common
,
**
kwargs_decoder
)
kwargs_encoder
=
dict
(
kwargs_common
,
**
kwargs_encoder
)
# forward pass on the encoder
encoder_outputs
=
self
.
model
.
encoder
.
forward
(
encoder_input_ids
,
kwargs_encoder
)
kwargs_decoder
[
"encoder_hidden_states"
]
=
tile
(
encoder_outputs
,
self
.
beam_size
,
dim
=
0
)
# grow the beam by generating sequences in an autoregressive way
self
.
growing_beam
=
torch
.
full
(
(
self
.
batch_size
*
self
.
beam_size
,
1
),
self
.
start_token_id
,
dtype
=
torch
.
long
)
for
step
in
range
(
self
.
max_length
):
decoder_input
=
self
.
growing_beam
[:,
-
1
]
outputs
=
self
.
model
.
decoder
(
decoder_input
,
kwargs_decoder
)
log_probabilities
=
torch
.
nn
.
functional
.
log_softmax
(
outputs
[
1
])
surviving_beams_rows
=
self
.
step
(
log_probabilities
)
if
self
.
is_done
:
break
kwargs_decoder
[
"encoder_hidden_states"
]
=
kwargs_decoder
[
"encoder_hidden_states"
].
index_select
(
0
,
surviving_beams_rows
)
return
self
.
results
def
remove_repeating_trigrams
(
self
,
log_probabilities
,
_B
):
if
(
self
.
_step
+
1
>
3
):
for
i
in
range
(
_B
*
self
.
beam_size
):
for
i
in
range
(
_B
*
self
.
beam_size
):
tokens
=
[
t
for
t
in
self
.
growing_beam
[
i
]]
tokens
=
self
.
growing_beams
[
i
]
trigrams
=
[(
tokens
[
i
-
1
],
tokens
[
i
],
tokens
[
i
+
1
])
for
i
in
range
(
1
,
len
(
words
)
-
1
)]
trigrams
=
[
(
tokens
[
j
-
1
],
tokens
[
j
],
tokens
[
j
+
1
])
for
j
in
range
(
1
,
len
(
self
)
-
1
)
]
last_trigram
=
tuple
(
trigrams
[
-
1
])
last_trigram
=
tuple
(
trigrams
[
-
1
])
if
last_trigram
in
trigrams
[:
-
1
]:
if
last_trigram
in
trigrams
[:
-
1
]:
log_probabilities
[
i
]
=
-
1e20
log_probabilities
[
i
]
=
-
1e20
def
enforce_min_length
(
self
):
def
_
enforce_min_length
(
self
,
log_probabilities
):
if
self
.
_step
<
self
.
min_length
:
if
self
.
_step
<
self
.
min_length
:
self
.
log_probabilities
[
self
.
e
nd
_token_id
]
=
-
1e20
log_probabilities
[
:,
self
.
e
os
_token_id
]
=
-
1e20
def
enforce_max_length
(
self
):
def
_enforce_max_length
(
self
,
is_finished
):
# +1 because we will need to add an [EOS] token
if
self
.
_step
+
1
==
self
.
max_length
:
if
self
.
_step
+
1
==
self
.
max_length
:
self
.
is_finished
.
fill_
(
1
)
is_finished
.
fill_
(
1
)
def
_length_penalty
(
self
):
""" The calculation of the length penalty follows that of [1].
def
length_penalty
(
self
):
[1] Wu, Yonghui, et al. "Google's neural machine translation system:
Bridging the gap between human and machine translation." arXiv preprint
arXiv:1609.08144 (2016).
"""
return
((
5.0
+
(
self
.
_step
+
1
))
/
6.0
)
**
self
.
alpha
return
((
5.0
+
(
self
.
_step
+
1
))
/
6.0
)
**
self
.
alpha
...
@@ -269,3 +328,31 @@ def tile(x, count, dim=0):
...
@@ -269,3 +328,31 @@ def tile(x, count, dim=0):
if
dim
!=
0
:
if
dim
!=
0
:
x
=
x
.
permute
(
perm
).
contiguous
()
x
=
x
.
permute
(
perm
).
contiguous
()
return
x
return
x
def
fit_to_block_size
(
sequence
,
block_size
,
pad_token_id
):
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter we append padding tokens to the right.
"""
if
len
(
sequence
)
>
block_size
:
return
sequence
[:
block_size
]
else
:
sequence
.
extend
([
pad_token_id
]
*
(
block_size
-
len
(
sequence
)))
return
sequence
def
build_lm_labels
(
sequence
,
pad_token_id
):
""" Padding token, encoded as 0, are represented by the value -1 so they
are not taken into account in the loss computation. """
padded
=
sequence
.
clone
()
padded
[
padded
==
pad_token_id
]
=
-
1
return
padded
def
build_mask
(
sequence
,
pad_token_id
):
""" Builds the mask. The attention mechanism will only attend to positions
with value 1. """
mask
=
torch
.
ones_like
(
sequence
)
idx_pad_tokens
=
sequence
==
pad_token_id
mask
[
idx_pad_tokens
]
=
0
return
mask
transformers/tests/beam_search_tests.py
0 → 100644
View file @
9660ba1c
from
collections
import
namedtuple
import
unittest
import
numpy
as
np
import
torch
from
transformers.generate
import
BeamSearch
from
transformers
import
PreTrainedEncoderDecoder
StubTokenizer
=
namedtuple
(
"Tokenizer"
,
[
"bos_token_id"
,
"eos_token_id"
,
"pad_token_id"
])
StubTransformer
=
namedtuple
(
"Transformer"
,
[
"encoder"
,
"decoder"
])
class
BeamSearchtest
(
unittest
.
TestCase
):
def
test_beam_search_encoder_decoder_integration
(
self
):
""" We make sure that no internal change in the PreTrainedEncoderDecoder
class will break the integration with the beam search.
"""
model
=
PreTrainedEncoderDecoder
(
"encoder"
,
"decoder"
)
tokenizer
=
StubTokenizer
(
0
,
1
,
2
)
try
:
_
=
BeamSearch
(
model
=
model
,
tokenizer
=
tokenizer
,
batch_size
=
1
,
beam_size
=
1
,
min_length
=
1
,
max_length
=
1
,
alpha
=
0
,
block_repeating_trigrams
=
False
,
)
except
:
self
.
fail
(
"Instantiating BeamSearch with a PreTrainedEncoderDecoder failed."
)
def
test_beam_search_min_length
(
self
):
""" We keep predicting the end_token for the first beam and check that
it is not marked as finished until the beam has reached the minimum
length. """
eos_idx
=
3
vocab_size
=
10
batch_size
=
3
beam_size
=
2
min_length
=
5
beam
=
BeamSearch
(
model
=
StubTransformer
(
"encoder"
,
"decoder"
),
tokenizer
=
StubTokenizer
(
bos_token_id
=
0
,
eos_token_id
=
eos_idx
,
pad_token_id
=
2
),
batch_size
=
batch_size
,
beam_size
=
beam_size
,
min_length
=
5
,
max_length
=
10
,
alpha
=
0
,
block_repeating_trigrams
=
False
,
)
# To test that the minimum length is correctly enforced we constantly
# assign the highest probability to the [EOS] token (and assign lower
# probabilities to some other tokens).
# Since BeamSearch will reset its probability to 1e-20 as long as
# min_length has not been reached, we need to reset the value between
# steps.
non_eos_idxs
=
[
4
,
5
,
1
,
8
,
9
]
score_distribution
=
torch
.
log_softmax
(
torch
.
tensor
([
6.0
,
5.0
,
4.0
,
3.0
,
2.0
,
1.0
]),
dim
=
0
)
log_probabilities
=
torch
.
full
((
batch_size
*
beam_size
,
vocab_size
),
float
(
"-inf"
))
log_probabilities
[
0
,
eos_idx
]
=
score_distribution
[
0
]
for
idx
,
score
in
zip
(
non_eos_idxs
,
score_distribution
[
1
:]):
log_probabilities
[
0
,
idx
]
=
score
for
step
in
range
(
1
,
min_length
+
2
):
log_probabilities
[
0
,
eos_idx
]
=
score_distribution
[
0
]
# Beam #3 and #4 teminate at the first step since the probability
# of the [EOS] token is -1e20 > -\infty so there are only two beams left.
surviving_beams_rows
=
beam
.
grow
(
log_probabilities
)
if
step
<
min_length
:
np
.
testing
.
assert_array_equal
(
beam
.
growing_beams
.
numpy
(),
np
.
repeat
(
np
.
array
([[
0
]
+
[
4
]
*
step
]),
2
,
axis
=
0
),
)
elif
step
==
min_length
:
np
.
testing
.
assert_array_equal
(
surviving_beams_rows
.
numpy
(),
np
.
array
([]))
self
.
assertTrue
(
beam
.
is_done
)
break
log_probabilities
=
log_probabilities
.
index_select
(
0
,
surviving_beams_rows
)
def
test_beam_search_max_length
(
self
):
""" We keep predicting the same non-EOS token until we reach the
maximum permitted length """
batch_size
=
3
beam_size
=
2
max_length
=
5
vocab_size
=
10
beam
=
BeamSearch
(
model
=
StubTransformer
(
"encoder"
,
"decoder"
),
tokenizer
=
StubTokenizer
(
bos_token_id
=
0
,
eos_token_id
=
1
,
pad_token_id
=
2
),
batch_size
=
batch_size
,
beam_size
=
beam_size
,
min_length
=
2
,
max_length
=
max_length
,
alpha
=
0
,
block_repeating_trigrams
=
False
,
)
log_probabilities
=
torch
.
full
((
batch_size
*
beam_size
,
vocab_size
),
float
(
"-inf"
))
# To test that beam search enforces the max length constraint we
# keep giving the highest probability to a token that is not the
# [EOS] token.
# The beam search will stop at max_length-1, assuming that one would
# add the [EOS] token at the end of the returned sequence.
token_idxs
=
[
3
,
4
,
5
]
score_distribution
=
torch
.
log_softmax
(
torch
.
tensor
([
10.0
,
6.0
,
4.0
]),
dim
=
0
)
for
idx
,
score
in
zip
(
token_idxs
,
score_distribution
):
log_probabilities
[:,
idx
]
=
score
for
step
in
range
(
1
,
max_length
+
2
):
surviving_beams_rows
=
beam
.
grow
(
log_probabilities
)
if
step
+
1
<
max_length
:
self
.
assertFalse
(
beam
.
is_done
)
elif
step
+
1
==
max_length
:
# Now [EOS] is the most probable token
np
.
testing
.
assert_array_equal
(
surviving_beams_rows
.
numpy
(),
np
.
array
([]))
self
.
assertTrue
(
beam
.
is_done
)
break
log_probabilities
=
log_probabilities
.
index_select
(
0
,
surviving_beams_rows
)
def
test_beam_search_block_repeating_trigrams
(
self
):
""" We make sure that the beams that contain repeating trigrams are removed. """
batch_size
=
3
beam_size
=
2
max_length
=
10
vocab_size
=
10
beam
=
BeamSearch
(
model
=
StubTransformer
(
"encoder"
,
"decoder"
),
tokenizer
=
StubTokenizer
(
bos_token_id
=
0
,
eos_token_id
=
1
,
pad_token_id
=
2
),
batch_size
=
batch_size
,
beam_size
=
beam_size
,
min_length
=
2
,
max_length
=
max_length
,
alpha
=
0
,
block_repeating_trigrams
=
True
,
)
log_probabilities
=
torch
.
full
((
batch_size
*
beam_size
,
vocab_size
),
float
(
"-inf"
))
# To test that BeamSearch enforces the 3-gram constraint we give the
# highest probably to the same tokens in a cyclic fashion and make sure
# they disappear once the cycle has completed.
token_idxs
=
[
3
,
4
,
5
]
score_distribution
=
torch
.
log_softmax
(
torch
.
tensor
([
10.0
,
6.0
,
4.0
]),
dim
=
0
)
for
idx
,
score
in
zip
(
token_idxs
,
score_distribution
):
log_probabilities
[:,
idx
]
=
score
for
step
in
range
(
1
,
max_length
+
2
):
# Rotate the probabilities at each step
for
idx
in
token_idxs
:
score
=
score_distribution
[(
idx
+
step
)
%
3
]
log_probabilities
[::
beam_size
,
idx
]
=
score
surviving_beams_rows
=
beam
.
grow
(
log_probabilities
)
log_probabilities
=
log_probabilities
.
index_select
(
0
,
surviving_beams_rows
)
if
step
<
7
:
self
.
assertFalse
(
np
.
array_equal
(
log_probabilities
.
numpy
()[
0
,
:],
np
.
array
([
-
1e20
]
*
vocab_size
,
dtype
=
"float32"
),
)
)
if
step
==
7
:
np
.
testing
.
assert_array_equal
(
log_probabilities
.
numpy
()[
0
,
:],
np
.
array
([
-
1e20
]
*
vocab_size
,
dtype
=
"float32"
),
)
def
test_beam_search_example_for_one_step
(
self
):
""" We test that the predictions for one step of growth are correct. """
batch_size
=
2
beam_size
=
2
max_length
=
10
vocab_size
=
5
beam
=
BeamSearch
(
model
=
StubTransformer
(
"encoder"
,
"decoder"
),
tokenizer
=
StubTokenizer
(
bos_token_id
=
0
,
eos_token_id
=
1
,
pad_token_id
=
2
),
batch_size
=
batch_size
,
beam_size
=
beam_size
,
min_length
=
2
,
max_length
=
max_length
,
alpha
=
0
,
block_repeating_trigrams
=
False
,
)
log_probabilities
=
torch
.
full
((
batch_size
*
beam_size
,
vocab_size
),
float
(
"-inf"
))
log_probabilities
[
0
,
3
:]
=
torch
.
log_softmax
(
torch
.
tensor
([
2.0
,
1.0
]),
dim
=
0
)
log_probabilities
[
2
,
3
:]
=
torch
.
log_softmax
(
torch
.
tensor
([
1.0
,
2.0
]),
dim
=
0
)
# First pass
surviving_beams_rows
=
beam
.
grow
(
log_probabilities
)
np
.
testing
.
assert_array_equal
(
surviving_beams_rows
.
numpy
(),
np
.
array
([
0
,
0
,
2
,
2
]))
np
.
testing
.
assert_array_equal
(
beam
.
growing_beams
.
numpy
(),
np
.
array
([[
0
,
3
],
[
0
,
4
],
[
0
,
4
],
[
0
,
3
]])
)
self
.
assertFalse
(
beam
.
is_done
)
# Second pass
surviving_beams_rows
=
beam
.
grow
(
log_probabilities
)
np
.
testing
.
assert_array_equal
(
surviving_beams_rows
.
numpy
(),
np
.
array
([
0
,
0
,
2
,
2
]))
np
.
testing
.
assert_array_equal
(
beam
.
growing_beams
.
numpy
(),
np
.
array
([[
0
,
3
,
3
],
[
0
,
3
,
4
],
[
0
,
4
,
4
],
[
0
,
4
,
3
]]),
)
self
.
assertFalse
(
beam
.
is_done
)
if
__name__
==
"__name__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment