main.py 12.5 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
import re
Timothy J. Baek's avatar
Timothy J. Baek committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import requests
from fastapi import (
    FastAPI,
    Request,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel

from constants import ERROR_MESSAGES
from utils.utils import (
    get_current_user,
    get_admin_user,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
21
22

from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
Timothy J. Baek's avatar
Timothy J. Baek committed
23
24
25
from utils.misc import calculate_sha256
from typing import Optional
from pydantic import BaseModel
Timothy J. Baek's avatar
Timothy J. Baek committed
26
27
28
29
from pathlib import Path
import uuid
import base64
import json
30
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
31

32
from config import SRC_LOG_LEVELS, CACHE_DIR, AUTOMATIC1111_BASE_URL, COMFYUI_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
33
34


35
36
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
Timothy J. Baek's avatar
Timothy J. Baek committed
37
38
39

IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
40
41
42
43
44
45
46
47
48
49

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Timothy J. Baek's avatar
Timothy J. Baek committed
50
51
52
53
54
55
56
app.state.ENGINE = ""
app.state.ENABLED = False

app.state.OPENAI_API_KEY = ""
app.state.MODEL = ""


Timothy J. Baek's avatar
Timothy J. Baek committed
57
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
58
59
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL

Timothy J. Baek's avatar
Timothy J. Baek committed
60

Timothy J. Baek's avatar
Timothy J. Baek committed
61
app.state.IMAGE_SIZE = "512x512"
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
62
app.state.IMAGE_STEPS = 50
Timothy J. Baek's avatar
Timothy J. Baek committed
63
64


Timothy J. Baek's avatar
Timothy J. Baek committed
65
66
67
@app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
    return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
Timothy J. Baek's avatar
Timothy J. Baek committed
68
69


Timothy J. Baek's avatar
Timothy J. Baek committed
70
71
72
73
74
75
76
77
78
79
class ConfigUpdateForm(BaseModel):
    engine: str
    enabled: bool


@app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
    app.state.ENGINE = form_data.engine
    app.state.ENABLED = form_data.enabled
    return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
Timothy J. Baek's avatar
Timothy J. Baek committed
80
81


Timothy J. Baek's avatar
Timothy J. Baek committed
82
83
84
class EngineUrlUpdateForm(BaseModel):
    AUTOMATIC1111_BASE_URL: Optional[str] = None
    COMFYUI_BASE_URL: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
85
86
87


@app.get("/url")
Timothy J. Baek's avatar
Timothy J. Baek committed
88
89
90
91
92
async def get_engine_url(user=Depends(get_admin_user)):
    return {
        "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
        "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
93
94
95


@app.post("/url/update")
Timothy J. Baek's avatar
Timothy J. Baek committed
96
97
async def update_engine_url(
    form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
98
):
Timothy J. Baek's avatar
Timothy J. Baek committed
99

Timothy J. Baek's avatar
Timothy J. Baek committed
100
    if form_data.AUTOMATIC1111_BASE_URL == None:
Timothy J. Baek's avatar
Timothy J. Baek committed
101
102
        app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
103
        url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
Timothy J. Baek's avatar
Timothy J. Baek committed
104
105
106
107
108
        try:
            r = requests.head(url)
            app.state.AUTOMATIC1111_BASE_URL = url
        except Exception as e:
            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
109

Timothy J. Baek's avatar
Timothy J. Baek committed
110
111
112
113
    if form_data.COMFYUI_BASE_URL == None:
        app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
    else:
        url = form_data.COMFYUI_BASE_URL.strip("/")
Timothy J. Baek's avatar
Timothy J. Baek committed
114
115
116
117
118
119

        try:
            r = requests.head(url)
            app.state.COMFYUI_BASE_URL = url
        except Exception as e:
            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
120

Timothy J. Baek's avatar
Timothy J. Baek committed
121
122
    return {
        "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
123
        "COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
124
125
        "status": True,
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
126
127


Timothy J. Baek's avatar
Timothy J. Baek committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
class OpenAIKeyUpdateForm(BaseModel):
    key: str


@app.get("/key")
async def get_openai_key(user=Depends(get_admin_user)):
    return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}


@app.post("/key/update")
async def update_openai_key(
    form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user)
):

    if form_data.key == "":
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

    app.state.OPENAI_API_KEY = form_data.key
    return {
        "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
        "status": True,
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
class ImageSizeUpdateForm(BaseModel):
    size: str


@app.get("/size")
async def get_image_size(user=Depends(get_admin_user)):
    return {"IMAGE_SIZE": app.state.IMAGE_SIZE}


@app.post("/size/update")
async def update_image_size(
    form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
):
    pattern = r"^\d+x\d+$"  # Regular expression pattern
    if re.match(pattern, form_data.size):
        app.state.IMAGE_SIZE = form_data.size
        return {
            "IMAGE_SIZE": app.state.IMAGE_SIZE,
            "status": True,
        }
    else:
        raise HTTPException(
            status_code=400,
            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 512x512)."),
        )
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
177

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

class ImageStepsUpdateForm(BaseModel):
    steps: int


@app.get("/steps")
async def get_image_size(user=Depends(get_admin_user)):
    return {"IMAGE_STEPS": app.state.IMAGE_STEPS}


@app.post("/steps/update")
async def update_image_size(
    form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
):
    if form_data.steps >= 0:
        app.state.IMAGE_STEPS = form_data.steps
        return {
            "IMAGE_STEPS": app.state.IMAGE_STEPS,
            "status": True,
        }
    else:
        raise HTTPException(
            status_code=400,
            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 50)."),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
203
204


Timothy J. Baek's avatar
Timothy J. Baek committed
205
206
207
@app.get("/models")
def get_models(user=Depends(get_current_user)):
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
208
209
210
211
212
        if app.state.ENGINE == "openai":
            return [
                {"id": "dall-e-2", "name": "DALL·E 2"},
                {"id": "dall-e-3", "name": "DALL·E 3"},
            ]
Timothy J. Baek's avatar
Timothy J. Baek committed
213
214
215
216
217
218
219
220
221
222
223
224
        elif app.state.ENGINE == "comfyui":

            r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
            info = r.json()

            return list(
                map(
                    lambda model: {"id": model, "name": model},
                    info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
                )
            )

Timothy J. Baek's avatar
Timothy J. Baek committed
225
226
227
228
229
230
231
232
233
234
235
        else:
            r = requests.get(
                url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
            )
            models = r.json()
            return list(
                map(
                    lambda model: {"id": model["title"], "name": model["model_name"]},
                    models,
                )
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
236
    except Exception as e:
237
        app.state.ENABLED = False
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
238
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
239
240
241
242
243


@app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)):
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
244
245
        if app.state.ENGINE == "openai":
            return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
Timothy J. Baek's avatar
Timothy J. Baek committed
246
247
        elif app.state.ENGINE == "comfyui":
            return {"model": app.state.MODEL if app.state.MODEL else ""}
Timothy J. Baek's avatar
Timothy J. Baek committed
248
249
250
251
        else:
            r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
            options = r.json()
            return {"model": options["sd_model_checkpoint"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
252
    except Exception as e:
253
        app.state.ENABLED = False
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
254
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
255
256
257
258
259
260
261


class UpdateModelForm(BaseModel):
    model: str


def set_model_handler(model: str):
Timothy J. Baek's avatar
Timothy J. Baek committed
262
263
264
    if app.state.ENGINE == "openai":
        app.state.MODEL = model
        return app.state.MODEL
Timothy J. Baek's avatar
Timothy J. Baek committed
265
266
267
    if app.state.ENGINE == "comfyui":
        app.state.MODEL = model
        return app.state.MODEL
Timothy J. Baek's avatar
Timothy J. Baek committed
268
269
270
271
272
273
274
275
276
    else:
        r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
        options = r.json()

        if model != options["sd_model_checkpoint"]:
            options["sd_model_checkpoint"] = model
            r = requests.post(
                url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
277

Timothy J. Baek's avatar
Timothy J. Baek committed
278
        return options
Timothy J. Baek's avatar
Timothy J. Baek committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292


@app.post("/models/default/update")
def update_default_model(
    form_data: UpdateModelForm,
    user=Depends(get_current_user),
):
    return set_model_handler(form_data.model)


class GenerateImageForm(BaseModel):
    model: Optional[str] = None
    prompt: str
    n: int = 1
Timothy J. Baek's avatar
Timothy J. Baek committed
293
    size: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
294
295
296
    negative_prompt: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
297
298
299
300
301
302
303
304
305
306
307
308
309
310
def save_b64_image(b64_str):
    image_id = str(uuid.uuid4())
    file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")

    try:
        # Split the base64 string to get the actual image data
        img_data = base64.b64decode(b64_str)

        # Write the image data to a file
        with open(file_path, "wb") as f:
            f.write(img_data)

        return image_id
    except Exception as e:
311
        log.error(f"Error saving image: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
312
313
314
        return None


Timothy J. Baek's avatar
Timothy J. Baek committed
315
316
317
318
319
320
321
322
323
324
325
326
327
def save_url_image(url):
    image_id = str(uuid.uuid4())
    file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")

    try:
        r = requests.get(url)
        r.raise_for_status()

        with open(file_path, "wb") as image_file:
            image_file.write(r.content)

        return image_id
    except Exception as e:
328
        log.exception(f"Error saving image: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
329
330
331
        return None


Timothy J. Baek's avatar
Timothy J. Baek committed
332
333
334
335
336
337
@app.post("/generations")
def generate_image(
    form_data: GenerateImageForm,
    user=Depends(get_current_user),
):

Timothy J. Baek's avatar
Timothy J. Baek committed
338
339
    width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))

Timothy J. Baek's avatar
Timothy J. Baek committed
340
    r = None
341
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
342
        if app.state.ENGINE == "openai":
343

Timothy J. Baek's avatar
Timothy J. Baek committed
344
345
346
            headers = {}
            headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
            headers["Content-Type"] = "application/json"
347

Timothy J. Baek's avatar
Timothy J. Baek committed
348
349
350
351
            data = {
                "model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
                "prompt": form_data.prompt,
                "n": form_data.n,
Timothy J. Baek's avatar
Timothy J. Baek committed
352
                "size": form_data.size if form_data.size else app.state.IMAGE_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
353
354
                "response_format": "b64_json",
            }
355

Timothy J. Baek's avatar
Timothy J. Baek committed
356
357
358
359
360
            r = requests.post(
                url=f"https://api.openai.com/v1/images/generations",
                json=data,
                headers=headers,
            )
361

Timothy J. Baek's avatar
Timothy J. Baek committed
362
363
            r.raise_for_status()
            res = r.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
364

Timothy J. Baek's avatar
Timothy J. Baek committed
365
366
367
368
369
370
371
372
373
374
375
376
            images = []

            for image in res["data"]:
                image_id = save_b64_image(image["b64_json"])
                images.append({"url": f"/cache/image/generations/{image_id}.png"})
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")

                with open(file_body_path, "w") as f:
                    json.dump(data, f)

            return images

Timothy J. Baek's avatar
Timothy J. Baek committed
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
        elif app.state.ENGINE == "comfyui":

            data = {
                "prompt": form_data.prompt,
                "width": width,
                "height": height,
                "n": form_data.n,
            }

            if app.state.IMAGE_STEPS != None:
                data["steps"] = app.state.IMAGE_STEPS

            if form_data.negative_prompt != None:
                data["negative_prompt"] = form_data.negative_prompt

            data = ImageGenerationPayload(**data)

            res = comfyui_generate_image(
                app.state.MODEL,
                data,
                user.id,
                app.state.COMFYUI_BASE_URL,
            )
400
            log.debug(f"res: {res}")
Timothy J. Baek's avatar
Timothy J. Baek committed
401
402
403
404
405
406
407
408
409
410
411

            images = []

            for image in res["data"]:
                image_id = save_url_image(image["url"])
                images.append({"url": f"/cache/image/generations/{image_id}.png"})
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")

                with open(file_body_path, "w") as f:
                    json.dump(data.model_dump(exclude_none=True), f)

412
            log.debug(f"images: {images}")
Timothy J. Baek's avatar
Timothy J. Baek committed
413
            return images
Timothy J. Baek's avatar
Timothy J. Baek committed
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
        else:
            if form_data.model:
                set_model_handler(form_data.model)

            data = {
                "prompt": form_data.prompt,
                "batch_size": form_data.n,
                "width": width,
                "height": height,
            }

            if app.state.IMAGE_STEPS != None:
                data["steps"] = app.state.IMAGE_STEPS

            if form_data.negative_prompt != None:
                data["negative_prompt"] = form_data.negative_prompt

            r = requests.post(
                url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
                json=data,
            )

            res = r.json()

438
            log.debug(f"res: {res}")
Timothy J. Baek's avatar
Timothy J. Baek committed
439
440
441
442
443
444
445
446
447
448
449
450

            images = []

            for image in res["images"]:
                image_id = save_b64_image(image)
                images.append({"url": f"/cache/image/generations/{image_id}.png"})
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")

                with open(file_body_path, "w") as f:
                    json.dump({**data, "info": res["info"]}, f)

            return images
451
452

    except Exception as e:
453
454
455
456
457
458
459
        error = e

        if r != None:
            data = r.json()
            if "error" in data:
                error = data["error"]["message"]
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))