"src/nni_manager/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "bf8be1e705cdc0c153d6f38d1a83bd3fe55d7a1e"
Unverified Commit 5166e92f authored by arkohut's avatar arkohut Committed by GitHub
Browse files

Merge branch 'dev' into support-py-for-run-code

parents b443d61c b6b71c08
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Model(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
base_model_id = pw.TextField(null=True)
name = pw.TextField()
meta = pw.TextField()
params = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "model"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("model")
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
import json
from utils.misc import parse_ollama_modelfile
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Fetch data from 'modelfile' table and insert into 'model' table
migrate_modelfile_to_model(migrator, database)
# Drop the 'modelfile' table
migrator.remove_model("modelfile")
def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database):
ModelFile = migrator.orm["modelfile"]
Model = migrator.orm["model"]
modelfiles = ModelFile.select()
for modelfile in modelfiles:
# Extract and transform data in Python
modelfile.modelfile = json.loads(modelfile.modelfile)
meta = json.dumps(
{
"description": modelfile.modelfile.get("desc"),
"profile_image_url": modelfile.modelfile.get("imageUrl"),
"ollama": {"modelfile": modelfile.modelfile.get("content")},
"suggestion_prompts": modelfile.modelfile.get("suggestionPrompts"),
"categories": modelfile.modelfile.get("categories"),
"user": {**modelfile.modelfile.get("user", {}), "community": True},
}
)
info = parse_ollama_modelfile(modelfile.modelfile.get("content"))
# Insert the processed data into the 'model' table
Model.create(
id=f"ollama-{modelfile.tag_name}",
user_id=modelfile.user_id,
base_model_id=info.get("base_model_id"),
name=modelfile.modelfile.get("title"),
meta=meta,
params=json.dumps(info.get("params", {})),
created_at=modelfile.timestamp,
updated_at=modelfile.timestamp,
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
recreate_modelfile_table(migrator, database)
move_data_back_to_modelfile(migrator, database)
migrator.remove_model("model")
def recreate_modelfile_table(migrator: Migrator, database: pw.Database):
query = """
CREATE TABLE IF NOT EXISTS modelfile (
user_id TEXT,
tag_name TEXT,
modelfile JSON,
timestamp BIGINT
)
"""
migrator.sql(query)
def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database):
Model = migrator.orm["model"]
Modelfile = migrator.orm["modelfile"]
models = Model.select()
for model in models:
# Extract and transform data in Python
meta = json.loads(model.meta)
modelfile_data = {
"title": model.name,
"desc": meta.get("description"),
"imageUrl": meta.get("profile_image_url"),
"content": meta.get("ollama", {}).get("modelfile"),
"suggestionPrompts": meta.get("suggestion_prompts"),
"categories": meta.get("categories"),
"user": {k: v for k, v in meta.get("user", {}).items() if k != "community"},
}
# Insert the processed data back into the 'modelfile' table
Modelfile.create(
user_id=model.user_id,
tag_name=model.id,
modelfile=modelfile_data,
timestamp=model.created_at,
)
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Adding fields settings to the 'user' table
migrator.add_fields("user", settings=pw.TextField(null=True))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
# Remove the settings field
migrator.remove_fields("user", "settings")
...@@ -14,7 +14,7 @@ You will need to create a migration file to ensure that existing databases are u ...@@ -14,7 +14,7 @@ You will need to create a migration file to ensure that existing databases are u
2. Make your changes to the models. 2. Make your changes to the models.
3. From the `backend` directory, run the following command: 3. From the `backend` directory, run the following command:
```bash ```bash
pw_migrate create --auto --auto-source apps.web.models --database sqlite:///${SQLITE_DB} --directory apps/web/internal/migrations ${MIGRATION_NAME} pw_migrate create --auto --auto-source apps.webui.models --database sqlite:///${SQLITE_DB} --directory apps/web/internal/migrations ${MIGRATION_NAME}
``` ```
- `$SQLITE_DB` should be the path to the database file. - `$SQLITE_DB` should be the path to the database file.
- `$MIGRATION_NAME` should be a descriptive name for the migration. - `$MIGRATION_NAME` should be a descriptive name for the migration.
......
from fastapi import FastAPI, Depends from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from apps.web.routers import ( from apps.webui.routers import (
auths, auths,
users, users,
chats, chats,
documents, documents,
modelfiles, models,
prompts, prompts,
configs, configs,
memories, memories,
utils, utils,
) )
from config import ( from config import (
WEBUI_VERSION, WEBUI_BUILD_HASH,
WEBUI_AUTH, WEBUI_AUTH,
DEFAULT_MODELS, DEFAULT_MODELS,
DEFAULT_PROMPT_SUGGESTIONS, DEFAULT_PROMPT_SUGGESTIONS,
...@@ -23,7 +23,9 @@ from config import ( ...@@ -23,7 +23,9 @@ from config import (
WEBHOOK_URL, WEBHOOK_URL,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN, JWT_EXPIRES_IN,
WEBUI_BANNERS,
AppConfig, AppConfig,
ENABLE_COMMUNITY_SHARING,
) )
app = FastAPI() app = FastAPI()
...@@ -40,6 +42,11 @@ app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS ...@@ -40,6 +42,11 @@ app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.MODELS = {}
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
...@@ -56,11 +63,10 @@ app.include_router(users.router, prefix="/users", tags=["users"]) ...@@ -56,11 +63,10 @@ app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"]) app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"]) app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(utils.router, prefix="/utils", tags=["utils"]) app.include_router(utils.router, prefix="/utils", tags=["utils"])
......
...@@ -5,10 +5,10 @@ import uuid ...@@ -5,10 +5,10 @@ import uuid
import logging import logging
from peewee import * from peewee import *
from apps.web.models.users import UserModel, Users from apps.webui.models.users import UserModel, Users
from utils.utils import verify_password from utils.utils import verify_password
from apps.web.internal.db import DB from apps.webui.internal.db import DB
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
......
...@@ -7,7 +7,7 @@ import json ...@@ -7,7 +7,7 @@ import json
import uuid import uuid
import time import time
from apps.web.internal.db import DB from apps.webui.internal.db import DB
#################### ####################
# Chat DB Schema # Chat DB Schema
...@@ -191,6 +191,20 @@ class ChatTable: ...@@ -191,6 +191,20 @@ class ChatTable:
except: except:
return None return None
def archive_all_chats_by_user_id(self, user_id: str) -> bool:
try:
chats = self.get_chats_by_user_id(user_id)
for chat in chats:
query = Chat.update(
archived=True,
).where(Chat.id == chat.id)
query.execute()
return True
except:
return False
def get_archived_chat_list_by_user_id( def get_archived_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50 self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
...@@ -205,17 +219,31 @@ class ChatTable: ...@@ -205,17 +219,31 @@ class ChatTable:
] ]
def get_chat_list_by_user_id( def get_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50 self,
user_id: str,
include_archived: bool = False,
skip: int = 0,
limit: int = 50,
) -> List[ChatModel]: ) -> List[ChatModel]:
return [ if include_archived:
ChatModel(**model_to_dict(chat)) return [
for chat in Chat.select() ChatModel(**model_to_dict(chat))
.where(Chat.archived == False) for chat in Chat.select()
.where(Chat.user_id == user_id) .where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc()) .order_by(Chat.updated_at.desc())
# .limit(limit) # .limit(limit)
# .offset(skip) # .offset(skip)
] ]
else:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == False)
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
# .limit(limit)
# .offset(skip)
]
def get_chat_list_by_chat_ids( def get_chat_list_by_chat_ids(
self, chat_ids: List[str], skip: int = 0, limit: int = 50 self, chat_ids: List[str], skip: int = 0, limit: int = 50
......
...@@ -8,7 +8,7 @@ import logging ...@@ -8,7 +8,7 @@ import logging
from utils.utils import decode_token from utils.utils import decode_token
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
from apps.web.internal.db import DB from apps.webui.internal.db import DB
import json import json
......
...@@ -3,8 +3,8 @@ from peewee import * ...@@ -3,8 +3,8 @@ from peewee import *
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional from typing import List, Union, Optional
from apps.web.internal.db import DB from apps.webui.internal.db import DB
from apps.web.models.chats import Chats from apps.webui.models.chats import Chats
import time import time
import uuid import uuid
......
import json
import logging
from typing import Optional
import peewee as pw
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from apps.webui.internal.db import DB, JSONField
from typing import List, Union, Optional
from config import SRC_LOG_LEVELS
import time
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Models DB Schema
####################
# ModelParams is a model for the data stored in the params field of the Model table
class ModelParams(BaseModel):
model_config = ConfigDict(extra="allow")
pass
# ModelMeta is a model for the data stored in the meta field of the Model table
class ModelMeta(BaseModel):
profile_image_url: Optional[str] = "/favicon.png"
description: Optional[str] = None
"""
User-facing description of the model.
"""
capabilities: Optional[dict] = None
model_config = ConfigDict(extra="allow")
pass
class Model(pw.Model):
id = pw.TextField(unique=True)
"""
The model's id as used in the API. If set to an existing model, it will override the model.
"""
user_id = pw.TextField()
base_model_id = pw.TextField(null=True)
"""
An optional pointer to the actual model that should be used when proxying requests.
"""
name = pw.TextField()
"""
The human-readable display name of the model.
"""
params = JSONField()
"""
Holds a JSON encoded blob of parameters, see `ModelParams`.
"""
meta = JSONField()
"""
Holds a JSON encoded blob of metadata, see `ModelMeta`.
"""
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta:
database = DB
class ModelModel(BaseModel):
id: str
user_id: str
base_model_id: Optional[str] = None
name: str
params: ModelParams
meta: ModelMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
####################
# Forms
####################
class ModelResponse(BaseModel):
id: str
name: str
meta: ModelMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class ModelForm(BaseModel):
id: str
base_model_id: Optional[str] = None
name: str
meta: ModelMeta
params: ModelParams
class ModelsTable:
def __init__(
self,
db: pw.SqliteDatabase | pw.PostgresqlDatabase,
):
self.db = db
self.db.create_tables([Model])
def insert_new_model(
self, form_data: ModelForm, user_id: str
) -> Optional[ModelModel]:
model = ModelModel(
**{
**form_data.model_dump(),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
try:
result = Model.create(**model.model_dump())
if result:
return model
else:
return None
except Exception as e:
print(e)
return None
def get_all_models(self) -> List[ModelModel]:
return [ModelModel(**model_to_dict(model)) for model in Model.select()]
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try:
model = Model.get(Model.id == id)
return ModelModel(**model_to_dict(model))
except:
return None
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
try:
# update only the fields that are present in the model
query = Model.update(**model.model_dump()).where(Model.id == id)
query.execute()
model = Model.get(Model.id == id)
return ModelModel(**model_to_dict(model))
except Exception as e:
print(e)
return None
def delete_model_by_id(self, id: str) -> bool:
try:
query = Model.delete().where(Model.id == id)
query.execute()
return True
except:
return False
Models = ModelsTable(DB)
...@@ -7,7 +7,7 @@ import time ...@@ -7,7 +7,7 @@ import time
from utils.utils import decode_token from utils.utils import decode_token
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
from apps.web.internal.db import DB from apps.webui.internal.db import DB
import json import json
......
...@@ -8,7 +8,7 @@ import uuid ...@@ -8,7 +8,7 @@ import uuid
import time import time
import logging import logging
from apps.web.internal.db import DB from apps.webui.internal.db import DB
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
......
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from peewee import * from peewee import *
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional from typing import List, Union, Optional
import time import time
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
from apps.web.internal.db import DB from apps.webui.internal.db import DB, JSONField
from apps.web.models.chats import Chats from apps.webui.models.chats import Chats
#################### ####################
# User DB Schema # User DB Schema
...@@ -25,11 +25,18 @@ class User(Model): ...@@ -25,11 +25,18 @@ class User(Model):
created_at = BigIntegerField() created_at = BigIntegerField()
api_key = CharField(null=True, unique=True) api_key = CharField(null=True, unique=True)
settings = JSONField(null=True)
class Meta: class Meta:
database = DB database = DB
class UserSettings(BaseModel):
ui: Optional[dict] = {}
model_config = ConfigDict(extra="allow")
pass
class UserModel(BaseModel): class UserModel(BaseModel):
id: str id: str
name: str name: str
...@@ -42,6 +49,7 @@ class UserModel(BaseModel): ...@@ -42,6 +49,7 @@ class UserModel(BaseModel):
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
api_key: Optional[str] = None api_key: Optional[str] = None
settings: Optional[UserSettings] = None
#################### ####################
......
...@@ -10,7 +10,7 @@ import uuid ...@@ -10,7 +10,7 @@ import uuid
import csv import csv
from apps.web.models.auths import ( from apps.webui.models.auths import (
SigninForm, SigninForm,
SignupForm, SignupForm,
AddUserForm, AddUserForm,
...@@ -21,7 +21,7 @@ from apps.web.models.auths import ( ...@@ -21,7 +21,7 @@ from apps.web.models.auths import (
Auths, Auths,
ApiKey, ApiKey,
) )
from apps.web.models.users import Users from apps.webui.models.users import Users
from utils.utils import ( from utils.utils import (
get_password_hash, get_password_hash,
......
...@@ -7,8 +7,8 @@ from pydantic import BaseModel ...@@ -7,8 +7,8 @@ from pydantic import BaseModel
import json import json
import logging import logging
from apps.web.models.users import Users from apps.webui.models.users import Users
from apps.web.models.chats import ( from apps.webui.models.chats import (
ChatModel, ChatModel,
ChatResponse, ChatResponse,
ChatTitleForm, ChatTitleForm,
...@@ -18,7 +18,7 @@ from apps.web.models.chats import ( ...@@ -18,7 +18,7 @@ from apps.web.models.chats import (
) )
from apps.web.models.tags import ( from apps.webui.models.tags import (
TagModel, TagModel,
ChatIdTagModel, ChatIdTagModel,
ChatIdTagForm, ChatIdTagForm,
...@@ -78,43 +78,25 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user) ...@@ -78,43 +78,25 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user)
async def get_user_chat_list_by_user_id( async def get_user_chat_list_by_user_id(
user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50 user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50
): ):
return Chats.get_chat_list_by_user_id(user_id, skip, limit) return Chats.get_chat_list_by_user_id(
user_id, include_archived=True, skip=skip, limit=limit
)
############################
# GetArchivedChats
############################
@router.get("/archived", response_model=List[ChatTitleIdResponse])
async def get_archived_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
############################ ############################
# GetSharedChatById # CreateNewChat
############################ ############################
@router.get("/share/{share_id}", response_model=Optional[ChatResponse]) @router.post("/new", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)): async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
if user.role == "pending": try:
raise HTTPException( chat = Chats.insert_new_chat(user.id, form_data)
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role == "user":
chat = Chats.get_chat_by_share_id(share_id)
elif user.role == "admin":
chat = Chats.get_chat_by_id(share_id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else: except Exception as e:
log.exception(e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
) )
...@@ -150,19 +132,49 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)): ...@@ -150,19 +132,49 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
############################ ############################
# CreateNewChat # GetArchivedChats
############################ ############################
@router.post("/new", response_model=Optional[ChatResponse]) @router.get("/archived", response_model=List[ChatTitleIdResponse])
async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): async def get_archived_session_user_chat_list(
try: user=Depends(get_current_user), skip: int = 0, limit: int = 50
chat = Chats.insert_new_chat(user.id, form_data) ):
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
############################
# ArchiveAllChats
############################
@router.post("/archive/all", response_model=List[ChatTitleIdResponse])
async def archive_all_chats(user=Depends(get_current_user)):
return Chats.archive_all_chats_by_user_id(user.id)
############################
# GetSharedChatById
############################
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
if user.role == "pending":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role == "user":
chat = Chats.get_chat_by_share_id(share_id)
elif user.role == "admin":
chat = Chats.get_chat_by_id(share_id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
except Exception as e: else:
log.exception(e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment