Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
efd602c8
Commit
efd602c8
authored
Oct 29, 2024
by
xuxzh1
🎱
Browse files
last
parent
f1b779fc
Changes
214
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
500 additions
and
4 deletions
+500
-4
integration-tests/models/test_flash_falcon.py
integration-tests/models/test_flash_falcon.py
+3
-0
integration-tests/models/test_flash_gemma.py
integration-tests/models/test_flash_gemma.py
+4
-4
integration-tests/models/test_flash_gemma_gptq.py
integration-tests/models/test_flash_gemma_gptq.py
+67
-0
integration-tests/models/test_flash_gpt2.py
integration-tests/models/test_flash_gpt2.py
+46
-0
integration-tests/models/test_flash_llama_exl2.py
integration-tests/models/test_flash_llama_exl2.py
+76
-0
integration-tests/models/test_flash_llama_gptq.py
integration-tests/models/test_flash_llama_gptq.py
+3
-0
integration-tests/models/test_flash_llama_marlin.py
integration-tests/models/test_flash_llama_marlin.py
+66
-0
integration-tests/models/test_flash_neox.py
integration-tests/models/test_flash_neox.py
+2
-0
integration-tests/models/test_flash_neox_sharded.py
integration-tests/models/test_flash_neox_sharded.py
+2
-0
integration-tests/models/test_flash_pali_gemma.py
integration-tests/models/test_flash_pali_gemma.py
+64
-0
integration-tests/models/test_flash_phi.py
integration-tests/models/test_flash_phi.py
+3
-0
integration-tests/models/test_flash_qwen2.py
integration-tests/models/test_flash_qwen2.py
+3
-0
integration-tests/models/test_flash_santacoder.py
integration-tests/models/test_flash_santacoder.py
+2
-0
integration-tests/models/test_flash_starcoder.py
integration-tests/models/test_flash_starcoder.py
+3
-0
integration-tests/models/test_flash_starcoder2.py
integration-tests/models/test_flash_starcoder2.py
+3
-0
integration-tests/models/test_flash_starcoder_gptq.py
integration-tests/models/test_flash_starcoder_gptq.py
+3
-0
integration-tests/models/test_grammar_llama.py
integration-tests/models/test_grammar_llama.py
+1
-0
integration-tests/models/test_grammar_response_format_llama.py
...ration-tests/models/test_grammar_response_format_llama.py
+103
-0
integration-tests/models/test_idefics.py
integration-tests/models/test_idefics.py
+23
-0
integration-tests/models/test_idefics2.py
integration-tests/models/test_idefics2.py
+23
-0
No files found.
Too many changes to show.
To preserve performance only
214 of 214+
files are displayed.
Plain diff
Email patch
integration-tests/models/test_flash_falcon.py
View file @
efd602c8
...
...
@@ -13,6 +13,7 @@ async def flash_falcon(flash_falcon_handle):
return
flash_falcon_handle
.
client
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_falcon
(
flash_falcon
,
response_snapshot
):
...
...
@@ -26,6 +27,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot):
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_falcon_all_params
(
flash_falcon
,
response_snapshot
):
...
...
@@ -49,6 +51,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_falcon_load
(
flash_falcon
,
generate_load
,
response_snapshot
):
...
...
integration-tests/models/test_flash_gemma.py
View file @
efd602c8
...
...
@@ -3,7 +3,7 @@ import pytest
@
pytest
.
fixture
(
scope
=
"module"
)
def
flash_gemma_handle
(
launcher
):
with
launcher
(
"g
g-hf
/gemma-2b"
,
num_shard
=
1
)
as
handle
:
with
launcher
(
"g
oogle
/gemma-2b"
,
num_shard
=
1
)
as
handle
:
yield
handle
...
...
@@ -13,7 +13,7 @@ async def flash_gemma(flash_gemma_handle):
return
flash_gemma_handle
.
client
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_gemma
(
flash_gemma
,
response_snapshot
):
...
...
@@ -25,7 +25,7 @@ async def test_flash_gemma(flash_gemma, response_snapshot):
assert
response
==
response_snapshot
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_gemma_all_params
(
flash_gemma
,
response_snapshot
):
...
...
@@ -49,7 +49,7 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
assert
response
==
response_snapshot
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_gemma_load
(
flash_gemma
,
generate_load
,
response_snapshot
):
...
...
integration-tests/models/test_flash_gemma_gptq.py
0 → 100644
View file @
efd602c8
import
pytest
@
pytest
.
fixture
(
scope
=
"module"
)
def
flash_gemma_gptq_handle
(
launcher
):
with
launcher
(
"TechxGenus/gemma-2b-GPTQ"
,
num_shard
=
1
,
quantize
=
"gptq"
)
as
handle
:
yield
handle
@
pytest
.
fixture
(
scope
=
"module"
)
async
def
flash_gemma_gptq
(
flash_gemma_gptq_handle
):
await
flash_gemma_gptq_handle
.
health
(
300
)
return
flash_gemma_gptq_handle
.
client
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_gemma_gptq
(
flash_gemma_gptq
,
ignore_logprob_response_snapshot
):
response
=
await
flash_gemma_gptq
.
generate
(
"Test request"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
)
assert
response
.
details
.
generated_tokens
==
10
assert
response
==
ignore_logprob_response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_gemma_gptq_all_params
(
flash_gemma_gptq
,
ignore_logprob_response_snapshot
):
response
=
await
flash_gemma_gptq
.
generate
(
"Test request"
,
max_new_tokens
=
10
,
repetition_penalty
=
1.2
,
return_full_text
=
True
,
stop_sequences
=
[
"test"
],
temperature
=
0.5
,
top_p
=
0.9
,
top_k
=
10
,
truncate
=
5
,
typical_p
=
0.9
,
watermark
=
True
,
decoder_input_details
=
True
,
seed
=
0
,
)
assert
response
.
details
.
generated_tokens
==
10
assert
response
==
ignore_logprob_response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_gemma_gptq_load
(
flash_gemma_gptq
,
generate_load
,
ignore_logprob_response_snapshot
):
responses
=
await
generate_load
(
flash_gemma_gptq
,
"Test request"
,
max_new_tokens
=
10
,
n
=
4
)
assert
len
(
responses
)
==
4
assert
all
([
r
.
generated_text
==
responses
[
0
].
generated_text
for
r
in
responses
])
assert
responses
==
ignore_logprob_response_snapshot
integration-tests/models/test_flash_gpt2.py
0 → 100644
View file @
efd602c8
import
pytest
@
pytest
.
fixture
(
scope
=
"module"
)
def
flash_gpt2_handle
(
launcher
):
with
launcher
(
"openai-community/gpt2"
,
num_shard
=
2
)
as
handle
:
yield
handle
@
pytest
.
fixture
(
scope
=
"module"
)
async
def
flash_gpt2
(
flash_gpt2_handle
):
await
flash_gpt2_handle
.
health
(
300
)
return
flash_gpt2_handle
.
client
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
async
def
test_flash_gpt2
(
flash_gpt2
,
response_snapshot
):
response
=
await
flash_gpt2
.
generate
(
"What is deep learning?"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
,
)
assert
response
.
details
.
generated_tokens
==
10
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
async
def
test_flash_gpt2_load
(
flash_gpt2
,
generate_load
,
response_snapshot
):
responses
=
await
generate_load
(
flash_gpt2
,
"What is deep learning?"
,
max_new_tokens
=
10
,
n
=
4
,
)
generated_texts
=
[
r
.
generated_text
for
r
in
responses
]
assert
len
(
generated_texts
)
==
4
assert
all
(
[
text
==
generated_texts
[
0
]
for
text
in
generated_texts
]
),
generated_texts
assert
responses
==
response_snapshot
integration-tests/models/test_flash_llama_exl2.py
0 → 100644
View file @
efd602c8
import
pytest
@
pytest
.
fixture
(
scope
=
"module"
)
def
flash_llama_exl2_handle
(
launcher
):
with
launcher
(
"turboderp/Llama-3-8B-Instruct-exl2"
,
revision
=
"2.5bpw"
,
# Set max input length to avoid OOM due to extremely large
# scratch buffer.
max_input_length
=
1024
,
num_shard
=
1
,
quantize
=
"exl2"
,
)
as
handle
:
yield
handle
@
pytest
.
fixture
(
scope
=
"module"
)
async
def
flash_llama_exl2
(
flash_llama_exl2_handle
):
await
flash_llama_exl2_handle
.
health
(
300
)
return
flash_llama_exl2_handle
.
client
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_exl2
(
flash_llama_exl2
,
ignore_logprob_response_snapshot
):
response
=
await
flash_llama_exl2
.
generate
(
"Test request"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
)
assert
response
.
details
.
generated_tokens
==
10
assert
response
==
ignore_logprob_response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_exl2_all_params
(
flash_llama_exl2
,
ignore_logprob_response_snapshot
):
response
=
await
flash_llama_exl2
.
generate
(
"Test request"
,
max_new_tokens
=
10
,
repetition_penalty
=
1.2
,
return_full_text
=
True
,
temperature
=
0.5
,
top_p
=
0.9
,
top_k
=
10
,
truncate
=
5
,
typical_p
=
0.9
,
watermark
=
True
,
decoder_input_details
=
True
,
seed
=
0
,
)
assert
(
response
.
generated_text
==
'Test request. The server responds with a "200 OK"'
)
assert
response
==
ignore_logprob_response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_exl2_load
(
flash_llama_exl2
,
generate_load
,
ignore_logprob_response_snapshot
):
responses
=
await
generate_load
(
flash_llama_exl2
,
"Test request"
,
max_new_tokens
=
10
,
n
=
4
)
assert
len
(
responses
)
==
4
assert
all
([
r
.
generated_text
==
responses
[
0
].
generated_text
for
r
in
responses
])
assert
responses
==
ignore_logprob_response_snapshot
integration-tests/models/test_flash_llama_gptq.py
View file @
efd602c8
...
...
@@ -13,6 +13,7 @@ async def flash_llama_gptq(flash_llama_gptq_handle):
return
flash_llama_gptq_handle
.
client
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_gptq
(
flash_llama_gptq
,
response_snapshot
):
...
...
@@ -24,6 +25,7 @@ async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_gptq_all_params
(
flash_llama_gptq
,
response_snapshot
):
...
...
@@ -46,6 +48,7 @@ async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_gptq_load
(
...
...
integration-tests/models/test_flash_llama_marlin.py
0 → 100644
View file @
efd602c8
import
pytest
@
pytest
.
fixture
(
scope
=
"module"
)
def
flash_llama_marlin_handle
(
launcher
):
with
launcher
(
"neuralmagic/llama-2-7b-chat-marlin"
,
num_shard
=
2
,
quantize
=
"marlin"
)
as
handle
:
yield
handle
@
pytest
.
fixture
(
scope
=
"module"
)
async
def
flash_llama_marlin
(
flash_llama_marlin_handle
):
await
flash_llama_marlin_handle
.
health
(
300
)
return
flash_llama_marlin_handle
.
client
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_marlin
(
flash_llama_marlin
,
response_snapshot
):
response
=
await
flash_llama_marlin
.
generate
(
"Test request"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
)
assert
response
.
details
.
generated_tokens
==
10
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_marlin_all_params
(
flash_llama_marlin
,
response_snapshot
):
response
=
await
flash_llama_marlin
.
generate
(
"Test request"
,
max_new_tokens
=
10
,
repetition_penalty
=
1.2
,
return_full_text
=
True
,
temperature
=
0.5
,
top_p
=
0.9
,
top_k
=
10
,
truncate
=
5
,
typical_p
=
0.9
,
watermark
=
True
,
decoder_input_details
=
True
,
seed
=
0
,
)
assert
response
.
details
.
generated_tokens
==
10
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_llama_marlin_load
(
flash_llama_marlin
,
generate_load
,
response_snapshot
):
responses
=
await
generate_load
(
flash_llama_marlin
,
"Test request"
,
max_new_tokens
=
10
,
n
=
4
)
assert
len
(
responses
)
==
4
assert
all
([
r
.
generated_text
==
responses
[
0
].
generated_text
for
r
in
responses
])
assert
responses
==
response_snapshot
integration-tests/models/test_flash_neox.py
View file @
efd602c8
...
...
@@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle):
return
flash_neox_handle
.
client
@
pytest
.
mark
.
release
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
asyncio
async
def
test_flash_neox
(
flash_neox
,
response_snapshot
):
...
...
@@ -26,6 +27,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
asyncio
async
def
test_flash_neox_load
(
flash_neox
,
generate_load
,
response_snapshot
):
...
...
integration-tests/models/test_flash_neox_sharded.py
View file @
efd602c8
...
...
@@ -13,6 +13,7 @@ async def flash_neox_sharded(flash_neox_sharded_handle):
return
flash_neox_sharded_handle
.
client
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
async
def
test_flash_neox
(
flash_neox_sharded
,
response_snapshot
):
response
=
await
flash_neox_sharded
.
generate
(
...
...
@@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot):
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
async
def
test_flash_neox_load
(
flash_neox_sharded
,
generate_load
,
response_snapshot
):
responses
=
await
generate_load
(
...
...
integration-tests/models/test_flash_pali_gemma.py
0 → 100644
View file @
efd602c8
import
pytest
import
requests
import
io
import
base64
@
pytest
.
fixture
(
scope
=
"module"
)
def
flash_pali_gemma_handle
(
launcher
):
with
launcher
(
"google/paligemma-3b-pt-224"
,
num_shard
=
1
,
revision
=
"float16"
,
max_input_length
=
4000
,
max_total_tokens
=
4096
,
)
as
handle
:
yield
handle
@
pytest
.
fixture
(
scope
=
"module"
)
async
def
flash_pali_gemma
(
flash_pali_gemma_handle
):
await
flash_pali_gemma_handle
.
health
(
300
)
return
flash_pali_gemma_handle
.
client
def
get_chicken
():
with
open
(
"integration-tests/images/chicken_on_money.png"
,
"rb"
)
as
image_file
:
encoded_string
=
base64
.
b64encode
(
image_file
.
read
())
return
f
"data:image/png;base64,
{
encoded_string
.
decode
(
'utf-8'
)
}
"
def
get_cow_beach
():
with
open
(
"integration-tests/images/cow_beach.png"
,
"rb"
)
as
image_file
:
encoded_string
=
base64
.
b64encode
(
image_file
.
read
())
return
f
"data:image/png;base64,
{
encoded_string
.
decode
(
'utf-8'
)
}
"
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_pali_gemma
(
flash_pali_gemma
,
response_snapshot
):
cow
=
get_cow_beach
()
inputs
=
f
"Where is the cow standing?
\n
"
response
=
await
flash_pali_gemma
.
generate
(
inputs
,
max_new_tokens
=
20
)
assert
response
.
generated_text
==
"beach"
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_pali_gemma_two_images
(
flash_pali_gemma
,
response_snapshot
):
chicken
=
get_chicken
()
cow_beach
=
get_cow_beach
()
response
=
await
flash_pali_gemma
.
generate
(
f
"caption
\n
"
,
max_new_tokens
=
20
,
)
# Is PaliGemma not able to handle two separate images? At least we
# get output showing that both images are used.
assert
(
response
.
generated_text
==
"image result for chicken on the beach"
),
f
"
{
repr
(
response
.
generated_text
)
}
"
assert
response
==
response_snapshot
integration-tests/models/test_flash_phi.py
View file @
efd602c8
...
...
@@ -13,6 +13,7 @@ async def flash_phi(flash_phi_handle):
return
flash_phi_handle
.
client
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
async
def
test_flash_phi
(
flash_phi
,
response_snapshot
):
response
=
await
flash_phi
.
generate
(
...
...
@@ -24,6 +25,7 @@ async def test_flash_phi(flash_phi, response_snapshot):
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
async
def
test_flash_phi_all_params
(
flash_phi
,
response_snapshot
):
response
=
await
flash_phi
.
generate
(
...
...
@@ -47,6 +49,7 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
async
def
test_flash_phi_load
(
flash_phi
,
generate_load
,
response_snapshot
):
responses
=
await
generate_load
(
flash_phi
,
"Test request"
,
max_new_tokens
=
10
,
n
=
4
)
...
...
integration-tests/models/test_flash_qwen2.py
View file @
efd602c8
...
...
@@ -13,6 +13,7 @@ async def flash_qwen2(flash_qwen2_handle):
return
flash_qwen2_handle
.
client
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
async
def
test_flash_qwen2
(
flash_qwen2
,
response_snapshot
):
response
=
await
flash_qwen2
.
generate
(
...
...
@@ -24,6 +25,7 @@ async def test_flash_qwen2(flash_qwen2, response_snapshot):
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
async
def
test_flash_qwen2_all_params
(
flash_qwen2
,
response_snapshot
):
response
=
await
flash_qwen2
.
generate
(
...
...
@@ -46,6 +48,7 @@ async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
async
def
test_flash_qwen2_load
(
flash_qwen2
,
generate_load
,
response_snapshot
):
responses
=
await
generate_load
(
flash_qwen2
,
"Test request"
,
max_new_tokens
=
10
,
n
=
4
)
...
...
integration-tests/models/test_flash_santacoder.py
View file @
efd602c8
...
...
@@ -13,6 +13,7 @@ async def flash_santacoder(flash_santacoder_handle):
return
flash_santacoder_handle
.
client
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
async
def
test_flash_santacoder
(
flash_santacoder
,
response_snapshot
):
response
=
await
flash_santacoder
.
generate
(
...
...
@@ -23,6 +24,7 @@ async def test_flash_santacoder(flash_santacoder, response_snapshot):
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
async
def
test_flash_santacoder_load
(
flash_santacoder
,
generate_load
,
response_snapshot
...
...
integration-tests/models/test_flash_starcoder.py
View file @
efd602c8
...
...
@@ -13,6 +13,7 @@ async def flash_starcoder(flash_starcoder_handle):
return
flash_starcoder_handle
.
client
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_starcoder
(
flash_starcoder
,
response_snapshot
):
...
...
@@ -24,6 +25,7 @@ async def test_flash_starcoder(flash_starcoder, response_snapshot):
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_starcoder_default_params
(
flash_starcoder
,
response_snapshot
):
...
...
@@ -40,6 +42,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_starcoder_load
(
flash_starcoder
,
generate_load
,
response_snapshot
):
...
...
integration-tests/models/test_flash_starcoder2.py
View file @
efd602c8
...
...
@@ -13,6 +13,7 @@ async def flash_starcoder2(flash_starcoder2_handle):
return
flash_starcoder2_handle
.
client
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_starcoder2
(
flash_starcoder2
,
response_snapshot
):
...
...
@@ -24,6 +25,7 @@ async def test_flash_starcoder2(flash_starcoder2, response_snapshot):
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_starcoder2_default_params
(
flash_starcoder2
,
response_snapshot
):
...
...
@@ -40,6 +42,7 @@ async def test_flash_starcoder2_default_params(flash_starcoder2, response_snapsh
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_starcoder2_load
(
...
...
integration-tests/models/test_flash_starcoder_gptq.py
View file @
efd602c8
...
...
@@ -13,6 +13,7 @@ async def flash_starcoder_gptq(flash_starcoder_gptq_handle):
return
flash_starcoder_gptq_handle
.
client
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
async
def
test_flash_starcoder_gptq
(
flash_starcoder_gptq
,
generous_response_snapshot
):
response
=
await
flash_starcoder_gptq
.
generate
(
...
...
@@ -24,6 +25,7 @@ async def test_flash_starcoder_gptq(flash_starcoder_gptq, generous_response_snap
assert
response
==
generous_response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
async
def
test_flash_starcoder_gptq_default_params
(
flash_starcoder_gptq
,
generous_response_snapshot
...
...
@@ -40,6 +42,7 @@ async def test_flash_starcoder_gptq_default_params(
assert
response
==
generous_response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
async
def
test_flash_starcoder_gptq_load
(
flash_starcoder_gptq
,
generate_load
,
generous_response_snapshot
...
...
integration-tests/models/test_grammar_llama.py
View file @
efd602c8
...
...
@@ -21,6 +21,7 @@ async def non_flash_llama_grammar(non_flash_llama_grammar_handle):
return
non_flash_llama_grammar_handle
.
client
@
pytest
.
mark
.
release
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
asyncio
async
def
test_non_flash_llama_grammar_json
(
non_flash_llama_grammar
,
response_snapshot
):
...
...
integration-tests/models/test_grammar_response_format_llama.py
0 → 100644
View file @
efd602c8
import
pytest
import
requests
from
pydantic
import
BaseModel
from
typing
import
List
@
pytest
.
fixture
(
scope
=
"module"
)
def
llama_grammar_handle
(
launcher
):
with
launcher
(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
,
num_shard
=
1
,
disable_grammar_support
=
False
,
use_flash_attention
=
False
,
max_batch_prefill_tokens
=
3000
,
)
as
handle
:
yield
handle
@
pytest
.
fixture
(
scope
=
"module"
)
async
def
llama_grammar
(
llama_grammar_handle
):
await
llama_grammar_handle
.
health
(
300
)
return
llama_grammar_handle
.
client
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
async
def
test_grammar_response_format_llama_json
(
llama_grammar
,
response_snapshot
):
class
Weather
(
BaseModel
):
unit
:
str
temperature
:
List
[
int
]
# send the request
response
=
requests
.
post
(
f
"
{
llama_grammar
.
base_url
}
/v1/chat/completions"
,
headers
=
llama_grammar
.
headers
,
json
=
{
"model"
:
"tgi"
,
"messages"
:
[
{
"role"
:
"system"
,
"content"
:
f
"Respond to the users questions and answer them in the following format:
{
Weather
.
schema
()
}
"
,
},
{
"role"
:
"user"
,
"content"
:
"What's the weather like the next 3 days in San Francisco, CA?"
,
},
],
"seed"
:
42
,
"max_tokens"
:
500
,
"response_format"
:
{
"type"
:
"json_object"
,
"value"
:
Weather
.
schema
()},
},
)
chat_completion
=
response
.
json
()
called
=
chat_completion
[
"choices"
][
0
][
"message"
][
"content"
]
assert
response
.
status_code
==
200
assert
(
called
==
'{
\n
"temperature": [
\n
35,
\n
34,
\n
36
\n
],
\n
"unit": "°c"
\n
}'
)
assert
chat_completion
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
async
def
test_grammar_response_format_llama_error_if_tools_not_installed
(
llama_grammar
,
):
class
Weather
(
BaseModel
):
unit
:
str
temperature
:
List
[
int
]
# send the request
response
=
requests
.
post
(
f
"
{
llama_grammar
.
base_url
}
/v1/chat/completions"
,
headers
=
llama_grammar
.
headers
,
json
=
{
"model"
:
"tgi"
,
"messages"
:
[
{
"role"
:
"system"
,
"content"
:
f
"Respond to the users questions and answer them in the following format:
{
Weather
.
schema
()
}
"
,
},
{
"role"
:
"user"
,
"content"
:
"What's the weather like the next 3 days in San Francisco, CA?"
,
},
],
"seed"
:
42
,
"max_tokens"
:
500
,
"tools"
:
[],
"response_format"
:
{
"type"
:
"json_object"
,
"value"
:
Weather
.
schema
()},
},
)
# 422 means the server was unable to process the request because it contains invalid data.
assert
response
.
status_code
==
422
assert
response
.
json
()
==
{
"error"
:
"Grammar and tools are mutually exclusive"
,
"error_type"
:
"grammar and tools"
,
}
integration-tests/models/test_idefics.py
View file @
efd602c8
...
...
@@ -23,6 +23,12 @@ def get_chicken():
return
f
"data:image/png;base64,
{
encoded_string
.
decode
(
'utf-8'
)
}
"
def
get_cow_beach
():
with
open
(
"integration-tests/images/cow_beach.png"
,
"rb"
)
as
image_file
:
encoded_string
=
base64
.
b64encode
(
image_file
.
read
())
return
f
"data:image/png;base64,
{
encoded_string
.
decode
(
'utf-8'
)
}
"
@
pytest
.
mark
.
asyncio
async
def
test_idefics
(
idefics
,
response_snapshot
):
chicken
=
get_chicken
()
...
...
@@ -39,6 +45,23 @@ async def test_idefics(idefics, response_snapshot):
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_idefics_two_images
(
idefics
,
response_snapshot
):
chicken
=
get_chicken
()
cow_beach
=
get_cow_beach
()
response
=
await
idefics
.
generate
(
f
"User:Where are the cow and chicken?<end_of_utterance>
\n
Assistant:"
,
max_new_tokens
=
20
,
)
assert
(
response
.
generated_text
==
" The cow and chicken are on a beach."
),
f
"
{
repr
(
response
.
generated_text
)
}
"
assert
response
==
response_snapshot
@
pytest
.
mark
.
release
@
pytest
.
mark
.
asyncio
async
def
test_idefics_load
(
idefics
,
generate_load
,
response_snapshot
):
chicken
=
get_chicken
()
...
...
integration-tests/models/test_idefics2.py
View file @
efd602c8
...
...
@@ -9,6 +9,12 @@ def get_chicken():
return
f
"data:image/png;base64,
{
encoded_string
.
decode
(
'utf-8'
)
}
"
def
get_cow_beach
():
with
open
(
"integration-tests/images/cow_beach.png"
,
"rb"
)
as
image_file
:
encoded_string
=
base64
.
b64encode
(
image_file
.
read
())
return
f
"data:image/png;base64,
{
encoded_string
.
decode
(
'utf-8'
)
}
"
@
pytest
.
fixture
(
scope
=
"module"
)
def
flash_idefics2_next_handle
(
launcher
):
with
launcher
(
...
...
@@ -38,6 +44,23 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot
assert
response
==
response_snapshot
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_idefics2_two_images
(
flash_idefics2_next
,
response_snapshot
):
chicken
=
get_chicken
()
cow_beach
=
get_cow_beach
()
response
=
await
flash_idefics2_next
.
generate
(
f
"User:Where are the cow and chicken?<end_of_utterance>
\n
Assistant:"
,
max_new_tokens
=
20
,
)
assert
(
response
.
generated_text
==
" The cow is standing on the beach and the chicken is sitting on a pile of money."
),
f
"
{
repr
(
response
.
generated_text
)
}
"
assert
response
.
details
.
generated_tokens
==
20
assert
response
==
response_snapshot
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_idefics2_next_all_params
(
flash_idefics2_next
,
response_snapshot
):
...
...
Prev
1
2
3
4
5
6
7
8
9
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment