main.py 17.6 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
import re
Timothy J. Baek's avatar
Timothy J. Baek committed
2
import requests
3
import base64
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from fastapi import (
    FastAPI,
    Request,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
from fastapi.middleware.cors import CORSMiddleware

from constants import ERROR_MESSAGES
from utils.utils import (
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
18
    get_verified_user,
Timothy J. Baek's avatar
Timothy J. Baek committed
19
20
    get_admin_user,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
21
22

from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
Timothy J. Baek's avatar
Timothy J. Baek committed
23
24
25
from utils.misc import calculate_sha256
from typing import Optional
from pydantic import BaseModel
Timothy J. Baek's avatar
Timothy J. Baek committed
26
from pathlib import Path
27
import mimetypes
Timothy J. Baek's avatar
Timothy J. Baek committed
28
29
30
import uuid
import base64
import json
31
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
32

Self Denial's avatar
Self Denial committed
33
34
35
from config import (
    SRC_LOG_LEVELS,
    CACHE_DIR,
36
    IMAGE_GENERATION_ENGINE,
Self Denial's avatar
Self Denial committed
37
    ENABLE_IMAGE_GENERATION,
Self Denial's avatar
Self Denial committed
38
    AUTOMATIC1111_BASE_URL,
39
    AUTOMATIC1111_API_AUTH,
Self Denial's avatar
Self Denial committed
40
    COMFYUI_BASE_URL,
41
42
43
44
    COMFYUI_CFG_SCALE,
    COMFYUI_SAMPLER,
    COMFYUI_SCHEDULER,
    COMFYUI_SD3,
45
46
47
    COMFYUI_FLUX,
    COMFYUI_FLUX_WEIGHT_DTYPE,
    COMFYUI_FLUX_FP8_CLIP,
48
49
    IMAGES_OPENAI_API_BASE_URL,
    IMAGES_OPENAI_API_KEY,
50
    IMAGE_GENERATION_MODEL,
51
52
    IMAGE_SIZE,
    IMAGE_STEPS,
53
    AppConfig,
Self Denial's avatar
Self Denial committed
54
)
Timothy J. Baek's avatar
Timothy J. Baek committed
55

56
57
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
Timothy J. Baek's avatar
Timothy J. Baek committed
58
59
60

IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
61
62
63
64
65
66
67
68
69
70

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

71
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
72

73
74
app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
Timothy J. Baek's avatar
Timothy J. Baek committed
75

76
77
app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
78

79
app.state.config.MODEL = IMAGE_GENERATION_MODEL
Timothy J. Baek's avatar
Timothy J. Baek committed
80

81
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
82
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
83
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
84

85
86
app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS
87
88
89
90
app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
app.state.config.COMFYUI_SD3 = COMFYUI_SD3
91
92
93
app.state.config.COMFYUI_FLUX = COMFYUI_FLUX
app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE = COMFYUI_FLUX_WEIGHT_DTYPE
app.state.config.COMFYUI_FLUX_FP8_CLIP = COMFYUI_FLUX_FP8_CLIP
Timothy J. Baek's avatar
Timothy J. Baek committed
94

95
96
97
98
def get_automatic1111_api_auth():
    if app.state.config.AUTOMATIC1111_API_AUTH == None:
        return ""
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
99
        auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
100
        auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
Timothy J. Baek's avatar
Timothy J. Baek committed
101
        auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
102
103
104
        return f"Basic {auth1111_base64_encoded_string}"


Timothy J. Baek's avatar
Timothy J. Baek committed
105
106
@app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
107
    return {
108
109
        "engine": app.state.config.ENGINE,
        "enabled": app.state.config.ENABLED,
110
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
111
112


Timothy J. Baek's avatar
Timothy J. Baek committed
113
114
115
116
117
118
119
class ConfigUpdateForm(BaseModel):
    engine: str
    enabled: bool


@app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
120
121
    app.state.config.ENGINE = form_data.engine
    app.state.config.ENABLED = form_data.enabled
122
    return {
123
124
        "engine": app.state.config.ENGINE,
        "enabled": app.state.config.ENABLED,
125
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
126
127


Timothy J. Baek's avatar
Timothy J. Baek committed
128
129
class EngineUrlUpdateForm(BaseModel):
    AUTOMATIC1111_BASE_URL: Optional[str] = None
130
    AUTOMATIC1111_API_AUTH: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
131
    COMFYUI_BASE_URL: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
132
133
134


@app.get("/url")
Timothy J. Baek's avatar
Timothy J. Baek committed
135
136
async def get_engine_url(user=Depends(get_admin_user)):
    return {
137
        "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
138
        "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
139
        "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
140
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
141
142
143


@app.post("/url/update")
Timothy J. Baek's avatar
Timothy J. Baek committed
144
async def update_engine_url(
Timothy J. Baek's avatar
Timothy J. Baek committed
145
    form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
146
):
Timothy J. Baek's avatar
Timothy J. Baek committed
147
    if form_data.AUTOMATIC1111_BASE_URL == None:
148
        app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
149
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
150
        url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
Timothy J. Baek's avatar
Timothy J. Baek committed
151
152
        try:
            r = requests.head(url)
153
            app.state.config.AUTOMATIC1111_BASE_URL = url
Timothy J. Baek's avatar
Timothy J. Baek committed
154
155
        except Exception as e:
            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
156

Timothy J. Baek's avatar
Timothy J. Baek committed
157
    if form_data.COMFYUI_BASE_URL == None:
158
        app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
159
160
    else:
        url = form_data.COMFYUI_BASE_URL.strip("/")
Timothy J. Baek's avatar
Timothy J. Baek committed
161
162
163

        try:
            r = requests.head(url)
164
            app.state.config.COMFYUI_BASE_URL = url
Timothy J. Baek's avatar
Timothy J. Baek committed
165
166
        except Exception as e:
            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
167

168
169
170
171
172
    if form_data.AUTOMATIC1111_API_AUTH == None:
        app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
    else:
        app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH

Timothy J. Baek's avatar
Timothy J. Baek committed
173
    return {
174
        "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
175
        "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
176
        "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
177
178
        "status": True,
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
179
180


181
182
class OpenAIConfigUpdateForm(BaseModel):
    url: str
Timothy J. Baek's avatar
Timothy J. Baek committed
183
184
185
    key: str


186
187
188
@app.get("/openai/config")
async def get_openai_config(user=Depends(get_admin_user)):
    return {
189
190
        "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
        "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
191
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
192
193


194
195
@app.post("/openai/config/update")
async def update_openai_config(
Timothy J. Baek's avatar
Timothy J. Baek committed
196
    form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
197
198
199
200
):
    if form_data.key == "":
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

201
202
    app.state.config.OPENAI_API_BASE_URL = form_data.url
    app.state.config.OPENAI_API_KEY = form_data.key
203

Timothy J. Baek's avatar
Timothy J. Baek committed
204
205
    return {
        "status": True,
206
207
        "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
        "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
208
209
210
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
211
212
213
214
215
216
class ImageSizeUpdateForm(BaseModel):
    size: str


@app.get("/size")
async def get_image_size(user=Depends(get_admin_user)):
217
    return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
Timothy J. Baek's avatar
Timothy J. Baek committed
218
219
220
221


@app.post("/size/update")
async def update_image_size(
Timothy J. Baek's avatar
Timothy J. Baek committed
222
    form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
223
224
225
):
    pattern = r"^\d+x\d+$"  # Regular expression pattern
    if re.match(pattern, form_data.size):
226
        app.state.config.IMAGE_SIZE = form_data.size
Timothy J. Baek's avatar
Timothy J. Baek committed
227
        return {
228
            "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
229
230
231
232
233
234
235
            "status": True,
        }
    else:
        raise HTTPException(
            status_code=400,
            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 512x512)."),
        )
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
236

237
238
239
240
241
242
243

class ImageStepsUpdateForm(BaseModel):
    steps: int


@app.get("/steps")
async def get_image_size(user=Depends(get_admin_user)):
244
    return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
245
246
247
248


@app.post("/steps/update")
async def update_image_size(
Timothy J. Baek's avatar
Timothy J. Baek committed
249
    form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
250
251
):
    if form_data.steps >= 0:
252
        app.state.config.IMAGE_STEPS = form_data.steps
253
        return {
254
            "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
255
256
257
258
259
260
261
            "status": True,
        }
    else:
        raise HTTPException(
            status_code=400,
            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 50)."),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
262
263


Timothy J. Baek's avatar
Timothy J. Baek committed
264
@app.get("/models")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
265
def get_models(user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
266
    try:
267
        if app.state.config.ENGINE == "openai":
Timothy J. Baek's avatar
Timothy J. Baek committed
268
269
270
271
            return [
                {"id": "dall-e-2", "name": "DALL·E 2"},
                {"id": "dall-e-3", "name": "DALL·E 3"},
            ]
272
        elif app.state.config.ENGINE == "comfyui":
Timothy J. Baek's avatar
Timothy J. Baek committed
273

274
            r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
Timothy J. Baek's avatar
Timothy J. Baek committed
275
276
277
278
279
280
281
282
283
            info = r.json()

            return list(
                map(
                    lambda model: {"id": model, "name": model},
                    info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
                )
            )

Timothy J. Baek's avatar
Timothy J. Baek committed
284
285
        else:
            r = requests.get(
286
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
Timothy J. Baek's avatar
Timothy J. Baek committed
287
                headers={"authorization": get_automatic1111_api_auth()},
Timothy J. Baek's avatar
Timothy J. Baek committed
288
289
290
291
292
293
294
295
            )
            models = r.json()
            return list(
                map(
                    lambda model: {"id": model["title"], "name": model["model_name"]},
                    models,
                )
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
296
    except Exception as e:
297
        app.state.config.ENABLED = False
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
298
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
299
300
301
302
303


@app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)):
    try:
304
        if app.state.config.ENGINE == "openai":
305
306
            return {
                "model": (
307
                    app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
308
309
                )
            }
310
311
        elif app.state.config.ENGINE == "comfyui":
            return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
Timothy J. Baek's avatar
Timothy J. Baek committed
312
        else:
313
            r = requests.get(
314
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
Timothy J. Baek's avatar
Timothy J. Baek committed
315
                headers={"authorization": get_automatic1111_api_auth()},
316
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
317
318
            options = r.json()
            return {"model": options["sd_model_checkpoint"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
319
    except Exception as e:
320
        app.state.config.ENABLED = False
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
321
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
322
323
324
325
326
327
328


class UpdateModelForm(BaseModel):
    model: str


def set_model_handler(model: str):
329
330
331
    if app.state.config.ENGINE in ["openai", "comfyui"]:
        app.state.config.MODEL = model
        return app.state.config.MODEL
Timothy J. Baek's avatar
Timothy J. Baek committed
332
    else:
333
        api_auth = get_automatic1111_api_auth()
334
        r = requests.get(
335
            url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
Timothy J. Baek's avatar
Timothy J. Baek committed
336
            headers={"authorization": api_auth},
337
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
338
339
340
341
342
        options = r.json()

        if model != options["sd_model_checkpoint"]:
            options["sd_model_checkpoint"] = model
            r = requests.post(
343
344
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
                json=options,
345
                headers={"authorization": api_auth},
Timothy J. Baek's avatar
Timothy J. Baek committed
346
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
347

Timothy J. Baek's avatar
Timothy J. Baek committed
348
        return options
Timothy J. Baek's avatar
Timothy J. Baek committed
349
350
351
352


@app.post("/models/default/update")
def update_default_model(
Timothy J. Baek's avatar
Timothy J. Baek committed
353
    form_data: UpdateModelForm,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
354
    user=Depends(get_verified_user),
Timothy J. Baek's avatar
Timothy J. Baek committed
355
356
357
358
359
360
361
362
):
    return set_model_handler(form_data.model)


class GenerateImageForm(BaseModel):
    model: Optional[str] = None
    prompt: str
    n: int = 1
Timothy J. Baek's avatar
Timothy J. Baek committed
363
    size: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
364
365
366
    negative_prompt: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
367
368
def save_b64_image(b64_str):
    try:
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
369
        image_id = str(uuid.uuid4())
Timothy J. Baek's avatar
Timothy J. Baek committed
370

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
371
372
373
        if "," in b64_str:
            header, encoded = b64_str.split(",", 1)
            mime_type = header.split(";")[0]
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
374

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
            img_data = base64.b64decode(encoded)
            image_format = mimetypes.guess_extension(mime_type)

            image_filename = f"{image_id}{image_format}"
            file_path = IMAGE_CACHE_DIR / f"{image_filename}"
            with open(file_path, "wb") as f:
                f.write(img_data)
            return image_filename
        else:
            image_filename = f"{image_id}.png"
            file_path = IMAGE_CACHE_DIR.joinpath(image_filename)

            img_data = base64.b64decode(b64_str)

            # Write the image data to a file
            with open(file_path, "wb") as f:
                f.write(img_data)
            return image_filename
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
393

Timothy J. Baek's avatar
Timothy J. Baek committed
394
    except Exception as e:
395
        log.exception(f"Error saving image: {e}")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
396
        return None
Timothy J. Baek's avatar
Timothy J. Baek committed
397
398


Timothy J. Baek's avatar
Timothy J. Baek committed
399
400
401
402
403
def save_url_image(url):
    image_id = str(uuid.uuid4())
    try:
        r = requests.get(url)
        r.raise_for_status()
404
405
406
407
        if r.headers["content-type"].split("/")[0] == "image":

            mime_type = r.headers["content-type"]
            image_format = mimetypes.guess_extension(mime_type)
Timothy J. Baek's avatar
Timothy J. Baek committed
408

409
410
411
            if not image_format:
                raise ValueError("Could not determine image type from MIME type")

Timothy J. Baek's avatar
Timothy J. Baek committed
412
413
414
            image_filename = f"{image_id}{image_format}"

            file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
415
416
417
            with open(file_path, "wb") as image_file:
                for chunk in r.iter_content(chunk_size=8192):
                    image_file.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
418
            return image_filename
419
420
        else:
            log.error(f"Url does not point to an image.")
Timothy J. Baek's avatar
Timothy J. Baek committed
421
            return None
Timothy J. Baek's avatar
Timothy J. Baek committed
422
423

    except Exception as e:
424
        log.exception(f"Error saving image: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
425
        return None
Timothy J. Baek's avatar
Timothy J. Baek committed
426
427


Timothy J. Baek's avatar
Timothy J. Baek committed
428
@app.post("/generations")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
429
async def image_generations(
Timothy J. Baek's avatar
Timothy J. Baek committed
430
    form_data: GenerateImageForm,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
431
    user=Depends(get_verified_user),
Timothy J. Baek's avatar
Timothy J. Baek committed
432
):
433
    width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
Timothy J. Baek's avatar
Timothy J. Baek committed
434

Timothy J. Baek's avatar
Timothy J. Baek committed
435
    r = None
436
    try:
437
        if app.state.config.ENGINE == "openai":
438

Timothy J. Baek's avatar
Timothy J. Baek committed
439
            headers = {}
440
            headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
Timothy J. Baek's avatar
Timothy J. Baek committed
441
            headers["Content-Type"] = "application/json"
442

Timothy J. Baek's avatar
Timothy J. Baek committed
443
            data = {
444
445
446
447
448
                "model": (
                    app.state.config.MODEL
                    if app.state.config.MODEL != ""
                    else "dall-e-2"
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
449
450
                "prompt": form_data.prompt,
                "n": form_data.n,
451
                "size": (
452
                    form_data.size if form_data.size else app.state.config.IMAGE_SIZE
453
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
454
455
                "response_format": "b64_json",
            }
456

Timothy J. Baek's avatar
Timothy J. Baek committed
457
            r = requests.post(
458
                url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
Timothy J. Baek's avatar
Timothy J. Baek committed
459
460
461
                json=data,
                headers=headers,
            )
462

Timothy J. Baek's avatar
Timothy J. Baek committed
463
464
            r.raise_for_status()
            res = r.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
465

Timothy J. Baek's avatar
Timothy J. Baek committed
466
467
468
            images = []

            for image in res["data"]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
469
470
                image_filename = save_b64_image(image["b64_json"])
                images.append({"url": f"/cache/image/generations/{image_filename}"})
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
471
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
472
473
474
475
476
477

                with open(file_body_path, "w") as f:
                    json.dump(data, f)

            return images

478
        elif app.state.config.ENGINE == "comfyui":
Timothy J. Baek's avatar
Timothy J. Baek committed
479
480
481
482
483
484
485
486

            data = {
                "prompt": form_data.prompt,
                "width": width,
                "height": height,
                "n": form_data.n,
            }

487
488
            if app.state.config.IMAGE_STEPS is not None:
                data["steps"] = app.state.config.IMAGE_STEPS
Timothy J. Baek's avatar
Timothy J. Baek committed
489

490
            if form_data.negative_prompt is not None:
Timothy J. Baek's avatar
Timothy J. Baek committed
491
492
                data["negative_prompt"] = form_data.negative_prompt

493
494
495
496
497
498
499
500
501
502
503
504
            if app.state.config.COMFYUI_CFG_SCALE:
                data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE

            if app.state.config.COMFYUI_SAMPLER is not None:
                data["sampler"] = app.state.config.COMFYUI_SAMPLER

            if app.state.config.COMFYUI_SCHEDULER is not None:
                data["scheduler"] = app.state.config.COMFYUI_SCHEDULER

            if app.state.config.COMFYUI_SD3 is not None:
                data["sd3"] = app.state.config.COMFYUI_SD3

505
506
507
508
509
510
511
512
513
            if app.state.config.COMFYUI_FLUX is not None:
                data["flux"] = app.state.config.COMFYUI_FLUX

            if app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE is not None:
                data["flux_weight_dtype"] = app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE

            if app.state.config.COMFYUI_FLUX_FP8_CLIP is not None:
                data["flux_fp8_clip"] = app.state.config.COMFYUI_FLUX_FP8_CLIP

Timothy J. Baek's avatar
Timothy J. Baek committed
514
515
516
            data = ImageGenerationPayload(**data)

            res = comfyui_generate_image(
517
                app.state.config.MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
518
519
                data,
                user.id,
520
                app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
521
            )
522
            log.debug(f"res: {res}")
Timothy J. Baek's avatar
Timothy J. Baek committed
523
524
525
526

            images = []

            for image in res["data"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
527
528
529
                image_filename = save_url_image(image["url"])
                images.append({"url": f"/cache/image/generations/{image_filename}"})
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
530
531
532
533

                with open(file_body_path, "w") as f:
                    json.dump(data.model_dump(exclude_none=True), f)

534
            log.debug(f"images: {images}")
Timothy J. Baek's avatar
Timothy J. Baek committed
535
            return images
Timothy J. Baek's avatar
Timothy J. Baek committed
536
537
538
539
540
541
542
543
544
545
546
        else:
            if form_data.model:
                set_model_handler(form_data.model)

            data = {
                "prompt": form_data.prompt,
                "batch_size": form_data.n,
                "width": width,
                "height": height,
            }

547
548
            if app.state.config.IMAGE_STEPS is not None:
                data["steps"] = app.state.config.IMAGE_STEPS
Timothy J. Baek's avatar
Timothy J. Baek committed
549

550
            if form_data.negative_prompt is not None:
Timothy J. Baek's avatar
Timothy J. Baek committed
551
552
553
                data["negative_prompt"] = form_data.negative_prompt

            r = requests.post(
554
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
Timothy J. Baek's avatar
Timothy J. Baek committed
555
                json=data,
556
                headers={"authorization": get_automatic1111_api_auth()},
Timothy J. Baek's avatar
Timothy J. Baek committed
557
558
559
560
            )

            res = r.json()

561
            log.debug(f"res: {res}")
Timothy J. Baek's avatar
Timothy J. Baek committed
562
563
564
565

            images = []

            for image in res["images"]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
566
567
                image_filename = save_b64_image(image)
                images.append({"url": f"/cache/image/generations/{image_filename}"})
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
568
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
569
570
571
572
573

                with open(file_body_path, "w") as f:
                    json.dump({**data, "info": res["info"]}, f)

            return images
574
575

    except Exception as e:
576
577
578
579
580
581
582
        error = e

        if r != None:
            data = r.json()
            if "error" in data:
                error = data["error"]["message"]
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))