Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8f2f0f0f
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7e51a441e47919536f6b831648b9774181ff2091"
Unverified
Commit
8f2f0f0f
authored
Feb 26, 2024
by
Raushan Turganbay
Committed by
GitHub
Feb 26, 2024
Browse files
Track each row separately for stopping criteria (#29116)
parent
ece1b62b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
45 deletions
+43
-45
src/transformers/generation/stopping_criteria.py
src/transformers/generation/stopping_criteria.py
+16
-10
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+16
-24
tests/generation/test_stopping_criteria.py
tests/generation/test_stopping_criteria.py
+11
-11
No files found.
src/transformers/generation/stopping_criteria.py
View file @
8f2f0f0f
...
@@ -29,7 +29,8 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
...
@@ -29,7 +29,8 @@ STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
Additional stopping criteria specific kwargs.
Additional stopping criteria specific kwargs.
Return:
Return:
`bool`. `False` indicates we should continue, `True` indicates we should stop.
`torch.BoolTensor`. (`torch.BoolTensor` of shape `(batch_size, 1)`), where `True` indicates we stop generation
for a particular row, `True` indicates we should continue.
"""
"""
...
@@ -42,7 +43,7 @@ class StoppingCriteria(ABC):
...
@@ -42,7 +43,7 @@ class StoppingCriteria(ABC):
"""
"""
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
BoolTensor
:
raise
NotImplementedError
(
"StoppingCriteria needs to be subclassed"
)
raise
NotImplementedError
(
"StoppingCriteria needs to be subclassed"
)
...
@@ -63,7 +64,7 @@ class MaxLengthCriteria(StoppingCriteria):
...
@@ -63,7 +64,7 @@ class MaxLengthCriteria(StoppingCriteria):
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
BoolTensor
:
cur_len
=
input_ids
.
shape
[
-
1
]
cur_len
=
input_ids
.
shape
[
-
1
]
is_done
=
cur_len
>=
self
.
max_length
is_done
=
cur_len
>=
self
.
max_length
if
self
.
max_position_embeddings
is
not
None
and
not
is_done
and
cur_len
>=
self
.
max_position_embeddings
:
if
self
.
max_position_embeddings
is
not
None
and
not
is_done
and
cur_len
>=
self
.
max_position_embeddings
:
...
@@ -72,7 +73,7 @@ class MaxLengthCriteria(StoppingCriteria):
...
@@ -72,7 +73,7 @@ class MaxLengthCriteria(StoppingCriteria):
f
"maximum length (
{
self
.
max_position_embeddings
}
). Depending on the model, you may observe "
f
"maximum length (
{
self
.
max_position_embeddings
}
). Depending on the model, you may observe "
"exceptions, performance degradation, or nothing at all."
"exceptions, performance degradation, or nothing at all."
)
)
return
is_done
return
torch
.
full
((
input_ids
.
shape
[
0
],),
is_done
,
device
=
input_ids
.
device
)
class
MaxNewTokensCriteria
(
StoppingCriteria
):
class
MaxNewTokensCriteria
(
StoppingCriteria
):
...
@@ -100,8 +101,9 @@ class MaxNewTokensCriteria(StoppingCriteria):
...
@@ -100,8 +101,9 @@ class MaxNewTokensCriteria(StoppingCriteria):
self
.
max_length
=
start_length
+
max_new_tokens
self
.
max_length
=
start_length
+
max_new_tokens
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
BoolTensor
:
return
input_ids
.
shape
[
-
1
]
>=
self
.
max_length
is_done
=
input_ids
.
shape
[
-
1
]
>=
self
.
max_length
return
torch
.
full
((
input_ids
.
shape
[
0
],),
is_done
,
device
=
input_ids
.
device
)
class
MaxTimeCriteria
(
StoppingCriteria
):
class
MaxTimeCriteria
(
StoppingCriteria
):
...
@@ -122,14 +124,18 @@ class MaxTimeCriteria(StoppingCriteria):
...
@@ -122,14 +124,18 @@ class MaxTimeCriteria(StoppingCriteria):
self
.
initial_timestamp
=
time
.
time
()
if
initial_timestamp
is
None
else
initial_timestamp
self
.
initial_timestamp
=
time
.
time
()
if
initial_timestamp
is
None
else
initial_timestamp
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
BoolTensor
:
return
time
.
time
()
-
self
.
initial_timestamp
>
self
.
max_time
is_done
=
time
.
time
()
-
self
.
initial_timestamp
>
self
.
max_time
return
torch
.
full
((
input_ids
.
shape
[
0
],),
is_done
,
device
=
input_ids
.
device
)
class
StoppingCriteriaList
(
list
):
class
StoppingCriteriaList
(
list
):
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
@
add_start_docstrings
(
STOPPING_CRITERIA_INPUTS_DOCSTRING
)
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
bool
:
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
,
**
kwargs
)
->
torch
.
BoolTensor
:
return
any
(
criteria
(
input_ids
,
scores
,
**
kwargs
)
for
criteria
in
self
)
is_done
=
torch
.
full
((
input_ids
.
shape
[
0
],),
False
,
device
=
input_ids
.
device
)
for
criteria
in
self
:
is_done
=
is_done
|
criteria
(
input_ids
,
scores
,
**
kwargs
)
return
is_done
@
property
@
property
def
max_length
(
self
)
->
Optional
[
int
]:
def
max_length
(
self
)
->
Optional
[
int
]:
...
...
src/transformers/generation/utils.py
View file @
8f2f0f0f
...
@@ -2194,12 +2194,10 @@ class GenerationMixin:
...
@@ -2194,12 +2194,10 @@ class GenerationMixin:
next_tokens
.
tile
(
eos_token_id_tensor
.
shape
[
0
],
1
).
ne
(
eos_token_id_tensor
.
unsqueeze
(
1
)).
prod
(
dim
=
0
)
next_tokens
.
tile
(
eos_token_id_tensor
.
shape
[
0
],
1
).
ne
(
eos_token_id_tensor
.
unsqueeze
(
1
)).
prod
(
dim
=
0
)
)
)
# stop when each sentence is finished
# stop when each sentence is finished
if
unfinished_sequences
.
max
()
==
0
:
unfinished_sequences
=
unfinished_sequences
&
~
stopping_criteria
(
input_ids
,
scores
)
this_peer_finished
=
True
# stop if we exceed the maximum length
if
unfinished_sequences
.
max
()
==
0
:
if
stopping_criteria
(
input_ids
,
scores
):
this_peer_finished
=
True
this_peer_finished
=
True
if
this_peer_finished
and
not
synced_gpus
:
if
this_peer_finished
and
not
synced_gpus
:
...
@@ -2478,12 +2476,10 @@ class GenerationMixin:
...
@@ -2478,12 +2476,10 @@ class GenerationMixin:
next_tokens
.
tile
(
eos_token_id_tensor
.
shape
[
0
],
1
).
ne
(
eos_token_id_tensor
.
unsqueeze
(
1
)).
prod
(
dim
=
0
)
next_tokens
.
tile
(
eos_token_id_tensor
.
shape
[
0
],
1
).
ne
(
eos_token_id_tensor
.
unsqueeze
(
1
)).
prod
(
dim
=
0
)
)
)
# stop when each sentence is finished
unfinished_sequences
=
unfinished_sequences
&
~
stopping_criteria
(
input_ids
,
scores
)
if
unfinished_sequences
.
max
()
==
0
:
this_peer_finished
=
True
# stop
if we exceed the maximum length
# stop
when each sentence is finished
if
stopping_criteria
(
input_ids
,
scores
)
:
if
unfinished_sequences
.
max
()
==
0
:
this_peer_finished
=
True
this_peer_finished
=
True
if
this_peer_finished
and
not
synced_gpus
:
if
this_peer_finished
and
not
synced_gpus
:
...
@@ -2772,12 +2768,10 @@ class GenerationMixin:
...
@@ -2772,12 +2768,10 @@ class GenerationMixin:
next_tokens
.
tile
(
eos_token_id_tensor
.
shape
[
0
],
1
).
ne
(
eos_token_id_tensor
.
unsqueeze
(
1
)).
prod
(
dim
=
0
)
next_tokens
.
tile
(
eos_token_id_tensor
.
shape
[
0
],
1
).
ne
(
eos_token_id_tensor
.
unsqueeze
(
1
)).
prod
(
dim
=
0
)
)
)
# stop when each sentence is finished
unfinished_sequences
=
unfinished_sequences
&
~
stopping_criteria
(
input_ids
,
scores
)
if
unfinished_sequences
.
max
()
==
0
:
this_peer_finished
=
True
# stop
if we exceed the maximum length
# stop
when each sentence is finished
if
stopping_criteria
(
input_ids
,
scores
)
:
if
unfinished_sequences
.
max
()
==
0
:
this_peer_finished
=
True
this_peer_finished
=
True
if
this_peer_finished
and
not
synced_gpus
:
if
this_peer_finished
and
not
synced_gpus
:
...
@@ -3169,7 +3163,7 @@ class GenerationMixin:
...
@@ -3169,7 +3163,7 @@ class GenerationMixin:
# increase cur_len
# increase cur_len
cur_len
=
cur_len
+
1
cur_len
=
cur_len
+
1
if
beam_scorer
.
is_done
or
stopping_criteria
(
input_ids
,
scores
):
if
beam_scorer
.
is_done
or
all
(
stopping_criteria
(
input_ids
,
scores
)
)
:
if
not
synced_gpus
:
if
not
synced_gpus
:
break
break
else
:
else
:
...
@@ -3516,7 +3510,7 @@ class GenerationMixin:
...
@@ -3516,7 +3510,7 @@ class GenerationMixin:
# increase cur_len
# increase cur_len
cur_len
=
cur_len
+
1
cur_len
=
cur_len
+
1
if
beam_scorer
.
is_done
or
stopping_criteria
(
input_ids
,
scores
):
if
beam_scorer
.
is_done
or
all
(
stopping_criteria
(
input_ids
,
scores
)
)
:
if
not
synced_gpus
:
if
not
synced_gpus
:
break
break
else
:
else
:
...
@@ -3912,7 +3906,7 @@ class GenerationMixin:
...
@@ -3912,7 +3906,7 @@ class GenerationMixin:
# increase cur_len
# increase cur_len
cur_len
=
cur_len
+
1
cur_len
=
cur_len
+
1
if
beam_scorer
.
is_done
or
stopping_criteria
(
input_ids
,
scores
):
if
beam_scorer
.
is_done
or
all
(
stopping_criteria
(
input_ids
,
scores
)
)
:
if
not
synced_gpus
:
if
not
synced_gpus
:
break
break
else
:
else
:
...
@@ -4267,7 +4261,7 @@ class GenerationMixin:
...
@@ -4267,7 +4261,7 @@ class GenerationMixin:
# increase cur_len
# increase cur_len
cur_len
=
cur_len
+
1
cur_len
=
cur_len
+
1
if
constrained_beam_scorer
.
is_done
or
stopping_criteria
(
input_ids
,
scores
):
if
constrained_beam_scorer
.
is_done
or
all
(
stopping_criteria
(
input_ids
,
scores
)
)
:
if
not
synced_gpus
:
if
not
synced_gpus
:
break
break
else
:
else
:
...
@@ -4657,12 +4651,10 @@ class GenerationMixin:
...
@@ -4657,12 +4651,10 @@ class GenerationMixin:
.
prod
(
dim
=
0
)
.
prod
(
dim
=
0
)
)
)
# stop when each sentence is finished
unfinished_sequences
=
unfinished_sequences
&
~
stopping_criteria
(
input_ids
,
scores
)
if
unfinished_sequences
.
max
()
==
0
:
this_peer_finished
=
True
# stop
if we exceed the maximum length
# stop
when each sentence is finished
if
stopping_criteria
(
input_ids
,
scores
)
:
if
unfinished_sequences
.
max
()
==
0
:
this_peer_finished
=
True
this_peer_finished
=
True
if
this_peer_finished
and
not
synced_gpus
:
if
this_peer_finished
and
not
synced_gpus
:
...
...
tests/generation/test_stopping_criteria.py
View file @
8f2f0f0f
...
@@ -54,37 +54,37 @@ class StoppingCriteriaTestCase(unittest.TestCase):
...
@@ -54,37 +54,37 @@ class StoppingCriteriaTestCase(unittest.TestCase):
]
]
)
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
input_ids
,
scores
=
self
.
_get_tensors
(
9
)
input_ids
,
scores
=
self
.
_get_tensors
(
9
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
input_ids
,
scores
=
self
.
_get_tensors
(
10
)
input_ids
,
scores
=
self
.
_get_tensors
(
10
)
self
.
assertTrue
(
criteria
(
input_ids
,
scores
))
self
.
assertTrue
(
all
(
criteria
(
input_ids
,
scores
))
)
def
test_max_length_criteria
(
self
):
def
test_max_length_criteria
(
self
):
criteria
=
MaxLengthCriteria
(
max_length
=
10
)
criteria
=
MaxLengthCriteria
(
max_length
=
10
)
input_ids
,
scores
=
self
.
_get_tensors
(
5
)
input_ids
,
scores
=
self
.
_get_tensors
(
5
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
input_ids
,
scores
=
self
.
_get_tensors
(
9
)
input_ids
,
scores
=
self
.
_get_tensors
(
9
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
input_ids
,
scores
=
self
.
_get_tensors
(
10
)
input_ids
,
scores
=
self
.
_get_tensors
(
10
)
self
.
assertTrue
(
criteria
(
input_ids
,
scores
))
self
.
assertTrue
(
all
(
criteria
(
input_ids
,
scores
))
)
def
test_max_new_tokens_criteria
(
self
):
def
test_max_new_tokens_criteria
(
self
):
criteria
=
MaxNewTokensCriteria
(
start_length
=
5
,
max_new_tokens
=
5
)
criteria
=
MaxNewTokensCriteria
(
start_length
=
5
,
max_new_tokens
=
5
)
input_ids
,
scores
=
self
.
_get_tensors
(
5
)
input_ids
,
scores
=
self
.
_get_tensors
(
5
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
input_ids
,
scores
=
self
.
_get_tensors
(
9
)
input_ids
,
scores
=
self
.
_get_tensors
(
9
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
input_ids
,
scores
=
self
.
_get_tensors
(
10
)
input_ids
,
scores
=
self
.
_get_tensors
(
10
)
self
.
assertTrue
(
criteria
(
input_ids
,
scores
))
self
.
assertTrue
(
all
(
criteria
(
input_ids
,
scores
))
)
criteria_list
=
StoppingCriteriaList
([
criteria
])
criteria_list
=
StoppingCriteriaList
([
criteria
])
self
.
assertEqual
(
criteria_list
.
max_length
,
10
)
self
.
assertEqual
(
criteria_list
.
max_length
,
10
)
...
@@ -93,10 +93,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
...
@@ -93,10 +93,10 @@ class StoppingCriteriaTestCase(unittest.TestCase):
input_ids
,
scores
=
self
.
_get_tensors
(
5
)
input_ids
,
scores
=
self
.
_get_tensors
(
5
)
criteria
=
MaxTimeCriteria
(
max_time
=
0.1
)
criteria
=
MaxTimeCriteria
(
max_time
=
0.1
)
self
.
assertFalse
(
criteria
(
input_ids
,
scores
))
self
.
assertFalse
(
all
(
criteria
(
input_ids
,
scores
))
)
criteria
=
MaxTimeCriteria
(
max_time
=
0.1
,
initial_timestamp
=
time
.
time
()
-
0.2
)
criteria
=
MaxTimeCriteria
(
max_time
=
0.1
,
initial_timestamp
=
time
.
time
()
-
0.2
)
self
.
assertTrue
(
criteria
(
input_ids
,
scores
))
self
.
assertTrue
(
all
(
criteria
(
input_ids
,
scores
))
)
def
test_validate_stopping_criteria
(
self
):
def
test_validate_stopping_criteria
(
self
):
validate_stopping_criteria
(
StoppingCriteriaList
([
MaxLengthCriteria
(
10
)]),
10
)
validate_stopping_criteria
(
StoppingCriteriaList
([
MaxLengthCriteria
(
10
)]),
10
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment