Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
c44fc82e
Commit
c44fc82e
authored
Jun 09, 2024
by
Timothy J. Baek
Browse files
refac: openai
parent
8b6f422d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
129 additions
and
81 deletions
+129
-81
backend/apps/openai/main.py
backend/apps/openai/main.py
+129
-81
No files found.
backend/apps/openai/main.py
View file @
c44fc82e
...
@@ -345,108 +345,156 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
...
@@ -345,108 +345,156 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
)
)
@
app
.
api_route
(
"/{path:path}"
,
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
@
app
.
post
(
"/chat/completions"
)
async
def
proxy
(
path
:
str
,
request
:
Request
,
user
=
Depends
(
get_verified_user
)):
@
app
.
post
(
"/chat/completions/{url_idx}"
)
async
def
generate_chat_completion
(
form_data
:
dict
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
):
idx
=
0
idx
=
0
payload
=
{
**
form_data
}
body
=
await
request
.
body
()
model_id
=
form_data
.
get
(
"model"
)
# TODO: Remove below after gpt-4-vision fix from Open AI
model_info
=
Models
.
get_model_by_id
(
model_id
)
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
payload
=
None
if
model_info
:
print
(
model_info
)
if
model_info
.
base_model_id
:
payload
[
"model"
]
=
model_info
.
base_model_id
try
:
model_info
.
params
=
model_info
.
params
.
model_dump
()
if
"chat/completions"
in
path
:
body
=
body
.
decode
(
"utf-8"
)
body
=
json
.
loads
(
body
)
payload
=
{
**
body
}
if
model_info
.
params
:
if
model_info
.
params
.
get
(
"temperature"
,
None
)
is
not
None
:
payload
[
"temperature"
]
=
float
(
model_info
.
params
.
get
(
"temperature"
))
model_i
d
=
body
.
get
(
"model"
)
if
model_i
nfo
.
params
.
get
(
"top_p"
,
None
):
model_info
=
Models
.
get_model_by_id
(
model_id
)
payload
[
"top_p"
]
=
int
(
model_info
.
params
.
get
(
"top_p"
,
None
)
)
if
model_info
:
if
model_info
.
params
.
get
(
"max_tokens"
,
None
):
print
(
model_info
)
payload
[
"max_tokens"
]
=
int
(
model_info
.
params
.
get
(
"max_tokens"
,
None
))
if
model_info
.
base_model_id
:
payload
[
"model"
]
=
model_info
.
base_model_id
model_info
.
params
=
model_info
.
params
.
model_dump
()
if
model_info
.
params
.
get
(
"frequency_penalty"
,
None
):
payload
[
"frequency_penalty"
]
=
int
(
model_info
.
params
.
get
(
"frequency_penalty"
,
None
)
)
if
model_info
.
params
.
get
(
"seed"
,
None
):
payload
[
"seed"
]
=
model_info
.
params
.
get
(
"seed"
,
None
)
if
model_info
.
params
.
get
(
"stop"
,
None
):
payload
[
"stop"
]
=
(
[
bytes
(
stop
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
stop
in
model_info
.
params
[
"stop"
]
]
if
model_info
.
params
.
get
(
"stop"
,
None
)
else
None
)
if
model_info
.
params
:
if
model_info
.
params
.
get
(
"system"
,
None
):
if
model_info
.
params
.
get
(
"temperature"
,
None
)
is
not
None
:
# Check if the payload already has a system message
payload
[
"temperature"
]
=
float
(
# If not, add a system message to the payload
model_info
.
params
.
get
(
"temperature"
)
if
payload
.
get
(
"messages"
):
for
message
in
payload
[
"messages"
]:
if
message
.
get
(
"role"
)
==
"system"
:
message
[
"content"
]
=
(
model_info
.
params
.
get
(
"system"
,
None
)
+
message
[
"content"
]
)
)
break
else
:
payload
[
"messages"
].
insert
(
0
,
{
"role"
:
"system"
,
"content"
:
model_info
.
params
.
get
(
"system"
,
None
),
},
)
if
model_info
.
params
.
get
(
"top_p"
,
None
)
:
else
:
payload
[
"top_p"
]
=
int
(
model_info
.
params
.
get
(
"top_p"
,
None
))
pass
if
model_info
.
params
.
get
(
"max_tokens"
,
None
):
model
=
app
.
state
.
MODELS
[
payload
.
get
(
"model"
)]
payload
[
"max_tokens"
]
=
int
(
idx
=
model
[
"urlIdx"
]
model_info
.
params
.
get
(
"max_tokens"
,
None
)
)
if
model_info
.
params
.
get
(
"frequency_penalty"
,
None
):
if
"pipeline"
in
model
and
model
.
get
(
"pipeline"
):
payload
[
"frequency_penalty"
]
=
int
(
payload
[
"user"
]
=
{
"name"
:
user
.
name
,
"id"
:
user
.
id
}
model_info
.
params
.
get
(
"frequency_penalty"
,
None
)
)
if
model_info
.
params
.
get
(
"seed"
,
None
):
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
payload
[
"seed"
]
=
model_info
.
params
.
get
(
"seed"
,
None
)
# This is a workaround until OpenAI fixes the issue with this model
if
payload
.
get
(
"model"
)
==
"gpt-4-vision-preview"
:
if
model_info
.
params
.
get
(
"stop"
,
None
):
if
"max_tokens"
not
in
payload
:
payload
[
"stop"
]
=
(
payload
[
"max_tokens"
]
=
4000
[
log
.
debug
(
"Modified payload:"
,
payload
)
bytes
(
stop
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
stop
in
model_info
.
params
[
"stop"
]
]
if
model_info
.
params
.
get
(
"stop"
,
None
)
else
None
)
if
model_info
.
params
.
get
(
"system"
,
None
):
# Convert the modified body back to JSON
# Check if the payload already has a system message
payload
=
json
.
dumps
(
payload
)
# If not, add a system message to the payload
if
payload
.
get
(
"messages"
):
print
(
payload
)
for
message
in
payload
[
"messages"
]:
if
message
.
get
(
"role"
)
==
"system"
:
message
[
"content"
]
=
(
model_info
.
params
.
get
(
"system"
,
None
)
+
message
[
"content"
]
)
break
else
:
payload
[
"messages"
].
insert
(
0
,
{
"role"
:
"system"
,
"content"
:
model_info
.
params
.
get
(
"system"
,
None
),
},
)
else
:
pass
model
=
app
.
state
.
MODELS
[
payload
.
get
(
"model"
)]
url
=
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]
key
=
app
.
state
.
config
.
OPENAI_API_KEYS
[
idx
]
idx
=
model
[
"urlIdx"
]
print
(
payload
)
if
"pipeline"
in
model
and
model
.
get
(
"pipeline"
):
headers
=
{}
payload
[
"user"
]
=
{
"name"
:
user
.
name
,
"id"
:
user
.
id
}
headers
[
"Authorization"
]
=
f
"Bearer
{
key
}
"
headers
[
"Content-Type"
]
=
"application/json"
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
r
=
None
# This is a workaround until OpenAI fixes the issue with this model
session
=
None
if
payload
.
get
(
"model"
)
==
"gpt-4-vision-preview"
:
streaming
=
False
if
"max_tokens"
not
in
payload
:
payload
[
"max_tokens"
]
=
4000
log
.
debug
(
"Modified payload:"
,
payload
)
# Convert the modified body back to JSON
try
:
payload
=
json
.
dumps
(
payload
)
session
=
aiohttp
.
ClientSession
(
trust_env
=
True
)
r
=
await
session
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/chat/completions"
,
data
=
payload
,
headers
=
headers
,
)
except
json
.
JSONDecodeError
as
e
:
r
.
raise_for_status
()
log
.
error
(
"Error loading request body into a dictionary:"
,
e
)
print
(
payload
)
# Check if response is SSE
if
"text/event-stream"
in
r
.
headers
.
get
(
"Content-Type"
,
""
):
streaming
=
True
return
StreamingResponse
(
r
.
content
,
status_code
=
r
.
status
,
headers
=
dict
(
r
.
headers
),
background
=
BackgroundTask
(
cleanup_response
,
response
=
r
,
session
=
session
),
)
else
:
response_data
=
await
r
.
json
()
return
response_data
except
Exception
as
e
:
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
res
=
await
r
.
json
()
print
(
res
)
if
"error"
in
res
:
error_detail
=
f
"External:
{
res
[
'error'
][
'message'
]
if
'message'
in
res
[
'error'
]
else
res
[
'error'
]
}
"
except
:
error_detail
=
f
"External:
{
e
}
"
raise
HTTPException
(
status_code
=
r
.
status
if
r
else
500
,
detail
=
error_detail
)
finally
:
if
not
streaming
and
session
:
if
r
:
r
.
close
()
await
session
.
close
()
@
app
.
api_route
(
"/{path:path}"
,
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
async
def
proxy
(
path
:
str
,
request
:
Request
,
user
=
Depends
(
get_verified_user
)):
idx
=
0
body
=
await
request
.
body
()
url
=
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]
url
=
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]
key
=
app
.
state
.
config
.
OPENAI_API_KEYS
[
idx
]
key
=
app
.
state
.
config
.
OPENAI_API_KEYS
[
idx
]
...
@@ -466,7 +514,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
...
@@ -466,7 +514,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
r
=
await
session
.
request
(
r
=
await
session
.
request
(
method
=
request
.
method
,
method
=
request
.
method
,
url
=
target_url
,
url
=
target_url
,
data
=
payload
if
payload
else
body
,
data
=
body
,
headers
=
headers
,
headers
=
headers
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment