Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
12c21fac
Commit
12c21fac
authored
Aug 03, 2024
by
Michael Poluektov
Browse files
refac: apps/openai/main.py and utils
parent
774defd1
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
149 additions
and
231 deletions
+149
-231
backend/apps/openai/main.py
backend/apps/openai/main.py
+60
-134
backend/apps/socket/main.py
backend/apps/socket/main.py
+15
-12
backend/apps/webui/main.py
backend/apps/webui/main.py
+2
-43
backend/apps/webui/routers/tools.py
backend/apps/webui/routers/tools.py
+13
-18
backend/main.py
backend/main.py
+12
-16
backend/utils/misc.py
backend/utils/misc.py
+43
-0
backend/utils/task.py
backend/utils/task.py
+1
-2
backend/utils/utils.py
backend/utils/utils.py
+3
-6
No files found.
backend/apps/openai/main.py
View file @
12c21fac
from
fastapi
import
FastAPI
,
Request
,
Response
,
HTTPException
,
Depends
from
fastapi
import
FastAPI
,
Request
,
HTTPException
,
Depends
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
StreamingResponse
,
JSONResponse
,
FileResponse
from
fastapi.responses
import
StreamingResponse
,
FileResponse
import
requests
import
aiohttp
...
...
@@ -12,16 +12,12 @@ from pydantic import BaseModel
from
starlette.background
import
BackgroundTask
from
apps.webui.models.models
import
Models
from
apps.webui.models.users
import
Users
from
constants
import
ERROR_MESSAGES
from
utils.utils
import
(
decode_token
,
get_verified_user
,
get_verified_user
,
get_admin_user
,
)
from
utils.task
import
prompt_template
from
utils.misc
import
add_or_update_system_message
from
utils.misc
import
apply_model_params_to_body
,
apply_model_system_prompt_to_body
from
config
import
(
SRC_LOG_LEVELS
,
...
...
@@ -69,8 +65,6 @@ app.state.MODELS = {}
async
def
check_url
(
request
:
Request
,
call_next
):
if
len
(
app
.
state
.
MODELS
)
==
0
:
await
get_all_models
()
else
:
pass
response
=
await
call_next
(
request
)
return
response
...
...
@@ -175,7 +169,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"External:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"External:
{
e
}
"
raise
HTTPException
(
...
...
@@ -234,64 +228,58 @@ def merge_models_lists(model_lists):
return
merged_list
async
def
get_all_models
(
raw
:
bool
=
False
):
def
is_openai_api_disabled
():
api_keys
=
app
.
state
.
config
.
OPENAI_API_KEYS
no_keys
=
len
(
api_keys
)
==
1
and
api_keys
[
0
]
==
""
return
no_keys
or
not
app
.
state
.
config
.
ENABLE_OPENAI_API
async
def
get_all_models_raw
()
->
list
:
if
is_openai_api_disabled
():
return
[]
# Check if API KEYS length is same than API URLS length
num_urls
=
len
(
app
.
state
.
config
.
OPENAI_API_BASE_URLS
)
num_keys
=
len
(
app
.
state
.
config
.
OPENAI_API_KEYS
)
if
num_keys
!=
num_urls
:
# if there are more keys than urls, remove the extra keys
if
num_keys
>
num_urls
:
new_keys
=
app
.
state
.
config
.
OPENAI_API_KEYS
[:
num_urls
]
app
.
state
.
config
.
OPENAI_API_KEYS
=
new_keys
# if there are more urls than keys, add empty keys
else
:
app
.
state
.
config
.
OPENAI_API_KEYS
+=
[
""
]
*
(
num_urls
-
num_keys
)
tasks
=
[
fetch_url
(
f
"
{
url
}
/models"
,
app
.
state
.
config
.
OPENAI_API_KEYS
[
idx
])
for
idx
,
url
in
enumerate
(
app
.
state
.
config
.
OPENAI_API_BASE_URLS
)
]
responses
=
await
asyncio
.
gather
(
*
tasks
)
log
.
debug
(
f
"get_all_models:responses()
{
responses
}
"
)
return
responses
async
def
get_all_models
()
->
dict
[
str
,
list
]:
log
.
info
(
"get_all_models()"
)
if
is_openai_api_disabled
():
return
{
"data"
:
[]}
if
(
len
(
app
.
state
.
config
.
OPENAI_API_KEYS
)
==
1
and
app
.
state
.
config
.
OPENAI_API_KEYS
[
0
]
==
""
)
or
not
app
.
state
.
config
.
ENABLE_OPENAI_API
:
models
=
{
"data"
:
[]}
else
:
# Check if API KEYS length is same than API URLS length
if
len
(
app
.
state
.
config
.
OPENAI_API_KEYS
)
!=
len
(
app
.
state
.
config
.
OPENAI_API_BASE_URLS
):
# if there are more keys than urls, remove the extra keys
if
len
(
app
.
state
.
config
.
OPENAI_API_KEYS
)
>
len
(
app
.
state
.
config
.
OPENAI_API_BASE_URLS
):
app
.
state
.
config
.
OPENAI_API_KEYS
=
app
.
state
.
config
.
OPENAI_API_KEYS
[
:
len
(
app
.
state
.
config
.
OPENAI_API_BASE_URLS
)
]
# if there are more urls than keys, add empty keys
else
:
app
.
state
.
config
.
OPENAI_API_KEYS
+=
[
""
for
_
in
range
(
len
(
app
.
state
.
config
.
OPENAI_API_BASE_URLS
)
-
len
(
app
.
state
.
config
.
OPENAI_API_KEYS
)
)
]
responses
=
await
get_all_models_raw
()
tasks
=
[
fetch_url
(
f
"
{
url
}
/models"
,
app
.
state
.
config
.
OPENAI_API_KEYS
[
idx
])
for
idx
,
url
in
enumerate
(
app
.
state
.
config
.
OPENAI_API_BASE_URLS
)
]
responses
=
await
asyncio
.
gather
(
*
tasks
)
log
.
debug
(
f
"get_all_models:responses()
{
responses
}
"
)
if
raw
:
return
responses
models
=
{
"data"
:
merge_models_lists
(
list
(
map
(
lambda
response
:
(
response
[
"data"
]
if
(
response
and
"data"
in
response
)
else
(
response
if
isinstance
(
response
,
list
)
else
None
)
),
responses
,
)
)
)
}
def
extract_data
(
response
):
if
response
and
"data"
in
response
:
return
response
[
"data"
]
if
isinstance
(
response
,
list
):
return
response
return
None
log
.
debug
(
f
"models:
{
models
}
"
)
app
.
state
.
MODELS
=
{
model
[
"id"
]:
model
for
model
in
models
[
"data"
]}
models
=
{
"data"
:
merge_models_lists
(
map
(
extract_data
,
responses
))}
log
.
debug
(
f
"models:
{
models
}
"
)
app
.
state
.
MODELS
=
{
model
[
"id"
]:
model
for
model
in
models
[
"data"
]}
return
models
...
...
@@ -299,7 +287,7 @@ async def get_all_models(raw: bool = False):
@
app
.
get
(
"/models"
)
@
app
.
get
(
"/models/{url_idx}"
)
async
def
get_models
(
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
)):
if
url_idx
==
None
:
if
url_idx
is
None
:
models
=
await
get_all_models
()
if
app
.
state
.
config
.
ENABLE_MODEL_FILTER
:
if
user
.
role
==
"user"
:
...
...
@@ -340,7 +328,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"External:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"External:
{
e
}
"
raise
HTTPException
(
...
...
@@ -358,8 +346,7 @@ async def generate_chat_completion(
):
idx
=
0
payload
=
{
**
form_data
}
if
"metadata"
in
payload
:
del
payload
[
"metadata"
]
payload
.
pop
(
"metadata"
)
model_id
=
form_data
.
get
(
"model"
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
...
...
@@ -368,70 +355,9 @@ async def generate_chat_completion(
if
model_info
.
base_model_id
:
payload
[
"model"
]
=
model_info
.
base_model_id
model_info
.
params
=
model_info
.
params
.
model_dump
()
if
model_info
.
params
:
if
(
model_info
.
params
.
get
(
"temperature"
,
None
)
is
not
None
and
payload
.
get
(
"temperature"
)
is
None
):
payload
[
"temperature"
]
=
float
(
model_info
.
params
.
get
(
"temperature"
))
if
model_info
.
params
.
get
(
"top_p"
,
None
)
and
payload
.
get
(
"top_p"
)
is
None
:
payload
[
"top_p"
]
=
int
(
model_info
.
params
.
get
(
"top_p"
,
None
))
if
(
model_info
.
params
.
get
(
"max_tokens"
,
None
)
and
payload
.
get
(
"max_tokens"
)
is
None
):
payload
[
"max_tokens"
]
=
int
(
model_info
.
params
.
get
(
"max_tokens"
,
None
))
if
(
model_info
.
params
.
get
(
"frequency_penalty"
,
None
)
and
payload
.
get
(
"frequency_penalty"
)
is
None
):
payload
[
"frequency_penalty"
]
=
int
(
model_info
.
params
.
get
(
"frequency_penalty"
,
None
)
)
if
(
model_info
.
params
.
get
(
"seed"
,
None
)
is
not
None
and
payload
.
get
(
"seed"
)
is
None
):
payload
[
"seed"
]
=
model_info
.
params
.
get
(
"seed"
,
None
)
if
model_info
.
params
.
get
(
"stop"
,
None
)
and
payload
.
get
(
"stop"
)
is
None
:
payload
[
"stop"
]
=
(
[
bytes
(
stop
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
stop
in
model_info
.
params
[
"stop"
]
]
if
model_info
.
params
.
get
(
"stop"
,
None
)
else
None
)
system
=
model_info
.
params
.
get
(
"system"
,
None
)
if
system
:
system
=
prompt_template
(
system
,
**
(
{
"user_name"
:
user
.
name
,
"user_location"
:
(
user
.
info
.
get
(
"location"
)
if
user
.
info
else
None
),
}
if
user
else
{}
),
)
if
payload
.
get
(
"messages"
):
payload
[
"messages"
]
=
add_or_update_system_message
(
system
,
payload
[
"messages"
]
)
else
:
pass
params
=
model_info
.
params
.
model_dump
()
payload
=
apply_model_params_to_body
(
params
,
payload
)
payload
=
apply_model_system_prompt_to_body
(
params
,
payload
,
user
)
model
=
app
.
state
.
MODELS
[
payload
.
get
(
"model"
)]
idx
=
model
[
"urlIdx"
]
...
...
@@ -506,7 +432,7 @@ async def generate_chat_completion(
print
(
res
)
if
"error"
in
res
:
error_detail
=
f
"External:
{
res
[
'error'
][
'message'
]
if
'message'
in
res
[
'error'
]
else
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"External:
{
e
}
"
raise
HTTPException
(
status_code
=
r
.
status
if
r
else
500
,
detail
=
error_detail
)
finally
:
...
...
@@ -569,7 +495,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
print
(
res
)
if
"error"
in
res
:
error_detail
=
f
"External:
{
res
[
'error'
][
'message'
]
if
'message'
in
res
[
'error'
]
else
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"External:
{
e
}
"
raise
HTTPException
(
status_code
=
r
.
status
if
r
else
500
,
detail
=
error_detail
)
finally
:
...
...
backend/apps/socket/main.py
View file @
12c21fac
...
...
@@ -44,23 +44,26 @@ async def user_join(sid, data):
print
(
"user-join"
,
sid
,
data
)
auth
=
data
[
"auth"
]
if
"auth"
in
data
else
None
if
not
auth
or
"token"
not
in
auth
:
return
if
auth
and
"token"
in
auth
:
data
=
decode_token
(
auth
[
"token"
])
data
=
decode_token
(
auth
[
"token"
])
if
data
is
None
or
"id"
not
in
data
:
return
if
data
is
not
None
and
"id"
in
data
:
user
=
Users
.
get_user_by_id
(
data
[
"id"
])
user
=
Users
.
get_user_by_id
(
data
[
"id"
])
if
not
user
:
return
if
user
:
SESSION_POOL
[
sid
]
=
user
.
id
if
user
.
id
in
USER_POOL
:
USER_POOL
[
user
.
id
].
append
(
sid
)
else
:
USER_POOL
[
user
.
id
]
=
[
sid
]
SESSION_POOL
[
sid
]
=
user
.
id
if
user
.
id
in
USER_POOL
:
USER_POOL
[
user
.
id
].
append
(
sid
)
else
:
USER_POOL
[
user
.
id
]
=
[
sid
]
print
(
f
"user
{
user
.
name
}
(
{
user
.
id
}
) connected with session ID
{
sid
}
"
)
print
(
f
"user
{
user
.
name
}
(
{
user
.
id
}
) connected with session ID
{
sid
}
"
)
await
sio
.
emit
(
"user-count"
,
{
"count"
:
len
(
set
(
USER_POOL
))})
await
sio
.
emit
(
"user-count"
,
{
"count"
:
len
(
set
(
USER_POOL
))})
@
sio
.
on
(
"user-count"
)
...
...
backend/apps/webui/main.py
View file @
12c21fac
...
...
@@ -22,9 +22,9 @@ from apps.webui.utils import load_function_module_by_id
from
utils.misc
import
(
openai_chat_chunk_message_template
,
openai_chat_completion_message_template
,
add_or_update_system_message
,
apply_model_params_to_body
,
apply_model_system_prompt_to_body
,
)
from
utils.task
import
prompt_template
from
config
import
(
...
...
@@ -269,47 +269,6 @@ def get_function_params(function_module, form_data, user, extra_params={}):
return
params
# inplace function: form_data is modified
def
apply_model_params_to_body
(
params
:
dict
,
form_data
:
dict
)
->
dict
:
if
not
params
:
return
form_data
mappings
=
{
"temperature"
:
float
,
"top_p"
:
int
,
"max_tokens"
:
int
,
"frequency_penalty"
:
int
,
"seed"
:
lambda
x
:
x
,
"stop"
:
lambda
x
:
[
bytes
(
s
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
s
in
x
],
}
for
key
,
cast_func
in
mappings
.
items
():
if
(
value
:
=
params
.
get
(
key
))
is
not
None
:
form_data
[
key
]
=
cast_func
(
value
)
return
form_data
# inplace function: form_data is modified
def
apply_model_system_prompt_to_body
(
params
:
dict
,
form_data
:
dict
,
user
)
->
dict
:
system
=
params
.
get
(
"system"
,
None
)
if
not
system
:
return
form_data
if
user
:
template_params
=
{
"user_name"
:
user
.
name
,
"user_location"
:
user
.
info
.
get
(
"location"
)
if
user
.
info
else
None
,
}
else
:
template_params
=
{}
system
=
prompt_template
(
system
,
**
template_params
)
form_data
[
"messages"
]
=
add_or_update_system_message
(
system
,
form_data
.
get
(
"messages"
,
[])
)
return
form_data
async
def
generate_function_chat_completion
(
form_data
,
user
):
model_id
=
form_data
.
get
(
"model"
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
...
...
backend/apps/webui/routers/tools.py
View file @
12c21fac
from
fastapi
import
Depends
,
FastAPI
,
HTTPException
,
status
,
Request
from
datetime
import
datetime
,
timedelta
from
typing
import
List
,
Union
,
Optional
from
fastapi
import
Depends
,
HTTPException
,
status
,
Request
from
typing
import
List
,
Optional
from
fastapi
import
APIRouter
from
pydantic
import
BaseModel
import
json
from
apps.webui.models.users
import
Users
from
apps.webui.models.tools
import
Tools
,
ToolForm
,
ToolModel
,
ToolResponse
from
apps.webui.utils
import
load_toolkit_module_by_id
...
...
@@ -14,7 +10,6 @@ from utils.utils import get_admin_user, get_verified_user
from
utils.tools
import
get_tools_specs
from
constants
import
ERROR_MESSAGES
from
importlib
import
util
import
os
from
pathlib
import
Path
...
...
@@ -69,7 +64,7 @@ async def create_new_toolkit(
form_data
.
id
=
form_data
.
id
.
lower
()
toolkit
=
Tools
.
get_tool_by_id
(
form_data
.
id
)
if
toolkit
==
None
:
if
toolkit
is
None
:
toolkit_path
=
os
.
path
.
join
(
TOOLS_DIR
,
f
"
{
form_data
.
id
}
.py"
)
try
:
with
open
(
toolkit_path
,
"w"
)
as
tool_file
:
...
...
@@ -98,7 +93,7 @@ async def create_new_toolkit(
print
(
e
)
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
),
detail
=
ERROR_MESSAGES
.
DEFAULT
(
str
(
e
)
),
)
else
:
raise
HTTPException
(
...
...
@@ -170,7 +165,7 @@ async def update_toolkit_by_id(
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
),
detail
=
ERROR_MESSAGES
.
DEFAULT
(
str
(
e
)
),
)
...
...
@@ -210,7 +205,7 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
),
detail
=
ERROR_MESSAGES
.
DEFAULT
(
str
(
e
)
),
)
else
:
raise
HTTPException
(
...
...
@@ -233,7 +228,7 @@ async def get_toolkit_valves_spec_by_id(
if
id
in
request
.
app
.
state
.
TOOLS
:
toolkit_module
=
request
.
app
.
state
.
TOOLS
[
id
]
else
:
toolkit_module
,
frontmatter
=
load_toolkit_module_by_id
(
id
)
toolkit_module
,
_
=
load_toolkit_module_by_id
(
id
)
request
.
app
.
state
.
TOOLS
[
id
]
=
toolkit_module
if
hasattr
(
toolkit_module
,
"Valves"
):
...
...
@@ -261,7 +256,7 @@ async def update_toolkit_valves_by_id(
if
id
in
request
.
app
.
state
.
TOOLS
:
toolkit_module
=
request
.
app
.
state
.
TOOLS
[
id
]
else
:
toolkit_module
,
frontmatter
=
load_toolkit_module_by_id
(
id
)
toolkit_module
,
_
=
load_toolkit_module_by_id
(
id
)
request
.
app
.
state
.
TOOLS
[
id
]
=
toolkit_module
if
hasattr
(
toolkit_module
,
"Valves"
):
...
...
@@ -276,7 +271,7 @@ async def update_toolkit_valves_by_id(
print
(
e
)
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
),
detail
=
ERROR_MESSAGES
.
DEFAULT
(
str
(
e
)
),
)
else
:
raise
HTTPException
(
...
...
@@ -306,7 +301,7 @@ async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
),
detail
=
ERROR_MESSAGES
.
DEFAULT
(
str
(
e
)
),
)
else
:
raise
HTTPException
(
...
...
@@ -324,7 +319,7 @@ async def get_toolkit_user_valves_spec_by_id(
if
id
in
request
.
app
.
state
.
TOOLS
:
toolkit_module
=
request
.
app
.
state
.
TOOLS
[
id
]
else
:
toolkit_module
,
frontmatter
=
load_toolkit_module_by_id
(
id
)
toolkit_module
,
_
=
load_toolkit_module_by_id
(
id
)
request
.
app
.
state
.
TOOLS
[
id
]
=
toolkit_module
if
hasattr
(
toolkit_module
,
"UserValves"
):
...
...
@@ -348,7 +343,7 @@ async def update_toolkit_user_valves_by_id(
if
id
in
request
.
app
.
state
.
TOOLS
:
toolkit_module
=
request
.
app
.
state
.
TOOLS
[
id
]
else
:
toolkit_module
,
frontmatter
=
load_toolkit_module_by_id
(
id
)
toolkit_module
,
_
=
load_toolkit_module_by_id
(
id
)
request
.
app
.
state
.
TOOLS
[
id
]
=
toolkit_module
if
hasattr
(
toolkit_module
,
"UserValves"
):
...
...
@@ -365,7 +360,7 @@ async def update_toolkit_user_valves_by_id(
print
(
e
)
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
),
detail
=
ERROR_MESSAGES
.
DEFAULT
(
str
(
e
)
),
)
else
:
raise
HTTPException
(
...
...
backend/main.py
View file @
12c21fac
...
...
@@ -36,6 +36,7 @@ from apps.ollama.main import (
from
apps.openai.main
import
(
app
as
openai_app
,
get_all_models
as
get_openai_models
,
get_all_models_raw
as
get_openai_models_raw
,
generate_chat_completion
as
generate_openai_chat_completion
,
)
...
...
@@ -957,7 +958,7 @@ async def get_all_models():
custom_models
=
Models
.
get_all_models
()
for
custom_model
in
custom_models
:
if
custom_model
.
base_model_id
==
None
:
if
custom_model
.
base_model_id
is
None
:
for
model
in
models
:
if
(
custom_model
.
id
==
model
[
"id"
]
...
...
@@ -1656,13 +1657,13 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
@
app
.
get
(
"/api/pipelines/list"
)
async
def
get_pipelines_list
(
user
=
Depends
(
get_admin_user
)):
responses
=
await
get_openai_models
(
raw
=
True
)
responses
=
await
get_openai_models
_
raw
(
)
print
(
responses
)
urlIdxs
=
[
idx
for
idx
,
response
in
enumerate
(
responses
)
if
response
!=
None
and
"pipelines"
in
response
if
response
is
not
None
and
"pipelines"
in
response
]
return
{
...
...
@@ -1723,7 +1724,7 @@ async def upload_pipeline(
res
=
r
.
json
()
if
"detail"
in
res
:
detail
=
res
[
"detail"
]
except
:
except
Exception
:
pass
raise
HTTPException
(
...
...
@@ -1769,7 +1770,7 @@ async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user))
res
=
r
.
json
()
if
"detail"
in
res
:
detail
=
res
[
"detail"
]
except
:
except
Exception
:
pass
raise
HTTPException
(
...
...
@@ -1811,7 +1812,7 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_
res
=
r
.
json
()
if
"detail"
in
res
:
detail
=
res
[
"detail"
]
except
:
except
Exception
:
pass
raise
HTTPException
(
...
...
@@ -1844,7 +1845,7 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
res
=
r
.
json
()
if
"detail"
in
res
:
detail
=
res
[
"detail"
]
except
:
except
Exception
:
pass
raise
HTTPException
(
...
...
@@ -1859,7 +1860,6 @@ async def get_pipeline_valves(
pipeline_id
:
str
,
user
=
Depends
(
get_admin_user
),
):
models
=
await
get_all_models
()
r
=
None
try
:
url
=
openai_app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
urlIdx
]
...
...
@@ -1898,8 +1898,6 @@ async def get_pipeline_valves_spec(
pipeline_id
:
str
,
user
=
Depends
(
get_admin_user
),
):
models
=
await
get_all_models
()
r
=
None
try
:
url
=
openai_app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
urlIdx
]
...
...
@@ -1922,7 +1920,7 @@ async def get_pipeline_valves_spec(
res
=
r
.
json
()
if
"detail"
in
res
:
detail
=
res
[
"detail"
]
except
:
except
Exception
:
pass
raise
HTTPException
(
...
...
@@ -1938,8 +1936,6 @@ async def update_pipeline_valves(
form_data
:
dict
,
user
=
Depends
(
get_admin_user
),
):
models
=
await
get_all_models
()
r
=
None
try
:
url
=
openai_app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
urlIdx
]
...
...
@@ -1967,7 +1963,7 @@ async def update_pipeline_valves(
res
=
r
.
json
()
if
"detail"
in
res
:
detail
=
res
[
"detail"
]
except
:
except
Exception
:
pass
raise
HTTPException
(
...
...
@@ -2068,7 +2064,7 @@ async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
@
app
.
get
(
"/api/version"
)
async
def
get_app_
config
():
async
def
get_app_
version
():
return
{
"version"
:
VERSION
,
}
...
...
@@ -2091,7 +2087,7 @@ async def get_app_latest_release_version():
latest_version
=
data
[
"tag_name"
]
return
{
"current"
:
VERSION
,
"latest"
:
latest_version
[
1
:]}
except
aiohttp
.
ClientError
as
e
:
except
aiohttp
.
ClientError
:
raise
HTTPException
(
status_code
=
status
.
HTTP_503_SERVICE_UNAVAILABLE
,
detail
=
ERROR_MESSAGES
.
RATE_LIMIT_EXCEEDED
,
...
...
backend/utils/misc.py
View file @
12c21fac
...
...
@@ -6,6 +6,8 @@ from typing import Optional, List, Tuple
import
uuid
import
time
from
utils.task
import
prompt_template
def
get_last_user_message_item
(
messages
:
List
[
dict
])
->
Optional
[
dict
]:
for
message
in
reversed
(
messages
):
...
...
@@ -111,6 +113,47 @@ def openai_chat_completion_message_template(model: str, message: str):
template
[
"choices"
][
0
][
"finish_reason"
]
=
"stop"
# inplace function: form_data is modified
def
apply_model_system_prompt_to_body
(
params
:
dict
,
form_data
:
dict
,
user
)
->
dict
:
system
=
params
.
get
(
"system"
,
None
)
if
not
system
:
return
form_data
if
user
:
template_params
=
{
"user_name"
:
user
.
name
,
"user_location"
:
user
.
info
.
get
(
"location"
)
if
user
.
info
else
None
,
}
else
:
template_params
=
{}
system
=
prompt_template
(
system
,
**
template_params
)
form_data
[
"messages"
]
=
add_or_update_system_message
(
system
,
form_data
.
get
(
"messages"
,
[])
)
return
form_data
# inplace function: form_data is modified
def
apply_model_params_to_body
(
params
:
dict
,
form_data
:
dict
)
->
dict
:
if
not
params
:
return
form_data
mappings
=
{
"temperature"
:
float
,
"top_p"
:
int
,
"max_tokens"
:
int
,
"frequency_penalty"
:
int
,
"seed"
:
lambda
x
:
x
,
"stop"
:
lambda
x
:
[
bytes
(
s
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
s
in
x
],
}
for
key
,
cast_func
in
mappings
.
items
():
if
(
value
:
=
params
.
get
(
key
))
is
not
None
:
form_data
[
key
]
=
cast_func
(
value
)
return
form_data
def
get_gravatar_url
(
email
):
# Trim leading and trailing whitespace from
# an email address and force all characters
...
...
backend/utils/task.py
View file @
12c21fac
...
...
@@ -6,7 +6,7 @@ from typing import Optional
def
prompt_template
(
template
:
str
,
user_name
:
str
=
None
,
user_location
:
str
=
None
template
:
str
,
user_name
:
Optional
[
str
]
=
None
,
user_location
:
Optional
[
str
]
=
None
)
->
str
:
# Get the current date
current_date
=
datetime
.
now
()
...
...
@@ -83,7 +83,6 @@ def title_generation_template(
def
search_query_generation_template
(
template
:
str
,
prompt
:
str
,
user
:
Optional
[
dict
]
=
None
)
->
str
:
def
replacement_function
(
match
):
full_match
=
match
.
group
(
0
)
start_length
=
match
.
group
(
1
)
...
...
backend/utils/utils.py
View file @
12c21fac
from
fastapi.security
import
HTTPBearer
,
HTTPAuthorizationCredentials
from
fastapi
import
HTTPException
,
status
,
Depends
,
Request
from
sqlalchemy.orm
import
Session
from
apps.webui.models.users
import
Users
from
pydantic
import
BaseModel
from
typing
import
Union
,
Optional
from
constants
import
ERROR_MESSAGES
from
passlib.context
import
CryptContext
from
datetime
import
datetime
,
timedelta
import
requests
import
jwt
import
uuid
import
logging
...
...
@@ -54,7 +51,7 @@ def decode_token(token: str) -> Optional[dict]:
try
:
decoded
=
jwt
.
decode
(
token
,
SESSION_SECRET
,
algorithms
=
[
ALGORITHM
])
return
decoded
except
Exception
as
e
:
except
Exception
:
return
None
...
...
@@ -71,7 +68,7 @@ def get_http_authorization_cred(auth_header: str):
try
:
scheme
,
credentials
=
auth_header
.
split
(
" "
)
return
HTTPAuthorizationCredentials
(
scheme
=
scheme
,
credentials
=
credentials
)
except
:
except
Exception
:
raise
ValueError
(
ERROR_MESSAGES
.
INVALID_TOKEN
)
...
...
@@ -96,7 +93,7 @@ def get_current_user(
# auth by jwt token
data
=
decode_token
(
token
)
if
data
!=
None
and
"id"
in
data
:
if
data
is
not
None
and
"id"
in
data
:
user
=
Users
.
get_user_by_id
(
data
[
"id"
])
if
user
is
None
:
raise
HTTPException
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment