Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5c057e06
Unverified
Commit
5c057e06
authored
Oct 04, 2025
by
Li, Jiang
Committed by
GitHub
Oct 04, 2025
Browse files
[CPU] Refine batch reorder of CPU attention backend (#26096)
Signed-off-by:
jiang1.li
<
jiang1.li@intel.com
>
parent
ed3aeb25
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
128 deletions
+44
-128
vllm/v1/attention/backends/cpu_attn.py
vllm/v1/attention/backends/cpu_attn.py
+42
-87
vllm/v1/worker/cpu_model_runner.py
vllm/v1/worker/cpu_model_runner.py
+2
-41
No files found.
vllm/v1/attention/backends/cpu_attn.py
View file @
5c057e06
...
@@ -14,10 +14,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
...
@@ -14,10 +14,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
CommonAttentionMetadata
,
from
vllm.v1.core.sched.output
import
SchedulerOutput
split_decodes_and_prefills
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
try
:
try
:
import
intel_extension_for_pytorch.llm.modules
as
ipex_modules
import
intel_extension_for_pytorch.llm.modules
as
ipex_modules
...
@@ -102,16 +101,16 @@ class TorchSDPAMetadata(AttentionMetadata):
...
@@ -102,16 +101,16 @@ class TorchSDPAMetadata(AttentionMetadata):
"""Metadata for PagedAttention."""
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
# sequence.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
decode_
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
max_
decode_seq_len
:
int
decode_
max_
seq_len
:
int
# (batch_size, max_blocks_per_seq).
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
# captured.
block_tables
:
Optional
[
torch
.
Tensor
]
decode_
block_tables
:
Optional
[
torch
.
Tensor
]
"""Metadata for TorchSDPABackend.
"""Metadata for TorchSDPABackend.
"""
"""
# Currently, input sequences can only contain all prompts
# Currently, input sequences can only contain all prompts
...
@@ -121,9 +120,9 @@ class TorchSDPAMetadata(AttentionMetadata):
...
@@ -121,9 +120,9 @@ class TorchSDPAMetadata(AttentionMetadata):
# For chunked prefill only
# For chunked prefill only
max_query_len
:
Optional
[
int
]
=
None
max_query_len
:
Optional
[
int
]
=
None
max_
kv
_len
:
Optional
[
int
]
=
None
prefill_
max_
seq
_len
:
Optional
[
int
]
=
None
prefill_query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
prefill_query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
kv
_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
prefill_seq
_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
prefill_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
prefill_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
# For V1 logits index only
# For V1 logits index only
...
@@ -307,8 +306,8 @@ class TorchSDPAMetadata(AttentionMetadata):
...
@@ -307,8 +306,8 @@ class TorchSDPAMetadata(AttentionMetadata):
or
attn_type
==
AttentionType
.
ENCODER_ONLY
):
or
attn_type
==
AttentionType
.
ENCODER_ONLY
):
# Decoder self-attention
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
# Choose max_seq_len based on whether we are in prompt_run
return
(
self
.
seq_lens_tensor
,
self
.
max_
decode_seq_len
,
return
(
self
.
decode_
seq_lens_tensor
,
self
.
decode_
max_
seq_len
,
self
.
block_tables
)
self
.
decode_
block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Enc/dec cross-attention KVs match encoder sequence length;
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
# cross-attention utilizes special "cross" block tables
...
@@ -323,19 +322,14 @@ class TorchSDPAMetadata(AttentionMetadata):
...
@@ -323,19 +322,14 @@ class TorchSDPAMetadata(AttentionMetadata):
class
TorchSDPAMetadataBuilderV1
(
AttentionMetadataBuilder
[
TorchSDPAMetadata
]):
class
TorchSDPAMetadataBuilderV1
(
AttentionMetadataBuilder
[
TorchSDPAMetadata
]):
reorder_batch_threshold
:
int
=
1
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
layer_names
:
list
[
str
],
vllm_config
:
VllmConfig
,
device
:
torch
.
device
)
->
None
:
vllm_config
:
VllmConfig
,
device
:
torch
.
device
)
->
None
:
super
().
__init__
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
super
().
__init__
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
_init_reorder_batch_threshold
(
1
,
False
)
# For reorder
self
.
reorder_prompt_req_index_list
=
np
.
empty
(
vllm_config
.
scheduler_config
.
max_num_seqs
,
dtype
=
np
.
int64
)
self
.
reorder_decode_req_index_list
=
np
.
empty
(
vllm_config
.
scheduler_config
.
max_num_seqs
,
dtype
=
np
.
int64
)
self
.
num_prompt_req
:
int
=
0
self
.
seq_start_loc_cpu
=
torch
.
zeros
(
self
.
seq_start_loc_cpu
=
torch
.
zeros
(
vllm_config
.
scheduler_config
.
max_num_seqs
+
1
,
vllm_config
.
scheduler_config
.
max_num_seqs
+
1
,
...
@@ -344,50 +338,6 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
...
@@ -344,50 +338,6 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
)
)
self
.
seq_start_loc_np
=
self
.
seq_start_loc_cpu
.
numpy
()
self
.
seq_start_loc_np
=
self
.
seq_start_loc_cpu
.
numpy
()
def
reorder_batch
(
self
,
input_batch
:
InputBatch
,
scheduler_output
:
SchedulerOutput
)
->
bool
:
prompt_list_idx
=
0
decode_list_idx
=
0
for
req_index
in
range
(
input_batch
.
num_reqs
):
if
input_batch
.
num_computed_tokens_cpu
[
req_index
]
<
input_batch
.
num_prompt_tokens
[
req_index
]:
# prompt stage
self
.
reorder_prompt_req_index_list
[
prompt_list_idx
]
=
req_index
prompt_list_idx
+=
1
else
:
# decode stage
self
.
reorder_decode_req_index_list
[
decode_list_idx
]
=
req_index
decode_list_idx
+=
1
assert
decode_list_idx
+
prompt_list_idx
==
input_batch
.
num_reqs
# Update prompt requests number
self
.
num_prompt_req
=
prompt_list_idx
reorder_req_num
=
0
for
req_index
in
range
(
decode_list_idx
):
if
self
.
reorder_decode_req_index_list
[
req_index
]
<
prompt_list_idx
:
reorder_req_num
+=
1
else
:
break
if
reorder_req_num
==
0
:
return
False
reorder_prompt_list
=
(
self
.
reorder_prompt_req_index_list
[:
prompt_list_idx
]
[
-
reorder_req_num
:])
reorder_decode_list
=
(
self
.
reorder_decode_req_index_list
[:
decode_list_idx
]
[:
reorder_req_num
])
assert
reorder_decode_list
.
size
==
reorder_prompt_list
.
size
for
idx
in
range
(
reorder_req_num
):
prompt_req_index
=
reorder_prompt_list
[
idx
].
item
()
decode_req_index
=
reorder_decode_list
[
idx
].
item
()
input_batch
.
swap_states
(
prompt_req_index
,
decode_req_index
)
return
True
def
build
(
self
,
def
build
(
self
,
common_prefix_len
:
int
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
common_attn_metadata
:
CommonAttentionMetadata
,
...
@@ -397,41 +347,46 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
...
@@ -397,41 +347,46 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
seq_lens_np
=
seq_lens_cpu
.
numpy
()
seq_lens_np
=
seq_lens_cpu
.
numpy
()
num_prompt_req
=
self
.
num_prompt_req
max_prefill_seq_len
=
seq_lens_np
[:
num_prompt_req
].
max
().
item
(
)
if
num_prompt_req
>
0
else
0
max_decode_seq_len
=
seq_lens_np
[
num_prompt_req
:
num_reqs
].
max
().
item
(
)
if
num_prompt_req
<
num_reqs
else
0
self
.
seq_start_loc_np
[
0
]
=
0
np
.
cumsum
(
seq_lens_np
,
out
=
self
.
seq_start_loc_np
[
1
:
num_reqs
+
1
])
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
num_prefill_tokens
=
int
(
query_start_loc_cpu
[
num_prompt_req
].
item
())
query_start_loc_np
=
query_start_loc_cpu
.
numpy
()
num_decode_tokens
=
int
(
query_start_loc_cpu
[
num_reqs
].
item
()
-
num_prefill_tokens
)
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
\
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
,
require_uniform
=
True
)
max_prefill_seq_len
=
seq_lens_np
[
num_decodes
:
num_reqs
].
max
().
item
(
)
if
num_prefills
>
0
else
0
max_decode_seq_len
=
seq_lens_np
[:
num_decodes
].
max
().
item
(
)
if
num_prefills
<
num_reqs
else
0
self
.
seq_start_loc_np
[
0
]
=
0
np
.
cumsum
(
seq_lens_np
,
out
=
self
.
seq_start_loc_np
[
1
:
num_reqs
+
1
])
slot_mapping
=
common_attn_metadata
.
slot_mapping
.
long
()
slot_mapping
=
common_attn_metadata
.
slot_mapping
.
long
()
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
query_start_loc_np
=
query_start_loc_cpu
.
numpy
()
query_start_loc_np
[
num_decodes
:
num_reqs
+
1
]
-=
num_decode_tokens
attn_metadata
=
TorchSDPAMetadata
(
attn_metadata
=
TorchSDPAMetadata
(
num_prefills
=
num_pr
ompt_req
,
num_prefills
=
num_pr
efills
,
num_prefill_tokens
=
num_prefill_tokens
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
num_decode_tokens
=
num_decode_tokens
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
# to ensure inference when chunked_prefill is disabled
# to ensure inference when chunked_prefill is disabled
seq_lens
=
seq_lens_cpu
.
tolist
(),
seq_lens
=
seq_lens_cpu
.
tolist
(),
seq_lens_tensor
=
seq_lens_cpu
[
num_
prompt_req
:
num_req
s
],
# decode
decode_
seq_lens_tensor
=
seq_lens_cpu
[
:
num_
decode
s
],
# decode
max_
decode_seq_len
=
max_decode_seq_len
,
# decode
decode_
max_
seq_len
=
max_decode_seq_len
,
# decode
block_tables
=
block_table_tensor
[
num_
prompt_req
:
num_req
s
],
# decode
decode_
block_tables
=
block_table_tensor
[
:
num_
decode
s
],
# decode
chunked_prefill
=
self
.
scheduler_config
.
chunked_prefill_enabled
,
chunked_prefill
=
self
.
scheduler_config
.
chunked_prefill_enabled
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
max_
kv
_len
=
max_prefill_seq_len
,
prefill_
max_
seq
_len
=
max_prefill_seq_len
,
prefill_query_start_loc
=
query_start_loc_cpu
[
:
num_
prompt
_req
+
prefill_query_start_loc
=
query_start_loc_cpu
[
num_
decodes
:
num
_req
s
+
1
],
# prefill
1
],
# prefill
kv
_start_loc
=
self
.
seq_start_loc_cpu
[
:
num_
prompt
_req
+
prefill_seq
_start_loc
=
self
.
seq_start_loc_cpu
[
num_
decodes
:
num
_req
s
+
1
],
# prefill
1
],
# prefill
prefill_block_tables
=
block_table_tensor
[
:
prefill_block_tables
=
block_table_tensor
[
num_
prompt
_req
],
# prefill
num_
decodes
:
num
_req
s
],
# prefill
query_start_loc
=
query_start_loc_cpu
[:
num_reqs
+
query_start_loc
=
query_start_loc_cpu
[:
num_reqs
+
1
],
# for logits index
1
],
# for logits index
)
)
...
@@ -596,14 +551,14 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -596,14 +551,14 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
import
intel_extension_for_pytorch.llm.modules
as
ipex_modules
import
intel_extension_for_pytorch.llm.modules
as
ipex_modules
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
ipex_modules
.
PagedAttention
.
flash_attn_varlen_func
(
ipex_modules
.
PagedAttention
.
flash_attn_varlen_func
(
output
[
:
prefill_meta
.
num_
prefill
_tokens
,
:,
:],
output
[
prefill_meta
.
num_
decode
_tokens
:
,
:,
:],
query
[
:
prefill_meta
.
num_
prefill
_tokens
,
:,
:],
query
[
prefill_meta
.
num_
decode
_tokens
:
,
:,
:],
key_cache
,
key_cache
,
value_cache
,
value_cache
,
prefill_meta
.
prefill_query_start_loc
,
prefill_meta
.
prefill_query_start_loc
,
prefill_meta
.
kv
_start_loc
,
prefill_meta
.
prefill_seq
_start_loc
,
prefill_meta
.
max_query_len
,
prefill_meta
.
max_query_len
,
prefill_meta
.
max_
kv
_len
,
prefill_meta
.
prefill_
max_
seq
_len
,
self
.
scale
,
self
.
scale
,
True
,
True
,
prefill_meta
.
prefill_block_tables
,
prefill_meta
.
prefill_block_tables
,
...
@@ -621,8 +576,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -621,8 +576,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
)
=
decode_meta
.
get_seq_len_block_table_args
(
attn_type
)
)
=
decode_meta
.
get_seq_len_block_table_args
(
attn_type
)
self
.
paged_attn_impl
.
forward_decode
(
self
.
paged_attn_impl
.
forward_decode
(
output
[
attn_metadata
.
num_
prefill
_tokens
:
,
:,
:],
output
[
:
attn_metadata
.
num_
decode
_tokens
,
:,
:],
query
[
attn_metadata
.
num_
prefill
_tokens
:
,
:,
:],
query
[
:
attn_metadata
.
num_
decode
_tokens
,
:,
:],
key_cache
,
key_cache
,
value_cache
,
value_cache
,
block_tables_arg
,
block_tables_arg
,
...
...
vllm/v1/worker/cpu_model_runner.py
View file @
5c057e06
...
@@ -9,7 +9,6 @@ import torch.nn as nn
...
@@ -9,7 +9,6 @@ import torch.nn as nn
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.v1.attention.backends.cpu_attn
import
TorchSDPAMetadataBuilderV1
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
...
@@ -33,50 +32,12 @@ class CPUModelRunner(GPUModelRunner):
...
@@ -33,50 +32,12 @@ class CPUModelRunner(GPUModelRunner):
self
.
_postprocess_tensors
()
self
.
_postprocess_tensors
()
# Note: Remove the override after new attention backend finished
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
Args:
scheduler_output: The scheduler output.
"""
# Attention free models have zero kv_cache_groups, however models
# like Mamba are also attention free but use the kv_cache for
# keeping its internal state. This is why we check the number
# of kv_cache groups instead of solely checking
# for self.model_config.is_attention_free.
if
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
==
0
:
return
if
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
>
1
:
if
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
>
1
:
raise
ValueError
(
"Multiple KVCacheGroups is not"
raise
ValueError
(
"Multiple KVCacheGroups is not"
"currently supported with CPU model runner."
)
"currently supported with CPU model runner."
)
super
().
_may_reorder_batch
(
scheduler_output
)
# Guard against encoder-only / pooling models where `attn_groups`
# may be empty or lack the expected metadata_builder.
# Without this check, accessing `attn_groups[0][0]` would trigger
# an AssertionError on CPU backend.
if
not
hasattr
(
self
,
"attn_groups"
)
or
not
self
.
attn_groups
:
return
if
not
self
.
attn_groups
[
0
]:
return
mb
=
getattr
(
self
.
attn_groups
[
0
][
0
],
"metadata_builders"
,
None
)
if
isinstance
(
mb
,
list
):
if
not
isinstance
(
mb
[
0
],
TorchSDPAMetadataBuilderV1
):
return
mb
[
0
].
reorder_batch
(
self
.
input_batch
,
scheduler_output
)
return
elif
not
isinstance
(
mb
,
TorchSDPAMetadataBuilderV1
):
# Encoder-only / rerank models do not benefit from reordering,
# so we safely skip here.
return
# Safe path for decoder/attention-heavy models
mb
.
reorder_batch
(
self
.
input_batch
,
scheduler_output
)
def
_postprocess_tensors
(
self
)
->
None
:
def
_postprocess_tensors
(
self
)
->
None
:
# Note: replace device tensors with cpu tensors
# Note: replace device tensors with cpu tensors
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment