# -- coding: utf-8 --
#!/usr/bin/env python
import gradio as gr
from PIL import Image
import sys
import os
sys.path.append(os.getcwd())
import json
import numpy as np
from pathlib import Path
import io
import hashlib
import requests
import base64
import pandas as pd
from sample_t2i import inferencer
from mllm.dialoggen_demo import init_dialoggen_model, eval_model
SIZES = {
"正方形(square, 1024x1024)": (1024, 1024),
"风景(landscape, 1280x768)": (768, 1280),
"人像(portrait, 768x1280)": (1280, 768),
}
global_seed = np.random.randint(0, 10000)
# Helper Functions
def image_to_base64(image_path):
with open(image_path, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode()
return encoded_image
def get_strings(lang):
lang_file = Path(f"app/lang/{lang}.csv")
strings = pd.read_csv(lang_file, header=0)
strings = strings.set_index("key")["value"].to_dict()
return strings
def get_image_md5(image):
image_data = io.BytesIO()
image.save(image_data, format="PNG")
image_data = image_data.getvalue()
md5_hash = hashlib.md5(image_data).hexdigest()
return md5_hash
# mllm调用
def request_dialogGen(
server_url="http://0.0.0.0:8080",
history_messages=[],
question="画一个木制的鸟",
image="",
):
if image != "":
image = base64.b64encode(open(image, "rb").read()).decode()
print("history_messages before request", history_messages)
headers = {"accept": "application/json", "Content-Type": "application/json"}
data = {
"text": question,
"image": image, # "image为空字符串,则进行文本对话"
"history": history_messages,
}
response = requests.post(server_url, headers=headers, json=data)
print("response", response)
response = response.json()
print(response)
response_text = response["result"]
history_messages = response["history"]
print("history_messages before request", history_messages)
return history_messages, response_text
# 画图
def image_generation(prompt, infer_steps, seed, image_size):
print(
f"prompt sent to T2I model: {prompt}, infer_steps: {infer_steps}, seed: {seed}, size: {image_size}"
)
height, width = SIZES[image_size]
results = gen.predict(
prompt,
height=height,
width=width,
seed=seed,
infer_steps=infer_steps,
batch_size=1,
)
image = results["images"][0]
file_name = get_image_md5(image)
# Save images
save_dir = Path("results")
save_dir.mkdir(exist_ok=True)
save_path = f"results/multiRound_{file_name}.png"
image.save(save_path)
encoded_image = image_to_base64(save_path)
return encoded_image
# 图文对话
def chat(history_messages, input_text):
history_messages, response_text = request_dialogGen(
history_messages=history_messages, question=input_text
)
return history_messages, response_text
#
def pipeline(input_text, state, infer_steps, seed, image_size):
# 忽略空输入
if len(input_text) == 0:
return state, state[0]
conversation = state[0]
history_messages = state[1]
system_prompt = "请先判断用户的意图,若为画图则在输出前加入<画图>:"
print(f"input history:{history_messages}")
if not isinstance(history_messages, list) and len(history_messages.messages) >= 2:
response, history_messages = enhancer(
input_text, return_history=True, history=history_messages, skip_special=True
)
else:
response, history_messages = enhancer(
input_text,
return_history=True,
history=history_messages,
skip_special=False,
)
history_messages.messages[-1][-1] = response
if "<画图>" in response:
intention_draw = True
else:
intention_draw = False
print(f"response:{response}")
print("-" * 80)
print(f"history_messages:{history_messages}")
print(f"intention_draw:{intention_draw}")
if intention_draw:
prompt = response.split("<画图>")[-1]
# 画图
image_url = image_generation(prompt, infer_steps, seed, image_size)
response = f'
{prompt}
' conversation += [((input_text, response))] return [conversation, history_messages], conversation # 页面设计 def upload_image(state, image_input): conversation = state[0] history_messages = state[1] input_image = Image.open(image_input.name).resize((224, 224)).convert("RGB") input_image.save(image_input.name) # Overwrite with smaller image. system_prompt = "请先判断用户的意图,若为画图则在输出前加入<画图>:" history_messages, response = request_dialogGen( question="这张图描述了什么?", history_messages=history_messages, image=image_input.name, ) conversation += [ ( f'powered by DialogGen and HunyuanDiT
""" ) text_input.submit( pipeline, [text_input, gr_state, infer_steps, seed, size_dropdown], [gr_state, chatbot], ) text_input.submit(lambda: "", None, text_input) # Reset chatbox. submit_btn.click( pipeline, [text_input, gr_state, infer_steps, seed, size_dropdown], [gr_state, chatbot], ) submit_btn.click(lambda: "", None, text_input) # Reset chatbox. # image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot]) clear_last_btn.click(reset_last, [gr_state], [gr_state, chatbot]) clear_btn.click(reset, [], [gr_state, chatbot]) interface = demo interface.launch(server_name="0.0.0.0", server_port=443, share=False)