Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7721ef17
Unverified
Commit
7721ef17
authored
Jul 08, 2025
by
Li, Jiang
Committed by
GitHub
Jul 07, 2025
Browse files
[CI/Build][CPU] Fix CPU CI and remove all CPU V0 files (#20560)
Signed-off-by:
jiang1.li
<
jiang1.li@intel.com
>
parent
8369b7c2
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
785 additions
and
839 deletions
+785
-839
.buildkite/scripts/hardware_ci/run-cpu-test.sh
.buildkite/scripts/hardware_ci/run-cpu-test.sh
+12
-12
tests/basic_correctness/test_chunked_prefill.py
tests/basic_correctness/test_chunked_prefill.py
+0
-58
tests/models/language/generation/test_common.py
tests/models/language/generation/test_common.py
+6
-2
tests/models/language/pooling/test_embedding.py
tests/models/language/pooling/test_embedding.py
+11
-12
tests/models/language/pooling/test_reward.py
tests/models/language/pooling/test_reward.py
+5
-0
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+2
-1
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+0
-546
vllm/attention/ops/ipex_attn.py
vllm/attention/ops/ipex_attn.py
+0
-195
vllm/v1/attention/backends/cpu_attn.py
vllm/v1/attention/backends/cpu_attn.py
+749
-13
No files found.
.buildkite/scripts/hardware_ci/run-cpu-test.sh
View file @
7721ef17
...
...
@@ -48,10 +48,16 @@ function cpu_tests() {
# Run basic model test
docker
exec
cpu-test-
"
$NUMA_NODE
"
bash
-c
"
set -e
pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model
pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model
pytest -v -s tests/models/language/generation -m cpu_model
VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model
# Note: disable until supports V1
# pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model
# pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model
# Note: disable Bart until supports V1
pytest -v -s tests/models/language/generation -m cpu_model
\
--ignore=tests/models/language/generation/test_bart.py
VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model
\
--ignore=tests/models/language/generation/test_bart.py
pytest -v -s tests/models/language/pooling -m cpu_model
pytest -v -s tests/models/multimodal/generation
\
--ignore=tests/models/multimodal/generation/test_mllama.py
\
...
...
@@ -62,21 +68,15 @@ function cpu_tests() {
docker
exec
cpu-test-
"
$NUMA_NODE
"
bash
-c
"
set -e
pytest -s -v
\
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup
\
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token"
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs[False-10-32-neuralmagic/Llama-3.2-1B-quantized.w8a8]"
# Note: disable it until supports V1
# Run AWQ test
# docker exec cpu-test-"$NUMA_NODE" bash -c "
# set -e
# VLLM_USE_V1=0 pytest -s -v \
# tests/quantization/test_ipex_quant.py"
# Run chunked-prefill and prefix-cache test
docker
exec
cpu-test-
"
$NUMA_NODE
"
bash
-c
"
set -e
pytest -s -v -k cpu_model
\
tests/basic_correctness/test_chunked_prefill.py"
# online serving
docker
exec
cpu-test-
"
$NUMA_NODE
"
bash
-c
"
set -e
...
...
tests/basic_correctness/test_chunked_prefill.py
View file @
7721ef17
...
...
@@ -294,61 +294,3 @@ def test_with_prefix_caching(
name_0
=
"w/o prefix caching"
,
name_1
=
"with prefix caching"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"facebook/opt-125m"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
,
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"chunked_prefill_token_size"
,
[
1
,
4
,
16
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"attention_backend"
,
[
"TORCH_SDPA"
])
@
pytest
.
mark
.
cpu_model
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cpu
(),
reason
=
"CPU only"
)
def
test_models_cpu
(
hf_runner
:
HfRunner
,
vllm_runner
:
VllmRunner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
chunked_prefill_token_size
:
int
,
enforce_eager
:
bool
,
attention_backend
:
str
,
monkeypatch
:
pytest
.
MonkeyPatch
,
)
->
None
:
test_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
,
dtype
,
max_tokens
,
chunked_prefill_token_size
,
enforce_eager
,
1
,
attention_backend
,
monkeypatch
,
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"chunk_size"
,
[
30
,
32
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
,
"half"
])
@
pytest
.
mark
.
cpu_model
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cpu
(),
reason
=
"CPU only"
)
def
test_with_prefix_caching_cpu
(
vllm_runner
:
VllmRunner
,
max_tokens
:
int
,
enforce_eager
:
bool
,
chunk_size
:
int
,
dtype
:
str
,
)
->
None
:
test_with_prefix_caching
(
vllm_runner
,
max_tokens
,
enforce_eager
,
chunk_size
,
1
,
dtype
,
)
tests/models/language/generation/test_common.py
View file @
7721ef17
...
...
@@ -39,7 +39,7 @@ AITER_MODEL_LIST = [
[
pytest
.
param
(
"bigscience/bloom-560m"
,
# bloom - testing alibi slopes
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
],
marks
=
[
pytest
.
mark
.
core_model
],
),
pytest
.
param
(
"openai-community/gpt2"
,
# gpt2
...
...
@@ -87,7 +87,11 @@ AITER_MODEL_LIST = [
pytest
.
param
(
"bigcode/starcoder2-3b"
),
# starcoder2
pytest
.
param
(
"TitanML/tiny-mixtral"
,
# mixtral
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
],
marks
=
[
pytest
.
mark
.
core_model
],
),
pytest
.
param
(
"Qwen/Qwen1.5-MoE-A2.7B-Chat"
,
marks
=
[
pytest
.
mark
.
cpu_model
],
)
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
...
...
tests/models/language/pooling/test_embedding.py
View file @
7721ef17
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
from
typing
import
Optional
import
pytest
...
...
@@ -29,8 +28,10 @@ def v1(run_with_both_engines):
# [Decoder-only]
pytest
.
param
(
"BAAI/bge-multilingual-gemma2"
,
marks
=
[
pytest
.
mark
.
core_model
]),
pytest
.
param
(
"intfloat/e5-mistral-7b-instruct"
,
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
]),
pytest
.
param
(
"intfloat/e5-mistral-7b-instruct"
,
# CPU v1 doesn't support sliding window
marks
=
[
pytest
.
mark
.
core_model
]),
# the qwen models interfere with each other (see PR
# https://github.com/vllm-project/vllm/pull/18720).
# To avoid this problem, for now we skip v0 since it will be
...
...
@@ -38,11 +39,13 @@ def v1(run_with_both_engines):
pytest
.
param
(
"ssmits/Qwen2-7B-Instruct-embed-base"
,
marks
=
[
pytest
.
mark
.
skip_v0
,
pytest
.
mark
.
cpu_model
]),
# [Encoder-only]
pytest
.
param
(
"BAAI/bge-base-en-v1.5"
,
marks
=
[
pytest
.
mark
.
core_model
,
pytest
.
mark
.
cpu_model
,
pytest
.
mark
.
skip_v1
]),
pytest
.
param
(
"BAAI/bge-base-en-v1.5"
,
marks
=
[
# CPU only supports V1
pytest
.
mark
.
core_model
,
pytest
.
mark
.
skip_v1
]),
pytest
.
param
(
"sentence-transformers/all-MiniLM-L12-v2"
,
marks
=
[
pytest
.
mark
.
skip_v1
]),
pytest
.
param
(
"intfloat/multilingual-e5-small"
,
...
...
@@ -61,10 +64,6 @@ def test_models(
model
,
monkeypatch
,
)
->
None
:
if
model
==
"intfloat/e5-mistral-7b-instruct"
and
current_platform
.
is_cpu
(
)
and
os
.
environ
.
get
(
"VLLM_USE_V1"
,
"0"
)
==
"1"
:
pytest
.
skip
(
"CPU V1 doesn't support sliding window"
)
if
model
==
"BAAI/bge-multilingual-gemma2"
and
current_platform
.
is_rocm
():
# ROCm Triton FA does not currently support sliding window attention
# switch to use ROCm CK FA backend
...
...
tests/models/language/pooling/test_reward.py
View file @
7721ef17
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
pytest
import
torch
import
torch.nn.functional
as
F
...
...
@@ -84,6 +86,9 @@ def test_prm_models(
dtype
:
str
,
monkeypatch
,
)
->
None
:
if
current_platform
.
is_cpu
()
and
os
.
environ
.
get
(
"VLLM_USE_V1"
,
"0"
)
==
"0"
:
pytest
.
skip
(
"CPU only supports V1"
)
if
current_platform
.
is_rocm
():
# ROCm Triton FA does not currently support sliding window attention
# switch to use ROCm CK FA backend
...
...
tests/quantization/test_compressed_tensors.py
View file @
7721ef17
...
...
@@ -45,7 +45,8 @@ def use_v0_only(monkeypatch):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
if
not
current_platform
.
is_cpu
():
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
@
pytest
.
mark
.
parametrize
(
...
...
vllm/attention/backends/torch_sdpa.py
deleted
100644 → 0
View file @
8369b7c2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.functional
import
scaled_dot_product_attention
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.attention.backends.abstract
import
(
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
# yapf: enable
from
vllm.attention.ops.ipex_attn
import
PagedAttention
,
_use_ipex
from
vllm.attention.ops.paged_attn
import
PagedAttentionMetadata
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
@
dataclass
class
TorchSDPAMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
chunked_prefill
:
bool
seq_lens
:
Optional
[
List
[
int
]]
=
None
# For non-chunked prefill
# For chunked prefill only
max_query_len
:
Optional
[
int
]
=
None
max_kv_len
:
Optional
[
int
]
=
None
prefill_query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
kv_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
prefill_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
# For V1 logits index only
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens
:
Optional
[
List
[
int
]]
=
None
encoder_seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# Maximum sequence length among encoder sequences
max_encoder_seq_len
:
Optional
[
int
]
=
None
# Number of tokens input to encoder
num_encoder_tokens
:
Optional
[
int
]
=
None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping
:
Optional
[
torch
.
Tensor
]
=
None
cross_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self
.
attn_bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
self
.
encoder_attn_bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
self
.
cross_attn_bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
@
property
def
is_all_encoder_attn_metadata_set
(
self
):
'''
All attention metadata required for encoder attention is set.
'''
return
((
self
.
encoder_seq_lens
is
not
None
)
and
(
self
.
encoder_seq_lens_tensor
is
not
None
)
and
(
self
.
max_encoder_seq_len
is
not
None
))
@
property
def
is_all_cross_attn_metadata_set
(
self
):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return
(
self
.
is_all_encoder_attn_metadata_set
and
(
self
.
cross_slot_mapping
is
not
None
)
and
(
self
.
cross_block_tables
is
not
None
))
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"TorchSDPAMetadata"
]:
if
self
.
num_prefill_tokens
==
0
:
return
None
return
self
@
property
def
decode_metadata
(
self
)
->
Optional
[
"TorchSDPAMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
return
self
def
get_seq_lens
(
self
,
attn_type
:
str
,
):
'''
Extract appropriate sequence lengths from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence lengths tensor for query
* Appropriate sequence lengths tensor for key & value
'''
if
(
attn_type
==
AttentionType
.
DECODER
or
attn_type
==
AttentionType
.
ENCODER_ONLY
):
seq_lens_q
=
self
.
seq_lens
seq_lens_kv
=
self
.
seq_lens
elif
attn_type
==
AttentionType
.
ENCODER
:
seq_lens_q
=
self
.
encoder_seq_lens
seq_lens_kv
=
self
.
encoder_seq_lens
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
seq_lens_q
=
self
.
seq_lens
seq_lens_kv
=
self
.
encoder_seq_lens
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
return
seq_lens_q
,
seq_lens_kv
def
get_attn_bias
(
self
,
attn_type
:
str
,
)
->
Optional
[
List
[
torch
.
Tensor
]]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if
(
attn_type
==
AttentionType
.
DECODER
or
attn_type
==
AttentionType
.
ENCODER_ONLY
):
return
self
.
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER
:
return
self
.
encoder_attn_bias
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
return
self
.
cross_attn_bias
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
set_attn_bias
(
self
,
attn_bias
:
List
[
torch
.
Tensor
],
attn_type
:
str
,
)
->
None
:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if
(
attn_type
==
AttentionType
.
DECODER
or
attn_type
==
AttentionType
.
ENCODER_ONLY
):
self
.
attn_bias
=
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER
:
self
.
encoder_attn_bias
=
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
self
.
cross_attn_bias
=
attn_bias
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
get_seq_len_block_table_args
(
self
,
attn_type
:
str
,
)
->
tuple
:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if
(
attn_type
==
AttentionType
.
DECODER
or
attn_type
==
AttentionType
.
ENCODER_ONLY
):
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return
(
self
.
seq_lens_tensor
,
self
.
max_decode_seq_len
,
self
.
block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return
(
self
.
encoder_seq_lens_tensor
,
self
.
max_encoder_seq_len
,
self
.
cross_block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER
:
# No block tables associated with encoder attention
return
(
self
.
encoder_seq_lens_tensor
,
self
.
max_encoder_seq_len
,
None
)
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
class
TorchSDPABackendImpl
(
AttentionImpl
[
TorchSDPAMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"Torch SPDA does not support block-sparse attention."
)
if
logits_soft_cap
is
not
None
:
logger
.
warning_once
(
"Torch SPDA does not support logits soft cap. "
"Outputs may be slightly off."
)
if
use_irope
:
logger
.
warning_once
(
"Using irope in Torch SPDA is not supported yet, it will fall"
" back to global attention for long context."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
sliding_window
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
need_mask
=
(
self
.
alibi_slopes
is
not
None
or
self
.
sliding_window
is
not
None
)
supported_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
supported_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
if
is_quantized_kv_cache
(
kv_cache_dtype
)
and
not
_use_ipex
:
raise
NotImplementedError
(
"Torch SDPA backend FP8 KV cache requires "
"intel_extension_for_pytorch support."
)
self
.
attn_type
=
attn_type
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with torch SDPA and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for TorchSDPABackendImpl"
)
# For warming-up
if
attn_metadata
is
None
:
return
query
attn_type
=
self
.
attn_type
if
(
attn_type
==
AttentionType
.
ENCODER
and
(
not
attn_metadata
.
is_all_encoder_attn_metadata_set
)):
raise
AttributeError
(
"Encoder attention requires setting "
"encoder metadata attributes."
)
elif
(
attn_type
==
AttentionType
.
ENCODER_DECODER
and
(
not
attn_metadata
.
is_all_cross_attn_metadata_set
)):
raise
AttributeError
(
"Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes."
)
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
if
key
is
not
None
:
assert
value
is
not
None
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
else
:
assert
value
is
None
if
(
attn_type
!=
AttentionType
.
ENCODER
and
kv_cache
.
numel
()
>
0
):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
if
(
key
is
not
None
)
and
(
value
is
not
None
):
if
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping
=
attn_metadata
.
cross_slot_mapping
else
:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping
=
attn_metadata
.
slot_mapping
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
updated_slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
)
if
attn_type
!=
AttentionType
.
ENCODER
:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
else
:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert
attn_metadata
.
num_encoder_tokens
is
not
None
num_prefill_tokens
=
attn_metadata
.
num_encoder_tokens
num_decode_tokens
=
0
if
attn_type
==
AttentionType
.
DECODER
:
# Only enforce this shape-constraint for decoder
# self-attention
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
output
=
torch
.
empty_like
(
query
)
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
not
prefill_meta
.
prefill_metadata
.
chunked_prefill
:
# type: ignore
assert
attn_metadata
.
seq_lens
is
not
None
self
.
_run_sdpa_forward
(
output
,
query
,
key
,
value
,
prefill_meta
,
attn_type
=
attn_type
)
else
:
# prefix-enabled attention
assert
not
self
.
need_mask
import
intel_extension_for_pytorch.llm.modules
as
ipex_modules
output
=
torch
.
empty_like
(
query
)
ipex_modules
.
PagedAttention
.
flash_attn_varlen_func
(
output
[:
prefill_meta
.
num_prefill_tokens
,
:,
:],
query
[:
prefill_meta
.
num_prefill_tokens
,
:,
:],
key_cache
,
value_cache
,
prefill_meta
.
prefill_query_start_loc
,
prefill_meta
.
kv_start_loc
,
prefill_meta
.
max_query_len
,
prefill_meta
.
max_kv_len
,
self
.
scale
,
True
,
prefill_meta
.
prefill_block_tables
,
self
.
alibi_slopes
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
assert
attn_type
!=
AttentionType
.
ENCODER_ONLY
,
(
"Encoder-only models should not have decode metadata."
)
# Decoding run.
(
seq_lens_arg
,
max_seq_len_arg
,
block_tables_arg
,
)
=
decode_meta
.
get_seq_len_block_table_args
(
attn_type
)
PagedAttention
.
forward_decode
(
output
[
attn_metadata
.
num_prefill_tokens
:,
:,
:],
query
[
attn_metadata
.
num_prefill_tokens
:,
:,
:],
key_cache
,
value_cache
,
block_tables_arg
,
seq_lens_arg
,
max_seq_len_arg
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
# Reshape the output tensor.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
def
_run_sdpa_forward
(
self
,
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
TorchSDPAMetadata
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
attn_masks
=
attn_metadata
.
get_attn_bias
(
attn_type
)
if
attn_masks
is
None
:
if
self
.
alibi_slopes
is
not
None
:
attn_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
attn_metadata
.
seq_lens
)
# type: ignore
elif
self
.
sliding_window
is
not
None
:
assert
attn_metadata
.
seq_lens
is
not
None
attn_masks
=
_make_sliding_window_bias
(
attn_metadata
.
seq_lens
,
self
.
sliding_window
,
query
.
dtype
)
# type: ignore
else
:
seq_lens
,
_
=
attn_metadata
.
get_seq_lens
(
attn_type
)
attn_masks
=
[
None
]
*
len
(
seq_lens
)
attn_metadata
.
set_attn_bias
(
attn_masks
,
attn_type
)
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
causal_attn
=
(
attn_type
==
AttentionType
.
DECODER
)
seq_lens_q
,
seq_lens_kv
=
attn_metadata
.
get_seq_lens
(
attn_type
)
start_q
,
start_kv
=
0
,
0
for
seq_len_q
,
seq_len_kv
,
mask
in
zip
(
seq_lens_q
,
seq_lens_kv
,
attn_masks
):
end_q
=
start_q
+
seq_len_q
end_kv
=
start_kv
+
seq_len_kv
sub_out
=
scaled_dot_product_attention
(
query
[
None
,
:,
start_q
:
end_q
,
:],
key
[
None
,
:,
start_kv
:
end_kv
,
:],
value
[
None
,
:,
start_kv
:
end_kv
,
:],
attn_mask
=
mask
,
dropout_p
=
0.0
,
is_causal
=
causal_attn
and
mask
is
None
,
scale
=
self
.
scale
).
squeeze
(
0
).
movedim
(
query
.
dim
()
-
2
,
0
)
output
[
start_q
:
end_q
,
:,
:]
=
sub_out
start_q
,
start_kv
=
end_q
,
end_kv
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
seq_lens
:
List
[
int
],
)
->
List
[
torch
.
Tensor
]:
attn_biases
:
List
[
torch
.
Tensor
]
=
[]
for
seq_len
in
seq_lens
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
bias
[
None
,
:].
repeat
((
num_heads
,
1
,
1
))
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
]).
unsqueeze_
(
0
)
inf_mask
=
torch
.
empty
(
(
1
,
seq_len
,
seq_len
),
dtype
=
bias
.
dtype
).
fill_
(
-
torch
.
inf
).
triu_
(
diagonal
=
1
)
attn_biases
.
append
((
bias
+
inf_mask
).
to
(
dtype
))
return
attn_biases
def
_make_sliding_window_bias
(
seq_lens
:
List
[
int
],
window_size
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
)
->
List
[
torch
.
Tensor
]:
attn_biases
:
List
[
torch
.
Tensor
]
=
[]
for
seq_len
in
seq_lens
:
tensor
=
torch
.
full
(
(
1
,
seq_len
,
seq_len
),
dtype
=
dtype
,
fill_value
=
1
,
)
shift
=
0
mask
=
torch
.
tril
(
tensor
,
diagonal
=
shift
).
to
(
dtype
)
# type: ignore
if
window_size
is
not
None
:
mask
=
torch
.
triu
(
mask
,
diagonal
=
shift
-
window_size
+
1
)
mask
=
torch
.
log
(
mask
)
attn_biases
.
append
(
mask
.
to
(
dtype
))
return
attn_biases
vllm/attention/ops/ipex_attn.py
deleted
100644 → 0
View file @
8369b7c2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
List
,
Optional
,
Tuple
try
:
import
intel_extension_for_pytorch.llm.modules
as
ipex_modules
_use_ipex
=
True
# AttributeError is to handle a bug in ipex https://github.com/intel/intel-extension-for-pytorch/pull/813
except
(
ImportError
,
AttributeError
):
_use_ipex
=
False
import
torch
from
vllm
import
_custom_ops
as
ops
class
_PagedAttention
:
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
32
,
64
,
80
,
96
,
112
,
128
,
192
,
256
]
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
*
args
,
)
->
Tuple
[
int
,
...]:
return
2
,
num_blocks
,
block_size
*
num_kv_heads
*
head_size
@
staticmethod
def
split_kv_cache
(
kv_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
head_size
:
int
,
*
args
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
x
=
16
//
kv_cache
.
element_size
()
num_blocks
=
kv_cache
.
shape
[
1
]
key_cache
=
kv_cache
[
0
]
key_cache
=
key_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
//
x
,
-
1
,
x
)
value_cache
=
kv_cache
[
1
]
value_cache
=
value_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
,
-
1
)
return
key_cache
,
value_cache
@
staticmethod
def
write_to_paged_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
*
args
,
)
->
None
:
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
.
flatten
(),
kv_cache_dtype
,
k_scale
,
v_scale
,
)
@
staticmethod
def
forward_decode
(
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_context_len
:
int
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
*
args
,
)
->
None
:
tp_rank
:
int
=
0
blocksparse_local_blocks
:
int
=
0
blocksparse_vert_stride
:
int
=
0
blocksparse_block_size
:
int
=
64
blocksparse_head_sliding_step
:
int
=
0
block_size
=
value_cache
.
shape
[
3
]
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
*
args
,
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
class
_IPEXPagedAttention
(
_PagedAttention
):
@
staticmethod
def
split_kv_cache
(
kv_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
head_size
:
int
,
*
args
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
num_blocks
=
kv_cache
.
shape
[
1
]
key_cache
=
kv_cache
[
0
]
key_cache
=
key_cache
.
view
(
num_blocks
,
num_kv_heads
,
-
1
,
head_size
)
value_cache
=
kv_cache
[
1
]
value_cache
=
value_cache
.
view
(
num_blocks
,
num_kv_heads
,
-
1
,
head_size
)
return
key_cache
,
value_cache
@
staticmethod
def
write_to_paged_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
*
args
,
)
->
None
:
ipex_modules
.
PagedAttention
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
.
flatten
().
int
())
@
staticmethod
def
forward_decode
(
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_context_len
:
int
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
*
args
,
)
->
None
:
block_size
=
value_cache
.
shape
[
2
]
head_mapping
=
torch
.
arange
(
0
,
num_kv_heads
,
device
=
"cpu"
,
dtype
=
torch
.
int32
,
).
view
(
num_kv_heads
,
1
).
repeat_interleave
(
query
.
size
(
1
)
//
num_kv_heads
).
flatten
()
ipex_modules
.
PagedAttention
.
single_query_cached_kv_attention
(
output
,
query
.
contiguous
(),
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
)
PagedAttention
=
_IPEXPagedAttention
if
_use_ipex
else
_PagedAttention
vllm/v1/attention/backends/cpu_attn.py
View file @
7721ef17
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
import
numpy
as
np
import
torch
from
torch.nn.functional
import
scaled_dot_product_attention
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
Attention
Metadata
)
from
vllm.attention.backends.torch_sdpa
import
(
TorchSDPABackendImpl
,
TorchSDPAMetadata
)
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
Attention
Layer
,
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.
attention.ops.ipex_attn
import
PagedAttention
from
vllm.
logger
import
init_logger
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
...
@@ -17,18 +21,28 @@ from vllm.v1.worker.block_table import BlockTable
from
vllm.v1.worker.cpu_model_runner
import
CPUModelRunner
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
try
:
import
intel_extension_for_pytorch.llm.modules
as
ipex_modules
_use_ipex
=
True
# AttributeError is to handle a bug in ipex
# https://github.com/intel/intel-extension-for-pytorch/pull/813
except
(
ImportError
,
AttributeError
):
_use_ipex
=
False
from
vllm
import
_custom_ops
as
ops
logger
=
init_logger
(
__name__
)
class
TorchSDPABackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
False
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
return
PagedAttention
.
get_supported_head_sizes
()
@
classmethod
def
validate_head_size
(
cls
,
head_size
:
int
)
->
None
:
supported_head_sizes
=
cls
.
get_supported_head_sizes
()
if
head_size
not
in
supported_head_sizes
:
attn_impl
=
_get_paged_attn_impl
()
is_valid
,
supported_head_sizes
=
attn_impl
.
validate_head_size
(
head_size
)
if
not
is_valid
:
attn_type
=
cls
.
__name__
.
removesuffix
(
"Backend"
)
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by
{
attn_type
}
. "
...
...
@@ -63,14 +77,239 @@ class TorchSDPABackend(AttentionBackend):
num_kv_heads
:
int
,
head_size
:
int
,
)
->
tuple
[
int
,
...]:
return
PagedAttention
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
return
_get_paged_attn_impl
().
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
use_cascade_attention
(
*
args
,
**
kwargs
)
->
bool
:
return
False
@
dataclass
class
TorchSDPAMetadata
(
AttentionMetadata
):
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
max_decode_seq_len
:
int
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables
:
Optional
[
torch
.
Tensor
]
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
chunked_prefill
:
bool
seq_lens
:
Optional
[
list
[
int
]]
=
None
# For non-chunked prefill
# For chunked prefill only
max_query_len
:
Optional
[
int
]
=
None
max_kv_len
:
Optional
[
int
]
=
None
prefill_query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
kv_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
prefill_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
# For V1 logits index only
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens
:
Optional
[
list
[
int
]]
=
None
encoder_seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# Maximum sequence length among encoder sequences
max_encoder_seq_len
:
Optional
[
int
]
=
None
# Number of tokens input to encoder
num_encoder_tokens
:
Optional
[
int
]
=
None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping
:
Optional
[
torch
.
Tensor
]
=
None
cross_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self
.
attn_bias
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
self
.
encoder_attn_bias
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
self
.
cross_attn_bias
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
@
property
def
is_all_encoder_attn_metadata_set
(
self
):
'''
All attention metadata required for encoder attention is set.
'''
return
((
self
.
encoder_seq_lens
is
not
None
)
and
(
self
.
encoder_seq_lens_tensor
is
not
None
)
and
(
self
.
max_encoder_seq_len
is
not
None
))
@
property
def
is_all_cross_attn_metadata_set
(
self
):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return
(
self
.
is_all_encoder_attn_metadata_set
and
(
self
.
cross_slot_mapping
is
not
None
)
and
(
self
.
cross_block_tables
is
not
None
))
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"TorchSDPAMetadata"
]:
if
self
.
num_prefill_tokens
==
0
:
return
None
return
self
@
property
def
decode_metadata
(
self
)
->
Optional
[
"TorchSDPAMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
return
self
def
get_seq_lens
(
self
,
attn_type
:
str
,
):
'''
Extract appropriate sequence lengths from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence lengths tensor for query
* Appropriate sequence lengths tensor for key & value
'''
if
(
attn_type
==
AttentionType
.
DECODER
or
attn_type
==
AttentionType
.
ENCODER_ONLY
):
seq_lens_q
=
self
.
seq_lens
seq_lens_kv
=
self
.
seq_lens
elif
attn_type
==
AttentionType
.
ENCODER
:
seq_lens_q
=
self
.
encoder_seq_lens
seq_lens_kv
=
self
.
encoder_seq_lens
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
seq_lens_q
=
self
.
seq_lens
seq_lens_kv
=
self
.
encoder_seq_lens
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
return
seq_lens_q
,
seq_lens_kv
def
get_attn_bias
(
self
,
attn_type
:
str
,
)
->
Optional
[
list
[
torch
.
Tensor
]]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if
(
attn_type
==
AttentionType
.
DECODER
or
attn_type
==
AttentionType
.
ENCODER_ONLY
):
return
self
.
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER
:
return
self
.
encoder_attn_bias
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
return
self
.
cross_attn_bias
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
set_attn_bias
(
self
,
attn_bias
:
list
[
torch
.
Tensor
],
attn_type
:
str
,
)
->
None
:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if
(
attn_type
==
AttentionType
.
DECODER
or
attn_type
==
AttentionType
.
ENCODER_ONLY
):
self
.
attn_bias
=
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER
:
self
.
encoder_attn_bias
=
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
self
.
cross_attn_bias
=
attn_bias
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
get_seq_len_block_table_args
(
self
,
attn_type
:
str
,
)
->
tuple
:
'''
The particular choice of sequence-length- and block-table-related
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* is_prompt: True if prefill, False otherwise
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate sequence-lengths tensor
* Appropriate max sequence-length scalar
* Appropriate block tables (or None)
'''
if
(
attn_type
==
AttentionType
.
DECODER
or
attn_type
==
AttentionType
.
ENCODER_ONLY
):
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return
(
self
.
seq_lens_tensor
,
self
.
max_decode_seq_len
,
self
.
block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return
(
self
.
encoder_seq_lens_tensor
,
self
.
max_encoder_seq_len
,
self
.
cross_block_tables
)
elif
attn_type
==
AttentionType
.
ENCODER
:
# No block tables associated with encoder attention
return
(
self
.
encoder_seq_lens_tensor
,
self
.
max_encoder_seq_len
,
None
)
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
class
TorchSDPAMetadataBuilderV1
(
AttentionMetadataBuilder
[
TorchSDPAMetadata
]):
def
__init__
(
self
,
runner
:
CPUModelRunner
,
kv_cache_spec
:
AttentionSpec
,
...
...
@@ -182,3 +421,500 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
)
return
attn_metadata
class
TorchSDPABackendImpl
(
AttentionImpl
[
TorchSDPAMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
list
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"Torch SPDA does not support block-sparse attention."
)
if
logits_soft_cap
is
not
None
:
logger
.
warning_once
(
"Torch SPDA does not support logits soft cap. "
"Outputs may be slightly off."
)
if
use_irope
:
logger
.
warning_once
(
"Using irope in Torch SPDA is not supported yet, it will fall"
" back to global attention for long context."
)
self
.
paged_attn_impl
=
_get_paged_attn_impl
()
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
sliding_window
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
need_mask
=
(
self
.
alibi_slopes
is
not
None
or
self
.
sliding_window
is
not
None
)
if
is_quantized_kv_cache
(
kv_cache_dtype
)
and
not
_use_ipex
:
raise
NotImplementedError
(
"Torch SDPA backend FP8 KV cache requires "
"intel_extension_for_pytorch support."
)
self
.
attn_type
=
attn_type
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with torch SDPA and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for TorchSDPABackendImpl"
)
# For warming-up
if
attn_metadata
is
None
:
return
query
attn_type
=
self
.
attn_type
if
(
attn_type
==
AttentionType
.
ENCODER
and
(
not
attn_metadata
.
is_all_encoder_attn_metadata_set
)):
raise
AttributeError
(
"Encoder attention requires setting "
"encoder metadata attributes."
)
elif
(
attn_type
==
AttentionType
.
ENCODER_DECODER
and
(
not
attn_metadata
.
is_all_cross_attn_metadata_set
)):
raise
AttributeError
(
"Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes."
)
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
if
key
is
not
None
:
assert
value
is
not
None
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
else
:
assert
value
is
None
if
(
attn_type
!=
AttentionType
.
ENCODER
and
kv_cache
.
numel
()
>
0
):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache
,
value_cache
=
self
.
paged_attn_impl
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
if
(
key
is
not
None
)
and
(
value
is
not
None
):
if
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping
=
attn_metadata
.
cross_slot_mapping
else
:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping
=
attn_metadata
.
slot_mapping
self
.
paged_attn_impl
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
updated_slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
)
if
attn_type
!=
AttentionType
.
ENCODER
:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
else
:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert
attn_metadata
.
num_encoder_tokens
is
not
None
num_prefill_tokens
=
attn_metadata
.
num_encoder_tokens
num_decode_tokens
=
0
if
attn_type
==
AttentionType
.
DECODER
:
# Only enforce this shape-constraint for decoder
# self-attention
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
output
=
torch
.
empty_like
(
query
)
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
not
prefill_meta
.
prefill_metadata
.
chunked_prefill
:
# type: ignore
assert
attn_metadata
.
seq_lens
is
not
None
self
.
_run_sdpa_forward
(
output
,
query
,
key
,
value
,
prefill_meta
,
attn_type
=
attn_type
)
else
:
# prefix-enabled attention
assert
not
self
.
need_mask
import
intel_extension_for_pytorch.llm.modules
as
ipex_modules
output
=
torch
.
empty_like
(
query
)
ipex_modules
.
PagedAttention
.
flash_attn_varlen_func
(
output
[:
prefill_meta
.
num_prefill_tokens
,
:,
:],
query
[:
prefill_meta
.
num_prefill_tokens
,
:,
:],
key_cache
,
value_cache
,
prefill_meta
.
prefill_query_start_loc
,
prefill_meta
.
kv_start_loc
,
prefill_meta
.
max_query_len
,
prefill_meta
.
max_kv_len
,
self
.
scale
,
True
,
prefill_meta
.
prefill_block_tables
,
self
.
alibi_slopes
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
assert
attn_type
!=
AttentionType
.
ENCODER_ONLY
,
(
"Encoder-only models should not have decode metadata."
)
# Decoding run.
(
seq_lens_arg
,
max_seq_len_arg
,
block_tables_arg
,
)
=
decode_meta
.
get_seq_len_block_table_args
(
attn_type
)
self
.
paged_attn_impl
.
forward_decode
(
output
[
attn_metadata
.
num_prefill_tokens
:,
:,
:],
query
[
attn_metadata
.
num_prefill_tokens
:,
:,
:],
key_cache
,
value_cache
,
block_tables_arg
,
seq_lens_arg
,
max_seq_len_arg
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
# Reshape the output tensor.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
def
_run_sdpa_forward
(
self
,
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
TorchSDPAMetadata
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
attn_masks
=
attn_metadata
.
get_attn_bias
(
attn_type
)
if
attn_masks
is
None
:
if
self
.
alibi_slopes
is
not
None
:
attn_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
attn_metadata
.
seq_lens
)
# type: ignore
elif
self
.
sliding_window
is
not
None
:
assert
attn_metadata
.
seq_lens
is
not
None
attn_masks
=
_make_sliding_window_bias
(
attn_metadata
.
seq_lens
,
self
.
sliding_window
,
query
.
dtype
)
# type: ignore
else
:
seq_lens
,
_
=
attn_metadata
.
get_seq_lens
(
attn_type
)
attn_masks
=
[
None
]
*
len
(
seq_lens
)
attn_metadata
.
set_attn_bias
(
attn_masks
,
attn_type
)
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
causal_attn
=
(
attn_type
==
AttentionType
.
DECODER
)
seq_lens_q
,
seq_lens_kv
=
attn_metadata
.
get_seq_lens
(
attn_type
)
start_q
,
start_kv
=
0
,
0
for
seq_len_q
,
seq_len_kv
,
mask
in
zip
(
seq_lens_q
,
seq_lens_kv
,
attn_masks
):
end_q
=
start_q
+
seq_len_q
end_kv
=
start_kv
+
seq_len_kv
sub_out
=
scaled_dot_product_attention
(
query
[
None
,
:,
start_q
:
end_q
,
:],
key
[
None
,
:,
start_kv
:
end_kv
,
:],
value
[
None
,
:,
start_kv
:
end_kv
,
:],
attn_mask
=
mask
,
dropout_p
=
0.0
,
is_causal
=
causal_attn
and
mask
is
None
,
scale
=
self
.
scale
).
squeeze
(
0
).
movedim
(
query
.
dim
()
-
2
,
0
)
output
[
start_q
:
end_q
,
:,
:]
=
sub_out
start_q
,
start_kv
=
end_q
,
end_kv
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
seq_lens
:
list
[
int
],
)
->
list
[
torch
.
Tensor
]:
attn_biases
:
list
[
torch
.
Tensor
]
=
[]
for
seq_len
in
seq_lens
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
bias
[
None
,
:].
repeat
((
num_heads
,
1
,
1
))
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
]).
unsqueeze_
(
0
)
inf_mask
=
torch
.
empty
(
(
1
,
seq_len
,
seq_len
),
dtype
=
bias
.
dtype
).
fill_
(
-
torch
.
inf
).
triu_
(
diagonal
=
1
)
attn_biases
.
append
((
bias
+
inf_mask
).
to
(
dtype
))
return
attn_biases
def
_make_sliding_window_bias
(
seq_lens
:
list
[
int
],
window_size
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
)
->
list
[
torch
.
Tensor
]:
attn_biases
:
list
[
torch
.
Tensor
]
=
[]
for
seq_len
in
seq_lens
:
tensor
=
torch
.
full
(
(
1
,
seq_len
,
seq_len
),
dtype
=
dtype
,
fill_value
=
1
,
)
shift
=
0
mask
=
torch
.
tril
(
tensor
,
diagonal
=
shift
).
to
(
dtype
)
# type: ignore
if
window_size
is
not
None
:
mask
=
torch
.
triu
(
mask
,
diagonal
=
shift
-
window_size
+
1
)
mask
=
torch
.
log
(
mask
)
attn_biases
.
append
(
mask
.
to
(
dtype
))
return
attn_biases
class
_PagedAttention
:
@
staticmethod
def
validate_head_size
(
head_size
:
int
)
->
tuple
[
bool
,
list
[
int
]]:
SUPPORT_HS
=
[
32
,
64
,
80
,
96
,
112
,
128
,
192
,
256
]
return
head_size
in
SUPPORT_HS
,
SUPPORT_HS
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
*
args
,
)
->
tuple
[
int
,
...]:
return
2
,
num_blocks
,
block_size
*
num_kv_heads
*
head_size
@
staticmethod
def
split_kv_cache
(
kv_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
head_size
:
int
,
*
args
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
x
=
16
//
kv_cache
.
element_size
()
num_blocks
=
kv_cache
.
shape
[
1
]
key_cache
=
kv_cache
[
0
]
key_cache
=
key_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
//
x
,
-
1
,
x
)
value_cache
=
kv_cache
[
1
]
value_cache
=
value_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
,
-
1
)
return
key_cache
,
value_cache
@
staticmethod
def
write_to_paged_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
*
args
,
)
->
None
:
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
.
flatten
(),
kv_cache_dtype
,
k_scale
,
v_scale
,
)
@
staticmethod
def
forward_decode
(
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_context_len
:
int
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
*
args
,
)
->
None
:
tp_rank
:
int
=
0
blocksparse_local_blocks
:
int
=
0
blocksparse_vert_stride
:
int
=
0
blocksparse_block_size
:
int
=
64
blocksparse_head_sliding_step
:
int
=
0
block_size
=
value_cache
.
shape
[
3
]
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
list
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
*
args
,
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
class
_IPEXPagedAttention
(
_PagedAttention
):
@
staticmethod
def
validate_head_size
(
head_size
:
int
)
->
tuple
[
bool
,
list
[
int
]]:
return
True
,
[]
@
staticmethod
def
split_kv_cache
(
kv_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
head_size
:
int
,
*
args
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
num_blocks
=
kv_cache
.
shape
[
1
]
key_cache
=
kv_cache
[
0
]
key_cache
=
key_cache
.
view
(
num_blocks
,
num_kv_heads
,
-
1
,
head_size
)
value_cache
=
kv_cache
[
1
]
value_cache
=
value_cache
.
view
(
num_blocks
,
num_kv_heads
,
-
1
,
head_size
)
return
key_cache
,
value_cache
@
staticmethod
def
write_to_paged_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
*
args
,
)
->
None
:
ipex_modules
.
PagedAttention
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
.
flatten
().
int
())
@
staticmethod
def
forward_decode
(
output
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_context_len
:
int
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
*
args
,
)
->
None
:
block_size
=
value_cache
.
shape
[
2
]
head_mapping
=
torch
.
arange
(
0
,
num_kv_heads
,
device
=
"cpu"
,
dtype
=
torch
.
int32
,
).
view
(
num_kv_heads
,
1
).
repeat_interleave
(
query
.
size
(
1
)
//
num_kv_heads
).
flatten
()
ipex_modules
.
PagedAttention
.
single_query_cached_kv_attention
(
output
,
query
.
contiguous
(),
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
)
def
_get_paged_attn_impl
():
if
_use_ipex
:
return
_IPEXPagedAttention
else
:
return
_PagedAttention
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment