Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
9c2429ff
Unverified
Commit
9c2429ff
authored
Aug 12, 2024
by
Timothy Jaeryang Baek
Committed by
GitHub
Aug 12, 2024
Browse files
Merge pull request #4402 from michaelpoluektov/remove-ollama
refactor: re-use utils in Ollama
parents
d0645d3c
547611b7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
139 additions
and
320 deletions
+139
-320
backend/apps/images/utils/comfyui.py
backend/apps/images/utils/comfyui.py
+3
-2
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+84
-309
backend/apps/openai/main.py
backend/apps/openai/main.py
+5
-2
backend/apps/webui/main.py
backend/apps/webui/main.py
+2
-2
backend/utils/misc.py
backend/utils/misc.py
+45
-5
No files found.
backend/apps/images/utils/comfyui.py
View file @
9c2429ff
import
asyncio
import
websocket
# NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import
uuid
import
json
import
urllib.request
import
urllib.parse
...
...
@@ -398,7 +397,9 @@ async def comfyui_generate_image(
return
None
try
:
images
=
await
asyncio
.
to_thread
(
get_images
,
ws
,
comfyui_prompt
,
client_id
,
base_url
)
images
=
await
asyncio
.
to_thread
(
get_images
,
ws
,
comfyui_prompt
,
client_id
,
base_url
)
except
Exception
as
e
:
log
.
exception
(
f
"Error while receiving images:
{
e
}
"
)
images
=
None
...
...
backend/apps/ollama/main.py
View file @
9c2429ff
from
fastapi
import
(
FastAPI
,
Request
,
Response
,
HTTPException
,
Depends
,
status
,
UploadFile
,
File
,
BackgroundTasks
,
)
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
StreamingResponse
from
fastapi.concurrency
import
run_in_threadpool
from
pydantic
import
BaseModel
,
ConfigDict
import
os
import
re
import
copy
import
random
import
requests
import
json
import
uuid
import
aiohttp
import
asyncio
import
logging
...
...
@@ -32,16 +26,11 @@ from typing import Optional, List, Union
from
starlette.background
import
BackgroundTask
from
apps.webui.models.models
import
Models
from
apps.webui.models.users
import
Users
from
constants
import
ERROR_MESSAGES
from
utils.utils
import
(
decode_token
,
get_current_user
,
get_verified_user
,
get_admin_user
,
)
from
utils.task
import
prompt_template
from
config
import
(
SRC_LOG_LEVELS
,
...
...
@@ -53,7 +42,12 @@ from config import (
UPLOAD_DIR
,
AppConfig
,
)
from
utils.misc
import
calculate_sha256
,
add_or_update_system_message
from
utils.misc
import
(
apply_model_params_to_body_ollama
,
calculate_sha256
,
apply_model_params_to_body_openai
,
apply_model_system_prompt_to_body
,
)
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"OLLAMA"
])
...
...
@@ -183,7 +177,7 @@ async def post_streaming_url(url: str, payload: str, stream: bool = True):
res
=
await
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
...
...
@@ -238,7 +232,7 @@ async def get_all_models():
async
def
get_ollama_tags
(
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
)
):
if
url_idx
==
None
:
if
url_idx
is
None
:
models
=
await
get_all_models
()
if
app
.
state
.
config
.
ENABLE_MODEL_FILTER
:
...
...
@@ -269,7 +263,7 @@ async def get_ollama_tags(
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
...
...
@@ -282,8 +276,7 @@ async def get_ollama_tags(
@
app
.
get
(
"/api/version/{url_idx}"
)
async
def
get_ollama_versions
(
url_idx
:
Optional
[
int
]
=
None
):
if
app
.
state
.
config
.
ENABLE_OLLAMA_API
:
if
url_idx
==
None
:
if
url_idx
is
None
:
# returns lowest version
tasks
=
[
fetch_url
(
f
"
{
url
}
/api/version"
)
...
...
@@ -323,7 +316,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
...
...
@@ -346,8 +339,6 @@ async def pull_model(
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
# Admin should be able to pull models from any source
payload
=
{
**
form_data
.
model_dump
(
exclude_none
=
True
),
"insecure"
:
True
}
...
...
@@ -367,7 +358,7 @@ async def push_model(
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_admin_user
),
):
if
url_idx
==
None
:
if
url_idx
is
None
:
if
form_data
.
name
in
app
.
state
.
MODELS
:
url_idx
=
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
][
0
]
else
:
...
...
@@ -417,7 +408,7 @@ async def copy_model(
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_admin_user
),
):
if
url_idx
==
None
:
if
url_idx
is
None
:
if
form_data
.
source
in
app
.
state
.
MODELS
:
url_idx
=
app
.
state
.
MODELS
[
form_data
.
source
][
"urls"
][
0
]
else
:
...
...
@@ -428,13 +419,13 @@ async def copy_model(
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/copy"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
try
:
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/copy"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
r
.
raise_for_status
()
log
.
debug
(
f
"r.text:
{
r
.
text
}
"
)
...
...
@@ -448,7 +439,7 @@ async def copy_model(
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
...
...
@@ -464,7 +455,7 @@ async def delete_model(
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_admin_user
),
):
if
url_idx
==
None
:
if
url_idx
is
None
:
if
form_data
.
name
in
app
.
state
.
MODELS
:
url_idx
=
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
][
0
]
else
:
...
...
@@ -476,12 +467,12 @@ async def delete_model(
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
r
=
requests
.
request
(
method
=
"DELETE"
,
url
=
f
"
{
url
}
/api/delete"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
try
:
r
=
requests
.
request
(
method
=
"DELETE"
,
url
=
f
"
{
url
}
/api/delete"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
r
.
raise_for_status
()
log
.
debug
(
f
"r.text:
{
r
.
text
}
"
)
...
...
@@ -495,7 +486,7 @@ async def delete_model(
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
...
...
@@ -516,12 +507,12 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/show"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
try
:
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/show"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
r
.
raise_for_status
()
return
r
.
json
()
...
...
@@ -533,7 +524,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
...
...
@@ -556,7 +547,7 @@ async def generate_embeddings(
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
):
if
url_idx
==
None
:
if
url_idx
is
None
:
model
=
form_data
.
model
if
":"
not
in
model
:
...
...
@@ -573,12 +564,12 @@ async def generate_embeddings(
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/embeddings"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
try
:
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/embeddings"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
r
.
raise_for_status
()
return
r
.
json
()
...
...
@@ -590,7 +581,7 @@ async def generate_embeddings(
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
...
...
@@ -603,10 +594,9 @@ def generate_ollama_embeddings(
form_data
:
GenerateEmbeddingsForm
,
url_idx
:
Optional
[
int
]
=
None
,
):
log
.
info
(
f
"generate_ollama_embeddings
{
form_data
}
"
)
if
url_idx
==
None
:
if
url_idx
is
None
:
model
=
form_data
.
model
if
":"
not
in
model
:
...
...
@@ -623,12 +613,12 @@ def generate_ollama_embeddings(
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/embeddings"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
try
:
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/embeddings"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
r
.
raise_for_status
()
data
=
r
.
json
()
...
...
@@ -638,7 +628,7 @@ def generate_ollama_embeddings(
if
"embedding"
in
data
:
return
data
[
"embedding"
]
else
:
raise
"Something went wrong :/"
raise
Exception
(
"Something went wrong :/"
)
except
Exception
as
e
:
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
...
...
@@ -647,10 +637,10 @@ def generate_ollama_embeddings(
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
raise
error_detail
raise
Exception
(
error_detail
)
class
GenerateCompletionForm
(
BaseModel
):
...
...
@@ -674,8 +664,7 @@ async def generate_completion(
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
):
if
url_idx
==
None
:
if
url_idx
is
None
:
model
=
form_data
.
model
if
":"
not
in
model
:
...
...
@@ -713,6 +702,18 @@ class GenerateChatCompletionForm(BaseModel):
keep_alive
:
Optional
[
Union
[
int
,
str
]]
=
None
def
get_ollama_url
(
url_idx
:
Optional
[
int
],
model
:
str
):
if
url_idx
is
None
:
if
model
not
in
app
.
state
.
MODELS
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
model
),
)
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
model
][
"urls"
])
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
return
url
@
app
.
post
(
"/api/chat"
)
@
app
.
post
(
"/api/chat/{url_idx}"
)
async
def
generate_chat_completion
(
...
...
@@ -720,12 +721,7 @@ async def generate_chat_completion(
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
):
log
.
debug
(
"form_data.model_dump_json(exclude_none=True).encode(): {0} "
.
format
(
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
()
)
)
log
.
debug
(
f
"
{
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
()
}
="
)
payload
=
{
**
form_data
.
model_dump
(
exclude_none
=
True
,
exclude
=
[
"metadata"
]),
...
...
@@ -740,185 +736,21 @@ async def generate_chat_completion(
if
model_info
.
base_model_id
:
payload
[
"model"
]
=
model_info
.
base_model_id
model_info
.
params
=
model_info
.
params
.
model_dump
()
params
=
model_info
.
params
.
model_dump
()
if
model_info
.
params
:
if
params
:
if
payload
.
get
(
"options"
)
is
None
:
payload
[
"options"
]
=
{}
if
(
model_info
.
params
.
get
(
"mirostat"
,
None
)
and
payload
[
"options"
].
get
(
"mirostat"
)
is
None
):
payload
[
"options"
][
"mirostat"
]
=
model_info
.
params
.
get
(
"mirostat"
,
None
)
if
(
model_info
.
params
.
get
(
"mirostat_eta"
,
None
)
and
payload
[
"options"
].
get
(
"mirostat_eta"
)
is
None
):
payload
[
"options"
][
"mirostat_eta"
]
=
model_info
.
params
.
get
(
"mirostat_eta"
,
None
)
if
(
model_info
.
params
.
get
(
"mirostat_tau"
,
None
)
and
payload
[
"options"
].
get
(
"mirostat_tau"
)
is
None
):
payload
[
"options"
][
"mirostat_tau"
]
=
model_info
.
params
.
get
(
"mirostat_tau"
,
None
)
if
(
model_info
.
params
.
get
(
"num_ctx"
,
None
)
and
payload
[
"options"
].
get
(
"num_ctx"
)
is
None
):
payload
[
"options"
][
"num_ctx"
]
=
model_info
.
params
.
get
(
"num_ctx"
,
None
)
if
(
model_info
.
params
.
get
(
"num_batch"
,
None
)
and
payload
[
"options"
].
get
(
"num_batch"
)
is
None
):
payload
[
"options"
][
"num_batch"
]
=
model_info
.
params
.
get
(
"num_batch"
,
None
)
if
(
model_info
.
params
.
get
(
"num_keep"
,
None
)
and
payload
[
"options"
].
get
(
"num_keep"
)
is
None
):
payload
[
"options"
][
"num_keep"
]
=
model_info
.
params
.
get
(
"num_keep"
,
None
)
if
(
model_info
.
params
.
get
(
"repeat_last_n"
,
None
)
and
payload
[
"options"
].
get
(
"repeat_last_n"
)
is
None
):
payload
[
"options"
][
"repeat_last_n"
]
=
model_info
.
params
.
get
(
"repeat_last_n"
,
None
)
if
(
model_info
.
params
.
get
(
"frequency_penalty"
,
None
)
and
payload
[
"options"
].
get
(
"frequency_penalty"
)
is
None
):
payload
[
"options"
][
"repeat_penalty"
]
=
model_info
.
params
.
get
(
"frequency_penalty"
,
None
)
if
(
model_info
.
params
.
get
(
"temperature"
,
None
)
is
not
None
and
payload
[
"options"
].
get
(
"temperature"
)
is
None
):
payload
[
"options"
][
"temperature"
]
=
model_info
.
params
.
get
(
"temperature"
,
None
)
if
(
model_info
.
params
.
get
(
"seed"
,
None
)
is
not
None
and
payload
[
"options"
].
get
(
"seed"
)
is
None
):
payload
[
"options"
][
"seed"
]
=
model_info
.
params
.
get
(
"seed"
,
None
)
if
(
model_info
.
params
.
get
(
"stop"
,
None
)
and
payload
[
"options"
].
get
(
"stop"
)
is
None
):
payload
[
"options"
][
"stop"
]
=
(
[
bytes
(
stop
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
stop
in
model_info
.
params
[
"stop"
]
]
if
model_info
.
params
.
get
(
"stop"
,
None
)
else
None
)
if
(
model_info
.
params
.
get
(
"tfs_z"
,
None
)
and
payload
[
"options"
].
get
(
"tfs_z"
)
is
None
):
payload
[
"options"
][
"tfs_z"
]
=
model_info
.
params
.
get
(
"tfs_z"
,
None
)
if
(
model_info
.
params
.
get
(
"max_tokens"
,
None
)
and
payload
[
"options"
].
get
(
"max_tokens"
)
is
None
):
payload
[
"options"
][
"num_predict"
]
=
model_info
.
params
.
get
(
"max_tokens"
,
None
)
if
(
model_info
.
params
.
get
(
"top_k"
,
None
)
and
payload
[
"options"
].
get
(
"top_k"
)
is
None
):
payload
[
"options"
][
"top_k"
]
=
model_info
.
params
.
get
(
"top_k"
,
None
)
if
(
model_info
.
params
.
get
(
"top_p"
,
None
)
and
payload
[
"options"
].
get
(
"top_p"
)
is
None
):
payload
[
"options"
][
"top_p"
]
=
model_info
.
params
.
get
(
"top_p"
,
None
)
if
(
model_info
.
params
.
get
(
"min_p"
,
None
)
and
payload
[
"options"
].
get
(
"min_p"
)
is
None
):
payload
[
"options"
][
"min_p"
]
=
model_info
.
params
.
get
(
"min_p"
,
None
)
if
(
model_info
.
params
.
get
(
"use_mmap"
,
None
)
and
payload
[
"options"
].
get
(
"use_mmap"
)
is
None
):
payload
[
"options"
][
"use_mmap"
]
=
model_info
.
params
.
get
(
"use_mmap"
,
None
)
if
(
model_info
.
params
.
get
(
"use_mlock"
,
None
)
and
payload
[
"options"
].
get
(
"use_mlock"
)
is
None
):
payload
[
"options"
][
"use_mlock"
]
=
model_info
.
params
.
get
(
"use_mlock"
,
None
)
if
(
model_info
.
params
.
get
(
"num_thread"
,
None
)
and
payload
[
"options"
].
get
(
"num_thread"
)
is
None
):
payload
[
"options"
][
"num_thread"
]
=
model_info
.
params
.
get
(
"num_thread"
,
None
)
system
=
model_info
.
params
.
get
(
"system"
,
None
)
if
system
:
system
=
prompt_template
(
system
,
**
(
{
"user_name"
:
user
.
name
,
"user_location"
:
(
user
.
info
.
get
(
"location"
)
if
user
.
info
else
None
),
}
if
user
else
{}
),
payload
[
"options"
]
=
apply_model_params_to_body_ollama
(
params
,
payload
[
"options"
]
)
payload
=
apply_model_system_prompt_to_body
(
params
,
payload
,
user
)
if
payload
.
get
(
"messages"
):
payload
[
"messages"
]
=
add_or_update_system_message
(
system
,
payload
[
"messages"
]
)
if
url_idx
==
None
:
if
":"
not
in
payload
[
"model"
]:
payload
[
"model"
]
=
f
"
{
payload
[
'model'
]
}
:latest"
if
payload
[
"model"
]
in
app
.
state
.
MODELS
:
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
payload
[
"model"
]][
"urls"
])
else
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
if
":"
not
in
payload
[
"model"
]:
payload
[
"model"
]
=
f
"
{
payload
[
'model'
]
}
:latest"
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
get_ollama_url
(
url_idx
,
payload
[
"model"
])
log
.
info
(
f
"url:
{
url
}
"
)
log
.
debug
(
payload
)
...
...
@@ -952,83 +784,28 @@ async def generate_openai_chat_completion(
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
):
form_data
=
OpenAIChatCompletionForm
(
**
form_data
)
payload
=
{
**
form_data
.
model_dump
(
exclude_none
=
True
,
exclude
=
[
"metadata"
])}
completion_form
=
OpenAIChatCompletionForm
(
**
form_data
)
payload
=
{
**
completion_form
.
model_dump
(
exclude_none
=
True
,
exclude
=
[
"metadata"
])}
if
"metadata"
in
payload
:
del
payload
[
"metadata"
]
model_id
=
form_data
.
model
model_id
=
completion_form
.
model
model_info
=
Models
.
get_model_by_id
(
model_id
)
if
model_info
:
if
model_info
.
base_model_id
:
payload
[
"model"
]
=
model_info
.
base_model_id
model_info
.
params
=
model_info
.
params
.
model_dump
()
params
=
model_info
.
params
.
model_dump
()
if
model_info
.
params
:
payload
[
"temperature"
]
=
model_info
.
params
.
get
(
"temperature"
,
None
)
payload
[
"top_p"
]
=
model_info
.
params
.
get
(
"top_p"
,
None
)
payload
[
"max_tokens"
]
=
model_info
.
params
.
get
(
"max_tokens"
,
None
)
payload
[
"frequency_penalty"
]
=
model_info
.
params
.
get
(
"frequency_penalty"
,
None
)
payload
[
"seed"
]
=
model_info
.
params
.
get
(
"seed"
,
None
)
payload
[
"stop"
]
=
(
[
bytes
(
stop
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
stop
in
model_info
.
params
[
"stop"
]
]
if
model_info
.
params
.
get
(
"stop"
,
None
)
else
None
)
if
params
:
payload
=
apply_model_params_to_body_openai
(
params
,
payload
)
payload
=
apply_model_system_prompt_to_body
(
params
,
payload
,
user
)
system
=
model_info
.
params
.
get
(
"system"
,
None
)
if
system
:
system
=
prompt_template
(
system
,
**
(
{
"user_name"
:
user
.
name
,
"user_location"
:
(
user
.
info
.
get
(
"location"
)
if
user
.
info
else
None
),
}
if
user
else
{}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if
payload
.
get
(
"messages"
):
for
message
in
payload
[
"messages"
]:
if
message
.
get
(
"role"
)
==
"system"
:
message
[
"content"
]
=
system
+
message
[
"content"
]
break
else
:
payload
[
"messages"
].
insert
(
0
,
{
"role"
:
"system"
,
"content"
:
system
,
},
)
if
":"
not
in
payload
[
"model"
]:
payload
[
"model"
]
=
f
"
{
payload
[
'model'
]
}
:latest"
if
url_idx
==
None
:
if
":"
not
in
payload
[
"model"
]:
payload
[
"model"
]
=
f
"
{
payload
[
'model'
]
}
:latest"
if
payload
[
"model"
]
in
app
.
state
.
MODELS
:
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
payload
[
"model"
]][
"urls"
])
else
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
get_ollama_url
(
url_idx
,
payload
[
"model"
])
log
.
info
(
f
"url:
{
url
}
"
)
return
await
post_streaming_url
(
...
...
@@ -1044,7 +821,7 @@ async def get_openai_models(
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
):
if
url_idx
==
None
:
if
url_idx
is
None
:
models
=
await
get_all_models
()
if
app
.
state
.
config
.
ENABLE_MODEL_FILTER
:
...
...
@@ -1099,7 +876,7 @@ async def get_openai_models(
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
...
...
@@ -1125,7 +902,6 @@ def parse_huggingface_url(hf_url):
path_components
=
parsed_url
.
path
.
split
(
"/"
)
# Extract the desired output
user_repo
=
"/"
.
join
(
path_components
[
1
:
3
])
model_file
=
path_components
[
-
1
]
return
model_file
...
...
@@ -1190,7 +966,6 @@ async def download_model(
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_admin_user
),
):
allowed_hosts
=
[
"https://huggingface.co/"
,
"https://github.com/"
]
if
not
any
(
form_data
.
url
.
startswith
(
host
)
for
host
in
allowed_hosts
):
...
...
@@ -1199,7 +974,7 @@ async def download_model(
detail
=
"Invalid file_url. Only URLs from allowed hosts are permitted."
,
)
if
url_idx
==
None
:
if
url_idx
is
None
:
url_idx
=
0
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
...
...
@@ -1222,7 +997,7 @@ def upload_model(
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_admin_user
),
):
if
url_idx
==
None
:
if
url_idx
is
None
:
url_idx
=
0
ollama_url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
...
...
backend/apps/openai/main.py
View file @
9c2429ff
...
...
@@ -17,7 +17,10 @@ from utils.utils import (
get_verified_user
,
get_admin_user
,
)
from
utils.misc
import
apply_model_params_to_body
,
apply_model_system_prompt_to_body
from
utils.misc
import
(
apply_model_params_to_body_openai
,
apply_model_system_prompt_to_body
,
)
from
config
import
(
SRC_LOG_LEVELS
,
...
...
@@ -368,7 +371,7 @@ async def generate_chat_completion(
payload
[
"model"
]
=
model_info
.
base_model_id
params
=
model_info
.
params
.
model_dump
()
payload
=
apply_model_params_to_body
(
params
,
payload
)
payload
=
apply_model_params_to_body
_openai
(
params
,
payload
)
payload
=
apply_model_system_prompt_to_body
(
params
,
payload
,
user
)
model
=
app
.
state
.
MODELS
[
payload
.
get
(
"model"
)]
...
...
backend/apps/webui/main.py
View file @
9c2429ff
...
...
@@ -22,7 +22,7 @@ from apps.webui.utils import load_function_module_by_id
from
utils.misc
import
(
openai_chat_chunk_message_template
,
openai_chat_completion_message_template
,
apply_model_params_to_body
,
apply_model_params_to_body
_openai
,
apply_model_system_prompt_to_body
,
)
...
...
@@ -291,7 +291,7 @@ async def generate_function_chat_completion(form_data, user):
form_data
[
"model"
]
=
model_info
.
base_model_id
params
=
model_info
.
params
.
model_dump
()
form_data
=
apply_model_params_to_body
(
params
,
form_data
)
form_data
=
apply_model_params_to_body
_openai
(
params
,
form_data
)
form_data
=
apply_model_system_prompt_to_body
(
params
,
form_data
,
user
)
pipe_id
=
get_pipe_id
(
form_data
)
...
...
backend/utils/misc.py
View file @
9c2429ff
...
...
@@ -2,7 +2,7 @@ from pathlib import Path
import
hashlib
import
re
from
datetime
import
timedelta
from
typing
import
Optional
,
List
,
Tuple
from
typing
import
Optional
,
List
,
Tuple
,
Callable
import
uuid
import
time
...
...
@@ -135,10 +135,21 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di
# inplace function: form_data is modified
def
apply_model_params_to_body
(
params
:
dict
,
form_data
:
dict
)
->
dict
:
def
apply_model_params_to_body
(
params
:
dict
,
form_data
:
dict
,
mappings
:
dict
[
str
,
Callable
]
)
->
dict
:
if
not
params
:
return
form_data
for
key
,
cast_func
in
mappings
.
items
():
if
(
value
:
=
params
.
get
(
key
))
is
not
None
:
form_data
[
key
]
=
cast_func
(
value
)
return
form_data
# inplace function: form_data is modified
def
apply_model_params_to_body_openai
(
params
:
dict
,
form_data
:
dict
)
->
dict
:
mappings
=
{
"temperature"
:
float
,
"top_p"
:
int
,
...
...
@@ -147,10 +158,39 @@ def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
"seed"
:
lambda
x
:
x
,
"stop"
:
lambda
x
:
[
bytes
(
s
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
s
in
x
],
}
return
apply_model_params_to_body
(
params
,
form_data
,
mappings
)
def
apply_model_params_to_body_ollama
(
params
:
dict
,
form_data
:
dict
)
->
dict
:
opts
=
[
"temperature"
,
"top_p"
,
"seed"
,
"mirostat"
,
"mirostat_eta"
,
"mirostat_tau"
,
"num_ctx"
,
"num_batch"
,
"num_keep"
,
"repeat_last_n"
,
"tfs_z"
,
"top_k"
,
"min_p"
,
"use_mmap"
,
"use_mlock"
,
"num_thread"
,
]
mappings
=
{
i
:
lambda
x
:
x
for
i
in
opts
}
form_data
=
apply_model_params_to_body
(
params
,
form_data
,
mappings
)
name_differences
=
{
"max_tokens"
:
"num_predict"
,
"frequency_penalty"
:
"repeat_penalty"
,
}
for
key
,
cast_func
in
mapping
s
.
items
():
if
(
value
:
=
params
.
get
(
key
))
is
not
None
:
form_data
[
key
]
=
cast_func
(
value
)
for
key
,
value
in
name_difference
s
.
items
():
if
(
param
:
=
params
.
get
(
key
,
None
))
is
not
None
:
form_data
[
value
]
=
param
return
form_data
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment