Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
93c90dc1
Commit
93c90dc1
authored
Mar 20, 2024
by
Timothy J. Baek
Browse files
feat: litellm model filter support
parent
8cb7127f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
65 additions
and
2 deletions
+65
-2
backend/apps/litellm/main.py
backend/apps/litellm/main.py
+64
-1
backend/config.py
backend/config.py
+1
-1
No files found.
backend/apps/litellm/main.py
View file @
93c90dc1
from
litellm.proxy.proxy_server
import
ProxyConfig
,
initialize
from
litellm.proxy.proxy_server
import
ProxyConfig
,
initialize
from
litellm.proxy.proxy_server
import
app
from
litellm.proxy.proxy_server
import
app
from
fastapi
import
FastAPI
,
Request
,
Depends
,
status
from
fastapi
import
FastAPI
,
Request
,
Depends
,
status
,
Response
from
fastapi.responses
import
JSONResponse
from
fastapi.responses
import
JSONResponse
from
starlette.middleware.base
import
BaseHTTPMiddleware
,
RequestResponseEndpoint
from
starlette.responses
import
StreamingResponse
import
json
from
utils.utils
import
get_http_authorization_cred
,
get_current_user
from
utils.utils
import
get_http_authorization_cred
,
get_current_user
from
config
import
ENV
from
config
import
ENV
from
config
import
(
MODEL_FILTER_ENABLED
,
MODEL_FILTER_LIST
,
)
proxy_config
=
ProxyConfig
()
proxy_config
=
ProxyConfig
()
...
@@ -26,16 +38,67 @@ async def on_startup():
...
@@ -26,16 +38,67 @@ async def on_startup():
await
startup
()
await
startup
()
app
.
state
.
MODEL_FILTER_ENABLED
=
MODEL_FILTER_ENABLED
app
.
state
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
@
app
.
middleware
(
"http"
)
@
app
.
middleware
(
"http"
)
async
def
auth_middleware
(
request
:
Request
,
call_next
):
async
def
auth_middleware
(
request
:
Request
,
call_next
):
auth_header
=
request
.
headers
.
get
(
"Authorization"
,
""
)
auth_header
=
request
.
headers
.
get
(
"Authorization"
,
""
)
request
.
state
.
user
=
None
if
ENV
!=
"dev"
:
if
ENV
!=
"dev"
:
try
:
try
:
user
=
get_current_user
(
get_http_authorization_cred
(
auth_header
))
user
=
get_current_user
(
get_http_authorization_cred
(
auth_header
))
print
(
user
)
print
(
user
)
request
.
state
.
user
=
user
except
Exception
as
e
:
except
Exception
as
e
:
return
JSONResponse
(
status_code
=
400
,
content
=
{
"detail"
:
str
(
e
)})
return
JSONResponse
(
status_code
=
400
,
content
=
{
"detail"
:
str
(
e
)})
response
=
await
call_next
(
request
)
response
=
await
call_next
(
request
)
return
response
return
response
class
ModifyModelsResponseMiddleware
(
BaseHTTPMiddleware
):
async
def
dispatch
(
self
,
request
:
Request
,
call_next
:
RequestResponseEndpoint
)
->
Response
:
response
=
await
call_next
(
request
)
user
=
request
.
state
.
user
# Check if the request is for the `/models` route
if
"/models"
in
request
.
url
.
path
:
# Ensure the response is a StreamingResponse
if
isinstance
(
response
,
StreamingResponse
):
# Read the content of the streaming response
body
=
b
""
async
for
chunk
in
response
.
body_iterator
:
body
+=
chunk
# Modify the content as needed
data
=
json
.
loads
(
body
.
decode
(
"utf-8"
))
print
(
data
)
if
app
.
state
.
MODEL_FILTER_ENABLED
:
if
user
and
user
.
role
==
"user"
:
data
[
"data"
]
=
list
(
filter
(
lambda
model
:
model
[
"id"
]
in
app
.
state
.
MODEL_FILTER_LIST
,
data
[
"data"
],
)
)
# Example modification: Add a new key-value pair
data
[
"modified"
]
=
True
# Return a new JSON response with the modified content
return
JSONResponse
(
content
=
data
)
return
response
# Add the middleware to the app
app
.
add_middleware
(
ModifyModelsResponseMiddleware
)
backend/config.py
View file @
93c90dc1
...
@@ -298,7 +298,7 @@ USER_PERMISSIONS_CHAT_DELETION = (
...
@@ -298,7 +298,7 @@ USER_PERMISSIONS_CHAT_DELETION = (
USER_PERMISSIONS
=
{
"chat"
:
{
"deletion"
:
USER_PERMISSIONS_CHAT_DELETION
}}
USER_PERMISSIONS
=
{
"chat"
:
{
"deletion"
:
USER_PERMISSIONS_CHAT_DELETION
}}
MODEL_FILTER_ENABLED
=
os
.
environ
.
get
(
"MODEL_FILTER_ENABLED"
,
False
)
MODEL_FILTER_ENABLED
=
os
.
environ
.
get
(
"MODEL_FILTER_ENABLED"
,
"
False
"
).
lower
()
==
"true"
MODEL_FILTER_LIST
=
os
.
environ
.
get
(
"MODEL_FILTER_LIST"
,
""
)
MODEL_FILTER_LIST
=
os
.
environ
.
get
(
"MODEL_FILTER_LIST"
,
""
)
MODEL_FILTER_LIST
=
[
model
.
strip
()
for
model
in
MODEL_FILTER_LIST
.
split
(
";"
)]
MODEL_FILTER_LIST
=
[
model
.
strip
()
for
model
in
MODEL_FILTER_LIST
.
split
(
";"
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment