Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
c44fc82e
Commit
c44fc82e
authored
Jun 09, 2024
by
Timothy J. Baek
Browse files
refac: openai
parent
8b6f422d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
129 additions
and
81 deletions
+129
-81
backend/apps/openai/main.py
backend/apps/openai/main.py
+129
-81
No files found.
backend/apps/openai/main.py
View file @
c44fc82e
...
...
@@ -345,24 +345,17 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
)
@
app
.
api_route
(
"/{path:path}"
,
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
async
def
proxy
(
path
:
str
,
request
:
Request
,
user
=
Depends
(
get_verified_user
)):
@
app
.
post
(
"/chat/completions"
)
@
app
.
post
(
"/chat/completions/{url_idx}"
)
async
def
generate_chat_completion
(
form_data
:
dict
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
):
idx
=
0
payload
=
{
**
form_data
}
body
=
await
request
.
body
()
# TODO: Remove below after gpt-4-vision fix from Open AI
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
payload
=
None
try
:
if
"chat/completions"
in
path
:
body
=
body
.
decode
(
"utf-8"
)
body
=
json
.
loads
(
body
)
payload
=
{
**
body
}
model_id
=
body
.
get
(
"model"
)
model_id
=
form_data
.
get
(
"model"
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
if
model_info
:
...
...
@@ -374,17 +367,13 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
if
model_info
.
params
:
if
model_info
.
params
.
get
(
"temperature"
,
None
)
is
not
None
:
payload
[
"temperature"
]
=
float
(
model_info
.
params
.
get
(
"temperature"
)
)
payload
[
"temperature"
]
=
float
(
model_info
.
params
.
get
(
"temperature"
))
if
model_info
.
params
.
get
(
"top_p"
,
None
):
payload
[
"top_p"
]
=
int
(
model_info
.
params
.
get
(
"top_p"
,
None
))
if
model_info
.
params
.
get
(
"max_tokens"
,
None
):
payload
[
"max_tokens"
]
=
int
(
model_info
.
params
.
get
(
"max_tokens"
,
None
)
)
payload
[
"max_tokens"
]
=
int
(
model_info
.
params
.
get
(
"max_tokens"
,
None
))
if
model_info
.
params
.
get
(
"frequency_penalty"
,
None
):
payload
[
"frequency_penalty"
]
=
int
(
...
...
@@ -411,8 +400,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
for
message
in
payload
[
"messages"
]:
if
message
.
get
(
"role"
)
==
"system"
:
message
[
"content"
]
=
(
model_info
.
params
.
get
(
"system"
,
None
)
+
message
[
"content"
]
model_info
.
params
.
get
(
"system"
,
None
)
+
message
[
"content"
]
)
break
else
:
...
...
@@ -423,11 +411,11 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
"content"
:
model_info
.
params
.
get
(
"system"
,
None
),
},
)
else
:
pass
model
=
app
.
state
.
MODELS
[
payload
.
get
(
"model"
)]
idx
=
model
[
"urlIdx"
]
if
"pipeline"
in
model
and
model
.
get
(
"pipeline"
):
...
...
@@ -443,11 +431,71 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
# Convert the modified body back to JSON
payload
=
json
.
dumps
(
payload
)
except
json
.
JSONDecodeError
as
e
:
log
.
error
(
"Error loading request body into a dictionary:"
,
e
)
print
(
payload
)
url
=
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]
key
=
app
.
state
.
config
.
OPENAI_API_KEYS
[
idx
]
print
(
payload
)
headers
=
{}
headers
[
"Authorization"
]
=
f
"Bearer
{
key
}
"
headers
[
"Content-Type"
]
=
"application/json"
r
=
None
session
=
None
streaming
=
False
try
:
session
=
aiohttp
.
ClientSession
(
trust_env
=
True
)
r
=
await
session
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/chat/completions"
,
data
=
payload
,
headers
=
headers
,
)
r
.
raise_for_status
()
# Check if response is SSE
if
"text/event-stream"
in
r
.
headers
.
get
(
"Content-Type"
,
""
):
streaming
=
True
return
StreamingResponse
(
r
.
content
,
status_code
=
r
.
status
,
headers
=
dict
(
r
.
headers
),
background
=
BackgroundTask
(
cleanup_response
,
response
=
r
,
session
=
session
),
)
else
:
response_data
=
await
r
.
json
()
return
response_data
except
Exception
as
e
:
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
res
=
await
r
.
json
()
print
(
res
)
if
"error"
in
res
:
error_detail
=
f
"External:
{
res
[
'error'
][
'message'
]
if
'message'
in
res
[
'error'
]
else
res
[
'error'
]
}
"
except
:
error_detail
=
f
"External:
{
e
}
"
raise
HTTPException
(
status_code
=
r
.
status
if
r
else
500
,
detail
=
error_detail
)
finally
:
if
not
streaming
and
session
:
if
r
:
r
.
close
()
await
session
.
close
()
@
app
.
api_route
(
"/{path:path}"
,
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
async
def
proxy
(
path
:
str
,
request
:
Request
,
user
=
Depends
(
get_verified_user
)):
idx
=
0
body
=
await
request
.
body
()
url
=
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]
key
=
app
.
state
.
config
.
OPENAI_API_KEYS
[
idx
]
...
...
@@ -466,7 +514,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
r
=
await
session
.
request
(
method
=
request
.
method
,
url
=
target_url
,
data
=
payload
if
payload
else
body
,
data
=
body
,
headers
=
headers
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment