Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e1eae1fd
Unverified
Commit
e1eae1fd
authored
Aug 05, 2024
by
Ke Bao
Committed by
GitHub
Aug 05, 2024
Browse files
Support MLA for DeepSeek-V2 with Triton - step 1 (#905)
parent
f4d9953d
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
439 additions
and
78 deletions
+439
-78
benchmark/gsm8k/download_data.sh
benchmark/gsm8k/download_data.sh
+0
-0
python/sglang/srt/layers/extend_attention.py
python/sglang/srt/layers/extend_attention.py
+59
-7
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+22
-9
python/sglang/srt/layers/token_attention.py
python/sglang/srt/layers/token_attention.py
+28
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-3
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+65
-24
python/sglang/srt/model_config.py
python/sglang/srt/model_config.py
+11
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+46
-17
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+198
-16
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
No files found.
benchmark/gsm8k/download_data.sh
100644 → 100755
View file @
e1eae1fd
File mode changed from 100644 to 100755
python/sglang/srt/layers/extend_attention.py
View file @
e1eae1fd
...
...
@@ -57,6 +57,8 @@ def _fwd_kernel(
stride_buf_vh
,
stride_req_to_tokens_b
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
...
...
@@ -75,8 +77,10 @@ def _fwd_kernel(
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_seq
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
offs_m
=
tl
.
arange
(
0
,
BLOCK_M
)
mask_m
=
(
cur_block_m
*
BLOCK_M
+
offs_m
)
<
cur_seq_len_extend
offs_q
=
(
(
cur_seq_extend_start_contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
...
...
@@ -85,10 +89,20 @@ def _fwd_kernel(
)
q
=
tl
.
load
(
Q_Extend
+
offs_q
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
offs_qpe
=
(
(
cur_seq_extend_start_contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_dpe
[
None
,
:]
)
qpe
=
tl
.
load
(
Q_Extend
+
offs_qpe
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
# stage1: compute scores with prefix
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_D
MODEL
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_D
V
],
dtype
=
tl
.
float32
)
deno
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
e_max
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
...
...
@@ -110,6 +124,18 @@ def _fwd_kernel(
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
if
BLOCK_DPE
>
0
:
offs_kpe
=
(
offs_kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_kpe
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
,
kpe
)
qk
*=
sm_scale
if
logit_cap
>
0
:
...
...
@@ -125,7 +151,7 @@ def _fwd_kernel(
offs_buf_v
=
(
offs_kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
offs_d
[
None
,
:]
+
offs_d
v
[
None
,
:]
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
...
...
@@ -150,6 +176,21 @@ def _fwd_kernel(
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
if
BLOCK_DPE
>
0
:
offs_kpe
=
(
(
cur_seq_extend_start_contiguous
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Extend
+
offs_kpe
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
,
kpe
)
qk
*=
sm_scale
if
logit_cap
>
0
:
...
...
@@ -169,7 +210,7 @@ def _fwd_kernel(
offs_v
=
(
(
cur_seq_extend_start_contiguous
+
start_n
+
offs_n
[:,
None
])
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_d
[
None
,
:]
+
offs_d
v
[
None
,
:]
)
v
=
tl
.
load
(
V_Extend
+
offs_v
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
...
...
@@ -181,7 +222,7 @@ def _fwd_kernel(
(
cur_seq_extend_start_contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
[
None
,
:]
+
offs_d
v
[
None
,
:]
)
tl
.
store
(
O_Extend
+
offs_o
,
acc
/
deno
[:,
None
],
mask
=
mask_m
[:,
None
])
...
...
@@ -217,8 +258,17 @@ def extend_attention_fwd(
o_extend
.
shape
[
-
1
],
)
assert
Lq
==
Lk
and
Lk
==
Lv
and
Lv
==
Lo
assert
Lq
in
{
16
,
32
,
64
,
128
,
256
}
assert
Lq
==
Lk
and
Lv
==
Lo
assert
Lq
in
{
16
,
32
,
64
,
128
,
256
,
576
}
assert
Lv
in
{
16
,
32
,
64
,
128
,
256
,
512
}
if
Lq
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
else
:
BLOCK_DMODEL
=
Lq
BLOCK_DPE
=
0
BLOCK_DV
=
Lv
if
CUDA_CAPABILITY
[
0
]
>=
8
:
BLOCK_M
,
BLOCK_N
=
(
128
,
128
)
if
Lq
<=
128
else
(
64
,
64
)
...
...
@@ -260,7 +310,9 @@ def extend_attention_fwd(
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
BLOCK_DMODEL
=
Lq
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
num_warps
=
num_warps
,
...
...
python/sglang/srt/layers/radix_attention.py
View file @
e1eae1fd
...
...
@@ -38,16 +38,22 @@ class RadixAttention(nn.Module):
num_kv_heads
:
int
,
layer_id
:
int
,
logit_cap
:
int
=
-
1
,
v_head_dim
:
int
=
-
1
,
):
super
().
__init__
()
self
.
tp_q_head_num
=
num_heads
self
.
tp_k_head_num
=
num_kv_heads
self
.
tp_v_head_num
=
num_kv_heads
self
.
head_dim
=
head_dim
self
.
qk_head_dim
=
head_dim
self
.
v_head_dim
=
v_head_dim
if
v_head_dim
!=
-
1
else
head_dim
self
.
scaling
=
scaling
self
.
layer_id
=
layer_id
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
if
(
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
)
and
self
.
qk_head_dim
==
self
.
v_head_dim
):
self
.
extend_forward
=
self
.
extend_forward_flashinfer
self
.
decode_forward
=
self
.
decode_forward_flashinfer
else
:
...
...
@@ -57,13 +63,17 @@ class RadixAttention(nn.Module):
self
.
logit_cap
=
logit_cap
if
logit_cap
is
not
None
and
logit_cap
>
0
else
0
def
extend_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
o
=
torch
.
empty_like
(
q
)
if
self
.
qk_head_dim
!=
self
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
self
.
tp_q_head_num
*
self
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
extend_attention_fwd
(
q
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
q
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
qk_
head_dim
),
k
.
contiguous
(),
v
.
contiguous
(),
o
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
o
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
v_
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
),
input_metadata
.
req_to_token_pool
.
req_to_token
,
...
...
@@ -82,14 +92,17 @@ class RadixAttention(nn.Module):
return
o
def
decode_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
o
=
torch
.
empty_like
(
q
)
if
self
.
qk_head_dim
!=
self
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
self
.
tp_q_head_num
*
self
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
token_attention_fwd
(
q
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
q
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
qk_
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_key_buffer
(
self
.
layer_id
),
input_metadata
.
token_to_kv_pool
.
get_value_buffer
(
self
.
layer_id
),
o
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
o
.
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
v_
head_dim
),
input_metadata
.
req_to_token_pool
.
req_to_token
,
input_metadata
.
req_pool_indices
,
input_metadata
.
triton_start_loc
,
...
...
@@ -160,8 +173,8 @@ class RadixAttention(nn.Module):
return
o
.
view
(
-
1
,
self
.
tp_q_head_num
*
self
.
head_dim
)
def
forward
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
)
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
head_dim
)
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
qk_
head_dim
)
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
v_
head_dim
)
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
return
self
.
extend_forward
(
q
,
k
,
v
,
input_metadata
)
...
...
python/sglang/srt/layers/token_attention.py
View file @
e1eae1fd
...
...
@@ -54,6 +54,7 @@ def _fwd_kernel_stage1(
att_stride_h
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
):
...
...
@@ -73,6 +74,10 @@ def _fwd_kernel_stage1(
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
off_qpe
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_dpe
offs_n
=
start_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
block_stard_index
=
start_n
*
BLOCK_N
...
...
@@ -97,6 +102,19 @@ def _fwd_kernel_stage1(
other
=
0.0
,
).
to
(
REDUCE_TRITON_TYPE
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
if
BLOCK_DPE
>
0
:
qpe
=
tl
.
load
(
Q
+
off_qpe
+
start_mark
).
to
(
REDUCE_TRITON_TYPE
)
offs_buf_kpe
=
(
k_loc
[:,
None
]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[
None
,
:]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_buf_kpe
,
mask
=
offs_n_new
[:,
None
]
<
cur_batch_end_index
,
other
=
0.0
,
).
to
(
REDUCE_TRITON_TYPE
)
att_value
+=
tl
.
sum
(
qpe
[
None
,
:]
*
kpe
,
1
)
att_value
*=
sm_scale
if
logit_cap
>
0
:
...
...
@@ -192,7 +210,14 @@ def _token_att_m_fwd(
# shape constraints
Lq
,
Lk
=
q
.
shape
[
-
1
],
k_buffer
.
shape
[
-
1
]
assert
Lq
==
Lk
assert
Lk
in
{
16
,
32
,
64
,
128
,
256
}
assert
Lk
in
{
16
,
32
,
64
,
128
,
256
,
576
}
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
else
:
BLOCK_DMODEL
=
Lk
BLOCK_DPE
=
0
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
...
...
@@ -220,7 +245,8 @@ def _token_att_m_fwd(
k_buffer
.
stride
(
1
),
att_out
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_N
=
BLOCK
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
e1eae1fd
...
...
@@ -29,7 +29,7 @@ from sglang.global_config import global_config
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
ReqTo
TokenPool
,
Token
ToKV
Pool
from
sglang.srt.mem_cache.memory_pool
import
Base
Token
ToKV
Pool
,
ReqTo
TokenPool
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
...
...
@@ -39,6 +39,7 @@ global_server_args_dict = {
"disable_flashinfer"
:
False
,
"disable_flashinfer_sampling"
:
False
,
"attention_reduce_in_fp32"
:
False
,
"enable_mla"
:
False
,
}
...
...
@@ -289,7 +290,7 @@ class Batch:
# Request, memory pool, and cache
reqs
:
List
[
Req
]
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
TokenToKVPool
token_to_kv_pool
:
Base
TokenToKVPool
tree_cache
:
RadixCache
# Batched arguments to model runner
...
...
@@ -780,7 +781,7 @@ class InputMetadata:
seq_lens
:
torch
.
Tensor
positions
:
torch
.
Tensor
req_to_token_pool
:
ReqToTokenPool
token_to_kv_pool
:
TokenToKVPool
token_to_kv_pool
:
Base
TokenToKVPool
# For extend
extend_seq_lens
:
torch
.
Tensor
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
e1eae1fd
...
...
@@ -57,32 +57,18 @@ class ReqToTokenPool:
self
.
can_use_mem_size
=
len
(
self
.
mem_state
)
class
TokenToKVPool
:
class
Base
TokenToKVPool
:
"""A memory pool that maps a token to its kv cache locations"""
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
head_num
:
int
,
head_dim
:
int
,
layer_num
:
int
,
):
self
.
size
=
size
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
mem_state
=
torch
.
ones
((
self
.
size
+
1
,),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
# [size, head_num, head_dim] for each layer
self
.
k_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
)
for
_
in
range
(
layer_num
)
]
self
.
v_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
)
for
_
in
range
(
layer_num
)
]
# Prefetch buffer
self
.
prefetch_buffer
=
torch
.
empty
(
0
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
self
.
prefetch_chunk_size
=
512
...
...
@@ -90,15 +76,6 @@ class TokenToKVPool:
self
.
can_use_mem_size
=
self
.
size
self
.
clear
()
def
get_key_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
return
self
.
v_buffer
[
layer_id
]
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
],
self
.
v_buffer
[
layer_id
]
def
available_size
(
self
):
return
self
.
can_use_mem_size
+
len
(
self
.
prefetch_buffer
)
...
...
@@ -139,3 +116,67 @@ class TokenToKVPool:
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
mem_state
[
0
]
=
False
class
MHATokenToKVPool
(
BaseTokenToKVPool
):
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
head_num
:
int
,
head_dim
:
int
,
layer_num
:
int
,
):
super
().
__init__
(
size
)
# [size, head_num, head_dim] for each layer
self
.
k_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
)
for
_
in
range
(
layer_num
)
]
self
.
v_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
"cuda"
)
for
_
in
range
(
layer_num
)
]
def
get_key_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
return
self
.
v_buffer
[
layer_id
]
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
],
self
.
v_buffer
[
layer_id
]
class
MLATokenToKVPool
(
BaseTokenToKVPool
):
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
kv_lora_rank
:
int
,
qk_rope_head_dim
:
int
,
layer_num
:
int
,
):
super
().
__init__
(
size
)
self
.
kv_lora_rank
=
kv_lora_rank
self
.
kv_buffer
=
[
torch
.
empty
(
(
size
+
1
,
1
,
kv_lora_rank
+
qk_rope_head_dim
),
dtype
=
dtype
,
device
=
"cuda"
,
)
for
_
in
range
(
layer_num
)
]
def
get_key_buffer
(
self
,
layer_id
:
int
):
return
self
.
kv_buffer
[
layer_id
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
return
self
.
kv_buffer
[
layer_id
][...,
:
self
.
kv_lora_rank
]
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
get_key_buffer
(
layer_id
),
self
.
get_value_buffer
(
layer_id
)
python/sglang/srt/model_config.py
View file @
e1eae1fd
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from
enum
import
IntEnum
,
auto
from
typing
import
Optional
from
transformers
import
PretrainedConfig
...
...
@@ -20,6 +21,11 @@ from transformers import PretrainedConfig
from
sglang.srt.hf_transformers_utils
import
get_config
,
get_context_length
class
AttentionArch
(
IntEnum
):
MLA
=
auto
()
MHA
=
auto
()
class
ModelConfig
:
def
__init__
(
self
,
...
...
@@ -55,6 +61,11 @@ class ModelConfig:
# FIXME: temporary special judge for deepseek v2 MLA architecture
if
"DeepseekV2ForCausalLM"
in
self
.
hf_config
.
architectures
:
self
.
head_dim
=
256
self
.
attention_arch
=
AttentionArch
.
MLA
self
.
kv_lora_rank
=
self
.
hf_config
.
kv_lora_rank
self
.
qk_rope_head_dim
=
self
.
hf_config
.
qk_rope_head_dim
else
:
self
.
attention_arch
=
AttentionArch
.
MHA
self
.
num_attention_heads
=
self
.
hf_config
.
num_attention_heads
self
.
num_key_value_heads
=
getattr
(
self
.
hf_config
,
"num_key_value_heads"
,
None
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
e1eae1fd
...
...
@@ -47,7 +47,12 @@ from sglang.srt.managers.schedule_batch import (
InputMetadata
,
global_server_args_dict
,
)
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPool
,
MLATokenToKVPool
,
ReqToTokenPool
,
)
from
sglang.srt.model_config
import
AttentionArch
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
...
...
@@ -86,6 +91,7 @@ class ModelRunner:
"disable_flashinfer"
:
server_args
.
disable_flashinfer
,
"disable_flashinfer_sampling"
:
server_args
.
disable_flashinfer_sampling
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
"enable_mla"
:
server_args
.
enable_mla
,
}
)
...
...
@@ -193,15 +199,23 @@ class ModelRunner:
available_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
head_dim
=
self
.
model_config
.
head_dim
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
)
cell_size
=
(
head_num
*
head_dim
*
self
.
model_config
.
num_hidden_layers
*
2
*
torch
.
_utils
.
_element_size
(
self
.
dtype
)
)
if
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
and
self
.
server_args
.
enable_mla
):
cell_size
=
(
(
self
.
model_config
.
kv_lora_rank
+
self
.
model_config
.
qk_rope_head_dim
)
*
self
.
model_config
.
num_hidden_layers
*
torch
.
_utils
.
_element_size
(
self
.
dtype
)
)
else
:
cell_size
=
(
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
)
*
self
.
model_config
.
head_dim
*
self
.
model_config
.
num_hidden_layers
*
2
*
torch
.
_utils
.
_element_size
(
self
.
dtype
)
)
rest_memory
=
available_gpu_memory
-
total_gpu_memory
*
(
1
-
self
.
mem_fraction_static
)
...
...
@@ -241,13 +255,28 @@ class ModelRunner:
max_num_reqs
,
self
.
model_config
.
context_len
+
8
,
)
self
.
token_to_kv_pool
=
TokenToKVPool
(
self
.
max_total_num_tokens
,
dtype
=
self
.
dtype
,
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
)
if
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
and
self
.
server_args
.
enable_mla
):
self
.
token_to_kv_pool
=
MLATokenToKVPool
(
self
.
max_total_num_tokens
,
dtype
=
self
.
dtype
,
kv_lora_rank
=
self
.
model_config
.
kv_lora_rank
,
qk_rope_head_dim
=
self
.
model_config
.
qk_rope_head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
)
logger
.
info
(
"using MLA Triton implementaion, flashinfer is disabled"
)
# FIXME: temporarily only Triton MLA is supported
self
.
server_args
.
disable_flashinfer
=
True
else
:
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
self
.
max_total_num_tokens
,
dtype
=
self
.
dtype
,
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
)
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Memory pool end. "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
e1eae1fd
...
...
@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.model_runner
import
InputMetadata
...
...
@@ -312,6 +313,165 @@ class DeepseekV2Attention(nn.Module):
return
output
class
DeepseekV2AttentionMLA
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
hidden_size
:
int
,
num_heads
:
int
,
qk_nope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
v_head_dim
:
int
,
q_lora_rank
:
int
,
kv_lora_rank
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
layer_id
=
None
,
)
->
None
:
super
().
__init__
()
self
.
layer_id
=
layer_id
self
.
hidden_size
=
hidden_size
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
qk_head_dim
=
qk_nope_head_dim
+
qk_rope_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
q_lora_rank
=
q_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
num_heads
=
num_heads
tp_size
=
get_tensor_model_parallel_world_size
()
assert
num_heads
%
tp_size
==
0
self
.
num_local_heads
=
num_heads
//
tp_size
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
q_lora_rank
is
not
None
:
self
.
q_a_proj
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
q_lora_rank
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
q_a_layernorm
=
RMSNorm
(
self
.
q_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
q_b_proj
=
ColumnParallelLinear
(
q_lora_rank
,
self
.
num_heads
*
self
.
qk_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
)
else
:
self
.
q_proj
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
qk_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
kv_a_proj_with_mqa
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
bias
=
False
,
quant_config
=
quant_config
,
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
kv_b_proj
=
ColumnParallelLinear
(
self
.
kv_lora_rank
,
self
.
num_heads
*
(
self
.
qk_nope_head_dim
+
self
.
v_head_dim
),
bias
=
False
,
quant_config
=
quant_config
,
)
# O projection.
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
v_head_dim
,
self
.
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
)
rope_scaling
[
"type"
]
=
"deepseek_yarn"
self
.
rotary_emb
=
get_rope
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
is_neox_style
=
False
,
)
if
rope_scaling
:
mscale_all_dim
=
rope_scaling
.
get
(
"mscale_all_dim"
,
False
)
scaling_factor
=
rope_scaling
[
"factor"
]
mscale
=
yarn_get_mscale
(
scaling_factor
,
float
(
mscale_all_dim
))
self
.
scaling
=
self
.
scaling
*
mscale
*
mscale
self
.
attn
=
RadixAttention
(
self
.
num_local_heads
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
self
.
scaling
,
num_kv_heads
=
1
,
layer_id
=
layer_id
,
v_head_dim
=
self
.
kv_lora_rank
,
)
kv_b_proj
=
self
.
kv_b_proj
w_kc
,
w_vc
=
kv_b_proj
.
weight
.
unflatten
(
0
,
(
-
1
,
qk_nope_head_dim
+
v_head_dim
)
).
split
([
qk_nope_head_dim
,
v_head_dim
],
dim
=
1
)
self
.
w_kc
=
w_kc
self
.
w_vc
=
w_vc
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
q_len
=
hidden_states
.
shape
[
0
]
q_input
=
hidden_states
.
new_empty
(
q_len
,
self
.
num_local_heads
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
)
if
self
.
q_lora_rank
is
not
None
:
q
=
self
.
q_a_proj
(
hidden_states
)[
0
]
q
=
self
.
q_a_layernorm
(
q
)
q
=
self
.
q_b_proj
(
q
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
else
:
q
=
self
.
q_proj
(
hidden_states
)[
0
].
view
(
-
1
,
self
.
num_local_heads
,
self
.
qk_head_dim
)
q_nope
,
q_pe
=
q
.
split
([
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
q_nope_out
=
q_input
[...,
:
self
.
kv_lora_rank
]
torch
.
bmm
(
q_nope
.
transpose
(
0
,
1
),
self
.
w_kc
,
out
=
q_nope_out
.
transpose
(
0
,
1
))
k_input
=
self
.
kv_a_proj_with_mqa
(
hidden_states
)[
0
].
unsqueeze
(
1
)
k_pe
=
k_input
[...,
self
.
kv_lora_rank
:]
v_input
=
k_input
[...,
:
self
.
kv_lora_rank
]
v_input
=
self
.
kv_a_layernorm
(
v_input
.
contiguous
())
k_input
[...,
:
self
.
kv_lora_rank
]
=
v_input
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
q_input
[...,
self
.
kv_lora_rank
:]
=
q_pe
k_input
[...,
self
.
kv_lora_rank
:]
=
k_pe
attn_output
=
self
.
attn
(
q_input
,
k_input
,
v_input
,
input_metadata
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
kv_lora_rank
)
attn_bmm_output
=
attn_output
.
new_empty
(
q_len
,
self
.
num_local_heads
,
self
.
v_head_dim
)
torch
.
bmm
(
attn_output
.
transpose
(
0
,
1
),
self
.
w_vc
.
transpose
(
1
,
2
).
contiguous
(),
out
=
attn_bmm_output
.
transpose
(
0
,
1
),
)
attn_output
=
attn_bmm_output
.
flatten
(
1
,
2
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
DeepseekV2DecoderLayer
(
nn
.
Module
):
def
__init__
(
...
...
@@ -326,22 +486,44 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
self_attn
=
DeepseekV2Attention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
qk_nope_head_dim
=
config
.
qk_nope_head_dim
,
qk_rope_head_dim
=
config
.
qk_rope_head_dim
,
v_head_dim
=
config
.
v_head_dim
,
q_lora_rank
=
config
.
q_lora_rank
if
hasattr
(
config
,
"q_lora_rank"
)
else
None
,
kv_lora_rank
=
config
.
kv_lora_rank
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
,
)
if
global_server_args_dict
[
"enable_mla"
]:
self
.
self_attn
=
DeepseekV2AttentionMLA
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
qk_nope_head_dim
=
config
.
qk_nope_head_dim
,
qk_rope_head_dim
=
config
.
qk_rope_head_dim
,
v_head_dim
=
config
.
v_head_dim
,
q_lora_rank
=
(
config
.
q_lora_rank
if
hasattr
(
config
,
"q_lora_rank"
)
else
None
),
kv_lora_rank
=
config
.
kv_lora_rank
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
,
)
else
:
self
.
self_attn
=
DeepseekV2Attention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
qk_nope_head_dim
=
config
.
qk_nope_head_dim
,
qk_rope_head_dim
=
config
.
qk_rope_head_dim
,
v_head_dim
=
config
.
v_head_dim
,
q_lora_rank
=
(
config
.
q_lora_rank
if
hasattr
(
config
,
"q_lora_rank"
)
else
None
),
kv_lora_rank
=
config
.
kv_lora_rank
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
,
)
if
(
config
.
n_routed_experts
is
not
None
and
layer_id
>=
config
.
first_k_dense_replace
...
...
python/sglang/srt/server_args.py
View file @
e1eae1fd
...
...
@@ -80,6 +80,7 @@ class ServerArgs:
disable_disk_cache
:
bool
=
False
enable_torch_compile
:
bool
=
False
enable_p2p_check
:
bool
=
False
enable_mla
:
bool
=
False
attention_reduce_in_fp32
:
bool
=
False
efficient_weight_load
:
bool
=
False
...
...
@@ -393,6 +394,11 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Enable P2P check for GPU access, otherwise the p2p access is allowed by default."
,
)
parser
.
add_argument
(
"--enable-mla"
,
action
=
"store_true"
,
help
=
"Enable Multi-head Latent Attention (MLA) for DeepSeek-V2"
,
)
parser
.
add_argument
(
"--attention-reduce-in-fp32"
,
action
=
"store_true"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment