Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
07cc7f15
Unverified
Commit
07cc7f15
authored
Jan 03, 2024
by
ThatOneCalculator
Browse files
chore:
🚨
lint and format
parent
03779316
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
170 additions
and
160 deletions
+170
-160
.github/ISSUE_TEMPLATE/bug_report.md
.github/ISSUE_TEMPLATE/bug_report.md
+1
-1
.github/ISSUE_TEMPLATE/feature_request.md
.github/ISSUE_TEMPLATE/feature_request.md
+0
-1
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+13
-11
backend/apps/web/main.py
backend/apps/web/main.py
+3
-2
backend/apps/web/models/auths.py
backend/apps/web/models/auths.py
+16
-9
backend/apps/web/models/chats.py
backend/apps/web/models/chats.py
+18
-21
backend/apps/web/models/modelfiles.py
backend/apps/web/models/modelfiles.py
+13
-12
backend/apps/web/models/prompts.py
backend/apps/web/models/prompts.py
+7
-9
backend/apps/web/models/users.py
backend/apps/web/models/users.py
+11
-8
backend/apps/web/routers/auths.py
backend/apps/web/routers/auths.py
+9
-13
backend/apps/web/routers/chats.py
backend/apps/web/routers/chats.py
+16
-14
backend/apps/web/routers/configs.py
backend/apps/web/routers/configs.py
+3
-4
backend/apps/web/routers/modelfiles.py
backend/apps/web/routers/modelfiles.py
+21
-22
backend/apps/web/routers/prompts.py
backend/apps/web/routers/prompts.py
+7
-6
backend/apps/web/routers/users.py
backend/apps/web/routers/users.py
+5
-6
backend/apps/web/routers/utils.py
backend/apps/web/routers/utils.py
+9
-8
backend/config.py
backend/config.py
+2
-3
backend/constants.py
backend/constants.py
+2
-2
backend/main.py
backend/main.py
+4
-1
backend/utils/utils.py
backend/utils/utils.py
+10
-7
No files found.
.github/ISSUE_TEMPLATE/bug_report.md
View file @
07cc7f15
...
@@ -4,7 +4,6 @@ about: Create a report to help us improve
...
@@ -4,7 +4,6 @@ about: Create a report to help us improve
title
:
'
'
title
:
'
'
labels
:
'
'
labels
:
'
'
assignees
:
'
'
assignees
:
'
'
---
---
# Bug Report
# Bug Report
...
@@ -31,6 +30,7 @@ assignees: ''
...
@@ -31,6 +30,7 @@ assignees: ''
## Reproduction Details
## Reproduction Details
**Confirmation:**
**Confirmation:**
-
[ ] I have read and followed all the instructions provided in the README.md.
-
[ ] I have read and followed all the instructions provided in the README.md.
-
[ ] I have reviewed the troubleshooting.md document.
-
[ ] I have reviewed the troubleshooting.md document.
-
[ ] I have included the browser console logs.
-
[ ] I have included the browser console logs.
...
...
.github/ISSUE_TEMPLATE/feature_request.md
View file @
07cc7f15
...
@@ -4,7 +4,6 @@ about: Suggest an idea for this project
...
@@ -4,7 +4,6 @@ about: Suggest an idea for this project
title
:
'
'
title
:
'
'
labels
:
'
'
labels
:
'
'
assignees
:
'
'
assignees
:
'
'
---
---
**Is your feature request related to a problem? Please describe.**
**Is your feature request related to a problem? Please describe.**
...
...
backend/apps/ollama/main.py
View file @
07cc7f15
from
flask
import
Flask
,
request
,
Response
,
jsonify
from
flask
import
Flask
,
request
,
Response
,
jsonify
from
flask_cors
import
CORS
from
flask_cors
import
CORS
import
requests
import
requests
import
json
import
json
from
apps.web.models.users
import
Users
from
apps.web.models.users
import
Users
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
from
utils.utils
import
decode_token
from
utils.utils
import
decode_token
...
@@ -20,7 +18,9 @@ CORS(
...
@@ -20,7 +18,9 @@ CORS(
TARGET_SERVER_URL
=
OLLAMA_API_BASE_URL
TARGET_SERVER_URL
=
OLLAMA_API_BASE_URL
@
app
.
route
(
"/"
,
defaults
=
{
"path"
:
""
},
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
@
app
.
route
(
"/"
,
defaults
=
{
"path"
:
""
},
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
@
app
.
route
(
"/<path:path>"
,
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
@
app
.
route
(
"/<path:path>"
,
methods
=
[
"GET"
,
"POST"
,
"PUT"
,
"DELETE"
])
def
proxy
(
path
):
def
proxy
(
path
):
# Combine the base URL of the target server with the requested path
# Combine the base URL of the target server with the requested path
...
@@ -49,13 +49,17 @@ def proxy(path):
...
@@ -49,13 +49,17 @@ def proxy(path):
pass
pass
else
:
else
:
return
(
return
(
jsonify
({
"detail"
:
ERROR_MESSAGES
.
ACCESS_PROHIBITED
}),
jsonify
({
"detail"
:
ERROR_MESSAGES
.
ACCESS_PROHIBITED
}),
401
,
401
,
)
)
else
:
else
:
pass
pass
else
:
else
:
return
jsonify
({
"detail"
:
ERROR_MESSAGES
.
ACCESS_PROHIBITED
}),
401
return
jsonify
(
{
"detail"
:
ERROR_MESSAGES
.
ACCESS_PROHIBITED
}),
401
else
:
else
:
return
jsonify
({
"detail"
:
ERROR_MESSAGES
.
UNAUTHORIZED
}),
401
return
jsonify
({
"detail"
:
ERROR_MESSAGES
.
UNAUTHORIZED
}),
401
else
:
else
:
...
@@ -105,12 +109,10 @@ def proxy(path):
...
@@ -105,12 +109,10 @@ def proxy(path):
print
(
res
)
print
(
res
)
return
(
return
(
jsonify
(
jsonify
({
{
"detail"
:
error_detail
,
"detail"
:
error_detail
,
"message"
:
str
(
e
),
"message"
:
str
(
e
),
}),
}
),
400
,
400
,
)
)
...
...
backend/apps/web/main.py
View file @
07cc7f15
...
@@ -22,10 +22,11 @@ app.add_middleware(
...
@@ -22,10 +22,11 @@ app.add_middleware(
app
.
include_router
(
auths
.
router
,
prefix
=
"/auths"
,
tags
=
[
"auths"
])
app
.
include_router
(
auths
.
router
,
prefix
=
"/auths"
,
tags
=
[
"auths"
])
app
.
include_router
(
users
.
router
,
prefix
=
"/users"
,
tags
=
[
"users"
])
app
.
include_router
(
users
.
router
,
prefix
=
"/users"
,
tags
=
[
"users"
])
app
.
include_router
(
chats
.
router
,
prefix
=
"/chats"
,
tags
=
[
"chats"
])
app
.
include_router
(
chats
.
router
,
prefix
=
"/chats"
,
tags
=
[
"chats"
])
app
.
include_router
(
modelfiles
.
router
,
prefix
=
"/modelfiles"
,
tags
=
[
"modelfiles"
])
app
.
include_router
(
modelfiles
.
router
,
prefix
=
"/modelfiles"
,
tags
=
[
"modelfiles"
])
app
.
include_router
(
prompts
.
router
,
prefix
=
"/prompts"
,
tags
=
[
"prompts"
])
app
.
include_router
(
prompts
.
router
,
prefix
=
"/prompts"
,
tags
=
[
"prompts"
])
app
.
include_router
(
configs
.
router
,
prefix
=
"/configs"
,
tags
=
[
"configs"
])
app
.
include_router
(
configs
.
router
,
prefix
=
"/configs"
,
tags
=
[
"configs"
])
app
.
include_router
(
utils
.
router
,
prefix
=
"/utils"
,
tags
=
[
"utils"
])
app
.
include_router
(
utils
.
router
,
prefix
=
"/utils"
,
tags
=
[
"utils"
])
...
...
backend/apps/web/models/auths.py
View file @
07cc7f15
...
@@ -4,7 +4,6 @@ import time
...
@@ -4,7 +4,6 @@ import time
import
uuid
import
uuid
from
peewee
import
*
from
peewee
import
*
from
apps.web.models.users
import
UserModel
,
Users
from
apps.web.models.users
import
UserModel
,
Users
from
utils.utils
import
(
from
utils.utils
import
(
verify_password
,
verify_password
,
...
@@ -76,20 +75,26 @@ class SignupForm(BaseModel):
...
@@ -76,20 +75,26 @@ class SignupForm(BaseModel):
class
AuthsTable
:
class
AuthsTable
:
def
__init__
(
self
,
db
):
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
=
db
self
.
db
.
create_tables
([
Auth
])
self
.
db
.
create_tables
([
Auth
])
def
insert_new_auth
(
def
insert_new_auth
(
self
,
self
,
email
:
str
,
password
:
str
,
name
:
str
,
role
:
str
=
"pending"
email
:
str
,
)
->
Optional
[
UserModel
]:
password
:
str
,
name
:
str
,
role
:
str
=
"pending"
)
->
Optional
[
UserModel
]:
print
(
"insert_new_auth"
)
print
(
"insert_new_auth"
)
id
=
str
(
uuid
.
uuid4
())
id
=
str
(
uuid
.
uuid4
())
auth
=
AuthModel
(
auth
=
AuthModel
(
**
{
**
{
"id"
:
id
,
"email"
:
email
,
"password"
:
password
,
"active"
:
True
}
"id"
:
id
,
)
"email"
:
email
,
"password"
:
password
,
"active"
:
True
})
result
=
Auth
.
create
(
**
auth
.
model_dump
())
result
=
Auth
.
create
(
**
auth
.
model_dump
())
user
=
Users
.
insert_new_user
(
id
,
name
,
email
,
role
)
user
=
Users
.
insert_new_user
(
id
,
name
,
email
,
role
)
...
@@ -99,7 +104,8 @@ class AuthsTable:
...
@@ -99,7 +104,8 @@ class AuthsTable:
else
:
else
:
return
None
return
None
def
authenticate_user
(
self
,
email
:
str
,
password
:
str
)
->
Optional
[
UserModel
]:
def
authenticate_user
(
self
,
email
:
str
,
password
:
str
)
->
Optional
[
UserModel
]:
print
(
"authenticate_user"
,
email
)
print
(
"authenticate_user"
,
email
)
try
:
try
:
auth
=
Auth
.
get
(
Auth
.
email
==
email
,
Auth
.
active
==
True
)
auth
=
Auth
.
get
(
Auth
.
email
==
email
,
Auth
.
active
==
True
)
...
@@ -131,7 +137,8 @@ class AuthsTable:
...
@@ -131,7 +137,8 @@ class AuthsTable:
if
result
:
if
result
:
# Delete Auth
# Delete Auth
query
=
Auth
.
delete
().
where
(
Auth
.
id
==
id
)
query
=
Auth
.
delete
().
where
(
Auth
.
id
==
id
)
query
.
execute
()
# Remove the rows, return number of rows removed.
query
.
execute
(
)
# Remove the rows, return number of rows removed.
return
True
return
True
else
:
else
:
...
...
backend/apps/web/models/chats.py
View file @
07cc7f15
...
@@ -3,14 +3,12 @@ from typing import List, Union, Optional
...
@@ -3,14 +3,12 @@ from typing import List, Union, Optional
from
peewee
import
*
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
playhouse.shortcuts
import
model_to_dict
import
json
import
json
import
uuid
import
uuid
import
time
import
time
from
apps.web.internal.db
import
DB
from
apps.web.internal.db
import
DB
####################
####################
# Chat DB Schema
# Chat DB Schema
####################
####################
...
@@ -62,23 +60,23 @@ class ChatTitleIdResponse(BaseModel):
...
@@ -62,23 +60,23 @@ class ChatTitleIdResponse(BaseModel):
class
ChatTable
:
class
ChatTable
:
def
__init__
(
self
,
db
):
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
=
db
db
.
create_tables
([
Chat
])
db
.
create_tables
([
Chat
])
def
insert_new_chat
(
self
,
user_id
:
str
,
form_data
:
ChatForm
)
->
Optional
[
ChatModel
]:
def
insert_new_chat
(
self
,
user_id
:
str
,
form_data
:
ChatForm
)
->
Optional
[
ChatModel
]:
id
=
str
(
uuid
.
uuid4
())
id
=
str
(
uuid
.
uuid4
())
chat
=
ChatModel
(
chat
=
ChatModel
(
**
{
**
{
"id"
:
id
,
"id"
:
id
,
"user_id"
:
user_id
,
"user_id"
:
user_id
,
"title"
:
form_data
.
chat
[
"title"
]
"title"
:
form_data
.
chat
[
"title"
]
if
"title"
in
if
"title"
in
form_data
.
chat
form_data
.
chat
else
"New Chat"
,
else
"New Chat"
,
"chat"
:
json
.
dumps
(
form_data
.
chat
),
"chat"
:
json
.
dumps
(
form_data
.
chat
),
"timestamp"
:
int
(
time
.
time
()),
"timestamp"
:
int
(
time
.
time
()),
}
})
)
result
=
Chat
.
create
(
**
chat
.
model_dump
())
result
=
Chat
.
create
(
**
chat
.
model_dump
())
return
chat
if
result
else
None
return
chat
if
result
else
None
...
@@ -111,27 +109,25 @@ class ChatTable:
...
@@ -111,27 +109,25 @@ class ChatTable:
except
:
except
:
return
None
return
None
def
get_chat_lists_by_user_id
(
def
get_chat_lists_by_user_id
(
self
,
self
,
user_id
:
str
,
skip
:
int
=
0
,
limit
:
int
=
50
user_id
:
str
,
)
->
List
[
ChatModel
]:
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ChatModel
]:
return
[
return
[
ChatModel
(
**
model_to_dict
(
chat
))
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
().
where
(
for
chat
in
Chat
.
select
()
Chat
.
user_id
==
user_id
).
order_by
(
Chat
.
timestamp
.
desc
())
.
where
(
Chat
.
user_id
==
user_id
)
.
order_by
(
Chat
.
timestamp
.
desc
())
# .limit(limit)
# .limit(limit)
# .offset(skip)
# .offset(skip)
]
]
def
get_all_chats_by_user_id
(
self
,
user_id
:
str
)
->
List
[
ChatModel
]:
def
get_all_chats_by_user_id
(
self
,
user_id
:
str
)
->
List
[
ChatModel
]:
return
[
return
[
ChatModel
(
**
model_to_dict
(
chat
))
ChatModel
(
**
model_to_dict
(
chat
))
for
chat
in
Chat
.
select
().
where
(
for
chat
in
Chat
.
select
()
Chat
.
user_id
==
user_id
).
order_by
(
Chat
.
timestamp
.
desc
())
.
where
(
Chat
.
user_id
==
user_id
)
.
order_by
(
Chat
.
timestamp
.
desc
())
]
]
def
get_chat_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
Optional
[
ChatModel
]:
def
get_chat_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
Optional
[
ChatModel
]:
try
:
try
:
chat
=
Chat
.
get
(
Chat
.
id
==
id
,
Chat
.
user_id
==
user_id
)
chat
=
Chat
.
get
(
Chat
.
id
==
id
,
Chat
.
user_id
==
user_id
)
return
ChatModel
(
**
model_to_dict
(
chat
))
return
ChatModel
(
**
model_to_dict
(
chat
))
...
@@ -146,7 +142,8 @@ class ChatTable:
...
@@ -146,7 +142,8 @@ class ChatTable:
def
delete_chat_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
bool
:
def
delete_chat_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
bool
:
try
:
try
:
query
=
Chat
.
delete
().
where
((
Chat
.
id
==
id
)
&
(
Chat
.
user_id
==
user_id
))
query
=
Chat
.
delete
().
where
((
Chat
.
id
==
id
)
&
(
Chat
.
user_id
==
user_id
))
query
.
execute
()
# Remove the rows, return number of rows removed.
query
.
execute
()
# Remove the rows, return number of rows removed.
return
True
return
True
...
...
backend/apps/web/models/modelfiles.py
View file @
07cc7f15
...
@@ -58,13 +58,14 @@ class ModelfileResponse(BaseModel):
...
@@ -58,13 +58,14 @@ class ModelfileResponse(BaseModel):
class
ModelfilesTable
:
class
ModelfilesTable
:
def
__init__
(
self
,
db
):
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
=
db
self
.
db
.
create_tables
([
Modelfile
])
self
.
db
.
create_tables
([
Modelfile
])
def
insert_new_modelfile
(
def
insert_new_modelfile
(
self
,
user_id
:
str
,
form_data
:
ModelfileForm
self
,
user_id
:
str
,
)
->
Optional
[
ModelfileModel
]:
form_data
:
ModelfileForm
)
->
Optional
[
ModelfileModel
]:
if
"tagName"
in
form_data
.
modelfile
:
if
"tagName"
in
form_data
.
modelfile
:
modelfile
=
ModelfileModel
(
modelfile
=
ModelfileModel
(
**
{
**
{
...
@@ -72,8 +73,7 @@ class ModelfilesTable:
...
@@ -72,8 +73,7 @@ class ModelfilesTable:
"tag_name"
:
form_data
.
modelfile
[
"tagName"
],
"tag_name"
:
form_data
.
modelfile
[
"tagName"
],
"modelfile"
:
json
.
dumps
(
form_data
.
modelfile
),
"modelfile"
:
json
.
dumps
(
form_data
.
modelfile
),
"timestamp"
:
int
(
time
.
time
()),
"timestamp"
:
int
(
time
.
time
()),
}
})
)
try
:
try
:
result
=
Modelfile
.
create
(
**
modelfile
.
model_dump
())
result
=
Modelfile
.
create
(
**
modelfile
.
model_dump
())
...
@@ -87,28 +87,29 @@ class ModelfilesTable:
...
@@ -87,28 +87,29 @@ class ModelfilesTable:
else
:
else
:
return
None
return
None
def
get_modelfile_by_tag_name
(
self
,
tag_name
:
str
)
->
Optional
[
ModelfileModel
]:
def
get_modelfile_by_tag_name
(
self
,
tag_name
:
str
)
->
Optional
[
ModelfileModel
]:
try
:
try
:
modelfile
=
Modelfile
.
get
(
Modelfile
.
tag_name
==
tag_name
)
modelfile
=
Modelfile
.
get
(
Modelfile
.
tag_name
==
tag_name
)
return
ModelfileModel
(
**
model_to_dict
(
modelfile
))
return
ModelfileModel
(
**
model_to_dict
(
modelfile
))
except
:
except
:
return
None
return
None
def
get_modelfiles
(
self
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ModelfileResponse
]:
def
get_modelfiles
(
self
,
skip
:
int
=
0
,
limit
:
int
=
50
)
->
List
[
ModelfileResponse
]:
return
[
return
[
ModelfileResponse
(
ModelfileResponse
(
**
{
**
{
**
model_to_dict
(
modelfile
),
**
model_to_dict
(
modelfile
),
"modelfile"
:
json
.
loads
(
modelfile
.
modelfile
),
"modelfile"
:
}
json
.
loads
(
modelfile
.
modelfile
),
)
})
for
modelfile
in
Modelfile
.
select
()
for
modelfile
in
Modelfile
.
select
()
# .limit(limit).offset(skip)
# .limit(limit).offset(skip)
]
]
def
update_modelfile_by_tag_name
(
def
update_modelfile_by_tag_name
(
self
,
tag_name
:
str
,
modelfile
:
dict
self
,
tag_name
:
str
,
modelfile
:
dict
)
->
Optional
[
ModelfileModel
]:
)
->
Optional
[
ModelfileModel
]:
try
:
try
:
query
=
Modelfile
.
update
(
query
=
Modelfile
.
update
(
modelfile
=
json
.
dumps
(
modelfile
),
modelfile
=
json
.
dumps
(
modelfile
),
...
...
backend/apps/web/models/prompts.py
View file @
07cc7f15
...
@@ -47,13 +47,13 @@ class PromptForm(BaseModel):
...
@@ -47,13 +47,13 @@ class PromptForm(BaseModel):
class
PromptsTable
:
class
PromptsTable
:
def
__init__
(
self
,
db
):
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
=
db
self
.
db
.
create_tables
([
Prompt
])
self
.
db
.
create_tables
([
Prompt
])
def
insert_new_prompt
(
def
insert_new_prompt
(
self
,
user_id
:
str
,
self
,
user_id
:
str
,
form_data
:
PromptForm
form_data
:
PromptForm
)
->
Optional
[
PromptModel
]:
)
->
Optional
[
PromptModel
]:
prompt
=
PromptModel
(
prompt
=
PromptModel
(
**
{
**
{
"user_id"
:
user_id
,
"user_id"
:
user_id
,
...
@@ -61,8 +61,7 @@ class PromptsTable:
...
@@ -61,8 +61,7 @@ class PromptsTable:
"title"
:
form_data
.
title
,
"title"
:
form_data
.
title
,
"content"
:
form_data
.
content
,
"content"
:
form_data
.
content
,
"timestamp"
:
int
(
time
.
time
()),
"timestamp"
:
int
(
time
.
time
()),
}
})
)
try
:
try
:
result
=
Prompt
.
create
(
**
prompt
.
model_dump
())
result
=
Prompt
.
create
(
**
prompt
.
model_dump
())
...
@@ -82,14 +81,13 @@ class PromptsTable:
...
@@ -82,14 +81,13 @@ class PromptsTable:
def
get_prompts
(
self
)
->
List
[
PromptModel
]:
def
get_prompts
(
self
)
->
List
[
PromptModel
]:
return
[
return
[
PromptModel
(
**
model_to_dict
(
prompt
))
PromptModel
(
**
model_to_dict
(
prompt
))
for
prompt
in
Prompt
.
select
()
for
prompt
in
Prompt
.
select
()
# .limit(limit).offset(skip)
# .limit(limit).offset(skip)
]
]
def
update_prompt_by_command
(
def
update_prompt_by_command
(
self
,
command
:
str
,
form_data
:
PromptForm
self
,
command
:
str
,
)
->
Optional
[
PromptModel
]:
form_data
:
PromptForm
)
->
Optional
[
PromptModel
]:
try
:
try
:
query
=
Prompt
.
update
(
query
=
Prompt
.
update
(
title
=
form_data
.
title
,
title
=
form_data
.
title
,
...
...
backend/apps/web/models/users.py
View file @
07cc7f15
...
@@ -8,7 +8,6 @@ from utils.misc import get_gravatar_url
...
@@ -8,7 +8,6 @@ from utils.misc import get_gravatar_url
from
apps.web.internal.db
import
DB
from
apps.web.internal.db
import
DB
from
apps.web.models.chats
import
Chats
from
apps.web.models.chats
import
Chats
####################
####################
# User DB Schema
# User DB Schema
####################
####################
...
@@ -46,13 +45,16 @@ class UserRoleUpdateForm(BaseModel):
...
@@ -46,13 +45,16 @@ class UserRoleUpdateForm(BaseModel):
class
UsersTable
:
class
UsersTable
:
def
__init__
(
self
,
db
):
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
=
db
self
.
db
.
create_tables
([
User
])
self
.
db
.
create_tables
([
User
])
def
insert_new_user
(
def
insert_new_user
(
self
,
self
,
id
:
str
,
name
:
str
,
email
:
str
,
role
:
str
=
"pending"
id
:
str
,
)
->
Optional
[
UserModel
]:
name
:
str
,
email
:
str
,
role
:
str
=
"pending"
)
->
Optional
[
UserModel
]:
user
=
UserModel
(
user
=
UserModel
(
**
{
**
{
"id"
:
id
,
"id"
:
id
,
...
@@ -61,8 +63,7 @@ class UsersTable:
...
@@ -61,8 +63,7 @@ class UsersTable:
"role"
:
role
,
"role"
:
role
,
"profile_image_url"
:
get_gravatar_url
(
email
),
"profile_image_url"
:
get_gravatar_url
(
email
),
"timestamp"
:
int
(
time
.
time
()),
"timestamp"
:
int
(
time
.
time
()),
}
})
)
result
=
User
.
create
(
**
user
.
model_dump
())
result
=
User
.
create
(
**
user
.
model_dump
())
if
result
:
if
result
:
return
user
return
user
...
@@ -92,7 +93,8 @@ class UsersTable:
...
@@ -92,7 +93,8 @@ class UsersTable:
def
get_num_users
(
self
)
->
Optional
[
int
]:
def
get_num_users
(
self
)
->
Optional
[
int
]:
return
User
.
select
().
count
()
return
User
.
select
().
count
()
def
update_user_role_by_id
(
self
,
id
:
str
,
role
:
str
)
->
Optional
[
UserModel
]:
def
update_user_role_by_id
(
self
,
id
:
str
,
role
:
str
)
->
Optional
[
UserModel
]:
try
:
try
:
query
=
User
.
update
(
role
=
role
).
where
(
User
.
id
==
id
)
query
=
User
.
update
(
role
=
role
).
where
(
User
.
id
==
id
)
query
.
execute
()
query
.
execute
()
...
@@ -110,7 +112,8 @@ class UsersTable:
...
@@ -110,7 +112,8 @@ class UsersTable:
if
result
:
if
result
:
# Delete User
# Delete User
query
=
User
.
delete
().
where
(
User
.
id
==
id
)
query
=
User
.
delete
().
where
(
User
.
id
==
id
)
query
.
execute
()
# Remove the rows, return number of rows removed.
query
.
execute
(
)
# Remove the rows, return number of rows removed.
return
True
return
True
else
:
else
:
...
...
backend/apps/web/routers/auths.py
View file @
07cc7f15
...
@@ -8,7 +8,6 @@ from pydantic import BaseModel
...
@@ -8,7 +8,6 @@ from pydantic import BaseModel
import
time
import
time
import
uuid
import
uuid
from
apps.web.models.auths
import
(
from
apps.web.models.auths
import
(
SigninForm
,
SigninForm
,
SignupForm
,
SignupForm
,
...
@@ -19,12 +18,10 @@ from apps.web.models.auths import (
...
@@ -19,12 +18,10 @@ from apps.web.models.auths import (
)
)
from
apps.web.models.users
import
Users
from
apps.web.models.users
import
Users
from
utils.utils
import
get_password_hash
,
get_current_user
,
create_token
from
utils.utils
import
get_password_hash
,
get_current_user
,
create_token
from
utils.misc
import
get_gravatar_url
,
validate_email_format
from
utils.misc
import
get_gravatar_url
,
validate_email_format
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
router
=
APIRouter
()
router
=
APIRouter
()
############################
############################
...
@@ -49,9 +46,8 @@ async def get_session_user(user=Depends(get_current_user)):
...
@@ -49,9 +46,8 @@ async def get_session_user(user=Depends(get_current_user)):
@
router
.
post
(
"/update/password"
,
response_model
=
bool
)
@
router
.
post
(
"/update/password"
,
response_model
=
bool
)
async
def
update_password
(
async
def
update_password
(
form_data
:
UpdatePasswordForm
,
form_data
:
UpdatePasswordForm
,
session_user
=
Depends
(
get_current_user
)
session_user
=
Depends
(
get_current_user
)):
):
if
session_user
:
if
session_user
:
user
=
Auths
.
authenticate_user
(
session_user
.
email
,
form_data
.
password
)
user
=
Auths
.
authenticate_user
(
session_user
.
email
,
form_data
.
password
)
...
@@ -101,9 +97,8 @@ async def signup(request: Request, form_data: SignupForm):
...
@@ -101,9 +97,8 @@ async def signup(request: Request, form_data: SignupForm):
try
:
try
:
role
=
"admin"
if
Users
.
get_num_users
()
==
0
else
"pending"
role
=
"admin"
if
Users
.
get_num_users
()
==
0
else
"pending"
hashed
=
get_password_hash
(
form_data
.
password
)
hashed
=
get_password_hash
(
form_data
.
password
)
user
=
Auths
.
insert_new_auth
(
user
=
Auths
.
insert_new_auth
(
form_data
.
email
.
lower
(),
form_data
.
email
.
lower
(),
hashed
,
form_data
.
name
,
role
hashed
,
form_data
.
name
,
role
)
)
if
user
:
if
user
:
token
=
create_token
(
data
=
{
"email"
:
user
.
email
})
token
=
create_token
(
data
=
{
"email"
:
user
.
email
})
...
@@ -120,14 +115,15 @@ async def signup(request: Request, form_data: SignupForm):
...
@@ -120,14 +115,15 @@ async def signup(request: Request, form_data: SignupForm):
}
}
else
:
else
:
raise
HTTPException
(
raise
HTTPException
(
500
,
detail
=
ERROR_MESSAGES
.
CREATE_USER_ERROR
500
,
detail
=
ERROR_MESSAGES
.
CREATE_USER_ERROR
)
)
except
Exception
as
err
:
except
Exception
as
err
:
raise
HTTPException
(
500
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
err
))
raise
HTTPException
(
500
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
err
))
else
:
else
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
EMAIL_TAKEN
)
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
EMAIL_TAKEN
)
else
:
else
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_EMAIL_FORMAT
)
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
INVALID_EMAIL_FORMAT
)
else
:
else
:
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
raise
HTTPException
(
400
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
...
...
backend/apps/web/routers/chats.py
View file @
07cc7f15
...
@@ -17,8 +17,7 @@ from apps.web.models.chats import (
...
@@ -17,8 +17,7 @@ from apps.web.models.chats import (
)
)
from
utils.utils
import
(
from
utils.utils
import
(
bearer_scheme
,
bearer_scheme
,
)
)
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
router
=
APIRouter
()
router
=
APIRouter
()
...
@@ -30,8 +29,7 @@ router = APIRouter()
...
@@ -30,8 +29,7 @@ router = APIRouter()
@
router
.
get
(
"/"
,
response_model
=
List
[
ChatTitleIdResponse
])
@
router
.
get
(
"/"
,
response_model
=
List
[
ChatTitleIdResponse
])
async
def
get_user_chats
(
async
def
get_user_chats
(
user
=
Depends
(
get_current_user
),
skip
:
int
=
0
,
limit
:
int
=
50
user
=
Depends
(
get_current_user
),
skip
:
int
=
0
,
limit
:
int
=
50
):
):
return
Chats
.
get_chat_lists_by_user_id
(
user
.
id
,
skip
,
limit
)
return
Chats
.
get_chat_lists_by_user_id
(
user
.
id
,
skip
,
limit
)
...
@@ -43,8 +41,9 @@ async def get_user_chats(
...
@@ -43,8 +41,9 @@ async def get_user_chats(
@
router
.
get
(
"/all"
,
response_model
=
List
[
ChatResponse
])
@
router
.
get
(
"/all"
,
response_model
=
List
[
ChatResponse
])
async
def
get_all_user_chats
(
user
=
Depends
(
get_current_user
)):
async
def
get_all_user_chats
(
user
=
Depends
(
get_current_user
)):
return
[
return
[
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
ChatResponse
(
**
{
for
chat
in
Chats
.
get_all_chats_by_user_id
(
user
.
id
)
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)
})
for
chat
in
Chats
.
get_all_chats_by_user_id
(
user
.
id
)
]
]
...
@@ -69,11 +68,12 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
...
@@ -69,11 +68,12 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
chat
=
Chats
.
get_chat_by_id_and_user_id
(
id
,
user
.
id
)
chat
=
Chats
.
get_chat_by_id_and_user_id
(
id
,
user
.
id
)
if
chat
:
if
chat
:
return
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
return
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)
})
else
:
else
:
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
detail
=
ERROR_MESSAGES
.
NOT_FOUND
detail
=
ERROR_MESSAGES
.
NOT_FOUND
)
)
############################
############################
...
@@ -82,15 +82,17 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
...
@@ -82,15 +82,17 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
@
router
.
post
(
"/{id}"
,
response_model
=
Optional
[
ChatResponse
])
@
router
.
post
(
"/{id}"
,
response_model
=
Optional
[
ChatResponse
])
async
def
update_chat_by_id
(
async
def
update_chat_by_id
(
id
:
str
,
id
:
str
,
form_data
:
ChatForm
,
user
=
Depends
(
get_current_user
)
form_data
:
ChatForm
,
):
user
=
Depends
(
get_current_user
)
):
chat
=
Chats
.
get_chat_by_id_and_user_id
(
id
,
user
.
id
)
chat
=
Chats
.
get_chat_by_id_and_user_id
(
id
,
user
.
id
)
if
chat
:
if
chat
:
updated_chat
=
{
**
json
.
loads
(
chat
.
chat
),
**
form_data
.
chat
}
updated_chat
=
{
**
json
.
loads
(
chat
.
chat
),
**
form_data
.
chat
}
chat
=
Chats
.
update_chat_by_id
(
id
,
updated_chat
)
chat
=
Chats
.
update_chat_by_id
(
id
,
updated_chat
)
return
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)})
return
ChatResponse
(
**
{
**
chat
.
model_dump
(),
"chat"
:
json
.
loads
(
chat
.
chat
)
})
else
:
else
:
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
...
...
backend/apps/web/routers/configs.py
View file @
07cc7f15
...
@@ -10,7 +10,6 @@ import uuid
...
@@ -10,7 +10,6 @@ import uuid
from
apps.web.models.users
import
Users
from
apps.web.models.users
import
Users
from
utils.utils
import
get_password_hash
,
get_current_user
,
create_token
from
utils.utils
import
get_password_hash
,
get_current_user
,
create_token
from
utils.misc
import
get_gravatar_url
,
validate_email_format
from
utils.misc
import
get_gravatar_url
,
validate_email_format
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
...
@@ -28,9 +27,9 @@ class SetDefaultModelsForm(BaseModel):
...
@@ -28,9 +27,9 @@ class SetDefaultModelsForm(BaseModel):
@
router
.
post
(
"/default/models"
,
response_model
=
str
)
@
router
.
post
(
"/default/models"
,
response_model
=
str
)
async
def
set_global_default_models
(
async
def
set_global_default_models
(
request
:
Request
,
request
:
Request
,
form_data
:
SetDefaultModelsForm
,
user
=
Depends
(
get_current_user
)
form_data
:
SetDefaultModelsForm
,
):
user
=
Depends
(
get_current_user
)
):
if
user
.
role
==
"admin"
:
if
user
.
role
==
"admin"
:
request
.
app
.
state
.
DEFAULT_MODELS
=
form_data
.
models
request
.
app
.
state
.
DEFAULT_MODELS
=
form_data
.
models
return
request
.
app
.
state
.
DEFAULT_MODELS
return
request
.
app
.
state
.
DEFAULT_MODELS
...
...
backend/apps/web/routers/modelfiles.py
View file @
07cc7f15
...
@@ -24,7 +24,9 @@ router = APIRouter()
...
@@ -24,7 +24,9 @@ router = APIRouter()
@
router
.
get
(
"/"
,
response_model
=
List
[
ModelfileResponse
])
@
router
.
get
(
"/"
,
response_model
=
List
[
ModelfileResponse
])
async
def
get_modelfiles
(
skip
:
int
=
0
,
limit
:
int
=
50
,
user
=
Depends
(
get_current_user
)):
async
def
get_modelfiles
(
skip
:
int
=
0
,
limit
:
int
=
50
,
user
=
Depends
(
get_current_user
)):
return
Modelfiles
.
get_modelfiles
(
skip
,
limit
)
return
Modelfiles
.
get_modelfiles
(
skip
,
limit
)
...
@@ -34,9 +36,8 @@ async def get_modelfiles(skip: int = 0, limit: int = 50, user=Depends(get_curren
...
@@ -34,9 +36,8 @@ async def get_modelfiles(skip: int = 0, limit: int = 50, user=Depends(get_curren
@
router
.
post
(
"/create"
,
response_model
=
Optional
[
ModelfileResponse
])
@
router
.
post
(
"/create"
,
response_model
=
Optional
[
ModelfileResponse
])
async
def
create_new_modelfile
(
async
def
create_new_modelfile
(
form_data
:
ModelfileForm
,
form_data
:
ModelfileForm
,
user
=
Depends
(
get_current_user
)
user
=
Depends
(
get_current_user
)):
):
if
user
.
role
!=
"admin"
:
if
user
.
role
!=
"admin"
:
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
...
@@ -49,9 +50,9 @@ async def create_new_modelfile(
...
@@ -49,9 +50,9 @@ async def create_new_modelfile(
return
ModelfileResponse
(
return
ModelfileResponse
(
**
{
**
{
**
modelfile
.
model_dump
(),
**
modelfile
.
model_dump
(),
"modelfile"
:
json
.
loads
(
modelfile
.
modelfile
),
"modelfile"
:
}
json
.
loads
(
modelfile
.
modelfile
),
)
}
)
else
:
else
:
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
...
@@ -65,16 +66,17 @@ async def create_new_modelfile(
...
@@ -65,16 +66,17 @@ async def create_new_modelfile(
@
router
.
post
(
"/"
,
response_model
=
Optional
[
ModelfileResponse
])
@
router
.
post
(
"/"
,
response_model
=
Optional
[
ModelfileResponse
])
async
def
get_modelfile_by_tag_name
(
form_data
:
ModelfileTagNameForm
,
user
=
Depends
(
get_current_user
)):
async
def
get_modelfile_by_tag_name
(
form_data
:
ModelfileTagNameForm
,
user
=
Depends
(
get_current_user
)):
modelfile
=
Modelfiles
.
get_modelfile_by_tag_name
(
form_data
.
tag_name
)
modelfile
=
Modelfiles
.
get_modelfile_by_tag_name
(
form_data
.
tag_name
)
if
modelfile
:
if
modelfile
:
return
ModelfileResponse
(
return
ModelfileResponse
(
**
{
**
{
**
modelfile
.
model_dump
(),
**
modelfile
.
model_dump
(),
"modelfile"
:
json
.
loads
(
modelfile
.
modelfile
),
"modelfile"
:
}
json
.
loads
(
modelfile
.
modelfile
),
)
}
)
else
:
else
:
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
...
@@ -88,9 +90,8 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, user=Depend
...
@@ -88,9 +90,8 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, user=Depend
@
router
.
post
(
"/update"
,
response_model
=
Optional
[
ModelfileResponse
])
@
router
.
post
(
"/update"
,
response_model
=
Optional
[
ModelfileResponse
])
async
def
update_modelfile_by_tag_name
(
async
def
update_modelfile_by_tag_name
(
form_data
:
ModelfileUpdateForm
,
form_data
:
ModelfileUpdateForm
,
user
=
Depends
(
get_current_user
)
user
=
Depends
(
get_current_user
)):
):
if
user
.
role
!=
"admin"
:
if
user
.
role
!=
"admin"
:
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
...
@@ -104,15 +105,14 @@ async def update_modelfile_by_tag_name(
...
@@ -104,15 +105,14 @@ async def update_modelfile_by_tag_name(
}
}
modelfile
=
Modelfiles
.
update_modelfile_by_tag_name
(
modelfile
=
Modelfiles
.
update_modelfile_by_tag_name
(
form_data
.
tag_name
,
updated_modelfile
form_data
.
tag_name
,
updated_modelfile
)
)
return
ModelfileResponse
(
return
ModelfileResponse
(
**
{
**
{
**
modelfile
.
model_dump
(),
**
modelfile
.
model_dump
(),
"modelfile"
:
json
.
loads
(
modelfile
.
modelfile
),
"modelfile"
:
}
json
.
loads
(
modelfile
.
modelfile
),
)
}
)
else
:
else
:
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
...
@@ -126,9 +126,8 @@ async def update_modelfile_by_tag_name(
...
@@ -126,9 +126,8 @@ async def update_modelfile_by_tag_name(
@
router
.
delete
(
"/delete"
,
response_model
=
bool
)
@
router
.
delete
(
"/delete"
,
response_model
=
bool
)
async
def
delete_modelfile_by_tag_name
(
async
def
delete_modelfile_by_tag_name
(
form_data
:
ModelfileTagNameForm
,
form_data
:
ModelfileTagNameForm
,
user
=
Depends
(
get_current_user
)
user
=
Depends
(
get_current_user
)):
):
if
user
.
role
!=
"admin"
:
if
user
.
role
!=
"admin"
:
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
...
...
backend/apps/web/routers/prompts.py
View file @
07cc7f15
...
@@ -6,7 +6,6 @@ from fastapi import APIRouter
...
@@ -6,7 +6,6 @@ from fastapi import APIRouter
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
import
json
import
json
from
apps.web.models.prompts
import
Prompts
,
PromptForm
,
PromptModel
from
apps.web.models.prompts
import
Prompts
,
PromptForm
,
PromptModel
from
utils.utils
import
get_current_user
from
utils.utils
import
get_current_user
...
@@ -30,7 +29,8 @@ async def get_prompts(user=Depends(get_current_user)):
...
@@ -30,7 +29,8 @@ async def get_prompts(user=Depends(get_current_user)):
@
router
.
post
(
"/create"
,
response_model
=
Optional
[
PromptModel
])
@
router
.
post
(
"/create"
,
response_model
=
Optional
[
PromptModel
])
async
def
create_new_prompt
(
form_data
:
PromptForm
,
user
=
Depends
(
get_current_user
)):
async
def
create_new_prompt
(
form_data
:
PromptForm
,
user
=
Depends
(
get_current_user
)):
if
user
.
role
!=
"admin"
:
if
user
.
role
!=
"admin"
:
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
...
@@ -79,9 +79,9 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
...
@@ -79,9 +79,9 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
@
router
.
post
(
"/{command}/update"
,
response_model
=
Optional
[
PromptModel
])
@
router
.
post
(
"/{command}/update"
,
response_model
=
Optional
[
PromptModel
])
async
def
update_prompt_by_command
(
async
def
update_prompt_by_command
(
command
:
str
,
command
:
str
,
form_data
:
PromptForm
,
user
=
Depends
(
get_current_user
)
form_data
:
PromptForm
,
):
user
=
Depends
(
get_current_user
)
):
if
user
.
role
!=
"admin"
:
if
user
.
role
!=
"admin"
:
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
...
@@ -104,7 +104,8 @@ async def update_prompt_by_command(
...
@@ -104,7 +104,8 @@ async def update_prompt_by_command(
@
router
.
delete
(
"/{command}/delete"
,
response_model
=
bool
)
@
router
.
delete
(
"/{command}/delete"
,
response_model
=
bool
)
async
def
delete_prompt_by_command
(
command
:
str
,
user
=
Depends
(
get_current_user
)):
async
def
delete_prompt_by_command
(
command
:
str
,
user
=
Depends
(
get_current_user
)):
if
user
.
role
!=
"admin"
:
if
user
.
role
!=
"admin"
:
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
...
...
backend/apps/web/routers/users.py
View file @
07cc7f15
...
@@ -11,11 +11,9 @@ import uuid
...
@@ -11,11 +11,9 @@ import uuid
from
apps.web.models.users
import
UserModel
,
UserRoleUpdateForm
,
Users
from
apps.web.models.users
import
UserModel
,
UserRoleUpdateForm
,
Users
from
apps.web.models.auths
import
Auths
from
apps.web.models.auths
import
Auths
from
utils.utils
import
get_current_user
from
utils.utils
import
get_current_user
from
constants
import
ERROR_MESSAGES
from
constants
import
ERROR_MESSAGES
router
=
APIRouter
()
router
=
APIRouter
()
############################
############################
...
@@ -24,7 +22,9 @@ router = APIRouter()
...
@@ -24,7 +22,9 @@ router = APIRouter()
@
router
.
get
(
"/"
,
response_model
=
List
[
UserModel
])
@
router
.
get
(
"/"
,
response_model
=
List
[
UserModel
])
async
def
get_users
(
skip
:
int
=
0
,
limit
:
int
=
50
,
user
=
Depends
(
get_current_user
)):
async
def
get_users
(
skip
:
int
=
0
,
limit
:
int
=
50
,
user
=
Depends
(
get_current_user
)):
if
user
.
role
!=
"admin"
:
if
user
.
role
!=
"admin"
:
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_403_FORBIDDEN
,
status_code
=
status
.
HTTP_403_FORBIDDEN
,
...
@@ -39,9 +39,8 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_use
...
@@ -39,9 +39,8 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_use
@
router
.
post
(
"/update/role"
,
response_model
=
Optional
[
UserModel
])
@
router
.
post
(
"/update/role"
,
response_model
=
Optional
[
UserModel
])
async
def
update_user_role
(
async
def
update_user_role
(
form_data
:
UserRoleUpdateForm
,
form_data
:
UserRoleUpdateForm
,
user
=
Depends
(
get_current_user
)
user
=
Depends
(
get_current_user
)):
):
if
user
.
role
!=
"admin"
:
if
user
.
role
!=
"admin"
:
raise
HTTPException
(
raise
HTTPException
(
status_code
=
status
.
HTTP_403_FORBIDDEN
,
status_code
=
status
.
HTTP_403_FORBIDDEN
,
...
...
backend/apps/web/routers/utils.py
View file @
07cc7f15
...
@@ -9,12 +9,10 @@ import os
...
@@ -9,12 +9,10 @@ import os
import
aiohttp
import
aiohttp
import
json
import
json
from
utils.misc
import
calculate_sha256
from
utils.misc
import
calculate_sha256
from
config
import
OLLAMA_API_BASE_URL
from
config
import
OLLAMA_API_BASE_URL
router
=
APIRouter
()
router
=
APIRouter
()
...
@@ -42,7 +40,10 @@ def parse_huggingface_url(hf_url):
...
@@ -42,7 +40,10 @@ def parse_huggingface_url(hf_url):
return
None
return
None
async
def
download_file_stream
(
url
,
file_path
,
file_name
,
chunk_size
=
1024
*
1024
):
async
def
download_file_stream
(
url
,
file_path
,
file_name
,
chunk_size
=
1024
*
1024
):
done
=
False
done
=
False
if
os
.
path
.
exists
(
file_path
):
if
os
.
path
.
exists
(
file_path
):
...
@@ -56,7 +57,8 @@ async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024
...
@@ -56,7 +57,8 @@ async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024
async
with
aiohttp
.
ClientSession
(
timeout
=
timeout
)
as
session
:
async
with
aiohttp
.
ClientSession
(
timeout
=
timeout
)
as
session
:
async
with
session
.
get
(
url
,
headers
=
headers
)
as
response
:
async
with
session
.
get
(
url
,
headers
=
headers
)
as
response
:
total_size
=
int
(
response
.
headers
.
get
(
"content-length"
,
0
))
+
current_size
total_size
=
int
(
response
.
headers
.
get
(
"content-length"
,
0
))
+
current_size
with
open
(
file_path
,
"ab+"
)
as
file
:
with
open
(
file_path
,
"ab+"
)
as
file
:
async
for
data
in
response
.
content
.
iter_chunked
(
chunk_size
):
async
for
data
in
response
.
content
.
iter_chunked
(
chunk_size
):
...
@@ -89,9 +91,7 @@ async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024
...
@@ -89,9 +91,7 @@ async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024
@
router
.
get
(
"/download"
)
@
router
.
get
(
"/download"
)
async
def
download
(
async
def
download
(
url
:
str
,
):
url
:
str
,
):
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
file_name
=
parse_huggingface_url
(
url
)
file_name
=
parse_huggingface_url
(
url
)
...
@@ -161,4 +161,5 @@ async def upload(file: UploadFile = File(...)):
...
@@ -161,4 +161,5 @@ async def upload(file: UploadFile = File(...)):
res
=
{
"error"
:
str
(
e
)}
res
=
{
"error"
:
str
(
e
)}
yield
f
"data:
{
json
.
dumps
(
res
)
}
\n\n
"
yield
f
"data:
{
json
.
dumps
(
res
)
}
\n\n
"
return
StreamingResponse
(
file_write_stream
(),
media_type
=
"text/event-stream"
)
return
StreamingResponse
(
file_write_stream
(),
media_type
=
"text/event-stream"
)
backend/config.py
View file @
07cc7f15
...
@@ -19,9 +19,8 @@ ENV = os.environ.get("ENV", "dev")
...
@@ -19,9 +19,8 @@ ENV = os.environ.get("ENV", "dev")
# OLLAMA_API_BASE_URL
# OLLAMA_API_BASE_URL
####################################
####################################
OLLAMA_API_BASE_URL
=
os
.
environ
.
get
(
OLLAMA_API_BASE_URL
=
os
.
environ
.
get
(
"OLLAMA_API_BASE_URL"
,
"OLLAMA_API_BASE_URL"
,
"http://localhost:11434/api"
"http://localhost:11434/api"
)
)
if
ENV
==
"prod"
:
if
ENV
==
"prod"
:
if
OLLAMA_API_BASE_URL
==
"/ollama/api"
:
if
OLLAMA_API_BASE_URL
==
"/ollama/api"
:
...
...
backend/constants.py
View file @
07cc7f15
...
@@ -6,6 +6,7 @@ class MESSAGES(str, Enum):
...
@@ -6,6 +6,7 @@ class MESSAGES(str, Enum):
class
ERROR_MESSAGES
(
str
,
Enum
):
class
ERROR_MESSAGES
(
str
,
Enum
):
def
__str__
(
self
)
->
str
:
def
__str__
(
self
)
->
str
:
return
super
().
__str__
()
return
super
().
__str__
()
...
@@ -29,8 +30,7 @@ class ERROR_MESSAGES(str, Enum):
...
@@ -29,8 +30,7 @@ class ERROR_MESSAGES(str, Enum):
UNAUTHORIZED
=
"401 Unauthorized"
UNAUTHORIZED
=
"401 Unauthorized"
ACCESS_PROHIBITED
=
"You do not have permission to access this resource. Please contact your administrator for assistance."
ACCESS_PROHIBITED
=
"You do not have permission to access this resource. Please contact your administrator for assistance."
ACTION_PROHIBITED
=
(
ACTION_PROHIBITED
=
(
"The requested action has been restricted as a security measure."
"The requested action has been restricted as a security measure."
)
)
NOT_FOUND
=
"We could not find what you're looking for :/"
NOT_FOUND
=
"We could not find what you're looking for :/"
USER_NOT_FOUND
=
"We could not find what you're looking for :/"
USER_NOT_FOUND
=
"We could not find what you're looking for :/"
MALICIOUS
=
"Unusual activities detected, please try again in a few minutes."
MALICIOUS
=
"Unusual activities detected, please try again in a few minutes."
backend/main.py
View file @
07cc7f15
...
@@ -12,6 +12,7 @@ import time
...
@@ -12,6 +12,7 @@ import time
class
SPAStaticFiles
(
StaticFiles
):
class
SPAStaticFiles
(
StaticFiles
):
async
def
get_response
(
self
,
path
:
str
,
scope
):
async
def
get_response
(
self
,
path
:
str
,
scope
):
try
:
try
:
return
await
super
().
get_response
(
path
,
scope
)
return
await
super
().
get_response
(
path
,
scope
)
...
@@ -47,4 +48,6 @@ async def check_url(request: Request, call_next):
...
@@ -47,4 +48,6 @@ async def check_url(request: Request, call_next):
app
.
mount
(
"/api/v1"
,
webui_app
)
app
.
mount
(
"/api/v1"
,
webui_app
)
app
.
mount
(
"/ollama/api"
,
WSGIMiddleware
(
ollama_app
))
app
.
mount
(
"/ollama/api"
,
WSGIMiddleware
(
ollama_app
))
app
.
mount
(
"/"
,
SPAStaticFiles
(
directory
=
"../build"
,
html
=
True
),
name
=
"spa-static-files"
)
app
.
mount
(
"/"
,
SPAStaticFiles
(
directory
=
"../build"
,
html
=
True
),
name
=
"spa-static-files"
)
backend/utils/utils.py
View file @
07cc7f15
...
@@ -23,16 +23,16 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
...
@@ -23,16 +23,16 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def
verify_password
(
plain_password
,
hashed_password
):
def
verify_password
(
plain_password
,
hashed_password
):
return
(
return
(
pwd_context
.
verify
(
plain_password
,
hashed_password
)
pwd_context
.
verify
(
plain_password
,
hashed_password
)
if
hashed_password
else
None
if
hashed_password
else
None
)
)
def
get_password_hash
(
password
):
def
get_password_hash
(
password
):
return
pwd_context
.
hash
(
password
)
return
pwd_context
.
hash
(
password
)
def
create_token
(
data
:
dict
,
expires_delta
:
Union
[
timedelta
,
None
]
=
None
)
->
str
:
def
create_token
(
data
:
dict
,
expires_delta
:
Union
[
timedelta
,
None
]
=
None
)
->
str
:
payload
=
data
.
copy
()
payload
=
data
.
copy
()
if
expires_delta
:
if
expires_delta
:
...
@@ -45,17 +45,20 @@ def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> st
...
@@ -45,17 +45,20 @@ def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> st
def
decode_token
(
token
:
str
)
->
Optional
[
dict
]:
def
decode_token
(
token
:
str
)
->
Optional
[
dict
]:
try
:
try
:
decoded
=
jwt
.
decode
(
token
,
JWT_SECRET_KEY
,
options
=
{
"verify_signature"
:
False
})
decoded
=
jwt
.
decode
(
token
,
JWT_SECRET_KEY
,
options
=
{
"verify_signature"
:
False
})
return
decoded
return
decoded
except
Exception
as
e
:
except
Exception
as
e
:
return
None
return
None
def
extract_token_from_auth_header
(
auth_header
:
str
):
def
extract_token_from_auth_header
(
auth_header
:
str
):
return
auth_header
[
len
(
"Bearer "
)
:]
return
auth_header
[
len
(
"Bearer "
):]
def
get_current_user
(
auth_token
:
HTTPAuthorizationCredentials
=
Depends
(
HTTPBearer
())):
def
get_current_user
(
auth_token
:
HTTPAuthorizationCredentials
=
Depends
(
HTTPBearer
())):
data
=
decode_token
(
auth_token
.
credentials
)
data
=
decode_token
(
auth_token
.
credentials
)
if
data
!=
None
and
"email"
in
data
:
if
data
!=
None
and
"email"
in
data
:
user
=
Users
.
get_user_by_email
(
data
[
"email"
])
user
=
Users
.
get_user_by_email
(
data
[
"email"
])
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment