Commit e1deac83 authored by chenpangpang's avatar chenpangpang
Browse files

feat: 更新kolors到新版本,带参考图版本

parents 0f1bd6af 1b51b523
Pipeline #1592 canceled with stages
This diff is collapsed.
...@@ -331,7 +331,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -331,7 +331,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
position_ids=text_inputs['position_ids'], position_ids=text_inputs['position_ids'],
output_hidden_states=True) output_hidden_states=True)
prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
text_proj = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096] pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096]
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
...@@ -387,7 +387,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -387,7 +387,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
position_ids=uncond_input['position_ids'], position_ids=uncond_input['position_ids'],
output_hidden_states=True) output_hidden_states=True)
negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
negative_text_proj = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096] negative_pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() # [batch_size, 4096]
if do_classifier_free_guidance: if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
...@@ -409,15 +409,16 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -409,15 +409,16 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
# negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) # negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
negative_prompt_embeds = negative_prompt_embeds_list[0] negative_prompt_embeds = negative_prompt_embeds_list[0]
bs_embed = text_proj.shape[0] bs_embed = pooled_prompt_embeds.shape[0]
text_proj = text_proj.repeat(1, num_images_per_prompt).view( pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
negative_text_proj = negative_text_proj.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1 bs_embed * num_images_per_prompt, -1
) )
if do_classifier_free_guidance:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
return prompt_embeds, negative_prompt_embeds, text_proj, negative_text_proj return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta): def prepare_extra_step_kwargs(self, generator, eta):
......
deepspeed==0.8.1 accelerate
imageio==2.25.1 diffusers
omegaconf==2.3.0 invisible_watermark
diffusers==0.28.2 torch
gradio==3.38.0 transformers
matplotlib-inline sentencepiece
ipython gradio==3.40.0
import os, torch
# from PIL import Image
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from diffusers import UNet2DConditionModel, AutoencoderKL
from diffusers import EulerDiscreteScheduler
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def infer(prompt):
ckpt_dir = f'{root_dir}/weights/Kolors'
text_encoder = ChatGLMModel.from_pretrained(
f'{ckpt_dir}/text_encoder',
torch_dtype=torch.float16).half()
tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half()
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half()
pipe = StableDiffusionXLPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
force_zeros_for_empty_prompt=False)
pipe = pipe.to("cuda")
pipe.enable_model_cpu_offload()
image = pipe(
prompt=prompt,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=5.0,
num_images_per_prompt=1,
generator= torch.Generator(pipe.device).manual_seed(66)).images[0]
image.save(f'{root_dir}/scripts/outputs/sample_test.jpg')
if __name__ == '__main__':
import fire
fire.Fire(infer)
import os
import torch
import gradio as gr
# from PIL import Image
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from diffusers import UNet2DConditionModel, AutoencoderKL
from diffusers import EulerDiscreteScheduler
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Initialize global variables for models and pipeline
text_encoder = None
tokenizer = None
vae = None
scheduler = None
unet = None
pipe = None
def remove_folder(path):
if os.path.exists(path):
if os.path.isfile(path) or os.path.islink(path):
os.remove(path)
else:
for filename in os.listdir(path):
remove_folder(os.path.join(path, filename))
os.rmdir(path)
def load_models():
global text_encoder, tokenizer, vae, scheduler, unet, pipe
if text_encoder is None:
ckpt_dir = f'{root_dir}/weights/Kolors'
# Load the text encoder on CPU (this speeds stuff up 2x)
text_encoder = ChatGLMModel.from_pretrained(
f'{ckpt_dir}/text_encoder',
torch_dtype=torch.float16).to('cpu').half()
tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
# Load the VAE and UNet on GPU
vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to('cuda')
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to('cuda')
# Prepare the pipeline
pipe = StableDiffusionXLPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
force_zeros_for_empty_prompt=False)
pipe = pipe.to("cuda")
pipe.enable_model_cpu_offload() # Enable offloading to balance CPU/GPU usage
def infer(prompt, use_random_seed, seed, height, width, num_inference_steps, guidance_scale, num_images_per_prompt):
load_models()
if use_random_seed:
seed = torch.randint(0, 2 ** 32 - 1, (1,)).item()
generator = torch.Generator(pipe.device).manual_seed(seed)
images = pipe(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
generator=generator
).images
saved_images = []
output_dir = f'{root_dir}/scripts/outputs'
remove_folder(output_dir)
os.makedirs(output_dir, exist_ok=True)
for i, image in enumerate(images):
file_path = os.path.join(output_dir, 'sample_test.jpg')
base_name, ext = os.path.splitext(file_path)
counter = 1
while os.path.exists(file_path):
file_path = f"{base_name}_{counter}{ext}"
counter += 1
image.save(file_path)
saved_images.append(file_path)
return saved_images
def gradio_interface():
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown("## Kolors: Diffusion Model Gradio Interface")
prompt = gr.Textbox(label="Prompt")
use_random_seed = gr.Checkbox(label="Use Random Seed", value=True)
seed = gr.Slider(minimum=0, maximum=2 ** 32 - 1, step=1, label="Seed", randomize=True, visible=False)
use_random_seed.change(lambda x: gr.update(visible=not x), use_random_seed, seed)
height = gr.Slider(minimum=128, maximum=2048, step=64, label="Height", value=1024)
width = gr.Slider(minimum=128, maximum=2048, step=64, label="Width", value=1024)
num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, label="Inference Steps", value=50)
guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0)
num_images_per_prompt = gr.Slider(minimum=1, maximum=10, step=1, label="Images per Prompt", value=1)
btn = gr.Button("Generate Image")
with gr.Column():
output_images = gr.Gallery(label="Output Images", elem_id="output_gallery")
btn.click(
fn=infer,
inputs=[prompt, use_random_seed, seed, height, width, num_inference_steps, guidance_scale,
num_images_per_prompt],
outputs=output_images
)
return demo
if __name__ == '__main__':
load_models()
gradio_interface().launch(server_name='0.0.0.0')
from setuptools import setup, find_packages
setup(
name="kolors",
version="0.1",
author="Kolors",
description="The training and inference code for Kolors models.",
packages=find_packages(),
install_requires=[],
dependency_links=[],
)
# pip install huggingface-cli
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
model_list = [
"Kwai-Kolors/Kolors",
"Kwai-Kolors/Kolors-IP-Adapter-Plus"
]
os.system("pip install -U huggingface-hub")
for model_path in model_list:
os.system(
f"huggingface-cli download --resume-download {model_path} --local-dir ./{model_path} --local-dir-use-symlinks False")
#!/bin/bash
cd /root/Kolors
python app.py
...@@ -7,14 +7,16 @@ ...@@ -7,14 +7,16 @@
"tags": [] "tags": []
}, },
"source": [ "source": [
"## 说明\n", "## 项目介绍\n",
"\n",
"- 启动和重启 Notebook 点上方工具栏中的「重启并运行所有单元格」。出现 `http://0.0.0.0:7860` 这个字样就算成功了。可以去控制台打开「自定义服务」了\n",
"- 访问自定义服务端口号设置为7860\n",
"\n",
"## 功能介绍\n",
"- 原项目地址:https://github.com/Kwai-Kolors/Kolors\n", "- 原项目地址:https://github.com/Kwai-Kolors/Kolors\n",
"- 可图大模型是由快手可图团队开发的基于潜在扩散的大规模文本到图像生成模型。Kolors 在数十亿图文对下进行训练,在视觉质量、复杂语义理解、文字生成(中英文字符)等方面,相比于开源/闭源模型,都展示出了巨大的优势。同时,Kolors 支持中英双语,在中文特色内容理解方面更具竞争力。" "- 可图大模型是由快手可图团队开发的基于潜在扩散的大规模文本到图像生成模型。Kolors 在数十亿图文对下进行训练,在视觉质量、复杂语义理解、文字生成(中英文字符)等方面,相比于开源/闭源模型,都展示出了巨大的优势。同时,Kolors 支持中英双语,在中文特色内容理解方面更具竞争力。\n",
"## 使用说明\n",
"- 启动和重启 Notebook 点上方工具栏中的「重启并运行所有单元格」。出现如下内容就算成功了:\n",
" - `Running on local URL: http://0.0.0.0:7860`\n",
" - `Running on public URL: https://xxxxxxxxxxxxxxx.gradio.live`\n",
"- 通过以下方式开启页面:\n",
" - 控制台打开「自定义服务」了,访问自定义服务端口号设置为7860\n",
" - 直接打开显示的公开链接`public URL`\n"
] ]
}, },
{ {
...@@ -24,39 +26,34 @@ ...@@ -24,39 +26,34 @@
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"/opt/conda/envs/kolors/lib/python3.8/site-packages/gradio_client/documentation.py:105: UserWarning: Could not get documentation group for <class 'gradio.mix.Parallel'>: No known documentation group for module 'gradio.mix'\n",
" warnings.warn(f\"Could not get documentation group for {cls}: {exc}\")\n",
"/opt/conda/envs/kolors/lib/python3.8/site-packages/gradio_client/documentation.py:105: UserWarning: Could not get documentation group for <class 'gradio.mix.Series'>: No known documentation group for module 'gradio.mix'\n",
" warnings.warn(f\"Could not get documentation group for {cls}: {exc}\")\n",
"/opt/conda/envs/kolors/lib/python3.8/site-packages/diffusers/models/transformers/transformer_2d.py:34: FutureWarning: `Transformer2DModelOutput` is deprecated and will be removed in version 1.0.0. Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead.\n",
" deprecate(\"Transformer2DModelOutput\", \"1.0.0\", deprecation_message)\n",
"Loading checkpoint shards: 100%|██████████████████| 7/7 [00:16<00:00, 2.35s/it]\n",
"Running on local URL: http://0.0.0.0:7860\n",
"\n",
"To create a public link, set `share=True` in `launch()`.\n",
"IMPORTANT: You are using gradio version 3.38.0, however version 4.29.0 is available, please upgrade.\n",
"--------\n",
"100%|███████████████████████████████████████████| 50/50 [00:26<00:00, 1.90it/s]\n"
]
}
],
"source": [ "source": [
"# 启动\n", "# 启动\n",
"!bash run.sh" "!sh start.sh"
] ]
}, },
{
"cell_type": "markdown",
"source": [
"---\n",
"**扫码关注公众号,获取更多资讯**<br>\n",
"<div align=center>\n",
"<img src=\"assets/二维码.jpeg\" width = 20% />\n",
"</div>\n"
],
"metadata": {
"collapsed": false
},
"id": "2f54158c2967bc25"
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"id": "9e81ae9d-3a34-43a0-943a-ff5e9d6ce961",
"metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": [],
"metadata": {
"collapsed": false
},
"id": "6dc59fbbcf222b6b"
} }
], ],
"metadata": { "metadata": {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment