Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
6b9453e2
"src/lib/vscode:/vscode.git/clone" did not exist on "db817fcf29cdeede00244c8eb8bbbc726fdf4b39"
Unverified
Commit
6b9453e2
authored
Jan 05, 2024
by
Timothy Jaeryang Baek
Committed by
GitHub
Jan 05, 2024
Browse files
Merge pull request #398 from ollama-webui/proxy-fix
fix: backend proxy
parents
439185be
bb297126
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
120 additions
and
188 deletions
+120
-188
Dockerfile
Dockerfile
+2
-3
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+28
-44
backend/apps/ollama/old_main.py
backend/apps/ollama/old_main.py
+90
-141
No files found.
Dockerfile
View file @
6b9453e2
...
...
@@ -12,10 +12,9 @@ RUN npm run build
FROM
python:3.11-slim-buster as base
ARG
OLLAMA_API_BASE_URL='/ollama/api'
ENV
ENV=prod
ENV
OLLAMA_API_BASE_URL $OLLAMA_API_BASE_URL
ENV
OLLAMA_API_BASE_URL "/ollama/api"
ENV
OPENAI_API_BASE_URL ""
ENV
OPENAI_API_KEY ""
...
...
backend/apps/ollama/main.py
View file @
6b9453e2
from
fastapi
import
FastAPI
,
Request
,
Response
,
HTTPException
,
Depends
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
StreamingResponse
from
fastapi.concurrency
import
run_in_threadpool
import
requests
import
json
...
...
@@ -11,8 +12,6 @@ from constants import ERROR_MESSAGES
from
utils.utils
import
decode_token
,
get_current_user
from
config
import
OLLAMA_API_BASE_URL
,
WEBUI_AUTH
import
aiohttp
app
=
FastAPI
()
app
.
add_middleware
(
CORSMiddleware
,
...
...
@@ -50,25 +49,9 @@ async def update_ollama_api_url(
raise
HTTPException
(
status_code
=
401
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
# async def fetch_sse(method, target_url, body, headers):
# async with aiohttp.ClientSession() as session:
# try:
# async with session.request(
# method, target_url, data=body, headers=headers
# ) as response:
# print(response.status)
# async for line in response.content:
# yield line
# except Exception as e:
# print(e)
# error_detail = "Ollama WebUI: Server Connection Error"
# yield json.dumps({"error": error_detail, "message": str(e)}).encode()
@
app
.
api_route
(
"/{path:path}"
,
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
async
def
proxy
(
path
:
str
,
request
:
Request
,
user
=
Depends
(
get_current_user
)):
target_url
=
f
"
{
app
.
state
.
OLLAMA_API_BASE_URL
}
/
{
path
}
"
print
(
target_url
)
body
=
await
request
.
body
()
headers
=
dict
(
request
.
headers
)
...
...
@@ -87,41 +70,42 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
headers
.
pop
(
"Origin"
,
None
)
headers
.
pop
(
"Referer"
,
None
)
session
=
aiohttp
.
ClientSession
()
response
=
None
try
:
response
=
await
session
.
request
(
request
.
method
,
target_url
,
data
=
body
,
headers
=
headers
)
print
(
response
)
if
not
response
.
ok
:
data
=
await
response
.
json
()
print
(
data
)
response
.
raise_for_status
()
async
def
generate
():
async
for
line
in
response
.
content
:
print
(
line
)
yield
line
await
session
.
close
()
return
StreamingResponse
(
generate
(),
response
.
status
)
r
=
None
def
get_request
():
nonlocal
r
try
:
r
=
requests
.
request
(
method
=
request
.
method
,
url
=
target_url
,
data
=
body
,
headers
=
headers
,
stream
=
True
,
)
r
.
raise_for_status
()
return
StreamingResponse
(
r
.
iter_content
(
chunk_size
=
8192
),
status_code
=
r
.
status_code
,
headers
=
dict
(
r
.
headers
),
)
except
Exception
as
e
:
raise
e
try
:
return
await
run_in_threadpool
(
get_request
)
except
Exception
as
e
:
print
(
e
)
error_detail
=
"Ollama WebUI: Server Connection Error"
if
response
is
not
None
:
if
r
is
not
None
:
try
:
res
=
await
response
.
json
()
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
error_detail
=
f
"Ollama:
{
e
}
"
await
session
.
close
()
raise
HTTPException
(
status_code
=
r
esponse
.
status
if
response
else
500
,
status_code
=
r
.
status_code
if
r
else
500
,
detail
=
error_detail
,
)
backend/apps/ollama/old_main.py
View file @
6b9453e2
from
flask
import
Flask
,
request
,
Response
,
jsonify
from
flask_cors
import
CORS
from
fastapi
import
FastAPI
,
Request
,
Response
,
HTTPException
,
Depends
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
StreamingResponse
import
requests
import
json
from
pydantic
import
BaseModel
from
apps.web.models.users
import
Users
from
constants
import
ERROR_MESSAGES
from
utils.utils
import
decode_token
from
utils.utils
import
decode_token
,
get_current_user
from
config
import
OLLAMA_API_BASE_URL
,
WEBUI_AUTH
app
=
Flask
(
__name__
)
CORS
(
app
)
# Enable Cross-Origin Resource Sharing (CORS) to allow requests from different domains
import
aiohttp
# Define the target server URL
TARGET_SERVER_URL
=
OLLAMA_API_BASE_URL
app
=
FastAPI
()
app
.
add_middleware
(
CORSMiddleware
,
allow_origins
=
[
"*"
],
allow_credentials
=
True
,
allow_methods
=
[
"*"
],
allow_headers
=
[
"*"
],
)
app
.
state
.
OLLAMA_API_BASE_URL
=
OLLAMA_API_BASE_URL
@
app
.
route
(
"/url"
,
methods
=
[
"GET"
])
def
get_ollama_api_url
():
headers
=
dict
(
request
.
headers
)
if
"Authorization"
in
headers
:
_
,
credentials
=
headers
[
"Authorization"
].
split
()
token_data
=
decode_token
(
credentials
)
if
token_data
is
None
or
"email"
not
in
token_data
:
return
jsonify
({
"detail"
:
ERROR_MESSAGES
.
UNAUTHORIZED
}),
401
user
=
Users
.
get_user_by_email
(
token_data
[
"email"
])
if
user
and
user
.
role
==
"admin"
:
return
(
jsonify
({
"OLLAMA_API_BASE_URL"
:
TARGET_SERVER_URL
}),
200
,
)
else
:
return
(
jsonify
({
"detail"
:
ERROR_MESSAGES
.
ACCESS_PROHIBITED
}),
401
,
)
else
:
return
(
jsonify
({
"detail"
:
ERROR_MESSAGES
.
UNAUTHORIZED
}),
401
,
)
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
@
app
.
route
(
"/url/update"
,
methods
=
[
"POST"
])
def
update_ollama_api_url
():
headers
=
dict
(
request
.
headers
)
data
=
request
.
get_json
(
force
=
True
)
if
"Authorization"
in
headers
:
_
,
credentials
=
headers
[
"Authorization"
].
split
()
token_data
=
decode_token
(
credentials
)
if
token_data
is
None
or
"email"
not
in
token_data
:
return
jsonify
({
"detail"
:
ERROR_MESSAGES
.
UNAUTHORIZED
}),
401
user
=
Users
.
get_user_by_email
(
token_data
[
"email"
])
if
user
and
user
.
role
==
"admin"
:
TARGET_SERVER_URL
=
data
[
"url"
]
return
(
jsonify
({
"OLLAMA_API_BASE_URL"
:
TARGET_SERVER_URL
}),
200
,
)
else
:
return
(
jsonify
({
"detail"
:
ERROR_MESSAGES
.
ACCESS_PROHIBITED
}),
401
,
)
@
app
.
get
(
"/url"
)
async
def
get_ollama_api_url
(
user
=
Depends
(
get_current_user
)):
if
user
and
user
.
role
==
"admin"
:
return
{
"OLLAMA_API_BASE_URL"
:
app
.
state
.
OLLAMA_API_BASE_URL
}
else
:
return
(
jsonify
({
"detail"
:
ERROR_MESSAGES
.
UNAUTHORIZED
}),
401
,
)
raise
HTTPException
(
status_code
=
401
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
class
UrlUpdateForm
(
BaseModel
):
url
:
str
@
app
.
route
(
"/"
,
defaults
=
{
"path"
:
""
},
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
@
app
.
route
(
"/<path:path>"
,
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
def
proxy
(
path
):
# Combine the base URL of the target server with the requested path
target_url
=
f
"
{
TARGET_SERVER_URL
}
/
{
path
}
"
@
app
.
post
(
"/url/update"
)
async
def
update_ollama_api_url
(
form_data
:
UrlUpdateForm
,
user
=
Depends
(
get_current_user
)
):
if
user
and
user
.
role
==
"admin"
:
app
.
state
.
OLLAMA_API_BASE_URL
=
form_data
.
url
return
{
"OLLAMA_API_BASE_URL"
:
app
.
state
.
OLLAMA_API_BASE_URL
}
else
:
raise
HTTPException
(
status_code
=
401
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
# async def fetch_sse(method, target_url, body, headers):
# async with aiohttp.ClientSession() as session:
# try:
# async with session.request(
# method, target_url, data=body, headers=headers
# ) as response:
# print(response.status)
# async for line in response.content:
# yield line
# except Exception as e:
# print(e)
# error_detail = "Ollama WebUI: Server Connection Error"
# yield json.dumps({"error": error_detail, "message": str(e)}).encode()
@
app
.
api_route
(
"/{path:path}"
,
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
async
def
proxy
(
path
:
str
,
request
:
Request
,
user
=
Depends
(
get_current_user
)):
target_url
=
f
"
{
app
.
state
.
OLLAMA_API_BASE_URL
}
/
{
path
}
"
print
(
target_url
)
# Get data from the original request
data
=
request
.
get_data
()
body
=
await
request
.
body
()
headers
=
dict
(
request
.
headers
)
# Basic RBAC support
if
WEBUI_AUTH
:
if
"Authorization"
in
headers
:
_
,
credentials
=
headers
[
"Authorization"
].
split
()
token_data
=
decode_token
(
credentials
)
if
token_data
is
None
or
"email"
not
in
token_data
:
return
jsonify
({
"detail"
:
ERROR_MESSAGES
.
UNAUTHORIZED
}),
401
user
=
Users
.
get_user_by_email
(
token_data
[
"email"
])
if
user
:
# Only user and admin roles can access
if
user
.
role
in
[
"user"
,
"admin"
]:
if
path
in
[
"pull"
,
"delete"
,
"push"
,
"copy"
,
"create"
]:
# Only admin role can perform actions above
if
user
.
role
==
"admin"
:
pass
else
:
return
(
jsonify
({
"detail"
:
ERROR_MESSAGES
.
ACCESS_PROHIBITED
}),
401
,
)
else
:
pass
else
:
return
jsonify
(
{
"detail"
:
ERROR_MESSAGES
.
ACCESS_PROHIBITED
}),
401
else
:
return
jsonify
({
"detail"
:
ERROR_MESSAGES
.
UNAUTHORIZED
}),
401
else
:
return
jsonify
({
"detail"
:
ERROR_MESSAGES
.
UNAUTHORIZED
}),
401
if
user
.
role
in
[
"user"
,
"admin"
]:
if
path
in
[
"pull"
,
"delete"
,
"push"
,
"copy"
,
"create"
]:
if
user
.
role
!=
"admin"
:
raise
HTTPException
(
status_code
=
401
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
else
:
pass
r
=
None
raise
HTTPException
(
status_code
=
401
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
headers
.
pop
(
"Host"
,
None
)
headers
.
pop
(
"Authorization"
,
None
)
headers
.
pop
(
"Origin"
,
None
)
headers
.
pop
(
"Referer"
,
None
)
session
=
aiohttp
.
ClientSession
()
response
=
None
try
:
# Make a request to the target server
r
=
requests
.
request
(
method
=
request
.
method
,
url
=
target_url
,
data
=
data
,
headers
=
headers
,
stream
=
True
,
# Enable streaming for server-sent events
response
=
await
session
.
request
(
request
.
method
,
target_url
,
data
=
body
,
headers
=
headers
)
r
.
raise_for_status
()
# Proxy the target server's response to the client
def
generate
():
for
chunk
in
r
.
iter_content
(
chunk_size
=
8192
):
yield
chunk
print
(
response
)
if
not
response
.
ok
:
data
=
await
response
.
json
()
print
(
data
)
response
.
raise_for_status
()
response
=
Response
(
generate
(),
status
=
r
.
status_code
)
async
def
generate
():
async
for
line
in
response
.
content
:
print
(
line
)
yield
line
await
session
.
close
()
# Copy headers from the target server's response to the client's response
for
key
,
value
in
r
.
headers
.
items
():
response
.
headers
[
key
]
=
value
return
StreamingResponse
(
generate
(),
response
.
status
)
return
response
except
Exception
as
e
:
print
(
e
)
error_detail
=
"Ollama WebUI: Server Connection Error"
if
r
!=
None
:
print
(
r
.
text
)
res
=
r
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
print
(
res
)
return
(
jsonify
({
"detail"
:
error_detail
,
"message"
:
str
(
e
),
}),
400
,
)
if
__name__
==
"__main__"
:
app
.
run
(
debug
=
True
)
if
response
is
not
None
:
try
:
res
=
await
response
.
json
()
if
"error"
in
res
:
error_detail
=
f
"Ollama:
{
res
[
'error'
]
}
"
except
:
error_detail
=
f
"Ollama:
{
e
}
"
await
session
.
close
()
raise
HTTPException
(
status_code
=
response
.
status
if
response
else
500
,
detail
=
error_detail
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment