Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3d3e605a
Unverified
Commit
3d3e605a
authored
Jun 18, 2020
by
Sam Shleifer
Committed by
GitHub
Jun 18, 2020
Browse files
[cleanup] generate_beam_search comments (#5115)
parent
ca2d0f98
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
18 deletions
+14
-18
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+3
-5
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+11
-13
No files found.
src/transformers/modeling_tf_utils.py
View file @
3d3e605a
...
@@ -1219,9 +1219,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
...
@@ -1219,9 +1219,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if
len
(
next_sent_beam
)
==
num_beams
:
if
len
(
next_sent_beam
)
==
num_beams
:
break
break
# Check if were done so that we can save a pad step if all(done)
# Check if we
a
re done so that we can save a pad step if all(done)
done
[
batch_idx
]
=
done
[
batch_idx
]
or
generated_hyps
[
batch_idx
].
is_done
(
done
[
batch_idx
]
=
done
[
batch_idx
]
or
generated_hyps
[
batch_idx
].
is_done
(
tf
.
reduce_max
(
next_scores
[
batch_idx
]).
numpy
(),
cur_len
=
cur_len
tf
.
reduce_max
(
next_scores
[
batch_idx
]).
numpy
(),
cur_len
)
)
# update next beam content
# update next beam content
...
@@ -1509,7 +1509,7 @@ class BeamHypotheses(object):
...
@@ -1509,7 +1509,7 @@ class BeamHypotheses(object):
else
:
else
:
self
.
worst_score
=
min
(
score
,
self
.
worst_score
)
self
.
worst_score
=
min
(
score
,
self
.
worst_score
)
def
is_done
(
self
,
best_sum_logprobs
,
cur_len
=
None
):
def
is_done
(
self
,
best_sum_logprobs
,
cur_len
):
"""
"""
If there are enough hypotheses and that none of the hypotheses being generated
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
can become better than the worst one in the heap, then we are done with this sentence.
...
@@ -1520,8 +1520,6 @@ class BeamHypotheses(object):
...
@@ -1520,8 +1520,6 @@ class BeamHypotheses(object):
elif
self
.
early_stopping
:
elif
self
.
early_stopping
:
return
True
return
True
else
:
else
:
if
cur_len
is
None
:
cur_len
=
self
.
max_length
cur_score
=
best_sum_logprobs
/
cur_len
**
self
.
length_penalty
cur_score
=
best_sum_logprobs
/
cur_len
**
self
.
length_penalty
ret
=
self
.
worst_score
>=
cur_score
ret
=
self
.
worst_score
>=
cur_score
return
ret
return
ret
...
...
src/transformers/modeling_utils.py
View file @
3d3e605a
...
@@ -1462,7 +1462,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -1462,7 +1462,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# for each sentence
# for each sentence
for
batch_idx
in
range
(
batch_size
):
for
batch_idx
in
range
(
batch_size
):
# if we are done with this sentence
# if we are done with this sentence
, add a pad token
if
done
[
batch_idx
]:
if
done
[
batch_idx
]:
assert
(
assert
(
len
(
generated_hyps
[
batch_idx
])
>=
num_beams
len
(
generated_hyps
[
batch_idx
])
>=
num_beams
...
@@ -1473,7 +1473,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -1473,7 +1473,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_batch_beam
.
extend
([(
0
,
pad_token_id
,
0
)]
*
num_beams
)
# pad the batch
next_batch_beam
.
extend
([(
0
,
pad_token_id
,
0
)]
*
num_beams
)
# pad the batch
continue
continue
# next sentence beam content
# next sentence beam content
, this will get added to next_batch_beam
next_sent_beam
=
[]
next_sent_beam
=
[]
# next tokens for this sentence
# next tokens for this sentence
...
@@ -1485,7 +1485,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -1485,7 +1485,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
token_id
=
beam_token_id
%
vocab_size
token_id
=
beam_token_id
%
vocab_size
effective_beam_id
=
batch_idx
*
num_beams
+
beam_id
effective_beam_id
=
batch_idx
*
num_beams
+
beam_id
# add to generated hypotheses if end of sentence
or last iteration
# add to generated hypotheses if end of sentence
if
(
eos_token_id
is
not
None
)
and
(
token_id
.
item
()
==
eos_token_id
):
if
(
eos_token_id
is
not
None
)
and
(
token_id
.
item
()
==
eos_token_id
):
# if beam_token does not belong to top num_beams tokens, it should not be added
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams
=
beam_token_rank
>=
num_beams
is_beam_token_worse_than_top_num_beams
=
beam_token_rank
>=
num_beams
...
@@ -1495,22 +1495,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -1495,22 +1495,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids
[
effective_beam_id
].
clone
(),
beam_token_score
.
item
(),
input_ids
[
effective_beam_id
].
clone
(),
beam_token_score
.
item
(),
)
)
else
:
else
:
# add next predicted token
if
it is not eos_token
# add next predicted token
since
it is not eos_token
next_sent_beam
.
append
((
beam_token_score
,
token_id
,
effective_beam_id
))
next_sent_beam
.
append
((
beam_token_score
,
token_id
,
effective_beam_id
))
# the beam for next step is full
#
once
the beam for next step is full
, don't add more tokens to it.
if
len
(
next_sent_beam
)
==
num_beams
:
if
len
(
next_sent_beam
)
==
num_beams
:
break
break
# Check if were done so that we can save a pad step if all(done)
# Check if we
a
re done so that we can save a pad step if all(done)
done
[
batch_idx
]
=
done
[
batch_idx
]
or
generated_hyps
[
batch_idx
].
is_done
(
done
[
batch_idx
]
=
done
[
batch_idx
]
or
generated_hyps
[
batch_idx
].
is_done
(
next_scores
[
batch_idx
].
max
().
item
(),
cur_len
=
cur_len
next_scores
[
batch_idx
].
max
().
item
(),
cur_len
)
)
# update next beam content
# update next beam content
assert
len
(
next_sent_beam
)
==
num_beams
,
"Beam should always be full"
assert
len
(
next_sent_beam
)
==
num_beams
,
"Beam should always be full"
next_batch_beam
.
extend
(
next_sent_beam
)
next_batch_beam
.
extend
(
next_sent_beam
)
assert
len
(
next_batch_beam
)
==
num_beams
*
(
batch_idx
+
1
)
assert
len
(
next_batch_beam
)
==
num_beams
*
(
batch_idx
+
1
)
,
"We should have added num_beams each step"
# stop when we are done with each sentence
# stop when we are done with each sentence
if
all
(
done
):
if
all
(
done
):
...
@@ -1537,7 +1537,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -1537,7 +1537,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
[
attention_mask
,
attention_mask
.
new_ones
((
attention_mask
.
shape
[
0
],
1
))],
dim
=-
1
[
attention_mask
,
attention_mask
.
new_ones
((
attention_mask
.
shape
[
0
],
1
))],
dim
=-
1
)
)
# finalize all open beam hypotheses and
en
d to generated hypotheses
# finalize all open beam hypotheses and
ad
d to generated hypotheses
for
batch_idx
in
range
(
batch_size
):
for
batch_idx
in
range
(
batch_size
):
if
done
[
batch_idx
]:
if
done
[
batch_idx
]:
continue
continue
...
@@ -1576,7 +1576,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
...
@@ -1576,7 +1576,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
sent_lengths
[
effective_batch_idx
]
=
len
(
best_hyp
)
sent_lengths
[
effective_batch_idx
]
=
len
(
best_hyp
)
best
.
append
(
best_hyp
)
best
.
append
(
best_hyp
)
# shorter batches are
filled with pad_token
# shorter batches are
padded
if
sent_lengths
.
min
().
item
()
!=
sent_lengths
.
max
().
item
():
if
sent_lengths
.
min
().
item
()
!=
sent_lengths
.
max
().
item
():
assert
pad_token_id
is
not
None
,
"`Pad_token_id` has to be defined"
assert
pad_token_id
is
not
None
,
"`Pad_token_id` has to be defined"
sent_max_len
=
min
(
sent_lengths
.
max
().
item
()
+
1
,
max_length
)
sent_max_len
=
min
(
sent_lengths
.
max
().
item
()
+
1
,
max_length
)
...
@@ -1731,7 +1731,7 @@ class BeamHypotheses(object):
...
@@ -1731,7 +1731,7 @@ class BeamHypotheses(object):
else
:
else
:
self
.
worst_score
=
min
(
score
,
self
.
worst_score
)
self
.
worst_score
=
min
(
score
,
self
.
worst_score
)
def
is_done
(
self
,
best_sum_logprobs
,
cur_len
=
None
):
def
is_done
(
self
,
best_sum_logprobs
,
cur_len
):
"""
"""
If there are enough hypotheses and that none of the hypotheses being generated
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
can become better than the worst one in the heap, then we are done with this sentence.
...
@@ -1742,8 +1742,6 @@ class BeamHypotheses(object):
...
@@ -1742,8 +1742,6 @@ class BeamHypotheses(object):
elif
self
.
early_stopping
:
elif
self
.
early_stopping
:
return
True
return
True
else
:
else
:
if
cur_len
is
None
:
cur_len
=
self
.
max_length
cur_score
=
best_sum_logprobs
/
cur_len
**
self
.
length_penalty
cur_score
=
best_sum_logprobs
/
cur_len
**
self
.
length_penalty
ret
=
self
.
worst_score
>=
cur_score
ret
=
self
.
worst_score
>=
cur_score
return
ret
return
ret
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment