"vscode:/vscode.git/clone" did not exist on "4ca2c358b178d5e026db925a1ed9f8945010a98f"
Unverified Commit 87260b7b authored by 胡译文's avatar 胡译文 Committed by GitHub
Browse files

Litellm Backend (#502)

parent 651a23ee
...@@ -23,7 +23,8 @@ srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", ...@@ -23,7 +23,8 @@ srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
"zmq", "vllm==0.4.3", "interegular", "pydantic", "pillow", "packaging", "huggingface_hub", "hf_transfer", "outlines>=0.0.34"] "zmq", "vllm==0.4.3", "interegular", "pydantic", "pillow", "packaging", "huggingface_hub", "hf_transfer", "outlines>=0.0.34"]
openai = ["openai>=1.0", "numpy", "tiktoken"] openai = ["openai>=1.0", "numpy", "tiktoken"]
anthropic = ["anthropic>=0.20.0", "numpy"] anthropic = ["anthropic>=0.20.0", "numpy"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"] litellm = ["litellm>=1.0.0"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
[project.urls] [project.urls]
"Homepage" = "https://github.com/sgl-project/sglang" "Homepage" = "https://github.com/sgl-project/sglang"
......
...@@ -27,6 +27,7 @@ from sglang.backend.anthropic import Anthropic ...@@ -27,6 +27,7 @@ from sglang.backend.anthropic import Anthropic
from sglang.backend.openai import OpenAI from sglang.backend.openai import OpenAI
from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.backend.vertexai import VertexAI from sglang.backend.vertexai import VertexAI
from sglang.backend.litellm import LiteLLM
# Global Configurations # Global Configurations
from sglang.global_config import global_config from sglang.global_config import global_config
...@@ -35,6 +36,7 @@ from sglang.global_config import global_config ...@@ -35,6 +36,7 @@ from sglang.global_config import global_config
__all__ = [ __all__ = [
"global_config", "global_config",
"Anthropic", "Anthropic",
"LiteLLM",
"OpenAI", "OpenAI",
"RuntimeEndpoint", "RuntimeEndpoint",
"VertexAI", "VertexAI",
......
from typing import Mapping, Optional
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
try:
import litellm
except ImportError as e:
litellm = e
class LiteLLM(BaseBackend):
def __init__(
self,
model_name,
chat_template=None,
api_key=None,
organization: Optional[str] = None,
base_url: Optional[str] = None,
timeout: Optional[float] = 600,
max_retries: Optional[int] = litellm.num_retries,
default_headers: Optional[Mapping[str, str]] = None,
):
super().__init__()
if isinstance(litellm, Exception):
raise litellm
self.model_name = model_name
self.chat_template = chat_template or get_chat_template_by_model_path(
model_name)
self.client_params = {
"api_key": api_key,
"organization": organization,
"base_url": base_url,
"timeout": timeout,
"max_retries": max_retries,
"default_headers": default_headers,
}
def get_chat_template(self):
return self.chat_template
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
messages = s.messages_
else:
messages = [{"role": "user", "content": s.text_}]
ret = litellm.completion(
model=self.model_name,
messages=messages,
**self.client_params,
**sampling_params.to_anthropic_kwargs(),
)
comp = ret.choices[0].message.content
return comp, {}
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
messages = s.messages_
else:
messages = [{"role": "user", "content": s.text_}]
ret = litellm.completion(
model=self.model_name,
messages=messages,
stream=True,
**self.client_params,
**sampling_params.to_litellm_kwargs(),
)
for chunk in ret:
text = chunk.choices[0].delta.content
if text is not None:
yield text, {}
...@@ -82,6 +82,21 @@ class SglSamplingParams: ...@@ -82,6 +82,21 @@ class SglSamplingParams:
"top_k": self.top_k, "top_k": self.top_k,
} }
def to_litellm_kwargs(self):
if self.regex is not None:
warnings.warn(
"Regular expression is not supported in the LiteLLM backend."
)
return {
"max_tokens": self.max_new_tokens,
"stop": self.stop or None,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
}
def to_srt_kwargs(self): def to_srt_kwargs(self):
return { return {
"max_new_tokens": self.max_new_tokens, "max_new_tokens": self.max_new_tokens,
......
import json
import unittest
from sglang import LiteLLM, set_default_backend
from sglang.test.test_programs import test_mt_bench, test_stream
class TestAnthropicBackend(unittest.TestCase):
backend = None
chat_backend = None
def setUp(self):
cls = type(self)
if cls.backend is None:
cls.backend = LiteLLM("gpt-3.5-turbo")
set_default_backend(cls.backend)
def test_mt_bench(self):
test_mt_bench()
def test_stream(self):
test_stream()
if __name__ == "__main__":
unittest.main(warnings="ignore")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment