Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
b1265c9c
Commit
b1265c9c
authored
May 25, 2024
by
Jun Siang Cheah
Browse files
Merge remote-tracking branch 'upstream/dev' into feat/backend-web-search
parents
60433856
e9c8341d
Changes
88
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1197 additions
and
336 deletions
+1197
-336
backend/apps/litellm/main.py
backend/apps/litellm/main.py
+11
-2
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+150
-26
backend/apps/openai/main.py
backend/apps/openai/main.py
+86
-28
backend/apps/web/internal/db.py
backend/apps/web/internal/db.py
+12
-0
backend/apps/web/internal/migrations/009_add_models.py
backend/apps/web/internal/migrations/009_add_models.py
+61
-0
backend/apps/web/internal/migrations/010_migrate_modelfiles_to_models.py
...b/internal/migrations/010_migrate_modelfiles_to_models.py
+130
-0
backend/apps/web/main.py
backend/apps/web/main.py
+5
-3
backend/apps/web/models/modelfiles.py
backend/apps/web/models/modelfiles.py
+8
-0
backend/apps/web/models/models.py
backend/apps/web/models/models.py
+179
-0
backend/apps/web/routers/modelfiles.py
backend/apps/web/routers/modelfiles.py
+0
-124
backend/apps/web/routers/models.py
backend/apps/web/routers/models.py
+108
-0
backend/constants.py
backend/constants.py
+2
-0
backend/main.py
backend/main.py
+104
-6
backend/utils/misc.py
backend/utils/misc.py
+74
-0
backend/utils/models.py
backend/utils/models.py
+10
-0
src/lib/apis/index.ts
src/lib/apis/index.ts
+123
-0
src/lib/apis/litellm/index.ts
src/lib/apis/litellm/index.ts
+2
-1
src/lib/apis/models/index.ts
src/lib/apis/models/index.ts
+17
-32
src/lib/apis/openai/index.ts
src/lib/apis/openai/index.ts
+6
-1
src/lib/components/chat/Chat.svelte
src/lib/components/chat/Chat.svelte
+109
-113
No files found.
backend/apps/litellm/main.py
View file @
b1265c9c
...
@@ -18,8 +18,9 @@ import requests
...
@@ -18,8 +18,9 @@ import requests
from
pydantic
import
BaseModel
,
ConfigDict
from
pydantic
import
BaseModel
,
ConfigDict
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
from
apps.web.models.models
import
Models
from
utils.utils
import
get_verified_user
,
get_current_user
,
get_admin_user
from
utils.utils
import
get_verified_user
,
get_current_user
,
get_admin_user
from
config
import
SRC_LOG_LEVELS
,
ENV
from
config
import
SRC_LOG_LEVELS
from
constants
import
MESSAGES
from
constants
import
MESSAGES
import
os
import
os
...
@@ -77,7 +78,7 @@ with open(LITELLM_CONFIG_DIR, "r") as file:
...
@@ -77,7 +78,7 @@ with open(LITELLM_CONFIG_DIR, "r") as file:
app
.
state
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
.
value
app
.
state
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
.
value
app
.
state
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
.
value
app
.
state
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
.
value
app
.
state
.
MODEL_CONFIG
=
Models
.
get_all_models
()
app
.
state
.
ENABLE
=
ENABLE_LITELLM
app
.
state
.
ENABLE
=
ENABLE_LITELLM
app
.
state
.
CONFIG
=
litellm_config
app
.
state
.
CONFIG
=
litellm_config
...
@@ -261,6 +262,14 @@ async def get_models(user=Depends(get_current_user)):
...
@@ -261,6 +262,14 @@ async def get_models(user=Depends(get_current_user)):
"object"
:
"model"
,
"object"
:
"model"
,
"created"
:
int
(
time
.
time
()),
"created"
:
int
(
time
.
time
()),
"owned_by"
:
"openai"
,
"owned_by"
:
"openai"
,
"custom_info"
:
next
(
(
item
for
item
in
app
.
state
.
MODEL_CONFIG
if
item
.
id
==
model
[
"model_name"
]
),
None
,
),
}
}
for
model
in
app
.
state
.
CONFIG
[
"model_list"
]
for
model
in
app
.
state
.
CONFIG
[
"model_list"
]
],
],
...
...
backend/apps/ollama/main.py
View file @
b1265c9c
...
@@ -29,7 +29,7 @@ import time
...
@@ -29,7 +29,7 @@ import time
from
urllib.parse
import
urlparse
from
urllib.parse
import
urlparse
from
typing
import
Optional
,
List
,
Union
from
typing
import
Optional
,
List
,
Union
from
apps.web.models.models
import
Models
from
apps.web.models.users
import
Users
from
apps.web.models.users
import
Users
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
from
utils.utils
import
(
from
utils.utils
import
(
...
@@ -39,6 +39,8 @@ from utils.utils import (
...
@@ -39,6 +39,8 @@ from utils.utils import (
get_admin_user
,
get_admin_user
,
)
)
from
utils.models
import
get_model_id_from_custom_model_id
from
config
import
(
from
config
import
(
SRC_LOG_LEVELS
,
SRC_LOG_LEVELS
,
...
@@ -68,7 +70,6 @@ app.state.config = AppConfig()
...
@@ -68,7 +70,6 @@ app.state.config = AppConfig()
app
.
state
.
config
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
config
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
config
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
config
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
config
.
ENABLE_OLLAMA_API
=
ENABLE_OLLAMA_API
app
.
state
.
config
.
ENABLE_OLLAMA_API
=
ENABLE_OLLAMA_API
app
.
state
.
config
.
OLLAMA_BASE_URLS
=
OLLAMA_BASE_URLS
app
.
state
.
config
.
OLLAMA_BASE_URLS
=
OLLAMA_BASE_URLS
app
.
state
.
MODELS
=
{}
app
.
state
.
MODELS
=
{}
...
@@ -875,14 +876,93 @@ async def generate_chat_completion(
...
@@ -875,14 +876,93 @@ async def generate_chat_completion(
user
=
Depends
(
get_verified_user
),
user
=
Depends
(
get_verified_user
),
):
):
if
url_idx
==
None
:
log
.
debug
(
model
=
form_data
.
model
"form_data.model_dump_json(exclude_none=True).encode(): {0} "
.
format
(
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
()
)
)
if
":"
not
in
model
:
payload
=
{
model
=
f
"
{
model
}
:latest"
**
form_data
.
model_dump
(
exclude_none
=
True
),
}
if
model
in
app
.
state
.
MODELS
:
model_id
=
form_data
.
model
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
model
][
"urls"
])
model_info
=
Models
.
get_model_by_id
(
model_id
)
if
model_info
:
print
(
model_info
)
if
model_info
.
base_model_id
:
payload
[
"model"
]
=
model_info
.
base_model_id
model_info
.
params
=
model_info
.
params
.
model_dump
()
if
model_info
.
params
:
payload
[
"options"
]
=
{}
payload
[
"options"
][
"mirostat"
]
=
model_info
.
params
.
get
(
"mirostat"
,
None
)
payload
[
"options"
][
"mirostat_eta"
]
=
model_info
.
params
.
get
(
"mirostat_eta"
,
None
)
payload
[
"options"
][
"mirostat_tau"
]
=
model_info
.
params
.
get
(
"mirostat_tau"
,
None
)
payload
[
"options"
][
"num_ctx"
]
=
model_info
.
params
.
get
(
"num_ctx"
,
None
)
payload
[
"options"
][
"repeat_last_n"
]
=
model_info
.
params
.
get
(
"repeat_last_n"
,
None
)
payload
[
"options"
][
"repeat_penalty"
]
=
model_info
.
params
.
get
(
"frequency_penalty"
,
None
)
payload
[
"options"
][
"temperature"
]
=
model_info
.
params
.
get
(
"temperature"
,
None
)
payload
[
"options"
][
"seed"
]
=
model_info
.
params
.
get
(
"seed"
,
None
)
payload
[
"options"
][
"stop"
]
=
(
[
bytes
(
stop
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
stop
in
model_info
.
params
[
"stop"
]
]
if
model_info
.
params
.
get
(
"stop"
,
None
)
else
None
)
payload
[
"options"
][
"tfs_z"
]
=
model_info
.
params
.
get
(
"tfs_z"
,
None
)
payload
[
"options"
][
"num_predict"
]
=
model_info
.
params
.
get
(
"max_tokens"
,
None
)
payload
[
"options"
][
"top_k"
]
=
model_info
.
params
.
get
(
"top_k"
,
None
)
payload
[
"options"
][
"top_p"
]
=
model_info
.
params
.
get
(
"top_p"
,
None
)
if
model_info
.
params
.
get
(
"system"
,
None
):
# Check if the payload already has a system message
# If not, add a system message to the payload
if
payload
.
get
(
"messages"
):
for
message
in
payload
[
"messages"
]:
if
message
.
get
(
"role"
)
==
"system"
:
message
[
"content"
]
=
(
model_info
.
params
.
get
(
"system"
,
None
)
+
message
[
"content"
]
)
break
else
:
payload
[
"messages"
].
insert
(
0
,
{
"role"
:
"system"
,
"content"
:
model_info
.
params
.
get
(
"system"
,
None
),
},
)
if
url_idx
==
None
:
if
":"
not
in
payload
[
"model"
]:
payload
[
"model"
]
=
f
"
{
payload
[
'model'
]
}
:latest"
if
payload
[
"model"
]
in
app
.
state
.
MODELS
:
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
payload
[
"model"
]][
"urls"
])
else
:
else
:
raise
HTTPException
(
raise
HTTPException
(
status_code
=
400
,
status_code
=
400
,
...
@@ -892,16 +972,12 @@ async def generate_chat_completion(
...
@@ -892,16 +972,12 @@ async def generate_chat_completion(
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
print
(
payload
)
log
.
debug
(
r
=
None
"form_data.model_dump_json(exclude_none=True).encode(): {0} "
.
format
(
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
()
)
)
def
get_request
():
def
get_request
():
nonlocal
form_data
nonlocal
payload
nonlocal
r
nonlocal
r
request_id
=
str
(
uuid
.
uuid4
())
request_id
=
str
(
uuid
.
uuid4
())
...
@@ -910,7 +986,7 @@ async def generate_chat_completion(
...
@@ -910,7 +986,7 @@ async def generate_chat_completion(
def
stream_content
():
def
stream_content
():
try
:
try
:
if
form_data
.
stream
:
if
payload
.
get
(
"stream"
,
None
)
:
yield
json
.
dumps
({
"id"
:
request_id
,
"done"
:
False
})
+
"
\n
"
yield
json
.
dumps
({
"id"
:
request_id
,
"done"
:
False
})
+
"
\n
"
for
chunk
in
r
.
iter_content
(
chunk_size
=
8192
):
for
chunk
in
r
.
iter_content
(
chunk_size
=
8192
):
...
@@ -928,7 +1004,7 @@ async def generate_chat_completion(
...
@@ -928,7 +1004,7 @@ async def generate_chat_completion(
r
=
requests
.
request
(
r
=
requests
.
request
(
method
=
"POST"
,
method
=
"POST"
,
url
=
f
"
{
url
}
/api/chat"
,
url
=
f
"
{
url
}
/api/chat"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(
),
data
=
json
.
dumps
(
payload
),
stream
=
True
,
stream
=
True
,
)
)
...
@@ -984,14 +1060,62 @@ async def generate_openai_chat_completion(
...
@@ -984,14 +1060,62 @@ async def generate_openai_chat_completion(
user
=
Depends
(
get_verified_user
),
user
=
Depends
(
get_verified_user
),
):
):
if
url_idx
==
None
:
payload
=
{
model
=
form_data
.
model
**
form_data
.
model_dump
(
exclude_none
=
True
),
}
if
":"
not
in
model
:
model_id
=
form_data
.
model
model
=
f
"
{
model
}
:latest"
model
_info
=
Models
.
get_model_by_id
(
model_id
)
if
model
in
app
.
state
.
MODELS
:
if
model_info
:
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
model
][
"urls"
])
print
(
model_info
)
if
model_info
.
base_model_id
:
payload
[
"model"
]
=
model_info
.
base_model_id
model_info
.
params
=
model_info
.
params
.
model_dump
()
if
model_info
.
params
:
payload
[
"temperature"
]
=
model_info
.
params
.
get
(
"temperature"
,
None
)
payload
[
"top_p"
]
=
model_info
.
params
.
get
(
"top_p"
,
None
)
payload
[
"max_tokens"
]
=
model_info
.
params
.
get
(
"max_tokens"
,
None
)
payload
[
"frequency_penalty"
]
=
model_info
.
params
.
get
(
"frequency_penalty"
,
None
)
payload
[
"seed"
]
=
model_info
.
params
.
get
(
"seed"
,
None
)
payload
[
"stop"
]
=
(
[
bytes
(
stop
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
stop
in
model_info
.
params
[
"stop"
]
]
if
model_info
.
params
.
get
(
"stop"
,
None
)
else
None
)
if
model_info
.
params
.
get
(
"system"
,
None
):
# Check if the payload already has a system message
# If not, add a system message to the payload
if
payload
.
get
(
"messages"
):
for
message
in
payload
[
"messages"
]:
if
message
.
get
(
"role"
)
==
"system"
:
message
[
"content"
]
=
(
model_info
.
params
.
get
(
"system"
,
None
)
+
message
[
"content"
]
)
break
else
:
payload
[
"messages"
].
insert
(
0
,
{
"role"
:
"system"
,
"content"
:
model_info
.
params
.
get
(
"system"
,
None
),
},
)
if
url_idx
==
None
:
if
":"
not
in
payload
[
"model"
]:
payload
[
"model"
]
=
f
"
{
payload
[
'model'
]
}
:latest"
if
payload
[
"model"
]
in
app
.
state
.
MODELS
:
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
payload
[
"model"
]][
"urls"
])
else
:
else
:
raise
HTTPException
(
raise
HTTPException
(
status_code
=
400
,
status_code
=
400
,
...
@@ -1004,7 +1128,7 @@ async def generate_openai_chat_completion(
...
@@ -1004,7 +1128,7 @@ async def generate_openai_chat_completion(
r
=
None
r
=
None
def
get_request
():
def
get_request
():
nonlocal
form_data
nonlocal
payload
nonlocal
r
nonlocal
r
request_id
=
str
(
uuid
.
uuid4
())
request_id
=
str
(
uuid
.
uuid4
())
...
@@ -1013,7 +1137,7 @@ async def generate_openai_chat_completion(
...
@@ -1013,7 +1137,7 @@ async def generate_openai_chat_completion(
def
stream_content
():
def
stream_content
():
try
:
try
:
if
form_data
.
stream
:
if
payload
.
get
(
"
stream
"
)
:
yield
json
.
dumps
(
yield
json
.
dumps
(
{
"request_id"
:
request_id
,
"done"
:
False
}
{
"request_id"
:
request_id
,
"done"
:
False
}
)
+
"
\n
"
)
+
"
\n
"
...
@@ -1033,7 +1157,7 @@ async def generate_openai_chat_completion(
...
@@ -1033,7 +1157,7 @@ async def generate_openai_chat_completion(
r
=
requests
.
request
(
r
=
requests
.
request
(
method
=
"POST"
,
method
=
"POST"
,
url
=
f
"
{
url
}
/v1/chat/completions"
,
url
=
f
"
{
url
}
/v1/chat/completions"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(
),
data
=
json
.
dumps
(
payload
),
stream
=
True
,
stream
=
True
,
)
)
...
...
backend/apps/openai/main.py
View file @
b1265c9c
...
@@ -10,7 +10,7 @@ import logging
...
@@ -10,7 +10,7 @@ import logging
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
apps.web.models.models
import
Models
from
apps.web.models.users
import
Users
from
apps.web.models.users
import
Users
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
from
utils.utils
import
(
from
utils.utils
import
(
...
@@ -53,7 +53,6 @@ app.state.config = AppConfig()
...
@@ -53,7 +53,6 @@ app.state.config = AppConfig()
app
.
state
.
config
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
config
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
config
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
config
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
config
.
ENABLE_OPENAI_API
=
ENABLE_OPENAI_API
app
.
state
.
config
.
ENABLE_OPENAI_API
=
ENABLE_OPENAI_API
app
.
state
.
config
.
OPENAI_API_BASE_URLS
=
OPENAI_API_BASE_URLS
app
.
state
.
config
.
OPENAI_API_BASE_URLS
=
OPENAI_API_BASE_URLS
app
.
state
.
config
.
OPENAI_API_KEYS
=
OPENAI_API_KEYS
app
.
state
.
config
.
OPENAI_API_KEYS
=
OPENAI_API_KEYS
...
@@ -206,7 +205,13 @@ def merge_models_lists(model_lists):
...
@@ -206,7 +205,13 @@ def merge_models_lists(model_lists):
if
models
is
not
None
and
"error"
not
in
models
:
if
models
is
not
None
and
"error"
not
in
models
:
merged_list
.
extend
(
merged_list
.
extend
(
[
[
{
**
model
,
"urlIdx"
:
idx
}
{
**
model
,
"name"
:
model
.
get
(
"name"
,
model
[
"id"
]),
"owned_by"
:
"openai"
,
"openai"
:
model
,
"urlIdx"
:
idx
,
}
for
model
in
models
for
model
in
models
if
"api.openai.com"
if
"api.openai.com"
not
in
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]
not
in
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]
...
@@ -252,7 +257,7 @@ async def get_all_models():
...
@@ -252,7 +257,7 @@ async def get_all_models():
log
.
info
(
f
"models:
{
models
}
"
)
log
.
info
(
f
"models:
{
models
}
"
)
app
.
state
.
MODELS
=
{
model
[
"id"
]:
model
for
model
in
models
[
"data"
]}
app
.
state
.
MODELS
=
{
model
[
"id"
]:
model
for
model
in
models
[
"data"
]}
return
models
return
models
@
app
.
get
(
"/models"
)
@
app
.
get
(
"/models"
)
...
@@ -306,44 +311,97 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
...
@@ -306,44 +311,97 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
@
app
.
api_route
(
"/{path:path}"
,
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
@
app
.
api_route
(
"/{path:path}"
,
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
async
def
proxy
(
path
:
str
,
request
:
Request
,
user
=
Depends
(
get_verified_user
)):
async
def
proxy
(
path
:
str
,
request
:
Request
,
user
=
Depends
(
get_verified_user
)):
idx
=
0
idx
=
0
pipeline
=
False
body
=
await
request
.
body
()
body
=
await
request
.
body
()
# TODO: Remove below after gpt-4-vision fix from Open AI
# TODO: Remove below after gpt-4-vision fix from Open AI
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
payload
=
None
try
:
try
:
body
=
body
.
decode
(
"utf-8"
)
if
"chat/completions"
in
path
:
body
=
json
.
loads
(
body
)
body
=
body
.
decode
(
"utf-8"
)
body
=
json
.
loads
(
body
)
model
=
app
.
state
.
MODELS
[
body
.
get
(
"model"
)]
payload
=
{
**
body
}
idx
=
model
[
"urlIdx"
]
model_id
=
body
.
get
(
"model"
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
if
"pipeline"
in
model
:
if
model_info
:
pipeline
=
model
.
get
(
"pipeline"
)
print
(
model_info
)
if
model_info
.
base_model_id
:
payload
[
"model"
]
=
model_info
.
base_model_id
if
pipeline
:
model_info
.
params
=
model_info
.
params
.
model_dump
()
body
[
"user"
]
=
{
"name"
:
user
.
name
,
"id"
:
user
.
id
}
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
if
model_info
.
params
:
# This is a workaround until OpenAI fixes the issue with this model
payload
[
"temperature"
]
=
model_info
.
params
.
get
(
"temperature"
,
None
)
if
body
.
get
(
"model"
)
==
"gpt-4-vision-preview"
:
payload
[
"top_p"
]
=
model_info
.
params
.
get
(
"top_p"
,
None
)
if
"max_tokens"
not
in
body
:
payload
[
"max_tokens"
]
=
model_info
.
params
.
get
(
"max_tokens"
,
None
)
body
[
"max_tokens"
]
=
4000
payload
[
"frequency_penalty"
]
=
model_info
.
params
.
get
(
log
.
debug
(
"Modified body_dict:"
,
body
)
"frequency_penalty"
,
None
)
payload
[
"seed"
]
=
model_info
.
params
.
get
(
"seed"
,
None
)
payload
[
"stop"
]
=
(
[
bytes
(
stop
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
stop
in
model_info
.
params
[
"stop"
]
]
if
model_info
.
params
.
get
(
"stop"
,
None
)
else
None
)
if
model_info
.
params
.
get
(
"system"
,
None
):
# Check if the payload already has a system message
# If not, add a system message to the payload
if
payload
.
get
(
"messages"
):
for
message
in
payload
[
"messages"
]:
if
message
.
get
(
"role"
)
==
"system"
:
message
[
"content"
]
=
(
model_info
.
params
.
get
(
"system"
,
None
)
+
message
[
"content"
]
)
break
else
:
payload
[
"messages"
].
insert
(
0
,
{
"role"
:
"system"
,
"content"
:
model_info
.
params
.
get
(
"system"
,
None
),
},
)
else
:
pass
print
(
app
.
state
.
MODELS
)
model
=
app
.
state
.
MODELS
[
payload
.
get
(
"model"
)]
idx
=
model
[
"urlIdx"
]
if
"pipeline"
in
model
and
model
.
get
(
"pipeline"
):
payload
[
"user"
]
=
{
"name"
:
user
.
name
,
"id"
:
user
.
id
}
payload
[
"title"
]
=
(
True
if
payload
[
"stream"
]
==
False
and
payload
[
"max_tokens"
]
==
50
else
False
)
# Fix for ChatGPT calls failing because the num_ctx key is in body
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
if
"num_ctx"
in
body
:
# This is a workaround until OpenAI fixes the issue with this model
# If 'num_ctx' is in the dictionary, delete it
if
payload
.
get
(
"model"
)
==
"gpt-4-vision-preview"
:
# Leaving it there generates an error with the
if
"max_tokens"
not
in
payload
:
# OpenAI API (Feb 2024)
payload
[
"max_tokens"
]
=
4000
del
body
[
"num_ctx"
]
log
.
debug
(
"Modified payload:"
,
payload
)
# Convert the modified body back to JSON
payload
=
json
.
dumps
(
payload
)
# Convert the modified body back to JSON
body
=
json
.
dumps
(
body
)
except
json
.
JSONDecodeError
as
e
:
except
json
.
JSONDecodeError
as
e
:
log
.
error
(
"Error loading request body into a dictionary:"
,
e
)
log
.
error
(
"Error loading request body into a dictionary:"
,
e
)
print
(
payload
)
url
=
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]
url
=
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]
key
=
app
.
state
.
config
.
OPENAI_API_KEYS
[
idx
]
key
=
app
.
state
.
config
.
OPENAI_API_KEYS
[
idx
]
...
@@ -362,7 +420,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
...
@@ -362,7 +420,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
r
=
requests
.
request
(
r
=
requests
.
request
(
method
=
request
.
method
,
method
=
request
.
method
,
url
=
target_url
,
url
=
target_url
,
data
=
body
,
data
=
payload
if
payload
else
body
,
headers
=
headers
,
headers
=
headers
,
stream
=
True
,
stream
=
True
,
)
)
...
...
backend/apps/web/internal/db.py
View file @
b1265c9c
import
json
from
peewee
import
*
from
peewee
import
*
from
peewee_migrate
import
Router
from
peewee_migrate
import
Router
from
playhouse.db_url
import
connect
from
playhouse.db_url
import
connect
...
@@ -8,6 +10,16 @@ import logging
...
@@ -8,6 +10,16 @@ import logging
log
=
logging
.
getLogger
(
__name__
)
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"DB"
])
log
.
setLevel
(
SRC_LOG_LEVELS
[
"DB"
])
class
JSONField
(
TextField
):
def
db_value
(
self
,
value
):
return
json
.
dumps
(
value
)
def
python_value
(
self
,
value
):
if
value
is
not
None
:
return
json
.
loads
(
value
)
# Check if the file exists
# Check if the file exists
if
os
.
path
.
exists
(
f
"
{
DATA_DIR
}
/ollama.db"
):
if
os
.
path
.
exists
(
f
"
{
DATA_DIR
}
/ollama.db"
):
# Rename the file
# Rename the file
...
...
backend/apps/web/internal/migrations/009_add_models.py
0 → 100644
View file @
b1265c9c
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from
contextlib
import
suppress
import
peewee
as
pw
from
peewee_migrate
import
Migrator
with
suppress
(
ImportError
):
import
playhouse.postgres_ext
as
pw_pext
def
migrate
(
migrator
:
Migrator
,
database
:
pw
.
Database
,
*
,
fake
=
False
):
"""Write your migrations here."""
@
migrator
.
create_model
class
Model
(
pw
.
Model
):
id
=
pw
.
TextField
(
unique
=
True
)
user_id
=
pw
.
TextField
()
base_model_id
=
pw
.
TextField
(
null
=
True
)
name
=
pw
.
TextField
()
meta
=
pw
.
TextField
()
params
=
pw
.
TextField
()
created_at
=
pw
.
BigIntegerField
(
null
=
False
)
updated_at
=
pw
.
BigIntegerField
(
null
=
False
)
class
Meta
:
table_name
=
"model"
def
rollback
(
migrator
:
Migrator
,
database
:
pw
.
Database
,
*
,
fake
=
False
):
"""Write your rollback migrations here."""
migrator
.
remove_model
(
"model"
)
backend/apps/web/internal/migrations/010_migrate_modelfiles_to_models.py
0 → 100644
View file @
b1265c9c
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from
contextlib
import
suppress
import
peewee
as
pw
from
peewee_migrate
import
Migrator
import
json
from
utils.misc
import
parse_ollama_modelfile
with
suppress
(
ImportError
):
import
playhouse.postgres_ext
as
pw_pext
def
migrate
(
migrator
:
Migrator
,
database
:
pw
.
Database
,
*
,
fake
=
False
):
"""Write your migrations here."""
# Fetch data from 'modelfile' table and insert into 'model' table
migrate_modelfile_to_model
(
migrator
,
database
)
# Drop the 'modelfile' table
migrator
.
remove_model
(
"modelfile"
)
def
migrate_modelfile_to_model
(
migrator
:
Migrator
,
database
:
pw
.
Database
):
ModelFile
=
migrator
.
orm
[
"modelfile"
]
Model
=
migrator
.
orm
[
"model"
]
modelfiles
=
ModelFile
.
select
()
for
modelfile
in
modelfiles
:
# Extract and transform data in Python
modelfile
.
modelfile
=
json
.
loads
(
modelfile
.
modelfile
)
meta
=
json
.
dumps
(
{
"description"
:
modelfile
.
modelfile
.
get
(
"desc"
),
"profile_image_url"
:
modelfile
.
modelfile
.
get
(
"imageUrl"
),
"ollama"
:
{
"modelfile"
:
modelfile
.
modelfile
.
get
(
"content"
)},
"suggestion_prompts"
:
modelfile
.
modelfile
.
get
(
"suggestionPrompts"
),
"categories"
:
modelfile
.
modelfile
.
get
(
"categories"
),
"user"
:
{
**
modelfile
.
modelfile
.
get
(
"user"
,
{}),
"community"
:
True
},
}
)
info
=
parse_ollama_modelfile
(
modelfile
.
modelfile
.
get
(
"content"
))
# Insert the processed data into the 'model' table
Model
.
create
(
id
=
f
"ollama-
{
modelfile
.
tag_name
}
"
,
user_id
=
modelfile
.
user_id
,
base_model_id
=
info
.
get
(
"base_model_id"
),
name
=
modelfile
.
modelfile
.
get
(
"title"
),
meta
=
meta
,
params
=
json
.
dumps
(
info
.
get
(
"params"
,
{})),
created_at
=
modelfile
.
timestamp
,
updated_at
=
modelfile
.
timestamp
,
)
def
rollback
(
migrator
:
Migrator
,
database
:
pw
.
Database
,
*
,
fake
=
False
):
"""Write your rollback migrations here."""
recreate_modelfile_table
(
migrator
,
database
)
move_data_back_to_modelfile
(
migrator
,
database
)
migrator
.
remove_model
(
"model"
)
def
recreate_modelfile_table
(
migrator
:
Migrator
,
database
:
pw
.
Database
):
query
=
"""
CREATE TABLE IF NOT EXISTS modelfile (
user_id TEXT,
tag_name TEXT,
modelfile JSON,
timestamp BIGINT
)
"""
migrator
.
sql
(
query
)
def
move_data_back_to_modelfile
(
migrator
:
Migrator
,
database
:
pw
.
Database
):
Model
=
migrator
.
orm
[
"model"
]
Modelfile
=
migrator
.
orm
[
"modelfile"
]
models
=
Model
.
select
()
for
model
in
models
:
# Extract and transform data in Python
meta
=
json
.
loads
(
model
.
meta
)
modelfile_data
=
{
"title"
:
model
.
name
,
"desc"
:
meta
.
get
(
"description"
),
"imageUrl"
:
meta
.
get
(
"profile_image_url"
),
"content"
:
meta
.
get
(
"ollama"
,
{}).
get
(
"modelfile"
),
"suggestionPrompts"
:
meta
.
get
(
"suggestion_prompts"
),
"categories"
:
meta
.
get
(
"categories"
),
"user"
:
{
k
:
v
for
k
,
v
in
meta
.
get
(
"user"
,
{}).
items
()
if
k
!=
"community"
},
}
# Insert the processed data back into the 'modelfile' table
Modelfile
.
create
(
user_id
=
model
.
user_id
,
tag_name
=
model
.
id
,
modelfile
=
modelfile_data
,
timestamp
=
model
.
created_at
,
)
backend/apps/web/main.py
View file @
b1265c9c
...
@@ -6,7 +6,7 @@ from apps.web.routers import (
...
@@ -6,7 +6,7 @@ from apps.web.routers import (
users
,
users
,
chats
,
chats
,
documents
,
documents
,
model
file
s
,
models
,
prompts
,
prompts
,
configs
,
configs
,
memories
,
memories
,
...
@@ -40,6 +40,9 @@ app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
...
@@ -40,6 +40,9 @@ app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app
.
state
.
config
.
DEFAULT_USER_ROLE
=
DEFAULT_USER_ROLE
app
.
state
.
config
.
DEFAULT_USER_ROLE
=
DEFAULT_USER_ROLE
app
.
state
.
config
.
USER_PERMISSIONS
=
USER_PERMISSIONS
app
.
state
.
config
.
USER_PERMISSIONS
=
USER_PERMISSIONS
app
.
state
.
config
.
WEBHOOK_URL
=
WEBHOOK_URL
app
.
state
.
config
.
WEBHOOK_URL
=
WEBHOOK_URL
app
.
state
.
MODELS
=
{}
app
.
state
.
AUTH_TRUSTED_EMAIL_HEADER
=
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app
.
state
.
AUTH_TRUSTED_EMAIL_HEADER
=
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
...
@@ -56,11 +59,10 @@ app.include_router(users.router, prefix="/users", tags=["users"])
...
@@ -56,11 +59,10 @@ app.include_router(users.router, prefix="/users", tags=["users"])
app
.
include_router
(
chats
.
router
,
prefix
=
"/chats"
,
tags
=
[
"chats"
])
app
.
include_router
(
chats
.
router
,
prefix
=
"/chats"
,
tags
=
[
"chats"
])
app
.
include_router
(
documents
.
router
,
prefix
=
"/documents"
,
tags
=
[
"documents"
])
app
.
include_router
(
documents
.
router
,
prefix
=
"/documents"
,
tags
=
[
"documents"
])
app
.
include_router
(
model
file
s
.
router
,
prefix
=
"/model
file
s"
,
tags
=
[
"model
file
s"
])
app
.
include_router
(
models
.
router
,
prefix
=
"/models"
,
tags
=
[
"models"
])
app
.
include_router
(
prompts
.
router
,
prefix
=
"/prompts"
,
tags
=
[
"prompts"
])
app
.
include_router
(
prompts
.
router
,
prefix
=
"/prompts"
,
tags
=
[
"prompts"
])
app
.
include_router
(
memories
.
router
,
prefix
=
"/memories"
,
tags
=
[
"memories"
])
app
.
include_router
(
memories
.
router
,
prefix
=
"/memories"
,
tags
=
[
"memories"
])
app
.
include_router
(
configs
.
router
,
prefix
=
"/configs"
,
tags
=
[
"configs"
])
app
.
include_router
(
configs
.
router
,
prefix
=
"/configs"
,
tags
=
[
"configs"
])
app
.
include_router
(
utils
.
router
,
prefix
=
"/utils"
,
tags
=
[
"utils"
])
app
.
include_router
(
utils
.
router
,
prefix
=
"/utils"
,
tags
=
[
"utils"
])
...
...
backend/apps/web/models/modelfiles.py
View file @
b1265c9c
################################################################################
# DEPRECATION NOTICE #
# #
# This file has been deprecated since version 0.2.0. #
# #
################################################################################
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
peewee
import
*
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
playhouse.shortcuts
import
model_to_dict
...
...
backend/apps/web/models/models.py
0 → 100644
View file @
b1265c9c
import
json
import
logging
from
typing
import
Optional
import
peewee
as
pw
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
pydantic
import
BaseModel
,
ConfigDict
from
apps.web.internal.db
import
DB
,
JSONField
from
typing
import
List
,
Union
,
Optional
from
config
import
SRC_LOG_LEVELS
import
time
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
####################
# Models DB Schema
####################
# ModelParams is a model for the data stored in the params field of the Model table
class
ModelParams
(
BaseModel
):
model_config
=
ConfigDict
(
extra
=
"allow"
)
pass
# ModelMeta is a model for the data stored in the meta field of the Model table
class
ModelMeta
(
BaseModel
):
profile_image_url
:
Optional
[
str
]
=
"/favicon.png"
description
:
Optional
[
str
]
=
None
"""
User-facing description of the model.
"""
capabilities
:
Optional
[
dict
]
=
None
model_config
=
ConfigDict
(
extra
=
"allow"
)
pass
class
Model
(
pw
.
Model
):
id
=
pw
.
TextField
(
unique
=
True
)
"""
The model's id as used in the API. If set to an existing model, it will override the model.
"""
user_id
=
pw
.
TextField
()
base_model_id
=
pw
.
TextField
(
null
=
True
)
"""
An optional pointer to the actual model that should be used when proxying requests.
"""
name
=
pw
.
TextField
()
"""
The human-readable display name of the model.
"""
params
=
JSONField
()
"""
Holds a JSON encoded blob of parameters, see `ModelParams`.
"""
meta
=
JSONField
()
"""
Holds a JSON encoded blob of metadata, see `ModelMeta`.
"""
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Meta
:
database
=
DB
class
ModelModel
(
BaseModel
):
id
:
str
user_id
:
str
base_model_id
:
Optional
[
str
]
=
None
name
:
str
params
:
ModelParams
meta
:
ModelMeta
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
####################
# Forms
####################
class
ModelResponse
(
BaseModel
):
id
:
str
name
:
str
meta
:
ModelMeta
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
class
ModelForm
(
BaseModel
):
id
:
str
base_model_id
:
Optional
[
str
]
=
None
name
:
str
meta
:
ModelMeta
params
:
ModelParams
class
ModelsTable
:
def
__init__
(
self
,
db
:
pw
.
SqliteDatabase
|
pw
.
PostgresqlDatabase
,
):
self
.
db
=
db
self
.
db
.
create_tables
([
Model
])
def
insert_new_model
(
self
,
form_data
:
ModelForm
,
user_id
:
str
)
->
Optional
[
ModelModel
]:
model
=
ModelModel
(
**
{
**
form_data
.
model_dump
(),
"user_id"
:
user_id
,
"created_at"
:
int
(
time
.
time
()),
"updated_at"
:
int
(
time
.
time
()),
}
)
try
:
result
=
Model
.
create
(
**
model
.
model_dump
())
if
result
:
return
model
else
:
return
None
except
Exception
as
e
:
print
(
e
)
return
None
def
get_all_models
(
self
)
->
List
[
ModelModel
]:
return
[
ModelModel
(
**
model_to_dict
(
model
))
for
model
in
Model
.
select
()]
def
get_model_by_id
(
self
,
id
:
str
)
->
Optional
[
ModelModel
]:
try
:
model
=
Model
.
get
(
Model
.
id
==
id
)
return
ModelModel
(
**
model_to_dict
(
model
))
except
:
return
None
def
update_model_by_id
(
self
,
id
:
str
,
model
:
ModelForm
)
->
Optional
[
ModelModel
]:
try
:
# update only the fields that are present in the model
query
=
Model
.
update
(
**
model
.
model_dump
()).
where
(
Model
.
id
==
id
)
query
.
execute
()
model
=
Model
.
get
(
Model
.
id
==
id
)
return
ModelModel
(
**
model_to_dict
(
model
))
except
Exception
as
e
:
print
(
e
)
return
None
def
delete_model_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
query
=
Model
.
delete
().
where
(
Model
.
id
==
id
)
query
.
execute
()
return
True
except
:
return
False
Models
=
ModelsTable
(
DB
)
backend/apps/web/routers/modelfiles.py
deleted
100644 → 0
View file @
60433856
from
fastapi
import
Depends
,
FastAPI
,
HTTPException
,
status
from
datetime
import
datetime
,
timedelta
from
typing
import
List
,
Union
,
Optional
from
fastapi
import
APIRouter
from
pydantic
import
BaseModel
import
json
from
apps.web.models.modelfiles
import
(
Modelfiles
,
ModelfileForm
,
ModelfileTagNameForm
,
ModelfileUpdateForm
,
ModelfileResponse
,
)
from
utils.utils
import
get_current_user
,
get_admin_user
from
constants
import
ERROR_MESSAGES
router
=
APIRouter
()
############################
# GetModelfiles
############################
@
router
.
get
(
"/"
,
response_model
=
List
[
ModelfileResponse
])
async
def
get_modelfiles
(
skip
:
int
=
0
,
limit
:
int
=
50
,
user
=
Depends
(
get_current_user
)
):
return
Modelfiles
.
get_modelfiles
(
skip
,
limit
)
############################
# CreateNewModelfile
############################
@
router
.
post
(
"/create"
,
response_model
=
Optional
[
ModelfileResponse
])
async
def
create_new_modelfile
(
form_data
:
ModelfileForm
,
user
=
Depends
(
get_admin_user
)):
modelfile
=
Modelfiles
.
insert_new_modelfile
(
user
.
id
,
form_data
)
if
modelfile
:
return
ModelfileResponse
(
**
{
**
modelfile
.
model_dump
(),
"modelfile"
:
json
.
loads
(
modelfile
.
modelfile
),
}
)
else
:
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(),
)
############################
# GetModelfileByTagName
############################
@
router
.
post
(
"/"
,
response_model
=
Optional
[
ModelfileResponse
])
async
def
get_modelfile_by_tag_name
(
form_data
:
ModelfileTagNameForm
,
user
=
Depends
(
get_current_user
)
):
modelfile
=
Modelfiles
.
get_modelfile_by_tag_name
(
form_data
.
tag_name
)
if
modelfile
:
return
ModelfileResponse
(
**
{
**
modelfile
.
model_dump
(),
"modelfile"
:
json
.
loads
(
modelfile
.
modelfile
),
}
)
else
:
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
detail
=
ERROR_MESSAGES
.
NOT_FOUND
,
)
############################
# UpdateModelfileByTagName
############################
@
router
.
post
(
"/update"
,
response_model
=
Optional
[
ModelfileResponse
])
async
def
update_modelfile_by_tag_name
(
form_data
:
ModelfileUpdateForm
,
user
=
Depends
(
get_admin_user
)
):
modelfile
=
Modelfiles
.
get_modelfile_by_tag_name
(
form_data
.
tag_name
)
if
modelfile
:
updated_modelfile
=
{
**
json
.
loads
(
modelfile
.
modelfile
),
**
form_data
.
modelfile
,
}
modelfile
=
Modelfiles
.
update_modelfile_by_tag_name
(
form_data
.
tag_name
,
updated_modelfile
)
return
ModelfileResponse
(
**
{
**
modelfile
.
model_dump
(),
"modelfile"
:
json
.
loads
(
modelfile
.
modelfile
),
}
)
else
:
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
,
)
############################
# DeleteModelfileByTagName
############################
@
router
.
delete
(
"/delete"
,
response_model
=
bool
)
async
def
delete_modelfile_by_tag_name
(
form_data
:
ModelfileTagNameForm
,
user
=
Depends
(
get_admin_user
)
):
result
=
Modelfiles
.
delete_modelfile_by_tag_name
(
form_data
.
tag_name
)
return
result
backend/apps/web/routers/models.py
0 → 100644
View file @
b1265c9c
from
fastapi
import
Depends
,
FastAPI
,
HTTPException
,
status
,
Request
from
datetime
import
datetime
,
timedelta
from
typing
import
List
,
Union
,
Optional
from
fastapi
import
APIRouter
from
pydantic
import
BaseModel
import
json
from
apps.web.models.models
import
Models
,
ModelModel
,
ModelForm
,
ModelResponse
from
utils.utils
import
get_verified_user
,
get_admin_user
from
constants
import
ERROR_MESSAGES
router
=
APIRouter
()
###########################
# getModels
###########################
@
router
.
get
(
"/"
,
response_model
=
List
[
ModelResponse
])
async
def
get_models
(
user
=
Depends
(
get_verified_user
)):
return
Models
.
get_all_models
()
############################
# AddNewModel
############################
@
router
.
post
(
"/add"
,
response_model
=
Optional
[
ModelModel
])
async
def
add_new_model
(
request
:
Request
,
form_data
:
ModelForm
,
user
=
Depends
(
get_admin_user
)
):
if
form_data
.
id
in
request
.
app
.
state
.
MODELS
:
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
detail
=
ERROR_MESSAGES
.
MODEL_ID_TAKEN
,
)
else
:
model
=
Models
.
insert_new_model
(
form_data
,
user
.
id
)
if
model
:
return
model
else
:
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(),
)
############################
# GetModelById
############################
@
router
.
get
(
"/{id}"
,
response_model
=
Optional
[
ModelModel
])
async
def
get_model_by_id
(
id
:
str
,
user
=
Depends
(
get_verified_user
)):
model
=
Models
.
get_model_by_id
(
id
)
if
model
:
return
model
else
:
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
detail
=
ERROR_MESSAGES
.
NOT_FOUND
,
)
############################
# UpdateModelById
############################
@
router
.
post
(
"/{id}/update"
,
response_model
=
Optional
[
ModelModel
])
async
def
update_model_by_id
(
request
:
Request
,
id
:
str
,
form_data
:
ModelForm
,
user
=
Depends
(
get_admin_user
)
):
model
=
Models
.
get_model_by_id
(
id
)
if
model
:
model
=
Models
.
update_model_by_id
(
id
,
form_data
)
return
model
else
:
if
form_data
.
id
in
request
.
app
.
state
.
MODELS
:
model
=
Models
.
insert_new_model
(
form_data
,
user
.
id
)
print
(
model
)
if
model
:
return
model
else
:
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(),
)
else
:
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(),
)
############################
# DeleteModelById
############################
@
router
.
delete
(
"/{id}/delete"
,
response_model
=
bool
)
async
def
delete_model_by_id
(
id
:
str
,
user
=
Depends
(
get_admin_user
)):
result
=
Models
.
delete_model_by_id
(
id
)
return
result
backend/constants.py
View file @
b1265c9c
...
@@ -32,6 +32,8 @@ class ERROR_MESSAGES(str, Enum):
...
@@ -32,6 +32,8 @@ class ERROR_MESSAGES(str, Enum):
COMMAND_TAKEN
=
"Uh-oh! This command is already registered. Please choose another command string."
COMMAND_TAKEN
=
"Uh-oh! This command is already registered. Please choose another command string."
FILE_EXISTS
=
"Uh-oh! This file is already registered. Please choose another file."
FILE_EXISTS
=
"Uh-oh! This file is already registered. Please choose another file."
MODEL_ID_TAKEN
=
"Uh-oh! This model id is already registered. Please choose another model id string."
NAME_TAG_TAKEN
=
"Uh-oh! This name tag is already registered. Please choose another name tag string."
NAME_TAG_TAKEN
=
"Uh-oh! This name tag is already registered. Please choose another name tag string."
INVALID_TOKEN
=
(
INVALID_TOKEN
=
(
"Your session has expired or the token is invalid. Please sign in again."
"Your session has expired or the token is invalid. Please sign in again."
...
...
backend/main.py
View file @
b1265c9c
...
@@ -19,8 +19,8 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
...
@@ -19,8 +19,8 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
starlette.responses
import
StreamingResponse
,
Response
from
starlette.responses
import
StreamingResponse
,
Response
from
apps.ollama.main
import
app
as
ollama_app
from
apps.ollama.main
import
app
as
ollama_app
,
get_all_models
as
get_ollama_models
from
apps.openai.main
import
app
as
openai_app
from
apps.openai.main
import
app
as
openai_app
,
get_all_models
as
get_openai_models
from
apps.litellm.main
import
(
from
apps.litellm.main
import
(
app
as
litellm_app
,
app
as
litellm_app
,
...
@@ -36,10 +36,10 @@ from apps.web.main import app as webui_app
...
@@ -36,10 +36,10 @@ from apps.web.main import app as webui_app
import
asyncio
import
asyncio
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
typing
import
List
from
typing
import
List
,
Optional
from
apps.web.models.models
import
Models
,
ModelModel
from
utils.utils
import
get_admin_user
from
utils.utils
import
get_admin_user
,
get_verified_user
from
apps.rag.utils
import
rag_messages
from
apps.rag.utils
import
rag_messages
from
config
import
(
from
config
import
(
...
@@ -53,6 +53,8 @@ from config import (
...
@@ -53,6 +53,8 @@ from config import (
FRONTEND_BUILD_DIR
,
FRONTEND_BUILD_DIR
,
CACHE_DIR
,
CACHE_DIR
,
STATIC_DIR
,
STATIC_DIR
,
ENABLE_OPENAI_API
,
ENABLE_OLLAMA_API
,
ENABLE_LITELLM
,
ENABLE_LITELLM
,
ENABLE_MODEL_FILTER
,
ENABLE_MODEL_FILTER
,
MODEL_FILTER_LIST
,
MODEL_FILTER_LIST
,
...
@@ -111,11 +113,19 @@ app = FastAPI(
...
@@ -111,11 +113,19 @@ app = FastAPI(
)
)
app
.
state
.
config
=
AppConfig
()
app
.
state
.
config
=
AppConfig
()
app
.
state
.
config
.
ENABLE_OPENAI_API
=
ENABLE_OPENAI_API
app
.
state
.
config
.
ENABLE_OLLAMA_API
=
ENABLE_OLLAMA_API
app
.
state
.
config
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
config
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
config
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
config
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
config
.
WEBHOOK_URL
=
WEBHOOK_URL
app
.
state
.
config
.
WEBHOOK_URL
=
WEBHOOK_URL
app
.
state
.
MODELS
=
{}
origins
=
[
"*"
]
origins
=
[
"*"
]
...
@@ -232,6 +242,11 @@ app.add_middleware(
...
@@ -232,6 +242,11 @@ app.add_middleware(
@
app
.
middleware
(
"http"
)
@
app
.
middleware
(
"http"
)
async
def
check_url
(
request
:
Request
,
call_next
):
async
def
check_url
(
request
:
Request
,
call_next
):
if
len
(
app
.
state
.
MODELS
)
==
0
:
await
get_all_models
()
else
:
pass
start_time
=
int
(
time
.
time
())
start_time
=
int
(
time
.
time
())
response
=
await
call_next
(
request
)
response
=
await
call_next
(
request
)
process_time
=
int
(
time
.
time
())
-
start_time
process_time
=
int
(
time
.
time
())
-
start_time
...
@@ -248,9 +263,11 @@ async def update_embedding_function(request: Request, call_next):
...
@@ -248,9 +263,11 @@ async def update_embedding_function(request: Request, call_next):
return
response
return
response
# TODO: Deprecate LiteLLM
app
.
mount
(
"/litellm/api"
,
litellm_app
)
app
.
mount
(
"/litellm/api"
,
litellm_app
)
app
.
mount
(
"/ollama"
,
ollama_app
)
app
.
mount
(
"/ollama"
,
ollama_app
)
app
.
mount
(
"/openai
/api
"
,
openai_app
)
app
.
mount
(
"/openai"
,
openai_app
)
app
.
mount
(
"/images/api/v1"
,
images_app
)
app
.
mount
(
"/images/api/v1"
,
images_app
)
app
.
mount
(
"/audio/api/v1"
,
audio_app
)
app
.
mount
(
"/audio/api/v1"
,
audio_app
)
...
@@ -261,6 +278,87 @@ app.mount("/api/v1", webui_app)
...
@@ -261,6 +278,87 @@ app.mount("/api/v1", webui_app)
webui_app
.
state
.
EMBEDDING_FUNCTION
=
rag_app
.
state
.
EMBEDDING_FUNCTION
webui_app
.
state
.
EMBEDDING_FUNCTION
=
rag_app
.
state
.
EMBEDDING_FUNCTION
async
def
get_all_models
():
openai_models
=
[]
ollama_models
=
[]
if
app
.
state
.
config
.
ENABLE_OPENAI_API
:
openai_models
=
await
get_openai_models
()
openai_models
=
openai_models
[
"data"
]
if
app
.
state
.
config
.
ENABLE_OLLAMA_API
:
ollama_models
=
await
get_ollama_models
()
ollama_models
=
[
{
"id"
:
model
[
"model"
],
"name"
:
model
[
"name"
],
"object"
:
"model"
,
"created"
:
int
(
time
.
time
()),
"owned_by"
:
"ollama"
,
"ollama"
:
model
,
}
for
model
in
ollama_models
[
"models"
]
]
models
=
openai_models
+
ollama_models
custom_models
=
Models
.
get_all_models
()
for
custom_model
in
custom_models
:
if
custom_model
.
base_model_id
==
None
:
for
model
in
models
:
if
(
custom_model
.
id
==
model
[
"id"
]
or
custom_model
.
id
==
model
[
"id"
].
split
(
":"
)[
0
]
):
model
[
"name"
]
=
custom_model
.
name
model
[
"info"
]
=
custom_model
.
model_dump
()
else
:
owned_by
=
"openai"
for
model
in
models
:
if
(
custom_model
.
base_model_id
==
model
[
"id"
]
or
custom_model
.
base_model_id
==
model
[
"id"
].
split
(
":"
)[
0
]
):
owned_by
=
model
[
"owned_by"
]
break
models
.
append
(
{
"id"
:
custom_model
.
id
,
"name"
:
custom_model
.
name
,
"object"
:
"model"
,
"created"
:
custom_model
.
created_at
,
"owned_by"
:
owned_by
,
"info"
:
custom_model
.
model_dump
(),
"preset"
:
True
,
}
)
app
.
state
.
MODELS
=
{
model
[
"id"
]:
model
for
model
in
models
}
webui_app
.
state
.
MODELS
=
app
.
state
.
MODELS
return
models
@
app
.
get
(
"/api/models"
)
async
def
get_models
(
user
=
Depends
(
get_verified_user
)):
models
=
await
get_all_models
()
if
app
.
state
.
config
.
ENABLE_MODEL_FILTER
:
if
user
.
role
==
"user"
:
models
=
list
(
filter
(
lambda
model
:
model
[
"id"
]
in
app
.
state
.
config
.
MODEL_FILTER_LIST
,
models
,
)
)
return
{
"data"
:
models
}
return
{
"data"
:
models
}
@
app
.
get
(
"/api/config"
)
@
app
.
get
(
"/api/config"
)
async
def
get_app_config
():
async
def
get_app_config
():
# Checking and Handling the Absence of 'ui' in CONFIG_DATA
# Checking and Handling the Absence of 'ui' in CONFIG_DATA
...
...
backend/utils/misc.py
View file @
b1265c9c
from
pathlib
import
Path
from
pathlib
import
Path
import
hashlib
import
hashlib
import
json
import
re
import
re
from
datetime
import
timedelta
from
datetime
import
timedelta
from
typing
import
Optional
from
typing
import
Optional
...
@@ -110,3 +111,76 @@ def parse_duration(duration: str) -> Optional[timedelta]:
...
@@ -110,3 +111,76 @@ def parse_duration(duration: str) -> Optional[timedelta]:
total_duration
+=
timedelta
(
weeks
=
number
)
total_duration
+=
timedelta
(
weeks
=
number
)
return
total_duration
return
total_duration
def
parse_ollama_modelfile
(
model_text
):
parameters_meta
=
{
"mirostat"
:
int
,
"mirostat_eta"
:
float
,
"mirostat_tau"
:
float
,
"num_ctx"
:
int
,
"repeat_last_n"
:
int
,
"repeat_penalty"
:
float
,
"temperature"
:
float
,
"seed"
:
int
,
"stop"
:
str
,
"tfs_z"
:
float
,
"num_predict"
:
int
,
"top_k"
:
int
,
"top_p"
:
float
,
}
data
=
{
"base_model_id"
:
None
,
"params"
:
{}}
# Parse base model
base_model_match
=
re
.
search
(
r
"^FROM\s+(\w+)"
,
model_text
,
re
.
MULTILINE
|
re
.
IGNORECASE
)
if
base_model_match
:
data
[
"base_model_id"
]
=
base_model_match
.
group
(
1
)
# Parse template
template_match
=
re
.
search
(
r
'TEMPLATE\s+"""(.+?)"""'
,
model_text
,
re
.
DOTALL
|
re
.
IGNORECASE
)
if
template_match
:
data
[
"params"
]
=
{
"template"
:
template_match
.
group
(
1
).
strip
()}
# Parse stops
stops
=
re
.
findall
(
r
'PARAMETER stop "(.*?)"'
,
model_text
,
re
.
IGNORECASE
)
if
stops
:
data
[
"params"
][
"stop"
]
=
stops
# Parse other parameters from the provided list
for
param
,
param_type
in
parameters_meta
.
items
():
param_match
=
re
.
search
(
rf
"PARAMETER
{
param
}
(.+)"
,
model_text
,
re
.
IGNORECASE
)
if
param_match
:
value
=
param_match
.
group
(
1
)
if
param_type
==
int
:
value
=
int
(
value
)
elif
param_type
==
float
:
value
=
float
(
value
)
data
[
"params"
][
param
]
=
value
# Parse adapter
adapter_match
=
re
.
search
(
r
"ADAPTER (.+)"
,
model_text
,
re
.
IGNORECASE
)
if
adapter_match
:
data
[
"params"
][
"adapter"
]
=
adapter_match
.
group
(
1
)
# Parse system description
system_desc_match
=
re
.
search
(
r
'SYSTEM\s+"""(.+?)"""'
,
model_text
,
re
.
DOTALL
|
re
.
IGNORECASE
)
if
system_desc_match
:
data
[
"params"
][
"system"
]
=
system_desc_match
.
group
(
1
).
strip
()
# Parse messages
messages
=
[]
message_matches
=
re
.
findall
(
r
"MESSAGE (\w+) (.+)"
,
model_text
,
re
.
IGNORECASE
)
for
role
,
content
in
message_matches
:
messages
.
append
({
"role"
:
role
,
"content"
:
content
})
if
messages
:
data
[
"params"
][
"messages"
]
=
messages
return
data
backend/utils/models.py
0 → 100644
View file @
b1265c9c
from
apps.web.models.models
import
Models
,
ModelModel
,
ModelForm
,
ModelResponse
def
get_model_id_from_custom_model_id
(
id
:
str
):
model
=
Models
.
get_model_by_id
(
id
)
if
model
:
return
model
.
id
else
:
return
id
src/lib/apis/index.ts
View file @
b1265c9c
import
{
WEBUI_BASE_URL
}
from
'
$lib/constants
'
;
import
{
WEBUI_BASE_URL
}
from
'
$lib/constants
'
;
export
const
getModels
=
async
(
token
:
string
=
''
)
=>
{
let
error
=
null
;
const
res
=
await
fetch
(
`
${
WEBUI_BASE_URL
}
/api/models`
,
{
method
:
'
GET
'
,
headers
:
{
Accept
:
'
application/json
'
,
'
Content-Type
'
:
'
application/json
'
,
...(
token
&&
{
authorization
:
`Bearer
${
token
}
`
})
}
})
.
then
(
async
(
res
)
=>
{
if
(
!
res
.
ok
)
throw
await
res
.
json
();
return
res
.
json
();
})
.
catch
((
err
)
=>
{
console
.
log
(
err
);
error
=
err
;
return
null
;
});
if
(
error
)
{
throw
error
;
}
let
models
=
res
?.
data
??
[];
models
=
models
.
filter
((
models
)
=>
models
)
.
sort
((
a
,
b
)
=>
{
// Compare case-insensitively
const
lowerA
=
a
.
name
.
toLowerCase
();
const
lowerB
=
b
.
name
.
toLowerCase
();
if
(
lowerA
<
lowerB
)
return
-
1
;
if
(
lowerA
>
lowerB
)
return
1
;
// If same case-insensitively, sort by original strings,
// lowercase will come before uppercase due to ASCII values
if
(
a
<
b
)
return
-
1
;
if
(
a
>
b
)
return
1
;
return
0
;
// They are equal
});
console
.
log
(
models
);
return
models
;
};
export
const
getBackendConfig
=
async
()
=>
{
export
const
getBackendConfig
=
async
()
=>
{
let
error
=
null
;
let
error
=
null
;
...
@@ -196,3 +245,77 @@ export const updateWebhookUrl = async (token: string, url: string) => {
...
@@ -196,3 +245,77 @@ export const updateWebhookUrl = async (token: string, url: string) => {
return
res
.
url
;
return
res
.
url
;
};
};
export
const
getModelConfig
=
async
(
token
:
string
):
Promise
<
GlobalModelConfig
>
=>
{
let
error
=
null
;
const
res
=
await
fetch
(
`
${
WEBUI_BASE_URL
}
/api/config/models`
,
{
method
:
'
GET
'
,
headers
:
{
'
Content-Type
'
:
'
application/json
'
,
Authorization
:
`Bearer
${
token
}
`
}
})
.
then
(
async
(
res
)
=>
{
if
(
!
res
.
ok
)
throw
await
res
.
json
();
return
res
.
json
();
})
.
catch
((
err
)
=>
{
console
.
log
(
err
);
error
=
err
;
return
null
;
});
if
(
error
)
{
throw
error
;
}
return
res
.
models
;
};
export
interface
ModelConfig
{
id
:
string
;
name
:
string
;
meta
:
ModelMeta
;
base_model_id
?:
string
;
params
:
ModelParams
;
}
export
interface
ModelMeta
{
description
?:
string
;
capabilities
?:
object
;
}
export
interface
ModelParams
{}
export
type
GlobalModelConfig
=
ModelConfig
[];
export
const
updateModelConfig
=
async
(
token
:
string
,
config
:
GlobalModelConfig
)
=>
{
let
error
=
null
;
const
res
=
await
fetch
(
`
${
WEBUI_BASE_URL
}
/api/config/models`
,
{
method
:
'
POST
'
,
headers
:
{
'
Content-Type
'
:
'
application/json
'
,
Authorization
:
`Bearer
${
token
}
`
},
body
:
JSON
.
stringify
({
models
:
config
})
})
.
then
(
async
(
res
)
=>
{
if
(
!
res
.
ok
)
throw
await
res
.
json
();
return
res
.
json
();
})
.
catch
((
err
)
=>
{
console
.
log
(
err
);
error
=
err
;
return
null
;
});
if
(
error
)
{
throw
error
;
}
return
res
;
};
src/lib/apis/litellm/index.ts
View file @
b1265c9c
...
@@ -33,7 +33,8 @@ export const getLiteLLMModels = async (token: string = '') => {
...
@@ -33,7 +33,8 @@ export const getLiteLLMModels = async (token: string = '') => {
id
:
model
.
id
,
id
:
model
.
id
,
name
:
model
.
name
??
model
.
id
,
name
:
model
.
name
??
model
.
id
,
external
:
true
,
external
:
true
,
source
:
'
LiteLLM
'
source
:
'
LiteLLM
'
,
custom_info
:
model
.
custom_info
}))
}))
.
sort
((
a
,
b
)
=>
{
.
sort
((
a
,
b
)
=>
{
return
a
.
name
.
localeCompare
(
b
.
name
);
return
a
.
name
.
localeCompare
(
b
.
name
);
...
...
src/lib/apis/model
file
s/index.ts
→
src/lib/apis/models/index.ts
View file @
b1265c9c
import
{
WEBUI_API_BASE_URL
}
from
'
$lib/constants
'
;
import
{
WEBUI_API_BASE_URL
}
from
'
$lib/constants
'
;
export
const
create
NewModel
file
=
async
(
token
:
string
,
model
file
:
object
)
=>
{
export
const
add
NewModel
=
async
(
token
:
string
,
model
:
object
)
=>
{
let
error
=
null
;
let
error
=
null
;
const
res
=
await
fetch
(
`
${
WEBUI_API_BASE_URL
}
/model
files/create
`
,
{
const
res
=
await
fetch
(
`
${
WEBUI_API_BASE_URL
}
/model
s/add
`
,
{
method
:
'
POST
'
,
method
:
'
POST
'
,
headers
:
{
headers
:
{
Accept
:
'
application/json
'
,
Accept
:
'
application/json
'
,
'
Content-Type
'
:
'
application/json
'
,
'
Content-Type
'
:
'
application/json
'
,
authorization
:
`Bearer
${
token
}
`
authorization
:
`Bearer
${
token
}
`
},
},
body
:
JSON
.
stringify
({
body
:
JSON
.
stringify
(
model
)
modelfile
:
modelfile
})
})
})
.
then
(
async
(
res
)
=>
{
.
then
(
async
(
res
)
=>
{
if
(
!
res
.
ok
)
throw
await
res
.
json
();
if
(
!
res
.
ok
)
throw
await
res
.
json
();
...
@@ -31,10 +29,10 @@ export const createNewModelfile = async (token: string, modelfile: object) => {
...
@@ -31,10 +29,10 @@ export const createNewModelfile = async (token: string, modelfile: object) => {
return
res
;
return
res
;
};
};
export
const
getModel
file
s
=
async
(
token
:
string
=
''
)
=>
{
export
const
getModel
Info
s
=
async
(
token
:
string
=
''
)
=>
{
let
error
=
null
;
let
error
=
null
;
const
res
=
await
fetch
(
`
${
WEBUI_API_BASE_URL
}
/model
file
s/`
,
{
const
res
=
await
fetch
(
`
${
WEBUI_API_BASE_URL
}
/models/`
,
{
method
:
'
GET
'
,
method
:
'
GET
'
,
headers
:
{
headers
:
{
Accept
:
'
application/json
'
,
Accept
:
'
application/json
'
,
...
@@ -59,22 +57,19 @@ export const getModelfiles = async (token: string = '') => {
...
@@ -59,22 +57,19 @@ export const getModelfiles = async (token: string = '') => {
throw
error
;
throw
error
;
}
}
return
res
.
map
((
modelfile
)
=>
modelfile
.
modelfile
)
;
return
res
;
};
};
export
const
getModel
fileByTagName
=
async
(
token
:
string
,
tagName
:
string
)
=>
{
export
const
getModel
ById
=
async
(
token
:
string
,
id
:
string
)
=>
{
let
error
=
null
;
let
error
=
null
;
const
res
=
await
fetch
(
`
${
WEBUI_API_BASE_URL
}
/model
files/
`
,
{
const
res
=
await
fetch
(
`
${
WEBUI_API_BASE_URL
}
/model
s/
${
id
}
`
,
{
method
:
'
POS
T
'
,
method
:
'
GE
T
'
,
headers
:
{
headers
:
{
Accept
:
'
application/json
'
,
Accept
:
'
application/json
'
,
'
Content-Type
'
:
'
application/json
'
,
'
Content-Type
'
:
'
application/json
'
,
authorization
:
`Bearer
${
token
}
`
authorization
:
`Bearer
${
token
}
`
},
}
body
:
JSON
.
stringify
({
tag_name
:
tagName
})
})
})
.
then
(
async
(
res
)
=>
{
.
then
(
async
(
res
)
=>
{
if
(
!
res
.
ok
)
throw
await
res
.
json
();
if
(
!
res
.
ok
)
throw
await
res
.
json
();
...
@@ -94,27 +89,20 @@ export const getModelfileByTagName = async (token: string, tagName: string) => {
...
@@ -94,27 +89,20 @@ export const getModelfileByTagName = async (token: string, tagName: string) => {
throw
error
;
throw
error
;
}
}
return
res
.
modelfile
;
return
res
;
};
};
export
const
updateModelfileByTagName
=
async
(
export
const
updateModelById
=
async
(
token
:
string
,
id
:
string
,
model
:
object
)
=>
{
token
:
string
,
tagName
:
string
,
modelfile
:
object
)
=>
{
let
error
=
null
;
let
error
=
null
;
const
res
=
await
fetch
(
`
${
WEBUI_API_BASE_URL
}
/model
files
/update`
,
{
const
res
=
await
fetch
(
`
${
WEBUI_API_BASE_URL
}
/model
s/
${
id
}
/update`
,
{
method
:
'
POST
'
,
method
:
'
POST
'
,
headers
:
{
headers
:
{
Accept
:
'
application/json
'
,
Accept
:
'
application/json
'
,
'
Content-Type
'
:
'
application/json
'
,
'
Content-Type
'
:
'
application/json
'
,
authorization
:
`Bearer
${
token
}
`
authorization
:
`Bearer
${
token
}
`
},
},
body
:
JSON
.
stringify
({
body
:
JSON
.
stringify
(
model
)
tag_name
:
tagName
,
modelfile
:
modelfile
})
})
})
.
then
(
async
(
res
)
=>
{
.
then
(
async
(
res
)
=>
{
if
(
!
res
.
ok
)
throw
await
res
.
json
();
if
(
!
res
.
ok
)
throw
await
res
.
json
();
...
@@ -137,19 +125,16 @@ export const updateModelfileByTagName = async (
...
@@ -137,19 +125,16 @@ export const updateModelfileByTagName = async (
return
res
;
return
res
;
};
};
export
const
deleteModel
fileByTagName
=
async
(
token
:
string
,
tagName
:
string
)
=>
{
export
const
deleteModel
ById
=
async
(
token
:
string
,
id
:
string
)
=>
{
let
error
=
null
;
let
error
=
null
;
const
res
=
await
fetch
(
`
${
WEBUI_API_BASE_URL
}
/model
files
/delete`
,
{
const
res
=
await
fetch
(
`
${
WEBUI_API_BASE_URL
}
/model
s/
${
id
}
/delete`
,
{
method
:
'
DELETE
'
,
method
:
'
DELETE
'
,
headers
:
{
headers
:
{
Accept
:
'
application/json
'
,
Accept
:
'
application/json
'
,
'
Content-Type
'
:
'
application/json
'
,
'
Content-Type
'
:
'
application/json
'
,
authorization
:
`Bearer
${
token
}
`
authorization
:
`Bearer
${
token
}
`
},
}
body
:
JSON
.
stringify
({
tag_name
:
tagName
})
})
})
.
then
(
async
(
res
)
=>
{
.
then
(
async
(
res
)
=>
{
if
(
!
res
.
ok
)
throw
await
res
.
json
();
if
(
!
res
.
ok
)
throw
await
res
.
json
();
...
...
src/lib/apis/openai/index.ts
View file @
b1265c9c
...
@@ -231,7 +231,12 @@ export const getOpenAIModels = async (token: string = '') => {
...
@@ -231,7 +231,12 @@ export const getOpenAIModels = async (token: string = '') => {
return
models
return
models
?
models
?
models
.
map
((
model
)
=>
({
id
:
model
.
id
,
name
:
model
.
name
??
model
.
id
,
external
:
true
}))
.
map
((
model
)
=>
({
id
:
model
.
id
,
name
:
model
.
name
??
model
.
id
,
external
:
true
,
custom_info
:
model
.
custom_info
}))
.
sort
((
a
,
b
)
=>
{
.
sort
((
a
,
b
)
=>
{
return
a
.
name
.
localeCompare
(
b
.
name
);
return
a
.
name
.
localeCompare
(
b
.
name
);
})
})
...
...
src/lib/components/chat/Chat.svelte
View file @
b1265c9c
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
chatId
,
chatId
,
chats
,
chats
,
config
,
config
,
modelfiles
,
type
Model
,
models
,
models
,
settings
,
settings
,
showSidebar
,
showSidebar
,
...
@@ -65,28 +65,10 @@
...
@@ -65,28 +65,10 @@
let
showModelSelector
=
true
;
let
showModelSelector
=
true
;
let
selectedModels
=
[
''
];
let
selectedModels
=
[
''
];
let
atSelectedModel
=
''
;
let
atSelectedModel
:
Model
|
undefined
;
let
useWebSearch
=
false
;
let
useWebSearch
=
false
;
let
selectedModelfile
=
null
;
$:
selectedModelfile
=
selectedModels
.
length
===
1
&&
$
modelfiles
.
filter
((
modelfile
)
=>
modelfile
.
tagName
===
selectedModels
[
0
]).
length
>
0
?
$
modelfiles
.
filter
((
modelfile
)
=>
modelfile
.
tagName
===
selectedModels
[
0
])[
0
]
:
null
;
let
selectedModelfiles
=
{};
$:
selectedModelfiles
=
selectedModels
.
reduce
((
a
,
tagName
,
i
,
arr
)
=>
{
const
modelfile
=
$
modelfiles
.
filter
((
modelfile
)
=>
modelfile
.
tagName
===
tagName
)?.
at
(
0
)
??
undefined
;
return
{
...
a
,
...(
modelfile
&&
{
[
tagName
]:
modelfile
})
};
},
{});
let
chat
=
null
;
let
chat
=
null
;
let
tags
=
[];
let
tags
=
[];
...
@@ -171,6 +153,7 @@
...
@@ -171,6 +153,7 @@
if
($
page
.
url
.
searchParams
.
get
(
'q'
))
{
if
($
page
.
url
.
searchParams
.
get
(
'q'
))
{
prompt
=
$
page
.
url
.
searchParams
.
get
(
'q'
)
??
''
;
prompt
=
$
page
.
url
.
searchParams
.
get
(
'q'
)
??
''
;
if
(
prompt
)
{
if
(
prompt
)
{
await
tick
();
await
tick
();
submitPrompt
(
prompt
);
submitPrompt
(
prompt
);
...
@@ -218,7 +201,7 @@
...
@@ -218,7 +201,7 @@
await
settings
.
set
({
await
settings
.
set
({
...
_settings
,
...
_settings
,
system
:
chatContent
.
system
??
_settings
.
system
,
system
:
chatContent
.
system
??
_settings
.
system
,
option
s
:
chatContent
.
options
??
_settings
.
option
s
param
s
:
chatContent
.
options
??
_settings
.
param
s
});
});
autoScroll
=
true
;
autoScroll
=
true
;
await
tick
();
await
tick
();
...
@@ -307,7 +290,7 @@
...
@@ -307,7 +290,7 @@
models
:
selectedModels
,
models
:
selectedModels
,
system
:
$
settings
.
system
??
undefined
,
system
:
$
settings
.
system
??
undefined
,
options
:
{
options
:
{
...($
settings
.
option
s
??
{})
...($
settings
.
param
s
??
{})
},
},
messages
:
messages
,
messages
:
messages
,
history
:
history
,
history
:
history
,
...
@@ -324,6 +307,7 @@
...
@@ -324,6 +307,7 @@
//
Reset
chat
input
textarea
//
Reset
chat
input
textarea
prompt
=
''
;
prompt
=
''
;
document
.
getElementById
(
'chat-textarea'
).
style
.
height
=
''
;
files
=
[];
files
=
[];
//
Send
prompt
//
Send
prompt
...
@@ -335,79 +319,96 @@
...
@@ -335,79 +319,96 @@
const
_chatId
=
JSON
.
parse
(
JSON
.
stringify
($
chatId
));
const
_chatId
=
JSON
.
parse
(
JSON
.
stringify
($
chatId
));
await
Promise
.
all
(
await
Promise
.
all
(
(
modelId
?
[
modelId
]
:
atSelectedModel
!== '' ? [atSelectedModel.id] : selectedModels).map(
(
modelId
async
(
modelId
)
=>
{
?
[
modelId
]
console
.
log
(
'modelId'
,
modelId
);
:
atSelectedModel
!== undefined
const
model
=
$
models
.
filter
((
m
)
=>
m
.
id
===
modelId
).
at
(
0
);
?
[
atSelectedModel
.
id
]
:
selectedModels
if
(
model
)
{
).
map
(
async
(
modelId
)
=>
{
//
Create
response
message
console
.
log
(
'modelId'
,
modelId
);
let
responseMessageId
=
uuidv4
();
const
model
=
$
models
.
filter
((
m
)
=>
m
.
id
===
modelId
).
at
(
0
);
let
responseMessage
=
{
parentId
:
parentId
,
if
(
model
)
{
id
:
responseMessageId
,
//
If
there
are
image
files
,
check
if
model
is
vision
capable
childrenIds
:
[],
const
hasImages
=
messages
.
some
((
message
)
=>
role
:
'assistant'
,
message
.
files
?.
some
((
file
)
=>
file
.
type
===
'image'
)
content
:
''
,
);
model
:
model
.
id
,
userContext
:
null
,
timestamp
:
Math
.
floor
(
Date
.
now
()
/
1000
)
//
Unix
epoch
};
//
Add
message
to
history
and
Set
currentId
to
messageId
history
.
messages
[
responseMessageId
]
=
responseMessage
;
history
.
currentId
=
responseMessageId
;
//
Append
messageId
to
childrenIds
of
parent
message
if
(
parentId
!== null) {
history
.
messages
[
parentId
].
childrenIds
=
[
...
history
.
messages
[
parentId
].
childrenIds
,
responseMessageId
];
}
await
tick
();
if
(
hasImages
&&
!(model.info?.meta?.capabilities?.vision ?? true)) {
toast
.
error
(
$
i18n
.
t
(
'Model {{modelName}} is not vision capable'
,
{
modelName
:
model
.
name
??
model
.
id
})
);
}
let
userContext
=
null
;
//
Create
response
message
if
($
settings
?.
memory
??
false
)
{
let
responseMessageId
=
uuidv4
();
if
(
userContext
===
null
)
{
let
responseMessage
=
{
const
res
=
await
queryMemory
(
localStorage
.
token
,
prompt
).
catch
((
error
)
=>
{
parentId
:
parentId
,
toast
.
error
(
error
);
id
:
responseMessageId
,
return
null
;
childrenIds
:
[],
});
role
:
'assistant'
,
content
:
''
,
if
(
res
)
{
model
:
model
.
id
,
if
(
res
.
documents
[
0
].
length
>
0
)
{
modelName
:
model
.
name
??
model
.
id
,
userContext
=
res
.
documents
.
reduce
((
acc
,
doc
,
index
)
=>
{
userContext
:
null
,
const
createdAtTimestamp
=
res
.
metadatas
[
index
][
0
].
created_at
;
timestamp
:
Math
.
floor
(
Date
.
now
()
/
1000
)
//
Unix
epoch
const
createdAtDate
=
new
Date
(
createdAtTimestamp
*
1000
)
};
.
toISOString
()
.
split
(
'T'
)[
0
];
//
Add
message
to
history
and
Set
currentId
to
messageId
acc
.
push
(`${
index
+
1
}.
[${
createdAtDate
}].
${
doc
[
0
]}`);
history
.
messages
[
responseMessageId
]
=
responseMessage
;
return
acc
;
history
.
currentId
=
responseMessageId
;
},
[]);
}
//
Append
messageId
to
childrenIds
of
parent
message
if
(
parentId
!== null) {
history
.
messages
[
parentId
].
childrenIds
=
[
...
history
.
messages
[
parentId
].
childrenIds
,
responseMessageId
];
}
console
.
log
(
userContext
);
await
tick
();
let
userContext
=
null
;
if
($
settings
?.
memory
??
false
)
{
if
(
userContext
===
null
)
{
const
res
=
await
queryMemory
(
localStorage
.
token
,
prompt
).
catch
((
error
)
=>
{
toast
.
error
(
error
);
return
null
;
});
if
(
res
)
{
if
(
res
.
documents
[
0
].
length
>
0
)
{
userContext
=
res
.
documents
.
reduce
((
acc
,
doc
,
index
)
=>
{
const
createdAtTimestamp
=
res
.
metadatas
[
index
][
0
].
created_at
;
const
createdAtDate
=
new
Date
(
createdAtTimestamp
*
1000
)
.
toISOString
()
.
split
(
'T'
)[
0
];
acc
.
push
(`${
index
+
1
}.
[${
createdAtDate
}].
${
doc
[
0
]}`);
return
acc
;
},
[]);
}
}
console
.
log
(
userContext
);
}
}
}
}
responseMessage
.
userContext
=
userContext
;
}
responseMessage
.
userContext
=
userContext
;
if
(
useWebSearch
)
{
if
(
useWebSearch
)
{
await
runWebSearchForPrompt
(
model
.
id
,
parentId
,
responseMessageId
);
await
runWebSearchForPrompt
(
model
.
id
,
parentId
,
responseMessageId
);
}
}
if
(
model
?.
external
)
{
if
(
model
?.
owned_by
===
'openai'
)
{
await
sendPromptOpenAI
(
model
,
prompt
,
responseMessageId
,
_chatId
);
await
sendPromptOpenAI
(
model
,
prompt
,
responseMessageId
,
_chatId
);
}
else
if
(
model
)
{
}
else
if
(
model
)
{
await
sendPromptOllama
(
model
,
prompt
,
responseMessageId
,
_chatId
);
await
sendPromptOllama
(
model
,
prompt
,
responseMessageId
,
_chatId
);
}
}
else
{
toast
.
error
($
i18n
.
t
(`
Model
{{
modelId
}}
not
found
`,
{
modelId
}));
}
}
}
else
{
toast
.
error
($
i18n
.
t
(`
Model
{{
modelId
}}
not
found
`,
{
modelId
}));
}
}
)
}
)
);
);
await
chats
.
set
(
await
getChatList
(
localStorage
.
token
));
await
chats
.
set
(
await
getChatList
(
localStorage
.
token
));
...
@@ -476,7 +477,7 @@
...
@@ -476,7 +477,7 @@
//
Prepare
the
base
message
object
//
Prepare
the
base
message
object
const
baseMessage
=
{
const
baseMessage
=
{
role
:
message
.
role
,
role
:
message
.
role
,
content
:
arr
.
length
-
2
!== idx ? message.content : message?.raContent ??
message.content
content
:
message
.
content
};
};
//
Extract
and
format
image
URLs
if
any
exist
//
Extract
and
format
image
URLs
if
any
exist
...
@@ -488,7 +489,6 @@
...
@@ -488,7 +489,6 @@
if
(
imageUrls
&&
imageUrls
.
length
>
0
&&
message
.
role
===
'user'
)
{
if
(
imageUrls
&&
imageUrls
.
length
>
0
&&
message
.
role
===
'user'
)
{
baseMessage
.
images
=
imageUrls
;
baseMessage
.
images
=
imageUrls
;
}
}
return
baseMessage
;
return
baseMessage
;
});
});
...
@@ -519,13 +519,15 @@
...
@@ -519,13 +519,15 @@
model
:
model
,
model
:
model
,
messages
:
messagesBody
,
messages
:
messagesBody
,
options
:
{
options
:
{
...($
settings
.
option
s
??
{}),
...($
settings
.
param
s
??
{}),
stop
:
stop
:
$
settings
?.
option
s
?.
stop
??
undefined
$
settings
?.
param
s
?.
stop
??
undefined
?
$
settings
.
option
s
.
stop
.
map
((
str
)
=>
?
$
settings
.
param
s
.
stop
.
map
((
str
)
=>
decodeURIComponent
(
JSON
.
parse
(
'"'
+
str
.
replace
(/\
"/g, '
\\
"
') + '
"'))
decodeURIComponent
(
JSON
.
parse
(
'"'
+
str
.
replace
(/\
"/g, '
\\
"
') + '
"'))
)
)
: undefined
: undefined,
num_predict: $settings?.params?.max_tokens ?? undefined,
repeat_penalty: $settings?.params?.frequency_penalty ?? undefined
},
},
format: $settings.requestFormat ?? undefined,
format: $settings.requestFormat ?? undefined,
keep_alive: $settings.keepAlive ?? undefined,
keep_alive: $settings.keepAlive ?? undefined,
...
@@ -651,7 +653,8 @@
...
@@ -651,7 +653,8 @@
if ($settings.saveChatHistory ?? true) {
if ($settings.saveChatHistory ?? true) {
chat = await updateChatById(localStorage.token, _chatId, {
chat = await updateChatById(localStorage.token, _chatId, {
messages: messages,
messages: messages,
history: history
history: history,
models: selectedModels
});
});
await chats.set(await getChatList(localStorage.token));
await chats.set(await getChatList(localStorage.token));
}
}
...
@@ -762,18 +765,17 @@
...
@@ -762,18 +765,17 @@
: message?.raContent ?? message.content
: message?.raContent ?? message.content
})
})
})),
})),
seed: $settings?.
option
s?.seed ?? undefined,
seed: $settings?.
param
s?.seed ?? undefined,
stop:
stop:
$settings?.
option
s?.stop ?? undefined
$settings?.
param
s?.stop ?? undefined
? $settings.
option
s.stop.map((str) =>
? $settings.
param
s.stop.map((str) =>
decodeURIComponent(JSON.parse('"
' + str.replace(/\"/g, '
\\
"') + '"
'))
decodeURIComponent(JSON.parse('"
' + str.replace(/\"/g, '
\\
"') + '"
'))
)
)
: undefined,
: undefined,
temperature: $settings?.options?.temperature ?? undefined,
temperature: $settings?.params?.temperature ?? undefined,
top_p: $settings?.options?.top_p ?? undefined,
top_p: $settings?.params?.top_p ?? undefined,
num_ctx: $settings?.options?.num_ctx ?? undefined,
frequency_penalty: $settings?.params?.frequency_penalty ?? undefined,
frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
max_tokens: $settings?.params?.max_tokens ?? undefined,
max_tokens: $settings?.options?.num_predict ?? undefined,
docs: docs.length > 0 ? docs : undefined,
docs: docs.length > 0 ? docs : undefined,
citations: docs.length > 0
citations: docs.length > 0
},
},
...
@@ -843,6 +845,7 @@
...
@@ -843,6 +845,7 @@
if ($chatId == _chatId) {
if ($chatId == _chatId) {
if ($settings.saveChatHistory ?? true) {
if ($settings.saveChatHistory ?? true) {
chat = await updateChatById(localStorage.token, _chatId, {
chat = await updateChatById(localStorage.token, _chatId, {
models: selectedModels,
messages: messages,
messages: messages,
history: history
history: history
});
});
...
@@ -981,10 +984,8 @@
...
@@ -981,10 +984,8 @@
) + '
{{
prompt
}}
',
) + '
{{
prompt
}}
',
titleModelId,
titleModelId,
userPrompt,
userPrompt,
titleModel?.external ?? false
titleModel?.owned_by === '
openai
' ?? false
? titleModel?.source?.toLowerCase() === '
litellm
'
? `${OPENAI_API_BASE_URL}`
? `${LITELLM_API_BASE_URL}/v1`
: `${OPENAI_API_BASE_URL}`
: `${OLLAMA_API_BASE_URL}/v1`
: `${OLLAMA_API_BASE_URL}/v1`
);
);
...
@@ -1011,10 +1012,8 @@
...
@@ -1011,10 +1012,8 @@
taskModelId,
taskModelId,
previousMessages,
previousMessages,
userPrompt,
userPrompt,
taskModel?.external ?? false
taskModel?.owned_by === '
openai
' ?? false
? taskModel?.source?.toLowerCase() === '
litellm
'
? `${OPENAI_API_BASE_URL}`
? `${LITELLM_API_BASE_URL}/v1`
: `${OPENAI_API_BASE_URL}`
: `${OLLAMA_API_BASE_URL}/v1`
: `${OLLAMA_API_BASE_URL}/v1`
);
);
};
};
...
@@ -1096,16 +1095,12 @@
...
@@ -1096,16 +1095,12 @@
<Messages
<Messages
chatId={$chatId}
chatId={$chatId}
{selectedModels}
{selectedModels}
{selectedModelfiles}
{processing}
{processing}
bind:history
bind:history
bind:messages
bind:messages
bind:autoScroll
bind:autoScroll
bind:prompt
bind:prompt
bottomPadding={files.length > 0}
bottomPadding={files.length > 0}
suggestionPrompts={chatIdProp
? []
: selectedModelfile?.suggestionPrompts ?? $config.default_prompt_suggestions}
{sendPrompt}
{sendPrompt}
{continueGeneration}
{continueGeneration}
{regenerateResponse}
{regenerateResponse}
...
@@ -1119,8 +1114,9 @@
...
@@ -1119,8 +1114,9 @@
bind:files
bind:files
bind:prompt
bind:prompt
bind:autoScroll
bind:autoScroll
bind:selectedModel={atSelectedModel}
bind:useWebSearch
bind:useWebSearch
bind:atSelectedModel
{selectedModels}
{messages}
{messages}
{submitPrompt}
{submitPrompt}
{stopResponse}
{stopResponse}
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment