Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5ae5ed1e
Unverified
Commit
5ae5ed1e
authored
May 29, 2024
by
Cyrus Leung
Committed by
GitHub
May 28, 2024
Browse files
[Core] Consolidate prompt arguments to LLM engines (#4328)
Co-authored-by:
Roger Wang
<
ywang@roblox.com
>
parent
290f4ada
Changes
43
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
378 additions
and
59 deletions
+378
-59
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+6
-3
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+7
-4
docs/source/dev/offline_inference/llm.rst
docs/source/dev/offline_inference/llm.rst
+1
-1
docs/source/dev/offline_inference/llm_inputs.rst
docs/source/dev/offline_inference/llm_inputs.rst
+14
-0
docs/source/dev/offline_inference/offline_index.rst
docs/source/dev/offline_inference/offline_index.rst
+8
-0
docs/source/dev/sampling_params.rst
docs/source/dev/sampling_params.rst
+0
-0
docs/source/index.rst
docs/source/index.rst
+3
-8
docs/source/serving/openai_compatible_server.md
docs/source/serving/openai_compatible_server.md
+2
-2
examples/llava_example.py
examples/llava_example.py
+16
-9
pyproject.toml
pyproject.toml
+7
-0
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+1
-1
tests/async_engine/test_openapi_server_ray.py
tests/async_engine/test_openapi_server_ray.py
+1
-1
tests/conftest.py
tests/conftest.py
+17
-6
tests/core/test_block_manager.py
tests/core/test_block_manager.py
+12
-3
tests/core/utils.py
tests/core/utils.py
+12
-3
tests/engine/test_skip_tokenizer_init.py
tests/engine/test_skip_tokenizer_init.py
+1
-1
tests/entrypoints/openai/test_serving_chat.py
tests/entrypoints/openai/test_serving_chat.py
+4
-0
tests/entrypoints/test_guided_processors.py
tests/entrypoints/test_guided_processors.py
+2
-0
tests/entrypoints/test_llm_encode.py
tests/entrypoints/test_llm_encode.py
+144
-0
tests/entrypoints/test_llm_generate.py
tests/entrypoints/test_llm_generate.py
+120
-17
No files found.
.buildkite/test-pipeline.yaml
View file @
5ae5ed1e
...
...
@@ -63,9 +63,9 @@ steps:
mirror_hardwares
:
[
amd
]
commands
:
# these tests have to be separated, because each one will allocate all posible GPU memor
y
-
pytest -v -s entrypoints -
-ignore=entrypoints/test_server_oot_registration.py
-
pytest -v -s entrypoints
/test_server_oot_registration.py
-
pytest -v -s test_inputs.p
y
-
pytest -v -s entrypoints -
m llm
-
pytest -v -s entrypoints
-m openai
-
label
:
Examples Test
working_dir
:
"
/vllm-workspace/examples"
...
...
@@ -110,6 +110,9 @@ steps:
mirror_hardwares
:
[
amd
]
command
:
pytest -v -s test_logits_processor.py
-
label
:
Utils Test
command
:
pytest -v -s test_utils.py
-
label
:
Worker Test
mirror_hardwares
:
[
amd
]
command
:
pytest -v -s worker
...
...
benchmarks/benchmark_latency.py
View file @
5ae5ed1e
...
...
@@ -3,13 +3,14 @@ import argparse
import
json
import
time
from
pathlib
import
Path
from
typing
import
Optional
from
typing
import
List
,
Optional
import
numpy
as
np
import
torch
from
tqdm
import
tqdm
from
vllm
import
LLM
,
SamplingParams
from
vllm.inputs
import
PromptStrictInputs
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
...
...
@@ -48,7 +49,9 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids
=
np
.
random
.
randint
(
10000
,
size
=
(
args
.
batch_size
,
args
.
input_len
))
dummy_prompt_token_ids
=
dummy_prompt_token_ids
.
tolist
()
dummy_inputs
:
List
[
PromptStrictInputs
]
=
[{
"prompt_token_ids"
:
batch
}
for
batch
in
dummy_prompt_token_ids
.
tolist
()]
def
run_to_completion
(
profile_dir
:
Optional
[
str
]
=
None
):
if
profile_dir
:
...
...
@@ -59,13 +62,13 @@ def main(args: argparse.Namespace):
],
on_trace_ready
=
torch
.
profiler
.
tensorboard_trace_handler
(
str
(
profile_dir
)))
as
p
:
llm
.
generate
(
prompt_token_ids
=
dummy_prompt_token_id
s
,
llm
.
generate
(
dummy_input
s
,
sampling_params
=
sampling_params
,
use_tqdm
=
False
)
print
(
p
.
key_averages
())
else
:
start_time
=
time
.
perf_counter
()
llm
.
generate
(
prompt_token_ids
=
dummy_prompt_token_id
s
,
llm
.
generate
(
dummy_input
s
,
sampling_params
=
sampling_params
,
use_tqdm
=
False
)
end_time
=
time
.
perf_counter
()
...
...
docs/source/offline_inference/llm.rst
→
docs/source/
dev/
offline_inference/llm.rst
View file @
5ae5ed1e
LLM Class
=========
=
=========
.. autoclass:: vllm.LLM
:members:
...
...
docs/source/dev/offline_inference/llm_inputs.rst
0 → 100644
View file @
5ae5ed1e
LLM Inputs
==========
.. autodata:: vllm.inputs.PromptStrictInputs
.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:
:members:
:member-order: bysource
.. autoclass:: vllm.inputs.TokensPrompt
:show-inheritance:
:members:
:member-order: bysource
docs/source/dev/offline_inference/offline_index.rst
0 → 100644
View file @
5ae5ed1e
Offline Inference
=================================
.. toctree::
:maxdepth: 1
llm
llm_inputs
docs/source/
offline_inference
/sampling_params.rst
→
docs/source/
dev
/sampling_params.rst
View file @
5ae5ed1e
File moved
docs/source/index.rst
View file @
5ae5ed1e
...
...
@@ -68,13 +68,6 @@ Documentation
getting_started/quickstart
getting_started/examples/examples_index
.. toctree::
:maxdepth: 1
:caption: Offline Inference
offline_inference/llm
offline_inference/sampling_params
.. toctree::
:maxdepth: 1
:caption: Serving
...
...
@@ -108,7 +101,9 @@ Documentation
.. toctree::
:maxdepth: 2
:caption: Developer Documentation
dev/sampling_params
dev/offline_inference/offline_index
dev/engine/engine_index
dev/kernel/paged_attention
dev/dockerfile/dockerfile
...
...
docs/source/serving/openai_compatible_server.md
View file @
5ae5ed1e
...
...
@@ -48,7 +48,7 @@ completion = client.chat.completions.create(
```
### Extra Parameters for Chat API
The following
[
sampling parameters (click through to see documentation)
](
../
offline_inference
/sampling_params.rst
)
are supported.
The following
[
sampling parameters (click through to see documentation)
](
../
dev
/sampling_params.rst
)
are supported.
```
{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
...
...
@@ -65,7 +65,7 @@ The following extra parameters are supported:
```
### Extra Parameters for Completions API
The following
[
sampling parameters (click through to see documentation)
](
../
offline_inference
/sampling_params.rst
)
are supported.
The following
[
sampling parameters (click through to see documentation)
](
../
dev
/sampling_params.rst
)
are supported.
```
{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
...
...
examples/llava_example.py
View file @
5ae5ed1e
...
...
@@ -23,11 +23,15 @@ def run_llava_pixel_values():
"
\n
USER: What is the content of this image?
\n
ASSISTANT:"
)
# This should be provided by another online or offline component.
images
=
torch
.
load
(
"images/stop_sign_pixel_values.pt"
)
image
=
torch
.
load
(
"images/stop_sign_pixel_values.pt"
)
outputs
=
llm
.
generate
({
"prompt"
:
prompt
,
"multi_modal_data"
:
MultiModalData
(
type
=
MultiModalData
.
Type
.
IMAGE
,
data
=
image
),
})
outputs
=
llm
.
generate
(
prompt
,
multi_modal_data
=
MultiModalData
(
type
=
MultiModalData
.
Type
.
IMAGE
,
data
=
images
))
for
o
in
outputs
:
generated_text
=
o
.
outputs
[
0
].
text
print
(
generated_text
)
...
...
@@ -46,11 +50,14 @@ def run_llava_image_features():
"
\n
USER: What is the content of this image?
\n
ASSISTANT:"
)
# This should be provided by another online or offline component.
images
=
torch
.
load
(
"images/stop_sign_image_features.pt"
)
outputs
=
llm
.
generate
(
prompt
,
multi_modal_data
=
MultiModalData
(
type
=
MultiModalData
.
Type
.
IMAGE
,
data
=
images
))
image
=
torch
.
load
(
"images/stop_sign_image_features.pt"
)
outputs
=
llm
.
generate
({
"prompt"
:
prompt
,
"multi_modal_data"
:
MultiModalData
(
type
=
MultiModalData
.
Type
.
IMAGE
,
data
=
image
),
})
for
o
in
outputs
:
generated_text
=
o
.
outputs
[
0
].
text
print
(
generated_text
)
...
...
pyproject.toml
View file @
5ae5ed1e
...
...
@@ -65,3 +65,10 @@ skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
[tool.isort]
use_parentheses
=
true
skip_gitignore
=
true
[tool.pytest.ini_options]
markers
=
[
"skip_global_cleanup"
,
"llm: run tests for vLLM API only"
,
"openai: run tests for OpenAI API only"
,
]
tests/async_engine/test_async_llm_engine.py
View file @
5ae5ed1e
...
...
@@ -25,7 +25,7 @@ class MockEngine:
return
[
RequestOutput
(
request_id
=
self
.
request_id
)]
if
self
.
request_id
else
[]
async
def
encode_request
_async
(
self
,
*
args
,
**
kwargs
):
async
def
process_model_inputs
_async
(
self
,
*
args
,
**
kwargs
):
pass
def
generate
(
self
,
request_id
):
...
...
tests/async_engine/test_openapi_server_ray.py
View file @
5ae5ed1e
...
...
@@ -29,7 +29,7 @@ def server():
ray
.
shutdown
()
@
pytest
.
fixture
(
scope
=
"
session
"
)
@
pytest
.
fixture
(
scope
=
"
module
"
)
def
client
():
client
=
openai
.
AsyncOpenAI
(
base_url
=
"http://localhost:8000/v1"
,
...
...
tests/conftest.py
View file @
5ae5ed1e
...
...
@@ -12,6 +12,7 @@ from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
TokenizerPoolConfig
,
VisionLanguageConfig
from
vllm.distributed
import
destroy_model_parallel
from
vllm.inputs
import
PromptInputs
from
vllm.logger
import
init_logger
from
vllm.sequence
import
MultiModalData
...
...
@@ -402,12 +403,22 @@ class VllmRunner:
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
if
images
is
not
None
:
assert
len
(
prompts
)
==
images
.
shape
[
0
]
req_outputs
=
self
.
model
.
generate
(
prompts
,
sampling_params
=
sampling_params
,
multi_modal_data
=
MultiModalData
(
type
=
MultiModalData
.
Type
.
IMAGE
,
data
=
images
)
if
images
is
not
None
else
None
)
prompt_inputs
:
List
[
PromptInputs
]
=
[]
for
i
,
prompt
in
enumerate
(
prompts
):
image
=
None
if
images
is
None
else
images
[
i
:
i
+
1
]
mm_data
=
None
if
image
is
None
else
MultiModalData
(
type
=
MultiModalData
.
Type
.
IMAGE
,
data
=
image
,
)
prompt_inputs
.
append
({
"prompt"
:
prompt
,
"multi_modal_data"
:
mm_data
,
})
req_outputs
=
self
.
model
.
generate
(
prompt_inputs
,
sampling_params
=
sampling_params
)
outputs
=
[]
for
req_output
in
req_outputs
:
prompt_str
=
req_output
.
prompt
...
...
tests/core/test_block_manager.py
View file @
5ae5ed1e
...
...
@@ -133,8 +133,11 @@ def test_append_slot_cow():
# Allocate prompt to gpu block. There is one slot left in the block.
prompt
=
Sequence
(
seq_id
=
1
,
prompt
=
"one two three"
,
prompt_token_ids
=
[
1
,
2
,
3
],
inputs
=
{
"prompt"
:
"one two three"
,
"prompt_token_ids"
:
[
1
,
2
,
3
],
"multi_modal_data"
:
None
},
block_size
=
block_size
)
# Fork the sequence, such that a COW will be required when we append a new
...
...
@@ -304,7 +307,13 @@ def test_sliding_window_multi_seq():
assert
block_manager
.
get_num_free_gpu_blocks
()
==
num_gpu_blocks
parent
=
Sequence
(
1
,
"one two three"
,
[
0
,
1
,
2
],
block_size
)
parent
=
Sequence
(
seq_id
=
1
,
inputs
=
{
"prompt"
:
"one two three"
,
"prompt_token_ids"
:
[
0
,
1
,
2
],
"multi_modal_data"
:
None
},
block_size
=
block_size
)
seq_group
=
SequenceGroup
(
request_id
=
"1"
,
seqs
=
[
parent
],
arrival_time
=
time
.
time
(),
...
...
tests/core/utils.py
View file @
5ae5ed1e
...
...
@@ -21,7 +21,13 @@ def create_dummy_prompt(
# and prompt "0 ... block_size".
prompt_tokens
=
list
(
range
(
prompt_length
))
prompt_str
=
" "
.
join
([
str
(
t
)
for
t
in
prompt_tokens
])
prompt
=
Sequence
(
int
(
request_id
),
prompt_str
,
prompt_tokens
,
block_size
)
prompt
=
Sequence
(
int
(
request_id
),
inputs
=
{
"prompt"
:
prompt_str
,
"prompt_token_ids"
:
prompt_tokens
,
"multi_modal_data"
:
None
,
},
block_size
=
block_size
)
seq_group
=
SequenceGroup
(
request_id
=
request_id
,
seqs
=
[
prompt
],
arrival_time
=
time
.
time
(),
...
...
@@ -51,8 +57,11 @@ def create_seq_group(
for
seq_id_offset
,
output_len
in
enumerate
(
seq_output_lens
):
seq
=
Sequence
(
seq_id
=
seq_id_start
+
seq_id_offset
,
prompt
=
""
,
prompt_token_ids
=
prompt_token_ids
,
inputs
=
{
"prompt"
:
""
,
"prompt_token_ids"
:
prompt_token_ids
,
"multi_modal_data"
:
None
,
},
block_size
=
16
,
)
...
...
tests/engine/test_skip_tokenizer_init.py
View file @
5ae5ed1e
...
...
@@ -14,7 +14,7 @@ def test_skip_tokenizer_initialization(model: str):
with
pytest
.
raises
(
ValueError
)
as
err
:
llm
.
generate
(
"abc"
,
sampling_params
)
assert
"prompts must be None if"
in
str
(
err
.
value
)
outputs
=
llm
.
generate
(
prompt_token_ids
=
[
[
1
,
2
,
3
]
]
,
outputs
=
llm
.
generate
(
{
"
prompt_token_ids
"
:
[
1
,
2
,
3
]
}
,
sampling_params
=
sampling_params
)
assert
len
(
outputs
)
>
0
completions
=
outputs
[
0
].
outputs
...
...
tests/entrypoints/openai/test_serving_chat.py
View file @
5ae5ed1e
import
asyncio
from
dataclasses
import
dataclass
import
pytest
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
MODEL_NAME
=
"openai-community/gpt2"
CHAT_TEMPLATE
=
"Dummy chat template for testing {}"
pytestmark
=
pytest
.
mark
.
openai
@
dataclass
class
MockModelConfig
:
...
...
tests/entrypoints/test_guided_processors.py
View file @
5ae5ed1e
...
...
@@ -52,6 +52,8 @@ TEST_SCHEMA = {
TEST_REGEX
=
(
r
"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r
"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
)
pytestmark
=
pytest
.
mark
.
openai
def
test_guided_logits_processors
():
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
...
...
tests/entrypoints/test_llm_encode.py
0 → 100644
View file @
5ae5ed1e
import
weakref
from
typing
import
List
import
pytest
from
vllm
import
LLM
,
EmbeddingRequestOutput
,
PoolingParams
from
..conftest
import
cleanup
MODEL_NAME
=
"intfloat/e5-mistral-7b-instruct"
PROMPTS
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
TOKEN_IDS
=
[
# Using ID={0, 1, 2, 3} results in NaN values,
# so we add this offset of 1000
[
1000
],
[
1000
,
1001
],
[
1000
,
1002
,
1001
],
[
1000
,
1003
,
1001
,
1002
],
]
pytestmark
=
pytest
.
mark
.
llm
@
pytest
.
fixture
(
scope
=
"module"
)
def
llm
():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm
=
LLM
(
model
=
MODEL_NAME
,
max_num_batched_tokens
=
32768
,
tensor_parallel_size
=
1
,
gpu_memory_utilization
=
0.75
,
enforce_eager
=
True
)
with
llm
.
deprecate_legacy_api
():
yield
weakref
.
proxy
(
llm
)
del
llm
cleanup
()
def
assert_outputs_equal
(
o1
:
List
[
EmbeddingRequestOutput
],
o2
:
List
[
EmbeddingRequestOutput
]):
assert
[
o
.
outputs
for
o
in
o1
]
==
[
o
.
outputs
for
o
in
o2
]
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
'prompt'
,
PROMPTS
)
def
test_v1_v2_api_consistency_single_prompt_string
(
llm
:
LLM
,
prompt
):
pooling_params
=
PoolingParams
()
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompts'"
):
v1_output
=
llm
.
encode
(
prompts
=
prompt
,
pooling_params
=
pooling_params
)
v2_output
=
llm
.
encode
(
prompt
,
pooling_params
=
pooling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
v2_output
=
llm
.
encode
({
"prompt"
:
prompt
},
pooling_params
=
pooling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
'prompt_token_ids'
,
TOKEN_IDS
)
def
test_v1_v2_api_consistency_single_prompt_tokens
(
llm
:
LLM
,
prompt_token_ids
):
pooling_params
=
PoolingParams
()
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompt_token_ids'"
):
v1_output
=
llm
.
encode
(
prompt_token_ids
=
prompt_token_ids
,
pooling_params
=
pooling_params
)
v2_output
=
llm
.
encode
({
"prompt_token_ids"
:
prompt_token_ids
},
pooling_params
=
pooling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_v1_v2_api_consistency_multi_prompt_string
(
llm
:
LLM
):
pooling_params
=
PoolingParams
()
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompts'"
):
v1_output
=
llm
.
encode
(
prompts
=
PROMPTS
,
pooling_params
=
pooling_params
)
v2_output
=
llm
.
encode
(
PROMPTS
,
pooling_params
=
pooling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
v2_output
=
llm
.
encode
(
[{
"prompt"
:
p
}
for
p
in
PROMPTS
],
pooling_params
=
pooling_params
,
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_v1_v2_api_consistency_multi_prompt_tokens
(
llm
:
LLM
):
pooling_params
=
PoolingParams
()
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompt_token_ids'"
):
v1_output
=
llm
.
encode
(
prompt_token_ids
=
TOKEN_IDS
,
pooling_params
=
pooling_params
)
v2_output
=
llm
.
encode
(
[{
"prompt_token_ids"
:
p
}
for
p
in
TOKEN_IDS
],
pooling_params
=
pooling_params
,
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_multiple_pooling_params
(
llm
:
LLM
):
pooling_params
=
[
PoolingParams
(),
PoolingParams
(),
PoolingParams
(),
PoolingParams
(),
]
# Multiple PoolingParams should be matched with each prompt
outputs
=
llm
.
encode
(
PROMPTS
,
pooling_params
=
pooling_params
)
assert
len
(
PROMPTS
)
==
len
(
outputs
)
# Exception raised, if the size of params does not match the size of prompts
with
pytest
.
raises
(
ValueError
):
outputs
=
llm
.
encode
(
PROMPTS
,
pooling_params
=
pooling_params
[:
3
])
# Single PoolingParams should be applied to every prompt
single_pooling_params
=
PoolingParams
()
outputs
=
llm
.
encode
(
PROMPTS
,
pooling_params
=
single_pooling_params
)
assert
len
(
PROMPTS
)
==
len
(
outputs
)
# pooling_params is None, default params should be applied
outputs
=
llm
.
encode
(
PROMPTS
,
pooling_params
=
None
)
assert
len
(
PROMPTS
)
==
len
(
outputs
)
tests/entrypoints/test_llm_generate.py
View file @
5ae5ed1e
import
weakref
from
typing
import
List
import
pytest
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
RequestOutput
,
SamplingParams
from
..conftest
import
cleanup
MODEL_NAME
=
"facebook/opt-125m"
PROMPTS
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
TOKEN_IDS
=
[
[
0
],
[
0
,
1
],
[
0
,
2
,
1
],
[
0
,
3
,
1
,
2
],
]
def
test_multiple_sampling_params
():
pytestmark
=
pytest
.
mark
.
llm
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
@
pytest
.
fixture
(
scope
=
"module"
)
def
llm
():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm
=
LLM
(
model
=
MODEL_NAME
,
max_num_batched_tokens
=
4096
,
tensor_parallel_size
=
1
)
tensor_parallel_size
=
1
,
gpu_memory_utilization
=
0.10
,
enforce_eager
=
True
)
with
llm
.
deprecate_legacy_api
():
yield
weakref
.
proxy
(
llm
)
del
llm
cleanup
()
def
assert_outputs_equal
(
o1
:
List
[
RequestOutput
],
o2
:
List
[
RequestOutput
]):
assert
[
o
.
outputs
for
o
in
o1
]
==
[
o
.
outputs
for
o
in
o2
]
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
'prompt'
,
PROMPTS
)
def
test_v1_v2_api_consistency_single_prompt_string
(
llm
:
LLM
,
prompt
):
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
)
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompts'"
):
v1_output
=
llm
.
generate
(
prompts
=
prompt
,
sampling_params
=
sampling_params
)
v2_output
=
llm
.
generate
(
prompt
,
sampling_params
=
sampling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
v2_output
=
llm
.
generate
({
"prompt"
:
prompt
},
sampling_params
=
sampling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
'prompt_token_ids'
,
TOKEN_IDS
)
def
test_v1_v2_api_consistency_single_prompt_tokens
(
llm
:
LLM
,
prompt_token_ids
):
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
)
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompt_token_ids'"
):
v1_output
=
llm
.
generate
(
prompt_token_ids
=
prompt_token_ids
,
sampling_params
=
sampling_params
)
v2_output
=
llm
.
generate
({
"prompt_token_ids"
:
prompt_token_ids
},
sampling_params
=
sampling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_v1_v2_api_consistency_multi_prompt_string
(
llm
:
LLM
):
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
)
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompts'"
):
v1_output
=
llm
.
generate
(
prompts
=
PROMPTS
,
sampling_params
=
sampling_params
)
v2_output
=
llm
.
generate
(
PROMPTS
,
sampling_params
=
sampling_params
)
assert_outputs_equal
(
v1_output
,
v2_output
)
v2_output
=
llm
.
generate
(
[{
"prompt"
:
p
}
for
p
in
PROMPTS
],
sampling_params
=
sampling_params
,
)
assert_outputs_equal
(
v1_output
,
v2_output
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_v1_v2_api_consistency_multi_prompt_tokens
(
llm
:
LLM
):
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
top_p
=
1.0
)
with
pytest
.
warns
(
DeprecationWarning
,
match
=
"'prompt_token_ids'"
):
v1_output
=
llm
.
generate
(
prompt_token_ids
=
TOKEN_IDS
,
sampling_params
=
sampling_params
)
v2_output
=
llm
.
generate
(
[{
"prompt_token_ids"
:
p
}
for
p
in
TOKEN_IDS
],
sampling_params
=
sampling_params
,
)
assert_outputs_equal
(
v1_output
,
v2_output
)
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
@
pytest
.
mark
.
skip_global_cleanup
def
test_multiple_sampling_params
(
llm
:
LLM
):
sampling_params
=
[
SamplingParams
(
temperature
=
0.01
,
top_p
=
0.95
),
SamplingParams
(
temperature
=
0.3
,
top_p
=
0.95
),
...
...
@@ -24,18 +127,18 @@ def test_multiple_sampling_params():
]
# Multiple SamplingParams should be matched with each prompt
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
sampling_params
)
assert
len
(
prompts
)
==
len
(
outputs
)
outputs
=
llm
.
generate
(
PROMPTS
,
sampling_params
=
sampling_params
)
assert
len
(
PROMPTS
)
==
len
(
outputs
)
# Exception raised, if the size of params does not match the size of prompts
with
pytest
.
raises
(
ValueError
):
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
sampling_params
[:
3
])
outputs
=
llm
.
generate
(
PROMPTS
,
sampling_params
=
sampling_params
[:
3
])
# Single SamplingParams should be applied to every prompt
single_sampling_params
=
SamplingParams
(
temperature
=
0.3
,
top_p
=
0.95
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
single_sampling_params
)
assert
len
(
prompts
)
==
len
(
outputs
)
outputs
=
llm
.
generate
(
PROMPTS
,
sampling_params
=
single_sampling_params
)
assert
len
(
PROMPTS
)
==
len
(
outputs
)
# sampling_params is None, default params should be applied
outputs
=
llm
.
generate
(
prompts
,
sampling_params
=
None
)
assert
len
(
prompts
)
==
len
(
outputs
)
\ No newline at end of file
outputs
=
llm
.
generate
(
PROMPTS
,
sampling_params
=
None
)
assert
len
(
PROMPTS
)
==
len
(
outputs
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment