Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e7bd944e
Unverified
Commit
e7bd944e
authored
Mar 01, 2025
by
Chen Zhang
Committed by
GitHub
Feb 28, 2025
Browse files
[v1] Cleanup the BlockTable in InputBatch (#13977)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
c3b6559a
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
17 deletions
+25
-17
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+14
-0
vllm/v1/worker/block_table.py
vllm/v1/worker/block_table.py
+6
-7
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+1
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-4
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+2
-4
No files found.
tests/v1/worker/test_gpu_model_runner.py
View file @
e7bd944e
...
@@ -89,6 +89,17 @@ def _is_sampling_metadata_changed(model_runner,
...
@@ -89,6 +89,17 @@ def _is_sampling_metadata_changed(model_runner,
sampling_metadata_before
)
sampling_metadata_before
)
def
_is_req_state_block_table_match
(
model_runner
,
req_id
:
str
)
->
bool
:
req_index
=
model_runner
.
input_batch
.
req_id_to_index
[
req_id
]
block_table
=
model_runner
.
input_batch
.
block_table
req_state
=
model_runner
.
requests
[
req_id
]
if
block_table
.
num_blocks_per_row
[
req_index
]
!=
len
(
req_state
.
block_ids
):
return
False
num_blocks
=
block_table
.
num_blocks_per_row
[
req_index
]
return
(
block_table
.
block_table_np
[
req_index
,
:
num_blocks
]
==
req_state
.
block_ids
).
all
()
def
test_update_states_new_request
(
model_runner
):
def
test_update_states_new_request
(
model_runner
):
req_id
=
"req_0"
req_id
=
"req_0"
...
@@ -100,6 +111,7 @@ def test_update_states_new_request(model_runner):
...
@@ -100,6 +111,7 @@ def test_update_states_new_request(model_runner):
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
assert
_is_req_state_block_table_match
(
model_runner
,
req_id
)
def
test_update_states_request_finished
(
model_runner
):
def
test_update_states_request_finished
(
model_runner
):
...
@@ -185,6 +197,7 @@ def test_update_states_request_resumed(model_runner):
...
@@ -185,6 +197,7 @@ def test_update_states_request_resumed(model_runner):
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
assert
_is_req_state_block_table_match
(
model_runner
,
req_id
)
def
test_update_states_no_changes
(
model_runner
):
def
test_update_states_no_changes
(
model_runner
):
...
@@ -215,6 +228,7 @@ def test_update_states_no_changes(model_runner):
...
@@ -215,6 +228,7 @@ def test_update_states_no_changes(model_runner):
assert
not
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
not
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
assert
_is_req_state_block_table_match
(
model_runner
,
req_id
)
def
test_update_states_request_unscheduled
(
model_runner
):
def
test_update_states_request_unscheduled
(
model_runner
):
...
...
vllm/v1/worker/block_table.py
View file @
e7bd944e
...
@@ -15,13 +15,11 @@ class BlockTable:
...
@@ -15,13 +15,11 @@ class BlockTable:
def
__init__
(
def
__init__
(
self
,
self
,
max_num_reqs
:
int
,
max_num_reqs
:
int
,
max_model_len
:
int
,
max_num_blocks_per_req
:
int
,
max_num_blocks_per_req
:
int
,
pin_memory
:
bool
,
pin_memory
:
bool
,
device
:
torch
.
device
,
device
:
torch
.
device
,
):
):
self
.
max_num_reqs
=
max_num_reqs
self
.
max_num_reqs
=
max_num_reqs
self
.
max_model_len
=
max_model_len
self
.
max_num_blocks_per_req
=
max_num_blocks_per_req
self
.
max_num_blocks_per_req
=
max_num_blocks_per_req
self
.
pin_memory
=
pin_memory
self
.
pin_memory
=
pin_memory
self
.
device
=
device
self
.
device
=
device
...
@@ -42,18 +40,19 @@ class BlockTable:
...
@@ -42,18 +40,19 @@ class BlockTable:
def
append_row
(
def
append_row
(
self
,
self
,
row_idx
:
int
,
start
:
int
,
block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
row_idx
:
int
,
)
->
None
:
)
->
None
:
if
not
block_ids
:
if
not
block_ids
:
return
return
num_blocks
=
len
(
block_ids
)
num_blocks
=
len
(
block_ids
)
start
=
self
.
num_blocks_per_row
[
row_idx
]
self
.
num_blocks_per_row
[
row_idx
]
+=
num_blocks
self
.
block_table_np
[
row_idx
,
start
:
start
+
num_blocks
]
=
block_ids
self
.
block_table_np
[
row_idx
,
start
:
start
+
num_blocks
]
=
block_ids
self
.
num_blocks_per_row
[
row_idx
]
=
start
+
num_blocks
def
add_row
(
self
,
row_idx
:
int
,
block_ids
:
List
[
int
])
->
None
:
def
add_row
(
self
,
block_ids
:
List
[
int
],
row_idx
:
int
)
->
None
:
self
.
append_row
(
row_idx
,
0
,
block_ids
)
self
.
num_blocks_per_row
[
row_idx
]
=
0
self
.
append_row
(
block_ids
,
row_idx
)
def
move_row
(
self
,
src
:
int
,
tgt
:
int
)
->
None
:
def
move_row
(
self
,
src
:
int
,
tgt
:
int
)
->
None
:
num_blocks
=
self
.
num_blocks_per_row
[
src
]
num_blocks
=
self
.
num_blocks_per_row
[
src
]
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
e7bd944e
...
@@ -92,7 +92,6 @@ class InputBatch:
...
@@ -92,7 +92,6 @@ class InputBatch:
# Block table.
# Block table.
self
.
block_table
=
BlockTable
(
self
.
block_table
=
BlockTable
(
max_num_reqs
=
max_num_reqs
,
max_num_reqs
=
max_num_reqs
,
max_model_len
=
max_model_len
,
max_num_blocks_per_req
=
max_num_blocks_per_req
,
max_num_blocks_per_req
=
max_num_blocks_per_req
,
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
device
=
device
,
device
=
device
,
...
@@ -249,7 +248,7 @@ class InputBatch:
...
@@ -249,7 +248,7 @@ class InputBatch:
self
.
num_tokens_no_spec
[
req_index
]
=
request
.
num_tokens
self
.
num_tokens_no_spec
[
req_index
]
=
request
.
num_tokens
self
.
num_computed_tokens_cpu
[
req_index
]
=
request
.
num_computed_tokens
self
.
num_computed_tokens_cpu
[
req_index
]
=
request
.
num_computed_tokens
self
.
block_table
.
add_row
(
req_index
,
request
.
block_ids
)
self
.
block_table
.
add_row
(
request
.
block_ids
,
req_index
)
sampling_params
=
request
.
sampling_params
sampling_params
=
request
.
sampling_params
if
sampling_params
.
sampling_type
==
SamplingType
.
GREEDY
:
if
sampling_params
.
sampling_type
==
SamplingType
.
GREEDY
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
e7bd944e
...
@@ -399,10 +399,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -399,10 +399,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Update the persistent batch.
# Update the persistent batch.
self
.
input_batch
.
num_computed_tokens_cpu
[
req_index
]
=
(
self
.
input_batch
.
num_computed_tokens_cpu
[
req_index
]
=
(
num_computed_tokens
)
num_computed_tokens
)
start_index
=
(
len
(
req_state
.
block_ids
)
-
self
.
input_batch
.
block_table
.
append_row
(
req_data
.
new_block_ids
,
len
(
req_data
.
new_block_ids
))
req_index
)
self
.
input_batch
.
block_table
.
append_row
(
req_index
,
start_index
,
req_data
.
new_block_ids
)
# Add new_token_ids to token_ids_cpu.
# Add new_token_ids to token_ids_cpu.
start_token_index
=
num_computed_tokens
start_token_index
=
num_computed_tokens
end_token_index
=
num_computed_tokens
+
len
(
req_data
.
new_token_ids
)
end_token_index
=
num_computed_tokens
+
len
(
req_data
.
new_token_ids
)
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
e7bd944e
...
@@ -247,10 +247,8 @@ class TPUModelRunner:
...
@@ -247,10 +247,8 @@ class TPUModelRunner:
# Update the persistent batch.
# Update the persistent batch.
self
.
input_batch
.
num_computed_tokens_cpu
[
req_index
]
=
(
self
.
input_batch
.
num_computed_tokens_cpu
[
req_index
]
=
(
req_data
.
num_computed_tokens
)
req_data
.
num_computed_tokens
)
start_index
=
len
(
req_state
.
block_ids
)
-
len
(
self
.
input_batch
.
block_table
.
append_row
(
req_data
.
new_block_ids
,
req_data
.
new_block_ids
)
req_index
)
self
.
input_batch
.
block_table
.
append_row
(
req_index
,
start_index
,
req_data
.
new_block_ids
)
# Add the new or resumed requests to the persistent batch.
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
# The smaller empty indices are filled first.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment