Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8f2f0f0f
Unverified
Commit
8f2f0f0f
authored
Feb 26, 2024
by
Raushan Turganbay
Committed by
GitHub
Feb 26, 2024
Browse files
Track each row separately for stopping criteria (#29116)
parent
ece1b62b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
45 deletions
+43
-45
src/transformers/generation/stopping_criteria.py
src/transformers/generation/stopping_criteria.py
+16
-10
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+16
-24
tests/generation/test_stopping_criteria.py
tests/generation/test_stopping_criteria.py
+11
-11
No files found.
src/transformers/generation/stopping_criteria.py
View file @
8f2f0f0f
...
...
@@ -29,7 +29,8 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
Additional stopping criteria specific kwargs.
Return:
`bool`. `False` indicates we should continue, `True` indicates we should stop.
`torch.BoolTensor`. (`torch.BoolTensor` of shape `(batch_size, 1)`), where `True` indicates we stop generation
for a particular row, `True` indicates we should continue.
"""
...
...
@@ -42,7 +43,7 @@ class StoppingCriteria(ABC):
"""
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
BoolTensor
:
raise
NotImplementedError
(
"StoppingCriteria needs to be subclassed"
)
...
...
@@ -63,7 +64,7 @@ class MaxLengthCriteria(StoppingCriteria):
self
.
max_position_embeddings
=
max_position_embeddings
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
BoolTensor
:
cur_len
=
input_ids
.
shape
[
-
1
]
is_done
=
cur_len
>=
self
.
max_length
if
self
.
max_position_embeddings
is
not
None
and
not
is_done
and
cur_len
>=
self
.
max_position_embeddings
:
...
...
@@ -72,7 +73,7 @@ class MaxLengthCriteria(StoppingCriteria):
f
"maximum length (
{
self
.
max_position_embeddings
}
). Depending on the model, you may observe "
"exceptions, performance degradation, or nothing at all."
)
return
is_done
return
torch
.
full
((
input_ids
.
shape
[
0
],),
is_done
,
device
=
input_ids
.
device
)
class
MaxNewTokensCriteria
(
StoppingCriteria
):
...
...
@@ -100,8 +101,9 @@ class MaxNewTokensCriteria(StoppingCriteria):
self
.
max_length
=
start_length
+
max_new_tokens
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
return
input_ids
.
shape
[
-
1
]
>=
self
.
max_length
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
BoolTensor
:
is_done
=
input_ids
.
shape
[
-
1
]
>=
self
.
max_length
return
torch
.
full
((
input_ids
.
shape
[
0
],),
is_done
,
device
=
input_ids
.
device
)
class
MaxTimeCriteria
(
StoppingCriteria
):
...
...
@@ -122,14 +124,18 @@ class MaxTimeCriteria(StoppingCriteria):
self
.
initial_timestamp
=
time
.
time
()
if
initial_timestamp
is
None
else
initial_timestamp
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
return
time
.
time
()
-
self
.
initial_timestamp
>
self
.
max_time
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
BoolTensor
:
is_done
=
time
.
time
()
-
self
.
initial_timestamp
>
self
.
max_time
return
torch
.
full
((
input_ids
.
shape
[
0
],),
is_done
,
device
=
input_ids
.
device
)
class
StoppingCriteriaList
(
list
):
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
return
any
(
criteria
(
input_ids
,
scores
,
**
kwargs
)
for
criteria
in
self
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
BoolTensor
:
is_done
=
torch
.
full
((
input_ids
.
shape
[
0
],),
False
,
device
=
input_ids
.
device
)
for
criteria
in
self
:
is_done
=
is_done
|
criteria
(
input_ids
,
scores
,
**
kwargs
)
return
is_done
@
property
def
max_length
(
self
)
->
Optional
[
int
]:
...
...
src/transformers/generation/utils.py
View file @
8f2f0f0f
...
...
@@ -2195,11 +2195,9 @@ class GenerationMixin:
)
# stop when each sentence is finished
if
unfinished_sequences
.
max
()
==
0
:
this_peer_finished
=
True
unfinished_sequences
=
unfinished_sequences
&
~
stopping_criteria
(
input_ids
,
scores
)
# stop if we exceed the maximum length
if
stopping_criteria
(
input_ids
,
scores
):
if
unfinished_sequences
.
max
()
==
0
:
this_peer_finished
=
True
if
this_peer_finished
and
not
synced_gpus
:
...
...
@@ -2478,14 +2476,12 @@ class GenerationMixin:
next_tokens
.
tile
(
eos_token_id_tensor
.
shape
[
0
],
1
).
ne
(
eos_token_id_tensor
.
unsqueeze
(
1
)).
prod
(
dim
=
0
)
)
unfinished_sequences
=
unfinished_sequences
&
~
stopping_criteria
(
input_ids
,
scores
)
# stop when each sentence is finished
if
unfinished_sequences
.
max
()
==
0
:
this_peer_finished
=
True
# stop if we exceed the maximum length
if
stopping_criteria
(
input_ids
,
scores
):
this_peer_finished
=
True
if
this_peer_finished
and
not
synced_gpus
:
break
...
...
@@ -2772,14 +2768,12 @@ class GenerationMixin:
next_tokens
.
tile
(
eos_token_id_tensor
.
shape
[
0
],
1
).
ne
(
eos_token_id_tensor
.
unsqueeze
(
1
)).
prod
(
dim
=
0
)
)
unfinished_sequences
=
unfinished_sequences
&
~
stopping_criteria
(
input_ids
,
scores
)
# stop when each sentence is finished
if
unfinished_sequences
.
max
()
==
0
:
this_peer_finished
=
True
# stop if we exceed the maximum length
if
stopping_criteria
(
input_ids
,
scores
):
this_peer_finished
=
True
if
this_peer_finished
and
not
synced_gpus
:
break
...
...
@@ -3169,7 +3163,7 @@ class GenerationMixin:
# increase cur_len
cur_len
=
cur_len
+
1
if
beam_scorer
.
is_done
or
stopping_criteria
(
input_ids
,
scores
):
if
beam_scorer
.
is_done
or
all
(
stopping_criteria
(
input_ids
,
scores
)
)
:
if
not
synced_gpus
:
break
else
:
...
...
@@ -3516,7 +3510,7 @@ class GenerationMixin:
# increase cur_len
cur_len
=
cur_len
+
1
if
beam_scorer
.
is_done
or
stopping_criteria
(
input_ids
,
scores
):
if
beam_scorer
.
is_done
or
all
(
stopping_criteria
(
input_ids
,
scores
)
)
:
if
not
synced_gpus
:
break
else
:
...
...
@@ -3912,7 +3906,7 @@ class GenerationMixin:
# increase cur_len
cur_len
=
cur_len
+
1
if
beam_scorer
.
is_done
or
stopping_criteria
(
input_ids
,
scores
):
if
beam_scorer
.
is_done
or
all
(
stopping_criteria
(
input_ids
,
scores
)
)
:
if
not
synced_gpus
:
break
else
:
...
...
@@ -4267,7 +4261,7 @@ class GenerationMixin:
# increase cur_len
cur_len
=
cur_len
+
1
if
constrained_beam_scorer
.
is_done
or
stopping_criteria
(
input_ids
,
scores
):
if
constrained_beam_scorer
.
is_done
or
all
(
stopping_criteria
(
input_ids
,
scores
)
)
:
if
not
synced_gpus
:
break
else
:
...
...
@@ -4657,14 +4651,12 @@ class GenerationMixin:
.
prod
(
dim
=
0
)
)
unfinished_sequences
=
unfinished_sequences
&
~
stopping_criteria
(
input_ids
,
scores
)
# stop when each sentence is finished
if
unfinished_sequences
.
max
()
==
0
:
this_peer_finished
=
True
# stop if we exceed the maximum length
if
stopping_criteria
(
input_ids
,
scores
):
this_peer_finished
=
True
if
this_peer_finished
and
not
synced_gpus
:
break
...
...
tests/generation/test_stopping_criteria.py
View file @
8f2f0f0f
...
...
@@ -54,37 +54,37 @@ class StoppingCriteriaTestCase(unittest.TestCase):
]
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
input_ids
,
scores
=
self
.
_get_tensors
(
9
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
input_ids
,
scores
=
self
.
_get_tensors
(
10
)
self
.
assertTrue
(
criteria
(
input_ids
,
scores
))
self
.
assertTrue
(
all
(
criteria
(
input_ids
,
scores
))
)
def
test_max_length_criteria
(
self
):
criteria
=
MaxLengthCriteria
(
max_length
=
10
)
input_ids
,
scores
=
self
.
_get_tensors
(
5
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
input_ids
,
scores
=
self
.
_get_tensors
(
9
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
input_ids
,
scores
=
self
.
_get_tensors
(
10
)
self
.
assertTrue
(
criteria
(
input_ids
,
scores
))
self
.
assertTrue
(
all
(
criteria
(
input_ids
,
scores
))
)
def
test_max_new_tokens_criteria
(
self
):
criteria
=
MaxNewTokensCriteria
(
start_length
=
5
,
max_new_tokens
=
5
)
input_ids
,
scores
=
self
.
_get_tensors
(
5
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
input_ids
,
scores
=
self
.
_get_tensors
(
9
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
input_ids
,
scores
=
self
.
_get_tensors
(
10
)
self
.
assertTrue
(
criteria
(
input_ids
,
scores
))
self
.
assertTrue
(
all
(
criteria
(
input_ids
,
scores
))
)
criteria_list
=
StoppingCriteriaList
([
criteria
])
self
.
assertEqual
(
criteria_list
.
max_length
,
10
)
...
...
@@ -93,10 +93,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
input_ids
,
scores
=
self
.
_get_tensors
(
5
)
criteria
=
MaxTimeCriteria
(
max_time
=
0.1
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
criteria
=
MaxTimeCriteria
(
max_time
=
0.1
,
initial_timestamp
=
time
.
time
()
-
0.2
)
self
.
assertTrue
(
criteria
(
input_ids
,
scores
))
self
.
assertTrue
(
all
(
criteria
(
input_ids
,
scores
))
)
def
test_validate_stopping_criteria
(
self
):
validate_stopping_criteria
(
StoppingCriteriaList
([
MaxLengthCriteria
(
10
)]),
10
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment