Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
99324e25
Commit
99324e25
authored
Jul 12, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.2' into v0.9.2-ori
parents
cc7f22a8
a5dd03c1
Changes
475
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1543 additions
and
205 deletions
+1543
-205
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+302
-3
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
+2
-2
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
+6
-6
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+2
-1
tests/v1/sample/test_logits_processors.py
tests/v1/sample/test_logits_processors.py
+627
-0
tests/v1/sample/test_logprobs_e2e.py
tests/v1/sample/test_logprobs_e2e.py
+4
-3
tests/v1/sample/test_rejection_sampler.py
tests/v1/sample/test_rejection_sampler.py
+7
-7
tests/v1/sample/test_sampler.py
tests/v1/sample/test_sampler.py
+5
-149
tests/v1/sample/test_topk_topp_sampler.py
tests/v1/sample/test_topk_topp_sampler.py
+7
-6
tests/v1/sample/utils.py
tests/v1/sample/utils.py
+80
-1
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+13
-10
tests/v1/test_async_llm_dp.py
tests/v1/test_async_llm_dp.py
+43
-8
tests/v1/test_external_lb_dp.py
tests/v1/test_external_lb_dp.py
+312
-0
tests/v1/test_oracle.py
tests/v1/test_oracle.py
+1
-8
tests/v1/test_request.py
tests/v1/test_request.py
+16
-0
tests/v1/tpu/test_basic.py
tests/v1/tpu/test_basic.py
+37
-0
tests/v1/tpu/test_kv_cache_update_kernel.py
tests/v1/tpu/test_kv_cache_update_kernel.py
+75
-0
tests/v1/tpu/test_pallas.py
tests/v1/tpu/test_pallas.py
+2
-1
tests/v1/tpu/test_spmd_model_weight_loading.py
tests/v1/tpu/test_spmd_model_weight_loading.py
+1
-0
tests/v1/tpu/test_tpu_qkv_linear.py
tests/v1/tpu/test_tpu_qkv_linear.py
+1
-0
No files found.
Too many changes to show.
To preserve performance only
475 of 475+
files are displayed.
Plain diff
Email patch
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
import
uuid
from
collections
import
defaultdict
from
typing
import
Optional
from
unittest.mock
import
patch
import
pytest
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorMetadata
)
KVConnectorRole
,
NixlAgentMetadata
,
NixlConnector
,
NixlConnectorMetadata
,
NixlConnectorWorker
)
from
vllm.forward_context
import
ForwardContext
from
.utils
import
create_request
,
create_scheduler
,
create_vllm_config
def
test_basic_in
f
erface
():
def
test_basic_in
t
erface
():
"""Unit test for basic NixlConnector interface functionality."""
vllm_config
=
create_vllm_config
()
...
...
@@ -25,7 +35,7 @@ def test_basic_inferface():
scheduler
.
add_request
(
request
)
# Remote Prefill, triggers NixlConnectorMetdata.
# Remote Prefill, triggers NixlConnectorMet
a
data.
scheduler_output
=
scheduler
.
schedule
()
kv_connector_metadata
=
scheduler_output
.
kv_connector_metadata
assert
kv_connector_metadata
is
not
None
...
...
@@ -72,3 +82,292 @@ def test_prompt_less_than_block_size():
# This request should be scheduled regularly.
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
1
class
FakeNixlWrapper
:
"""Mock implementation of NixlWrapper for testing.
We don't inherit from nixl._api.nixl_agent because nixl may not be
installed.
"""
AGENT_METADATA
=
b
"fake_agent_metadata"
REMOTE_AGENT_NAME
=
"remote_agent"
def
__init__
(
self
,
agent_name
:
str
,
*
args
,
**
kwargs
):
self
.
_cycles_before_xfer_done
=
0
self
.
_check_xfer_state_cycles
:
defaultdict
[
int
,
int
]
=
defaultdict
(
lambda
:
0
)
def
get_reg_descs
(
self
,
caches_data
,
memory_type
:
str
)
->
list
:
return
[
str
(
uuid
.
uuid4
())
for
_
in
caches_data
]
def
register_memory
(
self
,
descs
)
->
None
:
pass
def
get_xfer_descs
(
self
,
blocks_data
,
memory_type
:
str
)
->
list
:
return
[
str
(
uuid
.
uuid4
())
for
_
in
blocks_data
]
def
prep_xfer_dlist
(
self
,
agent_name
:
str
,
descs
:
list
)
->
int
:
return
uuid
.
uuid4
().
int
def
get_agent_metadata
(
self
)
->
bytes
:
return
self
.
AGENT_METADATA
def
add_remote_agent
(
self
,
agent_metadata
:
bytes
)
->
str
:
return
self
.
REMOTE_AGENT_NAME
def
get_new_notifs
(
self
)
->
dict
[
str
,
list
[
bytes
]]:
# Used to collect done_sending, which we don't test yet.
return
{}
def
check_xfer_state
(
self
,
handle
:
int
)
->
str
:
if
self
.
_check_xfer_state_cycles
[
handle
]
>=
self
.
_cycles_before_xfer_done
:
return
"DONE"
self
.
_check_xfer_state_cycles
[
handle
]
+=
1
return
"PROC"
def
release_xfer_handle
(
self
,
handle
:
int
)
->
None
:
pass
def
send_notif
(
self
,
agent_name
:
str
,
notif_msg
:
bytes
)
->
None
:
pass
def
make_prepped_xfer
(
self
,
xfer_type
:
str
,
local_xfer_side_handle
:
int
,
local_block_descs_ids
:
list
[
int
],
remote_xfer_side_handle
:
int
,
remote_block_descs_ids
:
list
[
int
],
notif_msg
:
Optional
[
bytes
]
=
None
)
->
int
:
return
uuid
.
uuid4
().
int
def
transfer
(
self
,
handle
:
int
)
->
str
:
return
"PROC"
############################################################
# Follow are for changing the behavior during testing.
############################################################
def
set_cycles_before_xfer_done
(
self
,
cycles
:
int
):
"""Set the number of cycles before a transfer is considered done."""
self
.
_cycles_before_xfer_done
=
cycles
class
FakeNixlConnectorWorker
(
NixlConnectorWorker
):
REMOTE_ENGINE_ID
=
"remote_engine"
def
__init__
(
self
,
*
args
,
hand_shake_latency
:
float
=
1.8
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_hand_shake_latency
=
hand_shake_latency
def
_nixl_handshake
(
self
,
host
:
str
,
port
:
int
,
remote_tp_size
:
int
)
->
dict
[
int
,
str
]:
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
time
.
sleep
(
self
.
_hand_shake_latency
)
# These should've been done in register_kv_caches(), called by
# gpu_model_runner. Here we just hardcode some dummy values.
self
.
slot_size_bytes
=
4096
self
.
block_len
=
self
.
slot_size_bytes
*
self
.
block_size
self
.
num_blocks
=
1
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
remote_agent_name
=
self
.
add_remote_agent
(
NixlAgentMetadata
(
engine_id
=
self
.
REMOTE_ENGINE_ID
,
agent_metadata
=
FakeNixlWrapper
.
AGENT_METADATA
,
kv_caches_base_addr
=
[
0
],
num_blocks
=
1
,
block_len
=
self
.
block_len
,
attn_backend_name
=
self
.
backend_name
,
),
remote_tp_size
=
remote_tp_size
)
return
{
0
:
remote_agent_name
}
class
TestNixlHandshake
:
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
)
def
test_multi_xfer_one_engine
(
self
,
# dist_init is a fixture that initializes the distributed environment.
dist_init
):
"""Test case where multiple xfers are initiated to the same engine.
This test triggers the connector to load remote KV for the same
`request_id`. The transfer is not done immediately due to
`set_cycles_before_xfer_done`, so there is a state where there are
multiple transfer states for the same `request_id`, and `get_finished`
should handle it correctly (wait for all transfers to be done).
"""
vllm_config
=
create_vllm_config
()
request_id
=
"req_id"
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
assert
isinstance
(
connector
.
connector_worker
.
nixl_wrapper
,
FakeNixlWrapper
)
connector
.
connector_worker
.
nixl_wrapper
.
set_cycles_before_xfer_done
(
3
)
num_xfers
=
4
while
True
:
# For the same request_id, initiate multiple xfers across different
# round of `execute_model` calls.
metadata
=
NixlConnectorMetadata
()
if
num_xfers
>
0
:
num_xfers
-=
1
metadata
.
add_new_req
(
request_id
=
request_id
,
local_block_ids
=
[
num_xfers
+
1
,
num_xfers
+
2
,
num_xfers
+
3
],
kv_transfer_params
=
{
"remote_block_ids"
:
[
num_xfers
+
4
,
num_xfers
+
5
,
num_xfers
+
6
],
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
1
,
})
connector
.
bind_connector_metadata
(
metadata
)
# Mimic maybe_setup_kv_connector in gpu_model_runner.
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
)
_before_load
=
time
.
perf_counter
()
connector
.
start_load_kv
(
dummy_ctx
)
_after_load
=
time
.
perf_counter
()
assert
_after_load
-
_before_load
<
0.1
,
"start_load_kv took "
\
f
"
{
_after_load
-
_before_load
}
seconds"
# Mimic get_finished_kv_transfers in gpu_model_runner.
_
,
done_recving
=
connector
.
get_finished
(
finished_req_ids
=
set
())
if
len
(
done_recving
)
>
0
:
assert
request_id
in
done_recving
break
connector
.
clear_connector_metadata
()
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
)
@
pytest
.
mark
.
parametrize
(
"decode_tp_size, prefill_tp_size"
,
[
(
1
,
1
),
(
2
,
1
),
(
4
,
2
),
(
4
,
4
),
])
def
test_async_load_kv
(
self
,
# Fixture that initializes the distributed environment.
dist_init
,
# Simulate consumer-producer TP sizes.
decode_tp_size
,
prefill_tp_size
):
"""Test that NixlConnector's start_load_kv should be non-blocking."""
vllm_config
=
create_vllm_config
()
vllm_config
.
parallel_config
.
tensor_parallel_size
=
decode_tp_size
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
)
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req
(
request_id
=
"id"
,
local_block_ids
=
[
1
,
2
,
3
],
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
prefill_tp_size
,
})
connector
.
bind_connector_metadata
(
metadata
)
timeout
=
2.5
start
=
time
.
perf_counter
()
while
time
.
perf_counter
()
-
start
<
timeout
:
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
)
_before_load
=
time
.
perf_counter
()
connector
.
start_load_kv
(
dummy_ctx
)
_after_load
=
time
.
perf_counter
()
assert
_after_load
-
_before_load
<
0.1
,
"start_load_kv took "
\
f
"
{
_after_load
-
_before_load
}
seconds"
time
.
sleep
(
0.5
)
# backoff for the async handshake to complete.
connector
.
bind_connector_metadata
(
NixlConnectorMetadata
())
_
,
done_recving
=
connector
.
get_finished
(
finished_req_ids
=
set
())
if
len
(
done_recving
)
>
0
:
return
raise
TimeoutError
(
"Took too long to complete async handshake."
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
)
def
test_concurrent_load_kv
(
self
,
# dist_init is a fixture that initializes the distributed environment.
dist_init
):
"""Test that multiple start_load_kv calls should occur concurrently."""
vllm_config
=
create_vllm_config
()
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
)
metadata
=
NixlConnectorMetadata
()
total_reqs
=
5
for
i
in
range
(
total_reqs
):
metadata
.
add_new_req
(
request_id
=
f
"id_
{
i
}
"
,
local_block_ids
=
[
1
,
2
,
3
],
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
1
,
})
connector
.
bind_connector_metadata
(
metadata
)
timeout
=
2.5
*
total_reqs
cnt_finished_reqs
=
0
start
=
time
.
perf_counter
()
while
time
.
perf_counter
()
-
start
<
timeout
:
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
)
_before_load
=
time
.
perf_counter
()
connector
.
start_load_kv
(
dummy_ctx
)
_after_load
=
time
.
perf_counter
()
assert
_after_load
-
_before_load
<
0.1
,
"start_load_kv took "
\
f
"
{
_after_load
-
_before_load
}
seconds"
time
.
sleep
(
0.5
)
# backoff for the async handshake to complete.
connector
.
bind_connector_metadata
(
NixlConnectorMetadata
())
_
,
done_recving
=
connector
.
get_finished
(
finished_req_ids
=
set
())
if
len
(
done_recving
)
>
0
:
cnt_finished_reqs
+=
len
(
done_recving
)
if
cnt_finished_reqs
==
total_reqs
:
return
raise
TimeoutError
(
"Took too long to complete async handshake."
)
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
View file @
99324e25
...
...
@@ -66,7 +66,7 @@ def test_basic_lifecycle():
assert
len
(
scheduler_output
.
finished_req_ids
)
==
1
assert
request_id
in
scheduler_output
.
finished_req_ids
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
assert
len
(
scheduler_output
.
scheduled_cached_reqs
)
==
0
assert
scheduler_output
.
scheduled_cached_reqs
.
num_reqs
==
0
assert
len
(
scheduler
.
finished_req_ids
)
==
0
# (2b): execute_model()
...
...
@@ -81,7 +81,7 @@ def test_basic_lifecycle():
assert
len
(
scheduler
.
running
)
==
0
assert
len
(
scheduler_output
.
finished_req_ids
)
==
0
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
assert
len
(
scheduler_output
.
scheduled_cached_reqs
)
==
0
assert
scheduler_output
.
scheduled_cached_reqs
.
num_reqs
==
0
assert
len
(
scheduler
.
finished_req_ids
)
==
0
# (3b): execute_model()
...
...
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
View file @
99324e25
...
...
@@ -36,7 +36,7 @@ def test_basic_lifecycle():
# Nothing running and empty scheduler output.
assert
len
(
scheduler
.
running
)
==
0
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
assert
len
(
scheduler_output
.
scheduled_cached_reqs
)
==
0
assert
scheduler_output
.
scheduled_cached_reqs
.
num_reqs
==
0
assert
len
(
scheduler_output
.
num_scheduled_tokens
)
==
0
assert
scheduler_output
.
total_num_scheduled_tokens
==
0
...
...
@@ -158,7 +158,7 @@ def test_interleaved_lifecycle():
assert
len
(
scheduler
.
running
)
==
2
assert
len
(
scheduler
.
waiting
)
==
1
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
1
assert
len
(
scheduler_output
.
scheduled_cached_reqs
)
==
1
assert
scheduler_output
.
scheduled_cached_reqs
.
num_reqs
==
1
model_runner_output
=
create_model_runner_output
(
[
request_local_a
,
request_local_b
])
...
...
@@ -169,7 +169,7 @@ def test_interleaved_lifecycle():
assert
len
(
scheduler
.
running
)
==
2
assert
len
(
scheduler
.
waiting
)
==
1
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
assert
len
(
scheduler_output
.
scheduled_cached_reqs
)
==
2
assert
scheduler_output
.
scheduled_cached_reqs
.
num_reqs
==
2
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_local_a
,
request_local_b
])
...
...
@@ -177,14 +177,14 @@ def test_interleaved_lifecycle():
assert
len
(
scheduler
.
running
)
==
2
assert
len
(
scheduler
.
waiting
)
==
1
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
assert
len
(
scheduler_output
.
scheduled_cached_reqs
)
==
2
assert
scheduler_output
.
scheduled_cached_reqs
.
num_reqs
==
2
# STEP 4: KVs arrive.
scheduler_output
=
scheduler
.
schedule
()
assert
len
(
scheduler
.
running
)
==
2
assert
len
(
scheduler
.
waiting
)
==
1
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
assert
len
(
scheduler_output
.
scheduled_cached_reqs
)
==
2
assert
scheduler_output
.
scheduled_cached_reqs
.
num_reqs
==
2
model_runner_output
=
create_model_runner_output
(
[
request_local_a
,
request_local_b
],
...
...
@@ -196,7 +196,7 @@ def test_interleaved_lifecycle():
assert
len
(
scheduler
.
running
)
==
3
assert
len
(
scheduler
.
waiting
)
==
0
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
1
assert
len
(
scheduler_output
.
scheduled_cached_reqs
)
==
2
assert
scheduler_output
.
scheduled_cached_reqs
.
num_reqs
==
2
model_runner_output
=
create_model_runner_output
(
[
request_local_a
,
request_local_b
,
request_remote
])
...
...
tests/v1/kv_connector/unit/utils.py
View file @
99324e25
...
...
@@ -25,7 +25,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert
len
(
scheduler
.
running
)
==
0
assert
len
(
scheduler
.
finished_req_ids
)
==
0
assert
len
(
scheduler
.
finished_recving_kv_req_ids
)
==
0
assert
len
(
scheduler
.
_cached_reqs_data
)
==
0
# EncoderCacheManager.
assert
len
(
scheduler
.
encoder_cache_manager
.
freed
)
==
0
...
...
@@ -150,6 +149,7 @@ def create_request(
request_id
=
f
"id-
{
request_id
}
"
,
prompt_token_ids
=
prompt_token_ids
,
sampling_params
=
sampling_params
,
pooling_params
=
None
,
multi_modal_inputs
=
None
,
multi_modal_placeholders
=
None
,
multi_modal_hashes
=
None
,
...
...
@@ -183,6 +183,7 @@ def create_model_runner_output(
spec_token_ids
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
None
,
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
)
tests/v1/sample/test_logits_processors.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
from
collections.abc
import
Callable
from
typing
import
NamedTuple
,
Optional
,
Union
import
numpy
as
np
import
pytest
import
torch
from
tests.v1.sample.utils
import
(
LogitsprocsTestFakes
,
create_fake_logits
,
create_penalty_tensor
,
create_prompt_tokens_tensor
,
fake_apply_logitsprocs
,
fake_update_logitsprocs_state
)
from
vllm.platforms
import
current_platform
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
is_pin_memory_available
# yapf: disable
from
vllm.v1.sample.logits_processor
import
(
BatchUpdate
,
BatchUpdateBuilder
,
LogitBiasLogitsProcessor
,
LogitsProcessor
,
MinPLogitsProcessor
,
MinTokensLogitsProcessor
,
MoveDirectionality
,
init_builtin_logitsprocs
)
# yapf: enable
from
vllm.v1.sample.metadata
import
SamplingMetadata
PIN_MEMORY_AVAILABLE
=
is_pin_memory_available
()
MAX_NUM_REQS
=
256
VOCAB_SIZE
=
1024
NUM_OUTPUT_TOKENS
=
20
CUDA_DEVICES
=
[
f
"
{
current_platform
.
device_type
}
:
{
i
}
"
for
i
in
range
(
1
if
current_platform
.
device_count
()
==
1
else
2
)
]
MAX_NUM_PROMPT_TOKENS
=
64
MIN_TOKENS_LEN_THRESHOLD
=
5
REQS_PER_LOGITPROC
=
50
STR_NO_LOGITPROC
=
"none"
# LogitsProcessor subclass or "none"
LogitprocType
=
Union
[
type
[
LogitsProcessor
],
str
]
class
LogitsProcsRequestParams
:
"""Encapsulates key params for a single request in a batch.
Params can be customized based on the enabled logitproc
"""
workload_index
:
int
logitproc_type
:
LogitprocType
# Logitproc enabled, specified by str id
out_tokens
:
list
[
int
]
# Output tokens required for min tokens test
params
:
SamplingParams
# Settings customized for logitproc
def
__init__
(
self
,
workload_index
:
int
,
logitproc_type
:
LogitprocType
):
self
.
workload_index
=
workload_index
self
.
logitproc_type
=
logitproc_type
# Number of output tokens is randomly 0 or twice the min-tokens
# threshold which will be used in testing. Output token values
# don't matter *for these tests* so use 0 as a dummy value
self
.
out_tokens
=
([
0
]
*
(
MIN_TOKENS_LEN_THRESHOLD
*
random
.
randint
(
0
,
2
)))
self
.
params
=
_sampling_params_from_logitproc
(
logitproc_type
)
def
__str__
(
self
):
"""For debugging"""
summ
=
', '
.
join
(
f
'
{
k
}
=
{
v
}
'
for
k
,
v
in
vars
(
self
).
items
())
return
f
"MyClass(
{
summ
}
)"
def
_generate_fake_sampling_metadata
(
num_output_tokens
:
int
,
batch_size
:
int
,
vocab_size
:
int
,
device
:
torch
.
device
,
)
->
SamplingMetadata
:
"""Generate fake sampling metadata with fake logitsprocs"""
output_token_ids
:
list
[
list
[
int
]]
=
[]
prompt_token_ids
:
list
[
list
[
int
]]
=
[]
for
_
in
range
(
batch_size
):
output_token_ids
.
append
(
np
.
random
.
randint
(
0
,
vocab_size
,
size
=
num_output_tokens
).
tolist
())
prompt_token_ids
.
append
(
np
.
random
.
randint
(
0
,
vocab_size
,
size
=
np
.
random
.
randint
(
1
,
MAX_NUM_PROMPT_TOKENS
)).
tolist
())
logitsprocs
=
init_builtin_logitsprocs
(
pin_memory_available
=
PIN_MEMORY_AVAILABLE
,
max_num_reqs
=
MAX_NUM_REQS
+
1
,
device
=
device
)
fake_sampling_metadata
=
SamplingMetadata
(
temperature
=
torch
.
full
((
batch_size
,
),
0.0
),
all_greedy
=
True
,
all_random
=
False
,
top_p
=
None
,
top_k
=
None
,
generators
=
{},
max_num_logprobs
=
0
,
prompt_token_ids
=
create_prompt_tokens_tensor
(
prompt_token_ids
,
vocab_size
,
device
),
output_token_ids
=
output_token_ids
,
frequency_penalties
=
create_penalty_tensor
(
batch_size
,
0.0
,
device
),
presence_penalties
=
create_penalty_tensor
(
batch_size
,
0.0
,
device
),
repetition_penalties
=
create_penalty_tensor
(
batch_size
,
1.0
,
device
),
no_penalties
=
True
,
allowed_token_ids_mask
=
None
,
bad_words_token_ids
=
{},
logitsprocs
=
logitsprocs
)
return
fake_sampling_metadata
def
_generate_test_fakes
(
batch_size
:
int
,
device
:
str
)
->
LogitsprocsTestFakes
:
"""Generate fake logits and sampling metadata"""
fake_logits
=
create_fake_logits
(
batch_size
,
VOCAB_SIZE
)
# Create one dominant token per batch, to support min-p test
for
i
in
range
(
batch_size
):
fake_logits
[
i
,
0
]
=
10.0
# High logit for first token
fake_logits
[
i
,
1
:]
=
1e-2
# Others remain low
sampling_metadata
=
_generate_fake_sampling_metadata
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
torch
.
device
(
device
))
return
LogitsprocsTestFakes
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
,
)
def
_sampling_params_from_logitproc
(
logitproc_type
:
LogitprocType
)
->
SamplingParams
:
"""Customize request SamplingParams for a specified logitproc"""
# SamplingParams for req with no logitproc
kwargs
=
{
"min_p"
:
0.0
,
"logit_bias"
:
None
,
"min_tokens"
:
0
}
if
fxn
:
=
logitsprocs_test_mapping
[
logitproc_type
].
gen_request_fxn
:
fxn
(
kwargs
)
return
SamplingParams
(
**
kwargs
)
def
_generate_mixed_logitsprocs_batch_params
(
reqs_per_logitproc
:
int
,
logitsprocs_types
:
list
[
str
],
)
->
list
[
LogitsProcsRequestParams
]:
"""Define key params for a batch of requests with a different
logitproc enabled per request.
The batch will have `reqs_per_logitproc` repeats for all
`logitsprocs_types` under test, including the case where
no logitsproc is enabled. The batch is randomly shuffled. The
size of the batch is `reqs_per_logitproc` times
`n = len(logitsprocs_types)`
Args:
reqs_per_logitproc: number of requests using each logitproc
logitsprocs_types: logitsprocs under test
Returns:
List of per-request params which configure the engine for that request's
enabled logitproc
"""
batch_size
=
len
(
logitsprocs_types
)
*
reqs_per_logitproc
# Generate multiple repeats of key params for each logitproc;
# apply random inverse permutation to the iteration
# over logitsprocs, such that logitsprocs are shuffled.
batch_perm
=
random
.
sample
(
range
(
batch_size
),
k
=
batch_size
)
return
[
LogitsProcsRequestParams
(
workload_index
=
idx
,
logitproc_type
=
logitsprocs_types
[
pdx
//
reqs_per_logitproc
])
for
idx
,
pdx
in
enumerate
(
batch_perm
)
]
def
_raise_error_invalid
(
msg_suffix
:
str
,
batch_index
:
int
,
request_params
:
LogitsProcsRequestParams
,
step_idx
:
int
,
err_cls
:
type
[
Exception
]
=
ValueError
,
)
->
None
:
raise
err_cls
(
f
"Validation failed for step=
{
step_idx
}
, "
f
"batch_index=
{
batch_index
}
, "
f
"workload_index=
{
request_params
.
workload_index
}
, "
f
"req_params=
{
request_params
}
. Reason:
{
msg_suffix
}
"
)
def
_logit_bias_params
(
kwargs
:
dict
)
->
None
:
"""Logit bias config"""
kwargs
[
"logit_bias"
]
=
{
random
.
randint
(
0
,
VOCAB_SIZE
-
1
):
random
.
choice
([
-
0.1
,
0.2
])
}
def
_logit_bias_validate
(
test_fakes
:
LogitsprocsTestFakes
,
persistent_batch
:
list
[
LogitsProcsRequestParams
],
logits_new
:
torch
.
Tensor
,
batch_index
:
int
,
request_params
:
LogitsProcsRequestParams
,
step_idx
:
int
,
)
->
None
:
"""Validate logit bias logitproc applied correctly"""
logit_bias
=
request_params
.
params
.
logit_bias
logits_old
=
(
test_fakes
.
logits
[
persistent_batch
[
batch_index
].
workload_index
].
cpu
())
logits_new
=
logits_new
[
batch_index
].
cpu
()
for
token_id
in
range
(
VOCAB_SIZE
):
logit_old_value
=
logits_old
[
token_id
]
logit_new_value
=
logits_new
[
token_id
]
if
token_id
in
logit_bias
:
bias_value
=
logit_bias
[
token_id
]
exp_value
=
bias_value
+
logit_old_value
if
logit_new_value
!=
pytest
.
approx
(
exp_value
):
_raise_error_invalid
(
msg_suffix
=
(
f
"Biased token
{
token_id
}
logit value
{
logit_new_value
}
"
f
"does not match expected value
{
exp_value
}
"
f
"given bias
{
bias_value
}
"
),
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
else
:
if
logit_new_value
!=
pytest
.
approx
(
logit_old_value
):
_raise_error_invalid
(
msg_suffix
=
(
f
"Unbiased token
{
token_id
}
logit value
{
logit_new_value
}
"
f
"does not match expected value
{
logit_old_value
}
"
),
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
def
_min_p_params
(
kwargs
:
dict
)
->
None
:
"""Min-p logitproc config"""
kwargs
[
"min_p"
]
=
0.1
def
_min_p_validate
(
test_fakes
:
LogitsprocsTestFakes
,
persistent_batch
:
list
[
LogitsProcsRequestParams
],
logits_new
:
torch
.
Tensor
,
batch_index
:
int
,
request_params
:
LogitsProcsRequestParams
,
step_idx
:
int
,
)
->
None
:
"""Validate min-p logitproc applied correctly"""
for
token_id
in
range
(
VOCAB_SIZE
):
logits_for_token
=
logits_new
[
batch_index
][
token_id
]
if
token_id
==
0
:
# Dominant token should always be unmasked
if
logits_for_token
==
-
float
(
"inf"
):
_raise_error_invalid
(
msg_suffix
=
"Invalid: dominant token 0 masked (-inf)"
,
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
else
:
if
request_params
.
params
.
min_p
>
0.0
:
# Non-dominant tokens should be masked when min_p > 0
if
logits_for_token
!=
-
float
(
"inf"
):
_raise_error_invalid
(
msg_suffix
=
f
"Invalid: non-dominant token
{
token_id
}
not masked"
,
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
else
:
# No masking when min_p is 0
if
logits_for_token
==
-
float
(
"inf"
):
_raise_error_invalid
(
msg_suffix
=
f
"Invalid: token
{
token_id
}
masked when min_p=0.0"
,
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
def
_min_tokens_params
(
kwargs
:
dict
)
->
None
:
"""Min-tokens logitproc config"""
kwargs
[
"min_tokens"
]
=
MIN_TOKENS_LEN_THRESHOLD
kwargs
[
"stop_token_ids"
]
=
[
np
.
random
.
randint
(
0
,
VOCAB_SIZE
-
1
)
for
_
in
range
(
np
.
random
.
randint
(
0
,
VOCAB_SIZE
))
]
def
_min_tokens_validate
(
test_fakes
:
LogitsprocsTestFakes
,
persistent_batch
:
list
[
LogitsProcsRequestParams
],
logits_new
:
torch
.
Tensor
,
batch_index
:
int
,
request_params
:
LogitsProcsRequestParams
,
step_idx
:
int
,
)
->
None
:
"""Validate min-tokens logitsproc applied correctly"""
ref_num_out_tokens
=
len
(
request_params
.
out_tokens
)
min_reached
=
ref_num_out_tokens
>=
MIN_TOKENS_LEN_THRESHOLD
ref_all_stop_token_ids
=
request_params
.
params
.
all_stop_token_ids
mt_lp
:
MinTokensLogitsProcessor
=
next
(
test_fakes
.
get_logitsprocs_by_cls
(
MinTokensLogitsProcessor
))
assert
isinstance
(
mt_lp
,
MinTokensLogitsProcessor
)
min_tok
=
mt_lp
.
min_toks
.
get
(
batch_index
,
None
)
# Validate min-token logits processor state
if
min_tok
:
(
_
,
out_tok
,
all_stop_token_ids
)
=
min_tok
num_out_tokens
=
len
(
out_tok
)
if
num_out_tokens
!=
ref_num_out_tokens
:
_raise_error_invalid
(
msg_suffix
=
(
"Number of output tokens in min-token logit processor "
f
"request metadata (
{
num_out_tokens
}
) does not match "
f
"reference (
{
ref_num_out_tokens
}
)."
),
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
if
ref_all_stop_token_ids
!=
all_stop_token_ids
:
_raise_error_invalid
(
msg_suffix
=
(
"Stop token ids do not match reference; all_stop_token_ids: "
f
"
{
sorted
(
all_stop_token_ids
)
}
, ref_all_stop_token_ids: "
f
"
{
sorted
(
ref_all_stop_token_ids
)
}
"
),
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
if
min_reached
:
_raise_error_invalid
(
msg_suffix
=
(
"Expected min-tokens request with min reached, but batch "
"index is recognized by min-tokens logits processor."
),
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
,
err_cls
=
RuntimeError
)
elif
not
min_reached
:
_raise_error_invalid
(
msg_suffix
=
(
"Expected min-tokens request with min not reached, but batch "
"index is not recognized by min-tokens logits processor."
),
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
,
err_cls
=
RuntimeError
)
# Validate min-token logits
for
token_id
in
range
(
VOCAB_SIZE
):
logits_for_token
=
logits_new
[
batch_index
][
token_id
]
if
token_id
in
ref_all_stop_token_ids
and
not
min_reached
:
if
logits_for_token
!=
-
float
(
"inf"
):
_raise_error_invalid
(
msg_suffix
=
(
f
"Token
{
token_id
}
is a stop token and "
"the sequence has not reached min length, "
"but the token is not masked "
f
"(logit=
{
logits_for_token
}
)"
),
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
else
:
if
logits_for_token
==
-
float
(
"inf"
):
_raise_error_invalid
(
msg_suffix
=
(
f
"Token
{
token_id
}
should not be masked but "
f
"is (output len=
{
ref_num_out_tokens
}
)"
),
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
def
_none_validate
(
test_fakes
:
LogitsprocsTestFakes
,
persistent_batch
:
list
[
LogitsProcsRequestParams
],
logits_new
:
torch
.
Tensor
,
batch_index
:
int
,
request_params
:
LogitsProcsRequestParams
,
step_idx
:
int
,
)
->
None
:
"""Validate that no logits processors are applied"""
logits
=
(
test_fakes
.
logits
[
persistent_batch
[
batch_index
].
workload_index
].
cpu
())
ref_logits
=
logits_new
[
batch_index
]
if
not
torch
.
all
(
ref_logits
==
logits
):
mismatch_toks
=
(
ref_logits
!=
logits
).
nonzero
(
as_tuple
=
True
)[
0
].
tolist
()
mismatch_strs
=
[]
for
token
in
mismatch_toks
:
val
=
float
(
logits
[
token
])
ref_val
=
float
(
ref_logits
[
token
])
mismatch_strs
.
append
(
f
"(
{
token
=
}
,
{
val
=
}
,
{
ref_val
=
}
)"
)
_raise_error_invalid
(
msg_suffix
=
(
f
"Unexpected modification of logits:
{
','
.
join
(
mismatch_strs
)
}
"
),
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
class
LogitsprocTestHelpers
(
NamedTuple
):
"""Supports setting up and validating logitsprocs unit tests."""
eval_fxn
:
Callable
gen_request_fxn
:
Optional
[
Callable
]
=
None
logitsprocs_test_mapping
=
{
STR_NO_LOGITPROC
:
LogitsprocTestHelpers
(
eval_fxn
=
_none_validate
),
LogitBiasLogitsProcessor
:
LogitsprocTestHelpers
(
gen_request_fxn
=
_logit_bias_params
,
eval_fxn
=
_logit_bias_validate
),
MinPLogitsProcessor
:
LogitsprocTestHelpers
(
gen_request_fxn
=
_min_p_params
,
eval_fxn
=
_min_p_validate
),
MinTokensLogitsProcessor
:
LogitsprocTestHelpers
(
gen_request_fxn
=
_min_tokens_params
,
eval_fxn
=
_min_tokens_validate
),
}
def
_get_test_cases
()
->
list
[
list
[
str
]]:
"""Each test case is a set of logitsprocs"""
logitsprocs_types
=
list
(
logitsprocs_test_mapping
.
keys
())
return
[[
STR_NO_LOGITPROC
]]
+
[[
logitproc_type
,
STR_NO_LOGITPROC
]
for
logitproc_type
in
logitsprocs_types
if
logitproc_type
!=
STR_NO_LOGITPROC
]
+
[
logitsprocs_types
]
def
_generate_fake_step_update
(
persistent_batch
:
list
[
LogitsProcsRequestParams
],
workload_params
:
list
[
LogitsProcsRequestParams
],
wdx
:
int
,
batch_update_builder
:
BatchUpdateBuilder
,
)
->
tuple
[
Optional
[
BatchUpdate
],
int
,
int
]:
batch_size
=
len
(
persistent_batch
)
workload_size
=
len
(
workload_params
)
workload_reqs_remaining
=
workload_size
-
wdx
max_add_remove_per_step
=
max
(
1
,
int
(
0.2
*
workload_size
))
# 50% of steps: add no reqs
# Other 50%: add a limited number of reqs (less than the number
# of workload reqs remaining, less than an arbitrary max)
# If no workload reqs remain: 100% of steps have 0 adds
num_step_add
=
random
.
choice
([
0
,
random
.
randint
(
1
,
min
(
max_add_remove_per_step
,
workload_reqs_remaining
))
])
if
workload_reqs_remaining
else
0
# 50% of steps: remove no requests
# Other 50%: remove a limited number of reqs (less than the number
# persistent batch reqs remaining, less than an arbitrary max)
# If persistent batch is empty: 100% of steps have 0 removals until
# more requests are added. Assume that removed requests are always
# drawn from the current batch, before new adds
num_step_remove
=
random
.
choice
([
0
,
random
.
randint
(
1
,
min
(
max_add_remove_per_step
,
batch_size
))
])
if
batch_size
else
0
num_step_add_replace
=
min
(
num_step_add
,
num_step_remove
)
# Generate fake removed request indices drawn from persistent batch indices
for
removal
in
random
.
sample
(
range
(
batch_size
),
num_step_remove
):
batch_update_builder
.
removed_append
(
removal
)
# Get added requests from workload
for
add_req_params
in
workload_params
[
wdx
:(
wdx
+
num_step_add_replace
)]:
# Replace as many removed requests as possible with added requests
add_remove_idx
=
batch_update_builder
.
pop_removed
()
batch_update_builder
.
added
.
append
(
(
add_remove_idx
,
add_req_params
.
params
,
add_req_params
.
out_tokens
))
persistent_batch
[
add_remove_idx
]
=
add_req_params
# Append remaining added requests to end of batch
add_reqs_append
=
workload_params
[(
wdx
+
num_step_add_replace
):(
wdx
+
num_step_add
)]
batch_update_builder
.
added
.
extend
([
(
adx
+
batch_size
,
add_req_params
.
params
,
add_req_params
.
out_tokens
)
for
adx
,
add_req_params
in
enumerate
(
add_reqs_append
)
])
persistent_batch
.
extend
(
add_reqs_append
)
pre_condense_batch_size
=
len
(
persistent_batch
)
wdx
+=
num_step_add
# Update workload offset
# Simulate condensing persistent batch
last_nonempty_index
=
pre_condense_batch_size
-
1
condensed_to_idxs
=
set
()
while
batch_update_builder
.
removed
:
if
(
last_nonempty_index
in
batch_update_builder
.
removed
or
last_nonempty_index
in
condensed_to_idxs
):
last_nonempty_index
-=
1
continue
# last_nonempty_index is the highest persistent batch index that was
# not removed
first_empty_index
=
batch_update_builder
.
peek_removed
()
assert
first_empty_index
is
not
None
if
first_empty_index
>
last_nonempty_index
:
break
# first_empty_index is the lowest removed persistent batch index
# that is less than last_nonempty_index
#
# move last_nonempty_index -> first_empty_index
batch_update_builder
.
pop_removed
()
condensed_to_idxs
.
add
(
first_empty_index
)
persistent_batch
[
first_empty_index
]
=
persistent_batch
[
last_nonempty_index
]
batch_update_builder
.
moved
.
append
(
(
last_nonempty_index
,
first_empty_index
,
MoveDirectionality
.
UNIDIRECTIONAL
))
last_nonempty_index
-=
1
# Now removed requests & gaps left by non-removed requests that got
# moved downward are grouped consecutively in the upper indices of
# the persistent batch. Truncate them to get condensed persistent batch
condensed_batch_size
=
batch_size
+
num_step_add
-
num_step_remove
persistent_batch
[:]
=
persistent_batch
[
0
:
condensed_batch_size
]
if
condensed_batch_size
>
1
:
# Simulate arbitrary reorder_batch() in the kernel backend
# Generate a random number k of non-overlapping swap tuples
k
=
random
.
randint
(
0
,
condensed_batch_size
//
2
)
idxs
=
list
(
range
(
condensed_batch_size
))
random
.
shuffle
(
idxs
)
swaps
=
[
tuple
(
sorted
([
idxs
[
2
*
i
],
idxs
[
2
*
i
+
1
]]))
for
i
in
range
(
k
)
]
batch_update_builder
.
moved
.
extend
([
(
sw
[
0
],
sw
[
1
],
MoveDirectionality
.
SWAP
)
for
sw
in
swaps
])
for
adx
,
bdx
in
swaps
:
persistent_batch
[
adx
],
persistent_batch
[
bdx
]
=
persistent_batch
[
bdx
],
persistent_batch
[
adx
]
return
(
batch_update_builder
.
get_and_reset
(
condensed_batch_size
),
wdx
,
workload_size
-
wdx
)
def
_assert_valid
(
batch_size
:
int
,
persistent_batch
:
list
[
LogitsProcsRequestParams
],
test_fakes
:
LogitsprocsTestFakes
,
slice_idxs
:
list
[
int
],
logits_w_lp
:
torch
.
Tensor
,
step_idx
:
int
,
)
->
None
:
if
not
slice_idxs
:
# Trivial case of empty persistent batch
assert
len
(
persistent_batch
)
==
0
if
logits_w_lp
.
shape
[
0
]
!=
0
:
raise
ValueError
(
"Fake persistent batch is empty but logitsprocs "
f
"output batch has shape
{
logits_w_lp
.
shape
}
"
)
return
# Validate logits for each fake request
for
batch_index
in
range
(
batch_size
):
request_params
=
persistent_batch
[
batch_index
]
# Invoke the appropriate validation function for
# the logitproc employed by this request
fxn
=
logitsprocs_test_mapping
[
request_params
.
logitproc_type
].
eval_fxn
fxn
(
test_fakes
=
test_fakes
,
persistent_batch
=
persistent_batch
,
logits_new
=
logits_w_lp
,
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"reqs_per_logitproc"
,
[
REQS_PER_LOGITPROC
])
@
pytest
.
mark
.
parametrize
(
"logitsprocs_under_test"
,
_get_test_cases
())
def
test_logitsprocs
(
device
:
str
,
reqs_per_logitproc
:
int
,
logitsprocs_under_test
:
list
[
str
]):
random
.
seed
(
40
)
torch
.
set_default_device
(
device
)
# Define a shuffled batch of requests which individually use a different
# logitproc, or no logitproc at all
workload_params
=
_generate_mixed_logitsprocs_batch_params
(
reqs_per_logitproc
=
reqs_per_logitproc
,
logitsprocs_types
=
logitsprocs_under_test
)
workload_size
=
len
(
workload_params
)
# Create fake test data structures for testing.
test_fakes
=
_generate_test_fakes
(
workload_size
,
device
)
wdx
=
0
# Next request index in workload to add
persistent_batch
:
list
[
LogitsProcsRequestParams
]
=
[
]
# Persistent batch state, as list of workload indices
# Generate fake removed request indices from current persistent
# batch before adds
batch_update_builder
=
BatchUpdateBuilder
()
# Break when entire workload has been added previously and persistent
# batch is empty
workload_reqs_remaining
=
workload_size
batch_size
=
0
step_idx
=
0
while
True
:
if
not
(
workload_reqs_remaining
or
batch_size
):
break
(
batch_update
,
wdx
,
workload_reqs_remaining
,
)
=
_generate_fake_step_update
(
persistent_batch
=
persistent_batch
,
workload_params
=
workload_params
,
wdx
=
wdx
,
batch_update_builder
=
batch_update_builder
,
)
batch_size
=
len
(
persistent_batch
)
# Apply fake batch update to logitsprocs
fake_update_logitsprocs_state
(
test_fakes
,
batch_update
)
# Emulate application of logits processors in engine
slice_idxs
=
[
req
.
workload_index
for
req
in
persistent_batch
]
logits_w_lp
=
fake_apply_logitsprocs
(
test_fakes
,
slice_idxs
).
cpu
()
_assert_valid
(
batch_size
=
batch_size
,
persistent_batch
=
persistent_batch
,
test_fakes
=
test_fakes
,
slice_idxs
=
slice_idxs
,
logits_w_lp
=
logits_w_lp
,
step_idx
=
step_idx
,
)
step_idx
+=
1
tests/v1/sample/test_logprobs_e2e.py
View file @
99324e25
...
...
@@ -13,9 +13,10 @@ EXPECTED_VALUE = 0.62
# FIXME(rob): enable prefix caching once supported.
MODEL
=
"meta-llama/Llama-3.2-1B-Instruct"
MODEL_ARGS
=
f
"pretrained=
{
MODEL
}
,enforce_eager=True,enable_prefix_caching=False"
# noqa: E501
MODEL_ARGS
=
f
"pretrained=
{
MODEL
}
,enforce_eager=True,enable_prefix_caching=False
,gpu_memory_utilization=0.8
"
# noqa: E501
SERVER_ARGS
=
[
"--enforce_eager"
,
"--no_enable_prefix_caching"
,
"--disable-log-requests"
"--enforce_eager"
,
"--no_enable_prefix_caching"
,
"--disable-log-requests"
,
"--gpu-memory-utilization=0.8"
]
NUM_CONCURRENT
=
100
...
...
@@ -32,7 +33,7 @@ def test_prompt_logprobs_e2e():
),
f
"Expected:
{
EXPECTED_VALUE
}
| Measured:
{
measured_value
}
"
def
test_promt_logprobs_e2e_server
():
def
test_prom
p
t_logprobs_e2e_server
():
with
RemoteOpenAIServer
(
MODEL
,
SERVER_ARGS
)
as
remote_server
:
url
=
f
"
{
remote_server
.
url_for
(
'v1'
)
}
/completions"
...
...
tests/v1/sample/test_rejection_sampler.py
View file @
99324e25
...
...
@@ -6,12 +6,14 @@ import pytest
import
torch
import
torch.nn.functional
as
F
from
vllm.platforms
import
current_platform
from
vllm.v1.sample.logits_processor
import
LogitsProcessorManager
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
(
PLACEHOLDER_TOKEN_ID
,
RejectionSampler
)
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
DEVICE
=
"
cu
da"
DEVICE
=
cu
rrent_platform
.
device_type
@
pytest
.
fixture
...
...
@@ -21,7 +23,7 @@ def rejection_sampler():
def
create_logits_tensor
(
output_token_ids
:
list
[
list
[
int
]],
vocab_size
:
int
=
100
)
->
torch
.
Tensor
:
"""Helper function to create logits tensor that
"""Helper function to create logits tensor that
will produce desired token ids on argmax"""
token_ids
=
[
tokens
[:
-
1
]
for
tokens
in
output_token_ids
]
num_total_tokens
=
sum
(
len
(
tokens
)
for
tokens
in
token_ids
)
...
...
@@ -41,8 +43,8 @@ def create_sampling_metadata(
top_p
:
Optional
[
torch
.
Tensor
]
=
None
,
generators
:
Optional
[
dict
[
int
,
Any
]]
=
None
,
)
->
SamplingMetadata
:
"""Create a v1 sampling metadata object with all_greedy set
to the given value. Either all greedy or all random sampling
"""Create a v1 sampling metadata object with all_greedy set
to the given value. Either all greedy or all random sampling
is used.
"""
generators
=
generators
or
{}
...
...
@@ -57,7 +59,6 @@ def create_sampling_metadata(
all_random
=
not
all_greedy
,
top_p
=
top_p
,
top_k
=
top_k
,
min_p
=
torch
.
empty
(
1
,
),
generators
=
generators
,
max_num_logprobs
=
0
,
no_penalties
=
False
,
...
...
@@ -66,10 +67,9 @@ def create_sampling_metadata(
presence_penalties
=
torch
.
tensor
([]),
repetition_penalties
=
torch
.
tensor
([]),
output_token_ids
=
[],
min_tokens
=
{},
logit_bias
=
[
None
],
allowed_token_ids_mask
=
None
,
bad_words_token_ids
=
{},
logitsprocs
=
LogitsProcessorManager
(),
)
...
...
tests/v1/sample/test_sampler.py
View file @
99324e25
...
...
@@ -8,10 +8,13 @@ import pytest
import
torch
from
vllm.platforms
import
current_platform
from
vllm.utils
import
make_tensor_with_pad
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
from
vllm.v1.sample.logits_processor
import
LogitsProcessorManager
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.sampler
import
Sampler
PIN_MEMORY_AVAILABLE
=
is_pin_memory_available
()
MAX_NUM_REQS
=
256
VOCAB_SIZE
=
1024
NUM_OUTPUT_TOKENS
=
20
CUDA_DEVICES
=
[
...
...
@@ -48,18 +51,6 @@ def _create_prompt_tokens_tensor(
)
def
_create_logit_bias
(
batch_size
:
int
,
vocab_size
:
int
,
bias_value
:
float
,
)
->
list
[
Optional
[
dict
[
int
,
float
]]]:
res
:
list
[
Optional
[
dict
[
int
,
float
]]]
=
[]
for
i
in
range
(
batch_size
):
logit_bias
=
{
min
(
i
,
vocab_size
-
1
):
bias_value
}
res
.
append
(
logit_bias
)
return
res
def
_create_allowed_token_ids
(
batch_size
:
int
,
vocab_size
:
int
,
...
...
@@ -145,7 +136,6 @@ def _create_default_sampling_metadata(
all_random
=
False
,
top_p
=
None
,
top_k
=
None
,
min_p
=
None
,
generators
=
{},
max_num_logprobs
=
0
,
prompt_token_ids
=
_create_prompt_tokens_tensor
(
prompt_token_ids
,
...
...
@@ -155,43 +145,13 @@ def _create_default_sampling_metadata(
presence_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
repetition_penalties
=
_create_penalty_tensor
(
batch_size
,
1.0
,
device
),
no_penalties
=
True
,
min_tokens
=
{},
logit_bias
=
[
None
]
*
batch_size
,
allowed_token_ids_mask
=
None
,
bad_words_token_ids
=
{},
logitsprocs
=
LogitsProcessorManager
(),
)
return
fake_sampling_metadata
def
_generate_min_token_penalties_and_stop_tokens
(
num_output_tokens
:
int
,
batch_size
:
int
,
vocab_size
:
int
,
batch_indices_for_min_token_penalty
:
list
[
int
]
)
->
dict
[
int
,
tuple
[
int
,
set
[
int
]]]:
"""
Generates and returns a dict of minimum token penalties and
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
batch.
If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
min_tokens
:
dict
[
int
,
tuple
[
int
,
set
[
int
]]]
=
{}
for
index
in
range
(
batch_size
):
if
index
in
batch_indices_for_min_token_penalty
:
min_tokens
[
index
]
=
(
np
.
random
.
randint
(
num_output_tokens
+
1
,
2
*
num_output_tokens
),
set
(
np
.
random
.
randint
(
0
,
vocab_size
-
1
)
for
_
in
range
(
np
.
random
.
randint
(
0
,
vocab_size
))))
else
:
min_tokens
[
index
]
=
(
np
.
random
.
randint
(
0
,
num_output_tokens
),
set
())
return
min_tokens
def
_create_weighted_output_token_list
(
batch_size
:
int
,
vocab_size
:
int
)
->
tuple
[
list
[
list
[
int
]],
list
[
list
[
int
]]]:
...
...
@@ -227,36 +187,6 @@ def _create_weighted_output_token_list(
return
output_token_ids
,
sorted_token_ids_in_output
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
def
test_sampler_min_tokens_penalty
(
device
:
str
,
batch_size
:
int
):
"""
Tests that if the number of output tokens is less than
SamplingParams.min_tokens then we will set the logits for
the stop token ids to -inf.
"""
torch
.
set_default_device
(
device
)
fake_logits
=
_create_fake_logits
(
batch_size
,
VOCAB_SIZE
)
sampling_metadata
=
_create_default_sampling_metadata
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
torch
.
device
(
device
))
batch_indices_for_min_token_penalty
=
np
.
random
.
randint
(
0
,
batch_size
-
1
,
size
=
np
.
random
.
randint
(
0
,
batch_size
)).
tolist
()
min_tokens
=
_generate_min_token_penalties_and_stop_tokens
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
batch_indices_for_min_token_penalty
)
sampling_metadata
.
min_tokens
=
min_tokens
sampler
=
Sampler
()
logits
=
sampler
.
apply_penalties
(
fake_logits
,
sampling_metadata
)
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
for
token_id
in
range
(
VOCAB_SIZE
):
_
,
stop_token_ids
=
min_tokens
.
get
(
batch_idx
,
(
0
,
set
()))
if
token_id
in
stop_token_ids
:
assert
logits
[
batch_idx
][
token_id
]
==
-
float
(
"inf"
)
else
:
assert
logits
[
batch_idx
][
token_id
]
!=
-
float
(
"inf"
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"presence_penalty"
,
[
-
2.0
,
2.0
])
...
...
@@ -401,80 +331,6 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
or
non_penalized_token_id
in
output_tokens
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"min_p"
,
[
0.0
,
0.1
])
def
test_sampler_min_p
(
device
:
str
,
batch_size
:
int
,
min_p
:
float
):
"""
Tests that when min_p is applied, tokens with probability below
min_p * max_prob are masked with -inf.
"""
torch
.
set_default_device
(
device
)
fake_logits
=
_create_fake_logits
(
batch_size
,
VOCAB_SIZE
)
# Create one dominant token per batch
for
i
in
range
(
batch_size
):
fake_logits
[
i
,
0
]
=
10.0
# High logit for first token
fake_logits
[
i
,
1
:]
=
1e-2
# Others remain low
sampling_metadata
=
_create_default_sampling_metadata
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
torch
.
device
(
device
))
# Configure min_p parameters
sampling_metadata
.
min_p
=
torch
.
full
((
batch_size
,
),
min_p
,
device
=
device
)
sampler
=
Sampler
()
logits
=
sampler
.
apply_min_p
(
fake_logits
,
sampling_metadata
.
min_p
)
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
for
token_id
in
range
(
VOCAB_SIZE
):
if
token_id
==
0
:
# Dominant token should always be unmasked
assert
logits
[
batch_idx
][
token_id
]
!=
-
float
(
"inf"
)
else
:
if
min_p
>
0.0
:
# Non-dominant tokens should be masked when min_p > 0
assert
logits
[
batch_idx
][
token_id
]
==
-
float
(
"inf"
)
else
:
# No masking when min_p is 0
assert
logits
[
batch_idx
][
token_id
]
!=
-
float
(
"inf"
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"bias_value"
,
[
-
0.1
,
1.2
])
def
test_sampler_logit_bias
(
device
:
str
,
batch_size
:
int
,
bias_value
:
float
):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch
.
set_default_device
(
device
)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits
=
_create_fake_logits
(
batch_size
,
VOCAB_SIZE
)
sampling_metadata
=
_create_default_sampling_metadata
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
torch
.
device
(
device
))
sampling_metadata
.
logit_bias
=
_create_logit_bias
(
batch_size
=
batch_size
,
vocab_size
=
VOCAB_SIZE
,
bias_value
=
bias_value
,
)
sampler
=
Sampler
()
logits
=
sampler
.
apply_logits_bias
(
fake_logits
,
sampling_metadata
)
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
logits_for_req
=
logits
[
batch_idx
]
biased_index
=
min
(
batch_idx
,
VOCAB_SIZE
-
1
)
for
token_id
in
range
(
VOCAB_SIZE
):
if
biased_index
==
token_id
:
assert
logits_for_req
[
token_id
]
==
pytest
.
approx
(
bias_value
+
1e-2
)
else
:
assert
logits_for_req
[
token_id
]
==
pytest
.
approx
(
1e-2
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"num_allowed_token_ids"
,
[
0
,
1
,
2
])
...
...
tests/v1/sample/test_topk_topp_sampler.py
View file @
99324e25
...
...
@@ -2,25 +2,26 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
flashinfer.sampling
import
top_k_renorm_probs
,
top_p_renorm_probs
from
torch
import
Generator
from
vllm.platforms
import
current_platform
from
vllm.v1.sample.ops.topk_topp_sampler
import
(
apply_top_k_top_p
,
is_flashinfer_available
)
DEVICE
=
"
cu
da"
DEVICE
=
cu
rrent_platform
.
device_type
BATCH_SIZE
=
1024
VOCAB_SIZE
=
128
*
1024
FLASHINFER_ENABLED
=
current_platform
.
is_cuda
()
and
is_flashinfer_available
if
is_flashinfer_available
:
from
flashinfer.sampling
import
top_k_renorm_probs
,
top_p_renorm_probs
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_default_device
():
"""
Explicitly set the default device, which can affect subsequent tests.
Explicitly set the default device, which can affect subsequent tests.
Adding this fixture helps avoid this problem.
"""
original_device
=
torch
.
get_default_device
()
...
...
@@ -28,7 +29,7 @@ def reset_default_device():
torch
.
set_default_device
(
original_device
)
def
test_topk_impl_equival
a
nce
():
def
test_topk_impl_equival
e
nce
():
torch
.
set_default_device
(
DEVICE
)
generator
=
Generator
(
device
=
DEVICE
).
manual_seed
(
33
)
...
...
@@ -58,8 +59,8 @@ def test_flashinfer_sampler():
This test verifies that the FlashInfer top-k and top-p sampling
implementation produces the same results as the Python implementation.
NOTE: FlashInfer did not directly expose an interface for fused top-k and
top-p prob renorm (it did provide fused sampling but we cannot compare
NOTE: FlashInfer did not directly expose an interface for fused top-k and
top-p prob renorm (it did provide fused sampling but we cannot compare
sampling results due to randomness), so we will compare the probability
renormed consequently by top-k and then top-p of FlashInfer implementation.
'''
...
...
tests/v1/sample/utils.py
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterator
from
enum
import
Enum
from
typing
import
Optional
from
typing
import
NamedTuple
,
Optional
import
regex
as
re
import
torch
from
vllm
import
CompletionOutput
from
vllm.utils
import
make_tensor_with_pad
from
vllm.v1.sample.logits_processor
import
BatchUpdate
,
LogitsProcessor
from
vllm.v1.sample.metadata
import
SamplingMetadata
class
BatchLogprobsComposition
(
Enum
):
...
...
@@ -134,3 +139,77 @@ def compute_correct_cumulative_logprob(
logprobs
=
completion_output
.
logprobs
assert
logprobs
is
not
None
return
sum
([
lp
[
tok_id
].
logprob
for
tok_id
,
lp
in
zip
(
token_ids
,
logprobs
)])
def
create_fake_logits
(
batch_size
:
int
,
vocab_size
:
int
)
->
torch
.
Tensor
:
fake_logits
=
torch
.
full
((
batch_size
,
vocab_size
),
1e-2
,
dtype
=
torch
.
float
)
return
fake_logits
def
create_penalty_tensor
(
batch_size
:
int
,
penalty_value
:
float
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
return
torch
.
full
((
batch_size
,
),
fill_value
=
penalty_value
,
dtype
=
torch
.
float
,
device
=
device
)
def
create_prompt_tokens_tensor
(
prompt_token_ids
:
list
[
list
[
int
]],
vocab_size
:
int
,
device
:
torch
.
device
,
)
->
torch
.
Tensor
:
return
make_tensor_with_pad
(
prompt_token_ids
,
pad
=
vocab_size
,
device
=
device
,
dtype
=
torch
.
int64
,
pin_memory
=
False
,
)
class
LogitsprocsTestFakes
(
NamedTuple
):
"""Wraps fake data structures to support testing"""
logits
:
torch
.
Tensor
sampling_metadata
:
SamplingMetadata
def
get_logitsprocs_by_cls
(
self
,
cls
:
type
[
LogitsProcessor
],
)
->
Iterator
[
LogitsProcessor
]:
"""Yield logits processors of a specific class.
Args:
cls: :class:`LogitsProcessor` subclass
Returns:
Iterator over logits processors
"""
return
(
lp
for
lp
in
self
.
sampling_metadata
.
logitsprocs
.
all
if
isinstance
(
lp
,
cls
))
def
get_logitsprocs
(
self
)
->
Iterator
[
LogitsProcessor
]:
"""Iterator over all logits processors."""
return
self
.
sampling_metadata
.
logitsprocs
.
all
def
fake_update_logitsprocs_state
(
test_fakes
:
LogitsprocsTestFakes
,
batch_update
:
BatchUpdate
,
)
->
None
:
"""Imitate logits processors persistent batch state update
in engine core"""
for
logitproc
in
test_fakes
.
get_logitsprocs
():
logitproc
.
update_state
(
batch_update
)
def
fake_apply_logitsprocs
(
test_fakes
:
LogitsprocsTestFakes
,
slice_indices
:
list
[
int
],
)
->
torch
.
Tensor
:
"""Imitate application of logits processors in engine core"""
logits
=
test_fakes
.
logits
[
torch
.
tensor
(
slice_indices
,
dtype
=
torch
.
long
)].
clone
()
for
processor
in
test_fakes
.
get_logitsprocs
():
logits
=
processor
.
apply
(
logits
)
return
logits
tests/v1/spec_decode/test_eagle.py
View file @
99324e25
...
...
@@ -10,6 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
VllmConfig
)
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.platforms
import
current_platform
from
vllm.v1.spec_decode.eagle
import
EagleProposer
model_dir
=
"meta-llama/Llama-3.1-8B-Instruct"
...
...
@@ -38,15 +39,17 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
num_speculative_tokens
=
k
,
)
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
CacheConfig
(),
speculative_config
=
speculative_config
,
device_config
=
DeviceConfig
(
device
=
"cuda"
),
parallel_config
=
ParallelConfig
(),
load_config
=
LoadConfig
(),
scheduler_config
=
SchedulerConfig
())
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
CacheConfig
(),
speculative_config
=
speculative_config
,
device_config
=
DeviceConfig
(
device
=
current_platform
.
device_type
),
parallel_config
=
ParallelConfig
(),
load_config
=
LoadConfig
(),
scheduler_config
=
SchedulerConfig
())
return
EagleProposer
(
vllm_config
=
vllm_config
,
device
=
'cuda'
)
return
EagleProposer
(
vllm_config
=
vllm_config
,
device
=
current_platform
.
device_type
)
def
test_prepare_inputs
():
...
...
@@ -59,7 +62,7 @@ def test_prepare_inputs():
a, a + 1, ..., a + b - n2 - 1,
a + b, a + b + 1, ..., a + b + c - n3 - 1]
"""
device
=
torch
.
device
(
'
cu
da'
)
device
=
torch
.
device
(
cu
rrent_platform
.
device_type
)
# a = 4, b = 7, c = 5
# n1 = 1, n2 = 3, n3 = 2
...
...
@@ -198,7 +201,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
@
pytest
.
mark
.
parametrize
(
"num_speculative_tokens"
,
[
1
,
3
,
8
])
def
test_propose
(
num_speculative_tokens
):
# Use GPU device
device
=
torch
.
device
(
'
cu
da'
)
device
=
torch
.
device
(
cu
rrent_platform
.
device_type
)
# Setup test parameters
batch_size
=
2
...
...
tests/v1/test_async_llm_dp.py
View file @
99324e25
...
...
@@ -4,24 +4,30 @@
import
asyncio
import
os
from
contextlib
import
ExitStack
from
dataclasses
import
dataclass
from
typing
import
Optional
import
pytest
from
vllm
import
SamplingParams
from
vllm.config
import
VllmConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.inputs
import
PromptType
from
vllm.platforms
import
current_platform
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.v1.engine.async_llm
import
AsyncLLM
from
vllm.v1.engine.core_client
import
DPAsyncMPClient
from
vllm.v1.metrics.loggers
import
StatLoggerBase
from
vllm.v1.metrics.stats
import
IterationStats
,
SchedulerStats
DP_SIZE
=
int
(
os
.
getenv
(
"DP_SIZE"
,
2
))
engine_args
=
AsyncEngineArgs
(
model
=
"ibm-research/PowerMoE-3b"
,
enforce_eager
=
True
,
disable_log_requests
=
True
,
tensor_parallel_size
=
int
(
os
.
getenv
(
"TP_SIZE"
,
1
)),
data_parallel_size
=
int
(
os
.
getenv
(
"
DP_SIZE
"
,
2
))
,
data_parallel_size
=
DP_SIZE
,
)
if
not
current_platform
.
supports_v1
(
engine_args
.
create_model_config
()):
...
...
@@ -74,12 +80,32 @@ async def generate(
async
def
test_load
(
output_kind
:
RequestOutputKind
,
data_parallel_backend
:
str
):
stats_loggers
=
{}
@
dataclass
class
SimpleStatsLogger
(
StatLoggerBase
):
init_count
:
int
=
0
finished_req_count
:
int
=
0
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_index
:
int
=
0
):
stats_loggers
[
engine_index
]
=
self
def
record
(
self
,
scheduler_stats
:
Optional
[
SchedulerStats
],
iteration_stats
:
Optional
[
IterationStats
]):
if
iteration_stats
:
self
.
finished_req_count
+=
len
(
iteration_stats
.
finished_requests
)
def
log_engine_initialized
(
self
):
self
.
init_count
+=
1
with
ExitStack
()
as
after
:
prompt
=
"This is a test of data parallel"
engine_args
.
data_parallel_backend
=
data_parallel_backend
engine
=
AsyncLLM
.
from_engine_args
(
engine_args
)
engine
=
AsyncLLM
.
from_engine_args
(
engine_args
,
stat_loggers
=
[
SimpleStatsLogger
])
after
.
callback
(
engine
.
shutdown
)
NUM_REQUESTS
=
100
...
...
@@ -92,12 +118,10 @@ async def test_load(output_kind: RequestOutputKind,
for
request_id
in
request_ids
:
tasks
.
append
(
asyncio
.
create_task
(
generate
(
engine
,
request_id
,
prompt
,
output_kind
,
NUM_EXPECTED_TOKENS
,
data_parallel_rank
=
0
)))
generate
(
engine
,
request_id
,
prompt
,
output_kind
,
NUM_EXPECTED_TOKENS
)))
# Short sleep to ensure that requests are distributed.
await
asyncio
.
sleep
(
0.01
)
# Confirm that we got all the EXPECTED tokens from the requests.
done
,
pending
=
await
asyncio
.
wait
(
tasks
,
return_when
=
asyncio
.
FIRST_EXCEPTION
)
...
...
@@ -122,3 +146,14 @@ async def test_load(output_kind: RequestOutputKind,
assert
not
core_client
.
engines_running
assert
not
core_client
.
reqs_in_flight
# Check that requests were distributed between the engines
print
(
f
"Stats loggers after test:
{
stats_loggers
}
"
)
assert
len
(
stats_loggers
)
==
DP_SIZE
assert
stats_loggers
[
0
].
init_count
==
1
for
sl
in
stats_loggers
.
values
():
slogger
:
SimpleStatsLogger
=
sl
assert
slogger
.
finished_req_count
>
NUM_REQUESTS
//
(
DP_SIZE
+
1
),
f
"requests are imbalanced:
{
stats_loggers
}
"
tests/v1/test_external_lb_dp.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
os
import
threading
import
time
from
contextlib
import
AsyncExitStack
import
openai
# use the official client for correctness check
import
pytest
import
pytest_asyncio
from
tests.utils
import
RemoteOpenAIServer
from
vllm.platforms
import
Platform
MODEL_NAME
=
"ibm-research/PowerMoE-3b"
# Number of data parallel ranks for external LB testing
DP_SIZE
=
int
(
os
.
getenv
(
"DP_SIZE"
,
"2"
))
# Default tensor parallell size to use
TP_SIZE
=
int
(
os
.
getenv
(
"TP_SIZE"
,
"1"
))
class
ExternalLBServerManager
:
"""Manages data parallel vLLM server instances for external
load balancer testing."""
def
__init__
(
self
,
model_name
:
str
,
dp_size
:
int
,
api_server_count
:
int
,
base_server_args
:
list
,
tp_size
:
int
=
TP_SIZE
):
self
.
model_name
=
model_name
self
.
dp_size
=
dp_size
self
.
tp_size
=
tp_size
self
.
api_server_count
=
api_server_count
self
.
base_server_args
=
base_server_args
self
.
servers
:
list
[
tuple
[
RemoteOpenAIServer
,
list
[
str
]]]
=
[]
self
.
server_threads
:
list
[
threading
.
Thread
]
=
[]
def
__enter__
(
self
)
->
list
[
tuple
[
RemoteOpenAIServer
,
list
[
str
]]]:
"""Start all server instances for external LB mode."""
for
rank
in
range
(
self
.
dp_size
):
# Create server args for this specific rank
server_args
=
self
.
base_server_args
.
copy
()
# Add external LB specific arguments
server_args
.
extend
([
"--data-parallel-size"
,
str
(
self
.
dp_size
),
"--data-parallel-rank"
,
str
(
rank
),
"--data-parallel-size-local"
,
"1"
,
"--tensor-parallel-size"
,
str
(
self
.
tp_size
),
"--port"
,
str
(
8000
+
rank
),
# Different port for each rank
"--api-server-count"
,
str
(
self
.
api_server_count
),
])
# Use a thread to start each server to allow parallel initialization
def
start_server
(
r
:
int
,
sargs
:
list
[
str
]):
try
:
# Start the server
server
=
RemoteOpenAIServer
(
self
.
model_name
,
sargs
,
auto_port
=
False
,
env_dict
=
{
"CUDA_VISIBLE_DEVICES"
:
","
.
join
(
str
(
Platform
.
device_id_to_physical_device_id
(
i
))
for
i
in
range
(
r
*
TP_SIZE
,
(
r
+
1
)
*
TP_SIZE
))
})
server
.
__enter__
()
print
(
f
"Server rank
{
r
}
started successfully with "
f
"
{
self
.
api_server_count
}
API servers"
)
self
.
servers
.
append
((
server
,
sargs
))
except
Exception
as
e
:
print
(
f
"Failed to start server rank
{
r
}
:
{
e
}
"
)
raise
thread
=
threading
.
Thread
(
target
=
start_server
,
args
=
(
rank
,
server_args
))
thread
.
start
()
self
.
server_threads
.
append
(
thread
)
# Wait for all servers to start
for
thread
in
self
.
server_threads
:
thread
.
join
()
# Give servers additional time to fully initialize and coordinate
time
.
sleep
(
2
)
if
len
(
self
.
servers
)
!=
self
.
dp_size
:
raise
Exception
(
"Servers failed to start"
)
return
self
.
servers
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
"""Stop all server instances."""
while
self
.
servers
:
try
:
self
.
servers
.
pop
()[
0
].
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
except
Exception
as
e
:
print
(
f
"Error stopping server:
{
e
}
"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
default_server_args
():
return
[
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"bfloat16"
,
"--max-model-len"
,
"2048"
,
"--max-num-seqs"
,
"128"
,
"--enforce-eager"
,
]
@
pytest
.
fixture
(
scope
=
"module"
,
params
=
[
1
,
4
])
def
servers
(
request
,
default_server_args
):
api_server_count
=
request
.
param
with
ExternalLBServerManager
(
MODEL_NAME
,
DP_SIZE
,
api_server_count
,
default_server_args
)
as
server_list
:
yield
server_list
@
pytest_asyncio
.
fixture
async
def
clients
(
servers
:
list
[
tuple
[
RemoteOpenAIServer
,
list
[
str
]]]):
# Create a client for each server
async
with
AsyncExitStack
()
as
stack
:
yield
[
await
stack
.
enter_async_context
(
server
.
get_async_client
())
for
server
,
_
in
servers
]
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
],
)
async
def
test_external_lb_single_completion
(
clients
:
list
[
openai
.
AsyncOpenAI
],
servers
:
list
[
tuple
[
RemoteOpenAIServer
,
list
[
str
]]],
model_name
:
str
)
->
None
:
async
def
make_request
(
client
:
openai
.
AsyncOpenAI
):
completion
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
"Hello, my name is"
,
max_tokens
=
10
,
temperature
=
1.0
)
assert
completion
.
id
is
not
None
assert
completion
.
choices
is
not
None
and
len
(
completion
.
choices
)
==
1
choice
=
completion
.
choices
[
0
]
# The exact number of tokens can vary slightly with temperature=1.0,
# so we check for a reasonable minimum length.
assert
len
(
choice
.
text
)
>=
1
# Finish reason might not always be 'length' if the model finishes early
# or due to other reasons, especially with high temperature.
# So, we'll accept 'length' or 'stop'.
assert
choice
.
finish_reason
in
(
"length"
,
"stop"
)
# Token counts can also vary, so we check they are positive.
assert
completion
.
usage
.
completion_tokens
>
0
assert
completion
.
usage
.
prompt_tokens
>
0
assert
completion
.
usage
.
total_tokens
>
0
return
completion
# Test single request to each server
for
i
,
client
in
enumerate
(
clients
):
result
=
await
make_request
(
client
)
assert
result
is
not
None
print
(
f
"Server
{
i
}
handled single completion request successfully"
)
await
asyncio
.
sleep
(
0.5
)
# Send requests to all servers in round-robin fashion
num_requests_per_server
=
25
# Total 50 requests across 2 servers
all_tasks
=
[]
for
i
,
client
in
enumerate
(
clients
):
tasks
=
[
make_request
(
client
)
for
_
in
range
(
num_requests_per_server
)]
all_tasks
.
extend
(
tasks
)
results
=
await
asyncio
.
gather
(
*
all_tasks
)
assert
len
(
results
)
==
num_requests_per_server
*
len
(
clients
)
assert
all
(
completion
is
not
None
for
completion
in
results
)
await
asyncio
.
sleep
(
0.5
)
# Second burst of requests
all_tasks
=
[]
for
i
,
client
in
enumerate
(
clients
):
tasks
=
[
make_request
(
client
)
for
_
in
range
(
num_requests_per_server
)]
all_tasks
.
extend
(
tasks
)
results
=
await
asyncio
.
gather
(
*
all_tasks
)
assert
len
(
results
)
==
num_requests_per_server
*
len
(
clients
)
assert
all
(
completion
is
not
None
for
completion
in
results
)
_
,
server_args
=
servers
[
0
]
api_server_count
=
(
server_args
.
count
(
'--api-server-count'
)
and
server_args
[
server_args
.
index
(
'--api-server-count'
)
+
1
]
or
1
)
print
(
f
"Successfully completed external LB test with
{
len
(
clients
)
}
servers "
f
"(API server count:
{
api_server_count
}
)"
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
],
)
async
def
test_external_lb_completion_streaming
(
clients
:
list
[
openai
.
AsyncOpenAI
],
servers
:
list
[
tuple
[
RemoteOpenAIServer
,
list
[
str
]]],
model_name
:
str
)
->
None
:
prompt
=
"What is an LLM?"
async
def
make_streaming_request
(
client
:
openai
.
AsyncOpenAI
):
# Perform a non-streaming request to get the expected full output
single_completion
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
prompt
,
max_tokens
=
5
,
temperature
=
0.0
,
)
single_output
=
single_completion
.
choices
[
0
].
text
# Perform the streaming request
stream
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
prompt
,
max_tokens
=
5
,
temperature
=
0.0
,
stream
=
True
)
chunks
:
list
[
str
]
=
[]
finish_reason_count
=
0
last_chunk
=
None
async
for
chunk
in
stream
:
chunks
.
append
(
chunk
.
choices
[
0
].
text
)
if
chunk
.
choices
[
0
].
finish_reason
is
not
None
:
finish_reason_count
+=
1
last_chunk
=
chunk
# Keep track of the last chunk
# finish reason should only return in the last block for OpenAI API
assert
finish_reason_count
==
1
,
(
"Finish reason should appear exactly once."
)
assert
last_chunk
is
not
None
,
(
"Stream should have yielded at least one chunk."
)
assert
last_chunk
.
choices
[
0
].
finish_reason
==
"length"
,
"Finish reason should be 'length'."
# Check that the combined text matches the non-streamed version.
assert
""
.
join
(
chunks
)
==
single_output
,
"Streamed output should match non-streamed output."
return
True
# Indicate success for this request
# Test single request to each server
for
i
,
client
in
enumerate
(
clients
):
result
=
await
make_streaming_request
(
client
)
assert
result
is
not
None
print
(
f
"Server
{
i
}
handled single streaming request successfully"
)
await
asyncio
.
sleep
(
0.5
)
# Send streaming requests to all servers in round-robin fashion
num_requests_per_server
=
25
# Total 50 requests across 2 servers
all_tasks
=
[]
for
i
,
client
in
enumerate
(
clients
):
tasks
=
[
make_streaming_request
(
client
)
for
_
in
range
(
num_requests_per_server
)
]
all_tasks
.
extend
(
tasks
)
results
=
await
asyncio
.
gather
(
*
all_tasks
)
assert
len
(
results
)
==
num_requests_per_server
*
len
(
clients
)
assert
all
(
results
),
"Not all streaming requests completed successfully."
await
asyncio
.
sleep
(
0.5
)
# Second burst of streaming requests
all_tasks
=
[]
for
i
,
client
in
enumerate
(
clients
):
tasks
=
[
make_streaming_request
(
client
)
for
_
in
range
(
num_requests_per_server
)
]
all_tasks
.
extend
(
tasks
)
results
=
await
asyncio
.
gather
(
*
all_tasks
)
assert
len
(
results
)
==
num_requests_per_server
*
len
(
clients
)
assert
all
(
results
),
"Not all streaming requests completed successfully."
_
,
server_args
=
servers
[
0
]
api_server_count
=
(
server_args
.
count
(
'--api-server-count'
)
and
server_args
[
server_args
.
index
(
'--api-server-count'
)
+
1
]
or
1
)
print
(
f
"Successfully completed external LB streaming test with "
f
"
{
len
(
clients
)
}
servers (API server count:
{
api_server_count
}
)"
)
tests/v1/test_oracle.py
View file @
99324e25
...
...
@@ -12,8 +12,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
UNSUPPORTED_MODELS_V1
=
[
"openai/whisper-large-v3"
,
# transcription
"facebook/bart-large-cnn"
,
# encoder decoder
"mistralai/Mamba-Codestral-7B-v0.1"
,
# mamba
"hmellor/tiny-random-BambaForCausalLM"
,
# hybrid
"state-spaces/mamba-130m-hf"
,
# mamba1
"BAAI/bge-m3"
,
# embedding
]
...
...
@@ -74,12 +73,6 @@ def test_unsupported_configs(monkeypatch):
disable_async_output_proc
=
True
,
).
create_engine_config
()
with
pytest
.
raises
(
NotImplementedError
):
AsyncEngineArgs
(
model
=
MODEL
,
scheduling_policy
=
"priority"
,
).
create_engine_config
()
with
pytest
.
raises
(
NotImplementedError
):
AsyncEngineArgs
(
model
=
MODEL
,
...
...
tests/v1/test_request.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.v1.request
import
RequestStatus
def
test_request_status_fmt_str
():
"""Test that the string representation of RequestStatus is correct."""
assert
f
"
{
RequestStatus
.
WAITING
}
"
==
"WAITING"
assert
f
"
{
RequestStatus
.
WAITING_FOR_FSM
}
"
==
"WAITING_FOR_FSM"
assert
f
"
{
RequestStatus
.
WAITING_FOR_REMOTE_KVS
}
"
==
"WAITING_FOR_REMOTE_KVS"
assert
f
"
{
RequestStatus
.
RUNNING
}
"
==
"RUNNING"
assert
f
"
{
RequestStatus
.
PREEMPTED
}
"
==
"PREEMPTED"
assert
f
"
{
RequestStatus
.
FINISHED_STOPPED
}
"
==
"FINISHED_STOPPED"
assert
f
"
{
RequestStatus
.
FINISHED_LENGTH_CAPPED
}
"
==
"FINISHED_LENGTH_CAPPED"
assert
f
"
{
RequestStatus
.
FINISHED_ABORTED
}
"
==
"FINISHED_ABORTED"
assert
f
"
{
RequestStatus
.
FINISHED_IGNORED
}
"
==
"FINISHED_IGNORED"
tests/v1/tpu/test_basic.py
View file @
99324e25
...
...
@@ -67,6 +67,43 @@ def test_basic(
assert
"1024"
in
output
or
"0, 1"
in
output
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This is a basic test for TPU only"
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"max_num_seqs"
,
[
16
])
def
test_phi3
(
vllm_runner
:
type
[
VllmRunner
],
monkeypatch
:
pytest
.
MonkeyPatch
,
max_tokens
:
int
,
max_num_seqs
:
int
,
)
->
None
:
prompts
=
[
"A robot may not injure a human being"
,
"It is only with the heart that one can see rightly;"
,
"The greatest glory in living lies not in never falling,"
,
]
answers
=
[
" or, by violating privacy"
,
" what is essential is love."
,
" but in rising every time we fall."
,
]
# test head dim = 96
model
=
"microsoft/Phi-3-mini-128k-instruct"
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
with
vllm_runner
(
model
,
max_num_batched_tokens
=
256
,
max_num_seqs
=
max_num_seqs
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
prompts
,
max_tokens
)
# vllm_outputs is a list of tuples whose first element is the token id
# and the second element is the output (including the prompt).
for
output
,
answer
in
zip
(
vllm_outputs
,
answers
):
generated_text
=
output
[
1
]
assert
answer
in
generated_text
TP_SIZE_8
=
8
...
...
tests/v1/tpu/test_kv_cache_update_kernel.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
numpy
as
np
import
pytest
import
torch
import
torch_xla
import
vllm.v1.attention.backends.pallas
# noqa: F401
from
vllm.platforms
import
current_platform
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This is a test for TPU only"
)
@
pytest
.
mark
.
parametrize
(
"page_size"
,
[
32
,
33
])
@
pytest
.
mark
.
parametrize
(
"combined_kv_head_num"
,
[
2
,
16
])
@
pytest
.
mark
.
parametrize
(
"head_dim"
,
[
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"num_slices_per_block"
,
[
4
,
8
])
def
test_kv_cache_update_kernel
(
page_size
:
int
,
combined_kv_head_num
:
int
,
head_dim
:
int
,
num_slices_per_block
:
int
):
page_num
=
1000
padded_num_tokens
=
128
kv_cache_cpu
=
torch
.
zeros
(
(
page_num
*
page_size
,
combined_kv_head_num
,
head_dim
),
dtype
=
torch
.
bfloat16
,
device
=
"cpu"
)
kv_cache_xla
=
kv_cache_cpu
.
to
(
torch_xla
.
device
())
new_kv_cpu
=
torch
.
randn
(
(
padded_num_tokens
,
combined_kv_head_num
,
head_dim
),
dtype
=
torch
.
bfloat16
,
device
=
"cpu"
)
new_kv_xla
=
new_kv_cpu
.
to
(
torch_xla
.
device
())
slice_lens
=
np
.
array
([
7
,
page_size
,
page_size
,
1
,
1
,
1
,
9
],
dtype
=
np
.
int32
)
num_kv_update_slices
=
len
(
slice_lens
)
kv_cache_start_indices
=
np
.
array
([
page_size
*
2
-
7
,
page_size
*
2
,
page_size
*
3
,
page_size
*
4
+
6
,
page_size
*
5
+
7
,
page_size
*
6
+
8
,
page_size
*
15
+
3
],
dtype
=
np
.
int32
)
new_kv_cache_indices
=
np
.
concatenate
(
[
np
.
array
([
0
],
dtype
=
np
.
int32
),
np
.
cumsum
(
slice_lens
[:
-
1
])])
slot_mapping
=
np
.
stack
(
[
kv_cache_start_indices
,
new_kv_cache_indices
,
slice_lens
],
axis
=
1
)
padded_size
=
(
slot_mapping
.
shape
[
0
]
+
num_slices_per_block
-
1
)
//
num_slices_per_block
*
num_slices_per_block
slot_mapping
=
np
.
pad
(
slot_mapping
,
[[
0
,
padded_size
-
slot_mapping
.
shape
[
0
]],
[
0
,
0
]],
constant_values
=
0
)
slot_mapping
=
np
.
transpose
(
slot_mapping
)
slot_mapping_cpu
=
torch
.
tensor
(
slot_mapping
,
device
=
"cpu"
,
dtype
=
torch
.
int32
)
slot_mapping_xla
=
slot_mapping_cpu
.
to
(
torch_xla
.
device
())
num_kv_update_slices_xla
=
torch
.
tensor
([
num_kv_update_slices
],
device
=
torch_xla
.
device
(),
dtype
=
torch
.
int32
)
torch_xla
.
sync
()
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
kv_cache_xla
,
True
)
new_kv_cache_xla
=
torch
.
ops
.
xla
.
kv_cache_update_op
(
new_kv_xla
,
slot_mapping_xla
,
kv_cache_xla
,
num_kv_update_slices_xla
,
page_size
,
num_slices_per_block
)
kv_cache_xla
.
copy_
(
new_kv_cache_xla
)
torch_xla
.
sync
()
for
ni
,
ci
,
sl
in
zip
(
new_kv_cache_indices
,
kv_cache_start_indices
,
slice_lens
):
kv_cache_cpu
[
ci
:
ci
+
sl
,
:,
:]
=
new_kv_cpu
[
ni
:
ni
+
sl
,
:,
:]
assert
torch
.
allclose
(
kv_cache_xla
.
cpu
(),
kv_cache_cpu
,
atol
=
1e-4
,
rtol
=
1e-4
)
tests/v1/tpu/test_pallas.py
View file @
99324e25
...
...
@@ -47,7 +47,7 @@ def test_ragged_paged_attention():
key
=
torch
.
zeros
(
num_tokens
,
num_kv_heads
*
head_size
)
value
=
torch
.
zeros
(
num_tokens
,
num_kv_heads
*
head_size
)
kv_cache
=
torch
.
zeros
(
num_blocks
,
block_size
,
num_kv_heads
*
2
,
head_size
)
slot_mapping
=
torch
.
zeros
(
num_tokens
,
dtype
=
torch
.
int64
)
slot_mapping
=
torch
.
zeros
(
(
3
,
num_tokens
)
,
dtype
=
torch
.
int64
)
max_num_reqs
=
8
max_num_blocks_per_req
=
8
block_tables
=
torch
.
zeros
((
max_num_reqs
,
max_num_blocks_per_req
),
...
...
@@ -65,6 +65,7 @@ def test_ragged_paged_attention():
context_lens
=
context_lens
,
query_start_loc
=
query_start_loc
,
num_seqs
=
num_seqs
,
num_slices_per_kv_cache_update_block
=
8
,
)
with
patch
(
"torch.ops.xla.ragged_paged_attention"
...
...
tests/v1/tpu/test_spmd_model_weight_loading.py
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
gc
import
tempfile
...
...
tests/v1/tpu/test_tpu_qkv_linear.py
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
tempfile
import
numpy
as
np
...
...
Prev
1
…
18
19
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment