Commit 17e4be49 authored by Timothy J. Baek's avatar Timothy J. Baek
Browse files

feat: migrate modelfiles to models

parent 3be0fa63
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
import json
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Fetch data from 'modelfile' table and insert into 'model' table
migrate_modelfile_to_model(migrator, database)
# Drop the 'modelfile' table
migrator.remove_model("modelfile")
def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database):
ModelFile = migrator.orm["modelfile"]
Model = migrator.orm["model"]
modelfiles = ModelFile.select()
for modelfile in modelfiles:
# Extract and transform data in Python
modelfile.modelfile = json.loads(modelfile.modelfile)
meta = json.dumps(
{
"description": modelfile.modelfile.get("desc"),
"profile_image_url": modelfile.modelfile.get("imageUrl"),
"ollama": {"modelfile": modelfile.modelfile.get("content")},
"suggestion_prompts": modelfile.modelfile.get("suggestionPrompts"),
"categories": modelfile.modelfile.get("categories"),
"user": {**modelfile.modelfile.get("user", {}), "community": "true"},
}
)
# Insert the processed data into the 'model' table
Model.create(
id=modelfile.tag_name,
user_id=modelfile.user_id,
name=modelfile.modelfile.get("title"),
meta=meta,
params="{}",
created_at=modelfile.timestamp,
updated_at=modelfile.timestamp,
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
recreate_modelfile_table(migrator, database)
move_data_back_to_modelfile(migrator, database)
migrator.remove_model("model")
def recreate_modelfile_table(migrator: Migrator, database: pw.Database):
query = """
CREATE TABLE IF NOT EXISTS modelfile (
user_id TEXT,
tag_name TEXT,
modelfile JSON,
timestamp BIGINT
)
"""
migrator.sql(query)
def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database):
Model = migrator.orm["model"]
Modelfile = migrator.orm["modelfile"]
models = Model.select()
for model in models:
# Extract and transform data in Python
meta = json.loads(model.meta)
modelfile_data = {
"title": model.name,
"desc": meta.get("description"),
"imageUrl": meta.get("profile_image_url"),
"content": meta.get("ollama", {}).get("modelfile"),
"suggestionPrompts": meta.get("suggestion_prompts"),
"categories": meta.get("categories"),
"user": {k: v for k, v in meta.get("user", {}).items() if k != "community"},
}
# Insert the processed data back into the 'modelfile' table
Modelfile.create(
user_id=model.user_id,
tag_name=model.id,
modelfile=modelfile_data,
timestamp=model.created_at,
)
...@@ -4,7 +4,7 @@ from typing import Optional ...@@ -4,7 +4,7 @@ from typing import Optional
import peewee as pw import peewee as pw
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from apps.web.internal.db import DB, JSONField from apps.web.internal.db import DB, JSONField
...@@ -22,29 +22,34 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) ...@@ -22,29 +22,34 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
# ModelParams is a model for the data stored in the params field of the Model table # ModelParams is a model for the data stored in the params field of the Model table
# It isn't currently used in the backend, but it's here as a reference # It isn't currently used in the backend, but it's here as a reference
class ModelParams(BaseModel): class ModelParams(BaseModel):
model_config = ConfigDict(extra="allow")
pass pass
# ModelMeta is a model for the data stored in the meta field of the Model table # ModelMeta is a model for the data stored in the meta field of the Model table
# It isn't currently used in the backend, but it's here as a reference # It isn't currently used in the backend, but it's here as a reference
class ModelMeta(BaseModel): class ModelMeta(BaseModel):
description: str description: Optional[str] = None
""" """
User-facing description of the model. User-facing description of the model.
""" """
vision_capable: bool vision_capable: Optional[bool] = None
""" """
A flag indicating if the model is capable of vision and thus image inputs A flag indicating if the model is capable of vision and thus image inputs
""" """
model_config = ConfigDict(extra="allow")
pass
class Model(pw.Model): class Model(pw.Model):
id = pw.TextField(unique=True) id = pw.TextField(unique=True)
""" """
The model's id as used in the API. If set to an existing model, it will override the model. The model's id as used in the API. If set to an existing model, it will override the model.
""" """
user_id = pw.TextField() user_id = pw.TextField()
base_model_id = pw.TextField(null=True) base_model_id = pw.TextField(null=True)
...@@ -89,7 +94,6 @@ class ModelModel(BaseModel): ...@@ -89,7 +94,6 @@ class ModelModel(BaseModel):
class ModelsTable: class ModelsTable:
def __init__( def __init__(
self, self,
db: pw.SqliteDatabase | pw.PostgresqlDatabase, db: pw.SqliteDatabase | pw.PostgresqlDatabase,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment