Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cf069aa8
Unverified
Commit
cf069aa8
authored
Mar 03, 2025
by
Harry Mellor
Committed by
GitHub
Mar 02, 2025
Browse files
Update deprecated Python 3.8 typing (#13971)
parent
bf33700e
Changes
300
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
134 additions
and
140 deletions
+134
-140
tests/tool_use/test_chat_completions.py
tests/tool_use/test_chat_completions.py
+2
-4
tests/tool_use/test_jamba_tool_parser.py
tests/tool_use/test_jamba_tool_parser.py
+7
-6
tests/tool_use/test_parallel_tool_calls.py
tests/tool_use/test_parallel_tool_calls.py
+5
-5
tests/tool_use/test_tool_calls.py
tests/tool_use/test_tool_calls.py
+5
-5
tests/tool_use/utils.py
tests/tool_use/utils.py
+13
-13
tests/tracing/test_tracing.py
tests/tracing/test_tracing.py
+3
-2
tests/utils.py
tests/utils.py
+21
-21
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+1
-2
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+3
-3
tests/v1/engine/conftest.py
tests/v1/engine/conftest.py
+2
-4
tests/v1/engine/test_async_llm.py
tests/v1/engine/test_async_llm.py
+5
-5
tests/v1/engine/test_engine_core.py
tests/v1/engine/test_engine_core.py
+1
-2
tests/v1/engine/test_engine_core_client.py
tests/v1/engine/test_engine_core_client.py
+5
-5
tests/v1/engine/test_llm_engine.py
tests/v1/engine/test_llm_engine.py
+4
-4
tests/v1/engine/test_output_processor.py
tests/v1/engine/test_output_processor.py
+6
-6
tests/v1/engine/utils.py
tests/v1/engine/utils.py
+25
-25
tests/v1/entrypoints/openai/test_completion.py
tests/v1/entrypoints/openai/test_completion.py
+6
-6
tests/v1/sample/test_logprobs.py
tests/v1/sample/test_logprobs.py
+4
-5
tests/v1/sample/test_rejection_sampler.py
tests/v1/sample/test_rejection_sampler.py
+3
-4
tests/v1/sample/test_sampler.py
tests/v1/sample/test_sampler.py
+13
-13
No files found.
tests/tool_use/test_chat_completions.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
import
openai
import
openai
import
pytest
import
pytest
...
@@ -45,7 +43,7 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI,
...
@@ -45,7 +43,7 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI,
logprobs
=
False
,
logprobs
=
False
,
stream
=
True
,
stream
=
True
,
)
)
chunks
:
L
ist
[
str
]
=
[]
chunks
:
l
ist
[
str
]
=
[]
finish_reason_count
=
0
finish_reason_count
=
0
role_sent
:
bool
=
False
role_sent
:
bool
=
False
...
@@ -116,7 +114,7 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI,
...
@@ -116,7 +114,7 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI,
stream
=
True
,
stream
=
True
,
)
)
chunks
:
L
ist
[
str
]
=
[]
chunks
:
l
ist
[
str
]
=
[]
finish_reason_count
=
0
finish_reason_count
=
0
role_sent
:
bool
=
False
role_sent
:
bool
=
False
...
...
tests/tool_use/test_jamba_tool_parser.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
json
import
json
from
typing
import
Generator
,
List
,
Optional
from
collections.abc
import
Generator
from
typing
import
Optional
import
partial_json_parser
import
partial_json_parser
import
pytest
import
pytest
...
@@ -26,8 +27,8 @@ def jamba_tool_parser(jamba_tokenizer):
...
@@ -26,8 +27,8 @@ def jamba_tool_parser(jamba_tokenizer):
return
JambaToolParser
(
jamba_tokenizer
)
return
JambaToolParser
(
jamba_tokenizer
)
def
assert_tool_calls
(
actual_tool_calls
:
L
ist
[
ToolCall
],
def
assert_tool_calls
(
actual_tool_calls
:
l
ist
[
ToolCall
],
expected_tool_calls
:
L
ist
[
ToolCall
]):
expected_tool_calls
:
l
ist
[
ToolCall
]):
assert
len
(
actual_tool_calls
)
==
len
(
expected_tool_calls
)
assert
len
(
actual_tool_calls
)
==
len
(
expected_tool_calls
)
for
actual_tool_call
,
expected_tool_call
in
zip
(
actual_tool_calls
,
for
actual_tool_call
,
expected_tool_call
in
zip
(
actual_tool_calls
,
...
@@ -218,10 +219,10 @@ def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer,
...
@@ -218,10 +219,10 @@ def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer,
model_output
,
expected_tool_calls
,
model_output
,
expected_tool_calls
,
expected_content
):
expected_content
):
other_content
:
str
=
''
other_content
:
str
=
''
function_names
:
L
ist
[
str
]
=
[]
function_names
:
l
ist
[
str
]
=
[]
function_args_strs
:
L
ist
[
str
]
=
[]
function_args_strs
:
l
ist
[
str
]
=
[]
tool_call_idx
:
int
=
-
1
tool_call_idx
:
int
=
-
1
tool_call_ids
:
L
ist
[
Optional
[
str
]]
=
[]
tool_call_ids
:
l
ist
[
Optional
[
str
]]
=
[]
for
delta_message
in
stream_delta_message_generator
(
for
delta_message
in
stream_delta_message_generator
(
jamba_tool_parser
,
jamba_tokenizer
,
model_output
):
jamba_tool_parser
,
jamba_tokenizer
,
model_output
):
...
...
tests/tool_use/test_parallel_tool_calls.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
json
import
json
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Optional
import
openai
import
openai
import
pytest
import
pytest
...
@@ -54,7 +54,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
...
@@ -54,7 +54,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
assert
isinstance
(
tool_call
.
function
.
arguments
,
str
)
assert
isinstance
(
tool_call
.
function
.
arguments
,
str
)
parsed_arguments
=
json
.
loads
(
tool_call
.
function
.
arguments
)
parsed_arguments
=
json
.
loads
(
tool_call
.
function
.
arguments
)
assert
isinstance
(
parsed_arguments
,
D
ict
)
assert
isinstance
(
parsed_arguments
,
d
ict
)
assert
isinstance
(
parsed_arguments
.
get
(
"city"
),
str
)
assert
isinstance
(
parsed_arguments
.
get
(
"city"
),
str
)
assert
isinstance
(
parsed_arguments
.
get
(
"state"
),
str
)
assert
isinstance
(
parsed_arguments
.
get
(
"state"
),
str
)
...
@@ -73,8 +73,8 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
...
@@ -73,8 +73,8 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
role_name
:
Optional
[
str
]
=
None
role_name
:
Optional
[
str
]
=
None
finish_reason_count
:
int
=
0
finish_reason_count
:
int
=
0
tool_call_names
:
L
ist
[
str
]
=
[]
tool_call_names
:
l
ist
[
str
]
=
[]
tool_call_args
:
L
ist
[
str
]
=
[]
tool_call_args
:
l
ist
[
str
]
=
[]
tool_call_idx
:
int
=
-
1
tool_call_idx
:
int
=
-
1
tool_call_id_count
:
int
=
0
tool_call_id_count
:
int
=
0
...
@@ -180,7 +180,7 @@ async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI,
...
@@ -180,7 +180,7 @@ async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI,
logprobs
=
False
,
logprobs
=
False
,
stream
=
True
)
stream
=
True
)
chunks
:
L
ist
[
str
]
=
[]
chunks
:
l
ist
[
str
]
=
[]
finish_reason_count
=
0
finish_reason_count
=
0
role_sent
:
bool
=
False
role_sent
:
bool
=
False
...
...
tests/tool_use/test_tool_calls.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
json
import
json
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Optional
import
openai
import
openai
import
pytest
import
pytest
...
@@ -44,7 +44,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
...
@@ -44,7 +44,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
# make sure the arguments parse properly
# make sure the arguments parse properly
parsed_arguments
=
json
.
loads
(
tool_calls
[
0
].
function
.
arguments
)
parsed_arguments
=
json
.
loads
(
tool_calls
[
0
].
function
.
arguments
)
assert
isinstance
(
parsed_arguments
,
D
ict
)
assert
isinstance
(
parsed_arguments
,
d
ict
)
assert
isinstance
(
parsed_arguments
.
get
(
"city"
),
str
)
assert
isinstance
(
parsed_arguments
.
get
(
"city"
),
str
)
assert
isinstance
(
parsed_arguments
.
get
(
"state"
),
str
)
assert
isinstance
(
parsed_arguments
.
get
(
"state"
),
str
)
assert
parsed_arguments
.
get
(
"city"
)
==
"Dallas"
assert
parsed_arguments
.
get
(
"city"
)
==
"Dallas"
...
@@ -117,7 +117,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
...
@@ -117,7 +117,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
# validate arguments
# validate arguments
streamed_args
=
json
.
loads
(
function_args_str
)
streamed_args
=
json
.
loads
(
function_args_str
)
assert
isinstance
(
streamed_args
,
D
ict
)
assert
isinstance
(
streamed_args
,
d
ict
)
assert
isinstance
(
streamed_args
.
get
(
"city"
),
str
)
assert
isinstance
(
streamed_args
.
get
(
"city"
),
str
)
assert
isinstance
(
streamed_args
.
get
(
"state"
),
str
)
assert
isinstance
(
streamed_args
.
get
(
"state"
),
str
)
assert
streamed_args
.
get
(
"city"
)
==
"Dallas"
assert
streamed_args
.
get
(
"city"
)
==
"Dallas"
...
@@ -128,7 +128,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
...
@@ -128,7 +128,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
assert
choice
.
message
.
role
==
role_name
assert
choice
.
message
.
role
==
role_name
assert
choice
.
message
.
tool_calls
[
0
].
function
.
name
==
function_name
assert
choice
.
message
.
tool_calls
[
0
].
function
.
name
==
function_name
# compare streamed with non-streamed args
D
ict-wise, not string-wise
# compare streamed with non-streamed args
d
ict-wise, not string-wise
# because character-to-character comparison might not work e.g. the tool
# because character-to-character comparison might not work e.g. the tool
# call parser adding extra spaces or something like that. we care about the
# call parser adding extra spaces or something like that. we care about the
# dicts matching not byte-wise match
# dicts matching not byte-wise match
...
@@ -167,7 +167,7 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI):
...
@@ -167,7 +167,7 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI):
logprobs
=
False
,
logprobs
=
False
,
stream
=
True
)
stream
=
True
)
chunks
:
L
ist
[
str
]
=
[]
chunks
:
l
ist
[
str
]
=
[]
finish_reason_count
=
0
finish_reason_count
=
0
role_sent
:
bool
=
False
role_sent
:
bool
=
False
...
...
tests/tool_use/utils.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
copy
import
deepcopy
from
copy
import
deepcopy
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Optional
from
openai.types.chat
import
(
ChatCompletionMessageParam
,
from
openai.types.chat
import
(
ChatCompletionMessageParam
,
ChatCompletionToolParam
)
ChatCompletionToolParam
)
...
@@ -12,14 +12,14 @@ from tests.utils import VLLM_PATH
...
@@ -12,14 +12,14 @@ from tests.utils import VLLM_PATH
class
ServerConfig
(
TypedDict
,
total
=
False
):
class
ServerConfig
(
TypedDict
,
total
=
False
):
model
:
str
model
:
str
arguments
:
L
ist
[
str
]
arguments
:
l
ist
[
str
]
system_prompt
:
Optional
[
str
]
system_prompt
:
Optional
[
str
]
supports_parallel
:
Optional
[
bool
]
supports_parallel
:
Optional
[
bool
]
supports_rocm
:
Optional
[
bool
]
supports_rocm
:
Optional
[
bool
]
def
patch_system_prompt
(
messages
:
L
ist
[
D
ict
[
str
,
Any
]],
def
patch_system_prompt
(
messages
:
l
ist
[
d
ict
[
str
,
Any
]],
system_prompt
:
str
)
->
L
ist
[
D
ict
[
str
,
Any
]]:
system_prompt
:
str
)
->
l
ist
[
d
ict
[
str
,
Any
]]:
new_messages
=
deepcopy
(
messages
)
new_messages
=
deepcopy
(
messages
)
if
new_messages
[
0
][
"role"
]
==
"system"
:
if
new_messages
[
0
][
"role"
]
==
"system"
:
new_messages
[
0
][
"content"
]
=
system_prompt
new_messages
[
0
][
"content"
]
=
system_prompt
...
@@ -28,8 +28,8 @@ def patch_system_prompt(messages: List[Dict[str, Any]],
...
@@ -28,8 +28,8 @@ def patch_system_prompt(messages: List[Dict[str, Any]],
return
new_messages
return
new_messages
def
ensure_system_prompt
(
messages
:
L
ist
[
D
ict
[
str
,
Any
]],
def
ensure_system_prompt
(
messages
:
l
ist
[
d
ict
[
str
,
Any
]],
config
:
ServerConfig
)
->
L
ist
[
D
ict
[
str
,
Any
]]:
config
:
ServerConfig
)
->
l
ist
[
d
ict
[
str
,
Any
]]:
prompt
=
config
.
get
(
"system_prompt"
)
prompt
=
config
.
get
(
"system_prompt"
)
if
prompt
:
if
prompt
:
return
patch_system_prompt
(
messages
,
prompt
)
return
patch_system_prompt
(
messages
,
prompt
)
...
@@ -39,9 +39,9 @@ def ensure_system_prompt(messages: List[Dict[str, Any]],
...
@@ -39,9 +39,9 @@ def ensure_system_prompt(messages: List[Dict[str, Any]],
# universal args for all models go here. also good if you need to test locally
# universal args for all models go here. also good if you need to test locally
# and change type or KV cache quantization or something.
# and change type or KV cache quantization or something.
ARGS
:
L
ist
[
str
]
=
[
"--enable-auto-tool-choice"
,
"--max-model-len"
,
"1024"
]
ARGS
:
l
ist
[
str
]
=
[
"--enable-auto-tool-choice"
,
"--max-model-len"
,
"1024"
]
CONFIGS
:
D
ict
[
str
,
ServerConfig
]
=
{
CONFIGS
:
d
ict
[
str
,
ServerConfig
]
=
{
"hermes"
:
{
"hermes"
:
{
"model"
:
"model"
:
"NousResearch/Hermes-3-Llama-3.1-8B"
,
"NousResearch/Hermes-3-Llama-3.1-8B"
,
...
@@ -205,7 +205,7 @@ SEARCH_TOOL: ChatCompletionToolParam = {
...
@@ -205,7 +205,7 @@ SEARCH_TOOL: ChatCompletionToolParam = {
}
}
}
}
MESSAGES_WITHOUT_TOOLS
:
L
ist
[
ChatCompletionMessageParam
]
=
[{
MESSAGES_WITHOUT_TOOLS
:
l
ist
[
ChatCompletionMessageParam
]
=
[{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
"content"
:
...
@@ -222,14 +222,14 @@ MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{
...
@@ -222,14 +222,14 @@ MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{
"Can you tell me a joke please?"
"Can you tell me a joke please?"
}]
}]
MESSAGES_ASKING_FOR_TOOLS
:
L
ist
[
ChatCompletionMessageParam
]
=
[{
MESSAGES_ASKING_FOR_TOOLS
:
l
ist
[
ChatCompletionMessageParam
]
=
[{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
"content"
:
"What is the weather in Dallas, Texas in Fahrenheit?"
"What is the weather in Dallas, Texas in Fahrenheit?"
}]
}]
MESSAGES_WITH_TOOL_RESPONSE
:
L
ist
[
ChatCompletionMessageParam
]
=
[{
MESSAGES_WITH_TOOL_RESPONSE
:
l
ist
[
ChatCompletionMessageParam
]
=
[{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
"content"
:
...
@@ -258,7 +258,7 @@ MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{
...
@@ -258,7 +258,7 @@ MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{
"cloudy skies and a low chance of rain."
"cloudy skies and a low chance of rain."
}]
}]
MESSAGES_ASKING_FOR_PARALLEL_TOOLS
:
L
ist
[
ChatCompletionMessageParam
]
=
[{
MESSAGES_ASKING_FOR_PARALLEL_TOOLS
:
l
ist
[
ChatCompletionMessageParam
]
=
[{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
"content"
:
...
@@ -266,7 +266,7 @@ MESSAGES_ASKING_FOR_PARALLEL_TOOLS: List[ChatCompletionMessageParam] = [{
...
@@ -266,7 +266,7 @@ MESSAGES_ASKING_FOR_PARALLEL_TOOLS: List[ChatCompletionMessageParam] = [{
"Fahrenheit?"
"Fahrenheit?"
}]
}]
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE
:
L
ist
[
ChatCompletionMessageParam
]
=
[{
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE
:
l
ist
[
ChatCompletionMessageParam
]
=
[{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
"content"
:
...
...
tests/tracing/test_tracing.py
View file @
cf069aa8
...
@@ -2,8 +2,9 @@
...
@@ -2,8 +2,9 @@
import
os
import
os
import
threading
import
threading
from
collections.abc
import
Iterable
from
concurrent
import
futures
from
concurrent
import
futures
from
typing
import
Callable
,
Dict
,
Iterable
,
Literal
from
typing
import
Callable
,
Literal
import
grpc
import
grpc
import
pytest
import
pytest
...
@@ -25,7 +26,7 @@ FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value',
...
@@ -25,7 +26,7 @@ FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value',
def
decode_value
(
value
:
AnyValue
):
def
decode_value
(
value
:
AnyValue
):
field_decoders
:
D
ict
[
FieldName
,
Callable
]
=
{
field_decoders
:
d
ict
[
FieldName
,
Callable
]
=
{
"bool_value"
:
(
lambda
v
:
v
.
bool_value
),
"bool_value"
:
(
lambda
v
:
v
.
bool_value
),
"string_value"
:
(
lambda
v
:
v
.
string_value
),
"string_value"
:
(
lambda
v
:
v
.
string_value
),
"int_value"
:
(
lambda
v
:
v
.
int_value
),
"int_value"
:
(
lambda
v
:
v
.
int_value
),
...
...
tests/utils.py
View file @
cf069aa8
...
@@ -11,7 +11,7 @@ import time
...
@@ -11,7 +11,7 @@ import time
import
warnings
import
warnings
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Type
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
openai
import
openai
import
pytest
import
pytest
...
@@ -73,9 +73,9 @@ class RemoteOpenAIServer:
...
@@ -73,9 +73,9 @@ class RemoteOpenAIServer:
def
__init__
(
self
,
def
__init__
(
self
,
model
:
str
,
model
:
str
,
vllm_serve_args
:
L
ist
[
str
],
vllm_serve_args
:
l
ist
[
str
],
*
,
*
,
env_dict
:
Optional
[
D
ict
[
str
,
str
]]
=
None
,
env_dict
:
Optional
[
d
ict
[
str
,
str
]]
=
None
,
auto_port
:
bool
=
True
,
auto_port
:
bool
=
True
,
max_wait_seconds
:
Optional
[
float
]
=
None
)
->
None
:
max_wait_seconds
:
Optional
[
float
]
=
None
)
->
None
:
if
auto_port
:
if
auto_port
:
...
@@ -183,7 +183,7 @@ def _test_completion(
...
@@ -183,7 +183,7 @@ def _test_completion(
client
:
openai
.
OpenAI
,
client
:
openai
.
OpenAI
,
model
:
str
,
model
:
str
,
prompt
:
str
,
prompt
:
str
,
token_ids
:
L
ist
[
int
],
token_ids
:
l
ist
[
int
],
):
):
results
=
[]
results
=
[]
...
@@ -400,10 +400,10 @@ def _test_image_text(
...
@@ -400,10 +400,10 @@ def _test_image_text(
def
compare_two_settings
(
model
:
str
,
def
compare_two_settings
(
model
:
str
,
arg1
:
L
ist
[
str
],
arg1
:
l
ist
[
str
],
arg2
:
L
ist
[
str
],
arg2
:
l
ist
[
str
],
env1
:
Optional
[
D
ict
[
str
,
str
]]
=
None
,
env1
:
Optional
[
d
ict
[
str
,
str
]]
=
None
,
env2
:
Optional
[
D
ict
[
str
,
str
]]
=
None
,
env2
:
Optional
[
d
ict
[
str
,
str
]]
=
None
,
*
,
*
,
method
:
str
=
"generate"
,
method
:
str
=
"generate"
,
max_wait_seconds
:
Optional
[
float
]
=
None
)
->
None
:
max_wait_seconds
:
Optional
[
float
]
=
None
)
->
None
:
...
@@ -429,8 +429,8 @@ def compare_two_settings(model: str,
...
@@ -429,8 +429,8 @@ def compare_two_settings(model: str,
def
compare_all_settings
(
model
:
str
,
def
compare_all_settings
(
model
:
str
,
all_args
:
L
ist
[
L
ist
[
str
]],
all_args
:
l
ist
[
l
ist
[
str
]],
all_envs
:
L
ist
[
Optional
[
D
ict
[
str
,
str
]]],
all_envs
:
l
ist
[
Optional
[
d
ict
[
str
,
str
]]],
*
,
*
,
method
:
str
=
"generate"
,
method
:
str
=
"generate"
,
max_wait_seconds
:
Optional
[
float
]
=
None
)
->
None
:
max_wait_seconds
:
Optional
[
float
]
=
None
)
->
None
:
...
@@ -470,7 +470,7 @@ def compare_all_settings(model: str,
...
@@ -470,7 +470,7 @@ def compare_all_settings(model: str,
prompt
=
"Hello, my name is"
prompt
=
"Hello, my name is"
token_ids
=
tokenizer
(
prompt
).
input_ids
token_ids
=
tokenizer
(
prompt
).
input_ids
ref_results
:
L
ist
=
[]
ref_results
:
l
ist
=
[]
for
i
,
(
args
,
env
)
in
enumerate
(
zip
(
all_args
,
all_envs
)):
for
i
,
(
args
,
env
)
in
enumerate
(
zip
(
all_args
,
all_envs
)):
if
can_force_load_format
:
if
can_force_load_format
:
# we are comparing the results and
# we are comparing the results and
...
@@ -481,7 +481,7 @@ def compare_all_settings(model: str,
...
@@ -481,7 +481,7 @@ def compare_all_settings(model: str,
# environment variable to force the load format,
# environment variable to force the load format,
# e.g. in quantization tests.
# e.g. in quantization tests.
args
=
args
+
[
"--load-format"
,
envs
.
VLLM_TEST_FORCE_LOAD_FORMAT
]
args
=
args
+
[
"--load-format"
,
envs
.
VLLM_TEST_FORCE_LOAD_FORMAT
]
compare_results
:
L
ist
=
[]
compare_results
:
l
ist
=
[]
results
=
ref_results
if
i
==
0
else
compare_results
results
=
ref_results
if
i
==
0
else
compare_results
with
RemoteOpenAIServer
(
model
,
with
RemoteOpenAIServer
(
model
,
args
,
args
,
...
@@ -582,7 +582,7 @@ def multi_process_parallel(
...
@@ -582,7 +582,7 @@ def multi_process_parallel(
@
contextmanager
@
contextmanager
def
error_on_warning
(
category
:
T
ype
[
Warning
]
=
Warning
):
def
error_on_warning
(
category
:
t
ype
[
Warning
]
=
Warning
):
"""
"""
Within the scope of this context manager, tests will fail if any warning
Within the scope of this context manager, tests will fail if any warning
of the given category is emitted.
of the given category is emitted.
...
@@ -604,7 +604,7 @@ def get_physical_device_indices(devices):
...
@@ -604,7 +604,7 @@ def get_physical_device_indices(devices):
@
_nvml
()
@
_nvml
()
def
wait_for_gpu_memory_to_clear
(
devices
:
L
ist
[
int
],
def
wait_for_gpu_memory_to_clear
(
devices
:
l
ist
[
int
],
threshold_bytes
:
int
,
threshold_bytes
:
int
,
timeout_s
:
float
=
120
)
->
None
:
timeout_s
:
float
=
120
)
->
None
:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# Use nvml instead of pytorch to reduce measurement error from torch cuda
...
@@ -612,8 +612,8 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
...
@@ -612,8 +612,8 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
devices
=
get_physical_device_indices
(
devices
)
devices
=
get_physical_device_indices
(
devices
)
start_time
=
time
.
time
()
start_time
=
time
.
time
()
while
True
:
while
True
:
output
:
D
ict
[
int
,
str
]
=
{}
output
:
d
ict
[
int
,
str
]
=
{}
output_raw
:
D
ict
[
int
,
float
]
=
{}
output_raw
:
d
ict
[
int
,
float
]
=
{}
for
device
in
devices
:
for
device
in
devices
:
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
dev_handle
=
amdsmi_get_processor_handles
()[
device
]
dev_handle
=
amdsmi_get_processor_handles
()[
device
]
...
@@ -758,13 +758,13 @@ def multi_gpu_test(*, num_gpus: int):
...
@@ -758,13 +758,13 @@ def multi_gpu_test(*, num_gpus: int):
async
def
completions_with_server_args
(
async
def
completions_with_server_args
(
prompts
:
L
ist
[
str
],
prompts
:
l
ist
[
str
],
model_name
:
str
,
model_name
:
str
,
server_cli_args
:
L
ist
[
str
],
server_cli_args
:
l
ist
[
str
],
num_logprobs
:
Optional
[
int
],
num_logprobs
:
Optional
[
int
],
max_wait_seconds
:
int
=
240
,
max_wait_seconds
:
int
=
240
,
max_tokens
:
Union
[
int
,
list
]
=
5
,
max_tokens
:
Union
[
int
,
list
]
=
5
,
)
->
L
ist
[
Completion
]:
)
->
l
ist
[
Completion
]:
'''Construct a remote OpenAI server, obtain an async client to the
'''Construct a remote OpenAI server, obtain an async client to the
server & invoke the completions API to obtain completions.
server & invoke the completions API to obtain completions.
...
@@ -807,7 +807,7 @@ async def completions_with_server_args(
...
@@ -807,7 +807,7 @@ async def completions_with_server_args(
return
outputs
return
outputs
def
get_client_text_generations
(
completions
:
L
ist
[
Completion
])
->
L
ist
[
str
]:
def
get_client_text_generations
(
completions
:
l
ist
[
Completion
])
->
l
ist
[
str
]:
'''Extract generated tokens from the output of a
'''Extract generated tokens from the output of a
request made to an Open-AI-protocol completions endpoint.
request made to an Open-AI-protocol completions endpoint.
'''
'''
...
@@ -816,7 +816,7 @@ def get_client_text_generations(completions: List[Completion]) -> List[str]:
...
@@ -816,7 +816,7 @@ def get_client_text_generations(completions: List[Completion]) -> List[str]:
def
get_client_text_logprob_generations
(
def
get_client_text_logprob_generations
(
completions
:
L
ist
[
Completion
])
->
L
ist
[
TextTextLogprobs
]:
completions
:
l
ist
[
Completion
])
->
l
ist
[
TextTextLogprobs
]:
'''Operates on the output of a request made to an Open-AI-protocol
'''Operates on the output of a request made to an Open-AI-protocol
completions endpoint; obtains top-rank logprobs for each token in
completions endpoint; obtains top-rank logprobs for each token in
each :class:`SequenceGroup`
each :class:`SequenceGroup`
...
...
tests/v1/core/test_prefix_caching.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""Compare the with and without prefix caching."""
"""Compare the with and without prefix caching."""
from
typing
import
List
import
pytest
import
pytest
...
@@ -434,7 +433,7 @@ def test_cache_blocks():
...
@@ -434,7 +433,7 @@ def test_cache_blocks():
# Test that blocks are cached correctly for 2 full blocks from the start.
# Test that blocks are cached correctly for 2 full blocks from the start.
blocks
=
[
KVCacheBlock
(
block_id
=
i
)
for
i
in
range
(
2
)]
blocks
=
[
KVCacheBlock
(
block_id
=
i
)
for
i
in
range
(
2
)]
block_hashes
:
L
ist
[
BlockHashType
]
=
[]
block_hashes
:
l
ist
[
BlockHashType
]
=
[]
block_pool
.
cache_full_blocks
(
block_pool
.
cache_full_blocks
(
request
=
req
,
request
=
req
,
...
...
tests/v1/core/test_scheduler.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
from
typing
import
Optional
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
...
@@ -48,9 +48,9 @@ def create_scheduler(
...
@@ -48,9 +48,9 @@ def create_scheduler(
def
create_requests
(
def
create_requests
(
num_requests
:
int
,
num_requests
:
int
,
num_tokens
:
int
=
10
,
num_tokens
:
int
=
10
,
mm_positions
:
Optional
[
L
ist
[
PlaceholderRange
]]
=
None
,
mm_positions
:
Optional
[
l
ist
[
PlaceholderRange
]]
=
None
,
max_tokens
:
int
=
16
,
max_tokens
:
int
=
16
,
stop_token_ids
:
Optional
[
L
ist
[
int
]]
=
None
,
stop_token_ids
:
Optional
[
l
ist
[
int
]]
=
None
,
):
):
sampling_params
=
SamplingParams
(
ignore_eos
=
False
,
sampling_params
=
SamplingParams
(
ignore_eos
=
False
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
...
...
tests/v1/engine/conftest.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Tuple
import
pytest
import
pytest
import
torch
import
torch
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
...
@@ -17,8 +15,8 @@ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
...
@@ -17,8 +15,8 @@ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from
tests.v1.engine.utils
import
FULL_STRINGS
# isort: skip
from
tests.v1.engine.utils
import
FULL_STRINGS
# isort: skip
EngineCoreSampleLogprobsType
=
L
ist
[
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
]]
EngineCoreSampleLogprobsType
=
l
ist
[
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]]
EngineCorePromptLogprobsType
=
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
]
EngineCorePromptLogprobsType
=
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]
def
_build_test_vectors_no_logprobs
()
->
DummyOutputProcessorTestVectors
:
def
_build_test_vectors_no_logprobs
()
->
DummyOutputProcessorTestVectors
:
...
...
tests/v1/engine/test_async_llm.py
View file @
cf069aa8
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
import
asyncio
import
asyncio
from
contextlib
import
ExitStack
from
contextlib
import
ExitStack
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
pytest
...
@@ -47,7 +47,7 @@ async def generate(engine: AsyncLLM,
...
@@ -47,7 +47,7 @@ async def generate(engine: AsyncLLM,
prompt
:
PromptType
,
prompt
:
PromptType
,
output_kind
:
RequestOutputKind
,
output_kind
:
RequestOutputKind
,
max_tokens
:
int
,
max_tokens
:
int
,
prompt_logprobs
:
Optional
[
int
]
=
None
)
->
T
uple
[
int
,
str
]:
prompt_logprobs
:
Optional
[
int
]
=
None
)
->
t
uple
[
int
,
str
]:
# Ensure generate doesn't complete too fast for cancellation test.
# Ensure generate doesn't complete too fast for cancellation test.
await
asyncio
.
sleep
(
0.2
)
await
asyncio
.
sleep
(
0.2
)
...
@@ -114,7 +114,7 @@ async def test_async_llm_refuses_prompt_logprobs_with_apc(
...
@@ -114,7 +114,7 @@ async def test_async_llm_refuses_prompt_logprobs_with_apc(
(
VISION_ENGINE_ARGS
,
VISION_PROMPT
)])
(
VISION_ENGINE_ARGS
,
VISION_PROMPT
)])
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_load
(
monkeypatch
,
output_kind
:
RequestOutputKind
,
async
def
test_load
(
monkeypatch
,
output_kind
:
RequestOutputKind
,
engine_args_and_prompt
:
T
uple
[
AsyncEngineArgs
,
engine_args_and_prompt
:
t
uple
[
AsyncEngineArgs
,
PromptType
]):
PromptType
]):
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
# so that in the future when we switch, we don't have to change all the
# so that in the future when we switch, we don't have to change all the
...
@@ -160,7 +160,7 @@ async def test_load(monkeypatch, output_kind: RequestOutputKind,
...
@@ -160,7 +160,7 @@ async def test_load(monkeypatch, output_kind: RequestOutputKind,
(
VISION_ENGINE_ARGS
,
VISION_PROMPT
)])
(
VISION_ENGINE_ARGS
,
VISION_PROMPT
)])
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_abort
(
monkeypatch
,
output_kind
:
RequestOutputKind
,
async
def
test_abort
(
monkeypatch
,
output_kind
:
RequestOutputKind
,
engine_args_and_prompt
:
T
uple
[
AsyncEngineArgs
,
engine_args_and_prompt
:
t
uple
[
AsyncEngineArgs
,
PromptType
]):
PromptType
]):
with
monkeypatch
.
context
()
as
m
,
ExitStack
()
as
after
:
with
monkeypatch
.
context
()
as
m
,
ExitStack
()
as
after
:
...
@@ -177,7 +177,7 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
...
@@ -177,7 +177,7 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
request_ids
=
[
f
"request-
{
i
}
"
for
i
in
range
(
NUM_REQUESTS
)]
request_ids
=
[
f
"request-
{
i
}
"
for
i
in
range
(
NUM_REQUESTS
)]
# Create concurrent requests.
# Create concurrent requests.
tasks
:
L
ist
[
asyncio
.
Task
]
=
[]
tasks
:
l
ist
[
asyncio
.
Task
]
=
[]
for
request_id
in
request_ids
:
for
request_id
in
request_ids
:
tasks
.
append
(
tasks
.
append
(
asyncio
.
create_task
(
asyncio
.
create_task
(
...
...
tests/v1/engine/test_engine_core.py
View file @
cf069aa8
...
@@ -5,7 +5,6 @@ import threading
...
@@ -5,7 +5,6 @@ import threading
import
time
import
time
import
uuid
import
uuid
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
from
typing
import
List
import
pytest
import
pytest
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
...
@@ -213,7 +212,7 @@ def test_engine_core_concurrent_batches(monkeypatch):
...
@@ -213,7 +212,7 @@ def test_engine_core_concurrent_batches(monkeypatch):
class
DummyExecutor
(
UniProcExecutor
):
class
DummyExecutor
(
UniProcExecutor
):
def
initialize_from_config
(
def
initialize_from_config
(
self
,
kv_cache_configs
:
L
ist
[
KVCacheConfig
])
->
None
:
self
,
kv_cache_configs
:
l
ist
[
KVCacheConfig
])
->
None
:
super
().
initialize_from_config
(
kv_cache_configs
)
super
().
initialize_from_config
(
kv_cache_configs
)
# This executor actually can only run 1 batch at a time
# This executor actually can only run 1 batch at a time
...
...
tests/v1/engine/test_engine_core_client.py
View file @
cf069aa8
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
asyncio
import
asyncio
import
time
import
time
import
uuid
import
uuid
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Optional
import
pytest
import
pytest
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
...
@@ -44,7 +44,7 @@ def make_request(params: SamplingParams) -> EngineCoreRequest:
...
@@ -44,7 +44,7 @@ def make_request(params: SamplingParams) -> EngineCoreRequest:
)
)
def
loop_until_done
(
client
:
EngineCoreClient
,
outputs
:
D
ict
):
def
loop_until_done
(
client
:
EngineCoreClient
,
outputs
:
d
ict
):
while
True
:
while
True
:
engine_core_outputs
=
client
.
get_output
().
outputs
engine_core_outputs
=
client
.
get_output
().
outputs
...
@@ -62,7 +62,7 @@ def loop_until_done(client: EngineCoreClient, outputs: Dict):
...
@@ -62,7 +62,7 @@ def loop_until_done(client: EngineCoreClient, outputs: Dict):
break
break
async
def
loop_until_done_async
(
client
:
EngineCoreClient
,
outputs
:
D
ict
):
async
def
loop_until_done_async
(
client
:
EngineCoreClient
,
outputs
:
d
ict
):
while
True
:
while
True
:
engine_core_outputs
=
(
await
client
.
get_output_async
()).
outputs
engine_core_outputs
=
(
await
client
.
get_output_async
()).
outputs
...
@@ -121,7 +121,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
...
@@ -121,7 +121,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
client
.
add_request
(
request
)
client
.
add_request
(
request
)
time
.
sleep
(
0.01
)
time
.
sleep
(
0.01
)
outputs
:
D
ict
[
str
,
L
ist
]
=
{
req_id
:
[]
for
req_id
in
request_ids
}
outputs
:
d
ict
[
str
,
l
ist
]
=
{
req_id
:
[]
for
req_id
in
request_ids
}
loop_until_done
(
client
,
outputs
)
loop_until_done
(
client
,
outputs
)
for
req_id
in
request_ids
:
for
req_id
in
request_ids
:
...
@@ -207,7 +207,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
...
@@ -207,7 +207,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
await
client
.
add_request_async
(
request
)
await
client
.
add_request_async
(
request
)
await
asyncio
.
sleep
(
0.01
)
await
asyncio
.
sleep
(
0.01
)
outputs
:
D
ict
[
str
,
L
ist
]
=
{
req_id
:
[]
for
req_id
in
request_ids
}
outputs
:
d
ict
[
str
,
l
ist
]
=
{
req_id
:
[]
for
req_id
in
request_ids
}
await
loop_until_done_async
(
client
,
outputs
)
await
loop_until_done_async
(
client
,
outputs
)
for
req_id
in
request_ids
:
for
req_id
in
request_ids
:
...
...
tests/v1/engine/test_llm_engine.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
random
import
random
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
pytest
...
@@ -47,9 +47,9 @@ def vllm_model_apc(vllm_runner, monkeypatch):
...
@@ -47,9 +47,9 @@ def vllm_model_apc(vllm_runner, monkeypatch):
def
_get_test_sampling_params
(
def
_get_test_sampling_params
(
prompt_list
:
L
ist
[
str
],
prompt_list
:
l
ist
[
str
],
seed
:
Optional
[
int
]
=
42
,
seed
:
Optional
[
int
]
=
42
,
)
->
T
uple
[
L
ist
[
SamplingParams
],
L
ist
[
int
]]:
)
->
t
uple
[
l
ist
[
SamplingParams
],
l
ist
[
int
]]:
"""Generate random sampling params for a batch."""
"""Generate random sampling params for a batch."""
def
get_mostly_n_gt1
()
->
int
:
def
get_mostly_n_gt1
()
->
int
:
...
@@ -81,7 +81,7 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
...
@@ -81,7 +81,7 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
# Validate each request response
# Validate each request response
for
out
,
n
in
zip
(
outputs
,
n_list
):
for
out
,
n
in
zip
(
outputs
,
n_list
):
completion_counts
:
D
ict
[
str
,
int
]
=
{}
completion_counts
:
d
ict
[
str
,
int
]
=
{}
# Assert correct number of completions
# Assert correct number of completions
assert
len
(
out
.
outputs
)
==
n
,
(
assert
len
(
out
.
outputs
)
==
n
,
(
f
"
{
len
(
out
.
outputs
)
}
completions;
{
n
}
expected."
)
f
"
{
len
(
out
.
outputs
)
}
completions;
{
n
}
expected."
)
...
...
tests/v1/engine/test_output_processor.py
View file @
cf069aa8
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
import
math
import
math
import
time
import
time
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Optional
import
pytest
import
pytest
...
@@ -112,12 +112,12 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
...
@@ -112,12 +112,12 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
def
_validate_logprobs
(
def
_validate_logprobs
(
gen_tokens
:
D
ict
[
str
,
L
ist
[
int
]],
gen_tokens
:
d
ict
[
str
,
l
ist
[
int
]],
gen_logprobs
:
D
ict
[
str
,
Optional
[
SampleLogprobs
]],
gen_logprobs
:
d
ict
[
str
,
Optional
[
SampleLogprobs
]],
gen_prompt_logprobs
:
D
ict
[
str
,
Optional
[
PromptLogprobs
]],
gen_prompt_logprobs
:
d
ict
[
str
,
Optional
[
PromptLogprobs
]],
gen_cumulative_logprob
:
D
ict
[
str
,
float
],
gen_cumulative_logprob
:
d
ict
[
str
,
float
],
dtv
:
DummyOutputProcessorTestVectors
,
dtv
:
DummyOutputProcessorTestVectors
,
request_id_list
:
L
ist
[
str
],
request_id_list
:
l
ist
[
str
],
num_sample_logprobs
:
Optional
[
int
],
num_sample_logprobs
:
Optional
[
int
],
num_prompt_logprobs
:
Optional
[
int
],
num_prompt_logprobs
:
Optional
[
int
],
)
->
None
:
)
->
None
:
...
...
tests/v1/engine/utils.py
View file @
cf069aa8
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
import
random
import
random
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Union
import
torch
import
torch
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
...
@@ -61,7 +61,7 @@ def _create_random_top_logprob_test_vector(
...
@@ -61,7 +61,7 @@ def _create_random_top_logprob_test_vector(
def
_create_random_top_logprob_test_matrix
(
def
_create_random_top_logprob_test_matrix
(
shape
:
T
uple
,
shape
:
t
uple
,
lower
:
float
,
lower
:
float
,
upper
:
float
,
upper
:
float
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -90,7 +90,7 @@ def _create_random_top_token_test_vector(
...
@@ -90,7 +90,7 @@ def _create_random_top_token_test_vector(
lower
:
int
,
lower
:
int
,
upper
:
int
,
upper
:
int
,
sampled_token_id
:
int
,
sampled_token_id
:
int
,
adjust_num_logprobs
:
bool
=
True
)
->
T
uple
[
torch
.
Tensor
,
int
]:
adjust_num_logprobs
:
bool
=
True
)
->
t
uple
[
torch
.
Tensor
,
int
]:
"""Create a random vector of top logprob token indices
"""Create a random vector of top logprob token indices
Use to create fake sample logprobs for testing. The sampled token
Use to create fake sample logprobs for testing. The sampled token
...
@@ -141,11 +141,11 @@ def _create_random_top_token_test_vector(
...
@@ -141,11 +141,11 @@ def _create_random_top_token_test_vector(
def
_create_random_top_token_test_matrix
(
def
_create_random_top_token_test_matrix
(
shape
:
T
uple
[
int
,
int
],
shape
:
t
uple
[
int
,
int
],
lower
:
int
,
lower
:
int
,
upper
:
int
,
upper
:
int
,
tokens_list
:
L
ist
[
int
],
tokens_list
:
l
ist
[
int
],
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Create a random matrix of top logprob token indices
"""Create a random matrix of top logprob token indices
Use to create fake prompt logprobs for testing.
Use to create fake prompt logprobs for testing.
...
@@ -160,7 +160,7 @@ def _create_random_top_token_test_matrix(
...
@@ -160,7 +160,7 @@ def _create_random_top_token_test_matrix(
upper: upper range of token ids
upper: upper range of token ids
Returns:
Returns:
T
uple containing:
t
uple containing:
- 2D num_tokens x num_logprobs+1 torch Tensor of token ids
- 2D num_tokens x num_logprobs+1 torch Tensor of token ids
- 1D tensor of ranks of prompt tokens in their respective
- 1D tensor of ranks of prompt tokens in their respective
rows, or random values
rows, or random values
...
@@ -206,10 +206,10 @@ def decode_token(
...
@@ -206,10 +206,10 @@ def decode_token(
def
generate_dummy_sample_logprobs
(
def
generate_dummy_sample_logprobs
(
sampled_tokens_list
:
L
ist
,
sampled_tokens_list
:
l
ist
,
num_logprobs
:
int
,
num_logprobs
:
int
,
tokenizer
:
PreTrainedTokenizer
,
tokenizer
:
PreTrainedTokenizer
,
)
->
L
ist
[
T
uple
[
L
ist
[
int
],
L
ist
[
float
],
int
]]:
)
->
l
ist
[
t
uple
[
l
ist
[
int
],
l
ist
[
float
],
int
]]:
"""Generate dummy sample logprobs
"""Generate dummy sample logprobs
Generate a test data structure which imitates the list of sample logprobs
Generate a test data structure which imitates the list of sample logprobs
...
@@ -221,7 +221,7 @@ def generate_dummy_sample_logprobs(
...
@@ -221,7 +221,7 @@ def generate_dummy_sample_logprobs(
tokenizer: model tokenizer to use for detokenization
tokenizer: model tokenizer to use for detokenization
Returns
Returns
L
ist of (top token ids vector, logprobs vector, sampled token rank)
l
ist of (top token ids vector, logprobs vector, sampled token rank)
Python lists tuples; in each tuple the logprobs and top token ids
Python lists tuples; in each tuple the logprobs and top token ids
vectors have the same length which is either `num_logprobs` or
vectors have the same length which is either `num_logprobs` or
`num_logprobs+1`. Sampled token rank is the rank (index+1) of the
`num_logprobs+1`. Sampled token rank is the rank (index+1) of the
...
@@ -253,7 +253,7 @@ def generate_dummy_sample_logprobs(
...
@@ -253,7 +253,7 @@ def generate_dummy_sample_logprobs(
def
generate_dummy_prompt_logprobs_tensors
(
def
generate_dummy_prompt_logprobs_tensors
(
prompt_tokens_list
:
L
ist
,
prompt_tokens_list
:
l
ist
,
num_logprobs
:
int
,
num_logprobs
:
int
,
tokenizer
:
PreTrainedTokenizer
,
tokenizer
:
PreTrainedTokenizer
,
)
->
LogprobsTensors
:
)
->
LogprobsTensors
:
...
@@ -269,7 +269,7 @@ def generate_dummy_prompt_logprobs_tensors(
...
@@ -269,7 +269,7 @@ def generate_dummy_prompt_logprobs_tensors(
tokenizer: model tokenizer to use for detokenization
tokenizer: model tokenizer to use for detokenization
Returns
Returns
Single
T
uple of (logprobs matrix, top token ids matrix) torch Tensor,
Single
t
uple of (logprobs matrix, top token ids matrix) torch Tensor,
where both matrices have dimensions
where both matrices have dimensions
num_prompt_tokens x num_logprobs
num_prompt_tokens x num_logprobs
"""
"""
...
@@ -301,19 +301,19 @@ class DummyOutputProcessorTestVectors:
...
@@ -301,19 +301,19 @@ class DummyOutputProcessorTestVectors:
tokenizer
:
GeneralTokenizerType
tokenizer
:
GeneralTokenizerType
tokenizer_group
:
BaseTokenizerGroup
tokenizer_group
:
BaseTokenizerGroup
vllm_config
:
EngineArgs
vllm_config
:
EngineArgs
full_tokens
:
L
ist
[
L
ist
[
int
]]
# Prompt + generated tokens
full_tokens
:
l
ist
[
l
ist
[
int
]]
# Prompt + generated tokens
prompt_tokens
:
L
ist
[
L
ist
[
int
]]
prompt_tokens
:
l
ist
[
l
ist
[
int
]]
generation_tokens
:
L
ist
[
L
ist
[
int
]]
generation_tokens
:
l
ist
[
l
ist
[
int
]]
# Each request is associated with a tuple of
# Each request is associated with a tuple of
# (top tokens, top logprobs, ranks) prompt logprobs tensors
# (top tokens, top logprobs, ranks) prompt logprobs tensors
prompt_logprobs
:
L
ist
[
LogprobsTensors
]
prompt_logprobs
:
l
ist
[
LogprobsTensors
]
# Each request is associated with a sample logprobs; a request's
# Each request is associated with a sample logprobs; a request's
# sample logprobs are a list of (top tokens, top logprobs, ranks)
# sample logprobs are a list of (top tokens, top logprobs, ranks)
# sample logprobs tensors at each sequence position
# sample logprobs tensors at each sequence position
generation_logprobs
:
L
ist
[
L
ist
[
T
uple
[
L
ist
[
int
],
L
ist
[
float
],
int
]]]
generation_logprobs
:
l
ist
[
l
ist
[
t
uple
[
l
ist
[
int
],
l
ist
[
float
],
int
]]]
prompt_strings
:
L
ist
[
str
]
prompt_strings
:
l
ist
[
str
]
prompt_strings_len
:
L
ist
[
int
]
prompt_strings_len
:
l
ist
[
int
]
generation_strings
:
L
ist
[
str
]
generation_strings
:
l
ist
[
str
]
class
MockEngineCore
:
class
MockEngineCore
:
...
@@ -321,18 +321,18 @@ class MockEngineCore:
...
@@ -321,18 +321,18 @@ class MockEngineCore:
def
__init__
(
def
__init__
(
self
,
self
,
tokens_list
:
L
ist
[
L
ist
[
int
]],
tokens_list
:
l
ist
[
l
ist
[
int
]],
# For each request, for each sampled token offset,
# For each request, for each sampled token offset,
# a tuple of
# a tuple of
# (list of topk token ids, list of sample logprob vals, rank)
# (list of topk token ids, list of sample logprob vals, rank)
generated_logprobs_raw
:
Optional
[
L
ist
[
L
ist
[
T
uple
[
L
ist
[
int
],
generated_logprobs_raw
:
Optional
[
l
ist
[
l
ist
[
t
uple
[
l
ist
[
int
],
L
ist
[
float
],
l
ist
[
float
],
int
]]]]
=
None
,
int
]]]]
=
None
,
# For each request, a tuple of
# For each request, a tuple of
# (prompt logprob val matrix, prompt logprob tok id matrix);
# (prompt logprob val matrix, prompt logprob tok id matrix);
# each matrix has dimensions
# each matrix has dimensions
# (num prompt toks) x (num prompt logprobs+1)
# (num prompt toks) x (num prompt logprobs+1)
prompt_logprobs_raw
:
Optional
[
L
ist
[
LogprobsTensors
]]
=
None
,
prompt_logprobs_raw
:
Optional
[
l
ist
[
LogprobsTensors
]]
=
None
,
)
->
None
:
)
->
None
:
self
.
tokens_list
=
tokens_list
self
.
tokens_list
=
tokens_list
self
.
current_idx
=
0
self
.
current_idx
=
0
...
@@ -341,7 +341,7 @@ class MockEngineCore:
...
@@ -341,7 +341,7 @@ class MockEngineCore:
self
.
prompt_logprobs_raw
=
prompt_logprobs_raw
self
.
prompt_logprobs_raw
=
prompt_logprobs_raw
self
.
do_prompt_logprobs
=
prompt_logprobs_raw
is
not
None
self
.
do_prompt_logprobs
=
prompt_logprobs_raw
is
not
None
def
get_outputs
(
self
)
->
L
ist
[
EngineCoreOutput
]:
def
get_outputs
(
self
)
->
l
ist
[
EngineCoreOutput
]:
do_logprobs
=
self
.
do_logprobs
do_logprobs
=
self
.
do_logprobs
do_prompt_logprobs
=
self
.
do_prompt_logprobs
do_prompt_logprobs
=
self
.
do_prompt_logprobs
token_idx
=
self
.
current_idx
token_idx
=
self
.
current_idx
...
...
tests/v1/entrypoints/openai/test_completion.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
re
import
re
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Optional
import
openai
# use the official client for correctness check
import
openai
# use the official client for correctness check
import
pytest
import
pytest
...
@@ -193,7 +193,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
...
@@ -193,7 +193,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
async
def
test_prompt_logprobs_completion
(
client
:
openai
.
AsyncOpenAI
,
async
def
test_prompt_logprobs_completion
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
model_name
:
str
,
prompt_logprobs
:
Optional
[
int
]):
prompt_logprobs
:
Optional
[
int
]):
params
:
D
ict
=
{
params
:
d
ict
=
{
"prompt"
:
[
"A robot may not injure another robot"
,
"My name is"
],
"prompt"
:
[
"A robot may not injure another robot"
,
"My name is"
],
"model"
:
model_name
,
"model"
:
model_name
,
}
}
...
@@ -237,7 +237,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
...
@@ -237,7 +237,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
max_tokens
=
5
,
max_tokens
=
5
,
temperature
=
0.0
,
temperature
=
0.0
,
stream
=
True
)
stream
=
True
)
chunks
:
L
ist
[
str
]
=
[]
chunks
:
l
ist
[
str
]
=
[]
finish_reason_count
=
0
finish_reason_count
=
0
async
for
chunk
in
stream
:
async
for
chunk
in
stream
:
chunks
.
append
(
chunk
.
choices
[
0
].
text
)
chunks
.
append
(
chunk
.
choices
[
0
].
text
)
...
@@ -278,7 +278,7 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
...
@@ -278,7 +278,7 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
num_completions
=
len
(
completion
.
choices
)
num_completions
=
len
(
completion
.
choices
)
assert
num_completions
==
n
,
(
assert
num_completions
==
n
,
(
f
"Num completions
{
num_completions
}
but expected
{
n
}
."
)
f
"Num completions
{
num_completions
}
but expected
{
n
}
."
)
completion_repeats
:
D
ict
[
str
,
int
]
=
{}
completion_repeats
:
d
ict
[
str
,
int
]
=
{}
for
idx
,
choice
in
enumerate
(
completion
.
choices
):
for
idx
,
choice
in
enumerate
(
completion
.
choices
):
# Assert correct completion index & some finish reason.
# Assert correct completion index & some finish reason.
assert
choice
.
index
==
idx
,
(
assert
choice
.
index
==
idx
,
(
...
@@ -321,7 +321,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
...
@@ -321,7 +321,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
temperature
=
0.95
,
temperature
=
0.95
,
stream
=
True
,
stream
=
True
,
seed
=
42
)
seed
=
42
)
chunks
:
L
ist
[
L
ist
[
str
]]
=
[[]
for
i
in
range
(
n
)]
chunks
:
l
ist
[
l
ist
[
str
]]
=
[[]
for
i
in
range
(
n
)]
finish_reason_count
=
0
finish_reason_count
=
0
async
for
chunk
in
stream
:
async
for
chunk
in
stream
:
index
=
chunk
.
choices
[
0
].
index
index
=
chunk
.
choices
[
0
].
index
...
@@ -332,7 +332,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
...
@@ -332,7 +332,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
# Assert `n` completions with correct finish reasons
# Assert `n` completions with correct finish reasons
assert
finish_reason_count
==
n
,
(
assert
finish_reason_count
==
n
,
(
f
"Expected
{
n
}
completions with valid indices and finish_reason."
)
f
"Expected
{
n
}
completions with valid indices and finish_reason."
)
completion_repeats
:
D
ict
[
str
,
int
]
=
{}
completion_repeats
:
d
ict
[
str
,
int
]
=
{}
for
chunk
in
chunks
:
for
chunk
in
chunks
:
chunk_len
=
len
(
chunk
)
chunk_len
=
len
(
chunk
)
# Assert correct number of completion tokens
# Assert correct number of completion tokens
...
...
tests/v1/sample/test_logprobs.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
itertools
import
itertools
from
typing
import
List
,
Tuple
import
pytest
import
pytest
import
torch
import
torch
...
@@ -46,8 +45,8 @@ def hf_model(hf_runner):
...
@@ -46,8 +45,8 @@ def hf_model(hf_runner):
def
_repeat_logprob_config
(
def
_repeat_logprob_config
(
test_prompts
,
test_prompts
,
logprob_prompt_logprob_list
:
L
ist
[
T
uple
],
logprob_prompt_logprob_list
:
l
ist
[
t
uple
],
)
->
L
ist
[
T
uple
]:
)
->
l
ist
[
t
uple
]:
"""Ensure each test prompt has a logprob config.
"""Ensure each test prompt has a logprob config.
A logprob config specifies the optional (i.e.
A logprob config specifies the optional (i.e.
...
@@ -74,7 +73,7 @@ def _repeat_logprob_config(
...
@@ -74,7 +73,7 @@ def _repeat_logprob_config(
tuples
tuples
Returns:
Returns:
L
ist of
l
ist of
(optional num sample logprob,optional num prompt logprob)
(optional num sample logprob,optional num prompt logprob)
tuples which is either identical to
tuples which is either identical to
`logprob_prompt_logprob_list`, or else repeats
`logprob_prompt_logprob_list`, or else repeats
...
@@ -177,7 +176,7 @@ def _test_case_get_logprobs_and_prompt_logprobs(
...
@@ -177,7 +176,7 @@ def _test_case_get_logprobs_and_prompt_logprobs(
for
r
in
range
(
1
,
num_top_logprobs
+
1
))
for
r
in
range
(
1
,
num_top_logprobs
+
1
))
output_text
=
vllm_result
.
outputs
[
0
].
text
output_text
=
vllm_result
.
outputs
[
0
].
text
output_string_from_most_likely_tokens_lst
:
L
ist
[
str
]
=
[]
output_string_from_most_likely_tokens_lst
:
l
ist
[
str
]
=
[]
for
top_logprobs
in
vllm_result
.
outputs
[
0
].
logprobs
:
for
top_logprobs
in
vllm_result
.
outputs
[
0
].
logprobs
:
top_logprob
=
next
(
iter
(
top_logprobs
.
values
()))
top_logprob
=
next
(
iter
(
top_logprobs
.
values
()))
output_string_from_most_likely_tokens_lst
.
append
(
output_string_from_most_likely_tokens_lst
.
append
(
...
...
tests/v1/sample/test_rejection_sampler.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
import
pytest
import
pytest
import
torch
import
torch
...
@@ -13,7 +12,7 @@ def sampler():
...
@@ -13,7 +12,7 @@ def sampler():
return
RejectionSampler
()
return
RejectionSampler
()
def
create_logits_tensor
(
token_ids
:
L
ist
[
int
],
def
create_logits_tensor
(
token_ids
:
l
ist
[
int
],
vocab_size
:
int
=
100
)
->
torch
.
Tensor
:
vocab_size
:
int
=
100
)
->
torch
.
Tensor
:
"""Helper function to create logits tensor that
"""Helper function to create logits tensor that
will produce desired token ids on argmax"""
will produce desired token ids on argmax"""
...
@@ -23,7 +22,7 @@ def create_logits_tensor(token_ids: List[int],
...
@@ -23,7 +22,7 @@ def create_logits_tensor(token_ids: List[int],
return
logits
return
logits
def
create_sampling_metadata
(
spec_tokens
:
L
ist
[
L
ist
[
int
]])
->
SamplingMetadata
:
def
create_sampling_metadata
(
spec_tokens
:
l
ist
[
l
ist
[
int
]])
->
SamplingMetadata
:
batch_size
=
len
(
spec_tokens
)
batch_size
=
len
(
spec_tokens
)
return
SamplingMetadata
(
return
SamplingMetadata
(
temperature
=
torch
.
tensor
([]),
temperature
=
torch
.
tensor
([]),
...
@@ -106,7 +105,7 @@ def test_single_token_sequence(sampler):
...
@@ -106,7 +105,7 @@ def test_single_token_sequence(sampler):
def
test_empty_sequence
(
sampler
):
def
test_empty_sequence
(
sampler
):
"""Test handling empty sequence of speculated tokens"""
"""Test handling empty sequence of speculated tokens"""
spec_tokens
:
L
ist
[
L
ist
[
int
]]
=
[[]]
spec_tokens
:
l
ist
[
l
ist
[
int
]]
=
[[]]
output_tokens
=
[
5
]
# Just the bonus token
output_tokens
=
[
5
]
# Just the bonus token
metadata
=
create_sampling_metadata
(
spec_tokens
)
metadata
=
create_sampling_metadata
(
spec_tokens
)
...
...
tests/v1/sample/test_sampler.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Optional
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
...
@@ -32,7 +32,7 @@ def _create_penalty_tensor(batch_size: int, penalty_value: float,
...
@@ -32,7 +32,7 @@ def _create_penalty_tensor(batch_size: int, penalty_value: float,
def
_create_prompt_tokens_tensor
(
def
_create_prompt_tokens_tensor
(
prompt_token_ids
:
L
ist
[
L
ist
[
int
]],
prompt_token_ids
:
l
ist
[
l
ist
[
int
]],
vocab_size
:
int
,
vocab_size
:
int
,
device
:
torch
.
device
,
device
:
torch
.
device
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -49,8 +49,8 @@ def _create_logit_bias(
...
@@ -49,8 +49,8 @@ def _create_logit_bias(
batch_size
:
int
,
batch_size
:
int
,
vocab_size
:
int
,
vocab_size
:
int
,
bias_value
:
float
,
bias_value
:
float
,
)
->
L
ist
[
Optional
[
D
ict
[
int
,
float
]]]:
)
->
l
ist
[
Optional
[
d
ict
[
int
,
float
]]]:
res
:
L
ist
[
Optional
[
D
ict
[
int
,
float
]]]
=
[]
res
:
l
ist
[
Optional
[
d
ict
[
int
,
float
]]]
=
[]
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
logit_bias
=
{
min
(
i
,
vocab_size
-
1
):
bias_value
}
logit_bias
=
{
min
(
i
,
vocab_size
-
1
):
bias_value
}
res
.
append
(
logit_bias
)
res
.
append
(
logit_bias
)
...
@@ -83,8 +83,8 @@ def _create_default_sampling_metadata(
...
@@ -83,8 +83,8 @@ def _create_default_sampling_metadata(
vocab_size
:
int
,
vocab_size
:
int
,
device
:
torch
.
device
,
device
:
torch
.
device
,
)
->
SamplingMetadata
:
)
->
SamplingMetadata
:
output_token_ids
:
L
ist
[
L
ist
[
int
]]
=
[]
output_token_ids
:
l
ist
[
l
ist
[
int
]]
=
[]
prompt_token_ids
:
L
ist
[
L
ist
[
int
]]
=
[]
prompt_token_ids
:
l
ist
[
l
ist
[
int
]]
=
[]
for
_
in
range
(
batch_size
):
for
_
in
range
(
batch_size
):
output_token_ids
.
append
(
output_token_ids
.
append
(
np
.
random
.
randint
(
0
,
vocab_size
,
size
=
num_output_tokens
).
tolist
())
np
.
random
.
randint
(
0
,
vocab_size
,
size
=
num_output_tokens
).
tolist
())
...
@@ -118,8 +118,8 @@ def _create_default_sampling_metadata(
...
@@ -118,8 +118,8 @@ def _create_default_sampling_metadata(
def
_generate_min_token_penalties_and_stop_tokens
(
def
_generate_min_token_penalties_and_stop_tokens
(
num_output_tokens
:
int
,
batch_size
:
int
,
vocab_size
:
int
,
num_output_tokens
:
int
,
batch_size
:
int
,
vocab_size
:
int
,
batch_indices_for_min_token_penalty
:
L
ist
[
int
]
batch_indices_for_min_token_penalty
:
l
ist
[
int
]
)
->
D
ict
[
int
,
T
uple
[
int
,
S
et
[
int
]]]:
)
->
d
ict
[
int
,
t
uple
[
int
,
s
et
[
int
]]]:
"""
"""
Generates and returns a dict of minimum token penalties and
Generates and returns a dict of minimum token penalties and
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
...
@@ -130,7 +130,7 @@ def _generate_min_token_penalties_and_stop_tokens(
...
@@ -130,7 +130,7 @@ def _generate_min_token_penalties_and_stop_tokens(
and a random set of stop token IDs is created. Otherwise, a lower
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
"""
min_tokens
:
D
ict
[
int
,
T
uple
[
int
,
S
et
[
int
]]]
=
{}
min_tokens
:
d
ict
[
int
,
t
uple
[
int
,
s
et
[
int
]]]
=
{}
for
index
in
range
(
batch_size
):
for
index
in
range
(
batch_size
):
if
index
in
batch_indices_for_min_token_penalty
:
if
index
in
batch_indices_for_min_token_penalty
:
min_tokens
[
index
]
=
(
min_tokens
[
index
]
=
(
...
@@ -147,7 +147,7 @@ def _generate_min_token_penalties_and_stop_tokens(
...
@@ -147,7 +147,7 @@ def _generate_min_token_penalties_and_stop_tokens(
def
_create_weighted_output_token_list
(
def
_create_weighted_output_token_list
(
batch_size
:
int
,
batch_size
:
int
,
vocab_size
:
int
)
->
T
uple
[
L
ist
[
L
ist
[
int
]],
L
ist
[
L
ist
[
int
]]]:
vocab_size
:
int
)
->
t
uple
[
l
ist
[
l
ist
[
int
]],
l
ist
[
l
ist
[
int
]]]:
"""
"""
Creates an output token list where each token occurs a distinct
Creates an output token list where each token occurs a distinct
number of times.
number of times.
...
@@ -157,7 +157,7 @@ def _create_weighted_output_token_list(
...
@@ -157,7 +157,7 @@ def _create_weighted_output_token_list(
list, each with a different frequency.
list, each with a different frequency.
Returns:
Returns:
T
uple[
L
ist[
L
ist[int]],
L
ist[
L
ist[int]]]:
t
uple[
l
ist[
l
ist[int]],
l
ist[
l
ist[int]]]:
- The first element is the output token list, where each sublist
- The first element is the output token list, where each sublist
corresponds to a batch and contains tokens with weighted
corresponds to a batch and contains tokens with weighted
frequencies.
frequencies.
...
@@ -165,8 +165,8 @@ def _create_weighted_output_token_list(
...
@@ -165,8 +165,8 @@ def _create_weighted_output_token_list(
batch, ordered by their frequency in the corresponding output
batch, ordered by their frequency in the corresponding output
list.
list.
"""
"""
output_token_ids
:
L
ist
[
L
ist
[
int
]]
=
[]
output_token_ids
:
l
ist
[
l
ist
[
int
]]
=
[]
sorted_token_ids_in_output
:
L
ist
[
L
ist
[
int
]]
=
[]
sorted_token_ids_in_output
:
l
ist
[
l
ist
[
int
]]
=
[]
for
_
in
range
(
batch_size
):
for
_
in
range
(
batch_size
):
distinct_token_ids
=
np
.
random
.
choice
(
vocab_size
,
distinct_token_ids
=
np
.
random
.
choice
(
vocab_size
,
size
=
np
.
random
.
randint
(
1
,
10
),
size
=
np
.
random
.
randint
(
1
,
10
),
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment