main.py 17.6 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
import re
Timothy J. Baek's avatar
Timothy J. Baek committed
2
import requests
3
import base64
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from fastapi import (
    FastAPI,
    Request,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
from fastapi.middleware.cors import CORSMiddleware

from constants import ERROR_MESSAGES
from utils.utils import (
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
18
    get_verified_user,
Timothy J. Baek's avatar
Timothy J. Baek committed
19
20
    get_admin_user,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
21
22

from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
Timothy J. Baek's avatar
Timothy J. Baek committed
23
24
25
from utils.misc import calculate_sha256
from typing import Optional
from pydantic import BaseModel
Timothy J. Baek's avatar
Timothy J. Baek committed
26
from pathlib import Path
27
import mimetypes
Timothy J. Baek's avatar
Timothy J. Baek committed
28
29
30
import uuid
import base64
import json
31
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
32

Self Denial's avatar
Self Denial committed
33
34
35
from config import (
    SRC_LOG_LEVELS,
    CACHE_DIR,
36
    IMAGE_GENERATION_ENGINE,
Self Denial's avatar
Self Denial committed
37
    ENABLE_IMAGE_GENERATION,
Self Denial's avatar
Self Denial committed
38
    AUTOMATIC1111_BASE_URL,
39
    AUTOMATIC1111_API_AUTH,
Self Denial's avatar
Self Denial committed
40
    COMFYUI_BASE_URL,
41
42
43
44
    COMFYUI_CFG_SCALE,
    COMFYUI_SAMPLER,
    COMFYUI_SCHEDULER,
    COMFYUI_SD3,
45
46
47
    COMFYUI_FLUX,
    COMFYUI_FLUX_WEIGHT_DTYPE,
    COMFYUI_FLUX_FP8_CLIP,
48
49
    IMAGES_OPENAI_API_BASE_URL,
    IMAGES_OPENAI_API_KEY,
50
    IMAGE_GENERATION_MODEL,
51
52
    IMAGE_SIZE,
    IMAGE_STEPS,
53
    AppConfig,
Self Denial's avatar
Self Denial committed
54
)
Timothy J. Baek's avatar
Timothy J. Baek committed
55

56
57
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
Timothy J. Baek's avatar
Timothy J. Baek committed
58
59
60

IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
61
62
63
64
65
66
67
68
69
70

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

71
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
72

73
74
app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
Timothy J. Baek's avatar
Timothy J. Baek committed
75

76
77
app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
78

79
app.state.config.MODEL = IMAGE_GENERATION_MODEL
Timothy J. Baek's avatar
Timothy J. Baek committed
80

81
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
82
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
83
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
84

85
86
app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS
87
88
89
90
app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
app.state.config.COMFYUI_SD3 = COMFYUI_SD3
91
92
93
app.state.config.COMFYUI_FLUX = COMFYUI_FLUX
app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE = COMFYUI_FLUX_WEIGHT_DTYPE
app.state.config.COMFYUI_FLUX_FP8_CLIP = COMFYUI_FLUX_FP8_CLIP
Timothy J. Baek's avatar
Timothy J. Baek committed
94

Timothy J. Baek's avatar
Timothy J. Baek committed
95

96
97
98
99
def get_automatic1111_api_auth():
    if app.state.config.AUTOMATIC1111_API_AUTH == None:
        return ""
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
100
        auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
101
        auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
Timothy J. Baek's avatar
Timothy J. Baek committed
102
        auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
103
104
105
        return f"Basic {auth1111_base64_encoded_string}"


Timothy J. Baek's avatar
Timothy J. Baek committed
106
107
@app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
108
    return {
109
110
        "engine": app.state.config.ENGINE,
        "enabled": app.state.config.ENABLED,
111
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
112
113


Timothy J. Baek's avatar
Timothy J. Baek committed
114
115
116
117
118
119
120
class ConfigUpdateForm(BaseModel):
    engine: str
    enabled: bool


@app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
121
122
    app.state.config.ENGINE = form_data.engine
    app.state.config.ENABLED = form_data.enabled
123
    return {
124
125
        "engine": app.state.config.ENGINE,
        "enabled": app.state.config.ENABLED,
126
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
127
128


Timothy J. Baek's avatar
Timothy J. Baek committed
129
130
class EngineUrlUpdateForm(BaseModel):
    AUTOMATIC1111_BASE_URL: Optional[str] = None
131
    AUTOMATIC1111_API_AUTH: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
132
    COMFYUI_BASE_URL: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
133
134
135


@app.get("/url")
Timothy J. Baek's avatar
Timothy J. Baek committed
136
137
async def get_engine_url(user=Depends(get_admin_user)):
    return {
138
        "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
139
        "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
140
        "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
141
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
142
143
144


@app.post("/url/update")
Timothy J. Baek's avatar
Timothy J. Baek committed
145
async def update_engine_url(
Timothy J. Baek's avatar
Timothy J. Baek committed
146
    form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
147
):
Timothy J. Baek's avatar
Timothy J. Baek committed
148
    if form_data.AUTOMATIC1111_BASE_URL == None:
149
        app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
150
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
151
        url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
Timothy J. Baek's avatar
Timothy J. Baek committed
152
153
        try:
            r = requests.head(url)
154
            app.state.config.AUTOMATIC1111_BASE_URL = url
Timothy J. Baek's avatar
Timothy J. Baek committed
155
156
        except Exception as e:
            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
157

Timothy J. Baek's avatar
Timothy J. Baek committed
158
    if form_data.COMFYUI_BASE_URL == None:
159
        app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
160
161
    else:
        url = form_data.COMFYUI_BASE_URL.strip("/")
Timothy J. Baek's avatar
Timothy J. Baek committed
162
163
164

        try:
            r = requests.head(url)
165
            app.state.config.COMFYUI_BASE_URL = url
Timothy J. Baek's avatar
Timothy J. Baek committed
166
167
        except Exception as e:
            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
168

169
170
171
172
173
    if form_data.AUTOMATIC1111_API_AUTH == None:
        app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
    else:
        app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH

Timothy J. Baek's avatar
Timothy J. Baek committed
174
    return {
175
        "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
176
        "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
177
        "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
178
179
        "status": True,
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
180
181


182
183
class OpenAIConfigUpdateForm(BaseModel):
    url: str
Timothy J. Baek's avatar
Timothy J. Baek committed
184
185
186
    key: str


187
188
189
@app.get("/openai/config")
async def get_openai_config(user=Depends(get_admin_user)):
    return {
190
191
        "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
        "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
192
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
193
194


195
196
@app.post("/openai/config/update")
async def update_openai_config(
Timothy J. Baek's avatar
Timothy J. Baek committed
197
    form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
198
199
200
201
):
    if form_data.key == "":
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

202
203
    app.state.config.OPENAI_API_BASE_URL = form_data.url
    app.state.config.OPENAI_API_KEY = form_data.key
204

Timothy J. Baek's avatar
Timothy J. Baek committed
205
206
    return {
        "status": True,
207
208
        "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
        "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
209
210
211
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
212
213
214
215
216
217
class ImageSizeUpdateForm(BaseModel):
    size: str


@app.get("/size")
async def get_image_size(user=Depends(get_admin_user)):
218
    return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
Timothy J. Baek's avatar
Timothy J. Baek committed
219
220
221
222


@app.post("/size/update")
async def update_image_size(
Timothy J. Baek's avatar
Timothy J. Baek committed
223
    form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
224
225
226
):
    pattern = r"^\d+x\d+$"  # Regular expression pattern
    if re.match(pattern, form_data.size):
227
        app.state.config.IMAGE_SIZE = form_data.size
Timothy J. Baek's avatar
Timothy J. Baek committed
228
        return {
229
            "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
230
231
232
233
234
235
236
            "status": True,
        }
    else:
        raise HTTPException(
            status_code=400,
            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 512x512)."),
        )
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
237

238
239
240
241
242
243
244

class ImageStepsUpdateForm(BaseModel):
    steps: int


@app.get("/steps")
async def get_image_size(user=Depends(get_admin_user)):
245
    return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
246
247
248
249


@app.post("/steps/update")
async def update_image_size(
Timothy J. Baek's avatar
Timothy J. Baek committed
250
    form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
251
252
):
    if form_data.steps >= 0:
253
        app.state.config.IMAGE_STEPS = form_data.steps
254
        return {
255
            "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
256
257
258
259
260
261
262
            "status": True,
        }
    else:
        raise HTTPException(
            status_code=400,
            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 50)."),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
263
264


Timothy J. Baek's avatar
Timothy J. Baek committed
265
@app.get("/models")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
266
def get_models(user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
267
    try:
268
        if app.state.config.ENGINE == "openai":
Timothy J. Baek's avatar
Timothy J. Baek committed
269
270
271
272
            return [
                {"id": "dall-e-2", "name": "DALL·E 2"},
                {"id": "dall-e-3", "name": "DALL·E 3"},
            ]
273
        elif app.state.config.ENGINE == "comfyui":
Timothy J. Baek's avatar
Timothy J. Baek committed
274

275
            r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
Timothy J. Baek's avatar
Timothy J. Baek committed
276
277
278
279
280
281
282
283
284
            info = r.json()

            return list(
                map(
                    lambda model: {"id": model, "name": model},
                    info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
                )
            )

Timothy J. Baek's avatar
Timothy J. Baek committed
285
286
        else:
            r = requests.get(
287
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
Timothy J. Baek's avatar
Timothy J. Baek committed
288
                headers={"authorization": get_automatic1111_api_auth()},
Timothy J. Baek's avatar
Timothy J. Baek committed
289
290
291
292
293
294
295
296
            )
            models = r.json()
            return list(
                map(
                    lambda model: {"id": model["title"], "name": model["model_name"]},
                    models,
                )
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
297
    except Exception as e:
298
        app.state.config.ENABLED = False
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
299
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
300
301
302
303
304


@app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)):
    try:
305
        if app.state.config.ENGINE == "openai":
306
307
            return {
                "model": (
308
                    app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
309
310
                )
            }
311
312
        elif app.state.config.ENGINE == "comfyui":
            return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
Timothy J. Baek's avatar
Timothy J. Baek committed
313
        else:
314
            r = requests.get(
315
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
Timothy J. Baek's avatar
Timothy J. Baek committed
316
                headers={"authorization": get_automatic1111_api_auth()},
317
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
318
319
            options = r.json()
            return {"model": options["sd_model_checkpoint"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
320
    except Exception as e:
321
        app.state.config.ENABLED = False
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
322
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
323
324
325
326
327
328
329


class UpdateModelForm(BaseModel):
    model: str


def set_model_handler(model: str):
330
331
332
    if app.state.config.ENGINE in ["openai", "comfyui"]:
        app.state.config.MODEL = model
        return app.state.config.MODEL
Timothy J. Baek's avatar
Timothy J. Baek committed
333
    else:
334
        api_auth = get_automatic1111_api_auth()
335
        r = requests.get(
336
            url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
Timothy J. Baek's avatar
Timothy J. Baek committed
337
            headers={"authorization": api_auth},
338
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
339
340
341
342
343
        options = r.json()

        if model != options["sd_model_checkpoint"]:
            options["sd_model_checkpoint"] = model
            r = requests.post(
344
345
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
                json=options,
346
                headers={"authorization": api_auth},
Timothy J. Baek's avatar
Timothy J. Baek committed
347
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
348

Timothy J. Baek's avatar
Timothy J. Baek committed
349
        return options
Timothy J. Baek's avatar
Timothy J. Baek committed
350
351
352
353


@app.post("/models/default/update")
def update_default_model(
Timothy J. Baek's avatar
Timothy J. Baek committed
354
    form_data: UpdateModelForm,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
355
    user=Depends(get_verified_user),
Timothy J. Baek's avatar
Timothy J. Baek committed
356
357
358
359
360
361
362
363
):
    return set_model_handler(form_data.model)


class GenerateImageForm(BaseModel):
    model: Optional[str] = None
    prompt: str
    n: int = 1
Timothy J. Baek's avatar
Timothy J. Baek committed
364
    size: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
365
366
367
    negative_prompt: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
368
369
def save_b64_image(b64_str):
    try:
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
370
        image_id = str(uuid.uuid4())
Timothy J. Baek's avatar
Timothy J. Baek committed
371

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
372
373
374
        if "," in b64_str:
            header, encoded = b64_str.split(",", 1)
            mime_type = header.split(";")[0]
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
375

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
            img_data = base64.b64decode(encoded)
            image_format = mimetypes.guess_extension(mime_type)

            image_filename = f"{image_id}{image_format}"
            file_path = IMAGE_CACHE_DIR / f"{image_filename}"
            with open(file_path, "wb") as f:
                f.write(img_data)
            return image_filename
        else:
            image_filename = f"{image_id}.png"
            file_path = IMAGE_CACHE_DIR.joinpath(image_filename)

            img_data = base64.b64decode(b64_str)

            # Write the image data to a file
            with open(file_path, "wb") as f:
                f.write(img_data)
            return image_filename
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
394

Timothy J. Baek's avatar
Timothy J. Baek committed
395
    except Exception as e:
396
        log.exception(f"Error saving image: {e}")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
397
        return None
Timothy J. Baek's avatar
Timothy J. Baek committed
398
399


Timothy J. Baek's avatar
Timothy J. Baek committed
400
401
402
403
404
def save_url_image(url):
    image_id = str(uuid.uuid4())
    try:
        r = requests.get(url)
        r.raise_for_status()
405
406
407
408
        if r.headers["content-type"].split("/")[0] == "image":

            mime_type = r.headers["content-type"]
            image_format = mimetypes.guess_extension(mime_type)
Timothy J. Baek's avatar
Timothy J. Baek committed
409

410
411
412
            if not image_format:
                raise ValueError("Could not determine image type from MIME type")

Timothy J. Baek's avatar
Timothy J. Baek committed
413
414
415
            image_filename = f"{image_id}{image_format}"

            file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
416
417
418
            with open(file_path, "wb") as image_file:
                for chunk in r.iter_content(chunk_size=8192):
                    image_file.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
419
            return image_filename
420
421
        else:
            log.error(f"Url does not point to an image.")
Timothy J. Baek's avatar
Timothy J. Baek committed
422
            return None
Timothy J. Baek's avatar
Timothy J. Baek committed
423
424

    except Exception as e:
425
        log.exception(f"Error saving image: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
426
        return None
Timothy J. Baek's avatar
Timothy J. Baek committed
427
428


Timothy J. Baek's avatar
Timothy J. Baek committed
429
@app.post("/generations")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
430
async def image_generations(
Timothy J. Baek's avatar
Timothy J. Baek committed
431
    form_data: GenerateImageForm,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
432
    user=Depends(get_verified_user),
Timothy J. Baek's avatar
Timothy J. Baek committed
433
):
434
    width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
Timothy J. Baek's avatar
Timothy J. Baek committed
435

Timothy J. Baek's avatar
Timothy J. Baek committed
436
    r = None
437
    try:
438
        if app.state.config.ENGINE == "openai":
439

Timothy J. Baek's avatar
Timothy J. Baek committed
440
            headers = {}
441
            headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
Timothy J. Baek's avatar
Timothy J. Baek committed
442
            headers["Content-Type"] = "application/json"
443

Timothy J. Baek's avatar
Timothy J. Baek committed
444
            data = {
445
446
447
448
449
                "model": (
                    app.state.config.MODEL
                    if app.state.config.MODEL != ""
                    else "dall-e-2"
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
450
451
                "prompt": form_data.prompt,
                "n": form_data.n,
452
                "size": (
453
                    form_data.size if form_data.size else app.state.config.IMAGE_SIZE
454
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
455
456
                "response_format": "b64_json",
            }
457

Timothy J. Baek's avatar
Timothy J. Baek committed
458
            r = requests.post(
459
                url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
Timothy J. Baek's avatar
Timothy J. Baek committed
460
461
462
                json=data,
                headers=headers,
            )
463

Timothy J. Baek's avatar
Timothy J. Baek committed
464
465
            r.raise_for_status()
            res = r.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
466

Timothy J. Baek's avatar
Timothy J. Baek committed
467
468
469
            images = []

            for image in res["data"]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
470
471
                image_filename = save_b64_image(image["b64_json"])
                images.append({"url": f"/cache/image/generations/{image_filename}"})
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
472
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
473
474
475
476
477
478

                with open(file_body_path, "w") as f:
                    json.dump(data, f)

            return images

479
        elif app.state.config.ENGINE == "comfyui":
Timothy J. Baek's avatar
Timothy J. Baek committed
480
481
482
483
484
485
486
487

            data = {
                "prompt": form_data.prompt,
                "width": width,
                "height": height,
                "n": form_data.n,
            }

488
489
            if app.state.config.IMAGE_STEPS is not None:
                data["steps"] = app.state.config.IMAGE_STEPS
Timothy J. Baek's avatar
Timothy J. Baek committed
490

491
            if form_data.negative_prompt is not None:
Timothy J. Baek's avatar
Timothy J. Baek committed
492
493
                data["negative_prompt"] = form_data.negative_prompt

494
495
496
497
498
499
500
501
502
503
504
505
            if app.state.config.COMFYUI_CFG_SCALE:
                data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE

            if app.state.config.COMFYUI_SAMPLER is not None:
                data["sampler"] = app.state.config.COMFYUI_SAMPLER

            if app.state.config.COMFYUI_SCHEDULER is not None:
                data["scheduler"] = app.state.config.COMFYUI_SCHEDULER

            if app.state.config.COMFYUI_SD3 is not None:
                data["sd3"] = app.state.config.COMFYUI_SD3

506
507
508
509
510
511
512
513
514
            if app.state.config.COMFYUI_FLUX is not None:
                data["flux"] = app.state.config.COMFYUI_FLUX

            if app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE is not None:
                data["flux_weight_dtype"] = app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE

            if app.state.config.COMFYUI_FLUX_FP8_CLIP is not None:
                data["flux_fp8_clip"] = app.state.config.COMFYUI_FLUX_FP8_CLIP

Timothy J. Baek's avatar
Timothy J. Baek committed
515
516
517
            data = ImageGenerationPayload(**data)

            res = comfyui_generate_image(
518
                app.state.config.MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
519
520
                data,
                user.id,
521
                app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
522
            )
523
            log.debug(f"res: {res}")
Timothy J. Baek's avatar
Timothy J. Baek committed
524
525
526
527

            images = []

            for image in res["data"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
528
529
530
                image_filename = save_url_image(image["url"])
                images.append({"url": f"/cache/image/generations/{image_filename}"})
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
531
532
533
534

                with open(file_body_path, "w") as f:
                    json.dump(data.model_dump(exclude_none=True), f)

535
            log.debug(f"images: {images}")
Timothy J. Baek's avatar
Timothy J. Baek committed
536
            return images
Timothy J. Baek's avatar
Timothy J. Baek committed
537
538
539
540
541
542
543
544
545
546
547
        else:
            if form_data.model:
                set_model_handler(form_data.model)

            data = {
                "prompt": form_data.prompt,
                "batch_size": form_data.n,
                "width": width,
                "height": height,
            }

548
549
            if app.state.config.IMAGE_STEPS is not None:
                data["steps"] = app.state.config.IMAGE_STEPS
Timothy J. Baek's avatar
Timothy J. Baek committed
550

551
            if form_data.negative_prompt is not None:
Timothy J. Baek's avatar
Timothy J. Baek committed
552
553
554
                data["negative_prompt"] = form_data.negative_prompt

            r = requests.post(
555
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
Timothy J. Baek's avatar
Timothy J. Baek committed
556
                json=data,
557
                headers={"authorization": get_automatic1111_api_auth()},
Timothy J. Baek's avatar
Timothy J. Baek committed
558
559
560
561
            )

            res = r.json()

562
            log.debug(f"res: {res}")
Timothy J. Baek's avatar
Timothy J. Baek committed
563
564
565
566

            images = []

            for image in res["images"]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
567
568
                image_filename = save_b64_image(image)
                images.append({"url": f"/cache/image/generations/{image_filename}"})
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
569
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
570
571
572
573
574

                with open(file_body_path, "w") as f:
                    json.dump({**data, "info": res["info"]}, f)

            return images
575
576

    except Exception as e:
577
578
579
580
581
582
583
        error = e

        if r != None:
            data = r.json()
            if "error" in data:
                error = data["error"]["message"]
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))