Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a55a822a
Unverified
Commit
a55a822a
authored
Apr 03, 2023
by
Joao Gante
Committed by
GitHub
Apr 03, 2023
Browse files
Generate: `TextIteratorStreamer` (streamer for gradio) (#22501)
* haha text go brrr (but in gradio)
parent
7d25c9c8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
125 additions
and
8 deletions
+125
-8
docs/source/en/internal/generation_utils.mdx
docs/source/en/internal/generation_utils.mdx
+2
-0
src/transformers/__init__.py
src/transformers/__init__.py
+2
-2
src/transformers/generation/__init__.py
src/transformers/generation/__init__.py
+5
-2
src/transformers/generation/streamers.py
src/transformers/generation/streamers.py
+91
-0
tests/generation/test_streamers.py
tests/generation/test_streamers.py
+25
-4
No files found.
docs/source/en/internal/generation_utils.mdx
View file @
a55a822a
...
...
@@ -269,3 +269,5 @@ A [`Constraint`] can be used to force the generation to include specific tokens
## Streamers
[[autodoc]] TextStreamer
[[autodoc]] TextIteratorStreamer
src/transformers/__init__.py
View file @
a55a822a
...
...
@@ -96,7 +96,7 @@ _import_structure = {
"feature_extraction_sequence_utils"
:
[
"SequenceFeatureExtractor"
],
"feature_extraction_utils"
:
[
"BatchFeature"
,
"FeatureExtractionMixin"
],
"file_utils"
:
[],
"generation"
:
[
"GenerationConfig"
,
"TextStreamer"
],
"generation"
:
[
"GenerationConfig"
,
"TextIteratorStreamer"
,
"TextStreamer"
],
"hf_argparser"
:
[
"HfArgumentParser"
],
"image_transforms"
:
[],
"integrations"
:
[
...
...
@@ -3770,7 +3770,7 @@ if TYPE_CHECKING:
from
.feature_extraction_utils
import
BatchFeature
,
FeatureExtractionMixin
# Generation
from
.generation
import
GenerationConfig
,
TextStreamer
from
.generation
import
GenerationConfig
,
TextIteratorStreamer
,
TextStreamer
from
.hf_argparser
import
HfArgumentParser
# Integrations
...
...
src/transformers/generation/__init__.py
View file @
a55a822a
...
...
@@ -17,7 +17,10 @@ from typing import TYPE_CHECKING
from
..utils
import
OptionalDependencyNotAvailable
,
_LazyModule
,
is_flax_available
,
is_tf_available
,
is_torch_available
_import_structure
=
{
"configuration_utils"
:
[
"GenerationConfig"
],
"streamers"
:
[
"TextStreamer"
]}
_import_structure
=
{
"configuration_utils"
:
[
"GenerationConfig"
],
"streamers"
:
[
"TextIteratorStreamer"
,
"TextStreamer"
],
}
try
:
if
not
is_torch_available
():
...
...
@@ -149,7 +152,7 @@ else:
if
TYPE_CHECKING
:
from
.configuration_utils
import
GenerationConfig
from
.streamers
import
TextStreamer
from
.streamers
import
TextIteratorStreamer
,
TextStreamer
try
:
if
not
is_torch_available
():
...
...
src/transformers/generation/streamers.py
View file @
a55a822a
...
...
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
queue
import
Queue
from
typing
import
TYPE_CHECKING
...
...
@@ -102,3 +103,93 @@ class TextStreamer(BaseStreamer):
# Print a newline (and the remaining text, if any)
print
(
printable_text
,
flush
=
True
)
class
TextIteratorStreamer
(
BaseStreamer
):
"""
Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
useful for applications that want to use the generated text in a non-blocking way (e.g. in an interactive Gradio
demo).
Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.
Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
>>> from threading import Thread
>>> tok = AutoTokenizer.from_pretrained("gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
>>> streamer = TextIteratorStreamer(tok)
>>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
>>> generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
>>> thread = Thread(target=model.generate, kwargs=generation_kwargs)
>>> thread.start()
>>> generated_text = ""
>>> for new_text in streamer:
... generated_text += new_text
>>> generated_text
'An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,'
```
"""
def
__init__
(
self
,
tokenizer
:
"AutoTokenizer"
):
self
.
tokenizer
=
tokenizer
self
.
token_cache
=
[]
self
.
print_len
=
0
self
.
queue
=
Queue
()
self
.
stop_signal
=
None
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
value
=
self
.
queue
.
get
()
if
value
==
self
.
stop_signal
:
raise
StopIteration
()
else
:
return
value
def
put
(
self
,
value
):
"""
Recives tokens, decodes them, and pushes text to the queue as soon as it form entire words.
"""
if
len
(
value
.
shape
)
>
1
and
value
.
shape
[
0
]
>
1
:
raise
ValueError
(
"TextStreamer only supports batch size 1"
)
elif
len
(
value
.
shape
)
>
1
:
value
=
value
[
0
]
# Add the new token to the cache and decodes the entire thing.
self
.
token_cache
.
extend
(
value
.
tolist
())
text
=
self
.
tokenizer
.
decode
(
self
.
token_cache
)
# After symbol for a new line, we flush the cache.
if
text
.
endswith
(
"
\n
"
):
printable_text
=
text
[
self
.
print_len
:]
self
.
token_cache
=
[]
self
.
print_len
=
0
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else
:
printable_text
=
text
[
self
.
print_len
:
text
.
rfind
(
" "
)
+
1
]
self
.
print_len
+=
len
(
printable_text
)
self
.
queue
.
put
(
printable_text
)
def
end
(
self
):
"""Flushes any remaining cache and puts the stop signal in the queue."""
# Flush the cache, if it exists
if
len
(
self
.
token_cache
)
>
0
:
text
=
self
.
tokenizer
.
decode
(
self
.
token_cache
)
printable_text
=
text
[
self
.
print_len
:]
self
.
token_cache
=
[]
self
.
print_len
=
0
else
:
printable_text
=
""
self
.
queue
.
put
(
printable_text
)
self
.
queue
.
put
(
self
.
stop_signal
)
# Put the stop signal
tests/generation/test_streamers.py
View file @
a55a822a
...
...
@@ -14,8 +14,9 @@
# limitations under the License.
import
unittest
from
threading
import
Thread
from
transformers
import
AutoTokenizer
,
TextStreamer
,
is_torch_available
from
transformers
import
AutoTokenizer
,
TextIteratorStreamer
,
TextStreamer
,
is_torch_available
from
transformers.testing_utils
import
CaptureStdout
,
require_torch
,
torch_device
from
..test_modeling_common
import
ids_tensor
...
...
@@ -27,7 +28,7 @@ if is_torch_available():
@
require_torch
class
StreamerTester
(
unittest
.
TestCase
):
def
test_text_streamer_
stdout
(
self
):
def
test_text_streamer_
matches_non_streaming
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
).
to
(
torch_device
)
model
.
config
.
eos_token_id
=
-
1
...
...
@@ -39,6 +40,26 @@ class StreamerTester(unittest.TestCase):
with
CaptureStdout
()
as
cs
:
streamer
=
TextStreamer
(
tokenizer
)
model
.
generate
(
input_ids
,
max_new_tokens
=
10
,
do_sample
=
False
,
streamer
=
streamer
)
# The greedy text should be printed to stdout, except for the final "\n" in the streamer
self
.
assertEqual
(
cs
.
out
[:
-
1
],
greedy_text
)
streamer_text
=
cs
.
out
[:
-
1
]
self
.
assertEqual
(
streamer_text
,
greedy_text
)
def
test_iterator_streamer_matches_non_streaming
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
).
to
(
torch_device
)
model
.
config
.
eos_token_id
=
-
1
input_ids
=
ids_tensor
((
1
,
5
),
vocab_size
=
model
.
config
.
vocab_size
).
to
(
torch_device
)
greedy_ids
=
model
.
generate
(
input_ids
,
max_new_tokens
=
10
,
do_sample
=
False
)
greedy_text
=
tokenizer
.
decode
(
greedy_ids
[
0
])
streamer
=
TextIteratorStreamer
(
tokenizer
)
generation_kwargs
=
{
"input_ids"
:
input_ids
,
"max_new_tokens"
:
10
,
"do_sample"
:
False
,
"streamer"
:
streamer
}
thread
=
Thread
(
target
=
model
.
generate
,
kwargs
=
generation_kwargs
)
thread
.
start
()
streamer_text
=
""
for
new_text
in
streamer
:
streamer_text
+=
new_text
self
.
assertEqual
(
streamer_text
,
greedy_text
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment