"test/torchaudio_unittest/batch_consistency_test.py" did not exist on "0fafcb3eeca2c260c83366f585c1398ad6b7a6b7"
main.py 17 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
import re
Timothy J. Baek's avatar
Timothy J. Baek committed
2
import requests
3
import base64
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from fastapi import (
    FastAPI,
    Request,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel

from constants import ERROR_MESSAGES
from utils.utils import (
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
19
    get_verified_user,
Timothy J. Baek's avatar
Timothy J. Baek committed
20
21
    get_admin_user,
)
Timothy J. Baek's avatar
Timothy J. Baek committed
22
23

from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
Timothy J. Baek's avatar
Timothy J. Baek committed
24
25
26
from utils.misc import calculate_sha256
from typing import Optional
from pydantic import BaseModel
Timothy J. Baek's avatar
Timothy J. Baek committed
27
from pathlib import Path
28
import mimetypes
Timothy J. Baek's avatar
Timothy J. Baek committed
29
30
31
import uuid
import base64
import json
32
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
33

Self Denial's avatar
Self Denial committed
34
35
36
from config import (
    SRC_LOG_LEVELS,
    CACHE_DIR,
37
    IMAGE_GENERATION_ENGINE,
Self Denial's avatar
Self Denial committed
38
    ENABLE_IMAGE_GENERATION,
Self Denial's avatar
Self Denial committed
39
    AUTOMATIC1111_BASE_URL,
40
    AUTOMATIC1111_API_AUTH,
Self Denial's avatar
Self Denial committed
41
    COMFYUI_BASE_URL,
42
43
44
45
    COMFYUI_CFG_SCALE,
    COMFYUI_SAMPLER,
    COMFYUI_SCHEDULER,
    COMFYUI_SD3,
46
47
    IMAGES_OPENAI_API_BASE_URL,
    IMAGES_OPENAI_API_KEY,
48
    IMAGE_GENERATION_MODEL,
49
50
    IMAGE_SIZE,
    IMAGE_STEPS,
51
    AppConfig,
Self Denial's avatar
Self Denial committed
52
)
Timothy J. Baek's avatar
Timothy J. Baek committed
53

54
55
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
Timothy J. Baek's avatar
Timothy J. Baek committed
56
57
58

IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
59
60
61
62
63
64
65
66
67
68

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

69
app.state.config = AppConfig()
Timothy J. Baek's avatar
Timothy J. Baek committed
70

71
72
app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
Timothy J. Baek's avatar
Timothy J. Baek committed
73

74
75
app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
Timothy J. Baek's avatar
Timothy J. Baek committed
76

77
app.state.config.MODEL = IMAGE_GENERATION_MODEL
Timothy J. Baek's avatar
Timothy J. Baek committed
78

79
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
80
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
81
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
82

83
84
app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS
85
86
87
88
app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
app.state.config.COMFYUI_SD3 = COMFYUI_SD3
Timothy J. Baek's avatar
Timothy J. Baek committed
89
90


91
92
93
94
def get_automatic1111_api_auth():
    if app.state.config.AUTOMATIC1111_API_AUTH == None:
        return ""
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
95
        auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
96
        auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
Timothy J. Baek's avatar
Timothy J. Baek committed
97
        auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
98
99
100
        return f"Basic {auth1111_base64_encoded_string}"


Timothy J. Baek's avatar
Timothy J. Baek committed
101
102
@app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
103
    return {
104
105
        "engine": app.state.config.ENGINE,
        "enabled": app.state.config.ENABLED,
106
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
107
108


Timothy J. Baek's avatar
Timothy J. Baek committed
109
110
111
112
113
114
115
class ConfigUpdateForm(BaseModel):
    engine: str
    enabled: bool


@app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
116
117
    app.state.config.ENGINE = form_data.engine
    app.state.config.ENABLED = form_data.enabled
118
    return {
119
120
        "engine": app.state.config.ENGINE,
        "enabled": app.state.config.ENABLED,
121
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
122
123


Timothy J. Baek's avatar
Timothy J. Baek committed
124
125
class EngineUrlUpdateForm(BaseModel):
    AUTOMATIC1111_BASE_URL: Optional[str] = None
126
    AUTOMATIC1111_API_AUTH: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
127
    COMFYUI_BASE_URL: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
128
129
130


@app.get("/url")
Timothy J. Baek's avatar
Timothy J. Baek committed
131
132
async def get_engine_url(user=Depends(get_admin_user)):
    return {
133
        "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
134
        "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
135
        "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
136
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
137
138
139


@app.post("/url/update")
Timothy J. Baek's avatar
Timothy J. Baek committed
140
async def update_engine_url(
Timothy J. Baek's avatar
Timothy J. Baek committed
141
    form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
142
):
Timothy J. Baek's avatar
Timothy J. Baek committed
143
    if form_data.AUTOMATIC1111_BASE_URL == None:
144
        app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
145
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
146
        url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
Timothy J. Baek's avatar
Timothy J. Baek committed
147
148
        try:
            r = requests.head(url)
149
            app.state.config.AUTOMATIC1111_BASE_URL = url
Timothy J. Baek's avatar
Timothy J. Baek committed
150
151
        except Exception as e:
            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
152

Timothy J. Baek's avatar
Timothy J. Baek committed
153
    if form_data.COMFYUI_BASE_URL == None:
154
        app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
155
156
    else:
        url = form_data.COMFYUI_BASE_URL.strip("/")
Timothy J. Baek's avatar
Timothy J. Baek committed
157
158
159

        try:
            r = requests.head(url)
160
            app.state.config.COMFYUI_BASE_URL = url
Timothy J. Baek's avatar
Timothy J. Baek committed
161
162
        except Exception as e:
            raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
163

164
165
166
167
168
    if form_data.AUTOMATIC1111_API_AUTH == None:
        app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
    else:
        app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH

Timothy J. Baek's avatar
Timothy J. Baek committed
169
    return {
170
        "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
171
        "AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
172
        "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
173
174
        "status": True,
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
175
176


177
178
class OpenAIConfigUpdateForm(BaseModel):
    url: str
Timothy J. Baek's avatar
Timothy J. Baek committed
179
180
181
    key: str


182
183
184
@app.get("/openai/config")
async def get_openai_config(user=Depends(get_admin_user)):
    return {
185
186
        "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
        "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
187
    }
Timothy J. Baek's avatar
Timothy J. Baek committed
188
189


190
191
@app.post("/openai/config/update")
async def update_openai_config(
Timothy J. Baek's avatar
Timothy J. Baek committed
192
    form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
193
194
195
196
):
    if form_data.key == "":
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)

197
198
    app.state.config.OPENAI_API_BASE_URL = form_data.url
    app.state.config.OPENAI_API_KEY = form_data.key
199

Timothy J. Baek's avatar
Timothy J. Baek committed
200
201
    return {
        "status": True,
202
203
        "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
        "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
Timothy J. Baek's avatar
Timothy J. Baek committed
204
205
206
    }


Timothy J. Baek's avatar
Timothy J. Baek committed
207
208
209
210
211
212
class ImageSizeUpdateForm(BaseModel):
    size: str


@app.get("/size")
async def get_image_size(user=Depends(get_admin_user)):
213
    return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
Timothy J. Baek's avatar
Timothy J. Baek committed
214
215
216
217


@app.post("/size/update")
async def update_image_size(
Timothy J. Baek's avatar
Timothy J. Baek committed
218
    form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
219
220
221
):
    pattern = r"^\d+x\d+$"  # Regular expression pattern
    if re.match(pattern, form_data.size):
222
        app.state.config.IMAGE_SIZE = form_data.size
Timothy J. Baek's avatar
Timothy J. Baek committed
223
        return {
224
            "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
Timothy J. Baek's avatar
Timothy J. Baek committed
225
226
227
228
229
230
231
            "status": True,
        }
    else:
        raise HTTPException(
            status_code=400,
            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 512x512)."),
        )
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
232

233
234
235
236
237
238
239

class ImageStepsUpdateForm(BaseModel):
    steps: int


@app.get("/steps")
async def get_image_size(user=Depends(get_admin_user)):
240
    return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
241
242
243
244


@app.post("/steps/update")
async def update_image_size(
Timothy J. Baek's avatar
Timothy J. Baek committed
245
    form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
246
247
):
    if form_data.steps >= 0:
248
        app.state.config.IMAGE_STEPS = form_data.steps
249
        return {
250
            "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
251
252
253
254
255
256
257
            "status": True,
        }
    else:
        raise HTTPException(
            status_code=400,
            detail=ERROR_MESSAGES.INCORRECT_FORMAT("  (e.g., 50)."),
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
258
259


Timothy J. Baek's avatar
Timothy J. Baek committed
260
@app.get("/models")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
261
def get_models(user=Depends(get_verified_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
262
    try:
263
        if app.state.config.ENGINE == "openai":
Timothy J. Baek's avatar
Timothy J. Baek committed
264
265
266
267
            return [
                {"id": "dall-e-2", "name": "DALL·E 2"},
                {"id": "dall-e-3", "name": "DALL·E 3"},
            ]
268
        elif app.state.config.ENGINE == "comfyui":
Timothy J. Baek's avatar
Timothy J. Baek committed
269

270
            r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
Timothy J. Baek's avatar
Timothy J. Baek committed
271
272
273
274
275
276
277
278
279
            info = r.json()

            return list(
                map(
                    lambda model: {"id": model, "name": model},
                    info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
                )
            )

Timothy J. Baek's avatar
Timothy J. Baek committed
280
281
        else:
            r = requests.get(
282
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
Timothy J. Baek's avatar
Timothy J. Baek committed
283
                headers={"authorization": get_automatic1111_api_auth()},
Timothy J. Baek's avatar
Timothy J. Baek committed
284
285
286
287
288
289
290
291
            )
            models = r.json()
            return list(
                map(
                    lambda model: {"id": model["title"], "name": model["model_name"]},
                    models,
                )
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
292
    except Exception as e:
293
        app.state.config.ENABLED = False
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
294
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
295
296
297
298
299


@app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)):
    try:
300
        if app.state.config.ENGINE == "openai":
301
302
            return {
                "model": (
303
                    app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
304
305
                )
            }
306
307
        elif app.state.config.ENGINE == "comfyui":
            return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
Timothy J. Baek's avatar
Timothy J. Baek committed
308
        else:
309
            r = requests.get(
310
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
Timothy J. Baek's avatar
Timothy J. Baek committed
311
                headers={"authorization": get_automatic1111_api_auth()},
312
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
313
314
            options = r.json()
            return {"model": options["sd_model_checkpoint"]}
Timothy J. Baek's avatar
Timothy J. Baek committed
315
    except Exception as e:
316
        app.state.config.ENABLED = False
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
317
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
Timothy J. Baek's avatar
Timothy J. Baek committed
318
319
320
321
322
323
324


class UpdateModelForm(BaseModel):
    model: str


def set_model_handler(model: str):
325
326
327
    if app.state.config.ENGINE in ["openai", "comfyui"]:
        app.state.config.MODEL = model
        return app.state.config.MODEL
Timothy J. Baek's avatar
Timothy J. Baek committed
328
    else:
329
        api_auth = get_automatic1111_api_auth()
330
        r = requests.get(
331
            url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
Timothy J. Baek's avatar
Timothy J. Baek committed
332
            headers={"authorization": api_auth},
333
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
334
335
336
337
338
        options = r.json()

        if model != options["sd_model_checkpoint"]:
            options["sd_model_checkpoint"] = model
            r = requests.post(
339
340
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
                json=options,
341
                headers={"authorization": api_auth},
Timothy J. Baek's avatar
Timothy J. Baek committed
342
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
343

Timothy J. Baek's avatar
Timothy J. Baek committed
344
        return options
Timothy J. Baek's avatar
Timothy J. Baek committed
345
346
347
348


@app.post("/models/default/update")
def update_default_model(
Timothy J. Baek's avatar
Timothy J. Baek committed
349
    form_data: UpdateModelForm,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
350
    user=Depends(get_verified_user),
Timothy J. Baek's avatar
Timothy J. Baek committed
351
352
353
354
355
356
357
358
):
    return set_model_handler(form_data.model)


class GenerateImageForm(BaseModel):
    model: Optional[str] = None
    prompt: str
    n: int = 1
Timothy J. Baek's avatar
Timothy J. Baek committed
359
    size: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
360
361
362
    negative_prompt: Optional[str] = None


Timothy J. Baek's avatar
Timothy J. Baek committed
363
364
def save_b64_image(b64_str):
    try:
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
365
        image_id = str(uuid.uuid4())
Timothy J. Baek's avatar
Timothy J. Baek committed
366

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
367
368
369
        if "," in b64_str:
            header, encoded = b64_str.split(",", 1)
            mime_type = header.split(";")[0]
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
370

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
            img_data = base64.b64decode(encoded)
            image_format = mimetypes.guess_extension(mime_type)

            image_filename = f"{image_id}{image_format}"
            file_path = IMAGE_CACHE_DIR / f"{image_filename}"
            with open(file_path, "wb") as f:
                f.write(img_data)
            return image_filename
        else:
            image_filename = f"{image_id}.png"
            file_path = IMAGE_CACHE_DIR.joinpath(image_filename)

            img_data = base64.b64decode(b64_str)

            # Write the image data to a file
            with open(file_path, "wb") as f:
                f.write(img_data)
            return image_filename
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
389

Timothy J. Baek's avatar
Timothy J. Baek committed
390
    except Exception as e:
391
        log.exception(f"Error saving image: {e}")
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
392
        return None
Timothy J. Baek's avatar
Timothy J. Baek committed
393
394


Timothy J. Baek's avatar
Timothy J. Baek committed
395
396
397
398
399
def save_url_image(url):
    image_id = str(uuid.uuid4())
    try:
        r = requests.get(url)
        r.raise_for_status()
400
401
402
403
        if r.headers["content-type"].split("/")[0] == "image":

            mime_type = r.headers["content-type"]
            image_format = mimetypes.guess_extension(mime_type)
Timothy J. Baek's avatar
Timothy J. Baek committed
404

405
406
407
            if not image_format:
                raise ValueError("Could not determine image type from MIME type")

Timothy J. Baek's avatar
Timothy J. Baek committed
408
409
410
            image_filename = f"{image_id}{image_format}"

            file_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}")
411
412
413
            with open(file_path, "wb") as image_file:
                for chunk in r.iter_content(chunk_size=8192):
                    image_file.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
414
            return image_filename
415
416
        else:
            log.error(f"Url does not point to an image.")
Timothy J. Baek's avatar
Timothy J. Baek committed
417
            return None
Timothy J. Baek's avatar
Timothy J. Baek committed
418
419

    except Exception as e:
420
        log.exception(f"Error saving image: {e}")
Timothy J. Baek's avatar
Timothy J. Baek committed
421
        return None
Timothy J. Baek's avatar
Timothy J. Baek committed
422
423


Timothy J. Baek's avatar
Timothy J. Baek committed
424
425
@app.post("/generations")
def generate_image(
Timothy J. Baek's avatar
Timothy J. Baek committed
426
    form_data: GenerateImageForm,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
427
    user=Depends(get_verified_user),
Timothy J. Baek's avatar
Timothy J. Baek committed
428
):
429
    width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
Timothy J. Baek's avatar
Timothy J. Baek committed
430

Timothy J. Baek's avatar
Timothy J. Baek committed
431
    r = None
432
    try:
433
        if app.state.config.ENGINE == "openai":
434

Timothy J. Baek's avatar
Timothy J. Baek committed
435
            headers = {}
436
            headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
Timothy J. Baek's avatar
Timothy J. Baek committed
437
            headers["Content-Type"] = "application/json"
438

Timothy J. Baek's avatar
Timothy J. Baek committed
439
            data = {
440
441
442
443
444
                "model": (
                    app.state.config.MODEL
                    if app.state.config.MODEL != ""
                    else "dall-e-2"
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
445
446
                "prompt": form_data.prompt,
                "n": form_data.n,
447
                "size": (
448
                    form_data.size if form_data.size else app.state.config.IMAGE_SIZE
449
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
450
451
                "response_format": "b64_json",
            }
452

Timothy J. Baek's avatar
Timothy J. Baek committed
453
            r = requests.post(
454
                url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
Timothy J. Baek's avatar
Timothy J. Baek committed
455
456
457
                json=data,
                headers=headers,
            )
458

Timothy J. Baek's avatar
Timothy J. Baek committed
459
460
            r.raise_for_status()
            res = r.json()
Timothy J. Baek's avatar
Timothy J. Baek committed
461

Timothy J. Baek's avatar
Timothy J. Baek committed
462
463
464
            images = []

            for image in res["data"]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
465
466
                image_filename = save_b64_image(image["b64_json"])
                images.append({"url": f"/cache/image/generations/{image_filename}"})
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
467
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
468
469
470
471
472
473

                with open(file_body_path, "w") as f:
                    json.dump(data, f)

            return images

474
        elif app.state.config.ENGINE == "comfyui":
Timothy J. Baek's avatar
Timothy J. Baek committed
475
476
477
478
479
480
481
482

            data = {
                "prompt": form_data.prompt,
                "width": width,
                "height": height,
                "n": form_data.n,
            }

483
484
            if app.state.config.IMAGE_STEPS is not None:
                data["steps"] = app.state.config.IMAGE_STEPS
Timothy J. Baek's avatar
Timothy J. Baek committed
485

486
            if form_data.negative_prompt is not None:
Timothy J. Baek's avatar
Timothy J. Baek committed
487
488
                data["negative_prompt"] = form_data.negative_prompt

489
490
491
492
493
494
495
496
497
498
499
500
            if app.state.config.COMFYUI_CFG_SCALE:
                data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE

            if app.state.config.COMFYUI_SAMPLER is not None:
                data["sampler"] = app.state.config.COMFYUI_SAMPLER

            if app.state.config.COMFYUI_SCHEDULER is not None:
                data["scheduler"] = app.state.config.COMFYUI_SCHEDULER

            if app.state.config.COMFYUI_SD3 is not None:
                data["sd3"] = app.state.config.COMFYUI_SD3

Timothy J. Baek's avatar
Timothy J. Baek committed
501
502
503
            data = ImageGenerationPayload(**data)

            res = comfyui_generate_image(
504
                app.state.config.MODEL,
Timothy J. Baek's avatar
Timothy J. Baek committed
505
506
                data,
                user.id,
507
                app.state.config.COMFYUI_BASE_URL,
Timothy J. Baek's avatar
Timothy J. Baek committed
508
            )
509
            log.debug(f"res: {res}")
Timothy J. Baek's avatar
Timothy J. Baek committed
510
511
512
513

            images = []

            for image in res["data"]:
Timothy J. Baek's avatar
Timothy J. Baek committed
514
515
516
                image_filename = save_url_image(image["url"])
                images.append({"url": f"/cache/image/generations/{image_filename}"})
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
517
518
519
520

                with open(file_body_path, "w") as f:
                    json.dump(data.model_dump(exclude_none=True), f)

521
            log.debug(f"images: {images}")
Timothy J. Baek's avatar
Timothy J. Baek committed
522
            return images
Timothy J. Baek's avatar
Timothy J. Baek committed
523
524
525
526
527
528
529
530
531
532
533
        else:
            if form_data.model:
                set_model_handler(form_data.model)

            data = {
                "prompt": form_data.prompt,
                "batch_size": form_data.n,
                "width": width,
                "height": height,
            }

534
535
            if app.state.config.IMAGE_STEPS is not None:
                data["steps"] = app.state.config.IMAGE_STEPS
Timothy J. Baek's avatar
Timothy J. Baek committed
536

537
            if form_data.negative_prompt is not None:
Timothy J. Baek's avatar
Timothy J. Baek committed
538
539
540
                data["negative_prompt"] = form_data.negative_prompt

            r = requests.post(
541
                url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
Timothy J. Baek's avatar
Timothy J. Baek committed
542
                json=data,
543
                headers={"authorization": get_automatic1111_api_auth()},
Timothy J. Baek's avatar
Timothy J. Baek committed
544
545
546
547
            )

            res = r.json()

548
            log.debug(f"res: {res}")
Timothy J. Baek's avatar
Timothy J. Baek committed
549
550
551
552

            images = []

            for image in res["images"]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
553
554
                image_filename = save_b64_image(image)
                images.append({"url": f"/cache/image/generations/{image_filename}"})
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
555
                file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_filename}.json")
Timothy J. Baek's avatar
Timothy J. Baek committed
556
557
558
559
560

                with open(file_body_path, "w") as f:
                    json.dump({**data, "info": res["info"]}, f)

            return images
561
562

    except Exception as e:
563
564
565
566
567
568
569
        error = e

        if r != None:
            data = r.json()
            if "error" in data:
                error = data["error"]["message"]
        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))