Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1905384f
Unverified
Commit
1905384f
authored
Apr 04, 2023
by
Joao Gante
Committed by
GitHub
Apr 04, 2023
Browse files
Generate: Add text streamer decoding options (#22544)
parent
41a2f352
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
77 additions
and
56 deletions
+77
-56
src/transformers/generation/streamers.py
src/transformers/generation/streamers.py
+38
-56
tests/generation/test_streamers.py
tests/generation/test_streamers.py
+39
-0
No files found.
src/transformers/generation/streamers.py
View file @
1905384f
...
...
@@ -42,6 +42,10 @@ class TextStreamer(BaseStreamer):
Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.
Examples:
...
...
@@ -59,10 +63,15 @@ class TextStreamer(BaseStreamer):
```
"""
def
__init__
(
self
,
tokenizer
:
"AutoTokenizer"
):
def
__init__
(
self
,
tokenizer
:
"AutoTokenizer"
,
skip_prompt
:
bool
=
False
,
**
decode_kwargs
):
self
.
tokenizer
=
tokenizer
self
.
skip_prompt
=
skip_prompt
self
.
decode_kwargs
=
decode_kwargs
# variables used in the streaming process
self
.
token_cache
=
[]
self
.
print_len
=
0
self
.
next_tokens_are_prompt
=
True
def
put
(
self
,
value
):
"""
...
...
@@ -73,11 +82,15 @@ class TextStreamer(BaseStreamer):
elif
len
(
value
.
shape
)
>
1
:
value
=
value
[
0
]
if
self
.
skip_prompt
and
self
.
next_tokens_are_prompt
:
self
.
next_tokens_are_prompt
=
False
return
# Add the new token to the cache and decodes the entire thing.
self
.
token_cache
.
extend
(
value
.
tolist
())
text
=
self
.
tokenizer
.
decode
(
self
.
token_cache
)
text
=
self
.
tokenizer
.
decode
(
self
.
token_cache
,
**
self
.
decode_kwargs
)
# After symbol for a new line, we flush the cache.
# After
the
symbol for a new line, we flush the cache.
if
text
.
endswith
(
"
\n
"
):
printable_text
=
text
[
self
.
print_len
:]
self
.
token_cache
=
[]
...
...
@@ -94,30 +107,34 @@ class TextStreamer(BaseStreamer):
"""Flushes any remaining cache and prints a newline to stdout."""
# Flush the cache, if it exists
if
len
(
self
.
token_cache
)
>
0
:
text
=
self
.
tokenizer
.
decode
(
self
.
token_cache
)
text
=
self
.
tokenizer
.
decode
(
self
.
token_cache
,
**
self
.
decode_kwargs
)
printable_text
=
text
[
self
.
print_len
:]
self
.
token_cache
=
[]
self
.
print_len
=
0
else
:
printable_text
=
""
# Print a newline (and the remaining text, if any)
self
.
next_tokens_are_prompt
=
True
self
.
on_finalized_text
(
printable_text
,
stream_end
=
True
)
def
on_finalized_text
(
self
,
t
oken
:
str
,
stream_end
:
bool
=
False
):
"""Prints the new text to stdout."""
print
(
t
oken
,
flush
=
True
,
end
=
""
if
not
stream_end
else
None
)
def
on_finalized_text
(
self
,
t
ext
:
str
,
stream_end
:
bool
=
False
):
"""Prints the new text to stdout.
If the stream is ending, also prints a newline.
"""
print
(
t
ext
,
flush
=
True
,
end
=
""
if
not
stream_end
else
None
)
class
TextIteratorStreamer
(
Base
Streamer
):
class
TextIteratorStreamer
(
Text
Streamer
):
"""
Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
useful for applications that
want to use
the generated text in a non-blocking way (e.g. in an interactive
Gradio
demo).
useful for applications that
benefit from acessing
the generated text in a non-blocking way (e.g. in an interactive
Gradio
demo).
Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.
Examples:
...
...
@@ -142,58 +159,23 @@ class TextIteratorStreamer(BaseStreamer):
```
"""
def
__init__
(
self
,
tokenizer
:
"AutoTokenizer"
):
self
.
tokenizer
=
tokenizer
self
.
token_cache
=
[]
self
.
print_len
=
0
self
.
queue
=
Queue
()
def
__init__
(
self
,
tokenizer
:
"AutoTokenizer"
,
skip_prompt
:
bool
=
False
,
**
decode_kwargs
):
super
().
__init__
(
tokenizer
,
skip_prompt
,
**
decode_kwargs
)
self
.
text_queue
=
Queue
()
self
.
stop_signal
=
None
def
on_finalized_text
(
self
,
text
:
str
,
stream_end
:
bool
=
False
):
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
self
.
text_queue
.
put
(
text
)
if
stream_end
:
self
.
text_queue
.
put
(
self
.
stop_signal
)
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
value
=
self
.
queue
.
get
()
value
=
self
.
text_
queue
.
get
()
if
value
==
self
.
stop_signal
:
raise
StopIteration
()
else
:
return
value
def
put
(
self
,
value
):
"""
Recives tokens, decodes them, and pushes text to the queue as soon as it form entire words.
"""
if
len
(
value
.
shape
)
>
1
and
value
.
shape
[
0
]
>
1
:
raise
ValueError
(
"TextStreamer only supports batch size 1"
)
elif
len
(
value
.
shape
)
>
1
:
value
=
value
[
0
]
# Add the new token to the cache and decodes the entire thing.
self
.
token_cache
.
extend
(
value
.
tolist
())
text
=
self
.
tokenizer
.
decode
(
self
.
token_cache
)
# After symbol for a new line, we flush the cache.
if
text
.
endswith
(
"
\n
"
):
printable_text
=
text
[
self
.
print_len
:]
self
.
token_cache
=
[]
self
.
print_len
=
0
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else
:
printable_text
=
text
[
self
.
print_len
:
text
.
rfind
(
" "
)
+
1
]
self
.
print_len
+=
len
(
printable_text
)
self
.
queue
.
put
(
printable_text
)
def
end
(
self
):
"""Flushes any remaining cache and puts the stop signal in the queue."""
# Flush the cache, if it exists
if
len
(
self
.
token_cache
)
>
0
:
text
=
self
.
tokenizer
.
decode
(
self
.
token_cache
)
printable_text
=
text
[
self
.
print_len
:]
self
.
token_cache
=
[]
self
.
print_len
=
0
else
:
printable_text
=
""
self
.
queue
.
put
(
printable_text
)
self
.
queue
.
put
(
self
.
stop_signal
)
# Put the stop signal
tests/generation/test_streamers.py
View file @
1905384f
...
...
@@ -23,6 +23,8 @@ from ..test_modeling_common import ids_tensor
if
is_torch_available
():
import
torch
from
transformers
import
AutoModelForCausalLM
...
...
@@ -63,3 +65,40 @@ class StreamerTester(unittest.TestCase):
streamer_text
+=
new_text
self
.
assertEqual
(
streamer_text
,
greedy_text
)
def
test_text_streamer_skip_prompt
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
).
to
(
torch_device
)
model
.
config
.
eos_token_id
=
-
1
input_ids
=
ids_tensor
((
1
,
5
),
vocab_size
=
model
.
config
.
vocab_size
).
to
(
torch_device
)
greedy_ids
=
model
.
generate
(
input_ids
,
max_new_tokens
=
10
,
do_sample
=
False
)
new_greedy_ids
=
greedy_ids
[:,
input_ids
.
shape
[
1
]
:]
new_greedy_text
=
tokenizer
.
decode
(
new_greedy_ids
[
0
])
with
CaptureStdout
()
as
cs
:
streamer
=
TextStreamer
(
tokenizer
,
skip_prompt
=
True
)
model
.
generate
(
input_ids
,
max_new_tokens
=
10
,
do_sample
=
False
,
streamer
=
streamer
)
# The greedy text should be printed to stdout, except for the final "\n" in the streamer
streamer_text
=
cs
.
out
[:
-
1
]
self
.
assertEqual
(
streamer_text
,
new_greedy_text
)
def
test_text_streamer_decode_kwargs
(
self
):
# Tests that we can pass `decode_kwargs` to the streamer to control how the tokens are decoded. Must be tested
# with actual models -- the dummy models' tokenizers are not aligned with their models, and
# `skip_special_tokens=True` has no effect on them
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"distilgpt2"
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
"distilgpt2"
).
to
(
torch_device
)
model
.
config
.
eos_token_id
=
-
1
input_ids
=
torch
.
ones
((
1
,
5
),
device
=
torch_device
).
long
()
*
model
.
config
.
bos_token_id
with
CaptureStdout
()
as
cs
:
streamer
=
TextStreamer
(
tokenizer
,
skip_special_tokens
=
True
)
model
.
generate
(
input_ids
,
max_new_tokens
=
1
,
do_sample
=
False
,
streamer
=
streamer
)
# The prompt contains a special token, so the streamer should not print it. As such, the output text, when
# re-tokenized, must only contain one token
streamer_text
=
cs
.
out
[:
-
1
]
# Remove the final "\n"
streamer_text_tokenized
=
tokenizer
(
streamer_text
,
return_tensors
=
"pt"
)
self
.
assertEqual
(
streamer_text_tokenized
.
input_ids
.
shape
,
(
1
,
1
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment