Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
47dc3b5f
Commit
47dc3b5f
authored
Jan 05, 2024
by
Timothy J. Baek
Browse files
feat: async reverse proxy
parent
76139fc8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
25 deletions
+52
-25
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+52
-25
No files found.
backend/apps/ollama/main.py
View file @
47dc3b5f
...
@@ -11,6 +11,8 @@ from constants import ERROR_MESSAGES
...
@@ -11,6 +11,8 @@ from constants import ERROR_MESSAGES
from
utils.utils
import
decode_token
,
get_current_user
from
utils.utils
import
decode_token
,
get_current_user
from
config
import
OLLAMA_API_BASE_URL
,
WEBUI_AUTH
from
config
import
OLLAMA_API_BASE_URL
,
WEBUI_AUTH
import
aiohttp
app
=
FastAPI
()
app
=
FastAPI
()
app
.
add_middleware
(
app
.
add_middleware
(
CORSMiddleware
,
CORSMiddleware
,
...
@@ -30,8 +32,7 @@ async def get_ollama_api_url(user=Depends(get_current_user)):
...
@@ -30,8 +32,7 @@ async def get_ollama_api_url(user=Depends(get_current_user)):
if
user
and
user
.
role
==
"admin"
:
if
user
and
user
.
role
==
"admin"
:
return
{
"OLLAMA_API_BASE_URL"
:
app
.
state
.
OLLAMA_API_BASE_URL
}
return
{
"OLLAMA_API_BASE_URL"
:
app
.
state
.
OLLAMA_API_BASE_URL
}
else
:
else
:
raise
HTTPException
(
status_code
=
401
,
raise
HTTPException
(
status_code
=
401
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
class
UrlUpdateForm
(
BaseModel
):
class
UrlUpdateForm
(
BaseModel
):
...
@@ -39,14 +40,29 @@ class UrlUpdateForm(BaseModel):
...
@@ -39,14 +40,29 @@ class UrlUpdateForm(BaseModel):
@
app
.
post
(
"/url/update"
)
@
app
.
post
(
"/url/update"
)
async
def
update_ollama_api_url
(
form_data
:
UrlUpdateForm
,
async
def
update_ollama_api_url
(
user
=
Depends
(
get_current_user
)):
form_data
:
UrlUpdateForm
,
user
=
Depends
(
get_current_user
)
):
if
user
and
user
.
role
==
"admin"
:
if
user
and
user
.
role
==
"admin"
:
app
.
state
.
OLLAMA_API_BASE_URL
=
form_data
.
url
app
.
state
.
OLLAMA_API_BASE_URL
=
form_data
.
url
return
{
"OLLAMA_API_BASE_URL"
:
app
.
state
.
OLLAMA_API_BASE_URL
}
return
{
"OLLAMA_API_BASE_URL"
:
app
.
state
.
OLLAMA_API_BASE_URL
}
else
:
else
:
raise
HTTPException
(
status_code
=
401
,
raise
HTTPException
(
status_code
=
401
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
# async def fetch_sse(method, target_url, body, headers):
# async with aiohttp.ClientSession() as session:
# try:
# async with session.request(
# method, target_url, data=body, headers=headers
# ) as response:
# print(response.status)
# async for line in response.content:
# yield line
# except Exception as e:
# print(e)
# error_detail = "Ollama WebUI: Server Connection Error"
# yield json.dumps({"error": error_detail, "message": str(e)}).encode()
@
app
.
api_route
(
"/{path:path}"
,
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
@
app
.
api_route
(
"/{path:path}"
,
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
...
@@ -59,42 +75,53 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
...
@@ -59,42 +75,53 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
if
user
.
role
in
[
"user"
,
"admin"
]:
if
user
.
role
in
[
"user"
,
"admin"
]:
if
path
in
[
"pull"
,
"delete"
,
"push"
,
"copy"
,
"create"
]:
if
path
in
[
"pull"
,
"delete"
,
"push"
,
"copy"
,
"create"
]:
if
user
.
role
!=
"admin"
:
if
user
.
role
!=
"admin"
:
raise
HTTPException
(
status_code
=
401
,
raise
HTTPException
(
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
status_code
=
401
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
else
:
else
:
raise
HTTPException
(
status_code
=
401
,
raise
HTTPException
(
status_code
=
401
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
headers
.
pop
(
"Host"
,
None
)
headers
.
pop
(
"Host"
,
None
)
headers
.
pop
(
"Authorization"
,
None
)
headers
.
pop
(
"Authorization"
,
None
)
headers
.
pop
(
"Origin"
,
None
)
headers
.
pop
(
"Origin"
,
None
)
headers
.
pop
(
"Referer"
,
None
)
headers
.
pop
(
"Referer"
,
None
)
session
=
aiohttp
.
ClientSession
()
response
=
None
try
:
try
:
r
=
requests
.
request
(
response
=
await
session
.
request
(
method
=
request
.
method
,
request
.
method
,
target_url
,
data
=
body
,
headers
=
headers
url
=
target_url
,
data
=
body
,
headers
=
headers
,
stream
=
True
,
)
)
r
.
raise_for_status
()
if
not
response
.
ok
:
data
=
await
response
.
json
()
print
(
data
)
response
.
raise_for_status
()
async
def
gen
():
async
for
line
in
response
.
content
:
yield
line
await
session
.
close
()
return
StreamingResponse
(
gen
(),
response
.
status
)
return
StreamingResponse
(
r
.
iter_content
(
chunk_size
=
8192
),
status_code
=
r
.
status_code
,
headers
=
dict
(
r
.
headers
),
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
print
(
e
)
error_detail
=
"Ollama WebUI: Server Connection Error"
error_detail
=
"Ollama WebUI: Server Connection Error"
if
r
is
not
None
:
if
r
esponse
is
not
None
:
try
:
try
:
res
=
r
.
json
()
res
=
await
response
.
json
()
if
"error"
in
res
:
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
except
:
error_detail
=
f
"Ollama:
{
e
}
"
error_detail
=
f
"Ollama:
{
e
}
"
raise
HTTPException
(
status_code
=
r
.
status_code
,
detail
=
error_detail
)
await
session
.
close
()
raise
HTTPException
(
status_code
=
response
.
status
if
response
else
500
,
detail
=
error_detail
,
)
# print(e)
# error_detail = "Ollama WebUI: Server Connection Error"
# return {"error": error_detail, "message": str(e)}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment