Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f8acd01f
Unverified
Commit
f8acd01f
authored
Apr 26, 2025
by
Russell Bryant
Committed by
GitHub
Apr 26, 2025
Browse files
[V1] Add `structural_tag` support using xgrammar (#17085)
parent
c48334d4
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
270 additions
and
15 deletions
+270
-15
examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py
...enai_chat_completion_structured_outputs_structural_tag.py
+85
-0
tests/v1/entrypoints/llm/test_struct_output_generate.py
tests/v1/entrypoints/llm/test_struct_output_generate.py
+101
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+3
-1
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+39
-7
vllm/model_executor/guided_decoding/guided_fields.py
vllm/model_executor/guided_decoding/guided_fields.py
+6
-5
vllm/sampling_params.py
vllm/sampling_params.py
+5
-2
vllm/v1/structured_output/backend_guidance.py
vllm/v1/structured_output/backend_guidance.py
+3
-0
vllm/v1/structured_output/backend_types.py
vllm/v1/structured_output/backend_types.py
+1
-0
vllm/v1/structured_output/backend_xgrammar.py
vllm/v1/structured_output/backend_xgrammar.py
+25
-0
vllm/v1/structured_output/request.py
vllm/v1/structured_output/request.py
+2
-0
No files found.
examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py
0 → 100644
View file @
f8acd01f
# SPDX-License-Identifier: Apache-2.0
from
openai
import
OpenAI
# This example demonstrates the `structural_tag` response format.
# It can be used to specify a structured output format that occurs between
# specific tags in the response. This example shows how it could be used
# to enforce the format of a tool call response, but it could be used for
# any structured output within a subset of the response.
def
main
():
client
=
OpenAI
(
base_url
=
"http://localhost:8000/v1"
,
api_key
=
"-"
,
)
messages
=
[{
"role"
:
"user"
,
"content"
:
"""
You have access to the following function to retrieve the weather in a city:
{
"name": "get_weather",
"parameters": {
"city": {
"param_type": "string",
"description": "The city to get the weather for",
"required": True
}
}
}
If a you choose to call a function ONLY reply in the following format:
<{start_tag}={function_name}>{parameters}{end_tag}
where
start_tag => `<function`
parameters => a JSON dict with the function argument name as key and function
argument value as value.
end_tag => `</function>`
Here is an example,
<function=example_function_name>{"example_name": "example_value"}</function>
Reminder:
- Function calls MUST follow the specified format
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- Always add your sources when using search results to answer the user query
You are a helpful assistant.
Given the previous instructions, what is the weather in New York City, Boston,
and San Francisco?
"""
}]
response
=
client
.
chat
.
completions
.
create
(
model
=
"meta-llama/Llama-3.1-8B-Instruct"
,
messages
=
messages
,
response_format
=
{
"type"
:
"structural_tag"
,
"structures"
:
[{
"begin"
:
"<function=get_weather>"
,
"schema"
:
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
}
}
},
"end"
:
"</function>"
}],
"triggers"
:
[
"<function="
]
})
print
(
response
)
if
__name__
==
"__main__"
:
main
()
tests/v1/entrypoints/llm/test_struct_output_generate.py
View file @
f8acd01f
...
@@ -350,6 +350,7 @@ def test_structured_output(
...
@@ -350,6 +350,7 @@ def test_structured_output(
temperature
=
1.0
,
temperature
=
1.0
,
max_tokens
=
1000
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
json
=
json_schema
))
guided_decoding
=
GuidedDecodingParams
(
json
=
json_schema
))
outputs
=
llm
.
generate
(
outputs
=
llm
.
generate
(
prompts
=
"Generate a description of a frog using 50 characters."
,
prompts
=
"Generate a description of a frog using 50 characters."
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
...
@@ -368,6 +369,106 @@ def test_structured_output(
...
@@ -368,6 +369,106 @@ def test_structured_output(
output_json
=
json
.
loads
(
generated_text
)
output_json
=
json
.
loads
(
generated_text
)
jsonschema
.
validate
(
instance
=
output_json
,
schema
=
json_schema
)
jsonschema
.
validate
(
instance
=
output_json
,
schema
=
json_schema
)
#
# Test 11: Generate structured output using structural_tag format
#
structural_tag_config
=
{
"type"
:
"structural_tag"
,
"structures"
:
[{
"begin"
:
"<function=get_weather>"
,
"schema"
:
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
}
}
},
"end"
:
"</function>"
}],
"triggers"
:
[
"<function="
]
}
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
100
,
guided_decoding
=
GuidedDecodingParams
(
structural_tag
=
json
.
dumps
(
structural_tag_config
)))
prompt
=
"""
You have access to the following function to retrieve the weather in a city:
{
"name": "get_weather",
"parameters": {
"city": {
"param_type": "string",
"description": "The city to get the weather for",
"required": True
}
}
}
If a you choose to call a function ONLY reply in the following format:
<{start_tag}={function_name}>{parameters}{end_tag}
where
start_tag => `<function`
parameters => a JSON dict with the function argument name
as key and function argument value as value.
end_tag => `</function>`
Here is an example,
<function=example_function_name>{"example_name": "example_value"}</function>
Reminder:
- Function calls MUST follow the specified format
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- Always add your sources when using search results to answer the user query
You are a helpful assistant.
Given the previous instructions, what is the weather in New York City?
"""
# Change this once other backends support structural_tag
if
guided_decoding_backend
.
startswith
(
"xgrammar"
):
outputs
=
llm
.
generate
(
prompts
=
prompt
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
)
assert
outputs
is
not
None
else
:
outputs
=
[]
for
output
in
outputs
:
assert
output
is
not
None
assert
isinstance
(
output
,
RequestOutput
)
generated_text
=
output
.
outputs
[
0
].
text
assert
generated_text
is
not
None
# Search for function call pattern in the response
function_call_pattern
=
r
'<function=get_weather>(.*?)</function>'
matches
=
re
.
findall
(
function_call_pattern
,
generated_text
)
if
not
matches
:
print
(
f
"Warning: No function calls found in response: "
f
"
{
generated_text
!
r
}
"
)
continue
# Take the first function call if multiple are found
json_str
=
matches
[
0
]
try
:
json_content
=
json
.
loads
(
json_str
)
assert
"city"
in
json_content
assert
isinstance
(
json_content
[
"city"
],
str
)
print
(
f
"Found valid function call:
{
generated_text
!
r
}
"
)
except
(
json
.
JSONDecodeError
,
AssertionError
)
as
e
:
pytest
.
fail
(
"Invalid function call format: "
f
"
{
generated_text
!
r
}
\n
Error:
{
str
(
e
)
}
"
)
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"model_name, tokenizer_mode"
,
@
pytest
.
mark
.
parametrize
(
"model_name, tokenizer_mode"
,
...
...
vllm/entrypoints/llm.py
View file @
f8acd01f
...
@@ -1396,7 +1396,9 @@ class LLM:
...
@@ -1396,7 +1396,9 @@ class LLM:
grammar
=
guided_options
.
guided_grammar
,
grammar
=
guided_options
.
guided_grammar
,
json_object
=
guided_options
.
guided_json_object
,
json_object
=
guided_options
.
guided_json_object
,
backend
=
guided_options
.
guided_decoding_backend
,
backend
=
guided_options
.
guided_decoding_backend
,
whitespace_pattern
=
guided_options
.
guided_whitespace_pattern
)
whitespace_pattern
=
guided_options
.
guided_whitespace_pattern
,
structural_tag
=
guided_options
.
structural_tag
,
)
return
params
return
params
def
_run_engine
(
def
_run_engine
(
...
...
vllm/entrypoints/openai/protocol.py
View file @
f8acd01f
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# Adapted from
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import
json
import
re
import
re
import
time
import
time
from
argparse
import
Namespace
from
argparse
import
Namespace
...
@@ -139,12 +140,30 @@ class JsonSchemaResponseFormat(OpenAIBaseModel):
...
@@ -139,12 +140,30 @@ class JsonSchemaResponseFormat(OpenAIBaseModel):
strict
:
Optional
[
bool
]
=
None
strict
:
Optional
[
bool
]
=
None
class
StructuralTag
(
OpenAIBaseModel
):
begin
:
str
# schema is the field, but that causes conflicts with pydantic so
# instead use structural_tag_schema with an alias
structural_tag_schema
:
Optional
[
dict
[
str
,
Any
]]
=
Field
(
default
=
None
,
alias
=
"schema"
)
end
:
str
class
StructuralTagResponseFormat
(
OpenAIBaseModel
):
type
:
Literal
[
"structural_tag"
]
structures
:
list
[
StructuralTag
]
triggers
:
list
[
str
]
class
ResponseFormat
(
OpenAIBaseModel
):
class
ResponseFormat
(
OpenAIBaseModel
):
# type must be "json_schema", "json_object" or "text"
# type must be "json_schema", "json_object"
,
or "text"
type
:
Literal
[
"text"
,
"json_object"
,
"json_schema"
]
type
:
Literal
[
"text"
,
"json_object"
,
"json_schema"
]
json_schema
:
Optional
[
JsonSchemaResponseFormat
]
=
None
json_schema
:
Optional
[
JsonSchemaResponseFormat
]
=
None
AnyResponseFormat
=
Union
[
ResponseFormat
,
StructuralTagResponseFormat
]
class
StreamOptions
(
OpenAIBaseModel
):
class
StreamOptions
(
OpenAIBaseModel
):
include_usage
:
Optional
[
bool
]
=
True
include_usage
:
Optional
[
bool
]
=
True
continuous_usage_stats
:
Optional
[
bool
]
=
False
continuous_usage_stats
:
Optional
[
bool
]
=
False
...
@@ -227,7 +246,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -227,7 +246,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
max_completion_tokens
:
Optional
[
int
]
=
None
max_completion_tokens
:
Optional
[
int
]
=
None
n
:
Optional
[
int
]
=
1
n
:
Optional
[
int
]
=
1
presence_penalty
:
Optional
[
float
]
=
0.0
presence_penalty
:
Optional
[
float
]
=
0.0
response_format
:
Optional
[
ResponseFormat
]
=
None
response_format
:
Optional
[
Any
ResponseFormat
]
=
None
seed
:
Optional
[
int
]
=
Field
(
None
,
ge
=
_LONG_INFO
.
min
,
le
=
_LONG_INFO
.
max
)
seed
:
Optional
[
int
]
=
Field
(
None
,
ge
=
_LONG_INFO
.
min
,
le
=
_LONG_INFO
.
max
)
stop
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
Field
(
default_factory
=
list
)
stop
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
Field
(
default_factory
=
list
)
stream
:
Optional
[
bool
]
=
False
stream
:
Optional
[
bool
]
=
False
...
@@ -340,6 +359,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -340,6 +359,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
description
=
(
description
=
(
"If specified, the output will follow the context free grammar."
),
"If specified, the output will follow the context free grammar."
),
)
)
structural_tag
:
Optional
[
str
]
=
Field
(
default
=
None
,
description
=
(
"If specified, the output will follow the structural tag schema."
),
)
guided_decoding_backend
:
Optional
[
str
]
=
Field
(
guided_decoding_backend
:
Optional
[
str
]
=
Field
(
default
=
None
,
default
=
None
,
description
=
(
description
=
(
...
@@ -476,6 +500,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -476,6 +500,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
json_schema
=
self
.
response_format
.
json_schema
json_schema
=
self
.
response_format
.
json_schema
assert
json_schema
is
not
None
assert
json_schema
is
not
None
self
.
guided_json
=
json_schema
.
json_schema
self
.
guided_json
=
json_schema
.
json_schema
elif
self
.
response_format
.
type
==
"structural_tag"
:
structural_tag
=
self
.
response_format
assert
structural_tag
is
not
None
and
isinstance
(
structural_tag
,
StructuralTagResponseFormat
)
s_tag_obj
=
structural_tag
.
model_dump
(
by_alias
=
True
)
self
.
structural_tag
=
json
.
dumps
(
s_tag_obj
)
guided_decoding
=
GuidedDecodingParams
.
from_optional
(
guided_decoding
=
GuidedDecodingParams
.
from_optional
(
json
=
self
.
_get_guided_json_from_tool
()
or
self
.
guided_json
,
json
=
self
.
_get_guided_json_from_tool
()
or
self
.
guided_json
,
...
@@ -485,6 +515,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -485,6 +515,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
json_object
=
guided_json_object
,
json_object
=
guided_json_object
,
backend
=
self
.
guided_decoding_backend
,
backend
=
self
.
guided_decoding_backend
,
whitespace_pattern
=
self
.
guided_whitespace_pattern
,
whitespace_pattern
=
self
.
guided_whitespace_pattern
,
structural_tag
=
self
.
structural_tag
,
)
)
return
SamplingParams
.
from_optional
(
return
SamplingParams
.
from_optional
(
...
@@ -742,12 +773,13 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -742,12 +773,13 @@ class CompletionRequest(OpenAIBaseModel):
"If true (the default), special tokens (e.g. BOS) will be added to "
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."
),
"the prompt."
),
)
)
response_format
:
Optional
[
ResponseFormat
]
=
Field
(
response_format
:
Optional
[
Any
ResponseFormat
]
=
Field
(
default
=
None
,
default
=
None
,
description
=
description
=
(
(
"Similar to chat completion, this parameter specifies the format of "
"Similar to chat completion, this parameter specifies the format "
"output. Only {'type': 'json_object'}, {'type': 'json_schema'} or "
"of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
"{'type': 'text' } is supported."
),
", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
),
)
)
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
Field
(
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
Field
(
default
=
None
,
default
=
None
,
...
...
vllm/model_executor/guided_decoding/guided_fields.py
View file @
f8acd01f
...
@@ -27,14 +27,15 @@ class GuidedDecodingRequest:
...
@@ -27,14 +27,15 @@ class GuidedDecodingRequest:
guided_decoding_backend
:
Optional
[
str
]
=
None
guided_decoding_backend
:
Optional
[
str
]
=
None
guided_whitespace_pattern
:
Optional
[
str
]
=
None
guided_whitespace_pattern
:
Optional
[
str
]
=
None
guided_json_object
:
Optional
[
bool
]
=
None
guided_json_object
:
Optional
[
bool
]
=
None
structural_tag
:
Optional
[
str
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
"""Validate that some fields are mutually exclusive."""
"""Validate that some fields are mutually exclusive."""
guide_count
=
sum
(
[
guide_count
=
sum
(
x
is
not
None
self
.
guided_json
is
not
None
,
self
.
guided_regex
is
not
None
,
for
x
in
(
self
.
guided_json
,
self
.
guided_regex
,
self
.
guided_choice
is
not
None
,
self
.
guided_grammar
is
not
None
,
self
.
guided_choice
,
self
.
guided_grammar
,
self
.
guided_json_object
is
not
None
self
.
guided_json_object
,
]
)
self
.
structural_tag
)
)
if
guide_count
>
1
:
if
guide_count
>
1
:
raise
ValueError
(
raise
ValueError
(
"You can only use one kind of guided decoding but multiple are "
"You can only use one kind of guided decoding but multiple are "
...
...
vllm/sampling_params.py
View file @
f8acd01f
...
@@ -38,6 +38,7 @@ class GuidedDecodingParams:
...
@@ -38,6 +38,7 @@ class GuidedDecodingParams:
"""These are other options that can be set"""
"""These are other options that can be set"""
backend
:
Optional
[
str
]
=
None
backend
:
Optional
[
str
]
=
None
whitespace_pattern
:
Optional
[
str
]
=
None
whitespace_pattern
:
Optional
[
str
]
=
None
structural_tag
:
Optional
[
str
]
=
None
@
staticmethod
@
staticmethod
def
from_optional
(
def
from_optional
(
...
@@ -48,9 +49,10 @@ class GuidedDecodingParams:
...
@@ -48,9 +49,10 @@ class GuidedDecodingParams:
json_object
:
Optional
[
bool
]
=
None
,
json_object
:
Optional
[
bool
]
=
None
,
backend
:
Optional
[
str
]
=
None
,
backend
:
Optional
[
str
]
=
None
,
whitespace_pattern
:
Optional
[
str
]
=
None
,
whitespace_pattern
:
Optional
[
str
]
=
None
,
structural_tag
:
Optional
[
str
]
=
None
,
)
->
Optional
[
"GuidedDecodingParams"
]:
)
->
Optional
[
"GuidedDecodingParams"
]:
if
all
(
arg
is
None
if
all
(
arg
is
None
for
arg
in
(
json
,
regex
,
choice
,
grammar
,
for
arg
in
(
json
,
regex
,
choice
,
grammar
,
json_object
)):
json_object
,
structural_tag
)):
return
None
return
None
# Extract json schemas from pydantic models
# Extract json schemas from pydantic models
if
isinstance
(
json
,
(
BaseModel
,
type
(
BaseModel
))):
if
isinstance
(
json
,
(
BaseModel
,
type
(
BaseModel
))):
...
@@ -63,6 +65,7 @@ class GuidedDecodingParams:
...
@@ -63,6 +65,7 @@ class GuidedDecodingParams:
json_object
=
json_object
,
json_object
=
json_object
,
backend
=
backend
,
backend
=
backend
,
whitespace_pattern
=
whitespace_pattern
,
whitespace_pattern
=
whitespace_pattern
,
structural_tag
=
structural_tag
,
)
)
@
property
@
property
...
...
vllm/v1/structured_output/backend_guidance.py
View file @
f8acd01f
...
@@ -194,6 +194,9 @@ def serialize_guidance_grammar(
...
@@ -194,6 +194,9 @@ def serialize_guidance_grammar(
tp
=
"grammar"
tp
=
"grammar"
elif
request_type
==
StructuredOutputOptions
.
CHOICE
:
elif
request_type
==
StructuredOutputOptions
.
CHOICE
:
tp
=
"choice"
tp
=
"choice"
elif
request_type
==
StructuredOutputOptions
.
STRUCTURAL_TAG
:
raise
ValueError
(
"Structural tag is not supported "
"for guidance backend yet"
)
else
:
else
:
logger
.
error
(
"Validation should have already occurred. "
logger
.
error
(
"Validation should have already occurred. "
"Please file an issue."
)
"Please file an issue."
)
...
...
vllm/v1/structured_output/backend_types.py
View file @
f8acd01f
...
@@ -12,6 +12,7 @@ class StructuredOutputOptions(enum.Enum):
...
@@ -12,6 +12,7 @@ class StructuredOutputOptions(enum.Enum):
REGEX
=
enum
.
auto
()
REGEX
=
enum
.
auto
()
GRAMMAR
=
enum
.
auto
()
GRAMMAR
=
enum
.
auto
()
CHOICE
=
enum
.
auto
()
CHOICE
=
enum
.
auto
()
STRUCTURAL_TAG
=
enum
.
auto
()
StructuredOutputKey
=
tuple
[
StructuredOutputOptions
,
str
]
StructuredOutputKey
=
tuple
[
StructuredOutputOptions
,
str
]
...
...
vllm/v1/structured_output/backend_xgrammar.py
View file @
f8acd01f
...
@@ -108,6 +108,16 @@ class XgrammarBackend(StructuredOutputBackend):
...
@@ -108,6 +108,16 @@ class XgrammarBackend(StructuredOutputBackend):
ctx
=
self
.
compiler
.
compile_grammar
(
grammar_spec
)
ctx
=
self
.
compiler
.
compile_grammar
(
grammar_spec
)
elif
request_type
==
StructuredOutputOptions
.
REGEX
:
elif
request_type
==
StructuredOutputOptions
.
REGEX
:
ctx
=
self
.
compiler
.
compile_regex
(
grammar_spec
)
ctx
=
self
.
compiler
.
compile_regex
(
grammar_spec
)
elif
request_type
==
StructuredOutputOptions
.
STRUCTURAL_TAG
:
s_tag
=
json
.
loads
(
grammar_spec
)
tags
=
[
xgr
.
StructuralTagItem
(
begin
=
s
[
"begin"
],
schema
=
json
.
dumps
(
s
[
"schema"
]),
end
=
s
[
"end"
],
)
for
s
in
s_tag
[
"structures"
]
]
ctx
=
self
.
compiler
.
compile_structural_tag
(
tags
,
s_tag
[
"triggers"
])
else
:
else
:
logger
.
error
(
logger
.
error
(
"Validation should have already occurred. Please file an issue."
"Validation should have already occurred. Please file an issue."
...
@@ -272,3 +282,18 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
...
@@ -272,3 +282,18 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
xgr
.
Grammar
.
from_ebnf
(
gd_params
.
grammar
)
xgr
.
Grammar
.
from_ebnf
(
gd_params
.
grammar
)
except
Exception
as
e
:
except
Exception
as
e
:
raise
ValueError
(
"Invalid grammar specification."
)
from
e
raise
ValueError
(
"Invalid grammar specification."
)
from
e
return
if
gd_params
.
structural_tag
:
try
:
s_tag
=
json
.
loads
(
gd_params
.
structural_tag
)
tags
=
[
xgr
.
StructuralTagItem
(
begin
=
s
[
"begin"
],
schema
=
json
.
dumps
(
s
[
"schema"
]),
end
=
s
[
"end"
],
)
for
s
in
s_tag
[
"structures"
]
]
xgr
.
Grammar
.
from_structural_tag
(
tags
,
s_tag
[
"triggers"
])
except
Exception
as
e
:
raise
ValueError
(
"Invalid structural tag specification."
)
from
e
vllm/v1/structured_output/request.py
View file @
f8acd01f
...
@@ -78,5 +78,7 @@ def get_structured_output_key(
...
@@ -78,5 +78,7 @@ def get_structured_output_key(
return
(
StructuredOutputOptions
.
CHOICE
,
json_str
)
return
(
StructuredOutputOptions
.
CHOICE
,
json_str
)
elif
params
.
grammar
is
not
None
:
elif
params
.
grammar
is
not
None
:
return
(
StructuredOutputOptions
.
GRAMMAR
,
params
.
grammar
)
return
(
StructuredOutputOptions
.
GRAMMAR
,
params
.
grammar
)
elif
params
.
structural_tag
is
not
None
:
return
(
StructuredOutputOptions
.
STRUCTURAL_TAG
,
params
.
structural_tag
)
else
:
else
:
raise
ValueError
(
"No valid structured output parameter found"
)
raise
ValueError
(
"No valid structured output parameter found"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment