Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ec52464d
Unverified
Commit
ec52464d
authored
Dec 05, 2024
by
Ke Bao
Committed by
GitHub
Dec 05, 2024
Browse files
MLA prefill w/o weight absorption (#2349)
parent
eb0c1f53
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
166 additions
and
36 deletions
+166
-36
python/sglang/srt/layers/attention/__init__.py
python/sglang/srt/layers/attention/__init__.py
+5
-2
python/sglang/srt/layers/attention/double_sparsity_backend.py
...on/sglang/srt/layers/attention/double_sparsity_backend.py
+22
-8
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+20
-5
python/sglang/srt/layers/attention/torch_native_backend.py
python/sglang/srt/layers/attention/torch_native_backend.py
+22
-8
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+22
-8
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+3
-0
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+4
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+68
-3
No files found.
python/sglang/srt/layers/attention/__init__.py
View file @
ec52464d
...
@@ -52,12 +52,13 @@ class AttentionBackend(ABC):
...
@@ -52,12 +52,13 @@ class AttentionBackend(ABC):
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
):
):
"""Run forward on an attention layer."""
"""Run forward on an attention layer."""
if
forward_batch
.
forward_mode
.
is_decode
():
if
forward_batch
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
)
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
)
else
:
else
:
return
self
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
)
return
self
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
)
def
forward_decode
(
def
forward_decode
(
self
,
self
,
...
@@ -66,6 +67,7 @@ class AttentionBackend(ABC):
...
@@ -66,6 +67,7 @@ class AttentionBackend(ABC):
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
):
):
"""Run a forward for decode."""
"""Run a forward for decode."""
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -77,6 +79,7 @@ class AttentionBackend(ABC):
...
@@ -77,6 +79,7 @@ class AttentionBackend(ABC):
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
):
):
"""Run a forward for extend."""
"""Run a forward for extend."""
raise
NotImplementedError
()
raise
NotImplementedError
()
python/sglang/srt/layers/attention/double_sparsity_backend.py
View file @
ec52464d
...
@@ -165,7 +165,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
...
@@ -165,7 +165,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
return
1
return
1
def
forward_extend
(
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
):
# TODO: reuse the buffer across layers
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
...
@@ -181,9 +187,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
...
@@ -181,9 +187,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
.
expand
(
k
.
shape
[
0
],
-
1
,
-
1
),
.
expand
(
k
.
shape
[
0
],
-
1
,
-
1
),
)
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
if
save_kv_cache
:
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
)
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
)
(
(
start_loc
,
start_loc
,
...
@@ -212,7 +219,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
...
@@ -212,7 +219,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
return
o
return
o
def
forward_decode
(
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
):
# During torch.compile, there is a bug in rotary_emb that causes the
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
# output value to have a 3D tensor shape. This reshapes the output correctly.
...
@@ -242,9 +255,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
...
@@ -242,9 +255,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
.
expand
(
k
.
shape
[
0
],
-
1
,
-
1
),
.
expand
(
k
.
shape
[
0
],
-
1
,
-
1
),
)
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
if
save_kv_cache
:
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
)
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
)
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
# and set a minimum value for sparse_decode
# and set a minimum value for sparse_decode
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
ec52464d
...
@@ -221,7 +221,13 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -221,7 +221,13 @@ class FlashInferAttnBackend(AttentionBackend):
return
0
return
0
def
forward_extend
(
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
):
prefill_wrapper_paged
=
self
.
prefill_wrappers_paged
[
prefill_wrapper_paged
=
self
.
prefill_wrappers_paged
[
self
.
_get_wrapper_idx
(
layer
)
self
.
_get_wrapper_idx
(
layer
)
...
@@ -237,7 +243,8 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -237,7 +243,8 @@ class FlashInferAttnBackend(AttentionBackend):
if
not
use_ragged
:
if
not
use_ragged
:
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
o
=
prefill_wrapper_paged
.
forward
(
o
=
prefill_wrapper_paged
.
forward
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
...
@@ -270,12 +277,19 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -270,12 +277,19 @@ class FlashInferAttnBackend(AttentionBackend):
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
forward_decode
(
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
):
decode_wrapper
=
self
.
forward_metadata
[
0
][
self
.
_get_wrapper_idx
(
layer
)]
decode_wrapper
=
self
.
forward_metadata
[
0
][
self
.
_get_wrapper_idx
(
layer
)]
cache_loc
=
(
cache_loc
=
(
...
@@ -286,7 +300,8 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -286,7 +300,8 @@ class FlashInferAttnBackend(AttentionBackend):
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
o
=
decode_wrapper
.
forward
(
o
=
decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
...
...
python/sglang/srt/layers/attention/torch_native_backend.py
View file @
ec52464d
...
@@ -216,16 +216,23 @@ class TorchNativeAttnBackend(AttentionBackend):
...
@@ -216,16 +216,23 @@ class TorchNativeAttnBackend(AttentionBackend):
return
output
return
output
def
forward_extend
(
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
):
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
else
:
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
if
save_kv_cache
:
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
)
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
use_gqa
=
layer
.
tp_q_head_num
!=
layer
.
tp_k_head_num
use_gqa
=
layer
.
tp_q_head_num
!=
layer
.
tp_k_head_num
...
@@ -249,7 +256,13 @@ class TorchNativeAttnBackend(AttentionBackend):
...
@@ -249,7 +256,13 @@ class TorchNativeAttnBackend(AttentionBackend):
return
o
return
o
def
forward_decode
(
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
):
# During torch.compile, there is a bug in rotary_emb that causes the
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
# output value to have a 3D tensor shape. This reshapes the output correctly.
...
@@ -260,9 +273,10 @@ class TorchNativeAttnBackend(AttentionBackend):
...
@@ -260,9 +273,10 @@ class TorchNativeAttnBackend(AttentionBackend):
else
:
else
:
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
if
save_kv_cache
:
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
)
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
use_gqa
=
layer
.
tp_q_head_num
!=
layer
.
tp_k_head_num
use_gqa
=
layer
.
tp_q_head_num
!=
layer
.
tp_k_head_num
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
ec52464d
...
@@ -114,7 +114,13 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -114,7 +114,13 @@ class TritonAttnBackend(AttentionBackend):
return
1
return
1
def
forward_extend
(
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
):
# TODO: reuse the buffer across layers
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
...
@@ -122,9 +128,10 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -122,9 +128,10 @@ class TritonAttnBackend(AttentionBackend):
else
:
else
:
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
if
save_kv_cache
:
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
)
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
self
.
extend_attention_fwd
(
...
@@ -146,7 +153,13 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -146,7 +153,13 @@ class TritonAttnBackend(AttentionBackend):
return
o
return
o
def
forward_decode
(
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
self
,
q
,
k
,
v
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
):
):
# During torch.compile, there is a bug in rotary_emb that causes the
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
# output value to have a 3D tensor shape. This reshapes the output correctly.
...
@@ -160,9 +173,10 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -160,9 +173,10 @@ class TritonAttnBackend(AttentionBackend):
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
start_loc
,
attn_logits
,
max_seq_len
,
max_extend_len
=
self
.
forward_metadata
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
if
save_kv_cache
:
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
)
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
self
.
decode_attention_fwd
(
self
.
decode_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
...
...
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
ec52464d
...
@@ -284,6 +284,9 @@ def extend_attention_fwd(
...
@@ -284,6 +284,9 @@ def extend_attention_fwd(
elif
Lq
==
288
:
elif
Lq
==
288
:
BLOCK_DMODEL
=
256
BLOCK_DMODEL
=
256
BLOCK_DPE
=
32
BLOCK_DPE
=
32
elif
Lq
==
192
:
BLOCK_DMODEL
=
128
BLOCK_DPE
=
64
else
:
else
:
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lq
)
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lq
)
BLOCK_DPE
=
0
BLOCK_DPE
=
0
...
...
python/sglang/srt/layers/radix_attention.py
View file @
ec52464d
...
@@ -48,11 +48,13 @@ class RadixAttention(nn.Module):
...
@@ -48,11 +48,13 @@ class RadixAttention(nn.Module):
self
.
sliding_window_size
=
sliding_window_size
or
-
1
self
.
sliding_window_size
=
sliding_window_size
or
-
1
self
.
is_cross_attention
=
is_cross_attention
self
.
is_cross_attention
=
is_cross_attention
def
forward
(
self
,
q
,
k
,
v
,
forward_batch
:
ForwardBatch
):
def
forward
(
self
,
q
,
k
,
v
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
):
if
k
is
not
None
:
if
k
is
not
None
:
# For cross-layer sharing, kv can be None
# For cross-layer sharing, kv can be None
assert
v
is
not
None
assert
v
is
not
None
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
qk_head_dim
)
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
qk_head_dim
)
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
v_head_dim
)
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
v_head_dim
)
return
forward_batch
.
attn_backend
.
forward
(
q
,
k
,
v
,
self
,
forward_batch
)
return
forward_batch
.
attn_backend
.
forward
(
q
,
k
,
v
,
self
,
forward_batch
,
save_kv_cache
)
python/sglang/srt/models/deepseek_v2.py
View file @
ec52464d
...
@@ -453,7 +453,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -453,7 +453,7 @@ class DeepseekV2AttentionMLA(nn.Module):
mscale
=
yarn_get_mscale
(
scaling_factor
,
float
(
mscale_all_dim
))
mscale
=
yarn_get_mscale
(
scaling_factor
,
float
(
mscale_all_dim
))
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
self
.
attn
=
RadixAttention
(
self
.
attn
_mqa
=
RadixAttention
(
self
.
num_local_heads
,
self
.
num_local_heads
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
self
.
scaling
,
self
.
scaling
,
...
@@ -462,6 +462,15 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -462,6 +462,15 @@ class DeepseekV2AttentionMLA(nn.Module):
v_head_dim
=
self
.
kv_lora_rank
,
v_head_dim
=
self
.
kv_lora_rank
,
)
)
self
.
attn_mha
=
RadixAttention
(
self
.
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
qk_rope_head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_local_heads
,
layer_id
=
layer_id
,
v_head_dim
=
self
.
v_head_dim
,
)
self
.
w_kc
=
None
self
.
w_kc
=
None
self
.
w_vc
=
None
self
.
w_vc
=
None
self
.
w_scale
=
None
self
.
w_scale
=
None
...
@@ -471,6 +480,63 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -471,6 +480,63 @@ class DeepseekV2AttentionMLA(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
# Use normal computation for prefill and use weight absorption for extend/decode
if
(
forward_batch
.
forward_mode
.
is_extend
()
and
forward_batch
.
extend_prefix_lens
.
sum
()
==
0
):
return
self
.
forward_normal
(
positions
,
hidden_states
,
forward_batch
)
else
:
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
)
def
forward_normal
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
if
self
.
q_lora_rank
is
not
None
:
q
=
self
.
q_a_proj
(
hidden_states
)[
0
]
q
=
self
.
q_a_layernorm
(
q
)
q
=
self
.
q_b_proj
(
q
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
_
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
latent_cache
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
]
kv_a
,
_
=
latent_cache
.
split
([
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
latent_cache
=
latent_cache
.
unsqueeze
(
1
)
kv_a
=
self
.
kv_a_layernorm
(
kv_a
.
contiguous
())
kv
=
self
.
kv_b_proj
(
kv_a
)[
0
]
kv
=
kv
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
=
kv
[...,
:
self
.
qk_nope_head_dim
]
v
=
kv
[...,
self
.
qk_nope_head_dim
:]
k_pe
=
latent_cache
[:,
:,
self
.
kv_lora_rank
:]
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
q
[...,
self
.
qk_nope_head_dim
:]
=
q_pe
k
=
torch
.
empty_like
(
q
)
k
[...,
:
self
.
qk_nope_head_dim
]
=
k_nope
k
[...,
self
.
qk_nope_head_dim
:]
=
k_pe
latent_cache
[:,
:,
:
self
.
kv_lora_rank
]
=
kv_a
.
unsqueeze
(
1
)
latent_cache
[:,
:,
self
.
kv_lora_rank
:]
=
k_pe
# Save latent cache
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
self
.
attn_mha
,
forward_batch
.
out_cache_loc
,
latent_cache
,
None
)
attn_output
=
self
.
attn_mha
(
q
,
k
,
v
,
forward_batch
,
save_kv_cache
=
False
)
attn_output
=
attn_output
.
reshape
(
-
1
,
self
.
num_local_heads
*
self
.
v_head_dim
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
def
forward_absorb
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
q_len
=
hidden_states
.
shape
[
0
]
q_len
=
hidden_states
.
shape
[
0
]
q_input
=
hidden_states
.
new_empty
(
q_input
=
hidden_states
.
new_empty
(
...
@@ -508,7 +574,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -508,7 +574,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_input
[...,
self
.
kv_lora_rank
:]
=
q_pe
q_input
[...,
self
.
kv_lora_rank
:]
=
q_pe
k_input
[...,
self
.
kv_lora_rank
:]
=
k_pe
k_input
[...,
self
.
kv_lora_rank
:]
=
k_pe
attn_output
=
self
.
attn
(
q_input
,
k_input
,
v_input
,
forward_batch
)
attn_output
=
self
.
attn
_mqa
(
q_input
,
k_input
,
v_input
,
forward_batch
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
kv_lora_rank
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
kv_lora_rank
)
if
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
if
self
.
w_vc
.
dtype
==
torch
.
float8_e4m3fn
:
...
@@ -835,7 +901,6 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -835,7 +901,6 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn
.
w_vc
=
w_vc
.
contiguous
().
transpose
(
1
,
2
)
self_attn
.
w_vc
=
w_vc
.
contiguous
().
transpose
(
1
,
2
)
if
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale"
):
if
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale"
):
self_attn
.
w_scale
=
self_attn
.
kv_b_proj
.
weight_scale
self_attn
.
w_scale
=
self_attn
.
kv_b_proj
.
weight_scale
del
self_attn
.
kv_b_proj
EntryClass
=
DeepseekV2ForCausalLM
EntryClass
=
DeepseekV2ForCausalLM
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment