Unverified Commit 9bcd4ce5 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #3559 from open-webui/dev

0.3.8
parents 824966ad b38abf23
...@@ -2,13 +2,10 @@ import json ...@@ -2,13 +2,10 @@ import json
import logging import logging
from typing import Optional from typing import Optional
import peewee as pw
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import DB, JSONField from apps.webui.internal.db import Base, JSONField, get_db
from typing import List, Union, Optional from typing import List, Union, Optional
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -32,7 +29,7 @@ class ModelParams(BaseModel): ...@@ -32,7 +29,7 @@ class ModelParams(BaseModel):
# ModelMeta is a model for the data stored in the meta field of the Model table # ModelMeta is a model for the data stored in the meta field of the Model table
class ModelMeta(BaseModel): class ModelMeta(BaseModel):
profile_image_url: Optional[str] = "/favicon.png" profile_image_url: Optional[str] = "/static/favicon.png"
description: Optional[str] = None description: Optional[str] = None
""" """
...@@ -46,38 +43,37 @@ class ModelMeta(BaseModel): ...@@ -46,38 +43,37 @@ class ModelMeta(BaseModel):
pass pass
class Model(pw.Model): class Model(Base):
id = pw.TextField(unique=True) __tablename__ = "model"
id = Column(Text, primary_key=True)
""" """
The model's id as used in the API. If set to an existing model, it will override the model. The model's id as used in the API. If set to an existing model, it will override the model.
""" """
user_id = pw.TextField() user_id = Column(Text)
base_model_id = pw.TextField(null=True) base_model_id = Column(Text, nullable=True)
""" """
An optional pointer to the actual model that should be used when proxying requests. An optional pointer to the actual model that should be used when proxying requests.
""" """
name = pw.TextField() name = Column(Text)
""" """
The human-readable display name of the model. The human-readable display name of the model.
""" """
params = JSONField() params = Column(JSONField)
""" """
Holds a JSON encoded blob of parameters, see `ModelParams`. Holds a JSON encoded blob of parameters, see `ModelParams`.
""" """
meta = JSONField() meta = Column(JSONField)
""" """
Holds a JSON encoded blob of metadata, see `ModelMeta`. Holds a JSON encoded blob of metadata, see `ModelMeta`.
""" """
updated_at = BigIntegerField() updated_at = Column(BigInteger)
created_at = BigIntegerField() created_at = Column(BigInteger)
class Meta:
database = DB
class ModelModel(BaseModel): class ModelModel(BaseModel):
...@@ -92,6 +88,8 @@ class ModelModel(BaseModel): ...@@ -92,6 +88,8 @@ class ModelModel(BaseModel):
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -115,12 +113,6 @@ class ModelForm(BaseModel): ...@@ -115,12 +113,6 @@ class ModelForm(BaseModel):
class ModelsTable: class ModelsTable:
def __init__(
self,
db: pw.SqliteDatabase | pw.PostgresqlDatabase,
):
self.db = db
self.db.create_tables([Model])
def insert_new_model( def insert_new_model(
self, form_data: ModelForm, user_id: str self, form_data: ModelForm, user_id: str
...@@ -134,34 +126,50 @@ class ModelsTable: ...@@ -134,34 +126,50 @@ class ModelsTable:
} }
) )
try: try:
result = Model.create(**model.model_dump())
if result: with get_db() as db:
return model
else: result = Model(**model.model_dump())
return None db.add(result)
db.commit()
db.refresh(result)
if result:
return ModelModel.model_validate(result)
else:
return None
except Exception as e: except Exception as e:
print(e) print(e)
return None return None
def get_all_models(self) -> List[ModelModel]: def get_all_models(self) -> List[ModelModel]:
return [ModelModel(**model_to_dict(model)) for model in Model.select()] with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_model_by_id(self, id: str) -> Optional[ModelModel]: def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try: try:
model = Model.get(Model.id == id) with get_db() as db:
return ModelModel(**model_to_dict(model))
model = db.get(Model, id)
return ModelModel.model_validate(model)
except: except:
return None return None
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]: def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
try: try:
# update only the fields that are present in the model with get_db() as db:
query = Model.update(**model.model_dump()).where(Model.id == id) # update only the fields that are present in the model
query.execute() result = (
db.query(Model)
model = Model.get(Model.id == id) .filter_by(id=id)
return ModelModel(**model_to_dict(model)) .update(model.model_dump(exclude={"id"}, exclude_none=True))
)
db.commit()
model = db.get(Model, id)
db.refresh(model)
return ModelModel.model_validate(model)
except Exception as e: except Exception as e:
print(e) print(e)
...@@ -169,11 +177,14 @@ class ModelsTable: ...@@ -169,11 +177,14 @@ class ModelsTable:
def delete_model_by_id(self, id: str) -> bool: def delete_model_by_id(self, id: str) -> bool:
try: try:
query = Model.delete().where(Model.id == id) with get_db() as db:
query.execute()
return True db.query(Model).filter_by(id=id).delete()
db.commit()
return True
except: except:
return False return False
Models = ModelsTable(DB) Models = ModelsTable()
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from peewee import * from typing import List, Optional
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time import time
from utils.utils import decode_token from sqlalchemy import String, Column, BigInteger, Text
from utils.misc import get_gravatar_url
from apps.webui.internal.db import DB from apps.webui.internal.db import Base, get_db
import json import json
...@@ -16,15 +13,14 @@ import json ...@@ -16,15 +13,14 @@ import json
#################### ####################
class Prompt(Model): class Prompt(Base):
command = CharField(unique=True) __tablename__ = "prompt"
user_id = CharField()
title = TextField()
content = TextField()
timestamp = BigIntegerField()
class Meta: command = Column(String, primary_key=True)
database = DB user_id = Column(String)
title = Column(Text)
content = Column(Text)
timestamp = Column(BigInteger)
class PromptModel(BaseModel): class PromptModel(BaseModel):
...@@ -34,6 +30,8 @@ class PromptModel(BaseModel): ...@@ -34,6 +30,8 @@ class PromptModel(BaseModel):
content: str content: str
timestamp: int # timestamp in epoch timestamp: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -48,10 +46,6 @@ class PromptForm(BaseModel): ...@@ -48,10 +46,6 @@ class PromptForm(BaseModel):
class PromptsTable: class PromptsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Prompt])
def insert_new_prompt( def insert_new_prompt(
self, user_id: str, form_data: PromptForm self, user_id: str, form_data: PromptForm
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
...@@ -66,53 +60,60 @@ class PromptsTable: ...@@ -66,53 +60,60 @@ class PromptsTable:
) )
try: try:
result = Prompt.create(**prompt.model_dump()) with get_db() as db:
if result:
return prompt result = Prompt(**prompt.dict())
else: db.add(result)
return None db.commit()
except: db.refresh(result)
if result:
return PromptModel.model_validate(result)
else:
return None
except Exception as e:
return None return None
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
try: try:
prompt = Prompt.get(Prompt.command == command) with get_db() as db:
return PromptModel(**model_to_dict(prompt))
prompt = db.query(Prompt).filter_by(command=command).first()
return PromptModel.model_validate(prompt)
except: except:
return None return None
def get_prompts(self) -> List[PromptModel]: def get_prompts(self) -> List[PromptModel]:
return [ with get_db() as db:
PromptModel(**model_to_dict(prompt))
for prompt in Prompt.select() return [
# .limit(limit).offset(skip) PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
] ]
def update_prompt_by_command( def update_prompt_by_command(
self, command: str, form_data: PromptForm self, command: str, form_data: PromptForm
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
try: try:
query = Prompt.update( with get_db() as db:
title=form_data.title,
content=form_data.content, prompt = db.query(Prompt).filter_by(command=command).first()
timestamp=int(time.time()), prompt.title = form_data.title
).where(Prompt.command == command) prompt.content = form_data.content
prompt.timestamp = int(time.time())
query.execute() db.commit()
return PromptModel.model_validate(prompt)
prompt = Prompt.get(Prompt.command == command)
return PromptModel(**model_to_dict(prompt))
except: except:
return None return None
def delete_prompt_by_command(self, command: str) -> bool: def delete_prompt_by_command(self, command: str) -> bool:
try: try:
query = Prompt.delete().where((Prompt.command == command)) with get_db() as db:
query.execute() # Remove the rows, return number of rows removed.
db.query(Prompt).filter_by(command=command).delete()
db.commit()
return True return True
except: except:
return False return False
Prompts = PromptsTable(DB) Prompts = PromptsTable()
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional from typing import List, Optional
from peewee import *
from playhouse.shortcuts import model_to_dict
import json import json
import uuid import uuid
import time import time
import logging import logging
from apps.webui.internal.db import DB from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, get_db
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -20,25 +20,23 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) ...@@ -20,25 +20,23 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
class Tag(Model): class Tag(Base):
id = CharField(unique=True) __tablename__ = "tag"
name = CharField()
user_id = CharField()
data = TextField(null=True)
class Meta: id = Column(String, primary_key=True)
database = DB name = Column(String)
user_id = Column(String)
data = Column(Text, nullable=True)
class ChatIdTag(Model): class ChatIdTag(Base):
id = CharField(unique=True) __tablename__ = "chatidtag"
tag_name = CharField()
chat_id = CharField()
user_id = CharField()
timestamp = BigIntegerField()
class Meta: id = Column(String, primary_key=True)
database = DB tag_name = Column(String)
chat_id = Column(String)
user_id = Column(String)
timestamp = Column(BigInteger)
class TagModel(BaseModel): class TagModel(BaseModel):
...@@ -47,6 +45,8 @@ class TagModel(BaseModel): ...@@ -47,6 +45,8 @@ class TagModel(BaseModel):
user_id: str user_id: str
data: Optional[str] = None data: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
class ChatIdTagModel(BaseModel): class ChatIdTagModel(BaseModel):
id: str id: str
...@@ -55,6 +55,8 @@ class ChatIdTagModel(BaseModel): ...@@ -55,6 +55,8 @@ class ChatIdTagModel(BaseModel):
user_id: str user_id: str
timestamp: int timestamp: int
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -75,28 +77,31 @@ class ChatTagsResponse(BaseModel): ...@@ -75,28 +77,31 @@ class ChatTagsResponse(BaseModel):
class TagTable: class TagTable:
def __init__(self, db):
self.db = db
db.create_tables([Tag, ChatIdTag])
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
id = str(uuid.uuid4()) with get_db() as db:
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try: id = str(uuid.uuid4())
result = Tag.create(**tag.model_dump()) tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
if result: try:
return tag result = Tag(**tag.model_dump())
else: db.add(result)
db.commit()
db.refresh(result)
if result:
return TagModel.model_validate(result)
else:
return None
except Exception as e:
return None return None
except Exception as e:
return None
def get_tag_by_name_and_user_id( def get_tag_by_name_and_user_id(
self, name: str, user_id: str self, name: str, user_id: str
) -> Optional[TagModel]: ) -> Optional[TagModel]:
try: try:
tag = Tag.get(Tag.name == name, Tag.user_id == user_id) with get_db() as db:
return TagModel(**model_to_dict(tag)) tag = db.query(Tag).filter(name=name, user_id=user_id).first()
return TagModel.model_validate(tag)
except Exception as e: except Exception as e:
return None return None
...@@ -118,82 +123,110 @@ class TagTable: ...@@ -118,82 +123,110 @@ class TagTable:
} }
) )
try: try:
result = ChatIdTag.create(**chatIdTag.model_dump()) with get_db() as db:
if result: result = ChatIdTag(**chatIdTag.model_dump())
return chatIdTag db.add(result)
else: db.commit()
return None db.refresh(result)
if result:
return ChatIdTagModel.model_validate(result)
else:
return None
except: except:
return None return None
def get_tags_by_user_id(self, user_id: str) -> List[TagModel]: def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
tag_names = [ with get_db() as db:
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name tag_names = [
for chat_id_tag in ChatIdTag.select() chat_id_tag.tag_name
.where(ChatIdTag.user_id == user_id) for chat_id_tag in (
.order_by(ChatIdTag.timestamp.desc()) db.query(ChatIdTag)
] .filter_by(user_id=user_id)
.order_by(ChatIdTag.timestamp.desc())
return [ .all()
TagModel(**model_to_dict(tag)) )
for tag in Tag.select() ]
.where(Tag.user_id == user_id)
.where(Tag.name.in_(tag_names)) return [
] TagModel.model_validate(tag)
for tag in (
db.query(Tag)
.filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
)
]
def get_tags_by_chat_id_and_user_id( def get_tags_by_chat_id_and_user_id(
self, chat_id: str, user_id: str self, chat_id: str, user_id: str
) -> List[TagModel]: ) -> List[TagModel]:
tag_names = [ with get_db() as db:
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
for chat_id_tag in ChatIdTag.select() tag_names = [
.where((ChatIdTag.user_id == user_id) & (ChatIdTag.chat_id == chat_id)) chat_id_tag.tag_name
.order_by(ChatIdTag.timestamp.desc()) for chat_id_tag in (
] db.query(ChatIdTag)
.filter_by(user_id=user_id, chat_id=chat_id)
return [ .order_by(ChatIdTag.timestamp.desc())
TagModel(**model_to_dict(tag)) .all()
for tag in Tag.select() )
.where(Tag.user_id == user_id) ]
.where(Tag.name.in_(tag_names))
] return [
TagModel.model_validate(tag)
for tag in (
db.query(Tag)
.filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
)
]
def get_chat_ids_by_tag_name_and_user_id( def get_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str self, tag_name: str, user_id: str
) -> Optional[ChatIdTagModel]: ) -> List[ChatIdTagModel]:
return [ with get_db() as db:
ChatIdTagModel(**model_to_dict(chat_id_tag))
for chat_id_tag in ChatIdTag.select() return [
.where((ChatIdTag.user_id == user_id) & (ChatIdTag.tag_name == tag_name)) ChatIdTagModel.model_validate(chat_id_tag)
.order_by(ChatIdTag.timestamp.desc()) for chat_id_tag in (
] db.query(ChatIdTag)
.filter_by(user_id=user_id, tag_name=tag_name)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
def count_chat_ids_by_tag_name_and_user_id( def count_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str self, tag_name: str, user_id: str
) -> int: ) -> int:
return ( with get_db() as db:
ChatIdTag.select()
.where((ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id)) return (
.count() db.query(ChatIdTag)
) .filter_by(tag_name=tag_name, user_id=user_id)
.count()
)
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool: def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
try: try:
query = ChatIdTag.delete().where( with get_db() as db:
(ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id) res = (
) db.query(ChatIdTag)
res = query.execute() # Remove the rows, return number of rows removed. .filter_by(tag_name=tag_name, user_id=user_id)
log.debug(f"res: {res}") .delete()
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
if tag_count == 0:
# Remove tag item from Tag col as well
query = Tag.delete().where(
(Tag.name == tag_name) & (Tag.user_id == user_id)
) )
query.execute() # Remove the rows, return number of rows removed. log.debug(f"res: {res}")
db.commit()
return True tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id
)
if tag_count == 0:
# Remove tag item from Tag col as well
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
db.commit()
return True
except Exception as e: except Exception as e:
log.error(f"delete_tag: {e}") log.error(f"delete_tag: {e}")
return False return False
...@@ -202,23 +235,25 @@ class TagTable: ...@@ -202,23 +235,25 @@ class TagTable:
self, tag_name: str, chat_id: str, user_id: str self, tag_name: str, chat_id: str, user_id: str
) -> bool: ) -> bool:
try: try:
query = ChatIdTag.delete().where( with get_db() as db:
(ChatIdTag.tag_name == tag_name)
& (ChatIdTag.chat_id == chat_id) res = (
& (ChatIdTag.user_id == user_id) db.query(ChatIdTag)
) .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
res = query.execute() # Remove the rows, return number of rows removed. .delete()
log.debug(f"res: {res}") )
log.debug(f"res: {res}")
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) db.commit()
if tag_count == 0:
# Remove tag item from Tag col as well tag_count = self.count_chat_ids_by_tag_name_and_user_id(
query = Tag.delete().where( tag_name, user_id
(Tag.name == tag_name) & (Tag.user_id == user_id)
) )
query.execute() # Remove the rows, return number of rows removed. if tag_count == 0:
# Remove tag item from Tag col as well
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
db.commit()
return True return True
except Exception as e: except Exception as e:
log.error(f"delete_tag: {e}") log.error(f"delete_tag: {e}")
return False return False
...@@ -234,4 +269,4 @@ class TagTable: ...@@ -234,4 +269,4 @@ class TagTable:
return True return True
Tags = TagTable(DB) Tags = TagTable()
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from peewee import * from typing import List, Optional
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time import time
import logging import logging
from apps.webui.internal.db import DB, JSONField from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, JSONField, get_db
from apps.webui.models.users import Users from apps.webui.models.users import Users
import json import json
...@@ -21,19 +21,18 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) ...@@ -21,19 +21,18 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
class Tool(Model): class Tool(Base):
id = CharField(unique=True) __tablename__ = "tool"
user_id = CharField()
name = TextField()
content = TextField()
specs = JSONField()
meta = JSONField()
valves = JSONField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta: id = Column(String, primary_key=True)
database = DB user_id = Column(String)
name = Column(Text)
content = Column(Text)
specs = Column(JSONField)
meta = Column(JSONField)
valves = Column(JSONField)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class ToolMeta(BaseModel): class ToolMeta(BaseModel):
...@@ -51,6 +50,8 @@ class ToolModel(BaseModel): ...@@ -51,6 +50,8 @@ class ToolModel(BaseModel):
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -78,61 +79,68 @@ class ToolValves(BaseModel): ...@@ -78,61 +79,68 @@ class ToolValves(BaseModel):
class ToolsTable: class ToolsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Tool])
def insert_new_tool( def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: List[dict] self, user_id: str, form_data: ToolForm, specs: List[dict]
) -> Optional[ToolModel]: ) -> Optional[ToolModel]:
tool = ToolModel(
**{
**form_data.model_dump(),
"specs": specs,
"user_id": user_id,
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
try: with get_db() as db:
result = Tool.create(**tool.model_dump())
if result: tool = ToolModel(
return tool **{
else: **form_data.model_dump(),
"specs": specs,
"user_id": user_id,
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
try:
result = Tool(**tool.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return ToolModel.model_validate(result)
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
return None return None
except Exception as e:
print(f"Error creating tool: {e}")
return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]: def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try: try:
tool = Tool.get(Tool.id == id) with get_db() as db:
return ToolModel(**model_to_dict(tool))
tool = db.get(Tool, id)
return ToolModel.model_validate(tool)
except: except:
return None return None
def get_tools(self) -> List[ToolModel]: def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()] with get_db() as db:
return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]: def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try: try:
tool = Tool.get(Tool.id == id) with get_db() as db:
return tool.valves if tool.valves else {}
tool = db.get(Tool, id)
return tool.valves if tool.valves else {}
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
return None return None
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
try: try:
query = Tool.update( with get_db() as db:
**{"valves": valves},
updated_at=int(time.time()), db.query(Tool).filter_by(id=id).update(
).where(Tool.id == id) {"valves": valves, "updated_at": int(time.time())}
query.execute() )
db.commit()
tool = Tool.get(Tool.id == id) return self.get_tool_by_id(id)
return ToolValves(**model_to_dict(tool))
except: except:
return None return None
...@@ -141,7 +149,7 @@ class ToolsTable: ...@@ -141,7 +149,7 @@ class ToolsTable:
) -> Optional[dict]: ) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump() user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "tools" and "valves" settings # Check if user has "tools" and "valves" settings
if "tools" not in user_settings: if "tools" not in user_settings:
...@@ -159,7 +167,7 @@ class ToolsTable: ...@@ -159,7 +167,7 @@ class ToolsTable:
) -> Optional[dict]: ) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump() user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "tools" and "valves" settings # Check if user has "tools" and "valves" settings
if "tools" not in user_settings: if "tools" not in user_settings:
...@@ -170,8 +178,7 @@ class ToolsTable: ...@@ -170,8 +178,7 @@ class ToolsTable:
user_settings["tools"]["valves"][id] = valves user_settings["tools"]["valves"][id] = valves
# Update the user settings in the database # Update the user settings in the database
query = Users.update_user_by_id(user_id, {"settings": user_settings}) Users.update_user_by_id(user_id, {"settings": user_settings})
query.execute()
return user_settings["tools"]["valves"][id] return user_settings["tools"]["valves"][id]
except Exception as e: except Exception as e:
...@@ -180,25 +187,27 @@ class ToolsTable: ...@@ -180,25 +187,27 @@ class ToolsTable:
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try: try:
query = Tool.update( with get_db() as db:
**updated, db.query(Tool).filter_by(id=id).update(
updated_at=int(time.time()), {**updated, "updated_at": int(time.time())}
).where(Tool.id == id) )
query.execute() db.commit()
tool = Tool.get(Tool.id == id) tool = db.query(Tool).get(id)
return ToolModel(**model_to_dict(tool)) db.refresh(tool)
return ToolModel.model_validate(tool)
except: except:
return None return None
def delete_tool_by_id(self, id: str) -> bool: def delete_tool_by_id(self, id: str) -> bool:
try: try:
query = Tool.delete().where((Tool.id == id)) with get_db() as db:
query.execute() # Remove the rows, return number of rows removed. db.query(Tool).filter_by(id=id).delete()
db.commit()
return True return True
except: except:
return False return False
Tools = ToolsTable(DB) Tools = ToolsTable()
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, parse_obj_as
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional from typing import List, Union, Optional
import time import time
from sqlalchemy import String, Column, BigInteger, Text
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
from apps.webui.internal.db import DB, JSONField from apps.webui.internal.db import Base, JSONField, Session, get_db
from apps.webui.models.chats import Chats from apps.webui.models.chats import Chats
#################### ####################
...@@ -13,25 +14,24 @@ from apps.webui.models.chats import Chats ...@@ -13,25 +14,24 @@ from apps.webui.models.chats import Chats
#################### ####################
class User(Model): class User(Base):
id = CharField(unique=True) __tablename__ = "user"
name = CharField()
email = CharField()
role = CharField()
profile_image_url = TextField()
last_active_at = BigIntegerField() id = Column(String, primary_key=True)
updated_at = BigIntegerField() name = Column(String)
created_at = BigIntegerField() email = Column(String)
role = Column(String)
profile_image_url = Column(Text)
api_key = CharField(null=True, unique=True) last_active_at = Column(BigInteger)
settings = JSONField(null=True) updated_at = Column(BigInteger)
info = JSONField(null=True) created_at = Column(BigInteger)
oauth_sub = TextField(null=True, unique=True) api_key = Column(String, nullable=True, unique=True)
settings = Column(JSONField, nullable=True)
info = Column(JSONField, nullable=True)
class Meta: oauth_sub = Column(Text, unique=True)
database = DB
class UserSettings(BaseModel): class UserSettings(BaseModel):
...@@ -57,6 +57,8 @@ class UserModel(BaseModel): ...@@ -57,6 +57,8 @@ class UserModel(BaseModel):
oauth_sub: Optional[str] = None oauth_sub: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -76,9 +78,6 @@ class UserUpdateForm(BaseModel): ...@@ -76,9 +78,6 @@ class UserUpdateForm(BaseModel):
class UsersTable: class UsersTable:
def __init__(self, db):
self.db = db
self.db.create_tables([User])
def insert_new_user( def insert_new_user(
self, self,
...@@ -89,77 +88,92 @@ class UsersTable: ...@@ -89,77 +88,92 @@ class UsersTable:
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None, oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
user = UserModel( with get_db() as db:
**{ user = UserModel(
"id": id, **{
"name": name, "id": id,
"email": email, "name": name,
"role": role, "email": email,
"profile_image_url": profile_image_url, "role": role,
"last_active_at": int(time.time()), "profile_image_url": profile_image_url,
"created_at": int(time.time()), "last_active_at": int(time.time()),
"updated_at": int(time.time()), "created_at": int(time.time()),
"oauth_sub": oauth_sub, "updated_at": int(time.time()),
} "oauth_sub": oauth_sub,
) }
result = User.create(**user.model_dump()) )
if result: result = User(**user.model_dump())
return user db.add(result)
else: db.commit()
return None db.refresh(result)
if result:
return user
else:
return None
def get_user_by_id(self, id: str) -> Optional[UserModel]: def get_user_by_id(self, id: str) -> Optional[UserModel]:
try: try:
user = User.get(User.id == id) with get_db() as db:
return UserModel(**model_to_dict(user)) user = db.query(User).filter_by(id=id).first()
except: return UserModel.model_validate(user)
except Exception as e:
return None return None
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]: def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try: try:
user = User.get(User.api_key == api_key) with get_db() as db:
return UserModel(**model_to_dict(user))
user = db.query(User).filter_by(api_key=api_key).first()
return UserModel.model_validate(user)
except: except:
return None return None
def get_user_by_email(self, email: str) -> Optional[UserModel]: def get_user_by_email(self, email: str) -> Optional[UserModel]:
try: try:
user = User.get(User.email == email) with get_db() as db:
return UserModel(**model_to_dict(user))
user = db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user)
except: except:
return None return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try: try:
user = User.get(User.oauth_sub == sub) with get_db() as db:
return UserModel(**model_to_dict(user))
user = db.query(User).filter_by(oauth_sub=sub).first()
return UserModel.model_validate(user)
except: except:
return None return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [ with get_db() as db:
UserModel(**model_to_dict(user)) users = (
for user in User.select() db.query(User)
# .limit(limit).offset(skip) # .offset(skip).limit(limit)
] .all()
)
return [UserModel.model_validate(user) for user in users]
def get_num_users(self) -> Optional[int]: def get_num_users(self) -> Optional[int]:
return User.select().count() with get_db() as db:
return db.query(User).count()
def get_first_user(self) -> UserModel: def get_first_user(self) -> UserModel:
try: try:
user = User.select().order_by(User.created_at).first() with get_db() as db:
return UserModel(**model_to_dict(user)) user = db.query(User).order_by(User.created_at).first()
return UserModel.model_validate(user)
except: except:
return None return None
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]: def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
try: try:
query = User.update(role=role).where(User.id == id) with get_db() as db:
query.execute() db.query(User).filter_by(id=id).update({"role": role})
db.commit()
user = User.get(User.id == id) user = db.query(User).filter_by(id=id).first()
return UserModel(**model_to_dict(user)) return UserModel.model_validate(user)
except: except:
return None return None
...@@ -167,23 +181,28 @@ class UsersTable: ...@@ -167,23 +181,28 @@ class UsersTable:
self, id: str, profile_image_url: str self, id: str, profile_image_url: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
try: try:
query = User.update(profile_image_url=profile_image_url).where( with get_db() as db:
User.id == id db.query(User).filter_by(id=id).update(
) {"profile_image_url": profile_image_url}
query.execute() )
db.commit()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user)) user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except: except:
return None return None
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
try: try:
query = User.update(last_active_at=int(time.time())).where(User.id == id) with get_db() as db:
query.execute()
user = User.get(User.id == id) db.query(User).filter_by(id=id).update(
return UserModel(**model_to_dict(user)) {"last_active_at": int(time.time())}
)
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except: except:
return None return None
...@@ -191,22 +210,25 @@ class UsersTable: ...@@ -191,22 +210,25 @@ class UsersTable:
self, id: str, oauth_sub: str self, id: str, oauth_sub: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
try: try:
query = User.update(oauth_sub=oauth_sub).where(User.id == id) with get_db() as db:
query.execute() db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
db.commit()
user = User.get(User.id == id) user = db.query(User).filter_by(id=id).first()
return UserModel(**model_to_dict(user)) return UserModel.model_validate(user)
except: except:
return None return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try: try:
query = User.update(**updated).where(User.id == id) with get_db() as db:
query.execute() db.query(User).filter_by(id=id).update(updated)
db.commit()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user)) user = db.query(User).filter_by(id=id).first()
except: return UserModel.model_validate(user)
# return UserModel(**user.dict())
except Exception as e:
return None return None
def delete_user_by_id(self, id: str) -> bool: def delete_user_by_id(self, id: str) -> bool:
...@@ -215,9 +237,10 @@ class UsersTable: ...@@ -215,9 +237,10 @@ class UsersTable:
result = Chats.delete_chats_by_user_id(id) result = Chats.delete_chats_by_user_id(id)
if result: if result:
# Delete User with get_db() as db:
query = User.delete().where(User.id == id) # Delete User
query.execute() # Remove the rows, return number of rows removed. db.query(User).filter_by(id=id).delete()
db.commit()
return True return True
else: else:
...@@ -227,19 +250,20 @@ class UsersTable: ...@@ -227,19 +250,20 @@ class UsersTable:
def update_user_api_key_by_id(self, id: str, api_key: str) -> str: def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
try: try:
query = User.update(api_key=api_key).where(User.id == id) with get_db() as db:
result = query.execute() result = db.query(User).filter_by(id=id).update({"api_key": api_key})
db.commit()
return True if result == 1 else False return True if result == 1 else False
except: except:
return False return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]: def get_user_api_key_by_id(self, id: str) -> Optional[str]:
try: try:
user = User.get(User.id == id) with get_db() as db:
return user.api_key user = db.query(User).filter_by(id=id).first()
except: return user.api_key
except Exception as e:
return None return None
Users = UsersTable(DB) Users = UsersTable()
...@@ -76,7 +76,10 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user ...@@ -76,7 +76,10 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
@router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse]) @router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_user_id( async def get_user_chat_list_by_user_id(
user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50 user_id: str,
user=Depends(get_admin_user),
skip: int = 0,
limit: int = 50,
): ):
return Chats.get_chat_list_by_user_id( return Chats.get_chat_list_by_user_id(
user_id, include_archived=True, skip=skip, limit=limit user_id, include_archived=True, skip=skip, limit=limit
...@@ -119,7 +122,7 @@ async def get_user_chats(user=Depends(get_verified_user)): ...@@ -119,7 +122,7 @@ async def get_user_chats(user=Depends(get_verified_user)):
@router.get("/all/archived", response_model=List[ChatResponse]) @router.get("/all/archived", response_model=List[ChatResponse])
async def get_user_chats(user=Depends(get_verified_user)): async def get_user_archived_chats(user=Depends(get_verified_user)):
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_archived_chats_by_user_id(user.id) for chat in Chats.get_archived_chats_by_user_id(user.id)
...@@ -207,7 +210,6 @@ async def get_user_chat_list_by_tag_name( ...@@ -207,7 +210,6 @@ async def get_user_chat_list_by_tag_name(
form_data: TagNameForm, user=Depends(get_verified_user) form_data: TagNameForm, user=Depends(get_verified_user)
): ):
print(form_data)
chat_ids = [ chat_ids = [
chat_id_tag.chat_id chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id( for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
......
...@@ -130,7 +130,9 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_ ...@@ -130,7 +130,9 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_
@router.post("/doc/update", response_model=Optional[DocumentResponse]) @router.post("/doc/update", response_model=Optional[DocumentResponse])
async def update_doc_by_name( async def update_doc_by_name(
name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user) name: str,
form_data: DocumentUpdateForm,
user=Depends(get_admin_user),
): ):
doc = Documents.update_doc_by_name(name, form_data) doc = Documents.update_doc_by_name(name, form_data)
if doc: if doc:
......
...@@ -50,10 +50,7 @@ router = APIRouter() ...@@ -50,10 +50,7 @@ router = APIRouter()
@router.post("/") @router.post("/")
def upload_file( def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
file: UploadFile = File(...),
user=Depends(get_verified_user),
):
log.info(f"file.content_type: {file.content_type}") log.info(f"file.content_type: {file.content_type}")
try: try:
unsanitized_filename = file.filename unsanitized_filename = file.filename
......
...@@ -233,7 +233,10 @@ async def delete_function_by_id( ...@@ -233,7 +233,10 @@ async def delete_function_by_id(
# delete the function file # delete the function file
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py") function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
os.remove(function_path) try:
os.remove(function_path)
except:
pass
return result return result
......
...@@ -50,7 +50,9 @@ class MemoryUpdateModel(BaseModel): ...@@ -50,7 +50,9 @@ class MemoryUpdateModel(BaseModel):
@router.post("/add", response_model=Optional[MemoryModel]) @router.post("/add", response_model=Optional[MemoryModel])
async def add_memory( async def add_memory(
request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user) request: Request,
form_data: AddMemoryForm,
user=Depends(get_verified_user),
): ):
memory = Memories.insert_new_memory(user.id, form_data.content) memory = Memories.insert_new_memory(user.id, form_data.content)
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
......
...@@ -5,6 +5,7 @@ from typing import List, Union, Optional ...@@ -5,6 +5,7 @@ from typing import List, Union, Optional
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
from utils.utils import get_verified_user, get_admin_user from utils.utils import get_verified_user, get_admin_user
...@@ -29,7 +30,9 @@ async def get_models(user=Depends(get_verified_user)): ...@@ -29,7 +30,9 @@ async def get_models(user=Depends(get_verified_user)):
@router.post("/add", response_model=Optional[ModelModel]) @router.post("/add", response_model=Optional[ModelModel])
async def add_new_model( async def add_new_model(
request: Request, form_data: ModelForm, user=Depends(get_admin_user) request: Request,
form_data: ModelForm,
user=Depends(get_admin_user),
): ):
if form_data.id in request.app.state.MODELS: if form_data.id in request.app.state.MODELS:
raise HTTPException( raise HTTPException(
...@@ -73,7 +76,10 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)): ...@@ -73,7 +76,10 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/update", response_model=Optional[ModelModel]) @router.post("/update", response_model=Optional[ModelModel])
async def update_model_by_id( async def update_model_by_id(
request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user) request: Request,
id: str,
form_data: ModelForm,
user=Depends(get_admin_user),
): ):
model = Models.get_model_by_id(id) model = Models.get_model_by_id(id)
if model: if model:
......
...@@ -71,7 +71,9 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)): ...@@ -71,7 +71,9 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
@router.post("/command/{command}/update", response_model=Optional[PromptModel]) @router.post("/command/{command}/update", response_model=Optional[PromptModel])
async def update_prompt_by_command( async def update_prompt_by_command(
command: str, form_data: PromptForm, user=Depends(get_admin_user) command: str,
form_data: PromptForm,
user=Depends(get_admin_user),
): ):
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt: if prompt:
......
...@@ -6,7 +6,6 @@ from fastapi import APIRouter ...@@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.models.users import Users from apps.webui.models.users import Users
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id from apps.webui.utils import load_toolkit_module_by_id
...@@ -57,7 +56,9 @@ async def get_toolkits(user=Depends(get_admin_user)): ...@@ -57,7 +56,9 @@ async def get_toolkits(user=Depends(get_admin_user)):
@router.post("/create", response_model=Optional[ToolResponse]) @router.post("/create", response_model=Optional[ToolResponse])
async def create_new_toolkit( async def create_new_toolkit(
request: Request, form_data: ToolForm, user=Depends(get_admin_user) request: Request,
form_data: ToolForm,
user=Depends(get_admin_user),
): ):
if not form_data.id.isidentifier(): if not form_data.id.isidentifier():
raise HTTPException( raise HTTPException(
...@@ -131,7 +132,10 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)): ...@@ -131,7 +132,10 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/update", response_model=Optional[ToolModel]) @router.post("/id/{id}/update", response_model=Optional[ToolModel])
async def update_toolkit_by_id( async def update_toolkit_by_id(
request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user) request: Request,
id: str,
form_data: ToolForm,
user=Depends(get_admin_user),
): ):
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py") toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
......
...@@ -138,7 +138,7 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)): ...@@ -138,7 +138,7 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)):
@router.post("/user/info/update", response_model=Optional[dict]) @router.post("/user/info/update", response_model=Optional[dict])
async def update_user_settings_by_session_user( async def update_user_info_by_session_user(
form_data: dict, user=Depends(get_verified_user) form_data: dict, user=Depends(get_verified_user)
): ):
user = Users.get_user_by_id(user.id) user = Users.get_user_by_id(user.id)
...@@ -205,7 +205,9 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): ...@@ -205,7 +205,9 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
@router.post("/{user_id}/update", response_model=Optional[UserModel]) @router.post("/{user_id}/update", response_model=Optional[UserModel])
async def update_user_by_id( async def update_user_by_id(
user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user) user_id: str,
form_data: UserUpdateForm,
session_user=Depends(get_admin_user),
): ):
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id)
......
from fastapi import APIRouter, UploadFile, File, Response from fastapi import APIRouter, UploadFile, File, Response
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from peewee import SqliteDatabase
from starlette.responses import StreamingResponse, FileResponse from starlette.responses import StreamingResponse, FileResponse
from pydantic import BaseModel from pydantic import BaseModel
...@@ -10,7 +9,6 @@ import markdown ...@@ -10,7 +9,6 @@ import markdown
import black import black
from apps.webui.internal.db import DB
from utils.utils import get_admin_user from utils.utils import get_admin_user
from utils.misc import calculate_sha256, get_gravatar_url from utils.misc import calculate_sha256, get_gravatar_url
...@@ -114,13 +112,15 @@ async def download_db(user=Depends(get_admin_user)): ...@@ -114,13 +112,15 @@ async def download_db(user=Depends(get_admin_user)):
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
if not isinstance(DB, SqliteDatabase): from apps.webui.internal.db import engine
if engine.name != "sqlite":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DB_NOT_SQLITE, detail=ERROR_MESSAGES.DB_NOT_SQLITE,
) )
return FileResponse( return FileResponse(
DB.database, engine.url.database,
media_type="application/octet-stream", media_type="application/octet-stream",
filename="webui.db", filename="webui.db",
) )
......
...@@ -5,9 +5,8 @@ import importlib.metadata ...@@ -5,9 +5,8 @@ import importlib.metadata
import pkgutil import pkgutil
import chromadb import chromadb
from chromadb import Settings from chromadb import Settings
from base64 import b64encode
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from typing import TypeVar, Generic, Union from typing import TypeVar, Generic
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
...@@ -19,7 +18,6 @@ import markdown ...@@ -19,7 +18,6 @@ import markdown
import requests import requests
import shutil import shutil
from secrets import token_bytes
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
#################################### ####################################
...@@ -395,6 +393,18 @@ OAUTH_PROVIDER_NAME = PersistentConfig( ...@@ -395,6 +393,18 @@ OAUTH_PROVIDER_NAME = PersistentConfig(
os.environ.get("OAUTH_PROVIDER_NAME", "SSO"), os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
) )
OAUTH_USERNAME_CLAIM = PersistentConfig(
"OAUTH_USERNAME_CLAIM",
"oauth.oidc.username_claim",
os.environ.get("OAUTH_USERNAME_CLAIM", "name"),
)
OAUTH_PICTURE_CLAIM = PersistentConfig(
"OAUTH_USERNAME_CLAIM",
"oauth.oidc.avatar_claim",
os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
)
def load_oauth_providers(): def load_oauth_providers():
OAUTH_PROVIDERS.clear() OAUTH_PROVIDERS.clear()
...@@ -440,16 +450,27 @@ load_oauth_providers() ...@@ -440,16 +450,27 @@ load_oauth_providers()
STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve() STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve()
frontend_favicon = FRONTEND_BUILD_DIR / "favicon.png" frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png"
if frontend_favicon.exists(): if frontend_favicon.exists():
try: try:
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png") shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
except Exception as e: except Exception as e:
logging.error(f"An error occurred: {e}") logging.error(f"An error occurred: {e}")
else: else:
logging.warning(f"Frontend favicon not found at {frontend_favicon}") logging.warning(f"Frontend favicon not found at {frontend_favicon}")
frontend_splash = FRONTEND_BUILD_DIR / "static" / "splash.png"
if frontend_splash.exists():
try:
shutil.copyfile(frontend_splash, STATIC_DIR / "splash.png")
except Exception as e:
logging.error(f"An error occurred: {e}")
else:
logging.warning(f"Frontend splash not found at {frontend_splash}")
#################################### ####################################
# CUSTOM_NAME # CUSTOM_NAME
#################################### ####################################
...@@ -474,6 +495,19 @@ if CUSTOM_NAME: ...@@ -474,6 +495,19 @@ if CUSTOM_NAME:
r.raw.decode_content = True r.raw.decode_content = True
shutil.copyfileobj(r.raw, f) shutil.copyfileobj(r.raw, f)
if "splash" in data:
url = (
f"https://api.openwebui.com{data['splash']}"
if data["splash"][0] == "/"
else data["splash"]
)
r = requests.get(url, stream=True)
if r.status_code == 200:
with open(f"{STATIC_DIR}/splash.png", "wb") as f:
r.raw.decode_content = True
shutil.copyfileobj(r.raw, f)
WEBUI_NAME = data["name"] WEBUI_NAME = data["name"]
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
...@@ -769,11 +803,14 @@ class BannerModel(BaseModel): ...@@ -769,11 +803,14 @@ class BannerModel(BaseModel):
timestamp: int timestamp: int
WEBUI_BANNERS = PersistentConfig( try:
"WEBUI_BANNERS", banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]"))
"ui.banners", banners = [BannerModel(**banner) for banner in banners]
[BannerModel(**banner) for banner in json.loads("[]")], except Exception as e:
) print(f"Error loading WEBUI_BANNERS: {e}")
banners = []
WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners)
SHOW_ADMIN_DETAILS = PersistentConfig( SHOW_ADMIN_DETAILS = PersistentConfig(
...@@ -885,6 +922,22 @@ WEBUI_SESSION_COOKIE_SECURE = os.environ.get( ...@@ -885,6 +922,22 @@ WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
if WEBUI_AUTH and WEBUI_SECRET_KEY == "": if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
####################################
# RAG document content extraction
####################################
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
"CONTENT_EXTRACTION_ENGINE",
"rag.CONTENT_EXTRACTION_ENGINE",
os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
)
TIKA_SERVER_URL = PersistentConfig(
"TIKA_SERVER_URL",
"rag.tika_server_url",
os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment
)
#################################### ####################################
# RAG # RAG
#################################### ####################################
...@@ -1302,3 +1355,7 @@ AUDIO_TTS_VOICE = PersistentConfig( ...@@ -1302,3 +1355,7 @@ AUDIO_TTS_VOICE = PersistentConfig(
#################################### ####################################
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db") DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
# Replace the postgres:// with postgresql://
if "postgres://" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")
...@@ -89,3 +89,14 @@ class ERROR_MESSAGES(str, Enum): ...@@ -89,3 +89,14 @@ class ERROR_MESSAGES(str, Enum):
OLLAMA_API_DISABLED = ( OLLAMA_API_DISABLED = (
"The Ollama API is disabled. Please enable it to use this feature." "The Ollama API is disabled. Please enable it to use this feature."
) )
class TASKS(str, Enum):
def __str__(self) -> str:
return super().__str__()
DEFAULT = lambda task="": f"{task if task else 'default'}"
TITLE_GENERATION = "Title Generation"
EMOJI_GENERATION = "Emoji Generation"
QUERY_GENERATION = "Query Generation"
FUNCTION_CALLING = "Function Calling"
...@@ -4,9 +4,7 @@ from contextlib import asynccontextmanager ...@@ -4,9 +4,7 @@ from contextlib import asynccontextmanager
from authlib.integrations.starlette_client import OAuth from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo from authlib.oidc.core import UserInfo
from bs4 import BeautifulSoup
import json import json
import markdown
import time import time
import os import os
import sys import sys
...@@ -18,25 +16,22 @@ import shutil ...@@ -18,25 +16,22 @@ import shutil
import os import os
import uuid import uuid
import inspect import inspect
import asyncio
from fastapi.concurrency import run_in_threadpool
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi import HTTPException from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import text
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import StreamingResponse, Response, RedirectResponse from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import app as socket_app from apps.socket.main import sio, app as socket_app
from apps.ollama.main import ( from apps.ollama.main import (
app as ollama_app, app as ollama_app,
OpenAIChatCompletionForm,
get_all_models as get_ollama_models, get_all_models as get_ollama_models,
generate_openai_chat_completion as generate_ollama_chat_completion, generate_openai_chat_completion as generate_ollama_chat_completion,
) )
...@@ -54,13 +49,14 @@ from apps.webui.main import ( ...@@ -54,13 +49,14 @@ from apps.webui.main import (
get_pipe_models, get_pipe_models,
generate_function_chat_completion, generate_function_chat_completion,
) )
from apps.webui.internal.db import Session
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional, Iterator, Generator, Union from typing import List, Optional
from apps.webui.models.auths import Auths from apps.webui.models.auths import Auths
from apps.webui.models.models import Models, ModelModel from apps.webui.models.models import Models
from apps.webui.models.tools import Tools from apps.webui.models.tools import Tools
from apps.webui.models.functions import Functions from apps.webui.models.functions import Functions
from apps.webui.models.users import Users from apps.webui.models.users import Users
...@@ -83,14 +79,12 @@ from utils.task import ( ...@@ -83,14 +79,12 @@ from utils.task import (
from utils.misc import ( from utils.misc import (
get_last_user_message, get_last_user_message,
add_or_update_system_message, add_or_update_system_message,
stream_message_template,
parse_duration, parse_duration,
) )
from apps.rag.utils import get_rag_context, rag_template from apps.rag.utils import get_rag_context, rag_template
from config import ( from config import (
CONFIG_DATA,
WEBUI_NAME, WEBUI_NAME,
WEBUI_URL, WEBUI_URL,
WEBUI_AUTH, WEBUI_AUTH,
...@@ -98,7 +92,6 @@ from config import ( ...@@ -98,7 +92,6 @@ from config import (
VERSION, VERSION,
CHANGELOG, CHANGELOG,
FRONTEND_BUILD_DIR, FRONTEND_BUILD_DIR,
UPLOAD_DIR,
CACHE_DIR, CACHE_DIR,
STATIC_DIR, STATIC_DIR,
DEFAULT_LOCALE, DEFAULT_LOCALE,
...@@ -126,7 +119,8 @@ from config import ( ...@@ -126,7 +119,8 @@ from config import (
WEBUI_SESSION_COOKIE_SECURE, WEBUI_SESSION_COOKIE_SECURE,
AppConfig, AppConfig,
) )
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
from utils.webhook import post_webhook from utils.webhook import post_webhook
if SAFE_MODE: if SAFE_MODE:
...@@ -167,8 +161,20 @@ https://github.com/open-webui/open-webui ...@@ -167,8 +161,20 @@ https://github.com/open-webui/open-webui
) )
def run_migrations():
try:
from alembic.config import Config
from alembic import command
alembic_cfg = Config("alembic.ini")
command.upgrade(alembic_cfg, "head")
except Exception as e:
print(f"Error: {e}")
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
run_migrations()
yield yield
...@@ -212,8 +218,79 @@ origins = ["*"] ...@@ -212,8 +218,79 @@ origins = ["*"]
################################## ##################################
async def get_body_and_model_and_user(request):
# Read the original request body
body = await request.body()
body_str = body.decode("utf-8")
body = json.loads(body_str) if body_str else {}
model_id = body["model"]
if model_id not in app.state.MODELS:
raise Exception("Model not found")
model = app.state.MODELS[model_id]
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
return body, model, user
def get_task_model_id(default_model_id):
# Set the task model
task_model_id = default_model_id
# Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
return task_model_id
def get_filter_function_ids(model):
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
filter_ids.sort(key=get_priority)
return filter_ids
async def get_function_call_response( async def get_function_call_response(
messages, files, tool_id, template, task_model_id, user messages,
files,
tool_id,
template,
task_model_id,
user,
__event_emitter__=None,
__event_call__=None,
): ):
tool = Tools.get_tool_by_id(tool_id) tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2) tools_specs = json.dumps(tool.specs, indent=2)
...@@ -240,6 +317,7 @@ async def get_function_call_response( ...@@ -240,6 +317,7 @@ async def get_function_call_response(
{"role": "user", "content": f"Query: {prompt}"}, {"role": "user", "content": f"Query: {prompt}"},
], ],
"stream": False, "stream": False,
"task": TASKS.FUNCTION_CALLING,
} }
try: try:
...@@ -252,7 +330,6 @@ async def get_function_call_response( ...@@ -252,7 +330,6 @@ async def get_function_call_response(
response = None response = None
try: try:
response = await generate_chat_completions(form_data=payload, user=user) response = await generate_chat_completions(form_data=payload, user=user)
content = None content = None
if hasattr(response, "body_iterator"): if hasattr(response, "body_iterator"):
...@@ -266,334 +343,367 @@ async def get_function_call_response( ...@@ -266,334 +343,367 @@ async def get_function_call_response(
else: else:
content = response["choices"][0]["message"]["content"] content = response["choices"][0]["message"]["content"]
if content is None:
return None, None, False
# Parse the function response # Parse the function response
if content is not None: print(f"content: {content}")
print(f"content: {content}") result = json.loads(content)
result = json.loads(content) print(result)
print(result)
citation = None
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
file_handler = False
# check if toolkit_module has file_handler self variable
if hasattr(toolkit_module, "file_handler"):
file_handler = True
print("file_handler: ", file_handler)
if hasattr(toolkit_module, "valves") and hasattr(
toolkit_module, "Valves"
):
valves = Tools.get_tool_valves_by_id(tool_id)
toolkit_module.valves = toolkit_module.Valves(
**(valves if valves else {})
)
function = getattr(toolkit_module, result["name"]) citation = None
function_result = None
try:
# Get the signature of the function
sig = inspect.signature(function)
params = result["parameters"]
if "__user__" in sig.parameters: if "name" not in result:
# Call the function with the '__user__' parameter included return None, None, False
__user__ = {
"id": user.id, # Call the function
"email": user.email, if tool_id in webui_app.state.TOOLS:
"name": user.name, toolkit_module = webui_app.state.TOOLS[tool_id]
"role": user.role, else:
} toolkit_module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
try:
if hasattr(toolkit_module, "UserValves"): file_handler = False
__user__["valves"] = toolkit_module.UserValves( # check if toolkit_module has file_handler self variable
**Tools.get_user_valves_by_id_and_user_id( if hasattr(toolkit_module, "file_handler"):
tool_id, user.id file_handler = True
) print("file_handler: ", file_handler)
)
except Exception as e: if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
print(e) valves = Tools.get_tool_valves_by_id(tool_id)
toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
params = {**params, "__user__": __user__}
if "__messages__" in sig.parameters: function = getattr(toolkit_module, result["name"])
# Call the function with the '__messages__' parameter included function_result = None
params = { try:
**params, # Get the signature of the function
"__messages__": messages, sig = inspect.signature(function)
} params = result["parameters"]
if "__files__" in sig.parameters: # Extra parameters to be passed to the function
# Call the function with the '__files__' parameter included extra_params = {
params = { "__model__": model,
**params, "__id__": tool_id,
"__files__": files, "__messages__": messages,
} "__files__": files,
"__event_emitter__": __event_emitter__,
if "__model__" in sig.parameters: "__event_call__": __event_call__,
# Call the function with the '__model__' parameter included }
params = {
**params, # Add extra params in contained in function signature
"__model__": model, for key, value in extra_params.items():
} if key in sig.parameters:
params[key] = value
if "__id__" in sig.parameters:
# Call the function with the '__id__' parameter included if "__user__" in sig.parameters:
params = { # Call the function with the '__user__' parameter included
**params, __user__ = {
"__id__": tool_id, "id": user.id,
} "email": user.email,
"name": user.name,
if inspect.iscoroutinefunction(function): "role": user.role,
function_result = await function(**params) }
else:
function_result = function(**params) try:
if hasattr(toolkit_module, "UserValves"):
if hasattr(toolkit_module, "citation") and toolkit_module.citation: __user__["valves"] = toolkit_module.UserValves(
citation = { **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
"source": {"name": f"TOOL:{tool.name}/{result['name']}"}, )
"document": [function_result],
"metadata": [{"source": result["name"]}],
}
except Exception as e: except Exception as e:
print(e) print(e)
# Add the function result to the system prompt params = {**params, "__user__": __user__}
if function_result is not None:
return function_result, citation, file_handler if inspect.iscoroutinefunction(function):
function_result = await function(**params)
else:
function_result = function(**params)
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
citation = {
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
"document": [function_result],
"metadata": [{"source": result["name"]}],
}
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result is not None:
return function_result, citation, file_handler
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
return None, None, False return None, None, False
class ChatCompletionMiddleware(BaseHTTPMiddleware): async def chat_completion_functions_handler(
async def dispatch(self, request: Request, call_next): body, model, user, __event_emitter__, __event_call__
data_items = [] ):
skip_files = None
filter_ids = get_filter_function_ids(model)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if not filter:
continue
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
if not hasattr(function_module, "inlet"):
continue
try:
inlet = function_module.inlet
# Get the signature of the function
sig = inspect.signature(inlet)
params = {"body": body}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": filter_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(inlet):
body = await inlet(**params)
else:
body = inlet(**params)
except Exception as e:
print(f"Error: {e}")
raise e
if skip_files:
if "files" in body:
del body["files"]
return body, {}
async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__):
skip_files = None
contexts = []
citations = None
task_model_id = get_task_model_id(body["model"])
# If tool_ids field is present, call the functions
if "tool_ids" in body:
print(body["tool_ids"])
for tool_id in body["tool_ids"]:
print(tool_id)
try:
response, citation, file_handler = await get_function_call_response(
messages=body["messages"],
files=body.get("files", []),
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
__event_emitter__=__event_emitter__,
__event_call__=__event_call__,
)
show_citations = False print(file_handler)
citations = [] if isinstance(response, str):
contexts.append(response)
if citation:
if citations is None:
citations = [citation]
else:
citations.append(citation)
if file_handler:
skip_files = True
except Exception as e:
print(f"Error: {e}")
del body["tool_ids"]
print(f"tool_contexts: {contexts}")
if skip_files:
if "files" in body:
del body["files"]
return body, {
**({"contexts": contexts} if contexts is not None else {}),
**({"citations": citations} if citations is not None else {}),
}
async def chat_completion_files_handler(body):
contexts = []
citations = None
if "files" in body:
files = body["files"]
del body["files"]
contexts, citations = get_rag_context(
files=files,
messages=body["messages"],
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
log.debug(f"rag_contexts: {contexts}, citations: {citations}")
return body, {
**({"contexts": contexts} if contexts is not None else {}),
**({"citations": citations} if citations is not None else {}),
}
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.method == "POST" and any( if request.method == "POST" and any(
endpoint in request.url.path endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"] for endpoint in ["/ollama/api/chat", "/chat/completions"]
): ):
log.debug(f"request.url.path: {request.url.path}") log.debug(f"request.url.path: {request.url.path}")
# Read the original request body try:
body = await request.body() body, model, user = await get_body_and_model_and_user(request)
body_str = body.decode("utf-8") except Exception as e:
data = json.loads(body_str) if body_str else {} return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
user = get_current_user( content={"detail": str(e)},
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
# Flag to skip RAG completions if file_handler is present in tools/functions
skip_files = False
if data.get("citations"):
show_citations = True
del data["citations"]
model_id = data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
) )
model = app.state.MODELS[model_id]
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get(
"priority", 0
)
return 0
filter_ids = [ # Extract session_id, chat_id and message_id from the request body
function.id for function in Functions.get_global_filter_functions() session_id = None
] if "session_id" in body:
if "info" in model and "meta" in model["info"]: session_id = body["session_id"]
filter_ids.extend(model["info"]["meta"].get("filterIds", [])) del body["session_id"]
filter_ids = list(set(filter_ids)) chat_id = None
if "chat_id" in body:
enabled_filter_ids = [ chat_id = body["chat_id"]
function.id del body["chat_id"]
for function in Functions.get_functions_by_type( message_id = None
"filter", active_only=True if "id" in body:
message_id = body["id"]
del body["id"]
async def __event_emitter__(data):
await sio.emit(
"chat-events",
{
"chat_id": chat_id,
"message_id": message_id,
"data": data,
},
to=session_id,
) )
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
filter_ids.sort(key=get_priority) async def __event_call__(data):
for filter_id in filter_ids: response = await sio.call(
filter = Functions.get_function_by_id(filter_id) "chat-events",
if filter: {"chat_id": chat_id, "message_id": message_id, "data": data},
if filter_id in webui_app.state.FUNCTIONS: to=session_id,
function_module = webui_app.state.FUNCTIONS[filter_id] )
else: return response
function_module, function_type, frontmatter = (
load_function_module_by_id(filter_id)
)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
try: # Initialize data_items to store additional data to be sent to the client
if hasattr(function_module, "inlet"): data_items = []
inlet = function_module.inlet
# Get the signature of the function
sig = inspect.signature(inlet)
params = {"body": data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__id__" in sig.parameters:
params = {
**params,
"__id__": filter_id,
}
if inspect.iscoroutinefunction(inlet):
data = await inlet(**params)
else:
data = inlet(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
# Set the task model # Initialize context, and citations
task_model_id = data["model"] contexts = []
# Check if the user has a custom task model and use that model citations = []
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
prompt = get_last_user_message(data["messages"])
context = ""
# If tool_ids field is present, call the functions
if "tool_ids" in data:
print(data["tool_ids"])
for tool_id in data["tool_ids"]:
print(tool_id)
try:
response, citation, file_handler = (
await get_function_call_response(
messages=data["messages"],
files=data.get("files", []),
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
)
)
print(file_handler) try:
if isinstance(response, str): body, flags = await chat_completion_functions_handler(
context += ("\n" if context != "" else "") + response body, model, user, __event_emitter__, __event_call__
)
if citation: except Exception as e:
citations.append(citation) return JSONResponse(
show_citations = True status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
if file_handler: )
skip_files = True
except Exception as e:
print(f"Error: {e}")
del data["tool_ids"]
print(f"tool_context: {context}")
# If files field is present, generate RAG completions
# If skip_files is True, skip the RAG completions
if "files" in data:
if not skip_files:
data = {**data}
rag_context, rag_citations = get_rag_context(
files=data["files"],
messages=data["messages"],
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
if rag_context:
context += ("\n" if context != "" else "") + rag_context
log.debug(f"rag_context: {rag_context}, citations: {citations}") try:
body, flags = await chat_completion_tools_handler(
body, user, __event_emitter__, __event_call__
)
if rag_citations: contexts.extend(flags.get("contexts", []))
citations.extend(rag_citations) citations.extend(flags.get("citations", []))
except Exception as e:
print(e)
pass
del data["files"] try:
body, flags = await chat_completion_files_handler(body)
if show_citations and len(citations) > 0: contexts.extend(flags.get("contexts", []))
data_items.append({"citations": citations}) citations.extend(flags.get("citations", []))
except Exception as e:
print(e)
pass
if context != "": # If context is not empty, insert it into the messages
system_prompt = rag_template( if len(contexts) > 0:
rag_app.state.config.RAG_TEMPLATE, context, prompt context_string = "/n".join(contexts).strip()
) prompt = get_last_user_message(body["messages"])
print(system_prompt) body["messages"] = add_or_update_system_message(
data["messages"] = add_or_update_system_message( rag_template(
system_prompt, data["messages"] rag_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
) )
modified_body_bytes = json.dumps(data).encode("utf-8") # If there are citations, add them to the data_items
if len(citations) > 0:
data_items.append({"citations": citations})
modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one # Replace the request body with the modified one
request._body = modified_body_bytes request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length # Set custom header to ensure content-length matches new body length
...@@ -654,9 +764,7 @@ app.add_middleware(ChatCompletionMiddleware) ...@@ -654,9 +764,7 @@ app.add_middleware(ChatCompletionMiddleware)
################################## ##################################
def filter_pipeline(payload, user): def get_sorted_filters(model_id):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
filters = [ filters = [
model model
for model in app.state.MODELS.values() for model in app.state.MODELS.values()
...@@ -672,6 +780,13 @@ def filter_pipeline(payload, user): ...@@ -672,6 +780,13 @@ def filter_pipeline(payload, user):
) )
] ]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
return sorted_filters
def filter_pipeline(payload, user):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id)
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
...@@ -704,25 +819,12 @@ def filter_pipeline(payload, user): ...@@ -704,25 +819,12 @@ def filter_pipeline(payload, user):
print(f"Connection error: {e}") print(f"Connection error: {e}")
if r is not None: if r is not None:
try: res = r.json()
res = r.json()
except:
pass
if "detail" in res: if "detail" in res:
raise Exception(r.status_code, res["detail"]) raise Exception(r.status_code, res["detail"])
else: if "pipeline" not in app.state.MODELS[model_id] and "task" in payload:
pass del payload["task"]
if "pipeline" not in app.state.MODELS[model_id]:
if "chat_id" in payload:
del payload["chat_id"]
if "title" in payload:
del payload["title"]
if "task" in payload:
del payload["task"]
return payload return payload
...@@ -787,6 +889,14 @@ app.add_middleware( ...@@ -787,6 +889,14 @@ app.add_middleware(
) )
@app.middleware("http")
async def commit_session_after_request(request: Request, call_next):
response = await call_next(request)
log.debug("Commit session after request")
Session.commit()
return response
@app.middleware("http") @app.middleware("http")
async def check_url(request: Request, call_next): async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0: if len(app.state.MODELS) == 0:
...@@ -863,12 +973,16 @@ async def get_all_models(): ...@@ -863,12 +973,16 @@ async def get_all_models():
model["info"] = custom_model.model_dump() model["info"] = custom_model.model_dump()
else: else:
owned_by = "openai" owned_by = "openai"
pipe = None
for model in models: for model in models:
if ( if (
custom_model.base_model_id == model["id"] custom_model.base_model_id == model["id"]
or custom_model.base_model_id == model["id"].split(":")[0] or custom_model.base_model_id == model["id"].split(":")[0]
): ):
owned_by = model["owned_by"] owned_by = model["owned_by"]
if "pipe" in model:
pipe = model["pipe"]
break break
models.append( models.append(
...@@ -880,11 +994,11 @@ async def get_all_models(): ...@@ -880,11 +994,11 @@ async def get_all_models():
"owned_by": owned_by, "owned_by": owned_by,
"info": custom_model.model_dump(), "info": custom_model.model_dump(),
"preset": True, "preset": True,
**({"pipe": pipe} if pipe is not None else {}),
} }
) )
app.state.MODELS = {model["id"]: model for model in models} app.state.MODELS = {model["id"]: model for model in models}
webui_app.state.MODELS = app.state.MODELS webui_app.state.MODELS = app.state.MODELS
return models return models
...@@ -945,22 +1059,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ...@@ -945,22 +1059,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
) )
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
filters = [ sorted_filters = get_sorted_filters(model_id)
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
if "pipeline" in model: if "pipeline" in model:
sorted_filters = [model] + sorted_filters sorted_filters = [model] + sorted_filters
...@@ -1008,6 +1107,25 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ...@@ -1008,6 +1107,25 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
else: else:
pass pass
async def __event_emitter__(event_data):
await sio.emit(
"chat-events",
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"data": event_data,
},
to=data["session_id"],
)
async def __event_call__(event_data):
response = await sio.call(
"chat-events",
{"chat_id": data["chat_id"], "message_id": data["id"], "data": event_data},
to=data["session_id"],
)
return response
def get_priority(function_id): def get_priority(function_id):
function = Functions.get_function_by_id(function_id) function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"): if function is not None and hasattr(function, "valves"):
...@@ -1032,68 +1150,74 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ...@@ -1032,68 +1150,74 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
for filter_id in filter_ids: for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id) filter = Functions.get_function_by_id(filter_id)
if filter: if not filter:
if filter_id in webui_app.state.FUNCTIONS: continue
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type, frontmatter = (
load_function_module_by_id(filter_id)
)
webui_app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
try: if filter_id in webui_app.state.FUNCTIONS:
if hasattr(function_module, "outlet"): function_module = webui_app.state.FUNCTIONS[filter_id]
outlet = function_module.outlet else:
function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Get the signature of the function if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
sig = inspect.signature(outlet) valves = Functions.get_function_valves_by_id(filter_id)
params = {"body": data} function_module.valves = function_module.Valves(
**(valves if valves else {})
)
if "__user__" in sig.parameters: if not hasattr(function_module, "outlet"):
__user__ = { continue
"id": user.id, try:
"email": user.email, outlet = function_module.outlet
"name": user.name,
"role": user.role, # Get the signature of the function
} sig = inspect.signature(outlet)
params = {"body": data}
try:
if hasattr(function_module, "UserValves"): # Extra parameters to be passed to the function
__user__["valves"] = function_module.UserValves( extra_params = {
**Functions.get_user_valves_by_id_and_user_id( "__model__": model,
filter_id, user.id "__id__": filter_id,
) "__event_emitter__": __event_emitter__,
) "__event_call__": __event_call__,
except Exception as e: }
print(e)
params = {**params, "__user__": __user__}
if "__id__" in sig.parameters:
params = {
**params,
"__id__": filter_id,
}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
data = outlet(**params)
except Exception as e: # Add extra params in contained in function signature
print(f"Error: {e}") for key, value in extra_params.items():
return JSONResponse( if key in sig.parameters:
status_code=status.HTTP_400_BAD_REQUEST, params[key] = value
content={"detail": str(e)},
) if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
data = outlet(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data return data
...@@ -1169,19 +1293,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): ...@@ -1169,19 +1293,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
# Check if the user has a custom task model # Check if the user has a custom task model
# If the user has a custom task model, use that model # If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama": model_id = get_task_model_id(model_id)
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id) print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
...@@ -1200,7 +1314,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): ...@@ -1200,7 +1314,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
"stream": False, "stream": False,
"max_tokens": 50, "max_tokens": 50,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
"title": True, "task": TASKS.TITLE_GENERATION,
} }
log.debug(payload) log.debug(payload)
...@@ -1213,6 +1327,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): ...@@ -1213,6 +1327,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
content={"detail": e.args[1]}, content={"detail": e.args[1]},
) )
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completions(form_data=payload, user=user)
...@@ -1235,19 +1352,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) ...@@ -1235,19 +1352,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
# Check if the user has a custom task model # Check if the user has a custom task model
# If the user has a custom task model, use that model # If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama": model_id = get_task_model_id(model_id)
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id) print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
...@@ -1260,7 +1367,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) ...@@ -1260,7 +1367,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": False, "stream": False,
"max_tokens": 30, "max_tokens": 30,
"task": True, "task": TASKS.QUERY_GENERATION,
} }
print(payload) print(payload)
...@@ -1273,6 +1380,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) ...@@ -1273,6 +1380,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
content={"detail": e.args[1]}, content={"detail": e.args[1]},
) )
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completions(form_data=payload, user=user)
...@@ -1289,19 +1399,9 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): ...@@ -1289,19 +1399,9 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
# Check if the user has a custom task model # Check if the user has a custom task model
# If the user has a custom task model, use that model # If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama": model_id = get_task_model_id(model_id)
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id) print(model_id)
model = app.state.MODELS[model_id]
template = ''' template = '''
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
...@@ -1324,7 +1424,7 @@ Message: """{{prompt}}""" ...@@ -1324,7 +1424,7 @@ Message: """{{prompt}}"""
"stream": False, "stream": False,
"max_tokens": 4, "max_tokens": 4,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
"task": True, "task": TASKS.EMOJI_GENERATION,
} }
log.debug(payload) log.debug(payload)
...@@ -1337,6 +1437,9 @@ Message: """{{prompt}}""" ...@@ -1337,6 +1437,9 @@ Message: """{{prompt}}"""
content={"detail": e.args[1]}, content={"detail": e.args[1]},
) )
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completions(form_data=payload, user=user)
...@@ -1353,22 +1456,13 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ ...@@ -1353,22 +1456,13 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
# Check if the user has a custom task model # Check if the user has a custom task model
# If the user has a custom task model, use that model # If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama": model_id = get_task_model_id(model_id)
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id) print(model_id)
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try: try:
context, citation, file_handler = await get_function_call_response( context, _, _ = await get_function_call_response(
form_data["messages"], form_data["messages"],
form_data.get("files", []), form_data.get("files", []),
form_data["tool_id"], form_data["tool_id"],
...@@ -1432,6 +1526,7 @@ async def upload_pipeline( ...@@ -1432,6 +1526,7 @@ async def upload_pipeline(
os.makedirs(upload_folder, exist_ok=True) os.makedirs(upload_folder, exist_ok=True)
file_path = os.path.join(upload_folder, file.filename) file_path = os.path.join(upload_folder, file.filename)
r = None
try: try:
# Save the uploaded file # Save the uploaded file
with open(file_path, "wb") as buffer: with open(file_path, "wb") as buffer:
...@@ -1455,7 +1550,9 @@ async def upload_pipeline( ...@@ -1455,7 +1550,9 @@ async def upload_pipeline(
print(f"Connection error: {e}") print(f"Connection error: {e}")
detail = "Pipeline not found" detail = "Pipeline not found"
status_code = status.HTTP_404_NOT_FOUND
if r is not None: if r is not None:
status_code = r.status_code
try: try:
res = r.json() res = r.json()
if "detail" in res: if "detail" in res:
...@@ -1464,7 +1561,7 @@ async def upload_pipeline( ...@@ -1464,7 +1561,7 @@ async def upload_pipeline(
pass pass
raise HTTPException( raise HTTPException(
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), status_code=status_code,
detail=detail, detail=detail,
) )
finally: finally:
...@@ -1563,8 +1660,6 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_ ...@@ -1563,8 +1660,6 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_
async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)): async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)):
r = None r = None
try: try:
urlIdx
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
...@@ -1596,7 +1691,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use ...@@ -1596,7 +1691,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
@app.get("/api/pipelines/{pipeline_id}/valves") @app.get("/api/pipelines/{pipeline_id}/valves")
async def get_pipeline_valves( async def get_pipeline_valves(
urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user) urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
): ):
models = await get_all_models() models = await get_all_models()
r = None r = None
...@@ -1634,7 +1731,9 @@ async def get_pipeline_valves( ...@@ -1634,7 +1731,9 @@ async def get_pipeline_valves(
@app.get("/api/pipelines/{pipeline_id}/valves/spec") @app.get("/api/pipelines/{pipeline_id}/valves/spec")
async def get_pipeline_valves_spec( async def get_pipeline_valves_spec(
urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user) urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
): ):
models = await get_all_models() models = await get_all_models()
...@@ -1920,7 +2019,8 @@ async def oauth_callback(provider: str, request: Request, response: Response): ...@@ -1920,7 +2019,8 @@ async def oauth_callback(provider: str, request: Request, response: Response):
if existing_user: if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_url = user_data.get("picture", "") picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM
picture_url = user_data.get(picture_claim, "")
if picture_url: if picture_url:
# Download the profile image into a base64 string # Download the profile image into a base64 string
try: try:
...@@ -1940,6 +2040,7 @@ async def oauth_callback(provider: str, request: Request, response: Response): ...@@ -1940,6 +2040,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
picture_url = "" picture_url = ""
if not picture_url: if not picture_url:
picture_url = "/user.png" picture_url = "/user.png"
username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
role = ( role = (
"admin" "admin"
if Users.get_num_users() == 0 if Users.get_num_users() == 0
...@@ -1950,7 +2051,7 @@ async def oauth_callback(provider: str, request: Request, response: Response): ...@@ -1950,7 +2051,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
password=get_password_hash( password=get_password_hash(
str(uuid.uuid4()) str(uuid.uuid4())
), # Random password, not used ), # Random password, not used
name=user_data.get("name", "User"), name=user_data.get(username_claim, "User"),
profile_image_url=picture_url, profile_image_url=picture_url,
role=role, role=role,
oauth_sub=provider_sub, oauth_sub=provider_sub,
...@@ -2008,7 +2109,7 @@ async def get_opensearch_xml(): ...@@ -2008,7 +2109,7 @@ async def get_opensearch_xml():
<ShortName>{WEBUI_NAME}</ShortName> <ShortName>{WEBUI_NAME}</ShortName>
<Description>Search {WEBUI_NAME}</Description> <Description>Search {WEBUI_NAME}</Description>
<InputEncoding>UTF-8</InputEncoding> <InputEncoding>UTF-8</InputEncoding>
<Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/favicon.png</Image> <Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/static/favicon.png</Image>
<Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/> <Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/>
<moz:SearchForm>{WEBUI_URL}</moz:SearchForm> <moz:SearchForm>{WEBUI_URL}</moz:SearchForm>
</OpenSearchDescription> </OpenSearchDescription>
...@@ -2021,6 +2122,12 @@ async def healthcheck(): ...@@ -2021,6 +2122,12 @@ async def healthcheck():
return {"status": True} return {"status": True}
@app.get("/health/db")
async def healthcheck_with_db():
Session.execute(text("SELECT 1;")).all()
return {"status": True}
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache") app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
......
Generic single-database configuration.
Create new migrations with
DATABASE_URL=<replace with actual url> alembic revision --autogenerate -m "a description"
import os
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from apps.webui.models.auths import Auth
from apps.webui.models.chats import Chat
from apps.webui.models.documents import Document
from apps.webui.models.memories import Memory
from apps.webui.models.models import Model
from apps.webui.models.prompts import Prompt
from apps.webui.models.tags import Tag, ChatIdTag
from apps.webui.models.tools import Tool
from apps.webui.models.users import User
from apps.webui.models.files import File
from apps.webui.models.functions import Function
from config import DATABASE_URL
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Auth.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
DB_URL = DATABASE_URL
if DB_URL:
config.set_main_option("sqlalchemy.url", DB_URL.replace("%", "%%"))
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment