Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4eefbf96
Unverified
Commit
4eefbf96
authored
Apr 02, 2026
by
Jiangyun Zhu
Committed by
GitHub
Apr 02, 2026
Browse files
[Perf] fuse kernels in gdn (#37813)
Signed-off-by:
zjy0516
<
riverclouds.zhu@qq.com
>
parent
551b3fb3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
495 additions
and
10 deletions
+495
-10
tests/kernels/test_fused_gdn_post_conv.py
tests/kernels/test_fused_gdn_post_conv.py
+209
-0
vllm/model_executor/layers/fla/ops/__init__.py
vllm/model_executor/layers/fla/ops/__init__.py
+2
-0
vllm/model_executor/layers/fla/ops/fused_gdn_prefill_post_conv.py
...el_executor/layers/fla/ops/fused_gdn_prefill_post_conv.py
+248
-0
vllm/model_executor/layers/mamba/gdn_linear_attn.py
vllm/model_executor/layers/mamba/gdn_linear_attn.py
+36
-10
No files found.
tests/kernels/test_fused_gdn_post_conv.py
0 → 100644
View file @
4eefbf96
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for fused_gdn_prefill_post_conv kernel.
Verifies that the fused kernel matches the reference:
split → rearrange → contiguous → l2norm → gating
"""
import
pytest
import
torch
import
torch.nn.functional
as
F
from
vllm.model_executor.layers.fla.ops.fused_gdn_prefill_post_conv
import
(
fused_post_conv_prep
,
)
def
reference_post_conv
(
conv_output
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
A_log
:
torch
.
Tensor
,
dt_bias
:
torch
.
Tensor
,
H
:
int
,
K
:
int
,
V
:
int
,
apply_l2norm
:
bool
=
True
,
output_g_exp
:
bool
=
False
,
):
"""Reference implementation using individual ops."""
L
=
conv_output
.
shape
[
0
]
HV
=
A_log
.
shape
[
0
]
# Split
q_flat
,
k_flat
,
v_flat
=
torch
.
split
(
conv_output
,
[
H
*
K
,
H
*
K
,
HV
*
V
],
dim
=-
1
)
# Rearrange + contiguous
q
=
q_flat
.
view
(
L
,
H
,
K
).
contiguous
()
k
=
k_flat
.
view
(
L
,
H
,
K
).
contiguous
()
v
=
v_flat
.
view
(
L
,
HV
,
V
).
contiguous
()
# L2 norm
if
apply_l2norm
:
q
=
F
.
normalize
(
q
.
float
(),
p
=
2
,
dim
=-
1
,
eps
=
1e-6
).
to
(
conv_output
.
dtype
)
k
=
F
.
normalize
(
k
.
float
(),
p
=
2
,
dim
=-
1
,
eps
=
1e-6
).
to
(
conv_output
.
dtype
)
# Gating
x
=
a
.
float
()
+
dt_bias
.
float
()
sp
=
F
.
softplus
(
x
,
beta
=
1.0
,
threshold
=
20.0
)
g
=
-
torch
.
exp
(
A_log
.
float
())
*
sp
if
output_g_exp
:
g
=
torch
.
exp
(
g
)
beta_out
=
torch
.
sigmoid
(
b
.
float
())
return
q
,
k
,
v
,
g
,
beta_out
# Qwen3.5-35B config: H=16, HV=32, K=128, V=128
# Qwen3.5-397B config: H=16, HV=64, K=128, V=128
@
pytest
.
mark
.
parametrize
(
"H, HV, K, V"
,
[
(
16
,
32
,
128
,
128
),
# 35B
(
16
,
64
,
128
,
128
),
# 397B
(
4
,
8
,
64
,
64
),
# small
],
)
@
pytest
.
mark
.
parametrize
(
"L"
,
[
1
,
16
,
128
,
512
,
2048
])
@
pytest
.
mark
.
parametrize
(
"apply_l2norm"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"output_g_exp"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
def
test_fused_post_conv_correctness
(
H
,
HV
,
K
,
V
,
L
,
apply_l2norm
,
output_g_exp
,
dtype
):
"""Test fused kernel matches reference for all configs."""
torch
.
manual_seed
(
42
)
device
=
"cuda"
qkv_dim
=
2
*
H
*
K
+
HV
*
V
conv_output
=
torch
.
randn
(
L
,
qkv_dim
,
dtype
=
dtype
,
device
=
device
)
a
=
torch
.
randn
(
L
,
HV
,
dtype
=
dtype
,
device
=
device
)
b
=
torch
.
randn
(
L
,
HV
,
dtype
=
dtype
,
device
=
device
)
A_log
=
torch
.
randn
(
HV
,
dtype
=
torch
.
float32
,
device
=
device
)
-
2.0
dt_bias
=
torch
.
randn
(
HV
,
dtype
=
torch
.
float32
,
device
=
device
)
*
0.1
# Reference
ref_q
,
ref_k
,
ref_v
,
ref_g
,
ref_beta
=
reference_post_conv
(
conv_output
,
a
,
b
,
A_log
,
dt_bias
,
H
,
K
,
V
,
apply_l2norm
,
output_g_exp
,
)
# Fused kernel
fused_q
,
fused_k
,
fused_v
,
fused_g
,
fused_beta
=
fused_post_conv_prep
(
conv_output
,
a
,
b
,
A_log
,
dt_bias
,
num_k_heads
=
H
,
head_k_dim
=
K
,
head_v_dim
=
V
,
apply_l2norm
=
apply_l2norm
,
output_g_exp
=
output_g_exp
,
)
# Check shapes
assert
fused_q
.
shape
==
(
L
,
H
,
K
),
f
"q shape:
{
fused_q
.
shape
}
"
assert
fused_k
.
shape
==
(
L
,
H
,
K
),
f
"k shape:
{
fused_k
.
shape
}
"
assert
fused_v
.
shape
==
(
L
,
HV
,
V
),
f
"v shape:
{
fused_v
.
shape
}
"
assert
fused_g
.
shape
==
(
L
,
HV
),
f
"g shape:
{
fused_g
.
shape
}
"
assert
fused_beta
.
shape
==
(
L
,
HV
),
f
"beta shape:
{
fused_beta
.
shape
}
"
# Check dtypes
assert
fused_q
.
dtype
==
dtype
assert
fused_k
.
dtype
==
dtype
assert
fused_v
.
dtype
==
dtype
assert
fused_g
.
dtype
==
torch
.
float32
assert
fused_beta
.
dtype
==
torch
.
float32
# Check contiguity
assert
fused_q
.
is_contiguous
()
assert
fused_k
.
is_contiguous
()
assert
fused_v
.
is_contiguous
()
# Check values
atol_qkv
=
1e-2
if
apply_l2norm
else
1e-3
rtol_qkv
=
1e-2
if
apply_l2norm
else
1e-3
torch
.
testing
.
assert_close
(
fused_q
,
ref_q
,
atol
=
atol_qkv
,
rtol
=
rtol_qkv
)
torch
.
testing
.
assert_close
(
fused_k
,
ref_k
,
atol
=
atol_qkv
,
rtol
=
rtol_qkv
)
torch
.
testing
.
assert_close
(
fused_v
,
ref_v
,
atol
=
1e-3
,
rtol
=
1e-3
)
torch
.
testing
.
assert_close
(
fused_g
,
ref_g
,
atol
=
1e-4
,
rtol
=
1e-4
)
torch
.
testing
.
assert_close
(
fused_beta
,
ref_beta
,
atol
=
1e-4
,
rtol
=
1e-4
)
@
pytest
.
mark
.
parametrize
(
"L"
,
[
1
,
64
,
256
])
def
test_fused_post_conv_sanity
(
L
):
"""Sanity checks: no NaN, unit-norm q/k, beta in (0,1)."""
torch
.
manual_seed
(
0
)
device
=
"cuda"
H
,
HV
,
K
,
V
=
16
,
32
,
128
,
128
qkv_dim
=
2
*
H
*
K
+
HV
*
V
conv_output
=
torch
.
randn
(
L
,
qkv_dim
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
a
=
torch
.
randn
(
L
,
HV
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
b
=
torch
.
randn
(
L
,
HV
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
A_log
=
torch
.
randn
(
HV
,
dtype
=
torch
.
float32
,
device
=
device
)
-
2.0
dt_bias
=
torch
.
randn
(
HV
,
dtype
=
torch
.
float32
,
device
=
device
)
q
,
k
,
v
,
g
,
beta
=
fused_post_conv_prep
(
conv_output
,
a
,
b
,
A_log
,
dt_bias
,
num_k_heads
=
H
,
head_k_dim
=
K
,
head_v_dim
=
V
,
)
# Basic sanity
assert
not
torch
.
isnan
(
q
).
any
(),
"NaN in q"
assert
not
torch
.
isnan
(
k
).
any
(),
"NaN in k"
assert
not
torch
.
isnan
(
v
).
any
(),
"NaN in v"
assert
not
torch
.
isnan
(
g
).
any
(),
"NaN in g"
assert
not
torch
.
isnan
(
beta
).
any
(),
"NaN in beta"
# L2 norm check: each head vector should have unit norm
q_norms
=
torch
.
norm
(
q
.
float
(),
dim
=-
1
)
k_norms
=
torch
.
norm
(
k
.
float
(),
dim
=-
1
)
torch
.
testing
.
assert_close
(
q_norms
,
torch
.
ones_like
(
q_norms
),
atol
=
1e-3
,
rtol
=
1e-3
)
torch
.
testing
.
assert_close
(
k_norms
,
torch
.
ones_like
(
k_norms
),
atol
=
1e-3
,
rtol
=
1e-3
)
# Beta should be in (0, 1)
assert
(
beta
>=
0
).
all
()
and
(
beta
<=
1
).
all
(),
"beta out of range"
def
test_fused_post_conv_l0
():
"""Test L=0 edge case."""
device
=
"cuda"
H
,
HV
,
K
,
V
=
16
,
32
,
128
,
128
qkv_dim
=
2
*
H
*
K
+
HV
*
V
conv_output
=
torch
.
empty
(
0
,
qkv_dim
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
a
=
torch
.
empty
(
0
,
HV
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
b
=
torch
.
empty
(
0
,
HV
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
A_log
=
torch
.
randn
(
HV
,
dtype
=
torch
.
float32
,
device
=
device
)
dt_bias
=
torch
.
randn
(
HV
,
dtype
=
torch
.
float32
,
device
=
device
)
q
,
k
,
v
,
g
,
beta
=
fused_post_conv_prep
(
conv_output
,
a
,
b
,
A_log
,
dt_bias
,
num_k_heads
=
H
,
head_k_dim
=
K
,
head_v_dim
=
V
,
)
assert
q
.
shape
==
(
0
,
H
,
K
)
assert
g
.
shape
==
(
0
,
HV
)
vllm/model_executor/layers/fla/ops/__init__.py
View file @
4eefbf96
...
...
@@ -7,6 +7,7 @@
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from
.chunk
import
chunk_gated_delta_rule
from
.fused_gdn_prefill_post_conv
import
fused_post_conv_prep
from
.fused_recurrent
import
(
fused_recurrent_gated_delta_rule
,
fused_recurrent_gated_delta_rule_packed_decode
,
...
...
@@ -19,5 +20,6 @@ __all__ = [
"chunk_gated_delta_rule"
,
"fused_recurrent_gated_delta_rule"
,
"fused_recurrent_gated_delta_rule_packed_decode"
,
"fused_post_conv_prep"
,
"fused_sigmoid_gating_delta_rule_update"
,
]
vllm/model_executor/layers/fla/ops/fused_gdn_prefill_post_conv.py
0 → 100644
View file @
4eefbf96
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused post-conv1d preparation for GDN prefill.
Replaces the chain:
split → rearrange → contiguous * 3 → l2norm * 2 → gating
with a **single Triton kernel** that reads the conv'd mixed_qkv output
and writes directly to q/k/v/g/beta in the target contiguous layout.
"""
from
__future__
import
annotations
import
torch
from
vllm.triton_utils
import
tl
,
triton
@
triton
.
jit
def
_fused_post_conv_kernel
(
# ---- inputs ----
mixed_qkv_ptr
,
# [L, qkv_dim] conv'd output (contiguous)
a_ptr
,
# [L, HV]
b_ptr
,
# [L, HV]
# ---- params ----
A_log_ptr
,
# [HV]
dt_bias_ptr
,
# [HV]
# ---- outputs ----
q_ptr
,
# [L, H, K] contiguous
k_ptr
,
# [L, H, K] contiguous
v_ptr
,
# [L, HV, V] contiguous
g_ptr
,
# [L, HV] float32
beta_ptr
,
# [L, HV] float32
# ---- strides ----
stride_x_tok
,
# qkv_dim
stride_a_tok
,
# HV
stride_b_tok
,
# HV
stride_q_tok
,
# H * K
stride_k_tok
,
# H * K
stride_v_tok
,
# HV * V
# ---- dims ----
L
,
H
:
tl
.
constexpr
,
HV
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
APPLY_L2NORM
:
tl
.
constexpr
,
L2NORM_EPS
:
tl
.
constexpr
,
OUTPUT_G_EXP
:
tl
.
constexpr
,
SOFTPLUS_THRESHOLD
:
tl
.
constexpr
,
BLOCK_T
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
):
"""Single fused kernel for post-conv1d preparation.
Grid: (ceil(L, BLOCK_T), H + HV)
- program_id(1) in [0, H): Q/K head processing + l2norm
- program_id(1) in [H, H+HV): V head processing + gating
"""
i_tb
=
tl
.
program_id
(
0
)
i_head
=
tl
.
program_id
(
1
)
HK
:
tl
.
constexpr
=
H
*
K
offs_t
=
i_tb
*
BLOCK_T
+
tl
.
arange
(
0
,
BLOCK_T
)
# [BLOCK_T]
mask_t
=
offs_t
<
L
if
i_head
<
H
:
# ============ Q/K head processing ============
i_h
=
i_head
offs_k
=
tl
.
arange
(
0
,
BK
)
# [BK]
mask_k
=
offs_k
<
K
mask_2d
=
mask_t
[:,
None
]
&
mask_k
[
None
,
:]
# [BLOCK_T, BK]
# Load Q features: mixed_qkv[t, i_h*K + k]
q_offsets
=
offs_t
[:,
None
]
*
stride_x_tok
+
i_h
*
K
+
offs_k
[
None
,
:]
q_f32
=
tl
.
load
(
mixed_qkv_ptr
+
q_offsets
,
mask
=
mask_2d
,
other
=
0
).
to
(
tl
.
float32
)
# Load K features: mixed_qkv[t, HK + i_h*K + k]
k_offsets
=
offs_t
[:,
None
]
*
stride_x_tok
+
HK
+
i_h
*
K
+
offs_k
[
None
,
:]
k_f32
=
tl
.
load
(
mixed_qkv_ptr
+
k_offsets
,
mask
=
mask_2d
,
other
=
0
).
to
(
tl
.
float32
)
if
APPLY_L2NORM
:
q_sq_sum
=
tl
.
sum
(
q_f32
*
q_f32
,
axis
=
1
)
# [BLOCK_T]
q_inv
=
1.0
/
tl
.
sqrt
(
q_sq_sum
+
L2NORM_EPS
)
q_f32
=
q_f32
*
q_inv
[:,
None
]
k_sq_sum
=
tl
.
sum
(
k_f32
*
k_f32
,
axis
=
1
)
k_inv
=
1.0
/
tl
.
sqrt
(
k_sq_sum
+
L2NORM_EPS
)
k_f32
=
k_f32
*
k_inv
[:,
None
]
# Store Q
q_out
=
offs_t
[:,
None
]
*
stride_q_tok
+
i_h
*
K
+
offs_k
[
None
,
:]
tl
.
store
(
q_ptr
+
q_out
,
q_f32
.
to
(
q_ptr
.
dtype
.
element_ty
),
mask
=
mask_2d
,
)
# Store K
k_out
=
offs_t
[:,
None
]
*
stride_k_tok
+
i_h
*
K
+
offs_k
[
None
,
:]
tl
.
store
(
k_ptr
+
k_out
,
k_f32
.
to
(
k_ptr
.
dtype
.
element_ty
),
mask
=
mask_2d
,
)
else
:
# ============ V head + gating processing ============
i_hv
=
i_head
-
H
offs_v
=
tl
.
arange
(
0
,
BV
)
# [BV]
mask_v
=
offs_v
<
V
mask_2d
=
mask_t
[:,
None
]
&
mask_v
[
None
,
:]
# [BLOCK_T, BV]
V_OFFSET
:
tl
.
constexpr
=
2
*
H
*
K
# Load V features: mixed_qkv[t, 2*H*K + i_hv*V + v]
v_offsets
=
(
offs_t
[:,
None
]
*
stride_x_tok
+
V_OFFSET
+
i_hv
*
V
+
offs_v
[
None
,
:]
)
v_vals
=
tl
.
load
(
mixed_qkv_ptr
+
v_offsets
,
mask
=
mask_2d
,
other
=
0
)
# Store V
v_out
=
offs_t
[:,
None
]
*
stride_v_tok
+
i_hv
*
V
+
offs_v
[
None
,
:]
tl
.
store
(
v_ptr
+
v_out
,
v_vals
,
mask
=
mask_2d
)
# Gating: one scalar per (token, v-head)
A_log_val
=
tl
.
load
(
A_log_ptr
+
i_hv
).
to
(
tl
.
float32
)
dt_bias_val
=
tl
.
load
(
dt_bias_ptr
+
i_hv
).
to
(
tl
.
float32
)
a_offsets
=
offs_t
*
stride_a_tok
+
i_hv
b_offsets
=
offs_t
*
stride_b_tok
+
i_hv
a_vals
=
tl
.
load
(
a_ptr
+
a_offsets
,
mask
=
mask_t
,
other
=
0
).
to
(
tl
.
float32
)
b_vals
=
tl
.
load
(
b_ptr
+
b_offsets
,
mask
=
mask_t
,
other
=
0
).
to
(
tl
.
float32
)
# g = -exp(A_log) * softplus(a + dt_bias)
x
=
a_vals
+
dt_bias_val
sp
=
tl
.
where
(
x
>
0
,
x
+
tl
.
log
(
1.0
+
tl
.
exp
(
-
x
)),
tl
.
log
(
1.0
+
tl
.
exp
(
x
)))
sp
=
tl
.
where
(
x
<=
SOFTPLUS_THRESHOLD
,
sp
,
x
)
g_vals
=
-
tl
.
exp
(
A_log_val
)
*
sp
if
OUTPUT_G_EXP
:
g_vals
=
tl
.
exp
(
g_vals
)
beta_vals
=
tl
.
sigmoid
(
b_vals
)
gb_offsets
=
offs_t
*
HV
+
i_hv
tl
.
store
(
g_ptr
+
gb_offsets
,
g_vals
,
mask
=
mask_t
)
tl
.
store
(
beta_ptr
+
gb_offsets
,
beta_vals
,
mask
=
mask_t
)
def
fused_post_conv_prep
(
conv_output
:
torch
.
Tensor
,
# [L, qkv_dim] conv'd mixed_qkv
a
:
torch
.
Tensor
,
# [L, HV]
b
:
torch
.
Tensor
,
# [L, HV]
A_log
:
torch
.
Tensor
,
# [HV]
dt_bias
:
torch
.
Tensor
,
# [HV]
num_k_heads
:
int
,
head_k_dim
:
int
,
head_v_dim
:
int
,
apply_l2norm
:
bool
=
True
,
output_g_exp
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Fused post-conv1d prep: split + l2norm + gating in one kernel.
Args:
conv_output: [L, qkv_dim] contiguous conv'd mixed_qkv
a: [L, HV] gating input
b: [L, HV] gating input
A_log: [HV] log decay parameter
dt_bias: [HV] dt bias parameter
num_k_heads: number of K heads (H)
head_k_dim: dimension per K head (K)
head_v_dim: dimension per V head (V)
apply_l2norm: whether to L2-normalize q and k
output_g_exp: if True, output exp(g) instead of g (for FlashInfer)
Returns:
q: [L, H, K] contiguous, optionally l2-normalized
k: [L, H, K] contiguous, optionally l2-normalized
v: [L, HV, V] contiguous
g: [L, HV] float32
beta: [L, HV] float32
"""
L
=
conv_output
.
shape
[
0
]
qkv_dim
=
conv_output
.
shape
[
1
]
H
=
num_k_heads
K
=
head_k_dim
V
=
head_v_dim
HV
=
A_log
.
shape
[
0
]
dtype
=
conv_output
.
dtype
device
=
conv_output
.
device
assert
qkv_dim
==
2
*
H
*
K
+
HV
*
V
,
(
f
"qkv_dim=
{
qkv_dim
}
!= 2*H*K + HV*V =
{
2
*
H
*
K
+
HV
*
V
}
"
)
# Allocate outputs in target contiguous layout
q
=
torch
.
empty
(
L
,
H
,
K
,
dtype
=
dtype
,
device
=
device
)
k
=
torch
.
empty
(
L
,
H
,
K
,
dtype
=
dtype
,
device
=
device
)
v
=
torch
.
empty
(
L
,
HV
,
V
,
dtype
=
dtype
,
device
=
device
)
g
=
torch
.
empty
(
L
,
HV
,
dtype
=
torch
.
float32
,
device
=
device
)
beta
=
torch
.
empty
(
L
,
HV
,
dtype
=
torch
.
float32
,
device
=
device
)
if
L
==
0
:
return
q
,
k
,
v
,
g
,
beta
# ---- Kernel config ----
BK
=
triton
.
next_power_of_2
(
K
)
BV
=
triton
.
next_power_of_2
(
V
)
BLOCK_T
=
16
# tokens per block
# Single kernel: blocks [0,H) do Q/K, blocks [H, H+HV) do V+gating
grid
=
(
triton
.
cdiv
(
L
,
BLOCK_T
),
H
+
HV
)
_fused_post_conv_kernel
[
grid
](
mixed_qkv_ptr
=
conv_output
,
a_ptr
=
a
,
b_ptr
=
b
,
A_log_ptr
=
A_log
,
dt_bias_ptr
=
dt_bias
,
q_ptr
=
q
,
k_ptr
=
k
,
v_ptr
=
v
,
g_ptr
=
g
,
beta_ptr
=
beta
,
stride_x_tok
=
conv_output
.
stride
(
0
),
stride_a_tok
=
a
.
stride
(
0
),
stride_b_tok
=
b
.
stride
(
0
),
stride_q_tok
=
q
.
stride
(
0
),
stride_k_tok
=
k
.
stride
(
0
),
stride_v_tok
=
v
.
stride
(
0
),
L
=
L
,
H
=
H
,
HV
=
HV
,
K
=
K
,
V
=
V
,
APPLY_L2NORM
=
apply_l2norm
,
L2NORM_EPS
=
1e-6
,
OUTPUT_G_EXP
=
output_g_exp
,
SOFTPLUS_THRESHOLD
=
20.0
,
BLOCK_T
=
BLOCK_T
,
BK
=
BK
,
BV
=
BV
,
num_warps
=
4
,
num_stages
=
2
,
)
return
q
,
k
,
v
,
g
,
beta
vllm/model_executor/layers/mamba/gdn_linear_attn.py
View file @
4eefbf96
...
...
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.fla.ops import (
chunk_gated_delta_rule
as
fla_chunk_gated_delta_rule
,
)
from
vllm.model_executor.layers.fla.ops
import
(
fused_post_conv_prep
,
fused_recurrent_gated_delta_rule_packed_decode
,
fused_sigmoid_gating_delta_rule_update
,
)
...
...
@@ -774,19 +775,44 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
mixed_qkv_non_spec
=
None
query_spec
,
key_spec
,
value_spec
=
self
.
rearrange_mixed_qkv
(
mixed_qkv_spec
)
query_non_spec
,
key_non_spec
,
value_non_spec
=
self
.
rearrange_mixed_qkv
(
mixed_qkv_non_spec
)
if
attn_metadata
.
num_prefills
>
0
:
g
,
beta
=
fused_gdn_gating
(
self
.
A_log
,
a
,
b
,
self
.
dt_bias
)
assert
mixed_qkv_non_spec
is
not
None
,
(
"mixed_qkv_non_spec must be provided for prefill path"
)
if
spec_sequence_masks
is
not
None
:
g
_non_spec
=
g
.
index_select
(
1
,
non_spec_token_indx
)
b
eta
_non_spec
=
b
eta
.
index_select
(
1
,
non_spec_token_indx
)
a
_non_spec
=
a
.
index_select
(
0
,
non_spec_token_indx
)
b_non_spec
=
b
.
index_select
(
0
,
non_spec_token_indx
)
else
:
g_non_spec
=
g
beta_non_spec
=
beta
a_non_spec
=
a
b_non_spec
=
b
(
query_non_spec
,
key_non_spec
,
value_non_spec
,
g_non_spec
,
beta_non_spec
,
)
=
fused_post_conv_prep
(
conv_output
=
mixed_qkv_non_spec
,
a
=
a_non_spec
,
b
=
b_non_spec
,
A_log
=
self
.
A_log
,
dt_bias
=
self
.
dt_bias
,
num_k_heads
=
self
.
num_k_heads
//
self
.
tp_size
,
head_k_dim
=
self
.
head_k_dim
,
head_v_dim
=
self
.
head_v_dim
,
apply_l2norm
=
True
,
output_g_exp
=
False
,
)
query_non_spec
=
query_non_spec
.
unsqueeze
(
0
)
key_non_spec
=
key_non_spec
.
unsqueeze
(
0
)
value_non_spec
=
value_non_spec
.
unsqueeze
(
0
)
g_non_spec
=
g_non_spec
.
unsqueeze
(
0
)
beta_non_spec
=
beta_non_spec
.
unsqueeze
(
0
)
else
:
query_non_spec
,
key_non_spec
,
value_non_spec
=
self
.
rearrange_mixed_qkv
(
mixed_qkv_non_spec
)
g_non_spec
=
None
beta_non_spec
=
None
...
...
@@ -832,7 +858,7 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
initial_state
=
initial_state
,
output_final_state
=
True
,
cu_seqlens
=
non_spec_query_start_loc
,
use_qk_l2norm_in_kernel
=
Tru
e
,
use_qk_l2norm_in_kernel
=
Fals
e
,
)
# Init cache
ssm_state
[
non_spec_state_indices_tensor
]
=
last_recurrent_state
.
to
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment