main.py 17.7 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
import re
Timothy J. Baek's avatar
Timothy J. Baek committed
2
import requests
3
import base64
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from fastapi import (
    FastAPI,
    Request,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
from fastapi.middleware.cors import CORSMiddleware

from constants import ERROR_MESSAGES
from utils.utils import (
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
18
    get_verified_user,
Timothy J. Baek's avatar
Timothy J. Baek committed
19
20
    get_admin_user,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
21
22

from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
Timothy J. Baek's avatar
Timothy J. Baek committed
23
24
25
from utils.misc import calculate_sha256
from typing import Optional
from pydantic import BaseModel
Timothy J. Baek's avatar
Timothy J. Baek committed
26
from pathlib import Path
27
import mimetypes
Timothy J. Baek's avatar
Timothy J. Baek committed
28
29
30
import uuid
import base64
import json
31
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
32

Self Denial's avatar
Self Denial committed
33
34
35
from config import (
    SRC_LOG_LEVELS,
    CACHE_DIR,
36
    IMAGE_GENERATION_ENGINE,
Self Denial's avatar
Self Denial committed
37
    ENABLE_IMAGE_GENERATION,
Self Denial's avatar
Self Denial committed
38
    AUTOMATIC1111_BASE_URL,
39
    AUTOMATIC1111_API_AUTH,
Self Denial's avatar
Self Denial committed
40
    COMFYUI_BASE_URL,
41
42
43
44
    COMFYUI_CFG_SCALE,
    COMFYUI_SAMPLER,
    COMFYUI_SCHEDULER,
    COMFYUI_SD3,
45
46
47
    COMFYUI_FLUX,
    COMFYUI_FLUX_WEIGHT_DTYPE,
    COMFYUI_FLUX_FP8_CLIP,
48
49
    IMAGES_OPENAI_API_BASE_URL,
    IMAGES_OPENAI_API_KEY,
50
    IMAGE_GENERATION_MODEL,
51
52
    IMAGE_SIZE,
    IMAGE_STEPS,
53
    AppConfig,
Self Denial's avatar
Self Denial committed
54
)
Timothy J. Baek's avatar
Timothy J. Baek committed
55

56
57
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
Timothy J. Baek's avatar
Timothy J. Baek committed
58
59
60

IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
61
62
63
64
65
66
67
68
69
70

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

71
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
72

73
74
app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
Timothy J. Baek's avatar
Timothy J. Baek committed
75

76
77
app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
78

79
app.state.config.MODEL = IMAGE_GENERATION_MODEL
Timothy J. Baek's avatar
Timothy J. Baek committed
80

81
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
82
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
83
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
84

85
86
app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS
87
88
89
90
app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
app.state.config.COMFYUI_SD3 = COMFYUI_SD3
91
92
93
app.state.config.COMFYUI_FLUX = COMFYUI_FLUX
app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE = COMFYUI_FLUX_WEIGHT_DTYPE
app.state.config.COMFYUI_FLUX_FP8_CLIP = COMFYUI_FLUX_FP8_CLIP
Timothy J. Baek's avatar
Timothy J. Baek committed
94

Timothy J. Baek's avatar
Timothy J. Baek committed
95

96
97
98
99
def get_automatic1111_api_auth():
    if app.state.config.AUTOMATIC1111_API_AUTH == None:
        return ""
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
100
        auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
101
        auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
Timothy J. Baek's avatar
Timothy J. Baek committed
102
        auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
103
104
105
        return f"Basic {auth1111_base64_encoded_string}"


Timothy J. Baek's avatar
Timothy J. Baek committed
106
107
@app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
108
    return {
109
110
        "engine": app.state.config.ENGINE,
        "enabled": app.state.config.ENABLED,
111
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
112
113


Timothy J. Baek's avatar
Timothy J. Baek committed
114
115
116
117
118
119
120
class ConfigUpdateForm(BaseModel):
    engine: str
    enabled: bool


@app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
121
122
    app.state.config.ENGINE = form_data.engine
    app.state.config.ENABLED = form_data.enabled
123
    return {
124
125
        "engine": app.state.config.ENGINE,
        "enabled": app.state.config.ENABLED,
126
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
127
128


Timothy J. Baek's avatar
Timothy J. Baek committed
129
130
class EngineUrlUpdateForm(BaseModel):
    AUTOMATIC1111_BASE_URL: Optional[str] = None
131
    AUTOMATIC1111_API_AUTH: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
132
    COMFYUI_BASE_URL: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
133
134
135


@app.get("/url")
Timothy J. Baek's avatar
Timothy J. Baek committed
136
137
async def get_engine_url(user=Depends(get_admin_user)):
    return {
138
        "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
139
        "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
140
        "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
141
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
142
143
144


@app.post("/url/update")
Timothy J. Baek's avatar
Timothy J. Baek committed
145
async def update_engine_url(
Timothy J. Baek's avatar
Timothy J. Baek committed
146
    form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
147
):
Timothy J. Baek's avatar
Timothy J. Baek committed
148
    if form_data.AUTOMATIC1111_BASE_URL == None:
149
        app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
150
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
151
        url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
Timothy J. Baek's avatar
Timothy J. Baek committed
152
153
        try:
            r = requests.head(url)
root's avatar
root committed
154
	    r.raise_for_status() 
155
            app.state.config.AUTOMATIC1111_BASE_URL = url
Timothy J. Baek's avatar
Timothy J. Baek committed
156
        except Exception as e:
root's avatar
root committed
157
            raise HTTPException(status_code=400, detail="Invalid URL provided.")
Timothy J. Baek's avatar
Timothy J. Baek committed
158

Timothy J. Baek's avatar
Timothy J. Baek committed
159
    if form_data.COMFYUI_BASE_URL == None:
160
        app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
161
162
    else:
        url = form_data.COMFYUI_BASE_URL.strip("/")
Timothy J. Baek's avatar
Timothy J. Baek committed
163
164
165

        try:
            r = requests.head(url)
166
            app.state.config.COMFYUI_BASE_URL = url
Timothy J. Baek's avatar
Timothy J. Baek committed
167
168
        except Exception as e:
            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
169

170
171
172
173
174
    if form_data.AUTOMATIC1111_API_AUTH == None:
        app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
    else:
        app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH

Timothy J. Baek's avatar
Timothy J. Baek committed
175
    return {
176
        "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
177
        "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
178
        "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
179
180
        "status": True,
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
181
182


183
184
class OpenAIConfigUpdateForm(BaseModel):
    url: str
Timothy J. Baek's avatar
Timothy J. Baek committed
185
186
187
    key: str


188
189
190
@app.get("/openai/config")
async def get_openai_config(user=Depends(get_admin_user)):
    return {
191
192
        "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
        "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
193
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
194
195


196
197
@app.post("/openai/config/update")
async def update_openai_config(
Timothy J. Baek's avatar
Timothy J. Baek committed
198
    form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
199
200
201
202
):
    if form_data.key == "":
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

203
204
    app.state.config.OPENAI_API_BASE_URL = form_data.url
    app.state.config.OPENAI_API_KEY = form_data.key
205

Timothy J. Baek's avatar
Timothy J. Baek committed
206
207
    return {
        "status": True,
208
209
        "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
        "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
210
211
212
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
213
214
215
216
217
218
class ImageSizeUpdateForm(BaseModel):
    size: str


@app.get("/size")
async def get_image_size(user=Depends(get_admin_user)):
219
    return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
Timothy J. Baek's avatar
Timothy J. Baek committed
220
221
222
223


@app.post("/size/update")
async def update_image_size(
Timothy J. Baek's avatar
Timothy J. Baek committed
224
    form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
225
226
227
):
    pattern = r"^\d+x\d+$"  # Regular expression pattern
    if re.match(pattern, form_data.size):
228
        app.state.config.IMAGE_SIZE = form_data.size
Timothy J. Baek's avatar
Timothy J. Baek committed
229
        return {
230
            "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
231
232
233
234
235
236
237
            "status": True,
        }
    else:
        raise HTTPException(
            status_code=400,
            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 512x512)."),
        )
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
238

239
240
241
242
243
244
245

class ImageStepsUpdateForm(BaseModel):
    steps: int


@app.get("/steps")
async def get_image_size(user=Depends(get_admin_user)):
246
    return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
247
248
249
250


@app.post("/steps/update")
async def update_image_size(
Timothy J. Baek's avatar
Timothy J. Baek committed
251
    form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
252
253
):
    if form_data.steps >= 0:
254
        app.state.config.IMAGE_STEPS = form_data.steps
255
        return {
256
            "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
257
258
259
260
261
262
263
            "status": True,
        }
    else:
        raise HTTPException(
            status_code=400,
            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 50)."),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
264
265


Timothy J. Baek's avatar
Timothy J. Baek committed
266
@app.get("/models")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
267
def get_models(user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
268
    try:
269
        if app.state.config.ENGINE == "openai":
Timothy J. Baek's avatar
Timothy J. Baek committed
270
271
272
273
            return [
                {"id": "dall-e-2", "name": "DALL·E 2"},
                {"id": "dall-e-3", "name": "DALL·E 3"},
            ]
274
        elif app.state.config.ENGINE == "comfyui":
Timothy J. Baek's avatar
Timothy J. Baek committed
275

276
            r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
Timothy J. Baek's avatar
Timothy J. Baek committed
277
278
279
280
281
282
283
284
285
            info = r.json()

            return list(
                map(
                    lambda model: {"id": model, "name": model},
                    info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
                )
            )

Timothy J. Baek's avatar
Timothy J. Baek committed
286
287
        else:
            r = requests.get(
288
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
Timothy J. Baek's avatar
Timothy J. Baek committed
289
                headers={"authorization": get_automatic1111_api_auth()},
Timothy J. Baek's avatar
Timothy J. Baek committed
290
291
292
293
294
295
296
297
            )
            models = r.json()
            return list(
                map(
                    lambda model: {"id": model["title"], "name": model["model_name"]},
                    models,
                )
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
298
    except Exception as e:
299
        app.state.config.ENABLED = False
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
300
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
301
302
303
304
305


@app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)):
    try:
306
        if app.state.config.ENGINE == "openai":
307
308
            return {
                "model": (
309
                    app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
310
311
                )
            }
312
313
        elif app.state.config.ENGINE == "comfyui":
            return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
Timothy J. Baek's avatar
Timothy J. Baek committed
314
        else:
315
            r = requests.get(
316
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
Timothy J. Baek's avatar
Timothy J. Baek committed
317
                headers={"authorization": get_automatic1111_api_auth()},
318
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
319
320
            options = r.json()
            return {"model": options["sd_model_checkpoint"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
321
    except Exception as e:
322
        app.state.config.ENABLED = False
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
323
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
324
325
326
327
328
329
330


class UpdateModelForm(BaseModel):
    model: str


def set_model_handler(model: str):
331
332
333
    if app.state.config.ENGINE in ["openai", "comfyui"]:
        app.state.config.MODEL = model
        return app.state.config.MODEL
Timothy J. Baek's avatar
Timothy J. Baek committed
334
    else:
335
        api_auth = get_automatic1111_api_auth()
336
        r = requests.get(
337
            url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
Timothy J. Baek's avatar
Timothy J. Baek committed
338
            headers={"authorization": api_auth},
339
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
340
341
342
343
344
        options = r.json()

        if model != options["sd_model_checkpoint"]:
            options["sd_model_checkpoint"] = model
            r = requests.post(
345
346
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
                json=options,
347
                headers={"authorization": api_auth},
Timothy J. Baek's avatar
Timothy J. Baek committed
348
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
349

Timothy J. Baek's avatar
Timothy J. Baek committed
350
        return options
Timothy J. Baek's avatar
Timothy J. Baek committed
351
352
353
354


@app.post("/models/default/update")
def update_default_model(
Timothy J. Baek's avatar
Timothy J. Baek committed
355
    form_data: UpdateModelForm,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
356
    user=Depends(get_verified_user),
Timothy J. Baek's avatar
Timothy J. Baek committed
357
358
359
360
361
362
363
364
):
    return set_model_handler(form_data.model)


class GenerateImageForm(BaseModel):
    model: Optional[str] = None
    prompt: str
    n: int = 1
Timothy J. Baek's avatar
Timothy J. Baek committed
365
    size: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
366
367
368
    negative_prompt: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
369
370
def save_b64_image(b64_str):
    try:
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
371
        image_id = str(uuid.uuid4())
Timothy J. Baek's avatar
Timothy J. Baek committed
372

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
373
374
375
        if "," in b64_str:
            header, encoded = b64_str.split(",", 1)
            mime_type = header.split(";")[0]
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
376

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
            img_data = base64.b64decode(encoded)
            image_format = mimetypes.guess_extension(mime_type)

            image_filename = f"{image_id}{image_format}"
            file_path = IMAGE_CACHE_DIR / f"{image_filename}"
            with open(file_path, "wb") as f:
                f.write(img_data)
            return image_filename
        else:
            image_filename = f"{image_id}.png"
            file_path = IMAGE_CACHE_DIR.joinpath(image_filename)

            img_data = base64.b64decode(b64_str)

            # Write the image data to a file
            with open(file_path, "wb") as f:
                f.write(img_data)
            return image_filename
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
395

Timothy J. Baek's avatar
Timothy J. Baek committed
396
    except Exception as e:
397
        log.exception(f"Error saving image: {e}")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
398
        return None
Timothy J. Baek's avatar
Timothy J. Baek committed
399
400


Timothy J. Baek's avatar
Timothy J. Baek committed
401
402
403
404
405
def save_url_image(url):
    image_id = str(uuid.uuid4())
    try:
        r = requests.get(url)
        r.raise_for_status()
406
407
408
409
        if r.headers["content-type"].split("/")[0] == "image":

            mime_type = r.headers["content-type"]
            image_format = mimetypes.guess_extension(mime_type)
Timothy J. Baek's avatar
Timothy J. Baek committed
410

411
412
413
            if not image_format:
                raise ValueError("Could not determine image type from MIME type")

Timothy J. Baek's avatar
Timothy J. Baek committed
414
415
416
            image_filename = f"{image_id}{image_format}"

            file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
417
418
419
            with open(file_path, "wb") as image_file:
                for chunk in r.iter_content(chunk_size=8192):
                    image_file.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
420
            return image_filename
421
422
        else:
            log.error(f"Url does not point to an image.")
Timothy J. Baek's avatar
Timothy J. Baek committed
423
            return None
Timothy J. Baek's avatar
Timothy J. Baek committed
424
425

    except Exception as e:
426
        log.exception(f"Error saving image: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
427
        return None
Timothy J. Baek's avatar
Timothy J. Baek committed
428
429


Timothy J. Baek's avatar
Timothy J. Baek committed
430
@app.post("/generations")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
431
async def image_generations(
Timothy J. Baek's avatar
Timothy J. Baek committed
432
    form_data: GenerateImageForm,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
433
    user=Depends(get_verified_user),
Timothy J. Baek's avatar
Timothy J. Baek committed
434
):
435
    width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
Timothy J. Baek's avatar
Timothy J. Baek committed
436

Timothy J. Baek's avatar
Timothy J. Baek committed
437
    r = None
438
    try:
439
        if app.state.config.ENGINE == "openai":
440

Timothy J. Baek's avatar
Timothy J. Baek committed
441
            headers = {}
442
            headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
Timothy J. Baek's avatar
Timothy J. Baek committed
443
            headers["Content-Type"] = "application/json"
444

Timothy J. Baek's avatar
Timothy J. Baek committed
445
            data = {
446
447
448
449
450
                "model": (
                    app.state.config.MODEL
                    if app.state.config.MODEL != ""
                    else "dall-e-2"
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
451
452
                "prompt": form_data.prompt,
                "n": form_data.n,
453
                "size": (
454
                    form_data.size if form_data.size else app.state.config.IMAGE_SIZE
455
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
456
457
                "response_format": "b64_json",
            }
458

Timothy J. Baek's avatar
Timothy J. Baek committed
459
            r = requests.post(
460
                url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
Timothy J. Baek's avatar
Timothy J. Baek committed
461
462
463
                json=data,
                headers=headers,
            )
464

Timothy J. Baek's avatar
Timothy J. Baek committed
465
466
            r.raise_for_status()
            res = r.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
467

Timothy J. Baek's avatar
Timothy J. Baek committed
468
469
470
            images = []

            for image in res["data"]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
471
472
                image_filename = save_b64_image(image["b64_json"])
                images.append({"url": f"/cache/image/generations/{image_filename}"})
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
473
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
474
475
476
477
478
479

                with open(file_body_path, "w") as f:
                    json.dump(data, f)

            return images

480
        elif app.state.config.ENGINE == "comfyui":
Timothy J. Baek's avatar
Timothy J. Baek committed
481
482
483
484
485
486
487
488

            data = {
                "prompt": form_data.prompt,
                "width": width,
                "height": height,
                "n": form_data.n,
            }

489
490
            if app.state.config.IMAGE_STEPS is not None:
                data["steps"] = app.state.config.IMAGE_STEPS
Timothy J. Baek's avatar
Timothy J. Baek committed
491

492
            if form_data.negative_prompt is not None:
Timothy J. Baek's avatar
Timothy J. Baek committed
493
494
                data["negative_prompt"] = form_data.negative_prompt

495
496
497
498
499
500
501
502
503
504
505
506
            if app.state.config.COMFYUI_CFG_SCALE:
                data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE

            if app.state.config.COMFYUI_SAMPLER is not None:
                data["sampler"] = app.state.config.COMFYUI_SAMPLER

            if app.state.config.COMFYUI_SCHEDULER is not None:
                data["scheduler"] = app.state.config.COMFYUI_SCHEDULER

            if app.state.config.COMFYUI_SD3 is not None:
                data["sd3"] = app.state.config.COMFYUI_SD3

507
508
509
510
511
512
513
514
515
            if app.state.config.COMFYUI_FLUX is not None:
                data["flux"] = app.state.config.COMFYUI_FLUX

            if app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE is not None:
                data["flux_weight_dtype"] = app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE

            if app.state.config.COMFYUI_FLUX_FP8_CLIP is not None:
                data["flux_fp8_clip"] = app.state.config.COMFYUI_FLUX_FP8_CLIP

Timothy J. Baek's avatar
Timothy J. Baek committed
516
517
518
            data = ImageGenerationPayload(**data)

            res = comfyui_generate_image(
519
                app.state.config.MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
520
521
                data,
                user.id,
522
                app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
523
            )
524
            log.debug(f"res: {res}")
Timothy J. Baek's avatar
Timothy J. Baek committed
525
526
527
528

            images = []

            for image in res["data"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
529
530
531
                image_filename = save_url_image(image["url"])
                images.append({"url": f"/cache/image/generations/{image_filename}"})
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
532
533
534
535

                with open(file_body_path, "w") as f:
                    json.dump(data.model_dump(exclude_none=True), f)

536
            log.debug(f"images: {images}")
Timothy J. Baek's avatar
Timothy J. Baek committed
537
            return images
Timothy J. Baek's avatar
Timothy J. Baek committed
538
539
540
541
542
543
544
545
546
547
548
        else:
            if form_data.model:
                set_model_handler(form_data.model)

            data = {
                "prompt": form_data.prompt,
                "batch_size": form_data.n,
                "width": width,
                "height": height,
            }

549
550
            if app.state.config.IMAGE_STEPS is not None:
                data["steps"] = app.state.config.IMAGE_STEPS
Timothy J. Baek's avatar
Timothy J. Baek committed
551

552
            if form_data.negative_prompt is not None:
Timothy J. Baek's avatar
Timothy J. Baek committed
553
554
555
                data["negative_prompt"] = form_data.negative_prompt

            r = requests.post(
556
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
Timothy J. Baek's avatar
Timothy J. Baek committed
557
                json=data,
558
                headers={"authorization": get_automatic1111_api_auth()},
Timothy J. Baek's avatar
Timothy J. Baek committed
559
560
561
562
            )

            res = r.json()

563
            log.debug(f"res: {res}")
Timothy J. Baek's avatar
Timothy J. Baek committed
564
565
566
567

            images = []

            for image in res["images"]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
568
569
                image_filename = save_b64_image(image)
                images.append({"url": f"/cache/image/generations/{image_filename}"})
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
570
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
571
572
573
574
575

                with open(file_body_path, "w") as f:
                    json.dump({**data, "info": res["info"]}, f)

            return images
576
577

    except Exception as e:
578
579
580
581
582
583
584
        error = e

        if r != None:
            data = r.json()
            if "error" in data:
                error = data["error"]["message"]
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))