Unverified Commit ec380dfd authored by LiviaSun's avatar LiviaSun Committed by GitHub
Browse files

openai chat speculative execution (#250)


Co-authored-by: default avatarYing Sheng <sqy1415@gmail.com>
parent 5b647543
...@@ -14,10 +14,25 @@ def gen_character_spec(s): ...@@ -14,10 +14,25 @@ def gen_character_spec(s):
s += "\nJob:" + gen("job", stop="\n") + "\n" s += "\nJob:" + gen("job", stop="\n") + "\n"
@function(api_num_spec_tokens=64)
def gen_character_spec_no_few_shot(s):
# s += "Construct a character with name, birthday, and job:\n"
s += "Construct a character:\n"
s += "Name:" + gen("name", stop="\n") + "\nBirthday:" + gen("birthday", stop="\n")
s += "\nJob:" + gen("job", stop="\n") + "\n"
set_default_backend(OpenAI("gpt-3.5-turbo-instruct")) set_default_backend(OpenAI("gpt-3.5-turbo-instruct"))
state = gen_character_spec.run() state = gen_character_spec.run()
print("name:", state["name"]) print("...name:", state["name"])
print("birthday:", state["birthday"]) print("...birthday:", state["birthday"])
print("job:", state["job"]) print("...job:", state["job"])
state = gen_character_spec_no_few_shot.run()
print("\n...name:", state["name"])
print("...birthday:", state["birthday"])
print("...job:", state["job"])
"""
Usage:
***Note: for speculative execution to work, user must put all "gen" in "assistant". Show in "assistant" the desired answer format. Each "gen" term should have a stop token. The stream mode is not supported in speculative execution.
E.g.
correct:
sgl.assistant("\nName:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
incorrect:
s += sgl.assistant("\nName:" + sgl.gen("name", stop="\n"))
s += sgl.assistant("\nBirthday:" + sgl.gen("birthday", stop="\n"))
s += sgl.assistant("\nJob:" + sgl.gen("job", stop="\n"))
export OPENAI_API_KEY=sk-******
python3 openaichat_speculative.py
"""
import sglang as sgl
from sglang import function, gen, set_default_backend, OpenAI
@function(api_num_spec_tokens=512)
def gen_character_spec(s):
s += sgl.system("You are a helpful assistant.")
s += sgl.user("Construct a character within the following format:")
s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n")
s += sgl.user("Please generate new Name, Birthday and Job.\n")
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
@function(api_num_spec_tokens=512)
def gen_character_spec_no_few_shot(s):
s += sgl.user("Construct a character. For each field stop with a newline\n")
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nAge:" + sgl.gen("age", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
@function
def gen_character_normal(s):
s += sgl.system("You are a helpful assistant.")
s += sgl.user("What's the answer of 23 + 8?")
s += sgl.assistant(sgl.gen("answer", max_tokens=64))
@function(api_num_spec_tokens=1024)
def multi_turn_question(s, question_1, question_2):
s += sgl.system("You are a helpful assistant.")
s += sgl.user("Answer questions in the following format:")
s += sgl.user("Question 1: What is the capital of France?\nQuestion 2: What is the population of this city?\n")
s += sgl.assistant("Answer 1: The capital of France is Paris.\nAnswer 2: The population of Paris in 2024 is estimated to be around 2.1 million for the city proper.\n")
s += sgl.user("Question 1: "+question_1+"\nQuestion 2: "+question_2)
s += sgl.assistant("Answer 1: "+sgl.gen("answer_1", stop="\n") + "\nAnswer 2: "+ sgl.gen("answer_2", stop="\n"))
def test_spec_single_turn():
state = gen_character_spec.run()
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- name:", state["name"])
print("\n-- birthday:", state["birthday"])
print("\n-- job:", state["job"])
def test_inaccurate_spec_single_turn():
state = gen_character_spec_no_few_shot.run()
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- name:", state["name"])
print("\n-- age:", state["age"])
print("\n-- job:", state["job"])
def test_normal_single_turn():
state = gen_character_normal.run()
for m in state.messages():
print(m["role"], ":", m["content"])
def test_spec_multi_turn():
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions in the capital of the United States.",
)
for m in state.messages():
print(m["role"], ":", m["content"])
print("\n-- answer_1 --\n", state["answer_1"])
print("\n-- answer_2 --\n", state["answer_2"])
def test_spec_multi_turn_stream():
state = multi_turn_question.run(
question_1="What is the capital of the United States?",
question_2="List two local attractions.",
stream=True
)
for out in state.text_iter():
print(out, end="", flush=True)
if __name__ == "__main__":
set_default_backend(OpenAI("gpt-4-turbo"))
print("\n========== test spec single turn ==========\n")
# expect reasonable answer for each field
test_spec_single_turn()
print("\n========== test inaccurate spec single turn ==========\n")
# expect incomplete or unreasonable answers
test_inaccurate_spec_single_turn()
print("\n========== test normal single turn ==========\n")
# expect reasonable answer
test_normal_single_turn()
print("\n========== test spec multi turn ==========\n")
# expect answer with same format as in the few shot
test_spec_multi_turn()
print("\n========== test spec multi turn stream ==========\n")
# expect error in stream_executor: stream is not supported...
test_spec_multi_turn_stream()
...@@ -9,6 +9,7 @@ class BaseBackend: ...@@ -9,6 +9,7 @@ class BaseBackend:
def __init__(self) -> None: def __init__(self) -> None:
self.support_concate_and_append = False self.support_concate_and_append = False
self.chat_template = get_chat_template("default") self.chat_template = get_chat_template("default")
self.api_num_spec_tokens = None
def get_model_name(self): def get_model_name(self):
raise NotImplementedError() raise NotImplementedError()
......
import logging import logging
import time import time
import warnings
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
...@@ -80,7 +81,15 @@ class OpenAI(BaseBackend): ...@@ -80,7 +81,15 @@ class OpenAI(BaseBackend):
else: else:
self.is_chat_model = True self.is_chat_model = True
self.chat_begin_str = self.chat_template.role_prefix_and_suffix["assistant"][0] self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
self.spec_kwargs = {}
self.spec_format = []
self.spec_max_num_tries = 3
self.api_num_spec_tokens = None
def set_api_num_spec_tokens(self, num):
self.api_num_spec_tokens = num
def get_chat_template(self): def get_chat_template(self):
return self.chat_template return self.chat_template
...@@ -89,15 +98,45 @@ class OpenAI(BaseBackend): ...@@ -89,15 +98,45 @@ class OpenAI(BaseBackend):
self, self,
s: StreamExecutor, s: StreamExecutor,
sampling_params: SglSamplingParams, sampling_params: SglSamplingParams,
name=None,
): ):
if sampling_params.dtype is None: if sampling_params.dtype is None:
if self.is_chat_model: if self.is_chat_model:
if not s.text_.endswith(self.chat_begin_str): if self.api_num_spec_tokens is None:
raise RuntimeError( if not s.text_.endswith(self.chat_prefix):
"This use case is not supported. " raise RuntimeError(
"For OpenAI chat models, sgl.gen must be right after sgl.assistant" "This use case is not supported if api speculative execution is off. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant."
"Example of adding api speculative execution: @function(api_num_spec_tokens=128)."
)
prompt = s.messages_
else:
# collect assistant answer format
if "max_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_tokens"] = self.api_num_spec_tokens
else:
assert (
self.spec_kwargs["max_tokens"] == self.api_num_spec_tokens
)
params = sampling_params.to_openai_kwargs()
for key, value in params.items():
if key in ["stop"]:
continue
if key in ["max_tokens"]:
warnings.warn(
"The parameter max_tokens will be overwritten by speculated number of tokens."
)
continue
if key not in self.spec_kwargs:
self.spec_kwargs[key] = value
else:
assert (
value == self.spec_kwargs[key]
), "sampling parameters should be consistent if turn on api speculative execution."
self.spec_format.append(
{"text": "", "stop": params["stop"], "name": name}
) )
prompt = s.messages_ return "", {}
else: else:
prompt = s.text_ prompt = s.text_
...@@ -110,6 +149,9 @@ class OpenAI(BaseBackend): ...@@ -110,6 +149,9 @@ class OpenAI(BaseBackend):
**kwargs, **kwargs,
) )
elif sampling_params.dtype in [str, "str", "string"]: elif sampling_params.dtype in [str, "str", "string"]:
assert (
not self.is_chat_model
), "constrained type not supported on chat model"
kwargs = sampling_params.to_openai_kwargs() kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop") kwargs.pop("stop")
comp = openai_completion( comp = openai_completion(
...@@ -122,6 +164,9 @@ class OpenAI(BaseBackend): ...@@ -122,6 +164,9 @@ class OpenAI(BaseBackend):
) )
comp = '"' + comp + '"' comp = '"' + comp + '"'
elif sampling_params.dtype in [int, "int"]: elif sampling_params.dtype in [int, "int"]:
assert (
not self.is_chat_model
), "constrained type not supported on chat model"
kwargs = sampling_params.to_openai_kwargs() kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop") kwargs.pop("stop")
comp = openai_completion( comp = openai_completion(
...@@ -138,6 +183,62 @@ class OpenAI(BaseBackend): ...@@ -138,6 +183,62 @@ class OpenAI(BaseBackend):
return comp, {} return comp, {}
def spec_fill(self, value: str):
assert self.is_chat_model
self.spec_format.append({"text": value, "stop": None, "name": None})
def spec_pattern_match(self, comp):
for i, term in enumerate(self.spec_format):
text = term["text"]
if text != "":
if comp.startswith(text):
comp = comp[len(text) :]
else:
return False
else:
pos = comp.find(term["stop"])
if pos != -1:
term["text"] = comp[:pos]
comp = comp[pos:]
else:
if i == len(self.spec_format) - 1:
term["text"] = comp
else:
return False
return True
def role_end_generate(
self,
s: StreamExecutor,
):
if self.api_num_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
return
comp = ""
if not all(x["name"] is None for x in self.spec_format):
for i in range(self.spec_max_num_tries):
comp = openai_completion(
client=self.client,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.messages_,
**self.spec_kwargs,
)
if self.spec_pattern_match(comp):
break
for term in self.spec_format:
stop = term["stop"] if term["stop"] is not None else ""
s.text_ += term["text"]
name = term["name"]
if name is not None:
s.variables[name] = term["text"]
s.meta_info[name] = {}
s.variable_event[name].set()
self.spec_kwargs = {}
self.spec_format = []
def generate_stream( def generate_stream(
self, self,
s: StreamExecutor, s: StreamExecutor,
...@@ -145,7 +246,7 @@ class OpenAI(BaseBackend): ...@@ -145,7 +246,7 @@ class OpenAI(BaseBackend):
): ):
if sampling_params.dtype is None: if sampling_params.dtype is None:
if self.is_chat_model: if self.is_chat_model:
if not s.text_.endswith(self.chat_begin_str): if not s.text_.endswith(self.chat_prefix):
raise RuntimeError( raise RuntimeError(
"This use case is not supported. " "This use case is not supported. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant" "For OpenAI chat models, sgl.gen must be right after sgl.assistant"
......
...@@ -266,4 +266,4 @@ class RuntimeEndpoint(BaseBackend): ...@@ -266,4 +266,4 @@ class RuntimeEndpoint(BaseBackend):
def _assert_success(self, res): def _assert_success(self, res):
if res.status_code != 200: if res.status_code != 200:
raise RuntimeError(res.json()) raise RuntimeError(res.json())
\ No newline at end of file
...@@ -6,6 +6,7 @@ import multiprocessing ...@@ -6,6 +6,7 @@ import multiprocessing
import queue import queue
import threading import threading
import uuid import uuid
import warnings
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
...@@ -185,7 +186,6 @@ class StreamExecutor: ...@@ -185,7 +186,6 @@ class StreamExecutor:
self.arguments: Dict[str, Any] = arguments self.arguments: Dict[str, Any] = arguments
self.default_sampling_para = default_sampling_para self.default_sampling_para = default_sampling_para
self.stream = stream self.stream = stream
self.api_num_spec_tokens = api_num_spec_tokens
self.variables = {} # Dict[name: str -> value: str] self.variables = {} # Dict[name: str -> value: str]
self.variable_event = {} # Dict[name: str -> event: threading.Event] self.variable_event = {} # Dict[name: str -> event: threading.Event]
...@@ -197,6 +197,9 @@ class StreamExecutor: ...@@ -197,6 +197,9 @@ class StreamExecutor:
self.text_ = "" # The full text self.text_ = "" # The full text
# For speculative execution # For speculative execution
from sglang.backend.openai import OpenAI
if isinstance(backend, OpenAI):
self.backend.set_api_num_spec_tokens(api_num_spec_tokens)
self.speculated_text = "" self.speculated_text = ""
# For chat # For chat
...@@ -322,7 +325,7 @@ class StreamExecutor: ...@@ -322,7 +325,7 @@ class StreamExecutor:
try: try:
self._execute(expr) self._execute(expr)
except Exception as e: except Exception as e:
# print(f"Error in stream_executor: {get_exception_traceback()}") warnings.warn(f"Error in stream_executor: {get_exception_traceback()}")
error = e error = e
break break
self.queue.task_done() self.queue.task_done()
...@@ -391,12 +394,23 @@ class StreamExecutor: ...@@ -391,12 +394,23 @@ class StreamExecutor:
else: else:
raise ValueError(f"Unknown type: {type(other)}") raise ValueError(f"Unknown type: {type(other)}")
def _execute_fill(self, value: str): def _execute_fill(self, value: str, prefix=False):
value = str(value) value = str(value)
if (
self.cur_role == "assistant"
and self.backend.api_num_spec_tokens is not None
and self.backend.is_chat_model
and not prefix
):
self.backend.spec_fill(value)
return
if self.speculated_text.startswith(value): if self.speculated_text.startswith(value):
self.speculated_text = self.speculated_text[len(value) :] self.speculated_text = self.speculated_text[len(value) :]
else: else:
self.speculated_text = "" self.speculated_text = ""
self.text_ += value self.text_ += value
def _execute_image(self, expr: SglImage): def _execute_image(self, expr: SglImage):
...@@ -426,14 +440,29 @@ class StreamExecutor: ...@@ -426,14 +440,29 @@ class StreamExecutor:
name = expr.name name = expr.name
if not self.stream: if not self.stream:
if self.api_num_spec_tokens is not None: if self.backend.api_num_spec_tokens is None:
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
)
elif self.backend.is_chat_model:
# spec on model with only chat interface
comp, meta_info = self.backend.generate(
self,
sampling_params=sampling_params,
name=name,
)
return
else: # spec on model with completion
stop = sampling_params.stop stop = sampling_params.stop
max_new_tokens = sampling_params.max_new_tokens max_new_tokens = sampling_params.max_new_tokens
meta_info = {} meta_info = {}
def regen(): def regen():
sampling_params.max_new_tokens = max( sampling_params.max_new_tokens = max(
sampling_params.max_new_tokens, self.api_num_spec_tokens sampling_params.max_new_tokens, self.backend.api_num_spec_tokens
) )
sampling_params.stop = None sampling_params.stop = None
self.speculated_text, meta_info = self.backend.generate( self.speculated_text, meta_info = self.backend.generate(
...@@ -442,16 +471,14 @@ class StreamExecutor: ...@@ -442,16 +471,14 @@ class StreamExecutor:
def find_stop(): def find_stop():
if isinstance(stop, str): if isinstance(stop, str):
return self.speculated_text.find(stop), len(stop) return self.speculated_text.find(stop)
elif isinstance(stop, (tuple, list)): elif isinstance(stop, (tuple, list)):
pos = -1 pos = -1
stop_len = 0
for stop_str in stop: for stop_str in stop:
stop_pos = self.speculated_text.find(stop_str) stop_pos = self.speculated_text.find(stop_str)
if stop_pos != -1 and (pos == -1 or stop_pos < pos): if stop_pos != -1 and (pos == -1 or stop_pos < pos):
pos = stop_pos pos = stop_pos
stop_len = len(stop_str) return pos
return pos, stop_len
else: else:
raise Exception("Wrong type of stop in sampling parameters.") raise Exception("Wrong type of stop in sampling parameters.")
...@@ -463,23 +490,16 @@ class StreamExecutor: ...@@ -463,23 +490,16 @@ class StreamExecutor:
elif isinstance(stop, (str, list, tuple)): elif isinstance(stop, (str, list, tuple)):
if self.speculated_text == "": if self.speculated_text == "":
regen() regen()
stop_pos, stop_len = find_stop() stop_pos = find_stop()
if stop_pos == -1: if stop_pos == -1:
stop_pos, stop_len = ( stop_pos = min(
min( sampling_params.max_new_tokens,
sampling_params.max_new_tokens, len(self.speculated_text),
len(self.speculated_text),
),
0,
) )
comp = self.speculated_text[:stop_pos] comp = self.speculated_text[:stop_pos]
self.speculated_text = self.speculated_text[stop_pos:] self.speculated_text = self.speculated_text[stop_pos:]
else: else:
raise ValueError("Wrong type of stop in sampling parameters.") raise ValueError("Wrong type of stop in sampling parameters.")
else:
comp, meta_info = self.backend.generate(
self, sampling_params=sampling_params
)
self.text_ += comp self.text_ += comp
...@@ -487,6 +507,9 @@ class StreamExecutor: ...@@ -487,6 +507,9 @@ class StreamExecutor:
self.meta_info[name] = meta_info self.meta_info[name] = meta_info
self.variable_event[name].set() self.variable_event[name].set()
else: else:
assert (
self.backend.api_num_spec_tokens is None
), "stream is not supported with api speculative execution"
generator = self.backend.generate_stream( generator = self.backend.generate_stream(
self, sampling_params=sampling_params self, sampling_params=sampling_params
) )
...@@ -542,10 +565,18 @@ class StreamExecutor: ...@@ -542,10 +565,18 @@ class StreamExecutor:
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_) prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
self._execute_fill(prefix) self._execute_fill(prefix, prefix=True)
self.cur_role_begin_pos = len(self.text_) self.cur_role_begin_pos = len(self.text_)
def _execute_role_end(self, expr: SglRoleEnd): def _execute_role_end(self, expr: SglRoleEnd):
if (
self.cur_role == "assistant"
and self.backend.api_num_spec_tokens is not None
and self.backend.is_chat_model
):
self.backend.role_end_generate(self)
self.cur_role = None
new_text = self.text_[self.cur_role_begin_pos :].lstrip() new_text = self.text_[self.cur_role_begin_pos :].lstrip()
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_) _, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
...@@ -572,8 +603,6 @@ class StreamExecutor: ...@@ -572,8 +603,6 @@ class StreamExecutor:
# OpenAI chat API format # OpenAI chat API format
self.messages_.append({"role": expr.role, "content": new_text}) self.messages_.append({"role": expr.role, "content": new_text})
self.cur_role = None
def _execute_var_scope_begin(self, expr: SglVarScopeBegin): def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
self.variables[expr.name] = int(len(self.text_)) self.variables[expr.name] = int(len(self.text_))
......
...@@ -31,8 +31,9 @@ class GenerateReqInput: ...@@ -31,8 +31,9 @@ class GenerateReqInput:
def post_init(self): def post_init(self):
if ((self.text is None and self.input_ids is None) or if (self.text is None and self.input_ids is None) or (
(self.text is not None and self.input_ids is not None)): self.text is not None and self.input_ids is not None
):
raise ValueError("Either text or input_ids should be provided.") raise ValueError("Either text or input_ids should be provided.")
if self.text is not None: if self.text is not None:
......
...@@ -38,7 +38,6 @@ from sglang.srt.utils import ( ...@@ -38,7 +38,6 @@ from sglang.srt.utils import (
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger("model_rpc") logger = logging.getLogger("model_rpc")
vllm_default_logger.setLevel(logging.WARN) vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN) logging.getLogger("vllm.utils").setLevel(logging.WARN)
......
...@@ -341,7 +341,6 @@ class TokenizerManager: ...@@ -341,7 +341,6 @@ class TokenizerManager:
return top_logprobs return top_logprobs
global global_processor global global_processor
...@@ -385,4 +384,4 @@ def get_pixel_values( ...@@ -385,4 +384,4 @@ def get_pixel_values(
pixel_values = pixel_values.astype(np.float16) pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size return pixel_values, image_hash, image.size
except Exception: except Exception:
print("Exception in TokenizerManager:\n" + get_exception_traceback()) print("Exception in TokenizerManager:\n" + get_exception_traceback())
\ No newline at end of file
...@@ -9,8 +9,8 @@ import os ...@@ -9,8 +9,8 @@ import os
import sys import sys
import threading import threading
import time import time
from typing import List, Optional, Union
from http import HTTPStatus from http import HTTPStatus
from typing import List, Optional, Union
# Fix a bug of Python threading # Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
...@@ -45,7 +45,6 @@ from sglang.srt.utils import ( ...@@ -45,7 +45,6 @@ from sglang.srt.utils import (
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -84,6 +83,7 @@ async def flush_cache(): ...@@ -84,6 +83,7 @@ async def flush_cache():
async def generate_request(obj: GenerateReqInput, request: Request): async def generate_request(obj: GenerateReqInput, request: Request):
if obj.stream: if obj.stream:
async def stream_results(): async def stream_results():
try: try:
async for out in tokenizer_manager.generate_request(obj, request): async for out in tokenizer_manager.generate_request(obj, request):
...@@ -99,8 +99,10 @@ async def generate_request(obj: GenerateReqInput, request: Request): ...@@ -99,8 +99,10 @@ async def generate_request(obj: GenerateReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__() ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret return ret
except ValueError as e: except ValueError as e:
return JSONResponse({"error": {"message": str(e)}}, return JSONResponse(
status_code=HTTPStatus.BAD_REQUEST) {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
app.post("/generate")(generate_request) app.post("/generate")(generate_request)
app.put("/generate")(generate_request) app.put("/generate")(generate_request)
......
...@@ -19,7 +19,6 @@ from packaging import version as pkg_version ...@@ -19,7 +19,6 @@ from packaging import version as pkg_version
from pydantic import BaseModel from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -157,7 +156,9 @@ def allocate_init_ports( ...@@ -157,7 +156,9 @@ def allocate_init_ports(
cur_port += 1 cur_port += 1
if port and ret_ports[0] != port: if port and ret_ports[0] != port:
logger.warn(f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead.") logger.warn(
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
)
return ret_ports[0], ret_ports[1:] return ret_ports[0], ret_ports[1:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment