Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
fc31267a
"...resnet50_tensorflow.git" did not exist on "ad3427f9938c3051b8e12c7f2f4facf6f307743c"
Commit
fc31267a
authored
Aug 06, 2024
by
Michael Poluektov
Browse files
refac: re-use utils.misc
parent
44c781f4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
85 additions
and
229 deletions
+85
-229
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+38
-214
backend/apps/openai/main.py
backend/apps/openai/main.py
+5
-2
backend/apps/webui/main.py
backend/apps/webui/main.py
+2
-2
backend/utils/misc.py
backend/utils/misc.py
+40
-11
No files found.
backend/apps/ollama/main.py
View file @
fc31267a
...
@@ -44,7 +44,13 @@ from config import (
...
@@ -44,7 +44,13 @@ from config import (
UPLOAD_DIR
,
UPLOAD_DIR
,
AppConfig
,
AppConfig
,
)
)
from
utils.misc
import
calculate_sha256
,
add_or_update_system_message
from
utils.misc
import
(
apply_model_params_to_body_ollama
,
calculate_sha256
,
add_or_update_system_message
,
apply_model_params_to_body_openai
,
apply_model_system_prompt_to_body
,
)
log
=
logging
.
getLogger
(
__name__
)
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"OLLAMA"
])
log
.
setLevel
(
SRC_LOG_LEVELS
[
"OLLAMA"
])
...
@@ -699,6 +705,18 @@ class GenerateChatCompletionForm(BaseModel):
...
@@ -699,6 +705,18 @@ class GenerateChatCompletionForm(BaseModel):
keep_alive
:
Optional
[
Union
[
int
,
str
]]
=
None
keep_alive
:
Optional
[
Union
[
int
,
str
]]
=
None
def
get_ollama_url
(
url_idx
:
Optional
[
int
],
model
:
str
):
if
url_idx
is
None
:
if
model
not
in
app
.
state
.
MODELS
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
model
),
)
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
model
][
"urls"
])
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
return
url
@
app
.
post
(
"/api/chat"
)
@
app
.
post
(
"/api/chat"
)
@
app
.
post
(
"/api/chat/{url_idx}"
)
@
app
.
post
(
"/api/chat/{url_idx}"
)
async
def
generate_chat_completion
(
async
def
generate_chat_completion
(
...
@@ -706,17 +724,12 @@ async def generate_chat_completion(
...
@@ -706,17 +724,12 @@ async def generate_chat_completion(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
user
=
Depends
(
get_verified_user
),
):
):
log
.
debug
(
log
.
debug
(
f
"
{
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
()
}
="
)
"form_data.model_dump_json(exclude_none=True).encode(): {0} "
.
format
(
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
()
)
)
payload
=
{
payload
=
{
**
form_data
.
model_dump
(
exclude_none
=
True
,
exclude
=
[
"metadata"
]),
**
form_data
.
model_dump
(
exclude_none
=
True
,
exclude
=
[
"metadata"
]),
}
}
if
"metadata"
in
payload
:
payload
.
pop
(
"metadata"
)
del
payload
[
"metadata"
]
model_id
=
form_data
.
model
model_id
=
form_data
.
model
model_info
=
Models
.
get_model_by_id
(
model_id
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
...
@@ -731,148 +744,15 @@ async def generate_chat_completion(
...
@@ -731,148 +744,15 @@ async def generate_chat_completion(
if
payload
.
get
(
"options"
)
is
None
:
if
payload
.
get
(
"options"
)
is
None
:
payload
[
"options"
]
=
{}
payload
[
"options"
]
=
{}
if
(
payload
[
"options"
]
=
apply_model_params_to_body_ollama
(
params
.
get
(
"mirostat"
,
None
)
params
,
payload
[
"options"
]
and
payload
[
"options"
].
get
(
"mirostat"
)
is
None
):
payload
[
"options"
][
"mirostat"
]
=
params
.
get
(
"mirostat"
,
None
)
if
(
params
.
get
(
"mirostat_eta"
,
None
)
and
payload
[
"options"
].
get
(
"mirostat_eta"
)
is
None
):
payload
[
"options"
][
"mirostat_eta"
]
=
params
.
get
(
"mirostat_eta"
,
None
)
if
(
params
.
get
(
"mirostat_tau"
,
None
)
and
payload
[
"options"
].
get
(
"mirostat_tau"
)
is
None
):
payload
[
"options"
][
"mirostat_tau"
]
=
params
.
get
(
"mirostat_tau"
,
None
)
if
(
params
.
get
(
"num_ctx"
,
None
)
and
payload
[
"options"
].
get
(
"num_ctx"
)
is
None
):
payload
[
"options"
][
"num_ctx"
]
=
params
.
get
(
"num_ctx"
,
None
)
if
(
params
.
get
(
"num_batch"
,
None
)
and
payload
[
"options"
].
get
(
"num_batch"
)
is
None
):
payload
[
"options"
][
"num_batch"
]
=
params
.
get
(
"num_batch"
,
None
)
if
(
params
.
get
(
"num_keep"
,
None
)
and
payload
[
"options"
].
get
(
"num_keep"
)
is
None
):
payload
[
"options"
][
"num_keep"
]
=
params
.
get
(
"num_keep"
,
None
)
if
(
params
.
get
(
"repeat_last_n"
,
None
)
and
payload
[
"options"
].
get
(
"repeat_last_n"
)
is
None
):
payload
[
"options"
][
"repeat_last_n"
]
=
params
.
get
(
"repeat_last_n"
,
None
)
if
(
params
.
get
(
"frequency_penalty"
,
None
)
and
payload
[
"options"
].
get
(
"frequency_penalty"
)
is
None
):
payload
[
"options"
][
"repeat_penalty"
]
=
params
.
get
(
"frequency_penalty"
,
None
)
if
(
params
.
get
(
"temperature"
,
None
)
is
not
None
and
payload
[
"options"
].
get
(
"temperature"
)
is
None
):
payload
[
"options"
][
"temperature"
]
=
params
.
get
(
"temperature"
,
None
)
if
(
params
.
get
(
"seed"
,
None
)
is
not
None
and
payload
[
"options"
].
get
(
"seed"
)
is
None
):
payload
[
"options"
][
"seed"
]
=
params
.
get
(
"seed"
,
None
)
if
params
.
get
(
"stop"
,
None
)
and
payload
[
"options"
].
get
(
"stop"
)
is
None
:
payload
[
"options"
][
"stop"
]
=
(
[
bytes
(
stop
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
stop
in
params
[
"stop"
]
]
if
params
.
get
(
"stop"
,
None
)
else
None
)
if
params
.
get
(
"tfs_z"
,
None
)
and
payload
[
"options"
].
get
(
"tfs_z"
)
is
None
:
payload
[
"options"
][
"tfs_z"
]
=
params
.
get
(
"tfs_z"
,
None
)
if
(
params
.
get
(
"max_tokens"
,
None
)
and
payload
[
"options"
].
get
(
"max_tokens"
)
is
None
):
payload
[
"options"
][
"num_predict"
]
=
params
.
get
(
"max_tokens"
,
None
)
if
params
.
get
(
"top_k"
,
None
)
and
payload
[
"options"
].
get
(
"top_k"
)
is
None
:
payload
[
"options"
][
"top_k"
]
=
params
.
get
(
"top_k"
,
None
)
if
params
.
get
(
"top_p"
,
None
)
and
payload
[
"options"
].
get
(
"top_p"
)
is
None
:
payload
[
"options"
][
"top_p"
]
=
params
.
get
(
"top_p"
,
None
)
if
params
.
get
(
"min_p"
,
None
)
and
payload
[
"options"
].
get
(
"min_p"
)
is
None
:
payload
[
"options"
][
"min_p"
]
=
params
.
get
(
"min_p"
,
None
)
if
(
params
.
get
(
"use_mmap"
,
None
)
and
payload
[
"options"
].
get
(
"use_mmap"
)
is
None
):
payload
[
"options"
][
"use_mmap"
]
=
params
.
get
(
"use_mmap"
,
None
)
if
(
params
.
get
(
"use_mlock"
,
None
)
and
payload
[
"options"
].
get
(
"use_mlock"
)
is
None
):
payload
[
"options"
][
"use_mlock"
]
=
params
.
get
(
"use_mlock"
,
None
)
if
(
params
.
get
(
"num_thread"
,
None
)
and
payload
[
"options"
].
get
(
"num_thread"
)
is
None
):
payload
[
"options"
][
"num_thread"
]
=
params
.
get
(
"num_thread"
,
None
)
system
=
params
.
get
(
"system"
,
None
)
if
system
:
system
=
prompt_template
(
system
,
**
(
{
"user_name"
:
user
.
name
,
"user_location"
:
(
user
.
info
.
get
(
"location"
)
if
user
.
info
else
None
),
}
if
user
else
{}
),
)
)
payload
=
apply_model_system_prompt_to_body
(
params
,
payload
,
user
)
if
payload
.
get
(
"messages"
):
if
":"
not
in
payload
[
"model"
]:
payload
[
"messages"
]
=
add_or_update_system_message
(
payload
[
"model"
]
=
f
"
{
payload
[
'model'
]
}
:latest"
system
,
payload
[
"messages"
]
)
if
url_idx
is
None
:
if
":"
not
in
payload
[
"model"
]:
payload
[
"model"
]
=
f
"
{
payload
[
'model'
]
}
:latest"
if
payload
[
"model"
]
in
app
.
state
.
MODELS
:
url
=
get_ollama_url
(
url_idx
,
payload
[
"model"
])
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
payload
[
"model"
]][
"urls"
])
else
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
log
.
debug
(
payload
)
log
.
debug
(
payload
)
...
@@ -906,83 +786,27 @@ async def generate_openai_chat_completion(
...
@@ -906,83 +786,27 @@ async def generate_openai_chat_completion(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
user
=
Depends
(
get_verified_user
),
):
):
form_data
=
OpenAIChatCompletionForm
(
**
form_data
)
completion_form
=
OpenAIChatCompletionForm
(
**
form_data
)
payload
=
{
**
form_data
.
model_dump
(
exclude_none
=
True
,
exclude
=
[
"metadata"
])}
payload
=
{
**
completion_form
.
model_dump
(
exclude_none
=
True
,
exclude
=
[
"metadata"
])}
payload
.
pop
(
"metadata"
)
if
"metadata"
in
payload
:
del
payload
[
"metadata"
]
model_id
=
form_data
.
model
model_id
=
completion_form
.
model
model_info
=
Models
.
get_model_by_id
(
model_id
)
model_info
=
Models
.
get_model_by_id
(
model_id
)
if
model_info
:
if
model_info
:
if
model_info
.
base_model_id
:
if
model_info
.
base_model_id
:
payload
[
"model"
]
=
model_info
.
base_model_id
payload
[
"model"
]
=
model_info
.
base_model_id
model_info
.
params
=
model_info
.
params
.
model_dump
()
params
=
model_info
.
params
.
model_dump
()
if
model_info
.
params
:
payload
[
"temperature"
]
=
model_info
.
params
.
get
(
"temperature"
,
None
)
payload
[
"top_p"
]
=
model_info
.
params
.
get
(
"top_p"
,
None
)
payload
[
"max_tokens"
]
=
model_info
.
params
.
get
(
"max_tokens"
,
None
)
payload
[
"frequency_penalty"
]
=
model_info
.
params
.
get
(
"frequency_penalty"
,
None
)
payload
[
"seed"
]
=
model_info
.
params
.
get
(
"seed"
,
None
)
payload
[
"stop"
]
=
(
[
bytes
(
stop
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
stop
in
model_info
.
params
[
"stop"
]
]
if
model_info
.
params
.
get
(
"stop"
,
None
)
else
None
)
system
=
model_info
.
params
.
get
(
"system"
,
None
)
if
system
:
system
=
prompt_template
(
system
,
**
(
{
"user_name"
:
user
.
name
,
"user_location"
:
(
user
.
info
.
get
(
"location"
)
if
user
.
info
else
None
),
}
if
user
else
{}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if
payload
.
get
(
"messages"
):
for
message
in
payload
[
"messages"
]:
if
message
.
get
(
"role"
)
==
"system"
:
message
[
"content"
]
=
system
+
message
[
"content"
]
break
else
:
payload
[
"messages"
].
insert
(
0
,
{
"role"
:
"system"
,
"content"
:
system
,
},
)
if
url_idx
is
None
:
if
params
:
if
":"
not
in
payload
[
"model"
]:
payload
=
apply_model_params_to_body_openai
(
params
,
payload
)
payload
[
"model"
]
=
f
"
{
payload
[
'model'
]
}
:latest"
payload
=
apply_model_system_prompt_to_body
(
params
,
payload
,
user
)
if
payload
[
"model"
]
in
app
.
state
.
MODELS
:
if
":"
not
in
payload
[
"model"
]:
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
payload
[
"model"
]][
"urls"
])
payload
[
"model"
]
=
f
"
{
payload
[
'model'
]
}
:latest"
else
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
get_ollama_url
(
url_idx
,
payload
[
"model"
])
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
return
await
post_streaming_url
(
return
await
post_streaming_url
(
...
...
backend/apps/openai/main.py
View file @
fc31267a
...
@@ -17,7 +17,10 @@ from utils.utils import (
...
@@ -17,7 +17,10 @@ from utils.utils import (
get_verified_user
,
get_verified_user
,
get_admin_user
,
get_admin_user
,
)
)
from
utils.misc
import
apply_model_params_to_body
,
apply_model_system_prompt_to_body
from
utils.misc
import
(
apply_model_params_to_body_openai
,
apply_model_system_prompt_to_body
,
)
from
config
import
(
from
config
import
(
SRC_LOG_LEVELS
,
SRC_LOG_LEVELS
,
...
@@ -366,7 +369,7 @@ async def generate_chat_completion(
...
@@ -366,7 +369,7 @@ async def generate_chat_completion(
payload
[
"model"
]
=
model_info
.
base_model_id
payload
[
"model"
]
=
model_info
.
base_model_id
params
=
model_info
.
params
.
model_dump
()
params
=
model_info
.
params
.
model_dump
()
payload
=
apply_model_params_to_body
(
params
,
payload
)
payload
=
apply_model_params_to_body
_openai
(
params
,
payload
)
payload
=
apply_model_system_prompt_to_body
(
params
,
payload
,
user
)
payload
=
apply_model_system_prompt_to_body
(
params
,
payload
,
user
)
model
=
app
.
state
.
MODELS
[
payload
.
get
(
"model"
)]
model
=
app
.
state
.
MODELS
[
payload
.
get
(
"model"
)]
...
...
backend/apps/webui/main.py
View file @
fc31267a
...
@@ -22,7 +22,7 @@ from apps.webui.utils import load_function_module_by_id
...
@@ -22,7 +22,7 @@ from apps.webui.utils import load_function_module_by_id
from
utils.misc
import
(
from
utils.misc
import
(
openai_chat_chunk_message_template
,
openai_chat_chunk_message_template
,
openai_chat_completion_message_template
,
openai_chat_completion_message_template
,
apply_model_params_to_body
,
apply_model_params_to_body
_openai
,
apply_model_system_prompt_to_body
,
apply_model_system_prompt_to_body
,
)
)
...
@@ -289,7 +289,7 @@ async def generate_function_chat_completion(form_data, user):
...
@@ -289,7 +289,7 @@ async def generate_function_chat_completion(form_data, user):
form_data
[
"model"
]
=
model_info
.
base_model_id
form_data
[
"model"
]
=
model_info
.
base_model_id
params
=
model_info
.
params
.
model_dump
()
params
=
model_info
.
params
.
model_dump
()
form_data
=
apply_model_params_to_body
(
params
,
form_data
)
form_data
=
apply_model_params_to_body
_openai
(
params
,
form_data
)
form_data
=
apply_model_system_prompt_to_body
(
params
,
form_data
,
user
)
form_data
=
apply_model_system_prompt_to_body
(
params
,
form_data
,
user
)
pipe_id
=
get_pipe_id
(
form_data
)
pipe_id
=
get_pipe_id
(
form_data
)
...
...
backend/utils/misc.py
View file @
fc31267a
...
@@ -2,7 +2,7 @@ from pathlib import Path
...
@@ -2,7 +2,7 @@ from pathlib import Path
import
hashlib
import
hashlib
import
re
import
re
from
datetime
import
timedelta
from
datetime
import
timedelta
from
typing
import
Optional
,
List
,
Tuple
from
typing
import
Optional
,
List
,
Tuple
,
Callable
import
uuid
import
uuid
import
time
import
time
...
@@ -135,19 +135,12 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di
...
@@ -135,19 +135,12 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di
# inplace function: form_data is modified
# inplace function: form_data is modified
def
apply_model_params_to_body
(
params
:
dict
,
form_data
:
dict
)
->
dict
:
def
apply_model_params_to_body
(
params
:
dict
,
form_data
:
dict
,
mappings
:
dict
[
str
,
Callable
]
)
->
dict
:
if
not
params
:
if
not
params
:
return
form_data
return
form_data
mappings
=
{
"temperature"
:
float
,
"top_p"
:
int
,
"max_tokens"
:
int
,
"frequency_penalty"
:
int
,
"seed"
:
lambda
x
:
x
,
"stop"
:
lambda
x
:
[
bytes
(
s
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
s
in
x
],
}
for
key
,
cast_func
in
mappings
.
items
():
for
key
,
cast_func
in
mappings
.
items
():
if
(
value
:
=
params
.
get
(
key
))
is
not
None
:
if
(
value
:
=
params
.
get
(
key
))
is
not
None
:
form_data
[
key
]
=
cast_func
(
value
)
form_data
[
key
]
=
cast_func
(
value
)
...
@@ -155,6 +148,42 @@ def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
...
@@ -155,6 +148,42 @@ def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
return
form_data
return
form_data
OPENAI_MAPPINGS
=
{
"temperature"
:
float
,
"top_p"
:
int
,
"max_tokens"
:
int
,
"frequency_penalty"
:
int
,
"seed"
:
lambda
x
:
x
,
"stop"
:
lambda
x
:
[
bytes
(
s
,
"utf-8"
).
decode
(
"unicode_escape"
)
for
s
in
x
],
}
# inplace function: form_data is modified
def
apply_model_params_to_body_openai
(
params
:
dict
,
form_data
:
dict
)
->
dict
:
return
apply_model_params_to_body
(
params
,
form_data
,
OPENAI_MAPPINGS
)
def
apply_model_params_to_body_ollama
(
params
:
dict
,
form_data
:
dict
)
->
dict
:
opts
=
[
"mirostat"
,
"mirostat_eta"
,
"mirostat_tau"
,
"num_ctx"
,
"num_batch"
,
"num_keep"
,
"repeat_last_n"
,
"tfs_z"
,
"top_k"
,
"min_p"
,
"use_mmap"
,
"use_mlock"
,
"num_thread"
,
]
mappings
=
{
i
:
lambda
x
:
x
for
i
in
opts
}
mappings
=
{
**
mappings
,
**
OPENAI_MAPPINGS
}
return
apply_model_params_to_body
(
params
,
form_data
,
mappings
)
def
get_gravatar_url
(
email
):
def
get_gravatar_url
(
email
):
# Trim leading and trailing whitespace from
# Trim leading and trailing whitespace from
# an email address and force all characters
# an email address and force all characters
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment