Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
061e5463
"vscode:/vscode.git/clone" did not exist on "9aa6fcab60962398650d6b09cc48b140665b3364"
Unverified
Commit
061e5463
authored
Oct 14, 2024
by
Shuo Yang
Committed by
GitHub
Oct 14, 2024
Browse files
Support double sparsity (#1459)
parent
0c1e8796
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1269 additions
and
1 deletion
+1269
-1
python/sglang/srt/layers/attention/double_sparsity_backend.py
...on/sglang/srt/layers/attention/double_sparsity_backend.py
+281
-0
python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py
.../layers/attention/triton_ops/double_sparsity_attention.py
+772
-0
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+58
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+49
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+45
-0
test/srt/Llama-3.1-8B-Instruct.json
test/srt/Llama-3.1-8B-Instruct.json
+1
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_double_sparsity.py
test/srt/test_double_sparsity.py
+62
-0
No files found.
python/sglang/srt/layers/attention/double_sparsity_backend.py
0 → 100644
View file @
061e5463
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
import
torch
import
torch.nn
as
nn
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
class
DoubleSparseAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Lazy import to avoid the initialization of cuda context
from
sglang.srt.layers.attention.triton_ops.double_sparsity_attention
import
(
flash_decode_attention_fwd
,
flash_decode_sparse_attention_fwd
,
)
from
sglang.srt.layers.attention.triton_ops.extend_attention
import
(
extend_attention_fwd
,
)
super
().
__init__
()
self
.
decode_attention_fwd
=
flash_decode_attention_fwd
self
.
decode_sparse_attention_fwd
=
flash_decode_sparse_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
num_head
=
model_runner
.
model_config
.
num_attention_heads
self
.
head_dim
=
model_runner
.
model_config
.
hidden_size
//
self
.
num_head
self
.
heavy_token_num
=
model_runner
.
server_args
.
ds_heavy_token_num
self
.
sorted_channels
=
model_runner
.
sorted_channels
self
.
sparse_decode_thresold
=
(
model_runner
.
server_args
.
ds_sparse_decode_threshold
)
self
.
att_out_approx
:
torch
.
Tensor
=
None
self
.
mid_out
:
torch
.
Tensor
=
None
self
.
mid_o_logexpsum
:
torch
.
Tensor
=
None
# TODO: Change the hard-coded block_seq_num
self
.
BLOCK_SEQ
=
128
if
global_server_args_dict
.
get
(
"triton_attention_reduce_in_fp32"
,
False
):
self
.
reduce_dtype
=
torch
.
float32
else
:
self
.
reduce_dtype
=
torch
.
float16
self
.
forward_metadata
=
None
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init auxiliary variables for triton attention backend."""
if
forward_batch
.
forward_mode
.
is_decode
():
start_loc
=
torch
.
zeros_like
(
forward_batch
.
seq_lens
,
dtype
=
torch
.
int32
)
start_loc
[
1
:]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
[:
-
1
],
dim
=
0
)
total_num_tokens
=
torch
.
sum
(
forward_batch
.
seq_lens
).
item
()
attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
total_num_tokens
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
max_seq_len
=
torch
.
max
(
forward_batch
.
seq_lens
).
item
()
min_seq_len
=
torch
.
min
(
forward_batch
.
seq_lens
).
item
()
max_extend_len
=
None
# NOTE: Align sequence order with req_to_token order
ds_req_to_token
=
forward_batch
.
req_to_token_pool
.
req_to_token
[
forward_batch
.
req_pool_indices
]
bsz
=
forward_batch
.
seq_lens
.
shape
[
0
]
att_out_approx
=
torch
.
empty
(
[
self
.
num_head
,
bsz
,
max_seq_len
],
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
block_seq_num
=
(
self
.
heavy_token_num
+
self
.
BLOCK_SEQ
-
1
)
//
self
.
BLOCK_SEQ
mid_out
=
torch
.
empty
(
[
bsz
,
self
.
num_head
,
block_seq_num
,
self
.
head_dim
],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
mid_o_logexpsum
=
torch
.
empty
(
[
bsz
,
self
.
num_head
,
block_seq_num
],
dtype
=
torch
.
float32
,
device
=
"cuda"
)
self
.
att_out_approx
=
att_out_approx
self
.
mid_out
=
mid_out
self
.
mid_o_logexpsum
=
mid_o_logexpsum
else
:
start_loc
=
attn_logits
=
max_seq_len
=
min_seq_len
=
None
prefix_lens
=
forward_batch
.
extend_prefix_lens
max_extend_len
=
torch
.
max
(
forward_batch
.
seq_lens
-
prefix_lens
).
item
()
ds_req_to_token
=
None
self
.
forward_metadata
=
(
start_loc
,
attn_logits
,
max_seq_len
,
min_seq_len
,
max_extend_len
,
ds_req_to_token
,
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
# TODO(Andy): Support CUDA graph for double sparse attention
raise
ValueError
(
"Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
self
.
cuda_graph_max_total_num_tokens
,
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
self
.
cuda_graph_max_seq_len
,
None
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
k_label
=
torch
.
gather
(
k
,
2
,
self
.
sorted_channels
[
layer
.
layer_id
]
.
unsqueeze
(
0
)
.
expand
(
k
.
shape
[
0
],
-
1
,
-
1
),
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
)
(
start_loc
,
attn_logits
,
max_seq_len
,
min_seq_len
,
max_extend_len
,
ds_req_to_token
,
)
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
v
.
contiguous
(),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
extend_seq_lens
,
forward_batch
.
extend_start_loc
,
max_extend_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
return
o
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q
=
q
.
reshape
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
qk_head_dim
)
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
o
=
q
.
new_empty
((
q
.
shape
[
0
],
layer
.
tp_q_head_num
*
layer
.
v_head_dim
))
else
:
o
=
torch
.
empty_like
(
q
)
# TODO: Add min seqlen
(
start_loc
,
attn_logits
,
max_seq_len
,
min_seq_len
,
max_extend_len
,
ds_req_to_token
,
)
=
self
.
forward_metadata
k_label
=
torch
.
gather
(
k
,
2
,
self
.
sorted_channels
[
layer
.
layer_id
]
.
unsqueeze
(
0
)
.
expand
(
k
.
shape
[
0
],
-
1
,
-
1
),
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
.
layer_id
,
forward_batch
.
out_cache_loc
,
k
,
v
,
k_label
)
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
# and set a minimum value for sparse_decode
if
(
min_seq_len
<
self
.
heavy_token_num
or
max_seq_len
<
self
.
sparse_decode_thresold
):
self
.
decode_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
start_loc
,
forward_batch
.
seq_lens
,
attn_logits
,
max_seq_len
,
layer
.
scaling
,
layer
.
logit_cap
,
)
else
:
# TODO(Andy): indexing with torch.gather or torch.index_select or customized kernel
q_label
=
torch
.
gather
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
2
,
self
.
sorted_channels
[
layer
.
layer_id
]
.
unsqueeze
(
0
)
.
expand
(
q
.
shape
[
0
],
-
1
,
-
1
),
)
self
.
decode_sparse_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
q_label
,
forward_batch
.
token_to_kv_pool
.
get_label_buffer
(
layer
.
layer_id
),
ds_req_to_token
,
forward_batch
.
seq_lens
,
max_seq_len
,
layer
.
scaling
,
layer
.
logit_cap
,
self
.
heavy_token_num
,
self
.
att_out_approx
,
self
.
mid_out
,
self
.
mid_o_logexpsum
,
self
.
BLOCK_SEQ
,
)
return
o
python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py
0 → 100644
View file @
061e5463
This diff is collapsed.
Click to expand it.
python/sglang/srt/mem_cache/memory_pool.py
View file @
061e5463
...
...
@@ -231,3 +231,61 @@ class MLATokenToKVPool(BaseTokenToKVPool):
self
.
kv_buffer
[
layer_id
][
loc
]
=
cache_k
.
view
(
self
.
store_dtype
)
else
:
self
.
kv_buffer
[
layer_id
][
loc
]
=
cache_k
class
DoubleSparseTokenToKVPool
(
BaseTokenToKVPool
):
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
head_num
:
int
,
head_dim
:
int
,
layer_num
:
int
,
device
:
str
,
heavy_channel_num
:
int
,
):
super
().
__init__
(
size
,
dtype
,
device
)
# [size, head_num, head_dim] for each layer
self
.
k_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
device
)
for
_
in
range
(
layer_num
)
]
self
.
v_buffer
=
[
torch
.
empty
((
size
+
1
,
head_num
,
head_dim
),
dtype
=
dtype
,
device
=
device
)
for
_
in
range
(
layer_num
)
]
# [size, head_num, heavy_channel_num] for each layer
self
.
label_buffer
=
[
torch
.
empty
(
(
size
+
1
,
head_num
,
heavy_channel_num
),
dtype
=
dtype
,
device
=
device
)
for
_
in
range
(
layer_num
)
]
def
get_key_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
return
self
.
v_buffer
[
layer_id
]
def
get_label_buffer
(
self
,
layer_id
:
int
):
return
self
.
label_buffer
[
layer_id
]
def
get_kv_buffer
(
self
,
layer_id
:
int
):
return
self
.
k_buffer
[
layer_id
],
self
.
v_buffer
[
layer_id
]
def
set_kv_buffer
(
self
,
layer_id
:
int
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
cache_label
:
torch
.
Tensor
,
):
# NOTE(Andy): ignore the dtype check
self
.
k_buffer
[
layer_id
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
][
loc
]
=
cache_v
self
.
label_buffer
[
layer_id
][
loc
]
=
cache_label
python/sglang/srt/model_executor/model_runner.py
View file @
061e5463
...
...
@@ -18,6 +18,7 @@ limitations under the License.
import
gc
import
importlib
import
importlib.resources
import
json
import
logging
import
pkgutil
from
functools
import
lru_cache
...
...
@@ -39,6 +40,7 @@ from vllm.model_executor.models import ModelRegistry
from
sglang.srt.configs.model_config
import
AttentionArch
,
ModelConfig
from
sglang.srt.constrained
import
disable_cache
from
sglang.srt.layers.attention.double_sparsity_backend
import
DoubleSparseAttnBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
FlashInferAttnBackend
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
...
...
@@ -46,6 +48,7 @@ from sglang.srt.layers.sampler import Sampler
from
sglang.srt.lora.lora_manager
import
LoRAManager
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.mem_cache.memory_pool
import
(
DoubleSparseTokenToKVPool
,
MHATokenToKVPool
,
MLATokenToKVPool
,
ReqToTokenPool
,
...
...
@@ -99,6 +102,20 @@ class ModelRunner:
logger
.
info
(
"MLA optimization is turned on. Use triton backend."
)
self
.
server_args
.
attention_backend
=
"triton"
if
self
.
server_args
.
enable_double_sparsity
:
logger
.
info
(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
)
self
.
server_args
.
attention_backend
=
"triton"
self
.
server_args
.
disable_cuda_graph
=
True
if
self
.
server_args
.
ds_heavy_channel_type
is
None
:
raise
ValueError
(
"Please specify the heavy channel type for double sparsity optimization."
)
self
.
init_double_sparsity_channel_config
(
self
.
server_args
.
ds_heavy_channel_type
)
if
self
.
is_multimodal_model
:
logger
.
info
(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
...
...
@@ -439,6 +456,16 @@ class ModelRunner:
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
self
.
device
,
)
elif
self
.
server_args
.
enable_double_sparsity
:
self
.
token_to_kv_pool
=
DoubleSparseTokenToKVPool
(
self
.
max_total_num_tokens
,
dtype
=
self
.
kv_cache_dtype
,
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
self
.
device
,
heavy_channel_num
=
self
.
server_args
.
ds_heavy_channel_num
,
)
else
:
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
self
.
max_total_num_tokens
,
...
...
@@ -475,12 +502,33 @@ class ModelRunner:
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
self
.
attn_backend
=
TritonAttnBackend
(
self
)
if
self
.
server_args
.
enable_double_sparsity
:
self
.
attn_backend
=
DoubleSparseAttnBackend
(
self
)
else
:
self
.
attn_backend
=
TritonAttnBackend
(
self
)
else
:
raise
ValueError
(
f
"Invalid attention backend:
{
self
.
server_args
.
attention_backend
}
"
)
def
init_double_sparsity_channel_config
(
self
,
selected_channel
):
selected_channel
=
"."
+
selected_channel
+
"_proj"
self
.
sorted_channels
=
[]
# load channel config
with
open
(
self
.
server_args
.
ds_channel_config_path
,
"r"
)
as
f
:
channel_config
=
json
.
load
(
f
)
for
i
in
range
(
self
.
model_config
.
num_hidden_layers
):
key
=
"model.layers."
+
str
(
i
)
+
".self_attn"
+
selected_channel
self
.
sorted_channels
.
append
(
torch
.
tensor
(
channel_config
[
key
])[
:,
:
self
.
server_args
.
ds_heavy_channel_num
]
.
contiguous
()
.
cuda
()
)
def
init_cuda_graphs
(
self
):
"""Capture cuda graphs."""
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
...
...
python/sglang/srt/server_args.py
View file @
061e5463
...
...
@@ -86,6 +86,14 @@ class ServerArgs:
# Model override args in JSON
json_model_override_args
:
str
=
"{}"
# Double Sparsity
enable_double_sparsity
:
bool
=
False
ds_channel_config_path
:
str
=
None
ds_heavy_channel_num
:
int
=
32
ds_heavy_token_num
:
int
=
256
ds_heavy_channel_type
:
str
=
"qk"
ds_sparse_decode_threshold
:
int
=
4096
# LoRA
lora_paths
:
Optional
[
List
[
str
]]
=
None
max_loras_per_batch
:
int
=
8
...
...
@@ -443,6 +451,43 @@ class ServerArgs:
default
=
ServerArgs
.
json_model_override_args
,
)
# Double Sparsity
parser
.
add_argument
(
"--enable-double-sparsity"
,
action
=
"store_true"
,
help
=
"Enable double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-channel-config-path"
,
type
=
str
,
default
=
ServerArgs
.
ds_channel_config_path
,
help
=
"The path of the double sparsity channel config"
,
)
parser
.
add_argument
(
"--ds-heavy-channel-num"
,
type
=
int
,
default
=
ServerArgs
.
ds_heavy_channel_num
,
help
=
"The number of heavy channels in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-heavy-token-num"
,
type
=
int
,
default
=
ServerArgs
.
ds_heavy_token_num
,
help
=
"The number of heavy tokens in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-heavy-channel-type"
,
type
=
str
,
default
=
ServerArgs
.
ds_heavy_channel_type
,
help
=
"The type of heavy channels in double sparsity attention"
,
)
parser
.
add_argument
(
"--ds-sparse-decode-threshold"
,
type
=
int
,
default
=
ServerArgs
.
ds_sparse_decode_threshold
,
help
=
"The type of heavy channels in double sparsity attention"
,
)
# LoRA
parser
.
add_argument
(
"--lora-paths"
,
...
...
test/srt/Llama-3.1-8B-Instruct.json
0 → 100644
View file @
061e5463
This diff is collapsed.
Click to expand it.
test/srt/run_suite.py
View file @
061e5463
...
...
@@ -11,6 +11,7 @@ suites = {
"models/test_reward_models.py"
,
"sampling/penaltylib"
,
"test_chunked_prefill.py"
,
"test_double_sparsity.py"
,
"test_embedding_openai_server.py"
,
"test_eval_accuracy_mini.py"
,
"test_json_constrained.py"
,
...
...
test/srt/test_double_sparsity.py
0 → 100644
View file @
061e5463
import
os
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestDoubleSparsity
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
dirpath
=
os
.
path
.
dirname
(
__file__
)
config_file
=
os
.
path
.
join
(
dirpath
,
"Llama-3.1-8B-Instruct.json"
)
# NOTE: Generate the config file by running https://github.com/andy-yang-1/DoubleSparse/blob/main/evaluation/group_channel_config.py
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--enable-double-sparsity"
,
"--ds-channel-config-path"
,
config_file
,
"--ds-heavy-channel-num"
,
"32"
,
"--ds-heavy-channel-type"
,
"k"
,
"--ds-heavy-token-num"
,
"512"
,
"--ds-sparse-decode-threshold"
,
"0"
,
"--max-total-tokens"
,
"200000"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.65
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment