Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
99324e25
Commit
99324e25
authored
Jul 12, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.2' into v0.9.2-ori
parents
cc7f22a8
a5dd03c1
Changes
475
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1543 additions
and
205 deletions
+1543
-205
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+302
-3
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
+2
-2
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
+6
-6
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+2
-1
tests/v1/sample/test_logits_processors.py
tests/v1/sample/test_logits_processors.py
+627
-0
tests/v1/sample/test_logprobs_e2e.py
tests/v1/sample/test_logprobs_e2e.py
+4
-3
tests/v1/sample/test_rejection_sampler.py
tests/v1/sample/test_rejection_sampler.py
+7
-7
tests/v1/sample/test_sampler.py
tests/v1/sample/test_sampler.py
+5
-149
tests/v1/sample/test_topk_topp_sampler.py
tests/v1/sample/test_topk_topp_sampler.py
+7
-6
tests/v1/sample/utils.py
tests/v1/sample/utils.py
+80
-1
tests/v1/spec_decode/test_eagle.py
tests/v1/spec_decode/test_eagle.py
+13
-10
tests/v1/test_async_llm_dp.py
tests/v1/test_async_llm_dp.py
+43
-8
tests/v1/test_external_lb_dp.py
tests/v1/test_external_lb_dp.py
+312
-0
tests/v1/test_oracle.py
tests/v1/test_oracle.py
+1
-8
tests/v1/test_request.py
tests/v1/test_request.py
+16
-0
tests/v1/tpu/test_basic.py
tests/v1/tpu/test_basic.py
+37
-0
tests/v1/tpu/test_kv_cache_update_kernel.py
tests/v1/tpu/test_kv_cache_update_kernel.py
+75
-0
tests/v1/tpu/test_pallas.py
tests/v1/tpu/test_pallas.py
+2
-1
tests/v1/tpu/test_spmd_model_weight_loading.py
tests/v1/tpu/test_spmd_model_weight_loading.py
+1
-0
tests/v1/tpu/test_tpu_qkv_linear.py
tests/v1/tpu/test_tpu_qkv_linear.py
+1
-0
No files found.
Too many changes to show.
To preserve performance only
475 of 475+
files are displayed.
Plain diff
Email patch
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
time
import
uuid
from
collections
import
defaultdict
from
typing
import
Optional
from
unittest.mock
import
patch
import
pytest
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
NixlConnectorMetadata
)
KVConnectorRole
,
NixlAgentMetadata
,
NixlConnector
,
NixlConnectorMetadata
,
NixlConnectorWorker
)
from
vllm.forward_context
import
ForwardContext
from
.utils
import
create_request
,
create_scheduler
,
create_vllm_config
from
.utils
import
create_request
,
create_scheduler
,
create_vllm_config
def
test_basic_in
f
erface
():
def
test_basic_in
t
erface
():
"""Unit test for basic NixlConnector interface functionality."""
"""Unit test for basic NixlConnector interface functionality."""
vllm_config
=
create_vllm_config
()
vllm_config
=
create_vllm_config
()
...
@@ -25,7 +35,7 @@ def test_basic_inferface():
...
@@ -25,7 +35,7 @@ def test_basic_inferface():
scheduler
.
add_request
(
request
)
scheduler
.
add_request
(
request
)
# Remote Prefill, triggers NixlConnectorMetdata.
# Remote Prefill, triggers NixlConnectorMet
a
data.
scheduler_output
=
scheduler
.
schedule
()
scheduler_output
=
scheduler
.
schedule
()
kv_connector_metadata
=
scheduler_output
.
kv_connector_metadata
kv_connector_metadata
=
scheduler_output
.
kv_connector_metadata
assert
kv_connector_metadata
is
not
None
assert
kv_connector_metadata
is
not
None
...
@@ -72,3 +82,292 @@ def test_prompt_less_than_block_size():
...
@@ -72,3 +82,292 @@ def test_prompt_less_than_block_size():
# This request should be scheduled regularly.
# This request should be scheduled regularly.
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
1
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
1
class
FakeNixlWrapper
:
"""Mock implementation of NixlWrapper for testing.
We don't inherit from nixl._api.nixl_agent because nixl may not be
installed.
"""
AGENT_METADATA
=
b
"fake_agent_metadata"
REMOTE_AGENT_NAME
=
"remote_agent"
def
__init__
(
self
,
agent_name
:
str
,
*
args
,
**
kwargs
):
self
.
_cycles_before_xfer_done
=
0
self
.
_check_xfer_state_cycles
:
defaultdict
[
int
,
int
]
=
defaultdict
(
lambda
:
0
)
def
get_reg_descs
(
self
,
caches_data
,
memory_type
:
str
)
->
list
:
return
[
str
(
uuid
.
uuid4
())
for
_
in
caches_data
]
def
register_memory
(
self
,
descs
)
->
None
:
pass
def
get_xfer_descs
(
self
,
blocks_data
,
memory_type
:
str
)
->
list
:
return
[
str
(
uuid
.
uuid4
())
for
_
in
blocks_data
]
def
prep_xfer_dlist
(
self
,
agent_name
:
str
,
descs
:
list
)
->
int
:
return
uuid
.
uuid4
().
int
def
get_agent_metadata
(
self
)
->
bytes
:
return
self
.
AGENT_METADATA
def
add_remote_agent
(
self
,
agent_metadata
:
bytes
)
->
str
:
return
self
.
REMOTE_AGENT_NAME
def
get_new_notifs
(
self
)
->
dict
[
str
,
list
[
bytes
]]:
# Used to collect done_sending, which we don't test yet.
return
{}
def
check_xfer_state
(
self
,
handle
:
int
)
->
str
:
if
self
.
_check_xfer_state_cycles
[
handle
]
>=
self
.
_cycles_before_xfer_done
:
return
"DONE"
self
.
_check_xfer_state_cycles
[
handle
]
+=
1
return
"PROC"
def
release_xfer_handle
(
self
,
handle
:
int
)
->
None
:
pass
def
send_notif
(
self
,
agent_name
:
str
,
notif_msg
:
bytes
)
->
None
:
pass
def
make_prepped_xfer
(
self
,
xfer_type
:
str
,
local_xfer_side_handle
:
int
,
local_block_descs_ids
:
list
[
int
],
remote_xfer_side_handle
:
int
,
remote_block_descs_ids
:
list
[
int
],
notif_msg
:
Optional
[
bytes
]
=
None
)
->
int
:
return
uuid
.
uuid4
().
int
def
transfer
(
self
,
handle
:
int
)
->
str
:
return
"PROC"
############################################################
# Follow are for changing the behavior during testing.
############################################################
def
set_cycles_before_xfer_done
(
self
,
cycles
:
int
):
"""Set the number of cycles before a transfer is considered done."""
self
.
_cycles_before_xfer_done
=
cycles
class
FakeNixlConnectorWorker
(
NixlConnectorWorker
):
REMOTE_ENGINE_ID
=
"remote_engine"
def
__init__
(
self
,
*
args
,
hand_shake_latency
:
float
=
1.8
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_hand_shake_latency
=
hand_shake_latency
def
_nixl_handshake
(
self
,
host
:
str
,
port
:
int
,
remote_tp_size
:
int
)
->
dict
[
int
,
str
]:
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
time
.
sleep
(
self
.
_hand_shake_latency
)
# These should've been done in register_kv_caches(), called by
# gpu_model_runner. Here we just hardcode some dummy values.
self
.
slot_size_bytes
=
4096
self
.
block_len
=
self
.
slot_size_bytes
*
self
.
block_size
self
.
num_blocks
=
1
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
remote_agent_name
=
self
.
add_remote_agent
(
NixlAgentMetadata
(
engine_id
=
self
.
REMOTE_ENGINE_ID
,
agent_metadata
=
FakeNixlWrapper
.
AGENT_METADATA
,
kv_caches_base_addr
=
[
0
],
num_blocks
=
1
,
block_len
=
self
.
block_len
,
attn_backend_name
=
self
.
backend_name
,
),
remote_tp_size
=
remote_tp_size
)
return
{
0
:
remote_agent_name
}
class
TestNixlHandshake
:
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
)
def
test_multi_xfer_one_engine
(
self
,
# dist_init is a fixture that initializes the distributed environment.
dist_init
):
"""Test case where multiple xfers are initiated to the same engine.
This test triggers the connector to load remote KV for the same
`request_id`. The transfer is not done immediately due to
`set_cycles_before_xfer_done`, so there is a state where there are
multiple transfer states for the same `request_id`, and `get_finished`
should handle it correctly (wait for all transfers to be done).
"""
vllm_config
=
create_vllm_config
()
request_id
=
"req_id"
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
assert
isinstance
(
connector
.
connector_worker
.
nixl_wrapper
,
FakeNixlWrapper
)
connector
.
connector_worker
.
nixl_wrapper
.
set_cycles_before_xfer_done
(
3
)
num_xfers
=
4
while
True
:
# For the same request_id, initiate multiple xfers across different
# round of `execute_model` calls.
metadata
=
NixlConnectorMetadata
()
if
num_xfers
>
0
:
num_xfers
-=
1
metadata
.
add_new_req
(
request_id
=
request_id
,
local_block_ids
=
[
num_xfers
+
1
,
num_xfers
+
2
,
num_xfers
+
3
],
kv_transfer_params
=
{
"remote_block_ids"
:
[
num_xfers
+
4
,
num_xfers
+
5
,
num_xfers
+
6
],
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
1
,
})
connector
.
bind_connector_metadata
(
metadata
)
# Mimic maybe_setup_kv_connector in gpu_model_runner.
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
)
_before_load
=
time
.
perf_counter
()
connector
.
start_load_kv
(
dummy_ctx
)
_after_load
=
time
.
perf_counter
()
assert
_after_load
-
_before_load
<
0.1
,
"start_load_kv took "
\
f
"
{
_after_load
-
_before_load
}
seconds"
# Mimic get_finished_kv_transfers in gpu_model_runner.
_
,
done_recving
=
connector
.
get_finished
(
finished_req_ids
=
set
())
if
len
(
done_recving
)
>
0
:
assert
request_id
in
done_recving
break
connector
.
clear_connector_metadata
()
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
)
@
pytest
.
mark
.
parametrize
(
"decode_tp_size, prefill_tp_size"
,
[
(
1
,
1
),
(
2
,
1
),
(
4
,
2
),
(
4
,
4
),
])
def
test_async_load_kv
(
self
,
# Fixture that initializes the distributed environment.
dist_init
,
# Simulate consumer-producer TP sizes.
decode_tp_size
,
prefill_tp_size
):
"""Test that NixlConnector's start_load_kv should be non-blocking."""
vllm_config
=
create_vllm_config
()
vllm_config
.
parallel_config
.
tensor_parallel_size
=
decode_tp_size
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
)
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req
(
request_id
=
"id"
,
local_block_ids
=
[
1
,
2
,
3
],
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
prefill_tp_size
,
})
connector
.
bind_connector_metadata
(
metadata
)
timeout
=
2.5
start
=
time
.
perf_counter
()
while
time
.
perf_counter
()
-
start
<
timeout
:
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
)
_before_load
=
time
.
perf_counter
()
connector
.
start_load_kv
(
dummy_ctx
)
_after_load
=
time
.
perf_counter
()
assert
_after_load
-
_before_load
<
0.1
,
"start_load_kv took "
\
f
"
{
_after_load
-
_before_load
}
seconds"
time
.
sleep
(
0.5
)
# backoff for the async handshake to complete.
connector
.
bind_connector_metadata
(
NixlConnectorMetadata
())
_
,
done_recving
=
connector
.
get_finished
(
finished_req_ids
=
set
())
if
len
(
done_recving
)
>
0
:
return
raise
TimeoutError
(
"Took too long to complete async handshake."
)
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
)
def
test_concurrent_load_kv
(
self
,
# dist_init is a fixture that initializes the distributed environment.
dist_init
):
"""Test that multiple start_load_kv calls should occur concurrently."""
vllm_config
=
create_vllm_config
()
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
)
metadata
=
NixlConnectorMetadata
()
total_reqs
=
5
for
i
in
range
(
total_reqs
):
metadata
.
add_new_req
(
request_id
=
f
"id_
{
i
}
"
,
local_block_ids
=
[
1
,
2
,
3
],
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
1
,
})
connector
.
bind_connector_metadata
(
metadata
)
timeout
=
2.5
*
total_reqs
cnt_finished_reqs
=
0
start
=
time
.
perf_counter
()
while
time
.
perf_counter
()
-
start
<
timeout
:
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
)
_before_load
=
time
.
perf_counter
()
connector
.
start_load_kv
(
dummy_ctx
)
_after_load
=
time
.
perf_counter
()
assert
_after_load
-
_before_load
<
0.1
,
"start_load_kv took "
\
f
"
{
_after_load
-
_before_load
}
seconds"
time
.
sleep
(
0.5
)
# backoff for the async handshake to complete.
connector
.
bind_connector_metadata
(
NixlConnectorMetadata
())
_
,
done_recving
=
connector
.
get_finished
(
finished_req_ids
=
set
())
if
len
(
done_recving
)
>
0
:
cnt_finished_reqs
+=
len
(
done_recving
)
if
cnt_finished_reqs
==
total_reqs
:
return
raise
TimeoutError
(
"Took too long to complete async handshake."
)
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
View file @
99324e25
...
@@ -66,7 +66,7 @@ def test_basic_lifecycle():
...
@@ -66,7 +66,7 @@ def test_basic_lifecycle():
assert
len
(
scheduler_output
.
finished_req_ids
)
==
1
assert
len
(
scheduler_output
.
finished_req_ids
)
==
1
assert
request_id
in
scheduler_output
.
finished_req_ids
assert
request_id
in
scheduler_output
.
finished_req_ids
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
assert
len
(
scheduler_output
.
scheduled_cached_reqs
)
==
0
assert
scheduler_output
.
scheduled_cached_reqs
.
num_reqs
==
0
assert
len
(
scheduler
.
finished_req_ids
)
==
0
assert
len
(
scheduler
.
finished_req_ids
)
==
0
# (2b): execute_model()
# (2b): execute_model()
...
@@ -81,7 +81,7 @@ def test_basic_lifecycle():
...
@@ -81,7 +81,7 @@ def test_basic_lifecycle():
assert
len
(
scheduler
.
running
)
==
0
assert
len
(
scheduler
.
running
)
==
0
assert
len
(
scheduler_output
.
finished_req_ids
)
==
0
assert
len
(
scheduler_output
.
finished_req_ids
)
==
0
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
assert
len
(
scheduler_output
.
scheduled_cached_reqs
)
==
0
assert
scheduler_output
.
scheduled_cached_reqs
.
num_reqs
==
0
assert
len
(
scheduler
.
finished_req_ids
)
==
0
assert
len
(
scheduler
.
finished_req_ids
)
==
0
# (3b): execute_model()
# (3b): execute_model()
...
...
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
View file @
99324e25
...
@@ -36,7 +36,7 @@ def test_basic_lifecycle():
...
@@ -36,7 +36,7 @@ def test_basic_lifecycle():
# Nothing running and empty scheduler output.
# Nothing running and empty scheduler output.
assert
len
(
scheduler
.
running
)
==
0
assert
len
(
scheduler
.
running
)
==
0
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
assert
len
(
scheduler_output
.
scheduled_cached_reqs
)
==
0
assert
scheduler_output
.
scheduled_cached_reqs
.
num_reqs
==
0
assert
len
(
scheduler_output
.
num_scheduled_tokens
)
==
0
assert
len
(
scheduler_output
.
num_scheduled_tokens
)
==
0
assert
scheduler_output
.
total_num_scheduled_tokens
==
0
assert
scheduler_output
.
total_num_scheduled_tokens
==
0
...
@@ -158,7 +158,7 @@ def test_interleaved_lifecycle():
...
@@ -158,7 +158,7 @@ def test_interleaved_lifecycle():
assert
len
(
scheduler
.
running
)
==
2
assert
len
(
scheduler
.
running
)
==
2
assert
len
(
scheduler
.
waiting
)
==
1
assert
len
(
scheduler
.
waiting
)
==
1
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
1
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
1
assert
len
(
scheduler_output
.
scheduled_cached_reqs
)
==
1
assert
scheduler_output
.
scheduled_cached_reqs
.
num_reqs
==
1
model_runner_output
=
create_model_runner_output
(
model_runner_output
=
create_model_runner_output
(
[
request_local_a
,
request_local_b
])
[
request_local_a
,
request_local_b
])
...
@@ -169,7 +169,7 @@ def test_interleaved_lifecycle():
...
@@ -169,7 +169,7 @@ def test_interleaved_lifecycle():
assert
len
(
scheduler
.
running
)
==
2
assert
len
(
scheduler
.
running
)
==
2
assert
len
(
scheduler
.
waiting
)
==
1
assert
len
(
scheduler
.
waiting
)
==
1
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
assert
len
(
scheduler_output
.
scheduled_cached_reqs
)
==
2
assert
scheduler_output
.
scheduled_cached_reqs
.
num_reqs
==
2
model_runner_output
=
create_model_runner_output
(
model_runner_output
=
create_model_runner_output
(
reqs
=
[
request_local_a
,
request_local_b
])
reqs
=
[
request_local_a
,
request_local_b
])
...
@@ -177,14 +177,14 @@ def test_interleaved_lifecycle():
...
@@ -177,14 +177,14 @@ def test_interleaved_lifecycle():
assert
len
(
scheduler
.
running
)
==
2
assert
len
(
scheduler
.
running
)
==
2
assert
len
(
scheduler
.
waiting
)
==
1
assert
len
(
scheduler
.
waiting
)
==
1
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
assert
len
(
scheduler_output
.
scheduled_cached_reqs
)
==
2
assert
scheduler_output
.
scheduled_cached_reqs
.
num_reqs
==
2
# STEP 4: KVs arrive.
# STEP 4: KVs arrive.
scheduler_output
=
scheduler
.
schedule
()
scheduler_output
=
scheduler
.
schedule
()
assert
len
(
scheduler
.
running
)
==
2
assert
len
(
scheduler
.
running
)
==
2
assert
len
(
scheduler
.
waiting
)
==
1
assert
len
(
scheduler
.
waiting
)
==
1
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
0
assert
len
(
scheduler_output
.
scheduled_cached_reqs
)
==
2
assert
scheduler_output
.
scheduled_cached_reqs
.
num_reqs
==
2
model_runner_output
=
create_model_runner_output
(
model_runner_output
=
create_model_runner_output
(
[
request_local_a
,
request_local_b
],
[
request_local_a
,
request_local_b
],
...
@@ -196,7 +196,7 @@ def test_interleaved_lifecycle():
...
@@ -196,7 +196,7 @@ def test_interleaved_lifecycle():
assert
len
(
scheduler
.
running
)
==
3
assert
len
(
scheduler
.
running
)
==
3
assert
len
(
scheduler
.
waiting
)
==
0
assert
len
(
scheduler
.
waiting
)
==
0
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
1
assert
len
(
scheduler_output
.
scheduled_new_reqs
)
==
1
assert
len
(
scheduler_output
.
scheduled_cached_reqs
)
==
2
assert
scheduler_output
.
scheduled_cached_reqs
.
num_reqs
==
2
model_runner_output
=
create_model_runner_output
(
model_runner_output
=
create_model_runner_output
(
[
request_local_a
,
request_local_b
,
request_remote
])
[
request_local_a
,
request_local_b
,
request_remote
])
...
...
tests/v1/kv_connector/unit/utils.py
View file @
99324e25
...
@@ -25,7 +25,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
...
@@ -25,7 +25,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert
len
(
scheduler
.
running
)
==
0
assert
len
(
scheduler
.
running
)
==
0
assert
len
(
scheduler
.
finished_req_ids
)
==
0
assert
len
(
scheduler
.
finished_req_ids
)
==
0
assert
len
(
scheduler
.
finished_recving_kv_req_ids
)
==
0
assert
len
(
scheduler
.
finished_recving_kv_req_ids
)
==
0
assert
len
(
scheduler
.
_cached_reqs_data
)
==
0
# EncoderCacheManager.
# EncoderCacheManager.
assert
len
(
scheduler
.
encoder_cache_manager
.
freed
)
==
0
assert
len
(
scheduler
.
encoder_cache_manager
.
freed
)
==
0
...
@@ -150,6 +149,7 @@ def create_request(
...
@@ -150,6 +149,7 @@ def create_request(
request_id
=
f
"id-
{
request_id
}
"
,
request_id
=
f
"id-
{
request_id
}
"
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
pooling_params
=
None
,
multi_modal_inputs
=
None
,
multi_modal_inputs
=
None
,
multi_modal_placeholders
=
None
,
multi_modal_placeholders
=
None
,
multi_modal_hashes
=
None
,
multi_modal_hashes
=
None
,
...
@@ -183,6 +183,7 @@ def create_model_runner_output(
...
@@ -183,6 +183,7 @@ def create_model_runner_output(
spec_token_ids
=
None
,
spec_token_ids
=
None
,
logprobs
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
prompt_logprobs_dict
=
{},
pooler_output
=
None
,
finished_sending
=
finished_sending
,
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
finished_recving
=
finished_recving
,
)
)
tests/v1/sample/test_logits_processors.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
from
collections.abc
import
Callable
from
typing
import
NamedTuple
,
Optional
,
Union
import
numpy
as
np
import
pytest
import
torch
from
tests.v1.sample.utils
import
(
LogitsprocsTestFakes
,
create_fake_logits
,
create_penalty_tensor
,
create_prompt_tokens_tensor
,
fake_apply_logitsprocs
,
fake_update_logitsprocs_state
)
from
vllm.platforms
import
current_platform
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
is_pin_memory_available
# yapf: disable
from
vllm.v1.sample.logits_processor
import
(
BatchUpdate
,
BatchUpdateBuilder
,
LogitBiasLogitsProcessor
,
LogitsProcessor
,
MinPLogitsProcessor
,
MinTokensLogitsProcessor
,
MoveDirectionality
,
init_builtin_logitsprocs
)
# yapf: enable
from
vllm.v1.sample.metadata
import
SamplingMetadata
PIN_MEMORY_AVAILABLE
=
is_pin_memory_available
()
MAX_NUM_REQS
=
256
VOCAB_SIZE
=
1024
NUM_OUTPUT_TOKENS
=
20
CUDA_DEVICES
=
[
f
"
{
current_platform
.
device_type
}
:
{
i
}
"
for
i
in
range
(
1
if
current_platform
.
device_count
()
==
1
else
2
)
]
MAX_NUM_PROMPT_TOKENS
=
64
MIN_TOKENS_LEN_THRESHOLD
=
5
REQS_PER_LOGITPROC
=
50
STR_NO_LOGITPROC
=
"none"
# LogitsProcessor subclass or "none"
LogitprocType
=
Union
[
type
[
LogitsProcessor
],
str
]
class
LogitsProcsRequestParams
:
"""Encapsulates key params for a single request in a batch.
Params can be customized based on the enabled logitproc
"""
workload_index
:
int
logitproc_type
:
LogitprocType
# Logitproc enabled, specified by str id
out_tokens
:
list
[
int
]
# Output tokens required for min tokens test
params
:
SamplingParams
# Settings customized for logitproc
def
__init__
(
self
,
workload_index
:
int
,
logitproc_type
:
LogitprocType
):
self
.
workload_index
=
workload_index
self
.
logitproc_type
=
logitproc_type
# Number of output tokens is randomly 0 or twice the min-tokens
# threshold which will be used in testing. Output token values
# don't matter *for these tests* so use 0 as a dummy value
self
.
out_tokens
=
([
0
]
*
(
MIN_TOKENS_LEN_THRESHOLD
*
random
.
randint
(
0
,
2
)))
self
.
params
=
_sampling_params_from_logitproc
(
logitproc_type
)
def
__str__
(
self
):
"""For debugging"""
summ
=
', '
.
join
(
f
'
{
k
}
=
{
v
}
'
for
k
,
v
in
vars
(
self
).
items
())
return
f
"MyClass(
{
summ
}
)"
def
_generate_fake_sampling_metadata
(
num_output_tokens
:
int
,
batch_size
:
int
,
vocab_size
:
int
,
device
:
torch
.
device
,
)
->
SamplingMetadata
:
"""Generate fake sampling metadata with fake logitsprocs"""
output_token_ids
:
list
[
list
[
int
]]
=
[]
prompt_token_ids
:
list
[
list
[
int
]]
=
[]
for
_
in
range
(
batch_size
):
output_token_ids
.
append
(
np
.
random
.
randint
(
0
,
vocab_size
,
size
=
num_output_tokens
).
tolist
())
prompt_token_ids
.
append
(
np
.
random
.
randint
(
0
,
vocab_size
,
size
=
np
.
random
.
randint
(
1
,
MAX_NUM_PROMPT_TOKENS
)).
tolist
())
logitsprocs
=
init_builtin_logitsprocs
(
pin_memory_available
=
PIN_MEMORY_AVAILABLE
,
max_num_reqs
=
MAX_NUM_REQS
+
1
,
device
=
device
)
fake_sampling_metadata
=
SamplingMetadata
(
temperature
=
torch
.
full
((
batch_size
,
),
0.0
),
all_greedy
=
True
,
all_random
=
False
,
top_p
=
None
,
top_k
=
None
,
generators
=
{},
max_num_logprobs
=
0
,
prompt_token_ids
=
create_prompt_tokens_tensor
(
prompt_token_ids
,
vocab_size
,
device
),
output_token_ids
=
output_token_ids
,
frequency_penalties
=
create_penalty_tensor
(
batch_size
,
0.0
,
device
),
presence_penalties
=
create_penalty_tensor
(
batch_size
,
0.0
,
device
),
repetition_penalties
=
create_penalty_tensor
(
batch_size
,
1.0
,
device
),
no_penalties
=
True
,
allowed_token_ids_mask
=
None
,
bad_words_token_ids
=
{},
logitsprocs
=
logitsprocs
)
return
fake_sampling_metadata
def
_generate_test_fakes
(
batch_size
:
int
,
device
:
str
)
->
LogitsprocsTestFakes
:
"""Generate fake logits and sampling metadata"""
fake_logits
=
create_fake_logits
(
batch_size
,
VOCAB_SIZE
)
# Create one dominant token per batch, to support min-p test
for
i
in
range
(
batch_size
):
fake_logits
[
i
,
0
]
=
10.0
# High logit for first token
fake_logits
[
i
,
1
:]
=
1e-2
# Others remain low
sampling_metadata
=
_generate_fake_sampling_metadata
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
torch
.
device
(
device
))
return
LogitsprocsTestFakes
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
,
)
def
_sampling_params_from_logitproc
(
logitproc_type
:
LogitprocType
)
->
SamplingParams
:
"""Customize request SamplingParams for a specified logitproc"""
# SamplingParams for req with no logitproc
kwargs
=
{
"min_p"
:
0.0
,
"logit_bias"
:
None
,
"min_tokens"
:
0
}
if
fxn
:
=
logitsprocs_test_mapping
[
logitproc_type
].
gen_request_fxn
:
fxn
(
kwargs
)
return
SamplingParams
(
**
kwargs
)
def
_generate_mixed_logitsprocs_batch_params
(
reqs_per_logitproc
:
int
,
logitsprocs_types
:
list
[
str
],
)
->
list
[
LogitsProcsRequestParams
]:
"""Define key params for a batch of requests with a different
logitproc enabled per request.
The batch will have `reqs_per_logitproc` repeats for all
`logitsprocs_types` under test, including the case where
no logitsproc is enabled. The batch is randomly shuffled. The
size of the batch is `reqs_per_logitproc` times
`n = len(logitsprocs_types)`
Args:
reqs_per_logitproc: number of requests using each logitproc
logitsprocs_types: logitsprocs under test
Returns:
List of per-request params which configure the engine for that request's
enabled logitproc
"""
batch_size
=
len
(
logitsprocs_types
)
*
reqs_per_logitproc
# Generate multiple repeats of key params for each logitproc;
# apply random inverse permutation to the iteration
# over logitsprocs, such that logitsprocs are shuffled.
batch_perm
=
random
.
sample
(
range
(
batch_size
),
k
=
batch_size
)
return
[
LogitsProcsRequestParams
(
workload_index
=
idx
,
logitproc_type
=
logitsprocs_types
[
pdx
//
reqs_per_logitproc
])
for
idx
,
pdx
in
enumerate
(
batch_perm
)
]
def
_raise_error_invalid
(
msg_suffix
:
str
,
batch_index
:
int
,
request_params
:
LogitsProcsRequestParams
,
step_idx
:
int
,
err_cls
:
type
[
Exception
]
=
ValueError
,
)
->
None
:
raise
err_cls
(
f
"Validation failed for step=
{
step_idx
}
, "
f
"batch_index=
{
batch_index
}
, "
f
"workload_index=
{
request_params
.
workload_index
}
, "
f
"req_params=
{
request_params
}
. Reason:
{
msg_suffix
}
"
)
def
_logit_bias_params
(
kwargs
:
dict
)
->
None
:
"""Logit bias config"""
kwargs
[
"logit_bias"
]
=
{
random
.
randint
(
0
,
VOCAB_SIZE
-
1
):
random
.
choice
([
-
0.1
,
0.2
])
}
def
_logit_bias_validate
(
test_fakes
:
LogitsprocsTestFakes
,
persistent_batch
:
list
[
LogitsProcsRequestParams
],
logits_new
:
torch
.
Tensor
,
batch_index
:
int
,
request_params
:
LogitsProcsRequestParams
,
step_idx
:
int
,
)
->
None
:
"""Validate logit bias logitproc applied correctly"""
logit_bias
=
request_params
.
params
.
logit_bias
logits_old
=
(
test_fakes
.
logits
[
persistent_batch
[
batch_index
].
workload_index
].
cpu
())
logits_new
=
logits_new
[
batch_index
].
cpu
()
for
token_id
in
range
(
VOCAB_SIZE
):
logit_old_value
=
logits_old
[
token_id
]
logit_new_value
=
logits_new
[
token_id
]
if
token_id
in
logit_bias
:
bias_value
=
logit_bias
[
token_id
]
exp_value
=
bias_value
+
logit_old_value
if
logit_new_value
!=
pytest
.
approx
(
exp_value
):
_raise_error_invalid
(
msg_suffix
=
(
f
"Biased token
{
token_id
}
logit value
{
logit_new_value
}
"
f
"does not match expected value
{
exp_value
}
"
f
"given bias
{
bias_value
}
"
),
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
else
:
if
logit_new_value
!=
pytest
.
approx
(
logit_old_value
):
_raise_error_invalid
(
msg_suffix
=
(
f
"Unbiased token
{
token_id
}
logit value
{
logit_new_value
}
"
f
"does not match expected value
{
logit_old_value
}
"
),
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
def
_min_p_params
(
kwargs
:
dict
)
->
None
:
"""Min-p logitproc config"""
kwargs
[
"min_p"
]
=
0.1
def
_min_p_validate
(
test_fakes
:
LogitsprocsTestFakes
,
persistent_batch
:
list
[
LogitsProcsRequestParams
],
logits_new
:
torch
.
Tensor
,
batch_index
:
int
,
request_params
:
LogitsProcsRequestParams
,
step_idx
:
int
,
)
->
None
:
"""Validate min-p logitproc applied correctly"""
for
token_id
in
range
(
VOCAB_SIZE
):
logits_for_token
=
logits_new
[
batch_index
][
token_id
]
if
token_id
==
0
:
# Dominant token should always be unmasked
if
logits_for_token
==
-
float
(
"inf"
):
_raise_error_invalid
(
msg_suffix
=
"Invalid: dominant token 0 masked (-inf)"
,
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
else
:
if
request_params
.
params
.
min_p
>
0.0
:
# Non-dominant tokens should be masked when min_p > 0
if
logits_for_token
!=
-
float
(
"inf"
):
_raise_error_invalid
(
msg_suffix
=
f
"Invalid: non-dominant token
{
token_id
}
not masked"
,
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
else
:
# No masking when min_p is 0
if
logits_for_token
==
-
float
(
"inf"
):
_raise_error_invalid
(
msg_suffix
=
f
"Invalid: token
{
token_id
}
masked when min_p=0.0"
,
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
def
_min_tokens_params
(
kwargs
:
dict
)
->
None
:
"""Min-tokens logitproc config"""
kwargs
[
"min_tokens"
]
=
MIN_TOKENS_LEN_THRESHOLD
kwargs
[
"stop_token_ids"
]
=
[
np
.
random
.
randint
(
0
,
VOCAB_SIZE
-
1
)
for
_
in
range
(
np
.
random
.
randint
(
0
,
VOCAB_SIZE
))
]
def
_min_tokens_validate
(
test_fakes
:
LogitsprocsTestFakes
,
persistent_batch
:
list
[
LogitsProcsRequestParams
],
logits_new
:
torch
.
Tensor
,
batch_index
:
int
,
request_params
:
LogitsProcsRequestParams
,
step_idx
:
int
,
)
->
None
:
"""Validate min-tokens logitsproc applied correctly"""
ref_num_out_tokens
=
len
(
request_params
.
out_tokens
)
min_reached
=
ref_num_out_tokens
>=
MIN_TOKENS_LEN_THRESHOLD
ref_all_stop_token_ids
=
request_params
.
params
.
all_stop_token_ids
mt_lp
:
MinTokensLogitsProcessor
=
next
(
test_fakes
.
get_logitsprocs_by_cls
(
MinTokensLogitsProcessor
))
assert
isinstance
(
mt_lp
,
MinTokensLogitsProcessor
)
min_tok
=
mt_lp
.
min_toks
.
get
(
batch_index
,
None
)
# Validate min-token logits processor state
if
min_tok
:
(
_
,
out_tok
,
all_stop_token_ids
)
=
min_tok
num_out_tokens
=
len
(
out_tok
)
if
num_out_tokens
!=
ref_num_out_tokens
:
_raise_error_invalid
(
msg_suffix
=
(
"Number of output tokens in min-token logit processor "
f
"request metadata (
{
num_out_tokens
}
) does not match "
f
"reference (
{
ref_num_out_tokens
}
)."
),
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
if
ref_all_stop_token_ids
!=
all_stop_token_ids
:
_raise_error_invalid
(
msg_suffix
=
(
"Stop token ids do not match reference; all_stop_token_ids: "
f
"
{
sorted
(
all_stop_token_ids
)
}
, ref_all_stop_token_ids: "
f
"
{
sorted
(
ref_all_stop_token_ids
)
}
"
),
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
if
min_reached
:
_raise_error_invalid
(
msg_suffix
=
(
"Expected min-tokens request with min reached, but batch "
"index is recognized by min-tokens logits processor."
),
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
,
err_cls
=
RuntimeError
)
elif
not
min_reached
:
_raise_error_invalid
(
msg_suffix
=
(
"Expected min-tokens request with min not reached, but batch "
"index is not recognized by min-tokens logits processor."
),
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
,
err_cls
=
RuntimeError
)
# Validate min-token logits
for
token_id
in
range
(
VOCAB_SIZE
):
logits_for_token
=
logits_new
[
batch_index
][
token_id
]
if
token_id
in
ref_all_stop_token_ids
and
not
min_reached
:
if
logits_for_token
!=
-
float
(
"inf"
):
_raise_error_invalid
(
msg_suffix
=
(
f
"Token
{
token_id
}
is a stop token and "
"the sequence has not reached min length, "
"but the token is not masked "
f
"(logit=
{
logits_for_token
}
)"
),
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
else
:
if
logits_for_token
==
-
float
(
"inf"
):
_raise_error_invalid
(
msg_suffix
=
(
f
"Token
{
token_id
}
should not be masked but "
f
"is (output len=
{
ref_num_out_tokens
}
)"
),
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
def
_none_validate
(
test_fakes
:
LogitsprocsTestFakes
,
persistent_batch
:
list
[
LogitsProcsRequestParams
],
logits_new
:
torch
.
Tensor
,
batch_index
:
int
,
request_params
:
LogitsProcsRequestParams
,
step_idx
:
int
,
)
->
None
:
"""Validate that no logits processors are applied"""
logits
=
(
test_fakes
.
logits
[
persistent_batch
[
batch_index
].
workload_index
].
cpu
())
ref_logits
=
logits_new
[
batch_index
]
if
not
torch
.
all
(
ref_logits
==
logits
):
mismatch_toks
=
(
ref_logits
!=
logits
).
nonzero
(
as_tuple
=
True
)[
0
].
tolist
()
mismatch_strs
=
[]
for
token
in
mismatch_toks
:
val
=
float
(
logits
[
token
])
ref_val
=
float
(
ref_logits
[
token
])
mismatch_strs
.
append
(
f
"(
{
token
=
}
,
{
val
=
}
,
{
ref_val
=
}
)"
)
_raise_error_invalid
(
msg_suffix
=
(
f
"Unexpected modification of logits:
{
','
.
join
(
mismatch_strs
)
}
"
),
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
class
LogitsprocTestHelpers
(
NamedTuple
):
"""Supports setting up and validating logitsprocs unit tests."""
eval_fxn
:
Callable
gen_request_fxn
:
Optional
[
Callable
]
=
None
logitsprocs_test_mapping
=
{
STR_NO_LOGITPROC
:
LogitsprocTestHelpers
(
eval_fxn
=
_none_validate
),
LogitBiasLogitsProcessor
:
LogitsprocTestHelpers
(
gen_request_fxn
=
_logit_bias_params
,
eval_fxn
=
_logit_bias_validate
),
MinPLogitsProcessor
:
LogitsprocTestHelpers
(
gen_request_fxn
=
_min_p_params
,
eval_fxn
=
_min_p_validate
),
MinTokensLogitsProcessor
:
LogitsprocTestHelpers
(
gen_request_fxn
=
_min_tokens_params
,
eval_fxn
=
_min_tokens_validate
),
}
def
_get_test_cases
()
->
list
[
list
[
str
]]:
"""Each test case is a set of logitsprocs"""
logitsprocs_types
=
list
(
logitsprocs_test_mapping
.
keys
())
return
[[
STR_NO_LOGITPROC
]]
+
[[
logitproc_type
,
STR_NO_LOGITPROC
]
for
logitproc_type
in
logitsprocs_types
if
logitproc_type
!=
STR_NO_LOGITPROC
]
+
[
logitsprocs_types
]
def
_generate_fake_step_update
(
persistent_batch
:
list
[
LogitsProcsRequestParams
],
workload_params
:
list
[
LogitsProcsRequestParams
],
wdx
:
int
,
batch_update_builder
:
BatchUpdateBuilder
,
)
->
tuple
[
Optional
[
BatchUpdate
],
int
,
int
]:
batch_size
=
len
(
persistent_batch
)
workload_size
=
len
(
workload_params
)
workload_reqs_remaining
=
workload_size
-
wdx
max_add_remove_per_step
=
max
(
1
,
int
(
0.2
*
workload_size
))
# 50% of steps: add no reqs
# Other 50%: add a limited number of reqs (less than the number
# of workload reqs remaining, less than an arbitrary max)
# If no workload reqs remain: 100% of steps have 0 adds
num_step_add
=
random
.
choice
([
0
,
random
.
randint
(
1
,
min
(
max_add_remove_per_step
,
workload_reqs_remaining
))
])
if
workload_reqs_remaining
else
0
# 50% of steps: remove no requests
# Other 50%: remove a limited number of reqs (less than the number
# persistent batch reqs remaining, less than an arbitrary max)
# If persistent batch is empty: 100% of steps have 0 removals until
# more requests are added. Assume that removed requests are always
# drawn from the current batch, before new adds
num_step_remove
=
random
.
choice
([
0
,
random
.
randint
(
1
,
min
(
max_add_remove_per_step
,
batch_size
))
])
if
batch_size
else
0
num_step_add_replace
=
min
(
num_step_add
,
num_step_remove
)
# Generate fake removed request indices drawn from persistent batch indices
for
removal
in
random
.
sample
(
range
(
batch_size
),
num_step_remove
):
batch_update_builder
.
removed_append
(
removal
)
# Get added requests from workload
for
add_req_params
in
workload_params
[
wdx
:(
wdx
+
num_step_add_replace
)]:
# Replace as many removed requests as possible with added requests
add_remove_idx
=
batch_update_builder
.
pop_removed
()
batch_update_builder
.
added
.
append
(
(
add_remove_idx
,
add_req_params
.
params
,
add_req_params
.
out_tokens
))
persistent_batch
[
add_remove_idx
]
=
add_req_params
# Append remaining added requests to end of batch
add_reqs_append
=
workload_params
[(
wdx
+
num_step_add_replace
):(
wdx
+
num_step_add
)]
batch_update_builder
.
added
.
extend
([
(
adx
+
batch_size
,
add_req_params
.
params
,
add_req_params
.
out_tokens
)
for
adx
,
add_req_params
in
enumerate
(
add_reqs_append
)
])
persistent_batch
.
extend
(
add_reqs_append
)
pre_condense_batch_size
=
len
(
persistent_batch
)
wdx
+=
num_step_add
# Update workload offset
# Simulate condensing persistent batch
last_nonempty_index
=
pre_condense_batch_size
-
1
condensed_to_idxs
=
set
()
while
batch_update_builder
.
removed
:
if
(
last_nonempty_index
in
batch_update_builder
.
removed
or
last_nonempty_index
in
condensed_to_idxs
):
last_nonempty_index
-=
1
continue
# last_nonempty_index is the highest persistent batch index that was
# not removed
first_empty_index
=
batch_update_builder
.
peek_removed
()
assert
first_empty_index
is
not
None
if
first_empty_index
>
last_nonempty_index
:
break
# first_empty_index is the lowest removed persistent batch index
# that is less than last_nonempty_index
#
# move last_nonempty_index -> first_empty_index
batch_update_builder
.
pop_removed
()
condensed_to_idxs
.
add
(
first_empty_index
)
persistent_batch
[
first_empty_index
]
=
persistent_batch
[
last_nonempty_index
]
batch_update_builder
.
moved
.
append
(
(
last_nonempty_index
,
first_empty_index
,
MoveDirectionality
.
UNIDIRECTIONAL
))
last_nonempty_index
-=
1
# Now removed requests & gaps left by non-removed requests that got
# moved downward are grouped consecutively in the upper indices of
# the persistent batch. Truncate them to get condensed persistent batch
condensed_batch_size
=
batch_size
+
num_step_add
-
num_step_remove
persistent_batch
[:]
=
persistent_batch
[
0
:
condensed_batch_size
]
if
condensed_batch_size
>
1
:
# Simulate arbitrary reorder_batch() in the kernel backend
# Generate a random number k of non-overlapping swap tuples
k
=
random
.
randint
(
0
,
condensed_batch_size
//
2
)
idxs
=
list
(
range
(
condensed_batch_size
))
random
.
shuffle
(
idxs
)
swaps
=
[
tuple
(
sorted
([
idxs
[
2
*
i
],
idxs
[
2
*
i
+
1
]]))
for
i
in
range
(
k
)
]
batch_update_builder
.
moved
.
extend
([
(
sw
[
0
],
sw
[
1
],
MoveDirectionality
.
SWAP
)
for
sw
in
swaps
])
for
adx
,
bdx
in
swaps
:
persistent_batch
[
adx
],
persistent_batch
[
bdx
]
=
persistent_batch
[
bdx
],
persistent_batch
[
adx
]
return
(
batch_update_builder
.
get_and_reset
(
condensed_batch_size
),
wdx
,
workload_size
-
wdx
)
def
_assert_valid
(
batch_size
:
int
,
persistent_batch
:
list
[
LogitsProcsRequestParams
],
test_fakes
:
LogitsprocsTestFakes
,
slice_idxs
:
list
[
int
],
logits_w_lp
:
torch
.
Tensor
,
step_idx
:
int
,
)
->
None
:
if
not
slice_idxs
:
# Trivial case of empty persistent batch
assert
len
(
persistent_batch
)
==
0
if
logits_w_lp
.
shape
[
0
]
!=
0
:
raise
ValueError
(
"Fake persistent batch is empty but logitsprocs "
f
"output batch has shape
{
logits_w_lp
.
shape
}
"
)
return
# Validate logits for each fake request
for
batch_index
in
range
(
batch_size
):
request_params
=
persistent_batch
[
batch_index
]
# Invoke the appropriate validation function for
# the logitproc employed by this request
fxn
=
logitsprocs_test_mapping
[
request_params
.
logitproc_type
].
eval_fxn
fxn
(
test_fakes
=
test_fakes
,
persistent_batch
=
persistent_batch
,
logits_new
=
logits_w_lp
,
batch_index
=
batch_index
,
request_params
=
request_params
,
step_idx
=
step_idx
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"reqs_per_logitproc"
,
[
REQS_PER_LOGITPROC
])
@
pytest
.
mark
.
parametrize
(
"logitsprocs_under_test"
,
_get_test_cases
())
def
test_logitsprocs
(
device
:
str
,
reqs_per_logitproc
:
int
,
logitsprocs_under_test
:
list
[
str
]):
random
.
seed
(
40
)
torch
.
set_default_device
(
device
)
# Define a shuffled batch of requests which individually use a different
# logitproc, or no logitproc at all
workload_params
=
_generate_mixed_logitsprocs_batch_params
(
reqs_per_logitproc
=
reqs_per_logitproc
,
logitsprocs_types
=
logitsprocs_under_test
)
workload_size
=
len
(
workload_params
)
# Create fake test data structures for testing.
test_fakes
=
_generate_test_fakes
(
workload_size
,
device
)
wdx
=
0
# Next request index in workload to add
persistent_batch
:
list
[
LogitsProcsRequestParams
]
=
[
]
# Persistent batch state, as list of workload indices
# Generate fake removed request indices from current persistent
# batch before adds
batch_update_builder
=
BatchUpdateBuilder
()
# Break when entire workload has been added previously and persistent
# batch is empty
workload_reqs_remaining
=
workload_size
batch_size
=
0
step_idx
=
0
while
True
:
if
not
(
workload_reqs_remaining
or
batch_size
):
break
(
batch_update
,
wdx
,
workload_reqs_remaining
,
)
=
_generate_fake_step_update
(
persistent_batch
=
persistent_batch
,
workload_params
=
workload_params
,
wdx
=
wdx
,
batch_update_builder
=
batch_update_builder
,
)
batch_size
=
len
(
persistent_batch
)
# Apply fake batch update to logitsprocs
fake_update_logitsprocs_state
(
test_fakes
,
batch_update
)
# Emulate application of logits processors in engine
slice_idxs
=
[
req
.
workload_index
for
req
in
persistent_batch
]
logits_w_lp
=
fake_apply_logitsprocs
(
test_fakes
,
slice_idxs
).
cpu
()
_assert_valid
(
batch_size
=
batch_size
,
persistent_batch
=
persistent_batch
,
test_fakes
=
test_fakes
,
slice_idxs
=
slice_idxs
,
logits_w_lp
=
logits_w_lp
,
step_idx
=
step_idx
,
)
step_idx
+=
1
tests/v1/sample/test_logprobs_e2e.py
View file @
99324e25
...
@@ -13,9 +13,10 @@ EXPECTED_VALUE = 0.62
...
@@ -13,9 +13,10 @@ EXPECTED_VALUE = 0.62
# FIXME(rob): enable prefix caching once supported.
# FIXME(rob): enable prefix caching once supported.
MODEL
=
"meta-llama/Llama-3.2-1B-Instruct"
MODEL
=
"meta-llama/Llama-3.2-1B-Instruct"
MODEL_ARGS
=
f
"pretrained=
{
MODEL
}
,enforce_eager=True,enable_prefix_caching=False"
# noqa: E501
MODEL_ARGS
=
f
"pretrained=
{
MODEL
}
,enforce_eager=True,enable_prefix_caching=False
,gpu_memory_utilization=0.8
"
# noqa: E501
SERVER_ARGS
=
[
SERVER_ARGS
=
[
"--enforce_eager"
,
"--no_enable_prefix_caching"
,
"--disable-log-requests"
"--enforce_eager"
,
"--no_enable_prefix_caching"
,
"--disable-log-requests"
,
"--gpu-memory-utilization=0.8"
]
]
NUM_CONCURRENT
=
100
NUM_CONCURRENT
=
100
...
@@ -32,7 +33,7 @@ def test_prompt_logprobs_e2e():
...
@@ -32,7 +33,7 @@ def test_prompt_logprobs_e2e():
),
f
"Expected:
{
EXPECTED_VALUE
}
| Measured:
{
measured_value
}
"
),
f
"Expected:
{
EXPECTED_VALUE
}
| Measured:
{
measured_value
}
"
def
test_promt_logprobs_e2e_server
():
def
test_prom
p
t_logprobs_e2e_server
():
with
RemoteOpenAIServer
(
MODEL
,
SERVER_ARGS
)
as
remote_server
:
with
RemoteOpenAIServer
(
MODEL
,
SERVER_ARGS
)
as
remote_server
:
url
=
f
"
{
remote_server
.
url_for
(
'v1'
)
}
/completions"
url
=
f
"
{
remote_server
.
url_for
(
'v1'
)
}
/completions"
...
...
tests/v1/sample/test_rejection_sampler.py
View file @
99324e25
...
@@ -6,12 +6,14 @@ import pytest
...
@@ -6,12 +6,14 @@ import pytest
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
vllm.platforms
import
current_platform
from
vllm.v1.sample.logits_processor
import
LogitsProcessorManager
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
(
PLACEHOLDER_TOKEN_ID
,
from
vllm.v1.sample.rejection_sampler
import
(
PLACEHOLDER_TOKEN_ID
,
RejectionSampler
)
RejectionSampler
)
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
DEVICE
=
"
cu
da"
DEVICE
=
cu
rrent_platform
.
device_type
@
pytest
.
fixture
@
pytest
.
fixture
...
@@ -21,7 +23,7 @@ def rejection_sampler():
...
@@ -21,7 +23,7 @@ def rejection_sampler():
def
create_logits_tensor
(
output_token_ids
:
list
[
list
[
int
]],
def
create_logits_tensor
(
output_token_ids
:
list
[
list
[
int
]],
vocab_size
:
int
=
100
)
->
torch
.
Tensor
:
vocab_size
:
int
=
100
)
->
torch
.
Tensor
:
"""Helper function to create logits tensor that
"""Helper function to create logits tensor that
will produce desired token ids on argmax"""
will produce desired token ids on argmax"""
token_ids
=
[
tokens
[:
-
1
]
for
tokens
in
output_token_ids
]
token_ids
=
[
tokens
[:
-
1
]
for
tokens
in
output_token_ids
]
num_total_tokens
=
sum
(
len
(
tokens
)
for
tokens
in
token_ids
)
num_total_tokens
=
sum
(
len
(
tokens
)
for
tokens
in
token_ids
)
...
@@ -41,8 +43,8 @@ def create_sampling_metadata(
...
@@ -41,8 +43,8 @@ def create_sampling_metadata(
top_p
:
Optional
[
torch
.
Tensor
]
=
None
,
top_p
:
Optional
[
torch
.
Tensor
]
=
None
,
generators
:
Optional
[
dict
[
int
,
Any
]]
=
None
,
generators
:
Optional
[
dict
[
int
,
Any
]]
=
None
,
)
->
SamplingMetadata
:
)
->
SamplingMetadata
:
"""Create a v1 sampling metadata object with all_greedy set
"""Create a v1 sampling metadata object with all_greedy set
to the given value. Either all greedy or all random sampling
to the given value. Either all greedy or all random sampling
is used.
is used.
"""
"""
generators
=
generators
or
{}
generators
=
generators
or
{}
...
@@ -57,7 +59,6 @@ def create_sampling_metadata(
...
@@ -57,7 +59,6 @@ def create_sampling_metadata(
all_random
=
not
all_greedy
,
all_random
=
not
all_greedy
,
top_p
=
top_p
,
top_p
=
top_p
,
top_k
=
top_k
,
top_k
=
top_k
,
min_p
=
torch
.
empty
(
1
,
),
generators
=
generators
,
generators
=
generators
,
max_num_logprobs
=
0
,
max_num_logprobs
=
0
,
no_penalties
=
False
,
no_penalties
=
False
,
...
@@ -66,10 +67,9 @@ def create_sampling_metadata(
...
@@ -66,10 +67,9 @@ def create_sampling_metadata(
presence_penalties
=
torch
.
tensor
([]),
presence_penalties
=
torch
.
tensor
([]),
repetition_penalties
=
torch
.
tensor
([]),
repetition_penalties
=
torch
.
tensor
([]),
output_token_ids
=
[],
output_token_ids
=
[],
min_tokens
=
{},
logit_bias
=
[
None
],
allowed_token_ids_mask
=
None
,
allowed_token_ids_mask
=
None
,
bad_words_token_ids
=
{},
bad_words_token_ids
=
{},
logitsprocs
=
LogitsProcessorManager
(),
)
)
...
...
tests/v1/sample/test_sampler.py
View file @
99324e25
...
@@ -8,10 +8,13 @@ import pytest
...
@@ -8,10 +8,13 @@ import pytest
import
torch
import
torch
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
make_tensor_with_pad
from
vllm.utils
import
is_pin_memory_available
,
make_tensor_with_pad
from
vllm.v1.sample.logits_processor
import
LogitsProcessorManager
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.sampler
import
Sampler
from
vllm.v1.sample.sampler
import
Sampler
PIN_MEMORY_AVAILABLE
=
is_pin_memory_available
()
MAX_NUM_REQS
=
256
VOCAB_SIZE
=
1024
VOCAB_SIZE
=
1024
NUM_OUTPUT_TOKENS
=
20
NUM_OUTPUT_TOKENS
=
20
CUDA_DEVICES
=
[
CUDA_DEVICES
=
[
...
@@ -48,18 +51,6 @@ def _create_prompt_tokens_tensor(
...
@@ -48,18 +51,6 @@ def _create_prompt_tokens_tensor(
)
)
def
_create_logit_bias
(
batch_size
:
int
,
vocab_size
:
int
,
bias_value
:
float
,
)
->
list
[
Optional
[
dict
[
int
,
float
]]]:
res
:
list
[
Optional
[
dict
[
int
,
float
]]]
=
[]
for
i
in
range
(
batch_size
):
logit_bias
=
{
min
(
i
,
vocab_size
-
1
):
bias_value
}
res
.
append
(
logit_bias
)
return
res
def
_create_allowed_token_ids
(
def
_create_allowed_token_ids
(
batch_size
:
int
,
batch_size
:
int
,
vocab_size
:
int
,
vocab_size
:
int
,
...
@@ -145,7 +136,6 @@ def _create_default_sampling_metadata(
...
@@ -145,7 +136,6 @@ def _create_default_sampling_metadata(
all_random
=
False
,
all_random
=
False
,
top_p
=
None
,
top_p
=
None
,
top_k
=
None
,
top_k
=
None
,
min_p
=
None
,
generators
=
{},
generators
=
{},
max_num_logprobs
=
0
,
max_num_logprobs
=
0
,
prompt_token_ids
=
_create_prompt_tokens_tensor
(
prompt_token_ids
,
prompt_token_ids
=
_create_prompt_tokens_tensor
(
prompt_token_ids
,
...
@@ -155,43 +145,13 @@ def _create_default_sampling_metadata(
...
@@ -155,43 +145,13 @@ def _create_default_sampling_metadata(
presence_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
presence_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
repetition_penalties
=
_create_penalty_tensor
(
batch_size
,
1.0
,
device
),
repetition_penalties
=
_create_penalty_tensor
(
batch_size
,
1.0
,
device
),
no_penalties
=
True
,
no_penalties
=
True
,
min_tokens
=
{},
logit_bias
=
[
None
]
*
batch_size
,
allowed_token_ids_mask
=
None
,
allowed_token_ids_mask
=
None
,
bad_words_token_ids
=
{},
bad_words_token_ids
=
{},
logitsprocs
=
LogitsProcessorManager
(),
)
)
return
fake_sampling_metadata
return
fake_sampling_metadata
def
_generate_min_token_penalties_and_stop_tokens
(
num_output_tokens
:
int
,
batch_size
:
int
,
vocab_size
:
int
,
batch_indices_for_min_token_penalty
:
list
[
int
]
)
->
dict
[
int
,
tuple
[
int
,
set
[
int
]]]:
"""
Generates and returns a dict of minimum token penalties and
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
batch.
If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
min_tokens
:
dict
[
int
,
tuple
[
int
,
set
[
int
]]]
=
{}
for
index
in
range
(
batch_size
):
if
index
in
batch_indices_for_min_token_penalty
:
min_tokens
[
index
]
=
(
np
.
random
.
randint
(
num_output_tokens
+
1
,
2
*
num_output_tokens
),
set
(
np
.
random
.
randint
(
0
,
vocab_size
-
1
)
for
_
in
range
(
np
.
random
.
randint
(
0
,
vocab_size
))))
else
:
min_tokens
[
index
]
=
(
np
.
random
.
randint
(
0
,
num_output_tokens
),
set
())
return
min_tokens
def
_create_weighted_output_token_list
(
def
_create_weighted_output_token_list
(
batch_size
:
int
,
batch_size
:
int
,
vocab_size
:
int
)
->
tuple
[
list
[
list
[
int
]],
list
[
list
[
int
]]]:
vocab_size
:
int
)
->
tuple
[
list
[
list
[
int
]],
list
[
list
[
int
]]]:
...
@@ -227,36 +187,6 @@ def _create_weighted_output_token_list(
...
@@ -227,36 +187,6 @@ def _create_weighted_output_token_list(
return
output_token_ids
,
sorted_token_ids_in_output
return
output_token_ids
,
sorted_token_ids_in_output
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
def
test_sampler_min_tokens_penalty
(
device
:
str
,
batch_size
:
int
):
"""
Tests that if the number of output tokens is less than
SamplingParams.min_tokens then we will set the logits for
the stop token ids to -inf.
"""
torch
.
set_default_device
(
device
)
fake_logits
=
_create_fake_logits
(
batch_size
,
VOCAB_SIZE
)
sampling_metadata
=
_create_default_sampling_metadata
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
torch
.
device
(
device
))
batch_indices_for_min_token_penalty
=
np
.
random
.
randint
(
0
,
batch_size
-
1
,
size
=
np
.
random
.
randint
(
0
,
batch_size
)).
tolist
()
min_tokens
=
_generate_min_token_penalties_and_stop_tokens
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
batch_indices_for_min_token_penalty
)
sampling_metadata
.
min_tokens
=
min_tokens
sampler
=
Sampler
()
logits
=
sampler
.
apply_penalties
(
fake_logits
,
sampling_metadata
)
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
for
token_id
in
range
(
VOCAB_SIZE
):
_
,
stop_token_ids
=
min_tokens
.
get
(
batch_idx
,
(
0
,
set
()))
if
token_id
in
stop_token_ids
:
assert
logits
[
batch_idx
][
token_id
]
==
-
float
(
"inf"
)
else
:
assert
logits
[
batch_idx
][
token_id
]
!=
-
float
(
"inf"
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"presence_penalty"
,
[
-
2.0
,
2.0
])
@
pytest
.
mark
.
parametrize
(
"presence_penalty"
,
[
-
2.0
,
2.0
])
...
@@ -401,80 +331,6 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
...
@@ -401,80 +331,6 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
or
non_penalized_token_id
in
output_tokens
)
or
non_penalized_token_id
in
output_tokens
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"min_p"
,
[
0.0
,
0.1
])
def
test_sampler_min_p
(
device
:
str
,
batch_size
:
int
,
min_p
:
float
):
"""
Tests that when min_p is applied, tokens with probability below
min_p * max_prob are masked with -inf.
"""
torch
.
set_default_device
(
device
)
fake_logits
=
_create_fake_logits
(
batch_size
,
VOCAB_SIZE
)
# Create one dominant token per batch
for
i
in
range
(
batch_size
):
fake_logits
[
i
,
0
]
=
10.0
# High logit for first token
fake_logits
[
i
,
1
:]
=
1e-2
# Others remain low
sampling_metadata
=
_create_default_sampling_metadata
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
torch
.
device
(
device
))
# Configure min_p parameters
sampling_metadata
.
min_p
=
torch
.
full
((
batch_size
,
),
min_p
,
device
=
device
)
sampler
=
Sampler
()
logits
=
sampler
.
apply_min_p
(
fake_logits
,
sampling_metadata
.
min_p
)
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
for
token_id
in
range
(
VOCAB_SIZE
):
if
token_id
==
0
:
# Dominant token should always be unmasked
assert
logits
[
batch_idx
][
token_id
]
!=
-
float
(
"inf"
)
else
:
if
min_p
>
0.0
:
# Non-dominant tokens should be masked when min_p > 0
assert
logits
[
batch_idx
][
token_id
]
==
-
float
(
"inf"
)
else
:
# No masking when min_p is 0
assert
logits
[
batch_idx
][
token_id
]
!=
-
float
(
"inf"
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"bias_value"
,
[
-
0.1
,
1.2
])
def
test_sampler_logit_bias
(
device
:
str
,
batch_size
:
int
,
bias_value
:
float
):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch
.
set_default_device
(
device
)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits
=
_create_fake_logits
(
batch_size
,
VOCAB_SIZE
)
sampling_metadata
=
_create_default_sampling_metadata
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
torch
.
device
(
device
))
sampling_metadata
.
logit_bias
=
_create_logit_bias
(
batch_size
=
batch_size
,
vocab_size
=
VOCAB_SIZE
,
bias_value
=
bias_value
,
)
sampler
=
Sampler
()
logits
=
sampler
.
apply_logits_bias
(
fake_logits
,
sampling_metadata
)
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
logits_for_req
=
logits
[
batch_idx
]
biased_index
=
min
(
batch_idx
,
VOCAB_SIZE
-
1
)
for
token_id
in
range
(
VOCAB_SIZE
):
if
biased_index
==
token_id
:
assert
logits_for_req
[
token_id
]
==
pytest
.
approx
(
bias_value
+
1e-2
)
else
:
assert
logits_for_req
[
token_id
]
==
pytest
.
approx
(
1e-2
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"num_allowed_token_ids"
,
[
0
,
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"num_allowed_token_ids"
,
[
0
,
1
,
2
])
...
...
tests/v1/sample/test_topk_topp_sampler.py
View file @
99324e25
...
@@ -2,25 +2,26 @@
...
@@ -2,25 +2,26 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
pytest
import
torch
import
torch
from
flashinfer.sampling
import
top_k_renorm_probs
,
top_p_renorm_probs
from
torch
import
Generator
from
torch
import
Generator
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.sample.ops.topk_topp_sampler
import
(
apply_top_k_top_p
,
from
vllm.v1.sample.ops.topk_topp_sampler
import
(
apply_top_k_top_p
,
is_flashinfer_available
)
is_flashinfer_available
)
DEVICE
=
"
cu
da"
DEVICE
=
cu
rrent_platform
.
device_type
BATCH_SIZE
=
1024
BATCH_SIZE
=
1024
VOCAB_SIZE
=
128
*
1024
VOCAB_SIZE
=
128
*
1024
FLASHINFER_ENABLED
=
current_platform
.
is_cuda
()
and
is_flashinfer_available
FLASHINFER_ENABLED
=
current_platform
.
is_cuda
()
and
is_flashinfer_available
if
is_flashinfer_available
:
from
flashinfer.sampling
import
top_k_renorm_probs
,
top_p_renorm_probs
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_default_device
():
def
reset_default_device
():
"""
"""
Explicitly set the default device, which can affect subsequent tests.
Explicitly set the default device, which can affect subsequent tests.
Adding this fixture helps avoid this problem.
Adding this fixture helps avoid this problem.
"""
"""
original_device
=
torch
.
get_default_device
()
original_device
=
torch
.
get_default_device
()
...
@@ -28,7 +29,7 @@ def reset_default_device():
...
@@ -28,7 +29,7 @@ def reset_default_device():
torch
.
set_default_device
(
original_device
)
torch
.
set_default_device
(
original_device
)
def
test_topk_impl_equival
a
nce
():
def
test_topk_impl_equival
e
nce
():
torch
.
set_default_device
(
DEVICE
)
torch
.
set_default_device
(
DEVICE
)
generator
=
Generator
(
device
=
DEVICE
).
manual_seed
(
33
)
generator
=
Generator
(
device
=
DEVICE
).
manual_seed
(
33
)
...
@@ -58,8 +59,8 @@ def test_flashinfer_sampler():
...
@@ -58,8 +59,8 @@ def test_flashinfer_sampler():
This test verifies that the FlashInfer top-k and top-p sampling
This test verifies that the FlashInfer top-k and top-p sampling
implementation produces the same results as the Python implementation.
implementation produces the same results as the Python implementation.
NOTE: FlashInfer did not directly expose an interface for fused top-k and
NOTE: FlashInfer did not directly expose an interface for fused top-k and
top-p prob renorm (it did provide fused sampling but we cannot compare
top-p prob renorm (it did provide fused sampling but we cannot compare
sampling results due to randomness), so we will compare the probability
sampling results due to randomness), so we will compare the probability
renormed consequently by top-k and then top-p of FlashInfer implementation.
renormed consequently by top-k and then top-p of FlashInfer implementation.
'''
'''
...
...
tests/v1/sample/utils.py
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterator
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Optional
from
typing
import
NamedTuple
,
Optional
import
regex
as
re
import
regex
as
re
import
torch
from
vllm
import
CompletionOutput
from
vllm
import
CompletionOutput
from
vllm.utils
import
make_tensor_with_pad
from
vllm.v1.sample.logits_processor
import
BatchUpdate
,
LogitsProcessor
from
vllm.v1.sample.metadata
import
SamplingMetadata
class
BatchLogprobsComposition
(
Enum
):
class
BatchLogprobsComposition
(
Enum
):
...
@@ -134,3 +139,77 @@ def compute_correct_cumulative_logprob(
...
@@ -134,3 +139,77 @@ def compute_correct_cumulative_logprob(
logprobs
=
completion_output
.
logprobs
logprobs
=
completion_output
.
logprobs
assert
logprobs
is
not
None
assert
logprobs
is
not
None
return
sum
([
lp
[
tok_id
].
logprob
for
tok_id
,
lp
in
zip
(
token_ids
,
logprobs
)])
return
sum
([
lp
[
tok_id
].
logprob
for
tok_id
,
lp
in
zip
(
token_ids
,
logprobs
)])
def
create_fake_logits
(
batch_size
:
int
,
vocab_size
:
int
)
->
torch
.
Tensor
:
fake_logits
=
torch
.
full
((
batch_size
,
vocab_size
),
1e-2
,
dtype
=
torch
.
float
)
return
fake_logits
def
create_penalty_tensor
(
batch_size
:
int
,
penalty_value
:
float
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
return
torch
.
full
((
batch_size
,
),
fill_value
=
penalty_value
,
dtype
=
torch
.
float
,
device
=
device
)
def
create_prompt_tokens_tensor
(
prompt_token_ids
:
list
[
list
[
int
]],
vocab_size
:
int
,
device
:
torch
.
device
,
)
->
torch
.
Tensor
:
return
make_tensor_with_pad
(
prompt_token_ids
,
pad
=
vocab_size
,
device
=
device
,
dtype
=
torch
.
int64
,
pin_memory
=
False
,
)
class
LogitsprocsTestFakes
(
NamedTuple
):
"""Wraps fake data structures to support testing"""
logits
:
torch
.
Tensor
sampling_metadata
:
SamplingMetadata
def
get_logitsprocs_by_cls
(
self
,
cls
:
type
[
LogitsProcessor
],
)
->
Iterator
[
LogitsProcessor
]:
"""Yield logits processors of a specific class.
Args:
cls: :class:`LogitsProcessor` subclass
Returns:
Iterator over logits processors
"""
return
(
lp
for
lp
in
self
.
sampling_metadata
.
logitsprocs
.
all
if
isinstance
(
lp
,
cls
))
def
get_logitsprocs
(
self
)
->
Iterator
[
LogitsProcessor
]:
"""Iterator over all logits processors."""
return
self
.
sampling_metadata
.
logitsprocs
.
all
def
fake_update_logitsprocs_state
(
test_fakes
:
LogitsprocsTestFakes
,
batch_update
:
BatchUpdate
,
)
->
None
:
"""Imitate logits processors persistent batch state update
in engine core"""
for
logitproc
in
test_fakes
.
get_logitsprocs
():
logitproc
.
update_state
(
batch_update
)
def
fake_apply_logitsprocs
(
test_fakes
:
LogitsprocsTestFakes
,
slice_indices
:
list
[
int
],
)
->
torch
.
Tensor
:
"""Imitate application of logits processors in engine core"""
logits
=
test_fakes
.
logits
[
torch
.
tensor
(
slice_indices
,
dtype
=
torch
.
long
)].
clone
()
for
processor
in
test_fakes
.
get_logitsprocs
():
logits
=
processor
.
apply
(
logits
)
return
logits
tests/v1/spec_decode/test_eagle.py
View file @
99324e25
...
@@ -10,6 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
...
@@ -10,6 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
VllmConfig
)
VllmConfig
)
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.platforms
import
current_platform
from
vllm.v1.spec_decode.eagle
import
EagleProposer
from
vllm.v1.spec_decode.eagle
import
EagleProposer
model_dir
=
"meta-llama/Llama-3.1-8B-Instruct"
model_dir
=
"meta-llama/Llama-3.1-8B-Instruct"
...
@@ -38,15 +39,17 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
...
@@ -38,15 +39,17 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
num_speculative_tokens
=
k
,
num_speculative_tokens
=
k
,
)
)
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
vllm_config
=
VllmConfig
(
cache_config
=
CacheConfig
(),
model_config
=
model_config
,
speculative_config
=
speculative_config
,
cache_config
=
CacheConfig
(),
device_config
=
DeviceConfig
(
device
=
"cuda"
),
speculative_config
=
speculative_config
,
parallel_config
=
ParallelConfig
(),
device_config
=
DeviceConfig
(
device
=
current_platform
.
device_type
),
load_config
=
LoadConfig
(),
parallel_config
=
ParallelConfig
(),
scheduler_config
=
SchedulerConfig
())
load_config
=
LoadConfig
(),
scheduler_config
=
SchedulerConfig
())
return
EagleProposer
(
vllm_config
=
vllm_config
,
device
=
'cuda'
)
return
EagleProposer
(
vllm_config
=
vllm_config
,
device
=
current_platform
.
device_type
)
def
test_prepare_inputs
():
def
test_prepare_inputs
():
...
@@ -59,7 +62,7 @@ def test_prepare_inputs():
...
@@ -59,7 +62,7 @@ def test_prepare_inputs():
a, a + 1, ..., a + b - n2 - 1,
a, a + 1, ..., a + b - n2 - 1,
a + b, a + b + 1, ..., a + b + c - n3 - 1]
a + b, a + b + 1, ..., a + b + c - n3 - 1]
"""
"""
device
=
torch
.
device
(
'
cu
da'
)
device
=
torch
.
device
(
cu
rrent_platform
.
device_type
)
# a = 4, b = 7, c = 5
# a = 4, b = 7, c = 5
# n1 = 1, n2 = 3, n3 = 2
# n1 = 1, n2 = 3, n3 = 2
...
@@ -198,7 +201,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
...
@@ -198,7 +201,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
@
pytest
.
mark
.
parametrize
(
"num_speculative_tokens"
,
[
1
,
3
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_speculative_tokens"
,
[
1
,
3
,
8
])
def
test_propose
(
num_speculative_tokens
):
def
test_propose
(
num_speculative_tokens
):
# Use GPU device
# Use GPU device
device
=
torch
.
device
(
'
cu
da'
)
device
=
torch
.
device
(
cu
rrent_platform
.
device_type
)
# Setup test parameters
# Setup test parameters
batch_size
=
2
batch_size
=
2
...
...
tests/v1/test_async_llm_dp.py
View file @
99324e25
...
@@ -4,24 +4,30 @@
...
@@ -4,24 +4,30 @@
import
asyncio
import
asyncio
import
os
import
os
from
contextlib
import
ExitStack
from
contextlib
import
ExitStack
from
dataclasses
import
dataclass
from
typing
import
Optional
from
typing
import
Optional
import
pytest
import
pytest
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.config
import
VllmConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.inputs
import
PromptType
from
vllm.inputs
import
PromptType
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.v1.engine.async_llm
import
AsyncLLM
from
vllm.v1.engine.async_llm
import
AsyncLLM
from
vllm.v1.engine.core_client
import
DPAsyncMPClient
from
vllm.v1.engine.core_client
import
DPAsyncMPClient
from
vllm.v1.metrics.loggers
import
StatLoggerBase
from
vllm.v1.metrics.stats
import
IterationStats
,
SchedulerStats
DP_SIZE
=
int
(
os
.
getenv
(
"DP_SIZE"
,
2
))
engine_args
=
AsyncEngineArgs
(
engine_args
=
AsyncEngineArgs
(
model
=
"ibm-research/PowerMoE-3b"
,
model
=
"ibm-research/PowerMoE-3b"
,
enforce_eager
=
True
,
enforce_eager
=
True
,
disable_log_requests
=
True
,
disable_log_requests
=
True
,
tensor_parallel_size
=
int
(
os
.
getenv
(
"TP_SIZE"
,
1
)),
tensor_parallel_size
=
int
(
os
.
getenv
(
"TP_SIZE"
,
1
)),
data_parallel_size
=
int
(
os
.
getenv
(
"
DP_SIZE
"
,
2
))
,
data_parallel_size
=
DP_SIZE
,
)
)
if
not
current_platform
.
supports_v1
(
engine_args
.
create_model_config
()):
if
not
current_platform
.
supports_v1
(
engine_args
.
create_model_config
()):
...
@@ -74,12 +80,32 @@ async def generate(
...
@@ -74,12 +80,32 @@ async def generate(
async
def
test_load
(
output_kind
:
RequestOutputKind
,
async
def
test_load
(
output_kind
:
RequestOutputKind
,
data_parallel_backend
:
str
):
data_parallel_backend
:
str
):
stats_loggers
=
{}
@
dataclass
class
SimpleStatsLogger
(
StatLoggerBase
):
init_count
:
int
=
0
finished_req_count
:
int
=
0
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
engine_index
:
int
=
0
):
stats_loggers
[
engine_index
]
=
self
def
record
(
self
,
scheduler_stats
:
Optional
[
SchedulerStats
],
iteration_stats
:
Optional
[
IterationStats
]):
if
iteration_stats
:
self
.
finished_req_count
+=
len
(
iteration_stats
.
finished_requests
)
def
log_engine_initialized
(
self
):
self
.
init_count
+=
1
with
ExitStack
()
as
after
:
with
ExitStack
()
as
after
:
prompt
=
"This is a test of data parallel"
prompt
=
"This is a test of data parallel"
engine_args
.
data_parallel_backend
=
data_parallel_backend
engine_args
.
data_parallel_backend
=
data_parallel_backend
engine
=
AsyncLLM
.
from_engine_args
(
engine_args
)
engine
=
AsyncLLM
.
from_engine_args
(
engine_args
,
stat_loggers
=
[
SimpleStatsLogger
])
after
.
callback
(
engine
.
shutdown
)
after
.
callback
(
engine
.
shutdown
)
NUM_REQUESTS
=
100
NUM_REQUESTS
=
100
...
@@ -92,12 +118,10 @@ async def test_load(output_kind: RequestOutputKind,
...
@@ -92,12 +118,10 @@ async def test_load(output_kind: RequestOutputKind,
for
request_id
in
request_ids
:
for
request_id
in
request_ids
:
tasks
.
append
(
tasks
.
append
(
asyncio
.
create_task
(
asyncio
.
create_task
(
generate
(
engine
,
generate
(
engine
,
request_id
,
prompt
,
output_kind
,
request_id
,
NUM_EXPECTED_TOKENS
)))
prompt
,
# Short sleep to ensure that requests are distributed.
output_kind
,
await
asyncio
.
sleep
(
0.01
)
NUM_EXPECTED_TOKENS
,
data_parallel_rank
=
0
)))
# Confirm that we got all the EXPECTED tokens from the requests.
# Confirm that we got all the EXPECTED tokens from the requests.
done
,
pending
=
await
asyncio
.
wait
(
tasks
,
done
,
pending
=
await
asyncio
.
wait
(
tasks
,
return_when
=
asyncio
.
FIRST_EXCEPTION
)
return_when
=
asyncio
.
FIRST_EXCEPTION
)
...
@@ -122,3 +146,14 @@ async def test_load(output_kind: RequestOutputKind,
...
@@ -122,3 +146,14 @@ async def test_load(output_kind: RequestOutputKind,
assert
not
core_client
.
engines_running
assert
not
core_client
.
engines_running
assert
not
core_client
.
reqs_in_flight
assert
not
core_client
.
reqs_in_flight
# Check that requests were distributed between the engines
print
(
f
"Stats loggers after test:
{
stats_loggers
}
"
)
assert
len
(
stats_loggers
)
==
DP_SIZE
assert
stats_loggers
[
0
].
init_count
==
1
for
sl
in
stats_loggers
.
values
():
slogger
:
SimpleStatsLogger
=
sl
assert
slogger
.
finished_req_count
>
NUM_REQUESTS
//
(
DP_SIZE
+
1
),
f
"requests are imbalanced:
{
stats_loggers
}
"
tests/v1/test_external_lb_dp.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
os
import
threading
import
time
from
contextlib
import
AsyncExitStack
import
openai
# use the official client for correctness check
import
pytest
import
pytest_asyncio
from
tests.utils
import
RemoteOpenAIServer
from
vllm.platforms
import
Platform
MODEL_NAME
=
"ibm-research/PowerMoE-3b"
# Number of data parallel ranks for external LB testing
DP_SIZE
=
int
(
os
.
getenv
(
"DP_SIZE"
,
"2"
))
# Default tensor parallell size to use
TP_SIZE
=
int
(
os
.
getenv
(
"TP_SIZE"
,
"1"
))
class
ExternalLBServerManager
:
"""Manages data parallel vLLM server instances for external
load balancer testing."""
def
__init__
(
self
,
model_name
:
str
,
dp_size
:
int
,
api_server_count
:
int
,
base_server_args
:
list
,
tp_size
:
int
=
TP_SIZE
):
self
.
model_name
=
model_name
self
.
dp_size
=
dp_size
self
.
tp_size
=
tp_size
self
.
api_server_count
=
api_server_count
self
.
base_server_args
=
base_server_args
self
.
servers
:
list
[
tuple
[
RemoteOpenAIServer
,
list
[
str
]]]
=
[]
self
.
server_threads
:
list
[
threading
.
Thread
]
=
[]
def
__enter__
(
self
)
->
list
[
tuple
[
RemoteOpenAIServer
,
list
[
str
]]]:
"""Start all server instances for external LB mode."""
for
rank
in
range
(
self
.
dp_size
):
# Create server args for this specific rank
server_args
=
self
.
base_server_args
.
copy
()
# Add external LB specific arguments
server_args
.
extend
([
"--data-parallel-size"
,
str
(
self
.
dp_size
),
"--data-parallel-rank"
,
str
(
rank
),
"--data-parallel-size-local"
,
"1"
,
"--tensor-parallel-size"
,
str
(
self
.
tp_size
),
"--port"
,
str
(
8000
+
rank
),
# Different port for each rank
"--api-server-count"
,
str
(
self
.
api_server_count
),
])
# Use a thread to start each server to allow parallel initialization
def
start_server
(
r
:
int
,
sargs
:
list
[
str
]):
try
:
# Start the server
server
=
RemoteOpenAIServer
(
self
.
model_name
,
sargs
,
auto_port
=
False
,
env_dict
=
{
"CUDA_VISIBLE_DEVICES"
:
","
.
join
(
str
(
Platform
.
device_id_to_physical_device_id
(
i
))
for
i
in
range
(
r
*
TP_SIZE
,
(
r
+
1
)
*
TP_SIZE
))
})
server
.
__enter__
()
print
(
f
"Server rank
{
r
}
started successfully with "
f
"
{
self
.
api_server_count
}
API servers"
)
self
.
servers
.
append
((
server
,
sargs
))
except
Exception
as
e
:
print
(
f
"Failed to start server rank
{
r
}
:
{
e
}
"
)
raise
thread
=
threading
.
Thread
(
target
=
start_server
,
args
=
(
rank
,
server_args
))
thread
.
start
()
self
.
server_threads
.
append
(
thread
)
# Wait for all servers to start
for
thread
in
self
.
server_threads
:
thread
.
join
()
# Give servers additional time to fully initialize and coordinate
time
.
sleep
(
2
)
if
len
(
self
.
servers
)
!=
self
.
dp_size
:
raise
Exception
(
"Servers failed to start"
)
return
self
.
servers
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
"""Stop all server instances."""
while
self
.
servers
:
try
:
self
.
servers
.
pop
()[
0
].
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
except
Exception
as
e
:
print
(
f
"Error stopping server:
{
e
}
"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
default_server_args
():
return
[
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"bfloat16"
,
"--max-model-len"
,
"2048"
,
"--max-num-seqs"
,
"128"
,
"--enforce-eager"
,
]
@
pytest
.
fixture
(
scope
=
"module"
,
params
=
[
1
,
4
])
def
servers
(
request
,
default_server_args
):
api_server_count
=
request
.
param
with
ExternalLBServerManager
(
MODEL_NAME
,
DP_SIZE
,
api_server_count
,
default_server_args
)
as
server_list
:
yield
server_list
@
pytest_asyncio
.
fixture
async
def
clients
(
servers
:
list
[
tuple
[
RemoteOpenAIServer
,
list
[
str
]]]):
# Create a client for each server
async
with
AsyncExitStack
()
as
stack
:
yield
[
await
stack
.
enter_async_context
(
server
.
get_async_client
())
for
server
,
_
in
servers
]
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
],
)
async
def
test_external_lb_single_completion
(
clients
:
list
[
openai
.
AsyncOpenAI
],
servers
:
list
[
tuple
[
RemoteOpenAIServer
,
list
[
str
]]],
model_name
:
str
)
->
None
:
async
def
make_request
(
client
:
openai
.
AsyncOpenAI
):
completion
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
"Hello, my name is"
,
max_tokens
=
10
,
temperature
=
1.0
)
assert
completion
.
id
is
not
None
assert
completion
.
choices
is
not
None
and
len
(
completion
.
choices
)
==
1
choice
=
completion
.
choices
[
0
]
# The exact number of tokens can vary slightly with temperature=1.0,
# so we check for a reasonable minimum length.
assert
len
(
choice
.
text
)
>=
1
# Finish reason might not always be 'length' if the model finishes early
# or due to other reasons, especially with high temperature.
# So, we'll accept 'length' or 'stop'.
assert
choice
.
finish_reason
in
(
"length"
,
"stop"
)
# Token counts can also vary, so we check they are positive.
assert
completion
.
usage
.
completion_tokens
>
0
assert
completion
.
usage
.
prompt_tokens
>
0
assert
completion
.
usage
.
total_tokens
>
0
return
completion
# Test single request to each server
for
i
,
client
in
enumerate
(
clients
):
result
=
await
make_request
(
client
)
assert
result
is
not
None
print
(
f
"Server
{
i
}
handled single completion request successfully"
)
await
asyncio
.
sleep
(
0.5
)
# Send requests to all servers in round-robin fashion
num_requests_per_server
=
25
# Total 50 requests across 2 servers
all_tasks
=
[]
for
i
,
client
in
enumerate
(
clients
):
tasks
=
[
make_request
(
client
)
for
_
in
range
(
num_requests_per_server
)]
all_tasks
.
extend
(
tasks
)
results
=
await
asyncio
.
gather
(
*
all_tasks
)
assert
len
(
results
)
==
num_requests_per_server
*
len
(
clients
)
assert
all
(
completion
is
not
None
for
completion
in
results
)
await
asyncio
.
sleep
(
0.5
)
# Second burst of requests
all_tasks
=
[]
for
i
,
client
in
enumerate
(
clients
):
tasks
=
[
make_request
(
client
)
for
_
in
range
(
num_requests_per_server
)]
all_tasks
.
extend
(
tasks
)
results
=
await
asyncio
.
gather
(
*
all_tasks
)
assert
len
(
results
)
==
num_requests_per_server
*
len
(
clients
)
assert
all
(
completion
is
not
None
for
completion
in
results
)
_
,
server_args
=
servers
[
0
]
api_server_count
=
(
server_args
.
count
(
'--api-server-count'
)
and
server_args
[
server_args
.
index
(
'--api-server-count'
)
+
1
]
or
1
)
print
(
f
"Successfully completed external LB test with
{
len
(
clients
)
}
servers "
f
"(API server count:
{
api_server_count
}
)"
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
],
)
async
def
test_external_lb_completion_streaming
(
clients
:
list
[
openai
.
AsyncOpenAI
],
servers
:
list
[
tuple
[
RemoteOpenAIServer
,
list
[
str
]]],
model_name
:
str
)
->
None
:
prompt
=
"What is an LLM?"
async
def
make_streaming_request
(
client
:
openai
.
AsyncOpenAI
):
# Perform a non-streaming request to get the expected full output
single_completion
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
prompt
,
max_tokens
=
5
,
temperature
=
0.0
,
)
single_output
=
single_completion
.
choices
[
0
].
text
# Perform the streaming request
stream
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
prompt
,
max_tokens
=
5
,
temperature
=
0.0
,
stream
=
True
)
chunks
:
list
[
str
]
=
[]
finish_reason_count
=
0
last_chunk
=
None
async
for
chunk
in
stream
:
chunks
.
append
(
chunk
.
choices
[
0
].
text
)
if
chunk
.
choices
[
0
].
finish_reason
is
not
None
:
finish_reason_count
+=
1
last_chunk
=
chunk
# Keep track of the last chunk
# finish reason should only return in the last block for OpenAI API
assert
finish_reason_count
==
1
,
(
"Finish reason should appear exactly once."
)
assert
last_chunk
is
not
None
,
(
"Stream should have yielded at least one chunk."
)
assert
last_chunk
.
choices
[
0
].
finish_reason
==
"length"
,
"Finish reason should be 'length'."
# Check that the combined text matches the non-streamed version.
assert
""
.
join
(
chunks
)
==
single_output
,
"Streamed output should match non-streamed output."
return
True
# Indicate success for this request
# Test single request to each server
for
i
,
client
in
enumerate
(
clients
):
result
=
await
make_streaming_request
(
client
)
assert
result
is
not
None
print
(
f
"Server
{
i
}
handled single streaming request successfully"
)
await
asyncio
.
sleep
(
0.5
)
# Send streaming requests to all servers in round-robin fashion
num_requests_per_server
=
25
# Total 50 requests across 2 servers
all_tasks
=
[]
for
i
,
client
in
enumerate
(
clients
):
tasks
=
[
make_streaming_request
(
client
)
for
_
in
range
(
num_requests_per_server
)
]
all_tasks
.
extend
(
tasks
)
results
=
await
asyncio
.
gather
(
*
all_tasks
)
assert
len
(
results
)
==
num_requests_per_server
*
len
(
clients
)
assert
all
(
results
),
"Not all streaming requests completed successfully."
await
asyncio
.
sleep
(
0.5
)
# Second burst of streaming requests
all_tasks
=
[]
for
i
,
client
in
enumerate
(
clients
):
tasks
=
[
make_streaming_request
(
client
)
for
_
in
range
(
num_requests_per_server
)
]
all_tasks
.
extend
(
tasks
)
results
=
await
asyncio
.
gather
(
*
all_tasks
)
assert
len
(
results
)
==
num_requests_per_server
*
len
(
clients
)
assert
all
(
results
),
"Not all streaming requests completed successfully."
_
,
server_args
=
servers
[
0
]
api_server_count
=
(
server_args
.
count
(
'--api-server-count'
)
and
server_args
[
server_args
.
index
(
'--api-server-count'
)
+
1
]
or
1
)
print
(
f
"Successfully completed external LB streaming test with "
f
"
{
len
(
clients
)
}
servers (API server count:
{
api_server_count
}
)"
)
tests/v1/test_oracle.py
View file @
99324e25
...
@@ -12,8 +12,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
...
@@ -12,8 +12,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
UNSUPPORTED_MODELS_V1
=
[
UNSUPPORTED_MODELS_V1
=
[
"openai/whisper-large-v3"
,
# transcription
"openai/whisper-large-v3"
,
# transcription
"facebook/bart-large-cnn"
,
# encoder decoder
"facebook/bart-large-cnn"
,
# encoder decoder
"mistralai/Mamba-Codestral-7B-v0.1"
,
# mamba
"state-spaces/mamba-130m-hf"
,
# mamba1
"hmellor/tiny-random-BambaForCausalLM"
,
# hybrid
"BAAI/bge-m3"
,
# embedding
"BAAI/bge-m3"
,
# embedding
]
]
...
@@ -74,12 +73,6 @@ def test_unsupported_configs(monkeypatch):
...
@@ -74,12 +73,6 @@ def test_unsupported_configs(monkeypatch):
disable_async_output_proc
=
True
,
disable_async_output_proc
=
True
,
).
create_engine_config
()
).
create_engine_config
()
with
pytest
.
raises
(
NotImplementedError
):
AsyncEngineArgs
(
model
=
MODEL
,
scheduling_policy
=
"priority"
,
).
create_engine_config
()
with
pytest
.
raises
(
NotImplementedError
):
with
pytest
.
raises
(
NotImplementedError
):
AsyncEngineArgs
(
AsyncEngineArgs
(
model
=
MODEL
,
model
=
MODEL
,
...
...
tests/v1/test_request.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.v1.request
import
RequestStatus
def
test_request_status_fmt_str
():
"""Test that the string representation of RequestStatus is correct."""
assert
f
"
{
RequestStatus
.
WAITING
}
"
==
"WAITING"
assert
f
"
{
RequestStatus
.
WAITING_FOR_FSM
}
"
==
"WAITING_FOR_FSM"
assert
f
"
{
RequestStatus
.
WAITING_FOR_REMOTE_KVS
}
"
==
"WAITING_FOR_REMOTE_KVS"
assert
f
"
{
RequestStatus
.
RUNNING
}
"
==
"RUNNING"
assert
f
"
{
RequestStatus
.
PREEMPTED
}
"
==
"PREEMPTED"
assert
f
"
{
RequestStatus
.
FINISHED_STOPPED
}
"
==
"FINISHED_STOPPED"
assert
f
"
{
RequestStatus
.
FINISHED_LENGTH_CAPPED
}
"
==
"FINISHED_LENGTH_CAPPED"
assert
f
"
{
RequestStatus
.
FINISHED_ABORTED
}
"
==
"FINISHED_ABORTED"
assert
f
"
{
RequestStatus
.
FINISHED_IGNORED
}
"
==
"FINISHED_IGNORED"
tests/v1/tpu/test_basic.py
View file @
99324e25
...
@@ -67,6 +67,43 @@ def test_basic(
...
@@ -67,6 +67,43 @@ def test_basic(
assert
"1024"
in
output
or
"0, 1"
in
output
assert
"1024"
in
output
or
"0, 1"
in
output
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This is a basic test for TPU only"
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"max_num_seqs"
,
[
16
])
def
test_phi3
(
vllm_runner
:
type
[
VllmRunner
],
monkeypatch
:
pytest
.
MonkeyPatch
,
max_tokens
:
int
,
max_num_seqs
:
int
,
)
->
None
:
prompts
=
[
"A robot may not injure a human being"
,
"It is only with the heart that one can see rightly;"
,
"The greatest glory in living lies not in never falling,"
,
]
answers
=
[
" or, by violating privacy"
,
" what is essential is love."
,
" but in rising every time we fall."
,
]
# test head dim = 96
model
=
"microsoft/Phi-3-mini-128k-instruct"
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
with
vllm_runner
(
model
,
max_num_batched_tokens
=
256
,
max_num_seqs
=
max_num_seqs
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
prompts
,
max_tokens
)
# vllm_outputs is a list of tuples whose first element is the token id
# and the second element is the output (including the prompt).
for
output
,
answer
in
zip
(
vllm_outputs
,
answers
):
generated_text
=
output
[
1
]
assert
answer
in
generated_text
TP_SIZE_8
=
8
TP_SIZE_8
=
8
...
...
tests/v1/tpu/test_kv_cache_update_kernel.py
0 → 100644
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
numpy
as
np
import
pytest
import
torch
import
torch_xla
import
vllm.v1.attention.backends.pallas
# noqa: F401
from
vllm.platforms
import
current_platform
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This is a test for TPU only"
)
@
pytest
.
mark
.
parametrize
(
"page_size"
,
[
32
,
33
])
@
pytest
.
mark
.
parametrize
(
"combined_kv_head_num"
,
[
2
,
16
])
@
pytest
.
mark
.
parametrize
(
"head_dim"
,
[
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"num_slices_per_block"
,
[
4
,
8
])
def
test_kv_cache_update_kernel
(
page_size
:
int
,
combined_kv_head_num
:
int
,
head_dim
:
int
,
num_slices_per_block
:
int
):
page_num
=
1000
padded_num_tokens
=
128
kv_cache_cpu
=
torch
.
zeros
(
(
page_num
*
page_size
,
combined_kv_head_num
,
head_dim
),
dtype
=
torch
.
bfloat16
,
device
=
"cpu"
)
kv_cache_xla
=
kv_cache_cpu
.
to
(
torch_xla
.
device
())
new_kv_cpu
=
torch
.
randn
(
(
padded_num_tokens
,
combined_kv_head_num
,
head_dim
),
dtype
=
torch
.
bfloat16
,
device
=
"cpu"
)
new_kv_xla
=
new_kv_cpu
.
to
(
torch_xla
.
device
())
slice_lens
=
np
.
array
([
7
,
page_size
,
page_size
,
1
,
1
,
1
,
9
],
dtype
=
np
.
int32
)
num_kv_update_slices
=
len
(
slice_lens
)
kv_cache_start_indices
=
np
.
array
([
page_size
*
2
-
7
,
page_size
*
2
,
page_size
*
3
,
page_size
*
4
+
6
,
page_size
*
5
+
7
,
page_size
*
6
+
8
,
page_size
*
15
+
3
],
dtype
=
np
.
int32
)
new_kv_cache_indices
=
np
.
concatenate
(
[
np
.
array
([
0
],
dtype
=
np
.
int32
),
np
.
cumsum
(
slice_lens
[:
-
1
])])
slot_mapping
=
np
.
stack
(
[
kv_cache_start_indices
,
new_kv_cache_indices
,
slice_lens
],
axis
=
1
)
padded_size
=
(
slot_mapping
.
shape
[
0
]
+
num_slices_per_block
-
1
)
//
num_slices_per_block
*
num_slices_per_block
slot_mapping
=
np
.
pad
(
slot_mapping
,
[[
0
,
padded_size
-
slot_mapping
.
shape
[
0
]],
[
0
,
0
]],
constant_values
=
0
)
slot_mapping
=
np
.
transpose
(
slot_mapping
)
slot_mapping_cpu
=
torch
.
tensor
(
slot_mapping
,
device
=
"cpu"
,
dtype
=
torch
.
int32
)
slot_mapping_xla
=
slot_mapping_cpu
.
to
(
torch_xla
.
device
())
num_kv_update_slices_xla
=
torch
.
tensor
([
num_kv_update_slices
],
device
=
torch_xla
.
device
(),
dtype
=
torch
.
int32
)
torch_xla
.
sync
()
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
kv_cache_xla
,
True
)
new_kv_cache_xla
=
torch
.
ops
.
xla
.
kv_cache_update_op
(
new_kv_xla
,
slot_mapping_xla
,
kv_cache_xla
,
num_kv_update_slices_xla
,
page_size
,
num_slices_per_block
)
kv_cache_xla
.
copy_
(
new_kv_cache_xla
)
torch_xla
.
sync
()
for
ni
,
ci
,
sl
in
zip
(
new_kv_cache_indices
,
kv_cache_start_indices
,
slice_lens
):
kv_cache_cpu
[
ci
:
ci
+
sl
,
:,
:]
=
new_kv_cpu
[
ni
:
ni
+
sl
,
:,
:]
assert
torch
.
allclose
(
kv_cache_xla
.
cpu
(),
kv_cache_cpu
,
atol
=
1e-4
,
rtol
=
1e-4
)
tests/v1/tpu/test_pallas.py
View file @
99324e25
...
@@ -47,7 +47,7 @@ def test_ragged_paged_attention():
...
@@ -47,7 +47,7 @@ def test_ragged_paged_attention():
key
=
torch
.
zeros
(
num_tokens
,
num_kv_heads
*
head_size
)
key
=
torch
.
zeros
(
num_tokens
,
num_kv_heads
*
head_size
)
value
=
torch
.
zeros
(
num_tokens
,
num_kv_heads
*
head_size
)
value
=
torch
.
zeros
(
num_tokens
,
num_kv_heads
*
head_size
)
kv_cache
=
torch
.
zeros
(
num_blocks
,
block_size
,
num_kv_heads
*
2
,
head_size
)
kv_cache
=
torch
.
zeros
(
num_blocks
,
block_size
,
num_kv_heads
*
2
,
head_size
)
slot_mapping
=
torch
.
zeros
(
num_tokens
,
dtype
=
torch
.
int64
)
slot_mapping
=
torch
.
zeros
(
(
3
,
num_tokens
)
,
dtype
=
torch
.
int64
)
max_num_reqs
=
8
max_num_reqs
=
8
max_num_blocks_per_req
=
8
max_num_blocks_per_req
=
8
block_tables
=
torch
.
zeros
((
max_num_reqs
,
max_num_blocks_per_req
),
block_tables
=
torch
.
zeros
((
max_num_reqs
,
max_num_blocks_per_req
),
...
@@ -65,6 +65,7 @@ def test_ragged_paged_attention():
...
@@ -65,6 +65,7 @@ def test_ragged_paged_attention():
context_lens
=
context_lens
,
context_lens
=
context_lens
,
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
num_seqs
=
num_seqs
,
num_seqs
=
num_seqs
,
num_slices_per_kv_cache_update_block
=
8
,
)
)
with
patch
(
"torch.ops.xla.ragged_paged_attention"
with
patch
(
"torch.ops.xla.ragged_paged_attention"
...
...
tests/v1/tpu/test_spmd_model_weight_loading.py
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
gc
import
gc
import
tempfile
import
tempfile
...
...
tests/v1/tpu/test_tpu_qkv_linear.py
View file @
99324e25
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
tempfile
import
tempfile
import
numpy
as
np
import
numpy
as
np
...
...
Prev
1
…
18
19
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment