Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7f829be7
Unverified
Commit
7f829be7
authored
Nov 12, 2025
by
Li, Jiang
Committed by
GitHub
Nov 12, 2025
Browse files
[CPU] Refactor CPU attention backend (#27954)
Signed-off-by:
jiang1.li
<
jiang1.li@intel.com
>
parent
e1710393
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
953 additions
and
776 deletions
+953
-776
tests/kernels/attention/test_attention_selector.py
tests/kernels/attention/test_attention_selector.py
+3
-3
tests/kernels/attention/test_cpu_attn.py
tests/kernels/attention/test_cpu_attn.py
+575
-0
tests/kernels/test_onednn.py
tests/kernels/test_onednn.py
+0
-1
tests/models/language/generation/test_common.py
tests/models/language/generation/test_common.py
+10
-7
tests/models/language/pooling/test_embedding.py
tests/models/language/pooling/test_embedding.py
+1
-2
tests/models/registry.py
tests/models/registry.py
+3
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+82
-0
vllm/attention/backends/registry.py
vllm/attention/backends/registry.py
+2
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+0
-3
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+11
-26
vllm/utils/__init__.py
vllm/utils/__init__.py
+0
-1
vllm/v1/attention/backends/cpu_attn.py
vllm/v1/attention/backends/cpu_attn.py
+264
-717
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+1
-1
vllm/v1/worker/cpu_model_runner.py
vllm/v1/worker/cpu_model_runner.py
+1
-13
No files found.
tests/kernels/attention/test_attention_selector.py
View file @
7f829be7
...
...
@@ -35,7 +35,7 @@ DEVICE_MLA_BACKENDS = {
DEVICE_REGULAR_ATTN_BACKENDS
=
{
"cuda"
:
[
"XFORMERS"
,
"FLASHINFER"
,
"FLASH_ATTN"
],
"hip"
:
[
"ROCM_ATTN"
],
"cpu"
:
[
"
TORCH_SDPA
"
],
"cpu"
:
[
"
CPU_ATTN
"
],
}
DEVICE_MLA_BLOCK_SIZES
=
{
...
...
@@ -86,7 +86,7 @@ def test_env(
if
device
==
"cpu"
:
with
patch
(
"vllm.platforms.current_platform"
,
CpuPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
block_size
)
assert
backend
.
get_name
()
==
"
TORCH_SDPA
"
assert
backend
.
get_name
()
==
"
CPU_ATTN
"
elif
device
==
"hip"
:
with
patch
(
"vllm.platforms.current_platform"
,
RocmPlatform
()):
...
...
@@ -224,7 +224,7 @@ def test_fp32_fallback(device: str):
if
device
==
"cpu"
:
with
patch
(
"vllm.platforms.current_platform"
,
CpuPlatform
()):
backend
=
get_attn_backend
(
16
,
torch
.
float32
,
None
,
16
)
assert
backend
.
get_name
()
==
"
TORCH_SDPA
"
assert
backend
.
get_name
()
==
"
CPU_ATTN
"
elif
device
==
"cuda"
:
with
patch
(
"vllm.platforms.current_platform"
,
CudaPlatform
()):
...
...
tests/kernels/attention/test_cpu_attn.py
0 → 100644
View file @
7f829be7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
functools
import
math
import
pytest
import
torch
from
vllm.platforms
import
current_platform
if
not
current_platform
.
is_cpu
():
pytest
.
skip
(
"skipping CPU-only tests"
,
allow_module_level
=
True
)
from
vllm._custom_ops
import
(
cpu_attention_with_kv_cache
,
cpu_attn_get_scheduler_metadata
,
cpu_attn_reshape_and_cache
,
)
NUM_HEADS
=
[
(
4
,
4
),
(
8
,
2
),
(
9
,
3
),
]
HEAD_SIZES
=
[
96
,
128
]
QTYPES
=
[
torch
.
bfloat16
,
torch
.
half
,
torch
.
float32
]
SLIDING_WINDOWS
=
[
None
,
256
]
NUM_BLOCKS
=
[
1024
,
]
SEQ_LENS
=
[
# (q_len, kv_len)
[(
1
,
213
),
(
1
,
1
),
(
1
,
312
),
(
1
,
7
),
(
1
,
7812
)],
# decode batch
[(
2345
,
2345
),
(
5
,
5
),
(
3
,
16
),
(
134
,
5131
)],
# prefill batch
[(
992
,
2456
),
(
1
,
1234
),
(
98
,
1145
),
(
1
,
4162
),
(
2345
,
2345
)],
# mixed batch
]
# rand number generation takes too much time, cache rand tensors
@
functools
.
lru_cache
(
maxsize
=
128
,
typed
=
False
)
def
tensor_cache
(
elem_num
:
int
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
tensor
=
torch
.
randn
(
elem_num
,
dtype
=
dtype
)
return
tensor
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
total_num_heads
))
base
=
torch
.
tensor
(
2
**
(
-
(
2
**
-
(
math
.
log2
(
closest_power_of_2
)
-
3
))),
dtype
=
torch
.
float32
,
)
powers
=
torch
.
arange
(
1
,
1
+
closest_power_of_2
,
dtype
=
torch
.
int32
)
slopes
=
torch
.
pow
(
base
,
powers
)
if
closest_power_of_2
!=
total_num_heads
:
extra_base
=
torch
.
tensor
(
2
**
(
-
(
2
**
-
(
math
.
log2
(
2
*
closest_power_of_2
)
-
3
))),
dtype
=
torch
.
float32
,
)
num_remaining_heads
=
min
(
closest_power_of_2
,
total_num_heads
-
closest_power_of_2
)
extra_powers
=
torch
.
arange
(
start
=
1
,
end
=
1
+
2
*
num_remaining_heads
,
step
=
2
,
dtype
=
torch
.
int32
)
slopes
=
torch
.
cat
([
slopes
,
torch
.
pow
(
extra_base
,
extra_powers
)],
dim
=
0
)
return
slopes
.
float
()
def
ref_paged_attn
(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
query_lens
:
list
[
int
],
kv_lens
:
list
[
int
],
block_tables
:
torch
.
Tensor
,
scale
:
float
,
sliding_window
:
int
|
None
=
None
,
soft_cap
:
float
|
None
=
None
,
alibi_slopes
:
torch
.
Tensor
|
None
=
None
,
s_aux
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
num_seqs
=
len
(
query_lens
)
block_tables
=
block_tables
.
cpu
().
numpy
()
_
,
block_size
,
num_kv_heads
,
head_size
=
key_cache
.
shape
dtype
=
query
.
dtype
outputs
:
list
[
torch
.
Tensor
]
=
[]
start_idx
=
0
if
alibi_slopes
is
not
None
:
alibi_slopes
=
alibi_slopes
[:,
None
,
None
]
if
s_aux
is
not
None
:
s_aux
=
s_aux
.
float
()
s_aux
=
s_aux
[:,
None
,
None
]
for
i
in
range
(
num_seqs
):
query_len
=
query_lens
[
i
]
kv_len
=
kv_lens
[
i
]
q
=
query
[
start_idx
:
start_idx
+
query_len
].
float
()
q
*=
scale
num_kv_blocks
=
(
kv_len
+
block_size
-
1
)
//
block_size
block_indices
=
block_tables
[
i
,
:
num_kv_blocks
]
k
=
key_cache
[
block_indices
].
view
(
-
1
,
num_kv_heads
,
head_size
)
k
=
k
[:
kv_len
].
float
()
v
=
value_cache
[
block_indices
].
view
(
-
1
,
num_kv_heads
,
head_size
)
v
=
v
[:
kv_len
].
float
()
if
q
.
shape
[
1
]
!=
k
.
shape
[
1
]:
k
=
torch
.
repeat_interleave
(
k
,
q
.
shape
[
1
]
//
k
.
shape
[
1
],
dim
=
1
)
v
=
torch
.
repeat_interleave
(
v
,
q
.
shape
[
1
]
//
v
.
shape
[
1
],
dim
=
1
)
attn
=
torch
.
einsum
(
"qhd,khd->hqk"
,
q
,
k
).
float
()
empty_mask
=
torch
.
ones
(
query_len
,
kv_len
)
mask
=
torch
.
triu
(
empty_mask
,
diagonal
=
kv_len
-
query_len
+
1
).
bool
()
if
sliding_window
is
not
None
:
sliding_window_mask
=
(
torch
.
triu
(
empty_mask
,
diagonal
=
kv_len
-
(
query_len
+
sliding_window
)
+
1
)
.
bool
()
.
logical_not
()
)
mask
|=
sliding_window_mask
if
soft_cap
is
not
None
:
attn
=
soft_cap
*
torch
.
tanh
(
attn
/
soft_cap
)
if
alibi_slopes
is
not
None
:
q_start_pos
=
kv_len
-
query_len
q_pos
=
q_start_pos
+
torch
.
arange
(
0
,
query_len
)[
None
,
:,
None
]
kv_pos
=
torch
.
arange
(
0
,
kv_len
)[
None
,
None
,
:]
dist
=
q_pos
-
kv_pos
alibi_bias
=
-
alibi_slopes
*
dist
attn
+=
alibi_bias
attn
.
masked_fill_
(
mask
,
float
(
"-inf"
))
if
s_aux
is
not
None
:
s_aux_ext
=
s_aux
.
repeat
(
1
,
query_len
,
1
)
attn
=
torch
.
cat
((
s_aux_ext
,
attn
),
dim
=-
1
)
attn
=
torch
.
softmax
(
attn
,
dim
=-
1
)
if
s_aux
is
not
None
:
attn
=
attn
[:,
:,
1
:]
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn
,
v
).
to
(
dtype
=
dtype
)
outputs
.
append
(
out
)
start_idx
+=
query_len
return
torch
.
cat
(
outputs
,
dim
=
0
)
@
torch
.
inference_mode
()
def
varlen_with_paged_kv
(
seq_lens
:
list
[
tuple
[
int
,
int
]],
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
sliding_window
:
int
|
None
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
float
|
None
,
num_blocks
:
int
,
use_alibi
:
bool
,
use_sink
:
bool
,
isa
:
str
,
)
->
None
:
current_platform
.
seed_everything
(
0
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
window_size
=
(
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
)
scale
=
head_size
**-
0.5
token_num
=
sum
(
query_lens
)
# for n heads the set of slopes is the geometric sequence that starts
# 2^(-8/n)
alibi_slopes
=
_get_alibi_slopes
(
num_query_heads
)
if
use_alibi
else
None
s_aux
=
(
15
*
torch
.
rand
((
num_query_heads
,),
dtype
=
torch
.
bfloat16
)
if
use_sink
else
None
)
query
=
tensor_cache
(
elem_num
=
token_num
*
num_query_heads
*
head_size
,
dtype
=
dtype
,
)
query
=
query
.
view
(
token_num
,
num_query_heads
,
head_size
,
)
key_value
=
tensor_cache
(
elem_num
=
2
*
num_blocks
*
num_kv_heads
*
block_size
*
head_size
,
dtype
=
dtype
,
)
key_value
=
key_value
.
view
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
,
)
key_cache
,
value_cache
=
key_value
.
unbind
(
0
)
# KV cache for CPU attention
packed_key_cache
=
torch
.
empty
(
num_blocks
,
num_kv_heads
,
block_size
,
head_size
,
dtype
=
dtype
)
packed_value_cache
=
torch
.
empty_like
(
packed_key_cache
)
cu_query_lens
=
torch
.
tensor
([
0
]
+
query_lens
,
dtype
=
torch
.
int32
).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
kv_lens_tensor
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
num_blocks
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
# use reshape_and_cache to pack key_cache and value_cache
slot_mapping
=
torch
.
arange
(
0
,
num_blocks
*
block_size
,
dtype
=
torch
.
int64
)
cpu_attn_reshape_and_cache
(
key
=
key_cache
.
view
(
-
1
,
num_kv_heads
,
head_size
),
value
=
value_cache
.
view
(
-
1
,
num_kv_heads
,
head_size
),
key_cache
=
packed_key_cache
,
value_cache
=
packed_value_cache
,
slot_mapping
=
slot_mapping
,
isa
=
isa
,
)
metadata
=
cpu_attn_get_scheduler_metadata
(
num_reqs
=
num_seqs
,
num_heads
=
num_query_heads
,
num_kv_heads
=
num_kv_heads
,
head_dim
=
head_size
,
seq_lens
=
kv_lens_tensor
,
dtype
=
dtype
,
query_start_loc
=
cu_query_lens
,
causal
=
True
,
sliding_window_size
=
sliding_window
if
sliding_window
is
not
None
else
-
1
,
isa
=
isa
,
enable_kv_split
=
False
,
)
out_without_split
=
torch
.
empty_like
(
query
)
cpu_attention_with_kv_cache
(
query
=
query
,
key_cache
=
packed_key_cache
,
value_cache
=
packed_value_cache
,
output
=
out_without_split
,
query_start_loc
=
cu_query_lens
,
seq_lens
=
kv_lens_tensor
,
scale
=
scale
,
causal
=
True
,
alibi_slopes
=
alibi_slopes
,
sliding_window
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
scheduler_metadata
=
metadata
,
s_aux
=
s_aux
,
)
metadata
=
cpu_attn_get_scheduler_metadata
(
num_reqs
=
num_seqs
,
num_heads
=
num_query_heads
,
num_kv_heads
=
num_kv_heads
,
head_dim
=
head_size
,
seq_lens
=
kv_lens_tensor
,
dtype
=
dtype
,
query_start_loc
=
cu_query_lens
,
causal
=
True
,
sliding_window_size
=
sliding_window
if
sliding_window
is
not
None
else
-
1
,
isa
=
isa
,
enable_kv_split
=
True
,
)
out_with_split
=
torch
.
empty_like
(
query
)
cpu_attention_with_kv_cache
(
query
=
query
,
key_cache
=
packed_key_cache
,
value_cache
=
packed_value_cache
,
output
=
out_with_split
,
query_start_loc
=
cu_query_lens
,
seq_lens
=
kv_lens_tensor
,
scale
=
scale
,
causal
=
True
,
alibi_slopes
=
alibi_slopes
,
sliding_window
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
scheduler_metadata
=
metadata
,
s_aux
=
s_aux
,
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
query_lens
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
sliding_window
=
sliding_window
,
soft_cap
=
soft_cap
,
alibi_slopes
=
alibi_slopes
,
s_aux
=
s_aux
,
)
atol
,
rtol
=
1.5e-2
,
1e-2
(
torch
.
testing
.
assert_close
(
out_with_split
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
),
f
"
{
torch
.
max
(
torch
.
abs
(
out_with_split
-
ref_output
))
}
"
,
)
(
torch
.
testing
.
assert_close
(
out_without_split
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
),
f
"
{
torch
.
max
(
torch
.
abs
(
out_without_split
-
ref_output
))
}
"
,
)
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
96
,
128
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
SLIDING_WINDOWS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
QTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"use_alibi"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"use_sink"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"isa"
,
[
"vec"
])
def
test_varlen_with_paged_kv_normal_vec
(
seq_lens
:
list
[
tuple
[
int
,
int
]],
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
sliding_window
:
int
|
None
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
float
|
None
,
num_blocks
:
int
,
use_alibi
:
bool
,
use_sink
:
bool
,
isa
:
str
,
)
->
None
:
varlen_with_paged_kv
(
seq_lens
=
seq_lens
,
num_heads
=
num_heads
,
head_size
=
head_size
,
sliding_window
=
sliding_window
,
dtype
=
dtype
,
block_size
=
block_size
,
soft_cap
=
soft_cap
,
num_blocks
=
num_blocks
,
use_alibi
=
use_alibi
,
use_sink
=
use_sink
,
isa
=
isa
,
)
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
96
,
128
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
SLIDING_WINDOWS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"use_alibi"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"use_sink"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"isa"
,
[
"amx"
])
@
pytest
.
mark
.
skipif
(
not
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
(),
reason
=
"no AMX support."
)
def
test_varlen_with_paged_kv_normal_amx
(
seq_lens
:
list
[
tuple
[
int
,
int
]],
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
sliding_window
:
int
|
None
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
float
|
None
,
num_blocks
:
int
,
use_alibi
:
bool
,
use_sink
:
bool
,
isa
:
str
,
)
->
None
:
varlen_with_paged_kv
(
seq_lens
=
seq_lens
,
num_heads
=
num_heads
,
head_size
=
head_size
,
sliding_window
=
sliding_window
,
dtype
=
dtype
,
block_size
=
block_size
,
soft_cap
=
soft_cap
,
num_blocks
=
num_blocks
,
use_alibi
=
use_alibi
,
use_sink
=
use_sink
,
isa
=
isa
,
)
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
48
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
SLIDING_WINDOWS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"use_alibi"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"use_sink"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"isa"
,
[
"vec16"
])
def
test_varlen_with_paged_kv_normal_vec16
(
seq_lens
:
list
[
tuple
[
int
,
int
]],
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
sliding_window
:
int
|
None
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
float
|
None
,
num_blocks
:
int
,
use_alibi
:
bool
,
use_sink
:
bool
,
isa
:
str
,
)
->
None
:
varlen_with_paged_kv
(
seq_lens
=
seq_lens
,
num_heads
=
num_heads
,
head_size
=
head_size
,
sliding_window
=
sliding_window
,
dtype
=
dtype
,
block_size
=
block_size
,
soft_cap
=
soft_cap
,
num_blocks
=
num_blocks
,
use_alibi
=
use_alibi
,
use_sink
=
use_sink
,
isa
=
isa
,
)
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
96
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
SLIDING_WINDOWS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
50
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"use_alibi"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"use_sink"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"isa"
,
[
"amx"
]
if
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
()
else
[
"vec"
]
)
def
test_varlen_with_paged_kv_softcap
(
seq_lens
:
list
[
tuple
[
int
,
int
]],
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
sliding_window
:
int
|
None
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
float
|
None
,
num_blocks
:
int
,
use_alibi
:
bool
,
use_sink
:
bool
,
isa
:
str
,
)
->
None
:
varlen_with_paged_kv
(
seq_lens
=
seq_lens
,
num_heads
=
num_heads
,
head_size
=
head_size
,
sliding_window
=
sliding_window
,
dtype
=
dtype
,
block_size
=
block_size
,
soft_cap
=
soft_cap
,
num_blocks
=
num_blocks
,
use_alibi
=
use_alibi
,
use_sink
=
use_sink
,
isa
=
isa
,
)
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
96
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
SLIDING_WINDOWS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"use_alibi"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"use_sink"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"isa"
,
[
"amx"
]
if
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
()
else
[
"vec"
]
)
def
test_varlen_with_paged_kv_alibi
(
seq_lens
:
list
[
tuple
[
int
,
int
]],
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
sliding_window
:
int
|
None
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
float
|
None
,
num_blocks
:
int
,
use_alibi
:
bool
,
use_sink
:
bool
,
isa
:
str
,
)
->
None
:
varlen_with_paged_kv
(
seq_lens
=
seq_lens
,
num_heads
=
num_heads
,
head_size
=
head_size
,
sliding_window
=
sliding_window
,
dtype
=
dtype
,
block_size
=
block_size
,
soft_cap
=
soft_cap
,
num_blocks
=
num_blocks
,
use_alibi
=
use_alibi
,
use_sink
=
use_sink
,
isa
=
isa
,
)
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
96
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
SLIDING_WINDOWS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"use_alibi"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"use_sink"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"isa"
,
[
"amx"
]
if
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
()
else
[
"vec"
]
)
def
test_varlen_with_paged_kv_sink
(
seq_lens
:
list
[
tuple
[
int
,
int
]],
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
sliding_window
:
int
|
None
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
float
|
None
,
num_blocks
:
int
,
use_alibi
:
bool
,
use_sink
:
bool
,
isa
:
str
,
)
->
None
:
varlen_with_paged_kv
(
seq_lens
=
seq_lens
,
num_heads
=
num_heads
,
head_size
=
head_size
,
sliding_window
=
sliding_window
,
dtype
=
dtype
,
block_size
=
block_size
,
soft_cap
=
soft_cap
,
num_blocks
=
num_blocks
,
use_alibi
=
use_alibi
,
use_sink
=
use_sink
,
isa
=
isa
,
)
tests/kernels/test_onednn.py
View file @
7f829be7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for FlexAttention backend vs default backend"""
import
pytest
import
torch
...
...
tests/models/language/generation/test_common.py
View file @
7f829be7
...
...
@@ -38,7 +38,11 @@ AITER_MODEL_LIST = [
[
pytest
.
param
(
"bigscience/bloom-560m"
,
# bloom - testing alibi slopes
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
slow_test
],
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
slow_test
,
pytest
.
mark
.
cpu_model
,
],
),
pytest
.
param
(
"openai-community/gpt2"
,
# gpt2
...
...
@@ -55,6 +59,10 @@ AITER_MODEL_LIST = [
pytest
.
mark
.
slow_test
,
],
),
pytest
.
param
(
"google/gemma-2-2b-it"
,
# test hybrid attention
marks
=
[
pytest
.
mark
.
cpu_model
],
),
pytest
.
param
(
"zai-org/chatglm3-6b"
,
# chatglm (text-only)
),
...
...
@@ -64,7 +72,6 @@ AITER_MODEL_LIST = [
),
pytest
.
param
(
"openbmb/MiniCPM3-4B"
,
# fused_moe not supported on CPU
marks
=
[
pytest
.
mark
.
core_model
,
large_gpu_mark
(
min_gb
=
32
)],
),
pytest
.
param
(
...
...
@@ -93,11 +100,7 @@ AITER_MODEL_LIST = [
pytest
.
param
(
"bigcode/starcoder2-3b"
),
# starcoder2
pytest
.
param
(
"TitanML/tiny-mixtral"
,
# mixtral
marks
=
[
pytest
.
mark
.
core_model
],
),
pytest
.
param
(
"allenai/OLMoE-1B-7B-0924-Instruct"
,
marks
=
[
pytest
.
mark
.
cpu_model
],
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
],
),
pytest
.
param
(
"swiss-ai/Apertus-8B-Instruct-2509"
),
# apertus
],
...
...
tests/models/language/pooling/test_embedding.py
View file @
7f829be7
...
...
@@ -23,8 +23,7 @@ from ...utils import check_embeddings_close
),
pytest
.
param
(
"intfloat/e5-mistral-7b-instruct"
,
# CPU v1 doesn't support sliding window
marks
=
[
pytest
.
mark
.
core_model
],
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
],
),
pytest
.
param
(
"ssmits/Qwen2-7B-Instruct-embed-base"
,
marks
=
[
pytest
.
mark
.
cpu_model
]
...
...
tests/models/registry.py
View file @
7f829be7
...
...
@@ -243,7 +243,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"FalconH1ForCausalLM"
:
_HfExamplesInfo
(
"tiiuae/Falcon-H1-0.5B-Base"
),
"FlexOlmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/Flex-reddit-2x7B-1T"
),
"GemmaForCausalLM"
:
_HfExamplesInfo
(
"google/gemma-1.1-2b-it"
),
"Gemma2ForCausalLM"
:
_HfExamplesInfo
(
"google/gemma-2-9b"
),
"Gemma2ForCausalLM"
:
_HfExamplesInfo
(
"google/gemma-2-9b"
,
extras
=
{
"tiny"
:
"google/gemma-2-2b-it"
}
),
"Gemma3ForCausalLM"
:
_HfExamplesInfo
(
"google/gemma-3-1b-it"
),
"Gemma3nForCausalLM"
:
_HfExamplesInfo
(
"google/gemma-3n-E2B-it"
),
"GlmForCausalLM"
:
_HfExamplesInfo
(
"zai-org/glm-4-9b-chat-hf"
),
...
...
vllm/_custom_ops.py
View file @
7f829be7
...
...
@@ -2583,6 +2583,88 @@ def onednn_scaled_mm(
return
output
def
cpu_attn_get_scheduler_metadata
(
num_reqs
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
head_dim
:
int
,
seq_lens
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
query_start_loc
:
torch
.
Tensor
,
causal
:
bool
,
sliding_window_size
:
int
,
isa
:
str
,
enable_kv_split
:
bool
,
)
->
torch
.
Tensor
:
sheduler_metadata
=
torch
.
ops
.
_C
.
get_scheduler_metadata
(
num_reqs
,
num_heads
,
num_kv_heads
,
head_dim
,
seq_lens
,
dtype
,
query_start_loc
,
causal
,
sliding_window_size
,
isa
,
enable_kv_split
,
)
return
sheduler_metadata
def
cpu_attn_reshape_and_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
isa
:
str
,
)
->
None
:
torch
.
ops
.
_C
.
cpu_attn_reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
isa
,
)
def
cpu_attention_with_kv_cache
(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
query_start_loc
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
scale
:
float
,
causal
:
bool
,
alibi_slopes
:
torch
.
Tensor
|
None
,
sliding_window
:
tuple
[
int
,
int
],
block_table
:
torch
.
Tensor
,
softcap
:
float
,
scheduler_metadata
:
torch
.
Tensor
,
s_aux
:
torch
.
Tensor
|
None
,
)
->
None
:
torch
.
ops
.
_C
.
cpu_attention_with_kv_cache
(
query
,
key_cache
,
value_cache
,
output
,
query_start_loc
,
seq_lens
,
scale
,
causal
,
alibi_slopes
,
sliding_window
[
0
],
sliding_window
[
1
],
block_table
,
softcap
,
scheduler_metadata
,
s_aux
,
)
if
hasattr
(
torch
.
ops
.
_qutlass_C
,
"matmul_mxf4_bf16_tn"
):
@
register_fake
(
"_qutlass_C::matmul_mxf4_bf16_tn"
)
...
...
vllm/attention/backends/registry.py
View file @
7f829be7
...
...
@@ -49,7 +49,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
ROCM_AITER_FA
=
(
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
)
TORCH_SDPA
=
"
vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
TORCH_SDPA
=
"
"
# this tag is only used for ViT
FLASHINFER
=
"vllm.v1.attention.backends.flashinfer.FlashInferBackend"
FLASHINFER_MLA
=
(
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
...
...
@@ -70,6 +70,7 @@ class AttentionBackendEnum(enum.Enum, metaclass=_AttentionBackendEnumMeta):
"vllm.v1.attention.backends.rocm_aiter_unified_attn."
"RocmAiterUnifiedAttentionBackend"
)
CPU_ATTN
=
"vllm.v1.attention.backends.cpu_attn.CPUAttentionBackend"
# Placeholder for third-party/custom backends - must be registered before use
CUSTOM
=
""
...
...
vllm/engine/arg_utils.py
View file @
7f829be7
...
...
@@ -1726,9 +1726,6 @@ class EngineArgs:
)
_raise_unsupported_error
(
feature_name
=
name
)
if
current_platform
.
is_cpu
()
and
model_config
.
get_sliding_window
()
is
not
None
:
_raise_unsupported_error
(
feature_name
=
"sliding window (CPU backend)"
)
def
_set_default_args
(
self
,
usage_context
:
UsageContext
,
model_config
:
ModelConfig
)
->
None
:
...
...
vllm/platforms/cpu.py
View file @
7f829be7
...
...
@@ -8,7 +8,6 @@ import platform
import
subprocess
import
sys
from
dataclasses
import
dataclass
from
importlib.util
import
find_spec
from
typing
import
TYPE_CHECKING
import
regex
as
re
...
...
@@ -139,16 +138,15 @@ class CpuPlatform(Platform):
)
->
str
:
from
vllm.attention.backends.registry
import
AttentionBackendEnum
if
selected_backend
and
selected_backend
!=
AttentionBackendEnum
.
TORCH_SDPA
:
if
selected_backend
and
selected_backend
!=
AttentionBackendEnum
.
CPU_ATTN
:
logger
.
info
(
"Cannot use %s backend on CPU."
,
selected_backend
)
if
use_mla
:
raise
NotImplementedError
(
"MLA is not supported on CPU."
)
if
use_sparse
:
raise
NotImplementedError
(
"Sparse Attention is not supported on CPU."
)
logger
.
info
(
"Using Torch SDPA backend."
)
if
not
use_v1
:
raise
ValueError
(
"CPU backend only supports V1."
)
return
AttentionBackendEnum
.
TORCH_SDPA
.
get_path
()
return
AttentionBackendEnum
.
CPU_ATTN
.
get_path
()
@
classmethod
def
get_device_total_memory
(
cls
,
device_id
:
int
=
0
)
->
int
:
...
...
@@ -186,15 +184,13 @@ class CpuPlatform(Platform):
cache_config
=
vllm_config
.
cache_config
ipex_available
=
find_spec
(
"intel_extension_for_pytorch"
)
is
not
None
if
cache_config
.
block_size
is
None
:
cache_config
.
block_size
=
128
if
cache_config
and
cache_config
.
block_size
is
None
:
cache_config
.
block_size
=
128
if
ipex_available
else
16
if
not
ipex_available
and
cache_config
.
block_size
!=
16
:
raise
RuntimeError
(
f
"--block-size=
{
cache_config
.
block_size
}
requires"
" intel_extension_for_pytorch"
if
cache_config
.
block_size
%
32
!=
0
:
logger
.
warning
(
"CPU backend prefers block_size is multiples of 32, "
"otherwise the performance is not optimized."
)
scheduler_config
=
vllm_config
.
scheduler_config
...
...
@@ -207,22 +203,11 @@ class CpuPlatform(Platform):
"backend is not compatible with FP8 KV cache."
)
if
cache_config
.
cache_dtype
==
"fp8_e4m3"
:
cache_config
.
cache_dtype
=
"fp8_e5m2"
logger
.
warning
(
"CPU backend doesn't support fp8_e4m3 KV cache type, cast to fp8_e5m2."
)
if
(
cache_config
.
cache_dtype
!=
"auto"
and
model_config
is
not
None
and
model_config
.
dtype
==
torch
.
half
):
if
cache_config
.
cache_dtype
!=
"auto"
:
logger
.
warning
(
"FP8 KV cache on the CPU backend only does not"
" support fp16 for now, cast to bf16."
"CPU backend doesn't support KV cache quantization fallback to auto."
)
model
_config
.
dtype
=
torch
.
bfloat16
cache
_config
.
cache_
dtype
=
"auto"
cache_config
.
cpu_kvcache_space_bytes
=
CpuPlatform
.
get_device_total_memory
()
...
...
vllm/utils/__init__.py
View file @
7f829be7
...
...
@@ -57,7 +57,6 @@ STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL
:
str
=
"FLASHINFER"
STR_TORCH_SDPA_ATTN_VAL
:
str
=
"TORCH_SDPA"
STR_XFORMERS_ATTN_VAL
:
str
=
"XFORMERS"
STR_FLASH_ATTN_VAL
:
str
=
"FLASH_ATTN"
STR_INVALID_VAL
:
str
=
"INVALID"
...
...
vllm/v1/attention/backends/cpu_attn.py
View file @
7f829be7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
ClassVar
,
Optional
from
typing
import
ClassVar
import
numpy
as
np
import
torch
from
torch.nn.functional
import
scaled_dot_product_attention
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
,
)
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
...
...
@@ -24,44 +23,38 @@ from vllm.v1.attention.backends.utils import (
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
try
:
import
intel_extension_for_pytorch.llm.modules
as
ipex_modules
_use_ipex
=
True
# AttributeError is to handle a bug in ipex
# https://github.com/intel/intel-extension-for-pytorch/pull/813
except
(
ImportError
,
AttributeError
):
_use_ipex
=
False
from
vllm
import
_custom_ops
as
ops
logger
=
init_logger
(
__name__
)
_CPU_ARCH_PREFER_MIXED_BATCH
=
(
CpuArchEnum
.
X86
,)
class
TorchSDPA
Backend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
Fals
e
class
CPUAttention
Backend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
Tru
e
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
,
]
@
classmethod
def
get_supported_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
attn_impl
=
_get_paged_attn_impl
()
return
attn_impl
.
get_supported_head_sizes
()
return
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
@
staticmethod
def
get_name
()
->
str
:
return
"
TORCH_SDPA
"
return
"
CPU_ATTN
"
@
staticmethod
def
get_impl_cls
()
->
type
[
"
TorchSDPA
BackendImpl"
]:
return
TorchSDPA
BackendImpl
def
get_impl_cls
()
->
type
[
"
CPUAttention
BackendImpl"
]:
return
CPUAttention
BackendImpl
@
staticmethod
def
get_builder_cls
()
->
type
[
"
TorchSDPA
MetadataBuilder
V1
"
]:
return
TorchSDPA
MetadataBuilder
V1
def
get_builder_cls
()
->
type
[
"
CPUAttention
MetadataBuilder"
]:
return
CPUAttention
MetadataBuilder
@
staticmethod
def
get_kv_cache_shape
(
...
...
@@ -71,9 +64,7 @@ class TorchSDPABackend(AttentionBackend):
head_size
:
int
,
cache_dtype_str
:
str
=
"auto"
,
)
->
tuple
[
int
,
...]:
return
_get_paged_attn_impl
().
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
return
2
,
num_blocks
,
num_kv_heads
,
block_size
,
head_size
@
staticmethod
def
use_cascade_attention
(
*
args
,
**
kwargs
)
->
bool
:
...
...
@@ -81,264 +72,26 @@ class TorchSDPABackend(AttentionBackend):
@
dataclass
class
TorchSDPAMetadata
(
AttentionMetadata
):
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills
:
int
# Number of prefill tokens.
num_prefill_tokens
:
int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens
:
int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
class
CPUAttentionMetadata
:
isa
:
str
num_actual_tokens
:
int
# Number of tokens excluding padding.
max_query_len
:
int
query_start_loc
:
torch
.
Tensor
max_seq_len
:
int
seq_lens
:
torch
.
Tensor
block_table
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
decode_seq_lens_tensor
:
torch
.
Tensor
|
None
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
decode_max_seq_len
:
int
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
decode_block_tables
:
torch
.
Tensor
|
None
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
chunked_prefill
:
bool
seq_lens
:
list
[
int
]
|
None
=
None
# For non-chunked prefill
# For chunked prefill only
max_query_len
:
int
|
None
=
None
prefill_max_seq_len
:
int
|
None
=
None
prefill_query_start_loc
:
torch
.
Tensor
|
None
=
None
prefill_seq_start_loc
:
torch
.
Tensor
|
None
=
None
prefill_block_tables
:
torch
.
Tensor
|
None
=
None
# For V1 logits index only
query_start_loc
:
torch
.
Tensor
|
None
=
None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens
:
list
[
int
]
|
None
=
None
encoder_seq_lens_tensor
:
torch
.
Tensor
|
None
=
None
# Maximum sequence length among encoder sequences
max_encoder_seq_len
:
int
|
None
=
None
# Number of tokens input to encoder
num_encoder_tokens
:
int
|
None
=
None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping
:
torch
.
Tensor
|
None
=
None
cross_block_tables
:
torch
.
Tensor
|
None
=
None
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self
.
attn_bias
:
list
[
torch
.
Tensor
]
|
None
=
None
self
.
encoder_attn_bias
:
list
[
torch
.
Tensor
]
|
None
=
None
self
.
cross_attn_bias
:
list
[
torch
.
Tensor
]
|
None
=
None
@
property
def
is_all_encoder_attn_metadata_set
(
self
):
"""
All attention metadata required for encoder attention is set.
"""
return
(
(
self
.
encoder_seq_lens
is
not
None
)
and
(
self
.
encoder_seq_lens_tensor
is
not
None
)
and
(
self
.
max_encoder_seq_len
is
not
None
)
)
@
property
def
is_all_cross_attn_metadata_set
(
self
):
"""
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
"""
return
(
self
.
is_all_encoder_attn_metadata_set
and
(
self
.
cross_slot_mapping
is
not
None
)
and
(
self
.
cross_block_tables
is
not
None
)
)
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"TorchSDPAMetadata"
]:
if
self
.
num_prefill_tokens
==
0
:
return
None
return
self
@
property
def
decode_metadata
(
self
)
->
Optional
[
"TorchSDPAMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
return
self
def
get_seq_lens
(
self
,
attn_type
:
str
,
):
"""
Extract appropriate sequence lengths from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence lengths tensor for query
* Appropriate sequence lengths tensor for key & value
"""
if
(
attn_type
==
AttentionType
.
DECODER
or
attn_type
==
AttentionType
.
ENCODER_ONLY
):
seq_lens_q
=
self
.
seq_lens
seq_lens_kv
=
self
.
seq_lens
elif
attn_type
==
AttentionType
.
ENCODER
:
seq_lens_q
=
self
.
encoder_seq_lens
seq_lens_kv
=
self
.
encoder_seq_lens
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
seq_lens_q
=
self
.
seq_lens
seq_lens_kv
=
self
.
encoder_seq_lens
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
return
seq_lens_q
,
seq_lens_kv
def
get_attn_bias
(
self
,
attn_type
:
str
,
)
->
list
[
torch
.
Tensor
]
|
None
:
"""
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
"""
if
(
attn_type
==
AttentionType
.
DECODER
or
attn_type
==
AttentionType
.
ENCODER_ONLY
):
return
self
.
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER
:
return
self
.
encoder_attn_bias
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
return
self
.
cross_attn_bias
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
set_attn_bias
(
self
,
attn_bias
:
list
[
torch
.
Tensor
],
attn_type
:
str
,
)
->
None
:
"""
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
"""
if
(
attn_type
==
AttentionType
.
DECODER
or
attn_type
==
AttentionType
.
ENCODER_ONLY
):
self
.
attn_bias
=
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER
:
self
.
encoder_attn_bias
=
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
self
.
cross_attn_bias
=
attn_bias
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
get_seq_len_block_table_args
(
self
,
attn_type
:
str
,
)
->
tuple
:
"""
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
"""
if
(
attn_type
==
AttentionType
.
DECODER
or
attn_type
==
AttentionType
.
ENCODER_ONLY
):
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return
(
self
.
decode_seq_lens_tensor
,
self
.
decode_max_seq_len
,
self
.
decode_block_tables
,
)
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return
(
self
.
encoder_seq_lens_tensor
,
self
.
max_encoder_seq_len
,
self
.
cross_block_tables
,
)
elif
attn_type
==
AttentionType
.
ENCODER
:
# No block tables associated with encoder attention
return
(
self
.
encoder_seq_lens_tensor
,
self
.
max_encoder_seq_len
,
None
)
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
scheduler_metadata
:
torch
.
Tensor
|
None
causal
:
bool
=
True
# can be removed after deprecate sdpa
use_sdpa_prefill
:
bool
=
False
num_decode_tokens
:
int
=
0
sdpa_attn_masks
:
list
[
torch
.
Tensor
|
None
]
|
None
=
None
sdpa_start_loc
:
torch
.
Tensor
|
None
=
None
class
TorchSDPAMetadataBuilderV1
(
AttentionMetadataBuilder
[
TorchSDPAMetadata
]):
reorder_batch_threshold
:
int
=
1
class
CPUAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
CPUAttentionMetadata
]):
def
__init__
(
self
,
kv_cache_spec
:
AttentionSpec
,
...
...
@@ -348,80 +101,104 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
)
->
None
:
super
().
__init__
(
kv_cache_spec
,
layer_names
,
vllm_config
,
device
)
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
_init_reorder_batch_threshold
(
1
,
False
)
self
.
use_sdpa_prefill
=
False
reorder_batch_threshold
=
None
if
current_platform
.
get_cpu_architecture
()
not
in
_CPU_ARCH_PREFER_MIXED_BATCH
:
# in this case, decode seqs are reordered to the front of prefill seqs
# to split decode and prefill. Then use SDPA for prefill and
# cpu_attention_with_kv_cache for decode
reorder_batch_threshold
=
1
self
.
use_sdpa_prefill
=
True
self
.
seq_start_loc_cpu
=
torch
.
zeros
(
vllm_config
.
scheduler_config
.
max_num_seqs
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
self
.
_init_reorder_batch_threshold
(
reorder_batch_threshold
,
False
)
self
.
kv_cache_spec
=
kv_cache_spec
self
.
vllm_config
=
vllm_config
parallel_config
=
vllm_config
.
parallel_config
self
.
num_kv_heads
=
vllm_config
.
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
num_heads
=
vllm_config
.
model_config
.
get_num_attention_heads
(
parallel_config
)
self
.
seq_start_loc_np
=
self
.
seq_start_loc_cpu
.
numpy
()
self
.
head_dim
=
kv_cache_spec
.
head_size
self
.
dtype
=
vllm_config
.
model_config
.
dtype
self
.
window_size
=
getattr
(
kv_cache_spec
,
"sliding_window"
,
-
1
)
if
self
.
window_size
is
None
:
self
.
window_size
=
-
1
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
isa
=
_get_attn_isa
(
self
.
dtype
,
self
.
block_size
)
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
TorchSDPA
Metadata
:
)
->
CPUAttention
Metadata
:
num_reqs
=
common_attn_metadata
.
num_reqs
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
max_query_len
=
common_attn_metadata
.
max_query_len
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
seq_lens_np
=
seq_lens_cpu
.
numpy
()
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
query_start_loc_np
=
query_start_loc_cpu
.
numpy
()
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
,
require_uniform
=
True
,
)
)
max_prefill_seq_len
=
(
seq_lens_np
[
num_decodes
:
num_reqs
].
max
().
item
()
if
num_prefills
>
0
else
0
)
max_decode_seq_len
=
(
seq_lens_np
[:
num_decodes
].
max
().
item
()
if
num_prefills
<
num_reqs
else
0
)
self
.
seq_start_loc_np
[
0
]
=
0
np
.
cumsum
(
seq_lens_np
,
out
=
self
.
seq_start_loc_np
[
1
:
num_reqs
+
1
])
slot_mapping
=
common_attn_metadata
.
slot_mapping
.
long
()
max_seq_len
=
common_attn_metadata
.
max_seq_len
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
query_start_loc_np
=
query_start_loc_cpu
.
numpy
()
query_start_loc_np
[
num_decodes
:
num_reqs
+
1
]
-=
num_decode_tokens
slot_mapping
=
common_attn_metadata
.
slot_mapping
causal
=
common_attn_metadata
.
causal
sdpa_start_loc
=
query_start_loc
num_decode_tokens
=
0
if
self
.
use_sdpa_prefill
and
causal
:
# Decoder, need reorder and truncate
assert
self
.
reorder_batch_threshold
(
num_decodes
,
num_prefills
,
num_decode_tokens
,
num_prefill_tokens
)
=
(
split_decodes_and_prefills
(
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
,
require_uniform
=
True
,
)
)
num_reqs
=
num_decodes
sdpa_start_loc
=
sdpa_start_loc
[
num_decodes
:]
-
num_decode_tokens
seq_lens
=
seq_lens
[:
num_decodes
]
query_start_loc
=
query_start_loc
[:
num_decodes
+
1
]
block_table_tensor
=
block_table_tensor
[:
num_decodes
]
sheduler_metadata
=
None
if
causal
:
# for decode batch, use the custom kernel
sheduler_metadata
=
ops
.
cpu_attn_get_scheduler_metadata
(
num_reqs
=
num_reqs
,
num_heads
=
self
.
num_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
head_dim
=
self
.
head_dim
,
seq_lens
=
seq_lens
,
dtype
=
self
.
dtype
,
query_start_loc
=
query_start_loc
,
causal
=
causal
,
sliding_window_size
=
self
.
window_size
,
isa
=
self
.
isa
,
enable_kv_split
=
True
,
)
attn_metadata
=
TorchSDPAMetadata
(
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
slot_mapping
=
slot_mapping
,
# to ensure inference when chunked_prefill is disabled
seq_lens
=
seq_lens_cpu
.
tolist
()[
num_decodes
:],
# prefill
decode_seq_lens_tensor
=
seq_lens_cpu
[:
num_decodes
],
# decode
decode_max_seq_len
=
max_decode_seq_len
,
# decode
decode_block_tables
=
block_table_tensor
[:
num_decodes
],
# decode
chunked_prefill
=
self
.
scheduler_config
.
chunked_prefill_enabled
,
attn_metadata
=
CPUAttentionMetadata
(
isa
=
self
.
isa
,
num_actual_tokens
=
num_actual_tokens
,
max_query_len
=
max_query_len
,
prefill_max_seq_len
=
max_prefill_seq_len
,
prefill_query_start_loc
=
query_start_loc_cpu
[
num_decodes
:
num_reqs
+
1
],
# prefill
prefill_seq_start_loc
=
self
.
seq_start_loc_cpu
[
num_decodes
:
num_reqs
+
1
],
# prefill
prefill_block_tables
=
block_table_tensor
[
num_decodes
:
num_reqs
],
# prefill
query_start_loc
=
query_start_loc_cpu
[:
num_reqs
+
1
],
# for logits index
query_start_loc
=
query_start_loc
,
max_seq_len
=
max_seq_len
,
seq_lens
=
seq_lens
,
block_table
=
block_table_tensor
,
slot_mapping
=
slot_mapping
,
scheduler_metadata
=
sheduler_metadata
,
causal
=
causal
,
use_sdpa_prefill
=
self
.
use_sdpa_prefill
,
num_decode_tokens
=
num_decode_tokens
,
sdpa_start_loc
=
sdpa_start_loc
,
)
return
attn_metadata
class
TorchSDPA
BackendImpl
(
AttentionImpl
[
TorchSDPAMetadata
]
):
class
CPUAttention
BackendImpl
(
AttentionImpl
):
def
__init__
(
self
,
num_heads
:
int
,
...
...
@@ -434,37 +211,48 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
logits_soft_cap
:
float
|
None
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
str
|
None
=
None
,
sinks
:
torch
.
Tensor
|
None
=
None
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
logits_soft_cap
is
not
None
:
logger
.
warning_once
(
"Torch SPDA does not support logits soft cap. "
"Outputs may be slightly off."
)
self
.
paged_attn_impl
=
_get_paged_attn_impl
()
self
.
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
if
logits_soft_cap
is
not
None
and
attn_type
in
(
AttentionType
.
ENCODER
,
AttentionType
.
ENCODER_ONLY
,
):
logger
.
warning_once
(
"CPU_ATTN does not support logits softcap for"
" ENCODER and ENCODER_ONLY, outputs may be slightly off"
)
if
logits_soft_cap
is
None
:
logits_soft_cap
=
0
self
.
logits_soft_cap
=
logits_soft_cap
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
sliding_window
if
sliding_window
is
None
:
self
.
sliding_window
=
(
-
1
,
-
1
)
elif
attn_type
==
AttentionType
.
ENCODER_ONLY
:
self
.
sliding_window
=
(
sliding_window
-
1
,
sliding_window
-
1
)
else
:
self
.
sliding_window
=
(
sliding_window
-
1
,
0
)
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
need_mask
=
(
self
.
alibi_slopes
is
not
None
or
self
.
sliding_window
is
not
None
)
if
is_quantized_kv_cache
(
kv_cache_dtype
)
and
not
_use_ipex
:
raise
NotImplementedError
(
"Torch SDPA backend FP8 KV cache requires "
"intel_extension_for_pytorch support."
)
if
is_quantized_kv_cache
(
kv_cache_dtype
):
raise
NotImplementedError
(
"FP8 KV cache is unsupported in CPU_ATTN"
)
self
.
attn_type
=
attn_type
self
.
sinks
=
sinks
if
self
.
sinks
is
not
None
:
assert
self
.
sinks
.
shape
[
0
]
==
num_heads
,
(
"Sinks must have the same number of heads as the number of "
"heads in the layer"
)
def
forward
(
self
,
layer
:
AttentionLayer
,
...
...
@@ -472,196 +260,130 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
attn_metadata
:
CPUAttentionMetadata
|
None
,
output
:
torch
.
Tensor
|
None
=
None
,
output_scale
:
torch
.
Tensor
|
None
=
None
,
output_block_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass
with torch SDPA and PagedAttention
.
"""Forward pass
for CPU attention backend
.
Args:
query: shape = [num_tokens, num_heads
*
head_size]
key: shape = [num_tokens, num_kv_heads
*
head_size]
value: shape = [num_tokens, num_kv_heads
*
head_size]
query: shape = [num_tokens, num_heads
,
head_size]
key: shape = [num_tokens, num_kv_heads
,
head_size]
value: shape = [num_tokens, num_kv_heads
,
head_size]
kv_cache: shape =
[2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
[2, num_blocks, num_kv_heads, block_size, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for
TorchSDPA
BackendImpl"
" for
CPUAttention
BackendImpl"
)
# For warming-up
if
attn_metadata
is
None
:
return
query
attn_type
=
self
.
attn_type
if
attn_type
==
AttentionType
.
ENCODER
and
(
not
attn_metadata
.
is_all_encoder_attn_metadata_set
):
raise
AttributeError
(
"Encoder attention requires setting encoder metadata attributes."
)
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
and
(
not
attn_metadata
.
is_all_cross_attn_metadata_set
):
raise
AttributeError
(
"Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes."
return
output
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if
self
.
attn_type
in
(
AttentionType
.
ENCODER_ONLY
,
AttentionType
.
ENCODER
):
# For encoder attention,
return
self
.
_run_sdpa_forward
(
query
[:
num_actual_tokens
],
key
[:
num_actual_tokens
],
value
[:
num_actual_tokens
],
output
[:
num_actual_tokens
],
attn_metadata
,
self
.
attn_type
,
)
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
if
key
is
not
None
:
assert
value
is
not
None
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
else
:
assert
value
is
None
if
attn_type
!=
AttentionType
.
ENCODER
and
kv_cache
.
numel
()
>
0
:
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache
,
value_cache
=
self
.
paged_attn_impl
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
# For decoder and cross-attention, use KV cache, size are
# [num_blocks, num_kv_heads, block_size, head_size]
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
if
(
key
is
not
None
)
and
(
value
is
not
None
):
if
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping
=
attn_metadata
.
cross_slot_mapping
else
:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping
=
attn_metadata
.
slot_mapping
self
.
paged_attn_impl
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
updated_slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if
(
self
.
kv_sharing_target_layer_name
is
None
and
key
is
not
None
and
value
is
not
None
):
ops
.
cpu_attn_reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
isa
,
)
if
attn_type
!=
AttentionType
.
ENCODER
:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
if
attn_metadata
.
use_sdpa_prefill
:
assert
self
.
sinks
is
None
,
"Attention sink is unsupported in SDPA prefill"
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
else
:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert
attn_metadata
.
num_encoder_tokens
is
not
None
num_prefill_tokens
=
attn_metadata
.
num_encoder_tokens
num_decode_tokens
=
0
if
attn_type
==
AttentionType
.
DECODER
:
# Only enforce this shape-constraint for decoder
# self-attention
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
output
=
torch
.
empty_like
(
query
)
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
not
prefill_meta
.
prefill_metadata
.
chunked_prefill
:
# type: ignore
assert
attn_metadata
.
seq_lens
is
not
None
self
.
_run_sdpa_forward
(
output
,
query
,
key
,
value
,
prefill_meta
,
attn_type
=
attn_type
)
else
:
# prefix-enabled attention
assert
not
self
.
need_mask
import
intel_extension_for_pytorch.llm.modules
as
ipex_modules
output
=
torch
.
empty_like
(
query
)
ipex_modules
.
PagedAttention
.
flash_attn_varlen_func
(
output
[
prefill_meta
.
num_decode_tokens
:,
:,
:],
query
[
prefill_meta
.
num_decode_tokens
:,
:,
:],
key_cache
,
value_cache
,
prefill_meta
.
prefill_query_start_loc
,
prefill_meta
.
prefill_seq_start_loc
,
prefill_meta
.
max_query_len
,
prefill_meta
.
prefill_max_seq_len
,
self
.
scale
,
True
,
prefill_meta
.
prefill_block_tables
,
self
.
alibi_slopes
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
assert
attn_type
!=
AttentionType
.
ENCODER_ONLY
,
(
"Encoder-only models should not have decode metadata."
self
.
_run_sdpa_forward
(
query
[
num_decode_tokens
:
num_actual_tokens
],
key
[
num_decode_tokens
:
num_actual_tokens
],
value
[
num_decode_tokens
:
num_actual_tokens
],
output
[
num_decode_tokens
:
num_actual_tokens
],
attn_metadata
,
self
.
attn_type
,
)
# Decoding run.
(
seq_lens_arg
,
max_seq_len_arg
,
block_tables_arg
,
)
=
decode_meta
.
get_seq_len_block_table_args
(
attn_type
)
self
.
paged_attn_impl
.
forward_decode
(
output
[:
attn_metadata
.
num_decode_tokens
,
:,
:],
query
[:
attn_metadata
.
num_decode_tokens
,
:,
:],
key_cache
,
value_cache
,
block_tables_arg
,
seq_lens_arg
,
max_seq_len_arg
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
layer
.
_k_scale
,
layer
.
_v_scale
,
num_actual_tokens
=
num_decode_tokens
if
num_actual_tokens
>
0
:
ops
.
cpu_attention_with_kv_cache
(
query
=
query
[:
num_actual_tokens
],
key_cache
=
key_cache
,
value_cache
=
value_cache
,
output
=
output
[:
num_actual_tokens
],
# type: ignore
query_start_loc
=
attn_metadata
.
query_start_loc
,
seq_lens
=
attn_metadata
.
seq_lens
,
scale
=
self
.
scale
,
causal
=
attn_metadata
.
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
# type: ignore
sliding_window
=
self
.
sliding_window
,
block_table
=
attn_metadata
.
block_table
,
softcap
=
self
.
logits_soft_cap
,
scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
s_aux
=
self
.
sinks
,
)
# Reshape the output tensor.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
return
output
def
_run_sdpa_forward
(
self
,
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
TorchSDPAMetadata
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
attn_masks
=
attn_metadata
.
get_attn_bias
(
attn_type
)
output
:
torch
.
Tensor
,
attn_metadata
:
CPUAttentionMetadata
,
attn_type
:
str
,
)
->
torch
.
Tensor
:
attn_masks
=
attn_metadata
.
sdpa_attn_masks
if
attn_masks
is
None
:
if
self
.
alibi_slopes
is
not
None
:
attn_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
attn_metadata
.
s
eq_lens
,
# type: ignore
attn_metadata
.
s
dpa_start_loc
,
)
elif
self
.
sliding_window
is
not
None
:
elif
self
.
sliding_window
[
0
]
!=
-
1
or
self
.
sliding_window
[
1
]
!=
-
1
:
assert
attn_metadata
.
seq_lens
is
not
None
attn_masks
=
_make_sliding_window_bias
(
attn_metadata
.
seq_lens
,
self
.
sliding_window
,
query
.
dtype
attn_metadata
.
sdpa_start_loc
,
self
.
sliding_window
[
0
],
self
.
sliding_window
[
1
],
query
.
dtype
,
)
else
:
seq_lens
,
_
=
attn_metadata
.
get_seq_lens
(
attn_type
)
attn_masks
=
[
None
]
*
len
(
seq_lens
)
attn_metadata
.
set_attn_bias
(
attn_masks
,
attn_type
)
attn_masks
=
[
None
]
*
(
attn_metadata
.
sdpa_start_loc
.
size
(
0
)
-
1
)
# type: ignore
attn_metadata
.
sdpa_attn_masks
=
attn_masks
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
...
...
@@ -673,21 +395,16 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
causal_attn
=
attn_type
==
AttentionType
.
DECODER
seq_lens_q
,
seq_lens_kv
=
attn_metadata
.
get_seq_lens
(
attn_type
)
# Incoming Q and KV contain decoded tokens as well, hence start at an offset
# equal to num_decode_tokens since decode requests appear first
start_q
,
start_kv
=
(
attn_metadata
.
num_decode_tokens
,
attn_metadata
.
num_decode_tokens
,
)
for
seq_len_q
,
seq_len_kv
,
mask
in
zip
(
seq_lens_q
,
seq_lens_kv
,
attn_masks
):
end_q
=
start_q
+
seq_len_q
end_kv
=
start_kv
+
seq_len_kv
sdpa_start_loc
=
attn_metadata
.
sdpa_start_loc
.
numpy
()
# type: ignore
for
i
in
range
(
len
(
attn_masks
)):
mask
=
attn_masks
[
i
]
start_q
=
sdpa_start_loc
[
i
]
end_q
=
sdpa_start_loc
[
i
+
1
]
sub_out
=
(
scaled_dot_product_attention
(
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
query
[
None
,
:,
start_q
:
end_q
,
:],
key
[
None
,
:,
start_
kv
:
end_
kv
,
:],
value
[
None
,
:,
start_
kv
:
end_
kv
,
:],
key
[
None
,
:,
start_
q
:
end_
q
,
:],
value
[
None
,
:,
start_
q
:
end_
q
,
:],
attn_mask
=
mask
,
dropout_p
=
0.0
,
is_causal
=
causal_attn
and
mask
is
None
,
...
...
@@ -697,17 +414,20 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
.
movedim
(
query
.
dim
()
-
2
,
0
)
)
output
[
start_q
:
end_q
,
:,
:]
=
sub_out
start_q
,
start_kv
=
end_q
,
end_kv
return
output
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
s
eq_lens
:
list
[
int
]
,
s
dpa_start_loc
:
torch
.
Tensor
,
)
->
list
[
torch
.
Tensor
]:
attn_biases
:
list
[
torch
.
Tensor
]
=
[]
for
seq_len
in
seq_lens
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
seq_num
=
sdpa_start_loc
.
size
(
0
)
-
1
sdpa_start_loc
=
sdpa_start_loc
.
numpy
()
# type: ignore
for
i
in
range
(
seq_num
):
seq_len
=
sdpa_start_loc
[
i
+
1
]
-
sdpa_start_loc
[
i
]
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
# type: ignore
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
...
...
@@ -719,7 +439,7 @@ def _make_alibi_bias(
bias
=
bias
[
None
,
:].
repeat
((
num_heads
,
1
,
1
))
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
]).
unsqueeze_
(
0
)
inf_mask
=
(
torch
.
empty
((
1
,
seq_len
,
seq_len
),
dtype
=
bias
.
dtype
)
torch
.
empty
((
1
,
seq_len
,
seq_len
),
dtype
=
bias
.
dtype
)
# type: ignore
.
fill_
(
-
torch
.
inf
)
.
triu_
(
diagonal
=
1
)
)
...
...
@@ -729,210 +449,37 @@ def _make_alibi_bias(
def
_make_sliding_window_bias
(
seq_lens
:
list
[
int
],
window_size
:
int
|
None
,
sdpa_start_loc
:
torch
.
Tensor
,
left_window_size
:
int
,
right_window_size
:
int
,
dtype
:
torch
.
dtype
,
)
->
list
[
torch
.
Tensor
]:
attn_biases
:
list
[
torch
.
Tensor
]
=
[]
for
seq_len
in
seq_lens
:
tensor
=
torch
.
full
(
(
1
,
seq_len
,
seq_len
),
dtype
=
dtype
,
seq_num
=
sdpa_start_loc
.
size
(
0
)
-
1
sdpa_start_loc
=
sdpa_start_loc
.
numpy
()
# type: ignore
for
i
in
range
(
seq_num
):
seq_len
=
sdpa_start_loc
[
i
+
1
]
-
sdpa_start_loc
[
i
]
mask
=
torch
.
full
(
# type: ignore
(
1
,
seq_len
,
seq_len
),
# type: ignore
fill_value
=
1
,
dtype
=
dtype
,
)
shift
=
0
mask
=
torch
.
tril
(
tensor
,
diagonal
=
shift
).
to
(
dtype
)
# type: ignore
if
window_size
is
not
None
:
mask
=
torch
.
triu
(
mask
,
diagonal
=
shift
-
window_size
+
1
)
if
right_window_size
!=
-
1
:
mask
=
torch
.
tril
(
mask
,
diagonal
=
right_window_size
)
if
left_window_size
!=
-
1
:
mask
=
torch
.
triu
(
mask
,
diagonal
=-
left_window_size
)
mask
=
torch
.
log
(
mask
)
attn_biases
.
append
(
mask
.
to
(
dtype
)
)
attn_biases
.
append
(
mask
)
return
attn_biases
class
_PagedAttention
:
@
staticmethod
def
get_supported_head_sizes
()
->
list
[
int
]:
return
[
32
,
64
,
80
,
96
,
112
,
128
,
192
,
256
]
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
*
args
,
)
->
tuple
[
int
,
...]:
return
2
,
num_blocks
,
block_size
*
num_kv_heads
*
head_size
@
staticmethod
def
split_kv_cache
(
kv_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
head_size
:
int
,
*
args
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
x
=
16
//
kv_cache
.
element_size
()
num_blocks
=
kv_cache
.
shape
[
1
]
key_cache
=
kv_cache
[
0
]
key_cache
=
key_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
//
x
,
-
1
,
x
)
value_cache
=
kv_cache
[
1
]
value_cache
=
value_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
,
-
1
)
return
key_cache
,
value_cache
@
staticmethod
def
write_to_paged_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
*
args
,
)
->
None
:
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
.
flatten
(),
kv_cache_dtype
,
k_scale
,
v_scale
,
)
@
staticmethod
def
forward_decode
(
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_context_len
:
int
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
scale
:
float
,
alibi_slopes
:
torch
.
Tensor
|
None
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
*
args
,
)
->
None
:
tp_rank
:
int
=
0
blocksparse_local_blocks
:
int
=
0
blocksparse_vert_stride
:
int
=
0
blocksparse_block_size
:
int
=
64
blocksparse_head_sliding_step
:
int
=
0
block_size
=
value_cache
.
shape
[
3
]
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
class
_IPEXPagedAttention
(
_PagedAttention
):
@
staticmethod
def
get_supported_head_sizes
()
->
list
[
int
]:
return
[]
@
staticmethod
def
split_kv_cache
(
kv_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
head_size
:
int
,
*
args
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
num_blocks
=
kv_cache
.
shape
[
1
]
key_cache
=
kv_cache
[
0
]
key_cache
=
key_cache
.
view
(
num_blocks
,
num_kv_heads
,
-
1
,
head_size
)
value_cache
=
kv_cache
[
1
]
value_cache
=
value_cache
.
view
(
num_blocks
,
num_kv_heads
,
-
1
,
head_size
)
return
key_cache
,
value_cache
@
staticmethod
def
write_to_paged_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
*
args
,
)
->
None
:
ipex_modules
.
PagedAttention
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
.
flatten
().
int
()
)
@
staticmethod
def
forward_decode
(
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_context_len
:
int
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
scale
:
float
,
alibi_slopes
:
torch
.
Tensor
|
None
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
*
args
,
)
->
None
:
block_size
=
value_cache
.
shape
[
2
]
head_mapping
=
(
torch
.
arange
(
0
,
num_kv_heads
,
device
=
"cpu"
,
dtype
=
torch
.
int32
,
)
.
view
(
num_kv_heads
,
1
)
.
repeat_interleave
(
query
.
size
(
1
)
//
num_kv_heads
)
.
flatten
()
)
ipex_modules
.
PagedAttention
.
single_query_cached_kv_attention
(
output
,
query
.
contiguous
(),
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
)
def
_get_paged_attn_impl
():
if
_use_ipex
:
return
_IPEXPagedAttention
def
_get_attn_isa
(
dtype
:
torch
.
dtype
,
block_size
:
int
)
->
str
:
supports_amx
=
torch
.
_C
.
_cpu
.
_is_amx_tile_supported
()
if
supports_amx
and
dtype
in
(
torch
.
bfloat16
,)
and
block_size
%
32
==
0
:
return
"amx"
elif
block_size
%
32
==
0
:
return
"vec"
else
:
return
_PagedAttention
return
"vec16"
vllm/v1/attention/backends/utils.py
View file @
7f829be7
...
...
@@ -265,7 +265,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
def
_init_reorder_batch_threshold
(
self
,
reorder_batch_threshold
:
int
=
1
,
reorder_batch_threshold
:
int
|
None
=
1
,
supports_spec_as_decode
:
bool
=
False
,
supports_dcp_with_varlen
:
bool
=
False
,
)
->
None
:
...
...
vllm/v1/worker/cpu_model_runner.py
View file @
7f829be7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
Any
import
torch
import
torch.nn
as
nn
...
...
@@ -12,9 +12,6 @@ from vllm.model_executor.model_loader import get_model
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
logger
=
init_logger
(
__name__
)
...
...
@@ -31,15 +28,6 @@ class CPUModelRunner(GPUModelRunner):
self
.
_postprocess_tensors
()
# Note: Remove the override after new attention backend finished
def
_may_reorder_batch
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
if
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
>
1
:
raise
ValueError
(
"Multiple KVCacheGroups is not"
"currently supported with CPU model runner."
)
super
().
_may_reorder_batch
(
scheduler_output
)
def
_postprocess_tensors
(
self
)
->
None
:
# Note: replace device tensors with cpu tensors
def
replace_tensor
(
obj
:
Any
,
cpu_attr_name
:
str
,
device_attr_name
)
->
None
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment