Commit d8864e63 authored by dengjb's avatar dengjb
Browse files

update codes

parent ba4486d8
from datetime import datetime
import os
import shutil
from .ui import StableDiffusionUI
from .utils import save_image_info
from .png_info_helper import deserialize_from_filename, InfoFormat
from IPython.display import clear_output, display
import ipywidgets as widgets
from ipywidgets import Layout,HBox,VBox,Box
# HBox/VBox应当仅用于【单行/单列】内容
# Box应当始终假定display不明
from . import views
from .views import Div, Tab, Bunch
class StableDiffusionUI_img2img(StableDiffusionUI):
def __init__(self, **kwargs):
super().__init__() #暂且不处理pipline
CLASS_NAME = self.__class__.__name__ \
+ '_{:X}'.format(hash(self))[-4:]
STYLE_SHEETS = '''
@media (max-width:576px) {
{root} .column_left,
{root} .column_right {
min-width: 240px !important; #解决X轴超出边框
}
{root} .column_right .panel01 {
max_height: 500px !important; #解决Y轴超出边框
}
{root} .p-TabPanel-tabContents.widget-tab-contents {
padding: 0.5rem 0.1rem !important;
}
{root} button.run_button,
{root} button.collect_button {
_width: 45% !important;
}
{root} .seed {
min-width = '6rem';
}
}
/* 强度拖动条 */
{root} .strength.widget-slider > .widget-readout {
border: var(--jp-border-width) solid var(--jp-border-color1) !important;
}
@media (max-width:576px) {
{root} .strength.widget-slider {
flex: 4 4 30% !important;
min-width: 6rem !important;
}
{root} .strength.widget-slider > .slider-container {
display: none !important; /* 手机无法拖动 */
}
}
'''
self.task = 'img2img'
#默认参数覆盖次序:
#user_config.py > config.py > 当前args > views.py
args = { #注意无效Key错误
"prompt": '',
"negative_prompt": '',
"width": -1,
"height": -1,
"num_return_images": 1,
"strength": 0.8,
"image_path": '',
"mask_path": '',
"upload_image_path": 'resources/upload.png',
"upload_mask_path": 'resources/upload-mask.png',
}
args.update(kwargs)
args['num_return_images'] = 1 #不支持批量生成
# widget_opt = self.widget_opt
#生成主要控件
self._generateControls(args)
#生成左侧
self._renderColumnLeft(args)
#生成右侧
self._renderColumnRight(args)
# 样式表
STYLE_SHEETS = ('<style>' \
+ views.SHARED_STYLE_SHEETS \
+ STYLE_SHEETS \
+ self.view_prompts.style_sheets \
+ self.view_width_height.style_sheets \
+ '</style>'
).replace('{root}', '.' + CLASS_NAME)
html_css = widgets.HTML(STYLE_SHEETS)
html_css.layout.display = 'none'
box_gui = Box([
html_css,
self._column_left,
self._column_right,
],
layout = Layout(
display = "flex",
flex_flow = "row wrap", #HBox会覆写此属性
max_width = '100%',
),
)
box_gui.add_class(CLASS_NAME)
clear_output()
self.gui = box_gui
# 生成主要控件
def _generateControls(self, args):
widget_opt = self.widget_opt
# 提示词部分
view_prompts = views.createPromptsView(
value = args['prompt'],
negative_value = args['negative_prompt'],
)
widget_opt['prompt'] = view_prompts['prompt']
widget_opt['negative_prompt'] = view_prompts['negative_prompt']
self.view_prompts = view_prompts
# 图片尺寸部分
view_width_height = views.createWidthHeightView(
width_value = args['width'],
height_value = args['width'],
)
widget_opt['width'] = view_width_height['width']
widget_opt['height'] = view_width_height['height']
self.view_width_height = view_width_height
# 强度
widget_opt['strength'] = widgets.FloatSlider(
style={
'description_width': "4rem"
},
description='修改强度',
description_tooltip='修改图片的强度',
value=args['strength'],
min=0.01,
max=0.99,
step=0.01,
readout=True,
readout_format='.2f',
orientation='horizontal',
disabled=False,
continuous_update=False
)
widget_opt['strength'].add_class('strength')
views.setLayout('col08', widget_opt['strength'])
for key in (
'num_return_images',
'num_inference_steps',
'guidance_scale',
'seed',
'output_dir',
'sampler',
'model_name',
'concepts_library_dir',
):
widget_opt[key] = views.createView(key)
if key in args:
widget_opt[key].value = args[key]
for key in (
'enable_parsing',
'max_embeddings_multiples',
'superres_model_name',
'fp16',
):
widget_opt[key] = views.createView(
key,
layout_name = 'col06',
)
if key in args:
widget_opt[key].value = args[key]
widget_opt['seed'].layout.min_width = '8rem'
# 按钮x2
self.run_button = views.createView('run_button')
self.collect_button = views.createView('collect_button')
self._output_collections = []
self.collect_button.on_click(self.on_collect_button_click)
self.run_button.on_click(self.on_run_button_click)
# 事件处理绑定
def on_seed_change(change):
if change.new != -1:
widget_opt['num_return_images'].value = 1
def on_num_return_images(change):
if change.new != 1:
widget_opt['seed'].value = -1
widget_opt['seed'].observe(on_seed_change, names='value')
widget_opt['num_return_images'].observe(on_num_return_images, names='value')
# 构建视图
def _renderColumnRight(self, args):
widget_opt = self.widget_opt
views.setLayout('col12', self.view_prompts.container)
widget_opt['sampler'].layout.min_width = '10rem'
_panel_layout = Layout(
display = 'flex',
flex_flow = 'row wrap', #HBox会覆写此属性
max_width = '100%',
min_height = '360px',
max_height = '455px',
# height = '455px',
align_items = 'center',
align_content = 'center',
)
panel01 = Box(
layout = _panel_layout,
children = (
self.view_prompts.container,
widget_opt['strength'],
widget_opt['seed'],
widget_opt['num_inference_steps'],
widget_opt['guidance_scale'],
widget_opt['sampler'],
widget_opt['model_name'],
),
)
panel02 = Box(
layout = _panel_layout,
children = (
self.view_width_height.container,
widget_opt['superres_model_name'],
widget_opt['fp16'],
widget_opt['enable_parsing'],
widget_opt['max_embeddings_multiples'],
widget_opt['output_dir'],
widget_opt['concepts_library_dir'],
),
)
panel03 = Box(
layout = _panel_layout,
children = (
self.run_button_out,
),
)
self.run_button_out.layout.width = '100%'
self.run_button_out.layout.margin = '0'
self.run_button_out.align_self = 'stretch'
panel01.add_class('panel01')
tab_right = Tab(
titles = ('参数','其他','输出'),
children = (
panel01,
panel02,
panel03,
),
layout = Layout(
# flex = '1 1 360px',
margin = '0',
)
)
column_right = Div(
children = [
tab_right,
HBox(
(self.run_button, self.collect_button,),
layout = Layout(
justify_content = 'space-around',
align_centent = 'center',
height = '45px',
)
),
],
layout = Layout(
flex = '1 1 300px',
min_width = '300px',
margin = '0.5rem 0',
)
)
column_right.add_class('column_right')
self._column_right = column_right
self._tab_right = tab_right
def _renderColumnLeft(self, args):
widget_opt = self.widget_opt
#--------------------------------------------------
# ImportPanel
view_upload = _createUploadView(
label = '输入图片',
tooltip = '选择一张图片开始图生图',
default_path = args['image_path'],
upload_path = args['upload_image_path'],
text = '选择一张图片作为图生图的原始图片。你可以选择云端的文件,或者上传图片。',
)
view_upload_mask = _createUploadView(
label = '蒙版图片',
tooltip = '选择一张图片限定重绘范围',
default_path = args['mask_path'],
upload_path = args['upload_mask_path'],
text = '选择一张图片作为蒙版,用白色区域限定图片重绘的范围。折叠此面板时不启用蒙版。',
)
widget_opt['image_path'] = view_upload.input
widget_opt['mask_path'] = view_upload_mask.input
self.uploador = view_upload.uploador
btn_confirm = widgets.Button(
description='导 入',
disabled=False,
button_style='success',
layout=Layout(
flex = '0 1 auto',
),
)
btn_reset = widgets.Button(
description='重 置',
disabled=False,
button_style='warning',
layout=Layout(
flex = '0 1 auto',
),
)
btn_reset.add_class('btnV5')
btn_reset.add_class('btn-small')
btn_confirm.add_class('btnV5')
btn_confirm.add_class('btn-small')
accordion = widgets.Accordion([
view_upload_mask.container,
],
layout = Layout(
margin = '0.5rem 0',
max_width = '100%',
)
)
accordion.set_title(0, '启用蒙版')
accordion.selected_index = None
panel_import = HBox(
layout = Layout(
height = '100%',
max_width = '100%',
max_height = '500px',
min_height = '360px',
align_items = "center",
justify_content = 'center',
),
children = [
Div([
view_upload.container,
accordion,
HBox(
(btn_confirm, btn_reset),
layout = Layout(
justify_content = 'space-around',
max_width = '100%',
)
),
]
),
],
)
#--------------------------------------------------
# 其他Panel以及Tab
view_image = createPanelImage(args['image_path'])
view_image_mask = createPanelImage(args['mask_path'])
view_image_output = createPanelImage()
tab_left = Tab(
titles = ('导入','原图','蒙版','输出'),
children = (
panel_import,
view_image.container,
view_image_mask.container,
view_image_output.container,
),
layout = Layout(
flex = '1 1 300px',
# max_width = '360px',
min_width = '300px',
margin = '0.5rem 0',
)
)
tab_left.add_class('column_left')
#--------------------------------------------------
# 处理事件
def whether_use_mask():
return accordion.selected_index == 0
def on_reset_button_click(b):
with self.run_button_out:
view_upload.reset()
view_upload_mask.reset()
def on_conform_button_click(b):
with self.run_button_out:
path = view_upload.confirm()
if not view_image.set_file(path): raise IOError('未能读取文件:'+path)
# 试图从图片更新prompt信息
self._update_prompt_from_image(path)
self._tab_right.selected_index = 0
if whether_use_mask():
path = view_upload_mask.confirm()
if not view_image_mask.set_file(path): raise IOError('未能读取文件:'+path)
tab_left.selected_index = 2
else:
view_image_mask.set_file()
tab_left.selected_index = 1
view_image_output.set_file()
return
btn_reset.on_click(on_reset_button_click)
btn_confirm.on_click(on_conform_button_click)
self.is_inpaint_task = whether_use_mask
self._column_left = tab_left
self._tab_left = tab_left
self._set_output_image = view_image_output.set_file #不是class所以不用self
return
def on_collect_button_click(self, b):
with self.run_button_out:
dir = datetime.now().strftime(f'Favorates/{self.task}-%m%d/')
info = '收藏图片到 ' + dir
dir = './' + dir
os.makedirs(dir, exist_ok=True)
for file in self._output_collections:
if os.path.isfile(file):
shutil.move(file, dir)
print(info + os.path.basename(file))
file = file[:-4] + '.txt'
if os.path.isfile(file):
shutil.move(file, dir)
self._output_collections.clear()
self.collect_button.disabled = True
def on_run_button_click(self, b):
with self.run_button_out:
self._output_collections.clear()
self.collect_button.disabled = True
self.run_button.disabled = True
self.task = 'img2img' if not self.is_inpaint_task() else 'inpaint'
self._tab_left.selected_index = 1
self._tab_right.selected_index = 2
try:
super().on_run_button_click(b)
finally:
self.run_button.disabled = False
self.collect_button.disabled = len(self._output_collections) < 1
def on_image_generated(self, image, options, count = 0, total = 1, image_info = None):
image_path = save_image_info( image, options.output_dir, image_info)
self._output_collections.append(image_path)
self._set_output_image(image_path)
self._tab_left.selected_index = 3
if count % 5 == 0:
clear_output()
print('> Seed = ' + str(image.argument["seed"]))
print('> ' + image_path)
print(' (%d / %d ... %.2f%%)'%(count + 1, total, (count + 1.) / total * 100))
def _update_prompt_from_image(self, path):
info, fmt = deserialize_from_filename(path)
if fmt is InfoFormat.Unknown: return False
for key in ('prompt','negative_prompt', 'seed'):
if key in info:
self.widget_opt[key].value = info[key]
if 'max_embeddings_multiples' in info:
self.widget_opt['max_embeddings_multiples'].value = str(info['max_embeddings_multiples'])
# 检查括号格式
if fmt is InfoFormat.WebUI:
self.widget_opt['enable_parsing'].value = '圆括号 () 加强权重'
elif fmt is InfoFormat.NAIFU:
self.widget_opt['enable_parsing'].value = '花括号 {} 加权权重'
elif 'prompt' in info:
c1 = info['prompt'].count('(') + info['prompt'].count(')')
c2 = info['prompt'].count('{') + info['prompt'].count('}')
self.widget_opt['enable_parsing'].value = '花括号 {} 加权权重' if (c2 > c1 + 1) else '圆括号 () 加强权重'
return True
def _createUploadView(
label = '输入图片',
tooltip = '需要转换的图片的路径',
default_path = 'resources/Ring.png',
upload_path = 'resources/upload.png',
text = ''):
input = widgets.Text(
style={ 'description_width': "4rem" },
description = label,
description_tooltip = tooltip,
value = default_path,
)
upload = widgets.FileUpload(
accept = '.png,.jpg,.jpeg',
description = '上传图片',
layout = Layout(
padding = '0.5rem',
height = 'auto',
)
)
description = widgets.HTML(
text,
)
views.setLayout('col08', input)
views.setLayout('col12', upload)
views.setLayout('col12', description)
input.layout.margin = '0'
container = Box([
description,
input,
upload,
], layout = Layout(
display = 'flex',
flex_flow = 'row wrap',
max_width = '100%',
))
# views.setLayout('col12', input_image_path)
# input_image_path.add_class('image_path')
def reset():
input.value = default_path
try:
upload.value = ()
except:
pass
def confirm():
# 【注意】v8.0与7.5的value结构不同
for name in upload.value:
dict = upload.value[name]
#检查文件类型
path = upload_path
if dict['metadata']['type'] == 'image/jpeg':
path = upload_path.partition('.')[0] + '.jpg'
elif dict['metadata']['type'] == 'image/png':
path = upload_path.partition('.')[0] + '.png'
print('保存上传到:'+path)
with open(path, 'wb') as file:
file.write(dict['content'])
upload.value.clear()
input.value = path
break
return input.value
return Bunch({
'container': container,
'input': input,
'reset': reset,
'confirm': confirm,
'uploador': upload,
})
def createPanelImage(filename = None):
layout = Layout(
object_fit = 'contain',
#object_position = 'center center',
margin = '0 0 0 0',
max_height = '500px',
)
_None_Image = widgets.HTML('未选中图片或无效的图片')
container = HBox(
layout = Layout(
max_height = '500px',
min_height = '360px',
align_items = 'center',
align_centent = 'center',
justify_content = 'center',
),
)
def set_file(filename = None):
if filename is None or not os.path.isfile(filename):
if _None_Image not in container.children:
container.children = (_None_Image,)
return False
else:
img = widgets.Image.from_file(filename)
img.layout = layout
container.children = (img,)
return True
set_file(filename)
return Bunch({
'container': container,
'set_file': set_file,
})
\ No newline at end of file
from datetime import datetime
import os
import shutil
from .ui import StableDiffusionUI
from .utils import save_image_info
from IPython.display import clear_output, display
import ipywidgets as widgets
from ipywidgets import Layout,HBox,VBox,Box
from . import views
class StableDiffusionUI_txt2img(StableDiffusionUI):
def __init__(self, **kwargs):
super().__init__() #暂且不处理pipline
CLASS_NAME = self.__class__.__name__ \
+ '_{:X}'.format(hash(self))[-4:]
STYLE_SHEETS = '''
@media (max-width:576px) {
{root} .standard_size,
{root} .superres_model_name {
order: -1;
}
{root} button.run_button,
{root} button.collect_button {
width: 45% !important;
}
}
'''
#默认参数覆盖次序:
#user_config.py > config.py > 当前args > views.py
args = { #注意无效Key错误
"prompt": '',
"negative_prompt": '',
"width": 512,
"height": 512,
}
args.update(kwargs)
widget_opt = self.widget_opt
# 提示词部分
view_prompts = views.createPromptsView(
value = args['prompt'],
negative_value = args['negative_prompt'],
)
widget_opt['prompt'] = view_prompts['prompt']
widget_opt['negative_prompt'] = view_prompts['negative_prompt']
# 图片尺寸部分
view_width_height = views.createWidthHeightView(
width_value = args['width'],
height_value = args['height'],
step64 = True,
)
widget_opt['width'] = view_width_height['width']
widget_opt['height'] = view_width_height['height']
for key in (
'standard_size',
'num_return_images',
'enable_parsing',
'num_inference_steps',
'guidance_scale',
'max_embeddings_multiples',
'fp16',
'seed',
'superres_model_name',
'output_dir',
'sampler',
'model_name',
'concepts_library_dir'
):
widget_opt[key] = views.createView(key)
if key in args:
widget_opt[key].value = args[key]
# 事件处理绑定
def on_standard_size_change(change):
widget_opt['width'].value = change.new // 10000
widget_opt['height'].value = change.new % 10000
widget_opt['standard_size'].observe(
on_standard_size_change,
names = 'value'
)
def on_seed_change(change):
if change.new != -1:
widget_opt['num_return_images'].value = 1
def on_num_return_images(change):
if change.new != 1:
widget_opt['seed'].value = -1
widget_opt['seed'].observe(on_seed_change, names='value')
widget_opt['num_return_images'].observe(on_num_return_images, names='value')
# 按钮x2
self.run_button = views.createView('run_button')
self.collect_button = views.createView('collect_button')
self._output_collections = []
self.run_button.on_click(self.on_run_button_click)
self.collect_button.on_click(self.on_collect_button_click)
# 样式表
STYLE_SHEETS = ('<style>' \
+ views.SHARED_STYLE_SHEETS \
+ STYLE_SHEETS \
+ view_prompts.style_sheets \
+ view_width_height.style_sheets \
+ '</style>'
).replace('{root}', '.' + CLASS_NAME)
#
self.gui = views.createView("box_gui",
class_name = CLASS_NAME,
children = [
widgets.HTML(STYLE_SHEETS),
view_prompts.container,
views.createView("box_main",
[
widget_opt['standard_size'],
view_width_height.container,
widget_opt['superres_model_name'],
widget_opt['num_inference_steps'],
widget_opt['guidance_scale'],
widget_opt['sampler'],
widget_opt['num_return_images'],
widget_opt['seed'],
widget_opt['enable_parsing'],
widget_opt['max_embeddings_multiples'],
widget_opt['fp16'],
widget_opt['model_name'],
widget_opt['output_dir'],
widget_opt['concepts_library_dir']
]),
HBox(
(self.run_button,self.collect_button,),
layout = Layout(
justify_content = 'space-around',
max_width = '100%',
)
),
self.run_button_out
],
)
def on_collect_button_click(self, b):
with self.run_button_out:
dir = datetime.now().strftime(f'Favorates/{self.task}-%m%d/')
info = '收藏图片到 ' + dir
dir = './' + dir
os.makedirs(dir, exist_ok=True)
for file in self._output_collections:
if os.path.isfile(file):
shutil.move(file, dir)
print(info + os.path.basename(file))
file = file[:-4] + '.txt'
if os.path.isfile(file):
shutil.move(file, dir)
self._output_collections.clear()
self.collect_button.disabled = True
def on_run_button_click(self, b):
with self.run_button_out:
self._output_collections.clear()
self.collect_button.disabled = True
self.run_button.disabled = True
try:
super().on_run_button_click(b)
finally:
self.run_button.disabled = False
self.collect_button.disabled = len(self._output_collections) < 1
def on_image_generated(self, image, options, count = 0, total = 1, image_info = None):
image_path = save_image_info(image, options.output_dir)
self._output_collections.append(image_path)
if count % 5 == 0:
clear_output()
try:
# 使显示的图片包含嵌入信息
display(widgets.Image.from_file(image_path))
except:
display(image)
print('Seed = ', image.argument['seed'],
' (%d / %d ... %.2f%%)'%(count + 1, total, (count + 1.) / total * 100))
from .ui import StableDiffusionUI
import ipywidgets as widgets
from ipywidgets import Layout,HBox,VBox,Box
class SuperResolutionUI(StableDiffusionUI):
def __init__(self, pipeline, **kwargs):
super().__init__(pipeline = pipeline)
self.task = 'superres'
#默认参数覆盖次序:
#user_config.py > config.py > 当前args > views.py
args = { #注意无效Key错误
"image_path": 'resources/image_Kurisu.png',
"superres_model_name": 'falsr_a',
"output_dir": 'outputs/highres',
}
args.update(kwargs)
widget_opt = self.widget_opt
layoutCol12 = Layout(
flex = "12 12 90%",
margin = "0.5em",
align_items = "center"
)
styleDescription = {
'description_width': "9rem"
}
widget_opt['image_path'] = widgets.Text(
layout=layoutCol12, style=styleDescription,
description='需要超分的图片路径' ,
value=args['image_path'],
disabled=False
)
widget_opt['superres_model_name'] = widgets.Dropdown(
layout=layoutCol12, style=styleDescription,
description='超分模型的名字',
value=args['superres_model_name'],
options=["falsr_a", "falsr_b", "falsr_c"],
disabled=False
)
widget_opt['output_dir'] = widgets.Text(
layout=layoutCol12, style=styleDescription,
description='图片的保存路径',
value=args['output_dir'],
disabled=False
)
self.run_button = widgets.Button(
description='点击超分图片!',
disabled=False,
button_style='success', # 'success', 'info', 'warning', 'danger' or ''
tooltip='Click to run (settings will update automatically)',
icon='check'
)
self.run_button.on_click(self.on_run_button_click)
self.gui = widgets.Box([
widget_opt['image_path'],
widget_opt['superres_model_name'],
widget_opt['output_dir'],
self.run_button,
self.run_button_out
], layout = Layout(
display = "flex",
flex_flow = "row wrap", #HBox会覆写此属性
align_items = "center",
# max_width = '100%',
margin="0 45px 0 0"
))
from .env import DEBUG_UI
from .config import config
if DEBUG_UI:
print('==================================================')
print('调试环境')
print('==================================================')
from .StableDiffusionUI_txt2img import StableDiffusionUI_txt2img
from .StableDiffusionUI_img2img import StableDiffusionUI_img2img
gui_txt2img = StableDiffusionUI_txt2img(**config['txt2img'])
gui_img2img = StableDiffusionUI_img2img(**config['img2img'])
gui_inpaint = gui_img2img
else:
from .ui import (
StableDiffusionUI_text_inversion,
StableDiffusionUI_dreambooth,
pipeline_superres,
pipeline,
StableDiffusionUI_convert,
)
from .StableDiffusionUI_txt2img import StableDiffusionUI_txt2img
from .StableDiffusionUI_img2img import StableDiffusionUI_img2img
from .SuperResolutionUI import SuperResolutionUI
gui_txt2img = StableDiffusionUI_txt2img(
**config['txt2img']
)
gui_img2img = StableDiffusionUI_img2img(
**config['img2img']
)
gui_superres = SuperResolutionUI(
pipeline = pipeline_superres,
**config['superres']
)
gui_train_text_inversion = StableDiffusionUI_text_inversion(
**config['train_text_inversion']
)
gui_text_inversion = StableDiffusionUI_txt2img(
**config['text_inversion']
)
gui_dreambooth = StableDiffusionUI_dreambooth( #dreamboothUI
**config['dreambooth']
)
gui_convert = StableDiffusionUI_convert(
**config['convert']
)
gui_inpaint = gui_img2img
\ No newline at end of file
config = {
"txt2img": {
"prompt": 'extremely detailed CG unity 8k wallpaper,black long hair,cute face,1 adult girl,happy, green skirt dress, flower pattern in dress,solo,green gown,art of light novel,in field',
"negative_prompt": 'lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry',
"width": 512,
"height": 512,
# "seed": -1,
# "num_return_images": 1,
# "num_inference_steps": 50,
# "guidance_scale": 7.5,
# "fp16": 'float16',
# "superres_model_name": '无',
# "max_embeddings_multiples": '3',
# "enable_parsing": '圆括号 () 加强权重',
# "sampler": 'default',
"model_name": 'MoososCap/NOVEL-MODEL',
"output_dir": 'outputs/txt2img',
},
"img2img": {
"prompt": 'red dress',
"negative_prompt": 'lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry',
"width": -1,
"height": -1,
"num_return_images": 1,
"strength": 0.8,
"model_name": 'MoososCap/NOVEL-MODEL',
"image_path": 'resources/cat2.jpg',
"mask_path": 'resources/mask8.jpg',
"output_dir": 'outputs/img2img',
},
"superres": {
"image_path": 'resources/image_Kurisu.png',
"superres_model_name": 'falsr_a',
"output_dir": 'outputs/highres',
},
"train_text_inversion": {
"learnable_property": 'object',
"placeholder_token": '<Alice>',
"initializer_token": 'girl',
"repeats": '100',
"train_data_dir": 'resources/Alices',
"output_dir": 'outputs/textual_inversion',
"height": 512,
"width": 512,
"learning_rate": 5e-4,
"max_train_steps": 1000,
"save_steps": 200,
"model_name": "MoososCap/NOVEL-MODEL",
},
"text_inversion": {
"width": 512,
"height": 512,
"prompt": '<Alice> at the lake',
"negative_prompt": '',
"output_dir": 'outputs/text_inversion_txt2img',
},
"dreambooth": { #Dreambooth配置段
"pretrained_model_name_or_path": "MoososCap/NOVEL-MODEL",
"instance_data_dir": 'resources/Alices',
"instance_prompt": 'a photo of Alices',
"class_data_dir": 'resources/Girls',
"class_prompt": 'a photo of girls',
"num_class_images": 100,
"prior_loss_weight": 1.0,
"with_prior_preservation": True,
#"num_train_epochs": 1,
"max_train_steps": 1000,
"save_steps": 1000,
"train_text_encoder": False,
"height": 512,
"width": 512,
"learning_rate": 5e-6,
"lr_scheduler": "constant",
"lr_warmup_steps": 500,
"center_crop": True,
"output_dir": 'outputs/dreambooth',
},
"convert": {
"checkpoint_path": '',
'dump_path': 'outputs/convert'
},
}
try:
from user_config import config as _config
for k in _config:
if k in config:
config[k].update(_config[k])
else:
config[k] = _config[k]
except:
pass
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import io
import os
import pickle
from functools import lru_cache
import numpy as np
from _io import BufferedReader
MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30
import argparse
import numpy as np
import paddle
import pickle
from functools import lru_cache
from paddlenlp.utils.downloader import get_path_from_url
try:
from omegaconf import OmegaConf
except ImportError:
raise ImportError(
"OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
)
from paddlenlp.transformers import CLIPTextModel, CLIPTokenizer
from ppdiffusers import (
AutoencoderKL,
DDIMScheduler,
EulerAncestralDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
HeunDiscreteScheduler,
EulerDiscreteScheduler,
DPMSolverMultistepScheduler
)
import io
import os
import pickle
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import pickle
from functools import lru_cache
import numpy as np
from zipfile import ZipFile
from typing import Union
def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
# When using encoding='bytes' in Py3, some **internal** keys stored as
# strings in Py2 are loaded as bytes. This function decodes them with
# ascii encoding, one that Py3 uses by default.
#
# NOTE: This should only be used on internal keys (e.g., `typename` and
# `location` in `persistent_load` below!
if isinstance(bytes_str, bytes):
return bytes_str.decode('ascii')
return bytes_str
@lru_cache(maxsize=None)
def _storage_type_to_dtype_to_map():
"""convert storage type to numpy dtype"""
return {
"DoubleStorage": np.double,
"FloatStorage": np.float32,
"HalfStorage": np.half,
"LongStorage": np.int64,
"IntStorage": np.int32,
"ShortStorage": np.int16,
"CharStorage": np.int8,
"ByteStorage": np.uint8,
"BoolStorage": np.bool_,
"ComplexDoubleStorage": np.cdouble,
"ComplexFloatStorage": np.cfloat,
}
class StorageType:
"""Temp Class for Storage Type"""
def __init__(self, name):
self.dtype = _storage_type_to_dtype_to_map()[name]
def __str__(self):
return f"StorageType(dtype={self.dtype})"
def _element_size(dtype: str) -> int:
"""
Returns the element size for a dtype, in bytes
"""
if dtype in [np.float16, np.float32, np.float64]:
return np.finfo(dtype).bits >> 3
elif dtype == np.bool_:
return 1
else:
return np.iinfo(dtype).bits >> 3
class UnpicklerWrapperStage(pickle.Unpickler):
def find_class(self, mod_name, name):
if type(name) is str and "Storage" in name:
try:
return StorageType(name)
except KeyError:
pass
# pure torch tensor builder
if mod_name == "torch._utils":
return _rebuild_tensor_stage
# pytorch_lightning tensor builder
if "pytorch_lightning" in mod_name:
return dumpy
return super().find_class(mod_name, name)
def _rebuild_tensor_stage(storage, storage_offset, size, stride, requires_grad, backward_hooks):
# if a tensor has shape [M, N] and stride is [1, N], it's column-wise / fortran-style
# if a tensor has shape [M, N] and stride is [M, 1], it's row-wise / C-style
# defautls to C-style
if (
stride is not None
and len(stride) > 1
and stride[0] == 1
and stride[1] > 1
):
order = "F"
else:
order = "C"
return storage.reshape(size, order=order)
def dumpy(*args, **kwarsg):
return None
def load_safe(path):
from safetensors.numpy import load
with open(path, "rb") as f:
data = f.read()
loaded = load(data)
return loaded
def load_torch(path: str, **pickle_load_args):
"""
load torch weight file with the following steps:
1. load the structure of pytorch weight file
2. read the tensor data and re-construct the state-dict
Args:
path: the path of pytorch weight file
**pickle_load_args: args of pickle module
Returns:
"""
pickle_load_args.update({"encoding": "utf-8"})
torch_zip = ZipFile(path, "r")
loaded_storages = {}
def load_tensor(dtype, numel, key, location):
name = f'archive/data/{key}'
typed_storage = np.frombuffer(torch_zip.open(name).read()[:numel], dtype=dtype)
return typed_storage
def persistent_load(saved_id):
assert isinstance(saved_id, tuple)
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
assert typename == 'storage', \
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
storage_type, key, location, numel = data
dtype = storage_type.dtype
if key in loaded_storages:
typed_storage = loaded_storages[key]
else:
nbytes = numel * _element_size(dtype)
typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
loaded_storages[key] = typed_storage
return typed_storage
data_iostream = torch_zip.open("archive/data.pkl").read()
unpickler_stage = UnpicklerWrapperStage(io.BytesIO(data_iostream), **pickle_load_args)
unpickler_stage.persistent_load = persistent_load
result = unpickler_stage.load()
torch_zip.close()
return result
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "query.weight")
new_item = new_item.replace("q.bias", "query.bias")
new_item = new_item.replace("k.weight", "key.weight")
new_item = new_item.replace("k.bias", "key.bias")
new_item = new_item.replace("v.weight", "value.weight")
new_item = new_item.replace("v.bias", "value.bias")
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def assign_to_checkpoint(
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming
to them. It splits attention layers, and takes into account additional replacements
that may arise.
Assigns the weights to the new checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
# Splits the attention layers into three variables.
if attention_paths_to_split is not None:
for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = np.split(old_tensor, 3, axis=1)
checkpoint[path_map["query"]] = query.reshape(target_shape)
checkpoint[path_map["key"]] = key.reshape(target_shape)
checkpoint[path_map["value"]] = value.reshape(target_shape)
for path in paths:
new_path = path["new"]
# These have already been assigned
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue
# Global renaming happens here
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def create_unet_diffusers_config(original_config, image_size: int):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
unet_params = original_config.model.params.unet_config.params
vae_params = original_config.model.params.first_stage_config.params.ddconfig
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
)
if use_linear_projection:
# stable diffusion 2-base-512 and 2-768
if head_dim is None:
head_dim = [5, 10, 20, 20]
config = dict(
sample_size=image_size // vae_scale_factor,
in_channels=unet_params.in_channels,
out_channels=unet_params.out_channels,
down_block_types=tuple(down_block_types),
up_block_types=tuple(up_block_types),
block_out_channels=tuple(block_out_channels),
layers_per_block=unet_params.num_res_blocks,
cross_attention_dim=unet_params.context_dim,
attention_head_dim=head_dim,
use_linear_projection=use_linear_projection,
)
return config
def create_vae_diffusers_config(original_config, image_size: int):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
vae_params = original_config.model.params.first_stage_config.params.ddconfig
_ = original_config.model.params.first_stage_config.params.embed_dim
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = dict(
sample_size=image_size,
in_channels=vae_params.in_channels,
out_channels=vae_params.out_ch,
down_block_types=tuple(down_block_types),
up_block_types=tuple(up_block_types),
block_out_channels=tuple(block_out_channels),
latent_channels=vae_params.z_channels,
layers_per_block=vae_params.num_res_blocks,
)
return config
def create_diffusers_schedular(original_config):
schedular = DDIMScheduler(
num_train_timesteps=original_config.model.params.timesteps,
beta_start=original_config.model.params.linear_start,
beta_end=original_config.model.params.linear_end,
beta_schedule="scaled_linear",
)
return schedular
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"] if len(checkpoint["state_dict"]) > 25 else checkpoint
# extract state_dict for UNet
unet_state_dict = {}
keys = list(checkpoint.keys())
unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100:
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
if extract_ema:
print(
"我们将提取EMA版的UNET权重。如果你想使用非EMA版权重进行微调的话,请确保将『是否提取ema权重』选项设置为 『否』!"
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
print(
"我们将提取非EMA版的UNET权重。如果你想使用EMA版权重的话,请确保将『是否提取ema权重』选项设置为 『是』"
)
if extract_ema and len(unet_state_dict) == 0:
print("由于我们在CKPT中未找到EMA权重,因此我们将不会『提取ema权重』!")
# 如果没有找到ema的权重,
if len(unet_state_dict) == 0:
for key in keys:
if "model_ema" in key: continue
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
if len(unet_state_dict) == 0:
return None
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias"
)
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
resnet_1 = middle_blocks[2]
resnet_0_paths = renew_resnet_paths(resnet_0)
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
resnet_1_paths = renew_resnet_paths(resnet_1)
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
attentions_paths = renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
for i in range(num_output_blocks):
block_id = i // (config["layers_per_block"] + 1)
layer_in_block_id = i % (config["layers_per_block"] + 1)
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
output_block_list = {}
for layer in output_block_layers:
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
if layer_id in output_block_list:
output_block_list[layer_id].append(layer_name)
else:
output_block_list[layer_id] = [layer_name]
if len(output_block_list) > 1:
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
resnet_0_paths = renew_resnet_paths(resnets)
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if ["conv.weight", "conv.bias"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
elif ["conv.bias", "conv.weight"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {
"old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
else:
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
new_checkpoint[new_path] = unet_state_dict[old_path]
return new_checkpoint
def convert_ldm_vae_checkpoint(checkpoint, config, only_vae=False):
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"] if len(checkpoint["state_dict"]) > 25 else checkpoint
# extract state dict for VAE
vae_state_dict = {}
vae_key = "first_stage_model."
keys = list(checkpoint.keys())
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
if only_vae:
vae_state_dict = checkpoint
if len(vae_state_dict) == 0:
return None
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.bias"
)
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
return new_checkpoint
def convert_diffusers_vae_unet_to_ppdiffusers(vae_or_unet, diffusers_vae_unet_checkpoint, dtype="float32"):
need_transpose = []
for k, v in vae_or_unet.named_sublayers(include_self=True):
if isinstance(v, paddle.nn.Linear):
need_transpose.append(k + ".weight")
new_vae_or_unet = {}
for k, v in diffusers_vae_unet_checkpoint.items():
if k not in need_transpose:
new_vae_or_unet[k] = v.astype(dtype)
else:
new_vae_or_unet[k] = v.T.astype(dtype)
return new_vae_or_unet
def check_keys(model, state_dict):
cls_name = model.__class__.__name__
missing_keys = []
mismatched_keys = []
for k, v in model.state_dict().items():
if k not in state_dict.keys():
missing_keys.append(k)
else:
if list(v.shape) != list(state_dict[k].shape):
mismatched_keys.append(k)
if len(missing_keys):
missing_keys_str = ", ".join(missing_keys)
print(f"{cls_name} Found missing_keys {missing_keys_str}!")
if len(mismatched_keys):
mismatched_keys_str = ", ".join(mismatched_keys)
print(f"{cls_name} Found mismatched_keys {mismatched_keys_str}!")
def convert_hf_clip_to_ppnlp_clip(checkpoint, dtype="float32"):
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"] if len(checkpoint["state_dict"]) > 25 else checkpoint
clip = {}
for key in checkpoint.keys():
if key.startswith("cond_stage_model.transformer"):
newkey = key[len("cond_stage_model.transformer.") :]
if not newkey.startswith("text_model."):
newkey = "text_model." + newkey
clip[newkey] = checkpoint[key]
if len(clip) == 0:
return None, None
new_model_state = {}
transformers2ppnlp = {
".encoder.": ".transformer.",
".layer_norm": ".norm",
".mlp.": ".",
".fc1.": ".linear1.",
".fc2.": ".linear2.",
".final_layer_norm.": ".ln_final.",
".embeddings.": ".",
".position_embedding.": ".positional_embedding.",
".patch_embedding.": ".conv1.",
"visual_projection.weight": "vision_projection",
"text_projection.weight": "text_projection",
".pre_layrnorm.": ".ln_pre.",
".post_layernorm.": ".ln_post.",
".vision_model.": ".",
}
ignore_value = ["position_ids"]
donot_transpose = ["embeddings", "norm", "concept_embeds", "special_care_embeds"]
for name, value in clip.items():
# step1: ignore position_ids
if any(i in name for i in ignore_value):
continue
# step2: transpose nn.Linear weight
if value.ndim == 2 and not any(i in name for i in donot_transpose):
value = value.T
# step3: hf_name -> ppnlp_name mapping
for hf_name, ppnlp_name in transformers2ppnlp.items():
name = name.replace(hf_name, ppnlp_name)
# step4: 0d tensor -> 1d tensor
if name == "logit_scale":
value = value.reshape((1,))
new_model_state[name] = value.astype(dtype)
new_config = {
"max_text_length": new_model_state["text_model.positional_embedding.weight"].shape[0],
"vocab_size": new_model_state["text_model.token_embedding.weight"].shape[0],
"text_embed_dim": new_model_state["text_model.token_embedding.weight"].shape[1],
"text_heads": 12,
"text_layers": 12,
"text_hidden_act": "quick_gelu",
"projection_dim": 768,
"initializer_range": 0.02,
"initializer_factor": 1.0,
}
return new_model_state, new_config
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_path", default=None, type=str, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--vae_checkpoint_path", default=None, type=str, help="Path to the vae checkpoint to convert."
)
# !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
parser.add_argument(
"--original_config_file",
default=None,
type=str,
help="The YAML config file corresponding to the original architecture.",
)
parser.add_argument(
"--num_in_channels",
default=None,
type=int,
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
)
parser.add_argument(
"--scheduler_type",
default="pndm",
type=str,
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",
)
parser.add_argument(
"--extract_ema",
action="store_true",
help=(
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
),
)
parser.add_argument("--dump_path", default=None, type=str, help="Path to the output model.")
args = parser.parse_known_args()[0]
return args
def main(args): #主函数
if args.checkpoint_path.strip() == "":
print("ckpt or safetensors 模型文件位置不能为空!")
return
if not os.path.exists(args.checkpoint_path):
print(f"{args.checkpoint_path} 文件不存在,请检查是否存在!")
return
if args.vae_checkpoint_path is not None and args.vae_checkpoint_path.strip() == "":
args.vae_checkpoint_path = None
if args.vae_checkpoint_path is not None:
if not os.path.exists(args.vae_checkpoint_path):
print(f"{args.vae_checkpoint_path} vae 文件不存在,我们将尝试使用ckpt文件的vae权重!")
args.vae_checkpoint_path = None
print("正在开始转换,请耐心等待!!!")
image_size = 512
if "safetensors" in args.checkpoint_path:
checkpoint = load_safe(args.checkpoint_path)
else:
checkpoint = load_torch(args.checkpoint_path)
checkpoint = checkpoint.get("state_dict", checkpoint)
if args.original_config_file is None:
get_path_from_url("https://paddlenlp.bj.bcebos.com/models/community/CompVis/stable-diffusion-v1-4/v1-inference.yaml", root_dir="./")
args.original_config_file = "./v1-inference.yaml"
original_config = OmegaConf.load(args.original_config_file)
if args.num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = args.num_in_channels
num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start
beta_end = original_config.model.params.linear_end
scheduler = DDIMScheduler(
beta_end=beta_end,
beta_schedule="scaled_linear",
beta_start=beta_start,
num_train_timesteps=num_train_timesteps,
steps_offset=1,
clip_sample=False,
set_alpha_to_one=False,
)
# make sure scheduler works correctly with DDIM
scheduler.register_to_config(clip_sample=False)
if args.scheduler_type == "pndm":
config = dict(scheduler.config)
config["skip_prk_steps"] = True
scheduler = PNDMScheduler.from_config(config)
elif args.scheduler_type == "lms":
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
elif args.scheduler_type == "heun":
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
elif args.scheduler_type == "euler":
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
elif args.scheduler_type == "euler-ancestral":
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
elif args.scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
elif args.scheduler_type == "ddim":
scheduler = scheduler
else:
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
print("1. 开始转换Unet!")
# 1. Convert the UNet2DConditionModel model.
diffusers_unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
diffusers_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, diffusers_unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
)
if diffusers_unet_checkpoint is not None:
unet = UNet2DConditionModel.from_config(diffusers_unet_config)
ppdiffusers_unet_checkpoint = convert_diffusers_vae_unet_to_ppdiffusers(unet, diffusers_unet_checkpoint)
check_keys(unet, ppdiffusers_unet_checkpoint)
unet.load_dict(ppdiffusers_unet_checkpoint)
print(">>> Unet转换成功!")
else:
unet = None
print("在CKPT中,未发现Unet权重,请确认是否存在!")
print("2. 开始转换Vae!")
# 2. Convert the VAE model.
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
if args.vae_checkpoint_path is not None:
if "safetensors" in args.vae_checkpoint_path:
vae_checkpoint = load_safe(args.vae_checkpoint_path)
else:
vae_checkpoint = load_torch(args.vae_checkpoint_path)
print(f"发现 {args.vae_checkpoint_path},我们将转换该文件的vae权重!")
only_vae = True
else:
vae_checkpoint = checkpoint
only_vae = False
diffusers_vae_checkpoint = convert_ldm_vae_checkpoint(vae_checkpoint, vae_config, only_vae=only_vae)
if diffusers_vae_checkpoint is not None:
vae = AutoencoderKL.from_config(vae_config)
ppdiffusers_vae_checkpoint = convert_diffusers_vae_unet_to_ppdiffusers(vae, diffusers_vae_checkpoint)
check_keys(vae, ppdiffusers_vae_checkpoint)
vae.load_dict(ppdiffusers_vae_checkpoint)
print(">>> VAE转换成功!")
else:
vae = None
print("在CKPT中,未发现Vae权重,请确认是否存在!")
print("3. 开始转换text_encoder!")
# 3. Convert the text_encoder model.
text_model_state_dict, text_config = convert_hf_clip_to_ppnlp_clip(checkpoint, dtype="float32")
if text_model_state_dict is not None:
text_model = CLIPTextModel(**text_config)
text_model.eval()
check_keys(text_model, text_model_state_dict)
text_model.load_dict(text_model_state_dict)
print(">>> text_encoder转换成功!")
else:
text_model = None
print("在CKPT中,未发现TextModel权重,请确认是否存在!")
print("4. 开始转换CLIPTokenizer!")
# 4. Convert the tokenizer.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", pad_token="!", model_max_length=77)
print(">>> CLIPTokenizer 转换成功!")
if text_model is not None and vae is not None and unet is not None:
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
pipe.save_pretrained(args.dump_path)
print(">>> 所有权重转换完成啦,请前往"+str(args.dump_path)+"查看转换好的模型!")
else:
if vae is not None:
vae.save_pretrained(os.path.join(args.dump_path, "vae"))
if text_model is not None:
text_model.save_pretrained(os.path.join(args.dump_path, "text_encoder"))
if unet is not None:
unet.save_pretrained(os.path.join(args.dump_path, "unet"))
scheduler.save_pretrained(os.path.join(args.dump_path, "scheduler"))
tokenizer.save_pretrained(os.path.join(args.dump_path, "tokenizer"))
print(">>> 部分权重转换完成啦,请前往"+str(args.dump_path)+"查看转换好的部分模型!")
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#Modified By He
import argparse
import contextlib
import hashlib
import itertools
import math
import os
import sys
from pathlib import Path
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
fused_allreduce_gradients,
)
from paddle.io import BatchSampler, DataLoader, Dataset, DistributedBatchSampler
from paddle.optimizer import AdamW
from paddle.vision import transforms
from PIL import Image
from tqdm.auto import tqdm
from paddlenlp.trainer import set_seed
from paddlenlp.transformers import AutoTokenizer, BertModel, CLIPTextModel
from paddlenlp.utils.log import logger
from ppdiffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from ppdiffusers.modeling_utils import freeze_params, unwrap_model
from ppdiffusers.optimization import get_scheduler
#def parse_args(input_args=None):
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training dreambooth script.")
parser.add_argument(
"--save_steps",
type=int,
default=500,
help="Save pipe every X updates steps.",
)
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default="CompVis/stable-diffusion-v1-4",
#required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
#required=True,
help="A folder containing the training data of instance images.",
)
parser.add_argument(
"--class_data_dir",
type=str,
default=None,
required=False,
help="A folder containing the training data of class images.",
)
parser.add_argument(
"--instance_prompt",
type=str,
default=None,
help="The prompt with identifier specifying the instance",
)
parser.add_argument(
"--class_prompt",
type=str,
default=None,
help="The prompt to specify images in the same class as provided instance images.",
)
parser.add_argument(
"--with_prior_preservation",
default=False,
action="store_true",
help="Flag to add prior preservation loss.",
)
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
parser.add_argument(
"--num_class_images",
type=int,
default=100,
help=(
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
" sampled with class_prompt."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="./dreambooth-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--height",
type=int,
default=512,
help=(
"The height for input images, all the images in the train/validation dataset will be resized to this"
" height"
),
)
parser.add_argument(
"--width",
type=int,
default=512,
help=(
"The width for input images, all the images in the train/validation dataset will be resized to this"
" width"
),
)
parser.add_argument(
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
)
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
parser.add_argument(
"--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--sample_batch_size", type=int, default=1, help="Batch size (per device) for sampling images."
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-6,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) or [VisualDL](https://www.paddlepaddle.org.cn/paddle/visualdl) log directory. Will default to"
"*output_dir/logs"
),
)
parser.add_argument(
"--writer_type", type=str, default="visualdl", choices=["tensorboard", "visualdl"], help="Log writer type."
)
#if input_args is not None:
# args = parser.parse_args(input_args)
#else:
#args = parser.parse_args()
#args = parser.parse_known_args()[0]
args = parser.parse_args(args=[])
#if args.instance_data_dir is None:
# raise ValueError("You must specify a train data directory.")
#
#if args.with_prior_preservation:
# if args.class_data_dir is None:
# raise ValueError("You must specify a data directory for class images.")
# if args.class_prompt is None:
# raise ValueError("You must specify prompt for class images.")
#args.logging_dir = os.path.join(args.output_dir, args.logging_dir)
return args
class DreamBoothDataset(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
"""
def __init__(
self,
instance_data_root,
instance_prompt,
tokenizer,
class_data_root=None,
class_prompt=None,
height=512,
width=512,
center_crop=False,
):
self.height = height
self.width = width
self.center_crop = center_crop
self.tokenizer = tokenizer
self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")
ext = ["png", "jpg", "jpeg", "bmp", "PNG", "JPG", "JPEG", "BMP"]
self.instance_images_path = []
for p in Path(instance_data_root).iterdir():
if any(suffix in p.name for suffix in ext):
self.instance_images_path.append(p)
self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = []
for p in Path(class_data_root).iterdir():
if any(suffix in p.name for suffix in ext):
self.class_images_path.append(p)
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt
else:
self.class_data_root = None
self.image_transforms = transforms.Compose(
[
transforms.Resize((height, width), interpolation="bilinear"),
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
def __getitem__(self, index):
example = {}
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)
example["instance_prompt_ids"] = self.tokenizer(
self.instance_prompt,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.tokenizer(
self.class_prompt,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids
return example
class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
def __init__(self, prompt, num_samples):
self.prompt = prompt
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, index):
example = {}
example["prompt"] = self.prompt
example["index"] = index
return example
def get_writer(args):
if args.writer_type == "visualdl":
from visualdl import LogWriter
writer = LogWriter(logdir=args.logging_dir)
elif args.writer_type == "tensorboard":
from tensorboardX import SummaryWriter
writer = SummaryWriter(logdir=args.logging_dir)
else:
raise ValueError("writer_type must be in ['visualdl', 'tensorboard']")
return writer
def main(args):
rank = paddle.distributed.get_rank()
num_processes = paddle.distributed.get_world_size()
if num_processes > 1:
paddle.distributed.init_parallel_env()
# If passed along, set the training seed now.
if args.seed is not None:
seed = args.seed + rank
set_seed(seed)
if args.with_prior_preservation:
if rank == 0:
class_images_dir = Path(args.class_data_dir)
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True)
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, safety_checker=None
)
pipeline.set_progress_bar_config(disable=True)
num_new_images = args.num_class_images - cur_class_images
logger.info(f"Number of class images to sample: {num_new_images}.")
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = DataLoader(sample_dataset, batch_size=args.sample_batch_size)
for example in tqdm(
sample_dataloader,
desc="Generating class images",
):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = (
class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
)
image.save(image_filename)
del pipeline
# donot use paddle.device.cuda.empty_cache
# if paddle.device.is_compiled_with_cuda():
# paddle.device.cuda.empty_cache()
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load the tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
elif args.pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.pretrained_model_name_or_path, "tokenizer"))
# Load models and create wrapper for stable diffusion
if "Taiyi-Stable-Diffusion-1B-Chinese-v0.1" in args.pretrained_model_name_or_path:
model_cls = BertModel
else:
model_cls = CLIPTextModel
text_encoder = model_cls.from_pretrained(os.path.join(args.pretrained_model_name_or_path, "text_encoder"))
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
freeze_params(vae.parameters())
if not args.train_text_encoder:
freeze_params(text_encoder.parameters())
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * num_processes
)
lr_scheduler = get_scheduler(
args.lr_scheduler,
learning_rate=args.learning_rate,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
if num_processes > 1:
unet = paddle.DataParallel(unet)
if args.train_text_encoder:
text_encoder = paddle.DataParallel(text_encoder)
params_to_optimize = (
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
)
optimizer = AdamW(
learning_rate=lr_scheduler,
parameters=params_to_optimize,
beta1=args.adam_beta1,
beta2=args.adam_beta2,
weight_decay=args.adam_weight_decay,
epsilon=args.adam_epsilon,
grad_clip=nn.ClipGradByGlobalNorm(args.max_grad_norm) if args.max_grad_norm is not None else None,
)
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt,
tokenizer=tokenizer,
height=args.height,
width=args.width,
center_crop=args.center_crop,
)
def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if args.with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
pixel_values = paddle.stack(pixel_values).astype("float32")
input_ids = tokenizer.pad(
{"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pd"
).input_ids
batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
}
return batch
train_sampler = (
DistributedBatchSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True)
if num_processes > 1
else BatchSampler(train_dataset, batch_size=args.train_batch_size, shuffle=True)
)
train_dataloader = DataLoader(train_dataset, batch_sampler=train_sampler, collate_fn=collate_fn, num_workers=1)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if rank == 0:
logger.info("----------- Configuration Arguments -----------")
for arg, value in sorted(vars(args).items()):
logger.info("%s: %s" % (arg, value))
logger.info("------------------------------------------------")
writer = get_writer(args)
# Train!
total_batch_size = args.train_batch_size * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=rank > 0)
progress_bar.set_description("Train Steps")
global_step = 0
# Keep vae in eval model as we don't train these
vae.eval()
if args.train_text_encoder:
text_encoder.train()
else:
text_encoder.eval()
unet.train()
for epoch in range(args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = paddle.randn(latents.shape)
batch_size = latents.shape[0]
# Sample a random timestep for each image
timesteps = paddle.randint(0, noise_scheduler.config.num_train_timesteps, (batch_size,)).astype("int64")
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if num_processes > 1 and (
args.gradient_checkpointing or ((step + 1) % args.gradient_accumulation_steps != 0)
):
# grad acc, no_sync when (step + 1) % args.gradient_accumulation_steps != 0:
# gradient_checkpointing, no_sync every where
# gradient_checkpointing + grad_acc, no_sync every where
unet_ctx_manager = unet.no_sync()
if args.train_text_encoder:
text_encoder_ctx_manager = text_encoder.no_sync()
else:
text_encoder_ctx_manager = (
contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()
)
else:
unet_ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()
text_encoder_ctx_manager = (
contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()
)
with text_encoder_ctx_manager:
# Get the text embedding for conditioning
attention_mask = paddle.ones_like(batch["input_ids"])
encoder_hidden_states = text_encoder(batch["input_ids"], attention_mask=attention_mask)[0]
with unet_ctx_manager:
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.with_prior_preservation:
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
noise_pred, noise_pred_prior = noise_pred.chunk(2, axis=0)
noise, noise_prior = noise.chunk(2, axis=0)
# Compute instance loss
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
# Compute prior loss
prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean()
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
loss.backward()
if (step + 1) % args.gradient_accumulation_steps == 0:
if num_processes > 1 and args.gradient_checkpointing:
fused_allreduce_gradients(params_to_optimize, None)
optimizer.step()
lr_scheduler.step()
optimizer.clear_grad()
progress_bar.update(1)
global_step += 1
logs = {
"epoch": str(epoch).zfill(4),
"step_loss": round(loss.item() * args.gradient_accumulation_steps, 10),
"lr": lr_scheduler.get_lr(),
}
progress_bar.set_postfix(**logs)
if rank == 0:
for name, val in logs.items():
if name == "epoch":
continue
writer.add_scalar(f"train/{name}", val, step=global_step)
if global_step % args.save_steps == 0:
# Create the pipeline using using the trained modules and save it.
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unwrap_model(unet),
text_encoder=unwrap_model(text_encoder),
safety_checker=None,
tokenizer=tokenizer,
)
pipeline.save_pretrained(args.output_dir+str(global_step))
if global_step >= args.max_train_steps:
break
if rank == 0:
writer.close()
# Create the pipeline using using the trained modules and save it.
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unwrap_model(unet),
text_encoder=unwrap_model(text_encoder),
safety_checker=None,
tokenizer=tokenizer,
)
pipeline.save_pretrained(args.output_dir+str(global_step))
print("训练完成了,请重启内核")
if __name__ == "__main__":
args = parse_args()
main(args)
DEBUG_UI = False
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import random
import re
import time
from typing import Callable, List, Optional, Union
import numpy as np
import paddle
import PIL
import PIL.Image
from packaging import version
from paddlenlp.transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ppdiffusers.configuration_utils import FrozenDict
from ppdiffusers.models import AutoencoderKL, UNet2DConditionModel
from ppdiffusers.pipeline_utils import DiffusionPipeline
from ppdiffusers.schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ppdiffusers.utils import PIL_INTERPOLATION, deprecate, logging
from ppdiffusers.utils.testing_utils import load_image
from ppdiffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from ppdiffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def save_all(images, FORMAT="jpg", OUTDIR="./outputs/"):
if not isinstance(images, (list, tuple)):
images = [images]
for image in images:
PRECISION = "fp32"
argument = image.argument
os.makedirs(OUTDIR, exist_ok=True)
epoch_time = argument["epoch_time"]
PROMPT = argument["prompt"]
NEGPROMPT = argument["negative_prompt"]
HEIGHT = argument["height"]
WIDTH = argument["width"]
SEED = argument["seed"]
STRENGTH = argument.get("strength", 1)
INFERENCE_STEPS = argument["num_inference_steps"]
GUIDANCE_SCALE = argument["guidance_scale"]
filename = f"{str(epoch_time)}_scale_{GUIDANCE_SCALE}_steps_{INFERENCE_STEPS}_seed_{SEED}.{FORMAT}"
filedir = f"{OUTDIR}/{filename}"
image.save(filedir)
with open(f"{OUTDIR}/{epoch_time}_prompt.txt", "w") as file:
file.write(
f"PROMPT: {PROMPT}\nNEG_PROMPT: {NEGPROMPT}\n\nINFERENCE_STEPS: {INFERENCE_STEPS}\nHeight: {HEIGHT}\nWidth: {WIDTH}\nSeed: {SEED}\n\nPrecision: {PRECISION}\nSTRENGTH: {STRENGTH}\nGUIDANCE_SCALE: {GUIDANCE_SCALE}"
)
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
break
# truncate
if len(text_token) > max_length:
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
return tokens, weights
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
for i in range(len(tokens)):
tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i]))
if no_boseos_middle:
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range((len(weights[i]) - 1) // chunk_length + 1):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
pipe: DiffusionPipeline, text_input: paddle.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.text_encoder(text_input_chunk)[0]
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
text_embeddings = paddle.concat(text_embeddings, axis=1)
else:
text_embeddings = pipe.text_encoder(text_input)[0]
return text_embeddings
def get_weighted_text_embeddings(
pipe: DiffusionPipeline,
prompt: Union[str, List[str]],
uncond_prompt: Optional[Union[str, List[str]]] = None,
max_embeddings_multiples: Optional[int] = 1,
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
**kwargs
):
r"""
Prompts can be assigned with local weights using brackets. For example,
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Args:
pipe (`DiffusionPipeline`):
Pipe to provide access to the tokenizer and the text encoder.
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
uncond_prompt (`str` or `List[str]`):
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
is provided, the embeddings of prompt and uncond_prompt are concatenated.
max_embeddings_multiples (`int`, *optional*, defaults to `1`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
no_boseos_middle (`bool`, *optional*, defaults to `False`):
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
ending token in each of the chunk in the middle.
skip_parsing (`bool`, *optional*, defaults to `False`):
Skip the parsing of brackets.
skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True.
"""
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str):
prompt = [prompt]
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
else:
prompt_tokens = [
token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens = [
token[1:-1]
for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
if uncond_prompt is not None:
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
max_embeddings_multiples = min(
max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
# support bert tokenizer
bos = pipe.tokenizer.bos_token_id if pipe.tokenizer.bos_token_id is not None else pipe.tokenizer.cls_token_id
eos = pipe.tokenizer.eos_token_id if pipe.tokenizer.eos_token_id is not None else pipe.tokenizer.sep_token_id
pad = pipe.tokenizer.pad_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
pad,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.tokenizer.model_max_length,
)
prompt_tokens = paddle.to_tensor(prompt_tokens)
if uncond_prompt is not None:
uncond_tokens, uncond_weights = pad_tokens_and_weights(
uncond_tokens,
uncond_weights,
max_length,
bos,
eos,
pad,
no_boseos_middle=no_boseos_middle,
chunk_length=pipe.tokenizer.model_max_length,
)
uncond_tokens = paddle.to_tensor(uncond_tokens)
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
)
prompt_weights = paddle.to_tensor(prompt_weights, dtype=text_embeddings.dtype)
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
)
uncond_weights = paddle.to_tensor(uncond_weights, dtype=uncond_embeddings.dtype)
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
if (not skip_parsing) and (not skip_weighting):
previous_mean = text_embeddings.mean(axis=[-2, -1])
text_embeddings *= prompt_weights.unsqueeze(-1)
text_embeddings *= previous_mean / text_embeddings.mean(axis=[-2, -1])
if uncond_prompt is not None:
previous_mean = uncond_embeddings.mean(axis=[-2, -1])
uncond_embeddings *= uncond_weights.unsqueeze(-1)
uncond_embeddings *= previous_mean / uncond_embeddings.mean(axis=[-2, -1])
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if uncond_prompt is not None:
text_embeddings = paddle.concat([uncond_embeddings, text_embeddings])
return text_embeddings
def preprocess_image(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = paddle.to_tensor(image)
return 2.0 * image - 1.0
def preprocess_mask(mask):
mask = mask.convert("L")
w, h = mask.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
mask = np.array(mask).astype(np.float32) / 255.0
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
mask = 1 - mask # repaint white, keep black
mask = paddle.to_tensor(mask)
return mask
class StableDiffusionPipelineAllinOne(DiffusionPipeline):
r"""
Pipeline for text-to-image image-to-image inpainting generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular xxxx, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], [`PNDMScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`]
or [`DPMSolverMultistepScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/junnyu/stable-diffusion-v1-4-paddle) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = False,
):
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_ppdiffusers_version") and version.parse(
version.parse(unet.config._ppdiffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def __call__(self, *args, **kwargs):
return self.text2image(*args, **kwargs)
def text2img(self, *args, **kwargs):
return self.text2image(*args, **kwargs)
def _encode_prompt(
self,
prompt,
negative_prompt,
max_embeddings_multiples,
no_boseos_middle,
skip_parsing,
skip_weighting,
do_classifier_free_guidance,
num_images_per_prompt,
):
if do_classifier_free_guidance and negative_prompt is None:
negative_prompt = ""
text_embeddings = get_weighted_text_embeddings(
self, prompt, negative_prompt, max_embeddings_multiples, no_boseos_middle, skip_parsing, skip_weighting
)
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.tile([1, num_images_per_prompt, 1])
text_embeddings = text_embeddings.reshape([bs_embed * num_images_per_prompt, seq_len, -1])
return text_embeddings
def run_safety_checker(self, image, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pd")
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.cast(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clip(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.transpose([0, 2, 3, 1]).cast("float32").numpy()
return image
def prepare_extra_step_kwargs(self, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
return extra_step_kwargs
def check_inputs_text2img(self, prompt, height, width, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
def check_inputs_img2img_inpaint(self, prompt, strength, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
def prepare_latents_text2img(self, batch_size, num_channels_latents, height, width, dtype, latents=None):
shape = [batch_size, num_channels_latents, height // 8, width // 8]
if latents is None:
latents = paddle.randn(shape, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_latents_img2img(self, image, timestep, num_images_per_prompt, dtype):
image = image.cast(dtype=dtype)
init_latent_dist = self.vae.encode(image).latent_dist
init_latents = init_latent_dist.sample()
init_latents = 0.18215 * init_latents
b, c, h, w = init_latents.shape
init_latents = init_latents.tile([1, num_images_per_prompt, 1, 1])
init_latents = init_latents.reshape([b * num_images_per_prompt, c, h, w])
# add noise to latents using the timesteps
noise = paddle.randn(init_latents.shape, dtype=dtype)
# get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
return latents
def get_timesteps(self, num_inference_steps, strength):
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
t_start = max(num_inference_steps - init_timestep + offset, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps, num_inference_steps - t_start
def prepare_latents_inpaint(self, image, timestep, num_images_per_prompt, dtype):
image = image.cast(dtype)
init_latent_dist = self.vae.encode(image).latent_dist
init_latents = init_latent_dist.sample()
init_latents = 0.18215 * init_latents
b, c, h, w = init_latents.shape
init_latents = init_latents.tile([1, num_images_per_prompt, 1, 1])
init_latents = init_latents.reshape([b * num_images_per_prompt, c, h, w])
init_latents_orig = init_latents
# add noise to latents using the timesteps
noise = paddle.randn(init_latents.shape, dtype=dtype)
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
return latents, init_latents_orig, noise
@paddle.no_grad()
def text2image(
self,
prompt: Union[str, List[str]],
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
seed: Optional[int] = None,
latents: Optional[paddle.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
callback_steps: Optional[int] = 1,
# new add
max_embeddings_multiples: Optional[int] = 1,
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
seed (`int`, *optional*):
Random number seed.
latents (`paddle.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `seed`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
seed = random.randint(0, 2**32) if seed is None else seed
argument = dict(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
eta=eta,
seed=seed,
latents=latents,
max_embeddings_multiples=max_embeddings_multiples,
no_boseos_middle=no_boseos_middle,
skip_parsing=skip_parsing,
skip_weighting=skip_weighting,
epoch_time=time.time(),
)
paddle.seed(seed)
# 1. Check inputs. Raise error if not correct
self.check_inputs_text2img(prompt, height, width, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt,
negative_prompt,
max_embeddings_multiples,
no_boseos_middle,
skip_parsing,
skip_weighting,
do_classifier_free_guidance,
num_images_per_prompt,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents_text2img(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, text_embeddings.dtype)
# 10. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image, argument=argument)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
@paddle.no_grad()
def img2img(
self,
prompt: Union[str, List[str]],
image: Union[paddle.Tensor, PIL.Image.Image],
strength: float = 0.8,
height=None,
width=None,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
seed: Optional[int] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
callback_steps: Optional[int] = 1,
# new add
max_embeddings_multiples: Optional[int] = 1,
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
image (`paddle.Tensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
noise will be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter will be modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
seed (`int`, *optional*):
A random seed.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
seed = random.randint(0, 2**32) if seed is None else seed
image_str = image
if isinstance(image_str, str):
image = load_image(image_str)
if height is None and width is None:
width = (image.size[0] // 8) * 8
height = (image.size[1] // 8) * 8
elif height is None and width is not None:
height = (image.size[1] // 8) * 8
elif width is None and height is not None:
width = (image.size[0] // 8) * 8
else:
height = height
width = width
argument = dict(
prompt=prompt,
image=image_str,
negative_prompt=negative_prompt,
height=height,
width=width,
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
eta=eta,
seed=seed,
max_embeddings_multiples=max_embeddings_multiples,
no_boseos_middle=no_boseos_middle,
skip_parsing=skip_parsing,
skip_weighting=skip_weighting,
epoch_time=time.time(),
)
paddle.seed(seed)
# 1. Check inputs
self.check_inputs_img2img_inpaint(prompt, strength, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt,
negative_prompt,
max_embeddings_multiples,
no_boseos_middle,
skip_parsing,
skip_weighting,
do_classifier_free_guidance,
num_images_per_prompt,
)
# 4. Preprocess image
if isinstance(image, PIL.Image.Image):
image = image.resize((width, height))
image = preprocess_image(image)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
latent_timestep = timesteps[:1].tile([batch_size * num_images_per_prompt])
# 6. Prepare latent variables
latents = self.prepare_latents_img2img(image, latent_timestep, num_images_per_prompt, text_embeddings.dtype)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 9. Post-processing
image = self.decode_latents(latents)
# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, text_embeddings.dtype)
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image, argument=argument)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
@paddle.no_grad()
def inpaint(
self,
prompt: Union[str, List[str]],
image: Union[paddle.Tensor, PIL.Image.Image],
mask_image: Union[paddle.Tensor, PIL.Image.Image],
height=None,
width=None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
seed: Optional[int] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
callback_steps: Optional[int] = 1,
# new add
max_embeddings_multiples: Optional[int] = 1,
no_boseos_middle: Optional[bool] = False,
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
image (`paddle.Tensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process. This is the image whose masked region will be inpainted.
mask_image (`paddle.Tensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
is 1, the denoising process will be run on the masked area for the full number of iterations specified
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
num_inference_steps (`int`, *optional*, defaults to 50):
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
seed (`int`, *optional*):
A random seed.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
seed = random.randint(0, 2**32) if seed is None else seed
image_str = image
mask_image_str = mask_image
if isinstance(image_str, str):
image = load_image(image_str)
if isinstance(mask_image_str, str):
mask_image = load_image(mask_image_str)
if height is None and width is None:
width = (image.size[0] // 8) * 8
height = (image.size[1] // 8) * 8
elif height is None and width is not None:
height = (image.size[1] // 8) * 8
elif width is None and height is not None:
width = (image.size[0] // 8) * 8
else:
height = height
width = width
argument = dict(
prompt=prompt,
image=image_str,
mask_image=mask_image_str,
negative_prompt=negative_prompt,
height=height,
width=width,
strength=strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
eta=eta,
seed=seed,
max_embeddings_multiples=max_embeddings_multiples,
no_boseos_middle=no_boseos_middle,
skip_parsing=skip_parsing,
skip_weighting=skip_weighting,
epoch_time=time.time(),
)
paddle.seed(seed)
# 1. Check inputs
self.check_inputs_img2img_inpaint(prompt, strength, callback_steps)
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_embeddings = self._encode_prompt(
prompt,
negative_prompt,
max_embeddings_multiples,
no_boseos_middle,
skip_parsing,
skip_weighting,
do_classifier_free_guidance,
num_images_per_prompt,
)
if not isinstance(image, paddle.Tensor):
image = image.resize((width, height))
image = preprocess_image(image)
if not isinstance(mask_image, paddle.Tensor):
mask_image = mask_image.resize((width, height))
mask_image = preprocess_mask(mask_image)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
latent_timestep = timesteps[:1].tile([batch_size * num_images_per_prompt])
# 6. Prepare latent variables
# encode the init image into latents and scale the latents
latents, init_latents_orig, noise = self.prepare_latents_inpaint(
image, latent_timestep, num_images_per_prompt, text_embeddings.dtype
)
# 7. Prepare mask latent
mask = mask_image.cast(latents.dtype)
mask = paddle.concat([mask] * batch_size * num_images_per_prompt)
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
latents = (init_latents_proper * mask) + (latents * (1 - mask))
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 10. Post-processing
image = self.decode_latents(latents)
# 11. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, text_embeddings.dtype)
# 12. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image, argument=argument)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
@staticmethod
def numpy_to_pil(images, **kwargs):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = []
argument = kwargs.pop("argument", None)
for image in images:
image = PIL.Image.fromarray(image)
if argument is not None:
image.argument = argument
pil_images.append(image)
return pil_images
import os
import re
import json
from json import JSONDecodeError
from PIL import Image, PngImagePlugin
from enum import Enum #IntFlag?
'''
一共六种图片信息来源的情况
1. [Paddle] 直接输出到txt
2. [Paddle] 直接保存到png的info中
3. [PaddleLikeWebUI] 仿webui输出到txt
4. [PaddleLikeWebUI] 仿webui输出到png的info[parameters]中
5. [WebUI] webui图片
6. [NAIFU] NAIFU图片
'''
class InfoFormat(Enum):
Paddle = 1
WebUI = 4
NAIFU = 8
PaddleLikeWebUI = 5
Unknown = 255
# 所有Paddle输出的参数
PRAM_NAME_LIST = (
'prompt',
'negative_prompt',
'height',
'width',
'num_inference_steps',
'guidance_scale',
'num_images_per_prompt',
'eta',
'seed',
'latents',
'max_embeddings_multiples',
'no_boseos_middle',
'skip_parsing',
'skip_weighting',
'epoch_time',
'sampler',
'superres_model_name',
'model_name',
# img2img
'init_image',
)
#情况3/4 [Paddle]=>[PaddleLikeWebUI]
MAP_PARAM_TO_LABEL = {
'prompt': '',
'negative_prompt': 'Negative prompt: ', #用于与webui保持一致
'num_inference_steps': 'Steps: ', #用于与webui保持一致
'sampler': 'Sampler: ',
'guidance_scale': 'CFG scale: ',
'strength': 'Strength: ',
'seed': 'Seed: ',
'width':'width: ',
'height':'height: ',
}
#情况3/4/5 [PaddleLikeWebUI]/[WebUI]=>[Paddle]
MAP_LAEBL_TO_PARAM = {
'Prompt': 'prompt',
'Negative prompt': 'negative_prompt',
'Steps': 'num_inference_steps',
'Sampler': 'sampler',
'CFG Scale': 'guidance_scale',
'CFG scale': 'guidance_scale',
'Strength': 'strength',
'Seed': 'seed',
'Width':'width',
'Height':'height',
#webui
'Eta': 'eta',
'Model': 'model_name', #注意model_name可能是webui的模型,不能直接使用
'Model hash': 'model_hash',
}
# [NAIFU]=>[Paddle]
MAP_NAIFU_TAG_TO_PARAM = {
'steps': 'num_inference_steps',
'scale': 'guidance_scale',
'uc': 'negative_prompt',
}
# 可信的,是文本内容的tag。用于info信息继承
RELIABEL_TAG_LIST = (
'Title',
'Description',
'Software',
'Source',
'Comment',
'parameters',
'original_parameters',
)
# --------------------------------------------------
# 序列化
# --------------------------------------------------
# 输出[PaddleLikeWebUI]的样式
def serialize_to_text(params):
"""
将参数序列化为文本,以用于保存图片。格式[PaddleLikeWebUI]
"""
# Todo 剔除无用信息
labels = MAP_PARAM_TO_LABEL
info = ''
for k in labels:
if k in params: info += labels[k] + str(params[k]) + '\n'
for k in params:
if k not in labels: info += k + ': ' + str(params[k]) + '\n'
return info
def serialize_to_pnginfo(params, existing_info = None, mark_paddle = True):
"""
将参数序列化到图像信息中。格式[PaddleLikeWebUI]
参数 existing_info 用于继承。
"""
text = serialize_to_text(params)
dict = {}
if existing_info is None:
pass
elif 'parameters' in existing_info:
# dict.update(existing_info)
dict['original_parameters'] = existing_info.pop('parameters')
elif 'prompt' in existing_info:
# 如果是[Paddle],那么将其转换为[PaddleLikeWebUI],并舍弃掉[Paddle]信息
dict['original_parameters'] = serialize_to_text(existing_info)
for k in existing_info:
if k not in PRAM_NAME_LIST: dict[k] = existing_info[k]
else:
#当做[NAIFU]处理(未知的tag不能保证数据类型是iTXt)
for k in ('Title', 'Description', 'Software', 'Source', 'Comment'):
if k in existing_info:
dict[k] = existing_info[k]
if mark_paddle: dict['Software'] = 'PaddleNLP'
dict['parameters'] = text
pnginfo = PngImagePlugin.PngInfo()
for key, val in dict.items():
pnginfo.add_text(key, str(val))
return pnginfo
def imageinfo_to_pnginfo(info, update_format = True):
"""
从[Image.info]生成[PngInfo],用于继承图片信息。
仅用于超分辨率(highres)的图片保存。
不认识的信息会被过滤掉。
"""
dict = {}
if ('prompt' in info):
if update_format:
# 如果是[Paddle],那么将其转换为[PaddleLikeWebUI],并舍弃掉[Paddle]信息
dict['parameters'] = serialize_to_text(info)
for k in info:
if k not in PRAM_NAME_LIST: dict[k] = info[k]
else:
for k in PRAM_NAME_LIST:
if k in info: dict[k] = info[k]
# 可信的info
for k in RELIABEL_TAG_LIST:
if k in info:
dict[k] = info[k]
pnginfo = PngImagePlugin.PngInfo()
for key, val in dict.items():
pnginfo.add_text(key, str(val))
return pnginfo
# --------------------------------------------------
# 反序列化
# --------------------------------------------------
def _parse_value(s):
if s == 'None': return None
if s == 'False': return False
if s == 'True': return True
if re.fullmatch(r'[\d\.]+', s):
return int(s) if s.find('.') < 0 else float(s)
return s
# 只支持[Paddle][PaddleLikeWebUI][WebUI]
def _deserialize_from_lines(enumerable, format_presumed = InfoFormat.Unknown):
dict = {}
fmt = format_presumed
ln = -1
name = 'prompt'
for line in enumerable:
ln += 1
line = line.rstrip('\n')
key, colon, val = line.partition(': ')
if (ln == 0) and (key == 'prompt'):
fmt = InfoFormat.Paddle
# 没有冒号分隔
if colon == '':
if ln == 0:
name = 'prompt'
dict[name] = line
elif name == 'prompt' or name == 'negative_prompt':
dict[name] += '\n' + line #追加上一行
elif line == '':
pass
elif re.fullmatch(r'Strength[\d\.]+', name):
# 兼容之前Strength标签错误
dict['strength'] = name[8:]
else:
#不认识的换行参数
dict[name] += '\n' + line
# 有冒号分隔
elif key in PRAM_NAME_LIST:
# 1/2原始格式
name = key
dict[name] = val
elif key in MAP_LAEBL_TO_PARAM:
# 3/4格式
fmt = InfoFormat.PaddleLikeWebUI
name = MAP_LAEBL_TO_PARAM[key]
dict[name] = val
# 发现标签但是不认识
elif name == 'prompt' or name == 'negative_prompt':
# prompt下不视为标签
dict[name] += '\n' + line
# 看着像一个标签
elif re.fullmatch(r'\w+', name):
# 当他是个标签
name = key
dict[name] = val
else:
dict[name] += '\n' + line #追加上一行
# 处理webui格式([WebUI]=>[Paddle]
if ('num_inference_steps' in dict) and (dict['num_inference_steps'].find(', ') > -1):
webui_text = dict['num_inference_steps']
fmt = InfoFormat.WebUI
webui_text = 'num_inference_steps: '+webui_text
for pair in webui_text.split(', '):
key, colon, val = pair.partition(': ')
key = key if key not in MAP_LAEBL_TO_PARAM else MAP_LAEBL_TO_PARAM[key]
dict[key] = val
# 处理Size: 768x512
if ('Size' in dict):
size = re.split(r'\D',dict.pop('Size'))
dict['width'] = size[0]
dict['height'] = size[1]
for k in dict:
dict[k] = _parse_value(dict[k])
return (dict,fmt)
def deserialize_from_txt(text, format_presumed = InfoFormat.Unknown):
""" 从一段文本提取参数信息。支持格式[Paddle][PaddleLikeWebUI][WebUI] """
return _deserialize_from_lines(text.splitlines(), format_presumed)
# 直接从图片Info中收集信息[Paddle]
def _collect_from_pnginfo(info):
dict = {}
for key in info:
if key in PRAM_NAME_LIST:
dict[key] = _parse_value(info[key])
fmt = InfoFormat.Paddle if 'prompt' in dict else InfoFormat.Unknown
return (dict,fmt)
# 从Naifu中提取信息 [NAIFU]
def _collect_from_pnginfo_naifu(info):
if ('Description' not in info) \
or ('Comment' not in info) \
or ('Software' not in info):
return ({}, InfoFormat.Unknown)
try:
data = json.loads(info['Comment'])
data['prompt'] = info['Description']
for key in MAP_NAIFU_TAG_TO_PARAM:
if key in data:
data[MAP_NAIFU_TAG_TO_PARAM[key]] = data.pop(key)
return (data, InfoFormat.NAIFU)
except JSONDecodeError:
return ({}, InfoFormat.Unknown)
def deserialize_from_image(image):
""" 从图片获取参数信息。参数为Image或文件地址。"""
if isinstance(image, str):
assert os.path.isfile(image), f'{image}不是可读文件'
image = Image.open(image)
if 'parameters' in image.info: #是情况4/5 [PaddleLikeWebUI]
return deserialize_from_txt(image.info['parameters'], InfoFormat.PaddleLikeWebUI)
# [NAIFU]
dict, fmt = _collect_from_pnginfo_naifu(image.info)
if fmt is InfoFormat.NAIFU: return (dict, fmt)
# [Paddle]
return _collect_from_pnginfo(image.info)
def deserialize_from_filename(filename):
""" 从文本文件或图像文件获取参数信息,优先从其对应的文本文件中提取。参数为文件地址。"""
txt_path, dot, ext = filename.rpartition('.')
txt_path += '.txt'
if os.path.isfile(txt_path):
with open(txt_path, 'r') as f:
(dict,fmt) = _deserialize_from_lines(f)
if (fmt is not InfoFormat.Unknown) or ext.lower() == 'txt':
return (dict, fmt)
return deserialize_from_image(filename)
\ No newline at end of file
import paddle
from PIL import Image
import gc
def image_grid(imgs, rows=2, cols=2):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import glob
import itertools
import math
import os
import random
from pathlib import Path
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.io import BatchSampler, DataLoader, Dataset, DistributedBatchSampler
from paddle.optimizer import AdamW
from paddle.vision.transforms import RandomHorizontalFlip
from PIL import Image
from tqdm.auto import tqdm
from paddlenlp.trainer import set_seed
from paddlenlp.transformers import AutoTokenizer, BertModel, CLIPTextModel
from paddlenlp.utils.log import logger
try:
from ppdiffusers import (
AutoencoderKL,
DDPMScheduler,
PNDMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
patch_to,
)
from ppdiffusers.modeling_utils import freeze_params, unwrap_model
from ppdiffusers.optimization import get_scheduler
from ppdiffusers.pipelines.alt_diffusion import RobertaSeriesModelWithTransformation
from ppdiffusers.ppnlp_patch_utils import XLMRobertaTokenizer
from ppdiffusers.utils import PIL_INTERPOLATION
# patch
@patch_to(RobertaSeriesModelWithTransformation)
def get_input_embeddings(self):
return self.roberta.embeddings.word_embeddings
@patch_to(RobertaSeriesModelWithTransformation)
def set_input_embeddings(self, value):
self.roberta.embeddings.word_embeddings = value
except:
AutoencoderKL = None
DDPMScheduler = None
PNDMScheduler = None
StableDiffusionPipeline = None
UNet2DConditionModel = None
patch_to = None
freeze_params, unwrap_model = None, None
get_scheduler = None
RobertaSeriesModelWithTransformation = None
XLMRobertaTokenizer = None
PIL_INTERPOLATION = None
def get_writer(args):
if args.writer_type == "visualdl":
from visualdl import LogWriter
writer = LogWriter(logdir=args.logging_dir)
elif args.writer_type == "tensorboard":
from tensorboardX import SummaryWriter
writer = SummaryWriter(logdir=args.logging_dir)
else:
raise ValueError("writer_type must be in ['visualdl', 'tensorboard']")
return writer
def save_progress(text_encoder, placeholder_token_id, args, global_step=-1):
learned_embeds = unwrap_model(
text_encoder).get_input_embeddings().weight[placeholder_token_id]
learned_embeds_dict = {
args.placeholder_token: learned_embeds.detach().cpu()
}
# remove \/"*:?<>| in filename
name = args.placeholder_token
name = name.translate({
92: 95,
47: 95,
42: 95,
34: 95,
58: 95,
63: 95,
60: 95,
62: 95,
124: 95
})
path = os.path.join(args.output_dir, "step-"+str(global_step))
os.makedirs(path, exist_ok=True)
paddle.save(learned_embeds_dict,
os.path.join(args.output_dir, "step-"+str(global_step), f"{name}.pdparams"))
print(
f"Global_step: {global_step} 程序没有卡住,目前正在生成评估图片,请耐心等待!训练好的权重和评估图片将会自动保存到 {path} 目录下。")
def generate_image(text_encoder, unet, vae, tokenizer, eval_scheduler, args):
text_encoder.eval()
temp_pipeline = StableDiffusionPipeline(
text_encoder=unwrap_model(text_encoder),
unet=unet,
vae=vae,
tokenizer=tokenizer,
scheduler=eval_scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
temp_pipeline.set_progress_bar_config(disable=True)
all_images = []
for _ in range(4):
all_images.append(
temp_pipeline(args.image_logging_prompt, height=args.height, width=args.width, output_type="numpy").images[0]
)
all_images = np.stack(all_images, axis=0)
text_encoder.train()
return all_images, temp_pipeline.numpy_to_pil(all_images)
def parse_args():
parser = argparse.ArgumentParser(
description="Simple example of a training script.")
parser.add_argument(
"--save_steps",
type=int,
default=10,
help="Save learned_embeds.pdparams every X updates steps.",
)
parser.add_argument(
"--image_logging_prompt",
type=str,
default=None,
help="Logging image use which prompt.",
)
parser.add_argument(
"--model_name",
type=str,
default="CompVis/stable-diffusion-v1-4",
required=False,
help="Path to pretrained model or model identifier from local models.",
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument("--train_data_dir",
type=str,
default=None,
required=False,
help="A folder containing the training data.")
parser.add_argument(
"--placeholder_token",
type=str,
default=None,
required=False,
help="A token to use as a placeholder for the concept.",
)
parser.add_argument("--initializer_token",
type=str,
default=None,
required=False,
help="A token to use as initializer word.")
parser.add_argument("--learnable_property",
type=str,
default="object",
help="Choose between 'object' and 'style'")
parser.add_argument("--repeats",
type=int,
default=100,
help="How many times to repeat the training data.")
parser.add_argument(
"--output_dir",
type=str,
default="text-inversion-model",
help=
"The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed",
type=int,
default=None,
help="A seed for reproducible training.")
parser.add_argument(
"--height",
type=int,
default=512,
help=
("The height for input images, all the images in the train/validation dataset will be resized to this"
" height"),
)
parser.add_argument(
"--width",
type=int,
default=512,
help=
("The width for input images, all the images in the train/validation dataset will be resized to this"
" width"),
)
parser.add_argument(
"--center_crop",
action="store_true",
help="Whether to center crop images before resizing to resolution")
parser.add_argument(
"--train_batch_size",
type=int,
default=1,
help="Batch size (per device) for the training dataloader.")
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=5000,
help=
"Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=4,
help=
"Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help=
"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=
('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'),
)
parser.add_argument(
"--lr_warmup_steps",
type=int,
default=0,
help="Number of steps for the warmup in the lr scheduler.")
parser.add_argument("--adam_beta1",
type=float,
default=0.9,
help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2",
type=float,
default=0.999,
help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay",
type=float,
default=1e-2,
help="Weight decay to use.")
parser.add_argument("--adam_epsilon",
type=float,
default=1e-08,
help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm",
default=-1,
type=float,
help="Max gradient norm.")
parser.add_argument("--language", default="en", choices=["en", "zh", "zh_en"], help="Model language.")
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=
("[TensorBoard](https://www.tensorflow.org/tensorboard) or [VisualDL](https://www.paddlepaddle.org.cn/paddle/visualdl) log directory. Will default to"
"*output_dir/logs"),
)
parser.add_argument("--writer_type",
type=str,
default="visualdl",
choices=["tensorboard", "visualdl"],
help="Log writer type.")
args = parser.parse_args(args=[])
return args
imagenet_templates_small = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
imagenet_style_templates_small = [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]
zh_imagenet_templates_small = [
"一张{}的照片",
"{}的渲染",
"{}裁剪过的照片",
"一张干净的{}的照片",
"{}的黑暗照片",
"我的{}的照片",
"酷的{}的照片",
"{}的特写照片",
"{}的明亮照片",
"{}的裁剪照片",
"{}的照片",
"{}的好照片",
"一张{}的照片",
"干净的照片{}",
"一张漂亮的{}的照片",
"漂亮的照片{}",
"一张很酷的照片{}",
"一张奇怪的照片{}",
]
zh_imagenet_style_templates_small = [
"一幅{}风格的画",
"{}风格的渲染",
"{}风格的裁剪画",
"{}风格的绘画",
"{}风格的一幅干净的画",
"{}风格的黑暗画作",
"{}风格的图片",
"{}风格的一幅很酷的画",
"{}风格的特写画",
"一幅{}风格的明亮画作",
"{}风格的一幅好画",
"{}风格的特写画",
"{}风格的艺术画",
"一幅{}风格的漂亮画",
"一幅{}风格的奇怪的画",
]
class TextualInversionDataset(Dataset):
def __init__(
self,
data_root,
tokenizer,
learnable_property="object", # [object, style]
height=512,
width=512,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="*",
center_crop=False,
language="en",
):
self.data_root = data_root
self.tokenizer = tokenizer
self.learnable_property = learnable_property
self.height = height
self.width = width
self.placeholder_token = placeholder_token
self.center_crop = center_crop
self.flip_p = flip_p
if not Path(data_root).exists():
raise ValueError(f"{data_root} dir doesn't exists.")
ext = ["png", "jpg", "jpeg", "bmp", "PNG", "JPG", "JPEG", "BMP"]
self.image_paths = []
for e in ext:
self.image_paths.extend(glob.glob(os.path.join(data_root, "*." + e)))
self.num_images = len(self.image_paths)
self._length = self.num_images
if set == "train":
self._length = self.num_images * repeats
self.interpolation = {
"linear": PIL_INTERPOLATION["linear"],
"bilinear": PIL_INTERPOLATION["bilinear"],
"bicubic": PIL_INTERPOLATION["bicubic"],
"lanczos": PIL_INTERPOLATION["lanczos"],
}[interpolation]
self.templates = []
if learnable_property == "style":
if "en" in language:
self.templates.extend(imagenet_style_templates_small)
if "zh" in language:
self.templates.extend(zh_imagenet_style_templates_small)
else:
if "en" in language:
self.templates.extend(imagenet_templates_small)
if "zh" in language:
self.templates.extend(zh_imagenet_templates_small)
self.flip_transform = RandomHorizontalFlip(prob=self.flip_p)
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
placeholder_string = self.placeholder_token
text = random.choice(self.templates).format(placeholder_string)
example["input_ids"] = self.tokenizer(
text,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = (
img.shape[0],
img.shape[1],
)
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
image = Image.fromarray(img)
image = image.resize((self.width, self.height), resample=self.interpolation)
image = self.flip_transform(image)
image = np.array(image).astype(np.uint8)
image = (image / 127.5 - 1.0).astype(np.float32).transpose([2, 0, 1])
example["pixel_values"] = image
return example
def main(args):
rank = paddle.distributed.get_rank()
num_processes = paddle.distributed.get_world_size()
if num_processes > 1:
paddle.distributed.init_parallel_env()
# If passed along, set the training seed now.
if args.seed is not None:
seed = args.seed + rank
set_seed(seed)
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load the tokenizer and add the placeholder token as a additional special token
try:
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
elif args.pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.pretrained_model_name_or_path, "tokenizer"))
except:
tokenizer = XLMRobertaTokenizer.from_pretrained(os.path.join(args.pretrained_model_name_or_path, "tokenizer"))
# Add the placeholder token in tokenizer
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
if num_added_tokens == 0:
raise ValueError(f"单词 {args.placeholder_token} 原本就已经存在了哦. 请用一个新的词汇.")
# Convert the initializer_token, placeholder_token to ids
token_ids = tokenizer.encode(args.initializer_token,
add_special_tokens=False)["input_ids"]
# Check if initializer_token is a single token or a sequence of tokens
if len(token_ids) > 1:
# raise ValueError("The initializer token must be a single token.")
print(
f"用来初始化的 ‘最接近的单词’ 只能是一个简单词, {args.initializer_token} 不可以哟, 因此我们使用随机生成的单词!")
initializer_token_id = token_ids[0]
placeholder_token_id = tokenizer.convert_tokens_to_ids(
args.placeholder_token)
# Load models and create wrapper for stable diffusion
if args.text_encoder is None:
# Load models and create wrapper for stable diffusion
if "Taiyi-Stable-Diffusion-1B-Chinese-v0.1" in args.pretrained_model_name_or_path:
model_cls = BertModel
if "AltDiffusion" in args.pretrained_model_name_or_path:
model_cls = RobertaSeriesModelWithTransformation
else:
model_cls = CLIPTextModel
text_encoder = model_cls.from_pretrained(os.path.join(args.pretrained_model_name_or_path, "text_encoder"))
else:
text_encoder = args.text_encoder
if args.vae is None:
vae = AutoencoderKL.from_pretrained(args.model_name, subfolder="vae")
else:
vae = args.vae
if args.unet is None:
unet = UNet2DConditionModel.from_pretrained(args.model_name,
subfolder="unet")
else:
unet = args.unet
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
eval_scheduler = PNDMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
# Initialise the newly added placeholder token with the embeddings of the initializer token
with paddle.no_grad():
token_embeds = text_encoder.get_input_embeddings()
if len(token_ids) == 1:
token_embeds.weight[placeholder_token_id] = token_embeds.weight[
initializer_token_id]
# Freeze vae and unet
freeze_params(vae.parameters())
freeze_params(unet.parameters())
# Freeze all parameters except for the token embeddings in text encoder
if isinstance(text_encoder, BertModel):
# bert text_encoder
params_to_freeze = itertools.chain(
text_encoder.encoder.parameters(),
text_encoder.pooler.parameters(),
text_encoder.embeddings.position_embeddings.parameters(),
text_encoder.embeddings.token_type_embeddings.parameters(),
text_encoder.embeddings.layer_norm.parameters(),
)
# Freeze all parameters except for the token embeddings in text encoder
elif isinstance(text_encoder, RobertaSeriesModelWithTransformation):
# roberta text_encoder
params_to_freeze = itertools.chain(
text_encoder.transformation.parameters(),
text_encoder.roberta.encoder.parameters(),
text_encoder.roberta.pooler.parameters(),
text_encoder.roberta.embeddings.position_embeddings.parameters(),
text_encoder.roberta.embeddings.token_type_embeddings.parameters(),
text_encoder.roberta.embeddings.layer_norm.parameters(),
)
else:
# clip text_encoder
params_to_freeze = itertools.chain(
text_encoder.text_model.transformer.parameters(),
text_encoder.text_model.ln_final.parameters(),
text_encoder.text_model.positional_embedding.parameters(),
)
freeze_params(params_to_freeze)
if args.scale_lr:
args.learning_rate = (args.learning_rate *
args.gradient_accumulation_steps *
args.train_batch_size * num_processes)
lr_scheduler = get_scheduler(
args.lr_scheduler,
learning_rate=args.learning_rate,
num_warmup_steps=args.lr_warmup_steps *
args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps *
args.gradient_accumulation_steps,
)
# Initialize the optimizer
optimizer = AdamW(learning_rate=lr_scheduler,
parameters=text_encoder.get_input_embeddings().parameters(),
beta1=args.adam_beta1,
beta2=args.adam_beta2,
weight_decay=args.adam_weight_decay,
epsilon=args.adam_epsilon,
grad_clip=nn.ClipGradByGlobalNorm(args.max_grad_norm) if args.max_grad_norm > 0 else None)
if num_processes > 1:
text_encoder = paddle.DataParallel(text_encoder)
train_dataset = TextualInversionDataset(
data_root=args.train_data_dir,
tokenizer=tokenizer,
height=args.height,
width=args.width,
placeholder_token=args.placeholder_token,
repeats=args.repeats,
learnable_property=args.learnable_property,
center_crop=args.center_crop,
set="train",
language=args.language,
)
def collate_fn(examples):
input_ids = [example["input_ids"] for example in examples]
pixel_values = paddle.to_tensor([example["pixel_values"] for example in examples], dtype="float32")
input_ids = tokenizer.pad(
{"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pd"
).input_ids
batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
}
return batch
train_sampler = DistributedBatchSampler(
train_dataset, batch_size=args.train_batch_size,
shuffle=True) if num_processes > 1 else BatchSampler(
train_dataset, batch_size=args.train_batch_size, shuffle=True)
train_dataloader = DataLoader(train_dataset,
batch_sampler=train_sampler,
collate_fn=collate_fn)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps /
num_update_steps_per_epoch)
if rank == 0:
writer = get_writer(args)
# Train!
total_batch_size = args.train_batch_size * args.gradient_accumulation_steps * num_processes
progress_bar = tqdm(range(args.max_train_steps), disable=rank > 0)
progress_bar.set_description("Train Steps")
global_step = 0
text_encoder_embedding_clone = unwrap_model(
text_encoder).get_input_embeddings().weight.clone()
# Keep vae and unet in eval model as we don't train these
vae.eval()
unet.eval()
text_encoder.train()
try:
for epoch in range(args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = paddle.randn(latents.shape)
batch_size = latents.shape[0]
# Sample a random timestep for each image
timesteps = paddle.randint(0, noise_scheduler.config.num_train_timesteps, (batch_size,), dtype="int64")
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the unet output
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
loss.backward()
with paddle.no_grad():
# Get the index for tokens that we want to zero the grads for
index_grads_to_zero = (paddle.arange(
len(tokenizer)) == placeholder_token_id
).astype("float32").unsqueeze(-1)
unwrap_model(text_encoder).get_input_embeddings(
).weight.grad = unwrap_model(
text_encoder).get_input_embeddings(
).weight.grad * index_grads_to_zero
if (step + 1) % args.gradient_accumulation_steps == 0:
optimizer.step()
with paddle.no_grad():
unwrap_model(text_encoder).get_input_embeddings(
).weight[:-1] = text_encoder_embedding_clone[:-1]
lr_scheduler.step()
optimizer.clear_grad()
progress_bar.update(1)
global_step += 1
logs = {
"epoch":
str(epoch).zfill(4),
"step_loss":
round(loss.item() * args.gradient_accumulation_steps,
10),
"lr":
lr_scheduler.get_lr()
}
progress_bar.set_postfix(**logs)
if rank == 0:
for name, val in logs.items():
if name == "epoch": continue
writer.add_scalar(f"train/{name}",
val,
step=global_step)
if global_step % args.save_steps == 0:
save_progress(text_encoder, placeholder_token_id,
args, global_step)
images, pil_images = generate_image(text_encoder, unet, vae, tokenizer, eval_scheduler, args)
writer.add_image("images", images, step=global_step, dataformats="NHWC")
name = args.placeholder_token
name = name.translate({
92: 95,
47: 95,
42: 95,
34: 95,
58: 95,
63: 95,
60: 95,
62: 95,
124: 95
})
image_grid(pil_images).save(os.path.join(args.output_dir, "step-"+str(global_step), f"{name}.jpg"))
if global_step >= args.max_train_steps:
break
if rank == 0:
writer.close()
save_progress(text_encoder, placeholder_token_id, args, global_step)
print(f'训练完毕, 可以用新词 {args.placeholder_token} 去生成图片了.')
del text_encoder
del optimizer
del vae
del unet
del text_encoder_embedding_clone
gc.collect()
except:
save_progress(text_encoder, placeholder_token_id, args, global_step)
del text_encoder
del optimizer
del vae
del unet
del text_encoder_embedding_clone
gc.collect()
# Code credits to 凉心半浅良心人
# Has modified
import os
os.environ['PPNLP_HOME'] = "./model_weights"
model_name_list = [
"Linaqruf/anything-v3.0",
"MoososCap/NOVEL-MODEL",
"Baitian/momocha",
"Baitian/momoco",
"hequanshaguo/monoko-e",
"ruisi/anything",
"hakurei/waifu-diffusion-v1-3",
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2",
"stabilityai/stable-diffusion-2-base",
"hakurei/waifu-diffusion",
"naclbit/trinart_stable_diffusion_v2_60k",
"naclbit/trinart_stable_diffusion_v2_95k",
"naclbit/trinart_stable_diffusion_v2_115k",
"ringhyacinth/nail-set-diffuser",
"Deltaadams/Hentai-Diffusion",
"BAAI/AltDiffusion",
"BAAI/AltDiffusion-m9",
"IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1",
"IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-EN-v0.1",
"huawei-noah/Wukong-Huahua"]
import time
from IPython.display import clear_output, display
from .png_info_helper import serialize_to_pnginfo, imageinfo_to_pnginfo
from .env import DEBUG_UI
if not DEBUG_UI:
from .utils import diffusers_auto_update
diffusers_auto_update()
#from tqdm.auto import tqdm
import paddle
from .textual_inversion import parse_args as textual_inversion_parse_args
from .textual_inversion import main as textual_inversion_main
from .dreambooth import parse_args as dreambooth_parse_args
from .dreambooth import main as dreambooth_main
from .utils import StableDiffusionFriendlyPipeline, SuperResolutionPipeline, diffusers_auto_update
from .utils import compute_gpu_memory, empty_cache
from .utils import save_image_info
from .convert import parse_args as convert_parse_args
from .convert import main as convert_parse_main
#_ENABLE_ENHANCE = False
if paddle.device.get_device() != 'cpu':
# settings for super-resolution, currently not supporting multi-gpus
# see docs at https://github.com/PaddlePaddle/PaddleHub/tree/develop/modules/image/Image_editing/super_resolution/falsr_a
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
pipeline_superres = SuperResolutionPipeline()
pipeline = StableDiffusionFriendlyPipeline(superres_pipeline = pipeline_superres)
else:
pipeline_superres = None
pipeline = None
####################################################################
#
# Graphics User Interface
#
####################################################################
# Code to turn kwargs into Jupyter widgets
import ipywidgets as widgets
from ipywidgets import Layout,HBox,VBox,Box
from collections import OrderedDict
# Allows long widget descriptions
style = {'description_width': 'initial'}
# Force widget width to max
layout = widgets.Layout(width='100%')
def get_widget_extractor(widget_dict):
# allows accessing after setting, this is to reduce the diff against the argparse code
class WidgetDict(OrderedDict):
def __getattr__(self, val):
x = self.get(val)
return x.value if x is not None else None
return WidgetDict(widget_dict)
class StableDiffusionUI():
def __init__(self, pipeline = pipeline):
self.widget_opt = OrderedDict()
self.pipeline = pipeline
self.gui = None
self.run_button = None
self.run_button_out = widgets.Output()
self.task = 'txt2img'
def on_run_button_click(self, b):
with self.run_button_out:
clear_output()
self.pipeline.run(
get_widget_extractor(self.widget_opt),
task = self.task,
on_image_generated = self.on_image_generated
)
def on_image_generated(self, image, options, count = 0, total = 1, image_info = None):
# 超分
# --------------------------------------------------
if self.task == 'superres':
cur_time = time.time()
os.makedirs(options.output_dir, exist_ok = True)
image_path = os.path.join(
options.output_dir,
time.strftime(f'%Y-%m-%d_%H-%M-%S_Highres.png')
)
image.save(
image_path,
quality=100,
pnginfo = imageinfo_to_pnginfo(image_info) if image_info is not None else None
)
clear_output()
display(widgets.Image.from_file(image_path))
return
# 图生图/文生图
# --------------------------------------------------
image_path = save_image_info(image, options.output_dir, image_info)
if count % 5 == 0:
clear_output()
try:
# 使显示的图片包含嵌入信息
display(widgets.Image.from_file(image_path))
except:
display(image)
if 'seed' in image.argument['seed']:
print('Seed = ', image.argument['seed'],
' (%d / %d ... %.2f%%)'%(count + 1, total, (count + 1.) / total * 100))
####################################################################
#
# Training
#
####################################################################
class StableDiffusionTrainUI():
def __init__(self, pipeline = pipeline):
self.widget_opt = OrderedDict()
self.gui = None
self.run_button = None
self.run_button_out = widgets.Output()
self.pipeline = pipeline
# function pointers
#self.parse_args = None #Test
self.main = None
def run(self, opt):
args = self.parse_args()
for k, v in opt.items():
setattr(args, k, v.value)
self.pipeline.from_pretrained(model_name=opt.model_name)
# todo junnyu
args.pretrained_model_name_or_path = opt.model_name
if args.language == "en":
if "chinese-en" in args.pretrained_model_name_or_path.lower():
args.language = "zh_en"
elif "chinese" in args.pretrained_model_name_or_path.lower():
args.language = "zh"
if args.image_logging_prompt is None:
args.image_logging_prompt = args.placeholder_token
## done
args.text_encoder = self.pipeline.pipe.text_encoder
args.unet = self.pipeline.pipe.unet
args.vae = self.pipeline.pipe.vae
if compute_gpu_memory() <= 17. or args.height==768 or args.width==768:
args.train_batch_size = 1
args.gradient_accumulation_steps = 4
else:
args.train_batch_size = 4
args.gradient_accumulation_steps = 1
if args.pretrained_model_name_or_path in ["stabilityai/stable-diffusion-2", "stabilityai/stable-diffusion-2-base", "BAAI/AltDiffusion", "BAAI/AltDiffusion-m9"]:
args.train_batch_size = 1
args.gradient_accumulation_steps = 4
if args.train_data_dir is None:
raise ValueError("You must specify a train data directory.")
# remove \/"*:?<>| in filename
name = args.placeholder_token
name = name.translate({92: 95, 47: 95, 42: 95, 34: 95, 58: 95, 63: 95, 60: 95, 62: 95, 124: 95})
args.logging_dir = os.path.join(args.output_dir, 'logs', name)
self.main(args)
empty_cache()
def on_run_button_click(self, b):
with self.run_button_out:
clear_output()
self.run(get_widget_extractor(self.widget_opt))
class StableDiffusionUI_text_inversion(StableDiffusionTrainUI):
def __init__(self, **kwargs):
super().__init__()
self.parse_args = textual_inversion_parse_args
self.main = textual_inversion_main
#默认参数覆盖次序:
#user_config.py > config.py > 当前args > 实例化
args = { #注意无效Key错误
"learnable_property": 'object',
"placeholder_token": '<Alice>',
"initializer_token": 'girl',
"repeats": '100',
"train_data_dir": 'resources/Alices',
"output_dir": 'outputs/textual_inversion',
"height": 512,
"width": 512,
"learning_rate": 5e-6,
"max_train_steps": 1000,
"save_steps": 200,
"model_name": "MoososCap/NOVEL-MODEL",
}
args.update(kwargs)
layoutCol12 = Layout(
flex = "12 12 90%",
margin = "0.5em",
max_width = "100%",
align_items = "center"
)
styleDescription = {
'description_width': "10rem"
}
widget_opt = self.widget_opt
widget_opt['learnable_property'] = widgets.Dropdown(
layout=layoutCol12, style=styleDescription,
description='训练目标',
description_tooltip='训练目标是什么?风格还是实体?',
value="object",
options=[
('风格(style)', "style"),
('实体(object)', "object"),
],
orientation='horizontal',
disabled=False
)
widget_opt['placeholder_token'] = widgets.Text(
layout=layoutCol12, style=styleDescription,
description='用来表示该内容的新词',
description_tooltip='用来表示该内容的新词,建议用<>封闭',
value="<Alice>",
disabled=False
)
widget_opt['initializer_token'] = widgets.Text(
layout=layoutCol12, style=styleDescription,
description='该内容最接近的单词是',
description_tooltip='该内容最接近的单词是?若无则用*表示',
value="girl",
disabled=False
)
widget_opt['repeats'] = widgets.IntText(
layout=layoutCol12, style=styleDescription,
description='图片重复次数',
description_tooltip='训练图片需要重复多少遍',
value="100",
disabled=False
)
widget_opt['train_data_dir'] = widgets.Text(
layout=layoutCol12, style=styleDescription,
description='训练图片的文件夹路径',
value="resources/Alices",
disabled=False
)
widget_opt['output_dir'] = widgets.Text(
layout=layoutCol12, style=styleDescription,
description='训练结果的保存路径',
value="outputs/textual_inversion",
disabled=False
)
widget_opt['height'] = widgets.IntSlider(
layout=layoutCol12, style=styleDescription,
description='训练图片的高度',
description_tooltip='训练图片的高度。越大尺寸,消耗的显存也越多。',
value=512,
min=64,
max=1024,
step=64,
disabled=False
)
widget_opt['width'] = widgets.IntSlider(
layout=layoutCol12, style=styleDescription,
description='训练图片的宽度',
description_tooltip='训练图片的宽度。越大尺寸,消耗的显存也越多。',
value=512,
min=64,
max=1024,
step=64,
disabled=False
)
widget_opt['learning_rate'] = widgets.FloatText(
layout=layoutCol12, style=styleDescription,
description='训练学习率',
description_tooltip='训练学习率',
value=5e-4,
step=1e-4,
disabled=False
)
widget_opt['max_train_steps'] = widgets.IntText(
layout=layoutCol12, style=styleDescription,
description='最大训练步数',
description_tooltip='最大训练步数',
value=1000,
step=100,
disabled=False
)
widget_opt['save_steps'] = widgets.IntText(
layout=layoutCol12, style=styleDescription,
description='每隔多少步保存模型',
value=200,
step=100,
disabled=False
)
widget_opt['model_name'] = widgets.Combobox(
layout=layoutCol12, style=styleDescription,
description='需要训练的模型名称',
value="MoososCap/NOVEL-MODEL",
options=model_name_list,
ensure_option=False,
disabled=False
)
for key in widget_opt:
if (key in args) and (args[key] != widget_opt[key].value):
widget_opt[key].value = args[key]
self.run_button = widgets.Button(
description='开始训练',
disabled=False,
button_style='success', # 'success', 'info', 'warning', 'danger' or ''
tooltip='点击运行(配置将自动更新)',
icon='check'
)
self.run_button.on_click(self.on_run_button_click)
self.gui = Box([
Box([
widget_opt['learnable_property'],
widget_opt['placeholder_token'],
widget_opt['initializer_token'],
widget_opt['train_data_dir'],
widget_opt['width'],
widget_opt['height'],
widget_opt['repeats'],
widget_opt['learning_rate'],
widget_opt['max_train_steps'],
widget_opt['save_steps'],
widget_opt['model_name'],
widget_opt['output_dir'],
], layout = Layout(
display = "flex",
flex_flow = "row wrap", #HBox会覆写此属性
align_items = "center",
max_width = '100%',
)),
self.run_button,
self.run_button_out
], layout = Layout(display="block",margin="0 45px 0 0")
)
#Dreambooth训练
class StableDiffusionDreamboothUI():
def __init__(self, pipeline = pipeline):
self.widget_opt = OrderedDict()
self.gui = None
self.run_button = None
self.run_button_out = widgets.Output()
self.pipeline = pipeline
# function pointers
self.parse_args = None
self.main = None
def run(self, opt):
args = self.parse_args()
for k, v in opt.items():
setattr(args, k, v.value)
args.train_batch_size = 1
args.gradient_accumulation_steps = 1
if args.instance_data_dir is None:
raise ValueError("You must specify a train data directory.")
# remove \/"*:?<>| in filename
name = args.instance_prompt
name = name.translate({92: 95, 47: 95, 42: 95, 34: 95, 58: 95, 63: 95, 60: 95, 62: 95, 124: 95})
args.logging_dir = os.path.join(args.output_dir, 'logs', name)
self.main(args)
empty_cache()
def on_run_button_click(self, b):
with self.run_button_out:
clear_output()
self.run(get_widget_extractor(self.widget_opt))
class StableDiffusionUI_dreambooth(StableDiffusionDreamboothUI):
def __init__(self, **kwargs):
super().__init__()
self.parse_args = dreambooth_parse_args #配置加载
self.main = dreambooth_main
args = { #注意无效Key错误
"pretrained_model_name_or_path": "MoososCap/NOVEL-MODEL",# 预训练模型名称/路径
"instance_data_dir": 'resources/Alices',
"instance_prompt": 'a photo of Alices',
"class_data_dir": 'resources/Girls',
"class_prompt": 'a photo of girl',
"num_class_images": 100,
"prior_loss_weight": 1.0,
"with_prior_preservation": True,
#"num_train_epochs": 1,
"max_train_steps": 1000,
"save_steps": 1000,
"train_text_encoder": False,
"height": 512,
"width": 512,
"learning_rate": 5e-4,
"lr_scheduler": "constant",
"lr_warmup_steps": 500,
"center_crop": True,
"output_dir": 'outputs/dreambooth',
}
args.update(kwargs)
layoutCol12 = Layout(
flex = "12 12 90%",
margin = "0.5em",
max_width = "100%",
align_items = "center"
)
styleDescription = {
'description_width': "10rem"
}
widget_opt = self.widget_opt
widget_opt['pretrained_model_name_or_path'] = widgets.Combobox(
layout=layoutCol12, style=styleDescription,
description='需要训练的模型名称',
value="MoososCap/NOVEL-MODEL",
options=model_name_list,
ensure_option=False,
disabled=False
)
widget_opt['instance_data_dir'] = widgets.Text(
layout=layoutCol12, style=styleDescription,
description='实例(物体)图片文件夹地址。',
description_tooltip='你要训练的特殊图片目录(人物,背景等)',
value="resources/Alices",
disabled=False
)
widget_opt['instance_prompt'] = widgets.Text(
layout=layoutCol12, style=styleDescription,
description='实例(物体)的提示词描述文本',
description_tooltip='带有特定实例 物体的提示词描述文本例如『a photo of sks dog』其中dog代表实例物体。',
value="a photo of Alices",
disabled=False
)
widget_opt['class_data_dir'] = widgets.Text(
layout=layoutCol12, style=styleDescription,
description='类别 class 图片文件夹地址',
description_tooltip='类别 class 图片文件夹地址,这个文件夹里可以不放东西,会自动生成',
value="resources/Girls",
disabled=False
)
widget_opt['class_prompt'] = widgets.Text(
layout=layoutCol12, style=styleDescription,
description='类别class提示词文本',
description_tooltip='该提示器要与实例物体是同一种类别 例如『a photo of dog』',
value="a photo of girl",
disabled=False
)
widget_opt['num_class_images'] = widgets.IntText(
layout=layoutCol12, style=styleDescription,
description='类别class提示词对应图片数',
description_tooltip='如果文件夹里图片不够会自动补全',
value=100,
disabled=False
)
widget_opt['with_prior_preservation'] = widgets.Dropdown(
layout=layoutCol12, style=styleDescription,
description='是否将生成的同类图片(先验知识)一同加入训练',
description_tooltip='当开启的时候 上面的设置才生效。',
value=True,
options=[
('开启', True),
('关闭', False),
],
disabled=False
)
widget_opt['prior_loss_weight'] = widgets.FloatText(
layout=layoutCol12, style=styleDescription,
description='先验loss占比权重',
description_tooltip='不用改',
value=1.0,
disabled=False
)
'''widget_opt['num_train_epochs'] = widgets.IntText(
layout=layoutCol12, style=styleDescription,
description='训练的轮数',
description_tooltip='与最大训练步数互斥',
value=1,
disabled=False
)'''
widget_opt['max_train_steps'] = widgets.IntText(
layout=layoutCol12, style=styleDescription,
description='最大训练步数',
description_tooltip='当我们设置这个值后它会重新计算所需的轮数',
value=1000,
disabled=False
)
widget_opt['save_steps'] = widgets.IntText(
layout=layoutCol12, style=styleDescription,
description='模型保存步数',
description_tooltip='达到这个数后会保存模型',
value=1000,
disabled=False
)
widget_opt['train_text_encoder'] = widgets.Dropdown(
layout=layoutCol12, style=styleDescription,
description='是否一同训练文本编码器的部分',
description_tooltip='可以理解为是否同时训练textual_inversion',
value=False,
options=[
('开启', True),
('关闭', False),
],
disabled=False
)
widget_opt['height'] = widgets.IntSlider(
layout=layoutCol12, style=styleDescription,
description='训练图片的高度',
description_tooltip='训练图片的高度。越大尺寸,消耗的显存也越多。',
value=512,
min=64,
max=1024,
step=64,
disabled=False
)
widget_opt['width'] = widgets.IntSlider(
layout=layoutCol12, style=styleDescription,
description='训练图片的宽度',
description_tooltip='训练图片的宽度。越大尺寸,消耗的显存也越多。',
value=512,
min=64,
max=1024,
step=64,
disabled=False
)
widget_opt['learning_rate'] = widgets.FloatText(
layout=layoutCol12, style=styleDescription,
description='训练学习率',
description_tooltip='训练学习率',
value=5e-6,
step=1e-6,
disabled=False
)
widget_opt['lr_scheduler'] = widgets.Dropdown(
layout=layoutCol12, style=styleDescription,
description='学习率调度策略',
description_tooltip='可以选不同的学习率调度策略',
value='constant',
options=[
('linear', "linear"),
('cosine', "cosine"),
('cosine_with_restarts', "cosine_with_restarts"),
('polynomial', "polynomial"),
('constant', "constant"),
('constant_with_warmup', "constant_with_warmup"),
],
disabled=False
)
widget_opt['lr_warmup_steps'] = widgets.IntText(
layout=layoutCol12, style=styleDescription,
description='线性 warmup 的步数',
description_tooltip='用于从 0 到 learning_rate 的线性 warmup 的步数。',
value=500,
disabled=False
)
widget_opt['center_crop'] = widgets.Dropdown(
layout=layoutCol12, style=styleDescription,
description='自动裁剪图片时将人像居中',
description_tooltip='自动裁剪图片时将人像居中',
value=False,
options=[
('开启', True),
('关闭', False),
],
disabled=False
)
widget_opt['output_dir'] = widgets.Text(
layout=layoutCol12, style=styleDescription,
description='输出目录',
description_tooltip='训练好模型输出的地方',
value="outputs/dreambooth",
disabled=False
)
for key in widget_opt:
if (key in args) and (args[key] != widget_opt[key].value):
widget_opt[key].value = args[key]
self.run_button = widgets.Button(
description='开始训练',
disabled=False,
button_style='success', # 'success', 'info', 'warning', 'danger' or ''
tooltip='点击运行(配置将自动更新)',
icon='check'
)
self.run_button.on_click(self.on_run_button_click)
self.gui = Box([
Box([
widget_opt['pretrained_model_name_or_path'],
widget_opt['instance_data_dir'],
widget_opt['instance_prompt'],
widget_opt['class_data_dir'],
widget_opt['class_prompt'],
widget_opt['num_class_images'],
widget_opt['prior_loss_weight'],
widget_opt['with_prior_preservation'],
#widget_opt['num_train_epochs'],
widget_opt['max_train_steps'],
widget_opt['save_steps'],
widget_opt['train_text_encoder'],
widget_opt['height'],
widget_opt['width'],
widget_opt['learning_rate'],
widget_opt['lr_scheduler'],
widget_opt['lr_warmup_steps'],
widget_opt['center_crop'],
widget_opt['output_dir'],
], layout = Layout(
display = "flex",
flex_flow = "row wrap", #HBox会覆写此属性
align_items = "center",
max_width = '100%',
)),
self.run_button,
self.run_button_out
], layout = Layout(display="block",margin="0 45px 0 0")
)
#####################################
#M0DE1 C0NVERT
##############################
class StableDiffusionConvertUI():
def __init__(self, pipeline = pipeline):
self.widget_opt = OrderedDict()
self.gui = None
self.run_button = None
self.run_button_out = widgets.Output()
self.pipeline = pipeline
# function pointers
self.parse_args = None
self.main = None
def run(self, opt):
args = self.parse_args()
for k, v in opt.items():
setattr(args, k, v.value)
if args.checkpoint_path is None:
raise ValueError("你必须给出一个可用的ckpt模型路径")
self.main(args)
empty_cache()
def on_run_button_click(self, b):
with self.run_button_out:
clear_output()
self.run(get_widget_extractor(self.widget_opt))
class StableDiffusionUI_convert(StableDiffusionConvertUI):
def __init__(self, **kwargs):
super().__init__()
self.parse_args = convert_parse_args #配置加载
self.main = convert_parse_main
args = { #注意无效Key错误
"checkpoint_path": '',
"vae_checkpoint_path": '',
"extract_ema": False,
'dump_path': 'outputs/convert'
}
args.update(kwargs)
layoutCol12 = Layout(
flex = "12 12 90%",
margin = "0.5em",
max_width = "100%",
align_items = "center"
)
styleDescription = {
'description_width': "10rem"
}
widget_opt = self.widget_opt
widget_opt['checkpoint_path'] = widgets.Text(
layout=layoutCol12, style=styleDescription,
description='ckpt或safetensors模型文件位置',
description_tooltip='你要转换的模型位置',
value=" ",
disabled=False
)
widget_opt['extract_ema'] = widgets.Dropdown(
layout=layoutCol12, style=styleDescription,
description='是否提取ema权重',
description_tooltip='是否提取ema权重',
value=False,
options=[
('是', True),
('否', False),
],
orientation='horizontal',
disabled=False
)
widget_opt['vae_checkpoint_path'] = widgets.Text(
layout=layoutCol12, style=styleDescription,
description='vae文件位置',
description_tooltip='你要转换的vae模型位置',
value=" ",
disabled=False
)
widget_opt['dump_path'] = widgets.Text(
layout=layoutCol12, style=styleDescription,
description='输出目录',
description_tooltip='转换模型输出的地方',
value="outputs/convert",
disabled=False
)
for key in widget_opt:
if (key in args) and (args[key] != widget_opt[key].value):
widget_opt[key].value = args[key]
self.run_button = widgets.Button(
description='开始转换',
disabled=False,
button_style='success', # 'success', 'info', 'warning', 'danger' or ''
tooltip='点击运行(配置将自动更新)',
icon='check'
)
self.run_button.on_click(self.on_run_button_click)
self.gui = Box([
Box([
widget_opt['checkpoint_path'],
widget_opt['extract_ema'],
widget_opt['vae_checkpoint_path'],
widget_opt['dump_path'],
], layout = Layout(
display = "flex",
flex_flow = "row wrap", #HBox会覆写此属性
align_items = "center",
max_width = '100%',
)),
self.run_button,
self.run_button_out
], layout = Layout(display="block",margin="0 45px 0 0")
)
#pt加载功能基于群内@作者版本修改
import os
import time
from contextlib import nullcontext, contextmanager
from IPython.display import clear_output, display
from pathlib import Path
from PIL import Image
from .png_info_helper import serialize_to_text, serialize_to_pnginfo
import paddle
_VAE_SIZE_THRESHOLD_ = 300000000 # vae should not be smaller than this
_MODEL_SIZE_THRESHOLD_ = 3000000000 # model should not be smaller than this
def compute_gpu_memory():
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
return round(meminfo.total / 1024 / 1024 / 1024, 2)
def empty_cache():
"""Empty CUDA cache. Essential in stable diffusion pipeline."""
import gc
gc.collect()
def check_is_model_complete(path = None, check_vae_size=_VAE_SIZE_THRESHOLD_):
"""Auto check whether a model is complete by checking the size of vae > check_vae_size.
The vae of the model should be named by model_state.pdparams."""
path = path or os.path.join('./',os.path.basename(model_get_default())).rstrip('.zip')
return os.path.exists(os.path.join(path,'vae/model_state.pdparams')) and\
os.path.getsize(os.path.join(path,'vae/model_state.pdparams')) > check_vae_size
def model_get_default(base_path = '/home/aistudio/data'):
"""Return an absolute path of model zip file in the `base_path`."""
available_models = []
for folder in os.walk(base_path):
for filename_ in folder[2]:
filename = os.path.join(folder[0], filename_)
if filename.endswith('.zip') and os.path.isfile(filename) and os.path.getsize(filename) > _MODEL_SIZE_THRESHOLD_:
available_models.append((os.path.getsize(filename), filename, filename_))
available_models.sort()
# use the model with smallest size to save computation
return available_models[0][1]
def model_vae_get_default(base_path = 'data'):
"""Return an absolute path of extra vae if there is any."""
for folder in os.walk(base_path):
for filename_ in folder[2]:
filename = os.path.join(folder[0], filename_)
if filename.endswith('vae.pdparams'):
return filename
return None
def model_unzip(abs_path = None, name = None, dest_path = './', verbose = True):
"""Unzip a model from `abs_path`, `name` is the model name after unzipping."""
if abs_path is None:
abs_path = model_get_default()
if name is None:
name = os.path.basename(abs_path)
from zipfile import ZipFile
dest = os.path.join(dest_path, name).rstrip('.zip')
if not check_is_model_complete(dest):
if os.path.exists(dest):
# clear the incomplete zipfile
if verbose: print('检测到模型文件破损, 正在删除......')
import shutil
shutil.rmtree(dest)
if verbose: print('正在解压模型......')
with ZipFile(abs_path, 'r') as f:
f.extractall(dest_path)
else:
print('模型已存在')
def package_install(verbose = True):
try:
import safetensors
from ppdiffusers.utils import image_grid
from paddlenlp.transformers.clip.feature_extraction import CLIPFeatureExtractor
from paddlenlp.transformers import FeatureExtractionMixin
except (ModuleNotFoundError, ImportError, AttributeError):
if verbose: print('检测到库不完整, 正在安装库')
os.system("pip install -U pip -i https://mirror.baidu.com/pypi/simple")
os.system("pip install -U OmegaConf --user")
os.system("pip install ppdiffusers==0.9.0 --user")
os.system("pip install paddlenlp==2.4.9 --user")
os.system("pip install -U safetensors --user")
clear_output()
def diffusers_auto_update(verbose = True):
package_install(verbose=verbose)
def try_get_catched_model(model_name):
path = os.path.join('./models/', model_name)
if check_is_model_complete(path):
return path
path = os.path.join('./', model_name)
if check_is_model_complete(path):
return path
return model_name
@contextmanager
def context_nologging():
import logging
logging.disable(100)
try:
yield
finally:
logging.disable(30)
def save_image_info(image, path = './outputs/', existing_info = None):
"""Save image to a path with arguments."""
os.makedirs(path, exist_ok=True)
seed = image.argument['seed']
filename = time.strftime(f'%Y-%m-%d_%H-%M-%S_SEED_{seed}')
pnginfo_data = serialize_to_pnginfo(image.argument, existing_info)
info_text = serialize_to_text(image.argument)
with open(os.path.join(path, filename + '.txt'), 'w') as f:
f.write('Prompt: '+ info_text)
image_path = os.path.join(path, filename + '.png')
image.save(image_path,
quality=100,
pnginfo=pnginfo_data
)
return image_path
def ReadImage(image, height = None, width = None):
"""
Read an image and resize it to (height,width) if given.
If (height,width) = (-1,-1), resize it so that
it has w,h being multiples of 64 and in medium size.
"""
if isinstance(image, str):
image = Image.open(image).convert('RGB')
# clever auto inference of image size
w, h = image.size
if height == -1 or width == -1:
if w > h:
width = 768
height = max(64, round(width / w * h / 64) * 64)
else: # w < h
height = 768
width = max(64, round(height / h * w / 64) * 64)
if width > 576 and height > 576:
width = 576
height = 576
if (height is not None) and (width is not None) and (w != width or h != height):
image = image.resize((width, height), Image.ANTIALIAS)
return image
def convert_pt_to_pdparams(path, dim = 768, save = True):
"""Unsafe .pt embedding to .pdparams."""
path = str(path)
assert path.endswith('.pt'), 'Only support conversion of .pt files.'
import struct
with open(path, 'rb') as f:
data = f.read()
data = ''.join(chr(i) for i in data)
# locate the tensor in the file
tensors = []
for chunk in data.split('ZZZZ'):
chunk = chunk.strip('Z').split('PK')
if len(chunk) == 0:
continue
tensor = ''
for i in range(len(chunk)):
# extract the string with length 768 * (4k)
tensor += 'PK' + chunk[i]
if len(tensor) > 2 and len(tensor) % (dim * 4) == 2:
# remove the leading 'PK'
tensors.append(tensor[2:])
tensor = max(tensors, key = lambda x: len(x))
# convert back to binary representation
tensor = tensor.encode('latin')
# every four chars represent a float32
tensor = [struct.unpack('f', tensor[i:i+4])[0] for i in range(0, len(tensor), 4)]
tensor = paddle.to_tensor(tensor).reshape((-1, dim))
if tensor.shape[0] == 1:
tensor = tensor.flatten()
if save:
# locate the name of embedding
name = ''.join(filter(lambda x: ord(x) > 20, data.split('nameq\x12X')[1].split('q\x13X')[0]))
paddle.save({name: tensor}, path[:-3] + '.pdparams')
return tensor
def get_multiple_tokens(token, num = 1, ret_list = True):
"""Parse a single token to multiple tokens."""
tokens = ['%s_EMB_TOKEN_%d'%(token, i) for i in range(num)]
if ret_list:
return tokens
return ' '.join(tokens)
def collect_local_module_names(base_paths = None):
"""从指定位置检索可用的模型名称,以用于UI选择模型"""
base_paths = (
'./',
'./models',
os.path.join(os.environ['PPNLP_HOME'], 'models')
) if base_paths is None \
else (base_paths,) if isinstance(base_paths, str) \
else base_paths
is_model = lambda base, name: (os.path.isfile(
os.path.join(base, name,'model_index.json')
)) and (os.path.isfile(
os.path.join(base, name,'vae', 'config.json')
)) and (os.path.isfile(
os.path.join(base, name,'unet', 'model_state.pdparams')
))
models = []
for base_path in base_paths:
if not os.path.isdir(base_path): continue
for name in os.listdir(base_path):
if name.startswith('.'): continue
path = os.path.join(base_path, name)
if path in base_paths: continue
if not os.path.isdir(path): continue
if is_model(base_path, name):
models.append(name)
continue
for name2 in os.listdir(path):
if name.startswith('.'): continue
path2 = os.path.join(path, name2)
if os.path.isdir(path2) and is_model(path, name2):
models.append(f'{name}/{name2}')
continue
sorted(models)
return models
class StableDiffusionFriendlyPipeline():
def __init__(self, model_name = "runwayml/stable-diffusion-v1-5", superres_pipeline = None):
self.pipe = None
# model
self.model = model_name
# vae
self.vae = None
# schedulers
self.available_schedulers = {}
# super-resolution
self.superres_pipeline = superres_pipeline
self.added_tokens = []
def from_pretrained(self, verbose = True, force = False, model_name=None):
if model_name is not None:
if len(model_name.strip()) == 0:
print("!!!!!检测出模型名称为空,我们将默认使用 MoososCap/NOVEL-MODEL")
model_name = "MoososCap/NOVEL-MODEL"
if model_name != self.model.strip():
print(f"!!!!!正在切换新模型, {model_name}")
self.model = model_name.strip()
force=True
model = self.model
if (not force) and self.pipe is not None:
return
if verbose: print('!!!!!正在加载模型, 请耐心等待, 如果出现两行红字是正常的, 不要惊慌!!!!!')
_ = paddle.zeros((1,)) # activate the paddle on CUDA
with context_nologging():
from .pipeline_stable_diffusion_all_in_one import StableDiffusionPipelineAllinOne
self.pipe = StableDiffusionPipelineAllinOne.from_pretrained(model, safety_checker = None, requires_safety_checker=False)
# update scheduler
scheduler = self.pipe.scheduler
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
from ppdiffusers.configuration_utils import FrozenDict
scheduler._internal_dict = FrozenDict(new_config)
self.pipe.register_modules(scheduler=scheduler)
self.available_schedulers = {}
self.available_schedulers['default'] = scheduler
# schedulers
from ppdiffusers import KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler, HeunDiscreteScheduler
self.available_schedulers.update({
"DPMSolver": DPMSolverMultistepScheduler.from_pretrained(
model_name,
subfolder="scheduler",
thresholding=False,
algorithm_type="dpmsolver++",
solver_type="midpoint",
lower_order_final=True,
),
"EulerDiscrete": EulerDiscreteScheduler.from_config(model_name, subfolder="scheduler"),
'EulerAncestralDiscrete': EulerAncestralDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler"),
'PNDM': PNDMScheduler.from_pretrained(model_name, subfolder="scheduler"),
'DDIM': DDIMScheduler.from_pretrained(model_name, subfolder="scheduler", clip_sample=False),
'LMSDiscrete' : LMSDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler"),
'HeunDiscrete': HeunDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler"),
'KDPM2AncestralDiscrete': KDPM2AncestralDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler"),
'KDPM2Discrete': KDPM2DiscreteScheduler.from_pretrained(model_name, subfolder="scheduler"),
})
if verbose: print('成功加载完毕, 若默认设置无法生成, 请停止项目等待保存完毕选择GPU重新进入')
def load_concepts(self, opt):
added_tokens = []
is_exist_concepts_library_dir = False
original_dtype = None
has_updated = False
if opt.concepts_library_dir is not None:
file_paths = None
path = Path(opt.concepts_library_dir)
if path.exists():
#file_paths = path.glob("*.pdparams")
file_paths = [p for p in path.glob("*.pdparams")]
# conversion of .pt -> .pdparams embedding
pt_files = path.glob("*.pt")
for pt_file in pt_files:
try:
convert_pt_to_pdparams(pt_file, dim = 768, save = True)
except:
pass
if opt.concepts_library_dir.endswith('.pdparams') and os.path.exists(opt.concepts_library_dir):
# load single file
file_paths = [opt.concepts_library_dir]
if file_paths is not None and len(file_paths)>0:
is_exist_concepts_library_dir = True
# load the token safely in float32
original_dtype = self.pipe.text_encoder.dtype
self.pipe.text_encoder = self.pipe.text_encoder.to(dtype = 'float32')
self.added_tokens = []
for p in file_paths:
for token, embeds in paddle.load(str(p)).items():
added_tokens.append(token)
if embeds.dim() == 1:
embeds = embeds.reshape((1, -1))
tokens = get_multiple_tokens(token, embeds.shape[0], ret_list = True)
self.added_tokens.append((token, ' '.join(tokens)))
for token, embed in zip(tokens, embeds):
self.pipe.tokenizer.add_tokens(token)
self.pipe.text_encoder.resize_token_embeddings(len(self.pipe.tokenizer))
token_id = self.pipe.tokenizer.convert_tokens_to_ids(token)
with paddle.no_grad():
if paddle.max(paddle.abs(self.pipe.text_encoder.get_input_embeddings().weight[token_id] - embed)) > 1e-4:
# only add / update new token if it has changed
has_updated = True
self.pipe.text_encoder.get_input_embeddings().weight[token_id] = embed
if is_exist_concepts_library_dir:
if has_updated and len(added_tokens):
str_added_tokens = ", ".join(added_tokens)
print(f"[导入训练文件] 成功加载了这些新词: {str_added_tokens} ")
else:
print(f"[导入训练文件] {opt.concepts_library_dir} 文件夹下没有发现任何文件,跳过加载!")
if self.added_tokens:
#self_str_added_tokens = ", ".join((self.added_tokens))
str_added_tokens = ", ".join(added_tokens)
print(f"[支持的'风格'或'人物'单词]: {str_added_tokens} ")
if original_dtype is not None:
self.pipe.text_encoder = self.pipe.text_encoder.to(dtype = original_dtype)
def run(self, opt, task = 'txt2img', on_image_generated = None):
model_name = try_get_catched_model(opt.model_name)
self.from_pretrained(model_name=model_name)
self.load_concepts(opt)
seed = None if opt.seed == -1 else opt.seed
# switch scheduler
self.pipe.scheduler = self.available_schedulers[opt.sampler]
task_func = None
# process prompts
enable_parsing = False
prompt = opt.prompt
negative_prompt = opt.negative_prompt
if '{}' in opt.enable_parsing:
enable_parsing = True
# convert {} to ()
prompt = prompt.translate({40:123, 41:125, 123:40, 125:41})
negative_prompt = negative_prompt.translate({40:123, 41:125, 123:40, 125:41})
elif '()' in opt.enable_parsing:
enable_parsing = True
for token in self.added_tokens:
prompt = prompt.replace(token[0], token[1])
negative_prompt = negative_prompt.replace(token[0], token[1])
init_image = None
mask_image = None
if task == 'txt2img':
def task_func():
return self.pipe.text2image(
prompt, seed=seed,
width=opt.width,
height=opt.height,
guidance_scale=opt.guidance_scale,
num_inference_steps=opt.num_inference_steps,
negative_prompt=negative_prompt,
max_embeddings_multiples=int(opt.max_embeddings_multiples),
skip_parsing=(not enable_parsing)
).images[0]
elif task == 'img2img':
init_image = ReadImage(opt.image_path, height=opt.height, width=opt.width)
def task_func():
return self.pipe.img2img(
prompt, seed=seed,
image=init_image,
num_inference_steps=opt.num_inference_steps,
strength=opt.strength,
guidance_scale=opt.guidance_scale,
negative_prompt=negative_prompt,
max_embeddings_multiples=int(opt.max_embeddings_multiples),
skip_parsing=(not enable_parsing)
)[0][0]
elif task == 'inpaint':
init_image = ReadImage(opt.image_path, height=opt.height, width=opt.width)
mask_image = ReadImage(opt.mask_path, height=opt.height, width=opt.width)
def task_func():
return self.pipe.inpaint(
prompt, seed=seed,
image=init_image,
mask_image=mask_image,
num_inference_steps=opt.num_inference_steps,
strength=opt.strength,
guidance_scale=opt.guidance_scale,
negative_prompt=negative_prompt,
max_embeddings_multiples=int(opt.max_embeddings_multiples),
skip_parsing=(not enable_parsing)
)[0][0]
if opt.fp16 == 'float16' and opt.sampler != "LMSDiscrete":
context = paddle.amp.auto_cast(True, level = 'O2') # level = 'O2' # seems to have BUG if enable O2
else:
context = nullcontext()
image_info = init_image.info if init_image is not None else None
with context:
for i in range(opt.num_return_images):
empty_cache()
image = task_func()
image.argument['sampler'] = opt.sampler
# super resolution
if (self.superres_pipeline is not None):
argument = image.argument
argument['superres_model_name'] = opt.superres_model_name
image = self.superres_pipeline.run(opt, image = image, end_to_end = False)
image.argument = argument
if task == 'img2img':
image.argument['init_image'] = opt.image_path
elif task == 'inpaint':
image.argument['init_image'] = opt.image_path
image.argument['mask_path'] = opt.mask_path
image.argument['model_name'] = opt.model_name
if on_image_generated is not None:
on_image_generated(
image = image,
options = opt,
count = i,
total = opt.num_return_images,
image_info = image_info,
)
continue
save_image_info(image, opt.output_dir,image_info)
if i % 50 == 0:
clear_output()
display(image)
print('Seed = ', image.argument['seed'],
' (%d / %d ... %.2f%%)'%(i + 1, opt.num_return_images, (i + 1.) / opt.num_return_images * 100))
class SuperResolutionPipeline():
def __init__(self):
self.model = None
self.model_name = ''
def run(self, opt,
image = None,
task = 'superres',
end_to_end = True,
force_empty_cache = True,
on_image_generated = None,
):
"""
end_to_end: return PIL image if False, display in the notebook and autosave otherwise
empty_cache: force clear the GPU cache by deleting the model
"""
if opt.superres_model_name is None or opt.superres_model_name in ('','无'):
return image
import numpy as np
if image is None:
image = ReadImage(opt.image_path, height=None, width=None) # avoid resizing
image_info = image.info
image = np.array(image)
image = image[:,:,[2,1,0]] # RGB -> BGR
empty_cache()
if self.model_name != opt.superres_model_name:
if self.model is not None:
del self.model
with context_nologging():
# [ WARNING] - The _initialize method in HubModule will soon be deprecated, you can use the __init__() to handle the initialization of the object
import paddlehub as hub
# print('正在加载超分模型! 如果出现两三行红字是正常的, 不要担心哦!')
self.model = hub.Module(name = opt.superres_model_name)
self.model_name = opt.superres_model_name
# time.sleep(.1) # wait until the warning prints
# print('正在超分......请耐心等待')
try:
image = self.model.reconstruct([image], use_gpu = (paddle.device.get_device() != 'cpu'))[0]['data']
except:
print('图片尺寸过大, 超分时超过显存限制')
self.empty_cache(force_empty_cache)
paddle.disable_static()
return
image = image[:,:,[2,1,0]] # BGR -> RGB
image = Image.fromarray(image)
self.empty_cache(force_empty_cache)
paddle.disable_static()
if on_image_generated is not None:
on_image_generated(
image = image,
options = opt,
count = 0,
total = 1,
image_info = image_info,
)
return
if end_to_end:
cur_time = time.time()
os.makedirs(opt.output_dir, exist_ok = True)
image.save(os.path.join(opt.output_dir,f'Highres_{cur_time}.png'), quality=100)
clear_output()
display(image)
return
return image
def empty_cache(self, force = True):
# NOTE: it seems that ordinary method cannot clear the cache
# so we have to delete the model (?)
if not force:
return
del self.model
self.model = None
self.model_name = ''
from traitlets import Bunch
import ipywidgets
from ipywidgets import (
IntText,
BoundedIntText,
Layout,
Button,
Label,
Box, HBox,
# Box应当始终假定display不明
# HBox/VBox应当仅用于【单行/单列】内容
)
from .utils import collect_local_module_names
model_name_list = [
"Linaqruf/anything-v3.0",
"MoososCap/NOVEL-MODEL",
"Baitian/momocha",
"Baitian/momoco",
"hequanshaguo/monoko-e",
"ruisi/anything",
"hakurei/waifu-diffusion-v1-3",
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
"stabilityai/stable-diffusion-2",
"stabilityai/stable-diffusion-2-base",
"hakurei/waifu-diffusion",
"naclbit/trinart_stable_diffusion_v2_60k",
"naclbit/trinart_stable_diffusion_v2_95k",
"naclbit/trinart_stable_diffusion_v2_115k",
"ringhyacinth/nail-set-diffuser",
"Deltaadams/Hentai-Diffusion",
"BAAI/AltDiffusion",
"BAAI/AltDiffusion-m9",
"IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1",
"IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-EN-v0.1",
"huawei-noah/Wukong-Huahua"]
samler_list = [
"default",
"DPMSolver",
"EulerDiscrete",
"EulerAncestralDiscrete",
"PNDM",
"DDIM",
"LMSDiscrete",
"HeunDiscrete",
"KDPM2AncestralDiscrete",
"KDPM2Discrete"
]
_DefaultLayout = {
'col04': {
'flex': "4 4 30%",
'min_width': "6rem", #480/ sm-576, md768, lg-992, xl-12000
'max_width': "calc(100% - 0.75rem)",
'margin': "0.375rem",
'align_items': "center"
},
'col06': {
'flex': "6 6 45%",
'min_width': "9rem", #手机9rem会换行
'max_width': "calc(100% - 0.75rem)",
'margin': "0.375rem",
'align_items': "center"
},
'col08': {
'flex': "8 8 60%",
'min_width': "12rem",
'max_width': "calc(100% - 0.75rem)",
'margin': "0.375rem",
'align_items': "center"
},
'col12': {
'flex': "12 12 90%",
'max_width': "calc(100% - 0.75rem)",
'margin': "0.375rem",
'align_items': "center"
},
'btnV5': {}, #见css
}
# 为工具设置布局,并标记dom class
def setLayout(layout_name, widget):
_lists = layout_name if isinstance(layout_name, list) else [layout_name]
for name in _lists:
if layout_name not in _DefaultLayout:
raise Exception(f'未定义的layout名称:{layout_name}')
styles = _DefaultLayout[layout_name];
for key in styles:
setattr(widget.layout, key, styles[key])
widget.add_class(layout_name)
_description_style = { 'description_width': "4rem" }
_Views = {
# Textarea
"prompt": {
"__type": 'Textarea',
"class_name": 'prompt',
"layout": {
"flex": '1',
"min_height": '10rem',
"max_width": 'calc(100% - 0.75rem)',
"margin": '0.375rem',
"align_items": 'stretch'
},
"style": _description_style,
"description": '正面描述' ,
"description_tooltip": '仅支持(xxx)、(xxx:1.2)、[xxx]三种语法。设置括号格式可以对{}进行转换。',
},
"negative_prompt": {
"__type": 'Textarea',
"class_name": 'negative_prompt',
"layout": {
"flex": '1',
"max_width": 'calc(100% - 0.75rem)',
"margin": '0.375rem',
"align_items": 'stretch'
},
"style": _description_style,
"description": '负面描述',
"description_tooltip": '使生成图像的内容远离负面描述的文本',
},
# Text
"concepts_library_dir": {
"__type": 'Text',
"class_name": 'concepts_library_dir',
"layout_name": 'col08',
"style": _description_style,
"description": '风格权重',
"description_tooltip": 'TextualInversion训练的、“风格”或“人物”的权重文件路径',
"value": 'outputs/textual_inversion',
},
"output_dir": {
"__type": 'Text',
"class_name": 'output_dir',
"layout_name": 'col08',
"style": _description_style,
"description": '保存路径',
"description_tooltip": '用于保存输出图片的路径',
"value": 'outputs',
},
"seed": {
"__type": 'IntText',
"class_name": 'seed',
"layout_name": 'col04',
"style": _description_style,
"description": '随机种子',
"description_tooltip": '-1表示随机生成。',
"value": -1,
},
"num_inference_steps": {
"__type": 'BoundedIntText',
"class_name": 'num_inference_steps',
"layout_name": 'col04',
"style": _description_style,
"description": '推理步数',
"description_tooltip": '推理步数(Step):生成图片的迭代次数,步数越多运算次数越多。',
"value": 50,
"min": 2,
"max": 10000,
},
"num_return_images": {
"__type": 'BoundedIntText',
"class_name": 'num_return_images',
"layout_name": 'col04',
"style": _description_style,
"description": '生成数量',
"description_tooltip": '生成图片的数量',
"value": 1,
"min": 1,
"max": 100,
"step": 1,
},
"guidance_scale": {
"__type": 'BoundedFloatText',
"class_name": 'guidance_scale',
"layout_name": 'col04',
"style": _description_style,
"description": 'CFG',
"description_tooltip": '引导度(CFG Scale):控制图片与描述词之间的相关程度。',
"min": 0,
"max": 50,
"value": 7.5,
},
# Dropdown
"enable_parsing": {
"__type": 'Dropdown',
"class_name": 'enable_parsing',
"layout_name": 'col04',
"style": _description_style,
"description": '括号格式',
"description_tooltip": '增加权重所用括号的格式,可以将{}替换为()。选择“否”则不解析加权语法',
"value": '圆括号 () 加强权重',
"options": ['圆括号 () 加强权重','花括号 {} 加权权重', '否'],
},
"fp16": {
"__type": 'Dropdown',
"class_name": 'fp16',
"layout_name": 'col04',
"style": _description_style,
"description": '算术精度',
"description_tooltip": '模型推理使用的精度。选择float16可以加快模型的推理速度,但会牺牲部分的模型性能。',
"value": 'float32',
"options": ['float32', 'float16'],
},
"max_embeddings_multiples": {
"__type": 'Dropdown',
"class_name": 'max_embeddings_multiples',
"layout_name": 'col04',
"style": _description_style,
"description": '描述上限',
"description_tooltip": '修改描述词的上限倍数,使模型能够输入更长更多的描述词。',
"value": '3',
"options": ['1','2','3','4','5'],
},
"sampler": {
"__type": 'Dropdown',
"class_name": 'sampler',
"layout_name": 'col04',
"style": _description_style,
"description": '采样器',
"value": 'default',
"options": samler_list,
},
"standard_size": {
"__type": 'Dropdown',
"class_name": 'standard_size',
"layout_name": 'col04',
"style": _description_style,
"description": '图片尺寸',
"description_tooltip": '生成图片的尺寸',
"value": 5120512,
"options": [
('竖向(512x768)', 5120768),
('横向(768x512)', 7680512),
('正方形(640x640)', 6400640),
('大尺寸-竖向(512x1024)', 5121024),
('大尺寸-横向(1024x512)', 10240512),
('大尺寸-正方形(1024x1024)',10241024),
('小尺寸-竖向(384x640)', 3840640),
('小尺寸-横向(640x384)', 6400384),
('小尺寸-正方形(512x512)', 5120512),
],
},
"superres_model_name": {
"__type": 'Dropdown',
"class_name": 'superres_model_name',
"layout_name": 'col04',
"style": _description_style,
"description": '图像放大',
"description_tooltip": '指定放大图片尺寸所用的模型',
"value": '无',
"options": ['falsr_a', 'falsr_b', 'falsr_c', '无'],
},
# Combobox
"model_name": {
"__type": 'Combobox',
"class_name": 'model_name',
"layout_name": 'col08',
"style": _description_style,
"description": '模型名称',
"description_tooltip": '需要加载的模型名称',
"value": 'MoososCap/NOVEL-MODEL',
"options": model_name_list,
"ensure_option": False,
},
# Button
"run_button": {
"__type": 'Button',
"class_name": 'run_button',
"layout_name": 'btnV5',
"button_style": 'success', # 'success', 'info', 'warning', 'danger' or ''
"description": '生成图片!',
"tooltip": '单击开始生成图片',
"icon": 'check'
},
"collect_button": {
"__type": 'Button',
"class_name": 'collect_button',
"layout_name": 'btnV5',
"button_style": 'info', # 'success', 'info', 'warning', 'danger' or ''
"description": '收藏图片',
"tooltip": '将图片转移到Favorates文件夹中',
"icon": 'star-o',
"disabled": True,
},
# Box
"box_gui": {
"__type": 'Box',
"layout": {
"display": 'block', #Box默认值为flex
"margin": '0 45px 0 0',
},
},
"box_main": {
"__type": 'Box',
"layout": {
"display": 'flex',
"flex_flow": 'row wrap', #HBox会覆写此属性
"align_items": 'center',
"max_width": '100%',
},
},
}
SHARED_STYLE_SHEETS = '''
@media (max-width:576px) {
{root} {
margin-right: 0 !important;
}
{root} .widget-text,
{root} .widget-dropdown,
{root} .widget-hslider,
{root} .widget-textarea {
flex-wrap: wrap !important;
height: auto;
margin-top: 0.1rem !important;
margin-bottom: 0.1rem !important;
}
{root} .widget-text > label,
{root} .widget-dropdown > label,
{root} .widget-hslider > label,
{root} .widget-textarea > label {
width: 100% !important;
text-align: left !important;
font-size: small !important;
}
{root} .col04,
{root} .col06 {
/*手机9rem会换行*/
min-width: 6rem !important;
}
}
{root} {
background-color: var(--jp-layout-color1);
}
{root} .widget-text > label,
{root} .widget-text > .widget-label {
user-select: none;
}
/* bootcss v5 */
{root} button.btnV5.jupyter-button.widget-button
{
height:auto;
font-weight: 400;
line-height: 1.5;
text-align: center;
vertical-align: middle;
padding: .375rem .75rem;
font-size: 1rem;
}
{root} button.btnV5.btn-sm.jupyter-button.widget-button
{
padding: .25rem .5rem;
font-size: .875rem;
}
{root} .jupyter-widgets.widget-tab > .p-TabBar .p-TabBar-tab {
padding: .5rem 0;
text-align: center;
transition: color .15s ease-in-out,background-color .15s ease-in-out,border-color .15s ease-in-out;
}
'''
CUSTOM_OPTIONS = ('__type','class_name', 'layout_name')
def _mergeViewOptions(defaultOpt,kwargs):
r = {}
r.update(defaultOpt)
for k in kwargs:
r[k] = kwargs[k]
if (k in defaultOpt) and (type(defaultOpt[k]) == 'dict') \
and (type(kwargs[k]) == 'dict'):
r[k] = {}
r[k].update(defaultOpt[k])
r[k].update(kwargs[k])
#处理layout
if ('layout' in r) and (type(r['layout']) == 'dict'):
r.layout = Layout(**r['layout'])
#提取非ipywidgets参数
r2 = {}
for k in CUSTOM_OPTIONS:
if (k in r):
r2[k] = r.pop(k)
return (r, r2)
def createView(name, value = None, **kwargs):
assert name in _Views, f'未定义的View名称 {name}'
assert '__type' in _Views[name], f'View {name} 没有声明组件类型'
# 合并参数
args, options = _mergeViewOptions(_Views[name], kwargs)
# 反射
__type = _Views[name]['__type']
assert hasattr(ipywidgets, __type), f'View {name} 声明的组件{__type}未被实现'
ctor = getattr(ipywidgets, __type)
if value is None:
pass
elif hasattr(ctor, 'value'):
args['value'] = value
elif hasattr(ctor, 'children'):
args['children'] = value
#实例化
widget = ctor(**args)
# 给模型列表补充本地模型
if name == 'model_name':
widget.options = list(widget.options) + \
[m for m in collect_local_module_names() if m not in widget.options]
# 添加DOM class名
if 'class_name' in options:
widget.add_class(options['class_name'])
# 设置预设布局
if 'layout_name' in options:
setLayout(options['layout_name'], widget)
return widget
DEFAULT_BADWORDS = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"
def createPromptsView(value = '', negative_value = ''):
style_sheets = '''
@media (max-width:576px) {
{root} .box_prompts .prompt > textarea {
min-height:8rem;
margin-left:2rem!important;
}
{root} .box_prompts .negative_prompt > textarea {
margin-left:2rem!important;
}
{root} .box_prompts > .box_wrap_quikbtns {
margin-left: 0 !important;
}
{root} .box_prompts > .box_wrap_quikbtns > button {
padding: 0 !important;
}
}
'''
prompt = createView('prompt', value = value)
negative_prompt = createView('negative_prompt', value = negative_value)
# 按钮
btnGoodQuality = Button(
description= '',
tooltip='填充标准质量描述',
icon='palette',
layout = Layout(
#不支持的属性? position = 'absolute',
height = '1.8rem',
width = '1.8rem',
margin = '-11rem 0 0 0'
)
)
btnBadwards = Button(
description= '',
tooltip='填充标准负面描述',
icon='paper-plane',
layout = Layout(
#不支持的属性? position = 'absolute',
height = '1.8rem',
width = '1.8rem',
margin = '-2rem 0px 0rem -1.8rem'
)
)
def fill_good_quality(b):
if not prompt.value.startswith('masterpiece,best quality,'):
prompt.value = 'masterpiece,best quality,' + prompt.value
def fill_bad_words(b):
negative_prompt.value = DEFAULT_BADWORDS
btnGoodQuality.on_click(fill_good_quality)
btnBadwards.on_click(fill_bad_words)
box_wrap_quikbtns = Box([
btnGoodQuality,btnBadwards,
], layout = Layout(
margin = '0 1rem',
height = '0',
overflow = 'visible'
));
box_wrap_quikbtns.add_class('box_wrap_quikbtns')
container = Box([
HBox([prompt]),
HBox([negative_prompt]),
box_wrap_quikbtns,
])
container.layout.display = 'block';
container.add_class('box_prompts')
return Bunch({
'container': container,
'prompt': prompt,
'negative_prompt': negative_prompt,
'style_sheets': style_sheets,
})
def _create_WHView(width_value = 512, height_value = 512):
style_sheets = '''
@media (max-width:576px) {
{root} .box_width_height {
flex: 8 8 60% !important;
}
}
'''
_layout = Layout(
flex = '1 0 2rem',
width = '2rem',
)
w_width = BoundedIntText(
layout=_layout,
value=width_value,
min=64,
max=1088,
step=64,
)
w_height = BoundedIntText(
layout=_layout,
value=height_value,
min=64,
max=1088,
step=64,
)
def validate(change):
num = change.new % 64
if change.new < 64:
change.owner.value = 64
elif num == 0:
pass
elif num < 32:
change.owner.value = change.new - num
else:
change.owner.value = change.new - num + 64
w_width.observe(validate,names = 'value')
w_height.observe(validate,names = 'value')
container = HBox([
w_width,
Label(
value = 'X',
layout = Layout(
flex='0 0 auto',
padding='0 0.75rem'
),
# description_tooltip = '图片尺寸' if not step64 else '图片尺寸(-1为自动判断)'
),
w_height,
])
setLayout('col04', container)
container.add_class('box_width_height')
return Bunch({
'container': container,
'width': w_width,
'height': w_height,
'style_sheets': style_sheets,
})
def _create_WHView_for_img2img(width_value = -1, height_value = -1):
style_sheets = '''
{root} .box_width_height > .widget-label:first-of-type {
text-align: right !important;
}
@media (max-width:576px) {
{root} .box_width_height {
flex: 8 8 60% !important;
flex-wrap: wrap !important;
height: auto;
margin-top: 0.1rem !important;
margin-bottom: 0.1rem !important;
}
{root} .box_width_height > .widget-label:first-of-type {
width: 100% !important;
text-align: left !important;
font-size: small !important;
}
}
'''
_layout = Layout(
flex = '1 0 2rem',
width = '2rem',
)
w_width = IntText(
layout=_layout,
value=width_value,
)
w_height = IntText(
layout=_layout,
value=height_value,
)
container = HBox([
Label(
value = '图片尺寸',
description_tooltip = '-1表示自动检测',
style = _description_style,
layout = Layout(
flex='0 0 auto',
width = '4rem',
# margin = '0 4px 0 0',
margin = '0 calc(var(--jp-widgets-inline-margin)*2) 0 0',
)
),
w_width,
Label(
value = 'X',
layout = Layout(
flex='0 0 auto',
padding='0 0.75rem'
),
),
w_height,
])
setLayout('col08', container)
container.add_class('box_width_height')
return Bunch({
'container': container,
'width': w_width,
'height': w_height,
'style_sheets': style_sheets,
})
def createWidthHeightView(width_value = 512, height_value = 512, step64 = False):
if step64:
return _create_WHView(width_value, height_value)
else:
return _create_WHView_for_img2img(width_value, height_value)
# --------------------------------------------------
def Tab(children = None, **kwargs):
titles = None if 'titles' not in kwargs else kwargs.pop('titles')
if children is not None: kwargs['children'] = children
tab = ipywidgets.Tab(**kwargs)
if titles is not None:
for i in range(len(titles)):
tab.set_title(i, titles[i])
return tab
def Div(children = None, **kwargs):
if children is not None: kwargs['children'] = children
box = Box(**kwargs)
box.layout.display = 'block' # Box 默认flex
return box
def FlexBox(children = None, **kwargs):
if children is not None: kwargs['children'] = children
box = Box(**kwargs)
box.layout.display = 'flex'
box.layout.flex_flow = 'row wrap' # HBox覆写nowrap,Box默认nowrap
box.layout.justify_content = 'space-around' #注意覆写
box.layout.align_items = 'center',
box.layout.align_content = 'center',
return box
\ No newline at end of file
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: True
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 512
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
import os
# check main envs
def check_install(verbose = True):
try:
import paddle
except:
print("please install paddle==2.4.2 with version dtk-23.04 before running")
exit()
try:
print('checking install.......')
import safetensors
from ppdiffusers.utils import image_grid
from paddlenlp.transformers.clip.feature_extraction import CLIPFeatureExtractor
from paddlenlp.transformers import FeatureExtractionMixin
import ipywidgets
import PIL
import tqdm
print('检测完成,库完整')
except (ModuleNotFoundError, ImportError, AttributeError):
if verbose: print('检测到库不完整, 正在安装库')
os.system("pip install -U pip -i https://mirror.baidu.com/pypi/simple")
os.system("pip install -U OmegaConf --user")
os.system("pip install ppdiffusers==0.9.0 --user")
os.system("pip install paddlenlp==2.4.9 --user")
os.system("pip install -U safetensors --user")
os.system("pip install ipython==8.14.0")
os.system("pip install ipywidgets==8.0.7")
os.system("pip install pillow==9.5.0")
os.system("pip install tqdm==4.65.0")
def start_end(name,func):
print(f'------- test {name} start -------')
func.on_run_button_click('test')
print(f'------- test {name} finished -------')
if __name__=='__main__':
check_install()
from ui import gui_train_text_inversion,gui_txt2img,gui_img2img
start_end('txt2img',gui_txt2img)
start_end('img2img',gui_img2img)
start_end('txt2img train',gui_train_text_inversion)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment