Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0a049c7d
Unverified
Commit
0a049c7d
authored
Mar 25, 2025
by
yarongmu-google
Committed by
GitHub
Mar 25, 2025
Browse files
[CI/Build] Add tests for the V1 tpu_model_runner. (#14843)
Signed-off-by:
Yarong Mu
<
ymu@google.com
>
parent
d0cfec7a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
310 additions
and
1 deletion
+310
-1
.buildkite/run-tpu-v1-test.sh
.buildkite/run-tpu-v1-test.sh
+3
-1
tests/v1/tpu/worker/__init__.py
tests/v1/tpu/worker/__init__.py
+0
-0
tests/v1/tpu/worker/test_tpu_model_runner.py
tests/v1/tpu/worker/test_tpu_model_runner.py
+307
-0
No files found.
.buildkite/run-tpu-v1-test.sh
View file @
0a049c7d
...
...
@@ -30,7 +30,9 @@ docker run --privileged --net host --shm-size=16G -it \
&& echo TEST_4
\
&& pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py
\
&& echo TEST_5
\
&& python3 /workspace/vllm/examples/offline_inference/tpu.py"
\
&& python3 /workspace/vllm/examples/offline_inference/tpu.py
\
&& echo TEST_6
\
&& pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py"
\
# TODO: This test fails because it uses RANDOM_SEED sampling
...
...
tests/v1/tpu/worker/__init__.py
0 → 100644
View file @
0a049c7d
tests/v1/tpu/worker/test_tpu_model_runner.py
0 → 100644
View file @
0a049c7d
# SPDX-License-Identifier: Apache-2.0
import
unittest.mock
as
mock
import
pytest
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.sched.output
import
(
CachedRequestData
,
NewRequestData
,
SchedulerOutput
)
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.tpu_model_runner
import
TPUModelRunner
# Mock torch_xla module since it may not be available in the test environments
torch_xla_patcher
=
mock
.
patch
.
dict
(
"sys.modules"
,
{
"torch_xla"
:
mock
.
MagicMock
(),
"torch_xla.core.xla_model"
:
mock
.
MagicMock
(),
"torch_xla.runtime"
:
mock
.
MagicMock
(),
})
torch_xla_patcher
.
start
()
# Mock the PallasAttentionBackend
pallas_attention_backend_patcher
=
mock
.
patch
(
"vllm.v1.worker.tpu_model_runner.PallasAttentionBackend"
,
)
pallas_attention_backend_patcher
.
start
()
@
pytest
.
fixture
def
model_runner
():
# Patchers have already been started at module level.
scheduler_config
=
SchedulerConfig
(
max_num_seqs
=
10
,
max_num_batched_tokens
=
512
,
max_model_len
=
512
,
)
model_config
=
ModelConfig
(
model
=
"facebook/opt-125m"
,
task
=
"generate"
,
tokenizer
=
"facebook/opt-125m"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
True
,
dtype
=
"bfloat16"
,
# TPUs typically use bfloat16
seed
=
42
,
)
cache_config
=
CacheConfig
(
block_size
=
16
,
gpu_memory_utilization
=
0.9
,
swap_space
=
0
,
cache_dtype
=
"auto"
,
)
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
cache_config
,
scheduler_config
=
scheduler_config
,
)
device
=
"xla:0"
# Mocking TPU device
with
mock
.
patch
(
"vllm.v1.worker.tpu_model_runner.torch"
),
\
mock
.
patch
(
"vllm.v1.worker.tpu_model_runner.xm"
),
\
mock
.
patch
(
"vllm.v1.worker.tpu_model_runner.xr"
):
return
TPUModelRunner
(
vllm_config
,
device
)
@
pytest
.
fixture
(
autouse
=
True
,
scope
=
"session"
)
def
cleanup_patches
():
yield
torch_xla_patcher
.
stop
()
pallas_attention_backend_patcher
.
stop
()
def
_schedule_new_request
(
*
req_ids
:
str
)
->
SchedulerOutput
:
new_reqs
=
[]
num_scheduled_tokens
=
{}
total_num_scheduled_tokens
=
0
for
req_id
in
req_ids
:
new_reqs
.
append
(
NewRequestData
(
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt
=
"test"
,
mm_inputs
=
[],
mm_hashes
=
[],
mm_positions
=
[],
sampling_params
=
SamplingParams
(),
block_ids
=
[
0
],
num_computed_tokens
=
0
,
lora_request
=
None
,
))
num_scheduled_tokens
[
req_id
]
=
3
total_num_scheduled_tokens
+=
num_scheduled_tokens
[
req_id
]
return
SchedulerOutput
(
scheduled_new_reqs
=
new_reqs
,
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
num_scheduled_tokens
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[],
structured_output_request_ids
=
{},
grammar_bitmask
=
None
,
)
def
_is_req_scheduled
(
model_runner
,
req_id
:
str
)
->
bool
:
return
req_id
in
model_runner
.
input_batch
.
req_id_to_index
def
_is_req_added
(
model_runner
,
req_id
:
str
)
->
bool
:
return
req_id
in
model_runner
.
requests
def
_is_sampling_metadata_changed
(
model_runner
,
sampling_metadata_before
:
SamplingMetadata
):
return
model_runner
.
input_batch
.
sampling_metadata
is
not
(
sampling_metadata_before
)
def
_is_req_state_block_table_match
(
model_runner
,
req_id
:
str
)
->
bool
:
req_index
=
model_runner
.
input_batch
.
req_id_to_index
[
req_id
]
block_table
=
model_runner
.
input_batch
.
block_table
req_state
=
model_runner
.
requests
[
req_id
]
if
block_table
.
num_blocks_per_row
[
req_index
]
!=
len
(
req_state
.
block_ids
):
return
False
num_blocks
=
block_table
.
num_blocks_per_row
[
req_index
]
return
(
block_table
.
block_table_np
[
req_index
,
:
num_blocks
]
==
req_state
.
block_ids
).
all
()
def
test_update_states_new_request
(
model_runner
):
req_id
=
"req_0"
# new req
scheduler_output
=
_schedule_new_request
(
req_id
)
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
assert
_is_req_state_block_table_match
(
model_runner
,
req_id
)
def
test_update_states_request_finished
(
model_runner
):
req_id
=
"req_0"
# new req
scheduler_output
=
_schedule_new_request
(
req_id
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
# finish req
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{},
total_num_scheduled_tokens
=
0
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
{
req_id
},
free_encoder_input_ids
=
[],
structured_output_request_ids
=
{},
grammar_bitmask
=
None
,
)
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
not
_is_req_added
(
model_runner
,
req_id
)
assert
not
_is_req_scheduled
(
model_runner
,
req_id
)
def
test_update_states_request_resumed
(
model_runner
):
req_id
=
"req_0"
# new req
scheduler_output
=
_schedule_new_request
(
req_id
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
# unschedule req
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{},
total_num_scheduled_tokens
=
0
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[],
structured_output_request_ids
=
{},
grammar_bitmask
=
None
,
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
not
_is_req_scheduled
(
model_runner
,
req_id
)
# resume req
cached_req_data
=
CachedRequestData
(
req_id
=
req_id
,
resumed_from_preemption
=
False
,
new_token_ids
=
[],
new_block_ids
=
[],
num_computed_tokens
=
0
,
)
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[
cached_req_data
],
num_scheduled_tokens
=
{
req_id
:
1
},
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[],
structured_output_request_ids
=
{},
grammar_bitmask
=
None
,
)
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
assert
_is_req_state_block_table_match
(
model_runner
,
req_id
)
def
test_update_states_no_changes
(
model_runner
):
req_id
=
"req_0"
# new req
scheduler_output
=
_schedule_new_request
(
req_id
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
# schedule req
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{
req_id
:
1
},
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[],
structured_output_request_ids
=
{},
grammar_bitmask
=
None
,
)
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
model_runner
.
_update_states
(
scheduler_output
)
assert
not
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
assert
_is_req_state_block_table_match
(
model_runner
,
req_id
)
def
test_update_states_request_unscheduled
(
model_runner
):
req_ids
=
(
"req_0"
,
"req_1"
)
# new reqs
scheduler_output
=
_schedule_new_request
(
*
req_ids
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_ids
[
0
])
assert
_is_req_scheduled
(
model_runner
,
req_ids
[
0
])
assert
_is_req_added
(
model_runner
,
req_ids
[
1
])
assert
_is_req_scheduled
(
model_runner
,
req_ids
[
1
])
# unschedule req_1
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{
req_ids
[
0
]:
1
},
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[],
structured_output_request_ids
=
{},
grammar_bitmask
=
None
,
)
metadata_before
=
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_ids
[
0
])
assert
_is_req_scheduled
(
model_runner
,
req_ids
[
0
])
assert
_is_req_added
(
model_runner
,
req_ids
[
1
])
assert
not
_is_req_scheduled
(
model_runner
,
req_ids
[
1
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment