Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
950751a9
Unverified
Commit
950751a9
authored
May 11, 2025
by
Chen Zhang
Committed by
GitHub
May 10, 2025
Browse files
[v1] Pass BlockTable and KVCacheSpec to AttentionMetadataBuilders (#17483)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
4c31218f
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
132 additions
and
68 deletions
+132
-68
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+3
-0
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+20
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+30
-17
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+20
-15
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+15
-8
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+7
-4
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+5
-2
vllm/v1/worker/block_table.py
vllm/v1/worker/block_table.py
+11
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+3
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+9
-12
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+9
-9
No files found.
tests/v1/worker/test_gpu_input_batch.py
View file @
950751a9
...
...
@@ -221,6 +221,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
max_num_reqs
=
batch_size
,
max_model_len
=
1024
,
max_num_blocks_per_req
=
10
,
max_num_batched_tokens
=
1024
,
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
,
...
...
@@ -310,6 +311,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
max_num_reqs
=
batch_size
,
max_model_len
=
1024
,
max_num_blocks_per_req
=
10
,
max_num_batched_tokens
=
1024
,
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
,
...
...
@@ -318,6 +320,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
max_num_reqs
=
batch_size
,
max_model_len
=
1024
,
max_num_blocks_per_req
=
10
,
max_num_batched_tokens
=
1024
,
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
,
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
950751a9
# SPDX-License-Identifier: Apache-2.0
import
weakref
import
pytest
import
torch
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.sched.output
import
(
CachedRequestData
,
NewRequestData
,
SchedulerOutput
)
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
def
initialize_kv_cache
(
runner
:
GPUModelRunner
):
"""
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
"""
kv_cache_spec
=
FullAttentionSpec
(
block_size
=
16
,
num_kv_heads
=
1
,
head_size
=
64
,
dtype
=
torch
.
float16
,
use_mla
=
False
)
runner
.
attn_metadata_builder
=
runner
.
attn_backend
.
get_builder_cls
()(
weakref
.
proxy
(
runner
),
kv_cache_spec
,
runner
.
input_batch
.
block_table
)
@
pytest
.
fixture
def
model_runner
():
scheduler_config
=
SchedulerConfig
(
...
...
@@ -38,7 +55,9 @@ def model_runner():
)
device
=
"cuda"
return
GPUModelRunner
(
vllm_config
,
device
)
runner
=
GPUModelRunner
(
vllm_config
,
device
)
initialize_kv_cache
(
runner
)
return
runner
def
_schedule_new_request
(
*
req_ids
:
str
)
->
SchedulerOutput
:
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
950751a9
...
...
@@ -19,6 +19,8 @@ from vllm.logger import init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
...
@@ -167,7 +169,7 @@ def make_local_attention_virtual_batches(
query_start_loc_np
:
np
.
ndarray
,
seq_lens_np
:
np
.
ndarray
,
block_table
:
torch
.
Tensor
,
page
_size
:
int
=
0
,
block
_size
:
int
=
0
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
torch
.
Tensor
]:
q_seqlens
=
query_start_loc_np
[
1
:]
-
query_start_loc_np
[:
-
1
]
actual_batch_size
=
seq_lens_np
.
shape
[
0
]
...
...
@@ -238,14 +240,14 @@ def make_local_attention_virtual_batches(
# For the example the local attention blocks start at:
# _b0_ _____b1_____ _b2_
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
block_starts
=
k_seqstarts_absolute
//
page
_size
assert
attn_chunk_size
%
page
_size
==
0
,
\
block_starts
=
k_seqstarts_absolute
//
block
_size
assert
attn_chunk_size
%
block
_size
==
0
,
\
f
"attn_chunk_size
{
attn_chunk_size
}
is not "
\
f
"divisible by
page
_size
{
page
_size
}
"
pages_per_local_batch
=
attn_chunk_size
//
page
_size
f
"divisible by
block
_size
{
block
_size
}
"
pages_per_local_batch
=
attn_chunk_size
//
block
_size
# Create a block_table for the local attention blocks
# For out example if we have a block-table like (assuming
page
_size=2):
# For out example if we have a block-table like (assuming
block
_size=2):
# block_table = [
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
...
...
@@ -289,7 +291,8 @@ def _get_sliding_window_configs(
class
FlashAttentionMetadataBuilder
:
def
__init__
(
self
,
runner
:
"GPUModelRunner"
):
def
__init__
(
self
,
runner
:
"GPUModelRunner"
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
):
model_config
=
runner
.
model_config
compilation_config
=
runner
.
vllm_config
.
compilation_config
...
...
@@ -299,7 +302,9 @@ class FlashAttentionMetadataBuilder:
self
.
num_heads_kv
=
model_config
.
get_num_kv_heads
(
runner
.
parallel_config
)
self
.
headdim
=
model_config
.
get_head_size
()
self
.
page_size
=
self
.
runner
.
block_size
self
.
block_size
=
kv_cache_spec
.
block_size
self
.
kv_cache_spec
=
kv_cache_spec
self
.
block_table
=
block_table
if
get_flash_attn_version
()
==
3
:
self
.
aot_schedule
=
not
compilation_config
.
full_cuda_graph
...
...
@@ -323,9 +328,17 @@ class FlashAttentionMetadataBuilder:
max_seq_len
=
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
()
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
block_table
=
(
self
.
runner
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
])
slot_mapping
=
self
.
runner
.
slot_mapping
[:
num_actual_tokens
]
block_table
=
self
.
block_table
block_table_tensor
=
block_table
.
get_device_tensor
()[:
num_reqs
]
block_table
.
slot_mapping
[:
num_actual_tokens
].
copy_
(
block_table
.
slot_mapping_cpu
[:
num_actual_tokens
],
non_blocking
=
True
)
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
# mode.
block_table
.
slot_mapping
[
num_actual_tokens
:].
fill_
(
-
1
)
slot_mapping
=
block_table
.
slot_mapping
[:
num_actual_tokens
]
if
self
.
aot_sliding_window
is
None
:
self
.
aot_sliding_window
=
(
-
1
,
-
1
)
...
...
@@ -354,7 +367,7 @@ class FlashAttentionMetadataBuilder:
num_heads_q
=
self
.
num_heads_q
,
num_heads_kv
=
self
.
num_heads_kv
,
headdim
=
self
.
headdim
,
page_size
=
self
.
page
_size
,
page_size
=
self
.
block
_size
,
cu_seqlens_q
=
cu_query_lens
,
causal
=
causal
,
window_size
=
self
.
aot_sliding_window
,
...
...
@@ -365,12 +378,12 @@ class FlashAttentionMetadataBuilder:
local_attn_metadata
=
None
if
self
.
runner
.
attention_chunk_size
is
not
None
:
seqlens_q_local_np
,
virt_q_cu_seqlens_np
,
virt_k_seqlens_np
,
\
virt_block_table
=
make_local_attention_virtual_batches
(
virt_block_table
_tensor
=
make_local_attention_virtual_batches
(
self
.
runner
.
attention_chunk_size
,
self
.
runner
.
query_start_loc_np
[:
num_reqs
+
1
],
self
.
runner
.
seq_lens_np
[:
num_reqs
],
block_table
,
self
.
runner
.
block_size
,
block_table
_tensor
,
self
.
block_size
,
)
local_query_start_loc
=
torch
.
from_numpy
(
virt_q_cu_seqlens_np
).
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
...
...
@@ -389,7 +402,7 @@ class FlashAttentionMetadataBuilder:
local_attn_metadata
=
FlashAttentionMetadata
.
LocalAttentionMetadata
(
local_query_start_loc
=
local_query_start_loc
,
local_seqused_k
=
local_seqused_k
,
local_block_table
=
virt_block_table
,
local_block_table
=
virt_block_table
_tensor
,
local_max_query_len
=
local_max_query_len
,
local_max_seq_len
=
local_max_seq_len
,
local_scheduler_metadata
=
local_scheduler_metadata
,
...
...
@@ -440,7 +453,7 @@ class FlashAttentionMetadataBuilder:
query_start_loc
=
query_start_loc
,
max_seq_len
=
max_seq_len
,
seq_lens
=
seq_lens
,
block_table
=
block_table
,
block_table
=
block_table
_tensor
,
slot_mapping
=
slot_mapping
,
use_cascade
=
use_cascade
,
common_prefix_len
=
common_prefix_len
,
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
950751a9
...
...
@@ -19,6 +19,8 @@ from vllm.config import (VllmConfig, get_current_vllm_config,
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.flash_attn
import
use_cascade_attention
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
...
@@ -202,7 +204,8 @@ class FlashInferMetadata:
class
FlashInferMetadataBuilder
:
def
__init__
(
self
,
runner
:
GPUModelRunner
):
def
__init__
(
self
,
runner
:
GPUModelRunner
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
):
self
.
runner
=
runner
self
.
_workspace_buffer
=
None
self
.
_prefill_wrapper
=
None
# Wrapper for prefill/append
...
...
@@ -213,6 +216,8 @@ class FlashInferMetadataBuilder:
self
.
global_hyperparameters
:
Optional
[
PerLayerParameters
]
=
None
self
.
vllm_config
=
get_current_vllm_config
()
self
.
kv_cache_spec
=
kv_cache_spec
self
.
block_table
=
block_table
def
reorder_batch
(
self
,
input_batch
:
InputBatch
,
scheduler_output
:
SchedulerOutput
)
->
bool
:
...
...
@@ -400,13 +405,12 @@ class FlashInferMetadataBuilder:
assert
self
.
_num_decodes
+
self
.
_num_prefills
==
num_reqs
assert
(
self
.
_num_decode_tokens
+
self
.
_num_prefill_tokens
==
num_actual_tokens
)
page_size
=
self
.
runner
.
block_size
page_size
=
self
.
kv_cache_spec
.
block_size
device
=
self
.
runner
.
device
qo_indptr
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
block_table
=
(
self
.
runner
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
])
slot_mapping
=
self
.
runner
.
slot_mapping_cpu
[:
num_actual_tokens
].
to
(
block_table_tensor
=
self
.
block_table
.
get_device_tensor
()[:
num_reqs
]
slot_mapping
=
self
.
block_table
.
slot_mapping_cpu
[:
num_actual_tokens
].
to
(
self
.
runner
.
device
,
non_blocking
=
True
).
long
()
block_table_bounds
=
(
seq_lens
+
page_size
-
1
)
//
page_size
...
...
@@ -422,12 +426,13 @@ class FlashInferMetadataBuilder:
shared_kv_page_indptr
=
torch
.
tensor
([
0
,
num_common_kv_blocks
],
dtype
=
torch
.
int32
,
device
=
device
)
shared_kv_page_indices
=
block_table
[
0
,
:
num_common_kv_blocks
]
shared_kv_page_indices
=
block_table_tensor
[
0
,
:
num_common_kv_blocks
]
shared_kv_last_page_len
=
torch
.
tensor
([
page_size
],
dtype
=
torch
.
int32
,
device
=
device
)
# Remove the blocks of the shared prefix from all requests.
block_table
=
block_table
[:,
num_common_kv_blocks
:]
block_table
_tensor
=
block_table
_tensor
[:,
num_common_kv_blocks
:]
block_table_bounds
-=
num_common_kv_blocks
else
:
shared_qo_indptr
=
None
...
...
@@ -435,11 +440,11 @@ class FlashInferMetadataBuilder:
shared_kv_page_indices
=
None
shared_kv_last_page_len
=
None
mask
=
(
torch
.
arange
(
block_table
.
size
(
1
),
dtype
=
block_table
.
dtype
,
device
=
block_table
.
device
).
unsqueeze
(
0
)
mask
=
(
torch
.
arange
(
block_table
_tensor
.
size
(
1
),
dtype
=
block_table
_tensor
.
dtype
,
device
=
block_table
_tensor
.
device
).
unsqueeze
(
0
)
<
block_table_bounds
.
unsqueeze
(
1
))
paged_kv_indices
=
block_table
[
mask
]
paged_kv_indices
=
block_table
_tensor
[
mask
]
paged_kv_indptr
=
torch
.
cat
([
torch
.
zeros
(
1
,
...
...
@@ -459,10 +464,10 @@ class FlashInferMetadataBuilder:
paged_kv_indices
=
paged_kv_indices
,
paged_kv_last_page_len
=
paged_kv_last_page_len
,
num_qo_heads
=
self
.
runner
.
num_query_heads
,
num_kv_heads
=
self
.
runner
.
num_kv_heads
,
head_dim
=
self
.
runner
.
head_size
,
num_kv_heads
=
self
.
kv_cache_spec
.
num_kv_heads
,
head_dim
=
self
.
kv_cache_spec
.
head_size
,
page_size
=
page_size
,
data_type
=
self
.
runner
.
kv_cache_dtype
,
data_type
=
self
.
kv_cache_
spec
.
dtype
,
q_data_type
=
self
.
runner
.
dtype
,
slot_mapping
=
slot_mapping
,
num_decodes
=
self
.
_num_decodes
,
...
...
@@ -481,7 +486,7 @@ class FlashInferMetadataBuilder:
return
attn_metadata
def
use_cascade_attention
(
self
,
*
args
,
**
kwargs
)
->
bool
:
if
self
.
runner
.
kv_cache_dtype
!=
self
.
runner
.
model_config
.
dtype
:
if
self
.
kv_cache_
spec
.
dtype
!=
self
.
runner
.
model_config
.
dtype
:
# TODO: The cascade wrapper currently does not support setting
# kv cache dtype to something different from query dtype.
return
False
...
...
vllm/v1/attention/backends/mla/common.py
View file @
950751a9
...
...
@@ -207,6 +207,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
,
round_down
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
...
@@ -334,6 +336,8 @@ class MLACommonMetadataBuilder(Generic[M]):
def
__init__
(
self
,
runner
:
"GPUModelRunner"
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
,
metadata_cls
:
Optional
[
type
[
M
]]
=
None
):
self
.
metadata_cls
=
metadata_cls
\
if
metadata_cls
is
not
None
else
MLACommonMetadata
...
...
@@ -346,10 +350,11 @@ class MLACommonMetadataBuilder(Generic[M]):
runner
.
parallel_config
)
self
.
mla_dims
=
get_mla_dims
(
model_config
)
self
.
aot_schedule
=
is_vllm_fa
and
(
get_flash_attn_version
()
==
3
)
self
.
kv_cache_spec
=
kv_cache_spec
# Dont try to access the runner on AMD
if
self
.
aot_schedule
:
self
.
page_size
=
self
.
runner
.
block_size
self
.
page_size
=
self
.
kv_cache_spec
.
block_size
if
self
.
chunked_prefill_enabled
:
self
.
chunked_prefill_workspace_size
=
min
(
...
...
@@ -375,6 +380,7 @@ class MLACommonMetadataBuilder(Generic[M]):
dtype
=
model_config
.
dtype
,
device
=
runner
.
device
,
)
self
.
block_table
=
block_table
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
...
...
@@ -436,9 +442,10 @@ class MLACommonMetadataBuilder(Generic[M]):
return
modified_batch
def
_build_decode
(
self
,
block_table
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
def
_build_decode
(
self
,
block_table_tensor
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
return
MLACommonDecodeMetadata
(
block_table
=
block_table
,
block_table
=
block_table
_tensor
,
seq_lens
=
seq_lens
,
)
...
...
@@ -451,9 +458,9 @@ class MLACommonMetadataBuilder(Generic[M]):
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device
=
self
.
runner
.
device
block_table
=
(
self
.
runner
.
input_batch
.
block_table
.
get_device_tensor
()[:
num_reqs
]
)
slot_mapping
=
self
.
runner
.
slot_mapping_cpu
[:
num_actual_tokens
].
to
(
block_table
=
self
.
block_table
block_table_tensor
=
block_table
.
get_device_tensor
()[:
num_reqs
]
slot_mapping
=
block_table
.
slot_mapping_cpu
[:
num_actual_tokens
].
to
(
device
,
non_blocking
=
True
).
long
()
query_start_loc
=
common_attn_metadata
.
query_start_loc
...
...
@@ -530,7 +537,7 @@ class MLACommonMetadataBuilder(Generic[M]):
self
.
chunked_prefill_workspace_size
prefill_metadata
=
MLACommonPrefillMetadata
(
block_table
=
block_table
[
reqs_start
:,
...],
block_table
=
block_table
_tensor
[
reqs_start
:,
...],
query_start_loc
=
prefill_query_start_loc
,
max_query_len
=
max_query_len
,
chunked_context
=
chunked_context_metadata
,
...
...
@@ -539,7 +546,7 @@ class MLACommonMetadataBuilder(Generic[M]):
decode_metadata
=
None
if
self
.
_num_decodes
>
0
:
decode_metadata
=
self
.
_build_decode
(
block_table
=
block_table
[:
self
.
_num_decodes
,
...],
block_table
_tensor
=
block_table
_tensor
[:
self
.
_num_decodes
,
...],
seq_lens
=
seq_lens
[:
self
.
_num_decodes
],
)
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
950751a9
...
...
@@ -16,6 +16,8 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl
,
MLACommonMetadata
,
MLACommonMetadataBuilder
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
logger
=
init_logger
(
__name__
)
...
...
@@ -52,13 +54,14 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
class
FlashMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
FlashMLAMetadata
]):
def
__init__
(
self
,
runner
):
super
().
__init__
(
runner
)
def
__init__
(
self
,
runner
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
):
super
().
__init__
(
runner
,
kv_cache_spec
,
block_table
)
self
.
num_q_heads
=
self
.
runner
.
model_config
.
get_num_attention_heads
(
self
.
runner
.
parallel_config
)
def
_build_decode
(
self
,
block_table
:
torch
.
Tensor
,
def
_build_decode
(
self
,
block_table
_tensor
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
)
->
FlashMLADecodeMetadata
:
tile_scheduler_metadata
,
num_splits
=
\
get_mla_metadata
(
...
...
@@ -68,7 +71,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
)
return
FlashMLADecodeMetadata
(
block_table
=
block_table
,
block_table
=
block_table
_tensor
,
seq_lens
=
seq_lens
,
tile_scheduler_metadata
=
tile_scheduler_metadata
,
num_splits
=
num_splits
,
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
950751a9
...
...
@@ -14,6 +14,8 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl
,
MLACommonMetadata
,
MLACommonMetadataBuilder
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
# yapf: enable
...
...
@@ -59,8 +61,9 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class
AiterMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
AiterMLAMetadata
]):
def
__init__
(
self
,
runner
):
super
().
__init__
(
runner
)
def
__init__
(
self
,
runner
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
):
super
().
__init__
(
runner
,
kv_cache_spec
,
block_table
)
max_model_len
=
self
.
runner
.
model_config
.
max_model_len
assert
max_model_len
==
32768
,
\
"AITER MLA requires max_model_len=32768"
...
...
vllm/v1/worker/block_table.py
View file @
950751a9
...
...
@@ -14,11 +14,13 @@ class BlockTable:
self
,
max_num_reqs
:
int
,
max_num_blocks_per_req
:
int
,
max_num_batched_tokens
:
int
,
pin_memory
:
bool
,
device
:
torch
.
device
,
):
self
.
max_num_reqs
=
max_num_reqs
self
.
max_num_blocks_per_req
=
max_num_blocks_per_req
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
pin_memory
=
pin_memory
self
.
device
=
device
...
...
@@ -36,6 +38,15 @@ class BlockTable:
self
.
block_table_np
=
self
.
block_table_cpu
.
numpy
()
self
.
num_blocks_per_row
=
np
.
zeros
(
max_num_reqs
,
dtype
=
np
.
int32
)
self
.
slot_mapping_cpu
=
torch
.
zeros
(
self
.
max_num_batched_tokens
,
dtype
=
torch
.
int64
,
device
=
"cpu"
,
pin_memory
=
self
.
pin_memory
)
self
.
slot_mapping_np
=
self
.
slot_mapping_cpu
.
numpy
()
self
.
slot_mapping
=
torch
.
zeros
(
self
.
max_num_batched_tokens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
def
append_row
(
self
,
block_ids
:
list
[
int
],
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
950751a9
...
...
@@ -59,6 +59,7 @@ class InputBatch:
max_num_reqs
:
int
,
max_model_len
:
int
,
max_num_blocks_per_req
:
int
,
max_num_batched_tokens
:
int
,
device
:
torch
.
device
,
pin_memory
:
bool
,
vocab_size
:
int
,
...
...
@@ -66,6 +67,7 @@ class InputBatch:
self
.
max_num_reqs
=
max_num_reqs
self
.
max_model_len
=
max_model_len
self
.
max_num_blocks_per_req
=
max_num_blocks_per_req
self
.
max_num_batched_tokens
=
max_num_batched_tokens
self
.
device
=
device
self
.
pin_memory
=
pin_memory
self
.
vocab_size
=
vocab_size
...
...
@@ -100,6 +102,7 @@ class InputBatch:
self
.
block_table
=
BlockTable
(
max_num_reqs
=
max_num_reqs
,
max_num_blocks_per_req
=
max_num_blocks_per_req
,
max_num_batched_tokens
=
max_num_batched_tokens
,
pin_memory
=
pin_memory
,
device
=
device
,
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
950751a9
...
...
@@ -150,8 +150,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
f
"FA3. Current attention backend is
{
attn_backend_name
}
, "
f
"FlashAttention version is
{
flash_attn_version
}
."
)
self
.
attn_metadata_builder
=
self
.
attn_backend
.
get_builder_cls
()(
weakref
.
proxy
(
self
))
self
.
cascade_attn_enabled
=
not
self
.
model_config
.
disable_cascade_attn
# Multi-modal data support
...
...
@@ -174,6 +172,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Initialize in initialize_kv_cache
self
.
kv_caches
:
list
[
torch
.
Tensor
]
=
[]
# self.kv_cache_config: KVCacheConfig
# self.attn_metadata_builder: type[AttentionMetadataBuilder]
# req_id -> (input_id -> encoder_output)
self
.
encoder_cache
:
dict
[
str
,
dict
[
int
,
torch
.
Tensor
]]
=
{}
...
...
@@ -203,6 +202,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
max_num_blocks_per_req
=
self
.
max_num_blocks_per_req
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
model_config
.
get_vocab_size
(),
...
...
@@ -291,11 +291,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device
=
"cpu"
,
pin_memory
=
self
.
pin_memory
)
self
.
positions_np
=
self
.
positions_cpu
.
numpy
()
self
.
slot_mapping_cpu
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
"cpu"
,
pin_memory
=
self
.
pin_memory
)
self
.
slot_mapping_np
=
self
.
slot_mapping_cpu
.
numpy
()
self
.
query_start_loc_cpu
=
torch
.
zeros
(
self
.
max_num_reqs
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
...
...
@@ -586,7 +581,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_offsets
=
positions_np
%
self
.
block_size
np
.
add
(
block_numbers
*
self
.
block_size
,
block_offsets
,
out
=
self
.
slot_mapping_np
[:
total_num_scheduled_tokens
])
out
=
self
.
input_batch
.
block_table
.
slot_mapping_np
[:
total_num_scheduled_tokens
])
# Prepare the attention metadata.
self
.
query_start_loc_np
[
0
]
=
0
...
...
@@ -614,12 +610,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
query_start_loc_cpu
[:
num_reqs
+
1
],
non_blocking
=
True
)
self
.
seq_lens
[:
num_reqs
].
copy_
(
self
.
seq_lens_cpu
[:
num_reqs
],
non_blocking
=
True
)
self
.
slot_mapping
[:
total_num_scheduled_tokens
].
copy_
(
self
.
slot_mapping_cpu
[:
total_num_scheduled_tokens
],
non_blocking
=
True
)
# Fill unused with -1. Needed for reshape_and_cache
self
.
slot_mapping
[
total_num_scheduled_tokens
:].
fill_
(
-
1
)
self
.
seq_lens
[
num_reqs
:].
fill_
(
0
)
self
.
query_start_loc
[
num_reqs
+
1
:].
fill_
(
-
1
)
...
...
@@ -1821,6 +1813,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
self
.
kv_caches
)
self
.
attn_metadata_builder
=
self
.
attn_backend
.
get_builder_cls
()(
weakref
.
proxy
(
self
),
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
,
self
.
input_batch
.
block_table
)
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
950751a9
...
...
@@ -179,6 +179,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
max_num_reqs
=
self
.
max_num_reqs
,
max_model_len
=
self
.
max_model_len
,
max_num_blocks_per_req
=
self
.
max_num_blocks_per_req
,
max_num_batched_tokens
=
self
.
max_num_tokens
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
vocab_size
=
self
.
vocab_size
,
...
...
@@ -197,10 +198,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
device
=
"cpu"
)
self
.
positions_np
=
self
.
positions_cpu
.
numpy
()
self
.
slot_mapping_cpu
=
torch
.
zeros
(
self
.
max_num_tokens
,
dtype
=
torch
.
int64
,
device
=
"cpu"
)
self
.
slot_mapping_np
=
self
.
slot_mapping_cpu
.
numpy
()
self
.
block_table_cpu
=
torch
.
zeros
(
(
self
.
max_num_reqs
,
self
.
max_num_blocks_per_req
),
dtype
=
self
.
input_batch
.
block_table
.
get_cpu_tensor
().
dtype
,
...
...
@@ -533,7 +530,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
block_offsets
=
positions_np
%
self
.
block_size
np
.
add
(
block_numbers
*
self
.
block_size
,
block_offsets
,
out
=
self
.
slot_mapping_np
[:
total_num_scheduled_tokens
])
out
=
self
.
input_batch
.
block_table
.
slot_mapping_cpu
[:
total_num_scheduled_tokens
])
# Prepare the attention metadata.
self
.
query_start_loc_np
[
0
]
=
0
...
...
@@ -557,10 +555,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self
.
position_ids
=
self
.
positions_cpu
[:
padded_total_num_scheduled_tokens
].
to
(
self
.
device
)
self
.
slot_mapping_cpu
[
total_num_scheduled_tokens
:]
=
_PAD_SLOT_ID
slot_mapping
=
self
.
slot_mapping_cpu
[:
padded_total_num_scheduled_tokens
].
to
(
self
.
device
)
self
.
input_batch
.
block_table
.
slot_mapping_cpu
[
total_num_scheduled_tokens
:]
=
_PAD_SLOT_ID
slot_mapping
=
(
self
.
input_batch
.
block_table
.
slot_mapping_cpu
[:
padded_total_num_scheduled_tokens
].
to
(
self
.
device
))
block_tables
=
self
.
block_table_cpu
[:
self
.
max_num_reqs
]
block_tables
[:
num_reqs
,
:
self
.
max_num_blocks_per_req
]
=
(
self
.
input_batch
.
block_table
.
get_cpu_tensor
()[:
num_reqs
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment