Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
76e4dcf2
Unverified
Commit
76e4dcf2
authored
Nov 11, 2025
by
Lukas Geiger
Committed by
GitHub
Nov 11, 2025
Browse files
[Misc] Remove unused attention prefix prefill ops functions (#26971)
Signed-off-by:
Lukas Geiger
<
lukas.geiger94@gmail.com
>
parent
d5edcb86
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
0 additions
and
213 deletions
+0
-213
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+0
-210
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+0
-3
No files found.
vllm/attention/ops/prefix_prefill.py
View file @
76e4dcf2
...
...
@@ -335,216 +335,6 @@ def _fwd_kernel(
return
@
triton
.
jit
def
_fwd_kernel_flash_attn_v2
(
Q
,
K
,
V
,
K_cache
,
V_cache
,
B_Loc
,
sm_scale
,
B_Start_Loc
,
B_Seqlen
,
B_Ctxlen
,
block_size
,
x
,
Out
,
stride_b_loc_b
,
stride_b_loc_s
,
stride_qbs
,
stride_qh
,
stride_qd
,
stride_kbs
,
stride_kh
,
stride_kd
,
stride_vbs
,
stride_vh
,
stride_vd
,
stride_obs
,
stride_oh
,
stride_od
,
stride_k_cache_bs
,
stride_k_cache_h
,
stride_k_cache_d
,
stride_k_cache_bl
,
stride_k_cache_x
,
stride_v_cache_bs
,
stride_v_cache_h
,
stride_v_cache_d
,
stride_v_cache_bl
,
num_queries_per_kv
:
int
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
num_queries_per_kv
cur_batch_ctx_len
=
tl
.
load
(
B_Ctxlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
block_start_loc
=
BLOCK_M
*
start_m
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_q
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
*
stride_qd
)
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
other
=
0.0
,
)
# # initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
for
start_n
in
range
(
0
,
cur_batch_ctx_len
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
bn
=
tl
.
load
(
B_Loc
+
cur_batch
*
stride_b_loc_b
+
((
start_n
+
offs_n
)
//
block_size
)
*
stride_b_loc_s
,
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
,
).
to
(
tl
.
int64
)
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_kv_head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_kv_head
*
stride_v_cache_h
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
other
=
0.0
,
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
=
tl
.
where
(
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
qk
,
float
(
"-inf"
)
)
qk
*=
sm_scale
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
p
=
tl
.
math
.
exp
(
qk
-
m_i_new
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
# -- update m_i and l_i
alpha
=
tl
.
math
.
exp
(
m_i
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale
=
alpha
# acc_scale = l_i / l_i_new * alpha
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
V_cache
+
off_v
,
mask
=
(
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
,
other
=
0.0
,
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
*
stride_kd
)
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
1
,
0
)
for
start_n
in
range
(
0
,
block_mask
*
(
start_m
+
1
)
*
BLOCK_M
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
other
=
0.0
,
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
qk
=
tl
.
where
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
p
=
tl
.
math
.
exp
(
qk
-
m_i_new
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
# -- update m_i and l_i
alpha
=
tl
.
math
.
exp
(
m_i
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale
=
alpha
# acc_scale = l_i / l_i_new * alpha
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
mask
=
(
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
other
=
0.0
,
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
# acc /= l_i[:, None]
# initialize pointers to output
off_o
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
[
None
,
:]
*
stride_od
)
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
)
return
@
triton
.
jit
def
_fwd_kernel_alibi
(
Q
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
76e4dcf2
...
...
@@ -98,9 +98,6 @@ __all__ = [
class
CompressedTensorsMoEMethod
(
FusedMoEMethodBase
):
def
__init_
(
self
,
moe
:
FusedMoEConfig
):
super
().
__init__
(
moe
)
@
staticmethod
def
get_moe_method
(
quant_config
:
"CompressedTensorsConfig"
,
# type: ignore # noqa E501
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment