Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
06bfb519
Unverified
Commit
06bfb519
authored
Jan 06, 2025
by
Woosuk Kwon
Committed by
GitHub
Jan 06, 2025
Browse files
[V1] Add BlockTable class (#11693)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
408e5600
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
94 additions
and
25 deletions
+94
-25
vllm/v1/worker/block_table.py
vllm/v1/worker/block_table.py
+78
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+9
-16
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+7
-9
No files found.
vllm/v1/worker/block_table.py
0 → 100644
View file @
06bfb519
from
typing
import
List
import
numpy
as
np
import
torch
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
BlockTable
:
def
__init__
(
self
,
max_num_reqs
:
int
,
max_model_len
:
int
,
max_num_blocks_per_req
:
int
,
pin_memory
:
bool
,
device
:
torch
.
device
,
):
self
.
max_num_reqs
=
max_num_reqs
self
.
max_model_len
=
max_model_len
self
.
max_num_blocks_per_req
=
max_num_blocks_per_req
self
.
pin_memory
=
pin_memory
self
.
device
=
device
self
.
block_table
=
torch
.
zeros
(
(
max_num_reqs
,
max_num_blocks_per_req
),
device
=
self
.
device
,
dtype
=
torch
.
int32
,
)
self
.
block_table_cpu
=
torch
.
zeros
(
(
max_num_reqs
,
max_num_blocks_per_req
),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
pin_memory
,
)
self
.
block_table_np
=
self
.
block_table_cpu
.
numpy
()
self
.
num_blocks_per_row
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
def
append_row
(
self
,
row_idx
:
int
,
start
:
int
,
block_ids
:
List
[
int
],
)
->
None
:
num_blocks
=
len
(
block_ids
)
self
.
block_table_np
[
row_idx
,
start
:
start
+
num_blocks
]
=
block_ids
self
.
num_blocks_per_row
[
row_idx
]
=
start
+
num_blocks
def
add_row
(
self
,
row_idx
:
int
,
block_ids
:
List
[
int
])
->
None
:
self
.
append_row
(
row_idx
,
0
,
block_ids
)
def
move_row
(
self
,
src
:
int
,
tgt
:
int
)
->
None
:
num_blocks
=
self
.
num_blocks_per_row
[
src
]
self
.
block_table_np
[
tgt
,
:
num_blocks
]
=
self
.
block_table_np
[
src
,
:
num_blocks
]
self
.
num_blocks_per_row
[
tgt
]
=
num_blocks
def
commit
(
self
,
num_reqs
:
int
)
->
None
:
self
.
block_table
[:
num_reqs
].
copy_
(
self
.
block_table_cpu
[:
num_reqs
],
non_blocking
=
True
)
def
clear
(
self
)
->
None
:
self
.
block_table
.
fill_
(
0
)
self
.
block_table_cpu
.
fill_
(
0
)
def
get_device_tensor
(
self
)
->
torch
.
Tensor
:
"""Ruturns the device tensor of the block table."""
return
self
.
block_table
def
get_cpu_tensor
(
self
)
->
torch
.
Tensor
:
"""Returns the CPU tensor of the block table."""
return
self
.
block_table_cpu
def
get_numpy_array
(
self
)
->
np
.
ndarray
:
"""Returns the numpy array of the block table."""
return
self
.
block_table_np
vllm/v1/worker/gpu_input_batch.py
View file @
06bfb519
...
@@ -9,6 +9,7 @@ import torch
...
@@ -9,6 +9,7 @@ import torch
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.block_table
import
BlockTable
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.multimodal.inputs
import
PlaceholderRange
...
@@ -70,19 +71,14 @@ class InputBatch:
...
@@ -70,19 +71,14 @@ class InputBatch:
self
.
num_prompt_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_prompt_tokens
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_computed_tokens_cpu
=
np
.
empty
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
num_computed_tokens_cpu
=
np
.
empty
(
max_num_reqs
,
dtype
=
np
.
int32
)
# Attention-related.
# Block table.
self
.
block_table
=
torch
.
zeros
(
self
.
block_table
=
BlockTable
(
(
max_num_reqs
,
max_num_blocks_per_req
),
max_num_reqs
=
max_num_reqs
,
device
=
self
.
device
,
max_model_len
=
max_model_len
,
dtype
=
torch
.
int32
,
max_num_blocks_per_req
=
max_num_blocks_per_req
,
)
self
.
block_table_cpu_tensor
=
torch
.
zeros
(
(
max_num_reqs
,
max_num_blocks_per_req
),
device
=
"cpu"
,
dtype
=
torch
.
int32
,
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
device
=
device
,
)
)
self
.
block_table_cpu
=
self
.
block_table_cpu_tensor
.
numpy
()
# Sampling-related.
# Sampling-related.
self
.
temperature
=
torch
.
empty
((
max_num_reqs
,
),
self
.
temperature
=
torch
.
empty
((
max_num_reqs
,
),
...
@@ -193,8 +189,7 @@ class InputBatch:
...
@@ -193,8 +189,7 @@ class InputBatch:
self
.
num_tokens
[
req_index
]
=
request
.
num_tokens
self
.
num_tokens
[
req_index
]
=
request
.
num_tokens
self
.
num_computed_tokens_cpu
[
req_index
]
=
request
.
num_computed_tokens
self
.
num_computed_tokens_cpu
[
req_index
]
=
request
.
num_computed_tokens
num_blocks
=
len
(
request
.
block_ids
)
self
.
block_table
.
add_row
(
req_index
,
request
.
block_ids
)
self
.
block_table_cpu
[
req_index
,
:
num_blocks
]
=
request
.
block_ids
sampling_params
=
request
.
sampling_params
sampling_params
=
request
.
sampling_params
self
.
temperature_cpu
[
req_index
]
=
sampling_params
.
temperature
self
.
temperature_cpu
[
req_index
]
=
sampling_params
.
temperature
...
@@ -300,9 +295,7 @@ class InputBatch:
...
@@ -300,9 +295,7 @@ class InputBatch:
self
.
num_prompt_tokens
[
last_req_index
]
self
.
num_prompt_tokens
[
last_req_index
]
self
.
num_computed_tokens_cpu
[
self
.
num_computed_tokens_cpu
[
empty_index
]
=
self
.
num_computed_tokens_cpu
[
last_req_index
]
empty_index
]
=
self
.
num_computed_tokens_cpu
[
last_req_index
]
# TODO(woosuk): Optimize the copy of block_table_cpu.
self
.
block_table
.
move_row
(
last_req_index
,
empty_index
)
self
.
block_table_cpu
[
empty_index
]
=
self
.
block_table_cpu
[
last_req_index
]
self
.
temperature_cpu
[
empty_index
]
=
self
.
temperature_cpu
[
self
.
temperature_cpu
[
empty_index
]
=
self
.
temperature_cpu
[
last_req_index
]
last_req_index
]
self
.
top_p_cpu
[
empty_index
]
=
self
.
top_p_cpu
[
last_req_index
]
self
.
top_p_cpu
[
empty_index
]
=
self
.
top_p_cpu
[
last_req_index
]
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
06bfb519
...
@@ -211,10 +211,9 @@ class GPUModelRunner:
...
@@ -211,10 +211,9 @@ class GPUModelRunner:
if
num_new_blocks
==
0
:
if
num_new_blocks
==
0
:
continue
continue
start_index
=
len
(
req_state
.
block_ids
)
start_index
=
len
(
req_state
.
block_ids
)
end_index
=
start_index
+
num_new_blocks
req_state
.
block_ids
.
extend
(
req_data
.
new_block_ids
)
req_state
.
block_ids
.
extend
(
req_data
.
new_block_ids
)
self
.
input_batch
.
block_table
_cpu
[
self
.
input_batch
.
block_table
.
append_row
(
req_index
,
start_index
,
req_index
,
start_index
:
end_index
]
=
req_data
.
new_block_ids
req_data
.
new_block_ids
)
req_ids_to_add
:
List
[
str
]
=
[]
req_ids_to_add
:
List
[
str
]
=
[]
# Add new requests to the cached states.
# Add new requests to the cached states.
...
@@ -275,9 +274,7 @@ class GPUModelRunner:
...
@@ -275,9 +274,7 @@ class GPUModelRunner:
# OPTIMIZATION: Start copying the block table first.
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
# This way, we can overlap the copy with the following CPU operations.
self
.
input_batch
.
block_table
[:
num_reqs
].
copy_
(
self
.
input_batch
.
block_table
.
commit
(
num_reqs
)
self
.
input_batch
.
block_table_cpu_tensor
[:
num_reqs
],
non_blocking
=
True
)
# Get the number of scheduled tokens for each request.
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
# TODO: The Python loop can be slow. Optimize.
...
@@ -333,8 +330,8 @@ class GPUModelRunner:
...
@@ -333,8 +330,8 @@ class GPUModelRunner:
# NOTE(woosuk): We use torch.index_select instead of np.take here
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# because torch.index_select is much faster than np.take for large
# tensors.
# tensors.
block_
numbers
=
(
self
.
input_batch
.
block_table_cpu_tensor
.
flatten
()
block_
table_cpu
=
self
.
input_batch
.
block_table
.
get
_cpu_tensor
()
[
block_table_indices
].
numpy
()
)
block_numbers
=
block_table_cpu
.
flatten
()
[
block_table_indices
].
numpy
()
block_offsets
=
positions_np
%
self
.
block_size
block_offsets
=
positions_np
%
self
.
block_size
np
.
add
(
block_numbers
*
self
.
block_size
,
np
.
add
(
block_numbers
*
self
.
block_size
,
block_offsets
,
block_offsets
,
...
@@ -450,7 +447,8 @@ class GPUModelRunner:
...
@@ -450,7 +447,8 @@ class GPUModelRunner:
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
max_seq_len
=
max_seq_len
,
max_seq_len
=
max_seq_len
,
seq_start_loc
=
seq_start_loc
,
seq_start_loc
=
seq_start_loc
,
block_table
=
self
.
input_batch
.
block_table
[:
num_reqs
],
block_table
=
(
self
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
]),
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
use_cascade
=
use_cascade
,
use_cascade
=
use_cascade
,
common_prefix_len
=
common_prefix_len
,
common_prefix_len
=
common_prefix_len
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment