Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8f2f0f0f
Unverified
Commit
8f2f0f0f
authored
Feb 26, 2024
by
Raushan Turganbay
Committed by
GitHub
Feb 26, 2024
Browse files
Track each row separately for stopping criteria (#29116)
parent
ece1b62b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
45 deletions
+43
-45
src/transformers/generation/stopping_criteria.py
src/transformers/generation/stopping_criteria.py
+16
-10
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+16
-24
tests/generation/test_stopping_criteria.py
tests/generation/test_stopping_criteria.py
+11
-11
No files found.
src/transformers/generation/stopping_criteria.py
View file @
8f2f0f0f
...
@@ -29,7 +29,8 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
...
@@ -29,7 +29,8 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
Additional stopping criteria specific kwargs.
Additional stopping criteria specific kwargs.
Return:
Return:
`bool`. `False` indicates we should continue, `True` indicates we should stop.
`torch.BoolTensor`. (`torch.BoolTensor` of shape `(batch_size, 1)`), where `True` indicates we stop generation
for a particular row, `True` indicates we should continue.
"""
"""
...
@@ -42,7 +43,7 @@ class StoppingCriteria(ABC):
...
@@ -42,7 +43,7 @@ class StoppingCriteria(ABC):
"""
"""
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
BoolTensor
:
raise
NotImplementedError
(
"StoppingCriteria needs to be subclassed"
)
raise
NotImplementedError
(
"StoppingCriteria needs to be subclassed"
)
...
@@ -63,7 +64,7 @@ class MaxLengthCriteria(StoppingCriteria):
...
@@ -63,7 +64,7 @@ class MaxLengthCriteria(StoppingCriteria):
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
BoolTensor
:
cur_len
=
input_ids
.
shape
[
-
1
]
cur_len
=
input_ids
.
shape
[
-
1
]
is_done
=
cur_len
>=
self
.
max_length
is_done
=
cur_len
>=
self
.
max_length
if
self
.
max_position_embeddings
is
not
None
and
not
is_done
and
cur_len
>=
self
.
max_position_embeddings
:
if
self
.
max_position_embeddings
is
not
None
and
not
is_done
and
cur_len
>=
self
.
max_position_embeddings
:
...
@@ -72,7 +73,7 @@ class MaxLengthCriteria(StoppingCriteria):
...
@@ -72,7 +73,7 @@ class MaxLengthCriteria(StoppingCriteria):
f
"maximum length (
{
self
.
max_position_embeddings
}
). Depending on the model, you may observe "
f
"maximum length (
{
self
.
max_position_embeddings
}
). Depending on the model, you may observe "
"exceptions, performance degradation, or nothing at all."
"exceptions, performance degradation, or nothing at all."
)
)
return
is_done
return
torch
.
full
((
input_ids
.
shape
[
0
],),
is_done
,
device
=
input_ids
.
device
)
class
MaxNewTokensCriteria
(
StoppingCriteria
):
class
MaxNewTokensCriteria
(
StoppingCriteria
):
...
@@ -100,8 +101,9 @@ class MaxNewTokensCriteria(StoppingCriteria):
...
@@ -100,8 +101,9 @@ class MaxNewTokensCriteria(StoppingCriteria):
self
.
max_length
=
start_length
+
max_new_tokens
self
.
max_length
=
start_length
+
max_new_tokens
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
BoolTensor
:
return
input_ids
.
shape
[
-
1
]
>=
self
.
max_length
is_done
=
input_ids
.
shape
[
-
1
]
>=
self
.
max_length
return
torch
.
full
((
input_ids
.
shape
[
0
],),
is_done
,
device
=
input_ids
.
device
)
class
MaxTimeCriteria
(
StoppingCriteria
):
class
MaxTimeCriteria
(
StoppingCriteria
):
...
@@ -122,14 +124,18 @@ class MaxTimeCriteria(StoppingCriteria):
...
@@ -122,14 +124,18 @@ class MaxTimeCriteria(StoppingCriteria):
self
.
initial_timestamp
=
time
.
time
()
if
initial_timestamp
is
None
else
initial_timestamp
self
.
initial_timestamp
=
time
.
time
()
if
initial_timestamp
is
None
else
initial_timestamp
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
BoolTensor
:
return
time
.
time
()
-
self
.
initial_timestamp
>
self
.
max_time
is_done
=
time
.
time
()
-
self
.
initial_timestamp
>
self
.
max_time
return
torch
.
full
((
input_ids
.
shape
[
0
],),
is_done
,
device
=
input_ids
.
device
)
class
StoppingCriteriaList
(
list
):
class
StoppingCriteriaList
(
list
):
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
BoolTensor
:
return
any
(
criteria
(
input_ids
,
scores
,
**
kwargs
)
for
criteria
in
self
)
is_done
=
torch
.
full
((
input_ids
.
shape
[
0
],),
False
,
device
=
input_ids
.
device
)
for
criteria
in
self
:
is_done
=
is_done
|
criteria
(
input_ids
,
scores
,
**
kwargs
)
return
is_done
@
property
@
property
def
max_length
(
self
)
->
Optional
[
int
]:
def
max_length
(
self
)
->
Optional
[
int
]:
...
...
src/transformers/generation/utils.py
View file @
8f2f0f0f
...
@@ -2194,12 +2194,10 @@ class GenerationMixin:
...
@@ -2194,12 +2194,10 @@ class GenerationMixin:
next_tokens
.
tile
(
eos_token_id_tensor
.
shape
[
0
],
1
).
ne
(
eos_token_id_tensor
.
unsqueeze
(
1
)).
prod
(
dim
=
0
)
next_tokens
.
tile
(
eos_token_id_tensor
.
shape
[
0
],
1
).
ne
(
eos_token_id_tensor
.
unsqueeze
(
1
)).
prod
(
dim
=
0
)
)
)
# stop when each sentence is finished
# stop when each sentence is finished
if
unfinished_sequences
.
max
()
==
0
:
unfinished_sequences
=
unfinished_sequences
&
~
stopping_criteria
(
input_ids
,
scores
)
this_peer_finished
=
True
# stop if we exceed the maximum length
if
unfinished_sequences
.
max
()
==
0
:
if
stopping_criteria
(
input_ids
,
scores
):
this_peer_finished
=
True
this_peer_finished
=
True
if
this_peer_finished
and
not
synced_gpus
:
if
this_peer_finished
and
not
synced_gpus
:
...
@@ -2478,12 +2476,10 @@ class GenerationMixin:
...
@@ -2478,12 +2476,10 @@ class GenerationMixin:
next_tokens
.
tile
(
eos_token_id_tensor
.
shape
[
0
],
1
).
ne
(
eos_token_id_tensor
.
unsqueeze
(
1
)).
prod
(
dim
=
0
)
next_tokens
.
tile
(
eos_token_id_tensor
.
shape
[
0
],
1
).
ne
(
eos_token_id_tensor
.
unsqueeze
(
1
)).
prod
(
dim
=
0
)
)
)
# stop when each sentence is finished
unfinished_sequences
=
unfinished_sequences
&
~
stopping_criteria
(
input_ids
,
scores
)
if
unfinished_sequences
.
max
()
==
0
:
this_peer_finished
=
True
# stop
if we exceed the maximum length
# stop
when each sentence is finished
if
stopping_criteria
(
input_ids
,
scores
)
:
if
unfinished_sequences
.
max
()
==
0
:
this_peer_finished
=
True
this_peer_finished
=
True
if
this_peer_finished
and
not
synced_gpus
:
if
this_peer_finished
and
not
synced_gpus
:
...
@@ -2772,12 +2768,10 @@ class GenerationMixin:
...
@@ -2772,12 +2768,10 @@ class GenerationMixin:
next_tokens
.
tile
(
eos_token_id_tensor
.
shape
[
0
],
1
).
ne
(
eos_token_id_tensor
.
unsqueeze
(
1
)).
prod
(
dim
=
0
)
next_tokens
.
tile
(
eos_token_id_tensor
.
shape
[
0
],
1
).
ne
(
eos_token_id_tensor
.
unsqueeze
(
1
)).
prod
(
dim
=
0
)
)
)
# stop when each sentence is finished
unfinished_sequences
=
unfinished_sequences
&
~
stopping_criteria
(
input_ids
,
scores
)
if
unfinished_sequences
.
max
()
==
0
:
this_peer_finished
=
True
# stop
if we exceed the maximum length
# stop
when each sentence is finished
if
stopping_criteria
(
input_ids
,
scores
)
:
if
unfinished_sequences
.
max
()
==
0
:
this_peer_finished
=
True
this_peer_finished
=
True
if
this_peer_finished
and
not
synced_gpus
:
if
this_peer_finished
and
not
synced_gpus
:
...
@@ -3169,7 +3163,7 @@ class GenerationMixin:
...
@@ -3169,7 +3163,7 @@ class GenerationMixin:
# increase cur_len
# increase cur_len
cur_len
=
cur_len
+
1
cur_len
=
cur_len
+
1
if
beam_scorer
.
is_done
or
stopping_criteria
(
input_ids
,
scores
):
if
beam_scorer
.
is_done
or
all
(
stopping_criteria
(
input_ids
,
scores
)
)
:
if
not
synced_gpus
:
if
not
synced_gpus
:
break
break
else
:
else
:
...
@@ -3516,7 +3510,7 @@ class GenerationMixin:
...
@@ -3516,7 +3510,7 @@ class GenerationMixin:
# increase cur_len
# increase cur_len
cur_len
=
cur_len
+
1
cur_len
=
cur_len
+
1
if
beam_scorer
.
is_done
or
stopping_criteria
(
input_ids
,
scores
):
if
beam_scorer
.
is_done
or
all
(
stopping_criteria
(
input_ids
,
scores
)
)
:
if
not
synced_gpus
:
if
not
synced_gpus
:
break
break
else
:
else
:
...
@@ -3912,7 +3906,7 @@ class GenerationMixin:
...
@@ -3912,7 +3906,7 @@ class GenerationMixin:
# increase cur_len
# increase cur_len
cur_len
=
cur_len
+
1
cur_len
=
cur_len
+
1
if
beam_scorer
.
is_done
or
stopping_criteria
(
input_ids
,
scores
):
if
beam_scorer
.
is_done
or
all
(
stopping_criteria
(
input_ids
,
scores
)
)
:
if
not
synced_gpus
:
if
not
synced_gpus
:
break
break
else
:
else
:
...
@@ -4267,7 +4261,7 @@ class GenerationMixin:
...
@@ -4267,7 +4261,7 @@ class GenerationMixin:
# increase cur_len
# increase cur_len
cur_len
=
cur_len
+
1
cur_len
=
cur_len
+
1
if
constrained_beam_scorer
.
is_done
or
stopping_criteria
(
input_ids
,
scores
):
if
constrained_beam_scorer
.
is_done
or
all
(
stopping_criteria
(
input_ids
,
scores
)
)
:
if
not
synced_gpus
:
if
not
synced_gpus
:
break
break
else
:
else
:
...
@@ -4657,12 +4651,10 @@ class GenerationMixin:
...
@@ -4657,12 +4651,10 @@ class GenerationMixin:
.
prod
(
dim
=
0
)
.
prod
(
dim
=
0
)
)
)
# stop when each sentence is finished
unfinished_sequences
=
unfinished_sequences
&
~
stopping_criteria
(
input_ids
,
scores
)
if
unfinished_sequences
.
max
()
==
0
:
this_peer_finished
=
True
# stop
if we exceed the maximum length
# stop
when each sentence is finished
if
stopping_criteria
(
input_ids
,
scores
)
:
if
unfinished_sequences
.
max
()
==
0
:
this_peer_finished
=
True
this_peer_finished
=
True
if
this_peer_finished
and
not
synced_gpus
:
if
this_peer_finished
and
not
synced_gpus
:
...
...
tests/generation/test_stopping_criteria.py
View file @
8f2f0f0f
...
@@ -54,37 +54,37 @@ class StoppingCriteriaTestCase(unittest.TestCase):
...
@@ -54,37 +54,37 @@ class StoppingCriteriaTestCase(unittest.TestCase):
]
]
)
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
input_ids
,
scores
=
self
.
_get_tensors
(
9
)
input_ids
,
scores
=
self
.
_get_tensors
(
9
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
input_ids
,
scores
=
self
.
_get_tensors
(
10
)
input_ids
,
scores
=
self
.
_get_tensors
(
10
)
self
.
assertTrue
(
criteria
(
input_ids
,
scores
))
self
.
assertTrue
(
all
(
criteria
(
input_ids
,
scores
))
)
def
test_max_length_criteria
(
self
):
def
test_max_length_criteria
(
self
):
criteria
=
MaxLengthCriteria
(
max_length
=
10
)
criteria
=
MaxLengthCriteria
(
max_length
=
10
)
input_ids
,
scores
=
self
.
_get_tensors
(
5
)
input_ids
,
scores
=
self
.
_get_tensors
(
5
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
input_ids
,
scores
=
self
.
_get_tensors
(
9
)
input_ids
,
scores
=
self
.
_get_tensors
(
9
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
input_ids
,
scores
=
self
.
_get_tensors
(
10
)
input_ids
,
scores
=
self
.
_get_tensors
(
10
)
self
.
assertTrue
(
criteria
(
input_ids
,
scores
))
self
.
assertTrue
(
all
(
criteria
(
input_ids
,
scores
))
)
def
test_max_new_tokens_criteria
(
self
):
def
test_max_new_tokens_criteria
(
self
):
criteria
=
MaxNewTokensCriteria
(
start_length
=
5
,
max_new_tokens
=
5
)
criteria
=
MaxNewTokensCriteria
(
start_length
=
5
,
max_new_tokens
=
5
)
input_ids
,
scores
=
self
.
_get_tensors
(
5
)
input_ids
,
scores
=
self
.
_get_tensors
(
5
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
input_ids
,
scores
=
self
.
_get_tensors
(
9
)
input_ids
,
scores
=
self
.
_get_tensors
(
9
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
input_ids
,
scores
=
self
.
_get_tensors
(
10
)
input_ids
,
scores
=
self
.
_get_tensors
(
10
)
self
.
assertTrue
(
criteria
(
input_ids
,
scores
))
self
.
assertTrue
(
all
(
criteria
(
input_ids
,
scores
))
)
criteria_list
=
StoppingCriteriaList
([
criteria
])
criteria_list
=
StoppingCriteriaList
([
criteria
])
self
.
assertEqual
(
criteria_list
.
max_length
,
10
)
self
.
assertEqual
(
criteria_list
.
max_length
,
10
)
...
@@ -93,10 +93,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
...
@@ -93,10 +93,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
input_ids
,
scores
=
self
.
_get_tensors
(
5
)
input_ids
,
scores
=
self
.
_get_tensors
(
5
)
criteria
=
MaxTimeCriteria
(
max_time
=
0.1
)
criteria
=
MaxTimeCriteria
(
max_time
=
0.1
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
criteria
=
MaxTimeCriteria
(
max_time
=
0.1
,
initial_timestamp
=
time
.
time
()
-
0.2
)
criteria
=
MaxTimeCriteria
(
max_time
=
0.1
,
initial_timestamp
=
time
.
time
()
-
0.2
)
self
.
assertTrue
(
criteria
(
input_ids
,
scores
))
self
.
assertTrue
(
all
(
criteria
(
input_ids
,
scores
))
)
def
test_validate_stopping_criteria
(
self
):
def
test_validate_stopping_criteria
(
self
):
validate_stopping_criteria
(
StoppingCriteriaList
([
MaxLengthCriteria
(
10
)]),
10
)
validate_stopping_criteria
(
StoppingCriteriaList
([
MaxLengthCriteria
(
10
)]),
10
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment