Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
061e5463
Unverified
Commit
061e5463
authored
Oct 14, 2024
by
Shuo Yang
Committed by
GitHub
Oct 14, 2024
Browse files
Support double sparsity (#1459)
parent
0c1e8796
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1269 additions
and
1 deletion
+1269
-1
python/sglang/srt/layers/attention/double_sparsity_backend.py
...on/sglang/srt/layers/attention/double_sparsity_backend.py
+281
-0
python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py
.../layers/attention/triton_ops/double_sparsity_attention.py
+772
-0
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+58
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+49
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+45
-0
test/srt/Llama-3.1-8B-Instruct.json
test/srt/Llama-3.1-8B-Instruct.json
+1
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_double_sparsity.py
test/srt/test_double_sparsity.py
+62
-0
No files found.
python/sglang/srt/layers/attention/double_sparsity_backend.py
0 → 100644
View file @
061e5463
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
import
torch
import
torch.nn
as
nn
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
class
DoubleSparseAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Lazy import to avoid the initialization of cuda context
from
sglang.srt.layers.attention.triton_ops.double_sparsity_attention
import
(
flash_decode_attention_fwd
,
flash_decode_sparse_attention_fwd
,
)
from
sglang.srt.layers.attention.triton_ops.extend_attention
import
(
extend_attention_fwd
,
)
super
().
__init__
()
self
.
decode_attention_fwd
=
flash_decode_attention_fwd
self
.
decode_sparse_attention_fwd
=
flash_decode_sparse_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
num_head
=
model_runner
.
model_config
.
num_attention_heads
self
.
head_dim
=
model_runner
.
model_config
.
hidden_size
//
self
.
num_head
self
.
heavy_token_num
=
model_runner
.
server_args
.
ds_heavy_token_num
self
.
sorted_channels
=
model_runner
.
sorted_channels
self
.
sparse_decode_thresold
=
(
model_runner
.
server_args
.
ds_sparse_decode_threshold
)
self
.
att_out_approx
:
torch
.
Tensor
=
None
self
.
mid_out
:
torch
.
Tensor
=
None
self
.
mid_o_logexpsum
:
torch
.
Tensor
=
None
# TODO: Change the hard-coded block_seq_num
self
.
BLOCK_SEQ
=
128
if
global_server_args_dict
.
get
(
"triton_attention_reduce_in_fp32"
,
False
):
self
.
reduce_dtype
=
torch
.
float32
else
:
self
.
reduce_dtype
=
torch
.
float16
self
.
forward_metadata
=
None
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init auxiliary variables for triton attention backend."""
if
forward_batch
.
forward_mode
.
is_decode
():
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
torch
.
sum
(
forward_batch
.
seq_lens
).
item
()
attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
total_num_tokens
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
max_seq_len
=
torch
.
max
(
forward_batch
.
seq_lens
).
item
()
min_seq_len
=
torch
.
min
(
forward_batch
.
seq_lens
).
item
()
max_extend_len
=
None
# NOTE: Align sequence order with req_to_token order
ds_req_to_token
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
]
bsz
=
forward_batch
.
seq_lens
.
shape
[
0
]
att_out_approx
=
torch
.
empty
(
[
self
.
num_head
,
bsz
,
max_seq_len
],
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
block_seq_num
=
(
self
.
heavy_token_num
+
self
.
BLOCK_SEQ
-
1
)
//
self
.
BLOCK_SEQ
mid_out
=
torch
.
empty
(
[
bsz
,
self
.
num_head
,
block_seq_num
,
self
.
head_dim
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
mid_o_logexpsum
=
torch
.
empty
(
[
bsz
,
self
.
num_head
,
block_seq_num
],
dtype
=
torch
.
float32
,
device
=
"cuda"
)
self
.
att_out_approx
=
att_out_approx
self
.
mid_out
=
mid_out
self
.
mid_o_logexpsum
=
mid_o_logexpsum
else
:
start_loc
=
attn_logits
=
max_seq_len
=
min_seq_len
=
None
prefix_lens
=
forward_batch
.
extend_prefix_lens
max_extend_len
=
torch
.
max
(
forward_batch
.
seq_lens
-
prefix_lens
).
item
()
ds_req_to_token
=
None
self
.
forward_metadata
=
(
start_loc
,
attn_logits
,
max_seq_len
,
min_seq_len
,
max_extend_len
,
ds_req_to_token
,
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
# TODO(Andy): Support CUDA graph for double sparse attention
raise
ValueError
(
"Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
self
.
cuda_graph_max_total_num_tokens
,
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
self
.
cuda_graph_max_seq_len
,
None
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
k_label
=
torch
.
gather
(
k
,
2
,
self
.
sorted_channels
[
layer
.
layer_id
]
.
unsqueeze
(
0
)
.
expand
(
k
.
shape
[
0
],
-
1
,
-
1
),
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
)
(
start_loc
,
attn_logits
,
max_seq_len
,
min_seq_len
,
max_extend_len
,
ds_req_to_token
,
)
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
v
.
contiguous
(),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
extend_seq_lens
,
forward_batch
.
extend_start_loc
,
max_extend_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
return
o
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
# TODO: Add min seqlen
(
start_loc
,
attn_logits
,
max_seq_len
,
min_seq_len
,
max_extend_len
,
ds_req_to_token
,
)
=
self
.
forward_metadata
k_label
=
torch
.
gather
(
k
,
2
,
self
.
sorted_channels
[
layer
.
layer_id
]
.
unsqueeze
(
0
)
.
expand
(
k
.
shape
[
0
],
-
1
,
-
1
),
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
)
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
# and set a minimum value for sparse_decode
if
(
min_seq_len
<
self
.
heavy_token_num
or
max_seq_len
<
self
.
sparse_decode_thresold
):
self
.
decode_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
start_loc
,
forward_batch
.
seq_lens
,
attn_logits
,
max_seq_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
else
:
# TODO(Andy): indexing with torch.gather or torch.index_select or customized kernel
q_label
=
torch
.
gather
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
2
,
self
.
sorted_channels
[
layer
.
layer_id
]
.
unsqueeze
(
0
)
.
expand
(
q
.
shape
[
0
],
-
1
,
-
1
),
)
self
.
decode_sparse_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
q_label
,
forward_batch
.
token_to_kv_pool
.
get_label_buffer
(
layer
.
layer_id
),
ds_req_to_token
,
forward_batch
.
seq_lens
,
max_seq_len
,
layer
.
scaling
,
layer
.
logit_cap
,
self
.
heavy_token_num
,
self
.
att_out_approx
,
self
.
mid_out
,
self
.
mid_o_logexpsum
,
self
.
BLOCK_SEQ
,
)
return
o
python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py
0 → 100644
View file @
061e5463
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
if
global_server_args_dict
.
get
(
"attention_reduce_in_fp32"
,
False
):
REDUCE_TRITON_TYPE
=
tl
.
float32
REDUCE_TORCH_TYPE
=
torch
.
float32
else
:
REDUCE_TRITON_TYPE
=
tl
.
float16
REDUCE_TORCH_TYPE
=
torch
.
float16
@
triton
.
jit
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
return
2
*
tl
.
sigmoid
(
2
*
x
)
-
1
@
triton
.
jit
def
_fwd_kernel_flash_decode_stage1
(
Q
,
K
,
V
,
sm_scale
,
Req_to_tokens
,
B_req_idx
,
B_Seqlen
,
Mid_O
,
# [batch, head, seq_block_num, head_dim]
Mid_O_LogExpSum
,
# [batch, head, seq_block_num]
stride_req_to_tokens_b
,
stride_req_to_tokens_s
,
stride_qbs
,
stride_qh
,
stride_qd
,
stride_kbs
,
stride_kh
,
stride_kd
,
stride_vbs
,
stride_vh
,
stride_vd
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
stride_mid_od
,
stride_mid_o_eb
,
stride_mid_o_eh
,
stride_mid_o_es
,
gqa_group_size
,
BLOCK_SEQ
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
seq_start_block
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
gqa_group_size
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
cur_batch_start_index
=
seq_start_block
*
BLOCK_SEQ
cur_batch_end_index
=
tl
.
minimum
(
cur_batch_seq_len
,
cur_batch_start_index
+
BLOCK_SEQ
)
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
block_n_size
=
(
tl
.
where
(
cur_batch_end_index
-
cur_batch_start_index
<=
0
,
0
,
cur_batch_end_index
-
cur_batch_start_index
+
BLOCK_N
-
1
,
)
//
BLOCK_N
)
offs_n
=
cur_batch_start_index
+
tl
.
arange
(
0
,
BLOCK_N
)
q
=
tl
.
load
(
Q
+
off_q
)
sum_exp
=
0.0
max_logic
=
-
float
(
"inf"
)
acc
=
tl
.
zeros
([
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
for
start_n
in
range
(
0
,
block_n_size
,
1
):
offs_n_new
=
start_n
*
BLOCK_N
+
offs_n
k_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n_new
,
mask
=
offs_n_new
<
cur_batch_end_index
,
other
=
0
,
)
off_k
=
k_loc
[:,
None
]
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[
None
,
:]
k
=
tl
.
load
(
K
+
off_k
,
mask
=
offs_n_new
[:,
None
]
<
cur_batch_end_index
,
other
=
0.0
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
att_value
*=
sm_scale
att_value
=
tl
.
where
(
offs_n_new
<
cur_batch_end_index
,
att_value
,
float
(
"-inf"
))
v
=
tl
.
load
(
V
+
off_k
,
mask
=
offs_n_new
[:,
None
]
<
cur_batch_end_index
,
other
=
0.0
)
cur_max_logic
=
tl
.
max
(
att_value
,
axis
=
0
)
new_max_logic
=
tl
.
maximum
(
cur_max_logic
,
max_logic
)
exp_logic
=
tl
.
exp
(
att_value
-
new_max_logic
)
logic_scale
=
tl
.
exp
(
max_logic
-
new_max_logic
)
acc
*=
logic_scale
acc
+=
tl
.
sum
(
exp_logic
[:,
None
]
*
v
,
axis
=
0
)
sum_exp
=
sum_exp
*
logic_scale
+
tl
.
sum
(
exp_logic
,
axis
=
0
)
max_logic
=
new_max_logic
need_store
=
tl
.
where
(
block_n_size
==
0
,
0
,
1
)
for
_
in
range
(
0
,
need_store
,
1
):
off_mid_o
=
(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
seq_start_block
*
stride_mid_os
+
offs_d
)
off_mid_o_logexpsum
=
(
cur_batch
*
stride_mid_o_eb
+
cur_head
*
stride_mid_o_eh
+
seq_start_block
)
tl
.
store
(
Mid_O
+
off_mid_o
,
acc
/
sum_exp
)
tl
.
store
(
Mid_O_LogExpSum
+
off_mid_o_logexpsum
,
max_logic
+
tl
.
log
(
sum_exp
))
return
@
triton
.
jit
def
_fwd_kernel_flash_decode_stage2
(
B_Seqlen
,
Mid_O
,
# [batch, head, seq_block_num, head_dim]
Mid_O_LogExpSum
,
# [batch, head, seq_block_num]
O
,
# [batch, head, head_dim]
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
stride_mid_od
,
stride_mid_o_eb
,
stride_mid_o_eh
,
stride_mid_o_es
,
stride_obs
,
stride_oh
,
stride_od
,
BLOCK_SEQ
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
block_n_size
=
(
tl
.
where
(
cur_batch_seq_len
<=
0
,
0
,
cur_batch_seq_len
+
BLOCK_SEQ
-
1
)
//
BLOCK_SEQ
)
sum_exp
=
0.0
max_logic
=
-
float
(
"inf"
)
acc
=
tl
.
zeros
([
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
offs_v
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
offs_d
offs_logic
=
cur_batch
*
stride_mid_o_eb
+
cur_head
*
stride_mid_o_eh
for
block_seq_n
in
range
(
0
,
block_n_size
,
1
):
tv
=
tl
.
load
(
Mid_O
+
offs_v
+
block_seq_n
*
stride_mid_os
)
tlogic
=
tl
.
load
(
Mid_O_LogExpSum
+
offs_logic
+
block_seq_n
)
new_max_logic
=
tl
.
maximum
(
tlogic
,
max_logic
)
old_scale
=
tl
.
exp
(
max_logic
-
new_max_logic
)
acc
*=
old_scale
exp_logic
=
tl
.
exp
(
tlogic
-
new_max_logic
)
acc
+=
exp_logic
*
tv
sum_exp
=
sum_exp
*
old_scale
+
exp_logic
max_logic
=
new_max_logic
tl
.
store
(
O
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
acc
/
sum_exp
)
return
@
torch
.
no_grad
()
def
flash_decode_stage1
(
q
,
k
,
v
,
Req_to_tokens
,
B_req_idx
,
B_Seqlen
,
max_len_in_batch
,
mid_out
,
mid_out_logsumexp
,
block_seq
,
):
BLOCK_SEQ
=
block_seq
BLOCK_N
=
16
assert
BLOCK_SEQ
%
BLOCK_N
==
0
# shape constraints
Lq
,
Lk
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
]
assert
Lq
==
Lk
assert
Lk
in
{
16
,
32
,
64
,
128
}
sm_scale
=
1.0
/
(
Lk
**
0.5
)
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
grid
=
(
batch
,
head_num
,
triton
.
cdiv
(
max_len_in_batch
,
BLOCK_SEQ
))
gqa_group_size
=
q
.
shape
[
1
]
//
k
.
shape
[
1
]
_fwd_kernel_flash_decode_stage1
[
grid
](
q
,
k
,
v
,
sm_scale
,
Req_to_tokens
,
B_req_idx
,
B_Seqlen
,
mid_out
,
mid_out_logsumexp
,
Req_to_tokens
.
stride
(
0
),
Req_to_tokens
.
stride
(
1
),
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
k
.
stride
(
0
),
k
.
stride
(
1
),
k
.
stride
(
2
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
mid_out
.
stride
(
0
),
mid_out
.
stride
(
1
),
mid_out
.
stride
(
2
),
mid_out
.
stride
(
3
),
mid_out_logsumexp
.
stride
(
0
),
mid_out_logsumexp
.
stride
(
1
),
mid_out_logsumexp
.
stride
(
2
),
gqa_group_size
,
BLOCK_SEQ
=
BLOCK_SEQ
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK_N
,
num_warps
=
1
,
num_stages
=
2
,
)
return
@
torch
.
no_grad
()
def
flash_decode_stage2
(
mid_out
,
mid_out_logexpsum
,
B_Seqlen
,
O
,
block_seq
):
Lk
=
mid_out
.
shape
[
-
1
]
assert
Lk
in
{
16
,
32
,
64
,
128
}
batch
,
head_num
=
mid_out
.
shape
[
0
],
mid_out
.
shape
[
1
]
grid
=
(
batch
,
head_num
)
_fwd_kernel_flash_decode_stage2
[
grid
](
B_Seqlen
,
mid_out
,
mid_out_logexpsum
,
O
,
mid_out
.
stride
(
0
),
mid_out
.
stride
(
1
),
mid_out
.
stride
(
2
),
mid_out
.
stride
(
3
),
mid_out_logexpsum
.
stride
(
0
),
mid_out_logexpsum
.
stride
(
1
),
mid_out_logexpsum
.
stride
(
2
),
O
.
stride
(
0
),
O
.
stride
(
1
),
O
.
stride
(
2
),
BLOCK_SEQ
=
block_seq
,
BLOCK_DMODEL
=
Lk
,
num_warps
=
4
,
num_stages
=
2
,
)
return
import
torch
def
flash_decode_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_start_loc
,
b_seq_len
,
attn_logits
,
max_len_in_batch
,
sm_scale
,
logit_cap
=
0.0
,
):
BLOCK_SEQ
=
256
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
1
]
# batch_size = q.shape[0]
block_seq_num
=
(
max_len_in_batch
+
BLOCK_SEQ
-
1
)
//
BLOCK_SEQ
mid_o
=
torch
.
empty
(
[
q
.
shape
[
0
],
q
.
shape
[
1
],
block_seq_num
,
q
.
shape
[
-
1
]],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
mid_o_logexpsum
=
torch
.
empty
(
[
q
.
shape
[
0
],
q
.
shape
[
1
],
block_seq_num
],
dtype
=
torch
.
float32
,
device
=
"cuda"
)
flash_decode_stage1
(
q
,
k_buffer
,
v_buffer
,
req_to_token
,
b_req_idx
,
b_seq_len
,
max_len_in_batch
,
mid_o
,
mid_o_logexpsum
,
BLOCK_SEQ
,
)
flash_decode_stage2
(
mid_o
,
mid_o_logexpsum
,
b_seq_len
,
o
,
BLOCK_SEQ
)
@
triton
.
jit
def
_sparse_fwd_kernel_flash_decode_stage1
(
# Double Sparsity's approximate attention
Q_Label
,
K_Label_Buffer
,
sm_scale
,
Req_to_tokens
,
# shape: [B, S]
B_Seqlen
,
Att_Out
,
# shape: [H, B, S] easier for topk
stride_req_to_tokens_b
,
stride_qbs
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kh
,
att_stride_h
,
att_stride_b
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
start_n
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
kv_group_num
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_start_index
=
0
cur_batch_end_index
=
cur_batch_seq_len
min_val
=
-
float
(
"inf"
)
att_value
=
tl
.
full
([
BLOCK_N
],
min_val
,
dtype
=
tl
.
float32
)
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
offs_n
=
start_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
block_index
=
start_n
*
BLOCK_N
block_mask
=
tl
.
where
(
block_index
<
cur_batch_seq_len
,
1
,
0
)
for
start_mark
in
range
(
0
,
block_mask
,
1
):
q
=
tl
.
load
(
Q_Label
+
off_q
+
start_mark
).
to
(
REDUCE_TRITON_TYPE
)
offs_n_new
=
cur_batch_start_index
+
offs_n
k_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch
+
offs_n_new
,
mask
=
offs_n_new
<
cur_batch_end_index
,
other
=
0
,
)
offs_buf_k
=
(
k_loc
[:,
None
]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[
None
,
:]
)
k
=
tl
.
load
(
K_Label_Buffer
+
offs_buf_k
,
mask
=
offs_n_new
[:,
None
]
<
cur_batch_end_index
,
other
=
0.0
,
).
to
(
REDUCE_TRITON_TYPE
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
att_value
*=
sm_scale
if
logit_cap
>
0
:
att_value
=
logit_cap
*
tanh
(
att_value
/
logit_cap
)
att_value
=
tl
.
where
(
offs_n
<
cur_batch_end_index
,
att_value
,
min_val
)
off_o
=
cur_head
*
att_stride_h
+
(
cur_batch
*
att_stride_b
+
offs_n
)
tl
.
store
(
Att_Out
+
off_o
,
att_value
)
@
triton
.
jit
def
_sparse_fwd_kernel_flash_decode_stage2
(
Q
,
K
,
V
,
sm_scale
,
Req_to_tokens
,
# shape: [B, S]
Topk_token_indices
,
# shape: [H, B, k]
Mid_O
,
# [batch, head, seq_block_num, head_dim]
Mid_O_LogExpSum
,
# [batch, head, seq_block_num]
Heavy_token_num
,
# NOTE: This can be used as constexpr but we may support dynamic heavy token number in the future
stride_req_to_tokens_b
,
stride_topk_token_indices_h
,
stride_topk_token_indices_b
,
stride_qbs
,
stride_qh
,
stride_kbs
,
stride_kh
,
stride_vbs
,
stride_vh
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
stride_mid_o_eb
,
stride_mid_o_eh
,
gqa_group_size
,
BLOCK_SEQ
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
seq_start_block
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
gqa_group_size
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
cur_batch_start_index
=
seq_start_block
*
BLOCK_SEQ
cur_batch_end_index
=
tl
.
minimum
(
Heavy_token_num
,
cur_batch_start_index
+
BLOCK_SEQ
)
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
block_n_size
=
(
tl
.
where
(
cur_batch_end_index
-
cur_batch_start_index
<=
0
,
0
,
cur_batch_end_index
-
cur_batch_start_index
+
BLOCK_N
-
1
,
)
//
BLOCK_N
)
# offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
q
=
tl
.
load
(
Q
+
off_q
)
sum_exp
=
0.0
max_logic
=
-
float
(
"inf"
)
acc
=
tl
.
zeros
([
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
for
start_n
in
range
(
cur_batch_start_index
,
cur_batch_end_index
,
BLOCK_N
):
# for start_n in range(0, block_n_size, 1):
# offs_n_new = start_n * BLOCK_N + offs_n
offs_n_new
=
start_n
+
offs_n
# offs_n_new = cur_batch_start_index + start_n * BLOCK_N + offs_n
topk_token_indices
=
tl
.
load
(
Topk_token_indices
+
stride_topk_token_indices_h
*
cur_head
+
stride_topk_token_indices_b
*
cur_batch
+
offs_n_new
,
mask
=
offs_n_new
<
cur_batch_end_index
,
other
=
0
,
)
k_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch
+
topk_token_indices
,
mask
=
offs_n_new
<
cur_batch_end_index
,
other
=
0
,
)
off_k
=
k_loc
[:,
None
]
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[
None
,
:]
k
=
tl
.
load
(
K
+
off_k
,
mask
=
offs_n_new
[:,
None
]
<
cur_batch_end_index
,
other
=
0.0
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
att_value
*=
sm_scale
att_value
=
tl
.
where
(
offs_n_new
<
cur_batch_end_index
,
att_value
,
float
(
"-inf"
))
v
=
tl
.
load
(
V
+
off_k
,
mask
=
offs_n_new
[:,
None
]
<
cur_batch_end_index
,
other
=
0.0
)
cur_max_logic
=
tl
.
max
(
att_value
,
axis
=
0
)
new_max_logic
=
tl
.
maximum
(
cur_max_logic
,
max_logic
)
exp_logic
=
tl
.
exp
(
att_value
-
new_max_logic
)
logic_scale
=
tl
.
exp
(
max_logic
-
new_max_logic
)
acc
*=
logic_scale
acc
+=
tl
.
sum
(
exp_logic
[:,
None
]
*
v
,
axis
=
0
)
sum_exp
=
sum_exp
*
logic_scale
+
tl
.
sum
(
exp_logic
,
axis
=
0
)
max_logic
=
new_max_logic
# need_store = tl.where(block_n_size == 0, 0, 1)
need_store
=
1
for
_
in
range
(
0
,
need_store
,
1
):
off_mid_o
=
(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
seq_start_block
*
stride_mid_os
+
offs_d
)
off_mid_o_logexpsum
=
(
cur_batch
*
stride_mid_o_eb
+
cur_head
*
stride_mid_o_eh
+
seq_start_block
)
tl
.
store
(
Mid_O
+
off_mid_o
,
acc
/
sum_exp
)
tl
.
store
(
Mid_O_LogExpSum
+
off_mid_o_logexpsum
,
max_logic
+
tl
.
log
(
sum_exp
))
return
@
triton
.
jit
def
_sparse_fwd_kernel_flash_decode_stage3
(
Mid_O
,
# [batch, head, seq_block_num, head_dim]
Mid_O_LogExpSum
,
# [batch, head, seq_block_num]
O
,
# [batch, head, head_dim]
seq_len
,
# NOTE: This can be used as constexpr but we may support dynamic heavy token number in the future
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
stride_mid_o_eb
,
stride_mid_o_eh
,
stride_obs
,
stride_oh
,
BLOCK_SEQ
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
block_n_size
=
tl
.
where
(
seq_len
<=
0
,
0
,
seq_len
+
BLOCK_SEQ
-
1
)
//
BLOCK_SEQ
sum_exp
=
0.0
max_logic
=
-
float
(
"inf"
)
acc
=
tl
.
zeros
([
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
offs_v
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
offs_d
offs_logic
=
cur_batch
*
stride_mid_o_eb
+
cur_head
*
stride_mid_o_eh
for
block_seq_n
in
range
(
0
,
block_n_size
,
1
):
tv
=
tl
.
load
(
Mid_O
+
offs_v
+
block_seq_n
*
stride_mid_os
)
tlogic
=
tl
.
load
(
Mid_O_LogExpSum
+
offs_logic
+
block_seq_n
)
new_max_logic
=
tl
.
maximum
(
tlogic
,
max_logic
)
old_scale
=
tl
.
exp
(
max_logic
-
new_max_logic
)
acc
*=
old_scale
exp_logic
=
tl
.
exp
(
tlogic
-
new_max_logic
)
acc
+=
exp_logic
*
tv
sum_exp
=
sum_exp
*
old_scale
+
exp_logic
max_logic
=
new_max_logic
tl
.
store
(
O
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
acc
/
sum_exp
)
return
def
sparse_flash_decode_stage1
(
q_label
,
k_label_buffer
,
att_out
,
Req_to_tokens
,
B_Seqlen
,
max_len_in_batch
,
sm_scale
,
logit_cap
,
):
BLOCK
=
32
# shape constraints
Lq
,
Lk
=
q_label
.
shape
[
-
1
],
k_label_buffer
.
shape
[
-
1
]
assert
Lq
==
Lk
assert
Lk
in
{
16
,
32
,
64
,
128
,
256
,
576
}
BLOCK_DMODEL
=
Lk
batch
,
head_num
=
q_label
.
shape
[
0
],
q_label
.
shape
[
1
]
grid
=
(
batch
,
head_num
,
triton
.
cdiv
(
max_len_in_batch
,
BLOCK
))
kv_group_num
=
q_label
.
shape
[
1
]
//
k_label_buffer
.
shape
[
1
]
if
kv_group_num
==
1
:
num_warps
=
4
else
:
num_warps
=
2
_sparse_fwd_kernel_flash_decode_stage1
[
grid
](
q_label
,
k_label_buffer
,
sm_scale
,
Req_to_tokens
,
B_Seqlen
,
att_out
,
Req_to_tokens
.
stride
(
0
),
q_label
.
stride
(
0
),
q_label
.
stride
(
1
),
k_label_buffer
.
stride
(
0
),
k_label_buffer
.
stride
(
1
),
att_out
.
stride
(
0
),
att_out
.
stride
(
1
),
kv_group_num
,
BLOCK_DMODEL
,
BLOCK
,
logit_cap
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
@
torch
.
no_grad
()
def
sparse_flash_decode_stage2
(
q
,
k
,
v
,
Req_to_tokens
,
Topk_token_indices
,
heavy_token_num
,
mid_out
,
mid_out_logsumexp
,
block_seq
,
sm_scale
,
):
BLOCK_SEQ
=
block_seq
BLOCK_N
=
16
assert
BLOCK_SEQ
%
BLOCK_N
==
0
# shape constraints
Lq
,
Lk
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
]
assert
Lq
==
Lk
assert
Lk
in
{
16
,
32
,
64
,
128
}
assert
heavy_token_num
==
Topk_token_indices
.
shape
[
-
1
]
# sm_scale = 1.0 / (Lk ** 0.5)
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
grid
=
(
batch
,
head_num
,
triton
.
cdiv
(
heavy_token_num
,
BLOCK_SEQ
))
gqa_group_size
=
q
.
shape
[
1
]
//
k
.
shape
[
1
]
_sparse_fwd_kernel_flash_decode_stage2
[
grid
](
q
,
k
,
v
,
sm_scale
,
Req_to_tokens
,
Topk_token_indices
,
mid_out
,
mid_out_logsumexp
,
heavy_token_num
,
Req_to_tokens
.
stride
(
0
),
Topk_token_indices
.
stride
(
0
),
Topk_token_indices
.
stride
(
1
),
q
.
stride
(
0
),
q
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
1
),
mid_out
.
stride
(
0
),
mid_out
.
stride
(
1
),
mid_out
.
stride
(
2
),
mid_out_logsumexp
.
stride
(
0
),
mid_out_logsumexp
.
stride
(
1
),
gqa_group_size
,
BLOCK_SEQ
=
BLOCK_SEQ
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK_N
,
num_warps
=
1
,
num_stages
=
2
,
)
return
@
torch
.
no_grad
()
def
sparse_flash_decode_stage3
(
Seqlen
,
mid_out
,
mid_out_logexpsum
,
O
,
block_seq
):
Lk
=
mid_out
.
shape
[
-
1
]
assert
Lk
in
{
16
,
32
,
64
,
128
}
batch
,
head_num
=
mid_out
.
shape
[
0
],
mid_out
.
shape
[
1
]
grid
=
(
batch
,
head_num
)
_sparse_fwd_kernel_flash_decode_stage3
[
grid
](
mid_out
,
mid_out_logexpsum
,
O
,
Seqlen
,
mid_out
.
stride
(
0
),
mid_out
.
stride
(
1
),
mid_out
.
stride
(
2
),
mid_out_logexpsum
.
stride
(
0
),
mid_out_logexpsum
.
stride
(
1
),
O
.
stride
(
0
),
O
.
stride
(
1
),
BLOCK_SEQ
=
block_seq
,
BLOCK_DMODEL
=
Lk
,
num_warps
=
4
,
num_stages
=
2
,
)
return
def
flash_decode_sparse_attention_fwd
(
q
,
k_buffer
,
v_buffer
,
o
,
q_label
,
k_label_buffer
,
req_to_token
,
b_seq_len
,
max_len_in_batch
,
sm_scale
,
logit_cap
,
heavy_token_num
=
32
,
att_out_approx
=
None
,
mid_out
=
None
,
mid_o_logexpsum
=
None
,
BLOCK_SEQ
=
256
,
):
# TODO(Andy): Tune BLOCK_SEQ & BLOCK_D
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
1
]
# batch_size = q.shape[0]
# Step 1: BGEMV approximate attention (page implementation)
if
att_out_approx
is
None
:
att_out_approx
=
torch
.
empty
(
[
q
.
shape
[
1
],
q
.
shape
[
0
],
max_len_in_batch
],
dtype
=
REDUCE_TORCH_TYPE
,
device
=
q
.
device
,
)
if
mid_out
is
None
:
block_seq_num
=
(
heavy_token_num
+
BLOCK_SEQ
-
1
)
//
BLOCK_SEQ
mid_out
=
torch
.
empty
(
[
q
.
shape
[
0
],
q
.
shape
[
1
],
block_seq_num
,
q
.
shape
[
-
1
]],
dtype
=
torch
.
float32
,
device
=
q
.
device
,
)
mid_o_logexpsum
=
torch
.
empty
(
[
q
.
shape
[
0
],
q
.
shape
[
1
],
block_seq_num
],
dtype
=
torch
.
float32
,
device
=
q
.
device
,
)
sparse_flash_decode_stage1
(
q_label
,
k_label_buffer
,
att_out_approx
,
req_to_token
,
b_seq_len
,
max_len_in_batch
,
sm_scale
,
logit_cap
,
)
# Step 2: TopK token selection
# NOTE(Andy): Apply sparse decoding when min > heavy_token_num and max > sparse decoding threshold
# TODO(Andy): Change a faster topk implementation
topk_token_indices
=
torch
.
topk
(
att_out_approx
,
heavy_token_num
,
dim
=-
1
).
indices
# topk_token_indices: [H, B, k], Req_to_tokens: [B, S]
# topk_token_indices = torch.arange(0, heavy_token_num, device=q.device).unsqueeze(0).unsqueeze(0).expand(q.shape[1], q.shape[0], -1)
sparse_flash_decode_stage2
(
q
,
k_buffer
,
v_buffer
,
req_to_token
,
topk_token_indices
,
heavy_token_num
,
mid_out
,
mid_o_logexpsum
,
BLOCK_SEQ
,
sm_scale
,
)
sparse_flash_decode_stage3
(
heavy_token_num
,
mid_out
,
mid_o_logexpsum
,
o
,
BLOCK_SEQ
)
python/sglang/srt/mem_cache/memory_pool.py
View file @
061e5463
...
...
@@ -231,3 +231,61 @@ class MLATokenToKVPool(BaseTokenToKVPool):
self
.
kv_buffer
[
layer_id
][
loc
]
=
cache_k
.
view
(
self
.
store_dtype
)
else
:
self
.
kv_buffer
[
layer_id
][
loc
]
=
cache_k
class
DoubleSparseTokenToKVPool
(
BaseTokenToKVPool
):
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
head_num
:
int
,
head_dim
:
int
,
layer_num
:
int
,
device
:
str
,
heavy_channel_num
:
int
,
):
super
().
__init__
(
size
,
dtype
,
device
)
# [size, head_num, head_dim] for each layer
self
.
k_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
device
)
for
_
in
range
(
layer_num
)
]
self
.
v_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
device
)
for
_
in
range
(
layer_num
)
]
# [size, head_num, heavy_channel_num] for each layer
self
.
label_buffer
=
[
torch
.
empty
(
(
size
+
1
,
head_num
,
heavy_channel_num
),
dtype
=
dtype
,
device
=
device
)
for
_
in
range
(
layer_num
)
]
def
get_key_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
return
self
.
v_buffer
[
layer_id
]
def
get_label_buffer
(
self
,
layer_id
:
int
):
return
self
.
label_buffer
[
layer_id
]
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
],
self
.
v_buffer
[
layer_id
]
def
set_kv_buffer
(
self
,
layer_id
:
int
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
cache_label
:
torch
.
Tensor
,
):
# NOTE(Andy): ignore the dtype check
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
self
.
label_buffer
[
layer_id
][
loc
]
=
cache_label
python/sglang/srt/model_executor/model_runner.py
View file @
061e5463
...
...
@@ -18,6 +18,7 @@ limitations under the License.
import
gc
import
importlib
import
importlib.resources
import
json
import
logging
import
pkgutil
from
functools
import
lru_cache
...
...
@@ -39,6 +40,7 @@ from vllm.model_executor.models import ModelRegistry
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.constrained
import
disable_cache
from
sglang.srt.layers.attention.double_sparsity_backend
import
DoubleSparseAttnBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
FlashInferAttnBackend
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
...
...
@@ -46,6 +48,7 @@ from sglang.srt.layers.sampler import Sampler
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
(
DoubleSparseTokenToKVPool
,
MHATokenToKVPool
,
MLATokenToKVPool
,
ReqToTokenPool
,
...
...
@@ -99,6 +102,20 @@ class ModelRunner:
logger
.
info
(
"MLA optimization is turned on. Use triton backend."
)
self
.
server_args
.
attention_backend
=
"triton"
if
self
.
server_args
.
enable_double_sparsity
:
logger
.
info
(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
)
self
.
server_args
.
attention_backend
=
"triton"
self
.
server_args
.
disable_cuda_graph
=
True
if
self
.
server_args
.
ds_heavy_channel_type
is
None
:
raise
ValueError
(
"Please specify the heavy channel type for double sparsity optimization."
)
self
.
init_double_sparsity_channel_config
(
self
.
server_args
.
ds_heavy_channel_type
)
if
self
.
is_multimodal_model
:
logger
.
info
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
...
...
@@ -439,6 +456,16 @@ class ModelRunner:
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
self
.
device
,
)
elif
self
.
server_args
.
enable_double_sparsity
:
self
.
token_to_kv_pool
=
DoubleSparseTokenToKVPool
(
self
.
max_total_num_tokens
,
dtype
=
self
.
kv_cache_dtype
,
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
self
.
device
,
heavy_channel_num
=
self
.
server_args
.
ds_heavy_channel_num
,
)
else
:
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
self
.
max_total_num_tokens
,
...
...
@@ -475,12 +502,33 @@ class ModelRunner:
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
self
.
attn_backend
=
TritonAttnBackend
(
self
)
if
self
.
server_args
.
enable_double_sparsity
:
self
.
attn_backend
=
DoubleSparseAttnBackend
(
self
)
else
:
self
.
attn_backend
=
TritonAttnBackend
(
self
)
else
:
raise
ValueError
(
f
"Invalid attention backend:
{
self
.
server_args
.
attention_backend
}
"
)
def
init_double_sparsity_channel_config
(
self
,
selected_channel
):
selected_channel
=
"."
+
selected_channel
+
"_proj"
self
.
sorted_channels
=
[]
# load channel config
with
open
(
self
.
server_args
.
ds_channel_config_path
,
"r"
)
as
f
:
channel_config
=
json
.
load
(
f
)
for
i
in
range
(
self
.
model_config
.
num_hidden_layers
):
key
=
"model.layers."
+
str
(
i
)
+
".self_attn"
+
selected_channel
self
.
sorted_channels
.
append
(
torch
.
tensor
(
channel_config
[
key
])[
:,
:
self
.
server_args
.
ds_heavy_channel_num
]
.
contiguous
()
.
cuda
()
)
def
init_cuda_graphs
(
self
):
"""Capture cuda graphs."""
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
...
...
python/sglang/srt/server_args.py
View file @
061e5463
...
...
@@ -86,6 +86,14 @@ class ServerArgs:
# Model override args in JSON
json_model_override_args
:
str
=
"{}"
# Double Sparsity
enable_double_sparsity
:
bool
=
False
ds_channel_config_path
:
str
=
None
ds_heavy_channel_num
:
int
=
32
ds_heavy_token_num
:
int
=
256
ds_heavy_channel_type
:
str
=
"qk"
ds_sparse_decode_threshold
:
int
=
4096
# LoRA
lora_paths
:
Optional
[
List
[
str
]]
=
None
max_loras_per_batch
:
int
=
8
...
...
@@ -443,6 +451,43 @@ class ServerArgs:
default
=
ServerArgs
.
json_model_override_args
,
)
# Double Sparsity
parser
.
add_argument
(
"--enable-double-sparsity"
,
action
=
"store_true"
,
help
=
"Enable double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-channel-config-path"
,
type
=
str
,
default
=
ServerArgs
.
ds_channel_config_path
,
help
=
"The path of the double sparsity channel config"
,
)
parser
.
add_argument
(
"--ds-heavy-channel-num"
,
type
=
int
,
default
=
ServerArgs
.
ds_heavy_channel_num
,
help
=
"The number of heavy channels in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-heavy-token-num"
,
type
=
int
,
default
=
ServerArgs
.
ds_heavy_token_num
,
help
=
"The number of heavy tokens in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-heavy-channel-type"
,
type
=
str
,
default
=
ServerArgs
.
ds_heavy_channel_type
,
help
=
"The type of heavy channels in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-sparse-decode-threshold"
,
type
=
int
,
default
=
ServerArgs
.
ds_sparse_decode_threshold
,
help
=
"The type of heavy channels in double sparsity attention"
,
)
# LoRA
parser
.
add_argument
(
"--lora-paths"
,
...
...
test/srt/Llama-3.1-8B-Instruct.json
0 → 100644
View file @
061e5463
This source diff could not be displayed because it is too large. You can
view the blob
instead.
test/srt/run_suite.py
View file @
061e5463
...
...
@@ -11,6 +11,7 @@ suites = {
"models/test_reward_models.py"
,
"sampling/penaltylib"
,
"test_chunked_prefill.py"
,
"test_double_sparsity.py"
,
"test_embedding_openai_server.py"
,
"test_eval_accuracy_mini.py"
,
"test_json_constrained.py"
,
...
...
test/srt/test_double_sparsity.py
0 → 100644
View file @
061e5463
import
os
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestDoubleSparsity
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
dirpath
=
os
.
path
.
dirname
(
__file__
)
config_file
=
os
.
path
.
join
(
dirpath
,
"Llama-3.1-8B-Instruct.json"
)
# NOTE: Generate the config file by running https://github.com/andy-yang-1/DoubleSparse/blob/main/evaluation/group_channel_config.py
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--enable-double-sparsity"
,
"--ds-channel-config-path"
,
config_file
,
"--ds-heavy-channel-num"
,
"32"
,
"--ds-heavy-channel-type"
,
"k"
,
"--ds-heavy-token-num"
,
"512"
,
"--ds-sparse-decode-threshold"
,
"0"
,
"--max-total-tokens"
,
"200000"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.65
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment