Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e3963581
Unverified
Commit
e3963581
authored
Sep 30, 2022
by
Karim Foda
Committed by
GitHub
Sep 30, 2022
Browse files
Add stop sequence to text generation pipeline (#18444)
parent
582d085b
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
54 additions
and
1 deletion
+54
-1
src/transformers/generation_utils.py
src/transformers/generation_utils.py
+0
-1
src/transformers/pipelines/text2text_generation.py
src/transformers/pipelines/text2text_generation.py
+11
-0
src/transformers/pipelines/text_generation.py
src/transformers/pipelines/text_generation.py
+11
-0
tests/generation/test_generation_utils.py
tests/generation/test_generation_utils.py
+20
-0
tests/pipelines/test_pipelines_text_generation.py
tests/pipelines/test_pipelines_text_generation.py
+12
-0
No files found.
src/transformers/generation_utils.py
View file @
e3963581
...
@@ -1343,7 +1343,6 @@ class GenerationMixin:
...
@@ -1343,7 +1343,6 @@ class GenerationMixin:
stopping_criteria
=
self
.
_get_stopping_criteria
(
stopping_criteria
=
self
.
_get_stopping_criteria
(
max_length
=
max_length
,
max_time
=
max_time
,
stopping_criteria
=
stopping_criteria
max_length
=
max_length
,
max_time
=
max_time
,
stopping_criteria
=
stopping_criteria
)
)
# 9. go into different generation modes
# 9. go into different generation modes
if
is_greedy_gen_mode
:
if
is_greedy_gen_mode
:
if
num_return_sequences
>
1
:
if
num_return_sequences
>
1
:
...
...
src/transformers/pipelines/text2text_generation.py
View file @
e3963581
import
enum
import
enum
import
warnings
from
..tokenization_utils
import
TruncationStrategy
from
..tokenization_utils
import
TruncationStrategy
from
..utils
import
add_end_docstrings
,
is_tf_available
,
is_torch_available
,
logging
from
..utils
import
add_end_docstrings
,
is_tf_available
,
is_torch_available
,
logging
...
@@ -59,6 +60,7 @@ class Text2TextGenerationPipeline(Pipeline):
...
@@ -59,6 +60,7 @@ class Text2TextGenerationPipeline(Pipeline):
return_type
=
None
,
return_type
=
None
,
clean_up_tokenization_spaces
=
None
,
clean_up_tokenization_spaces
=
None
,
truncation
=
None
,
truncation
=
None
,
stop_sequence
=
None
,
**
generate_kwargs
**
generate_kwargs
):
):
preprocess_params
=
{}
preprocess_params
=
{}
...
@@ -76,6 +78,15 @@ class Text2TextGenerationPipeline(Pipeline):
...
@@ -76,6 +78,15 @@ class Text2TextGenerationPipeline(Pipeline):
if
clean_up_tokenization_spaces
is
not
None
:
if
clean_up_tokenization_spaces
is
not
None
:
postprocess_params
[
"clean_up_tokenization_spaces"
]
=
clean_up_tokenization_spaces
postprocess_params
[
"clean_up_tokenization_spaces"
]
=
clean_up_tokenization_spaces
if
stop_sequence
is
not
None
:
stop_sequence_ids
=
self
.
tokenizer
.
encode
(
stop_sequence
,
add_special_tokens
=
False
)
if
len
(
stop_sequence_ids
)
>
1
:
warnings
.
warn
(
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
" the stop sequence will be used as the stop sequence string in the interim."
)
generate_kwargs
[
"eos_token_id"
]
=
stop_sequence_ids
[
0
]
return
preprocess_params
,
forward_params
,
postprocess_params
return
preprocess_params
,
forward_params
,
postprocess_params
def
check_inputs
(
self
,
input_length
:
int
,
min_length
:
int
,
max_length
:
int
):
def
check_inputs
(
self
,
input_length
:
int
,
min_length
:
int
,
max_length
:
int
):
...
...
src/transformers/pipelines/text_generation.py
View file @
e3963581
import
enum
import
enum
import
warnings
from
transformers
import
MODEL_FOR_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_CAUSAL_LM_MAPPING
from
transformers
import
MODEL_FOR_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_CAUSAL_LM_MAPPING
...
@@ -80,6 +81,7 @@ class TextGenerationPipeline(Pipeline):
...
@@ -80,6 +81,7 @@ class TextGenerationPipeline(Pipeline):
clean_up_tokenization_spaces
=
None
,
clean_up_tokenization_spaces
=
None
,
prefix
=
None
,
prefix
=
None
,
handle_long_generation
=
None
,
handle_long_generation
=
None
,
stop_sequence
=
None
,
**
generate_kwargs
**
generate_kwargs
):
):
preprocess_params
=
{}
preprocess_params
=
{}
...
@@ -121,6 +123,15 @@ class TextGenerationPipeline(Pipeline):
...
@@ -121,6 +123,15 @@ class TextGenerationPipeline(Pipeline):
if
clean_up_tokenization_spaces
is
not
None
:
if
clean_up_tokenization_spaces
is
not
None
:
postprocess_params
[
"clean_up_tokenization_spaces"
]
=
clean_up_tokenization_spaces
postprocess_params
[
"clean_up_tokenization_spaces"
]
=
clean_up_tokenization_spaces
if
stop_sequence
is
not
None
:
stop_sequence_ids
=
self
.
tokenizer
.
encode
(
stop_sequence
,
add_special_tokens
=
False
)
if
len
(
stop_sequence_ids
)
>
1
:
warnings
.
warn
(
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
" the stop sequence will be used as the stop sequence string in the interim."
)
generate_kwargs
[
"eos_token_id"
]
=
stop_sequence_ids
[
0
]
return
preprocess_params
,
forward_params
,
postprocess_params
return
preprocess_params
,
forward_params
,
postprocess_params
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
...
...
tests/generation/test_generation_utils.py
View file @
e3963581
...
@@ -37,6 +37,7 @@ if is_torch_available():
...
@@ -37,6 +37,7 @@ if is_torch_available():
Speech2TextForConditionalGeneration
,
Speech2TextForConditionalGeneration
,
SpeechEncoderDecoderModel
,
SpeechEncoderDecoderModel
,
VisionEncoderDecoderModel
,
VisionEncoderDecoderModel
,
pipeline
,
top_k_top_p_filtering
,
top_k_top_p_filtering
,
)
)
from
transformers.generation_beam_constraints
import
DisjunctiveConstraint
,
PhrasalConstraint
from
transformers.generation_beam_constraints
import
DisjunctiveConstraint
,
PhrasalConstraint
...
@@ -1979,6 +1980,25 @@ class GenerationIntegrationTests(unittest.TestCase):
...
@@ -1979,6 +1980,25 @@ class GenerationIntegrationTests(unittest.TestCase):
[
1
,
18
],
[
1
,
18
],
)
)
def
test_stop_sequence_stopping_criteria
(
self
):
prompt
=
"""Hello I believe in"""
generator
=
pipeline
(
"text-generation"
,
model
=
"hf-internal-testing/tiny-random-bart"
)
output
=
generator
(
prompt
)
self
.
assertEqual
(
output
,
[
{
"generated_text"
:
(
"Hello I believe in in in number number number number number number number number number"
)
}
],
)
output
=
generator
(
prompt
,
stop_sequence
=
" number"
)
self
.
assertEqual
(
output
,
[{
"generated_text"
:
"Hello I believe in in in number"
}])
def
test_custom_logits_processor
(
self
):
def
test_custom_logits_processor
(
self
):
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"sshleifer/bart-tiny-random"
)
bart_tokenizer
=
BartTokenizer
.
from_pretrained
(
"sshleifer/bart-tiny-random"
)
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
article
=
"""Justin Timberlake and Jessica Biel, welcome to parenthood."""
...
...
tests/pipelines/test_pipelines_text_generation.py
View file @
e3963581
...
@@ -147,6 +147,18 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
...
@@ -147,6 +147,18 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
text_generator
=
TextGenerationPipeline
(
model
=
model
,
tokenizer
=
tokenizer
)
text_generator
=
TextGenerationPipeline
(
model
=
model
,
tokenizer
=
tokenizer
)
return
text_generator
,
[
"This is a test"
,
"Another test"
]
return
text_generator
,
[
"This is a test"
,
"Another test"
]
def
test_stop_sequence_stopping_criteria
(
self
):
prompt
=
"""Hello I believe in"""
text_generator
=
pipeline
(
"text-generation"
,
model
=
"hf-internal-testing/tiny-random-gpt2"
)
output
=
text_generator
(
prompt
)
self
.
assertEqual
(
output
,
[{
"generated_text"
:
"Hello I believe in fe fe fe fe fe fe fe fe fe fe fe fe"
}],
)
output
=
text_generator
(
prompt
,
stop_sequence
=
" fe"
)
self
.
assertEqual
(
output
,
[{
"generated_text"
:
"Hello I believe in fe"
}])
def
run_pipeline_test
(
self
,
text_generator
,
_
):
def
run_pipeline_test
(
self
,
text_generator
,
_
):
model
=
text_generator
.
model
model
=
text_generator
.
model
tokenizer
=
text_generator
.
tokenizer
tokenizer
=
text_generator
.
tokenizer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment