Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c0443df5
Commit
c0443df5
authored
Dec 05, 2019
by
Rémi Louf
Committed by
Julien Chaumond
Dec 09, 2019
Browse files
remove beam search
parent
2403a665
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
0 additions
and
619 deletions
+0
-619
transformers/generate/beam_search.py
transformers/generate/beam_search.py
+0
-376
transformers/tests/beam_search_tests.py
transformers/tests/beam_search_tests.py
+0
-243
No files found.
transformers/generate/beam_search.py
deleted
100644 → 0
View file @
2403a665
# coding=utf-8
# MIT License
# Copyright (c) 2017-Present OpenNMT
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Use Beam Search to generate sequences using encoder-decoder models.
"""
import
torch
from
torch
import
nn
import
logging
logger
=
logging
.
getLogger
(
__name__
)
class
BeamSearch
(
object
):
def
__init__
(
self
,
model
,
bos_token_id
,
pad_token_id
,
eos_token_id
,
batch_size
,
beam_size
,
min_length
,
max_length
,
alpha
=
0
,
block_repeating_trigrams
=
True
,
device
=
torch
.
device
(
"cpu"
),
):
r
"""
Inputs:
**model**: instance of ``transformers.PreTrainedEncoderDecoder``
The pretrained encoder-decoder model that will be used to generate the sequences.
**bos_token_id**: int
Id that is used by the tokenizer to represent the beggining of a sentence.
**pad_token_id**: int
Id that is used by the tokenizer for padding.
**eos_token_id**: int
Id that is used by the tokenizer to represent the end of a sentence.
**batch_size**: (`optional`) int
Batch size of the inputs. The value is set automatically when calling `forward`.
**beam_size**: int
Number of beams that are used for each element on the batch.
**min_length**: int
Minimum number of steps performed by the beam search before terminating.
**max_length**: int
Maximum number of steps performed by the beam search. Any beam that has not finished
will return its current solution with the highest probability. The sequence that is
returned has a length of max_length-1 to account for the end token that is subsequently added.
**alpha**: float
Parameter of the length penalty. Read the documentation of the `_length_penalty` method for mode details.
**block_repeating_trigrams**: bool
Whether to block sequences that have repeating 3-grams.
"""
super
(
BeamSearch
,
self
).
__init__
()
self
.
model
=
model
self
.
device
=
next
(
model
.
parameters
()).
device
# only works if all parameters of the model are stored on a single GPU
self
.
bos_token_id
=
bos_token_id
self
.
eos_token_id
=
eos_token_id
self
.
pad_token_id
=
pad_token_id
self
.
batch_size
=
batch_size
self
.
beam_size
=
beam_size
self
.
min_length
=
min_length
self
.
max_length
=
max_length
self
.
block_repeating_trigram
=
block_repeating_trigrams
self
.
apply_length_penalty
=
False
if
alpha
==
0
else
True
self
.
alpha
=
alpha
self
.
_init_beam_state
(
batch_size
)
def
__len__
(
self
):
return
self
.
growing_beams
.
size
(
1
)
def
_init_beam_state
(
self
,
batch_size
):
""" (re-)Initialize the state of the beams. """
self
.
hypotheses
=
[[]
for
_
in
range
(
batch_size
)]
self
.
batch_offset
=
torch
.
arange
(
batch_size
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
self
.
beam_offset
=
torch
.
arange
(
0
,
batch_size
*
self
.
beam_size
,
step
=
self
.
beam_size
,
dtype
=
torch
.
long
,
device
=
self
.
device
,
)
self
.
growing_beams
=
torch
.
full
(
(
batch_size
*
self
.
beam_size
,
1
),
self
.
bos_token_id
,
dtype
=
torch
.
long
,
device
=
self
.
device
,
)
self
.
topk_log_probabilities
=
torch
.
tensor
(
[
0.0
]
+
[
float
(
"-inf"
)]
*
(
self
.
beam_size
-
1
),
dtype
=
torch
.
float
,
device
=
self
.
device
,
).
repeat
(
batch_size
)
self
.
results
=
{
"predictions"
:
[[]
for
_
in
range
(
batch_size
)],
"scores"
:
[[]
for
_
in
range
(
batch_size
)],
}
self
.
_step
=
0
self
.
is_done
=
False
def
__call__
(
self
,
encoder_input_ids
,
**
model_kwargs
):
""" Generate a sequence using Beam Search. """
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_common
=
{
argument
:
value
for
argument
,
value
in
model_kwargs
.
items
()
if
not
argument
.
startswith
(
"encoder_"
)
and
not
argument
.
startswith
(
"decoder_"
)
}
kwargs_decoder
=
kwargs_common
.
copy
()
kwargs_encoder
=
kwargs_common
.
copy
()
kwargs_encoder
.
update
(
{
argument
[
len
(
"encoder_"
)
:]:
value
for
argument
,
value
in
model_kwargs
.
items
()
if
argument
.
startswith
(
"encoder_"
)
}
)
kwargs_decoder
.
update
(
{
argument
[
len
(
"decoder_"
)
:]:
value
for
argument
,
value
in
model_kwargs
.
items
()
if
argument
.
startswith
(
"decoder_"
)
}
)
# forward pass on the encoder
encoder_outputs
=
self
.
model
.
encoder
(
encoder_input_ids
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
]
kwargs_decoder
[
"encoder_hidden_states"
]
=
tile
(
encoder_hidden_states
,
self
.
beam_size
,
dim
=
0
)
try
:
kwargs_decoder
[
"encoder_attention_mask"
]
=
tile
(
kwargs_encoder
[
"attention_mask"
],
self
.
beam_size
,
dim
=
0
)
except
:
pass
kwargs_decoder
[
"state"
].
src
=
tile
(
kwargs_decoder
[
"state"
].
src
,
self
.
beam_size
,
dim
=
0
)
# grow the beam iteratively
batch_size
,
block_size
=
encoder_input_ids
.
size
()
self
.
_init_beam_state
(
batch_size
)
for
step
in
range
(
self
.
max_length
):
decoder_input
=
fit_to_block_size
(
self
.
growing_beams
,
block_size
,
self
.
pad_token_id
)
kwargs_decoder
[
"attention_mask"
]
=
build_mask
(
decoder_input
,
self
.
pad_token_id
)
outputs
,
state
=
self
.
model
.
decoder
(
decoder_input
,
**
kwargs_decoder
)
next_token_scores
=
outputs
[
0
][:,
-
1
,
:].
squeeze
(
1
)
log_probabilities
=
torch
.
nn
.
functional
.
log_softmax
(
next_token_scores
,
dim
=
0
)
surviving_beams_rows
=
self
.
grow
(
log_probabilities
)
if
self
.
is_done
:
break
kwargs_decoder
[
"encoder_hidden_states"
]
=
kwargs_decoder
[
"encoder_hidden_states"
].
index_select
(
0
,
surviving_beams_rows
)
try
:
kwargs_decoder
[
"encoder_attention_mask"
]
=
kwargs_decoder
[
"encoder_attention_mask"
].
index_select
(
0
,
surviving_beams_rows
)
except
:
pass
kwargs_decoder
[
"state"
]
=
state
return
self
.
results
def
grow
(
self
,
log_probabilities
):
""" Grow the beams by one step. """
self
.
_step
+=
1
# The number of beams changes as some beams finish so we define _B
vocab_size
=
log_probabilities
.
size
(
-
1
)
_B
=
log_probabilities
.
size
(
0
)
//
self
.
beam_size
# Multiply each beam probability with the probability of the
# next token (conditioned on the words in the beam).
log_probabilities
+=
self
.
topk_log_probabilities
.
view
(
-
1
,
1
)
self
.
_enforce_min_length
(
log_probabilities
)
if
self
.
block_repeating_trigram
:
self
.
_remove_beams_with_repeating_trigrams
(
log_probabilities
,
_B
)
# Find the `beam_size` (previous_beam + token) combinations with
# the highest score
self
.
topk_log_probabilities
,
topk_ids
=
torch
.
topk
(
log_probabilities
.
view
(
_B
,
self
.
beam_size
*
vocab_size
),
self
.
beam_size
,
dim
=
1
)
# Apply the length penalty. The +1 accounts for the [EOS] token
# that will be added if the beam ends.
topk_scores
=
self
.
topk_log_probabilities
if
self
.
apply_length_penalty
:
topk_scores
/=
self
.
_length_penalty
()
# Retrieve the corresponding respective beam and token id
# topk_token_ids[i] will be added to topk_beam_ids[i]
topk_beam_ids
=
topk_ids
.
div
(
vocab_size
)
topk_token_ids
=
topk_ids
.
fmod
(
vocab_size
)
# Retrieve the row index of the surviving beams in the original
# view of the log_probabilities tensor
surviving_beams_per_batch
=
topk_beam_ids
+
self
.
beam_offset
[:
_B
].
view
(
-
1
,
1
)
surviving_beams_rows
=
surviving_beams_per_batch
.
view
(
-
1
)
# Append the last predictions
self
.
growing_beams
=
torch
.
cat
(
[
self
.
growing_beams
.
index_select
(
0
,
surviving_beams_rows
),
topk_token_ids
.
view
(
-
1
,
1
),
],
1
,
)
# Check if any of the beam searches has ended during this
# growth step. Also if top beam (most probable) has ended
# for one element of the batch.
is_finished
=
topk_token_ids
.
eq
(
self
.
eos_token_id
)
self
.
_enforce_max_length
(
is_finished
)
if
is_finished
.
any
():
non_finished
=
self
.
_cut_finished
(
is_finished
,
topk_scores
)
self
.
batch_offset
=
self
.
batch_offset
.
index_select
(
0
,
non_finished
)
surviving_beams_per_batch
=
surviving_beams_per_batch
.
index_select
(
0
,
non_finished
)
self
.
topk_log_probabilities
=
self
.
topk_log_probabilities
.
index_select
(
0
,
non_finished
)
surviving_beams_rows
=
surviving_beams_per_batch
.
view
(
-
1
)
self
.
growing_beams
=
self
.
growing_beams
.
index_select
(
0
,
surviving_beams_rows
)
return
surviving_beams_rows
def
_cut_finished
(
self
,
is_finished
,
topk_scores
):
""" Save the finished searches and cut the correponding sequences off
the beams. """
is_top_beam_finished
=
is_finished
[:,
0
].
eq
(
True
)
# Save the finished searches
predictions
=
self
.
growing_beams
.
view
(
-
1
,
self
.
beam_size
,
self
.
growing_beams
.
size
(
1
)
)
for
i
in
range
(
is_finished
.
size
(
0
)):
if
is_top_beam_finished
[
i
]:
is_finished
[
i
].
fill_
(
1
)
finished_hyp
=
is_finished
[
i
].
nonzero
().
view
(
-
1
)
# Store the finished beams as a (score, prediction) hypothesis.
b
=
self
.
batch_offset
[
i
]
for
j
in
finished_hyp
:
self
.
hypotheses
[
b
].
append
((
topk_scores
[
i
,
j
],
predictions
[
i
,
j
,
:]))
# If the batch reached the end, save the best hypotheses
# in terms of length-penalized score.
if
is_top_beam_finished
[
i
]:
best_score
,
best_prediction
=
max
(
self
.
hypotheses
[
b
],
key
=
lambda
x
:
x
[
0
])
self
.
results
[
"scores"
][
b
].
append
(
best_score
)
self
.
results
[
"predictions"
][
b
].
append
(
best_prediction
)
non_finished
=
is_top_beam_finished
.
eq
(
False
).
nonzero
().
view
(
-
1
)
if
len
(
non_finished
)
==
0
:
self
.
is_done
=
True
return
non_finished
def
_remove_beams_with_repeating_trigrams
(
self
,
log_probabilities
,
_B
):
if
self
.
_step
+
1
>
3
:
# [BOS] does not count
for
i
in
range
(
_B
*
self
.
beam_size
):
tokens
=
self
.
growing_beams
[
i
]
trigrams
=
[
(
tokens
[
j
-
1
],
tokens
[
j
],
tokens
[
j
+
1
])
for
j
in
range
(
1
,
len
(
self
)
-
1
)
]
last_trigram
=
tuple
(
trigrams
[
-
1
])
if
last_trigram
in
trigrams
[:
-
1
]:
log_probabilities
[
i
]
=
-
1e20
def
_enforce_min_length
(
self
,
log_probabilities
):
if
self
.
_step
<
self
.
min_length
:
log_probabilities
[:,
self
.
eos_token_id
]
=
-
1e20
def
_enforce_max_length
(
self
,
is_finished
):
# +1 because we will need to add an [EOS] token
if
self
.
_step
+
1
==
self
.
max_length
:
is_finished
.
fill_
(
1
)
def
_length_penalty
(
self
):
""" The calculation of the length penalty follows that of [1].
[1] Wu, Yonghui, et al. "Google's neural machine translation system:
Bridging the gap between human and machine translation." arXiv preprint
arXiv:1609.08144 (2016).
"""
return
((
5.0
+
(
self
.
_step
+
1
))
/
6.0
)
**
self
.
alpha
def
tile
(
x
,
count
,
dim
=
0
):
"""
Tiles `x` along dimension `dim` `count` times.
Example:
>> ex = torch.tensor([1,2],[3,4])
>> tile(ex, 2, 0)
torch.Tensor([[1,2],[1,2],[3,4],[3,4]])
"""
perm
=
list
(
range
(
len
(
x
.
size
())))
if
dim
!=
0
:
perm
[
0
],
perm
[
dim
]
=
perm
[
dim
],
perm
[
0
]
x
=
x
.
permute
(
perm
).
contiguous
()
out_size
=
list
(
x
.
size
())
out_size
[
0
]
*=
count
batch
=
x
.
size
(
0
)
x
=
(
x
.
view
(
batch
,
-
1
)
.
transpose
(
0
,
1
)
.
repeat
(
count
,
1
)
.
transpose
(
0
,
1
)
.
contiguous
()
.
view
(
*
out_size
)
)
if
dim
!=
0
:
x
=
x
.
permute
(
perm
).
contiguous
()
return
x
def
fit_to_block_size
(
sequence
,
block_size
,
pad_token_id
):
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter we append padding tokens to the right.
"""
padded_sequence
=
torch
.
full
(
(
sequence
.
size
(
0
),
block_size
),
pad_token_id
,
dtype
=
torch
.
long
,
device
=
sequence
.
device
,
)
padded_sequence
[:,
:
sequence
.
size
(
1
)]
=
sequence
return
sequence
def
build_mask
(
sequence
,
pad_token_id
):
""" Builds the mask. The attention mechanism will only attend to positions
with value 1. """
mask
=
torch
.
ones_like
(
sequence
)
idx_pad_tokens
=
sequence
==
pad_token_id
mask
[
idx_pad_tokens
]
=
0
return
mask
transformers/tests/beam_search_tests.py
deleted
100644 → 0
View file @
2403a665
from
collections
import
namedtuple
import
unittest
import
pytest
import
numpy
as
np
import
torch
from
torch
import
nn
from
transformers.generate
import
BeamSearch
from
transformers
import
PreTrainedEncoderDecoder
class
StubTransformer
(
nn
.
Module
):
def
__init__
(
self
):
self
.
encoder
=
None
self
.
decoder
=
None
self
.
_parameters
=
{
"dumy"
:
torch
.
tensor
([
1
])}
def
forward
(
self
):
pass
class
BeamSearchtest
(
unittest
.
TestCase
):
def
test_beam_search_encoder_decoder_integration
(
self
):
""" We make sure that no internal change in the PreTrainedEncoderDecoder
class will break the integration with the beam search.
"""
model
=
StubTransformer
()
try
:
_
=
BeamSearch
(
model
=
model
,
bos_token_id
=
0
,
eos_token_id
=
1
,
pad_token_id
=
2
,
batch_size
=
1
,
beam_size
=
1
,
min_length
=
1
,
max_length
=
1
,
alpha
=
0
,
block_repeating_trigrams
=
False
,
)
except
:
self
.
fail
(
"Instantiating BeamSearch with a PreTrainedEncoderDecoder failed."
)
def
test_beam_search_min_length
(
self
):
""" We keep predicting the end_token for the first beam and check that
it is not marked as finished until the beam has reached the minimum
length. """
eos_idx
=
3
vocab_size
=
10
batch_size
=
3
beam_size
=
2
min_length
=
5
beam
=
BeamSearch
(
model
=
StubTransformer
(),
bos_token_id
=
0
,
eos_token_id
=
eos_idx
,
pad_token_id
=
2
,
batch_size
=
batch_size
,
beam_size
=
beam_size
,
min_length
=
5
,
max_length
=
10
,
alpha
=
0
,
block_repeating_trigrams
=
False
,
)
# To test that the minimum length is correctly enforced we constantly
# assign the highest probability to the [EOS] token (and assign lower
# probabilities to some other tokens).
# Since BeamSearch will reset its probability to 1e-20 as long as
# min_length has not been reached, we need to reset the value between
# steps.
non_eos_idxs
=
[
4
,
5
,
1
,
8
,
9
]
score_distribution
=
torch
.
log_softmax
(
torch
.
tensor
([
6.0
,
5.0
,
4.0
,
3.0
,
2.0
,
1.0
]),
dim
=
0
)
log_probabilities
=
torch
.
full
((
batch_size
*
beam_size
,
vocab_size
),
float
(
"-inf"
))
log_probabilities
[
0
,
eos_idx
]
=
score_distribution
[
0
]
for
idx
,
score
in
zip
(
non_eos_idxs
,
score_distribution
[
1
:]):
log_probabilities
[
0
,
idx
]
=
score
pytest
.
set_trace
()
for
step
in
range
(
1
,
min_length
+
2
):
log_probabilities
[
0
,
eos_idx
]
=
score_distribution
[
0
]
# Beam #3 and #4 teminate at the first step since the probability
# of the [EOS] token is -1e20 > -\infty so there are only two beams left.
# The top beam (most likely) always ends with 4 until we reach min_length.
surviving_beams_rows
=
beam
.
grow
(
log_probabilities
)
if
step
<
min_length
:
np
.
testing
.
assert_array_equal
(
beam
.
growing_beams
.
numpy
()[
0
,
:],
np
.
array
([
0
]
+
[
4
]
*
step
)
)
elif
step
==
min_length
:
np
.
testing
.
assert_array_equal
(
surviving_beams_rows
.
numpy
(),
np
.
array
([]))
self
.
assertTrue
(
beam
.
is_done
)
break
log_probabilities
=
log_probabilities
.
index_select
(
0
,
surviving_beams_rows
)
def
test_beam_search_max_length
(
self
):
""" We keep predicting the same non-EOS token until we reach the
maximum permitted length """
batch_size
=
3
beam_size
=
2
max_length
=
5
vocab_size
=
10
beam
=
BeamSearch
(
model
=
StubTransformer
(),
bos_token_id
=
0
,
eos_token_id
=
1
,
pad_token_id
=
2
,
batch_size
=
batch_size
,
beam_size
=
beam_size
,
min_length
=
2
,
max_length
=
max_length
,
alpha
=
0
,
block_repeating_trigrams
=
False
,
)
log_probabilities
=
torch
.
full
((
batch_size
*
beam_size
,
vocab_size
),
float
(
"-inf"
))
# To test that beam search enforces the max length constraint we
# keep giving the highest probability to a token that is not the
# [EOS] token.
# The beam search will stop at max_length-1, assuming that one would
# add the [EOS] token at the end of the returned sequence.
token_idxs
=
[
3
,
4
,
5
]
score_distribution
=
torch
.
log_softmax
(
torch
.
tensor
([
10.0
,
6.0
,
4.0
]),
dim
=
0
)
for
idx
,
score
in
zip
(
token_idxs
,
score_distribution
):
log_probabilities
[:,
idx
]
=
score
for
step
in
range
(
1
,
max_length
+
2
):
surviving_beams_rows
=
beam
.
grow
(
log_probabilities
)
if
step
+
1
<
max_length
:
self
.
assertFalse
(
beam
.
is_done
)
elif
step
+
1
==
max_length
:
# Now [EOS] is the most probable token
np
.
testing
.
assert_array_equal
(
surviving_beams_rows
.
numpy
(),
np
.
array
([]))
self
.
assertTrue
(
beam
.
is_done
)
break
log_probabilities
=
log_probabilities
.
index_select
(
0
,
surviving_beams_rows
)
def
test_beam_search_block_repeating_trigrams
(
self
):
""" We make sure that the beams that contain repeating trigrams are removed. """
batch_size
=
3
beam_size
=
2
max_length
=
10
vocab_size
=
10
beam
=
BeamSearch
(
model
=
StubTransformer
(),
bos_token_id
=
0
,
eos_token_id
=
1
,
pad_token_id
=
2
,
batch_size
=
batch_size
,
beam_size
=
beam_size
,
min_length
=
2
,
max_length
=
max_length
,
alpha
=
0
,
block_repeating_trigrams
=
True
,
)
log_probabilities
=
torch
.
full
((
batch_size
*
beam_size
,
vocab_size
),
float
(
"-inf"
))
# To test that BeamSearch enforces the 3-gram constraint we give the
# highest probably to the same tokens in a cyclic fashion and make sure
# they disappear once the cycle has completed.
token_idxs
=
[
3
,
4
,
5
]
score_distribution
=
torch
.
log_softmax
(
torch
.
tensor
([
10.0
,
6.0
,
4.0
]),
dim
=
0
)
for
idx
,
score
in
zip
(
token_idxs
,
score_distribution
):
log_probabilities
[:,
idx
]
=
score
for
step
in
range
(
1
,
max_length
+
2
):
# Rotate the probabilities at each step
for
idx
in
token_idxs
:
score
=
score_distribution
[(
idx
+
step
)
%
3
]
log_probabilities
[::
beam_size
,
idx
]
=
score
surviving_beams_rows
=
beam
.
grow
(
log_probabilities
)
if
step
<
7
:
self
.
assertFalse
(
np
.
array_equal
(
log_probabilities
.
numpy
()[
0
,
:],
np
.
array
([
-
1e20
]
*
vocab_size
,
dtype
=
"float32"
),
)
)
if
step
==
7
:
np
.
testing
.
assert_array_equal
(
log_probabilities
.
numpy
()[
0
,
:],
np
.
array
([
-
1e20
]
*
vocab_size
,
dtype
=
"float32"
),
)
log_probabilities
=
log_probabilities
.
index_select
(
0
,
surviving_beams_rows
)
def
test_beam_search_example_for_one_step
(
self
):
""" We test that the predictions for one step of growth are correct. """
batch_size
=
2
beam_size
=
2
max_length
=
10
vocab_size
=
5
beam
=
BeamSearch
(
model
=
StubTransformer
(),
bos_token_id
=
0
,
eos_token_id
=
1
,
pad_token_id
=
2
,
batch_size
=
batch_size
,
beam_size
=
beam_size
,
min_length
=
2
,
max_length
=
max_length
,
alpha
=
0
,
block_repeating_trigrams
=
False
,
)
log_probabilities
=
torch
.
full
((
batch_size
*
beam_size
,
vocab_size
),
float
(
"-inf"
))
log_probabilities
[
0
,
3
:]
=
torch
.
log_softmax
(
torch
.
tensor
([
2.0
,
1.0
]),
dim
=
0
)
log_probabilities
[
2
,
3
:]
=
torch
.
log_softmax
(
torch
.
tensor
([
1.0
,
2.0
]),
dim
=
0
)
# First pass
surviving_beams_rows
=
beam
.
grow
(
log_probabilities
)
np
.
testing
.
assert_array_equal
(
surviving_beams_rows
.
numpy
(),
np
.
array
([
0
,
0
,
2
,
2
]))
np
.
testing
.
assert_array_equal
(
beam
.
growing_beams
.
numpy
(),
np
.
array
([[
0
,
3
],
[
0
,
4
],
[
0
,
4
],
[
0
,
3
]])
)
self
.
assertFalse
(
beam
.
is_done
)
# Second pass
surviving_beams_rows
=
beam
.
grow
(
log_probabilities
)
np
.
testing
.
assert_array_equal
(
surviving_beams_rows
.
numpy
(),
np
.
array
([
0
,
0
,
2
,
2
]))
np
.
testing
.
assert_array_equal
(
beam
.
growing_beams
.
numpy
(),
np
.
array
([[
0
,
3
,
3
],
[
0
,
3
,
4
],
[
0
,
4
,
4
],
[
0
,
4
,
3
]]),
)
self
.
assertFalse
(
beam
.
is_done
)
if
__name__
==
"__name__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment