Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
df09d083
Commit
df09d083
authored
Jun 18, 2024
by
Jonathan Rohde
Browse files
feat(sqlalchemy): Replace peewee with sqlalchemy
parent
8dac2a21
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
522 additions
and
10 deletions
+522
-10
backend/test/apps/webui/routers/test_models.py
backend/test/apps/webui/routers/test_models.py
+60
-0
backend/test/apps/webui/routers/test_prompts.py
backend/test/apps/webui/routers/test_prompts.py
+82
-0
backend/test/apps/webui/routers/test_users.py
backend/test/apps/webui/routers/test_users.py
+170
-0
backend/test/util/abstract_integration_test.py
backend/test/util/abstract_integration_test.py
+155
-0
backend/test/util/mock_user.py
backend/test/util/mock_user.py
+45
-0
backend/utils/utils.py
backend/utils/utils.py
+9
-6
src/lib/apis/models/index.ts
src/lib/apis/models/index.ts
+1
-4
No files found.
backend/test/apps/webui/routers/test_models.py
0 → 100644
View file @
df09d083
from
test.util.abstract_integration_test
import
AbstractPostgresTest
from
test.util.mock_user
import
mock_webui_user
class
TestModels
(
AbstractPostgresTest
):
BASE_PATH
=
"/api/v1/models"
def
setup_class
(
cls
):
super
().
setup_class
()
from
apps.webui.models.models
import
Model
cls
.
models
=
Model
def
test_models
(
self
):
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
"/"
))
assert
response
.
status_code
==
200
assert
len
(
response
.
json
())
==
0
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
post
(
self
.
create_url
(
"/add"
),
json
=
{
"id"
:
"my-model"
,
"base_model_id"
:
"base-model-id"
,
"name"
:
"Hello World"
,
"meta"
:
{
"profile_image_url"
:
"/favicon.png"
,
"description"
:
"description"
,
"capabilities"
:
None
,
"model_config"
:
{},
},
"params"
:
{},
},
)
assert
response
.
status_code
==
200
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
"/"
))
assert
response
.
status_code
==
200
assert
len
(
response
.
json
())
==
1
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
"/my-model"
))
assert
response
.
status_code
==
200
data
=
response
.
json
()
assert
data
[
"id"
]
==
"my-model"
assert
data
[
"name"
]
==
"Hello World"
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
delete
(
self
.
create_url
(
"/delete?id=my-model"
)
)
assert
response
.
status_code
==
200
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
"/"
))
assert
response
.
status_code
==
200
assert
len
(
response
.
json
())
==
0
backend/test/apps/webui/routers/test_prompts.py
0 → 100644
View file @
df09d083
from
test.util.abstract_integration_test
import
AbstractPostgresTest
from
test.util.mock_user
import
mock_webui_user
class
TestPrompts
(
AbstractPostgresTest
):
BASE_PATH
=
"/api/v1/prompts"
def
test_prompts
(
self
):
# Get all prompts
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
"/"
))
assert
response
.
status_code
==
200
assert
len
(
response
.
json
())
==
0
# Create a two new prompts
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
post
(
self
.
create_url
(
"/create"
),
json
=
{
"command"
:
"/my-command"
,
"title"
:
"Hello World"
,
"content"
:
"description"
,
},
)
assert
response
.
status_code
==
200
with
mock_webui_user
(
id
=
"3"
):
response
=
self
.
fast_api_client
.
post
(
self
.
create_url
(
"/create"
),
json
=
{
"command"
:
"/my-command2"
,
"title"
:
"Hello World 2"
,
"content"
:
"description 2"
,
},
)
assert
response
.
status_code
==
200
# Get all prompts
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
"/"
))
assert
response
.
status_code
==
200
assert
len
(
response
.
json
())
==
2
# Get prompt by command
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
"/command/my-command"
))
assert
response
.
status_code
==
200
data
=
response
.
json
()
assert
data
[
"command"
]
==
"/my-command"
assert
data
[
"title"
]
==
"Hello World"
assert
data
[
"content"
]
==
"description"
assert
data
[
"user_id"
]
==
"2"
# Update prompt
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
post
(
self
.
create_url
(
"/command/my-command2/update"
),
json
=
{
"command"
:
"irrelevant for request"
,
"title"
:
"Hello World Updated"
,
"content"
:
"description Updated"
,
},
)
assert
response
.
status_code
==
200
data
=
response
.
json
()
assert
data
[
"command"
]
==
"/my-command2"
assert
data
[
"title"
]
==
"Hello World Updated"
assert
data
[
"content"
]
==
"description Updated"
assert
data
[
"user_id"
]
==
"3"
# Delete prompt
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
delete
(
self
.
create_url
(
"/command/my-command/delete"
)
)
assert
response
.
status_code
==
200
# Get all prompts
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
"/"
))
assert
response
.
status_code
==
200
assert
len
(
response
.
json
())
==
1
backend/test/apps/webui/routers/test_users.py
0 → 100644
View file @
df09d083
from
test.util.abstract_integration_test
import
AbstractPostgresTest
from
test.util.mock_user
import
mock_webui_user
def
_get_user_by_id
(
data
,
param
):
return
next
((
item
for
item
in
data
if
item
[
"id"
]
==
param
),
None
)
def
_assert_user
(
data
,
id
,
**
kwargs
):
user
=
_get_user_by_id
(
data
,
id
)
assert
user
is
not
None
comparison_data
=
{
"name"
:
f
"user
{
id
}
"
,
"email"
:
f
"user
{
id
}
@openwebui.com"
,
"profile_image_url"
:
f
"/user
{
id
}
.png"
,
"role"
:
"user"
,
**
kwargs
,
}
for
key
,
value
in
comparison_data
.
items
():
assert
user
[
key
]
==
value
class
TestUsers
(
AbstractPostgresTest
):
BASE_PATH
=
"/api/v1/users"
def
setup_class
(
cls
):
super
().
setup_class
()
from
apps.webui.models.users
import
Users
cls
.
users
=
Users
def
setup_method
(
self
):
super
().
setup_method
()
self
.
users
.
insert_new_user
(
self
.
db_session
,
id
=
"1"
,
name
=
"user 1"
,
email
=
"user1@openwebui.com"
,
profile_image_url
=
"/user1.png"
,
role
=
"user"
,
)
self
.
users
.
insert_new_user
(
self
.
db_session
,
id
=
"2"
,
name
=
"user 2"
,
email
=
"user2@openwebui.com"
,
profile_image_url
=
"/user2.png"
,
role
=
"user"
,
)
def
test_users
(
self
):
# Get all users
with
mock_webui_user
(
id
=
"3"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
""
))
assert
response
.
status_code
==
200
assert
len
(
response
.
json
())
==
2
data
=
response
.
json
()
_assert_user
(
data
,
"1"
)
_assert_user
(
data
,
"2"
)
# update role
with
mock_webui_user
(
id
=
"3"
):
response
=
self
.
fast_api_client
.
post
(
self
.
create_url
(
"/update/role"
),
json
=
{
"id"
:
"2"
,
"role"
:
"admin"
}
)
assert
response
.
status_code
==
200
_assert_user
([
response
.
json
()],
"2"
,
role
=
"admin"
)
# Get all users
with
mock_webui_user
(
id
=
"3"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
""
))
assert
response
.
status_code
==
200
assert
len
(
response
.
json
())
==
2
data
=
response
.
json
()
_assert_user
(
data
,
"1"
)
_assert_user
(
data
,
"2"
,
role
=
"admin"
)
# Get (empty) user settings
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
"/user/settings"
))
assert
response
.
status_code
==
200
assert
response
.
json
()
is
None
# Update user settings
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
post
(
self
.
create_url
(
"/user/settings/update"
),
json
=
{
"ui"
:
{
"attr1"
:
"value1"
,
"attr2"
:
"value2"
},
"model_config"
:
{
"attr3"
:
"value3"
,
"attr4"
:
"value4"
},
},
)
assert
response
.
status_code
==
200
# Get user settings
with
mock_webui_user
(
id
=
"2"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
"/user/settings"
))
assert
response
.
status_code
==
200
assert
response
.
json
()
==
{
"ui"
:
{
"attr1"
:
"value1"
,
"attr2"
:
"value2"
},
"model_config"
:
{
"attr3"
:
"value3"
,
"attr4"
:
"value4"
},
}
# Get (empty) user info
with
mock_webui_user
(
id
=
"1"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
"/user/info"
))
assert
response
.
status_code
==
200
assert
response
.
json
()
is
None
# Update user info
with
mock_webui_user
(
id
=
"1"
):
response
=
self
.
fast_api_client
.
post
(
self
.
create_url
(
"/user/info/update"
),
json
=
{
"attr1"
:
"value1"
,
"attr2"
:
"value2"
},
)
assert
response
.
status_code
==
200
# Get user info
with
mock_webui_user
(
id
=
"1"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
"/user/info"
))
assert
response
.
status_code
==
200
assert
response
.
json
()
==
{
"attr1"
:
"value1"
,
"attr2"
:
"value2"
}
# Get user by id
with
mock_webui_user
(
id
=
"1"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
"/2"
))
assert
response
.
status_code
==
200
assert
response
.
json
()
==
{
"name"
:
"user 2"
,
"profile_image_url"
:
"/user2.png"
}
# Update user by id
with
mock_webui_user
(
id
=
"1"
):
response
=
self
.
fast_api_client
.
post
(
self
.
create_url
(
"/2/update"
),
json
=
{
"name"
:
"user 2 updated"
,
"email"
:
"user2-updated@openwebui.com"
,
"profile_image_url"
:
"/user2-updated.png"
,
},
)
assert
response
.
status_code
==
200
# Get all users
with
mock_webui_user
(
id
=
"3"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
""
))
assert
response
.
status_code
==
200
assert
len
(
response
.
json
())
==
2
data
=
response
.
json
()
_assert_user
(
data
,
"1"
)
_assert_user
(
data
,
"2"
,
role
=
"admin"
,
name
=
"user 2 updated"
,
email
=
"user2-updated@openwebui.com"
,
profile_image_url
=
"/user2-updated.png"
,
)
# Delete user by id
with
mock_webui_user
(
id
=
"1"
):
response
=
self
.
fast_api_client
.
delete
(
self
.
create_url
(
"/2"
))
assert
response
.
status_code
==
200
# Get all users
with
mock_webui_user
(
id
=
"3"
):
response
=
self
.
fast_api_client
.
get
(
self
.
create_url
(
""
))
assert
response
.
status_code
==
200
assert
len
(
response
.
json
())
==
1
data
=
response
.
json
()
_assert_user
(
data
,
"1"
)
backend/test/util/abstract_integration_test.py
0 → 100644
View file @
df09d083
import
logging
import
os
import
time
import
docker
import
pytest
from
docker
import
DockerClient
from
pytest_docker.plugin
import
get_docker_ip
from
fastapi.testclient
import
TestClient
from
sqlalchemy
import
text
,
create_engine
log
=
logging
.
getLogger
(
__name__
)
def
get_fast_api_client
():
from
main
import
app
with
TestClient
(
app
)
as
c
:
return
c
class
AbstractIntegrationTest
:
BASE_PATH
=
None
def
create_url
(
self
,
path
):
if
self
.
BASE_PATH
is
None
:
raise
Exception
(
"BASE_PATH is not set"
)
parts
=
self
.
BASE_PATH
.
split
(
"/"
)
parts
=
[
part
.
strip
()
for
part
in
parts
if
part
.
strip
()
!=
""
]
path_parts
=
path
.
split
(
"/"
)
path_parts
=
[
part
.
strip
()
for
part
in
path_parts
if
part
.
strip
()
!=
""
]
return
"/"
.
join
(
parts
+
path_parts
)
@
classmethod
def
setup_class
(
cls
):
pass
def
setup_method
(
self
):
pass
@
classmethod
def
teardown_class
(
cls
):
pass
def
teardown_method
(
self
):
pass
class
AbstractPostgresTest
(
AbstractIntegrationTest
):
DOCKER_CONTAINER_NAME
=
"postgres-test-container-will-get-deleted"
docker_client
:
DockerClient
def
get_db
(
self
):
from
apps.webui.internal.db
import
SessionLocal
return
SessionLocal
()
@
classmethod
def
_create_db_url
(
cls
,
env_vars_postgres
:
dict
)
->
str
:
host
=
get_docker_ip
()
user
=
env_vars_postgres
[
"POSTGRES_USER"
]
pw
=
env_vars_postgres
[
"POSTGRES_PASSWORD"
]
port
=
8081
db
=
env_vars_postgres
[
"POSTGRES_DB"
]
return
f
"postgresql://
{
user
}
:
{
pw
}
@
{
host
}
:
{
port
}
/
{
db
}
"
@
classmethod
def
setup_class
(
cls
):
super
().
setup_class
()
try
:
env_vars_postgres
=
{
"POSTGRES_USER"
:
"user"
,
"POSTGRES_PASSWORD"
:
"example"
,
"POSTGRES_DB"
:
"openwebui"
,
}
cls
.
docker_client
=
docker
.
from_env
()
cls
.
docker_client
.
containers
.
run
(
"postgres:16.2"
,
detach
=
True
,
environment
=
env_vars_postgres
,
name
=
cls
.
DOCKER_CONTAINER_NAME
,
ports
=
{
5432
:
(
"0.0.0.0"
,
8081
)},
command
=
"postgres -c log_statement=all"
,
)
time
.
sleep
(
0.5
)
database_url
=
cls
.
_create_db_url
(
env_vars_postgres
)
os
.
environ
[
"DATABASE_URL"
]
=
database_url
retries
=
10
db
=
None
while
retries
>
0
:
try
:
from
config
import
BACKEND_DIR
db
=
create_engine
(
database_url
,
pool_pre_ping
=
True
)
db
=
db
.
connect
()
log
.
info
(
"postgres is ready!"
)
break
except
Exception
as
e
:
log
.
warning
(
e
)
time
.
sleep
(
3
)
retries
-=
1
if
db
:
# import must be after setting env!
cls
.
fast_api_client
=
get_fast_api_client
()
db
.
close
()
else
:
raise
Exception
(
"Could not connect to Postgres"
)
except
Exception
as
ex
:
log
.
error
(
ex
)
cls
.
teardown_class
()
pytest
.
fail
(
f
"Could not setup test environment:
{
ex
}
"
)
def
_check_db_connection
(
self
):
retries
=
10
while
retries
>
0
:
try
:
self
.
db_session
.
execute
(
text
(
"SELECT 1"
))
self
.
db_session
.
commit
()
break
except
Exception
as
e
:
self
.
db_session
.
rollback
()
log
.
warning
(
e
)
time
.
sleep
(
3
)
retries
-=
1
def
setup_method
(
self
):
super
().
setup_method
()
self
.
db_session
=
self
.
get_db
()
self
.
_check_db_connection
()
@
classmethod
def
teardown_class
(
cls
)
->
None
:
super
().
teardown_class
()
cls
.
docker_client
.
containers
.
get
(
cls
.
DOCKER_CONTAINER_NAME
).
remove
(
force
=
True
)
def
teardown_method
(
self
):
# rollback everything not yet committed
self
.
db_session
.
commit
()
# truncate all tables
tables
=
[
"auth"
,
"chat"
,
"chatidtag"
,
"document"
,
"memory"
,
"model"
,
"prompt"
,
"tag"
,
'"user"'
,
]
for
table
in
tables
:
self
.
db_session
.
execute
(
text
(
f
"TRUNCATE TABLE
{
table
}
"
))
self
.
db_session
.
commit
()
backend/test/util/mock_user.py
0 → 100644
View file @
df09d083
from
contextlib
import
contextmanager
from
fastapi
import
FastAPI
@
contextmanager
def
mock_webui_user
(
**
kwargs
):
from
apps.webui.main
import
app
with
mock_user
(
app
,
**
kwargs
):
yield
@
contextmanager
def
mock_user
(
app
:
FastAPI
,
**
kwargs
):
from
utils.utils
import
(
get_current_user
,
get_verified_user
,
get_admin_user
,
get_current_user_by_api_key
,
)
from
apps.webui.models.users
import
User
def
create_user
():
user_parameters
=
{
"id"
:
"1"
,
"name"
:
"John Doe"
,
"email"
:
"john.doe@openwebui.com"
,
"role"
:
"user"
,
"profile_image_url"
:
"/user.png"
,
"last_active_at"
:
1627351200
,
"updated_at"
:
1627351200
,
"created_at"
:
162735120
,
**
kwargs
,
}
return
User
(
**
user_parameters
)
app
.
dependency_overrides
=
{
get_current_user
:
create_user
,
get_verified_user
:
create_user
,
get_admin_user
:
create_user
,
get_current_user_by_api_key
:
create_user
,
}
yield
app
.
dependency_overrides
=
{}
backend/utils/utils.py
View file @
df09d083
from
fastapi.security
import
HTTPBearer
,
HTTPAuthorizationCredentials
from
fastapi
import
HTTPException
,
status
,
Depends
,
Request
from
sqlalchemy.orm
import
Session
from
apps.webui.internal.db
import
get_db
from
apps.webui.models.users
import
Users
from
pydantic
import
BaseModel
...
...
@@ -77,6 +79,7 @@ def get_http_authorization_cred(auth_header: str):
def
get_current_user
(
request
:
Request
,
auth_token
:
HTTPAuthorizationCredentials
=
Depends
(
bearer_security
),
db
=
Depends
(
get_db
),
):
token
=
None
...
...
@@ -91,19 +94,19 @@ def get_current_user(
# auth by api key
if
token
.
startswith
(
"sk-"
):
return
get_current_user_by_api_key
(
token
)
return
get_current_user_by_api_key
(
db
,
token
)
# auth by jwt token
data
=
decode_token
(
token
)
if
data
!=
None
and
"id"
in
data
:
user
=
Users
.
get_user_by_id
(
data
[
"id"
])
user
=
Users
.
get_user_by_id
(
db
,
data
[
"id"
])
if
user
is
None
:
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
detail
=
ERROR_MESSAGES
.
INVALID_TOKEN
,
)
else
:
Users
.
update_user_last_active_by_id
(
user
.
id
)
Users
.
update_user_last_active_by_id
(
db
,
user
.
id
)
return
user
else
:
raise
HTTPException
(
...
...
@@ -112,8 +115,8 @@ def get_current_user(
)
def
get_current_user_by_api_key
(
api_key
:
str
):
user
=
Users
.
get_user_by_api_key
(
api_key
)
def
get_current_user_by_api_key
(
db
:
Session
,
api_key
:
str
):
user
=
Users
.
get_user_by_api_key
(
db
,
api_key
)
if
user
is
None
:
raise
HTTPException
(
...
...
@@ -121,7 +124,7 @@ def get_current_user_by_api_key(api_key: str):
detail
=
ERROR_MESSAGES
.
INVALID_TOKEN
,
)
else
:
Users
.
update_user_last_active_by_id
(
user
.
id
)
Users
.
update_user_last_active_by_id
(
db
,
user
.
id
)
return
user
...
...
src/lib/apis/models/index.ts
View file @
df09d083
...
...
@@ -63,10 +63,7 @@ export const getModelInfos = async (token: string = '') => {
export
const
getModelById
=
async
(
token
:
string
,
id
:
string
)
=>
{
let
error
=
null
;
const
searchParams
=
new
URLSearchParams
();
searchParams
.
append
(
'
id
'
,
id
);
const
res
=
await
fetch
(
`
${
WEBUI_API_BASE_URL
}
/models?
${
searchParams
.
toString
()}
`
,
{
const
res
=
await
fetch
(
`
${
WEBUI_API_BASE_URL
}
/models/
${
id
}
`
,
{
method
:
'
GET
'
,
headers
:
{
Accept
:
'
application/json
'
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment