Unverified Commit 9bcd4ce5 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #3559 from open-webui/dev

0.3.8
parents 824966ad b38abf23
...@@ -2,13 +2,10 @@ import json ...@@ -2,13 +2,10 @@ import json
import logging import logging
from typing import Optional from typing import Optional
import peewee as pw
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import DB, JSONField from apps.webui.internal.db import Base, JSONField, get_db
from typing import List, Union, Optional from typing import List, Union, Optional
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -32,7 +29,7 @@ class ModelParams(BaseModel): ...@@ -32,7 +29,7 @@ class ModelParams(BaseModel):
# ModelMeta is a model for the data stored in the meta field of the Model table # ModelMeta is a model for the data stored in the meta field of the Model table
class ModelMeta(BaseModel): class ModelMeta(BaseModel):
profile_image_url: Optional[str] = "/favicon.png" profile_image_url: Optional[str] = "/static/favicon.png"
description: Optional[str] = None description: Optional[str] = None
""" """
...@@ -46,38 +43,37 @@ class ModelMeta(BaseModel): ...@@ -46,38 +43,37 @@ class ModelMeta(BaseModel):
pass pass
class Model(pw.Model): class Model(Base):
id = pw.TextField(unique=True) __tablename__ = "model"
id = Column(Text, primary_key=True)
""" """
The model's id as used in the API. If set to an existing model, it will override the model. The model's id as used in the API. If set to an existing model, it will override the model.
""" """
user_id = pw.TextField() user_id = Column(Text)
base_model_id = pw.TextField(null=True) base_model_id = Column(Text, nullable=True)
""" """
An optional pointer to the actual model that should be used when proxying requests. An optional pointer to the actual model that should be used when proxying requests.
""" """
name = pw.TextField() name = Column(Text)
""" """
The human-readable display name of the model. The human-readable display name of the model.
""" """
params = JSONField() params = Column(JSONField)
""" """
Holds a JSON encoded blob of parameters, see `ModelParams`. Holds a JSON encoded blob of parameters, see `ModelParams`.
""" """
meta = JSONField() meta = Column(JSONField)
""" """
Holds a JSON encoded blob of metadata, see `ModelMeta`. Holds a JSON encoded blob of metadata, see `ModelMeta`.
""" """
updated_at = BigIntegerField() updated_at = Column(BigInteger)
created_at = BigIntegerField() created_at = Column(BigInteger)
class Meta:
database = DB
class ModelModel(BaseModel): class ModelModel(BaseModel):
...@@ -92,6 +88,8 @@ class ModelModel(BaseModel): ...@@ -92,6 +88,8 @@ class ModelModel(BaseModel):
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -115,12 +113,6 @@ class ModelForm(BaseModel): ...@@ -115,12 +113,6 @@ class ModelForm(BaseModel):
class ModelsTable: class ModelsTable:
def __init__(
self,
db: pw.SqliteDatabase | pw.PostgresqlDatabase,
):
self.db = db
self.db.create_tables([Model])
def insert_new_model( def insert_new_model(
self, form_data: ModelForm, user_id: str self, form_data: ModelForm, user_id: str
...@@ -134,34 +126,50 @@ class ModelsTable: ...@@ -134,34 +126,50 @@ class ModelsTable:
} }
) )
try: try:
result = Model.create(**model.model_dump())
if result: with get_db() as db:
return model
else: result = Model(**model.model_dump())
return None db.add(result)
db.commit()
db.refresh(result)
if result:
return ModelModel.model_validate(result)
else:
return None
except Exception as e: except Exception as e:
print(e) print(e)
return None return None
def get_all_models(self) -> List[ModelModel]: def get_all_models(self) -> List[ModelModel]:
return [ModelModel(**model_to_dict(model)) for model in Model.select()] with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_model_by_id(self, id: str) -> Optional[ModelModel]: def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try: try:
model = Model.get(Model.id == id) with get_db() as db:
return ModelModel(**model_to_dict(model))
model = db.get(Model, id)
return ModelModel.model_validate(model)
except: except:
return None return None
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
try: try:
# update only the fields that are present in the model with get_db() as db:
query = Model.update(**model.model_dump()).where(Model.id == id) # update only the fields that are present in the model
query.execute() result = (
db.query(Model)
model = Model.get(Model.id == id) .filter_by(id=id)
return ModelModel(**model_to_dict(model)) .update(model.model_dump(exclude={"id"}, exclude_none=True))
)
db.commit()
model = db.get(Model, id)
db.refresh(model)
return ModelModel.model_validate(model)
except Exception as e: except Exception as e:
print(e) print(e)
...@@ -169,11 +177,14 @@ class ModelsTable: ...@@ -169,11 +177,14 @@ class ModelsTable:
def delete_model_by_id(self, id: str) -> bool: def delete_model_by_id(self, id: str) -> bool:
try: try:
query = Model.delete().where(Model.id == id) with get_db() as db:
query.execute()
return True db.query(Model).filter_by(id=id).delete()
db.commit()
return True
except: except:
return False return False
Models = ModelsTable(DB) Models = ModelsTable()
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from peewee import * from typing import List, Optional
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time import time
from utils.utils import decode_token from sqlalchemy import String, Column, BigInteger, Text
from utils.misc import get_gravatar_url
from apps.webui.internal.db import DB from apps.webui.internal.db import Base, get_db
import json import json
...@@ -16,15 +13,14 @@ import json ...@@ -16,15 +13,14 @@ import json
#################### ####################
class Prompt(Model): class Prompt(Base):
command = CharField(unique=True) __tablename__ = "prompt"
user_id = CharField()
title = TextField()
content = TextField()
timestamp = BigIntegerField()
class Meta: command = Column(String, primary_key=True)
database = DB user_id = Column(String)
title = Column(Text)
content = Column(Text)
timestamp = Column(BigInteger)
class PromptModel(BaseModel): class PromptModel(BaseModel):
...@@ -34,6 +30,8 @@ class PromptModel(BaseModel): ...@@ -34,6 +30,8 @@ class PromptModel(BaseModel):
content: str content: str
timestamp: int # timestamp in epoch timestamp: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -48,10 +46,6 @@ class PromptForm(BaseModel): ...@@ -48,10 +46,6 @@ class PromptForm(BaseModel):
class PromptsTable: class PromptsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Prompt])
def insert_new_prompt( def insert_new_prompt(
self, user_id: str, form_data: PromptForm self, user_id: str, form_data: PromptForm
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
...@@ -66,53 +60,60 @@ class PromptsTable: ...@@ -66,53 +60,60 @@ class PromptsTable:
) )
try: try:
result = Prompt.create(**prompt.model_dump()) with get_db() as db:
if result:
return prompt result = Prompt(**prompt.dict())
else: db.add(result)
return None db.commit()
except: db.refresh(result)
if result:
return PromptModel.model_validate(result)
else:
return None
except Exception as e:
return None return None
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
try: try:
prompt = Prompt.get(Prompt.command == command) with get_db() as db:
return PromptModel(**model_to_dict(prompt))
prompt = db.query(Prompt).filter_by(command=command).first()
return PromptModel.model_validate(prompt)
except: except:
return None return None
def get_prompts(self) -> List[PromptModel]: def get_prompts(self) -> List[PromptModel]:
return [ with get_db() as db:
PromptModel(**model_to_dict(prompt))
for prompt in Prompt.select() return [
# .limit(limit).offset(skip) PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
] ]
def update_prompt_by_command( def update_prompt_by_command(
self, command: str, form_data: PromptForm self, command: str, form_data: PromptForm
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
try: try:
query = Prompt.update( with get_db() as db:
title=form_data.title,
content=form_data.content, prompt = db.query(Prompt).filter_by(command=command).first()
timestamp=int(time.time()), prompt.title = form_data.title
).where(Prompt.command == command) prompt.content = form_data.content
prompt.timestamp = int(time.time())
query.execute() db.commit()
return PromptModel.model_validate(prompt)
prompt = Prompt.get(Prompt.command == command)
return PromptModel(**model_to_dict(prompt))
except: except:
return None return None
def delete_prompt_by_command(self, command: str) -> bool: def delete_prompt_by_command(self, command: str) -> bool:
try: try:
query = Prompt.delete().where((Prompt.command == command)) with get_db() as db:
query.execute() # Remove the rows, return number of rows removed.
db.query(Prompt).filter_by(command=command).delete()
db.commit()
return True return True
except: except:
return False return False
Prompts = PromptsTable(DB) Prompts = PromptsTable()
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional from typing import List, Optional
from peewee import *
from playhouse.shortcuts import model_to_dict
import json import json
import uuid import uuid
import time import time
import logging import logging
from apps.webui.internal.db import DB from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, get_db
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -20,25 +20,23 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) ...@@ -20,25 +20,23 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
class Tag(Model): class Tag(Base):
id = CharField(unique=True) __tablename__ = "tag"
name = CharField()
user_id = CharField()
data = TextField(null=True)
class Meta: id = Column(String, primary_key=True)
database = DB name = Column(String)
user_id = Column(String)
data = Column(Text, nullable=True)
class ChatIdTag(Model): class ChatIdTag(Base):
id = CharField(unique=True) __tablename__ = "chatidtag"
tag_name = CharField()
chat_id = CharField()
user_id = CharField()
timestamp = BigIntegerField()
class Meta: id = Column(String, primary_key=True)
database = DB tag_name = Column(String)
chat_id = Column(String)
user_id = Column(String)
timestamp = Column(BigInteger)
class TagModel(BaseModel): class TagModel(BaseModel):
...@@ -47,6 +45,8 @@ class TagModel(BaseModel): ...@@ -47,6 +45,8 @@ class TagModel(BaseModel):
user_id: str user_id: str
data: Optional[str] = None data: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
class ChatIdTagModel(BaseModel): class ChatIdTagModel(BaseModel):
id: str id: str
...@@ -55,6 +55,8 @@ class ChatIdTagModel(BaseModel): ...@@ -55,6 +55,8 @@ class ChatIdTagModel(BaseModel):
user_id: str user_id: str
timestamp: int timestamp: int
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -75,28 +77,31 @@ class ChatTagsResponse(BaseModel): ...@@ -75,28 +77,31 @@ class ChatTagsResponse(BaseModel):
class TagTable: class TagTable:
def __init__(self, db):
self.db = db
db.create_tables([Tag, ChatIdTag])
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
id = str(uuid.uuid4()) with get_db() as db:
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try: id = str(uuid.uuid4())
result = Tag.create(**tag.model_dump()) tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
if result: try:
return tag result = Tag(**tag.model_dump())
else: db.add(result)
db.commit()
db.refresh(result)
if result:
return TagModel.model_validate(result)
else:
return None
except Exception as e:
return None return None
except Exception as e:
return None
def get_tag_by_name_and_user_id( def get_tag_by_name_and_user_id(
self, name: str, user_id: str self, name: str, user_id: str
) -> Optional[TagModel]: ) -> Optional[TagModel]:
try: try:
tag = Tag.get(Tag.name == name, Tag.user_id == user_id) with get_db() as db:
return TagModel(**model_to_dict(tag)) tag = db.query(Tag).filter(name=name, user_id=user_id).first()
return TagModel.model_validate(tag)
except Exception as e: except Exception as e:
return None return None
...@@ -118,82 +123,110 @@ class TagTable: ...@@ -118,82 +123,110 @@ class TagTable:
} }
) )
try: try:
result = ChatIdTag.create(**chatIdTag.model_dump()) with get_db() as db:
if result: result = ChatIdTag(**chatIdTag.model_dump())
return chatIdTag db.add(result)
else: db.commit()
return None db.refresh(result)
if result:
return ChatIdTagModel.model_validate(result)
else:
return None
except: except:
return None return None
def get_tags_by_user_id(self, user_id: str) -> List[TagModel]: def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
tag_names = [ with get_db() as db:
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name tag_names = [
for chat_id_tag in ChatIdTag.select() chat_id_tag.tag_name
.where(ChatIdTag.user_id == user_id) for chat_id_tag in (
.order_by(ChatIdTag.timestamp.desc()) db.query(ChatIdTag)
] .filter_by(user_id=user_id)
.order_by(ChatIdTag.timestamp.desc())
return [ .all()
TagModel(**model_to_dict(tag)) )
for tag in Tag.select() ]
.where(Tag.user_id == user_id)
.where(Tag.name.in_(tag_names)) return [
] TagModel.model_validate(tag)
for tag in (
db.query(Tag)
.filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
)
]
def get_tags_by_chat_id_and_user_id( def get_tags_by_chat_id_and_user_id(
self, chat_id: str, user_id: str self, chat_id: str, user_id: str
) -> List[TagModel]: ) -> List[TagModel]:
tag_names = [ with get_db() as db:
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
for chat_id_tag in ChatIdTag.select() tag_names = [
.where((ChatIdTag.user_id == user_id) & (ChatIdTag.chat_id == chat_id)) chat_id_tag.tag_name
.order_by(ChatIdTag.timestamp.desc()) for chat_id_tag in (
] db.query(ChatIdTag)
.filter_by(user_id=user_id, chat_id=chat_id)
return [ .order_by(ChatIdTag.timestamp.desc())
TagModel(**model_to_dict(tag)) .all()
for tag in Tag.select() )
.where(Tag.user_id == user_id) ]
.where(Tag.name.in_(tag_names))
] return [
TagModel.model_validate(tag)
for tag in (
db.query(Tag)
.filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
)
]
def get_chat_ids_by_tag_name_and_user_id( def get_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str self, tag_name: str, user_id: str
) -> Optional[ChatIdTagModel]: ) -> List[ChatIdTagModel]:
return [ with get_db() as db:
ChatIdTagModel(**model_to_dict(chat_id_tag))
for chat_id_tag in ChatIdTag.select() return [
.where((ChatIdTag.user_id == user_id) & (ChatIdTag.tag_name == tag_name)) ChatIdTagModel.model_validate(chat_id_tag)
.order_by(ChatIdTag.timestamp.desc()) for chat_id_tag in (
] db.query(ChatIdTag)
.filter_by(user_id=user_id, tag_name=tag_name)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
def count_chat_ids_by_tag_name_and_user_id( def count_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str self, tag_name: str, user_id: str
) -> int: ) -> int:
return ( with get_db() as db:
ChatIdTag.select()
.where((ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id)) return (
.count() db.query(ChatIdTag)
) .filter_by(tag_name=tag_name, user_id=user_id)
.count()
)
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool: def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
try: try:
query = ChatIdTag.delete().where( with get_db() as db:
(ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id) res = (
) db.query(ChatIdTag)
res = query.execute() # Remove the rows, return number of rows removed. .filter_by(tag_name=tag_name, user_id=user_id)
log.debug(f"res: {res}") .delete()
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
if tag_count == 0:
# Remove tag item from Tag col as well
query = Tag.delete().where(
(Tag.name == tag_name) & (Tag.user_id == user_id)
) )
query.execute() # Remove the rows, return number of rows removed. log.debug(f"res: {res}")
db.commit()
return True tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id
)
if tag_count == 0:
# Remove tag item from Tag col as well
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
db.commit()
return True
except Exception as e: except Exception as e:
log.error(f"delete_tag: {e}") log.error(f"delete_tag: {e}")
return False return False
...@@ -202,23 +235,25 @@ class TagTable: ...@@ -202,23 +235,25 @@ class TagTable:
self, tag_name: str, chat_id: str, user_id: str self, tag_name: str, chat_id: str, user_id: str
) -> bool: ) -> bool:
try: try:
query = ChatIdTag.delete().where( with get_db() as db:
(ChatIdTag.tag_name == tag_name)
& (ChatIdTag.chat_id == chat_id) res = (
& (ChatIdTag.user_id == user_id) db.query(ChatIdTag)
) .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
res = query.execute() # Remove the rows, return number of rows removed. .delete()
log.debug(f"res: {res}") )
log.debug(f"res: {res}")
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) db.commit()
if tag_count == 0:
# Remove tag item from Tag col as well tag_count = self.count_chat_ids_by_tag_name_and_user_id(
query = Tag.delete().where( tag_name, user_id
(Tag.name == tag_name) & (Tag.user_id == user_id)
) )
query.execute() # Remove the rows, return number of rows removed. if tag_count == 0:
# Remove tag item from Tag col as well
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
db.commit()
return True return True
except Exception as e: except Exception as e:
log.error(f"delete_tag: {e}") log.error(f"delete_tag: {e}")
return False return False
...@@ -234,4 +269,4 @@ class TagTable: ...@@ -234,4 +269,4 @@ class TagTable:
return True return True
Tags = TagTable(DB) Tags = TagTable()
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from peewee import * from typing import List, Optional
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time import time
import logging import logging
from apps.webui.internal.db import DB, JSONField from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, JSONField, get_db
from apps.webui.models.users import Users from apps.webui.models.users import Users
import json import json
...@@ -21,19 +21,18 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) ...@@ -21,19 +21,18 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
class Tool(Model): class Tool(Base):
id = CharField(unique=True) __tablename__ = "tool"
user_id = CharField()
name = TextField()
content = TextField()
specs = JSONField()
meta = JSONField()
valves = JSONField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta: id = Column(String, primary_key=True)
database = DB user_id = Column(String)
name = Column(Text)
content = Column(Text)
specs = Column(JSONField)
meta = Column(JSONField)
valves = Column(JSONField)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class ToolMeta(BaseModel): class ToolMeta(BaseModel):
...@@ -51,6 +50,8 @@ class ToolModel(BaseModel): ...@@ -51,6 +50,8 @@ class ToolModel(BaseModel):
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -78,61 +79,68 @@ class ToolValves(BaseModel): ...@@ -78,61 +79,68 @@ class ToolValves(BaseModel):
class ToolsTable: class ToolsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Tool])
def insert_new_tool( def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: List[dict] self, user_id: str, form_data: ToolForm, specs: List[dict]
) -> Optional[ToolModel]: ) -> Optional[ToolModel]:
tool = ToolModel(
**{
**form_data.model_dump(),
"specs": specs,
"user_id": user_id,
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
try: with get_db() as db:
result = Tool.create(**tool.model_dump())
if result: tool = ToolModel(
return tool **{
else: **form_data.model_dump(),
"specs": specs,
"user_id": user_id,
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
try:
result = Tool(**tool.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return ToolModel.model_validate(result)
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
return None return None
except Exception as e:
print(f"Error creating tool: {e}")
return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]: def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try: try:
tool = Tool.get(Tool.id == id) with get_db() as db:
return ToolModel(**model_to_dict(tool))
tool = db.get(Tool, id)
return ToolModel.model_validate(tool)
except: except:
return None return None
def get_tools(self) -> List[ToolModel]: def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()] with get_db() as db:
return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]: def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try: try:
tool = Tool.get(Tool.id == id) with get_db() as db:
return tool.valves if tool.valves else {}
tool = db.get(Tool, id)
return tool.valves if tool.valves else {}
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
return None return None
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
try: try:
query = Tool.update( with get_db() as db:
**{"valves": valves},
updated_at=int(time.time()), db.query(Tool).filter_by(id=id).update(
).where(Tool.id == id) {"valves": valves, "updated_at": int(time.time())}
query.execute() )
db.commit()
tool = Tool.get(Tool.id == id) return self.get_tool_by_id(id)
return ToolValves(**model_to_dict(tool))
except: except:
return None return None
...@@ -141,7 +149,7 @@ class ToolsTable: ...@@ -141,7 +149,7 @@ class ToolsTable:
) -> Optional[dict]: ) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump() user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "tools" and "valves" settings # Check if user has "tools" and "valves" settings
if "tools" not in user_settings: if "tools" not in user_settings:
...@@ -159,7 +167,7 @@ class ToolsTable: ...@@ -159,7 +167,7 @@ class ToolsTable:
) -> Optional[dict]: ) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump() user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "tools" and "valves" settings # Check if user has "tools" and "valves" settings
if "tools" not in user_settings: if "tools" not in user_settings:
...@@ -170,8 +178,7 @@ class ToolsTable: ...@@ -170,8 +178,7 @@ class ToolsTable:
user_settings["tools"]["valves"][id] = valves user_settings["tools"]["valves"][id] = valves
# Update the user settings in the database # Update the user settings in the database
query = Users.update_user_by_id(user_id, {"settings": user_settings}) Users.update_user_by_id(user_id, {"settings": user_settings})
query.execute()
return user_settings["tools"]["valves"][id] return user_settings["tools"]["valves"][id]
except Exception as e: except Exception as e:
...@@ -180,25 +187,27 @@ class ToolsTable: ...@@ -180,25 +187,27 @@ class ToolsTable:
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try: try:
query = Tool.update( with get_db() as db:
**updated, db.query(Tool).filter_by(id=id).update(
updated_at=int(time.time()), {**updated, "updated_at": int(time.time())}
).where(Tool.id == id) )
query.execute() db.commit()
tool = Tool.get(Tool.id == id) tool = db.query(Tool).get(id)
return ToolModel(**model_to_dict(tool)) db.refresh(tool)
return ToolModel.model_validate(tool)
except: except:
return None return None
def delete_tool_by_id(self, id: str) -> bool: def delete_tool_by_id(self, id: str) -> bool:
try: try:
query = Tool.delete().where((Tool.id == id)) with get_db() as db:
query.execute() # Remove the rows, return number of rows removed. db.query(Tool).filter_by(id=id).delete()
db.commit()
return True return True
except: except:
return False return False
Tools = ToolsTable(DB) Tools = ToolsTable()
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, parse_obj_as
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional from typing import List, Union, Optional
import time import time
from sqlalchemy import String, Column, BigInteger, Text
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
from apps.webui.internal.db import DB, JSONField from apps.webui.internal.db import Base, JSONField, Session, get_db
from apps.webui.models.chats import Chats from apps.webui.models.chats import Chats
#################### ####################
...@@ -13,25 +14,24 @@ from apps.webui.models.chats import Chats ...@@ -13,25 +14,24 @@ from apps.webui.models.chats import Chats
#################### ####################
class User(Model): class User(Base):
id = CharField(unique=True) __tablename__ = "user"
name = CharField()
email = CharField()
role = CharField()
profile_image_url = TextField()
last_active_at = BigIntegerField() id = Column(String, primary_key=True)
updated_at = BigIntegerField() name = Column(String)
created_at = BigIntegerField() email = Column(String)
role = Column(String)
profile_image_url = Column(Text)
api_key = CharField(null=True, unique=True) last_active_at = Column(BigInteger)
settings = JSONField(null=True) updated_at = Column(BigInteger)
info = JSONField(null=True) created_at = Column(BigInteger)
oauth_sub = TextField(null=True, unique=True) api_key = Column(String, nullable=True, unique=True)
settings = Column(JSONField, nullable=True)
info = Column(JSONField, nullable=True)
class Meta: oauth_sub = Column(Text, unique=True)
database = DB
class UserSettings(BaseModel): class UserSettings(BaseModel):
...@@ -57,6 +57,8 @@ class UserModel(BaseModel): ...@@ -57,6 +57,8 @@ class UserModel(BaseModel):
oauth_sub: Optional[str] = None oauth_sub: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -76,9 +78,6 @@ class UserUpdateForm(BaseModel): ...@@ -76,9 +78,6 @@ class UserUpdateForm(BaseModel):
class UsersTable: class UsersTable:
def __init__(self, db):
self.db = db
self.db.create_tables([User])
def insert_new_user( def insert_new_user(
self, self,
...@@ -89,77 +88,92 @@ class UsersTable: ...@@ -89,77 +88,92 @@ class UsersTable:
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None, oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
user = UserModel( with get_db() as db:
**{ user = UserModel(
"id": id, **{
"name": name, "id": id,
"email": email, "name": name,
"role": role, "email": email,
"profile_image_url": profile_image_url, "role": role,
"last_active_at": int(time.time()), "profile_image_url": profile_image_url,
"created_at": int(time.time()), "last_active_at": int(time.time()),
"updated_at": int(time.time()), "created_at": int(time.time()),
"oauth_sub": oauth_sub, "updated_at": int(time.time()),
} "oauth_sub": oauth_sub,
) }
result = User.create(**user.model_dump()) )
if result: result = User(**user.model_dump())
return user db.add(result)
else: db.commit()
return None db.refresh(result)
if result:
return user
else:
return None
def get_user_by_id(self, id: str) -> Optional[UserModel]: def get_user_by_id(self, id: str) -> Optional[UserModel]:
try: try:
user = User.get(User.id == id) with get_db() as db:
return UserModel(**model_to_dict(user)) user = db.query(User).filter_by(id=id).first()
except: return UserModel.model_validate(user)
except Exception as e:
return None return None
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try: try:
user = User.get(User.api_key == api_key) with get_db() as db:
return UserModel(**model_to_dict(user))
user = db.query(User).filter_by(api_key=api_key).first()
return UserModel.model_validate(user)
except: except:
return None return None
def get_user_by_email(self, email: str) -> Optional[UserModel]: def get_user_by_email(self, email: str) -> Optional[UserModel]:
try: try:
user = User.get(User.email == email) with get_db() as db:
return UserModel(**model_to_dict(user))
user = db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user)
except: except:
return None return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try: try:
user = User.get(User.oauth_sub == sub) with get_db() as db:
return UserModel(**model_to_dict(user))
user = db.query(User).filter_by(oauth_sub=sub).first()
return UserModel.model_validate(user)
except: except:
return None return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [ with get_db() as db:
UserModel(**model_to_dict(user)) users = (
for user in User.select() db.query(User)
# .limit(limit).offset(skip) # .offset(skip).limit(limit)
] .all()
)
return [UserModel.model_validate(user) for user in users]
def get_num_users(self) -> Optional[int]: def get_num_users(self) -> Optional[int]:
return User.select().count() with get_db() as db:
return db.query(User).count()
def get_first_user(self) -> UserModel: def get_first_user(self) -> UserModel:
try: try:
user = User.select().order_by(User.created_at).first() with get_db() as db:
return UserModel(**model_to_dict(user)) user = db.query(User).order_by(User.created_at).first()
return UserModel.model_validate(user)
except: except:
return None return None
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
try: try:
query = User.update(role=role).where(User.id == id) with get_db() as db:
query.execute() db.query(User).filter_by(id=id).update({"role": role})
db.commit()
user = User.get(User.id == id) user = db.query(User).filter_by(id=id).first()
return UserModel(**model_to_dict(user)) return UserModel.model_validate(user)
except: except:
return None return None
...@@ -167,23 +181,28 @@ class UsersTable: ...@@ -167,23 +181,28 @@ class UsersTable:
self, id: str, profile_image_url: str self, id: str, profile_image_url: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
try: try:
query = User.update(profile_image_url=profile_image_url).where( with get_db() as db:
User.id == id db.query(User).filter_by(id=id).update(
) {"profile_image_url": profile_image_url}
query.execute() )
db.commit()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user)) user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except: except:
return None return None
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
try: try:
query = User.update(last_active_at=int(time.time())).where(User.id == id) with get_db() as db:
query.execute()
user = User.get(User.id == id) db.query(User).filter_by(id=id).update(
return UserModel(**model_to_dict(user)) {"last_active_at": int(time.time())}
)
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except: except:
return None return None
...@@ -191,22 +210,25 @@ class UsersTable: ...@@ -191,22 +210,25 @@ class UsersTable:
self, id: str, oauth_sub: str self, id: str, oauth_sub: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
try: try:
query = User.update(oauth_sub=oauth_sub).where(User.id == id) with get_db() as db:
query.execute() db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
db.commit()
user = User.get(User.id == id) user = db.query(User).filter_by(id=id).first()
return UserModel(**model_to_dict(user)) return UserModel.model_validate(user)
except: except:
return None return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try: try:
query = User.update(**updated).where(User.id == id) with get_db() as db:
query.execute() db.query(User).filter_by(id=id).update(updated)
db.commit()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user)) user = db.query(User).filter_by(id=id).first()
except: return UserModel.model_validate(user)
# return UserModel(**user.dict())
except Exception as e:
return None return None
def delete_user_by_id(self, id: str) -> bool: def delete_user_by_id(self, id: str) -> bool:
...@@ -215,9 +237,10 @@ class UsersTable: ...@@ -215,9 +237,10 @@ class UsersTable:
result = Chats.delete_chats_by_user_id(id) result = Chats.delete_chats_by_user_id(id)
if result: if result:
# Delete User with get_db() as db:
query = User.delete().where(User.id == id) # Delete User
query.execute() # Remove the rows, return number of rows removed. db.query(User).filter_by(id=id).delete()
db.commit()
return True return True
else: else:
...@@ -227,19 +250,20 @@ class UsersTable: ...@@ -227,19 +250,20 @@ class UsersTable:
def update_user_api_key_by_id(self, id: str, api_key: str) -> str: def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
try: try:
query = User.update(api_key=api_key).where(User.id == id) with get_db() as db:
result = query.execute() result = db.query(User).filter_by(id=id).update({"api_key": api_key})
db.commit()
return True if result == 1 else False return True if result == 1 else False
except: except:
return False return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]: def get_user_api_key_by_id(self, id: str) -> Optional[str]:
try: try:
user = User.get(User.id == id) with get_db() as db:
return user.api_key user = db.query(User).filter_by(id=id).first()
except: return user.api_key
except Exception as e:
return None return None
Users = UsersTable(DB) Users = UsersTable()
...@@ -76,7 +76,10 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user ...@@ -76,7 +76,10 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
@router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse]) @router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_user_id( async def get_user_chat_list_by_user_id(
user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50 user_id: str,
user=Depends(get_admin_user),
skip: int = 0,
limit: int = 50,
): ):
return Chats.get_chat_list_by_user_id( return Chats.get_chat_list_by_user_id(
user_id, include_archived=True, skip=skip, limit=limit user_id, include_archived=True, skip=skip, limit=limit
...@@ -119,7 +122,7 @@ async def get_user_chats(user=Depends(get_verified_user)): ...@@ -119,7 +122,7 @@ async def get_user_chats(user=Depends(get_verified_user)):
@router.get("/all/archived", response_model=List[ChatResponse]) @router.get("/all/archived", response_model=List[ChatResponse])
async def get_user_chats(user=Depends(get_verified_user)): async def get_user_archived_chats(user=Depends(get_verified_user)):
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_archived_chats_by_user_id(user.id) for chat in Chats.get_archived_chats_by_user_id(user.id)
...@@ -207,7 +210,6 @@ async def get_user_chat_list_by_tag_name( ...@@ -207,7 +210,6 @@ async def get_user_chat_list_by_tag_name(
form_data: TagNameForm, user=Depends(get_verified_user) form_data: TagNameForm, user=Depends(get_verified_user)
): ):
print(form_data)
chat_ids = [ chat_ids = [
chat_id_tag.chat_id chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id( for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
......
...@@ -130,7 +130,9 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_ ...@@ -130,7 +130,9 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_
@router.post("/doc/update", response_model=Optional[DocumentResponse]) @router.post("/doc/update", response_model=Optional[DocumentResponse])
async def update_doc_by_name( async def update_doc_by_name(
name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user) name: str,
form_data: DocumentUpdateForm,
user=Depends(get_admin_user),
): ):
doc = Documents.update_doc_by_name(name, form_data) doc = Documents.update_doc_by_name(name, form_data)
if doc: if doc:
......
...@@ -50,10 +50,7 @@ router = APIRouter() ...@@ -50,10 +50,7 @@ router = APIRouter()
@router.post("/") @router.post("/")
def upload_file( def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
file: UploadFile = File(...),
user=Depends(get_verified_user),
):
log.info(f"file.content_type: {file.content_type}") log.info(f"file.content_type: {file.content_type}")
try: try:
unsanitized_filename = file.filename unsanitized_filename = file.filename
......
...@@ -233,7 +233,10 @@ async def delete_function_by_id( ...@@ -233,7 +233,10 @@ async def delete_function_by_id(
# delete the function file # delete the function file
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py") function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
os.remove(function_path) try:
os.remove(function_path)
except:
pass
return result return result
......
...@@ -50,7 +50,9 @@ class MemoryUpdateModel(BaseModel): ...@@ -50,7 +50,9 @@ class MemoryUpdateModel(BaseModel):
@router.post("/add", response_model=Optional[MemoryModel]) @router.post("/add", response_model=Optional[MemoryModel])
async def add_memory( async def add_memory(
request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user) request: Request,
form_data: AddMemoryForm,
user=Depends(get_verified_user),
): ):
memory = Memories.insert_new_memory(user.id, form_data.content) memory = Memories.insert_new_memory(user.id, form_data.content)
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
......
...@@ -5,6 +5,7 @@ from typing import List, Union, Optional ...@@ -5,6 +5,7 @@ from typing import List, Union, Optional
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
from utils.utils import get_verified_user, get_admin_user from utils.utils import get_verified_user, get_admin_user
...@@ -29,7 +30,9 @@ async def get_models(user=Depends(get_verified_user)): ...@@ -29,7 +30,9 @@ async def get_models(user=Depends(get_verified_user)):
@router.post("/add", response_model=Optional[ModelModel]) @router.post("/add", response_model=Optional[ModelModel])
async def add_new_model( async def add_new_model(
request: Request, form_data: ModelForm, user=Depends(get_admin_user) request: Request,
form_data: ModelForm,
user=Depends(get_admin_user),
): ):
if form_data.id in request.app.state.MODELS: if form_data.id in request.app.state.MODELS:
raise HTTPException( raise HTTPException(
...@@ -73,7 +76,10 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)): ...@@ -73,7 +76,10 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/update", response_model=Optional[ModelModel]) @router.post("/update", response_model=Optional[ModelModel])
async def update_model_by_id( async def update_model_by_id(
request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user) request: Request,
id: str,
form_data: ModelForm,
user=Depends(get_admin_user),
): ):
model = Models.get_model_by_id(id) model = Models.get_model_by_id(id)
if model: if model:
......
...@@ -71,7 +71,9 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)): ...@@ -71,7 +71,9 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
@router.post("/command/{command}/update", response_model=Optional[PromptModel]) @router.post("/command/{command}/update", response_model=Optional[PromptModel])
async def update_prompt_by_command( async def update_prompt_by_command(
command: str, form_data: PromptForm, user=Depends(get_admin_user) command: str,
form_data: PromptForm,
user=Depends(get_admin_user),
): ):
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt: if prompt:
......
...@@ -6,7 +6,6 @@ from fastapi import APIRouter ...@@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.models.users import Users from apps.webui.models.users import Users
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id from apps.webui.utils import load_toolkit_module_by_id
...@@ -57,7 +56,9 @@ async def get_toolkits(user=Depends(get_admin_user)): ...@@ -57,7 +56,9 @@ async def get_toolkits(user=Depends(get_admin_user)):
@router.post("/create", response_model=Optional[ToolResponse]) @router.post("/create", response_model=Optional[ToolResponse])
async def create_new_toolkit( async def create_new_toolkit(
request: Request, form_data: ToolForm, user=Depends(get_admin_user) request: Request,
form_data: ToolForm,
user=Depends(get_admin_user),
): ):
if not form_data.id.isidentifier(): if not form_data.id.isidentifier():
raise HTTPException( raise HTTPException(
...@@ -131,7 +132,10 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): ...@@ -131,7 +132,10 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/update", response_model=Optional[ToolModel]) @router.post("/id/{id}/update", response_model=Optional[ToolModel])
async def update_toolkit_by_id( async def update_toolkit_by_id(
request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user) request: Request,
id: str,
form_data: ToolForm,
user=Depends(get_admin_user),
): ):
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py") toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
......
...@@ -138,7 +138,7 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)): ...@@ -138,7 +138,7 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)):
@router.post("/user/info/update", response_model=Optional[dict]) @router.post("/user/info/update", response_model=Optional[dict])
async def update_user_settings_by_session_user( async def update_user_info_by_session_user(
form_data: dict, user=Depends(get_verified_user) form_data: dict, user=Depends(get_verified_user)
): ):
user = Users.get_user_by_id(user.id) user = Users.get_user_by_id(user.id)
...@@ -205,7 +205,9 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): ...@@ -205,7 +205,9 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
@router.post("/{user_id}/update", response_model=Optional[UserModel]) @router.post("/{user_id}/update", response_model=Optional[UserModel])
async def update_user_by_id( async def update_user_by_id(
user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user) user_id: str,
form_data: UserUpdateForm,
session_user=Depends(get_admin_user),
): ):
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id)
......
from fastapi import APIRouter, UploadFile, File, Response from fastapi import APIRouter, UploadFile, File, Response
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from peewee import SqliteDatabase
from starlette.responses import StreamingResponse, FileResponse from starlette.responses import StreamingResponse, FileResponse
from pydantic import BaseModel from pydantic import BaseModel
...@@ -10,7 +9,6 @@ import markdown ...@@ -10,7 +9,6 @@ import markdown
import black import black
from apps.webui.internal.db import DB
from utils.utils import get_admin_user from utils.utils import get_admin_user
from utils.misc import calculate_sha256, get_gravatar_url from utils.misc import calculate_sha256, get_gravatar_url
...@@ -114,13 +112,15 @@ async def download_db(user=Depends(get_admin_user)): ...@@ -114,13 +112,15 @@ async def download_db(user=Depends(get_admin_user)):
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
if not isinstance(DB, SqliteDatabase): from apps.webui.internal.db import engine
if engine.name != "sqlite":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DB_NOT_SQLITE, detail=ERROR_MESSAGES.DB_NOT_SQLITE,
) )
return FileResponse( return FileResponse(
DB.database, engine.url.database,
media_type="application/octet-stream", media_type="application/octet-stream",
filename="webui.db", filename="webui.db",
) )
......
...@@ -5,9 +5,8 @@ import importlib.metadata ...@@ -5,9 +5,8 @@ import importlib.metadata
import pkgutil import pkgutil
import chromadb import chromadb
from chromadb import Settings from chromadb import Settings
from base64 import b64encode
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from typing import TypeVar, Generic, Union from typing import TypeVar, Generic
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
...@@ -19,7 +18,6 @@ import markdown ...@@ -19,7 +18,6 @@ import markdown
import requests import requests
import shutil import shutil
from secrets import token_bytes
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
#################################### ####################################
...@@ -395,6 +393,18 @@ OAUTH_PROVIDER_NAME = PersistentConfig( ...@@ -395,6 +393,18 @@ OAUTH_PROVIDER_NAME = PersistentConfig(
os.environ.get("OAUTH_PROVIDER_NAME", "SSO"), os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
) )
OAUTH_USERNAME_CLAIM = PersistentConfig(
"OAUTH_USERNAME_CLAIM",
"oauth.oidc.username_claim",
os.environ.get("OAUTH_USERNAME_CLAIM", "name"),
)
OAUTH_PICTURE_CLAIM = PersistentConfig(
"OAUTH_USERNAME_CLAIM",
"oauth.oidc.avatar_claim",
os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
)
def load_oauth_providers(): def load_oauth_providers():
OAUTH_PROVIDERS.clear() OAUTH_PROVIDERS.clear()
...@@ -440,16 +450,27 @@ load_oauth_providers() ...@@ -440,16 +450,27 @@ load_oauth_providers()
STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve() STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve()
frontend_favicon = FRONTEND_BUILD_DIR / "favicon.png" frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png"
if frontend_favicon.exists(): if frontend_favicon.exists():
try: try:
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png") shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
except Exception as e: except Exception as e:
logging.error(f"An error occurred: {e}") logging.error(f"An error occurred: {e}")
else: else:
logging.warning(f"Frontend favicon not found at {frontend_favicon}") logging.warning(f"Frontend favicon not found at {frontend_favicon}")
frontend_splash = FRONTEND_BUILD_DIR / "static" / "splash.png"
if frontend_splash.exists():
try:
shutil.copyfile(frontend_splash, STATIC_DIR / "splash.png")
except Exception as e:
logging.error(f"An error occurred: {e}")
else:
logging.warning(f"Frontend splash not found at {frontend_splash}")
#################################### ####################################
# CUSTOM_NAME # CUSTOM_NAME
#################################### ####################################
...@@ -474,6 +495,19 @@ if CUSTOM_NAME: ...@@ -474,6 +495,19 @@ if CUSTOM_NAME:
r.raw.decode_content = True r.raw.decode_content = True
shutil.copyfileobj(r.raw, f) shutil.copyfileobj(r.raw, f)
if "splash" in data:
url = (
f"https://api.openwebui.com{data['splash']}"
if data["splash"][0] == "/"
else data["splash"]
)
r = requests.get(url, stream=True)
if r.status_code == 200:
with open(f"{STATIC_DIR}/splash.png", "wb") as f:
r.raw.decode_content = True
shutil.copyfileobj(r.raw, f)
WEBUI_NAME = data["name"] WEBUI_NAME = data["name"]
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
...@@ -769,11 +803,14 @@ class BannerModel(BaseModel): ...@@ -769,11 +803,14 @@ class BannerModel(BaseModel):
timestamp: int timestamp: int
WEBUI_BANNERS = PersistentConfig( try:
"WEBUI_BANNERS", banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]"))
"ui.banners", banners = [BannerModel(**banner) for banner in banners]
[BannerModel(**banner) for banner in json.loads("[]")], except Exception as e:
) print(f"Error loading WEBUI_BANNERS: {e}")
banners = []
WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners)
SHOW_ADMIN_DETAILS = PersistentConfig( SHOW_ADMIN_DETAILS = PersistentConfig(
...@@ -885,6 +922,22 @@ WEBUI_SESSION_COOKIE_SECURE = os.environ.get( ...@@ -885,6 +922,22 @@ WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
if WEBUI_AUTH and WEBUI_SECRET_KEY == "": if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
####################################
# RAG document content extraction
####################################
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
"CONTENT_EXTRACTION_ENGINE",
"rag.CONTENT_EXTRACTION_ENGINE",
os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
)
TIKA_SERVER_URL = PersistentConfig(
"TIKA_SERVER_URL",
"rag.tika_server_url",
os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment
)
#################################### ####################################
# RAG # RAG
#################################### ####################################
...@@ -1302,3 +1355,7 @@ AUDIO_TTS_VOICE = PersistentConfig( ...@@ -1302,3 +1355,7 @@ AUDIO_TTS_VOICE = PersistentConfig(
#################################### ####################################
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db") DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
# Replace the postgres:// with postgresql://
if "postgres://" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")
...@@ -89,3 +89,14 @@ class ERROR_MESSAGES(str, Enum): ...@@ -89,3 +89,14 @@ class ERROR_MESSAGES(str, Enum):
OLLAMA_API_DISABLED = ( OLLAMA_API_DISABLED = (
"The Ollama API is disabled. Please enable it to use this feature." "The Ollama API is disabled. Please enable it to use this feature."
) )
class TASKS(str, Enum):
def __str__(self) -> str:
return super().__str__()
DEFAULT = lambda task="": f"{task if task else 'default'}"
TITLE_GENERATION = "Title Generation"
EMOJI_GENERATION = "Emoji Generation"
QUERY_GENERATION = "Query Generation"
FUNCTION_CALLING = "Function Calling"
This diff is collapsed.
Generic single-database configuration.
Create new migrations with
DATABASE_URL=<replace with actual url> alembic revision --autogenerate -m "a description"
import os
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from apps.webui.models.auths import Auth
from apps.webui.models.chats import Chat
from apps.webui.models.documents import Document
from apps.webui.models.memories import Memory
from apps.webui.models.models import Model
from apps.webui.models.prompts import Prompt
from apps.webui.models.tags import Tag, ChatIdTag
from apps.webui.models.tools import Tool
from apps.webui.models.users import User
from apps.webui.models.files import File
from apps.webui.models.functions import Function
from config import DATABASE_URL
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Auth.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
DB_URL = DATABASE_URL
if DB_URL:
config.set_main_option("sqlalchemy.url", DB_URL.replace("%", "%%"))
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment