Unverified Commit 50afed4e authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

Support extra field regex in OpenAI API (#172)

parent 4d303c4f
......@@ -36,6 +36,9 @@ class CompletionRequest(BaseModel):
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
regex: Optional[str] = None
class CompletionResponseChoice(BaseModel):
index: int
......@@ -119,6 +122,9 @@ class ChatCompletionRequest(BaseModel):
user: Optional[str] = None
best_of: Optional[int] = None
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
regex: Optional[str] = None
class ChatMessage(BaseModel):
role: Optional[str] = None
......
......@@ -151,6 +151,7 @@ async def v1_completions(raw_request: Request):
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"regex": request.regex,
},
return_logprob=request.logprobs is not None,
stream=request.stream,
......@@ -304,6 +305,7 @@ async def v1_chat_completions(raw_request: Request):
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"regex": request.regex,
},
stream=request.stream,
)
......
......@@ -14,6 +14,7 @@ The capital of Japan is Tokyo
"""
import argparse
import json
import openai
......@@ -151,6 +152,29 @@ def test_chat_completion_stream(args):
print()
def test_regex(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
regex = (r"""\{\n"""
+ r""" "name": "[\w]+",\n"""
+ r""" "population": "[\w\d\s]+"\n"""
+ r"""\}"""
)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=128,
extra_body={"regex": regex},
)
text = response.choices[0].message.content
print(json.loads(text))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
......@@ -169,5 +193,6 @@ if __name__ == "__main__":
test_completion_stream(args, echo=True, logprobs=True)
test_chat_completion(args)
test_chat_completion_stream(args)
test_regex(args)
if args.test_image:
test_chat_completion_image(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment