Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
cc6d9bb8
Commit
cc6d9bb8
authored
May 27, 2024
by
Timothy J. Baek
Browse files
feat: pipeline valve support
parent
abce172b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
85 additions
and
0 deletions
+85
-0
backend/main.py
backend/main.py
+85
-0
No files found.
backend/main.py
View file @
cc6d9bb8
...
@@ -229,6 +229,83 @@ class RAGMiddleware(BaseHTTPMiddleware):
...
@@ -229,6 +229,83 @@ class RAGMiddleware(BaseHTTPMiddleware):
app
.
add_middleware
(
RAGMiddleware
)
app
.
add_middleware
(
RAGMiddleware
)
class
PipelineMiddleware
(
BaseHTTPMiddleware
):
async
def
dispatch
(
self
,
request
:
Request
,
call_next
):
if
request
.
method
==
"POST"
and
(
"/api/chat"
in
request
.
url
.
path
or
"/chat/completions"
in
request
.
url
.
path
):
log
.
debug
(
f
"request.url.path:
{
request
.
url
.
path
}
"
)
# Read the original request body
body
=
await
request
.
body
()
# Decode body to string
body_str
=
body
.
decode
(
"utf-8"
)
# Parse string to JSON
data
=
json
.
loads
(
body_str
)
if
body_str
else
{}
model_id
=
data
[
"model"
]
valves
=
[
model
for
model
in
app
.
state
.
MODELS
.
values
()
if
"pipeline"
in
model
and
model
[
"pipeline"
][
"type"
]
==
"valve"
and
model_id
in
[
target_model
[
"id"
]
for
target_model
in
model
[
"pipeline"
][
"pipelines"
]
]
]
sorted_valves
=
sorted
(
valves
,
key
=
lambda
x
:
x
[
"pipeline"
][
"priority"
])
for
valve
in
sorted_valves
:
try
:
urlIdx
=
valve
[
"urlIdx"
]
url
=
openai_app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
urlIdx
]
key
=
openai_app
.
state
.
config
.
OPENAI_API_KEYS
[
urlIdx
]
if
key
!=
""
:
headers
=
{
"Authorization"
:
f
"Bearer
{
key
}
"
}
r
=
requests
.
post
(
f
"
{
url
}
/valve"
,
headers
=
headers
,
json
=
{
"model"
:
valve
[
"id"
],
"body"
:
data
,
},
)
r
.
raise_for_status
()
data
=
r
.
json
()
except
Exception
as
e
:
# Handle connection error here
log
.
error
(
f
"Connection error:
{
e
}
"
)
pass
modified_body_bytes
=
json
.
dumps
(
data
).
encode
(
"utf-8"
)
# Replace the request body with the modified one
request
.
_body
=
modified_body_bytes
# Set custom header to ensure content-length matches new body length
request
.
headers
.
__dict__
[
"_list"
]
=
[
(
b
"content-length"
,
str
(
len
(
modified_body_bytes
)).
encode
(
"utf-8"
)),
*
[
(
k
,
v
)
for
k
,
v
in
request
.
headers
.
raw
if
k
.
lower
()
!=
b
"content-length"
],
]
response
=
await
call_next
(
request
)
return
response
async
def
_receive
(
self
,
body
:
bytes
):
return
{
"type"
:
"http.request"
,
"body"
:
body
,
"more_body"
:
False
}
app
.
add_middleware
(
PipelineMiddleware
)
@
app
.
middleware
(
"http"
)
@
app
.
middleware
(
"http"
)
async
def
check_url
(
request
:
Request
,
call_next
):
async
def
check_url
(
request
:
Request
,
call_next
):
if
len
(
app
.
state
.
MODELS
)
==
0
:
if
len
(
app
.
state
.
MODELS
)
==
0
:
...
@@ -332,6 +409,14 @@ async def get_all_models():
...
@@ -332,6 +409,14 @@ async def get_all_models():
@
app
.
get
(
"/api/models"
)
@
app
.
get
(
"/api/models"
)
async
def
get_models
(
user
=
Depends
(
get_verified_user
)):
async
def
get_models
(
user
=
Depends
(
get_verified_user
)):
models
=
await
get_all_models
()
models
=
await
get_all_models
()
# Filter out valve models
models
=
[
model
for
model
in
models
if
"pipeline"
not
in
model
or
model
[
"pipeline"
][
"type"
]
!=
"valve"
]
if
app
.
state
.
config
.
ENABLE_MODEL_FILTER
:
if
app
.
state
.
config
.
ENABLE_MODEL_FILTER
:
if
user
.
role
==
"user"
:
if
user
.
role
==
"user"
:
models
=
list
(
models
=
list
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment