Unverified Commit 50afed4e authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

Support extra field regex in OpenAI API (#172)

parent 4d303c4f
...@@ -36,6 +36,9 @@ class CompletionRequest(BaseModel): ...@@ -36,6 +36,9 @@ class CompletionRequest(BaseModel):
logit_bias: Optional[Dict[str, float]] = None logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None user: Optional[str] = None
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
regex: Optional[str] = None
class CompletionResponseChoice(BaseModel): class CompletionResponseChoice(BaseModel):
index: int index: int
...@@ -119,6 +122,9 @@ class ChatCompletionRequest(BaseModel): ...@@ -119,6 +122,9 @@ class ChatCompletionRequest(BaseModel):
user: Optional[str] = None user: Optional[str] = None
best_of: Optional[int] = None best_of: Optional[int] = None
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
regex: Optional[str] = None
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Optional[str] = None role: Optional[str] = None
......
...@@ -151,6 +151,7 @@ async def v1_completions(raw_request: Request): ...@@ -151,6 +151,7 @@ async def v1_completions(raw_request: Request):
"top_p": request.top_p, "top_p": request.top_p,
"presence_penalty": request.presence_penalty, "presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty, "frequency_penalty": request.frequency_penalty,
"regex": request.regex,
}, },
return_logprob=request.logprobs is not None, return_logprob=request.logprobs is not None,
stream=request.stream, stream=request.stream,
...@@ -304,6 +305,7 @@ async def v1_chat_completions(raw_request: Request): ...@@ -304,6 +305,7 @@ async def v1_chat_completions(raw_request: Request):
"top_p": request.top_p, "top_p": request.top_p,
"presence_penalty": request.presence_penalty, "presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty, "frequency_penalty": request.frequency_penalty,
"regex": request.regex,
}, },
stream=request.stream, stream=request.stream,
) )
......
...@@ -14,6 +14,7 @@ The capital of Japan is Tokyo ...@@ -14,6 +14,7 @@ The capital of Japan is Tokyo
""" """
import argparse import argparse
import json
import openai import openai
...@@ -151,6 +152,29 @@ def test_chat_completion_stream(args): ...@@ -151,6 +152,29 @@ def test_chat_completion_stream(args):
print() print()
def test_regex(args):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
regex = (r"""\{\n"""
+ r""" "name": "[\w]+",\n"""
+ r""" "population": "[\w\d\s]+"\n"""
+ r"""\}"""
)
response = client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=128,
extra_body={"regex": regex},
)
text = response.choices[0].message.content
print(json.loads(text))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1") parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
...@@ -169,5 +193,6 @@ if __name__ == "__main__": ...@@ -169,5 +193,6 @@ if __name__ == "__main__":
test_completion_stream(args, echo=True, logprobs=True) test_completion_stream(args, echo=True, logprobs=True)
test_chat_completion(args) test_chat_completion(args)
test_chat_completion_stream(args) test_chat_completion_stream(args)
test_regex(args)
if args.test_image: if args.test_image:
test_chat_completion_image(args) test_chat_completion_image(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment