Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
acb34072
Unverified
Commit
acb34072
authored
Dec 26, 2024
by
Adarsh Shirawalmath
Committed by
GitHub
Dec 26, 2024
Browse files
[Feature] Support new parameter - EBNF in xgrammar (#2526)
parent
08effbff
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
384 additions
and
2 deletions
+384
-2
python/sglang/lang/backend/openai.py
python/sglang/lang/backend/openai.py
+10
-0
python/sglang/srt/constrained/xgrammar_backend.py
python/sglang/srt/constrained/xgrammar_backend.py
+6
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-0
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+19
-0
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+2
-0
python/sglang/srt/sampling/sampling_params.py
python/sglang/srt/sampling/sampling_params.py
+9
-2
test/srt/test_ebnf_constrained.py
test/srt/test_ebnf_constrained.py
+247
-0
test/srt/test_openai_server.py
test/srt/test_openai_server.py
+88
-0
No files found.
python/sglang/lang/backend/openai.py
View file @
acb34072
...
@@ -366,6 +366,11 @@ class OpenAI(BaseBackend):
...
@@ -366,6 +366,11 @@ class OpenAI(BaseBackend):
def
openai_completion
(
def
openai_completion
(
client
,
token_usage
,
is_chat
=
None
,
retries
=
3
,
prompt
=
None
,
**
kwargs
client
,
token_usage
,
is_chat
=
None
,
retries
=
3
,
prompt
=
None
,
**
kwargs
):
):
# if "ebnf" is in kwargs, warn and remove
if
"ebnf"
in
kwargs
:
warnings
.
warn
(
"EBNF is not officially supported by OpenAI endpoints. Ignoring."
)
del
kwargs
[
"ebnf"
]
for
attempt
in
range
(
retries
):
for
attempt
in
range
(
retries
):
try
:
try
:
if
is_chat
:
if
is_chat
:
...
@@ -398,6 +403,11 @@ def openai_completion(
...
@@ -398,6 +403,11 @@ def openai_completion(
def
openai_completion_stream
(
def
openai_completion_stream
(
client
,
token_usage
,
is_chat
=
None
,
retries
=
3
,
prompt
=
None
,
**
kwargs
client
,
token_usage
,
is_chat
=
None
,
retries
=
3
,
prompt
=
None
,
**
kwargs
):
):
# if "ebnf" is in kwargs, warn and remove
if
"ebnf"
in
kwargs
:
warnings
.
warn
(
"EBNF is not officially supported by OpenAI endpoints. Ignoring."
)
del
kwargs
[
"ebnf"
]
for
attempt
in
range
(
retries
):
for
attempt
in
range
(
retries
):
try
:
try
:
if
is_chat
:
if
is_chat
:
...
...
python/sglang/srt/constrained/xgrammar_backend.py
View file @
acb34072
...
@@ -126,6 +126,12 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
...
@@ -126,6 +126,12 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
f
"Skip invalid json_schema: json_schema=
{
key_string
}
,
{
e
=
}
"
f
"Skip invalid json_schema: json_schema=
{
key_string
}
,
{
e
=
}
"
)
)
return
None
return
None
elif
key_type
==
"ebnf"
:
try
:
ctx
=
self
.
grammar_compiler
.
compile_grammar
(
key_string
)
except
RuntimeError
as
e
:
logging
.
warning
(
f
"Skip invalid ebnf: ebnf=
{
key_string
}
,
{
e
=
}
"
)
return
None
elif
key_type
==
"regex"
:
elif
key_type
==
"regex"
:
logger
.
warning
(
logger
.
warning
(
"regex hasn't been supported by xgrammar yet. This is skipped."
"regex hasn't been supported by xgrammar yet. This is skipped."
...
...
python/sglang/srt/managers/scheduler.py
View file @
acb34072
...
@@ -589,12 +589,15 @@ class Scheduler:
...
@@ -589,12 +589,15 @@ class Scheduler:
if
(
if
(
req
.
sampling_params
.
json_schema
is
not
None
req
.
sampling_params
.
json_schema
is
not
None
or
req
.
sampling_params
.
regex
is
not
None
or
req
.
sampling_params
.
regex
is
not
None
or
req
.
sampling_params
.
ebnf
is
not
None
):
):
assert
self
.
grammar_backend
is
not
None
assert
self
.
grammar_backend
is
not
None
if
req
.
sampling_params
.
json_schema
is
not
None
:
if
req
.
sampling_params
.
json_schema
is
not
None
:
key
=
(
"json"
,
req
.
sampling_params
.
json_schema
)
key
=
(
"json"
,
req
.
sampling_params
.
json_schema
)
elif
req
.
sampling_params
.
regex
is
not
None
:
elif
req
.
sampling_params
.
regex
is
not
None
:
key
=
(
"regex"
,
req
.
sampling_params
.
regex
)
key
=
(
"regex"
,
req
.
sampling_params
.
regex
)
elif
req
.
sampling_params
.
ebnf
is
not
None
:
key
=
(
"ebnf"
,
req
.
sampling_params
.
ebnf
)
req
.
grammar
=
self
.
grammar_backend
.
get_cached_value
(
key
)
req
.
grammar
=
self
.
grammar_backend
.
get_cached_value
(
key
)
if
not
req
.
grammar
:
if
not
req
.
grammar
:
...
...
python/sglang/srt/openai_api/adapter.py
View file @
acb34072
...
@@ -517,6 +517,7 @@ def v1_generate_request(
...
@@ -517,6 +517,7 @@ def v1_generate_request(
"repetition_penalty"
:
request
.
repetition_penalty
,
"repetition_penalty"
:
request
.
repetition_penalty
,
"regex"
:
request
.
regex
,
"regex"
:
request
.
regex
,
"json_schema"
:
request
.
json_schema
,
"json_schema"
:
request
.
json_schema
,
"ebnf"
:
request
.
ebnf
,
"n"
:
request
.
n
,
"n"
:
request
.
n
,
"no_stop_trim"
:
request
.
no_stop_trim
,
"no_stop_trim"
:
request
.
no_stop_trim
,
"ignore_eos"
:
request
.
ignore_eos
,
"ignore_eos"
:
request
.
ignore_eos
,
...
@@ -692,6 +693,14 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
...
@@ -692,6 +693,14 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
async
def
v1_completions
(
tokenizer_manager
,
raw_request
:
Request
):
async
def
v1_completions
(
tokenizer_manager
,
raw_request
:
Request
):
request_json
=
await
raw_request
.
json
()
request_json
=
await
raw_request
.
json
()
if
"extra_body"
in
request_json
:
extra
=
request_json
[
"extra_body"
]
if
"ebnf"
in
extra
:
request_json
[
"ebnf"
]
=
extra
[
"ebnf"
]
if
"regex"
in
extra
:
request_json
[
"regex"
]
=
extra
[
"regex"
]
# remove extra_body to avoid pydantic conflict
del
request_json
[
"extra_body"
]
all_requests
=
[
CompletionRequest
(
**
request_json
)]
all_requests
=
[
CompletionRequest
(
**
request_json
)]
adapted_request
,
request
=
v1_generate_request
(
all_requests
)
adapted_request
,
request
=
v1_generate_request
(
all_requests
)
...
@@ -936,6 +945,7 @@ def v1_chat_generate_request(
...
@@ -936,6 +945,7 @@ def v1_chat_generate_request(
"frequency_penalty"
:
request
.
frequency_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"repetition_penalty"
:
request
.
repetition_penalty
,
"repetition_penalty"
:
request
.
repetition_penalty
,
"regex"
:
request
.
regex
,
"regex"
:
request
.
regex
,
"ebnf"
:
request
.
ebnf
,
"n"
:
request
.
n
,
"n"
:
request
.
n
,
"no_stop_trim"
:
request
.
no_stop_trim
,
"no_stop_trim"
:
request
.
no_stop_trim
,
"ignore_eos"
:
request
.
ignore_eos
,
"ignore_eos"
:
request
.
ignore_eos
,
...
@@ -1108,6 +1118,15 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
...
@@ -1108,6 +1118,15 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
async
def
v1_chat_completions
(
tokenizer_manager
,
raw_request
:
Request
):
async
def
v1_chat_completions
(
tokenizer_manager
,
raw_request
:
Request
):
request_json
=
await
raw_request
.
json
()
request_json
=
await
raw_request
.
json
()
if
"extra_body"
in
request_json
:
extra
=
request_json
[
"extra_body"
]
# For example, if 'ebnf' is given:
if
"ebnf"
in
extra
:
request_json
[
"ebnf"
]
=
extra
[
"ebnf"
]
if
"regex"
in
extra
:
request_json
[
"regex"
]
=
extra
[
"regex"
]
# remove extra_body to avoid pydantic conflict
del
request_json
[
"extra_body"
]
all_requests
=
[
ChatCompletionRequest
(
**
request_json
)]
all_requests
=
[
ChatCompletionRequest
(
**
request_json
)]
adapted_request
,
request
=
v1_chat_generate_request
(
all_requests
,
tokenizer_manager
)
adapted_request
,
request
=
v1_chat_generate_request
(
all_requests
,
tokenizer_manager
)
...
...
python/sglang/srt/openai_api/protocol.py
View file @
acb34072
...
@@ -179,6 +179,7 @@ class CompletionRequest(BaseModel):
...
@@ -179,6 +179,7 @@ class CompletionRequest(BaseModel):
ignore_eos
:
bool
=
False
ignore_eos
:
bool
=
False
skip_special_tokens
:
bool
=
True
skip_special_tokens
:
bool
=
True
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
ebnf
:
Optional
[
str
]
=
None
class
CompletionResponseChoice
(
BaseModel
):
class
CompletionResponseChoice
(
BaseModel
):
...
@@ -288,6 +289,7 @@ class ChatCompletionRequest(BaseModel):
...
@@ -288,6 +289,7 @@ class ChatCompletionRequest(BaseModel):
ignore_eos
:
bool
=
False
ignore_eos
:
bool
=
False
skip_special_tokens
:
bool
=
True
skip_special_tokens
:
bool
=
True
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
ebnf
:
Optional
[
str
]
=
None
class
ChatMessage
(
BaseModel
):
class
ChatMessage
(
BaseModel
):
...
...
python/sglang/srt/sampling/sampling_params.py
View file @
acb34072
...
@@ -36,6 +36,7 @@ class SamplingParams:
...
@@ -36,6 +36,7 @@ class SamplingParams:
regex
:
Optional
[
str
]
=
None
,
regex
:
Optional
[
str
]
=
None
,
n
:
int
=
1
,
n
:
int
=
1
,
json_schema
:
Optional
[
str
]
=
None
,
json_schema
:
Optional
[
str
]
=
None
,
ebnf
:
Optional
[
str
]
=
None
,
no_stop_trim
:
bool
=
False
,
no_stop_trim
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
skip_special_tokens
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
,
...
@@ -60,6 +61,7 @@ class SamplingParams:
...
@@ -60,6 +61,7 @@ class SamplingParams:
self
.
regex
=
regex
self
.
regex
=
regex
self
.
n
=
n
self
.
n
=
n
self
.
json_schema
=
json_schema
self
.
json_schema
=
json_schema
self
.
ebnf
=
ebnf
self
.
no_stop_trim
=
no_stop_trim
self
.
no_stop_trim
=
no_stop_trim
# Process some special cases
# Process some special cases
...
@@ -111,8 +113,13 @@ class SamplingParams:
...
@@ -111,8 +113,13 @@ class SamplingParams:
f
"min_new_tokens must be in (0, max_new_tokens(
{
self
.
max_new_tokens
}
)], got "
f
"min_new_tokens must be in (0, max_new_tokens(
{
self
.
max_new_tokens
}
)], got "
f
"
{
self
.
min_new_tokens
}
."
f
"
{
self
.
min_new_tokens
}
."
)
)
if
self
.
regex
is
not
None
and
self
.
json_schema
is
not
None
:
grammars
=
[
raise
ValueError
(
"regex and json_schema cannot be both set."
)
self
.
json_schema
,
self
.
regex
,
self
.
ebnf
,
]
# since mutually exclusive, only one can be set
if
sum
(
x
is
not
None
for
x
in
grammars
)
>
1
:
raise
ValueError
(
"Only one of regex, json_schema, or ebnf can be set."
)
def
normalize
(
self
,
tokenizer
):
def
normalize
(
self
,
tokenizer
):
# Process stop strings
# Process stop strings
...
...
test/srt/test_ebnf_constrained.py
0 → 100644
View file @
acb34072
"""
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_email
python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_greeting
"""
import
json
import
unittest
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
def
setup_class
(
cls
,
disable_overlap
:
bool
):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
ebnf_grammar
=
'root ::= "test"'
# Default grammar
other_args
=
[
"--max-running-requests"
,
"10"
,
"--grammar-backend"
,
"xgrammar"
,
]
if
disable_overlap
:
other_args
+=
[
"--disable-overlap-schedule"
]
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_args
,
)
class
TestEBNFConstrained
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
setup_class
(
cls
,
disable_overlap
=
False
)
cls
.
check_jump_forward
=
False
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
run_decode
(
self
,
ebnf
,
expected_patterns
,
prompt
,
return_logprob
=
False
,
top_logprobs_num
=
0
,
n
=
1
,
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"sampling_params"
:
{
"temperature"
:
0
if
n
==
1
else
0.5
,
"max_new_tokens"
:
128
,
"n"
:
n
,
"ebnf"
:
ebnf
,
},
"stream"
:
False
,
"return_logprob"
:
return_logprob
,
"top_logprobs_num"
:
top_logprobs_num
,
"logprob_start_len"
:
0
,
},
)
ret
=
response
.
json
()
print
(
json
.
dumps
(
ret
,
indent
=
2
))
print
(
"="
*
100
)
if
not
isinstance
(
ret
,
list
):
self
.
fail
(
f
"Expected response to be a list, but got
{
type
(
ret
)
}
"
)
for
item
in
ret
:
text
=
item
.
get
(
"text"
,
""
).
strip
()
if
not
text
:
self
.
fail
(
"Generated text is empty."
)
match
=
False
for
pattern
in
expected_patterns
:
if
self
.
regex_match
(
text
,
pattern
):
match
=
True
break
if
not
match
:
self
.
fail
(
f
"Text '
{
text
}
' does not match any of the allowed patterns."
)
def
regex_match
(
self
,
text
,
pattern
):
import
re
return
re
.
match
(
pattern
,
text
)
is
not
None
def
test_ebnf_generate_email
(
self
):
self
.
__class__
.
ebnf_grammar
=
'root ::= "user@example.com"'
allowed_patterns
=
[
r
"^user@example\.com$"
]
prompt
=
"Generate an email address:"
self
.
run_decode
(
ebnf
=
self
.
__class__
.
ebnf_grammar
,
expected_patterns
=
allowed_patterns
,
prompt
=
prompt
,
n
=
3
,
)
def
test_ebnf_generate_greeting
(
self
):
self
.
__class__
.
ebnf_grammar
=
'root ::= "Hello" | "Hi" | "Hey"'
allowed_patterns
=
[
r
"^(Hello|Hi|Hey)$"
]
prompt
=
"Generate a greeting:"
self
.
run_decode
(
ebnf
=
self
.
__class__
.
ebnf_grammar
,
expected_patterns
=
allowed_patterns
,
prompt
=
prompt
,
n
=
3
,
)
def
test_ebnf_generate_number
(
self
):
self
.
__class__
.
ebnf_grammar
=
"""
root ::= digit digit digit
digit ::= [0-9]
"""
allowed_patterns
=
[
r
"^\d{3}$"
]
prompt
=
"Generate a three-digit number:"
self
.
run_decode
(
ebnf
=
self
.
__class__
.
ebnf_grammar
,
expected_patterns
=
allowed_patterns
,
prompt
=
prompt
,
n
=
3
,
)
def
test_ebnf_generate_phone
(
self
):
self
.
__class__
.
ebnf_grammar
=
"""
root ::= "(" area ")" " " prefix "-" line
area ::= [0-9] [0-9] [0-9]
prefix ::= [0-9] [0-9] [0-9]
line ::= [0-9] [0-9] [0-9] [0-9]
"""
allowed_patterns
=
[
r
"^\(\d{3}\) \d{3}-\d{4}$"
]
prompt
=
"Generate a phone number:"
self
.
run_decode
(
ebnf
=
self
.
__class__
.
ebnf_grammar
,
expected_patterns
=
allowed_patterns
,
prompt
=
prompt
,
n
=
3
,
)
def
test_ebnf_generate_date
(
self
):
self
.
__class__
.
ebnf_grammar
=
"""
root ::= year "-" month "-" day
year ::= "2024"
month ::= "01" | "02" | "03" | "04" | "05" | "06" | "07" | "08" | "09" | "10" | "11" | "12"
day ::= "01" | "02" | "03" | "04" | "05" | "06" | "07" | "08" | "09" | "10" |
"11" | "12" | "13" | "14" | "15" | "16" | "17" | "18" | "19" | "20" |
"21" | "22" | "23" | "24" | "25" | "26" | "27" | "28" | "29" | "30" | "31"
"""
allowed_patterns
=
[
r
"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$"
]
prompt
=
"Generate a date in YYYY-MM-DD format:"
self
.
run_decode
(
ebnf
=
self
.
__class__
.
ebnf_grammar
,
expected_patterns
=
allowed_patterns
,
prompt
=
prompt
,
n
=
3
,
)
def
test_ebnf_generate_hex_color
(
self
):
self
.
__class__
.
ebnf_grammar
=
"""
root ::= "#" hex hex hex hex hex hex
hex ::= [0-9] | [A-F]
"""
allowed_patterns
=
[
r
"^#[0-9A-F]{6}$"
]
prompt
=
"Generate a hex color code:"
self
.
run_decode
(
ebnf
=
self
.
__class__
.
ebnf_grammar
,
expected_patterns
=
allowed_patterns
,
prompt
=
prompt
,
n
=
3
,
)
def
test_ebnf_generate_complex_json
(
self
):
self
.
__class__
.
ebnf_grammar
=
"""
root ::= object
object ::= "{" ws pair (ws "," ws pair)* ws "}"
pair ::= "
\\
"name
\\
"" ws ":" ws value |
"
\\
"age
\\
"" ws ":" ws number |
"
\\
"city
\\
"" ws ":" ws string
value ::= string | number
string ::= "
\\
"" [a-zA-Z0-9 ]+ "
\\
""
number ::= [1-9] [0-9]*
ws ::= [ ]*
"""
allowed_patterns
=
[
r
'^{\s*"name"\s*:\s*"[a-zA-Z0-9 ]+"\s*,\s*"age"\s*:\s*[1-9][0-9]*\s*,\s*"city"\s*:\s*"[a-zA-Z0-9 ]+"\s*}$'
,
]
prompt
=
"Generate a simple JSON with name, age, and city:"
self
.
run_decode
(
ebnf
=
self
.
__class__
.
ebnf_grammar
,
expected_patterns
=
allowed_patterns
,
prompt
=
prompt
,
n
=
3
,
)
def
test_ebnf_generate_custom_log_format
(
self
):
self
.
__class__
.
ebnf_grammar
=
"""
root ::= logentry
logentry ::= "[" datetime "] " level ": System.process - " message
datetime ::= "2024-01-01T12:00:00Z"
level ::= "INFO"
message ::= "Operation " [a-z]+ " successfully"
"""
allowed_patterns
=
[
r
"^\[2024-01-01T12:00:00Z\] INFO: System\.process - Operation [a-z]+ successfully$"
]
prompt
=
"Generate a log entry:"
self
.
run_decode
(
ebnf
=
self
.
__class__
.
ebnf_grammar
,
expected_patterns
=
allowed_patterns
,
prompt
=
prompt
,
n
=
3
,
)
class
TestJumpForward
(
TestEBNFConstrained
):
@
classmethod
def
setUpClass
(
cls
):
setup_class
(
cls
,
disable_overlap
=
True
)
cls
.
check_jump_forward
=
True
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_openai_server.py
View file @
acb34072
...
@@ -5,6 +5,7 @@ python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
...
@@ -5,6 +5,7 @@ python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
"""
"""
import
json
import
json
import
re
import
time
import
time
import
unittest
import
unittest
...
@@ -535,5 +536,92 @@ The SmartHome Mini is a compact smart home assistant available in black or white
...
@@ -535,5 +536,92 @@ The SmartHome Mini is a compact smart home assistant available in black or white
)
)
# -------------------------------------------------------------------------
# EBNF Test Class: TestOpenAIServerEBNF
# Launches the server with xgrammar, has only EBNF tests
# -------------------------------------------------------------------------
class
TestOpenAIServerEBNF
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
# passing xgrammar specifically
other_args
=
[
"--grammar-backend"
,
"xgrammar"
]
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
cls
.
api_key
,
other_args
=
other_args
,
)
cls
.
base_url
+=
"/v1"
cls
.
tokenizer
=
get_tokenizer
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_ebnf
(
self
):
"""
Ensure we can pass `ebnf` to the local openai server
and that it enforces the grammar.
"""
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
ebnf_grammar
=
r
"""
root ::= "Hello" | "Hi" | "Hey"
"""
pattern
=
re
.
compile
(
r
"^(Hello|Hi|Hey)[.!?]*\s*$"
)
response
=
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful EBNF test bot."
},
{
"role"
:
"user"
,
"content"
:
"Say a greeting (Hello, Hi, or Hey)."
},
],
temperature
=
0
,
max_tokens
=
32
,
extra_body
=
{
"ebnf"
:
ebnf_grammar
},
)
text
=
response
.
choices
[
0
].
message
.
content
.
strip
()
print
(
"EBNF test output:"
,
repr
(
text
))
self
.
assertTrue
(
len
(
text
)
>
0
,
"Got empty text from EBNF generation"
)
self
.
assertRegex
(
text
,
pattern
,
f
"Text '
{
text
}
' doesn't match EBNF choices"
)
def
test_ebnf_strict_json
(
self
):
"""
A stricter EBNF that produces exactly {"name":"Alice"} format
with no trailing punctuation or extra fields.
"""
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
self
.
base_url
)
ebnf_grammar
=
r
"""
root ::= "{" pair "}"
pair ::= "\"name\"" ":" string
string ::= "\"" [A-Za-z]+ "\""
"""
pattern
=
re
.
compile
(
r
'^\{"name":"[A-Za-z]+"\}$'
)
response
=
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
messages
=
[
{
"role"
:
"system"
,
"content"
:
"EBNF mini-JSON generator."
},
{
"role"
:
"user"
,
"content"
:
"Generate single key JSON with only letters."
,
},
],
temperature
=
0
,
max_tokens
=
64
,
extra_body
=
{
"ebnf"
:
ebnf_grammar
},
)
text
=
response
.
choices
[
0
].
message
.
content
.
strip
()
print
(
"EBNF strict JSON test output:"
,
repr
(
text
))
self
.
assertTrue
(
len
(
text
)
>
0
,
"Got empty text from EBNF strict JSON test"
)
self
.
assertRegex
(
text
,
pattern
,
f
"Text '
{
text
}
' not matching the EBNF strict JSON shape"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment