Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
061e5463
Unverified
Commit
061e5463
authored
Oct 14, 2024
by
Shuo Yang
Committed by
GitHub
Oct 14, 2024
Browse files
Support double sparsity (#1459)
parent
0c1e8796
Changes
8
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1269 additions
and
1 deletion
+1269
-1
python/sglang/srt/layers/attention/double_sparsity_backend.py
...on/sglang/srt/layers/attention/double_sparsity_backend.py
+281
-0
python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py
.../layers/attention/triton_ops/double_sparsity_attention.py
+772
-0
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+58
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+49
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+45
-0
test/srt/Llama-3.1-8B-Instruct.json
test/srt/Llama-3.1-8B-Instruct.json
+1
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_double_sparsity.py
test/srt/test_double_sparsity.py
+62
-0
No files found.
python/sglang/srt/layers/attention/double_sparsity_backend.py
0 → 100644
View file @
061e5463
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
import
torch
import
torch.nn
as
nn
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
class
DoubleSparseAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Lazy import to avoid the initialization of cuda context
from
sglang.srt.layers.attention.triton_ops.double_sparsity_attention
import
(
flash_decode_attention_fwd
,
flash_decode_sparse_attention_fwd
,
)
from
sglang.srt.layers.attention.triton_ops.extend_attention
import
(
extend_attention_fwd
,
)
super
().
__init__
()
self
.
decode_attention_fwd
=
flash_decode_attention_fwd
self
.
decode_sparse_attention_fwd
=
flash_decode_sparse_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
num_head
=
model_runner
.
model_config
.
num_attention_heads
self
.
head_dim
=
model_runner
.
model_config
.
hidden_size
//
self
.
num_head
self
.
heavy_token_num
=
model_runner
.
server_args
.
ds_heavy_token_num
self
.
sorted_channels
=
model_runner
.
sorted_channels
self
.
sparse_decode_thresold
=
(
model_runner
.
server_args
.
ds_sparse_decode_threshold
)
self
.
att_out_approx
:
torch
.
Tensor
=
None
self
.
mid_out
:
torch
.
Tensor
=
None
self
.
mid_o_logexpsum
:
torch
.
Tensor
=
None
# TODO: Change the hard-coded block_seq_num
self
.
BLOCK_SEQ
=
128
if
global_server_args_dict
.
get
(
"triton_attention_reduce_in_fp32"
,
False
):
self
.
reduce_dtype
=
torch
.
float32
else
:
self
.
reduce_dtype
=
torch
.
float16
self
.
forward_metadata
=
None
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init auxiliary variables for triton attention backend."""
if
forward_batch
.
forward_mode
.
is_decode
():
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
torch
.
sum
(
forward_batch
.
seq_lens
).
item
()
attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
total_num_tokens
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
max_seq_len
=
torch
.
max
(
forward_batch
.
seq_lens
).
item
()
min_seq_len
=
torch
.
min
(
forward_batch
.
seq_lens
).
item
()
max_extend_len
=
None
# NOTE: Align sequence order with req_to_token order
ds_req_to_token
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
]
bsz
=
forward_batch
.
seq_lens
.
shape
[
0
]
att_out_approx
=
torch
.
empty
(
[
self
.
num_head
,
bsz
,
max_seq_len
],
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
block_seq_num
=
(
self
.
heavy_token_num
+
self
.
BLOCK_SEQ
-
1
)
//
self
.
BLOCK_SEQ
mid_out
=
torch
.
empty
(
[
bsz
,
self
.
num_head
,
block_seq_num
,
self
.
head_dim
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
mid_o_logexpsum
=
torch
.
empty
(
[
bsz
,
self
.
num_head
,
block_seq_num
],
dtype
=
torch
.
float32
,
device
=
"cuda"
)
self
.
att_out_approx
=
att_out_approx
self
.
mid_out
=
mid_out
self
.
mid_o_logexpsum
=
mid_o_logexpsum
else
:
start_loc
=
attn_logits
=
max_seq_len
=
min_seq_len
=
None
prefix_lens
=
forward_batch
.
extend_prefix_lens
max_extend_len
=
torch
.
max
(
forward_batch
.
seq_lens
-
prefix_lens
).
item
()
ds_req_to_token
=
None
self
.
forward_metadata
=
(
start_loc
,
attn_logits
,
max_seq_len
,
min_seq_len
,
max_extend_len
,
ds_req_to_token
,
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
# TODO(Andy): Support CUDA graph for double sparse attention
raise
ValueError
(
"Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
self
.
cuda_graph_max_total_num_tokens
,
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
self
.
cuda_graph_max_seq_len
,
None
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
k_label
=
torch
.
gather
(
k
,
2
,
self
.
sorted_channels
[
layer
.
layer_id
]
.
unsqueeze
(
0
)
.
expand
(
k
.
shape
[
0
],
-
1
,
-
1
),
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
)
(
start_loc
,
attn_logits
,
max_seq_len
,
min_seq_len
,
max_extend_len
,
ds_req_to_token
,
)
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
v
.
contiguous
(),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
extend_seq_lens
,
forward_batch
.
extend_start_loc
,
max_extend_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
return
o
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
# TODO: Add min seqlen
(
start_loc
,
attn_logits
,
max_seq_len
,
min_seq_len
,
max_extend_len
,
ds_req_to_token
,
)
=
self
.
forward_metadata
k_label
=
torch
.
gather
(
k
,
2
,
self
.
sorted_channels
[
layer
.
layer_id
]
.
unsqueeze
(
0
)
.
expand
(
k
.
shape
[
0
],
-
1
,
-
1
),
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
)
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
# and set a minimum value for sparse_decode
if
(
min_seq_len
<
self
.
heavy_token_num
or
max_seq_len
<
self
.
sparse_decode_thresold
):
self
.
decode_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
start_loc
,
forward_batch
.
seq_lens
,
attn_logits
,
max_seq_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
else
:
# TODO(Andy): indexing with torch.gather or torch.index_select or customized kernel
q_label
=
torch
.
gather
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
2
,
self
.
sorted_channels
[
layer
.
layer_id
]
.
unsqueeze
(
0
)
.
expand
(
q
.
shape
[
0
],
-
1
,
-
1
),
)
self
.
decode_sparse_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
q_label
,
forward_batch
.
token_to_kv_pool
.
get_label_buffer
(
layer
.
layer_id
),
ds_req_to_token
,
forward_batch
.
seq_lens
,
max_seq_len
,
layer
.
scaling
,
layer
.
logit_cap
,
self
.
heavy_token_num
,
self
.
att_out_approx
,
self
.
mid_out
,
self
.
mid_o_logexpsum
,
self
.
BLOCK_SEQ
,
)
return
o
python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py
0 → 100644
View file @
061e5463
This diff is collapsed.
Click to expand it.
python/sglang/srt/mem_cache/memory_pool.py
View file @
061e5463
...
@@ -231,3 +231,61 @@ class MLATokenToKVPool(BaseTokenToKVPool):
...
@@ -231,3 +231,61 @@ class MLATokenToKVPool(BaseTokenToKVPool):
self
.
kv_buffer
[
layer_id
][
loc
]
=
cache_k
.
view
(
self
.
store_dtype
)
self
.
kv_buffer
[
layer_id
][
loc
]
=
cache_k
.
view
(
self
.
store_dtype
)
else
:
else
:
self
.
kv_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
kv_buffer
[
layer_id
][
loc
]
=
cache_k
class
DoubleSparseTokenToKVPool
(
BaseTokenToKVPool
):
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
head_num
:
int
,
head_dim
:
int
,
layer_num
:
int
,
device
:
str
,
heavy_channel_num
:
int
,
):
super
().
__init__
(
size
,
dtype
,
device
)
# [size, head_num, head_dim] for each layer
self
.
k_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
device
)
for
_
in
range
(
layer_num
)
]
self
.
v_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
device
)
for
_
in
range
(
layer_num
)
]
# [size, head_num, heavy_channel_num] for each layer
self
.
label_buffer
=
[
torch
.
empty
(
(
size
+
1
,
head_num
,
heavy_channel_num
),
dtype
=
dtype
,
device
=
device
)
for
_
in
range
(
layer_num
)
]
def
get_key_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
return
self
.
v_buffer
[
layer_id
]
def
get_label_buffer
(
self
,
layer_id
:
int
):
return
self
.
label_buffer
[
layer_id
]
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
],
self
.
v_buffer
[
layer_id
]
def
set_kv_buffer
(
self
,
layer_id
:
int
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
cache_label
:
torch
.
Tensor
,
):
# NOTE(Andy): ignore the dtype check
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
self
.
label_buffer
[
layer_id
][
loc
]
=
cache_label
python/sglang/srt/model_executor/model_runner.py
View file @
061e5463
...
@@ -18,6 +18,7 @@ limitations under the License.
...
@@ -18,6 +18,7 @@ limitations under the License.
import
gc
import
gc
import
importlib
import
importlib
import
importlib.resources
import
importlib.resources
import
json
import
logging
import
logging
import
pkgutil
import
pkgutil
from
functools
import
lru_cache
from
functools
import
lru_cache
...
@@ -39,6 +40,7 @@ from vllm.model_executor.models import ModelRegistry
...
@@ -39,6 +40,7 @@ from vllm.model_executor.models import ModelRegistry
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.constrained
import
disable_cache
from
sglang.srt.constrained
import
disable_cache
from
sglang.srt.layers.attention.double_sparsity_backend
import
DoubleSparseAttnBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
FlashInferAttnBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
FlashInferAttnBackend
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
...
@@ -46,6 +48,7 @@ from sglang.srt.layers.sampler import Sampler
...
@@ -46,6 +48,7 @@ from sglang.srt.layers.sampler import Sampler
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
(
from
sglang.srt.mem_cache.memory_pool
import
(
DoubleSparseTokenToKVPool
,
MHATokenToKVPool
,
MHATokenToKVPool
,
MLATokenToKVPool
,
MLATokenToKVPool
,
ReqToTokenPool
,
ReqToTokenPool
,
...
@@ -99,6 +102,20 @@ class ModelRunner:
...
@@ -99,6 +102,20 @@ class ModelRunner:
logger
.
info
(
"MLA optimization is turned on. Use triton backend."
)
logger
.
info
(
"MLA optimization is turned on. Use triton backend."
)
self
.
server_args
.
attention_backend
=
"triton"
self
.
server_args
.
attention_backend
=
"triton"
if
self
.
server_args
.
enable_double_sparsity
:
logger
.
info
(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
)
self
.
server_args
.
attention_backend
=
"triton"
self
.
server_args
.
disable_cuda_graph
=
True
if
self
.
server_args
.
ds_heavy_channel_type
is
None
:
raise
ValueError
(
"Please specify the heavy channel type for double sparsity optimization."
)
self
.
init_double_sparsity_channel_config
(
self
.
server_args
.
ds_heavy_channel_type
)
if
self
.
is_multimodal_model
:
if
self
.
is_multimodal_model
:
logger
.
info
(
logger
.
info
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
...
@@ -439,6 +456,16 @@ class ModelRunner:
...
@@ -439,6 +456,16 @@ class ModelRunner:
layer_num
=
self
.
model_config
.
num_hidden_layers
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
elif
self
.
server_args
.
enable_double_sparsity
:
self
.
token_to_kv_pool
=
DoubleSparseTokenToKVPool
(
self
.
max_total_num_tokens
,
dtype
=
self
.
kv_cache_dtype
,
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
self
.
device
,
heavy_channel_num
=
self
.
server_args
.
ds_heavy_channel_num
,
)
else
:
else
:
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
self
.
max_total_num_tokens
,
self
.
max_total_num_tokens
,
...
@@ -475,12 +502,33 @@ class ModelRunner:
...
@@ -475,12 +502,33 @@ class ModelRunner:
"Cross attention is not supported in the triton attention backend. "
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
"Please use `--attention-backend flashinfer`."
)
)
if
self
.
server_args
.
enable_double_sparsity
:
self
.
attn_backend
=
DoubleSparseAttnBackend
(
self
)
else
:
self
.
attn_backend
=
TritonAttnBackend
(
self
)
self
.
attn_backend
=
TritonAttnBackend
(
self
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Invalid attention backend:
{
self
.
server_args
.
attention_backend
}
"
f
"Invalid attention backend:
{
self
.
server_args
.
attention_backend
}
"
)
)
def
init_double_sparsity_channel_config
(
self
,
selected_channel
):
selected_channel
=
"."
+
selected_channel
+
"_proj"
self
.
sorted_channels
=
[]
# load channel config
with
open
(
self
.
server_args
.
ds_channel_config_path
,
"r"
)
as
f
:
channel_config
=
json
.
load
(
f
)
for
i
in
range
(
self
.
model_config
.
num_hidden_layers
):
key
=
"model.layers."
+
str
(
i
)
+
".self_attn"
+
selected_channel
self
.
sorted_channels
.
append
(
torch
.
tensor
(
channel_config
[
key
])[
:,
:
self
.
server_args
.
ds_heavy_channel_num
]
.
contiguous
()
.
cuda
()
)
def
init_cuda_graphs
(
self
):
def
init_cuda_graphs
(
self
):
"""Capture cuda graphs."""
"""Capture cuda graphs."""
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
...
...
python/sglang/srt/server_args.py
View file @
061e5463
...
@@ -86,6 +86,14 @@ class ServerArgs:
...
@@ -86,6 +86,14 @@ class ServerArgs:
# Model override args in JSON
# Model override args in JSON
json_model_override_args
:
str
=
"{}"
json_model_override_args
:
str
=
"{}"
# Double Sparsity
enable_double_sparsity
:
bool
=
False
ds_channel_config_path
:
str
=
None
ds_heavy_channel_num
:
int
=
32
ds_heavy_token_num
:
int
=
256
ds_heavy_channel_type
:
str
=
"qk"
ds_sparse_decode_threshold
:
int
=
4096
# LoRA
# LoRA
lora_paths
:
Optional
[
List
[
str
]]
=
None
lora_paths
:
Optional
[
List
[
str
]]
=
None
max_loras_per_batch
:
int
=
8
max_loras_per_batch
:
int
=
8
...
@@ -443,6 +451,43 @@ class ServerArgs:
...
@@ -443,6 +451,43 @@ class ServerArgs:
default
=
ServerArgs
.
json_model_override_args
,
default
=
ServerArgs
.
json_model_override_args
,
)
)
# Double Sparsity
parser
.
add_argument
(
"--enable-double-sparsity"
,
action
=
"store_true"
,
help
=
"Enable double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-channel-config-path"
,
type
=
str
,
default
=
ServerArgs
.
ds_channel_config_path
,
help
=
"The path of the double sparsity channel config"
,
)
parser
.
add_argument
(
"--ds-heavy-channel-num"
,
type
=
int
,
default
=
ServerArgs
.
ds_heavy_channel_num
,
help
=
"The number of heavy channels in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-heavy-token-num"
,
type
=
int
,
default
=
ServerArgs
.
ds_heavy_token_num
,
help
=
"The number of heavy tokens in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-heavy-channel-type"
,
type
=
str
,
default
=
ServerArgs
.
ds_heavy_channel_type
,
help
=
"The type of heavy channels in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-sparse-decode-threshold"
,
type
=
int
,
default
=
ServerArgs
.
ds_sparse_decode_threshold
,
help
=
"The type of heavy channels in double sparsity attention"
,
)
# LoRA
# LoRA
parser
.
add_argument
(
parser
.
add_argument
(
"--lora-paths"
,
"--lora-paths"
,
...
...
test/srt/Llama-3.1-8B-Instruct.json
0 → 100644
View file @
061e5463
This diff is collapsed.
Click to expand it.
test/srt/run_suite.py
View file @
061e5463
...
@@ -11,6 +11,7 @@ suites = {
...
@@ -11,6 +11,7 @@ suites = {
"models/test_reward_models.py"
,
"models/test_reward_models.py"
,
"sampling/penaltylib"
,
"sampling/penaltylib"
,
"test_chunked_prefill.py"
,
"test_chunked_prefill.py"
,
"test_double_sparsity.py"
,
"test_embedding_openai_server.py"
,
"test_embedding_openai_server.py"
,
"test_eval_accuracy_mini.py"
,
"test_eval_accuracy_mini.py"
,
"test_json_constrained.py"
,
"test_json_constrained.py"
,
...
...
test/srt/test_double_sparsity.py
0 → 100644
View file @
061e5463
import
os
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestDoubleSparsity
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
dirpath
=
os
.
path
.
dirname
(
__file__
)
config_file
=
os
.
path
.
join
(
dirpath
,
"Llama-3.1-8B-Instruct.json"
)
# NOTE: Generate the config file by running https://github.com/andy-yang-1/DoubleSparse/blob/main/evaluation/group_channel_config.py
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--enable-double-sparsity"
,
"--ds-channel-config-path"
,
config_file
,
"--ds-heavy-channel-num"
,
"32"
,
"--ds-heavy-channel-type"
,
"k"
,
"--ds-heavy-token-num"
,
"512"
,
"--ds-sparse-decode-threshold"
,
"0"
,
"--max-total-tokens"
,
"200000"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.65
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment