"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9062b2847d0ab412ed12b9bd5590779dda28d6b2"
Unverified Commit 007eeb4e authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix the error message and dependency of openai backend (#71)

parent e8f2b155
......@@ -164,7 +164,8 @@ def image_qa(s, image_file, question):
```
### Constrained Decoding
Use `regex=` to specify a regular expression as a decoding constraint.
Use `regex` to specify a regular expression as a decoding constraint.
This is only supported for local models.
```python
@sgl.function
......
......@@ -18,10 +18,11 @@ dependencies = [
]
[project.optional-dependencies]
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
"interegular", "lark", "numba", "pydantic", "diskcache", "cloudpickle"]
openai = ["openai>=1.0"]
anthropic = ["anthropic"]
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
"zmq", "vllm>=0.2.5", "interegular", "lark", "numba",
"pydantic", "diskcache", "cloudpickle"]
openai = ["openai>=1.0", "numpy"]
anthropic = ["anthropic", "numpy"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
[project.urls]
......
......@@ -77,7 +77,9 @@ class OpenAI(BaseBackend):
):
if sampling_params.dtype is None:
if self.is_chat_model:
assert s.text_.endswith("ASSISTANT:")
if not s.text_.endswith("ASSISTANT:"):
raise RuntimeError("This use case is not supported. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant")
prompt = s.messages_
else:
prompt = s.text_
......@@ -149,6 +151,12 @@ class OpenAI(BaseBackend):
choices: List[str],
temperature: float,
):
if self.is_chat_model:
raise NotImplementedError(
"select/choices is not supported for chat models. "
"Please try to use a non-chat model such as gpt-3.5-turbo-instruct"
)
n_choices = len(choices)
token_ids = [self.tokenizer.encode(x) for x in choices]
scores = [0] * n_choices
......
......@@ -197,16 +197,7 @@ class StreamExecutor:
self.stream_var_event = None
def submit(self, expr: SglExpr):
if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)):
self.variable_event[expr.name] = threading.Event()
if self.stream:
self.stream_var_event[expr.name] = threading.Event()
elif isinstance(expr, SglExprList):
for e in expr.expr_list:
if isinstance(e, (SglGen, SglSelect, SglVarScopeBegin)):
self.variable_event[e.name] = threading.Event()
if self.stream:
self.stream_var_event[e.name] = threading.Event()
self._init_var_event(expr)
if self.use_thread:
self.queue.put(expr)
......@@ -467,6 +458,15 @@ class StreamExecutor:
src_rids = [state.stream_executor.sid for state in expr.states]
self.backend.concatenate_and_append(src_rids, self.sid)
def _init_var_event(self, expr):
if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)):
self.variable_event[expr.name] = threading.Event()
if self.stream:
self.stream_var_event[expr.name] = threading.Event()
elif isinstance(expr, SglExprList):
for e in expr.expr_list:
self._init_var_event(e)
def _resolve_sampling_params(self, sampling_params):
clone = None
for item in [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment