Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e3963581
"docs/ZH_CN/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "87625ec25f695809458a40e3d353c0b072783636"
Unverified
Commit
e3963581
authored
Sep 30, 2022
by
Karim Foda
Committed by
GitHub
Sep 30, 2022
Browse files
Add stop sequence to text generation pipeline (#18444)
parent
582d085b
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
54 additions
and
1 deletion
+54
-1
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+0
-1
src/transformers/pipelines/text2text_generation.py
src/transformers/pipelines/text2text_generation.py
+11
-0
src/transformers/pipelines/text_generation.py
src/transformers/pipelines/text_generation.py
+11
-0
tests/generation/test_generation_utils.py
tests/generation/test_generation_utils.py
+20
-0
tests/pipelines/test_pipelines_text_generation.py
tests/pipelines/test_pipelines_text_generation.py
+12
-0
No files found.
src/transformers/generation_utils.py
View file @
e3963581
...
...
@@ -1343,7 +1343,6 @@ class GenerationMixin:
stopping_criteria
=
self
.
_get_stopping_criteria
(
max_length
=
max_length
,
max_time
=
max_time
,
stopping_criteria
=
stopping_criteria
)
# 9. go into different generation modes
if
is_greedy_gen_mode
:
if
num_return_sequences
>
1
:
...
...
src/transformers/pipelines/text2text_generation.py
View file @
e3963581
import
enum
import
warnings
from
..tokenization_utils
import
TruncationStrategy
from
..utils
import
add_end_docstrings
,
is_tf_available
,
is_torch_available
,
logging
...
...
@@ -59,6 +60,7 @@ class Text2TextGenerationPipeline(Pipeline):
return_type
=
None
,
clean_up_tokenization_spaces
=
None
,
truncation
=
None
,
stop_sequence
=
None
,
**
generate_kwargs
):
preprocess_params
=
{}
...
...
@@ -76,6 +78,15 @@ class Text2TextGenerationPipeline(Pipeline):
if
clean_up_tokenization_spaces
is
not
None
:
postprocess_params
[
"clean_up_tokenization_spaces"
]
=
clean_up_tokenization_spaces
if
stop_sequence
is
not
None
:
stop_sequence_ids
=
self
.
tokenizer
.
encode
(
stop_sequence
,
add_special_tokens
=
False
)
if
len
(
stop_sequence_ids
)
>
1
:
warnings
.
warn
(
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
" the stop sequence will be used as the stop sequence string in the interim."
)
generate_kwargs
[
"eos_token_id"
]
=
stop_sequence_ids
[
0
]
return
preprocess_params
,
forward_params
,
postprocess_params
def
check_inputs
(
self
,
input_length
:
int
,
min_length
:
int
,
max_length
:
int
):
...
...
src/transformers/pipelines/text_generation.py
View file @
e3963581
import
enum
import
warnings
from
transformers
import
MODEL_FOR_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_CAUSAL_LM_MAPPING
...
...
@@ -80,6 +81,7 @@ class TextGenerationPipeline(Pipeline):
clean_up_tokenization_spaces
=
None
,
prefix
=
None
,
handle_long_generation
=
None
,
stop_sequence
=
None
,
**
generate_kwargs
):
preprocess_params
=
{}
...
...
@@ -121,6 +123,15 @@ class TextGenerationPipeline(Pipeline):
if
clean_up_tokenization_spaces
is
not
None
:
postprocess_params
[
"clean_up_tokenization_spaces"
]
=
clean_up_tokenization_spaces
if
stop_sequence
is
not
None
:
stop_sequence_ids
=
self
.
tokenizer
.
encode
(
stop_sequence
,
add_special_tokens
=
False
)
if
len
(
stop_sequence_ids
)
>
1
:
warnings
.
warn
(
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
" the stop sequence will be used as the stop sequence string in the interim."
)
generate_kwargs
[
"eos_token_id"
]
=
stop_sequence_ids
[
0
]
return
preprocess_params
,
forward_params
,
postprocess_params
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
...
...
tests/generation/test_generation_utils.py
View file @
e3963581
...
...
@@ -37,6 +37,7 @@ if is_torch_available():
Speech2TextForConditionalGeneration
,
SpeechEncoderDecoderModel
,
VisionEncoderDecoderModel
,
pipeline
,
top_k_top_p_filtering
,
)
from
transformers.generation_beam_constraints
import
DisjunctiveConstraint
,
PhrasalConstraint
...
...
@@ -1979,6 +1980,25 @@ class GenerationIntegrationTests(unittest.TestCase):
[
1
,
18
],
)
def
test_stop_sequence_stopping_criteria
(
self
):
prompt
=
"""Hello I believe in"""
generator
=
pipeline
(
"text-generation"
,
model
=
"hf-internal-testing/tiny-random-bart"
)
output
=
generator
(
prompt
)
self
.
assertEqual
(
output
,
[
{
"generated_text"
:
(
"Hello I believe in in in number number number number number number number number number"
)
}
],
)
output
=
generator
(
prompt
,
stop_sequence
=
" number"
)
self
.
assertEqual
(
output
,
[{
"generated_text"
:
"Hello I believe in in in number"
}])
def
test_custom_logits_processor
(
self
):
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"sshleifer/bart-tiny-random"
)
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
...
...
tests/pipelines/test_pipelines_text_generation.py
View file @
e3963581
...
...
@@ -147,6 +147,18 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
text_generator
=
TextGenerationPipeline
(
model
=
model
,
tokenizer
=
tokenizer
)
return
text_generator
,
[
"This is a test"
,
"Another test"
]
def
test_stop_sequence_stopping_criteria
(
self
):
prompt
=
"""Hello I believe in"""
text_generator
=
pipeline
(
"text-generation"
,
model
=
"hf-internal-testing/tiny-random-gpt2"
)
output
=
text_generator
(
prompt
)
self
.
assertEqual
(
output
,
[{
"generated_text"
:
"Hello I believe in fe fe fe fe fe fe fe fe fe fe fe fe"
}],
)
output
=
text_generator
(
prompt
,
stop_sequence
=
" fe"
)
self
.
assertEqual
(
output
,
[{
"generated_text"
:
"Hello I believe in fe"
}])
def
run_pipeline_test
(
self
,
text_generator
,
_
):
model
=
text_generator
.
model
tokenizer
=
text_generator
.
tokenizer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment