main.py 17 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
import re
Timothy J. Baek's avatar
Timothy J. Baek committed
2
import requests
3
import base64
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from fastapi import (
    FastAPI,
    Request,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel

from constants import ERROR_MESSAGES
from utils.utils import (
    get_current_user,
    get_admin_user,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
22
23

from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
Timothy J. Baek's avatar
Timothy J. Baek committed
24
25
26
from utils.misc import calculate_sha256
from typing import Optional
from pydantic import BaseModel
Timothy J. Baek's avatar
Timothy J. Baek committed
27
from pathlib import Path
28
import mimetypes
Timothy J. Baek's avatar
Timothy J. Baek committed
29
30
31
import uuid
import base64
import json
32
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
33

Self Denial's avatar
Self Denial committed
34
35
36
from config import (
    SRC_LOG_LEVELS,
    CACHE_DIR,
37
    IMAGE_GENERATION_ENGINE,
Self Denial's avatar
Self Denial committed
38
    ENABLE_IMAGE_GENERATION,
Self Denial's avatar
Self Denial committed
39
    AUTOMATIC1111_BASE_URL,
40
    AUTOMATIC1111_API_AUTH,
Self Denial's avatar
Self Denial committed
41
    COMFYUI_BASE_URL,
42
43
44
45
    COMFYUI_CFG_SCALE,
    COMFYUI_SAMPLER,
    COMFYUI_SCHEDULER,
    COMFYUI_SD3,
46
47
    IMAGES_OPENAI_API_BASE_URL,
    IMAGES_OPENAI_API_KEY,
48
    IMAGE_GENERATION_MODEL,
49
50
    IMAGE_SIZE,
    IMAGE_STEPS,
51
    AppConfig,
Self Denial's avatar
Self Denial committed
52
)
Timothy J. Baek's avatar
Timothy J. Baek committed
53

54
55
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
Timothy J. Baek's avatar
Timothy J. Baek committed
56
57
58

IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
59
60
61
62
63
64
65
66
67
68

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

69
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
70

71
72
app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
Timothy J. Baek's avatar
Timothy J. Baek committed
73

74
75
app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
76

77
app.state.config.MODEL = IMAGE_GENERATION_MODEL
Timothy J. Baek's avatar
Timothy J. Baek committed
78

79
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
80
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
81
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
82

83
84
app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS
85
86
87
88
app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
app.state.config.COMFYUI_SD3 = COMFYUI_SD3
Timothy J. Baek's avatar
Timothy J. Baek committed
89
90


91
92
93
94
95
96
97
98
99
100
def get_automatic1111_api_auth():
    if app.state.config.AUTOMATIC1111_API_AUTH == None:
        return ""
    else:
        auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode('utf-8')
        auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
        auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode('utf-8')
        return f"Basic {auth1111_base64_encoded_string}"


Timothy J. Baek's avatar
Timothy J. Baek committed
101
102
@app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
103
    return {
104
105
        "engine": app.state.config.ENGINE,
        "enabled": app.state.config.ENABLED,
106
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
107
108


Timothy J. Baek's avatar
Timothy J. Baek committed
109
110
111
112
113
114
115
class ConfigUpdateForm(BaseModel):
    engine: str
    enabled: bool


@app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
116
117
    app.state.config.ENGINE = form_data.engine
    app.state.config.ENABLED = form_data.enabled
118
    return {
119
120
        "engine": app.state.config.ENGINE,
        "enabled": app.state.config.ENABLED,
121
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
122
123


Timothy J. Baek's avatar
Timothy J. Baek committed
124
125
126
class EngineUrlUpdateForm(BaseModel):
    AUTOMATIC1111_BASE_URL: Optional[str] = None
    COMFYUI_BASE_URL: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
127
128
129


@app.get("/url")
Timothy J. Baek's avatar
Timothy J. Baek committed
130
131
async def get_engine_url(user=Depends(get_admin_user)):
    return {
132
        "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
133
        "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
134
        "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
135
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
136
137
138


@app.post("/url/update")
Timothy J. Baek's avatar
Timothy J. Baek committed
139
async def update_engine_url(
140
        form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
141
):
Timothy J. Baek's avatar
Timothy J. Baek committed
142
    if form_data.AUTOMATIC1111_BASE_URL == None:
143
        app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
144
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
145
        url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
Timothy J. Baek's avatar
Timothy J. Baek committed
146
147
        try:
            r = requests.head(url)
148
            app.state.config.AUTOMATIC1111_BASE_URL = url
Timothy J. Baek's avatar
Timothy J. Baek committed
149
150
        except Exception as e:
            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
151

Timothy J. Baek's avatar
Timothy J. Baek committed
152
    if form_data.COMFYUI_BASE_URL == None:
153
        app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
154
155
    else:
        url = form_data.COMFYUI_BASE_URL.strip("/")
Timothy J. Baek's avatar
Timothy J. Baek committed
156
157
158

        try:
            r = requests.head(url)
159
            app.state.config.COMFYUI_BASE_URL = url
Timothy J. Baek's avatar
Timothy J. Baek committed
160
161
        except Exception as e:
            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
162

163
164
165
166
167
    if form_data.AUTOMATIC1111_API_AUTH == None:
        app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
    else:
        app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH

Timothy J. Baek's avatar
Timothy J. Baek committed
168
    return {
169
        "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
170
        "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
171
        "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
172
173
        "status": True,
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
174
175


176
177
class OpenAIConfigUpdateForm(BaseModel):
    url: str
Timothy J. Baek's avatar
Timothy J. Baek committed
178
179
180
    key: str


181
182
183
@app.get("/openai/config")
async def get_openai_config(user=Depends(get_admin_user)):
    return {
184
185
        "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
        "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
186
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
187
188


189
190
@app.post("/openai/config/update")
async def update_openai_config(
191
        form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
192
193
194
195
):
    if form_data.key == "":
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

196
197
    app.state.config.OPENAI_API_BASE_URL = form_data.url
    app.state.config.OPENAI_API_KEY = form_data.key
198

Timothy J. Baek's avatar
Timothy J. Baek committed
199
200
    return {
        "status": True,
201
202
        "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
        "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
203
204
205
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
206
207
208
209
210
211
class ImageSizeUpdateForm(BaseModel):
    size: str


@app.get("/size")
async def get_image_size(user=Depends(get_admin_user)):
212
    return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
Timothy J. Baek's avatar
Timothy J. Baek committed
213
214
215
216


@app.post("/size/update")
async def update_image_size(
217
        form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
218
219
220
):
    pattern = r"^\d+x\d+$"  # Regular expression pattern
    if re.match(pattern, form_data.size):
221
        app.state.config.IMAGE_SIZE = form_data.size
Timothy J. Baek's avatar
Timothy J. Baek committed
222
        return {
223
            "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
224
225
226
227
228
229
230
            "status": True,
        }
    else:
        raise HTTPException(
            status_code=400,
            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 512x512)."),
        )
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
231

232
233
234
235
236
237
238

class ImageStepsUpdateForm(BaseModel):
    steps: int


@app.get("/steps")
async def get_image_size(user=Depends(get_admin_user)):
239
    return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
240
241
242
243


@app.post("/steps/update")
async def update_image_size(
244
        form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
245
246
):
    if form_data.steps >= 0:
247
        app.state.config.IMAGE_STEPS = form_data.steps
248
        return {
249
            "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
250
251
252
253
254
255
256
            "status": True,
        }
    else:
        raise HTTPException(
            status_code=400,
            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 50)."),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
257
258


Timothy J. Baek's avatar
Timothy J. Baek committed
259
260
261
@app.get("/models")
def get_models(user=Depends(get_current_user)):
    try:
262
        if app.state.config.ENGINE == "openai":
Timothy J. Baek's avatar
Timothy J. Baek committed
263
264
265
266
            return [
                {"id": "dall-e-2", "name": "DALL·E 2"},
                {"id": "dall-e-3", "name": "DALL·E 3"},
            ]
267
        elif app.state.config.ENGINE == "comfyui":
Timothy J. Baek's avatar
Timothy J. Baek committed
268

269
            r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
Timothy J. Baek's avatar
Timothy J. Baek committed
270
271
272
273
274
275
276
277
278
            info = r.json()

            return list(
                map(
                    lambda model: {"id": model, "name": model},
                    info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
                )
            )

Timothy J. Baek's avatar
Timothy J. Baek committed
279
280
        else:
            r = requests.get(
281
282
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
                headers={"authorization": get_automatic1111_api_auth()}
Timothy J. Baek's avatar
Timothy J. Baek committed
283
284
285
286
287
288
289
290
            )
            models = r.json()
            return list(
                map(
                    lambda model: {"id": model["title"], "name": model["model_name"]},
                    models,
                )
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
291
    except Exception as e:
292
        app.state.config.ENABLED = False
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
293
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
294
295
296
297
298


@app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)):
    try:
299
        if app.state.config.ENGINE == "openai":
300
301
            return {
                "model": (
302
                    app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
303
304
                )
            }
305
306
        elif app.state.config.ENGINE == "comfyui":
            return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
Timothy J. Baek's avatar
Timothy J. Baek committed
307
        else:
308
            r = requests.get(
309
310
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
                headers={"authorization": get_automatic1111_api_auth()}
311
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
312
313
            options = r.json()
            return {"model": options["sd_model_checkpoint"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
314
    except Exception as e:
315
        app.state.config.ENABLED = False
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
316
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
317
318
319
320
321
322
323


class UpdateModelForm(BaseModel):
    model: str


def set_model_handler(model: str):
324
325
326
    if app.state.config.ENGINE in ["openai", "comfyui"]:
        app.state.config.MODEL = model
        return app.state.config.MODEL
Timothy J. Baek's avatar
Timothy J. Baek committed
327
    else:
328
        api_auth = get_automatic1111_api_auth()
329
        r = requests.get(
330
331
            url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
            headers={"authorization": api_auth}
332
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
333
334
335
336
337
        options = r.json()

        if model != options["sd_model_checkpoint"]:
            options["sd_model_checkpoint"] = model
            r = requests.post(
338
339
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
                json=options,
340
                headers={"authorization": api_auth},
Timothy J. Baek's avatar
Timothy J. Baek committed
341
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
342

Timothy J. Baek's avatar
Timothy J. Baek committed
343
        return options
Timothy J. Baek's avatar
Timothy J. Baek committed
344
345
346
347


@app.post("/models/default/update")
def update_default_model(
348
349
        form_data: UpdateModelForm,
        user=Depends(get_current_user),
Timothy J. Baek's avatar
Timothy J. Baek committed
350
351
352
353
354
355
356
357
):
    return set_model_handler(form_data.model)


class GenerateImageForm(BaseModel):
    model: Optional[str] = None
    prompt: str
    n: int = 1
Timothy J. Baek's avatar
Timothy J. Baek committed
358
    size: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
359
360
361
    negative_prompt: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
362
363
def save_b64_image(b64_str):
    try:
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
364
        image_id = str(uuid.uuid4())
Timothy J. Baek's avatar
Timothy J. Baek committed
365

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
366
367
368
        if "," in b64_str:
            header, encoded = b64_str.split(",", 1)
            mime_type = header.split(";")[0]
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
369

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
            img_data = base64.b64decode(encoded)
            image_format = mimetypes.guess_extension(mime_type)

            image_filename = f"{image_id}{image_format}"
            file_path = IMAGE_CACHE_DIR / f"{image_filename}"
            with open(file_path, "wb") as f:
                f.write(img_data)
            return image_filename
        else:
            image_filename = f"{image_id}.png"
            file_path = IMAGE_CACHE_DIR.joinpath(image_filename)

            img_data = base64.b64decode(b64_str)

            # Write the image data to a file
            with open(file_path, "wb") as f:
                f.write(img_data)
            return image_filename
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
388

Timothy J. Baek's avatar
Timothy J. Baek committed
389
    except Exception as e:
390
        log.exception(f"Error saving image: {e}")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
391
        return None
Timothy J. Baek's avatar
Timothy J. Baek committed
392
393


Timothy J. Baek's avatar
Timothy J. Baek committed
394
395
396
397
398
def save_url_image(url):
    image_id = str(uuid.uuid4())
    try:
        r = requests.get(url)
        r.raise_for_status()
399
400
401
402
        if r.headers["content-type"].split("/")[0] == "image":

            mime_type = r.headers["content-type"]
            image_format = mimetypes.guess_extension(mime_type)
Timothy J. Baek's avatar
Timothy J. Baek committed
403

404
405
406
            if not image_format:
                raise ValueError("Could not determine image type from MIME type")

Timothy J. Baek's avatar
Timothy J. Baek committed
407
408
409
            image_filename = f"{image_id}{image_format}"

            file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
410
411
412
            with open(file_path, "wb") as image_file:
                for chunk in r.iter_content(chunk_size=8192):
                    image_file.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
413
            return image_filename
414
415
        else:
            log.error(f"Url does not point to an image.")
Timothy J. Baek's avatar
Timothy J. Baek committed
416
            return None
Timothy J. Baek's avatar
Timothy J. Baek committed
417
418

    except Exception as e:
419
        log.exception(f"Error saving image: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
420
        return None
Timothy J. Baek's avatar
Timothy J. Baek committed
421
422


Timothy J. Baek's avatar
Timothy J. Baek committed
423
424
@app.post("/generations")
def generate_image(
425
426
        form_data: GenerateImageForm,
        user=Depends(get_current_user),
Timothy J. Baek's avatar
Timothy J. Baek committed
427
):
428
    width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
Timothy J. Baek's avatar
Timothy J. Baek committed
429

Timothy J. Baek's avatar
Timothy J. Baek committed
430
    r = None
431
    try:
432
        if app.state.config.ENGINE == "openai":
433

Timothy J. Baek's avatar
Timothy J. Baek committed
434
            headers = {}
435
            headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
Timothy J. Baek's avatar
Timothy J. Baek committed
436
            headers["Content-Type"] = "application/json"
437

Timothy J. Baek's avatar
Timothy J. Baek committed
438
            data = {
439
440
441
442
443
                "model": (
                    app.state.config.MODEL
                    if app.state.config.MODEL != ""
                    else "dall-e-2"
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
444
445
                "prompt": form_data.prompt,
                "n": form_data.n,
446
                "size": (
447
                    form_data.size if form_data.size else app.state.config.IMAGE_SIZE
448
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
449
450
                "response_format": "b64_json",
            }
451

Timothy J. Baek's avatar
Timothy J. Baek committed
452
            r = requests.post(
453
                url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
Timothy J. Baek's avatar
Timothy J. Baek committed
454
455
456
                json=data,
                headers=headers,
            )
457

Timothy J. Baek's avatar
Timothy J. Baek committed
458
459
            r.raise_for_status()
            res = r.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
460

Timothy J. Baek's avatar
Timothy J. Baek committed
461
462
463
            images = []

            for image in res["data"]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
464
465
                image_filename = save_b64_image(image["b64_json"])
                images.append({"url": f"/cache/image/generations/{image_filename}"})
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
466
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
467
468
469
470
471
472

                with open(file_body_path, "w") as f:
                    json.dump(data, f)

            return images

473
        elif app.state.config.ENGINE == "comfyui":
Timothy J. Baek's avatar
Timothy J. Baek committed
474
475
476
477
478
479
480
481

            data = {
                "prompt": form_data.prompt,
                "width": width,
                "height": height,
                "n": form_data.n,
            }

482
483
            if app.state.config.IMAGE_STEPS is not None:
                data["steps"] = app.state.config.IMAGE_STEPS
Timothy J. Baek's avatar
Timothy J. Baek committed
484

485
            if form_data.negative_prompt is not None:
Timothy J. Baek's avatar
Timothy J. Baek committed
486
487
                data["negative_prompt"] = form_data.negative_prompt

488
489
490
491
492
493
494
495
496
497
498
499
            if app.state.config.COMFYUI_CFG_SCALE:
                data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE

            if app.state.config.COMFYUI_SAMPLER is not None:
                data["sampler"] = app.state.config.COMFYUI_SAMPLER

            if app.state.config.COMFYUI_SCHEDULER is not None:
                data["scheduler"] = app.state.config.COMFYUI_SCHEDULER

            if app.state.config.COMFYUI_SD3 is not None:
                data["sd3"] = app.state.config.COMFYUI_SD3

Timothy J. Baek's avatar
Timothy J. Baek committed
500
501
502
            data = ImageGenerationPayload(**data)

            res = comfyui_generate_image(
503
                app.state.config.MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
504
505
                data,
                user.id,
506
                app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
507
            )
508
            log.debug(f"res: {res}")
Timothy J. Baek's avatar
Timothy J. Baek committed
509
510
511
512

            images = []

            for image in res["data"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
513
514
515
                image_filename = save_url_image(image["url"])
                images.append({"url": f"/cache/image/generations/{image_filename}"})
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
516
517
518
519

                with open(file_body_path, "w") as f:
                    json.dump(data.model_dump(exclude_none=True), f)

520
            log.debug(f"images: {images}")
Timothy J. Baek's avatar
Timothy J. Baek committed
521
            return images
Timothy J. Baek's avatar
Timothy J. Baek committed
522
523
524
525
526
527
528
529
530
531
532
        else:
            if form_data.model:
                set_model_handler(form_data.model)

            data = {
                "prompt": form_data.prompt,
                "batch_size": form_data.n,
                "width": width,
                "height": height,
            }

533
534
            if app.state.config.IMAGE_STEPS is not None:
                data["steps"] = app.state.config.IMAGE_STEPS
Timothy J. Baek's avatar
Timothy J. Baek committed
535

536
            if form_data.negative_prompt is not None:
Timothy J. Baek's avatar
Timothy J. Baek committed
537
538
539
                data["negative_prompt"] = form_data.negative_prompt

            r = requests.post(
540
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
Timothy J. Baek's avatar
Timothy J. Baek committed
541
                json=data,
542
                headers={"authorization": get_automatic1111_api_auth()},
Timothy J. Baek's avatar
Timothy J. Baek committed
543
544
545
546
            )

            res = r.json()

547
            log.debug(f"res: {res}")
Timothy J. Baek's avatar
Timothy J. Baek committed
548
549
550
551

            images = []

            for image in res["images"]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
552
553
                image_filename = save_b64_image(image)
                images.append({"url": f"/cache/image/generations/{image_filename}"})
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
554
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
555
556
557
558
559

                with open(file_body_path, "w") as f:
                    json.dump({**data, "info": res["info"]}, f)

            return images
560
561

    except Exception as e:
562
563
564
565
566
567
568
        error = e

        if r != None:
            data = r.json()
            if "error" in data:
                error = data["error"]["message"]
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))