Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
831fe9f5
Commit
831fe9f5
authored
Aug 06, 2024
by
Michael Poluektov
Browse files
cleanup
parent
a140d319
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
39 deletions
+24
-39
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+24
-39
No files found.
backend/apps/ollama/main.py
View file @
831fe9f5
from
fastapi
import
(
from
fastapi
import
(
FastAPI
,
FastAPI
,
Request
,
Request
,
Response
,
HTTPException
,
HTTPException
,
Depends
,
Depends
,
status
,
UploadFile
,
UploadFile
,
File
,
File
,
BackgroundTasks
,
)
)
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
StreamingResponse
from
fastapi.responses
import
StreamingResponse
from
fastapi.concurrency
import
run_in_threadpool
from
pydantic
import
BaseModel
,
ConfigDict
from
pydantic
import
BaseModel
,
ConfigDict
import
os
import
os
import
re
import
re
import
copy
import
random
import
random
import
requests
import
requests
import
json
import
json
import
uuid
import
aiohttp
import
aiohttp
import
asyncio
import
asyncio
import
logging
import
logging
...
@@ -32,11 +26,8 @@ from typing import Optional, List, Union
...
@@ -32,11 +26,8 @@ from typing import Optional, List, Union
from
starlette.background
import
BackgroundTask
from
starlette.background
import
BackgroundTask
from
apps.webui.models.models
import
Models
from
apps.webui.models.models
import
Models
from
apps.webui.models.users
import
Users
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
from
utils.utils
import
(
from
utils.utils
import
(
decode_token
,
get_current_user
,
get_verified_user
,
get_verified_user
,
get_admin_user
,
get_admin_user
,
)
)
...
@@ -183,7 +174,7 @@ async def post_streaming_url(url: str, payload: str, stream: bool = True):
...
@@ -183,7 +174,7 @@ async def post_streaming_url(url: str, payload: str, stream: bool = True):
res
=
await
r
.
json
()
res
=
await
r
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
raise
HTTPException
(
...
@@ -238,7 +229,7 @@ async def get_all_models():
...
@@ -238,7 +229,7 @@ async def get_all_models():
async
def
get_ollama_tags
(
async
def
get_ollama_tags
(
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
)
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
)
):
):
if
url_idx
==
None
:
if
url_idx
is
None
:
models
=
await
get_all_models
()
models
=
await
get_all_models
()
if
app
.
state
.
config
.
ENABLE_MODEL_FILTER
:
if
app
.
state
.
config
.
ENABLE_MODEL_FILTER
:
...
@@ -269,7 +260,7 @@ async def get_ollama_tags(
...
@@ -269,7 +260,7 @@ async def get_ollama_tags(
res
=
r
.
json
()
res
=
r
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
raise
HTTPException
(
...
@@ -282,8 +273,7 @@ async def get_ollama_tags(
...
@@ -282,8 +273,7 @@ async def get_ollama_tags(
@
app
.
get
(
"/api/version/{url_idx}"
)
@
app
.
get
(
"/api/version/{url_idx}"
)
async
def
get_ollama_versions
(
url_idx
:
Optional
[
int
]
=
None
):
async
def
get_ollama_versions
(
url_idx
:
Optional
[
int
]
=
None
):
if
app
.
state
.
config
.
ENABLE_OLLAMA_API
:
if
app
.
state
.
config
.
ENABLE_OLLAMA_API
:
if
url_idx
==
None
:
if
url_idx
is
None
:
# returns lowest version
# returns lowest version
tasks
=
[
tasks
=
[
fetch_url
(
f
"
{
url
}
/api/version"
)
fetch_url
(
f
"
{
url
}
/api/version"
)
...
@@ -323,7 +313,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
...
@@ -323,7 +313,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
res
=
r
.
json
()
res
=
r
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
raise
HTTPException
(
...
@@ -367,7 +357,7 @@ async def push_model(
...
@@ -367,7 +357,7 @@ async def push_model(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_admin_user
),
user
=
Depends
(
get_admin_user
),
):
):
if
url_idx
==
None
:
if
url_idx
is
None
:
if
form_data
.
name
in
app
.
state
.
MODELS
:
if
form_data
.
name
in
app
.
state
.
MODELS
:
url_idx
=
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
][
0
]
url_idx
=
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
][
0
]
else
:
else
:
...
@@ -417,7 +407,7 @@ async def copy_model(
...
@@ -417,7 +407,7 @@ async def copy_model(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_admin_user
),
user
=
Depends
(
get_admin_user
),
):
):
if
url_idx
==
None
:
if
url_idx
is
None
:
if
form_data
.
source
in
app
.
state
.
MODELS
:
if
form_data
.
source
in
app
.
state
.
MODELS
:
url_idx
=
app
.
state
.
MODELS
[
form_data
.
source
][
"urls"
][
0
]
url_idx
=
app
.
state
.
MODELS
[
form_data
.
source
][
"urls"
][
0
]
else
:
else
:
...
@@ -448,7 +438,7 @@ async def copy_model(
...
@@ -448,7 +438,7 @@ async def copy_model(
res
=
r
.
json
()
res
=
r
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
raise
HTTPException
(
...
@@ -464,7 +454,7 @@ async def delete_model(
...
@@ -464,7 +454,7 @@ async def delete_model(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_admin_user
),
user
=
Depends
(
get_admin_user
),
):
):
if
url_idx
==
None
:
if
url_idx
is
None
:
if
form_data
.
name
in
app
.
state
.
MODELS
:
if
form_data
.
name
in
app
.
state
.
MODELS
:
url_idx
=
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
][
0
]
url_idx
=
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
][
0
]
else
:
else
:
...
@@ -495,7 +485,7 @@ async def delete_model(
...
@@ -495,7 +485,7 @@ async def delete_model(
res
=
r
.
json
()
res
=
r
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
raise
HTTPException
(
...
@@ -533,7 +523,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
...
@@ -533,7 +523,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
res
=
r
.
json
()
res
=
r
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
raise
HTTPException
(
...
@@ -556,7 +546,7 @@ async def generate_embeddings(
...
@@ -556,7 +546,7 @@ async def generate_embeddings(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
user
=
Depends
(
get_verified_user
),
):
):
if
url_idx
==
None
:
if
url_idx
is
None
:
model
=
form_data
.
model
model
=
form_data
.
model
if
":"
not
in
model
:
if
":"
not
in
model
:
...
@@ -590,7 +580,7 @@ async def generate_embeddings(
...
@@ -590,7 +580,7 @@ async def generate_embeddings(
res
=
r
.
json
()
res
=
r
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
raise
HTTPException
(
...
@@ -603,10 +593,9 @@ def generate_ollama_embeddings(
...
@@ -603,10 +593,9 @@ def generate_ollama_embeddings(
form_data
:
GenerateEmbeddingsForm
,
form_data
:
GenerateEmbeddingsForm
,
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
):
):
log
.
info
(
f
"generate_ollama_embeddings
{
form_data
}
"
)
log
.
info
(
f
"generate_ollama_embeddings
{
form_data
}
"
)
if
url_idx
==
None
:
if
url_idx
is
None
:
model
=
form_data
.
model
model
=
form_data
.
model
if
":"
not
in
model
:
if
":"
not
in
model
:
...
@@ -638,7 +627,7 @@ def generate_ollama_embeddings(
...
@@ -638,7 +627,7 @@ def generate_ollama_embeddings(
if
"embedding"
in
data
:
if
"embedding"
in
data
:
return
data
[
"embedding"
]
return
data
[
"embedding"
]
else
:
else
:
raise
"Something went wrong :/"
raise
Exception
(
"Something went wrong :/"
)
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
...
@@ -647,10 +636,10 @@ def generate_ollama_embeddings(
...
@@ -647,10 +636,10 @@ def generate_ollama_embeddings(
res
=
r
.
json
()
res
=
r
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
error_detail
raise
Exception
(
error_detail
)
class
GenerateCompletionForm
(
BaseModel
):
class
GenerateCompletionForm
(
BaseModel
):
...
@@ -674,8 +663,7 @@ async def generate_completion(
...
@@ -674,8 +663,7 @@ async def generate_completion(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
user
=
Depends
(
get_verified_user
),
):
):
if
url_idx
is
None
:
if
url_idx
==
None
:
model
=
form_data
.
model
model
=
form_data
.
model
if
":"
not
in
model
:
if
":"
not
in
model
:
...
@@ -720,7 +708,6 @@ async def generate_chat_completion(
...
@@ -720,7 +708,6 @@ async def generate_chat_completion(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
user
=
Depends
(
get_verified_user
),
):
):
log
.
debug
(
log
.
debug
(
"form_data.model_dump_json(exclude_none=True).encode(): {0} "
.
format
(
"form_data.model_dump_json(exclude_none=True).encode(): {0} "
.
format
(
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
()
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
()
...
@@ -906,7 +893,7 @@ async def generate_chat_completion(
...
@@ -906,7 +893,7 @@ async def generate_chat_completion(
system
,
payload
[
"messages"
]
system
,
payload
[
"messages"
]
)
)
if
url_idx
==
None
:
if
url_idx
is
None
:
if
":"
not
in
payload
[
"model"
]:
if
":"
not
in
payload
[
"model"
]:
payload
[
"model"
]
=
f
"
{
payload
[
'model'
]
}
:latest"
payload
[
"model"
]
=
f
"
{
payload
[
'model'
]
}
:latest"
...
@@ -1016,7 +1003,7 @@ async def generate_openai_chat_completion(
...
@@ -1016,7 +1003,7 @@ async def generate_openai_chat_completion(
},
},
)
)
if
url_idx
==
None
:
if
url_idx
is
None
:
if
":"
not
in
payload
[
"model"
]:
if
":"
not
in
payload
[
"model"
]:
payload
[
"model"
]
=
f
"
{
payload
[
'model'
]
}
:latest"
payload
[
"model"
]
=
f
"
{
payload
[
'model'
]
}
:latest"
...
@@ -1044,7 +1031,7 @@ async def get_openai_models(
...
@@ -1044,7 +1031,7 @@ async def get_openai_models(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_verified_user
),
user
=
Depends
(
get_verified_user
),
):
):
if
url_idx
==
None
:
if
url_idx
is
None
:
models
=
await
get_all_models
()
models
=
await
get_all_models
()
if
app
.
state
.
config
.
ENABLE_MODEL_FILTER
:
if
app
.
state
.
config
.
ENABLE_MODEL_FILTER
:
...
@@ -1099,7 +1086,7 @@ async def get_openai_models(
...
@@ -1099,7 +1086,7 @@ async def get_openai_models(
res
=
r
.
json
()
res
=
r
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
Exception
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
raise
HTTPException
(
...
@@ -1125,7 +1112,6 @@ def parse_huggingface_url(hf_url):
...
@@ -1125,7 +1112,6 @@ def parse_huggingface_url(hf_url):
path_components
=
parsed_url
.
path
.
split
(
"/"
)
path_components
=
parsed_url
.
path
.
split
(
"/"
)
# Extract the desired output
# Extract the desired output
user_repo
=
"/"
.
join
(
path_components
[
1
:
3
])
model_file
=
path_components
[
-
1
]
model_file
=
path_components
[
-
1
]
return
model_file
return
model_file
...
@@ -1190,7 +1176,6 @@ async def download_model(
...
@@ -1190,7 +1176,6 @@ async def download_model(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_admin_user
),
user
=
Depends
(
get_admin_user
),
):
):
allowed_hosts
=
[
"https://huggingface.co/"
,
"https://github.com/"
]
allowed_hosts
=
[
"https://huggingface.co/"
,
"https://github.com/"
]
if
not
any
(
form_data
.
url
.
startswith
(
host
)
for
host
in
allowed_hosts
):
if
not
any
(
form_data
.
url
.
startswith
(
host
)
for
host
in
allowed_hosts
):
...
@@ -1199,7 +1184,7 @@ async def download_model(
...
@@ -1199,7 +1184,7 @@ async def download_model(
detail
=
"Invalid file_url. Only URLs from allowed hosts are permitted."
,
detail
=
"Invalid file_url. Only URLs from allowed hosts are permitted."
,
)
)
if
url_idx
==
None
:
if
url_idx
is
None
:
url_idx
=
0
url_idx
=
0
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
...
@@ -1222,7 +1207,7 @@ def upload_model(
...
@@ -1222,7 +1207,7 @@ def upload_model(
url_idx
:
Optional
[
int
]
=
None
,
url_idx
:
Optional
[
int
]
=
None
,
user
=
Depends
(
get_admin_user
),
user
=
Depends
(
get_admin_user
),
):
):
if
url_idx
==
None
:
if
url_idx
is
None
:
url_idx
=
0
url_idx
=
0
ollama_url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
ollama_url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment