Unverified Commit de62153d authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #1815 from Yanyutin753/new-dev

 expend the image format type after the file is downloaded
parents 3f0fae1d 80413a76
...@@ -24,6 +24,7 @@ from utils.misc import calculate_sha256 ...@@ -24,6 +24,7 @@ from utils.misc import calculate_sha256
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from pathlib import Path from pathlib import Path
import mimetypes
import uuid import uuid
import base64 import base64
import json import json
...@@ -315,38 +316,50 @@ class GenerateImageForm(BaseModel): ...@@ -315,38 +316,50 @@ class GenerateImageForm(BaseModel):
def save_b64_image(b64_str): def save_b64_image(b64_str):
image_id = str(uuid.uuid4())
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
try: try:
# Split the base64 string to get the actual image data header, encoded = b64_str.split(",", 1)
img_data = base64.b64decode(b64_str) mime_type = header.split(";")[0]
img_data = base64.b64decode(encoded)
# Write the image data to a file image_id = str(uuid.uuid4())
image_format = mimetypes.guess_extension(mime_type)
image_filename = f"{image_id}{image_format}"
file_path = IMAGE_CACHE_DIR / f"{image_filename}"
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(img_data) f.write(img_data)
return image_filename
return image_id
except Exception as e: except Exception as e:
log.error(f"Error saving image: {e}") log.exception(f"Error saving image: {e}")
return None return None
def save_url_image(url): def save_url_image(url):
image_id = str(uuid.uuid4()) image_id = str(uuid.uuid4())
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
try: try:
r = requests.get(url) r = requests.get(url)
r.raise_for_status() r.raise_for_status()
if r.headers["content-type"].split("/")[0] == "image":
mime_type = r.headers["content-type"]
image_format = mimetypes.guess_extension(mime_type)
if not image_format:
raise ValueError("Could not determine image type from MIME type")
with open(file_path, "wb") as image_file: file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}{image_format}")
image_file.write(r.content) with open(file_path, "wb") as image_file:
for chunk in r.iter_content(chunk_size=8192):
image_file.write(chunk)
return image_id, image_format
else:
log.error(f"Url does not point to an image.")
return None, None
return image_id
except Exception as e: except Exception as e:
log.exception(f"Error saving image: {e}") log.exception(f"Error saving image: {e}")
return None return None, None
@app.post("/generations") @app.post("/generations")
...@@ -385,8 +398,8 @@ def generate_image( ...@@ -385,8 +398,8 @@ def generate_image(
images = [] images = []
for image in res["data"]: for image in res["data"]:
image_id = save_b64_image(image["b64_json"]) image_filename = save_b64_image(image["b64_json"])
images.append({"url": f"/cache/image/generations/{image_id}.png"}) images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f: with open(file_body_path, "w") as f:
...@@ -422,8 +435,10 @@ def generate_image( ...@@ -422,8 +435,10 @@ def generate_image(
images = [] images = []
for image in res["data"]: for image in res["data"]:
image_id = save_url_image(image["url"]) image_id, image_format = save_url_image(image["url"])
images.append({"url": f"/cache/image/generations/{image_id}.png"}) images.append(
{"url": f"/cache/image/generations/{image_id}{image_format}"}
)
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f: with open(file_body_path, "w") as f:
...@@ -460,8 +475,8 @@ def generate_image( ...@@ -460,8 +475,8 @@ def generate_image(
images = [] images = []
for image in res["images"]: for image in res["images"]:
image_id = save_b64_image(image) image_filename = save_b64_image(image)
images.append({"url": f"/cache/image/generations/{image_id}.png"}) images.append({"url": f"/cache/image/generations/{image_filename}"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json") file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f: with open(file_body_path, "w") as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment