Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
4126e0b0
Commit
4126e0b0
authored
Jun 26, 2024
by
huangwb
Browse files
add client native unit test
parent
6e6d3c1a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
337 additions
and
0 deletions
+337
-0
clients/python/tests_native/conftest.py
clients/python/tests_native/conftest.py
+64
-0
clients/python/tests_native/test_client.py
clients/python/tests_native/test_client.py
+125
-0
clients/python/tests_native/test_errors.py
clients/python/tests_native/test_errors.py
+64
-0
clients/python/tests_native/test_types.py
clients/python/tests_native/test_types.py
+84
-0
No files found.
clients/python/tests_native/conftest.py
0 → 100644
View file @
4126e0b0
import
pytest
from
text_generation
import
__version__
from
huggingface_hub.utils
import
build_hf_headers
@
pytest
.
fixture
def
flan_t5_xxl
():
return
"google/flan-t5-xxl"
@
pytest
.
fixture
def
llama_7b
():
return
"meta-llama/Llama-2-7b-chat-hf"
@
pytest
.
fixture
def
fake_model
():
return
"fake/model"
@
pytest
.
fixture
def
unsupported_model
():
return
"gpt2"
@
pytest
.
fixture
def
base_url
():
return
"https://api-inference.huggingface.co/models"
@
pytest
.
fixture
def
bloom_url
(
base_url
,
bloom_model
):
return
f
"
{
base_url
}
/
{
bloom_model
}
"
@
pytest
.
fixture
def
flan_t5_xxl_url
(
base_url
,
flan_t5_xxl
):
return
f
"
{
base_url
}
/
{
flan_t5_xxl
}
"
@
pytest
.
fixture
def
llama_7b_url
(
base_url
,
llama_7b
):
# return f"{base_url}/{llama_7b}"
return
"http://localhost:3001"
@
pytest
.
fixture
def
fake_url
(
base_url
,
fake_model
):
return
f
"
{
base_url
}
/
{
fake_model
}
"
@
pytest
.
fixture
def
unsupported_url
(
base_url
,
unsupported_model
):
return
f
"
{
base_url
}
/
{
unsupported_model
}
"
@
pytest
.
fixture
(
scope
=
"session"
)
def
hf_headers
():
# return build_hf_headers(
# library_name="text-generation-tests", library_version=__version__
# )
header
=
{
'content-type'
:
'application/json'
}
return
header
clients/python/tests_native/test_client.py
0 → 100644
View file @
4126e0b0
import
pytest
from
text_generation
import
Client
,
AsyncClient
from
text_generation.errors
import
NotFoundError
,
ValidationError
from
text_generation.types
import
FinishReason
,
InputToken
def
test_generate
(
llama_7b_url
,
hf_headers
):
client
=
Client
(
llama_7b_url
,
hf_headers
)
response
=
client
.
generate
(
"test"
,
max_new_tokens
=
1
,
decoder_input_details
=
True
)
assert
response
.
generated_text
==
"_"
assert
response
.
details
.
finish_reason
==
FinishReason
.
Length
assert
response
.
details
.
generated_tokens
==
1
assert
response
.
details
.
seed
is
None
assert
len
(
response
.
details
.
prefill
)
==
2
assert
response
.
details
.
prefill
[
0
]
==
InputToken
(
id
=
1
,
text
=
"<s>"
,
logprob
=
None
)
assert
len
(
response
.
details
.
tokens
)
==
1
assert
response
.
details
.
tokens
[
0
].
id
==
29918
assert
response
.
details
.
tokens
[
0
].
text
==
"_"
assert
not
response
.
details
.
tokens
[
0
].
special
def
test_generate_best_of
(
llama_7b_url
,
hf_headers
):
client
=
Client
(
llama_7b_url
,
hf_headers
)
response
=
client
.
generate
(
"test"
,
max_new_tokens
=
1
,
best_of
=
2
,
do_sample
=
True
,
decoder_input_details
=
True
)
assert
response
.
details
.
seed
is
not
None
assert
response
.
details
.
best_of_sequences
is
not
None
assert
len
(
response
.
details
.
best_of_sequences
)
==
1
assert
response
.
details
.
best_of_sequences
[
0
].
seed
is
not
None
def
test_generate_validation_error
(
llama_7b_url
,
hf_headers
):
client
=
Client
(
llama_7b_url
,
hf_headers
)
with
pytest
.
raises
(
ValidationError
):
client
.
generate
(
"test"
,
max_new_tokens
=
10_000
)
def
test_generate_stream
(
llama_7b_url
,
hf_headers
):
client
=
Client
(
llama_7b_url
,
hf_headers
)
responses
=
[
response
for
response
in
client
.
generate_stream
(
"test"
,
max_new_tokens
=
1
)
]
assert
len
(
responses
)
==
1
response
=
responses
[
0
]
assert
response
.
generated_text
==
"_"
assert
response
.
details
.
finish_reason
==
FinishReason
.
Length
assert
response
.
details
.
generated_tokens
==
1
assert
response
.
details
.
seed
is
None
def
test_generate_stream_validation_error
(
llama_7b_url
,
hf_headers
):
client
=
Client
(
llama_7b_url
,
hf_headers
)
with
pytest
.
raises
(
ValidationError
):
list
(
client
.
generate_stream
(
"test"
,
max_new_tokens
=
10_000
))
@
pytest
.
mark
.
asyncio
async
def
test_generate_async
(
llama_7b_url
,
hf_headers
):
client
=
AsyncClient
(
llama_7b_url
,
hf_headers
)
response
=
await
client
.
generate
(
"test"
,
max_new_tokens
=
1
,
decoder_input_details
=
True
)
assert
response
.
generated_text
==
"_"
assert
response
.
details
.
finish_reason
==
FinishReason
.
Length
assert
response
.
details
.
generated_tokens
==
1
assert
response
.
details
.
seed
is
None
assert
len
(
response
.
details
.
prefill
)
==
2
assert
response
.
details
.
prefill
[
0
]
==
InputToken
(
id
=
1
,
text
=
"<s>"
,
logprob
=
None
)
assert
response
.
details
.
prefill
[
1
]
==
InputToken
(
id
=
1243
,
text
=
"test"
,
logprob
=-
10.9375
)
assert
len
(
response
.
details
.
tokens
)
==
1
assert
response
.
details
.
tokens
[
0
].
id
==
29918
assert
response
.
details
.
tokens
[
0
].
text
==
"_"
assert
not
response
.
details
.
tokens
[
0
].
special
@
pytest
.
mark
.
asyncio
async
def
test_generate_async_best_of
(
llama_7b_url
,
hf_headers
):
client
=
AsyncClient
(
llama_7b_url
,
hf_headers
)
response
=
await
client
.
generate
(
"test"
,
max_new_tokens
=
1
,
best_of
=
2
,
do_sample
=
True
,
decoder_input_details
=
True
)
assert
response
.
details
.
seed
is
not
None
assert
response
.
details
.
best_of_sequences
is
not
None
assert
len
(
response
.
details
.
best_of_sequences
)
==
1
assert
response
.
details
.
best_of_sequences
[
0
].
seed
is
not
None
@
pytest
.
mark
.
asyncio
async
def
test_generate_async_validation_error
(
llama_7b_url
,
hf_headers
):
client
=
AsyncClient
(
llama_7b_url
,
hf_headers
)
with
pytest
.
raises
(
ValidationError
):
await
client
.
generate
(
"test"
,
max_new_tokens
=
10_000
)
@
pytest
.
mark
.
asyncio
async
def
test_generate_stream_async
(
llama_7b_url
,
hf_headers
):
client
=
AsyncClient
(
llama_7b_url
,
hf_headers
)
responses
=
[
response
async
for
response
in
client
.
generate_stream
(
"test"
,
max_new_tokens
=
1
)
]
assert
len
(
responses
)
==
1
response
=
responses
[
0
]
assert
response
.
generated_text
==
"_"
assert
response
.
details
.
finish_reason
==
FinishReason
.
Length
assert
response
.
details
.
generated_tokens
==
1
assert
response
.
details
.
seed
is
None
@
pytest
.
mark
.
asyncio
async
def
test_generate_stream_async_validation_error
(
llama_7b_url
,
hf_headers
):
client
=
AsyncClient
(
llama_7b_url
,
hf_headers
)
with
pytest
.
raises
(
ValidationError
):
async
for
_
in
client
.
generate_stream
(
"test"
,
max_new_tokens
=
10_000
):
pass
clients/python/tests_native/test_errors.py
0 → 100644
View file @
4126e0b0
from
text_generation.errors
import
(
parse_error
,
GenerationError
,
IncompleteGenerationError
,
OverloadedError
,
ValidationError
,
BadRequestError
,
ShardNotReadyError
,
ShardTimeoutError
,
NotFoundError
,
RateLimitExceededError
,
UnknownError
,
)
def
test_generation_error
():
payload
=
{
"error_type"
:
"generation"
,
"error"
:
"test"
}
assert
isinstance
(
parse_error
(
400
,
payload
),
GenerationError
)
def
test_incomplete_generation_error
():
payload
=
{
"error_type"
:
"incomplete_generation"
,
"error"
:
"test"
}
assert
isinstance
(
parse_error
(
400
,
payload
),
IncompleteGenerationError
)
def
test_overloaded_error
():
payload
=
{
"error_type"
:
"overloaded"
,
"error"
:
"test"
}
assert
isinstance
(
parse_error
(
400
,
payload
),
OverloadedError
)
def
test_validation_error
():
payload
=
{
"error_type"
:
"validation"
,
"error"
:
"test"
}
assert
isinstance
(
parse_error
(
400
,
payload
),
ValidationError
)
def
test_bad_request_error
():
payload
=
{
"error"
:
"test"
}
assert
isinstance
(
parse_error
(
400
,
payload
),
BadRequestError
)
def
test_shard_not_ready_error
():
payload
=
{
"error"
:
"test"
}
assert
isinstance
(
parse_error
(
403
,
payload
),
ShardNotReadyError
)
assert
isinstance
(
parse_error
(
424
,
payload
),
ShardNotReadyError
)
def
test_shard_timeout_error
():
payload
=
{
"error"
:
"test"
}
assert
isinstance
(
parse_error
(
504
,
payload
),
ShardTimeoutError
)
def
test_not_found_error
():
payload
=
{
"error"
:
"test"
}
assert
isinstance
(
parse_error
(
404
,
payload
),
NotFoundError
)
def
test_rate_limit_exceeded_error
():
payload
=
{
"error"
:
"test"
}
assert
isinstance
(
parse_error
(
429
,
payload
),
RateLimitExceededError
)
def
test_unknown_error
():
payload
=
{
"error"
:
"test"
}
assert
isinstance
(
parse_error
(
500
,
payload
),
UnknownError
)
clients/python/tests_native/test_types.py
0 → 100644
View file @
4126e0b0
import
pytest
from
text_generation.types
import
Parameters
,
Request
from
text_generation.errors
import
ValidationError
def
test_parameters_validation
():
# Test best_of
Parameters
(
best_of
=
1
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
best_of
=
0
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
best_of
=-
1
)
Parameters
(
best_of
=
2
,
do_sample
=
True
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
best_of
=
2
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
best_of
=
2
,
seed
=
1
)
# Test repetition_penalty
Parameters
(
repetition_penalty
=
1
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
repetition_penalty
=
0
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
repetition_penalty
=-
1
)
# Test seed
Parameters
(
seed
=
1
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
seed
=-
1
)
# Test temperature
Parameters
(
temperature
=
1
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
temperature
=
0
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
temperature
=-
1
)
# Test top_k
Parameters
(
top_k
=
1
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
top_k
=
0
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
top_k
=-
1
)
# Test top_p
Parameters
(
top_p
=
0.5
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
top_p
=
0
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
top_p
=-
1
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
top_p
=
1
)
# Test truncate
Parameters
(
truncate
=
1
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
truncate
=
0
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
truncate
=-
1
)
# Test typical_p
Parameters
(
typical_p
=
0.5
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
typical_p
=
0
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
typical_p
=-
1
)
with
pytest
.
raises
(
ValidationError
):
Parameters
(
typical_p
=
1
)
def
test_request_validation
():
Request
(
inputs
=
"test"
)
with
pytest
.
raises
(
ValidationError
):
Request
(
inputs
=
""
)
Request
(
inputs
=
"test"
,
stream
=
True
)
Request
(
inputs
=
"test"
,
parameters
=
Parameters
(
best_of
=
2
,
do_sample
=
True
))
with
pytest
.
raises
(
ValidationError
):
Request
(
inputs
=
"test"
,
parameters
=
Parameters
(
best_of
=
2
,
do_sample
=
True
),
stream
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment