Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
d9a3e4db
Unverified
Commit
d9a3e4db
authored
Jun 02, 2024
by
Timothy Jaeryang Baek
Committed by
GitHub
Jun 02, 2024
Browse files
Merge pull request #2731 from cheahjs/fix/ollama-cancellation
fix: ollama and openai stream cancellation
parents
27ff3ab1
c5ff4c24
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
186 additions
and
503 deletions
+186
-503
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+54
-347
backend/apps/openai/main.py
backend/apps/openai/main.py
+29
-10
src/lib/apis/ollama/index.ts
src/lib/apis/ollama/index.ts
+3
-22
src/lib/components/chat/Chat.svelte
src/lib/components/chat/Chat.svelte
+50
-63
src/lib/components/chat/ModelSelector/Selector.svelte
src/lib/components/chat/ModelSelector/Selector.svelte
+21
-21
src/lib/components/chat/Settings/Models.svelte
src/lib/components/chat/Settings/Models.svelte
+28
-27
src/lib/components/workspace/Playground.svelte
src/lib/components/workspace/Playground.svelte
+1
-13
No files found.
backend/apps/ollama/main.py
View file @
d9a3e4db
...
@@ -29,6 +29,8 @@ import time
...
@@ -29,6 +29,8 @@ import time
from
urllib.parse
import
urlparse
from
urllib.parse
import
urlparse
from
typing
import
Optional
,
List
,
Union
from
typing
import
Optional
,
List
,
Union
from
starlette.background
import
BackgroundTask
from
apps.webui.models.models
import
Models
from
apps.webui.models.models
import
Models
from
apps.webui.models.users
import
Users
from
apps.webui.models.users
import
Users
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
...
@@ -75,9 +77,6 @@ app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
...
@@ -75,9 +77,6 @@ app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app
.
state
.
MODELS
=
{}
app
.
state
.
MODELS
=
{}
REQUEST_POOL
=
[]
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
# least connections, or least response time for better resource utilization and performance optimization.
# least connections, or least response time for better resource utilization and performance optimization.
...
@@ -132,16 +131,6 @@ async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin
...
@@ -132,16 +131,6 @@ async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
config
.
OLLAMA_BASE_URLS
}
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
config
.
OLLAMA_BASE_URLS
}
@
app
.
get
(
"/cancel/{request_id}"
)
async
def
cancel_ollama_request
(
request_id
:
str
,
user
=
Depends
(
get_current_user
)):
if
user
:
if
request_id
in
REQUEST_POOL
:
REQUEST_POOL
.
remove
(
request_id
)
return
True
else
:
raise
HTTPException
(
status_code
=
401
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
async
def
fetch_url
(
url
):
async
def
fetch_url
(
url
):
timeout
=
aiohttp
.
ClientTimeout
(
total
=
5
)
timeout
=
aiohttp
.
ClientTimeout
(
total
=
5
)
try
:
try
:
...
@@ -154,6 +143,45 @@ async def fetch_url(url):
...
@@ -154,6 +143,45 @@ async def fetch_url(url):
return
None
return
None
async
def
cleanup_response
(
response
:
Optional
[
aiohttp
.
ClientResponse
],
session
:
Optional
[
aiohttp
.
ClientSession
],
):
if
response
:
response
.
close
()
if
session
:
await
session
.
close
()
async
def
post_streaming_url
(
url
:
str
,
payload
:
str
):
r
=
None
try
:
session
=
aiohttp
.
ClientSession
()
r
=
await
session
.
post
(
url
,
data
=
payload
)
r
.
raise_for_status
()
return
StreamingResponse
(
r
.
content
,
status_code
=
r
.
status
,
headers
=
dict
(
r
.
headers
),
background
=
BackgroundTask
(
cleanup_response
,
response
=
r
,
session
=
session
),
)
except
Exception
as
e
:
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
res
=
await
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
status_code
=
r
.
status
if
r
else
500
,
detail
=
error_detail
,
)
def
merge_models_lists
(
model_lists
):
def
merge_models_lists
(
model_lists
):
merged_models
=
{}
merged_models
=
{}
...
@@ -313,65 +341,7 @@ async def pull_model(
...
@@ -313,65 +341,7 @@ async def pull_model(
# Admin should be able to pull models from any source
# Admin should be able to pull models from any source
payload
=
{
**
form_data
.
model_dump
(
exclude_none
=
True
),
"insecure"
:
True
}
payload
=
{
**
form_data
.
model_dump
(
exclude_none
=
True
),
"insecure"
:
True
}
def
get_request
():
return
await
post_streaming_url
(
f
"
{
url
}
/api/pull"
,
json
.
dumps
(
payload
))
nonlocal
url
nonlocal
r
request_id
=
str
(
uuid
.
uuid4
())
try
:
REQUEST_POOL
.
append
(
request_id
)
def
stream_content
():
try
:
yield
json
.
dumps
({
"id"
:
request_id
,
"done"
:
False
})
+
"
\n
"
for
chunk
in
r
.
iter_content
(
chunk_size
=
8192
):
if
request_id
in
REQUEST_POOL
:
yield
chunk
else
:
log
.
warning
(
"User: canceled request"
)
break
finally
:
if
hasattr
(
r
,
"close"
):
r
.
close
()
if
request_id
in
REQUEST_POOL
:
REQUEST_POOL
.
remove
(
request_id
)
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/pull"
,
data
=
json
.
dumps
(
payload
),
stream
=
True
,
)
r
.
raise_for_status
()
return
StreamingResponse
(
stream_content
(),
status_code
=
r
.
status_code
,
headers
=
dict
(
r
.
headers
),
)
except
Exception
as
e
:
raise
e
try
:
return
await
run_in_threadpool
(
get_request
)
except
Exception
as
e
:
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
status_code
=
r
.
status_code
if
r
else
500
,
detail
=
error_detail
,
)
class
PushModelForm
(
BaseModel
):
class
PushModelForm
(
BaseModel
):
...
@@ -399,50 +369,9 @@ async def push_model(
...
@@ -399,50 +369,9 @@ async def push_model(
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
debug
(
f
"url:
{
url
}
"
)
log
.
debug
(
f
"url:
{
url
}
"
)
r
=
None
return
await
post_streaming_url
(
f
"
{
url
}
/api/push"
,
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
()
def
get_request
():
)
nonlocal
url
nonlocal
r
try
:
def
stream_content
():
for
chunk
in
r
.
iter_content
(
chunk_size
=
8192
):
yield
chunk
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/push"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
)
r
.
raise_for_status
()
return
StreamingResponse
(
stream_content
(),
status_code
=
r
.
status_code
,
headers
=
dict
(
r
.
headers
),
)
except
Exception
as
e
:
raise
e
try
:
return
await
run_in_threadpool
(
get_request
)
except
Exception
as
e
:
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
status_code
=
r
.
status_code
if
r
else
500
,
detail
=
error_detail
,
)
class
CreateModelForm
(
BaseModel
):
class
CreateModelForm
(
BaseModel
):
...
@@ -461,53 +390,9 @@ async def create_model(
...
@@ -461,53 +390,9 @@ async def create_model(
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
return
await
post_streaming_url
(
f
"
{
url
}
/api/create"
,
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
()
def
get_request
():
)
nonlocal
url
nonlocal
r
try
:
def
stream_content
():
for
chunk
in
r
.
iter_content
(
chunk_size
=
8192
):
yield
chunk
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/create"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
stream
=
True
,
)
r
.
raise_for_status
()
log
.
debug
(
f
"r:
{
r
}
"
)
return
StreamingResponse
(
stream_content
(),
status_code
=
r
.
status_code
,
headers
=
dict
(
r
.
headers
),
)
except
Exception
as
e
:
raise
e
try
:
return
await
run_in_threadpool
(
get_request
)
except
Exception
as
e
:
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
status_code
=
r
.
status_code
if
r
else
500
,
detail
=
error_detail
,
)
class
CopyModelForm
(
BaseModel
):
class
CopyModelForm
(
BaseModel
):
...
@@ -797,66 +682,9 @@ async def generate_completion(
...
@@ -797,66 +682,9 @@ async def generate_completion(
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
return
await
post_streaming_url
(
f
"
{
url
}
/api/generate"
,
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
()
def
get_request
():
)
nonlocal
form_data
nonlocal
r
request_id
=
str
(
uuid
.
uuid4
())
try
:
REQUEST_POOL
.
append
(
request_id
)
def
stream_content
():
try
:
if
form_data
.
stream
:
yield
json
.
dumps
({
"id"
:
request_id
,
"done"
:
False
})
+
"
\n
"
for
chunk
in
r
.
iter_content
(
chunk_size
=
8192
):
if
request_id
in
REQUEST_POOL
:
yield
chunk
else
:
log
.
warning
(
"User: canceled request"
)
break
finally
:
if
hasattr
(
r
,
"close"
):
r
.
close
()
if
request_id
in
REQUEST_POOL
:
REQUEST_POOL
.
remove
(
request_id
)
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/generate"
,
data
=
form_data
.
model_dump_json
(
exclude_none
=
True
).
encode
(),
stream
=
True
,
)
r
.
raise_for_status
()
return
StreamingResponse
(
stream_content
(),
status_code
=
r
.
status_code
,
headers
=
dict
(
r
.
headers
),
)
except
Exception
as
e
:
raise
e
try
:
return
await
run_in_threadpool
(
get_request
)
except
Exception
as
e
:
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
status_code
=
r
.
status_code
if
r
else
500
,
detail
=
error_detail
,
)
class
ChatMessage
(
BaseModel
):
class
ChatMessage
(
BaseModel
):
...
@@ -1014,67 +842,7 @@ async def generate_chat_completion(
...
@@ -1014,67 +842,7 @@ async def generate_chat_completion(
print
(
payload
)
print
(
payload
)
r
=
None
return
await
post_streaming_url
(
f
"
{
url
}
/api/chat"
,
json
.
dumps
(
payload
))
def
get_request
():
nonlocal
payload
nonlocal
r
request_id
=
str
(
uuid
.
uuid4
())
try
:
REQUEST_POOL
.
append
(
request_id
)
def
stream_content
():
try
:
if
payload
.
get
(
"stream"
,
None
):
yield
json
.
dumps
({
"id"
:
request_id
,
"done"
:
False
})
+
"
\n
"
for
chunk
in
r
.
iter_content
(
chunk_size
=
8192
):
if
request_id
in
REQUEST_POOL
:
yield
chunk
else
:
log
.
warning
(
"User: canceled request"
)
break
finally
:
if
hasattr
(
r
,
"close"
):
r
.
close
()
if
request_id
in
REQUEST_POOL
:
REQUEST_POOL
.
remove
(
request_id
)
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/api/chat"
,
data
=
json
.
dumps
(
payload
),
stream
=
True
,
)
r
.
raise_for_status
()
return
StreamingResponse
(
stream_content
(),
status_code
=
r
.
status_code
,
headers
=
dict
(
r
.
headers
),
)
except
Exception
as
e
:
log
.
exception
(
e
)
raise
e
try
:
return
await
run_in_threadpool
(
get_request
)
except
Exception
as
e
:
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
status_code
=
r
.
status_code
if
r
else
500
,
detail
=
error_detail
,
)
# TODO: we should update this part once Ollama supports other types
# TODO: we should update this part once Ollama supports other types
...
@@ -1165,68 +933,7 @@ async def generate_openai_chat_completion(
...
@@ -1165,68 +933,7 @@ async def generate_openai_chat_completion(
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
return
await
post_streaming_url
(
f
"
{
url
}
/v1/chat/completions"
,
json
.
dumps
(
payload
))
def
get_request
():
nonlocal
payload
nonlocal
r
request_id
=
str
(
uuid
.
uuid4
())
try
:
REQUEST_POOL
.
append
(
request_id
)
def
stream_content
():
try
:
if
payload
.
get
(
"stream"
):
yield
json
.
dumps
(
{
"request_id"
:
request_id
,
"done"
:
False
}
)
+
"
\n
"
for
chunk
in
r
.
iter_content
(
chunk_size
=
8192
):
if
request_id
in
REQUEST_POOL
:
yield
chunk
else
:
log
.
warning
(
"User: canceled request"
)
break
finally
:
if
hasattr
(
r
,
"close"
):
r
.
close
()
if
request_id
in
REQUEST_POOL
:
REQUEST_POOL
.
remove
(
request_id
)
r
=
requests
.
request
(
method
=
"POST"
,
url
=
f
"
{
url
}
/v1/chat/completions"
,
data
=
json
.
dumps
(
payload
),
stream
=
True
,
)
r
.
raise_for_status
()
return
StreamingResponse
(
stream_content
(),
status_code
=
r
.
status_code
,
headers
=
dict
(
r
.
headers
),
)
except
Exception
as
e
:
raise
e
try
:
return
await
run_in_threadpool
(
get_request
)
except
Exception
as
e
:
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
try
:
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
status_code
=
r
.
status_code
if
r
else
500
,
detail
=
error_detail
,
)
@
app
.
get
(
"/v1/models"
)
@
app
.
get
(
"/v1/models"
)
...
@@ -1555,7 +1262,7 @@ async def deprecated_proxy(
...
@@ -1555,7 +1262,7 @@ async def deprecated_proxy(
if
path
==
"generate"
:
if
path
==
"generate"
:
data
=
json
.
loads
(
body
.
decode
(
"utf-8"
))
data
=
json
.
loads
(
body
.
decode
(
"utf-8"
))
if
not
(
"stream"
in
data
and
data
[
"stream"
]
==
Fals
e
):
if
data
.
get
(
"stream"
,
Tru
e
):
yield
json
.
dumps
({
"id"
:
request_id
,
"done"
:
False
})
+
"
\n
"
yield
json
.
dumps
({
"id"
:
request_id
,
"done"
:
False
})
+
"
\n
"
elif
path
==
"chat"
:
elif
path
==
"chat"
:
...
...
backend/apps/openai/main.py
View file @
d9a3e4db
...
@@ -9,6 +9,7 @@ import json
...
@@ -9,6 +9,7 @@ import json
import
logging
import
logging
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
starlette.background
import
BackgroundTask
from
apps.webui.models.models
import
Models
from
apps.webui.models.models
import
Models
from
apps.webui.models.users
import
Users
from
apps.webui.models.users
import
Users
...
@@ -194,6 +195,16 @@ async def fetch_url(url, key):
...
@@ -194,6 +195,16 @@ async def fetch_url(url, key):
return
None
return
None
async
def
cleanup_response
(
response
:
Optional
[
aiohttp
.
ClientResponse
],
session
:
Optional
[
aiohttp
.
ClientSession
],
):
if
response
:
response
.
close
()
if
session
:
await
session
.
close
()
def
merge_models_lists
(
model_lists
):
def
merge_models_lists
(
model_lists
):
log
.
debug
(
f
"merge_models_lists
{
model_lists
}
"
)
log
.
debug
(
f
"merge_models_lists
{
model_lists
}
"
)
merged_list
=
[]
merged_list
=
[]
...
@@ -447,40 +458,48 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
...
@@ -447,40 +458,48 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
headers
[
"Content-Type"
]
=
"application/json"
headers
[
"Content-Type"
]
=
"application/json"
r
=
None
r
=
None
session
=
None
streaming
=
False
try
:
try
:
r
=
requests
.
request
(
session
=
aiohttp
.
ClientSession
()
r
=
await
session
.
request
(
method
=
request
.
method
,
method
=
request
.
method
,
url
=
target_url
,
url
=
target_url
,
data
=
payload
if
payload
else
body
,
data
=
payload
if
payload
else
body
,
headers
=
headers
,
headers
=
headers
,
stream
=
True
,
)
)
r
.
raise_for_status
()
r
.
raise_for_status
()
# Check if response is SSE
# Check if response is SSE
if
"text/event-stream"
in
r
.
headers
.
get
(
"Content-Type"
,
""
):
if
"text/event-stream"
in
r
.
headers
.
get
(
"Content-Type"
,
""
):
streaming
=
True
return
StreamingResponse
(
return
StreamingResponse
(
r
.
iter_
content
(
chunk_size
=
8192
)
,
r
.
content
,
status_code
=
r
.
status
_code
,
status_code
=
r
.
status
,
headers
=
dict
(
r
.
headers
),
headers
=
dict
(
r
.
headers
),
background
=
BackgroundTask
(
cleanup_response
,
response
=
r
,
session
=
session
),
)
)
else
:
else
:
response_data
=
r
.
json
()
response_data
=
await
r
.
json
()
return
response_data
return
response_data
except
Exception
as
e
:
except
Exception
as
e
:
log
.
exception
(
e
)
log
.
exception
(
e
)
error_detail
=
"Open WebUI: Server Connection Error"
error_detail
=
"Open WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
is
not
None
:
try
:
try
:
res
=
r
.
json
()
res
=
await
r
.
json
()
print
(
res
)
print
(
res
)
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"External:
{
res
[
'error'
][
'message'
]
if
'message'
in
res
[
'error'
]
else
res
[
'error'
]
}
"
error_detail
=
f
"External:
{
res
[
'error'
][
'message'
]
if
'message'
in
res
[
'error'
]
else
res
[
'error'
]
}
"
except
:
except
:
error_detail
=
f
"External:
{
e
}
"
error_detail
=
f
"External:
{
e
}
"
raise
HTTPException
(
status_code
=
r
.
status
if
r
else
500
,
detail
=
error_detail
)
raise
HTTPException
(
finally
:
status_code
=
r
.
status_code
if
r
else
500
,
detail
=
error_detail
if
not
streaming
and
session
:
)
if
r
:
r
.
close
()
await
session
.
close
()
src/lib/apis/ollama/index.ts
View file @
d9a3e4db
...
@@ -369,27 +369,6 @@ export const generateChatCompletion = async (token: string = '', body: object) =
...
@@ -369,27 +369,6 @@ export const generateChatCompletion = async (token: string = '', body: object) =
return
[
res
,
controller
];
return
[
res
,
controller
];
};
};
export
const
cancelOllamaRequest
=
async
(
token
:
string
=
''
,
requestId
:
string
)
=>
{
let
error
=
null
;
const
res
=
await
fetch
(
`
${
OLLAMA_API_BASE_URL
}
/cancel/
${
requestId
}
`
,
{
method
:
'
GET
'
,
headers
:
{
'
Content-Type
'
:
'
text/event-stream
'
,
Authorization
:
`Bearer
${
token
}
`
}
}).
catch
((
err
)
=>
{
error
=
err
;
return
null
;
});
if
(
error
)
{
throw
error
;
}
return
res
;
};
export
const
createModel
=
async
(
token
:
string
,
tagName
:
string
,
content
:
string
)
=>
{
export
const
createModel
=
async
(
token
:
string
,
tagName
:
string
,
content
:
string
)
=>
{
let
error
=
null
;
let
error
=
null
;
...
@@ -461,8 +440,10 @@ export const deleteModel = async (token: string, tagName: string, urlIdx: string
...
@@ -461,8 +440,10 @@ export const deleteModel = async (token: string, tagName: string, urlIdx: string
export
const
pullModel
=
async
(
token
:
string
,
tagName
:
string
,
urlIdx
:
string
|
null
=
null
)
=>
{
export
const
pullModel
=
async
(
token
:
string
,
tagName
:
string
,
urlIdx
:
string
|
null
=
null
)
=>
{
let
error
=
null
;
let
error
=
null
;
const
controller
=
new
AbortController
();
const
res
=
await
fetch
(
`
${
OLLAMA_API_BASE_URL
}
/api/pull
${
urlIdx
!==
null
?
`/
${
urlIdx
}
`
:
''
}
`
,
{
const
res
=
await
fetch
(
`
${
OLLAMA_API_BASE_URL
}
/api/pull
${
urlIdx
!==
null
?
`/
${
urlIdx
}
`
:
''
}
`
,
{
signal
:
controller
.
signal
,
method
:
'
POST
'
,
method
:
'
POST
'
,
headers
:
{
headers
:
{
Accept
:
'
application/json
'
,
Accept
:
'
application/json
'
,
...
@@ -485,7 +466,7 @@ export const pullModel = async (token: string, tagName: string, urlIdx: string |
...
@@ -485,7 +466,7 @@ export const pullModel = async (token: string, tagName: string, urlIdx: string |
if
(
error
)
{
if
(
error
)
{
throw
error
;
throw
error
;
}
}
return
res
;
return
[
res
,
controller
]
;
};
};
export
const
downloadModel
=
async
(
export
const
downloadModel
=
async
(
...
...
src/lib/components/chat/Chat.svelte
View file @
d9a3e4db
...
@@ -26,7 +26,7 @@
...
@@ -26,7 +26,7 @@
splitStream
splitStream
}
from
'$lib/utils'
;
}
from
'$lib/utils'
;
import
{
cancelOllamaRequest
,
generateChatCompletion
}
from
'$lib/apis/ollama'
;
import
{
generateChatCompletion
}
from
'$lib/apis/ollama'
;
import
{
import
{
addTagById
,
addTagById
,
createNewChat
,
createNewChat
,
...
@@ -65,7 +65,6 @@
...
@@ -65,7 +65,6 @@
let
autoScroll
=
true
;
let
autoScroll
=
true
;
let
processing
=
''
;
let
processing
=
''
;
let
messagesContainerElement
:
HTMLDivElement
;
let
messagesContainerElement
:
HTMLDivElement
;
let
currentRequestId
=
null
;
let
showModelSelector
=
true
;
let
showModelSelector
=
true
;
...
@@ -130,10 +129,6 @@
...
@@ -130,10 +129,6 @@
//////////////////////////
//////////////////////////
const
initNewChat
=
async
()
=>
{
const
initNewChat
=
async
()
=>
{
if
(
currentRequestId
!== null) {
await
cancelOllamaRequest
(
localStorage
.
token
,
currentRequestId
);
currentRequestId
=
null
;
}
window
.
history
.
replaceState
(
history
.
state
,
''
,
`/`);
window
.
history
.
replaceState
(
history
.
state
,
''
,
`/`);
await
chatId
.
set
(
''
);
await
chatId
.
set
(
''
);
...
@@ -616,7 +611,6 @@
...
@@ -616,7 +611,6 @@
if (stopResponseFlag) {
if (stopResponseFlag) {
controller.abort('User: Stop Response');
controller.abort('User: Stop Response');
await cancelOllamaRequest(localStorage.token, currentRequestId);
} else {
} else {
const messages = createMessagesList(responseMessageId);
const messages = createMessagesList(responseMessageId);
const res = await chatCompleted(localStorage.token, {
const res = await chatCompleted(localStorage.token, {
...
@@ -647,8 +641,6 @@
...
@@ -647,8 +641,6 @@
}
}
}
}
currentRequestId = null;
break;
break;
}
}
...
@@ -669,63 +661,58 @@
...
@@ -669,63 +661,58 @@
throw data;
throw data;
}
}
if ('id' in data) {
if (data.done == false) {
console.log(data);
if (responseMessage.content == '' && data.message.content == '
\n
') {
currentRequestId = data.id;
continue;
} else {
if (data.done == false) {
if (responseMessage.content == '' && data.message.content == '
\n
') {
continue;
} else {
responseMessage.content += data.message.content;
messages = messages;
}
} else {
} else {
responseMessage.done = true;
responseMessage.content += data.message.content;
if (responseMessage.content == '') {
responseMessage.error = {
code: 400,
content: `Oops! No text generated from Ollama, Please try again.`
};
}
responseMessage.context = data.context ?? null;
responseMessage.info = {
total_duration: data.total_duration,
load_duration: data.load_duration,
sample_count: data.sample_count,
sample_duration: data.sample_duration,
prompt_eval_count: data.prompt_eval_count,
prompt_eval_duration: data.prompt_eval_duration,
eval_count: data.eval_count,
eval_duration: data.eval_duration
};
messages = messages;
messages = messages;
}
} else {
responseMessage.done = true;
if ($settings.notificationEnabled && !document.hasFocus()) {
if (responseMessage.content == '') {
const notification = new Notification(
responseMessage.error = {
selectedModelfile
code: 400,
? `${
content: `Oops! No text generated from Ollama, Please try again.`
selectedModelfile.title.charAt(0).toUpperCase() +
};
selectedModelfile.title.slice(1)
}
}`
: `${model}`,
responseMessage.context = data.context ?? null;
{
responseMessage.info = {
body: responseMessage.content,
total_duration: data.total_duration,
icon: selectedModelfile?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png`
load_duration: data.load_duration,
}
sample_count: data.sample_count,
);
sample_duration: data.sample_duration,
}
prompt_eval_count: data.prompt_eval_count,
prompt_eval_duration: data.prompt_eval_duration,
if ($settings.responseAutoCopy) {
eval_count: data.eval_count,
copyToClipboard(responseMessage.content);
eval_duration: data.eval_duration
}
};
messages = messages;
if ($settings.responseAutoPlayback) {
await tick();
if ($settings.notificationEnabled && !document.hasFocus()) {
document.getElementById(`speak-button-${responseMessage.id}`)?.click();
const notification = new Notification(
}
selectedModelfile
? `${
selectedModelfile.title.charAt(0).toUpperCase() +
selectedModelfile.title.slice(1)
}`
: `${model}`,
{
body: responseMessage.content,
icon: selectedModelfile?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png`
}
);
}
if ($settings.responseAutoCopy) {
copyToClipboard(responseMessage.content);
}
if ($settings.responseAutoPlayback) {
await tick();
document.getElementById(`speak-button-${responseMessage.id}`)?.click();
}
}
}
}
}
}
...
...
src/lib/components/chat/ModelSelector/Selector.svelte
View file @
d9a3e4db
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
import Check from '$lib/components/icons/Check.svelte';
import Check from '$lib/components/icons/Check.svelte';
import Search from '$lib/components/icons/Search.svelte';
import Search from '$lib/components/icons/Search.svelte';
import {
cancelOllamaRequest,
deleteModel, getOllamaVersion, pullModel } from '$lib/apis/ollama';
import { deleteModel, getOllamaVersion, pullModel } from '$lib/apis/ollama';
import { user, MODEL_DOWNLOAD_POOL, models, mobile } from '$lib/stores';
import { user, MODEL_DOWNLOAD_POOL, models, mobile } from '$lib/stores';
import { toast } from 'svelte-sonner';
import { toast } from 'svelte-sonner';
...
@@ -72,10 +72,12 @@
...
@@ -72,10 +72,12 @@
return;
return;
}
}
const res = await pullModel(localStorage.token, sanitizedModelTag, '0').catch((error) => {
const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, '0').catch(
toast.error(error);
(error) => {
return null;
toast.error(error);
});
return null;
}
);
if (res) {
if (res) {
const reader = res.body
const reader = res.body
...
@@ -83,6 +85,16 @@
...
@@ -83,6 +85,16 @@
.pipeThrough(splitStream('\n'))
.pipeThrough(splitStream('\n'))
.getReader();
.getReader();
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
abortController: controller,
reader,
done: false
}
});
while (true) {
while (true) {
try {
try {
const { value, done } = await reader.read();
const { value, done } = await reader.read();
...
@@ -101,19 +113,6 @@
...
@@ -101,19 +113,6 @@
throw data.detail;
throw data.detail;
}
}
if (data.id) {
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
requestId: data.id,
reader,
done: false
}
});
console.log(data);
}
if (data.status) {
if (data.status) {
if (data.digest) {
if (data.digest) {
let downloadProgress = 0;
let downloadProgress = 0;
...
@@ -181,11 +180,12 @@
...
@@ -181,11 +180,12 @@
});
});
const cancelModelPullHandler = async (model: string) => {
const cancelModelPullHandler = async (model: string) => {
const { reader, requestId } = $MODEL_DOWNLOAD_POOL[model];
const { reader, abortController } = $MODEL_DOWNLOAD_POOL[model];
if (abortController) {
abortController.abort();
}
if (reader) {
if (reader) {
await reader.cancel();
await reader.cancel();
await cancelOllamaRequest(localStorage.token, requestId);
delete $MODEL_DOWNLOAD_POOL[model];
delete $MODEL_DOWNLOAD_POOL[model];
MODEL_DOWNLOAD_POOL.set({
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL
...$MODEL_DOWNLOAD_POOL
...
...
src/lib/components/chat/Settings/Models.svelte
View file @
d9a3e4db
...
@@ -8,7 +8,6 @@
...
@@ -8,7 +8,6 @@
getOllamaUrls,
getOllamaUrls,
getOllamaVersion,
getOllamaVersion,
pullModel,
pullModel,
cancelOllamaRequest,
uploadModel,
uploadModel,
getOllamaConfig
getOllamaConfig
} from '$lib/apis/ollama';
} from '$lib/apis/ollama';
...
@@ -70,12 +69,14 @@
...
@@ -70,12 +69,14 @@
console.log(model);
console.log(model);
updateModelId = model.id;
updateModelId = model.id;
const res = await pullModel(localStorage.token, model.id, selectedOllamaUrlIdx).catch(
const [res, controller] = await pullModel(
(error) => {
localStorage.token,
toast.error(error);
model.id,
return null;
selectedOllamaUrlIdx
}
).catch((error) => {
);
toast.error(error);
return null;
});
if (res) {
if (res) {
const reader = res.body
const reader = res.body
...
@@ -144,10 +145,12 @@
...
@@ -144,10 +145,12 @@
return;
return;
}
}
const res = await pullModel(localStorage.token, sanitizedModelTag, '0').catch((error) => {
const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, '0').catch(
toast.error(error);
(error) => {
return null;
toast.error(error);
});
return null;
}
);
if (res) {
if (res) {
const reader = res.body
const reader = res.body
...
@@ -155,6 +158,16 @@
...
@@ -155,6 +158,16 @@
.pipeThrough(splitStream('\n'))
.pipeThrough(splitStream('\n'))
.getReader();
.getReader();
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
abortController: controller,
reader,
done: false
}
});
while (true) {
while (true) {
try {
try {
const { value, done } = await reader.read();
const { value, done } = await reader.read();
...
@@ -173,19 +186,6 @@
...
@@ -173,19 +186,6 @@
throw data.detail;
throw data.detail;
}
}
if (data.id) {
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL,
[sanitizedModelTag]: {
...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
requestId: data.id,
reader,
done: false
}
});
console.log(data);
}
if (data.status) {
if (data.status) {
if (data.digest) {
if (data.digest) {
let downloadProgress = 0;
let downloadProgress = 0;
...
@@ -419,11 +419,12 @@
...
@@ -419,11 +419,12 @@
};
};
const cancelModelPullHandler = async (model: string) => {
const cancelModelPullHandler = async (model: string) => {
const { reader, requestId } = $MODEL_DOWNLOAD_POOL[model];
const { reader, abortController } = $MODEL_DOWNLOAD_POOL[model];
if (abortController) {
abortController.abort();
}
if (reader) {
if (reader) {
await reader.cancel();
await reader.cancel();
await cancelOllamaRequest(localStorage.token, requestId);
delete $MODEL_DOWNLOAD_POOL[model];
delete $MODEL_DOWNLOAD_POOL[model];
MODEL_DOWNLOAD_POOL.set({
MODEL_DOWNLOAD_POOL.set({
...$MODEL_DOWNLOAD_POOL
...$MODEL_DOWNLOAD_POOL
...
...
src/lib/components/workspace/Playground.svelte
View file @
d9a3e4db
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_API_BASE_URL } from '$lib/constants';
import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_API_BASE_URL } from '$lib/constants';
import { WEBUI_NAME, config, user, models, settings } from '$lib/stores';
import { WEBUI_NAME, config, user, models, settings } from '$lib/stores';
import {
cancelOllamaRequest,
generateChatCompletion } from '$lib/apis/ollama';
import { generateChatCompletion } from '$lib/apis/ollama';
import { generateOpenAIChatCompletion } from '$lib/apis/openai';
import { generateOpenAIChatCompletion } from '$lib/apis/openai';
import { splitStream } from '$lib/utils';
import { splitStream } from '$lib/utils';
...
@@ -24,7 +24,6 @@
...
@@ -24,7 +24,6 @@
let selectedModelId = '';
let selectedModelId = '';
let loading = false;
let loading = false;
let currentRequestId = null;
let stopResponseFlag = false;
let stopResponseFlag = false;
let messagesContainerElement: HTMLDivElement;
let messagesContainerElement: HTMLDivElement;
...
@@ -46,14 +45,6 @@
...
@@ -46,14 +45,6 @@
}
}
};
};
// const cancelHandler = async () => {
// if (currentRequestId) {
// const res = await cancelOllamaRequest(localStorage.token, currentRequestId);
// currentRequestId = null;
// loading = false;
// }
// };
const stopResponse = () => {
const stopResponse = () => {
stopResponseFlag = true;
stopResponseFlag = true;
console.log('stopResponse');
console.log('stopResponse');
...
@@ -171,8 +162,6 @@
...
@@ -171,8 +162,6 @@
if (stopResponseFlag) {
if (stopResponseFlag) {
controller.abort('User: Stop Response');
controller.abort('User: Stop Response');
}
}
currentRequestId = null;
break;
break;
}
}
...
@@ -229,7 +218,6 @@
...
@@ -229,7 +218,6 @@
loading = false;
loading = false;
stopResponseFlag = false;
stopResponseFlag = false;
currentRequestId = null;
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment