Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
9c2429ff
Unverified
Commit
9c2429ff
authored
Aug 12, 2024
by
Timothy Jaeryang Baek
Committed by
GitHub
Aug 12, 2024
Browse files
Merge pull request #4402 from michaelpoluektov/remove-ollama
refactor: re-use utils in Ollama
parents
d0645d3c
547611b7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
139 additions
and
320 deletions
+139
-320
backend/apps/images/utils/comfyui.py
backend/apps/images/utils/comfyui.py
+3
-2
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+84
-309
backend/apps/openai/main.py
backend/apps/openai/main.py
+5
-2
backend/apps/webui/main.py
backend/apps/webui/main.py
+2
-2
backend/utils/misc.py
backend/utils/misc.py
+45
-5
No files found.
backend/apps/images/utils/comfyui.py
View file @
9c2429ff
import
asyncio
import
asyncio
import
websocket
# NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import
websocket
# NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import
uuid
import
json
import
json
import
urllib.request
import
urllib.request
import
urllib.parse
import
urllib.parse
...
@@ -398,7 +397,9 @@ async def comfyui_generate_image(
...
@@ -398,7 +397,9 @@ async def comfyui_generate_image(
return
None
return
None
try
:
try
:
images
=
await
asyncio
.
to_thread
(
get_images
,
ws
,
comfyui_prompt
,
client_id
,
base_url
)
images
=
await
asyncio
.
to_thread
(
get_images
,
ws
,
comfyui_prompt
,
client_id
,
base_url
)
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
f
"Error while receiving images:
{
e
}
"
)
log
.
exception
(
f
"Error while receiving images:
{
e
}
"
)
images
=
None
images
=
None
...
...
backend/apps/ollama/main.py
View file @
9c2429ff
from
fastapi
import
(
from
fastapi
import
(
FastAPI
,
FastAPI
,
Request
,
Request
,
Response
,
HTTPException
,
HTTPException
,
Depends
,
Depends
,
status
,
UploadFile
,
UploadFile
,
File
,
File
,
BackgroundTasks
,
)
)
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
StreamingResponse
from
fastapi.responses
import
StreamingResponse
from
fastapi.concurrency
import
run_in_threadpool
from
pydantic
import
BaseModel
,
ConfigDict
from
pydantic
import
BaseModel
,
ConfigDict
import
os
import
os
import
re
import
re
import
copy
import
random
import
random
import
requests
import
requests
import
json
import
json
import
uuid
import
aiohttp
import
aiohttp
import
asyncio
import
asyncio
import
logging
import
logging
...
@@ -32,16 +26,11 @@ from typing import Optional, List, Union
...
@@ -32,16 +26,11 @@ from typing import Optional, List, Union
from
starlette.background
import
BackgroundTask
from
starlette.background
import
BackgroundTask
from
apps.webui.models.models
import
Models
from
apps.webui.models.models
import
Models
from
apps.webui.models.users
import
Users
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
from
utils.utils
import
(
from
utils.utils
import
(
decode_token
,
get_current_user
,
get_verified_user
,
get_verified_user
,
get_admin_user
,
get_admin_user
,
)
)
from
utils.task
import
prompt_template
from
config
import
(
from
config
import
(
SRC_LOG_LEVELS
,
SRC_LOG_LEVELS
,
...
@@ -53,7 +42,12 @@ from config import (
...
@@ -53,7 +42,12 @@ from config import (
UPLOAD_DIR
,
UPLOAD_DIR
,
AppConfig
,
AppConfig
,
)
)
from
utils.misc
import
calculate_sha256
,
add_or_update_system_message
from
utils.misc
import
(
apply_model_params_to_body_ollama
,
calculate_sha256
,
apply_model_params_to_body_openai
,
apply_model_system_prompt_to_body
,
)
log
=
logging
.
getLogger
(
__name__
)
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"OLLAMA"
])
log
.
setLevel
(
SRC_LOG_LEVELS
[
"OLLAMA"
])
...
@@ -183,7 +177,7 @@ async def post_streaming_url(url: str, payload: str, stream: bool = True):
...
@@ -183,7 +177,7 @@ async def post_streaming_url(url: str, payload: str, stream: bool = True):
res
=
await
r
.
json
()
res
=
await
r
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
raise
HTTPException
(
...
@@ -238,7 +232,7 @@ async def get_all_models():
...
@@ -238,7 +232,7 @@ async def get_all_models():
async
def
get_ollama_tags
(
async
def
get_ollama_tags
(
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
)
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
)
):
):
if
url_idx
==
None
:
if
url_idx
is
None
:
models
=
await
get_all_models
()
models
=
await
get_all_models
()
if
app
.
state
.
config
.
ENABLE_MODEL_FILTER
:
if
app
.
state
.
config
.
ENABLE_MODEL_FILTER
:
...
@@ -269,7 +263,7 @@ async def get_ollama_tags(
...
@@ -269,7 +263,7 @@ async def get_ollama_tags(
res
=
r
.
json
()
res
=
r
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
raise
HTTPException
(
...
@@ -282,8 +276,7 @@ async def get_ollama_tags(
...
@@ -282,8 +276,7 @@ async def get_ollama_tags(
@
app
.
get
(
"/api/version/{url_idx}"
)
@
app
.
get
(
"/api/version/{url_idx}"
)
async
def
get_ollama_versions
(
url_idx
:
Optional
[
int
]
=
None
):
async
def
get_ollama_versions
(
url_idx
:
Optional
[
int
]
=
None
):
if
app
.
state
.
config
.
ENABLE_OLLAMA_API
:
if
app
.
state
.
config
.
ENABLE_OLLAMA_API
:
if
url_idx
==
None
:
if
url_idx
is
None
:
# returns lowest version
# returns lowest version
tasks
=
[
tasks
=
[
fetch_url
(
f
"
{
url
}
/api/version"
)
fetch_url
(
f
"
{
url
}
/api/version"
)
...
@@ -323,7 +316,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
...
@@ -323,7 +316,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
res
=
r
.
json
()
res
=
r
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
raise
HTTPException
(
...
@@ -346,8 +339,6 @@ async def pull_model(
...
@@ -346,8 +339,6 @@ async def pull_model(
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
# Admin should be able to pull models from any source
# Admin should be able to pull models from any source
payload
=
{
**
form_data
.
model_dump
(
exclude_none
=
True
),
"insecure"
:
True
}
payload
=
{
**
form_data
.
model_dump
(
exclude_none
=
True
),
"insecure"
:
True
}
...
@@ -367,7 +358,7 @@ async def push_model(
...
@@ -367,7 +358,7 @@ async def push_model(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_admin_user
),
user
=
Depends
(
get_admin_user
),
):
):
if
url_idx
==
None
:
if
url_idx
is
None
:
if
form_data
.
name
in
app
.
state
.
MODELS
:
if
form_data
.
name
in
app
.
state
.
MODELS
:
url_idx
=
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
][
0
]
url_idx
=
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
][
0
]
else
:
else
:
...
@@ -417,7 +408,7 @@ async def copy_model(
...
@@ -417,7 +408,7 @@ async def copy_model(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_admin_user
),
user
=
Depends
(
get_admin_user
),
):
):
if
url_idx
==
None
:
if
url_idx
is
None
:
if
form_data
.
source
in
app
.
state
.
MODELS
:
if
form_data
.
source
in
app
.
state
.
MODELS
:
url_idx
=
app
.
state
.
MODELS
[
form_data
.
source
][
"urls"
][
0
]
url_idx
=
app
.
state
.
MODELS
[
form_data
.
source
][
"urls"
][
0
]
else
:
else
:
...
@@ -428,13 +419,13 @@ async def copy_model(
...
@@ -428,13 +419,13 @@ async def copy_model(
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/copy"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
try
:
try
:
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/copy"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
r
.
raise_for_status
()
r
.
raise_for_status
()
log
.
debug
(
f
"r.text:
{
r
.
text
}
"
)
log
.
debug
(
f
"r.text:
{
r
.
text
}
"
)
...
@@ -448,7 +439,7 @@ async def copy_model(
...
@@ -448,7 +439,7 @@ async def copy_model(
res
=
r
.
json
()
res
=
r
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
raise
HTTPException
(
...
@@ -464,7 +455,7 @@ async def delete_model(
...
@@ -464,7 +455,7 @@ async def delete_model(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_admin_user
),
user
=
Depends
(
get_admin_user
),
):
):
if
url_idx
==
None
:
if
url_idx
is
None
:
if
form_data
.
name
in
app
.
state
.
MODELS
:
if
form_data
.
name
in
app
.
state
.
MODELS
:
url_idx
=
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
][
0
]
url_idx
=
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
][
0
]
else
:
else
:
...
@@ -476,12 +467,12 @@ async def delete_model(
...
@@ -476,12 +467,12 @@ async def delete_model(
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
requests
.
request
(
method
=
"DELETE"
,
url
=
f
"
{
url
}
/api/delete"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
try
:
try
:
r
=
requests
.
request
(
method
=
"DELETE"
,
url
=
f
"
{
url
}
/api/delete"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
r
.
raise_for_status
()
r
.
raise_for_status
()
log
.
debug
(
f
"r.text:
{
r
.
text
}
"
)
log
.
debug
(
f
"r.text:
{
r
.
text
}
"
)
...
@@ -495,7 +486,7 @@ async def delete_model(
...
@@ -495,7 +486,7 @@ async def delete_model(
res
=
r
.
json
()
res
=
r
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
raise
HTTPException
(
...
@@ -516,12 +507,12 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
...
@@ -516,12 +507,12 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/show"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
try
:
try
:
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/show"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
r
.
raise_for_status
()
r
.
raise_for_status
()
return
r
.
json
()
return
r
.
json
()
...
@@ -533,7 +524,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
...
@@ -533,7 +524,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
res
=
r
.
json
()
res
=
r
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
raise
HTTPException
(
...
@@ -556,7 +547,7 @@ async def generate_embeddings(
...
@@ -556,7 +547,7 @@ async def generate_embeddings(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
user
=
Depends
(
get_verified_user
),
):
):
if
url_idx
==
None
:
if
url_idx
is
None
:
model
=
form_data
.
model
model
=
form_data
.
model
if
":"
not
in
model
:
if
":"
not
in
model
:
...
@@ -573,12 +564,12 @@ async def generate_embeddings(
...
@@ -573,12 +564,12 @@ async def generate_embeddings(
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/embeddings"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
try
:
try
:
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/embeddings"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
r
.
raise_for_status
()
r
.
raise_for_status
()
return
r
.
json
()
return
r
.
json
()
...
@@ -590,7 +581,7 @@ async def generate_embeddings(
...
@@ -590,7 +581,7 @@ async def generate_embeddings(
res
=
r
.
json
()
res
=
r
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
raise
HTTPException
(
...
@@ -603,10 +594,9 @@ def generate_ollama_embeddings(
...
@@ -603,10 +594,9 @@ def generate_ollama_embeddings(
form_data
:
GenerateEmbeddingsForm
,
form_data
:
GenerateEmbeddingsForm
,
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
):
):
log
.
info
(
f
"generate_ollama_embeddings
{
form_data
}
"
)
log
.
info
(
f
"generate_ollama_embeddings
{
form_data
}
"
)
if
url_idx
==
None
:
if
url_idx
is
None
:
model
=
form_data
.
model
model
=
form_data
.
model
if
":"
not
in
model
:
if
":"
not
in
model
:
...
@@ -623,12 +613,12 @@ def generate_ollama_embeddings(
...
@@ -623,12 +613,12 @@ def generate_ollama_embeddings(
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/embeddings"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
try
:
try
:
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/embeddings"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
r
.
raise_for_status
()
r
.
raise_for_status
()
data
=
r
.
json
()
data
=
r
.
json
()
...
@@ -638,7 +628,7 @@ def generate_ollama_embeddings(
...
@@ -638,7 +628,7 @@ def generate_ollama_embeddings(
if
"embedding"
in
data
:
if
"embedding"
in
data
:
return
data
[
"embedding"
]
return
data
[
"embedding"
]
else
:
else
:
raise
"Something went wrong :/"
raise
Exception
(
"Something went wrong :/"
)
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
...
@@ -647,10 +637,10 @@ def generate_ollama_embeddings(
...
@@ -647,10 +637,10 @@ def generate_ollama_embeddings(
res
=
r
.
json
()
res
=
r
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
error_detail
raise
Exception
(
error_detail
)
class
GenerateCompletionForm
(
BaseModel
):
class
GenerateCompletionForm
(
BaseModel
):
...
@@ -674,8 +664,7 @@ async def generate_completion(
...
@@ -674,8 +664,7 @@ async def generate_completion(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
user
=
Depends
(
get_verified_user
),
):
):
if
url_idx
is
None
:
if
url_idx
==
None
:
model
=
form_data
.
model
model
=
form_data
.
model
if
":"
not
in
model
:
if
":"
not
in
model
:
...
@@ -713,6 +702,18 @@ class GenerateChatCompletionForm(BaseModel):
...
@@ -713,6 +702,18 @@ class GenerateChatCompletionForm(BaseModel):
keep_alive
:
Optional
[
Union
[
int
,
str
]]
=
None
keep_alive
:
Optional
[
Union
[
int
,
str
]]
=
None
def
get_ollama_url
(
url_idx
:
Optional
[
int
],
model
:
str
):
if
url_idx
is
None
:
if
model
not
in
app
.
state
.
MODELS
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
model
),
)
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
model
][
"urls"
])
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
return
url
@
app
.
post
(
"/api/chat"
)
@
app
.
post
(
"/api/chat"
)
@
app
.
post
(
"/api/chat/{url_idx}"
)
@
app
.
post
(
"/api/chat/{url_idx}"
)
async
def
generate_chat_completion
(
async
def
generate_chat_completion
(
...
@@ -720,12 +721,7 @@ async def generate_chat_completion(
...
@@ -720,12 +721,7 @@ async def generate_chat_completion(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
user
=
Depends
(
get_verified_user
),
):
):
log
.
debug
(
f
"
{
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
()
}
="
)
log
.
debug
(
"form_data.model_dump_json(exclude_none=True).encode(): {0} "
.
format
(
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
()
)
)
payload
=
{
payload
=
{
**
form_data
.
model_dump
(
exclude_none
=
True
,
exclude
=
[
"metadata"
]),
**
form_data
.
model_dump
(
exclude_none
=
True
,
exclude
=
[
"metadata"
]),
...
@@ -740,185 +736,21 @@ async def generate_chat_completion(
...
@@ -740,185 +736,21 @@ async def generate_chat_completion(
if
model_info
.
base_model_id
:
if
model_info
.
base_model_id
:
payload
[
"model"
]
=
model_info
.
base_model_id
payload
[
"model"
]
=
model_info
.
base_model_id
model_info
.
params
=
model_info
.
params
.
model_dump
()
params
=
model_info
.
params
.
model_dump
()
if
model_info
.
params
:
if
params
:
if
payload
.
get
(
"options"
)
is
None
:
if
payload
.
get
(
"options"
)
is
None
:
payload
[
"options"
]
=
{}
payload
[
"options"
]
=
{}
if
(
payload
[
"options"
]
=
apply_model_params_to_body_ollama
(
model_info
.
params
.
get
(
"mirostat"
,
None
)
params
,
payload
[
"options"
]
and
payload
[
"options"
].
get
(
"mirostat"
)
is
None
):
payload
[
"options"
][
"mirostat"
]
=
model_info
.
params
.
get
(
"mirostat"
,
None
)
if
(
model_info
.
params
.
get
(
"mirostat_eta"
,
None
)
and
payload
[
"options"
].
get
(
"mirostat_eta"
)
is
None
):
payload
[
"options"
][
"mirostat_eta"
]
=
model_info
.
params
.
get
(
"mirostat_eta"
,
None
)
if
(
model_info
.
params
.
get
(
"mirostat_tau"
,
None
)
and
payload
[
"options"
].
get
(
"mirostat_tau"
)
is
None
):
payload
[
"options"
][
"mirostat_tau"
]
=
model_info
.
params
.
get
(
"mirostat_tau"
,
None
)
if
(
model_info
.
params
.
get
(
"num_ctx"
,
None
)
and
payload
[
"options"
].
get
(
"num_ctx"
)
is
None
):
payload
[
"options"
][
"num_ctx"
]
=
model_info
.
params
.
get
(
"num_ctx"
,
None
)
if
(
model_info
.
params
.
get
(
"num_batch"
,
None
)
and
payload
[
"options"
].
get
(
"num_batch"
)
is
None
):
payload
[
"options"
][
"num_batch"
]
=
model_info
.
params
.
get
(
"num_batch"
,
None
)
if
(
model_info
.
params
.
get
(
"num_keep"
,
None
)
and
payload
[
"options"
].
get
(
"num_keep"
)
is
None
):
payload
[
"options"
][
"num_keep"
]
=
model_info
.
params
.
get
(
"num_keep"
,
None
)
if
(
model_info
.
params
.
get
(
"repeat_last_n"
,
None
)
and
payload
[
"options"
].
get
(
"repeat_last_n"
)
is
None
):
payload
[
"options"
][
"repeat_last_n"
]
=
model_info
.
params
.
get
(
"repeat_last_n"
,
None
)
if
(
model_info
.
params
.
get
(
"frequency_penalty"
,
None
)
and
payload
[
"options"
].
get
(
"frequency_penalty"
)
is
None
):
payload
[
"options"
][
"repeat_penalty"
]
=
model_info
.
params
.
get
(
"frequency_penalty"
,
None
)
if
(
model_info
.
params
.
get
(
"temperature"
,
None
)
is
not
None
and
payload
[
"options"
].
get
(
"temperature"
)
is
None
):
payload
[
"options"
][
"temperature"
]
=
model_info
.
params
.
get
(
"temperature"
,
None
)
if
(
model_info
.
params
.
get
(
"seed"
,
None
)
is
not
None
and
payload
[
"options"
].
get
(
"seed"
)
is
None
):
payload
[
"options"
][
"seed"
]
=
model_info
.
params
.
get
(
"seed"
,
None
)
if
(
model_info
.
params
.
get
(
"stop"
,
None
)
and
payload
[
"options"
].
get
(
"stop"
)
is
None
):
payload
[
"options"
][
"stop"
]
=
(
[
bytes
(
stop
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
stop
in
model_info
.
params
[
"stop"
]
]
if
model_info
.
params
.
get
(
"stop"
,
None
)
else
None
)
if
(
model_info
.
params
.
get
(
"tfs_z"
,
None
)
and
payload
[
"options"
].
get
(
"tfs_z"
)
is
None
):
payload
[
"options"
][
"tfs_z"
]
=
model_info
.
params
.
get
(
"tfs_z"
,
None
)
if
(
model_info
.
params
.
get
(
"max_tokens"
,
None
)
and
payload
[
"options"
].
get
(
"max_tokens"
)
is
None
):
payload
[
"options"
][
"num_predict"
]
=
model_info
.
params
.
get
(
"max_tokens"
,
None
)
if
(
model_info
.
params
.
get
(
"top_k"
,
None
)
and
payload
[
"options"
].
get
(
"top_k"
)
is
None
):
payload
[
"options"
][
"top_k"
]
=
model_info
.
params
.
get
(
"top_k"
,
None
)
if
(
model_info
.
params
.
get
(
"top_p"
,
None
)
and
payload
[
"options"
].
get
(
"top_p"
)
is
None
):
payload
[
"options"
][
"top_p"
]
=
model_info
.
params
.
get
(
"top_p"
,
None
)
if
(
model_info
.
params
.
get
(
"min_p"
,
None
)
and
payload
[
"options"
].
get
(
"min_p"
)
is
None
):
payload
[
"options"
][
"min_p"
]
=
model_info
.
params
.
get
(
"min_p"
,
None
)
if
(
model_info
.
params
.
get
(
"use_mmap"
,
None
)
and
payload
[
"options"
].
get
(
"use_mmap"
)
is
None
):
payload
[
"options"
][
"use_mmap"
]
=
model_info
.
params
.
get
(
"use_mmap"
,
None
)
if
(
model_info
.
params
.
get
(
"use_mlock"
,
None
)
and
payload
[
"options"
].
get
(
"use_mlock"
)
is
None
):
payload
[
"options"
][
"use_mlock"
]
=
model_info
.
params
.
get
(
"use_mlock"
,
None
)
if
(
model_info
.
params
.
get
(
"num_thread"
,
None
)
and
payload
[
"options"
].
get
(
"num_thread"
)
is
None
):
payload
[
"options"
][
"num_thread"
]
=
model_info
.
params
.
get
(
"num_thread"
,
None
)
system
=
model_info
.
params
.
get
(
"system"
,
None
)
if
system
:
system
=
prompt_template
(
system
,
**
(
{
"user_name"
:
user
.
name
,
"user_location"
:
(
user
.
info
.
get
(
"location"
)
if
user
.
info
else
None
),
}
if
user
else
{}
),
)
)
payload
=
apply_model_system_prompt_to_body
(
params
,
payload
,
user
)
if
payload
.
get
(
"messages"
):
if
":"
not
in
payload
[
"model"
]:
payload
[
"messages"
]
=
add_or_update_system_message
(
payload
[
"model"
]
=
f
"
{
payload
[
'model'
]
}
:latest"
system
,
payload
[
"messages"
]
)
if
url_idx
==
None
:
if
":"
not
in
payload
[
"model"
]:
payload
[
"model"
]
=
f
"
{
payload
[
'model'
]
}
:latest"
if
payload
[
"model"
]
in
app
.
state
.
MODELS
:
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
payload
[
"model"
]][
"urls"
])
else
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
get_ollama_url
(
url_idx
,
payload
[
"model"
])
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
log
.
debug
(
payload
)
log
.
debug
(
payload
)
...
@@ -952,83 +784,28 @@ async def generate_openai_chat_completion(
...
@@ -952,83 +784,28 @@ async def generate_openai_chat_completion(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
user
=
Depends
(
get_verified_user
),
):
):
form_data
=
OpenAIChatCompletionForm
(
**
form_data
)
completion_form
=
OpenAIChatCompletionForm
(
**
form_data
)
payload
=
{
**
form_data
.
model_dump
(
exclude_none
=
True
,
exclude
=
[
"metadata"
])}
payload
=
{
**
completion_form
.
model_dump
(
exclude_none
=
True
,
exclude
=
[
"metadata"
])}
if
"metadata"
in
payload
:
if
"metadata"
in
payload
:
del
payload
[
"metadata"
]
del
payload
[
"metadata"
]
model_id
=
form_data
.
model
model_id
=
completion_form
.
model
model_info
=
Models
.
get_model_by_id
(
model_id
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
if
model_info
:
if
model_info
:
if
model_info
.
base_model_id
:
if
model_info
.
base_model_id
:
payload
[
"model"
]
=
model_info
.
base_model_id
payload
[
"model"
]
=
model_info
.
base_model_id
model_info
.
params
=
model_info
.
params
.
model_dump
()
params
=
model_info
.
params
.
model_dump
()
if
model_info
.
params
:
if
params
:
payload
[
"temperature"
]
=
model_info
.
params
.
get
(
"temperature"
,
None
)
payload
=
apply_model_params_to_body_openai
(
params
,
payload
)
payload
[
"top_p"
]
=
model_info
.
params
.
get
(
"top_p"
,
None
)
payload
=
apply_model_system_prompt_to_body
(
params
,
payload
,
user
)
payload
[
"max_tokens"
]
=
model_info
.
params
.
get
(
"max_tokens"
,
None
)
payload
[
"frequency_penalty"
]
=
model_info
.
params
.
get
(
"frequency_penalty"
,
None
)
payload
[
"seed"
]
=
model_info
.
params
.
get
(
"seed"
,
None
)
payload
[
"stop"
]
=
(
[
bytes
(
stop
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
stop
in
model_info
.
params
[
"stop"
]
]
if
model_info
.
params
.
get
(
"stop"
,
None
)
else
None
)
system
=
model_info
.
params
.
get
(
"system"
,
None
)
if
":"
not
in
payload
[
"model"
]:
payload
[
"model"
]
=
f
"
{
payload
[
'model'
]
}
:latest"
if
system
:
system
=
prompt_template
(
system
,
**
(
{
"user_name"
:
user
.
name
,
"user_location"
:
(
user
.
info
.
get
(
"location"
)
if
user
.
info
else
None
),
}
if
user
else
{}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if
payload
.
get
(
"messages"
):
for
message
in
payload
[
"messages"
]:
if
message
.
get
(
"role"
)
==
"system"
:
message
[
"content"
]
=
system
+
message
[
"content"
]
break
else
:
payload
[
"messages"
].
insert
(
0
,
{
"role"
:
"system"
,
"content"
:
system
,
},
)
if
url_idx
==
None
:
url
=
get_ollama_url
(
url_idx
,
payload
[
"model"
])
if
":"
not
in
payload
[
"model"
]:
payload
[
"model"
]
=
f
"
{
payload
[
'model'
]
}
:latest"
if
payload
[
"model"
]
in
app
.
state
.
MODELS
:
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
payload
[
"model"
]][
"urls"
])
else
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
return
await
post_streaming_url
(
return
await
post_streaming_url
(
...
@@ -1044,7 +821,7 @@ async def get_openai_models(
...
@@ -1044,7 +821,7 @@ async def get_openai_models(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
user
=
Depends
(
get_verified_user
),
):
):
if
url_idx
==
None
:
if
url_idx
is
None
:
models
=
await
get_all_models
()
models
=
await
get_all_models
()
if
app
.
state
.
config
.
ENABLE_MODEL_FILTER
:
if
app
.
state
.
config
.
ENABLE_MODEL_FILTER
:
...
@@ -1099,7 +876,7 @@ async def get_openai_models(
...
@@ -1099,7 +876,7 @@ async def get_openai_models(
res
=
r
.
json
()
res
=
r
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
raise
HTTPException
(
...
@@ -1125,7 +902,6 @@ def parse_huggingface_url(hf_url):
...
@@ -1125,7 +902,6 @@ def parse_huggingface_url(hf_url):
path_components
=
parsed_url
.
path
.
split
(
"/"
)
path_components
=
parsed_url
.
path
.
split
(
"/"
)
# Extract the desired output
# Extract the desired output
user_repo
=
"/"
.
join
(
path_components
[
1
:
3
])
model_file
=
path_components
[
-
1
]
model_file
=
path_components
[
-
1
]
return
model_file
return
model_file
...
@@ -1190,7 +966,6 @@ async def download_model(
...
@@ -1190,7 +966,6 @@ async def download_model(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_admin_user
),
user
=
Depends
(
get_admin_user
),
):
):
allowed_hosts
=
[
"https://huggingface.co/"
,
"https://github.com/"
]
allowed_hosts
=
[
"https://huggingface.co/"
,
"https://github.com/"
]
if
not
any
(
form_data
.
url
.
startswith
(
host
)
for
host
in
allowed_hosts
):
if
not
any
(
form_data
.
url
.
startswith
(
host
)
for
host
in
allowed_hosts
):
...
@@ -1199,7 +974,7 @@ async def download_model(
...
@@ -1199,7 +974,7 @@ async def download_model(
detail
=
"Invalid file_url. Only URLs from allowed hosts are permitted."
,
detail
=
"Invalid file_url. Only URLs from allowed hosts are permitted."
,
)
)
if
url_idx
==
None
:
if
url_idx
is
None
:
url_idx
=
0
url_idx
=
0
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
...
@@ -1222,7 +997,7 @@ def upload_model(
...
@@ -1222,7 +997,7 @@ def upload_model(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_admin_user
),
user
=
Depends
(
get_admin_user
),
):
):
if
url_idx
==
None
:
if
url_idx
is
None
:
url_idx
=
0
url_idx
=
0
ollama_url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
ollama_url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
...
...
backend/apps/openai/main.py
View file @
9c2429ff
...
@@ -17,7 +17,10 @@ from utils.utils import (
...
@@ -17,7 +17,10 @@ from utils.utils import (
get_verified_user
,
get_verified_user
,
get_admin_user
,
get_admin_user
,
)
)
from
utils.misc
import
apply_model_params_to_body
,
apply_model_system_prompt_to_body
from
utils.misc
import
(
apply_model_params_to_body_openai
,
apply_model_system_prompt_to_body
,
)
from
config
import
(
from
config
import
(
SRC_LOG_LEVELS
,
SRC_LOG_LEVELS
,
...
@@ -368,7 +371,7 @@ async def generate_chat_completion(
...
@@ -368,7 +371,7 @@ async def generate_chat_completion(
payload
[
"model"
]
=
model_info
.
base_model_id
payload
[
"model"
]
=
model_info
.
base_model_id
params
=
model_info
.
params
.
model_dump
()
params
=
model_info
.
params
.
model_dump
()
payload
=
apply_model_params_to_body
(
params
,
payload
)
payload
=
apply_model_params_to_body
_openai
(
params
,
payload
)
payload
=
apply_model_system_prompt_to_body
(
params
,
payload
,
user
)
payload
=
apply_model_system_prompt_to_body
(
params
,
payload
,
user
)
model
=
app
.
state
.
MODELS
[
payload
.
get
(
"model"
)]
model
=
app
.
state
.
MODELS
[
payload
.
get
(
"model"
)]
...
...
backend/apps/webui/main.py
View file @
9c2429ff
...
@@ -22,7 +22,7 @@ from apps.webui.utils import load_function_module_by_id
...
@@ -22,7 +22,7 @@ from apps.webui.utils import load_function_module_by_id
from
utils.misc
import
(
from
utils.misc
import
(
openai_chat_chunk_message_template
,
openai_chat_chunk_message_template
,
openai_chat_completion_message_template
,
openai_chat_completion_message_template
,
apply_model_params_to_body
,
apply_model_params_to_body
_openai
,
apply_model_system_prompt_to_body
,
apply_model_system_prompt_to_body
,
)
)
...
@@ -291,7 +291,7 @@ async def generate_function_chat_completion(form_data, user):
...
@@ -291,7 +291,7 @@ async def generate_function_chat_completion(form_data, user):
form_data
[
"model"
]
=
model_info
.
base_model_id
form_data
[
"model"
]
=
model_info
.
base_model_id
params
=
model_info
.
params
.
model_dump
()
params
=
model_info
.
params
.
model_dump
()
form_data
=
apply_model_params_to_body
(
params
,
form_data
)
form_data
=
apply_model_params_to_body
_openai
(
params
,
form_data
)
form_data
=
apply_model_system_prompt_to_body
(
params
,
form_data
,
user
)
form_data
=
apply_model_system_prompt_to_body
(
params
,
form_data
,
user
)
pipe_id
=
get_pipe_id
(
form_data
)
pipe_id
=
get_pipe_id
(
form_data
)
...
...
backend/utils/misc.py
View file @
9c2429ff
...
@@ -2,7 +2,7 @@ from pathlib import Path
...
@@ -2,7 +2,7 @@ from pathlib import Path
import
hashlib
import
hashlib
import
re
import
re
from
datetime
import
timedelta
from
datetime
import
timedelta
from
typing
import
Optional
,
List
,
Tuple
from
typing
import
Optional
,
List
,
Tuple
,
Callable
import
uuid
import
uuid
import
time
import
time
...
@@ -135,10 +135,21 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di
...
@@ -135,10 +135,21 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di
# inplace function: form_data is modified
# inplace function: form_data is modified
def
apply_model_params_to_body
(
params
:
dict
,
form_data
:
dict
)
->
dict
:
def
apply_model_params_to_body
(
params
:
dict
,
form_data
:
dict
,
mappings
:
dict
[
str
,
Callable
]
)
->
dict
:
if
not
params
:
if
not
params
:
return
form_data
return
form_data
for
key
,
cast_func
in
mappings
.
items
():
if
(
value
:
=
params
.
get
(
key
))
is
not
None
:
form_data
[
key
]
=
cast_func
(
value
)
return
form_data
# inplace function: form_data is modified
def
apply_model_params_to_body_openai
(
params
:
dict
,
form_data
:
dict
)
->
dict
:
mappings
=
{
mappings
=
{
"temperature"
:
float
,
"temperature"
:
float
,
"top_p"
:
int
,
"top_p"
:
int
,
...
@@ -147,10 +158,39 @@ def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
...
@@ -147,10 +158,39 @@ def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
"seed"
:
lambda
x
:
x
,
"seed"
:
lambda
x
:
x
,
"stop"
:
lambda
x
:
[
bytes
(
s
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
s
in
x
],
"stop"
:
lambda
x
:
[
bytes
(
s
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
s
in
x
],
}
}
return
apply_model_params_to_body
(
params
,
form_data
,
mappings
)
def
apply_model_params_to_body_ollama
(
params
:
dict
,
form_data
:
dict
)
->
dict
:
opts
=
[
"temperature"
,
"top_p"
,
"seed"
,
"mirostat"
,
"mirostat_eta"
,
"mirostat_tau"
,
"num_ctx"
,
"num_batch"
,
"num_keep"
,
"repeat_last_n"
,
"tfs_z"
,
"top_k"
,
"min_p"
,
"use_mmap"
,
"use_mlock"
,
"num_thread"
,
]
mappings
=
{
i
:
lambda
x
:
x
for
i
in
opts
}
form_data
=
apply_model_params_to_body
(
params
,
form_data
,
mappings
)
name_differences
=
{
"max_tokens"
:
"num_predict"
,
"frequency_penalty"
:
"repeat_penalty"
,
}
for
key
,
cast_func
in
mapping
s
.
items
():
for
key
,
value
in
name_difference
s
.
items
():
if
(
value
:
=
params
.
get
(
key
))
is
not
None
:
if
(
param
:
=
params
.
get
(
key
,
None
))
is
not
None
:
form_data
[
key
]
=
cast_func
(
value
)
form_data
[
value
]
=
param
return
form_data
return
form_data
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment