main.py 17.7 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
import re
Timothy J. Baek's avatar
Timothy J. Baek committed
2
import requests
3
import base64
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from fastapi import (
    FastAPI,
    Request,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
from fastapi.middleware.cors import CORSMiddleware

from constants import ERROR_MESSAGES
from utils.utils import (
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
18
    get_verified_user,
Timothy J. Baek's avatar
Timothy J. Baek committed
19
20
    get_admin_user,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
21
22

from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
Timothy J. Baek's avatar
Timothy J. Baek committed
23
24
25
from utils.misc import calculate_sha256
from typing import Optional
from pydantic import BaseModel
Timothy J. Baek's avatar
Timothy J. Baek committed
26
from pathlib import Path
27
import mimetypes
Timothy J. Baek's avatar
Timothy J. Baek committed
28
29
30
import uuid
import base64
import json
31
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
32

Self Denial's avatar
Self Denial committed
33
34
35
from config import (
    SRC_LOG_LEVELS,
    CACHE_DIR,
36
    IMAGE_GENERATION_ENGINE,
Self Denial's avatar
Self Denial committed
37
    ENABLE_IMAGE_GENERATION,
Self Denial's avatar
Self Denial committed
38
    AUTOMATIC1111_BASE_URL,
39
    AUTOMATIC1111_API_AUTH,
Self Denial's avatar
Self Denial committed
40
    COMFYUI_BASE_URL,
41
42
43
44
    COMFYUI_CFG_SCALE,
    COMFYUI_SAMPLER,
    COMFYUI_SCHEDULER,
    COMFYUI_SD3,
45
46
47
    COMFYUI_FLUX,
    COMFYUI_FLUX_WEIGHT_DTYPE,
    COMFYUI_FLUX_FP8_CLIP,
48
49
    IMAGES_OPENAI_API_BASE_URL,
    IMAGES_OPENAI_API_KEY,
50
    IMAGE_GENERATION_MODEL,
51
52
    IMAGE_SIZE,
    IMAGE_STEPS,
53
    AppConfig,
Self Denial's avatar
Self Denial committed
54
)
Timothy J. Baek's avatar
Timothy J. Baek committed
55

56
57
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
Timothy J. Baek's avatar
Timothy J. Baek committed
58
59
60

IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
61
62
63
64
65
66
67
68
69
70

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

71
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
72

73
74
app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
Timothy J. Baek's avatar
Timothy J. Baek committed
75

76
77
app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
78

79
app.state.config.MODEL = IMAGE_GENERATION_MODEL
Timothy J. Baek's avatar
Timothy J. Baek committed
80

81
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
82
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
83
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
84

85
86
app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS
87
88
89
90
app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
app.state.config.COMFYUI_SD3 = COMFYUI_SD3
91
92
93
app.state.config.COMFYUI_FLUX = COMFYUI_FLUX
app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE = COMFYUI_FLUX_WEIGHT_DTYPE
app.state.config.COMFYUI_FLUX_FP8_CLIP = COMFYUI_FLUX_FP8_CLIP
Timothy J. Baek's avatar
Timothy J. Baek committed
94

Timothy J. Baek's avatar
Timothy J. Baek committed
95

96
97
98
99
def get_automatic1111_api_auth():
    if app.state.config.AUTOMATIC1111_API_AUTH == None:
        return ""
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
100
        auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
101
        auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
Timothy J. Baek's avatar
Timothy J. Baek committed
102
        auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
103
104
105
        return f"Basic {auth1111_base64_encoded_string}"


Timothy J. Baek's avatar
Timothy J. Baek committed
106
107
@app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
108
    return {
109
110
        "engine": app.state.config.ENGINE,
        "enabled": app.state.config.ENABLED,
111
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
112
113


Timothy J. Baek's avatar
Timothy J. Baek committed
114
115
116
117
118
119
120
class ConfigUpdateForm(BaseModel):
    engine: str
    enabled: bool


@app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
121
122
    app.state.config.ENGINE = form_data.engine
    app.state.config.ENABLED = form_data.enabled
123
    return {
124
125
        "engine": app.state.config.ENGINE,
        "enabled": app.state.config.ENABLED,
126
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
127
128


Timothy J. Baek's avatar
Timothy J. Baek committed
129
130
class EngineUrlUpdateForm(BaseModel):
    AUTOMATIC1111_BASE_URL: Optional[str] = None
131
    AUTOMATIC1111_API_AUTH: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
132
    COMFYUI_BASE_URL: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
133
134
135


@app.get("/url")
Timothy J. Baek's avatar
Timothy J. Baek committed
136
137
async def get_engine_url(user=Depends(get_admin_user)):
    return {
138
        "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
139
        "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
140
        "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
141
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
142
143
144


@app.post("/url/update")
Timothy J. Baek's avatar
Timothy J. Baek committed
145
async def update_engine_url(
Timothy J. Baek's avatar
Timothy J. Baek committed
146
    form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
147
):
Timothy J. Baek's avatar
Timothy J. Baek committed
148
    if form_data.AUTOMATIC1111_BASE_URL == None:
149
        app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
150
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
151
        url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
Timothy J. Baek's avatar
Timothy J. Baek committed
152
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
153
154
            r = requests.head(url)
            r.raise_for_status()
155
            app.state.config.AUTOMATIC1111_BASE_URL = url
Timothy J. Baek's avatar
Timothy J. Baek committed
156
        except Exception as e:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
157
            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
Timothy J. Baek's avatar
Timothy J. Baek committed
158

Timothy J. Baek's avatar
Timothy J. Baek committed
159
    if form_data.COMFYUI_BASE_URL == None:
160
        app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
161
162
    else:
        url = form_data.COMFYUI_BASE_URL.strip("/")
Timothy J. Baek's avatar
Timothy J. Baek committed
163
164
165

        try:
            r = requests.head(url)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
166
            r.raise_for_status()
167
            app.state.config.COMFYUI_BASE_URL = url
Timothy J. Baek's avatar
Timothy J. Baek committed
168
        except Exception as e:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
169
            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
Timothy J. Baek's avatar
Timothy J. Baek committed
170

171
172
173
174
175
    if form_data.AUTOMATIC1111_API_AUTH == None:
        app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
    else:
        app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH

Timothy J. Baek's avatar
Timothy J. Baek committed
176
    return {
177
        "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
178
        "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
179
        "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
180
181
        "status": True,
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
182
183


184
185
class OpenAIConfigUpdateForm(BaseModel):
    url: str
Timothy J. Baek's avatar
Timothy J. Baek committed
186
187
188
    key: str


189
190
191
@app.get("/openai/config")
async def get_openai_config(user=Depends(get_admin_user)):
    return {
192
193
        "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
        "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
194
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
195
196


197
198
@app.post("/openai/config/update")
async def update_openai_config(
Timothy J. Baek's avatar
Timothy J. Baek committed
199
    form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
200
201
202
203
):
    if form_data.key == "":
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

204
205
    app.state.config.OPENAI_API_BASE_URL = form_data.url
    app.state.config.OPENAI_API_KEY = form_data.key
206

Timothy J. Baek's avatar
Timothy J. Baek committed
207
208
    return {
        "status": True,
209
210
        "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
        "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
211
212
213
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
214
215
216
217
218
219
class ImageSizeUpdateForm(BaseModel):
    size: str


@app.get("/size")
async def get_image_size(user=Depends(get_admin_user)):
220
    return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
Timothy J. Baek's avatar
Timothy J. Baek committed
221
222
223
224


@app.post("/size/update")
async def update_image_size(
Timothy J. Baek's avatar
Timothy J. Baek committed
225
    form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
226
227
228
):
    pattern = r"^\d+x\d+$"  # Regular expression pattern
    if re.match(pattern, form_data.size):
229
        app.state.config.IMAGE_SIZE = form_data.size
Timothy J. Baek's avatar
Timothy J. Baek committed
230
        return {
231
            "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
232
233
234
235
236
237
238
            "status": True,
        }
    else:
        raise HTTPException(
            status_code=400,
            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 512x512)."),
        )
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
239

240
241
242
243
244
245
246

class ImageStepsUpdateForm(BaseModel):
    steps: int


@app.get("/steps")
async def get_image_size(user=Depends(get_admin_user)):
247
    return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
248
249
250
251


@app.post("/steps/update")
async def update_image_size(
Timothy J. Baek's avatar
Timothy J. Baek committed
252
    form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
253
254
):
    if form_data.steps >= 0:
255
        app.state.config.IMAGE_STEPS = form_data.steps
256
        return {
257
            "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
258
259
260
261
262
263
264
            "status": True,
        }
    else:
        raise HTTPException(
            status_code=400,
            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 50)."),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
265
266


Timothy J. Baek's avatar
Timothy J. Baek committed
267
@app.get("/models")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
268
def get_models(user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
269
    try:
270
        if app.state.config.ENGINE == "openai":
Timothy J. Baek's avatar
Timothy J. Baek committed
271
272
273
274
            return [
                {"id": "dall-e-2", "name": "DALL·E 2"},
                {"id": "dall-e-3", "name": "DALL·E 3"},
            ]
275
        elif app.state.config.ENGINE == "comfyui":
Timothy J. Baek's avatar
Timothy J. Baek committed
276

277
            r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
Timothy J. Baek's avatar
Timothy J. Baek committed
278
279
280
281
282
283
284
285
286
            info = r.json()

            return list(
                map(
                    lambda model: {"id": model, "name": model},
                    info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
                )
            )

Timothy J. Baek's avatar
Timothy J. Baek committed
287
288
        else:
            r = requests.get(
289
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
Timothy J. Baek's avatar
Timothy J. Baek committed
290
                headers={"authorization": get_automatic1111_api_auth()},
Timothy J. Baek's avatar
Timothy J. Baek committed
291
292
293
294
295
296
297
298
            )
            models = r.json()
            return list(
                map(
                    lambda model: {"id": model["title"], "name": model["model_name"]},
                    models,
                )
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
299
    except Exception as e:
300
        app.state.config.ENABLED = False
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
301
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
302
303
304
305
306


@app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)):
    try:
307
        if app.state.config.ENGINE == "openai":
308
309
            return {
                "model": (
310
                    app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
311
312
                )
            }
313
314
        elif app.state.config.ENGINE == "comfyui":
            return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
Timothy J. Baek's avatar
Timothy J. Baek committed
315
        else:
316
            r = requests.get(
317
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
Timothy J. Baek's avatar
Timothy J. Baek committed
318
                headers={"authorization": get_automatic1111_api_auth()},
319
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
320
321
            options = r.json()
            return {"model": options["sd_model_checkpoint"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
322
    except Exception as e:
323
        app.state.config.ENABLED = False
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
324
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
325
326
327
328
329
330
331


class UpdateModelForm(BaseModel):
    model: str


def set_model_handler(model: str):
332
333
334
    if app.state.config.ENGINE in ["openai", "comfyui"]:
        app.state.config.MODEL = model
        return app.state.config.MODEL
Timothy J. Baek's avatar
Timothy J. Baek committed
335
    else:
336
        api_auth = get_automatic1111_api_auth()
337
        r = requests.get(
338
            url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
Timothy J. Baek's avatar
Timothy J. Baek committed
339
            headers={"authorization": api_auth},
340
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
341
342
343
344
345
        options = r.json()

        if model != options["sd_model_checkpoint"]:
            options["sd_model_checkpoint"] = model
            r = requests.post(
346
347
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
                json=options,
348
                headers={"authorization": api_auth},
Timothy J. Baek's avatar
Timothy J. Baek committed
349
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
350

Timothy J. Baek's avatar
Timothy J. Baek committed
351
        return options
Timothy J. Baek's avatar
Timothy J. Baek committed
352
353
354
355


@app.post("/models/default/update")
def update_default_model(
Timothy J. Baek's avatar
Timothy J. Baek committed
356
    form_data: UpdateModelForm,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
357
    user=Depends(get_verified_user),
Timothy J. Baek's avatar
Timothy J. Baek committed
358
359
360
361
362
363
364
365
):
    return set_model_handler(form_data.model)


class GenerateImageForm(BaseModel):
    model: Optional[str] = None
    prompt: str
    n: int = 1
Timothy J. Baek's avatar
Timothy J. Baek committed
366
    size: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
367
368
369
    negative_prompt: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
370
371
def save_b64_image(b64_str):
    try:
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
372
        image_id = str(uuid.uuid4())
Timothy J. Baek's avatar
Timothy J. Baek committed
373

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
374
375
376
        if "," in b64_str:
            header, encoded = b64_str.split(",", 1)
            mime_type = header.split(";")[0]
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
377

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
            img_data = base64.b64decode(encoded)
            image_format = mimetypes.guess_extension(mime_type)

            image_filename = f"{image_id}{image_format}"
            file_path = IMAGE_CACHE_DIR / f"{image_filename}"
            with open(file_path, "wb") as f:
                f.write(img_data)
            return image_filename
        else:
            image_filename = f"{image_id}.png"
            file_path = IMAGE_CACHE_DIR.joinpath(image_filename)

            img_data = base64.b64decode(b64_str)

            # Write the image data to a file
            with open(file_path, "wb") as f:
                f.write(img_data)
            return image_filename
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
396

Timothy J. Baek's avatar
Timothy J. Baek committed
397
    except Exception as e:
398
        log.exception(f"Error saving image: {e}")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
399
        return None
Timothy J. Baek's avatar
Timothy J. Baek committed
400
401


Timothy J. Baek's avatar
Timothy J. Baek committed
402
403
404
405
406
def save_url_image(url):
    image_id = str(uuid.uuid4())
    try:
        r = requests.get(url)
        r.raise_for_status()
407
408
409
410
        if r.headers["content-type"].split("/")[0] == "image":

            mime_type = r.headers["content-type"]
            image_format = mimetypes.guess_extension(mime_type)
Timothy J. Baek's avatar
Timothy J. Baek committed
411

412
413
414
            if not image_format:
                raise ValueError("Could not determine image type from MIME type")

Timothy J. Baek's avatar
Timothy J. Baek committed
415
416
417
            image_filename = f"{image_id}{image_format}"

            file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
418
419
420
            with open(file_path, "wb") as image_file:
                for chunk in r.iter_content(chunk_size=8192):
                    image_file.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
421
            return image_filename
422
423
        else:
            log.error(f"Url does not point to an image.")
Timothy J. Baek's avatar
Timothy J. Baek committed
424
            return None
Timothy J. Baek's avatar
Timothy J. Baek committed
425
426

    except Exception as e:
427
        log.exception(f"Error saving image: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
428
        return None
Timothy J. Baek's avatar
Timothy J. Baek committed
429
430


Timothy J. Baek's avatar
Timothy J. Baek committed
431
@app.post("/generations")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
432
async def image_generations(
Timothy J. Baek's avatar
Timothy J. Baek committed
433
    form_data: GenerateImageForm,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
434
    user=Depends(get_verified_user),
Timothy J. Baek's avatar
Timothy J. Baek committed
435
):
436
    width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
Timothy J. Baek's avatar
Timothy J. Baek committed
437

Timothy J. Baek's avatar
Timothy J. Baek committed
438
    r = None
439
    try:
440
        if app.state.config.ENGINE == "openai":
441

Timothy J. Baek's avatar
Timothy J. Baek committed
442
            headers = {}
443
            headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
Timothy J. Baek's avatar
Timothy J. Baek committed
444
            headers["Content-Type"] = "application/json"
445

Timothy J. Baek's avatar
Timothy J. Baek committed
446
            data = {
447
448
449
450
451
                "model": (
                    app.state.config.MODEL
                    if app.state.config.MODEL != ""
                    else "dall-e-2"
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
452
453
                "prompt": form_data.prompt,
                "n": form_data.n,
454
                "size": (
455
                    form_data.size if form_data.size else app.state.config.IMAGE_SIZE
456
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
457
458
                "response_format": "b64_json",
            }
459

Timothy J. Baek's avatar
Timothy J. Baek committed
460
            r = requests.post(
461
                url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
Timothy J. Baek's avatar
Timothy J. Baek committed
462
463
464
                json=data,
                headers=headers,
            )
465

Timothy J. Baek's avatar
Timothy J. Baek committed
466
467
            r.raise_for_status()
            res = r.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
468

Timothy J. Baek's avatar
Timothy J. Baek committed
469
470
471
            images = []

            for image in res["data"]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
472
473
                image_filename = save_b64_image(image["b64_json"])
                images.append({"url": f"/cache/image/generations/{image_filename}"})
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
474
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
475
476
477
478
479
480

                with open(file_body_path, "w") as f:
                    json.dump(data, f)

            return images

481
        elif app.state.config.ENGINE == "comfyui":
Timothy J. Baek's avatar
Timothy J. Baek committed
482
483
484
485
486
487
488
489

            data = {
                "prompt": form_data.prompt,
                "width": width,
                "height": height,
                "n": form_data.n,
            }

490
491
            if app.state.config.IMAGE_STEPS is not None:
                data["steps"] = app.state.config.IMAGE_STEPS
Timothy J. Baek's avatar
Timothy J. Baek committed
492

493
            if form_data.negative_prompt is not None:
Timothy J. Baek's avatar
Timothy J. Baek committed
494
495
                data["negative_prompt"] = form_data.negative_prompt

496
497
498
499
500
501
502
503
504
505
506
507
            if app.state.config.COMFYUI_CFG_SCALE:
                data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE

            if app.state.config.COMFYUI_SAMPLER is not None:
                data["sampler"] = app.state.config.COMFYUI_SAMPLER

            if app.state.config.COMFYUI_SCHEDULER is not None:
                data["scheduler"] = app.state.config.COMFYUI_SCHEDULER

            if app.state.config.COMFYUI_SD3 is not None:
                data["sd3"] = app.state.config.COMFYUI_SD3

508
509
510
511
512
513
514
515
516
            if app.state.config.COMFYUI_FLUX is not None:
                data["flux"] = app.state.config.COMFYUI_FLUX

            if app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE is not None:
                data["flux_weight_dtype"] = app.state.config.COMFYUI_FLUX_WEIGHT_DTYPE

            if app.state.config.COMFYUI_FLUX_FP8_CLIP is not None:
                data["flux_fp8_clip"] = app.state.config.COMFYUI_FLUX_FP8_CLIP

Timothy J. Baek's avatar
Timothy J. Baek committed
517
518
            data = ImageGenerationPayload(**data)

519
            res = await comfyui_generate_image(
520
                app.state.config.MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
521
522
                data,
                user.id,
523
                app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
524
            )
525
            log.debug(f"res: {res}")
Timothy J. Baek's avatar
Timothy J. Baek committed
526
527
528
529

            images = []

            for image in res["data"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
530
531
532
                image_filename = save_url_image(image["url"])
                images.append({"url": f"/cache/image/generations/{image_filename}"})
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
533
534
535
536

                with open(file_body_path, "w") as f:
                    json.dump(data.model_dump(exclude_none=True), f)

537
            log.debug(f"images: {images}")
Timothy J. Baek's avatar
Timothy J. Baek committed
538
            return images
Timothy J. Baek's avatar
Timothy J. Baek committed
539
540
541
542
543
544
545
546
547
548
549
        else:
            if form_data.model:
                set_model_handler(form_data.model)

            data = {
                "prompt": form_data.prompt,
                "batch_size": form_data.n,
                "width": width,
                "height": height,
            }

550
551
            if app.state.config.IMAGE_STEPS is not None:
                data["steps"] = app.state.config.IMAGE_STEPS
Timothy J. Baek's avatar
Timothy J. Baek committed
552

553
            if form_data.negative_prompt is not None:
Timothy J. Baek's avatar
Timothy J. Baek committed
554
555
556
                data["negative_prompt"] = form_data.negative_prompt

            r = requests.post(
557
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
Timothy J. Baek's avatar
Timothy J. Baek committed
558
                json=data,
559
                headers={"authorization": get_automatic1111_api_auth()},
Timothy J. Baek's avatar
Timothy J. Baek committed
560
561
562
563
            )

            res = r.json()

564
            log.debug(f"res: {res}")
Timothy J. Baek's avatar
Timothy J. Baek committed
565
566
567
568

            images = []

            for image in res["images"]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
569
570
                image_filename = save_b64_image(image)
                images.append({"url": f"/cache/image/generations/{image_filename}"})
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
571
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
572
573
574
575
576

                with open(file_body_path, "w") as f:
                    json.dump({**data, "info": res["info"]}, f)

            return images
577
578

    except Exception as e:
579
580
581
582
583
584
585
        error = e

        if r != None:
            data = r.json()
            if "error" in data:
                error = data["error"]["message"]
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))