Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
861ff890
Unverified
Commit
861ff890
authored
Apr 05, 2023
by
Joao Gante
Committed by
GitHub
Apr 05, 2023
Browse files
Generate: `TextIteratorStreamer` timeout (#22576)
parent
11fd2c77
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
5 deletions
+29
-5
src/transformers/generation/streamers.py
src/transformers/generation/streamers.py
+11
-5
tests/generation/test_streamers.py
tests/generation/test_streamers.py
+18
-0
No files found.
src/transformers/generation/streamers.py
View file @
861ff890
...
...
@@ -14,7 +14,7 @@
# limitations under the License.
from
queue
import
Queue
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
Optional
if
TYPE_CHECKING
:
...
...
@@ -133,6 +133,9 @@ class TextIteratorStreamer(TextStreamer):
The tokenized used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
timeout (`float`, *optional*):
The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
in `.generate()`, when it is called in a separate thread.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.
...
...
@@ -159,22 +162,25 @@ class TextIteratorStreamer(TextStreamer):
```
"""
def
__init__
(
self
,
tokenizer
:
"AutoTokenizer"
,
skip_prompt
:
bool
=
False
,
**
decode_kwargs
):
def
__init__
(
self
,
tokenizer
:
"AutoTokenizer"
,
skip_prompt
:
bool
=
False
,
timeout
:
Optional
[
float
]
=
None
,
**
decode_kwargs
):
super
().
__init__
(
tokenizer
,
skip_prompt
,
**
decode_kwargs
)
self
.
text_queue
=
Queue
()
self
.
stop_signal
=
None
self
.
timeout
=
timeout
def
on_finalized_text
(
self
,
text
:
str
,
stream_end
:
bool
=
False
):
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
self
.
text_queue
.
put
(
text
)
self
.
text_queue
.
put
(
text
,
timeout
=
self
.
timeout
)
if
stream_end
:
self
.
text_queue
.
put
(
self
.
stop_signal
)
self
.
text_queue
.
put
(
self
.
stop_signal
,
timeout
=
self
.
timeout
)
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
value
=
self
.
text_queue
.
get
()
value
=
self
.
text_queue
.
get
(
timeout
=
self
.
timeout
)
if
value
==
self
.
stop_signal
:
raise
StopIteration
()
else
:
...
...
tests/generation/test_streamers.py
View file @
861ff890
...
...
@@ -14,6 +14,7 @@
# limitations under the License.
import
unittest
from
queue
import
Empty
from
threading
import
Thread
from
transformers
import
AutoTokenizer
,
TextIteratorStreamer
,
TextStreamer
,
is_torch_available
...
...
@@ -102,3 +103,20 @@ class StreamerTester(unittest.TestCase):
streamer_text
=
cs
.
out
[:
-
1
]
# Remove the final "\n"
streamer_text_tokenized
=
tokenizer
(
streamer_text
,
return_tensors
=
"pt"
)
self
.
assertEqual
(
streamer_text_tokenized
.
input_ids
.
shape
,
(
1
,
1
))
def
test_iterator_streamer_timeout
(
self
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
).
to
(
torch_device
)
model
.
config
.
eos_token_id
=
-
1
input_ids
=
ids_tensor
((
1
,
5
),
vocab_size
=
model
.
config
.
vocab_size
).
to
(
torch_device
)
streamer
=
TextIteratorStreamer
(
tokenizer
,
timeout
=
0.001
)
generation_kwargs
=
{
"input_ids"
:
input_ids
,
"max_new_tokens"
:
10
,
"do_sample"
:
False
,
"streamer"
:
streamer
}
thread
=
Thread
(
target
=
model
.
generate
,
kwargs
=
generation_kwargs
)
thread
.
start
()
# The streamer will timeout after 0.001 seconds, so an exception will be raised
with
self
.
assertRaises
(
Empty
):
streamer_text
=
""
for
new_text
in
streamer
:
streamer_text
+=
new_text
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment