Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0e7409ad
Unverified
Commit
0e7409ad
authored
Dec 06, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 06, 2024
Browse files
Fix the overlap for xgrammar (#2377)
parent
3cde5eb6
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
145 additions
and
133 deletions
+145
-133
docs/references/supported_models.md
docs/references/supported_models.md
+1
-1
python/sglang/srt/constrained/outlines_backend.py
python/sglang/srt/constrained/outlines_backend.py
+5
-0
python/sglang/srt/constrained/xgrammar_backend.py
python/sglang/srt/constrained/xgrammar_backend.py
+5
-5
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+69
-65
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+2
-1
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+9
-8
test/srt/test_json_constrained.py
test/srt/test_json_constrained.py
+54
-53
No files found.
docs/references/supported_models.md
View file @
0e7409ad
...
@@ -106,4 +106,4 @@ def import_new_model_classes():
...
@@ -106,4 +106,4 @@ def import_new_model_classes():
ModelRegistry
.
models
.
update
(
import_new_model_classes
())
ModelRegistry
.
models
.
update
(
import_new_model_classes
())
launch_server
(
server_args
)
launch_server
(
server_args
)
```
```
\ No newline at end of file
python/sglang/srt/constrained/outlines_backend.py
View file @
0e7409ad
...
@@ -42,6 +42,7 @@ class OutlinesGrammar(BaseGrammarObject):
...
@@ -42,6 +42,7 @@ class OutlinesGrammar(BaseGrammarObject):
self
.
guide
=
guide
self
.
guide
=
guide
self
.
jump_forward_map
=
jump_forward_map
self
.
jump_forward_map
=
jump_forward_map
self
.
state
=
0
self
.
state
=
0
self
.
finished
=
False
def
accept_token
(
self
,
token
:
int
):
def
accept_token
(
self
,
token
:
int
):
self
.
state
=
self
.
guide
.
get_next_state
(
self
.
state
,
token
)
self
.
state
=
self
.
guide
.
get_next_state
(
self
.
state
,
token
)
...
@@ -84,6 +85,10 @@ class OutlinesGrammar(BaseGrammarObject):
...
@@ -84,6 +85,10 @@ class OutlinesGrammar(BaseGrammarObject):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
return
torch
.
zeros
(
batch_size
,
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
@
staticmethod
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
return
vocab_mask
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
def
fill_vocab_mask
(
self
,
vocab_mask
:
torch
.
Tensor
,
idx
:
int
)
->
None
:
tokens
=
torch
.
tensor
(
tokens
=
torch
.
tensor
(
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
,
dtype
=
torch
.
int64
self
.
guide
.
get_next_instruction
(
self
.
state
).
tokens
,
dtype
=
torch
.
int64
...
...
python/sglang/srt/constrained/xgrammar_backend.py
View file @
0e7409ad
...
@@ -45,6 +45,7 @@ class XGrammarGrammar(BaseGrammarObject):
...
@@ -45,6 +45,7 @@ class XGrammarGrammar(BaseGrammarObject):
self
.
matcher
=
matcher
self
.
matcher
=
matcher
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
ctx
=
ctx
self
.
ctx
=
ctx
self
.
finished
=
False
def
accept_token
(
self
,
token
:
int
):
def
accept_token
(
self
,
token
:
int
):
assert
self
.
matcher
.
accept_token
(
token
)
assert
self
.
matcher
.
accept_token
(
token
)
...
@@ -85,12 +86,11 @@ class XGrammarGrammar(BaseGrammarObject):
...
@@ -85,12 +86,11 @@ class XGrammarGrammar(BaseGrammarObject):
self
.
matcher
.
fill_next_token_bitmask
(
vocab_mask
,
idx
)
self
.
matcher
.
fill_next_token_bitmask
(
vocab_mask
,
idx
)
@
staticmethod
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
def
move_vocab_mask
(
vocab_mask
:
torch
.
Tensor
,
device
)
->
torch
.
Tensor
:
if
vocab_mask
.
device
.
type
!=
logits
.
device
.
type
:
return
vocab_mask
.
to
(
device
,
non_blocking
=
True
)
# vocab_mask must then be on the same device as logits
# when applying the token bitmask, so we check and move if needed
vocab_mask
=
vocab_mask
.
to
(
logits
.
device
)
@
staticmethod
def
apply_vocab_mask
(
logits
:
torch
.
Tensor
,
vocab_mask
:
torch
.
Tensor
)
->
None
:
apply_token_bitmask_inplace
(
logits
,
vocab_mask
)
apply_token_bitmask_inplace
(
logits
,
vocab_mask
)
def
copy
(
self
):
def
copy
(
self
):
...
...
python/sglang/srt/managers/scheduler.py
View file @
0e7409ad
...
@@ -114,9 +114,6 @@ class Scheduler:
...
@@ -114,9 +114,6 @@ class Scheduler:
self
.
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
self
.
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
self
.
enable_metrics
=
server_args
.
enable_metrics
self
.
enable_metrics
=
server_args
.
enable_metrics
# Session info
self
.
sessions
=
{}
# Init inter-process communication
# Init inter-process communication
context
=
zmq
.
Context
(
2
)
context
=
zmq
.
Context
(
2
)
...
@@ -259,6 +256,10 @@ class Scheduler:
...
@@ -259,6 +256,10 @@ class Scheduler:
self
.
num_generated_tokens
=
0
self
.
num_generated_tokens
=
0
self
.
last_decode_stats_tic
=
time
.
time
()
self
.
last_decode_stats_tic
=
time
.
time
()
self
.
stream_interval
=
server_args
.
stream_interval
self
.
stream_interval
=
server_args
.
stream_interval
self
.
current_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
# Session info
self
.
sessions
=
{}
# Init chunked prefill
# Init chunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
...
@@ -356,6 +357,7 @@ class Scheduler:
...
@@ -356,6 +357,7 @@ class Scheduler:
)
)
def
watchdog_thread
(
self
):
def
watchdog_thread
(
self
):
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
self
.
watchdog_last_forward_ct
=
0
self
.
watchdog_last_forward_ct
=
0
self
.
watchdog_last_time
=
time
.
time
()
self
.
watchdog_last_time
=
time
.
time
()
...
@@ -433,61 +435,6 @@ class Scheduler:
...
@@ -433,61 +435,6 @@ class Scheduler:
self
.
last_batch
=
batch
self
.
last_batch
=
batch
def
prepare_dp_attn_batch
(
self
,
local_batch
:
ScheduleBatch
):
# Check if other DP workers have running batches
if
local_batch
is
None
:
num_tokens
=
0
elif
local_batch
.
forward_mode
.
is_decode
():
num_tokens
=
local_batch
.
batch_size
()
else
:
num_tokens
=
local_batch
.
extend_num_tokens
local_num_tokens
=
torch
.
tensor
([
num_tokens
],
dtype
=
torch
.
int64
)
global_num_tokens
=
torch
.
empty
(
self
.
tp_size
,
dtype
=
torch
.
int64
)
torch
.
distributed
.
all_gather_into_tensor
(
global_num_tokens
,
local_num_tokens
,
group
=
self
.
tp_cpu_group
,
)
if
local_batch
is
None
and
global_num_tokens
.
max
().
item
()
>
0
:
local_batch
=
self
.
get_idle_batch
()
if
local_batch
is
not
None
:
local_batch
.
global_num_tokens
=
global_num_tokens
.
tolist
()
# Check forward mode for cuda graph
if
not
self
.
server_args
.
disable_cuda_graph
:
forward_mode_state
=
torch
.
tensor
(
(
1
if
local_batch
.
forward_mode
.
is_decode
()
or
local_batch
.
forward_mode
.
is_idle
()
else
0
),
dtype
=
torch
.
int32
,
)
torch
.
distributed
.
all_reduce
(
forward_mode_state
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_cpu_group
,
)
local_batch
.
can_run_dp_cuda_graph
=
forward_mode_state
.
item
()
==
1
return
local_batch
def
get_idle_batch
(
self
):
idle_batch
=
ScheduleBatch
.
init_new
(
[],
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
tree_cache
,
self
.
model_config
,
self
.
enable_overlap
,
)
idle_batch
.
prepare_for_idle
()
return
idle_batch
def
recv_requests
(
self
):
def
recv_requests
(
self
):
if
self
.
tp_rank
==
0
or
self
.
server_args
.
enable_dp_attention
:
if
self
.
tp_rank
==
0
or
self
.
server_args
.
enable_dp_attention
:
recv_reqs
=
[]
recv_reqs
=
[]
...
@@ -993,7 +940,7 @@ class Scheduler:
...
@@ -993,7 +940,7 @@ class Scheduler:
self
.
process_batch_result_prefill
(
batch
,
result
)
self
.
process_batch_result_prefill
(
batch
,
result
)
elif
batch
.
forward_mode
.
is_dummy_first
():
elif
batch
.
forward_mode
.
is_dummy_first
():
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
()
.
synchronize
()
self
.
current_stream
.
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
def
process_batch_result_prefill
(
self
,
batch
:
ScheduleBatch
,
result
):
...
@@ -1049,13 +996,14 @@ class Scheduler:
...
@@ -1049,13 +996,14 @@ class Scheduler:
if
req
.
grammar
is
not
None
:
if
req
.
grammar
is
not
None
:
req
.
grammar
.
accept_token
(
next_token_id
)
req
.
grammar
.
accept_token
(
next_token_id
)
req
.
grammar
.
finished
=
req
.
finished
()
else
:
else
:
# being chunked reqs' prefill is not finished
# being chunked reqs' prefill is not finished
req
.
is_being_chunked
-=
1
req
.
is_being_chunked
-=
1
if
batch
.
next_batch_sampling_info
:
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
()
.
synchronize
()
self
.
current_stream
.
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
else
:
# embedding or reward model
else
:
# embedding or reward model
...
@@ -1127,10 +1075,11 @@ class Scheduler:
...
@@ -1127,10 +1075,11 @@ class Scheduler:
if
req
.
grammar
is
not
None
:
if
req
.
grammar
is
not
None
:
req
.
grammar
.
accept_token
(
next_token_id
)
req
.
grammar
.
accept_token
(
next_token_id
)
req
.
grammar
.
finished
=
req
.
finished
()
if
batch
.
next_batch_sampling_info
:
if
batch
.
next_batch_sampling_info
:
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
batch
.
next_batch_sampling_info
.
update_regex_vocab_mask
()
torch
.
get_device_module
(
self
.
device
)
.
current_stream
()
.
synchronize
()
self
.
current_stream
.
synchronize
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
batch
.
next_batch_sampling_info
.
sampling_info_done
.
set
()
self
.
stream_output
(
batch
.
reqs
)
self
.
stream_output
(
batch
.
reqs
)
...
@@ -1328,6 +1277,61 @@ class Scheduler:
...
@@ -1328,6 +1277,61 @@ class Scheduler:
)
)
)
)
def
prepare_dp_attn_batch
(
self
,
local_batch
:
ScheduleBatch
):
# Check if other DP workers have running batches
if
local_batch
is
None
:
num_tokens
=
0
elif
local_batch
.
forward_mode
.
is_decode
():
num_tokens
=
local_batch
.
batch_size
()
else
:
num_tokens
=
local_batch
.
extend_num_tokens
local_num_tokens
=
torch
.
tensor
([
num_tokens
],
dtype
=
torch
.
int64
)
global_num_tokens
=
torch
.
empty
(
self
.
tp_size
,
dtype
=
torch
.
int64
)
torch
.
distributed
.
all_gather_into_tensor
(
global_num_tokens
,
local_num_tokens
,
group
=
self
.
tp_cpu_group
,
)
if
local_batch
is
None
and
global_num_tokens
.
max
().
item
()
>
0
:
local_batch
=
self
.
get_idle_batch
()
if
local_batch
is
not
None
:
local_batch
.
global_num_tokens
=
global_num_tokens
.
tolist
()
# Check forward mode for cuda graph
if
not
self
.
server_args
.
disable_cuda_graph
:
forward_mode_state
=
torch
.
tensor
(
(
1
if
local_batch
.
forward_mode
.
is_decode
()
or
local_batch
.
forward_mode
.
is_idle
()
else
0
),
dtype
=
torch
.
int32
,
)
torch
.
distributed
.
all_reduce
(
forward_mode_state
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_cpu_group
,
)
local_batch
.
can_run_dp_cuda_graph
=
forward_mode_state
.
item
()
==
1
return
local_batch
def
get_idle_batch
(
self
):
idle_batch
=
ScheduleBatch
.
init_new
(
[],
self
.
req_to_token_pool
,
self
.
token_to_kv_pool
,
self
.
tree_cache
,
self
.
model_config
,
self
.
enable_overlap
,
)
idle_batch
.
prepare_for_idle
()
return
idle_batch
def
move_ready_grammar_requests
(
self
):
def
move_ready_grammar_requests
(
self
):
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
num_ready_reqs
=
0
num_ready_reqs
=
0
...
@@ -1469,10 +1473,6 @@ def run_scheduler_process(
...
@@ -1469,10 +1473,6 @@ def run_scheduler_process(
dp_rank
:
Optional
[
int
],
dp_rank
:
Optional
[
int
],
pipe_writer
,
pipe_writer
,
):
):
# set cpu affinity to this gpu process
if
get_bool_env_var
(
"SGLANG_SET_CPU_AFFINITY"
):
set_gpu_proc_affinity
(
server_args
.
tp_size
,
server_args
.
nnodes
,
gpu_id
)
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
if
dp_rank
is
None
and
"SGLANG_DP_RANK"
in
os
.
environ
:
if
dp_rank
is
None
and
"SGLANG_DP_RANK"
in
os
.
environ
:
dp_rank
=
int
(
os
.
environ
[
"SGLANG_DP_RANK"
])
dp_rank
=
int
(
os
.
environ
[
"SGLANG_DP_RANK"
])
...
@@ -1482,6 +1482,10 @@ def run_scheduler_process(
...
@@ -1482,6 +1482,10 @@ def run_scheduler_process(
else
:
else
:
configure_logger
(
server_args
,
prefix
=
f
" DP
{
dp_rank
}
TP
{
tp_rank
}
"
)
configure_logger
(
server_args
,
prefix
=
f
" DP
{
dp_rank
}
TP
{
tp_rank
}
"
)
# set cpu affinity to this gpu process
if
get_bool_env_var
(
"SGLANG_SET_CPU_AFFINITY"
):
set_gpu_proc_affinity
(
server_args
.
tp_size
,
server_args
.
nnodes
,
gpu_id
)
suppress_other_loggers
()
suppress_other_loggers
()
parent_process
=
psutil
.
Process
().
parent
()
parent_process
=
psutil
.
Process
().
parent
()
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
0e7409ad
...
@@ -80,6 +80,7 @@ class TpModelWorkerClient:
...
@@ -80,6 +80,7 @@ class TpModelWorkerClient:
)
)
self
.
forward_thread
.
start
()
self
.
forward_thread
.
start
()
self
.
parent_process
=
psutil
.
Process
().
parent
()
self
.
parent_process
=
psutil
.
Process
().
parent
()
self
.
scheduler_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
def
get_worker_info
(
self
):
def
get_worker_info
(
self
):
return
self
.
worker
.
get_worker_info
()
return
self
.
worker
.
get_worker_info
()
...
@@ -191,7 +192,7 @@ class TpModelWorkerClient:
...
@@ -191,7 +192,7 @@ class TpModelWorkerClient:
)
)
# A cuda stream sync here to avoid the cuda illegal memory access error.
# A cuda stream sync here to avoid the cuda illegal memory access error.
torch
.
get_device_module
(
self
.
device
).
current
_stream
()
.
synchronize
()
self
.
scheduler
_stream
.
synchronize
()
# Push a new batch to the queue
# Push a new batch to the queue
self
.
input_queue
.
put
((
model_worker_batch
,
self
.
future_token_ids_ct
))
self
.
input_queue
.
put
((
model_worker_batch
,
self
.
future_token_ids_ct
))
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
0e7409ad
...
@@ -158,22 +158,23 @@ class SamplingBatchInfo:
...
@@ -158,22 +158,23 @@ class SamplingBatchInfo:
return
return
# find a grammar from the list
# find a grammar from the list
grammar
=
next
(
grammar
for
grammar
in
self
.
grammars
if
grammar
)
first_
grammar
=
next
(
grammar
for
grammar
in
self
.
grammars
if
grammar
)
# maybe we can reuse the existing mask?
# maybe we can reuse the existing mask?
self
.
vocab_mask
=
grammar
.
allocate_vocab_mask
(
self
.
vocab_mask
=
first_
grammar
.
allocate_vocab_mask
(
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
batch_size
=
len
(
self
.
temperatures
),
batch_size
=
len
(
self
.
temperatures
),
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
apply_mask
=
type
(
grammar
)
.
apply_vocab_mask
# force to use static method
self
.
apply_mask
=
first_
grammar
.
apply_vocab_mask
# force to use static method
# Apply the mask
for
i
,
grammar
in
enumerate
(
self
.
grammars
):
for
i
,
grammar
in
enumerate
(
self
.
grammars
):
if
grammar
is
not
None
:
if
grammar
and
not
grammar
.
finished
:
try
:
grammar
.
fill_vocab_mask
(
self
.
vocab_mask
,
i
)
grammar
.
fill_vocab_mask
(
self
.
vocab_mask
,
i
)
except
RuntimeError
:
# Move the mask to the device if needed
continue
self
.
vocab_mask
=
first_grammar
.
move_vocab_mask
(
self
.
vocab_mask
,
self
.
device
)
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
...
...
test/srt/test_json_constrained.py
View file @
0e7409ad
"""
"""
python3 -m unittest test_json_constrained.TestJSONConstrained.test_json_generate
python3 -m unittest test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate
python3 -m unittest test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate
"""
"""
import
json
import
json
...
@@ -11,38 +12,50 @@ import requests
...
@@ -11,38 +12,50 @@ import requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
popen_launch_server
,
)
)
def
setup_class
(
cls
,
backend
:
str
,
disable_overlap
:
bool
):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
json_schema
=
json
.
dumps
(
{
"type"
:
"object"
,
"properties"
:
{
"name"
:
{
"type"
:
"string"
,
"pattern"
:
"^[
\\
w]+$"
},
"population"
:
{
"type"
:
"integer"
},
},
"required"
:
[
"name"
,
"population"
],
}
)
other_args
=
[
"--max-running-requests"
,
"10"
,
"--grammar-backend"
,
backend
,
]
if
disable_overlap
:
other_args
+=
[
"--disable-overlap-schedule"
]
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_args
,
)
class
TestJSONConstrainedOutlinesBackend
(
unittest
.
TestCase
):
class
TestJSONConstrainedOutlinesBackend
(
unittest
.
TestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
setup_class
(
cls
,
backend
=
"outlines"
,
disable_overlap
=
False
)
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
check_jump_forward
=
False
cls
.
json_schema
=
json
.
dumps
(
{
"type"
:
"object"
,
"properties"
:
{
"name"
:
{
"type"
:
"string"
,
"pattern"
:
"^[
\\
w]+$"
},
"population"
:
{
"type"
:
"integer"
},
},
"required"
:
[
"name"
,
"population"
],
}
)
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
other_args
=
[
"--max-running-requests"
,
"10"
,
"--grammar-backend"
,
"outlines"
,
],
)
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
...
@@ -83,11 +96,13 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
...
@@ -83,11 +96,13 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
self
.
assertIsInstance
(
js_obj
[
"population"
],
int
)
self
.
assertIsInstance
(
js_obj
[
"population"
],
int
)
# Make sure jump forward is triggered
# Make sure jump forward is triggered
# NOTE: This is skipped because overlap scheduler does not support jump forward
# NOTE: The overlap scheduler does not support jump forward so we only do this test
# self.assertGreater(
# when --disable-overlap-schedule is set.
# ret["meta_info"]["completion_tokens"],
if
self
.
check_jump_forward
:
# ret["meta_info"]["completion_tokens_wo_jump_forward"],
self
.
assertGreater
(
# )
ret
[
"meta_info"
][
"completion_tokens"
],
ret
[
"meta_info"
][
"completion_tokens_wo_jump_forward"
],
)
def
test_json_generate
(
self
):
def
test_json_generate
(
self
):
self
.
run_decode
(
json_schema
=
self
.
json_schema
)
self
.
run_decode
(
json_schema
=
self
.
json_schema
)
...
@@ -126,32 +141,18 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
...
@@ -126,32 +141,18 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
list
(
executor
.
map
(
self
.
run_decode
,
json_schemas
))
list
(
executor
.
map
(
self
.
run_decode
,
json_schemas
))
class
TestJumpForwardOutlinesBackend
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
setup_class
(
cls
,
backend
=
"outlines"
,
disable_overlap
=
True
)
cls
.
check_jump_forward
=
True
class
TestJSONConstrainedXGrammarBackend
(
TestJSONConstrainedOutlinesBackend
):
class
TestJSONConstrainedXGrammarBackend
(
TestJSONConstrainedOutlinesBackend
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
setup_class
(
cls
,
backend
=
"xgrammar"
,
disable_overlap
=
False
)
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
check_jump_forward
=
False
cls
.
json_schema
=
json
.
dumps
(
{
"type"
:
"object"
,
"properties"
:
{
"name"
:
{
"type"
:
"string"
},
"population"
:
{
"type"
:
"integer"
},
},
"required"
:
[
"name"
,
"population"
],
}
)
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
other_args
=
[
"--max-running-requests"
,
"10"
,
"--grammar-backend"
,
"xgrammar"
,
],
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment