Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
895c5f15
Unverified
Commit
895c5f15
authored
Jun 02, 2023
by
OlivierDehaene
Committed by
GitHub
Jun 02, 2023
Browse files
feat(server): only compute prefill logprobs when asked (#406)
Close #288
parent
83b84486
Changes
36
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
109 additions
and
30 deletions
+109
-30
Makefile
Makefile
+1
-0
benchmark/src/generation.rs
benchmark/src/generation.rs
+1
-0
clients/python/README.md
clients/python/README.md
+40
-6
clients/python/pyproject.toml
clients/python/pyproject.toml
+1
-1
clients/python/tests/test_client.py
clients/python/tests/test_client.py
+13
-9
clients/python/text_generation/client.py
clients/python/text_generation/client.py
+10
-0
clients/python/text_generation/types.py
clients/python/text_generation/types.py
+8
-6
integration-tests/conftest.py
integration-tests/conftest.py
+6
-3
integration-tests/models/test_bloom_560m.py
integration-tests/models/test_bloom_560m.py
+2
-0
integration-tests/models/test_bloom_560m_sharded.py
integration-tests/models/test_bloom_560m_sharded.py
+1
-0
integration-tests/models/test_flash_falcon.py
integration-tests/models/test_flash_falcon.py
+2
-0
integration-tests/models/test_flash_llama.py
integration-tests/models/test_flash_llama.py
+4
-1
integration-tests/models/test_flash_neox.py
integration-tests/models/test_flash_neox.py
+1
-0
integration-tests/models/test_flash_neox_sharded.py
integration-tests/models/test_flash_neox_sharded.py
+1
-0
integration-tests/models/test_flash_santacoder.py
integration-tests/models/test_flash_santacoder.py
+3
-1
integration-tests/models/test_flash_starcoder.py
integration-tests/models/test_flash_starcoder.py
+9
-2
integration-tests/models/test_mt0_base.py
integration-tests/models/test_mt0_base.py
+2
-0
integration-tests/models/test_t5_sharded.py
integration-tests/models/test_t5_sharded.py
+1
-0
integration-tests/requirements.txt
integration-tests/requirements.txt
+1
-1
proto/generate.proto
proto/generate.proto
+2
-0
No files found.
Makefile
View file @
895c5f15
...
...
@@ -3,6 +3,7 @@ install-server:
install-integration-tests
:
cd
integration-tests
&&
pip
install
-r
requirements.txt
cd
clients/python
&&
pip
install
.
install-router
:
cd
router
&&
cargo
install
--path
.
...
...
benchmark/src/generation.rs
View file @
895c5f15
...
...
@@ -136,6 +136,7 @@ async fn prefill(
let
requests
=
(
0
..
batch_size
)
.map
(|
id
|
Request
{
id
:
id
.into
(),
prefill_logprobs
:
false
,
inputs
:
sequence
.clone
(),
truncate
:
sequence_length
,
parameters
:
Some
(
parameters
.clone
()),
...
...
clients/python/README.md
View file @
895c5f15
...
...
@@ -107,8 +107,42 @@ print(text)
### Types
```
python
# Prompt tokens
class
PrefillToken
:
# Request Parameters
class
Parameters
:
# Activate logits sampling
do_sample
:
bool
# Maximum number of generated tokens
max_new_tokens
:
int
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty
:
Optional
[
float
]
# Whether to prepend the prompt to the generated text
return_full_text
:
bool
# Stop generating tokens if a member of `stop_sequences` is generated
stop
:
List
[
str
]
# Random sampling seed
seed
:
Optional
[
int
]
# The value used to module the logits distribution.
temperature
:
Optional
[
float
]
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_k
:
Optional
[
int
]
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
# higher are kept for generation.
top_p
:
Optional
[
float
]
# truncate inputs tokens to the given size
truncate
:
Optional
[
int
]
# Typical Decoding mass
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
typical_p
:
Optional
[
float
]
# Generate best_of sequences and return the one if the highest token logprobs
best_of
:
Optional
[
int
]
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
watermark
:
bool
# Get decoder input token logprobs and ids
decoder_input_details
:
bool
# Decoder input tokens
class
InputToken
:
# Token ID from the model tokenizer
id
:
int
# Token text
...
...
@@ -151,8 +185,8 @@ class BestOfSequence:
generated_tokens
:
int
# Sampling seed if sampling was activated
seed
:
Optional
[
int
]
#
Prompt tokens
prefill
:
List
[
Prefill
Token
]
#
Decoder input tokens, empty if decoder_input_details is False
prefill
:
List
[
Input
Token
]
# Generated tokens
tokens
:
List
[
Token
]
...
...
@@ -165,8 +199,8 @@ class Details:
generated_tokens
:
int
# Sampling seed if sampling was activated
seed
:
Optional
[
int
]
#
Prompt tokens
prefill
:
List
[
Prefill
Token
]
#
Decoder input tokens, empty if decoder_input_details is False
prefill
:
List
[
Input
Token
]
# Generated tokens
tokens
:
List
[
Token
]
# Additional sequences when using the `best_of` parameter
...
...
clients/python/pyproject.toml
View file @
895c5f15
[tool.poetry]
name
=
"text-generation"
version
=
"0.
5.2
"
version
=
"0.
6.0
"
description
=
"Hugging Face Text Generation Python Client"
license
=
"Apache-2.0"
authors
=
[
"Olivier Dehaene <olivier@huggingface.co>"
]
...
...
clients/python/tests/test_client.py
View file @
895c5f15
...
...
@@ -2,28 +2,30 @@ import pytest
from
text_generation
import
Client
,
AsyncClient
from
text_generation.errors
import
NotFoundError
,
ValidationError
from
text_generation.types
import
FinishReason
,
PrefillToken
,
Token
from
text_generation.types
import
FinishReason
,
Input
Token
def
test_generate
(
flan_t5_xxl_url
,
hf_headers
):
client
=
Client
(
flan_t5_xxl_url
,
hf_headers
)
response
=
client
.
generate
(
"test"
,
max_new_tokens
=
1
)
response
=
client
.
generate
(
"test"
,
max_new_tokens
=
1
,
decoder_input_details
=
True
)
assert
response
.
generated_text
==
""
assert
response
.
details
.
finish_reason
==
FinishReason
.
Length
assert
response
.
details
.
generated_tokens
==
1
assert
response
.
details
.
seed
is
None
assert
len
(
response
.
details
.
prefill
)
==
1
assert
response
.
details
.
prefill
[
0
]
==
Prefill
Token
(
id
=
0
,
text
=
"<pad>"
,
logprob
=
None
)
assert
response
.
details
.
prefill
[
0
]
==
Input
Token
(
id
=
0
,
text
=
"<pad>"
,
logprob
=
None
)
assert
len
(
response
.
details
.
tokens
)
==
1
assert
response
.
details
.
tokens
[
0
].
id
==
3
assert
response
.
details
.
tokens
[
0
].
text
==
""
assert
response
.
details
.
tokens
[
0
].
text
==
"
"
assert
not
response
.
details
.
tokens
[
0
].
special
def
test_generate_best_of
(
flan_t5_xxl_url
,
hf_headers
):
client
=
Client
(
flan_t5_xxl_url
,
hf_headers
)
response
=
client
.
generate
(
"test"
,
max_new_tokens
=
1
,
best_of
=
2
,
do_sample
=
True
)
response
=
client
.
generate
(
"test"
,
max_new_tokens
=
1
,
best_of
=
2
,
do_sample
=
True
,
decoder_input_details
=
True
)
assert
response
.
details
.
seed
is
not
None
assert
response
.
details
.
best_of_sequences
is
not
None
...
...
@@ -73,17 +75,19 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
@
pytest
.
mark
.
asyncio
async
def
test_generate_async
(
flan_t5_xxl_url
,
hf_headers
):
client
=
AsyncClient
(
flan_t5_xxl_url
,
hf_headers
)
response
=
await
client
.
generate
(
"test"
,
max_new_tokens
=
1
)
response
=
await
client
.
generate
(
"test"
,
max_new_tokens
=
1
,
decoder_input_details
=
True
)
assert
response
.
generated_text
==
""
assert
response
.
details
.
finish_reason
==
FinishReason
.
Length
assert
response
.
details
.
generated_tokens
==
1
assert
response
.
details
.
seed
is
None
assert
len
(
response
.
details
.
prefill
)
==
1
assert
response
.
details
.
prefill
[
0
]
==
Prefill
Token
(
id
=
0
,
text
=
"<pad>"
,
logprob
=
None
)
assert
response
.
details
.
prefill
[
0
]
==
Input
Token
(
id
=
0
,
text
=
"<pad>"
,
logprob
=
None
)
assert
len
(
response
.
details
.
tokens
)
==
1
assert
response
.
details
.
tokens
[
0
].
id
==
3
assert
response
.
details
.
tokens
[
0
].
text
==
""
assert
response
.
details
.
tokens
[
0
].
text
==
"
"
assert
not
response
.
details
.
tokens
[
0
].
special
...
...
@@ -91,7 +95,7 @@ async def test_generate_async(flan_t5_xxl_url, hf_headers):
async
def
test_generate_async_best_of
(
flan_t5_xxl_url
,
hf_headers
):
client
=
AsyncClient
(
flan_t5_xxl_url
,
hf_headers
)
response
=
await
client
.
generate
(
"test"
,
max_new_tokens
=
1
,
best_of
=
2
,
do_sample
=
True
"test"
,
max_new_tokens
=
1
,
best_of
=
2
,
do_sample
=
True
,
decoder_input_details
=
True
)
assert
response
.
details
.
seed
is
not
None
...
...
clients/python/text_generation/client.py
View file @
895c5f15
...
...
@@ -74,6 +74,7 @@ class Client:
truncate
:
Optional
[
int
]
=
None
,
typical_p
:
Optional
[
float
]
=
None
,
watermark
:
bool
=
False
,
decoder_input_details
:
bool
=
False
,
)
->
Response
:
"""
Given a prompt, generate the following text
...
...
@@ -110,6 +111,8 @@ class Client:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
Returns:
Response: generated response
...
...
@@ -130,6 +133,7 @@ class Client:
truncate
=
truncate
,
typical_p
=
typical_p
,
watermark
=
watermark
,
decoder_input_details
=
decoder_input_details
,
)
request
=
Request
(
inputs
=
prompt
,
stream
=
False
,
parameters
=
parameters
)
...
...
@@ -202,6 +206,7 @@ class Client:
parameters
=
Parameters
(
best_of
=
None
,
details
=
True
,
decoder_input_details
=
False
,
do_sample
=
do_sample
,
max_new_tokens
=
max_new_tokens
,
repetition_penalty
=
repetition_penalty
,
...
...
@@ -311,6 +316,7 @@ class AsyncClient:
truncate
:
Optional
[
int
]
=
None
,
typical_p
:
Optional
[
float
]
=
None
,
watermark
:
bool
=
False
,
decoder_input_details
:
bool
=
False
,
)
->
Response
:
"""
Given a prompt, generate the following text asynchronously
...
...
@@ -347,6 +353,8 @@ class AsyncClient:
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
decoder_input_details (`bool`):
Return the decoder input token logprobs and ids
Returns:
Response: generated response
...
...
@@ -355,6 +363,7 @@ class AsyncClient:
parameters
=
Parameters
(
best_of
=
best_of
,
details
=
True
,
decoder_input_details
=
decoder_input_details
,
do_sample
=
do_sample
,
max_new_tokens
=
max_new_tokens
,
repetition_penalty
=
repetition_penalty
,
...
...
@@ -437,6 +446,7 @@ class AsyncClient:
parameters
=
Parameters
(
best_of
=
None
,
details
=
True
,
decoder_input_details
=
False
,
do_sample
=
do_sample
,
max_new_tokens
=
max_new_tokens
,
repetition_penalty
=
repetition_penalty
,
...
...
clients/python/text_generation/types.py
View file @
895c5f15
...
...
@@ -37,6 +37,8 @@ class Parameters(BaseModel):
watermark
:
bool
=
False
# Get generation details
details
:
bool
=
False
# Get decoder input token logprobs and ids
decoder_input_details
:
bool
=
False
@
validator
(
"best_of"
)
def
valid_best_of
(
cls
,
field_value
,
values
):
...
...
@@ -129,8 +131,8 @@ class Request(BaseModel):
return
field_value
#
Promp
t tokens
class
Prefill
Token
(
BaseModel
):
#
Decoder inpu
t tokens
class
Input
Token
(
BaseModel
):
# Token ID from the model tokenizer
id
:
int
# Token text
...
...
@@ -173,8 +175,8 @@ class BestOfSequence(BaseModel):
generated_tokens
:
int
# Sampling seed if sampling was activated
seed
:
Optional
[
int
]
#
Prompt tokens
prefill
:
List
[
Prefill
Token
]
#
Decoder input tokens, empty if decoder_input_details is False
prefill
:
List
[
Input
Token
]
# Generated tokens
tokens
:
List
[
Token
]
...
...
@@ -187,8 +189,8 @@ class Details(BaseModel):
generated_tokens
:
int
# Sampling seed if sampling was activated
seed
:
Optional
[
int
]
#
Prompt tokens
prefill
:
List
[
Prefill
Token
]
#
Decoder input tokens, empty if decoder_input_details is False
prefill
:
List
[
Input
Token
]
# Generated tokens
tokens
:
List
[
Token
]
# Additional sequences when using the `best_of` parameter
...
...
integration-tests/conftest.py
View file @
895c5f15
...
...
@@ -16,7 +16,7 @@ from syrupy.extensions.json import JSONSnapshotExtension
from
aiohttp
import
ClientConnectorError
,
ClientOSError
,
ServerDisconnectedError
from
text_generation
import
AsyncClient
from
text_generation.types
import
Response
,
Details
,
Prefill
Token
,
Token
,
BestOfSequence
from
text_generation.types
import
Response
,
Details
,
Input
Token
,
Token
,
BestOfSequence
DOCKER_IMAGE
=
os
.
getenv
(
"DOCKER_IMAGE"
,
None
)
HUGGING_FACE_HUB_TOKEN
=
os
.
getenv
(
"HUGGING_FACE_HUB_TOKEN"
,
None
)
...
...
@@ -62,7 +62,7 @@ class ResponseComparator(JSONSnapshotExtension):
and
token
.
special
==
other
.
special
)
def
eq_prefill_token
(
prefill_token
:
Prefill
Token
,
other
:
Prefill
Token
)
->
bool
:
def
eq_prefill_token
(
prefill_token
:
Input
Token
,
other
:
Input
Token
)
->
bool
:
try
:
return
(
prefill_token
.
id
==
other
.
id
...
...
@@ -332,7 +332,10 @@ def generate_load():
client
:
AsyncClient
,
prompt
:
str
,
max_new_tokens
:
int
,
n
:
int
)
->
List
[
Response
]:
futures
=
[
client
.
generate
(
prompt
,
max_new_tokens
=
max_new_tokens
)
for
_
in
range
(
n
)
client
.
generate
(
prompt
,
max_new_tokens
=
max_new_tokens
,
decoder_input_details
=
True
)
for
_
in
range
(
n
)
]
return
await
asyncio
.
gather
(
*
futures
)
...
...
integration-tests/models/test_bloom_560m.py
View file @
895c5f15
...
...
@@ -19,6 +19,7 @@ async def test_bloom_560m(bloom_560, response_snapshot):
"Pour déguster un ortolan, il faut tout d'abord"
,
max_new_tokens
=
10
,
top_p
=
0.9
,
decoder_input_details
=
True
,
seed
=
0
,
)
...
...
@@ -40,6 +41,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot):
truncate
=
5
,
typical_p
=
0.9
,
watermark
=
True
,
decoder_input_details
=
True
,
seed
=
0
,
)
...
...
integration-tests/models/test_bloom_560m_sharded.py
View file @
895c5f15
...
...
@@ -19,6 +19,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
"Pour déguster un ortolan, il faut tout d'abord"
,
max_new_tokens
=
10
,
top_p
=
0.9
,
decoder_input_details
=
True
,
seed
=
0
,
)
...
...
integration-tests/models/test_flash_falcon.py
View file @
895c5f15
...
...
@@ -19,6 +19,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot):
response
=
await
flash_falcon
.
generate
(
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.
\n
Daniel: Hello, Girafatron!
\n
Girafatron:"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
,
)
assert
response
.
details
.
generated_tokens
==
10
...
...
@@ -40,6 +41,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
truncate
=
5
,
typical_p
=
0.9
,
watermark
=
True
,
decoder_input_details
=
True
,
seed
=
0
,
)
...
...
integration-tests/models/test_flash_llama.py
View file @
895c5f15
...
...
@@ -16,7 +16,9 @@ async def flash_llama(flash_llama_handle):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama
(
flash_llama
,
response_snapshot
):
response
=
await
flash_llama
.
generate
(
"Test request"
,
max_new_tokens
=
10
)
response
=
await
flash_llama
.
generate
(
"Test request"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
)
assert
response
.
details
.
generated_tokens
==
10
assert
response
==
response_snapshot
...
...
@@ -37,6 +39,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot):
truncate
=
5
,
typical_p
=
0.9
,
watermark
=
True
,
decoder_input_details
=
True
,
seed
=
0
,
)
...
...
integration-tests/models/test_flash_neox.py
View file @
895c5f15
...
...
@@ -18,6 +18,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
response
=
await
flash_neox
.
generate
(
"<|USER|>What's your mood today?<|ASSISTANT|>"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
,
)
assert
response
.
details
.
generated_tokens
==
10
...
...
integration-tests/models/test_flash_neox_sharded.py
View file @
895c5f15
...
...
@@ -18,6 +18,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot):
response
=
await
flash_neox_sharded
.
generate
(
"<|prompter|>What is a meme, and what's the history behind this word?<|endoftext|><|assistant|>"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
,
)
assert
response
.
details
.
generated_tokens
==
10
...
...
integration-tests/models/test_flash_santacoder.py
View file @
895c5f15
...
...
@@ -15,7 +15,9 @@ async def flash_santacoder(flash_santacoder_handle):
@
pytest
.
mark
.
asyncio
async
def
test_flash_santacoder
(
flash_santacoder
,
response_snapshot
):
response
=
await
flash_santacoder
.
generate
(
"def print_hello"
,
max_new_tokens
=
10
)
response
=
await
flash_santacoder
.
generate
(
"def print_hello"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
)
assert
response
.
details
.
generated_tokens
==
10
assert
response
==
response_snapshot
...
...
integration-tests/models/test_flash_starcoder.py
View file @
895c5f15
...
...
@@ -16,7 +16,9 @@ async def flash_starcoder(flash_starcoder_handle):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_starcoder
(
flash_starcoder
,
response_snapshot
):
response
=
await
flash_starcoder
.
generate
(
"def print_hello"
,
max_new_tokens
=
10
)
response
=
await
flash_starcoder
.
generate
(
"def print_hello"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
)
assert
response
.
details
.
generated_tokens
==
10
assert
response
==
response_snapshot
...
...
@@ -26,7 +28,12 @@ async def test_flash_starcoder(flash_starcoder, response_snapshot):
@
pytest
.
mark
.
private
async
def
test_flash_starcoder_default_params
(
flash_starcoder
,
response_snapshot
):
response
=
await
flash_starcoder
.
generate
(
"def print_hello"
,
max_new_tokens
=
60
,
temperature
=
0.2
,
top_p
=
0.95
,
seed
=
0
"def print_hello"
,
max_new_tokens
=
60
,
temperature
=
0.2
,
top_p
=
0.95
,
decoder_input_details
=
True
,
seed
=
0
,
)
assert
response
.
details
.
generated_tokens
==
60
...
...
integration-tests/models/test_mt0_base.py
View file @
895c5f15
...
...
@@ -19,6 +19,7 @@ async def test_mt0_base(mt0_base, response_snapshot):
"Why is the sky blue?"
,
max_new_tokens
=
10
,
top_p
=
0.9
,
decoder_input_details
=
True
,
seed
=
0
,
)
...
...
@@ -40,6 +41,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
truncate
=
5
,
typical_p
=
0.9
,
watermark
=
True
,
decoder_input_details
=
True
,
seed
=
0
,
)
...
...
integration-tests/models/test_t5_sharded.py
View file @
895c5f15
...
...
@@ -18,6 +18,7 @@ async def test_t5_sharded(t5_sharded, response_snapshot):
response
=
await
t5_sharded
.
generate
(
"Please answer the following question. What is the boiling point of Nitrogen?"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
,
)
assert
response
==
response_snapshot
...
...
integration-tests/requirements.txt
View file @
895c5f15
syrupy
text-generation
==0.5.2
text-generation
pytest
pytest-asyncio==0.17.2
docker
\ No newline at end of file
proto/generate.proto
View file @
895c5f15
...
...
@@ -87,6 +87,8 @@ message Request {
NextTokenChooserParameters
parameters
=
4
;
/// Stopping Criteria Parameters
StoppingCriteriaParameters
stopping_parameters
=
5
;
/// Return prefill logprobs
bool
prefill_logprobs
=
6
;
}
message
Batch
{
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment