Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
288d8a3e
Commit
288d8a3e
authored
May 19, 2024
by
Timothy J. Baek
Browse files
feat: memory backend
parent
1fb5ef99
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
296 additions
and
2 deletions
+296
-2
backend/apps/web/internal/migrations/008_add_memory.py
backend/apps/web/internal/migrations/008_add_memory.py
+53
-0
backend/apps/web/main.py
backend/apps/web/main.py
+5
-0
backend/apps/web/models/memories.py
backend/apps/web/models/memories.py
+109
-0
backend/apps/web/routers/memories.py
backend/apps/web/routers/memories.py
+117
-0
backend/main.py
backend/main.py
+12
-2
No files found.
backend/apps/web/internal/migrations/008_add_memory.py
0 → 100644
View file @
288d8a3e
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from
contextlib
import
suppress
import
peewee
as
pw
from
peewee_migrate
import
Migrator
with
suppress
(
ImportError
):
import
playhouse.postgres_ext
as
pw_pext
def
migrate
(
migrator
:
Migrator
,
database
:
pw
.
Database
,
*
,
fake
=
False
):
@
migrator
.
create_model
class
Memory
(
pw
.
Model
):
id
=
pw
.
CharField
(
max_length
=
255
,
unique
=
True
)
user_id
=
pw
.
CharField
(
max_length
=
255
)
content
=
pw
.
TextField
(
null
=
False
)
updated_at
=
pw
.
BigIntegerField
(
null
=
False
)
created_at
=
pw
.
BigIntegerField
(
null
=
False
)
class
Meta
:
table_name
=
"memory"
def
rollback
(
migrator
:
Migrator
,
database
:
pw
.
Database
,
*
,
fake
=
False
):
"""Write your rollback migrations here."""
migrator
.
remove_model
(
"memory"
)
backend/apps/web/main.py
View file @
288d8a3e
...
...
@@ -9,6 +9,7 @@ from apps.web.routers import (
modelfiles
,
prompts
,
configs
,
memories
,
utils
,
)
from
config
import
(
...
...
@@ -41,6 +42,7 @@ app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app
.
state
.
config
.
WEBHOOK_URL
=
WEBHOOK_URL
app
.
state
.
AUTH_TRUSTED_EMAIL_HEADER
=
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app
.
add_middleware
(
CORSMiddleware
,
allow_origins
=
origins
,
...
...
@@ -52,9 +54,12 @@ app.add_middleware(
app
.
include_router
(
auths
.
router
,
prefix
=
"/auths"
,
tags
=
[
"auths"
])
app
.
include_router
(
users
.
router
,
prefix
=
"/users"
,
tags
=
[
"users"
])
app
.
include_router
(
chats
.
router
,
prefix
=
"/chats"
,
tags
=
[
"chats"
])
app
.
include_router
(
documents
.
router
,
prefix
=
"/documents"
,
tags
=
[
"documents"
])
app
.
include_router
(
modelfiles
.
router
,
prefix
=
"/modelfiles"
,
tags
=
[
"modelfiles"
])
app
.
include_router
(
prompts
.
router
,
prefix
=
"/prompts"
,
tags
=
[
"prompts"
])
app
.
include_router
(
memories
.
router
,
prefix
=
"/memories"
,
tags
=
[
"memories"
])
app
.
include_router
(
configs
.
router
,
prefix
=
"/configs"
,
tags
=
[
"configs"
])
app
.
include_router
(
utils
.
router
,
prefix
=
"/utils"
,
tags
=
[
"utils"
])
...
...
backend/apps/web/models/memories.py
0 → 100644
View file @
288d8a3e
from
pydantic
import
BaseModel
from
peewee
import
*
from
playhouse.shortcuts
import
model_to_dict
from
typing
import
List
,
Union
,
Optional
from
apps.web.internal.db
import
DB
from
apps.web.models.chats
import
Chats
import
time
import
uuid
####################
# Memory DB Schema
####################
class
Memory
(
Model
):
id
=
CharField
(
unique
=
True
)
user_id
=
CharField
()
content
=
TextField
()
updated_at
=
BigIntegerField
()
created_at
=
BigIntegerField
()
class
Meta
:
database
=
DB
class
MemoryModel
(
BaseModel
):
id
:
str
user_id
:
str
content
:
str
updated_at
:
int
# timestamp in epoch
created_at
:
int
# timestamp in epoch
####################
# Forms
####################
class
MemoriesTable
:
def
__init__
(
self
,
db
):
self
.
db
=
db
self
.
db
.
create_tables
([
Memory
])
def
insert_new_memory
(
self
,
user_id
:
str
,
content
:
str
,
)
->
Optional
[
MemoryModel
]:
id
=
str
(
uuid
.
uuid4
())
memory
=
MemoryModel
(
**
{
"id"
:
id
,
"user_id"
:
user_id
,
"content"
:
content
,
"created_at"
:
int
(
time
.
time
()),
"updated_at"
:
int
(
time
.
time
()),
}
)
result
=
Memory
.
create
(
**
memory
.
model_dump
())
if
result
:
return
memory
else
:
return
None
def
get_memories
(
self
)
->
List
[
MemoryModel
]:
try
:
memories
=
Memory
.
select
()
return
[
MemoryModel
(
**
model_to_dict
(
memory
))
for
memory
in
memories
]
except
:
return
None
def
get_memories_by_user_id
(
self
,
user_id
:
str
)
->
List
[
MemoryModel
]:
try
:
memories
=
Memory
.
select
().
where
(
Memory
.
user_id
==
user_id
)
return
[
MemoryModel
(
**
model_to_dict
(
memory
))
for
memory
in
memories
]
except
:
return
None
def
get_memory_by_id
(
self
,
id
)
->
Optional
[
MemoryModel
]:
try
:
memory
=
Memory
.
get
(
Memory
.
id
==
id
)
return
MemoryModel
(
**
model_to_dict
(
memory
))
except
:
return
None
def
delete_memory_by_id
(
self
,
id
:
str
)
->
bool
:
try
:
query
=
Memory
.
delete
().
where
(
Memory
.
id
==
id
)
query
.
execute
()
# Remove the rows, return number of rows removed.
return
True
except
:
return
False
def
delete_memory_by_id_and_user_id
(
self
,
id
:
str
,
user_id
:
str
)
->
bool
:
try
:
query
=
Memory
.
delete
().
where
(
Memory
.
id
==
id
,
Memory
.
user_id
==
user_id
)
query
.
execute
()
return
True
except
:
return
False
Memories
=
MemoriesTable
(
DB
)
backend/apps/web/routers/memories.py
0 → 100644
View file @
288d8a3e
from
fastapi
import
Response
,
Request
from
fastapi
import
Depends
,
FastAPI
,
HTTPException
,
status
from
datetime
import
datetime
,
timedelta
from
typing
import
List
,
Union
,
Optional
from
fastapi
import
APIRouter
from
pydantic
import
BaseModel
import
logging
from
apps.web.models.memories
import
Memories
,
MemoryModel
from
utils.utils
import
get_verified_user
from
constants
import
ERROR_MESSAGES
from
config
import
SRC_LOG_LEVELS
,
CHROMA_CLIENT
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"MODELS"
])
router
=
APIRouter
()
@
router
.
get
(
"/ef"
)
async
def
get_embeddings
(
request
:
Request
):
return
{
"result"
:
request
.
app
.
state
.
EMBEDDING_FUNCTION
(
"hello world"
)}
############################
# GetMemories
############################
@
router
.
get
(
"/"
,
response_model
=
List
[
MemoryModel
])
async
def
get_memories
(
user
=
Depends
(
get_verified_user
)):
return
Memories
.
get_memories_by_user_id
(
user
.
id
)
############################
# AddMemory
############################
class
AddMemoryForm
(
BaseModel
):
content
:
str
@
router
.
post
(
"/add"
,
response_model
=
Optional
[
MemoryModel
])
async
def
add_memory
(
request
:
Request
,
form_data
:
AddMemoryForm
,
user
=
Depends
(
get_verified_user
)
):
memory
=
Memories
.
insert_new_memory
(
user
.
id
,
form_data
.
content
)
memory_embedding
=
request
.
app
.
state
.
EMBEDDING_FUNCTION
(
memory
.
content
)
collection
=
CHROMA_CLIENT
.
get_or_create_collection
(
name
=
f
"user-memory-
{
user
.
id
}
"
)
collection
.
upsert
(
documents
=
[
memory
.
content
],
ids
=
[
memory
.
id
],
embeddings
=
[
memory_embedding
],
metadatas
=
[{
"created_at"
:
memory
.
created_at
}],
)
return
memory
############################
# QueryMemory
############################
class
QueryMemoryForm
(
BaseModel
):
content
:
str
@
router
.
post
(
"/query"
,
response_model
=
Optional
[
MemoryModel
])
async
def
add_memory
(
request
:
Request
,
form_data
:
QueryMemoryForm
,
user
=
Depends
(
get_verified_user
)
):
query_embedding
=
request
.
app
.
state
.
EMBEDDING_FUNCTION
(
form_data
.
content
)
collection
=
CHROMA_CLIENT
.
get_or_create_collection
(
name
=
f
"user-memory-
{
user
.
id
}
"
)
results
=
collection
.
query
(
query_embeddings
=
[
query_embedding
],
n_results
=
1
,
# how many results to return
)
return
results
############################
# ResetMemoryFromVectorDB
############################
@
router
.
get
(
"/reset"
,
response_model
=
bool
)
async
def
reset_memory_from_vector_db
(
request
:
Request
,
user
=
Depends
(
get_verified_user
)
):
CHROMA_CLIENT
.
delete_collection
(
f
"user-memory-
{
user
.
id
}
"
)
collection
=
CHROMA_CLIENT
.
get_or_create_collection
(
name
=
f
"user-memory-
{
user
.
id
}
"
)
memories
=
Memories
.
get_memories_by_user_id
(
user
.
id
)
for
memory
in
memories
:
memory_embedding
=
request
.
app
.
state
.
EMBEDDING_FUNCTION
(
memory
.
content
)
collection
.
upsert
(
documents
=
[
memory
.
content
],
ids
=
[
memory
.
id
],
embeddings
=
[
memory_embedding
],
)
return
True
############################
# DeleteUserById
############################
@
router
.
delete
(
"/{memory_id}"
,
response_model
=
bool
)
async
def
delete_memory_by_id
(
memory_id
:
str
,
user
=
Depends
(
get_verified_user
)):
return
Memories
.
delete_memory_by_id_and_user_id
(
memory_id
,
user
.
id
)
backend/main.py
View file @
288d8a3e
...
...
@@ -238,9 +238,15 @@ async def check_url(request: Request, call_next):
return
response
app
.
mount
(
"/api/v1"
,
webui_app
)
app
.
mount
(
"/litellm/api"
,
litellm_app
)
@
app
.
middleware
(
"http"
)
async
def
update_embedding_function
(
request
:
Request
,
call_next
):
response
=
await
call_next
(
request
)
if
"/embedding/update"
in
request
.
url
.
path
:
webui_app
.
state
.
EMBEDDING_FUNCTION
=
rag_app
.
state
.
EMBEDDING_FUNCTION
return
response
app
.
mount
(
"/litellm/api"
,
litellm_app
)
app
.
mount
(
"/ollama"
,
ollama_app
)
app
.
mount
(
"/openai/api"
,
openai_app
)
...
...
@@ -248,6 +254,10 @@ app.mount("/images/api/v1", images_app)
app
.
mount
(
"/audio/api/v1"
,
audio_app
)
app
.
mount
(
"/rag/api/v1"
,
rag_app
)
app
.
mount
(
"/api/v1"
,
webui_app
)
webui_app
.
state
.
EMBEDDING_FUNCTION
=
rag_app
.
state
.
EMBEDDING_FUNCTION
@
app
.
get
(
"/api/config"
)
async
def
get_app_config
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment