Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e119f042
"vscode:/vscode.git/clone" did not exist on "f5dca445a2a22945a3a34fbf9abe409e23f83fc5"
Unverified
Commit
e119f042
authored
Apr 01, 2025
by
Zhiqiang Xie
Committed by
GitHub
Apr 01, 2025
Browse files
Large page size aligned hierarchical caching (#4581)
parent
9eb49e87
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
242 additions
and
71 deletions
+242
-71
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+34
-11
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-1
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+60
-50
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+66
-6
python/sglang/srt/mem_cache/paged_allocator.py
python/sglang/srt/mem_cache/paged_allocator.py
+24
-0
test/srt/test_hicache.py
test/srt/test_hicache.py
+4
-2
test/srt/test_hicache_mla.py
test/srt/test_hicache_mla.py
+4
-1
test/srt/test_hicache_page.py
test/srt/test_hicache_page.py
+49
-0
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
e119f042
...
...
@@ -149,6 +149,7 @@ class HiCacheController:
self
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
mem_pool_host
:
HostKVCache
,
page_size
:
int
,
load_cache_event
:
threading
.
Event
=
None
,
write_policy
:
str
=
"write_through_selective"
,
):
...
...
@@ -156,6 +157,7 @@ class HiCacheController:
self
.
mem_pool_device
=
token_to_kv_pool_allocator
.
get_kvcache
()
self
.
mem_pool_host
=
mem_pool_host
self
.
write_policy
=
write_policy
self
.
page_size
=
page_size
self
.
load_cache_event
=
load_cache_event
self
.
layer_done_counter
=
LayerDoneCounter
(
self
.
mem_pool_device
.
layer_num
)
...
...
@@ -184,7 +186,12 @@ class HiCacheController:
self
.
load_stream
=
torch
.
cuda
.
Stream
()
self
.
write_thread
=
threading
.
Thread
(
target
=
self
.
write_thread_func_buffer
,
daemon
=
True
target
=
(
self
.
write_thread_func_buffer
if
self
.
page_size
==
1
else
self
.
write_thread_func_direct
),
daemon
=
True
,
)
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_layer_by_layer
,
daemon
=
True
...
...
@@ -205,7 +212,12 @@ class HiCacheController:
self
.
ack_load_queue
.
queue
.
clear
()
self
.
write_thread
=
threading
.
Thread
(
target
=
self
.
write_thread_func_buffer
,
daemon
=
True
target
=
(
self
.
write_thread_func_buffer
if
self
.
page_size
==
1
else
self
.
write_thread_func_direct
),
daemon
=
True
,
)
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_layer_by_layer
,
daemon
=
True
...
...
@@ -260,10 +272,12 @@ class HiCacheController:
while
not
self
.
stop_event
.
is_set
():
try
:
operation
=
self
.
write_queue
.
get
(
block
=
True
,
timeout
=
1
)
operation
.
data
=
self
.
mem_pool_device
.
get_flat_data
(
operation
.
device_indices
self
.
mem_pool_host
.
write_page_all_layers
(
operation
.
host_indices
,
operation
.
device_indices
,
self
.
mem_pool_device
,
)
self
.
mem_pool_host
.
transfer
(
operation
.
host_indices
,
operation
.
data
)
self
.
write_stream
.
synchronize
(
)
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
for
node_id
in
operation
.
node_ids
:
if
node_id
!=
0
:
...
...
@@ -320,12 +334,21 @@ class HiCacheController:
self
.
layer_done_counter
.
reset
()
for
i
in
range
(
self
.
mem_pool_host
.
layer_num
):
flat_data
=
self
.
mem_pool_host
.
get_flat_data_by_layer
(
batch_operation
.
host_indices
,
i
)
self
.
mem_pool_device
.
transfer_per_layer
(
batch_operation
.
device_indices
,
flat_data
,
i
)
if
self
.
page_size
==
1
:
flat_data
=
self
.
mem_pool_host
.
get_flat_data_by_layer
(
batch_operation
.
host_indices
,
i
)
self
.
mem_pool_device
.
transfer_per_layer
(
batch_operation
.
device_indices
,
flat_data
,
i
)
else
:
self
.
mem_pool_host
.
load_page_per_layer
(
batch_operation
.
host_indices
,
batch_operation
.
device_indices
,
self
.
mem_pool_device
,
i
,
)
self
.
load_stream
.
synchronize
()
self
.
layer_done_counter
.
increment
()
self
.
mem_pool_host
.
complete_io
(
batch_operation
.
host_indices
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
e119f042
...
...
@@ -1282,7 +1282,7 @@ class Scheduler(
]
if
self
.
enable_hierarchical_cache
:
self
.
tree_cache
.
read_to_load_cache
()
self
.
tree_cache
.
read
y
_to_load_cache
()
if
adder
.
new_chunked_req
is
not
None
:
assert
self
.
chunked_req
is
None
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
e119f042
...
...
@@ -16,7 +16,6 @@ from sglang.srt.mem_cache.memory_pool import (
TokenToKVPoolAllocator
,
)
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
from
sglang.srt.mem_cache.radix_cache
import
_key_match_page_size1
as
_key_match
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -31,29 +30,25 @@ class HiRadixCache(RadixCache):
page_size
:
int
,
hicache_ratio
:
float
,
):
if
page_size
!=
1
:
raise
ValueError
(
"Page size larger than 1 is not yet supported in HiRadixCache."
)
self
.
kv_cache
=
token_to_kv_pool_allocator
.
get_kvcache
()
if
isinstance
(
self
.
kv_cache
,
MHATokenToKVPool
):
self
.
token_to_kv_pool_host
=
MHATokenToKVPoolHost
(
self
.
kv_cache
,
hicache_ratio
self
.
kv_cache
,
hicache_ratio
,
page_size
)
elif
isinstance
(
self
.
kv_cache
,
MLATokenToKVPool
):
self
.
token_to_kv_pool_host
=
MLATokenToKVPoolHost
(
self
.
kv_cache
,
hicache_ratio
self
.
kv_cache
,
hicache_ratio
,
page_size
)
else
:
raise
ValueError
(
f
"
Only MHA and MLA supports swap kv_cache to host.
"
)
raise
ValueError
(
f
"
HiRadixCache only supports MHA and MLA yet
"
)
self
.
tp_group
=
tp_cache_group
self
.
page_size
=
page_size
self
.
load_cache_event
=
threading
.
Event
()
self
.
cache_controller
=
HiCacheController
(
token_to_kv_pool_allocator
,
self
.
token_to_kv_pool_host
,
page_size
,
load_cache_event
=
self
.
load_cache_event
,
)
...
...
@@ -65,7 +60,7 @@ class HiRadixCache(RadixCache):
self
.
write_through_threshold
=
1
self
.
load_back_threshold
=
10
super
().
__init__
(
req_to_token_pool
,
token_to_kv_pool_allocator
,
self
.
page_size
,
disable
=
False
req_to_token_pool
,
token_to_kv_pool_allocator
,
page_size
,
disable
=
False
)
def
reset
(
self
):
...
...
@@ -299,18 +294,26 @@ class HiRadixCache(RadixCache):
return
last_node
,
prefix_indices
def
read_to_load_cache
(
self
):
def
read
y
_to_load_cache
(
self
):
self
.
load_cache_event
.
set
()
def
match_prefix
(
self
,
key
:
List
[
int
],
include_evicted
=
False
,
**
kwargs
):
if
self
.
disable
:
return
[],
self
.
root_node
empty_value
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
if
self
.
disable
or
len
(
key
)
==
0
:
if
include_evicted
:
return
empty_value
,
self
.
root_node
,
self
.
root_node
else
:
return
empty_value
,
self
.
root_node
if
self
.
page_size
!=
1
:
page_aligned_len
=
len
(
key
)
//
self
.
page_size
*
self
.
page_size
key
=
key
[:
page_aligned_len
]
value
,
last_node
=
self
.
_match_prefix_helper
(
self
.
root_node
,
key
)
if
value
:
value
=
torch
.
cat
(
value
)
else
:
value
=
torch
.
tensor
([],
dtype
=
torch
.
int64
)
value
=
empty_value
last_node_global
=
last_node
while
last_node
.
evicted
:
...
...
@@ -323,11 +326,13 @@ class HiRadixCache(RadixCache):
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
):
node
.
last_access_time
=
time
.
time
()
child_key
=
self
.
get_child_key_fn
(
key
)
value
=
[]
while
len
(
key
)
>
0
and
key
[
0
]
in
node
.
children
.
keys
():
child
=
node
.
children
[
key
[
0
]]
while
len
(
key
)
>
0
and
child_key
in
node
.
children
.
keys
():
child
=
node
.
children
[
child_key
]
child
.
last_access_time
=
time
.
time
()
prefix_len
=
_
key_match
(
child
.
key
,
key
)
prefix_len
=
self
.
key_match
_fn
(
child
.
key
,
key
)
if
prefix_len
<
len
(
child
.
key
):
new_node
=
self
.
_split_node
(
child
.
key
,
child
,
prefix_len
)
if
not
new_node
.
evicted
:
...
...
@@ -339,12 +344,16 @@ class HiRadixCache(RadixCache):
value
.
append
(
child
.
value
)
node
=
child
key
=
key
[
prefix_len
:]
if
len
(
key
):
child_key
=
self
.
get_child_key_fn
(
key
)
return
value
,
node
def
_split_node
(
self
,
key
,
child
:
TreeNode
,
split_len
:
int
):
# child node split into new_node -> child
new_node
=
TreeNode
()
new_node
.
children
=
{
key
[
split_len
]
:
child
}
new_node
.
children
=
{
self
.
get_child_key_fn
(
key
[
split_len
:])
:
child
}
new_node
.
parent
=
child
.
parent
new_node
.
lock_ref
=
child
.
lock_ref
new_node
.
key
=
child
.
key
[:
split_len
]
...
...
@@ -361,7 +370,7 @@ class HiRadixCache(RadixCache):
child
.
host_value
=
child
.
host_value
[
split_len
:]
child
.
parent
=
new_node
child
.
key
=
child
.
key
[
split_len
:]
new_node
.
parent
.
children
[
key
[
0
]
]
=
new_node
new_node
.
parent
.
children
[
self
.
get_child_key_fn
(
key
)
]
=
new_node
return
new_node
def
_insert_helper
(
self
,
node
:
TreeNode
,
key
:
List
,
value
):
...
...
@@ -369,52 +378,53 @@ class HiRadixCache(RadixCache):
if
len
(
key
)
==
0
:
return
0
if
key
[
0
]
in
node
.
children
.
keys
():
child
=
node
.
children
[
key
[
0
]]
prefix_len
=
_key_match
(
child
.
key
,
key
)
child_key
=
self
.
get_child_key_fn
(
key
)
total_prefix_length
=
0
while
len
(
key
)
>
0
and
child_key
in
node
.
children
.
keys
():
node
=
node
.
children
[
child_key
]
node
.
last_access_time
=
time
.
time
()
prefix_len
=
self
.
key_match_fn
(
node
.
key
,
key
)
if
prefix_len
==
len
(
child
.
key
):
if
child
.
evicted
:
if
prefix_len
==
len
(
node
.
key
):
if
node
.
evicted
:
# change the reference if the node is evicted
# this often happens in the case of KV cache recomputation
child
.
value
=
value
[:
prefix_len
]
self
.
token_to_kv_pool_host
.
update_synced
(
child
.
host_value
)
self
.
evictable_size_
+=
len
(
value
[:
prefix_len
])
return
self
.
_insert_helper
(
child
,
key
[
prefix_len
:],
value
[
prefix_len
:]
)
node
.
value
=
value
[:
prefix_len
]
self
.
token_to_kv_pool_host
.
update_synced
(
node
.
host_value
)
self
.
evictable_size_
+=
len
(
node
.
value
)
else
:
self
.
inc_hit_count
(
child
)
return
prefix_len
+
self
.
_insert_helper
(
child
,
key
[
prefix_len
:],
value
[
prefix_len
:]
)
# partial match, split the node
new_node
=
self
.
_split_node
(
child
.
key
,
child
,
prefix_len
)
if
new_node
.
evicted
:
new_node
.
value
=
value
[:
prefix_len
]
self
.
token_to_kv_pool_host
.
update_synced
(
new_node
.
host_value
)
self
.
evictable_size_
+=
len
(
new_node
.
value
)
return
self
.
_insert_helper
(
new_node
,
key
[
prefix_len
:],
value
[
prefix_len
:]
)
self
.
inc_hit_count
(
node
)
total_prefix_length
+=
prefix_len
else
:
self
.
inc_hit_count
(
new_node
)
return
prefix_len
+
self
.
_insert_helper
(
new_node
,
key
[
prefix_len
:],
value
[
prefix_len
:]
)
# partial match, split the node
new_node
=
self
.
_split_node
(
node
.
key
,
node
,
prefix_len
)
if
new_node
.
evicted
:
new_node
.
value
=
value
[:
prefix_len
]
self
.
token_to_kv_pool_host
.
update_synced
(
new_node
.
host_value
)
self
.
evictable_size_
+=
len
(
new_node
.
value
)
else
:
self
.
inc_hit_count
(
new_node
)
total_prefix_length
+=
prefix_len
node
=
new_node
key
=
key
[
prefix_len
:]
value
=
value
[
prefix_len
:]
if
len
(
key
):
child_key
=
self
.
get_child_key_fn
(
key
)
if
len
(
key
):
new_node
=
TreeNode
()
new_node
.
parent
=
node
new_node
.
key
=
key
new_node
.
value
=
value
node
.
children
[
key
[
0
]
]
=
new_node
node
.
children
[
child_
key
]
=
new_node
self
.
evictable_size_
+=
len
(
value
)
if
self
.
cache_controller
.
write_policy
==
"write_through"
:
self
.
write_backup
(
new_node
)
return
0
return
total_prefix_length
def
_collect_leaves_device
(
self
):
def
is_leaf
(
node
):
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
e119f042
...
...
@@ -608,8 +608,9 @@ class HostKVCache(abc.ABC):
self
,
device_pool
:
MHATokenToKVPool
,
host_to_device_ratio
:
float
,
pin_memory
:
bool
=
False
,
# no need to use pin memory with the double buffering
device
:
str
=
"cpu"
,
pin_memory
:
bool
,
device
:
str
,
page_size
:
int
,
):
assert
(
host_to_device_ratio
>=
1
...
...
@@ -620,8 +621,11 @@ class HostKVCache(abc.ABC):
self
.
host_to_device_ratio
=
host_to_device_ratio
self
.
pin_memory
=
pin_memory
self
.
device
=
device
self
.
page_size
=
page_size
self
.
size
=
int
(
device_pool
.
size
*
host_to_device_ratio
)
# Align the host memory pool size to the page size
self
.
size
=
self
.
size
-
(
self
.
size
%
self
.
page_size
)
self
.
dtype
=
device_pool
.
store_dtype
self
.
size_per_token
=
self
.
get_size_per_token
()
...
...
@@ -775,10 +779,13 @@ class MHATokenToKVPoolHost(HostKVCache):
self
,
device_pool
:
MHATokenToKVPool
,
host_to_device_ratio
:
float
,
pin_memory
:
bool
=
False
,
# no need to use pin memory with the double buffering
page_size
:
int
,
pin_memory
:
bool
=
True
,
device
:
str
=
"cpu"
,
):
super
().
__init__
(
device_pool
,
host_to_device_ratio
,
pin_memory
,
device
)
super
().
__init__
(
device_pool
,
host_to_device_ratio
,
pin_memory
,
device
,
page_size
)
def
get_size_per_token
(
self
):
self
.
head_num
=
self
.
device_pool
.
head_num
...
...
@@ -811,16 +818,48 @@ class MHATokenToKVPoolHost(HostKVCache):
def
assign_flat_data
(
self
,
indices
,
flat_data
):
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
def
write_page_all_layers
(
self
,
host_indices
,
device_indices
,
device_pool
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
for
j
in
range
(
self
.
layer_num
):
self
.
kv_buffer
[
0
,
j
,
h_index
:
h_index
+
self
.
page_size
].
copy_
(
device_pool
.
k_buffer
[
j
][
d_index
:
d_index
+
self
.
page_size
],
non_blocking
=
True
,
)
self
.
kv_buffer
[
1
,
j
,
h_index
:
h_index
+
self
.
page_size
].
copy_
(
device_pool
.
v_buffer
[
j
][
d_index
:
d_index
+
self
.
page_size
],
non_blocking
=
True
,
)
def
load_page_per_layer
(
self
,
host_indices
,
device_indices
,
device_pool
,
layer_id
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
device_pool
.
k_buffer
[
layer_id
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
0
,
layer_id
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
)
device_pool
.
v_buffer
[
layer_id
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
1
,
layer_id
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
)
class
MLATokenToKVPoolHost
(
HostKVCache
):
def
__init__
(
self
,
device_pool
:
MLATokenToKVPool
,
host_to_device_ratio
:
float
,
pin_memory
:
bool
=
False
,
# no need to use pin memory with the double buffering
page_size
:
int
,
pin_memory
:
bool
=
True
,
device
:
str
=
"cpu"
,
):
super
().
__init__
(
device_pool
,
host_to_device_ratio
,
pin_memory
,
device
)
super
().
__init__
(
device_pool
,
host_to_device_ratio
,
pin_memory
,
device
,
page_size
)
def
get_size_per_token
(
self
):
self
.
kv_lora_rank
=
self
.
device_pool
.
kv_lora_rank
...
...
@@ -857,3 +896,24 @@ class MLATokenToKVPoolHost(HostKVCache):
def
assign_flat_data
(
self
,
indices
,
flat_data
):
self
.
kv_buffer
[:,
indices
]
=
flat_data
def
write_page_all_layers
(
self
,
host_indices
,
device_indices
,
device_pool
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
for
j
in
range
(
self
.
layer_num
):
self
.
kv_buffer
[
j
,
h_index
:
h_index
+
self
.
page_size
].
copy_
(
device_pool
.
kv_buffer
[
j
][
d_index
:
d_index
+
self
.
page_size
],
non_blocking
=
True
,
)
def
load_page_per_layer
(
self
,
host_indices
,
device_indices
,
device_pool
,
layer_id
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
device_pool
.
kv_buffer
[
layer_id
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
layer_id
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
)
python/sglang/srt/mem_cache/paged_allocator.py
View file @
e119f042
...
...
@@ -190,6 +190,30 @@ class PagedTokenToKVPoolAllocator:
def
available_size
(
self
):
return
len
(
self
.
free_pages
)
*
self
.
page_size
def
get_kvcache
(
self
):
return
self
.
_kvcache
def
alloc
(
self
,
need_size
:
int
):
# page-aligned allocation, returning contiguous indices of pages
if
self
.
debug_mode
:
assert
(
need_size
%
self
.
page_size
==
0
),
"The allocation size should be page-aligned"
num_pages
=
need_size
//
self
.
page_size
if
num_pages
>
len
(
self
.
free_pages
):
return
None
out_pages
=
self
.
free_pages
[:
num_pages
]
self
.
free_pages
=
self
.
free_pages
[
num_pages
:]
out_indices
=
(
out_pages
[:,
None
]
*
self
.
page_size
+
torch
.
arange
(
self
.
page_size
,
device
=
self
.
device
)
).
reshape
(
-
1
)
return
out_indices
def
alloc_extend
(
self
,
prefix_lens
:
torch
.
Tensor
,
...
...
test/srt/test_hicache.py
View file @
e119f042
...
...
@@ -12,7 +12,7 @@ from sglang.test.test_utils import (
)
class
Test
PageSiz
e
(
CustomTestCase
):
class
Test
HiCach
e
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
...
...
@@ -21,7 +21,9 @@ class TestPageSize(CustomTestCase):
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--enable-hierarchical-cache"
],
other_args
=
[
"--enable-hierarchical-cache"
,
],
)
@
classmethod
...
...
test/srt/test_hicache_mla.py
View file @
e119f042
...
...
@@ -21,7 +21,10 @@ class TestHierarchicalMLA(CustomTestCase):
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--trust-remote-code"
,
"--enable-hierarchical-cache"
],
other_args
=
[
"--trust-remote-code"
,
"--enable-hierarchical-cache"
,
],
)
@
classmethod
...
...
test/srt/test_hicache_page.py
0 → 100644
View file @
e119f042
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
)
class
TestHiCachePage
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--enable-hierarchical-cache"
,
"--page-size"
,
"32"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.65
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment