Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
228792a9
Unverified
Commit
228792a9
authored
Mar 30, 2023
by
Joao Gante
Committed by
GitHub
Mar 30, 2023
Browse files
Generate: basic token streaming (#22449)
* haha tokens go brrrr
parent
f0aeb1be
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
230 additions
and
7 deletions
+230
-7
docs/source/en/generation_strategies.mdx
docs/source/en/generation_strategies.mdx
+23
-0
docs/source/en/internal/generation_utils.mdx
docs/source/en/internal/generation_utils.mdx
+4
-0
docs/source/en/main_classes/text_generation.mdx
docs/source/en/main_classes/text_generation.mdx
+2
-1
src/transformers/__init__.py
src/transformers/__init__.py
+2
-2
src/transformers/generation/__init__.py
src/transformers/generation/__init__.py
+2
-2
src/transformers/generation/streamers.py
src/transformers/generation/streamers.py
+104
-0
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+49
-2
tests/generation/test_streamers.py
tests/generation/test_streamers.py
+44
-0
No files found.
docs/source/en/generation_strategies.mdx
View file @
228792a9
...
...
@@ -139,6 +139,29 @@ one for summarization with beam search). You must have the right Hub permissions
['Les fichiers de configuration sont faciles à utiliser !']
```
## Streaming
The `generate()` supports streaming, through its `streamer` input. The `streamer` input is compatible any instance
from a class that has the following methods: `put()` and `end()`. Internally, `put()` is used to push new tokens and
`end()` is used to flag the end of text generation.
In practice, you can craft your own streaming class for all sorts of purposes! We also have basic streaming classes
ready for you to use. For example, you can use the [`TextStreamer`] class to stream the output of `generate()` into
your screen, one word at a time:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
>>> tok = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
>>> streamer = TextStreamer(tok)
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
```
## Decoding strategies
Certain combinations of the `generate()` parameters, and ultimately `generation_config`, can be used to enable specific
...
...
docs/source/en/internal/generation_utils.mdx
View file @
228792a9
...
...
@@ -265,3 +265,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] top_k_top_p_filtering
[[autodoc]] tf_top_k_top_p_filtering
## Streamers
[[autodoc]] TextStreamer
docs/source/en/main_classes/text_generation.mdx
View file @
228792a9
...
...
@@ -24,7 +24,8 @@ of the generation method.
To learn how to inspect a model's generation configuration, what are the defaults, how to change the parameters ad hoc,
and how to create and save a customized generation configuration, refer to the
[text generation strategies guide](../generation_strategies).
[text generation strategies guide](../generation_strategies). The guide also explains how to use related features,
like token streaming.
## GenerationConfig
...
...
src/transformers/__init__.py
View file @
228792a9
...
...
@@ -96,7 +96,7 @@ _import_structure = {
"feature_extraction_sequence_utils"
:
[
"SequenceFeatureExtractor"
],
"feature_extraction_utils"
:
[
"BatchFeature"
,
"FeatureExtractionMixin"
],
"file_utils"
:
[],
"generation"
:
[
"GenerationConfig"
],
"generation"
:
[
"GenerationConfig"
,
"TextStreamer"
],
"hf_argparser"
:
[
"HfArgumentParser"
],
"image_transforms"
:
[],
"integrations"
:
[
...
...
@@ -3769,7 +3769,7 @@ if TYPE_CHECKING:
from
.feature_extraction_utils
import
BatchFeature
,
FeatureExtractionMixin
# Generation
from
.generation
import
GenerationConfig
from
.generation
import
GenerationConfig
,
TextStreamer
from
.hf_argparser
import
HfArgumentParser
# Integrations
...
...
src/transformers/generation/__init__.py
View file @
228792a9
...
...
@@ -17,8 +17,7 @@ from typing import TYPE_CHECKING
from
..utils
import
OptionalDependencyNotAvailable
,
_LazyModule
,
is_flax_available
,
is_tf_available
,
is_torch_available
_import_structure
=
{
"configuration_utils"
:
[
"GenerationConfig"
]}
_import_structure
=
{
"configuration_utils"
:
[
"GenerationConfig"
],
"streamers"
:
[
"TextStreamer"
]}
try
:
if
not
is_torch_available
():
...
...
@@ -150,6 +149,7 @@ else:
if
TYPE_CHECKING
:
from
.configuration_utils
import
GenerationConfig
from
.streamers
import
TextStreamer
try
:
if
not
is_torch_available
():
...
...
src/transformers/generation/streamers.py
0 → 100644
View file @
228792a9
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
if
TYPE_CHECKING
:
from
..models.auto
import
AutoTokenizer
class
BaseStreamer
:
"""
Base class from which `.generate()` streamers should inherit.
"""
def
put
(
self
,
value
):
"""Function that is called by `.generate()` to push new tokens"""
raise
NotImplementedError
()
def
end
(
self
):
"""Function that is called by `.generate()` to signal the end of generation"""
raise
NotImplementedError
()
class
TextStreamer
(
BaseStreamer
):
"""
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.
Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
>>> tok = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
>>> streamer = TextStreamer(tok)
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
```
"""
def
__init__
(
self
,
tokenizer
:
"AutoTokenizer"
):
self
.
tokenizer
=
tokenizer
self
.
token_cache
=
[]
self
.
print_len
=
0
def
put
(
self
,
value
):
"""
Recives tokens, decodes them, and prints them to stdout as soon as they form entire words.
"""
if
len
(
value
.
shape
)
>
1
and
value
.
shape
[
0
]
>
1
:
raise
ValueError
(
"TextStreamer only supports batch size 1"
)
elif
len
(
value
.
shape
)
>
1
:
value
=
value
[
0
]
# Add the new token to the cache and decodes the entire thing.
self
.
token_cache
.
extend
(
value
.
tolist
())
text
=
self
.
tokenizer
.
decode
(
self
.
token_cache
)
# After symbol for a new line, we flush the cache.
if
text
.
endswith
(
"
\n
"
):
printable_text
=
text
[
self
.
print_len
:]
self
.
token_cache
=
[]
self
.
print_len
=
0
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else
:
printable_text
=
text
[
self
.
print_len
:
text
.
rfind
(
" "
)
+
1
]
self
.
print_len
+=
len
(
printable_text
)
print
(
printable_text
,
flush
=
True
,
end
=
""
)
def
end
(
self
):
"""Flushes any remaining cache and prints a newline to stdout."""
# Flush the cache, if it exists
if
len
(
self
.
token_cache
)
>
0
:
text
=
self
.
tokenizer
.
decode
(
self
.
token_cache
)
printable_text
=
text
[
self
.
print_len
:]
self
.
token_cache
=
[]
self
.
print_len
=
0
else
:
printable_text
=
""
# Print a newline (and the remaining text, if any)
print
(
printable_text
,
flush
=
True
)
src/transformers/generation/utils.py
View file @
228792a9
...
...
@@ -18,7 +18,7 @@ import copy
import
inspect
import
warnings
from
dataclasses
import
dataclass
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.distributed
as
dist
...
...
@@ -72,6 +72,10 @@ from .stopping_criteria import (
)
if
TYPE_CHECKING
:
from
.streamers
import
BaseStreamer
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -1116,6 +1120,7 @@ class GenerationMixin:
stopping_criteria
:
Optional
[
StoppingCriteriaList
]
=
None
,
prefix_allowed_tokens_fn
:
Optional
[
Callable
[[
int
,
torch
.
Tensor
],
List
[
int
]]]
=
None
,
synced_gpus
:
Optional
[
bool
]
=
None
,
streamer
:
Optional
[
"BaseStreamer"
]
=
None
,
**
kwargs
,
)
->
Union
[
GenerateOutput
,
torch
.
LongTensor
]:
r
"""
...
...
@@ -1165,6 +1170,9 @@ class GenerationMixin:
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
generating before other GPUs. Otherwise it'll be set to `False`.
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
...
...
@@ -1295,6 +1303,9 @@ class GenerationMixin:
else
:
input_ids
=
inputs_tensor
if
model_input_name
==
"input_ids"
else
model_kwargs
.
pop
(
"input_ids"
)
if
streamer
is
not
None
:
streamer
.
put
(
input_ids
.
cpu
())
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length
=
input_ids
.
shape
[
-
1
]
has_default_max_length
=
kwargs
.
get
(
"max_length"
)
is
None
and
generation_config
.
max_length
is
not
None
...
...
@@ -1335,7 +1346,8 @@ class GenerationMixin:
)
is_contrastive_search_gen_mode
=
(
generation_config
.
top_k
is
not
None
(
generation_config
.
num_beams
==
1
)
and
generation_config
.
top_k
is
not
None
and
generation_config
.
top_k
>
1
and
generation_config
.
do_sample
is
False
and
generation_config
.
penalty_alpha
is
not
None
...
...
@@ -1384,6 +1396,11 @@ class GenerationMixin:
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
)
if
streamer
is
not
None
and
(
generation_config
.
num_beams
>
1
):
raise
ValueError
(
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
)
if
self
.
device
.
type
!=
input_ids
.
device
.
type
:
warnings
.
warn
(
"You are calling .generate() with the `input_ids` being on a device type different"
...
...
@@ -1426,6 +1443,7 @@ class GenerationMixin:
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
**
model_kwargs
,
)
...
...
@@ -1447,6 +1465,7 @@ class GenerationMixin:
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
**
model_kwargs
,
)
...
...
@@ -1473,6 +1492,7 @@ class GenerationMixin:
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
**
model_kwargs
,
)
...
...
@@ -1703,6 +1723,7 @@ class GenerationMixin:
output_scores
:
Optional
[
bool
]
=
None
,
return_dict_in_generate
:
Optional
[
bool
]
=
None
,
synced_gpus
:
Optional
[
bool
]
=
False
,
streamer
:
Optional
[
"BaseStreamer"
]
=
None
,
**
model_kwargs
,
)
->
Union
[
ContrastiveSearchOutput
,
torch
.
LongTensor
]:
r
"""
...
...
@@ -1750,6 +1771,9 @@ class GenerationMixin:
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
...
...
@@ -2010,6 +2034,8 @@ class GenerationMixin:
# update generated ids, model inputs, and length for next step
input_ids
=
torch
.
cat
([
input_ids
,
next_tokens
[:,
None
]],
dim
=-
1
)
if
streamer
is
not
None
:
streamer
.
put
(
next_tokens
.
cpu
())
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
)
...
...
@@ -2027,6 +2053,9 @@ class GenerationMixin:
else
:
this_peer_finished
=
True
if
streamer
is
not
None
:
streamer
.
end
()
if
return_dict_in_generate
:
if
self
.
config
.
is_encoder_decoder
:
return
ContrastiveSearchEncoderDecoderOutput
(
...
...
@@ -2061,6 +2090,7 @@ class GenerationMixin:
output_scores
:
Optional
[
bool
]
=
None
,
return_dict_in_generate
:
Optional
[
bool
]
=
None
,
synced_gpus
:
Optional
[
bool
]
=
False
,
streamer
:
Optional
[
"BaseStreamer"
]
=
None
,
**
model_kwargs
,
)
->
Union
[
GreedySearchOutput
,
torch
.
LongTensor
]:
r
"""
...
...
@@ -2105,6 +2135,9 @@ class GenerationMixin:
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
...
...
@@ -2256,6 +2289,8 @@ class GenerationMixin:
# update generated ids, model inputs, and length for next step
input_ids
=
torch
.
cat
([
input_ids
,
next_tokens
[:,
None
]],
dim
=-
1
)
if
streamer
is
not
None
:
streamer
.
put
(
next_tokens
.
cpu
())
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
)
...
...
@@ -2273,6 +2308,9 @@ class GenerationMixin:
else
:
this_peer_finished
=
True
if
streamer
is
not
None
:
streamer
.
end
()
if
return_dict_in_generate
:
if
self
.
config
.
is_encoder_decoder
:
return
GreedySearchEncoderDecoderOutput
(
...
...
@@ -2308,6 +2346,7 @@ class GenerationMixin:
output_scores
:
Optional
[
bool
]
=
None
,
return_dict_in_generate
:
Optional
[
bool
]
=
None
,
synced_gpus
:
Optional
[
bool
]
=
False
,
streamer
:
Optional
[
"BaseStreamer"
]
=
None
,
**
model_kwargs
,
)
->
Union
[
SampleOutput
,
torch
.
LongTensor
]:
r
"""
...
...
@@ -2354,6 +2393,9 @@ class GenerationMixin:
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
...
...
@@ -2525,6 +2567,8 @@ class GenerationMixin:
# update generated ids, model inputs, and length for next step
input_ids
=
torch
.
cat
([
input_ids
,
next_tokens
[:,
None
]],
dim
=-
1
)
if
streamer
is
not
None
:
streamer
.
put
(
next_tokens
.
cpu
())
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
)
...
...
@@ -2542,6 +2586,9 @@ class GenerationMixin:
else
:
this_peer_finished
=
True
if
streamer
is
not
None
:
streamer
.
end
()
if
return_dict_in_generate
:
if
self
.
config
.
is_encoder_decoder
:
return
SampleEncoderDecoderOutput
(
...
...
tests/generation/test_streamers.py
0 → 100644
View file @
228792a9
# coding=utf-8
# Copyright 2023 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
transformers
import
AutoTokenizer
,
TextStreamer
,
is_torch_available
from
transformers.testing_utils
import
CaptureStdout
,
require_torch
,
torch_device
from
..test_modeling_common
import
ids_tensor
if
is_torch_available
():
from
transformers
import
AutoModelForCausalLM
@
require_torch
class
StreamerTester
(
unittest
.
TestCase
):
def
test_text_streamer_stdout
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
).
to
(
torch_device
)
model
.
config
.
eos_token_id
=
-
1
input_ids
=
ids_tensor
((
1
,
5
),
vocab_size
=
model
.
config
.
vocab_size
).
to
(
torch_device
)
greedy_ids
=
model
.
generate
(
input_ids
,
max_new_tokens
=
10
,
do_sample
=
False
)
greedy_text
=
tokenizer
.
decode
(
greedy_ids
[
0
])
with
CaptureStdout
()
as
cs
:
streamer
=
TextStreamer
(
tokenizer
)
model
.
generate
(
input_ids
,
max_new_tokens
=
10
,
do_sample
=
False
,
streamer
=
streamer
)
# The greedy text should be printed to stdout, except for the final "\n" in the streamer
self
.
assertEqual
(
cs
.
out
[:
-
1
],
greedy_text
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment