main.py 17 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
import re
Timothy J. Baek's avatar
Timothy J. Baek committed
2
import requests
3
import base64
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from fastapi import (
    FastAPI,
    Request,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
from fastapi.middleware.cors import CORSMiddleware

from constants import ERROR_MESSAGES
from utils.utils import (
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
18
    get_verified_user,
Timothy J. Baek's avatar
Timothy J. Baek committed
19
20
    get_admin_user,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
21
22

from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
Timothy J. Baek's avatar
Timothy J. Baek committed
23
24
25
from utils.misc import calculate_sha256
from typing import Optional
from pydantic import BaseModel
Timothy J. Baek's avatar
Timothy J. Baek committed
26
from pathlib import Path
27
import mimetypes
Timothy J. Baek's avatar
Timothy J. Baek committed
28
29
30
import uuid
import base64
import json
31
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
32

Self Denial's avatar
Self Denial committed
33
34
35
from config import (
    SRC_LOG_LEVELS,
    CACHE_DIR,
36
    IMAGE_GENERATION_ENGINE,
Self Denial's avatar
Self Denial committed
37
    ENABLE_IMAGE_GENERATION,
Self Denial's avatar
Self Denial committed
38
    AUTOMATIC1111_BASE_URL,
39
    AUTOMATIC1111_API_AUTH,
Self Denial's avatar
Self Denial committed
40
    COMFYUI_BASE_URL,
41
42
43
44
    COMFYUI_CFG_SCALE,
    COMFYUI_SAMPLER,
    COMFYUI_SCHEDULER,
    COMFYUI_SD3,
45
46
    IMAGES_OPENAI_API_BASE_URL,
    IMAGES_OPENAI_API_KEY,
47
    IMAGE_GENERATION_MODEL,
48
49
    IMAGE_SIZE,
    IMAGE_STEPS,
50
    AppConfig,
Self Denial's avatar
Self Denial committed
51
)
Timothy J. Baek's avatar
Timothy J. Baek committed
52

53
54
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
Timothy J. Baek's avatar
Timothy J. Baek committed
55
56
57

IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
58
59
60
61
62
63
64
65
66
67

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

68
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
69

70
71
app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
Timothy J. Baek's avatar
Timothy J. Baek committed
72

73
74
app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
75

76
app.state.config.MODEL = IMAGE_GENERATION_MODEL
Timothy J. Baek's avatar
Timothy J. Baek committed
77

78
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
79
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
80
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
81

82
83
app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS
84
85
86
87
app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
app.state.config.COMFYUI_SD3 = COMFYUI_SD3
Timothy J. Baek's avatar
Timothy J. Baek committed
88
89


90
91
92
93
def get_automatic1111_api_auth():
    if app.state.config.AUTOMATIC1111_API_AUTH == None:
        return ""
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
94
        auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
95
        auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
Timothy J. Baek's avatar
Timothy J. Baek committed
96
        auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
97
98
99
        return f"Basic {auth1111_base64_encoded_string}"


Timothy J. Baek's avatar
Timothy J. Baek committed
100
101
@app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
102
    return {
103
104
        "engine": app.state.config.ENGINE,
        "enabled": app.state.config.ENABLED,
105
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
106
107


Timothy J. Baek's avatar
Timothy J. Baek committed
108
109
110
111
112
113
114
class ConfigUpdateForm(BaseModel):
    engine: str
    enabled: bool


@app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
115
116
    app.state.config.ENGINE = form_data.engine
    app.state.config.ENABLED = form_data.enabled
117
    return {
118
119
        "engine": app.state.config.ENGINE,
        "enabled": app.state.config.ENABLED,
120
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
121
122


Timothy J. Baek's avatar
Timothy J. Baek committed
123
124
class EngineUrlUpdateForm(BaseModel):
    AUTOMATIC1111_BASE_URL: Optional[str] = None
125
    AUTOMATIC1111_API_AUTH: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
126
    COMFYUI_BASE_URL: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
127
128
129


@app.get("/url")
Timothy J. Baek's avatar
Timothy J. Baek committed
130
131
async def get_engine_url(user=Depends(get_admin_user)):
    return {
132
        "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
133
        "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
134
        "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
135
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
136
137
138


@app.post("/url/update")
Timothy J. Baek's avatar
Timothy J. Baek committed
139
async def update_engine_url(
Timothy J. Baek's avatar
Timothy J. Baek committed
140
    form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
141
):
Timothy J. Baek's avatar
Timothy J. Baek committed
142
    if form_data.AUTOMATIC1111_BASE_URL == None:
143
        app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
144
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
145
        url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
Timothy J. Baek's avatar
Timothy J. Baek committed
146
147
        try:
            r = requests.head(url)
148
            app.state.config.AUTOMATIC1111_BASE_URL = url
Timothy J. Baek's avatar
Timothy J. Baek committed
149
150
        except Exception as e:
            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
151

Timothy J. Baek's avatar
Timothy J. Baek committed
152
    if form_data.COMFYUI_BASE_URL == None:
153
        app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
154
155
    else:
        url = form_data.COMFYUI_BASE_URL.strip("/")
Timothy J. Baek's avatar
Timothy J. Baek committed
156
157
158

        try:
            r = requests.head(url)
159
            app.state.config.COMFYUI_BASE_URL = url
Timothy J. Baek's avatar
Timothy J. Baek committed
160
161
        except Exception as e:
            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
162

163
164
165
166
167
    if form_data.AUTOMATIC1111_API_AUTH == None:
        app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
    else:
        app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH

Timothy J. Baek's avatar
Timothy J. Baek committed
168
    return {
169
        "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
170
        "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
171
        "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
172
173
        "status": True,
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
174
175


176
177
class OpenAIConfigUpdateForm(BaseModel):
    url: str
Timothy J. Baek's avatar
Timothy J. Baek committed
178
179
180
    key: str


181
182
183
@app.get("/openai/config")
async def get_openai_config(user=Depends(get_admin_user)):
    return {
184
185
        "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
        "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
186
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
187
188


189
190
@app.post("/openai/config/update")
async def update_openai_config(
Timothy J. Baek's avatar
Timothy J. Baek committed
191
    form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
192
193
194
195
):
    if form_data.key == "":
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

196
197
    app.state.config.OPENAI_API_BASE_URL = form_data.url
    app.state.config.OPENAI_API_KEY = form_data.key
198

Timothy J. Baek's avatar
Timothy J. Baek committed
199
200
    return {
        "status": True,
201
202
        "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
        "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
203
204
205
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
206
207
208
209
210
211
class ImageSizeUpdateForm(BaseModel):
    size: str


@app.get("/size")
async def get_image_size(user=Depends(get_admin_user)):
212
    return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
Timothy J. Baek's avatar
Timothy J. Baek committed
213
214
215
216


@app.post("/size/update")
async def update_image_size(
Timothy J. Baek's avatar
Timothy J. Baek committed
217
    form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
218
219
220
):
    pattern = r"^\d+x\d+$"  # Regular expression pattern
    if re.match(pattern, form_data.size):
221
        app.state.config.IMAGE_SIZE = form_data.size
Timothy J. Baek's avatar
Timothy J. Baek committed
222
        return {
223
            "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
224
225
226
227
228
229
230
            "status": True,
        }
    else:
        raise HTTPException(
            status_code=400,
            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 512x512)."),
        )
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
231

232
233
234
235
236
237
238

class ImageStepsUpdateForm(BaseModel):
    steps: int


@app.get("/steps")
async def get_image_size(user=Depends(get_admin_user)):
239
    return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
240
241
242
243


@app.post("/steps/update")
async def update_image_size(
Timothy J. Baek's avatar
Timothy J. Baek committed
244
    form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
245
246
):
    if form_data.steps >= 0:
247
        app.state.config.IMAGE_STEPS = form_data.steps
248
        return {
249
            "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
250
251
252
253
254
255
256
            "status": True,
        }
    else:
        raise HTTPException(
            status_code=400,
            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 50)."),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
257
258


Timothy J. Baek's avatar
Timothy J. Baek committed
259
@app.get("/models")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
260
def get_models(user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
261
    try:
262
        if app.state.config.ENGINE == "openai":
Timothy J. Baek's avatar
Timothy J. Baek committed
263
264
265
266
            return [
                {"id": "dall-e-2", "name": "DALL·E 2"},
                {"id": "dall-e-3", "name": "DALL·E 3"},
            ]
267
        elif app.state.config.ENGINE == "comfyui":
Timothy J. Baek's avatar
Timothy J. Baek committed
268

269
            r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
Timothy J. Baek's avatar
Timothy J. Baek committed
270
271
272
273
274
275
276
277
278
            info = r.json()

            return list(
                map(
                    lambda model: {"id": model, "name": model},
                    info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
                )
            )

Timothy J. Baek's avatar
Timothy J. Baek committed
279
280
        else:
            r = requests.get(
281
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
Timothy J. Baek's avatar
Timothy J. Baek committed
282
                headers={"authorization": get_automatic1111_api_auth()},
Timothy J. Baek's avatar
Timothy J. Baek committed
283
284
285
286
287
288
289
290
            )
            models = r.json()
            return list(
                map(
                    lambda model: {"id": model["title"], "name": model["model_name"]},
                    models,
                )
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
291
    except Exception as e:
292
        app.state.config.ENABLED = False
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
293
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
294
295
296
297
298


@app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)):
    try:
299
        if app.state.config.ENGINE == "openai":
300
301
            return {
                "model": (
302
                    app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
303
304
                )
            }
305
306
        elif app.state.config.ENGINE == "comfyui":
            return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
Timothy J. Baek's avatar
Timothy J. Baek committed
307
        else:
308
            r = requests.get(
309
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
Timothy J. Baek's avatar
Timothy J. Baek committed
310
                headers={"authorization": get_automatic1111_api_auth()},
311
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
312
313
            options = r.json()
            return {"model": options["sd_model_checkpoint"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
314
    except Exception as e:
315
        app.state.config.ENABLED = False
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
316
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
317
318
319
320
321
322
323


class UpdateModelForm(BaseModel):
    model: str


def set_model_handler(model: str):
324
325
326
    if app.state.config.ENGINE in ["openai", "comfyui"]:
        app.state.config.MODEL = model
        return app.state.config.MODEL
Timothy J. Baek's avatar
Timothy J. Baek committed
327
    else:
328
        api_auth = get_automatic1111_api_auth()
329
        r = requests.get(
330
            url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
Timothy J. Baek's avatar
Timothy J. Baek committed
331
            headers={"authorization": api_auth},
332
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
333
334
335
336
337
        options = r.json()

        if model != options["sd_model_checkpoint"]:
            options["sd_model_checkpoint"] = model
            r = requests.post(
338
339
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
                json=options,
340
                headers={"authorization": api_auth},
Timothy J. Baek's avatar
Timothy J. Baek committed
341
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
342

Timothy J. Baek's avatar
Timothy J. Baek committed
343
        return options
Timothy J. Baek's avatar
Timothy J. Baek committed
344
345
346
347


@app.post("/models/default/update")
def update_default_model(
Timothy J. Baek's avatar
Timothy J. Baek committed
348
    form_data: UpdateModelForm,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
349
    user=Depends(get_verified_user),
Timothy J. Baek's avatar
Timothy J. Baek committed
350
351
352
353
354
355
356
357
):
    return set_model_handler(form_data.model)


class GenerateImageForm(BaseModel):
    model: Optional[str] = None
    prompt: str
    n: int = 1
Timothy J. Baek's avatar
Timothy J. Baek committed
358
    size: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
359
360
361
    negative_prompt: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
362
363
def save_b64_image(b64_str):
    try:
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
364
        image_id = str(uuid.uuid4())
Timothy J. Baek's avatar
Timothy J. Baek committed
365

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
366
367
368
        if "," in b64_str:
            header, encoded = b64_str.split(",", 1)
            mime_type = header.split(";")[0]
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
369

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
            img_data = base64.b64decode(encoded)
            image_format = mimetypes.guess_extension(mime_type)

            image_filename = f"{image_id}{image_format}"
            file_path = IMAGE_CACHE_DIR / f"{image_filename}"
            with open(file_path, "wb") as f:
                f.write(img_data)
            return image_filename
        else:
            image_filename = f"{image_id}.png"
            file_path = IMAGE_CACHE_DIR.joinpath(image_filename)

            img_data = base64.b64decode(b64_str)

            # Write the image data to a file
            with open(file_path, "wb") as f:
                f.write(img_data)
            return image_filename
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
388

Timothy J. Baek's avatar
Timothy J. Baek committed
389
    except Exception as e:
390
        log.exception(f"Error saving image: {e}")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
391
        return None
Timothy J. Baek's avatar
Timothy J. Baek committed
392
393


Timothy J. Baek's avatar
Timothy J. Baek committed
394
395
396
397
398
def save_url_image(url):
    image_id = str(uuid.uuid4())
    try:
        r = requests.get(url)
        r.raise_for_status()
399
400
401
402
        if r.headers["content-type"].split("/")[0] == "image":

            mime_type = r.headers["content-type"]
            image_format = mimetypes.guess_extension(mime_type)
Timothy J. Baek's avatar
Timothy J. Baek committed
403

404
405
406
            if not image_format:
                raise ValueError("Could not determine image type from MIME type")

Timothy J. Baek's avatar
Timothy J. Baek committed
407
408
409
            image_filename = f"{image_id}{image_format}"

            file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
410
411
412
            with open(file_path, "wb") as image_file:
                for chunk in r.iter_content(chunk_size=8192):
                    image_file.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
413
            return image_filename
414
415
        else:
            log.error(f"Url does not point to an image.")
Timothy J. Baek's avatar
Timothy J. Baek committed
416
            return None
Timothy J. Baek's avatar
Timothy J. Baek committed
417
418

    except Exception as e:
419
        log.exception(f"Error saving image: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
420
        return None
Timothy J. Baek's avatar
Timothy J. Baek committed
421
422


Timothy J. Baek's avatar
Timothy J. Baek committed
423
424
@app.post("/generations")
def generate_image(
Timothy J. Baek's avatar
Timothy J. Baek committed
425
    form_data: GenerateImageForm,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
426
    user=Depends(get_verified_user),
Timothy J. Baek's avatar
Timothy J. Baek committed
427
):
428
    width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
Timothy J. Baek's avatar
Timothy J. Baek committed
429

Timothy J. Baek's avatar
Timothy J. Baek committed
430
    r = None
431
    try:
432
        if app.state.config.ENGINE == "openai":
433

Timothy J. Baek's avatar
Timothy J. Baek committed
434
            headers = {}
435
            headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
Timothy J. Baek's avatar
Timothy J. Baek committed
436
            headers["Content-Type"] = "application/json"
437

Timothy J. Baek's avatar
Timothy J. Baek committed
438
            data = {
439
440
441
442
443
                "model": (
                    app.state.config.MODEL
                    if app.state.config.MODEL != ""
                    else "dall-e-2"
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
444
445
                "prompt": form_data.prompt,
                "n": form_data.n,
446
                "size": (
447
                    form_data.size if form_data.size else app.state.config.IMAGE_SIZE
448
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
449
450
                "response_format": "b64_json",
            }
451

Timothy J. Baek's avatar
Timothy J. Baek committed
452
            r = requests.post(
453
                url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
Timothy J. Baek's avatar
Timothy J. Baek committed
454
455
456
                json=data,
                headers=headers,
            )
457

Timothy J. Baek's avatar
Timothy J. Baek committed
458
459
            r.raise_for_status()
            res = r.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
460

Timothy J. Baek's avatar
Timothy J. Baek committed
461
462
463
            images = []

            for image in res["data"]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
464
465
                image_filename = save_b64_image(image["b64_json"])
                images.append({"url": f"/cache/image/generations/{image_filename}"})
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
466
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
467
468
469
470
471
472

                with open(file_body_path, "w") as f:
                    json.dump(data, f)

            return images

473
        elif app.state.config.ENGINE == "comfyui":
Timothy J. Baek's avatar
Timothy J. Baek committed
474
475
476
477
478
479
480
481

            data = {
                "prompt": form_data.prompt,
                "width": width,
                "height": height,
                "n": form_data.n,
            }

482
483
            if app.state.config.IMAGE_STEPS is not None:
                data["steps"] = app.state.config.IMAGE_STEPS
Timothy J. Baek's avatar
Timothy J. Baek committed
484

485
            if form_data.negative_prompt is not None:
Timothy J. Baek's avatar
Timothy J. Baek committed
486
487
                data["negative_prompt"] = form_data.negative_prompt

488
489
490
491
492
493
494
495
496
497
498
499
            if app.state.config.COMFYUI_CFG_SCALE:
                data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE

            if app.state.config.COMFYUI_SAMPLER is not None:
                data["sampler"] = app.state.config.COMFYUI_SAMPLER

            if app.state.config.COMFYUI_SCHEDULER is not None:
                data["scheduler"] = app.state.config.COMFYUI_SCHEDULER

            if app.state.config.COMFYUI_SD3 is not None:
                data["sd3"] = app.state.config.COMFYUI_SD3

Timothy J. Baek's avatar
Timothy J. Baek committed
500
501
502
            data = ImageGenerationPayload(**data)

            res = comfyui_generate_image(
503
                app.state.config.MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
504
505
                data,
                user.id,
506
                app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
507
            )
508
            log.debug(f"res: {res}")
Timothy J. Baek's avatar
Timothy J. Baek committed
509
510
511
512

            images = []

            for image in res["data"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
513
514
515
                image_filename = save_url_image(image["url"])
                images.append({"url": f"/cache/image/generations/{image_filename}"})
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
516
517
518
519

                with open(file_body_path, "w") as f:
                    json.dump(data.model_dump(exclude_none=True), f)

520
            log.debug(f"images: {images}")
Timothy J. Baek's avatar
Timothy J. Baek committed
521
            return images
Timothy J. Baek's avatar
Timothy J. Baek committed
522
523
524
525
526
527
528
529
530
531
532
        else:
            if form_data.model:
                set_model_handler(form_data.model)

            data = {
                "prompt": form_data.prompt,
                "batch_size": form_data.n,
                "width": width,
                "height": height,
            }

533
534
            if app.state.config.IMAGE_STEPS is not None:
                data["steps"] = app.state.config.IMAGE_STEPS
Timothy J. Baek's avatar
Timothy J. Baek committed
535

536
            if form_data.negative_prompt is not None:
Timothy J. Baek's avatar
Timothy J. Baek committed
537
538
539
                data["negative_prompt"] = form_data.negative_prompt

            r = requests.post(
540
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
Timothy J. Baek's avatar
Timothy J. Baek committed
541
                json=data,
542
                headers={"authorization": get_automatic1111_api_auth()},
Timothy J. Baek's avatar
Timothy J. Baek committed
543
544
545
546
            )

            res = r.json()

547
            log.debug(f"res: {res}")
Timothy J. Baek's avatar
Timothy J. Baek committed
548
549
550
551

            images = []

            for image in res["images"]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
552
553
                image_filename = save_b64_image(image)
                images.append({"url": f"/cache/image/generations/{image_filename}"})
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
554
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
555
556
557
558
559

                with open(file_body_path, "w") as f:
                    json.dump({**data, "info": res["info"]}, f)

            return images
560
561

    except Exception as e:
562
563
564
565
566
567
568
        error = e

        if r != None:
            data = r.json()
            if "error" in data:
                error = data["error"]["message"]
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))