Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
4f9ac67c
Unverified
Commit
4f9ac67c
authored
Jan 31, 2023
by
OlivierDehaene
Committed by
GitHub
Jan 31, 2023
Browse files
Revert "feat: Add token streaming using ServerSideEvents support" (#40)
Reverts huggingface/text-generation-inference#36
parent
7fbfbb0d
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
55 deletions
+24
-55
server/text_generation/models/types.py
server/text_generation/models/types.py
+12
-47
server/text_generation/server.py
server/text_generation/server.py
+12
-8
No files found.
server/text_generation/models/types.py
View file @
4f9ac67c
...
...
@@ -29,61 +29,26 @@ class Batch(ABC):
def
concatenate
(
cls
,
batches
:
List
[
"Batch"
])
->
"Batch"
:
raise
NotImplementedError
@
abstractmethod
def
__len__
(
self
):
raise
NotImplementedError
@
dataclass
class
GeneratedText
:
text
:
str
request
:
generate_pb2
.
Request
output_text
:
str
generated_tokens
:
int
finish_reason
:
str
tokens
:
List
[
str
]
token_ids
:
List
[
int
]
logprobs
:
List
[
float
]
reason
:
str
seed
:
Optional
[
int
]
def
to_pb
(
self
)
->
generate_pb2
.
GeneratedText
:
return
generate_pb2
.
GeneratedText
(
text
=
self
.
text
,
request
=
self
.
request
,
output_text
=
self
.
output_text
,
generated_tokens
=
self
.
generated_tokens
,
finish_reason
=
self
.
finish_reason
,
tokens
=
self
.
tokens
,
token_ids
=
self
.
token_ids
,
logprobs
=
self
.
logprobs
,
finish_reason
=
self
.
reason
,
seed
=
self
.
seed
,
)
@
dataclass
class
PrefillTokens
:
token_ids
:
List
[
int
]
logprobs
:
List
[
float
]
texts
:
List
[
str
]
def
to_pb
(
self
)
->
generate_pb2
.
PrefillTokens
:
return
generate_pb2
.
PrefillTokens
(
ids
=
self
.
token_ids
,
logprobs
=
self
.
logprobs
,
texts
=
self
.
texts
)
def
__len__
(
self
):
return
len
(
self
.
token_ids
)
@
dataclass
class
Generation
:
request_id
:
int
prefill_tokens
:
Optional
[
PrefillTokens
]
token_id
:
int
token_logprob
:
float
token_text
:
str
generated_text
:
Optional
[
GeneratedText
]
def
to_pb
(
self
)
->
generate_pb2
.
Generation
:
return
generate_pb2
.
Generation
(
request_id
=
self
.
request_id
,
prefill_tokens
=
self
.
prefill_tokens
.
to_pb
()
if
self
.
prefill_tokens
is
not
None
else
None
,
token_id
=
self
.
token_id
,
token_logprob
=
self
.
token_logprob
,
token_text
=
self
.
token_text
,
generated_text
=
self
.
generated_text
.
to_pb
()
if
self
.
generated_text
is
not
None
else
None
,
)
server/text_generation/server.py
View file @
4f9ac67c
...
...
@@ -27,20 +27,22 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
self
.
cache
.
clear
()
return
generate_pb2
.
ClearCacheResponse
()
async
def
Prefill
(
self
,
request
,
context
):
async
def
Generate
(
self
,
request
,
context
):
batch
=
self
.
model
.
batch_type
.
from_pb
(
request
.
batch
,
self
.
model
.
tokenizer
,
self
.
model
.
device
)
generat
ion
s
,
next_batch
=
self
.
model
.
generate_token
(
batch
)
generat
ed_text
s
,
next_batch
=
self
.
model
.
generate_token
(
batch
)
self
.
cache
.
set
(
next_batch
)
return
generate_pb2
.
PrefillResponse
(
generations
=
[
generation
.
to_pb
()
for
generation
in
generations
],
return
generate_pb2
.
GenerateResponse
(
generated_texts
=
[
generated_text
.
to_pb
()
for
generated_text
in
generated_texts
],
batch
=
next_batch
.
to_pb
()
if
next_batch
else
None
,
)
async
def
Decod
e
(
self
,
request
,
context
):
async
def
GenerateWithCach
e
(
self
,
request
,
context
):
if
len
(
request
.
batches
)
==
0
:
raise
ValueError
(
"Must provide at least one batch"
)
...
...
@@ -56,11 +58,13 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
else
:
batch
=
batches
[
0
]
generat
ion
s
,
next_batch
=
self
.
model
.
generate_token
(
batch
)
generat
ed_text
s
,
next_batch
=
self
.
model
.
generate_token
(
batch
)
self
.
cache
.
set
(
next_batch
)
return
generate_pb2
.
DecodeResponse
(
generations
=
[
generation
.
to_pb
()
for
generation
in
generations
],
return
generate_pb2
.
GenerateWithCacheResponse
(
generated_texts
=
[
generated_text
.
to_pb
()
for
generated_text
in
generated_texts
],
batch
=
next_batch
.
to_pb
()
if
next_batch
else
None
,
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment