Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d5cdd53
Commit
8d5cdd53
authored
Jan 26, 2026
by
zhuwenwen
Browse files
remove unused code
parent
b8025f24
Changes
13
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
7 additions
and
4584 deletions
+7
-4584
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+3
-1
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+3
-3
vllm/transformers_utils/configs/deepseek_v3.py
vllm/transformers_utils/configs/deepseek_v3.py
+0
-101
vllm/utils/deep_gemm.py
vllm/utils/deep_gemm.py
+0
-100
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+1
-135
vllm/v1/attention/cpu_mla.py
vllm/v1/attention/cpu_mla.py
+0
-307
vllm/v1/attention/ipex_attn.py
vllm/v1/attention/ipex_attn.py
+0
-403
vllm/v1/attention/mla/common.py
vllm/v1/attention/mla/common.py
+0
-1312
vllm/v1/attention/ops/common.py
vllm/v1/attention/ops/common.py
+0
-205
vllm/v1/attention/pallas.py
vllm/v1/attention/pallas.py
+0
-356
vllm/v1/attention/rocm_flash_attn.py
vllm/v1/attention/rocm_flash_attn.py
+0
-953
vllm/v1/attention/torch_sdpa.py
vllm/v1/attention/torch_sdpa.py
+0
-707
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+0
-1
No files found.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
8d5cdd53
...
@@ -1030,7 +1030,7 @@ def zero_experts_compute_triton(
...
@@ -1030,7 +1030,7 @@ def zero_experts_compute_triton(
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
def
get_config_file_name
(
def
get_config_file_name
(
E
:
int
,
N
:
int
,
dtype
:
str
|
None
,
block_shape
:
list
[
int
]
|
None
=
None
,
E
:
int
,
N
:
int
,
dtype
:
str
|
None
,
block_shape
:
list
[
int
]
|
None
=
None
)
->
str
:
)
->
str
:
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
# Set device_name to H200 if a device from the H200 family is detected
# Set device_name to H200 if a device from the H200 family is detected
...
@@ -1042,6 +1042,7 @@ def get_config_file_name(
...
@@ -1042,6 +1042,7 @@ def get_config_file_name(
).
replace
(
" "
,
""
)
).
replace
(
" "
,
""
)
return
f
"E=
{
E
}
,N=
{
N
}
,device_name=
{
device_name
}{
dtype_selector
}{
block_shape_selector
}
.json"
# noqa: E501
return
f
"E=
{
E
}
,N=
{
N
}
,device_name=
{
device_name
}{
dtype_selector
}{
block_shape_selector
}
.json"
# noqa: E501
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
@
functools
.
lru_cache
@
functools
.
lru_cache
def
get_moe_configs
(
def
get_moe_configs
(
...
@@ -1907,6 +1908,7 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
...
@@ -1907,6 +1908,7 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
return
torch_vllm_inplace_fused_experts
return
torch_vllm_inplace_fused_experts
return
torch_vllm_outplace_fused_experts
return
torch_vllm_outplace_fused_experts
# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace
# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace
# torch ops.
# torch ops.
def
fused_experts
(
def
fused_experts
(
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
8d5cdd53
...
@@ -1038,7 +1038,7 @@ class FusedMoE(CustomOp):
...
@@ -1038,7 +1038,7 @@ class FusedMoE(CustomOp):
shard_size
=
expert_data
.
shape
[
shard_dim
]
shard_size
=
expert_data
.
shape
[
shard_dim
]
if
not
load_full
:
if
not
load_full
:
loaded_weight
=
loaded_weight
.
narrow
(
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
)
# Narrow parameter and load.
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
# w1, gate_proj: Load into first logical weight of w13.
...
@@ -1064,7 +1064,7 @@ class FusedMoE(CustomOp):
...
@@ -1064,7 +1064,7 @@ class FusedMoE(CustomOp):
shard_size
=
expert_data
.
shape
[
shard_dim
]
shard_size
=
expert_data
.
shape
[
shard_dim
]
if
not
load_full
:
if
not
load_full
:
loaded_weight
=
loaded_weight
.
narrow
(
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
)
# w2, down_proj: Load into only logical weight of w2.
# w2, down_proj: Load into only logical weight of w2.
expert_data
.
copy_
(
loaded_weight
)
expert_data
.
copy_
(
loaded_weight
)
...
@@ -1222,7 +1222,7 @@ class FusedMoE(CustomOp):
...
@@ -1222,7 +1222,7 @@ class FusedMoE(CustomOp):
# is_transposed: if the dim to shard the weight
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size_per_partition is
# should be whatever dimension intermediate_size_per_partition is
is_transposed
=
getattr
(
param
,
"is_transposed"
,
False
)
is_transposed
=
getattr
(
param
,
"is_transposed"
,
False
)
shard_dim
=
SHARD_ID_TO_SHARDED_DIM
[
shard_id
]
shard_dim
=
SHARD_ID_TO_SHARDED_DIM
[
shard_id
]
if
is_transposed
:
if
is_transposed
:
shard_dim
=
int
(
not
shard_dim
)
shard_dim
=
int
(
not
shard_dim
)
...
...
vllm/transformers_utils/configs/deepseek_v3.py
deleted
100644 → 0
View file @
b8025f24
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
class
DeepseekV3Config
(
PretrainedConfig
):
model_type
=
"deepseek_v3"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
vocab_size
=
129280
,
hidden_size
=
7168
,
intermediate_size
=
18432
,
moe_intermediate_size
=
2048
,
num_hidden_layers
=
61
,
num_nextn_predict_layers
=
1
,
num_attention_heads
=
128
,
num_key_value_heads
=
128
,
n_shared_experts
=
1
,
n_routed_experts
=
256
,
ep_size
=
1
,
routed_scaling_factor
=
2.5
,
kv_lora_rank
=
512
,
q_lora_rank
=
1536
,
qk_rope_head_dim
=
64
,
v_head_dim
=
128
,
qk_nope_head_dim
=
128
,
topk_method
=
'noaux_tc'
,
n_group
=
8
,
topk_group
=
4
,
num_experts_per_tok
=
8
,
moe_layer_freq
=
1
,
first_k_dense_replace
=
3
,
norm_topk_prob
=
True
,
scoring_func
=
'sigmoid'
,
hidden_act
=
"silu"
,
max_position_embeddings
=
4096
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
pad_token_id
=
None
,
bos_token_id
=
0
,
eos_token_id
=
1
,
tie_word_embeddings
=
False
,
rope_theta
=
10000.0
,
rope_scaling
=
None
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
moe_intermediate_size
=
moe_intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_nextn_predict_layers
=
num_nextn_predict_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
n_shared_experts
=
n_shared_experts
self
.
n_routed_experts
=
n_routed_experts
self
.
ep_size
=
ep_size
self
.
routed_scaling_factor
=
routed_scaling_factor
self
.
kv_lora_rank
=
kv_lora_rank
self
.
q_lora_rank
=
q_lora_rank
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
topk_method
=
topk_method
self
.
n_group
=
n_group
self
.
topk_group
=
topk_group
self
.
num_experts_per_tok
=
num_experts_per_tok
self
.
moe_layer_freq
=
moe_layer_freq
self
.
first_k_dense_replace
=
first_k_dense_replace
self
.
norm_topk_prob
=
norm_topk_prob
self
.
scoring_func
=
scoring_func
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
vllm/utils/deep_gemm.py
View file @
8d5cdd53
...
@@ -189,12 +189,6 @@ def get_mk_alignment_for_contiguous_layout() -> list[int]:
...
@@ -189,12 +189,6 @@ def get_mk_alignment_for_contiguous_layout() -> list[int]:
return
[
mk_align_size
,
mk_align_size
]
return
[
mk_align_size
,
mk_align_size
]
def
get_num_sms
()
->
int
:
_lazy_init
()
_dg
=
importlib
.
import_module
(
"deep_gemm"
)
return
int
(
_dg
.
get_num_sms
())
def
get_col_major_tma_aligned_tensor
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_col_major_tma_aligned_tensor
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
"""Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
_lazy_init
()
_lazy_init
()
...
@@ -338,100 +332,6 @@ def fp8_paged_mqa_logits(
...
@@ -338,100 +332,6 @@ def fp8_paged_mqa_logits(
)
)
def
fp8_mqa_logits
(
q
:
torch
.
Tensor
,
kv
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
weights
:
torch
.
Tensor
,
cu_seqlen_ks
:
torch
.
Tensor
,
cu_seqlen_ke
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Compute FP8 MQA logits for a single sequence without KV paging.
Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
_lazy_init
()
if
_fp8_mqa_logits_impl
is
None
:
return
_missing
()
return
_fp8_mqa_logits_impl
(
q
,
kv
,
weights
,
cu_seqlen_ks
,
cu_seqlen_ke
)
def
get_paged_mqa_logits_metadata
(
context_lens
:
torch
.
Tensor
,
block_size
:
int
,
num_sms
:
int
)
->
torch
.
Tensor
:
"""Build scheduling metadata for paged MQA logits.
Args:
context_lens: Tensor of shape [B], dtype int32; effective context length
per batch element.
block_size: KV-cache block size in tokens (e.g., 64).
num_sms: Number of SMs available. 132 for Hopper
Returns:
Backend-specific tensor consumed by `fp8_paged_mqa_logits` to
schedule work across SMs.
"""
_lazy_init
()
if
_get_paged_mqa_logits_metadata_impl
is
None
:
return
_missing
()
return
_get_paged_mqa_logits_metadata_impl
(
context_lens
,
block_size
,
num_sms
)
def
fp8_paged_mqa_logits
(
q_fp8
:
torch
.
Tensor
,
kv_cache_fp8
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
schedule_metadata
:
torch
.
Tensor
,
max_model_len
:
int
,
)
->
torch
.
Tensor
:
"""Compute FP8 MQA logits using paged KV-cache.
Args:
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
4 bytes per (block,pos) store the `float` dequant scale.
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
context_lens: Tensor of shape [B], dtype int32; effective context length
for each batch element.
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
block indices to physical blocks in the paged cache.
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
used to distribute work across SMs.
max_model_len: Maximum sequence length used to size the logits output.
Returns:
Logits tensor of shape [B * next_n, max_model_len], dtype
`torch.float32`.
"""
_lazy_init
()
if
_fp8_paged_mqa_logits_impl
is
None
:
return
_missing
()
return
_fp8_paged_mqa_logits_impl
(
q_fp8
,
kv_cache_fp8
,
weights
,
context_lens
,
block_tables
,
schedule_metadata
,
max_model_len
,
clean_logits
=
True
)
def
_ceil_to_ue8m0
(
x
:
torch
.
Tensor
):
def
_ceil_to_ue8m0
(
x
:
torch
.
Tensor
):
return
torch
.
pow
(
2.0
,
torch
.
ceil
(
torch
.
log2
(
x
.
abs
())))
return
torch
.
pow
(
2.0
,
torch
.
ceil
(
torch
.
log2
(
x
.
abs
())))
...
...
vllm/v1/attention/backends/mla/common.py
View file @
8d5cdd53
...
@@ -1314,138 +1314,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
...
@@ -1314,138 +1314,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
def
get_layer_weight
(
layer
):
WEIGHT_NAMES
=
(
"weight"
,
"qweight"
,
"weight_packed"
)
for
attr
in
WEIGHT_NAMES
:
if
hasattr
(
layer
,
attr
):
return
getattr
(
layer
,
attr
)
raise
AttributeError
(
f
"Layer '
{
layer
}
' has no recognized weight attribute:"
f
"
{
WEIGHT_NAMES
}
."
)
def
get_and_maybe_dequant_weights
(
layer
:
LinearBase
):
if
not
isinstance
(
layer
.
quant_method
,
UnquantizedLinearMethod
):
# NOTE: This should only be used offline, since it's O(N^3)
eye
=
torch
.
eye
(
layer
.
input_size_per_partition
,
dtype
=
act_dtype
,
device
=
get_layer_weight
(
layer
).
device
)
dequant_weights
=
layer
.
quant_method
.
apply
(
layer
,
eye
,
bias
=
None
)
del
eye
# standardize to (output, input)
return
dequant_weights
.
T
return
layer
.
weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
kv_b_proj
).
T
assert
kv_b_proj_weight
.
shape
==
(
self
.
kv_lora_rank
,
self
.
num_heads
*
(
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)),
(
f
"
{
kv_b_proj_weight
.
shape
=
}
, "
f
"
{
self
.
kv_lora_rank
=
}
, "
f
"
{
self
.
num_heads
=
}
, "
f
"
{
self
.
qk_nope_head_dim
=
}
, "
f
"
{
self
.
v_head_dim
=
}
"
)
kv_b_proj_weight
=
kv_b_proj_weight
.
view
(
self
.
kv_lora_rank
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
,
)
W_UK
,
W_UV
=
kv_b_proj_weight
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
if
is_rocm_aiter_fp8bmm_enabled
():
W_K
=
W_UK
.
transpose
(
0
,
1
)
# 16 512 128
W_V
=
W_UV
.
permute
(
1
,
2
,
0
)
# 16 128 512
self
.
W_K
,
self
.
W_K_scale
=
dynamic_per_batched_tensor_quant
(
W_K
,
dtype
=
current_platform
.
fp8_dtype
())
self
.
W_V
,
self
.
W_V_scale
=
dynamic_per_batched_tensor_quant
(
W_V
,
dtype
=
current_platform
.
fp8_dtype
())
# The kernel operates on non-padded inputs. Hence, pre-compiling
# triton kernel to avoid runtime compilation for unseen batch sizes
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
# On DS-R1, this step adds roughly 50s to the model loading time.
max_batch_size
=
1024
# [ToDo] Find the optimal upper limit
pre_compilation_list
=
list
(
range
(
1
,
max_batch_size
+
1
))
if
is_global_first_rank
():
pre_compilation_list
=
tqdm
(
pre_compilation_list
,
desc
=
"[Aiter Triton] Pre-compiling fp8 BMM kernel"
,
total
=
max_batch_size
,
)
for
m
in
pre_compilation_list
:
x
=
torch
.
empty
((
self
.
W_K
.
shape
[
0
],
m
,
self
.
W_K
.
shape
[
2
]),
dtype
=
torch
.
bfloat16
,
device
=
self
.
W_K
.
device
)
aiter_triton_fp8_bmm
(
x
,
self
.
W_K
,
self
.
W_K_scale
,
group_size
=
128
,
transpose_bm
=
True
)
x
=
torch
.
empty
((
self
.
W_V
.
shape
[
0
],
m
,
self
.
W_V
.
shape
[
2
]),
dtype
=
torch
.
bfloat16
,
device
=
self
.
W_V
.
device
)
aiter_triton_fp8_bmm
(
x
,
self
.
W_V
,
self
.
W_V_scale
,
group_size
=
128
,
transpose_bm
=
True
)
else
:
# Convert from (L, N, V) to (N, L, V)
self
.
W_UV
=
W_UV
.
transpose
(
0
,
1
)
# Convert from (L, N, P) to (N, P, L)
self
.
W_UK_T
=
W_UK
.
permute
(
1
,
2
,
0
)
def
_v_up_proj
(
self
,
x
:
torch
.
Tensor
,
out
:
torch
.
Tensor
):
# Convert from (B, N, L) to (N, B, L)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
if
is_rocm_aiter_fp8bmm_enabled
():
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x
=
aiter_triton_fp8_bmm
(
x
,
self
.
W_V
,
self
.
W_V_scale
,
group_size
=
128
,
transpose_bm
=
True
)
# Convert from (B, N, V) to (B, N * V)
x
=
x
.
reshape
(
-
1
,
self
.
num_heads
*
self
.
v_head_dim
)
# Copy result
out
.
copy_
(
x
)
else
:
# Convert from (B, N * V) to (N, B, V)
out
=
out
.
view
(
-
1
,
self
.
num_heads
,
self
.
v_head_dim
).
transpose
(
0
,
1
)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
torch
.
bmm
(
x
,
self
.
W_UV
,
out
=
out
)
# Reuse "out" to make it "hot"
# Convert from (N, B, V) to (B, N * V)
out_new
=
out
.
transpose
(
0
,
1
).
reshape
(
-
1
,
self
.
num_heads
*
self
.
v_head_dim
)
# Adjust output buffer shape back to the original (B, N * V)
N
,
B
,
V
=
out
.
shape
out
.
resize_
((
B
,
N
*
V
))
out
.
copy_
(
out_new
)
# Copy result
class
MLACommonImpl
(
MLACommonBaseImpl
[
M
],
Generic
[
M
]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
__init__
(
*
args
,
**
kwargs
)
if
use_flashinfer_prefill
():
if
use_flashinfer_prefill
():
logger
.
debug_once
(
"Using FlashInfer prefill for MLA"
)
logger
.
debug_once
(
"Using FlashInfer prefill for MLA"
)
self
.
_run_prefill_context_chunk
=
self
.
_run_prefill_context_chunk_fi
self
.
_run_prefill_context_chunk
=
self
.
_run_prefill_context_chunk_fi
...
@@ -1627,9 +1496,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
...
@@ -1627,9 +1496,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# Convert from (q_len, num_heads) to (num_heads, q_len)
# Convert from (q_len, num_heads) to (num_heads, q_len)
return
attn_out
,
lse
.
transpose
(
0
,
1
).
contiguous
()
return
attn_out
,
lse
.
transpose
(
0
,
1
).
contiguous
()
# Convert from (q_len, num_heads) to (num_heads, q_len)
return
attn_out
,
lse
.
transpose
(
0
,
1
).
contiguous
()
def
_run_prefill_context_chunk_cudnn
(
def
_run_prefill_context_chunk_cudnn
(
self
,
prefill
:
MLACommonPrefillMetadata
,
chunk_idx
:
int
,
q
,
k
,
v
self
,
prefill
:
MLACommonPrefillMetadata
,
chunk_idx
:
int
,
q
,
k
,
v
):
):
...
...
vllm/v1/attention/cpu_mla.py
deleted
100644 → 0
View file @
b8025f24
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
vllm._custom_ops
as
ops
from
vllm._ipex_ops
import
ipex_ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadataBuilder
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.mla.common
import
MLACommonImpl
,
MLACommonState
from
vllm.attention.backends.torch_sdpa
import
TorchSDPAMetadata
from
vllm.utils
import
make_tensor_with_pad
from
vllm.worker.cpu_model_runner
import
ModelInputForCPUBuilder
class
CPUMLABackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"CPU_MLA"
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"CPUMLAMetadata"
]:
return
CPUMLAMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"CPUMLAMetadataBuilder"
]:
return
CPUMLAMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"MLACommonState"
]:
return
MLACommonState
@
staticmethod
def
get_impl_cls
()
->
Type
[
"CPUMLAImpl"
]:
return
CPUMLAImpl
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
# assumed to be 1 for MLA
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
(
num_blocks
,
block_size
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
ops
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
ops
.
copy_blocks_mla
(
kv_caches
,
src_to_dists
)
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
576
]
@
dataclass
class
CPUMLAMetadata
(
TorchSDPAMetadata
):
# New for MLA
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions
:
torch
.
Tensor
=
None
# required by MLACommonImpl
is_profile_run
:
bool
=
False
class
CPUMLAMetadataBuilder
(
AttentionMetadataBuilder
[
CPUMLAMetadata
]):
def
__init__
(
self
,
input_builder
:
ModelInputForCPUBuilder
)
->
None
:
self
.
chunked_prefill
=
input_builder
.
chunked_prefill
self
.
input_builder
=
input_builder
assert
not
self
.
chunked_prefill
,
\
"chunked prefill is currently not supported"
def
prepare
(
self
):
self
.
input_data
=
self
.
input_builder
.
input_data
def
build
(
self
,
seq_lens
,
query_lens
,
cuda_graph_pad_size
,
batch_size
):
input_data
=
self
.
input_data
prefill_seq_lens
=
seq_lens
[
0
:
input_data
.
num_prefills
]
prefill_query_lens
=
query_lens
[
0
:
input_data
.
num_prefills
]
slot_mapping
=
torch
.
tensor
(
input_data
.
slot_mapping
,
dtype
=
torch
.
long
,
device
=
"cpu"
)
# metadata for prefill
if
input_data
.
num_prefills
>
0
:
query_lens_tensor
=
torch
.
tensor
(
prefill_query_lens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
kv_lens_tensor
=
torch
.
tensor
(
prefill_seq_lens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
query_start_loc
=
torch
.
zeros
(
input_data
.
num_prefills
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
kv_start_loc
=
torch
.
zeros
(
input_data
.
num_prefills
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
torch
.
int32
,
out
=
query_start_loc
[
1
:])
torch
.
cumsum
(
kv_lens_tensor
,
dim
=
0
,
dtype
=
torch
.
int32
,
out
=
kv_start_loc
[
1
:])
max_query_len
=
max
(
prefill_query_lens
)
max_kv_len
=
max
(
prefill_seq_lens
)
# for chunked-prefill
if
self
.
chunked_prefill
:
prefill_block_tables
=
make_tensor_with_pad
(
self
.
input_data
.
prefill_block_tables
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
)
else
:
prefill_block_tables
=
None
else
:
query_start_loc
=
None
kv_start_loc
=
None
max_query_len
=
None
max_kv_len
=
None
prefill_block_tables
=
None
# metadata for decode
if
input_data
.
num_decode_tokens
!=
0
:
seq_lens_tensor
=
torch
.
tensor
(
input_data
.
seq_lens
[
input_data
.
num_prefills
:],
dtype
=
torch
.
int32
,
device
=
"cpu"
,
)
block_tables
=
make_tensor_with_pad
(
self
.
input_data
.
decode_block_tables
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
)
else
:
block_tables
=
torch
.
tensor
([])
seq_lens_tensor
=
torch
.
tensor
(
input_data
.
seq_lens
[:
input_data
.
num_prefills
],
dtype
=
torch
.
int32
,
device
=
"cpu"
,
)
# For multi-modal models
placeholder_index_maps
=
None
if
len
(
input_data
.
multi_modal_inputs_list
)
!=
0
:
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
input_data
.
multi_modal_placeholder_maps
.
items
()
}
return
CPUMLAMetadata
(
chunked_prefill
=
self
.
chunked_prefill
,
seq_lens
=
prefill_seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_kv_len
=
max_kv_len
,
prefill_query_start_loc
=
query_start_loc
,
kv_start_loc
=
kv_start_loc
,
max_decode_seq_len
=
input_data
.
max_decode_seq_len
,
num_prefills
=
input_data
.
num_prefills
,
num_prefill_tokens
=
input_data
.
num_prefill_tokens
,
num_decode_tokens
=
input_data
.
num_decode_tokens
,
block_tables
=
block_tables
,
prefill_block_tables
=
prefill_block_tables
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
enable_kv_scales_calculation
=
False
,
input_positions
=
torch
.
tensor
([
self
.
input_data
.
input_positions
]))
class
CPUMLAImpl
(
MLACommonImpl
[
CPUMLAMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
# MLA Specific Arguments
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
kv_sharing_target_layer_name
,
**
mla_args
)
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
]
if
any
(
unsupported_features
):
raise
NotImplementedError
(
"CPUMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap"
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"CPUMLAImpl"
)
# states is implemented.
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"CPUMLAImpl with FP8 KV cache not yet supported"
)
def
_forward_prefill
(
self
,
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
CPUMLAMetadata
,
# type: ignore[override]
)
->
torch
.
Tensor
:
prefill_metadata
=
attn_metadata
.
prefill_metadata
assert
prefill_metadata
is
not
None
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
\
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
,
v
=
kv_nope
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
output
=
torch
.
empty_like
(
q
)
ipex_ops
.
varlen_attention
(
query
=
q
,
key
=
k
,
value
=
v_padded
,
out
=
output
,
seqlen_q
=
prefill_metadata
.
prefill_query_start_loc
,
seqlen_k
=
prefill_metadata
.
prefill_query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
max_query_len
,
pdropout
=
0.0
,
softmax_scale
=
self
.
scale
,
zero_tensors
=
False
,
is_causal
=
True
,
return_softmax
=
False
,
gen_
=
None
,
logits_soft_cap
=
0.0
,
window_size_left
=-
1
,
window_size_right
=-
1
,
alibi_slopes
=
None
,
)
# remove padding
output
=
output
.
view
(
-
1
,
self
.
num_heads
,
q
.
shape
[
-
1
])[...,
:
v
.
shape
[
-
1
]]
return
output
.
reshape
(
-
1
,
self
.
num_heads
*
v
.
shape
[
-
1
])
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
CPUMLAMetadata
,
# type: ignore[override]
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
decode_meta
=
attn_metadata
.
decode_metadata
assert
decode_meta
is
not
None
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
o
=
q
.
new_empty
(
q
.
shape
[
0
],
self
.
num_heads
,
self
.
kv_lora_rank
)
# Run MQA
ops
.
mla_decode_kvcache_cpu
(
o
,
q
,
kv_c_and_k_pe_cache
,
self
.
scale
,
decode_meta
.
block_tables
,
decode_meta
.
seq_lens_tensor
)
return
self
.
_v_up_proj
(
o
)
vllm/v1/attention/ipex_attn.py
deleted
100644 → 0
View file @
b8025f24
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
vllm._ipex_ops
import
ipex_ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
_PARTITION_SIZE
=
512
class
IpexAttnBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"IPEX"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"IpexAttnBackendImpl"
]:
return
IpexAttnBackendImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"IpexAttnMetadata"
]:
return
IpexAttnMetadata
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
PagedAttention
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
from
vllm._ipex_ops
import
ipex_ops
as
ops
ops
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
from
vllm._ipex_ops
import
ipex_ops
as
ops
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
@
dataclass
class
IpexAttnMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
"""Metadata for IpexAttnBackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
slot_mapping
:
torch
.
Tensor
seq_lens
:
Optional
[
List
[
int
]]
seqlen_q
:
Optional
[
torch
.
Tensor
]
max_seqlen
:
Optional
[
int
]
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self
.
attn_bias
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"IpexAttnMetadata"
]:
# Currently chunked prefill is not supported
if
self
.
num_decode_tokens
==
0
:
assert
self
.
num_prefills
>
0
return
self
return
None
@
property
def
decode_metadata
(
self
)
->
Optional
[
"IpexAttnMetadata"
]:
# Currently chunked prefill is not supported
if
self
.
num_prefills
>
0
:
assert
self
.
num_decode_tokens
==
0
return
None
return
self
class
IpexAttnBackendImpl
(
AttentionImpl
[
IpexAttnMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
use_irope
:
logger
.
warning_once
(
"Using irope in Ipex is not supported yet, it will fall"
" back to global attention for long context."
)
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"IPEX backend does not support block-sparse attention."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
sliding_window
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
need_mask
=
(
self
.
sliding_window
is
not
None
)
if
logits_soft_cap
is
None
:
logits_soft_cap
=
-
1
self
.
logits_soft_cap
=
logits_soft_cap
supported_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
supported_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
if
is_quantized_kv_cache
(
kv_cache_dtype
):
raise
NotImplementedError
(
"IPEX backend does not support FP8 KV cache. "
"Please use xFormers backend instead."
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"IpexAttnBackendImpl"
)
def
split_kv_cache
(
self
,
kv_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
x
=
1
num_blocks
=
kv_cache
.
shape
[
1
]
key_cache
=
kv_cache
[
0
]
key_cache
=
key_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
//
x
,
-
1
,
x
)
value_cache
=
kv_cache
[
1
]
value_cache
=
value_cache
.
view
(
num_blocks
,
num_kv_heads
,
head_size
,
-
1
)
return
key_cache
,
value_cache
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
IpexAttnMetadata
,
# type: ignore
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with IPEX varlen_attention and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for IpexAttentionImpl"
)
assert
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
.
numel
()
>
0
:
key_cache
,
value_cache
=
self
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
ipex_ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
layer
.
_k_scale_float
,
layer
.
_v_scale_float
,
)
if
attn_metadata
.
is_prompt
:
assert
attn_metadata
.
seq_lens
is
not
None
if
(
kv_cache
.
numel
()
==
0
or
attn_metadata
.
block_tables
.
numel
()
==
0
):
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=
1
)
if
attn_metadata
.
attn_bias
is
None
:
if
self
.
sliding_window
is
not
None
:
att_masks
=
_make_sliding_window_bias
(
attn_metadata
.
seq_lens
,
self
.
sliding_window
,
query
.
dtype
)
# type: ignore
else
:
att_masks
=
_make_sliding_window_bias
(
attn_metadata
.
seq_lens
,
None
,
dtype
=
query
.
dtype
)
attn_metadata
.
attn_bias
=
att_masks
output
=
torch
.
empty
(
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
),
dtype
=
query
.
dtype
,
device
=
query
.
device
)
ipex_ops
.
varlen_attention
(
query
,
key
,
value
,
output
,
attn_metadata
.
seqlen_q
,
attn_metadata
.
seqlen_q
,
self
.
alibi_slopes
,
attn_metadata
.
max_seqlen
,
attn_metadata
.
max_seqlen
,
pdropout
=
0.0
,
softmax_scale
=
self
.
scale
,
zero_tensors
=
False
,
is_causal
=
True
,
return_softmax
=
False
,
gen_
=
None
,
window_size_left
=-
1
,
window_size_right
=-
1
,
logits_soft_cap
=
self
.
logits_soft_cap
,
)
else
:
# prefix-enabled attention
raise
RuntimeError
(
"IPEX backend doesn't support prefix decoding."
)
else
:
# Decoding run.
max_seq_len
=
attn_metadata
.
max_decode_seq_len
output
=
torch
.
empty_like
(
query
)
block_size
=
value_cache
.
shape
[
3
]
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
((
max_seq_len
+
_PARTITION_SIZE
-
1
)
//
_PARTITION_SIZE
)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory
# shortage.
use_v1
=
(
max_seq_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
))
if
use_v1
:
# Run PagedAttention V1.
ipex_ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
self
.
num_kv_heads
,
self
.
scale
,
attn_metadata
.
block_tables
,
attn_metadata
.
seq_lens_tensor
,
block_size
,
max_seq_len
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
layer
.
_k_scale_float
,
layer
.
_v_scale_float
,
)
else
:
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
ipex_ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
self
.
num_kv_heads
,
self
.
scale
,
attn_metadata
.
block_tables
,
attn_metadata
.
seq_lens_tensor
,
block_size
,
max_seq_len
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
layer
.
_k_scale_float
,
layer
.
_v_scale_float
,
)
# Reshape the output tensor.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
seq_lens
:
List
[
int
],
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
for
seq_len
in
seq_lens
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
,
device
=
alibi_slopes
.
device
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
bias
[
None
,
:].
repeat
((
num_heads
,
1
,
1
))
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
inf_mask
=
torch
.
empty
(
(
1
,
seq_len
,
seq_len
),
dtype
=
bias
.
dtype
,
device
=
alibi_slopes
.
device
).
fill_
(
-
torch
.
inf
).
triu_
(
diagonal
=
1
)
attn_biases
.
append
((
bias
+
inf_mask
).
to
(
dtype
))
return
attn_biases
def
_make_sliding_window_bias
(
seq_lens
:
List
[
int
],
window_size
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
for
seq_len
in
seq_lens
:
tensor
=
torch
.
full
(
(
1
,
seq_len
,
seq_len
),
dtype
=
dtype
,
fill_value
=
1
,
)
shift
=
0
mask
=
torch
.
tril
(
tensor
,
diagonal
=
shift
).
to
(
dtype
)
# type: ignore
if
window_size
is
not
None
:
mask
=
torch
.
triu
(
mask
,
diagonal
=
shift
-
window_size
+
1
)
mask
=
torch
.
log
(
mask
)
attn_biases
.
append
(
mask
.
to
(
dtype
))
return
attn_biases
vllm/v1/attention/mla/common.py
deleted
100644 → 0
View file @
b8025f24
This diff is collapsed.
Click to expand it.
vllm/v1/attention/ops/common.py
View file @
8d5cdd53
...
@@ -467,208 +467,3 @@ def unpack_seq_triton(
...
@@ -467,208 +467,3 @@ def unpack_seq_triton(
out
=
out
.
reshape
(
output_shape
)
out
=
out
.
reshape
(
output_shape
)
return
out
return
out
@
triton
.
jit
def
_pack_seq_kernel
(
x_ptr
,
# [N, D]
out_ptr
,
# [B, Lmax, D]
lengths_ptr
,
# *i32, [B]
N
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
Lmax
:
tl
.
constexpr
,
PAD_VALUE
:
tl
.
constexpr
,
BLOCK_T
:
tl
.
constexpr
,
# timesteps per program
BLOCK_D
:
tl
.
constexpr
# features per program
):
pid_b
=
tl
.
program_id
(
0
)
# batch id
pid_t
=
tl
.
program_id
(
1
)
# block over time dimension
pid_d
=
tl
.
program_id
(
2
)
# block over feature dimension
off_t
=
pid_t
*
BLOCK_T
+
tl
.
arange
(
0
,
BLOCK_T
)
# [BLOCK_T]
off_d
=
pid_d
*
BLOCK_D
+
tl
.
arange
(
0
,
BLOCK_D
)
# [BLOCK_D]
# Compute start index and sequence length from cumulative lengths
in_start
=
0
for
i
in
range
(
pid_b
):
in_start
+=
tl
.
load
(
lengths_ptr
+
i
)
seq_len
=
tl
.
load
(
lengths_ptr
+
pid_b
)
# valid time positions for this block
t_mask
=
off_t
<
Lmax
# compute input row indices for valid (b, t)
in_row
=
in_start
+
off_t
valid_row
=
(
off_t
<
seq_len
)
&
t_mask
# Pointers
# x_ptr: row-major [N, D]
x_row_ptr
=
x_ptr
+
in_row
[:,
None
]
*
D
+
off_d
[
None
,
:]
# out_ptr: row-major [B, Lmax, D]
out_row_ptr
=
out_ptr
+
(
pid_b
*
Lmax
+
off_t
)[:,
None
]
*
D
+
off_d
[
None
,
:]
# Initialize with PAD (cast will occur as needed based on out_ptr dtype)
d_mask
=
off_d
[
None
,
:]
<
D
pad_vals
=
tl
.
full
([
BLOCK_T
,
BLOCK_D
],
PAD_VALUE
,
tl
.
float32
)
tl
.
store
(
out_row_ptr
,
pad_vals
,
mask
=
t_mask
[:,
None
]
&
d_mask
)
# Load & write only where within seq_len
x_vals
=
tl
.
load
(
x_row_ptr
,
mask
=
valid_row
[:,
None
]
&
d_mask
)
tl
.
store
(
out_row_ptr
,
x_vals
,
mask
=
valid_row
[:,
None
]
&
d_mask
)
def
pack_seq_triton
(
x
:
torch
.
Tensor
,
lengths
:
torch
.
Tensor
,
pad_value
:
float
=
-
float
(
'inf'
),
block_t
:
int
=
64
,
block_d
:
int
=
64
)
->
torch
.
Tensor
:
"""
Pack sequences of different lengths into a batched tensor.
Args:
x: [N, ...] - input tensor where N is total number of tokens
lengths: [B] - sequence lengths for each batch
pad_value: value to use for padding
block_t: block size for time dimension
block_d: block size for feature dimension
Returns:
packed: [B, Lmax, ...] - packed tensor
"""
# Handle multi-dimensional input by reshaping to (N, -1)
original_shape
=
x
.
shape
if
len
(
original_shape
)
>
2
:
N
=
original_shape
[
0
]
x_reshaped
=
x
.
reshape
(
N
,
-
1
)
D
=
x_reshaped
.
shape
[
1
]
else
:
N
,
D
=
x
.
shape
x_reshaped
=
x
B
=
lengths
.
numel
()
Lmax
=
int
(
lengths
.
max
().
item
())
# Starts are computed inside the kernel from lengths
out
=
torch
.
empty
((
B
,
Lmax
,
D
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
grid
=
(
B
,
triton
.
cdiv
(
Lmax
,
block_t
),
triton
.
cdiv
(
D
,
block_d
))
_pack_seq_kernel
[
grid
](
x_reshaped
,
out
,
lengths
.
int
(),
N
,
D
,
Lmax
,
PAD_VALUE
=
float
(
pad_value
),
BLOCK_T
=
block_t
,
BLOCK_D
=
block_d
,
num_warps
=
4
,
num_stages
=
2
)
# Reshape output back to original dimensions (except first dimension)
if
len
(
original_shape
)
>
2
:
output_shape
=
(
B
,
Lmax
)
+
original_shape
[
1
:]
out
=
out
.
reshape
(
output_shape
)
return
out
@
triton
.
jit
def
_unpack_seq_triton_kernel
(
packed_ptr
,
# [B, Lmax, D]
out_ptr
,
# [N, D]
lengths_ptr
,
# *i32, [B]
B
:
tl
.
constexpr
,
Lmax
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
BLOCK_T
:
tl
.
constexpr
,
# timesteps per program
BLOCK_D
:
tl
.
constexpr
# features per program
):
pid_b
=
tl
.
program_id
(
0
)
# batch id
pid_t
=
tl
.
program_id
(
1
)
# block over time dimension
pid_d
=
tl
.
program_id
(
2
)
# block over feature dimension
off_t
=
pid_t
*
BLOCK_T
+
tl
.
arange
(
0
,
BLOCK_T
)
# [BLOCK_T]
off_d
=
pid_d
*
BLOCK_D
+
tl
.
arange
(
0
,
BLOCK_D
)
# [BLOCK_D]
# bounds: compute start from cumulative lengths
in_start
=
0
for
i
in
range
(
pid_b
):
in_start
+=
tl
.
load
(
lengths_ptr
+
i
)
seq_len
=
tl
.
load
(
lengths_ptr
+
pid_b
)
# valid time positions for this block
t_mask
=
off_t
<
Lmax
valid_row
=
(
off_t
<
seq_len
)
&
t_mask
# compute output row indices for valid (b, t)
out_row
=
in_start
+
off_t
# Pointers
# packed_ptr: row-major [B, Lmax, D]
packed_row_ptr
=
packed_ptr
+
(
pid_b
*
Lmax
+
off_t
)[:,
None
]
*
D
+
off_d
[
None
,
:]
# out_ptr: row-major [N, D]
out_row_ptr
=
out_ptr
+
out_row
[:,
None
]
*
D
+
off_d
[
None
,
:]
# Load from packed tensor and store to output
d_mask
=
off_d
[
None
,
:]
<
D
packed_vals
=
tl
.
load
(
packed_row_ptr
,
mask
=
valid_row
[:,
None
]
&
d_mask
)
tl
.
store
(
out_row_ptr
,
packed_vals
,
mask
=
valid_row
[:,
None
]
&
d_mask
)
def
unpack_seq_triton
(
packed_tensor
:
torch
.
Tensor
,
lengths
:
torch
.
Tensor
,
block_t
:
int
=
64
,
block_d
:
int
=
64
)
->
torch
.
Tensor
:
"""
Unpack a packed decode query tensor back to the original format.
Efficient Triton implementation.
Args:
packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton
lengths: [B] - sequence lengths for each batch
block_t: block size for time dimension
block_d: block size for feature dimension
Returns:
unpacked_tensor: [N, ...] where N = sum(lengths)
"""
# Handle multi-dimensional input by reshaping to (B, Lmax, -1)
original_shape
=
packed_tensor
.
shape
if
len
(
original_shape
)
>
3
:
B
,
Lmax
=
original_shape
[:
2
]
packed_reshaped
=
packed_tensor
.
reshape
(
B
,
Lmax
,
-
1
)
D
=
packed_reshaped
.
shape
[
2
]
else
:
B
,
Lmax
,
D
=
packed_tensor
.
shape
packed_reshaped
=
packed_tensor
# Calculate total number of elements
N
=
int
(
lengths
.
sum
().
item
())
out
=
torch
.
empty
((
N
,
D
),
device
=
packed_tensor
.
device
,
dtype
=
packed_tensor
.
dtype
)
grid
=
(
B
,
triton
.
cdiv
(
Lmax
,
block_t
),
triton
.
cdiv
(
D
,
block_d
))
_unpack_seq_triton_kernel
[
grid
](
packed_reshaped
,
out
,
lengths
.
int
(),
B
,
Lmax
,
D
,
BLOCK_T
=
block_t
,
BLOCK_D
=
block_d
,
num_warps
=
4
,
num_stages
=
2
)
# Reshape output back to original dimensions (except first dimension)
if
len
(
original_shape
)
>
3
:
output_shape
=
(
N
,
)
+
original_shape
[
2
:]
out
=
out
.
reshape
(
output_shape
)
return
out
vllm/v1/attention/pallas.py
deleted
100644 → 0
View file @
b8025f24
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch_xla.experimental.custom_kernel
# Required to register custom ops.
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
PallasAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"PALLAS"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"PallasAttentionBackendImpl"
]:
return
PallasAttentionBackendImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"PallasMetadata"
]:
return
PallasMetadata
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
(
num_kv_heads
,
num_blocks
,
block_size
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
raise
RuntimeError
(
"swap_blocks is not used for the TPU backend."
)
@
torch
.
compile
(
backend
=
"openxla"
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
src_to_dists
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
)
->
None
:
src_indices
,
dst_indices
=
src_to_dists
for
k_cache
,
v_cache
in
kv_caches
:
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
k_cache
,
True
)
k_cache
[:,
dst_indices
]
=
k_cache
[:,
src_indices
]
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
v_cache
,
True
)
v_cache
[:,
dst_indices
]
=
v_cache
[:,
src_indices
]
@
dataclass
class
PallasMetadata
(
AttentionMetadata
):
# Currently, input sequences can only contain all prefills
# or all decoding.
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
context_lens
:
Optional
[
torch
.
Tensor
]
=
None
effective_query_lens
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
assert
self
.
num_decode_tokens
==
0
return
self
@
property
def
decode_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
assert
self
.
num_prefills
==
0
assert
self
.
num_prefill_tokens
==
0
assert
self
.
block_tables
is
not
None
assert
self
.
context_lens
is
not
None
return
self
class
PallasAttentionBackendImpl
(
AttentionImpl
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
use_irope
:
logger
.
warning_once
(
"Using irope in Pallas is not supported yet, it will fall back "
"to global attention for long context."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
logits_soft_cap
=
logits_soft_cap
if
head_size
%
128
!=
0
:
raise
NotImplementedError
(
f
"Head size must be a multiple of 128, found
{
head_size
}
."
)
if
alibi_slopes
is
not
None
:
raise
NotImplementedError
(
"Alibi slopes is not supported."
)
if
sliding_window
is
not
None
:
raise
NotImplementedError
(
"Sliding window is not supported."
)
if
is_quantized_kv_cache
(
kv_cache_dtype
):
raise
NotImplementedError
(
"FP8 KV cache dtype is not supported."
)
if
blocksparse_params
is
not
None
:
raise
NotImplementedError
(
"Blocksparse is not supported."
)
if
torch_xla
.
tpu
.
version
()
<
4
:
raise
NotImplementedError
(
"TPU version must be 4 or higher."
)
self
.
megacore_mode
=
None
tpu_env
=
torch_xla
.
tpu
.
get_tpu_env
()
tpu_type
=
(
tpu_env
.
get
(
"ACCELERATOR_TYPE"
,
None
)
or
tpu_env
.
get
(
"TYPE"
,
None
)
or
tpu_env
.
get
(
"TPU_ACCELERATOR_TYPE"
,
None
))
assert
tpu_type
is
not
None
tpu_type
=
tpu_type
.
lower
()
if
((
"lite"
not
in
tpu_type
)
and
(
"v6"
not
in
tpu_type
)):
if
self
.
num_kv_heads
%
2
==
0
:
self
.
megacore_mode
=
"kv_head"
else
:
# NOTE(woosuk): If the batch size is not a multiple of 2, the
# megacore mode will be None.
self
.
megacore_mode
=
"batch"
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl"
)
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
attn_metadata
:
PallasMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with Pallas attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
with shape [0] for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for PallasAttentionImpl"
)
assert
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
[
0
].
numel
()
>
0
:
slot_mapping
=
attn_metadata
.
slot_mapping
key_cache
,
value_cache
=
kv_cache
write_to_kv_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
)
query
=
query
*
self
.
scale
if
attn_metadata
.
num_prefills
>
0
:
if
attn_metadata
.
block_tables
is
None
:
# Prefill without paged KV cache.
assert
seq_len
%
16
==
0
,
(
"Pallas FlashAttention kernel requires seq_len to be a "
f
"multiple of 16 but got
{
seq_len
}
"
)
# Handle GQA/MQA.
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=-
2
)
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=-
2
)
value
=
value
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
# FlashAttention kernel requires the input shape to be
# [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output
=
torch
.
ops
.
xla
.
flash_attention
(
query
.
permute
(
0
,
2
,
1
,
3
),
key
.
permute
(
0
,
2
,
1
,
3
),
value
.
permute
(
0
,
2
,
1
,
3
),
True
,
)
output
=
output
.
permute
(
0
,
2
,
1
,
3
)
else
:
# Prefill with paged KV cache.
# TODO(woosuk): Tune the below knobs.
num_kv_pages_per_compute_block
=
16
num_queries_per_compute_block
=
16
assert
seq_len
%
num_queries_per_compute_block
==
0
output
=
torch
.
ops
.
xla
.
multi_queries_paged_attention
(
query
,
key_cache
,
value_cache
,
attn_metadata
.
context_lens
,
attn_metadata
.
block_tables
,
attn_metadata
.
effective_query_lens
,
num_kv_pages_per_compute_block
,
num_queries_per_compute_block
,
use_kernel
=
True
,
attn_logits_soft_cap
=
self
.
logits_soft_cap
,
)
else
:
# Decoding run.
assert
kv_cache
[
0
].
numel
()
>
0
query
=
query
.
squeeze
(
dim
=
1
)
pages_per_compute_block
=
16
# TODO(woosuk): Tune this value.
assert
attn_metadata
.
block_tables
is
not
None
assert
attn_metadata
.
context_lens
is
not
None
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
# block table in SMEM. Therefore, if the block table is too large,
# the kernel compilation will fail. To avoid this, we split the
# batch dimension into smaller chunks and run the kernel multiple
# times.
MAX_SMEM_USAGE
=
512
*
1024
size_per_seq
=
4
*
attn_metadata
.
block_tables
.
shape
[
1
]
max_num_seq
=
MAX_SMEM_USAGE
//
size_per_seq
if
batch_size
<=
max_num_seq
:
output
=
paged_attention
(
query
,
key_cache
,
value_cache
,
attn_metadata
.
context_lens
,
attn_metadata
.
block_tables
,
pages_per_compute_block
,
self
.
megacore_mode
,
attn_logits_soft_cap
=
self
.
logits_soft_cap
,
)
else
:
chunk_size
=
max_num_seq
# Make sure the chunk size is a multiple of 2.
chunk_size
=
chunk_size
//
2
*
2
num_chunks
=
(
batch_size
+
chunk_size
-
1
)
//
chunk_size
output
=
torch
.
empty_like
(
query
)
for
chunk_idx
in
range
(
num_chunks
):
chunk_start
=
chunk_idx
*
chunk_size
chunk_end
=
chunk_start
+
chunk_size
# NOTE(woosuk): We skip this line because it causes Dynamo
# compilation error. Instead, we rely on the slice operation
# to handle the out-of-bound case.
# chunk_end = min(chunk_end, batch_size)
chunk_output
=
paged_attention
(
query
[
chunk_start
:
chunk_end
],
key_cache
,
value_cache
,
attn_metadata
.
context_lens
[
chunk_start
:
chunk_end
],
attn_metadata
.
block_tables
[
chunk_start
:
chunk_end
],
pages_per_compute_block
,
self
.
megacore_mode
,
attn_logits_soft_cap
=
self
.
logits_soft_cap
,
)
output
[
chunk_start
:
chunk_end
]
=
chunk_output
# Reshape the output tensor.
return
output
.
reshape
(
batch_size
,
seq_len
,
hidden_size
)
def
write_to_kv_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
key_cache
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
value_cache
,
True
)
key
=
key
.
flatten
(
0
,
2
)
value
=
value
.
flatten
(
0
,
2
)
key_cache
=
key_cache
.
flatten
(
0
,
2
)
value_cache
=
value_cache
.
flatten
(
0
,
2
)
key_cache
.
index_copy_
(
0
,
slot_mapping
,
key
)
value_cache
.
index_copy_
(
0
,
slot_mapping
,
value
)
def
paged_attention
(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
pages_per_compute_block
:
int
,
megacore_mode
:
Optional
[
str
],
*
,
attn_logits_soft_cap
:
Optional
[
float
],
)
->
torch
.
Tensor
:
batch_size
=
query
.
shape
[
0
]
if
megacore_mode
==
"batch"
and
batch_size
%
2
!=
0
:
megacore_mode
=
None
else
:
megacore_mode
=
megacore_mode
return
torch
.
ops
.
xla
.
paged_attention
(
query
,
key_cache
,
value_cache
,
context_lens
,
block_tables
,
pages_per_compute_block
,
megacore_mode
=
megacore_mode
,
attn_logits_soft_cap
=
attn_logits_soft_cap
,
)
vllm/v1/attention/rocm_flash_attn.py
deleted
100644 → 0
View file @
b8025f24
This diff is collapsed.
Click to expand it.
vllm/v1/attention/torch_sdpa.py
deleted
100644 → 0
View file @
b8025f24
This diff is collapsed.
Click to expand it.
vllm/v1/spec_decode/eagle.py
View file @
8d5cdd53
...
@@ -47,7 +47,6 @@ from vllm.v1.spec_decode.utils import (
...
@@ -47,7 +47,6 @@ from vllm.v1.spec_decode.utils import (
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.worker.dp_utils
import
coordinate_batch_across_dp
from
vllm.v1.worker.dp_utils
import
coordinate_batch_across_dp
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.ubatching
import
dbo_current_ubatch_id
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment