Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
d8dc8f1b
Unverified
Commit
d8dc8f1b
authored
Mar 09, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 09, 2023
Browse files
feat(python-client): add new parameters (#118)
parent
55bd4fed
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
278 additions
and
40 deletions
+278
-40
clients/python/README.md
clients/python/README.md
+18
-0
clients/python/pyproject.toml
clients/python/pyproject.toml
+1
-1
clients/python/tests/conftest.py
clients/python/tests/conftest.py
+0
-5
clients/python/tests/test_client.py
clients/python/tests/test_client.py
+26
-20
clients/python/tests/test_inference_api.py
clients/python/tests/test_inference_api.py
+4
-4
clients/python/tests/test_types.py
clients/python/tests/test_types.py
+45
-2
clients/python/text_generation/__init__.py
clients/python/text_generation/__init__.py
+1
-1
clients/python/text_generation/client.py
clients/python/text_generation/client.py
+52
-0
clients/python/text_generation/types.py
clients/python/text_generation/types.py
+131
-7
No files found.
clients/python/README.md
View file @
d8dc8f1b
...
@@ -133,6 +133,22 @@ class FinishReason(Enum):
...
@@ -133,6 +133,22 @@ class FinishReason(Enum):
StopSequence
=
"stop_sequence"
StopSequence
=
"stop_sequence"
# Additional sequences when using the `best_of` parameter
class
BestOfSequence
:
# Generated text
generated_text
:
str
# Generation finish reason
finish_reason
:
FinishReason
# Number of generated tokens
generated_tokens
:
int
# Sampling seed if sampling was activated
seed
:
Optional
[
int
]
# Prompt tokens
prefill
:
List
[
PrefillToken
]
# Generated tokens
tokens
:
List
[
Token
]
# `generate` details
# `generate` details
class
Details
:
class
Details
:
# Generation finish reason
# Generation finish reason
...
@@ -145,6 +161,8 @@ class Details:
...
@@ -145,6 +161,8 @@ class Details:
prefill
:
List
[
PrefillToken
]
prefill
:
List
[
PrefillToken
]
# Generated tokens
# Generated tokens
tokens
:
List
[
Token
]
tokens
:
List
[
Token
]
# Additional sequences when using the `best_of` parameter
best_of_sequences
:
Optional
[
List
[
BestOfSequence
]]
# `generate` return value
# `generate` return value
...
...
clients/python/pyproject.toml
View file @
d8dc8f1b
[tool.poetry]
[tool.poetry]
name
=
"text-generation"
name
=
"text-generation"
version
=
"0.
2.1
"
version
=
"0.
3.0
"
description
=
"Hugging Face Text Generation Python Client"
description
=
"Hugging Face Text Generation Python Client"
license
=
"Apache-2.0"
license
=
"Apache-2.0"
authors
=
[
"Olivier Dehaene <olivier@huggingface.co>"
]
authors
=
[
"Olivier Dehaene <olivier@huggingface.co>"
]
...
...
clients/python/tests/conftest.py
View file @
d8dc8f1b
...
@@ -4,11 +4,6 @@ from text_generation import __version__
...
@@ -4,11 +4,6 @@ from text_generation import __version__
from
huggingface_hub.utils
import
build_hf_headers
from
huggingface_hub.utils
import
build_hf_headers
@
pytest
.
fixture
def
bloom_model
():
return
"bigscience/bloom"
@
pytest
.
fixture
@
pytest
.
fixture
def
flan_t5_xxl
():
def
flan_t5_xxl
():
return
"google/flan-t5-xxl"
return
"google/flan-t5-xxl"
...
...
clients/python/tests/test_client.py
View file @
d8dc8f1b
...
@@ -5,24 +5,32 @@ from text_generation.errors import NotFoundError, ValidationError
...
@@ -5,24 +5,32 @@ from text_generation.errors import NotFoundError, ValidationError
from
text_generation.types
import
FinishReason
,
PrefillToken
,
Token
from
text_generation.types
import
FinishReason
,
PrefillToken
,
Token
def
test_generate
(
bloom
_url
,
hf_headers
):
def
test_generate
(
flan_t5_xxl
_url
,
hf_headers
):
client
=
Client
(
bloom
_url
,
hf_headers
)
client
=
Client
(
flan_t5_xxl
_url
,
hf_headers
)
response
=
client
.
generate
(
"test"
,
max_new_tokens
=
1
)
response
=
client
.
generate
(
"test"
,
max_new_tokens
=
1
)
assert
response
.
generated_text
==
"
.
"
assert
response
.
generated_text
==
""
assert
response
.
details
.
finish_reason
==
FinishReason
.
Length
assert
response
.
details
.
finish_reason
==
FinishReason
.
Length
assert
response
.
details
.
generated_tokens
==
1
assert
response
.
details
.
generated_tokens
==
1
assert
response
.
details
.
seed
is
None
assert
response
.
details
.
seed
is
None
assert
len
(
response
.
details
.
prefill
)
==
1
assert
len
(
response
.
details
.
prefill
)
==
1
assert
response
.
details
.
prefill
[
0
]
==
PrefillToken
(
assert
response
.
details
.
prefill
[
0
]
==
PrefillToken
(
id
=
0
,
text
=
"<pad>"
,
logprob
=
None
)
id
=
9234
,
text
=
"test"
,
logprob
=
None
)
assert
len
(
response
.
details
.
tokens
)
==
1
assert
len
(
response
.
details
.
tokens
)
==
1
assert
response
.
details
.
tokens
[
0
]
==
Token
(
assert
response
.
details
.
tokens
[
0
]
==
Token
(
id
=
17
,
text
=
"
.
"
,
logprob
=-
1.75
,
special
=
False
id
=
3
,
text
=
"
"
,
logprob
=-
1.
9843
75
,
special
=
False
)
)
def
test_generate_best_of
(
flan_t5_xxl_url
,
hf_headers
):
client
=
Client
(
flan_t5_xxl_url
,
hf_headers
)
response
=
client
.
generate
(
"test"
,
max_new_tokens
=
1
,
best_of
=
2
,
do_sample
=
True
)
assert
response
.
details
.
seed
is
not
None
assert
response
.
details
.
best_of_sequences
is
not
None
assert
len
(
response
.
details
.
best_of_sequences
)
==
1
assert
response
.
details
.
best_of_sequences
[
0
].
seed
is
not
None
def
test_generate_not_found
(
fake_url
,
hf_headers
):
def
test_generate_not_found
(
fake_url
,
hf_headers
):
client
=
Client
(
fake_url
,
hf_headers
)
client
=
Client
(
fake_url
,
hf_headers
)
with
pytest
.
raises
(
NotFoundError
):
with
pytest
.
raises
(
NotFoundError
):
...
@@ -35,8 +43,8 @@ def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
...
@@ -35,8 +43,8 @@ def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
client
.
generate
(
"test"
,
max_new_tokens
=
10_000
)
client
.
generate
(
"test"
,
max_new_tokens
=
10_000
)
def
test_generate_stream
(
bloom
_url
,
hf_headers
):
def
test_generate_stream
(
flan_t5_xxl
_url
,
hf_headers
):
client
=
Client
(
bloom
_url
,
hf_headers
)
client
=
Client
(
flan_t5_xxl
_url
,
hf_headers
)
responses
=
[
responses
=
[
response
for
response
in
client
.
generate_stream
(
"test"
,
max_new_tokens
=
1
)
response
for
response
in
client
.
generate_stream
(
"test"
,
max_new_tokens
=
1
)
]
]
...
@@ -44,7 +52,7 @@ def test_generate_stream(bloom_url, hf_headers):
...
@@ -44,7 +52,7 @@ def test_generate_stream(bloom_url, hf_headers):
assert
len
(
responses
)
==
1
assert
len
(
responses
)
==
1
response
=
responses
[
0
]
response
=
responses
[
0
]
assert
response
.
generated_text
==
"
.
"
assert
response
.
generated_text
==
""
assert
response
.
details
.
finish_reason
==
FinishReason
.
Length
assert
response
.
details
.
finish_reason
==
FinishReason
.
Length
assert
response
.
details
.
generated_tokens
==
1
assert
response
.
details
.
generated_tokens
==
1
assert
response
.
details
.
seed
is
None
assert
response
.
details
.
seed
is
None
...
@@ -63,21 +71,19 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
...
@@ -63,21 +71,19 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_generate_async
(
bloom
_url
,
hf_headers
):
async
def
test_generate_async
(
flan_t5_xxl
_url
,
hf_headers
):
client
=
AsyncClient
(
bloom
_url
,
hf_headers
)
client
=
AsyncClient
(
flan_t5_xxl
_url
,
hf_headers
)
response
=
await
client
.
generate
(
"test"
,
max_new_tokens
=
1
)
response
=
await
client
.
generate
(
"test"
,
max_new_tokens
=
1
)
assert
response
.
generated_text
==
"
.
"
assert
response
.
generated_text
==
""
assert
response
.
details
.
finish_reason
==
FinishReason
.
Length
assert
response
.
details
.
finish_reason
==
FinishReason
.
Length
assert
response
.
details
.
generated_tokens
==
1
assert
response
.
details
.
generated_tokens
==
1
assert
response
.
details
.
seed
is
None
assert
response
.
details
.
seed
is
None
assert
len
(
response
.
details
.
prefill
)
==
1
assert
len
(
response
.
details
.
prefill
)
==
1
assert
response
.
details
.
prefill
[
0
]
==
PrefillToken
(
assert
response
.
details
.
prefill
[
0
]
==
PrefillToken
(
id
=
0
,
text
=
"<pad>"
,
logprob
=
None
)
id
=
9234
,
text
=
"test"
,
logprob
=
None
)
assert
len
(
response
.
details
.
tokens
)
==
1
assert
len
(
response
.
details
.
tokens
)
==
1
assert
response
.
details
.
tokens
[
0
]
==
Token
(
assert
response
.
details
.
tokens
[
0
]
==
Token
(
id
=
17
,
text
=
"
.
"
,
logprob
=-
1.75
,
special
=
False
id
=
3
,
text
=
"
"
,
logprob
=-
1.
9843
75
,
special
=
False
)
)
...
@@ -96,8 +102,8 @@ async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
...
@@ -96,8 +102,8 @@ async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_generate_stream_async
(
bloom
_url
,
hf_headers
):
async
def
test_generate_stream_async
(
flan_t5_xxl
_url
,
hf_headers
):
client
=
AsyncClient
(
bloom
_url
,
hf_headers
)
client
=
AsyncClient
(
flan_t5_xxl
_url
,
hf_headers
)
responses
=
[
responses
=
[
response
async
for
response
in
client
.
generate_stream
(
"test"
,
max_new_tokens
=
1
)
response
async
for
response
in
client
.
generate_stream
(
"test"
,
max_new_tokens
=
1
)
]
]
...
@@ -105,7 +111,7 @@ async def test_generate_stream_async(bloom_url, hf_headers):
...
@@ -105,7 +111,7 @@ async def test_generate_stream_async(bloom_url, hf_headers):
assert
len
(
responses
)
==
1
assert
len
(
responses
)
==
1
response
=
responses
[
0
]
response
=
responses
[
0
]
assert
response
.
generated_text
==
"
.
"
assert
response
.
generated_text
==
""
assert
response
.
details
.
finish_reason
==
FinishReason
.
Length
assert
response
.
details
.
finish_reason
==
FinishReason
.
Length
assert
response
.
details
.
generated_tokens
==
1
assert
response
.
details
.
generated_tokens
==
1
assert
response
.
details
.
seed
is
None
assert
response
.
details
.
seed
is
None
...
...
clients/python/tests/test_inference_api.py
View file @
d8dc8f1b
...
@@ -14,8 +14,8 @@ def test_get_supported_models():
...
@@ -14,8 +14,8 @@ def test_get_supported_models():
assert
isinstance
(
get_supported_models
(),
list
)
assert
isinstance
(
get_supported_models
(),
list
)
def
test_client
(
bloom_mode
l
):
def
test_client
(
flan_t5_xx
l
):
client
=
InferenceAPIClient
(
bloom_mode
l
)
client
=
InferenceAPIClient
(
flan_t5_xx
l
)
assert
isinstance
(
client
,
Client
)
assert
isinstance
(
client
,
Client
)
...
@@ -24,8 +24,8 @@ def test_client_unsupported_model(unsupported_model):
...
@@ -24,8 +24,8 @@ def test_client_unsupported_model(unsupported_model):
InferenceAPIClient
(
unsupported_model
)
InferenceAPIClient
(
unsupported_model
)
def
test_async_client
(
bloom_mode
l
):
def
test_async_client
(
flan_t5_xx
l
):
client
=
InferenceAPIAsyncClient
(
bloom_mode
l
)
client
=
InferenceAPIAsyncClient
(
flan_t5_xx
l
)
assert
isinstance
(
client
,
AsyncClient
)
assert
isinstance
(
client
,
AsyncClient
)
...
...
clients/python/tests/test_types.py
View file @
d8dc8f1b
import
pytest
import
pytest
from
text_generation.types
import
Parameters
from
text_generation.types
import
Parameters
,
Request
from
text_generation.errors
import
ValidationError
from
text_generation.errors
import
ValidationError
def
test_parameters_validation
():
def
test_parameters_validation
():
# Test best_of
Parameters
(
best_of
=
1
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
best_of
=
0
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
best_of
=-
1
)
Parameters
(
best_of
=
2
,
do_sample
=
True
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
best_of
=
2
)
# Test repetition_penalty
# Test repetition_penalty
Parameters
(
repetition_penalty
=
1
)
Parameters
(
repetition_penalty
=
1
)
with
pytest
.
raises
(
ValidationError
):
with
pytest
.
raises
(
ValidationError
):
...
@@ -32,8 +42,41 @@ def test_parameters_validation():
...
@@ -32,8 +42,41 @@ def test_parameters_validation():
Parameters
(
top_k
=-
1
)
Parameters
(
top_k
=-
1
)
# Test top_p
# Test top_p
Parameters
(
top_p
=
1
)
Parameters
(
top_p
=
0.5
)
with
pytest
.
raises
(
ValidationError
):
with
pytest
.
raises
(
ValidationError
):
Parameters
(
top_p
=
0
)
Parameters
(
top_p
=
0
)
with
pytest
.
raises
(
ValidationError
):
with
pytest
.
raises
(
ValidationError
):
Parameters
(
top_p
=-
1
)
Parameters
(
top_p
=-
1
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
top_p
=
1
)
# Test truncate
Parameters
(
truncate
=
1
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
truncate
=
0
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
truncate
=-
1
)
# Test typical_p
Parameters
(
typical_p
=
0.5
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
typical_p
=
0
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
typical_p
=-
1
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
typical_p
=
1
)
def
test_request_validation
():
Request
(
inputs
=
"test"
)
with
pytest
.
raises
(
ValidationError
):
Request
(
inputs
=
""
)
Request
(
inputs
=
"test"
,
stream
=
True
)
Request
(
inputs
=
"test"
,
parameters
=
Parameters
(
best_of
=
2
,
do_sample
=
True
))
with
pytest
.
raises
(
ValidationError
):
Request
(
inputs
=
"test"
,
parameters
=
Parameters
(
best_of
=
2
,
do_sample
=
True
),
stream
=
True
)
clients/python/text_generation/__init__.py
View file @
d8dc8f1b
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
__version__
=
"0.
2.1
"
__version__
=
"0.
3.0
"
from
text_generation.client
import
Client
,
AsyncClient
from
text_generation.client
import
Client
,
AsyncClient
from
text_generation.inference_api
import
InferenceAPIClient
,
InferenceAPIAsyncClient
from
text_generation.inference_api
import
InferenceAPIClient
,
InferenceAPIAsyncClient
clients/python/text_generation/client.py
View file @
d8dc8f1b
...
@@ -56,6 +56,7 @@ class Client:
...
@@ -56,6 +56,7 @@ class Client:
prompt
:
str
,
prompt
:
str
,
do_sample
:
bool
=
False
,
do_sample
:
bool
=
False
,
max_new_tokens
:
int
=
20
,
max_new_tokens
:
int
=
20
,
best_of
:
Optional
[
int
]
=
None
,
repetition_penalty
:
Optional
[
float
]
=
None
,
repetition_penalty
:
Optional
[
float
]
=
None
,
return_full_text
:
bool
=
False
,
return_full_text
:
bool
=
False
,
seed
:
Optional
[
int
]
=
None
,
seed
:
Optional
[
int
]
=
None
,
...
@@ -63,6 +64,8 @@ class Client:
...
@@ -63,6 +64,8 @@ class Client:
temperature
:
Optional
[
float
]
=
None
,
temperature
:
Optional
[
float
]
=
None
,
top_k
:
Optional
[
int
]
=
None
,
top_k
:
Optional
[
int
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
truncate
:
Optional
[
int
]
=
None
,
typical_p
:
Optional
[
float
]
=
None
,
watermark
:
bool
=
False
,
watermark
:
bool
=
False
,
)
->
Response
:
)
->
Response
:
"""
"""
...
@@ -75,6 +78,8 @@ class Client:
...
@@ -75,6 +78,8 @@ class Client:
Activate logits sampling
Activate logits sampling
max_new_tokens (`int`):
max_new_tokens (`int`):
Maximum number of generated tokens
Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
...
@@ -91,6 +96,11 @@ class Client:
...
@@ -91,6 +96,11 @@ class Client:
top_p (`float`):
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
...
@@ -99,6 +109,7 @@ class Client:
...
@@ -99,6 +109,7 @@ class Client:
"""
"""
# Validate parameters
# Validate parameters
parameters
=
Parameters
(
parameters
=
Parameters
(
best_of
=
best_of
,
details
=
True
,
details
=
True
,
do_sample
=
do_sample
,
do_sample
=
do_sample
,
max_new_tokens
=
max_new_tokens
,
max_new_tokens
=
max_new_tokens
,
...
@@ -109,6 +120,8 @@ class Client:
...
@@ -109,6 +120,8 @@ class Client:
temperature
=
temperature
,
temperature
=
temperature
,
top_k
=
top_k
,
top_k
=
top_k
,
top_p
=
top_p
,
top_p
=
top_p
,
truncate
=
truncate
,
typical_p
=
typical_p
,
watermark
=
watermark
,
watermark
=
watermark
,
)
)
request
=
Request
(
inputs
=
prompt
,
stream
=
False
,
parameters
=
parameters
)
request
=
Request
(
inputs
=
prompt
,
stream
=
False
,
parameters
=
parameters
)
...
@@ -129,6 +142,7 @@ class Client:
...
@@ -129,6 +142,7 @@ class Client:
prompt
:
str
,
prompt
:
str
,
do_sample
:
bool
=
False
,
do_sample
:
bool
=
False
,
max_new_tokens
:
int
=
20
,
max_new_tokens
:
int
=
20
,
best_of
:
Optional
[
int
]
=
None
,
repetition_penalty
:
Optional
[
float
]
=
None
,
repetition_penalty
:
Optional
[
float
]
=
None
,
return_full_text
:
bool
=
False
,
return_full_text
:
bool
=
False
,
seed
:
Optional
[
int
]
=
None
,
seed
:
Optional
[
int
]
=
None
,
...
@@ -136,6 +150,8 @@ class Client:
...
@@ -136,6 +150,8 @@ class Client:
temperature
:
Optional
[
float
]
=
None
,
temperature
:
Optional
[
float
]
=
None
,
top_k
:
Optional
[
int
]
=
None
,
top_k
:
Optional
[
int
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
truncate
:
Optional
[
int
]
=
None
,
typical_p
:
Optional
[
float
]
=
None
,
watermark
:
bool
=
False
,
watermark
:
bool
=
False
,
)
->
Iterator
[
StreamResponse
]:
)
->
Iterator
[
StreamResponse
]:
"""
"""
...
@@ -148,6 +164,8 @@ class Client:
...
@@ -148,6 +164,8 @@ class Client:
Activate logits sampling
Activate logits sampling
max_new_tokens (`int`):
max_new_tokens (`int`):
Maximum number of generated tokens
Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
...
@@ -164,6 +182,11 @@ class Client:
...
@@ -164,6 +182,11 @@ class Client:
top_p (`float`):
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
...
@@ -172,6 +195,7 @@ class Client:
...
@@ -172,6 +195,7 @@ class Client:
"""
"""
# Validate parameters
# Validate parameters
parameters
=
Parameters
(
parameters
=
Parameters
(
best_of
=
best_of
,
details
=
True
,
details
=
True
,
do_sample
=
do_sample
,
do_sample
=
do_sample
,
max_new_tokens
=
max_new_tokens
,
max_new_tokens
=
max_new_tokens
,
...
@@ -182,6 +206,8 @@ class Client:
...
@@ -182,6 +206,8 @@ class Client:
temperature
=
temperature
,
temperature
=
temperature
,
top_k
=
top_k
,
top_k
=
top_k
,
top_p
=
top_p
,
top_p
=
top_p
,
truncate
=
truncate
,
typical_p
=
typical_p
,
watermark
=
watermark
,
watermark
=
watermark
,
)
)
request
=
Request
(
inputs
=
prompt
,
stream
=
True
,
parameters
=
parameters
)
request
=
Request
(
inputs
=
prompt
,
stream
=
True
,
parameters
=
parameters
)
...
@@ -261,6 +287,7 @@ class AsyncClient:
...
@@ -261,6 +287,7 @@ class AsyncClient:
prompt
:
str
,
prompt
:
str
,
do_sample
:
bool
=
False
,
do_sample
:
bool
=
False
,
max_new_tokens
:
int
=
20
,
max_new_tokens
:
int
=
20
,
best_of
:
Optional
[
int
]
=
None
,
repetition_penalty
:
Optional
[
float
]
=
None
,
repetition_penalty
:
Optional
[
float
]
=
None
,
return_full_text
:
bool
=
False
,
return_full_text
:
bool
=
False
,
seed
:
Optional
[
int
]
=
None
,
seed
:
Optional
[
int
]
=
None
,
...
@@ -268,6 +295,8 @@ class AsyncClient:
...
@@ -268,6 +295,8 @@ class AsyncClient:
temperature
:
Optional
[
float
]
=
None
,
temperature
:
Optional
[
float
]
=
None
,
top_k
:
Optional
[
int
]
=
None
,
top_k
:
Optional
[
int
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
truncate
:
Optional
[
int
]
=
None
,
typical_p
:
Optional
[
float
]
=
None
,
watermark
:
bool
=
False
,
watermark
:
bool
=
False
,
)
->
Response
:
)
->
Response
:
"""
"""
...
@@ -280,6 +309,8 @@ class AsyncClient:
...
@@ -280,6 +309,8 @@ class AsyncClient:
Activate logits sampling
Activate logits sampling
max_new_tokens (`int`):
max_new_tokens (`int`):
Maximum number of generated tokens
Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
...
@@ -296,6 +327,11 @@ class AsyncClient:
...
@@ -296,6 +327,11 @@ class AsyncClient:
top_p (`float`):
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
...
@@ -304,6 +340,7 @@ class AsyncClient:
...
@@ -304,6 +340,7 @@ class AsyncClient:
"""
"""
# Validate parameters
# Validate parameters
parameters
=
Parameters
(
parameters
=
Parameters
(
best_of
=
best_of
,
details
=
True
,
details
=
True
,
do_sample
=
do_sample
,
do_sample
=
do_sample
,
max_new_tokens
=
max_new_tokens
,
max_new_tokens
=
max_new_tokens
,
...
@@ -314,6 +351,8 @@ class AsyncClient:
...
@@ -314,6 +351,8 @@ class AsyncClient:
temperature
=
temperature
,
temperature
=
temperature
,
top_k
=
top_k
,
top_k
=
top_k
,
top_p
=
top_p
,
top_p
=
top_p
,
truncate
=
truncate
,
typical_p
=
typical_p
,
watermark
=
watermark
,
watermark
=
watermark
,
)
)
request
=
Request
(
inputs
=
prompt
,
stream
=
False
,
parameters
=
parameters
)
request
=
Request
(
inputs
=
prompt
,
stream
=
False
,
parameters
=
parameters
)
...
@@ -331,6 +370,7 @@ class AsyncClient:
...
@@ -331,6 +370,7 @@ class AsyncClient:
prompt
:
str
,
prompt
:
str
,
do_sample
:
bool
=
False
,
do_sample
:
bool
=
False
,
max_new_tokens
:
int
=
20
,
max_new_tokens
:
int
=
20
,
best_of
:
Optional
[
int
]
=
None
,
repetition_penalty
:
Optional
[
float
]
=
None
,
repetition_penalty
:
Optional
[
float
]
=
None
,
return_full_text
:
bool
=
False
,
return_full_text
:
bool
=
False
,
seed
:
Optional
[
int
]
=
None
,
seed
:
Optional
[
int
]
=
None
,
...
@@ -338,6 +378,8 @@ class AsyncClient:
...
@@ -338,6 +378,8 @@ class AsyncClient:
temperature
:
Optional
[
float
]
=
None
,
temperature
:
Optional
[
float
]
=
None
,
top_k
:
Optional
[
int
]
=
None
,
top_k
:
Optional
[
int
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
truncate
:
Optional
[
int
]
=
None
,
typical_p
:
Optional
[
float
]
=
None
,
watermark
:
bool
=
False
,
watermark
:
bool
=
False
,
)
->
AsyncIterator
[
StreamResponse
]:
)
->
AsyncIterator
[
StreamResponse
]:
"""
"""
...
@@ -350,6 +392,8 @@ class AsyncClient:
...
@@ -350,6 +392,8 @@ class AsyncClient:
Activate logits sampling
Activate logits sampling
max_new_tokens (`int`):
max_new_tokens (`int`):
Maximum number of generated tokens
Maximum number of generated tokens
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
...
@@ -366,6 +410,11 @@ class AsyncClient:
...
@@ -366,6 +410,11 @@ class AsyncClient:
top_p (`float`):
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
higher are kept for generation.
truncate (`int`):
Truncate inputs tokens to the given size
typical_p (`float`):
Typical Decoding mass
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
watermark (`bool`):
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
...
@@ -374,6 +423,7 @@ class AsyncClient:
...
@@ -374,6 +423,7 @@ class AsyncClient:
"""
"""
# Validate parameters
# Validate parameters
parameters
=
Parameters
(
parameters
=
Parameters
(
best_of
=
best_of
,
details
=
True
,
details
=
True
,
do_sample
=
do_sample
,
do_sample
=
do_sample
,
max_new_tokens
=
max_new_tokens
,
max_new_tokens
=
max_new_tokens
,
...
@@ -384,6 +434,8 @@ class AsyncClient:
...
@@ -384,6 +434,8 @@ class AsyncClient:
temperature
=
temperature
,
temperature
=
temperature
,
top_k
=
top_k
,
top_k
=
top_k
,
top_p
=
top_p
,
top_p
=
top_p
,
truncate
=
truncate
,
typical_p
=
typical_p
,
watermark
=
watermark
,
watermark
=
watermark
,
)
)
request
=
Request
(
inputs
=
prompt
,
stream
=
True
,
parameters
=
parameters
)
request
=
Request
(
inputs
=
prompt
,
stream
=
True
,
parameters
=
parameters
)
...
...
clients/python/text_generation/types.py
View file @
d8dc8f1b
...
@@ -6,27 +6,64 @@ from text_generation.errors import ValidationError
...
@@ -6,27 +6,64 @@ from text_generation.errors import ValidationError
class
Parameters
(
BaseModel
):
class
Parameters
(
BaseModel
):
# Activate logits sampling
do_sample
:
bool
=
False
do_sample
:
bool
=
False
# Maximum number of generated tokens
max_new_tokens
:
int
=
20
max_new_tokens
:
int
=
20
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty
:
Optional
[
float
]
=
None
repetition_penalty
:
Optional
[
float
]
=
None
# Whether to prepend the prompt to the generated text
return_full_text
:
bool
=
False
return_full_text
:
bool
=
False
# Stop generating tokens if a member of `stop_sequences` is generated
stop
:
List
[
str
]
=
[]
stop
:
List
[
str
]
=
[]
# Random sampling seed
seed
:
Optional
[
int
]
seed
:
Optional
[
int
]
# The value used to module the logits distribution.
temperature
:
Optional
[
float
]
temperature
:
Optional
[
float
]
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_k
:
Optional
[
int
]
top_k
:
Optional
[
int
]
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
# higher are kept for generation.
top_p
:
Optional
[
float
]
top_p
:
Optional
[
float
]
# truncate inputs tokens to the given size
truncate
:
Optional
[
int
]
# Typical Decoding mass
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
typical_p
:
Optional
[
float
]
# Generate best_of sequences and return the one if the highest token logprobs
best_of
:
Optional
[
int
]
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
watermark
:
bool
=
False
watermark
:
bool
=
False
# Get generation details
details
:
bool
=
False
details
:
bool
=
False
@
validator
(
"best_of"
)
def
valid_best_of
(
cls
,
field_value
,
values
):
if
field_value
is
not
None
:
if
field_value
<=
0
:
raise
ValidationError
(
"`best_of` must be strictly positive"
)
sampling
=
(
values
[
"do_sample"
]
|
(
values
[
"temperature"
]
is
not
None
)
|
(
values
[
"top_k"
]
is
not
None
)
|
(
values
[
"top_p"
]
is
not
None
)
|
(
values
[
"typical_p"
]
is
not
None
)
)
if
field_value
>
1
and
not
sampling
:
raise
ValidationError
(
"you must use sampling when `best_of` is > 1"
)
return
field_value
@
validator
(
"repetition_penalty"
)
@
validator
(
"repetition_penalty"
)
def
valid_repetition_penalty
(
cls
,
v
):
def
valid_repetition_penalty
(
cls
,
v
):
if
v
is
not
None
and
v
is
v
<=
0
:
if
v
is
not
None
and
v
<=
0
:
raise
ValidationError
(
"`repetition_penalty` must be strictly positive"
)
raise
ValidationError
(
"`repetition_penalty` must be strictly positive"
)
return
v
return
v
@
validator
(
"seed"
)
@
validator
(
"seed"
)
def
valid_seed
(
cls
,
v
):
def
valid_seed
(
cls
,
v
):
if
v
is
not
None
and
v
is
v
<
0
:
if
v
is
not
None
and
v
<
0
:
raise
ValidationError
(
"`seed` must be positive"
)
raise
ValidationError
(
"`seed` must be positive"
)
return
v
return
v
...
@@ -44,56 +81,143 @@ class Parameters(BaseModel):
...
@@ -44,56 +81,143 @@ class Parameters(BaseModel):
@
validator
(
"top_p"
)
@
validator
(
"top_p"
)
def
valid_top_p
(
cls
,
v
):
def
valid_top_p
(
cls
,
v
):
if
v
is
not
None
and
(
v
<=
0
or
v
>
1.0
):
if
v
is
not
None
and
(
v
<=
0
or
v
>=
1.0
):
raise
ValidationError
(
"`top_p` must be > 0.0 and <= 1.0"
)
raise
ValidationError
(
"`top_p` must be > 0.0 and < 1.0"
)
return
v
@
validator
(
"truncate"
)
def
valid_truncate
(
cls
,
v
):
if
v
is
not
None
and
v
<=
0
:
raise
ValidationError
(
"`truncate` must be strictly positive"
)
return
v
@
validator
(
"typical_p"
)
def
valid_typical_p
(
cls
,
v
):
if
v
is
not
None
and
(
v
<=
0
or
v
>=
1.0
):
raise
ValidationError
(
"`typical_p` must be > 0.0 and < 1.0"
)
return
v
return
v
class
Request
(
BaseModel
):
class
Request
(
BaseModel
):
# Prompt
inputs
:
str
inputs
:
str
parameters
:
Parameters
# Generation parameters
parameters
:
Optional
[
Parameters
]
# Whether to stream output tokens
stream
:
bool
=
False
stream
:
bool
=
False
@
validator
(
"inputs"
)
def
valid_input
(
cls
,
v
):
if
not
v
:
raise
ValidationError
(
"`inputs` cannot be empty"
)
return
v
@
validator
(
"stream"
)
def
valid_best_of_stream
(
cls
,
field_value
,
values
):
parameters
=
values
[
"parameters"
]
if
(
parameters
is
not
None
and
parameters
.
best_of
is
not
None
and
parameters
.
best_of
>
1
and
field_value
):
raise
ValidationError
(
"`best_of` != 1 is not supported when `stream` == True"
)
return
field_value
# Prompt tokens
class
PrefillToken
(
BaseModel
):
class
PrefillToken
(
BaseModel
):
# Token ID from the model tokenizer
id
:
int
id
:
int
# Token text
text
:
str
text
:
str
# Logprob
# Optional since the logprob of the first token cannot be computed
logprob
:
Optional
[
float
]
logprob
:
Optional
[
float
]
# Generated tokens
class
Token
(
BaseModel
):
class
Token
(
BaseModel
):
# Token ID from the model tokenizer
id
:
int
id
:
int
# Token text
text
:
str
text
:
str
# Logprob
logprob
:
float
logprob
:
float
# Is the token a special token
# Can be used to ignore tokens when concatenating
special
:
bool
special
:
bool
# Generation finish reason
class
FinishReason
(
Enum
):
class
FinishReason
(
Enum
):
# number of generated tokens == `max_new_tokens`
Length
=
"length"
Length
=
"length"
# the model generated its end of sequence token
EndOfSequenceToken
=
"eos_token"
EndOfSequenceToken
=
"eos_token"
# the model generated a text included in `stop_sequences`
StopSequence
=
"stop_sequence"
StopSequence
=
"stop_sequence"
class
Details
(
BaseModel
):
# Additional sequences when using the `best_of` parameter
class
BestOfSequence
(
BaseModel
):
# Generated text
generated_text
:
str
# Generation finish reason
finish_reason
:
FinishReason
finish_reason
:
FinishReason
# Number of generated tokens
generated_tokens
:
int
generated_tokens
:
int
# Sampling seed if sampling was activated
seed
:
Optional
[
int
]
seed
:
Optional
[
int
]
# Prompt tokens
prefill
:
List
[
PrefillToken
]
prefill
:
List
[
PrefillToken
]
# Generated tokens
tokens
:
List
[
Token
]
tokens
:
List
[
Token
]
class
StreamDetails
(
BaseModel
):
# `generate` details
class
Details
(
BaseModel
):
# Generation finish reason
finish_reason
:
FinishReason
finish_reason
:
FinishReason
# Number of generated tokens
generated_tokens
:
int
generated_tokens
:
int
# Sampling seed if sampling was activated
seed
:
Optional
[
int
]
seed
:
Optional
[
int
]
# Prompt tokens
prefill
:
List
[
PrefillToken
]
# Generated tokens
tokens
:
List
[
Token
]
# Additional sequences when using the `best_of` parameter
best_of_sequences
:
Optional
[
List
[
BestOfSequence
]]
# `generate` return value
class
Response
(
BaseModel
):
class
Response
(
BaseModel
):
# Generated text
generated_text
:
str
generated_text
:
str
# Generation details
details
:
Details
details
:
Details
# `generate_stream` details
class
StreamDetails
(
BaseModel
):
# Generation finish reason
finish_reason
:
FinishReason
# Number of generated tokens
generated_tokens
:
int
# Sampling seed if sampling was activated
seed
:
Optional
[
int
]
# `generate_stream` return value
class
StreamResponse
(
BaseModel
):
class
StreamResponse
(
BaseModel
):
# Generated token
token
:
Token
token
:
Token
# Complete generated text
# Only available when the generation is finished
generated_text
:
Optional
[
str
]
generated_text
:
Optional
[
str
]
# Generation details
# Only available when the generation is finished
details
:
Optional
[
StreamDetails
]
details
:
Optional
[
StreamDetails
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment