Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
a41b195f
Commit
a41b195f
authored
Apr 21, 2024
by
Timothy J. Baek
Browse files
DO NOT TRACK ME >:(
parent
5e458d49
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
119 additions
and
66 deletions
+119
-66
backend/apps/litellm/main.py
backend/apps/litellm/main.py
+119
-66
No files found.
backend/apps/litellm/main.py
View file @
a41b195f
from
fastapi
import
FastAPI
,
Depends
from
fastapi
import
FastAPI
,
Depends
,
HTTPException
from
fastapi.routing
import
APIRoute
from
fastapi.routing
import
APIRoute
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
...
@@ -9,9 +9,11 @@ from fastapi.responses import JSONResponse
...
@@ -9,9 +9,11 @@ from fastapi.responses import JSONResponse
from
starlette.middleware.base
import
BaseHTTPMiddleware
,
RequestResponseEndpoint
from
starlette.middleware.base
import
BaseHTTPMiddleware
,
RequestResponseEndpoint
from
starlette.responses
import
StreamingResponse
from
starlette.responses
import
StreamingResponse
import
json
import
json
import
requests
from
utils.utils
import
get_
http_authorization_cred
,
get_current_user
from
utils.utils
import
get_
verified_user
,
get_current_user
from
config
import
SRC_LOG_LEVELS
,
ENV
from
config
import
SRC_LOG_LEVELS
,
ENV
from
constants
import
ERROR_MESSAGES
log
=
logging
.
getLogger
(
__name__
)
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"LITELLM"
])
log
.
setLevel
(
SRC_LOG_LEVELS
[
"LITELLM"
])
...
@@ -49,12 +51,13 @@ async def run_background_process(command):
...
@@ -49,12 +51,13 @@ async def run_background_process(command):
async
def
start_litellm_background
():
async
def
start_litellm_background
():
# Command to run in the background
# Command to run in the background
command
=
"litellm --config ./data/litellm/config.yaml"
command
=
"litellm
--telemetry False
--config ./data/litellm/config.yaml"
await
run_background_process
(
command
)
await
run_background_process
(
command
)
@
app
.
on_event
(
"startup"
)
@
app
.
on_event
(
"startup"
)
async
def
startup_event
():
async
def
startup_event
():
# TODO: Check config.yaml file and create one
asyncio
.
create_task
(
start_litellm_background
())
asyncio
.
create_task
(
start_litellm_background
())
...
@@ -62,82 +65,132 @@ app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
...
@@ -62,82 +65,132 @@ app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app
.
state
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
@
app
.
middleware
(
"http"
)
async
def
auth_middleware
(
request
:
Request
,
call_next
):
auth_header
=
request
.
headers
.
get
(
"Authorization"
,
""
)
request
.
state
.
user
=
None
try
:
user
=
get_current_user
(
get_http_authorization_cred
(
auth_header
))
log
.
debug
(
f
"user:
{
user
}
"
)
request
.
state
.
user
=
user
except
Exception
as
e
:
return
JSONResponse
(
status_code
=
400
,
content
=
{
"detail"
:
str
(
e
)})
response
=
await
call_next
(
request
)
return
response
@
app
.
get
(
"/"
)
@
app
.
get
(
"/"
)
async
def
get_status
():
async
def
get_status
():
return
{
"status"
:
True
}
return
{
"status"
:
True
}
class
ModifyModelsResponseMiddleware
(
BaseHTTPMiddleware
):
@
app
.
get
(
"/models"
)
async
def
dispatch
(
@
app
.
get
(
"/v1/models"
)
self
,
request
:
Request
,
call_next
:
RequestResponseEndpoint
async
def
get_models
(
user
=
Depends
(
get_current_user
)):
)
->
Response
:
url
=
"http://localhost:4000/v1"
r
=
None
response
=
await
call_next
(
request
)
try
:
user
=
request
.
state
.
user
r
=
requests
.
request
(
method
=
"GET"
,
url
=
f
"
{
url
}
/models"
)
r
.
raise_for_status
()
if
"/models"
in
request
.
url
.
path
:
if
isinstance
(
response
,
StreamingResponse
):
# Read the content of the streaming response
body
=
b
""
async
for
chunk
in
response
.
body_iterator
:
body
+=
chunk
data
=
json
.
loads
(
body
.
decode
(
"utf-8"
)
)
data
=
r
.
json
(
)
if
app
.
state
.
MODEL_FILTER_ENABLED
:
if
app
.
state
.
MODEL_FILTER_ENABLED
:
if
user
and
user
.
role
==
"user"
:
if
user
and
user
.
role
==
"user"
:
data
[
"data"
]
=
list
(
data
[
"data"
]
=
list
(
filter
(
filter
(
lambda
model
:
model
[
"id"
]
lambda
model
:
model
[
"id"
]
in
app
.
state
.
MODEL_FILTER_LIST
,
in
app
.
state
.
MODEL_FILTER_LIST
,
data
[
"data"
],
data
[
"data"
],
)
)
)
)
# Modified Flag
return
data
data
[
"modified"
]
=
True
except
Exception
as
e
:
return
JSONResponse
(
content
=
data
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"External:
{
res
[
'error'
]
}
"
except
:
error_detail
=
f
"External:
{
e
}
"
raise
HTTPException
(
status_code
=
r
.
status_code
if
r
else
500
,
detail
=
error_detail
,
)
return
response
@
app
.
api_route
(
"/{path:path}"
,
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
async
def
proxy
(
path
:
str
,
request
:
Request
,
user
=
Depends
(
get_verified_user
)):
body
=
await
request
.
body
()
app
.
add_middleware
(
ModifyModelsResponseMiddleware
)
url
=
"http://localhost:4000/v1"
target_url
=
f
"
{
url
}
/
{
path
}
"
# from litellm.proxy.proxy_server import ProxyConfig, initialize
headers
=
{}
# from litellm.proxy.proxy_server import app
# headers["Authorization"] = f"Bearer {key}"
headers
[
"Content-Type"
]
=
"application/json"
# proxy_config = ProxyConfig()
r
=
None
try
:
r
=
requests
.
request
(
method
=
request
.
method
,
url
=
target_url
,
data
=
body
,
headers
=
headers
,
stream
=
True
,
)
r
.
raise_for_status
()
# Check if response is SSE
if
"text/event-stream"
in
r
.
headers
.
get
(
"Content-Type"
,
""
):
return
StreamingResponse
(
r
.
iter_content
(
chunk_size
=
8192
),
status_code
=
r
.
status_code
,
headers
=
dict
(
r
.
headers
),
)
else
:
response_data
=
r
.
json
()
return
response_data
except
Exception
as
e
:
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"External:
{
res
[
'error'
][
'message'
]
if
'message'
in
res
[
'error'
]
else
res
[
'error'
]
}
"
except
:
error_detail
=
f
"External:
{
e
}
"
raise
HTTPException
(
status_code
=
r
.
status_code
if
r
else
500
,
detail
=
error_detail
)
# async def config():
# router, model_list, general_settings = await proxy_config.load_config(
# router=None, config_file_path="./data/litellm/config.yaml"
# )
# await initialize(config="./data/litellm/config.yaml", telemetry=False)
# class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
# async def dispatch(
# self, request: Request, call_next: RequestResponseEndpoint
# ) -> Response:
# response = await call_next(request)
# user = request.state.user
# if "/models" in request.url.path:
# if isinstance(response, StreamingResponse):
# # Read the content of the streaming response
# body = b""
# async for chunk in response.body_iterator:
# body += chunk
# data = json.loads(body.decode("utf-8"))
# if app.state.MODEL_FILTER_ENABLED:
# if user and user.role == "user":
# data["data"] = list(
# filter(
# lambda model: model["id"]
# in app.state.MODEL_FILTER_LIST,
# data["data"],
# )
# )
# # Modified Flag
# data["modified"] = True
# return JSONResponse(content=data)
# async def startup():
# return response
# await config()
# @app.on_event("startup")
# app.add_middleware(ModifyModelsResponseMiddleware)
# async def on_startup():
# await startup()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment