Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fba02e3b
Unverified
Commit
fba02e3b
authored
May 30, 2025
by
Carol Zheng
Committed by
GitHub
May 30, 2025
Browse files
[Bugfix][TPU] Fix tpu model runner testcase failure (#18810)
Signed-off-by:
Carol Zheng
<
cazheng@google.com
>
parent
4577fc9a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
16 deletions
+50
-16
tests/v1/tpu/worker/test_tpu_model_runner.py
tests/v1/tpu/worker/test_tpu_model_runner.py
+26
-5
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+24
-11
No files found.
tests/v1/tpu/worker/test_tpu_model_runner.py
View file @
fba02e3b
...
@@ -81,7 +81,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
...
@@ -81,7 +81,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes
=
[],
mm_hashes
=
[],
mm_positions
=
[],
mm_positions
=
[],
sampling_params
=
SamplingParams
(),
sampling_params
=
SamplingParams
(),
block_ids
=
[
0
],
block_ids
=
[
[
0
]],
# block_ids should be list[list[int]]
num_computed_tokens
=
0
,
num_computed_tokens
=
0
,
lora_request
=
None
,
lora_request
=
None
,
))
))
...
@@ -112,14 +112,35 @@ def _is_req_added(model_runner, req_id: str) -> bool:
...
@@ -112,14 +112,35 @@ def _is_req_added(model_runner, req_id: str) -> bool:
def
_is_req_state_block_table_match
(
model_runner
,
req_id
:
str
)
->
bool
:
def
_is_req_state_block_table_match
(
model_runner
,
req_id
:
str
)
->
bool
:
"""Check if the request state block IDs match the block table.
This function handles both legacy BlockTable and new MultiGroupBlockTable
structures for backward compatibility.
"""
req_index
=
model_runner
.
input_batch
.
req_id_to_index
[
req_id
]
req_index
=
model_runner
.
input_batch
.
req_id_to_index
[
req_id
]
block_table
=
model_runner
.
input_batch
.
block_table
multi_group_
block_table
=
model_runner
.
input_batch
.
block_table
req_state
=
model_runner
.
requests
[
req_id
]
req_state
=
model_runner
.
requests
[
req_id
]
if
block_table
.
num_blocks_per_row
[
req_index
]
!=
len
(
req_state
.
block_ids
):
# Access the first block table from MultiGroupBlockTable
# This is safe since we currently only use single KV cache groups
block_table
=
multi_group_block_table
[
0
]
# req_state.block_ids is now list[list[int]] for MultiGroupBlockTable
# Extract the first group's block IDs
if
isinstance
(
req_state
.
block_ids
[
0
],
list
):
# New format: list[list[int]] - extract first group
req_block_ids
=
req_state
.
block_ids
[
0
]
else
:
# Legacy format: list[int] - use directly
req_block_ids
=
req_state
.
block_ids
if
block_table
.
num_blocks_per_row
[
req_index
]
!=
len
(
req_block_ids
):
return
False
return
False
num_blocks
=
block_table
.
num_blocks_per_row
[
req_index
]
num_blocks
=
block_table
.
num_blocks_per_row
[
req_index
]
return
(
block_table
.
block_table_np
[
req_index
,
:
num_blocks
]
==
block_table_values
=
block_table
.
block_table_np
[
req_index
,
:
num_blocks
]
req_state
.
block_ids
).
all
()
return
(
block_table_values
==
req_
block_ids
).
all
()
def
test_update_states_new_request
(
model_runner
):
def
test_update_states_new_request
(
model_runner
):
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
fba02e3b
...
@@ -175,11 +175,21 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -175,11 +175,21 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
kv_caches
:
list
[
torch
.
Tensor
]
=
[]
self
.
kv_caches
:
list
[
torch
.
Tensor
]
=
[]
# req_id -> (input_id -> encoder_output)
# req_id -> (input_id -> encoder_output)
self
.
encoder_cache
:
dict
[
str
,
dict
[
int
,
torch
.
Tensor
]]
=
{}
self
.
encoder_cache
:
dict
[
str
,
dict
[
int
,
torch
.
Tensor
]]
=
{}
# self.input_batch: InputBatch # Persistent batch.
# Request states.
# Request states.
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
self
.
requests
:
dict
[
str
,
CachedRequestState
]
=
{}
# Initialize input batch early to avoid AttributeError in _update_states
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
self
.
model_config
.
get_vocab_size
(),
block_size
=
self
.
block_size
,
)
# Cached torch/numpy tensor
# Cached torch/numpy tensor
# The pytorch tensor and numpy array share the same buffer.
# The pytorch tensor and numpy array share the same buffer.
# Sometimes the numpy op is faster so we create both.
# Sometimes the numpy op is faster so we create both.
...
@@ -1286,6 +1296,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1286,6 +1296,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
"Hybrid models with more than one KV cache type are not "
"Hybrid models with more than one KV cache type are not "
"supported yet."
)
"supported yet."
)
if
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
.
block_size
!=
self
.
block_size
:
self
.
input_batch
=
InputBatch
(
self
.
input_batch
=
InputBatch
(
max_num_reqs
=
self
.
max_num_reqs
,
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
max_model_len
=
self
.
max_model_len
,
...
@@ -1296,6 +1308,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1296,6 +1308,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
block_size
=
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
.
block_size
=
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
.
block_size
,
block_size
,
)
)
# Verify dtype compatibility between block_table_cpu and input_batch
assert
self
.
block_table_cpu
.
dtype
==
self
.
input_batch
.
block_table
[
assert
self
.
block_table_cpu
.
dtype
==
self
.
input_batch
.
block_table
[
0
].
get_cpu_tensor
().
dtype
0
].
get_cpu_tensor
().
dtype
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment