Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cf069aa8
Unverified
Commit
cf069aa8
authored
Mar 03, 2025
by
Harry Mellor
Committed by
GitHub
Mar 02, 2025
Browse files
Update deprecated Python 3.8 typing (#13971)
parent
bf33700e
Changes
300
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
134 additions
and
140 deletions
+134
-140
tests/tool_use/test_chat_completions.py
tests/tool_use/test_chat_completions.py
+2
-4
tests/tool_use/test_jamba_tool_parser.py
tests/tool_use/test_jamba_tool_parser.py
+7
-6
tests/tool_use/test_parallel_tool_calls.py
tests/tool_use/test_parallel_tool_calls.py
+5
-5
tests/tool_use/test_tool_calls.py
tests/tool_use/test_tool_calls.py
+5
-5
tests/tool_use/utils.py
tests/tool_use/utils.py
+13
-13
tests/tracing/test_tracing.py
tests/tracing/test_tracing.py
+3
-2
tests/utils.py
tests/utils.py
+21
-21
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+1
-2
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+3
-3
tests/v1/engine/conftest.py
tests/v1/engine/conftest.py
+2
-4
tests/v1/engine/test_async_llm.py
tests/v1/engine/test_async_llm.py
+5
-5
tests/v1/engine/test_engine_core.py
tests/v1/engine/test_engine_core.py
+1
-2
tests/v1/engine/test_engine_core_client.py
tests/v1/engine/test_engine_core_client.py
+5
-5
tests/v1/engine/test_llm_engine.py
tests/v1/engine/test_llm_engine.py
+4
-4
tests/v1/engine/test_output_processor.py
tests/v1/engine/test_output_processor.py
+6
-6
tests/v1/engine/utils.py
tests/v1/engine/utils.py
+25
-25
tests/v1/entrypoints/openai/test_completion.py
tests/v1/entrypoints/openai/test_completion.py
+6
-6
tests/v1/sample/test_logprobs.py
tests/v1/sample/test_logprobs.py
+4
-5
tests/v1/sample/test_rejection_sampler.py
tests/v1/sample/test_rejection_sampler.py
+3
-4
tests/v1/sample/test_sampler.py
tests/v1/sample/test_sampler.py
+13
-13
No files found.
tests/tool_use/test_chat_completions.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
import
openai
import
pytest
...
...
@@ -45,7 +43,7 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI,
logprobs
=
False
,
stream
=
True
,
)
chunks
:
L
ist
[
str
]
=
[]
chunks
:
l
ist
[
str
]
=
[]
finish_reason_count
=
0
role_sent
:
bool
=
False
...
...
@@ -116,7 +114,7 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI,
stream
=
True
,
)
chunks
:
L
ist
[
str
]
=
[]
chunks
:
l
ist
[
str
]
=
[]
finish_reason_count
=
0
role_sent
:
bool
=
False
...
...
tests/tool_use/test_jamba_tool_parser.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
import
json
from
typing
import
Generator
,
List
,
Optional
from
collections.abc
import
Generator
from
typing
import
Optional
import
partial_json_parser
import
pytest
...
...
@@ -26,8 +27,8 @@ def jamba_tool_parser(jamba_tokenizer):
return
JambaToolParser
(
jamba_tokenizer
)
def
assert_tool_calls
(
actual_tool_calls
:
L
ist
[
ToolCall
],
expected_tool_calls
:
L
ist
[
ToolCall
]):
def
assert_tool_calls
(
actual_tool_calls
:
l
ist
[
ToolCall
],
expected_tool_calls
:
l
ist
[
ToolCall
]):
assert
len
(
actual_tool_calls
)
==
len
(
expected_tool_calls
)
for
actual_tool_call
,
expected_tool_call
in
zip
(
actual_tool_calls
,
...
...
@@ -218,10 +219,10 @@ def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer,
model_output
,
expected_tool_calls
,
expected_content
):
other_content
:
str
=
''
function_names
:
L
ist
[
str
]
=
[]
function_args_strs
:
L
ist
[
str
]
=
[]
function_names
:
l
ist
[
str
]
=
[]
function_args_strs
:
l
ist
[
str
]
=
[]
tool_call_idx
:
int
=
-
1
tool_call_ids
:
L
ist
[
Optional
[
str
]]
=
[]
tool_call_ids
:
l
ist
[
Optional
[
str
]]
=
[]
for
delta_message
in
stream_delta_message_generator
(
jamba_tool_parser
,
jamba_tokenizer
,
model_output
):
...
...
tests/tool_use/test_parallel_tool_calls.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
import
json
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Optional
import
openai
import
pytest
...
...
@@ -54,7 +54,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
assert
isinstance
(
tool_call
.
function
.
arguments
,
str
)
parsed_arguments
=
json
.
loads
(
tool_call
.
function
.
arguments
)
assert
isinstance
(
parsed_arguments
,
D
ict
)
assert
isinstance
(
parsed_arguments
,
d
ict
)
assert
isinstance
(
parsed_arguments
.
get
(
"city"
),
str
)
assert
isinstance
(
parsed_arguments
.
get
(
"state"
),
str
)
...
...
@@ -73,8 +73,8 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
role_name
:
Optional
[
str
]
=
None
finish_reason_count
:
int
=
0
tool_call_names
:
L
ist
[
str
]
=
[]
tool_call_args
:
L
ist
[
str
]
=
[]
tool_call_names
:
l
ist
[
str
]
=
[]
tool_call_args
:
l
ist
[
str
]
=
[]
tool_call_idx
:
int
=
-
1
tool_call_id_count
:
int
=
0
...
...
@@ -180,7 +180,7 @@ async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI,
logprobs
=
False
,
stream
=
True
)
chunks
:
L
ist
[
str
]
=
[]
chunks
:
l
ist
[
str
]
=
[]
finish_reason_count
=
0
role_sent
:
bool
=
False
...
...
tests/tool_use/test_tool_calls.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
import
json
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Optional
import
openai
import
pytest
...
...
@@ -44,7 +44,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
# make sure the arguments parse properly
parsed_arguments
=
json
.
loads
(
tool_calls
[
0
].
function
.
arguments
)
assert
isinstance
(
parsed_arguments
,
D
ict
)
assert
isinstance
(
parsed_arguments
,
d
ict
)
assert
isinstance
(
parsed_arguments
.
get
(
"city"
),
str
)
assert
isinstance
(
parsed_arguments
.
get
(
"state"
),
str
)
assert
parsed_arguments
.
get
(
"city"
)
==
"Dallas"
...
...
@@ -117,7 +117,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
# validate arguments
streamed_args
=
json
.
loads
(
function_args_str
)
assert
isinstance
(
streamed_args
,
D
ict
)
assert
isinstance
(
streamed_args
,
d
ict
)
assert
isinstance
(
streamed_args
.
get
(
"city"
),
str
)
assert
isinstance
(
streamed_args
.
get
(
"state"
),
str
)
assert
streamed_args
.
get
(
"city"
)
==
"Dallas"
...
...
@@ -128,7 +128,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
assert
choice
.
message
.
role
==
role_name
assert
choice
.
message
.
tool_calls
[
0
].
function
.
name
==
function_name
# compare streamed with non-streamed args
D
ict-wise, not string-wise
# compare streamed with non-streamed args
d
ict-wise, not string-wise
# because character-to-character comparison might not work e.g. the tool
# call parser adding extra spaces or something like that. we care about the
# dicts matching not byte-wise match
...
...
@@ -167,7 +167,7 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI):
logprobs
=
False
,
stream
=
True
)
chunks
:
L
ist
[
str
]
=
[]
chunks
:
l
ist
[
str
]
=
[]
finish_reason_count
=
0
role_sent
:
bool
=
False
...
...
tests/tool_use/utils.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
copy
import
deepcopy
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Optional
from
openai.types.chat
import
(
ChatCompletionMessageParam
,
ChatCompletionToolParam
)
...
...
@@ -12,14 +12,14 @@ from tests.utils import VLLM_PATH
class
ServerConfig
(
TypedDict
,
total
=
False
):
model
:
str
arguments
:
L
ist
[
str
]
arguments
:
l
ist
[
str
]
system_prompt
:
Optional
[
str
]
supports_parallel
:
Optional
[
bool
]
supports_rocm
:
Optional
[
bool
]
def
patch_system_prompt
(
messages
:
L
ist
[
D
ict
[
str
,
Any
]],
system_prompt
:
str
)
->
L
ist
[
D
ict
[
str
,
Any
]]:
def
patch_system_prompt
(
messages
:
l
ist
[
d
ict
[
str
,
Any
]],
system_prompt
:
str
)
->
l
ist
[
d
ict
[
str
,
Any
]]:
new_messages
=
deepcopy
(
messages
)
if
new_messages
[
0
][
"role"
]
==
"system"
:
new_messages
[
0
][
"content"
]
=
system_prompt
...
...
@@ -28,8 +28,8 @@ def patch_system_prompt(messages: List[Dict[str, Any]],
return
new_messages
def
ensure_system_prompt
(
messages
:
L
ist
[
D
ict
[
str
,
Any
]],
config
:
ServerConfig
)
->
L
ist
[
D
ict
[
str
,
Any
]]:
def
ensure_system_prompt
(
messages
:
l
ist
[
d
ict
[
str
,
Any
]],
config
:
ServerConfig
)
->
l
ist
[
d
ict
[
str
,
Any
]]:
prompt
=
config
.
get
(
"system_prompt"
)
if
prompt
:
return
patch_system_prompt
(
messages
,
prompt
)
...
...
@@ -39,9 +39,9 @@ def ensure_system_prompt(messages: List[Dict[str, Any]],
# universal args for all models go here. also good if you need to test locally
# and change type or KV cache quantization or something.
ARGS
:
L
ist
[
str
]
=
[
"--enable-auto-tool-choice"
,
"--max-model-len"
,
"1024"
]
ARGS
:
l
ist
[
str
]
=
[
"--enable-auto-tool-choice"
,
"--max-model-len"
,
"1024"
]
CONFIGS
:
D
ict
[
str
,
ServerConfig
]
=
{
CONFIGS
:
d
ict
[
str
,
ServerConfig
]
=
{
"hermes"
:
{
"model"
:
"NousResearch/Hermes-3-Llama-3.1-8B"
,
...
...
@@ -205,7 +205,7 @@ SEARCH_TOOL: ChatCompletionToolParam = {
}
}
MESSAGES_WITHOUT_TOOLS
:
L
ist
[
ChatCompletionMessageParam
]
=
[{
MESSAGES_WITHOUT_TOOLS
:
l
ist
[
ChatCompletionMessageParam
]
=
[{
"role"
:
"user"
,
"content"
:
...
...
@@ -222,14 +222,14 @@ MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{
"Can you tell me a joke please?"
}]
MESSAGES_ASKING_FOR_TOOLS
:
L
ist
[
ChatCompletionMessageParam
]
=
[{
MESSAGES_ASKING_FOR_TOOLS
:
l
ist
[
ChatCompletionMessageParam
]
=
[{
"role"
:
"user"
,
"content"
:
"What is the weather in Dallas, Texas in Fahrenheit?"
}]
MESSAGES_WITH_TOOL_RESPONSE
:
L
ist
[
ChatCompletionMessageParam
]
=
[{
MESSAGES_WITH_TOOL_RESPONSE
:
l
ist
[
ChatCompletionMessageParam
]
=
[{
"role"
:
"user"
,
"content"
:
...
...
@@ -258,7 +258,7 @@ MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{
"cloudy skies and a low chance of rain."
}]
MESSAGES_ASKING_FOR_PARALLEL_TOOLS
:
L
ist
[
ChatCompletionMessageParam
]
=
[{
MESSAGES_ASKING_FOR_PARALLEL_TOOLS
:
l
ist
[
ChatCompletionMessageParam
]
=
[{
"role"
:
"user"
,
"content"
:
...
...
@@ -266,7 +266,7 @@ MESSAGES_ASKING_FOR_PARALLEL_TOOLS: List[ChatCompletionMessageParam] = [{
"Fahrenheit?"
}]
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE
:
L
ist
[
ChatCompletionMessageParam
]
=
[{
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE
:
l
ist
[
ChatCompletionMessageParam
]
=
[{
"role"
:
"user"
,
"content"
:
...
...
tests/tracing/test_tracing.py
View file @
cf069aa8
...
...
@@ -2,8 +2,9 @@
import
os
import
threading
from
collections.abc
import
Iterable
from
concurrent
import
futures
from
typing
import
Callable
,
Dict
,
Iterable
,
Literal
from
typing
import
Callable
,
Literal
import
grpc
import
pytest
...
...
@@ -25,7 +26,7 @@ FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value',
def
decode_value
(
value
:
AnyValue
):
field_decoders
:
D
ict
[
FieldName
,
Callable
]
=
{
field_decoders
:
d
ict
[
FieldName
,
Callable
]
=
{
"bool_value"
:
(
lambda
v
:
v
.
bool_value
),
"string_value"
:
(
lambda
v
:
v
.
string_value
),
"int_value"
:
(
lambda
v
:
v
.
int_value
),
...
...
tests/utils.py
View file @
cf069aa8
...
...
@@ -11,7 +11,7 @@ import time
import
warnings
from
contextlib
import
contextmanager
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Type
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
openai
import
pytest
...
...
@@ -73,9 +73,9 @@ class RemoteOpenAIServer:
def
__init__
(
self
,
model
:
str
,
vllm_serve_args
:
L
ist
[
str
],
vllm_serve_args
:
l
ist
[
str
],
*
,
env_dict
:
Optional
[
D
ict
[
str
,
str
]]
=
None
,
env_dict
:
Optional
[
d
ict
[
str
,
str
]]
=
None
,
auto_port
:
bool
=
True
,
max_wait_seconds
:
Optional
[
float
]
=
None
)
->
None
:
if
auto_port
:
...
...
@@ -183,7 +183,7 @@ def _test_completion(
client
:
openai
.
OpenAI
,
model
:
str
,
prompt
:
str
,
token_ids
:
L
ist
[
int
],
token_ids
:
l
ist
[
int
],
):
results
=
[]
...
...
@@ -400,10 +400,10 @@ def _test_image_text(
def
compare_two_settings
(
model
:
str
,
arg1
:
L
ist
[
str
],
arg2
:
L
ist
[
str
],
env1
:
Optional
[
D
ict
[
str
,
str
]]
=
None
,
env2
:
Optional
[
D
ict
[
str
,
str
]]
=
None
,
arg1
:
l
ist
[
str
],
arg2
:
l
ist
[
str
],
env1
:
Optional
[
d
ict
[
str
,
str
]]
=
None
,
env2
:
Optional
[
d
ict
[
str
,
str
]]
=
None
,
*
,
method
:
str
=
"generate"
,
max_wait_seconds
:
Optional
[
float
]
=
None
)
->
None
:
...
...
@@ -429,8 +429,8 @@ def compare_two_settings(model: str,
def
compare_all_settings
(
model
:
str
,
all_args
:
L
ist
[
L
ist
[
str
]],
all_envs
:
L
ist
[
Optional
[
D
ict
[
str
,
str
]]],
all_args
:
l
ist
[
l
ist
[
str
]],
all_envs
:
l
ist
[
Optional
[
d
ict
[
str
,
str
]]],
*
,
method
:
str
=
"generate"
,
max_wait_seconds
:
Optional
[
float
]
=
None
)
->
None
:
...
...
@@ -470,7 +470,7 @@ def compare_all_settings(model: str,
prompt
=
"Hello, my name is"
token_ids
=
tokenizer
(
prompt
).
input_ids
ref_results
:
L
ist
=
[]
ref_results
:
l
ist
=
[]
for
i
,
(
args
,
env
)
in
enumerate
(
zip
(
all_args
,
all_envs
)):
if
can_force_load_format
:
# we are comparing the results and
...
...
@@ -481,7 +481,7 @@ def compare_all_settings(model: str,
# environment variable to force the load format,
# e.g. in quantization tests.
args
=
args
+
[
"--load-format"
,
envs
.
VLLM_TEST_FORCE_LOAD_FORMAT
]
compare_results
:
L
ist
=
[]
compare_results
:
l
ist
=
[]
results
=
ref_results
if
i
==
0
else
compare_results
with
RemoteOpenAIServer
(
model
,
args
,
...
...
@@ -582,7 +582,7 @@ def multi_process_parallel(
@
contextmanager
def
error_on_warning
(
category
:
T
ype
[
Warning
]
=
Warning
):
def
error_on_warning
(
category
:
t
ype
[
Warning
]
=
Warning
):
"""
Within the scope of this context manager, tests will fail if any warning
of the given category is emitted.
...
...
@@ -604,7 +604,7 @@ def get_physical_device_indices(devices):
@
_nvml
()
def
wait_for_gpu_memory_to_clear
(
devices
:
L
ist
[
int
],
def
wait_for_gpu_memory_to_clear
(
devices
:
l
ist
[
int
],
threshold_bytes
:
int
,
timeout_s
:
float
=
120
)
->
None
:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
...
...
@@ -612,8 +612,8 @@ def wait_for_gpu_memory_to_clear(devices: List[int],
devices
=
get_physical_device_indices
(
devices
)
start_time
=
time
.
time
()
while
True
:
output
:
D
ict
[
int
,
str
]
=
{}
output_raw
:
D
ict
[
int
,
float
]
=
{}
output
:
d
ict
[
int
,
str
]
=
{}
output_raw
:
d
ict
[
int
,
float
]
=
{}
for
device
in
devices
:
if
current_platform
.
is_rocm
():
dev_handle
=
amdsmi_get_processor_handles
()[
device
]
...
...
@@ -758,13 +758,13 @@ def multi_gpu_test(*, num_gpus: int):
async
def
completions_with_server_args
(
prompts
:
L
ist
[
str
],
prompts
:
l
ist
[
str
],
model_name
:
str
,
server_cli_args
:
L
ist
[
str
],
server_cli_args
:
l
ist
[
str
],
num_logprobs
:
Optional
[
int
],
max_wait_seconds
:
int
=
240
,
max_tokens
:
Union
[
int
,
list
]
=
5
,
)
->
L
ist
[
Completion
]:
)
->
l
ist
[
Completion
]:
'''Construct a remote OpenAI server, obtain an async client to the
server & invoke the completions API to obtain completions.
...
...
@@ -807,7 +807,7 @@ async def completions_with_server_args(
return
outputs
def
get_client_text_generations
(
completions
:
L
ist
[
Completion
])
->
L
ist
[
str
]:
def
get_client_text_generations
(
completions
:
l
ist
[
Completion
])
->
l
ist
[
str
]:
'''Extract generated tokens from the output of a
request made to an Open-AI-protocol completions endpoint.
'''
...
...
@@ -816,7 +816,7 @@ def get_client_text_generations(completions: List[Completion]) -> List[str]:
def
get_client_text_logprob_generations
(
completions
:
L
ist
[
Completion
])
->
L
ist
[
TextTextLogprobs
]:
completions
:
l
ist
[
Completion
])
->
l
ist
[
TextTextLogprobs
]:
'''Operates on the output of a request made to an Open-AI-protocol
completions endpoint; obtains top-rank logprobs for each token in
each :class:`SequenceGroup`
...
...
tests/v1/core/test_prefix_caching.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
"""Compare the with and without prefix caching."""
from
typing
import
List
import
pytest
...
...
@@ -434,7 +433,7 @@ def test_cache_blocks():
# Test that blocks are cached correctly for 2 full blocks from the start.
blocks
=
[
KVCacheBlock
(
block_id
=
i
)
for
i
in
range
(
2
)]
block_hashes
:
L
ist
[
BlockHashType
]
=
[]
block_hashes
:
l
ist
[
BlockHashType
]
=
[]
block_pool
.
cache_full_blocks
(
request
=
req
,
...
...
tests/v1/core/test_scheduler.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
from
typing
import
Optional
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
...
...
@@ -48,9 +48,9 @@ def create_scheduler(
def
create_requests
(
num_requests
:
int
,
num_tokens
:
int
=
10
,
mm_positions
:
Optional
[
L
ist
[
PlaceholderRange
]]
=
None
,
mm_positions
:
Optional
[
l
ist
[
PlaceholderRange
]]
=
None
,
max_tokens
:
int
=
16
,
stop_token_ids
:
Optional
[
L
ist
[
int
]]
=
None
,
stop_token_ids
:
Optional
[
l
ist
[
int
]]
=
None
,
):
sampling_params
=
SamplingParams
(
ignore_eos
=
False
,
max_tokens
=
max_tokens
,
...
...
tests/v1/engine/conftest.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Tuple
import
pytest
import
torch
from
transformers
import
AutoTokenizer
...
...
@@ -17,8 +15,8 @@ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from
tests.v1.engine.utils
import
FULL_STRINGS
# isort: skip
EngineCoreSampleLogprobsType
=
L
ist
[
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
]]
EngineCorePromptLogprobsType
=
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
]
EngineCoreSampleLogprobsType
=
l
ist
[
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]]
EngineCorePromptLogprobsType
=
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]
def
_build_test_vectors_no_logprobs
()
->
DummyOutputProcessorTestVectors
:
...
...
tests/v1/engine/test_async_llm.py
View file @
cf069aa8
...
...
@@ -2,7 +2,7 @@
import
asyncio
from
contextlib
import
ExitStack
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
...
...
@@ -47,7 +47,7 @@ async def generate(engine: AsyncLLM,
prompt
:
PromptType
,
output_kind
:
RequestOutputKind
,
max_tokens
:
int
,
prompt_logprobs
:
Optional
[
int
]
=
None
)
->
T
uple
[
int
,
str
]:
prompt_logprobs
:
Optional
[
int
]
=
None
)
->
t
uple
[
int
,
str
]:
# Ensure generate doesn't complete too fast for cancellation test.
await
asyncio
.
sleep
(
0.2
)
...
...
@@ -114,7 +114,7 @@ async def test_async_llm_refuses_prompt_logprobs_with_apc(
(
VISION_ENGINE_ARGS
,
VISION_PROMPT
)])
@
pytest
.
mark
.
asyncio
async
def
test_load
(
monkeypatch
,
output_kind
:
RequestOutputKind
,
engine_args_and_prompt
:
T
uple
[
AsyncEngineArgs
,
engine_args_and_prompt
:
t
uple
[
AsyncEngineArgs
,
PromptType
]):
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
# so that in the future when we switch, we don't have to change all the
...
...
@@ -160,7 +160,7 @@ async def test_load(monkeypatch, output_kind: RequestOutputKind,
(
VISION_ENGINE_ARGS
,
VISION_PROMPT
)])
@
pytest
.
mark
.
asyncio
async
def
test_abort
(
monkeypatch
,
output_kind
:
RequestOutputKind
,
engine_args_and_prompt
:
T
uple
[
AsyncEngineArgs
,
engine_args_and_prompt
:
t
uple
[
AsyncEngineArgs
,
PromptType
]):
with
monkeypatch
.
context
()
as
m
,
ExitStack
()
as
after
:
...
...
@@ -177,7 +177,7 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
request_ids
=
[
f
"request-
{
i
}
"
for
i
in
range
(
NUM_REQUESTS
)]
# Create concurrent requests.
tasks
:
L
ist
[
asyncio
.
Task
]
=
[]
tasks
:
l
ist
[
asyncio
.
Task
]
=
[]
for
request_id
in
request_ids
:
tasks
.
append
(
asyncio
.
create_task
(
...
...
tests/v1/engine/test_engine_core.py
View file @
cf069aa8
...
...
@@ -5,7 +5,6 @@ import threading
import
time
import
uuid
from
concurrent.futures
import
Future
from
typing
import
List
import
pytest
from
transformers
import
AutoTokenizer
...
...
@@ -213,7 +212,7 @@ def test_engine_core_concurrent_batches(monkeypatch):
class
DummyExecutor
(
UniProcExecutor
):
def
initialize_from_config
(
self
,
kv_cache_configs
:
L
ist
[
KVCacheConfig
])
->
None
:
self
,
kv_cache_configs
:
l
ist
[
KVCacheConfig
])
->
None
:
super
().
initialize_from_config
(
kv_cache_configs
)
# This executor actually can only run 1 batch at a time
...
...
tests/v1/engine/test_engine_core_client.py
View file @
cf069aa8
...
...
@@ -3,7 +3,7 @@
import
asyncio
import
time
import
uuid
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Optional
import
pytest
from
transformers
import
AutoTokenizer
...
...
@@ -44,7 +44,7 @@ def make_request(params: SamplingParams) -> EngineCoreRequest:
)
def
loop_until_done
(
client
:
EngineCoreClient
,
outputs
:
D
ict
):
def
loop_until_done
(
client
:
EngineCoreClient
,
outputs
:
d
ict
):
while
True
:
engine_core_outputs
=
client
.
get_output
().
outputs
...
...
@@ -62,7 +62,7 @@ def loop_until_done(client: EngineCoreClient, outputs: Dict):
break
async
def
loop_until_done_async
(
client
:
EngineCoreClient
,
outputs
:
D
ict
):
async
def
loop_until_done_async
(
client
:
EngineCoreClient
,
outputs
:
d
ict
):
while
True
:
engine_core_outputs
=
(
await
client
.
get_output_async
()).
outputs
...
...
@@ -121,7 +121,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
client
.
add_request
(
request
)
time
.
sleep
(
0.01
)
outputs
:
D
ict
[
str
,
L
ist
]
=
{
req_id
:
[]
for
req_id
in
request_ids
}
outputs
:
d
ict
[
str
,
l
ist
]
=
{
req_id
:
[]
for
req_id
in
request_ids
}
loop_until_done
(
client
,
outputs
)
for
req_id
in
request_ids
:
...
...
@@ -207,7 +207,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
await
client
.
add_request_async
(
request
)
await
asyncio
.
sleep
(
0.01
)
outputs
:
D
ict
[
str
,
L
ist
]
=
{
req_id
:
[]
for
req_id
in
request_ids
}
outputs
:
d
ict
[
str
,
l
ist
]
=
{
req_id
:
[]
for
req_id
in
request_ids
}
await
loop_until_done_async
(
client
,
outputs
)
for
req_id
in
request_ids
:
...
...
tests/v1/engine/test_llm_engine.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
import
random
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
...
...
@@ -47,9 +47,9 @@ def vllm_model_apc(vllm_runner, monkeypatch):
def
_get_test_sampling_params
(
prompt_list
:
L
ist
[
str
],
prompt_list
:
l
ist
[
str
],
seed
:
Optional
[
int
]
=
42
,
)
->
T
uple
[
L
ist
[
SamplingParams
],
L
ist
[
int
]]:
)
->
t
uple
[
l
ist
[
SamplingParams
],
l
ist
[
int
]]:
"""Generate random sampling params for a batch."""
def
get_mostly_n_gt1
()
->
int
:
...
...
@@ -81,7 +81,7 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None:
# Validate each request response
for
out
,
n
in
zip
(
outputs
,
n_list
):
completion_counts
:
D
ict
[
str
,
int
]
=
{}
completion_counts
:
d
ict
[
str
,
int
]
=
{}
# Assert correct number of completions
assert
len
(
out
.
outputs
)
==
n
,
(
f
"
{
len
(
out
.
outputs
)
}
completions;
{
n
}
expected."
)
...
...
tests/v1/engine/test_output_processor.py
View file @
cf069aa8
...
...
@@ -2,7 +2,7 @@
import
math
import
time
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Optional
import
pytest
...
...
@@ -112,12 +112,12 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
def
_validate_logprobs
(
gen_tokens
:
D
ict
[
str
,
L
ist
[
int
]],
gen_logprobs
:
D
ict
[
str
,
Optional
[
SampleLogprobs
]],
gen_prompt_logprobs
:
D
ict
[
str
,
Optional
[
PromptLogprobs
]],
gen_cumulative_logprob
:
D
ict
[
str
,
float
],
gen_tokens
:
d
ict
[
str
,
l
ist
[
int
]],
gen_logprobs
:
d
ict
[
str
,
Optional
[
SampleLogprobs
]],
gen_prompt_logprobs
:
d
ict
[
str
,
Optional
[
PromptLogprobs
]],
gen_cumulative_logprob
:
d
ict
[
str
,
float
],
dtv
:
DummyOutputProcessorTestVectors
,
request_id_list
:
L
ist
[
str
],
request_id_list
:
l
ist
[
str
],
num_sample_logprobs
:
Optional
[
int
],
num_prompt_logprobs
:
Optional
[
int
],
)
->
None
:
...
...
tests/v1/engine/utils.py
View file @
cf069aa8
...
...
@@ -2,7 +2,7 @@
import
random
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Union
import
torch
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
...
...
@@ -61,7 +61,7 @@ def _create_random_top_logprob_test_vector(
def
_create_random_top_logprob_test_matrix
(
shape
:
T
uple
,
shape
:
t
uple
,
lower
:
float
,
upper
:
float
,
)
->
torch
.
Tensor
:
...
...
@@ -90,7 +90,7 @@ def _create_random_top_token_test_vector(
lower
:
int
,
upper
:
int
,
sampled_token_id
:
int
,
adjust_num_logprobs
:
bool
=
True
)
->
T
uple
[
torch
.
Tensor
,
int
]:
adjust_num_logprobs
:
bool
=
True
)
->
t
uple
[
torch
.
Tensor
,
int
]:
"""Create a random vector of top logprob token indices
Use to create fake sample logprobs for testing. The sampled token
...
...
@@ -141,11 +141,11 @@ def _create_random_top_token_test_vector(
def
_create_random_top_token_test_matrix
(
shape
:
T
uple
[
int
,
int
],
shape
:
t
uple
[
int
,
int
],
lower
:
int
,
upper
:
int
,
tokens_list
:
L
ist
[
int
],
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
tokens_list
:
l
ist
[
int
],
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Create a random matrix of top logprob token indices
Use to create fake prompt logprobs for testing.
...
...
@@ -160,7 +160,7 @@ def _create_random_top_token_test_matrix(
upper: upper range of token ids
Returns:
T
uple containing:
t
uple containing:
- 2D num_tokens x num_logprobs+1 torch Tensor of token ids
- 1D tensor of ranks of prompt tokens in their respective
rows, or random values
...
...
@@ -206,10 +206,10 @@ def decode_token(
def
generate_dummy_sample_logprobs
(
sampled_tokens_list
:
L
ist
,
sampled_tokens_list
:
l
ist
,
num_logprobs
:
int
,
tokenizer
:
PreTrainedTokenizer
,
)
->
L
ist
[
T
uple
[
L
ist
[
int
],
L
ist
[
float
],
int
]]:
)
->
l
ist
[
t
uple
[
l
ist
[
int
],
l
ist
[
float
],
int
]]:
"""Generate dummy sample logprobs
Generate a test data structure which imitates the list of sample logprobs
...
...
@@ -221,7 +221,7 @@ def generate_dummy_sample_logprobs(
tokenizer: model tokenizer to use for detokenization
Returns
L
ist of (top token ids vector, logprobs vector, sampled token rank)
l
ist of (top token ids vector, logprobs vector, sampled token rank)
Python lists tuples; in each tuple the logprobs and top token ids
vectors have the same length which is either `num_logprobs` or
`num_logprobs+1`. Sampled token rank is the rank (index+1) of the
...
...
@@ -253,7 +253,7 @@ def generate_dummy_sample_logprobs(
def
generate_dummy_prompt_logprobs_tensors
(
prompt_tokens_list
:
L
ist
,
prompt_tokens_list
:
l
ist
,
num_logprobs
:
int
,
tokenizer
:
PreTrainedTokenizer
,
)
->
LogprobsTensors
:
...
...
@@ -269,7 +269,7 @@ def generate_dummy_prompt_logprobs_tensors(
tokenizer: model tokenizer to use for detokenization
Returns
Single
T
uple of (logprobs matrix, top token ids matrix) torch Tensor,
Single
t
uple of (logprobs matrix, top token ids matrix) torch Tensor,
where both matrices have dimensions
num_prompt_tokens x num_logprobs
"""
...
...
@@ -301,19 +301,19 @@ class DummyOutputProcessorTestVectors:
tokenizer
:
GeneralTokenizerType
tokenizer_group
:
BaseTokenizerGroup
vllm_config
:
EngineArgs
full_tokens
:
L
ist
[
L
ist
[
int
]]
# Prompt + generated tokens
prompt_tokens
:
L
ist
[
L
ist
[
int
]]
generation_tokens
:
L
ist
[
L
ist
[
int
]]
full_tokens
:
l
ist
[
l
ist
[
int
]]
# Prompt + generated tokens
prompt_tokens
:
l
ist
[
l
ist
[
int
]]
generation_tokens
:
l
ist
[
l
ist
[
int
]]
# Each request is associated with a tuple of
# (top tokens, top logprobs, ranks) prompt logprobs tensors
prompt_logprobs
:
L
ist
[
LogprobsTensors
]
prompt_logprobs
:
l
ist
[
LogprobsTensors
]
# Each request is associated with a sample logprobs; a request's
# sample logprobs are a list of (top tokens, top logprobs, ranks)
# sample logprobs tensors at each sequence position
generation_logprobs
:
L
ist
[
L
ist
[
T
uple
[
L
ist
[
int
],
L
ist
[
float
],
int
]]]
prompt_strings
:
L
ist
[
str
]
prompt_strings_len
:
L
ist
[
int
]
generation_strings
:
L
ist
[
str
]
generation_logprobs
:
l
ist
[
l
ist
[
t
uple
[
l
ist
[
int
],
l
ist
[
float
],
int
]]]
prompt_strings
:
l
ist
[
str
]
prompt_strings_len
:
l
ist
[
int
]
generation_strings
:
l
ist
[
str
]
class
MockEngineCore
:
...
...
@@ -321,18 +321,18 @@ class MockEngineCore:
def
__init__
(
self
,
tokens_list
:
L
ist
[
L
ist
[
int
]],
tokens_list
:
l
ist
[
l
ist
[
int
]],
# For each request, for each sampled token offset,
# a tuple of
# (list of topk token ids, list of sample logprob vals, rank)
generated_logprobs_raw
:
Optional
[
L
ist
[
L
ist
[
T
uple
[
L
ist
[
int
],
L
ist
[
float
],
generated_logprobs_raw
:
Optional
[
l
ist
[
l
ist
[
t
uple
[
l
ist
[
int
],
l
ist
[
float
],
int
]]]]
=
None
,
# For each request, a tuple of
# (prompt logprob val matrix, prompt logprob tok id matrix);
# each matrix has dimensions
# (num prompt toks) x (num prompt logprobs+1)
prompt_logprobs_raw
:
Optional
[
L
ist
[
LogprobsTensors
]]
=
None
,
prompt_logprobs_raw
:
Optional
[
l
ist
[
LogprobsTensors
]]
=
None
,
)
->
None
:
self
.
tokens_list
=
tokens_list
self
.
current_idx
=
0
...
...
@@ -341,7 +341,7 @@ class MockEngineCore:
self
.
prompt_logprobs_raw
=
prompt_logprobs_raw
self
.
do_prompt_logprobs
=
prompt_logprobs_raw
is
not
None
def
get_outputs
(
self
)
->
L
ist
[
EngineCoreOutput
]:
def
get_outputs
(
self
)
->
l
ist
[
EngineCoreOutput
]:
do_logprobs
=
self
.
do_logprobs
do_prompt_logprobs
=
self
.
do_prompt_logprobs
token_idx
=
self
.
current_idx
...
...
tests/v1/entrypoints/openai/test_completion.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
import
re
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Optional
import
openai
# use the official client for correctness check
import
pytest
...
...
@@ -193,7 +193,7 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
async
def
test_prompt_logprobs_completion
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
prompt_logprobs
:
Optional
[
int
]):
params
:
D
ict
=
{
params
:
d
ict
=
{
"prompt"
:
[
"A robot may not injure another robot"
,
"My name is"
],
"model"
:
model_name
,
}
...
...
@@ -237,7 +237,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
max_tokens
=
5
,
temperature
=
0.0
,
stream
=
True
)
chunks
:
L
ist
[
str
]
=
[]
chunks
:
l
ist
[
str
]
=
[]
finish_reason_count
=
0
async
for
chunk
in
stream
:
chunks
.
append
(
chunk
.
choices
[
0
].
text
)
...
...
@@ -278,7 +278,7 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
num_completions
=
len
(
completion
.
choices
)
assert
num_completions
==
n
,
(
f
"Num completions
{
num_completions
}
but expected
{
n
}
."
)
completion_repeats
:
D
ict
[
str
,
int
]
=
{}
completion_repeats
:
d
ict
[
str
,
int
]
=
{}
for
idx
,
choice
in
enumerate
(
completion
.
choices
):
# Assert correct completion index & some finish reason.
assert
choice
.
index
==
idx
,
(
...
...
@@ -321,7 +321,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
temperature
=
0.95
,
stream
=
True
,
seed
=
42
)
chunks
:
L
ist
[
L
ist
[
str
]]
=
[[]
for
i
in
range
(
n
)]
chunks
:
l
ist
[
l
ist
[
str
]]
=
[[]
for
i
in
range
(
n
)]
finish_reason_count
=
0
async
for
chunk
in
stream
:
index
=
chunk
.
choices
[
0
].
index
...
...
@@ -332,7 +332,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
# Assert `n` completions with correct finish reasons
assert
finish_reason_count
==
n
,
(
f
"Expected
{
n
}
completions with valid indices and finish_reason."
)
completion_repeats
:
D
ict
[
str
,
int
]
=
{}
completion_repeats
:
d
ict
[
str
,
int
]
=
{}
for
chunk
in
chunks
:
chunk_len
=
len
(
chunk
)
# Assert correct number of completion tokens
...
...
tests/v1/sample/test_logprobs.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
import
itertools
from
typing
import
List
,
Tuple
import
pytest
import
torch
...
...
@@ -46,8 +45,8 @@ def hf_model(hf_runner):
def
_repeat_logprob_config
(
test_prompts
,
logprob_prompt_logprob_list
:
L
ist
[
T
uple
],
)
->
L
ist
[
T
uple
]:
logprob_prompt_logprob_list
:
l
ist
[
t
uple
],
)
->
l
ist
[
t
uple
]:
"""Ensure each test prompt has a logprob config.
A logprob config specifies the optional (i.e.
...
...
@@ -74,7 +73,7 @@ def _repeat_logprob_config(
tuples
Returns:
L
ist of
l
ist of
(optional num sample logprob,optional num prompt logprob)
tuples which is either identical to
`logprob_prompt_logprob_list`, or else repeats
...
...
@@ -177,7 +176,7 @@ def _test_case_get_logprobs_and_prompt_logprobs(
for
r
in
range
(
1
,
num_top_logprobs
+
1
))
output_text
=
vllm_result
.
outputs
[
0
].
text
output_string_from_most_likely_tokens_lst
:
L
ist
[
str
]
=
[]
output_string_from_most_likely_tokens_lst
:
l
ist
[
str
]
=
[]
for
top_logprobs
in
vllm_result
.
outputs
[
0
].
logprobs
:
top_logprob
=
next
(
iter
(
top_logprobs
.
values
()))
output_string_from_most_likely_tokens_lst
.
append
(
...
...
tests/v1/sample/test_rejection_sampler.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
import
pytest
import
torch
...
...
@@ -13,7 +12,7 @@ def sampler():
return
RejectionSampler
()
def
create_logits_tensor
(
token_ids
:
L
ist
[
int
],
def
create_logits_tensor
(
token_ids
:
l
ist
[
int
],
vocab_size
:
int
=
100
)
->
torch
.
Tensor
:
"""Helper function to create logits tensor that
will produce desired token ids on argmax"""
...
...
@@ -23,7 +22,7 @@ def create_logits_tensor(token_ids: List[int],
return
logits
def
create_sampling_metadata
(
spec_tokens
:
L
ist
[
L
ist
[
int
]])
->
SamplingMetadata
:
def
create_sampling_metadata
(
spec_tokens
:
l
ist
[
l
ist
[
int
]])
->
SamplingMetadata
:
batch_size
=
len
(
spec_tokens
)
return
SamplingMetadata
(
temperature
=
torch
.
tensor
([]),
...
...
@@ -106,7 +105,7 @@ def test_single_token_sequence(sampler):
def
test_empty_sequence
(
sampler
):
"""Test handling empty sequence of speculated tokens"""
spec_tokens
:
L
ist
[
L
ist
[
int
]]
=
[[]]
spec_tokens
:
l
ist
[
l
ist
[
int
]]
=
[[]]
output_tokens
=
[
5
]
# Just the bonus token
metadata
=
create_sampling_metadata
(
spec_tokens
)
...
...
tests/v1/sample/test_sampler.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Optional
import
numpy
as
np
import
pytest
...
...
@@ -32,7 +32,7 @@ def _create_penalty_tensor(batch_size: int, penalty_value: float,
def
_create_prompt_tokens_tensor
(
prompt_token_ids
:
L
ist
[
L
ist
[
int
]],
prompt_token_ids
:
l
ist
[
l
ist
[
int
]],
vocab_size
:
int
,
device
:
torch
.
device
,
)
->
torch
.
Tensor
:
...
...
@@ -49,8 +49,8 @@ def _create_logit_bias(
batch_size
:
int
,
vocab_size
:
int
,
bias_value
:
float
,
)
->
L
ist
[
Optional
[
D
ict
[
int
,
float
]]]:
res
:
L
ist
[
Optional
[
D
ict
[
int
,
float
]]]
=
[]
)
->
l
ist
[
Optional
[
d
ict
[
int
,
float
]]]:
res
:
l
ist
[
Optional
[
d
ict
[
int
,
float
]]]
=
[]
for
i
in
range
(
batch_size
):
logit_bias
=
{
min
(
i
,
vocab_size
-
1
):
bias_value
}
res
.
append
(
logit_bias
)
...
...
@@ -83,8 +83,8 @@ def _create_default_sampling_metadata(
vocab_size
:
int
,
device
:
torch
.
device
,
)
->
SamplingMetadata
:
output_token_ids
:
L
ist
[
L
ist
[
int
]]
=
[]
prompt_token_ids
:
L
ist
[
L
ist
[
int
]]
=
[]
output_token_ids
:
l
ist
[
l
ist
[
int
]]
=
[]
prompt_token_ids
:
l
ist
[
l
ist
[
int
]]
=
[]
for
_
in
range
(
batch_size
):
output_token_ids
.
append
(
np
.
random
.
randint
(
0
,
vocab_size
,
size
=
num_output_tokens
).
tolist
())
...
...
@@ -118,8 +118,8 @@ def _create_default_sampling_metadata(
def
_generate_min_token_penalties_and_stop_tokens
(
num_output_tokens
:
int
,
batch_size
:
int
,
vocab_size
:
int
,
batch_indices_for_min_token_penalty
:
L
ist
[
int
]
)
->
D
ict
[
int
,
T
uple
[
int
,
S
et
[
int
]]]:
batch_indices_for_min_token_penalty
:
l
ist
[
int
]
)
->
d
ict
[
int
,
t
uple
[
int
,
s
et
[
int
]]]:
"""
Generates and returns a dict of minimum token penalties and
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
...
...
@@ -130,7 +130,7 @@ def _generate_min_token_penalties_and_stop_tokens(
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
min_tokens
:
D
ict
[
int
,
T
uple
[
int
,
S
et
[
int
]]]
=
{}
min_tokens
:
d
ict
[
int
,
t
uple
[
int
,
s
et
[
int
]]]
=
{}
for
index
in
range
(
batch_size
):
if
index
in
batch_indices_for_min_token_penalty
:
min_tokens
[
index
]
=
(
...
...
@@ -147,7 +147,7 @@ def _generate_min_token_penalties_and_stop_tokens(
def
_create_weighted_output_token_list
(
batch_size
:
int
,
vocab_size
:
int
)
->
T
uple
[
L
ist
[
L
ist
[
int
]],
L
ist
[
L
ist
[
int
]]]:
vocab_size
:
int
)
->
t
uple
[
l
ist
[
l
ist
[
int
]],
l
ist
[
l
ist
[
int
]]]:
"""
Creates an output token list where each token occurs a distinct
number of times.
...
...
@@ -157,7 +157,7 @@ def _create_weighted_output_token_list(
list, each with a different frequency.
Returns:
T
uple[
L
ist[
L
ist[int]],
L
ist[
L
ist[int]]]:
t
uple[
l
ist[
l
ist[int]],
l
ist[
l
ist[int]]]:
- The first element is the output token list, where each sublist
corresponds to a batch and contains tokens with weighted
frequencies.
...
...
@@ -165,8 +165,8 @@ def _create_weighted_output_token_list(
batch, ordered by their frequency in the corresponding output
list.
"""
output_token_ids
:
L
ist
[
L
ist
[
int
]]
=
[]
sorted_token_ids_in_output
:
L
ist
[
L
ist
[
int
]]
=
[]
output_token_ids
:
l
ist
[
l
ist
[
int
]]
=
[]
sorted_token_ids_in_output
:
l
ist
[
l
ist
[
int
]]
=
[]
for
_
in
range
(
batch_size
):
distinct_token_ids
=
np
.
random
.
choice
(
vocab_size
,
size
=
np
.
random
.
randint
(
1
,
10
),
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment