Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e1eae1fd
Unverified
Commit
e1eae1fd
authored
Aug 05, 2024
by
Ke Bao
Committed by
GitHub
Aug 05, 2024
Browse files
Support MLA for DeepSeek-V2 with Triton - step 1 (#905)
parent
f4d9953d
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
439 additions
and
78 deletions
+439
-78
benchmark/gsm8k/download_data.sh
benchmark/gsm8k/download_data.sh
+0
-0
python/sglang/srt/layers/extend_attention.py
python/sglang/srt/layers/extend_attention.py
+59
-7
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+22
-9
python/sglang/srt/layers/token_attention.py
python/sglang/srt/layers/token_attention.py
+28
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-3
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+65
-24
python/sglang/srt/model_config.py
python/sglang/srt/model_config.py
+11
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+46
-17
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+198
-16
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
No files found.
benchmark/gsm8k/download_data.sh
100644 → 100755
View file @
e1eae1fd
File mode changed from 100644 to 100755
python/sglang/srt/layers/extend_attention.py
View file @
e1eae1fd
...
@@ -57,6 +57,8 @@ def _fwd_kernel(
...
@@ -57,6 +57,8 @@ def _fwd_kernel(
stride_buf_vh
,
stride_buf_vh
,
stride_req_to_tokens_b
,
stride_req_to_tokens_b
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
...
@@ -75,8 +77,10 @@ def _fwd_kernel(
...
@@ -75,8 +77,10 @@ def _fwd_kernel(
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_seq
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_seq
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
offs_m
=
tl
.
arange
(
0
,
BLOCK_M
)
offs_m
=
tl
.
arange
(
0
,
BLOCK_M
)
mask_m
=
(
cur_block_m
*
BLOCK_M
+
offs_m
)
<
cur_seq_len_extend
mask_m
=
(
cur_block_m
*
BLOCK_M
+
offs_m
)
<
cur_seq_len_extend
offs_q
=
(
offs_q
=
(
(
cur_seq_extend_start_contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
(
cur_seq_extend_start_contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
*
stride_qbs
...
@@ -85,10 +89,20 @@ def _fwd_kernel(
...
@@ -85,10 +89,20 @@ def _fwd_kernel(
)
)
q
=
tl
.
load
(
Q_Extend
+
offs_q
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
q
=
tl
.
load
(
Q_Extend
+
offs_q
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
offs_qpe
=
(
(
cur_seq_extend_start_contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_dpe
[
None
,
:]
)
qpe
=
tl
.
load
(
Q_Extend
+
offs_qpe
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
# stage1: compute scores with prefix
# stage1: compute scores with prefix
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_D
MODEL
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_D
V
],
dtype
=
tl
.
float32
)
deno
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
deno
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
e_max
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
e_max
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
...
@@ -110,6 +124,18 @@ def _fwd_kernel(
...
@@ -110,6 +124,18 @@ def _fwd_kernel(
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
+=
tl
.
dot
(
q
,
k
)
if
BLOCK_DPE
>
0
:
offs_kpe
=
(
offs_kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_kpe
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
,
kpe
)
qk
*=
sm_scale
qk
*=
sm_scale
if
logit_cap
>
0
:
if
logit_cap
>
0
:
...
@@ -125,7 +151,7 @@ def _fwd_kernel(
...
@@ -125,7 +151,7 @@ def _fwd_kernel(
offs_buf_v
=
(
offs_buf_v
=
(
offs_kv_loc
[:,
None
]
*
stride_buf_vbs
offs_kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
cur_kv_head
*
stride_buf_vh
+
offs_d
[
None
,
:]
+
offs_d
v
[
None
,
:]
)
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
...
@@ -150,6 +176,21 @@ def _fwd_kernel(
...
@@ -150,6 +176,21 @@ def _fwd_kernel(
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
+=
tl
.
dot
(
q
,
k
)
if
BLOCK_DPE
>
0
:
offs_kpe
=
(
(
cur_seq_extend_start_contiguous
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Extend
+
offs_kpe
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
,
kpe
)
qk
*=
sm_scale
qk
*=
sm_scale
if
logit_cap
>
0
:
if
logit_cap
>
0
:
...
@@ -169,7 +210,7 @@ def _fwd_kernel(
...
@@ -169,7 +210,7 @@ def _fwd_kernel(
offs_v
=
(
offs_v
=
(
(
cur_seq_extend_start_contiguous
+
start_n
+
offs_n
[:,
None
])
*
stride_vbs
(
cur_seq_extend_start_contiguous
+
start_n
+
offs_n
[:,
None
])
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
cur_kv_head
*
stride_vh
+
offs_d
[
None
,
:]
+
offs_d
v
[
None
,
:]
)
)
v
=
tl
.
load
(
V_Extend
+
offs_v
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
v
=
tl
.
load
(
V_Extend
+
offs_v
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
...
@@ -181,7 +222,7 @@ def _fwd_kernel(
...
@@ -181,7 +222,7 @@ def _fwd_kernel(
(
cur_seq_extend_start_contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
(
cur_seq_extend_start_contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_obs
*
stride_obs
+
cur_head
*
stride_oh
+
cur_head
*
stride_oh
+
offs_d
[
None
,
:]
+
offs_d
v
[
None
,
:]
)
)
tl
.
store
(
O_Extend
+
offs_o
,
acc
/
deno
[:,
None
],
mask
=
mask_m
[:,
None
])
tl
.
store
(
O_Extend
+
offs_o
,
acc
/
deno
[:,
None
],
mask
=
mask_m
[:,
None
])
...
@@ -217,8 +258,17 @@ def extend_attention_fwd(
...
@@ -217,8 +258,17 @@ def extend_attention_fwd(
o_extend
.
shape
[
-
1
],
o_extend
.
shape
[
-
1
],
)
)
assert
Lq
==
Lk
and
Lk
==
Lv
and
Lv
==
Lo
assert
Lq
==
Lk
and
Lv
==
Lo
assert
Lq
in
{
16
,
32
,
64
,
128
,
256
}
assert
Lq
in
{
16
,
32
,
64
,
128
,
256
,
576
}
assert
Lv
in
{
16
,
32
,
64
,
128
,
256
,
512
}
if
Lq
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
else
:
BLOCK_DMODEL
=
Lq
BLOCK_DPE
=
0
BLOCK_DV
=
Lv
if
CUDA_CAPABILITY
[
0
]
>=
8
:
if
CUDA_CAPABILITY
[
0
]
>=
8
:
BLOCK_M
,
BLOCK_N
=
(
128
,
128
)
if
Lq
<=
128
else
(
64
,
64
)
BLOCK_M
,
BLOCK_N
=
(
128
,
128
)
if
Lq
<=
128
else
(
64
,
64
)
...
@@ -260,7 +310,9 @@ def extend_attention_fwd(
...
@@ -260,7 +310,9 @@ def extend_attention_fwd(
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
v_buffer
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
req_to_tokens
.
stride
(
0
),
BLOCK_DMODEL
=
Lq
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_M
=
BLOCK_M
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
BLOCK_N
=
BLOCK_N
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
...
...
python/sglang/srt/layers/radix_attention.py
View file @
e1eae1fd
...
@@ -38,16 +38,22 @@ class RadixAttention(nn.Module):
...
@@ -38,16 +38,22 @@ class RadixAttention(nn.Module):
num_kv_heads
:
int
,
num_kv_heads
:
int
,
layer_id
:
int
,
layer_id
:
int
,
logit_cap
:
int
=
-
1
,
logit_cap
:
int
=
-
1
,
v_head_dim
:
int
=
-
1
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
tp_q_head_num
=
num_heads
self
.
tp_q_head_num
=
num_heads
self
.
tp_k_head_num
=
num_kv_heads
self
.
tp_k_head_num
=
num_kv_heads
self
.
tp_v_head_num
=
num_kv_heads
self
.
tp_v_head_num
=
num_kv_heads
self
.
head_dim
=
head_dim
self
.
head_dim
=
head_dim
self
.
qk_head_dim
=
head_dim
self
.
v_head_dim
=
v_head_dim
if
v_head_dim
!=
-
1
else
head_dim
self
.
scaling
=
scaling
self
.
scaling
=
scaling
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
if
(
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
)
and
self
.
qk_head_dim
==
self
.
v_head_dim
):
self
.
extend_forward
=
self
.
extend_forward_flashinfer
self
.
extend_forward
=
self
.
extend_forward_flashinfer
self
.
decode_forward
=
self
.
decode_forward_flashinfer
self
.
decode_forward
=
self
.
decode_forward_flashinfer
else
:
else
:
...
@@ -57,13 +63,17 @@ class RadixAttention(nn.Module):
...
@@ -57,13 +63,17 @@ class RadixAttention(nn.Module):
self
.
logit_cap
=
logit_cap
if
logit_cap
is
not
None
and
logit_cap
>
0
else
0
self
.
logit_cap
=
logit_cap
if
logit_cap
is
not
None
and
logit_cap
>
0
else
0
def
extend_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
def
extend_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
o
=
torch
.
empty_like
(
q
)
if
self
.
qk_head_dim
!=
self
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
self
.
tp_q_head_num
*
self
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
extend_attention_fwd
(
extend_attention_fwd
(
q
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
q
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
qk_
head_dim
),
k
.
contiguous
(),
k
.
contiguous
(),
v
.
contiguous
(),
v
.
contiguous
(),
o
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
o
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
v_
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
),
input_metadata
.
req_to_token_pool
.
req_to_token
,
input_metadata
.
req_to_token_pool
.
req_to_token
,
...
@@ -82,14 +92,17 @@ class RadixAttention(nn.Module):
...
@@ -82,14 +92,17 @@ class RadixAttention(nn.Module):
return
o
return
o
def
decode_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
def
decode_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
o
=
torch
.
empty_like
(
q
)
if
self
.
qk_head_dim
!=
self
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
self
.
tp_q_head_num
*
self
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
token_attention_fwd
(
token_attention_fwd
(
q
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
q
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
qk_
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
),
o
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
o
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
v_
head_dim
),
input_metadata
.
req_to_token_pool
.
req_to_token
,
input_metadata
.
req_to_token_pool
.
req_to_token
,
input_metadata
.
req_pool_indices
,
input_metadata
.
req_pool_indices
,
input_metadata
.
triton_start_loc
,
input_metadata
.
triton_start_loc
,
...
@@ -160,8 +173,8 @@ class RadixAttention(nn.Module):
...
@@ -160,8 +173,8 @@ class RadixAttention(nn.Module):
return
o
.
view
(
-
1
,
self
.
tp_q_head_num
*
self
.
head_dim
)
return
o
.
view
(
-
1
,
self
.
tp_q_head_num
*
self
.
head_dim
)
def
forward
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
def
forward
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
)
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
qk_
head_dim
)
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
head_dim
)
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
v_
head_dim
)
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
return
self
.
extend_forward
(
q
,
k
,
v
,
input_metadata
)
return
self
.
extend_forward
(
q
,
k
,
v
,
input_metadata
)
...
...
python/sglang/srt/layers/token_attention.py
View file @
e1eae1fd
...
@@ -54,6 +54,7 @@ def _fwd_kernel_stage1(
...
@@ -54,6 +54,7 @@ def _fwd_kernel_stage1(
att_stride_h
,
att_stride_h
,
kv_group_num
:
tl
.
constexpr
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
):
):
...
@@ -73,6 +74,10 @@ def _fwd_kernel_stage1(
...
@@ -73,6 +74,10 @@ def _fwd_kernel_stage1(
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
off_qpe
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_dpe
offs_n
=
start_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_n
=
start_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
block_stard_index
=
start_n
*
BLOCK_N
block_stard_index
=
start_n
*
BLOCK_N
...
@@ -97,6 +102,19 @@ def _fwd_kernel_stage1(
...
@@ -97,6 +102,19 @@ def _fwd_kernel_stage1(
other
=
0.0
,
other
=
0.0
,
).
to
(
REDUCE_TRITON_TYPE
)
).
to
(
REDUCE_TRITON_TYPE
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
if
BLOCK_DPE
>
0
:
qpe
=
tl
.
load
(
Q
+
off_qpe
+
start_mark
).
to
(
REDUCE_TRITON_TYPE
)
offs_buf_kpe
=
(
k_loc
[:,
None
]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[
None
,
:]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
mask
=
offs_n_new
[:,
None
]
<
cur_batch_end_index
,
other
=
0.0
,
).
to
(
REDUCE_TRITON_TYPE
)
att_value
+=
tl
.
sum
(
qpe
[
None
,
:]
*
kpe
,
1
)
att_value
*=
sm_scale
att_value
*=
sm_scale
if
logit_cap
>
0
:
if
logit_cap
>
0
:
...
@@ -192,7 +210,14 @@ def _token_att_m_fwd(
...
@@ -192,7 +210,14 @@ def _token_att_m_fwd(
# shape constraints
# shape constraints
Lq
,
Lk
=
q
.
shape
[
-
1
],
k_buffer
.
shape
[
-
1
]
Lq
,
Lk
=
q
.
shape
[
-
1
],
k_buffer
.
shape
[
-
1
]
assert
Lq
==
Lk
assert
Lq
==
Lk
assert
Lk
in
{
16
,
32
,
64
,
128
,
256
}
assert
Lk
in
{
16
,
32
,
64
,
128
,
256
,
576
}
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
else
:
BLOCK_DMODEL
=
Lk
BLOCK_DPE
=
0
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
...
@@ -220,7 +245,8 @@ def _token_att_m_fwd(
...
@@ -220,7 +245,8 @@ def _token_att_m_fwd(
k_buffer
.
stride
(
1
),
k_buffer
.
stride
(
1
),
att_out
.
stride
(
0
),
att_out
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
e1eae1fd
...
@@ -29,7 +29,7 @@ from sglang.global_config import global_config
...
@@ -29,7 +29,7 @@ from sglang.global_config import global_config
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
ReqTo
TokenPool
,
Token
ToKV
Pool
from
sglang.srt.mem_cache.memory_pool
import
Base
Token
ToKV
Pool
,
ReqTo
TokenPool
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
...
@@ -39,6 +39,7 @@ global_server_args_dict = {
...
@@ -39,6 +39,7 @@ global_server_args_dict = {
"disable_flashinfer"
:
False
,
"disable_flashinfer"
:
False
,
"disable_flashinfer_sampling"
:
False
,
"disable_flashinfer_sampling"
:
False
,
"attention_reduce_in_fp32"
:
False
,
"attention_reduce_in_fp32"
:
False
,
"enable_mla"
:
False
,
}
}
...
@@ -289,7 +290,7 @@ class Batch:
...
@@ -289,7 +290,7 @@ class Batch:
# Request, memory pool, and cache
# Request, memory pool, and cache
reqs
:
List
[
Req
]
reqs
:
List
[
Req
]
req_to_token_pool
:
ReqToTokenPool
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
TokenToKVPool
token_to_kv_pool
:
Base
TokenToKVPool
tree_cache
:
RadixCache
tree_cache
:
RadixCache
# Batched arguments to model runner
# Batched arguments to model runner
...
@@ -780,7 +781,7 @@ class InputMetadata:
...
@@ -780,7 +781,7 @@ class InputMetadata:
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
positions
:
torch
.
Tensor
positions
:
torch
.
Tensor
req_to_token_pool
:
ReqToTokenPool
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
TokenToKVPool
token_to_kv_pool
:
Base
TokenToKVPool
# For extend
# For extend
extend_seq_lens
:
torch
.
Tensor
extend_seq_lens
:
torch
.
Tensor
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
e1eae1fd
...
@@ -57,32 +57,18 @@ class ReqToTokenPool:
...
@@ -57,32 +57,18 @@ class ReqToTokenPool:
self
.
can_use_mem_size
=
len
(
self
.
mem_state
)
self
.
can_use_mem_size
=
len
(
self
.
mem_state
)
class
TokenToKVPool
:
class
Base
TokenToKVPool
:
"""A memory pool that maps a token to its kv cache locations"""
"""A memory pool that maps a token to its kv cache locations"""
def
__init__
(
def
__init__
(
self
,
self
,
size
:
int
,
size
:
int
,
dtype
:
torch
.
dtype
,
head_num
:
int
,
head_dim
:
int
,
layer_num
:
int
,
):
):
self
.
size
=
size
self
.
size
=
size
# We also add one slot. This slot is used for writing dummy output from padded tokens.
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
mem_state
=
torch
.
ones
((
self
.
size
+
1
,),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
self
.
mem_state
=
torch
.
ones
((
self
.
size
+
1
,),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
# [size, head_num, head_dim] for each layer
self
.
k_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
)
for
_
in
range
(
layer_num
)
]
self
.
v_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
)
for
_
in
range
(
layer_num
)
]
# Prefetch buffer
# Prefetch buffer
self
.
prefetch_buffer
=
torch
.
empty
(
0
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
self
.
prefetch_buffer
=
torch
.
empty
(
0
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
self
.
prefetch_chunk_size
=
512
self
.
prefetch_chunk_size
=
512
...
@@ -90,15 +76,6 @@ class TokenToKVPool:
...
@@ -90,15 +76,6 @@ class TokenToKVPool:
self
.
can_use_mem_size
=
self
.
size
self
.
can_use_mem_size
=
self
.
size
self
.
clear
()
self
.
clear
()
def
get_key_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
return
self
.
v_buffer
[
layer_id
]
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
],
self
.
v_buffer
[
layer_id
]
def
available_size
(
self
):
def
available_size
(
self
):
return
self
.
can_use_mem_size
+
len
(
self
.
prefetch_buffer
)
return
self
.
can_use_mem_size
+
len
(
self
.
prefetch_buffer
)
...
@@ -139,3 +116,67 @@ class TokenToKVPool:
...
@@ -139,3 +116,67 @@ class TokenToKVPool:
# We also add one slot. This slot is used for writing dummy output from padded tokens.
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
mem_state
[
0
]
=
False
self
.
mem_state
[
0
]
=
False
class
MHATokenToKVPool
(
BaseTokenToKVPool
):
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
head_num
:
int
,
head_dim
:
int
,
layer_num
:
int
,
):
super
().
__init__
(
size
)
# [size, head_num, head_dim] for each layer
self
.
k_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
)
for
_
in
range
(
layer_num
)
]
self
.
v_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
)
for
_
in
range
(
layer_num
)
]
def
get_key_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
return
self
.
v_buffer
[
layer_id
]
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
],
self
.
v_buffer
[
layer_id
]
class
MLATokenToKVPool
(
BaseTokenToKVPool
):
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
kv_lora_rank
:
int
,
qk_rope_head_dim
:
int
,
layer_num
:
int
,
):
super
().
__init__
(
size
)
self
.
kv_lora_rank
=
kv_lora_rank
self
.
kv_buffer
=
[
torch
.
empty
(
(
size
+
1
,
1
,
kv_lora_rank
+
qk_rope_head_dim
),
dtype
=
dtype
,
device
=
"cuda"
,
)
for
_
in
range
(
layer_num
)
]
def
get_key_buffer
(
self
,
layer_id
:
int
):
return
self
.
kv_buffer
[
layer_id
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
return
self
.
kv_buffer
[
layer_id
][...,
:
self
.
kv_lora_rank
]
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
get_key_buffer
(
layer_id
),
self
.
get_value_buffer
(
layer_id
)
python/sglang/srt/model_config.py
View file @
e1eae1fd
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
from
enum
import
IntEnum
,
auto
from
typing
import
Optional
from
typing
import
Optional
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
...
@@ -20,6 +21,11 @@ from transformers import PretrainedConfig
...
@@ -20,6 +21,11 @@ from transformers import PretrainedConfig
from
sglang.srt.hf_transformers_utils
import
get_config
,
get_context_length
from
sglang.srt.hf_transformers_utils
import
get_config
,
get_context_length
class
AttentionArch
(
IntEnum
):
MLA
=
auto
()
MHA
=
auto
()
class
ModelConfig
:
class
ModelConfig
:
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -55,6 +61,11 @@ class ModelConfig:
...
@@ -55,6 +61,11 @@ class ModelConfig:
# FIXME: temporary special judge for deepseek v2 MLA architecture
# FIXME: temporary special judge for deepseek v2 MLA architecture
if
"DeepseekV2ForCausalLM"
in
self
.
hf_config
.
architectures
:
if
"DeepseekV2ForCausalLM"
in
self
.
hf_config
.
architectures
:
self
.
head_dim
=
256
self
.
head_dim
=
256
self
.
attention_arch
=
AttentionArch
.
MLA
self
.
kv_lora_rank
=
self
.
hf_config
.
kv_lora_rank
self
.
qk_rope_head_dim
=
self
.
hf_config
.
qk_rope_head_dim
else
:
self
.
attention_arch
=
AttentionArch
.
MHA
self
.
num_attention_heads
=
self
.
hf_config
.
num_attention_heads
self
.
num_attention_heads
=
self
.
hf_config
.
num_attention_heads
self
.
num_key_value_heads
=
getattr
(
self
.
hf_config
,
"num_key_value_heads"
,
None
)
self
.
num_key_value_heads
=
getattr
(
self
.
hf_config
,
"num_key_value_heads"
,
None
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
e1eae1fd
...
@@ -47,7 +47,12 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -47,7 +47,12 @@ from sglang.srt.managers.schedule_batch import (
InputMetadata
,
InputMetadata
,
global_server_args_dict
,
global_server_args_dict
,
)
)
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPool
,
MLATokenToKVPool
,
ReqToTokenPool
,
)
from
sglang.srt.model_config
import
AttentionArch
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
get_available_gpu_memory
,
...
@@ -86,6 +91,7 @@ class ModelRunner:
...
@@ -86,6 +91,7 @@ class ModelRunner:
"disable_flashinfer"
:
server_args
.
disable_flashinfer
,
"disable_flashinfer"
:
server_args
.
disable_flashinfer
,
"disable_flashinfer_sampling"
:
server_args
.
disable_flashinfer_sampling
,
"disable_flashinfer_sampling"
:
server_args
.
disable_flashinfer_sampling
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
"enable_mla"
:
server_args
.
enable_mla
,
}
}
)
)
...
@@ -193,15 +199,23 @@ class ModelRunner:
...
@@ -193,15 +199,23 @@ class ModelRunner:
available_gpu_memory
=
get_available_gpu_memory
(
available_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
)
head_dim
=
self
.
model_config
.
head_dim
if
(
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
)
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
cell_size
=
(
and
self
.
server_args
.
enable_mla
head_num
):
*
head_dim
cell_size
=
(
*
self
.
model_config
.
num_hidden_layers
(
self
.
model_config
.
kv_lora_rank
+
self
.
model_config
.
qk_rope_head_dim
)
*
2
*
self
.
model_config
.
num_hidden_layers
*
torch
.
_utils
.
_element_size
(
self
.
dtype
)
*
torch
.
_utils
.
_element_size
(
self
.
dtype
)
)
)
else
:
cell_size
=
(
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
)
*
self
.
model_config
.
head_dim
*
self
.
model_config
.
num_hidden_layers
*
2
*
torch
.
_utils
.
_element_size
(
self
.
dtype
)
)
rest_memory
=
available_gpu_memory
-
total_gpu_memory
*
(
rest_memory
=
available_gpu_memory
-
total_gpu_memory
*
(
1
-
self
.
mem_fraction_static
1
-
self
.
mem_fraction_static
)
)
...
@@ -241,13 +255,28 @@ class ModelRunner:
...
@@ -241,13 +255,28 @@ class ModelRunner:
max_num_reqs
,
max_num_reqs
,
self
.
model_config
.
context_len
+
8
,
self
.
model_config
.
context_len
+
8
,
)
)
self
.
token_to_kv_pool
=
TokenToKVPool
(
if
(
self
.
max_total_num_tokens
,
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
dtype
=
self
.
dtype
,
and
self
.
server_args
.
enable_mla
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
):
head_dim
=
self
.
model_config
.
head_dim
,
self
.
token_to_kv_pool
=
MLATokenToKVPool
(
layer_num
=
self
.
model_config
.
num_hidden_layers
,
self
.
max_total_num_tokens
,
)
dtype
=
self
.
dtype
,
kv_lora_rank
=
self
.
model_config
.
kv_lora_rank
,
qk_rope_head_dim
=
self
.
model_config
.
qk_rope_head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
)
logger
.
info
(
"using MLA Triton implementaion, flashinfer is disabled"
)
# FIXME: temporarily only Triton MLA is supported
self
.
server_args
.
disable_flashinfer
=
True
else
:
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
self
.
max_total_num_tokens
,
dtype
=
self
.
dtype
,
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
)
logger
.
info
(
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Memory pool end. "
f
"[gpu=
{
self
.
gpu_id
}
] Memory pool end. "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
e1eae1fd
...
@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.model_runner
import
InputMetadata
from
sglang.srt.model_executor.model_runner
import
InputMetadata
...
@@ -312,6 +313,165 @@ class DeepseekV2Attention(nn.Module):
...
@@ -312,6 +313,165 @@ class DeepseekV2Attention(nn.Module):
return
output
return
output
class
DeepseekV2AttentionMLA
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
hidden_size
:
int
,
num_heads
:
int
,
qk_nope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
v_head_dim
:
int
,
q_lora_rank
:
int
,
kv_lora_rank
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
layer_id
=
None
,
)
->
None
:
super
().
__init__
()
self
.
layer_id
=
layer_id
self
.
hidden_size
=
hidden_size
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
qk_head_dim
=
qk_nope_head_dim
+
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
q_lora_rank
=
q_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
num_heads
=
num_heads
tp_size
=
get_tensor_model_parallel_world_size
()
assert
num_heads
%
tp_size
==
0
self
.
num_local_heads
=
num_heads
//
tp_size
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
q_lora_rank
is
not
None
:
self
.
q_a_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
q_lora_rank
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
q_a_layernorm
=
RMSNorm
(
self
.
q_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
q_b_proj
=
ColumnParallelLinear
(
q_lora_rank
,
self
.
num_heads
*
self
.
qk_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
)
else
:
self
.
q_proj
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
qk_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
kv_a_proj_with_mqa
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
kv_b_proj
=
ColumnParallelLinear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
(
self
.
qk_nope_head_dim
+
self
.
v_head_dim
),
bias
=
False
,
quant_config
=
quant_config
,
)
# O projection.
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
v_head_dim
,
self
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
rope_scaling
[
"type"
]
=
"deepseek_yarn"
self
.
rotary_emb
=
get_rope
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
False
,
)
if
rope_scaling
:
mscale_all_dim
=
rope_scaling
.
get
(
"mscale_all_dim"
,
False
)
scaling_factor
=
rope_scaling
[
"factor"
]
mscale
=
yarn_get_mscale
(
scaling_factor
,
float
(
mscale_all_dim
))
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
self
.
attn
=
RadixAttention
(
self
.
num_local_heads
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
self
.
scaling
,
num_kv_heads
=
1
,
layer_id
=
layer_id
,
v_head_dim
=
self
.
kv_lora_rank
,
)
kv_b_proj
=
self
.
kv_b_proj
w_kc
,
w_vc
=
kv_b_proj
.
weight
.
unflatten
(
0
,
(
-
1
,
qk_nope_head_dim
+
v_head_dim
)
).
split
([
qk_nope_head_dim
,
v_head_dim
],
dim
=
1
)
self
.
w_kc
=
w_kc
self
.
w_vc
=
w_vc
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
q_len
=
hidden_states
.
shape
[
0
]
q_input
=
hidden_states
.
new_empty
(
q_len
,
self
.
num_local_heads
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
)
if
self
.
q_lora_rank
is
not
None
:
q
=
self
.
q_a_proj
(
hidden_states
)[
0
]
q
=
self
.
q_a_layernorm
(
q
)
q
=
self
.
q_b_proj
(
q
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
q_nope
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
q_nope_out
=
q_input
[...,
:
self
.
kv_lora_rank
]
torch
.
bmm
(
q_nope
.
transpose
(
0
,
1
),
self
.
w_kc
,
out
=
q_nope_out
.
transpose
(
0
,
1
))
k_input
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
].
unsqueeze
(
1
)
k_pe
=
k_input
[...,
self
.
kv_lora_rank
:]
v_input
=
k_input
[...,
:
self
.
kv_lora_rank
]
v_input
=
self
.
kv_a_layernorm
(
v_input
.
contiguous
())
k_input
[...,
:
self
.
kv_lora_rank
]
=
v_input
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
q_input
[...,
self
.
kv_lora_rank
:]
=
q_pe
k_input
[...,
self
.
kv_lora_rank
:]
=
k_pe
attn_output
=
self
.
attn
(
q_input
,
k_input
,
v_input
,
input_metadata
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
kv_lora_rank
)
attn_bmm_output
=
attn_output
.
new_empty
(
q_len
,
self
.
num_local_heads
,
self
.
v_head_dim
)
torch
.
bmm
(
attn_output
.
transpose
(
0
,
1
),
self
.
w_vc
.
transpose
(
1
,
2
).
contiguous
(),
out
=
attn_bmm_output
.
transpose
(
0
,
1
),
)
attn_output
=
attn_bmm_output
.
flatten
(
1
,
2
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -326,22 +486,44 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -326,22 +486,44 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
self_attn
=
DeepseekV2Attention
(
if
global_server_args_dict
[
"enable_mla"
]:
config
=
config
,
self
.
self_attn
=
DeepseekV2AttentionMLA
(
hidden_size
=
self
.
hidden_size
,
config
=
config
,
num_heads
=
config
.
num_attention_heads
,
hidden_size
=
self
.
hidden_size
,
qk_nope_head_dim
=
config
.
qk_nope_head_dim
,
num_heads
=
config
.
num_attention_heads
,
qk_rope_head_dim
=
config
.
qk_rope_head_dim
,
qk_nope_head_dim
=
config
.
qk_nope_head_dim
,
v_head_dim
=
config
.
v_head_dim
,
qk_rope_head_dim
=
config
.
qk_rope_head_dim
,
q_lora_rank
=
config
.
q_lora_rank
if
hasattr
(
config
,
"q_lora_rank"
)
else
None
,
v_head_dim
=
config
.
v_head_dim
,
kv_lora_rank
=
config
.
kv_lora_rank
,
q_lora_rank
=
(
rope_theta
=
rope_theta
,
config
.
q_lora_rank
if
hasattr
(
config
,
"q_lora_rank"
)
else
None
rope_scaling
=
rope_scaling
,
),
max_position_embeddings
=
max_position_embeddings
,
kv_lora_rank
=
config
.
kv_lora_rank
,
cache_config
=
cache_config
,
rope_theta
=
rope_theta
,
quant_config
=
quant_config
,
rope_scaling
=
rope_scaling
,
layer_id
=
layer_id
,
max_position_embeddings
=
max_position_embeddings
,
)
cache_config
=
cache_config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
,
)
else
:
self
.
self_attn
=
DeepseekV2Attention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
qk_nope_head_dim
=
config
.
qk_nope_head_dim
,
qk_rope_head_dim
=
config
.
qk_rope_head_dim
,
v_head_dim
=
config
.
v_head_dim
,
q_lora_rank
=
(
config
.
q_lora_rank
if
hasattr
(
config
,
"q_lora_rank"
)
else
None
),
kv_lora_rank
=
config
.
kv_lora_rank
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
,
)
if
(
if
(
config
.
n_routed_experts
is
not
None
config
.
n_routed_experts
is
not
None
and
layer_id
>=
config
.
first_k_dense_replace
and
layer_id
>=
config
.
first_k_dense_replace
...
...
python/sglang/srt/server_args.py
View file @
e1eae1fd
...
@@ -80,6 +80,7 @@ class ServerArgs:
...
@@ -80,6 +80,7 @@ class ServerArgs:
disable_disk_cache
:
bool
=
False
disable_disk_cache
:
bool
=
False
enable_torch_compile
:
bool
=
False
enable_torch_compile
:
bool
=
False
enable_p2p_check
:
bool
=
False
enable_p2p_check
:
bool
=
False
enable_mla
:
bool
=
False
attention_reduce_in_fp32
:
bool
=
False
attention_reduce_in_fp32
:
bool
=
False
efficient_weight_load
:
bool
=
False
efficient_weight_load
:
bool
=
False
...
@@ -393,6 +394,11 @@ class ServerArgs:
...
@@ -393,6 +394,11 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Enable P2P check for GPU access, otherwise the p2p access is allowed by default."
,
help
=
"Enable P2P check for GPU access, otherwise the p2p access is allowed by default."
,
)
)
parser
.
add_argument
(
"--enable-mla"
,
action
=
"store_true"
,
help
=
"Enable Multi-head Latent Attention (MLA) for DeepSeek-V2"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--attention-reduce-in-fp32"
,
"--attention-reduce-in-fp32"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment