Commit 727428ec authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit CI/CD

parents
This source diff could not be displayed because it is too large. You can view the blob instead.
# LoRA network module
# reference:
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
import math
import os
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
import numpy as np
import torch
import re
from library.utils import setup_logging
from library.sdxl_original_unet import SdxlUNet2DConditionModel
setup_logging()
import logging
logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(
r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_"
)
class LoRAModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=None,
rank_dropout=None,
module_dropout=None,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
self.lora_name = lora_name
if org_module.__class__.__name__ == "Conv2d":
in_dim = org_module.in_channels
out_dim = org_module.out_channels
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
# if limit_rank:
# self.lora_dim = min(lora_dim, in_dim, out_dim)
# if self.lora_dim != lora_dim:
# logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# else:
self.lora_dim = lora_dim
if org_module.__class__.__name__ == "Conv2d":
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = torch.nn.Conv2d(
in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
)
self.lora_up = torch.nn.Conv2d(
self.lora_dim, out_dim, (1, 1), (1, 1), bias=False
)
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
# same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight)
self.multiplier = multiplier
self.org_module = org_module # remove in applying
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def forward(self, x):
org_forwarded = self.org_forward(x)
# module dropout
if self.module_dropout is not None and self.training:
if torch.rand(1) < self.module_dropout:
return org_forwarded
lx = self.lora_down(x)
# normal dropout
if self.dropout is not None and self.training:
lx = torch.nn.functional.dropout(lx, p=self.dropout)
# rank dropout
if self.rank_dropout is not None and self.training:
mask = (
torch.rand((lx.size(0), self.lora_dim), device=lx.device)
> self.rank_dropout
)
if len(lx.size()) == 3:
mask = mask.unsqueeze(1) # for Text Encoder
elif len(lx.size()) == 4:
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
lx = lx * mask
# scaling for rank dropout: treat as if the rank is changed
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
scale = self.scale * (
1.0 / (1.0 - self.rank_dropout)
) # redundant for readability
else:
scale = self.scale
lx = self.lora_up(lx)
return org_forwarded + lx * self.multiplier * scale
class LoRAInfModule(LoRAModule):
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
**kwargs,
):
# no dropout for inference
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
self.org_module_ref = [org_module] # 後から参照できるように
self.enabled = True
# check regional or not by lora_name
self.text_encoder = False
if lora_name.startswith("lora_te_"):
self.regional = False
self.use_sub_prompt = True
self.text_encoder = True
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
self.regional = False
self.use_sub_prompt = True
elif "time_emb" in lora_name:
self.regional = False
self.use_sub_prompt = False
else:
self.regional = True
self.use_sub_prompt = False
self.network: LoRANetwork = None
def set_network(self, network):
self.network = network
# freezeしてマージする
def merge_to(self, sd, dtype, device):
# get up/down weight
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
# extract weight from org_module
org_sd = self.org_module.state_dict()
weight = org_sd["weight"].to(torch.float)
# merge weight
if len(weight.size()) == 2:
# linear
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ self.multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2))
.unsqueeze(2)
.unsqueeze(3)
* self.scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(
down_weight.permute(1, 0, 2, 3), up_weight
).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + self.multiplier * conved * self.scale
# set weight to org_module
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
# 復元できるマージのため、このモジュールのweightを返す
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier
# get up/down weight from module
up_weight = self.lora_up.weight.to(torch.float)
down_weight = self.lora_down.weight.to(torch.float)
# pre-calculated weight
if len(down_weight.size()) == 2:
# linear
weight = self.multiplier * (up_weight @ down_weight) * self.scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
self.multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2))
.unsqueeze(2)
.unsqueeze(3)
* self.scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(
down_weight.permute(1, 0, 2, 3), up_weight
).permute(1, 0, 2, 3)
weight = self.multiplier * conved * self.scale
return weight
def set_region(self, region):
self.region = region
self.region_mask = None
def default_forward(self, x):
# logger.info(f"default_forward {self.lora_name} {x.size()}")
return (
self.org_forward(x)
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
)
def forward(self, x):
if not self.enabled:
return self.org_forward(x)
if self.network is None or self.network.sub_prompt_index is None:
return self.default_forward(x)
if not self.regional and not self.use_sub_prompt:
return self.default_forward(x)
if self.regional:
return self.regional_forward(x)
else:
return self.sub_prompt_forward(x)
def get_mask_for_x(self, x):
# calculate size from shape of x
if len(x.size()) == 4:
h, w = x.size()[2:4]
area = h * w
else:
area = x.size()[1]
mask = self.network.mask_dic.get(area, None)
if mask is None or len(x.size()) == 2:
# emb_layers in SDXL doesn't have mask
# if "emb" not in self.lora_name:
# print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}")
mask_size = (
(1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
)
return (
torch.ones(mask_size, dtype=x.dtype, device=x.device)
/ self.network.num_sub_prompts
)
if len(x.size()) == 3:
mask = torch.reshape(mask, (1, -1, 1))
return mask
def regional_forward(self, x):
if "attn2_to_out" in self.lora_name:
return self.to_out_forward(x)
if self.network.mask_dic is None: # sub_prompt_index >= 3
return self.default_forward(x)
# apply mask for LoRA result
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
mask = self.get_mask_for_x(lx)
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
# if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked)
# mask = mask.squeeze(-1)
lx = lx * mask
x = self.org_forward(x)
x = x + lx
if "attn2_to_q" in self.lora_name and self.network.is_last_network:
x = self.postp_to_q(x)
return x
def postp_to_q(self, x):
# repeat x to num_sub_prompts
has_real_uncond = x.size()[0] // self.network.batch_size == 3
qc = self.network.batch_size # uncond
qc += self.network.batch_size * self.network.num_sub_prompts # cond
if has_real_uncond:
qc += self.network.batch_size # real_uncond
query = torch.zeros(
(qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype
)
query[: self.network.batch_size] = x[: self.network.batch_size]
for i in range(self.network.batch_size):
qi = self.network.batch_size + i * self.network.num_sub_prompts
query[qi : qi + self.network.num_sub_prompts] = x[
self.network.batch_size + i
]
if has_real_uncond:
query[-self.network.batch_size :] = x[-self.network.batch_size :]
# logger.info(f"postp_to_q {self.lora_name} {x.size()} {query.size()} {self.network.num_sub_prompts}")
return query
def sub_prompt_forward(self, x):
if (
x.size()[0] == self.network.batch_size
): # if uncond in text_encoder, do not apply LoRA
return self.org_forward(x)
emb_idx = self.network.sub_prompt_index
if not self.text_encoder:
emb_idx += self.network.batch_size
# apply sub prompt of X
lx = x[emb_idx :: self.network.num_sub_prompts]
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
# logger.info(f"sub_prompt_forward {self.lora_name} {x.size()} {lx.size()} {emb_idx}")
x = self.org_forward(x)
x[emb_idx :: self.network.num_sub_prompts] += lx
return x
def to_out_forward(self, x):
# logger.info(f"to_out_forward {self.lora_name} {x.size()} {self.network.is_last_network}")
if self.network.is_last_network:
masks = [None] * self.network.num_sub_prompts
self.network.shared[self.lora_name] = (None, masks)
else:
lx, masks = self.network.shared[self.lora_name]
# call own LoRA
x1 = x[
self.network.batch_size
+ self.network.sub_prompt_index :: self.network.num_sub_prompts
]
lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
if self.network.is_last_network:
lx = torch.zeros(
(
self.network.num_sub_prompts * self.network.batch_size,
*lx1.size()[1:],
),
device=lx1.device,
dtype=lx1.dtype,
)
self.network.shared[self.lora_name] = (lx, masks)
# logger.info(f"to_out_forward {lx.size()} {lx1.size()} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
# if not last network, return x and masks
x = self.org_forward(x)
if not self.network.is_last_network:
return x
lx, masks = self.network.shared.pop(self.lora_name)
# if last network, combine separated x with mask weighted sum
has_real_uncond = (
x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
)
out = torch.zeros(
(self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]),
device=x.device,
dtype=x.dtype,
)
out[: self.network.batch_size] = x[: self.network.batch_size] # uncond
if has_real_uncond:
out[-self.network.batch_size :] = x[
-self.network.batch_size :
] # real_uncond
# logger.info(f"to_out_forward {self.lora_name} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
# if num_sub_prompts > num of LoRAs, fill with zero
for i in range(len(masks)):
if masks[i] is None:
masks[i] = torch.zeros_like(masks[0])
mask = torch.cat(masks)
mask_sum = torch.sum(mask, dim=0) + 1e-4
for i in range(self.network.batch_size):
# 1枚の画像ごとに処理する
lx1 = lx[
i
* self.network.num_sub_prompts : (i + 1)
* self.network.num_sub_prompts
]
lx1 = lx1 * mask
lx1 = torch.sum(lx1, dim=0)
xi = self.network.batch_size + i * self.network.num_sub_prompts
x1 = x[xi : xi + self.network.num_sub_prompts]
x1 = x1 * mask
x1 = torch.sum(x1, dim=0)
x1 = x1 / mask_sum
x1 = x1 + lx1
out[self.network.batch_size + i] = x1
# logger.info(f"to_out_forward {x.size()} {out.size()} {has_real_uncond}")
return out
def parse_block_lr_kwargs(is_sdxl: bool, nw_kwargs: Dict) -> Optional[List[float]]:
down_lr_weight = nw_kwargs.get("down_lr_weight", None)
mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
up_lr_weight = nw_kwargs.get("up_lr_weight", None)
# 以上のいずれにも設定がない場合は無効としてNoneを返す
if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
return None
# extract learning rate weight for each block
if down_lr_weight is not None:
# if some parameters are not set, use zero
if "," in down_lr_weight:
down_lr_weight = [
(float(s) if s else 0.0) for s in down_lr_weight.split(",")
]
if mid_lr_weight is not None:
mid_lr_weight = [(float(s) if s else 0.0) for s in mid_lr_weight.split(",")]
if up_lr_weight is not None:
if "," in up_lr_weight:
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
return get_block_lr_weight(
is_sdxl,
down_lr_weight,
mid_lr_weight,
up_lr_weight,
float(nw_kwargs.get("block_lr_zero_threshold", 0.0)),
)
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: AutoencoderKL,
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
unet,
neuron_dropout: Optional[float] = None,
**kwargs,
):
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel)
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
network_alpha = 1.0
# extract dim/alpha for conv2d, and block dim
conv_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get("conv_alpha", None)
if conv_dim is not None:
conv_dim = int(conv_dim)
if conv_alpha is None:
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)
# block dim/alpha/lr
block_dims = kwargs.get("block_dims", None)
block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs)
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
if block_dims is not None or block_lr_weight is not None:
block_alphas = kwargs.get("block_alphas", None)
conv_block_dims = kwargs.get("conv_block_dims", None)
conv_block_alphas = kwargs.get("conv_block_alphas", None)
block_dims, block_alphas, conv_block_dims, conv_block_alphas = (
get_block_dims_and_alphas(
is_sdxl,
block_dims,
block_alphas,
network_dim,
network_alpha,
conv_block_dims,
conv_block_alphas,
conv_dim,
conv_alpha,
)
)
# remove block dim/alpha without learning rate
block_dims, block_alphas, conv_block_dims, conv_block_alphas = (
remove_block_dims_and_alphas(
is_sdxl,
block_dims,
block_alphas,
conv_block_dims,
conv_block_alphas,
block_lr_weight,
)
)
else:
block_alphas = None
conv_block_dims = None
conv_block_alphas = None
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
rank_dropout = float(rank_dropout)
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
# すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork(
text_encoder,
unet,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
block_dims=block_dims,
block_alphas=block_alphas,
conv_block_dims=conv_block_dims,
conv_block_alphas=conv_block_alphas,
varbose=True,
is_sdxl=is_sdxl,
)
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = (
float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
)
loraplus_unet_lr_ratio = (
float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
)
loraplus_text_encoder_lr_ratio = (
float(loraplus_text_encoder_lr_ratio)
if loraplus_text_encoder_lr_ratio is not None
else None
)
if (
loraplus_lr_ratio is not None
or loraplus_unet_lr_ratio is not None
or loraplus_text_encoder_lr_ratio is not None
):
network.set_loraplus_lr_ratio(
loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio
)
if block_lr_weight is not None:
network.set_block_lr_weight(block_lr_weight)
return network
# このメソッドは外部から呼び出される可能性を考慮しておく
# network_dim, network_alpha にはデフォルト値が入っている。
# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
def get_block_dims_and_alphas(
is_sdxl,
block_dims,
block_alphas,
network_dim,
network_alpha,
conv_block_dims,
conv_block_alphas,
conv_dim,
conv_alpha,
):
if not is_sdxl:
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS
else:
# 1+9+3+9+1=23, no LoRA for emb_layers (0)
num_total_blocks = (
1
+ LoRANetwork.SDXL_NUM_OF_BLOCKS * 2
+ LoRANetwork.SDXL_NUM_OF_MID_BLOCKS
+ 1
)
def parse_ints(s):
return [int(i) for i in s.split(",")]
def parse_floats(s):
return [float(i) for i in s.split(",")]
# block_dimsとblock_alphasをパースする。必ず値が入る
if block_dims is not None:
block_dims = parse_ints(block_dims)
assert len(block_dims) == num_total_blocks, (
f"block_dims must have {num_total_blocks} elements but {len(block_dims)} elements are given"
+ f" / block_dimsは{num_total_blocks}個指定してください(指定された個数: {len(block_dims)})"
)
else:
logger.warning(
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
)
block_dims = [network_dim] * num_total_blocks
if block_alphas is not None:
block_alphas = parse_floats(block_alphas)
assert (
len(block_alphas) == num_total_blocks
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
else:
logger.warning(
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
)
block_alphas = [network_alpha] * num_total_blocks
# conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
if conv_block_dims is not None:
conv_block_dims = parse_ints(conv_block_dims)
assert (
len(conv_block_dims) == num_total_blocks
), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
if conv_block_alphas is not None:
conv_block_alphas = parse_floats(conv_block_alphas)
assert (
len(conv_block_alphas) == num_total_blocks
), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
else:
if conv_alpha is None:
conv_alpha = 1.0
logger.warning(
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
)
conv_block_alphas = [conv_alpha] * num_total_blocks
else:
if conv_dim is not None:
logger.warning(
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
)
conv_block_dims = [conv_dim] * num_total_blocks
conv_block_alphas = [conv_alpha] * num_total_blocks
else:
conv_block_dims = None
conv_block_alphas = None
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出せるようにclass外に出しておく
# 戻り値は block ごとの倍率のリスト
def get_block_lr_weight(
is_sdxl,
down_lr_weight: Union[str, List[float]],
mid_lr_weight: List[float],
up_lr_weight: Union[str, List[float]],
zero_threshold: float,
) -> Optional[List[float]]:
# パラメータ未指定時は何もせず、今までと同じ動作とする
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
return None
if not is_sdxl:
max_len_for_down_or_up = LoRANetwork.NUM_OF_BLOCKS
max_len_for_mid = LoRANetwork.NUM_OF_MID_BLOCKS
else:
max_len_for_down_or_up = LoRANetwork.SDXL_NUM_OF_BLOCKS
max_len_for_mid = LoRANetwork.SDXL_NUM_OF_MID_BLOCKS
def get_list(name_with_suffix) -> List[float]:
import math
tokens = name_with_suffix.split("+")
name = tokens[0]
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
if name == "cosine":
return [
math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr
for i in reversed(range(max_len_for_down_or_up))
]
elif name == "sine":
return [
math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr
for i in range(max_len_for_down_or_up)
]
elif name == "linear":
return [
i / (max_len_for_down_or_up - 1) + base_lr
for i in range(max_len_for_down_or_up)
]
elif name == "reverse_linear":
return [
i / (max_len_for_down_or_up - 1) + base_lr
for i in reversed(range(max_len_for_down_or_up))
]
elif name == "zeros":
return [0.0 + base_lr] * max_len_for_down_or_up
else:
logger.error(
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
% (name)
)
return None
if type(down_lr_weight) == str:
down_lr_weight = get_list(down_lr_weight)
if type(up_lr_weight) == str:
up_lr_weight = get_list(up_lr_weight)
if (up_lr_weight != None and len(up_lr_weight) > max_len_for_down_or_up) or (
down_lr_weight != None and len(down_lr_weight) > max_len_for_down_or_up
):
logger.warning(
"down_weight or up_weight is too long. Parameters after %d-th are ignored."
% max_len_for_down_or_up
)
logger.warning(
"down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。"
% max_len_for_down_or_up
)
up_lr_weight = up_lr_weight[:max_len_for_down_or_up]
down_lr_weight = down_lr_weight[:max_len_for_down_or_up]
if mid_lr_weight != None and len(mid_lr_weight) > max_len_for_mid:
logger.warning(
"mid_weight is too long. Parameters after %d-th are ignored."
% max_len_for_mid
)
logger.warning(
"mid_weightが長すぎます。%d個目以降のパラメータは無視されます。"
% max_len_for_mid
)
mid_lr_weight = mid_lr_weight[:max_len_for_mid]
if (up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up) or (
down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up
):
logger.warning(
"down_weight or up_weight is too short. Parameters after %d-th are filled with 1."
% max_len_for_down_or_up
)
logger.warning(
"down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。"
% max_len_for_down_or_up
)
if down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up:
down_lr_weight = down_lr_weight + [1.0] * (
max_len_for_down_or_up - len(down_lr_weight)
)
if up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up:
up_lr_weight = up_lr_weight + [1.0] * (
max_len_for_down_or_up - len(up_lr_weight)
)
if mid_lr_weight != None and len(mid_lr_weight) < max_len_for_mid:
logger.warning(
"mid_weight is too short. Parameters after %d-th are filled with 1."
% max_len_for_mid
)
logger.warning(
"mid_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。"
% max_len_for_mid
)
mid_lr_weight = mid_lr_weight + [1.0] * (max_len_for_mid - len(mid_lr_weight))
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
logger.info("apply block learning rate / 階層別学習率を適用します。")
if down_lr_weight != None:
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
logger.info(
f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}"
)
else:
down_lr_weight = [1.0] * max_len_for_down_or_up
logger.info("down_lr_weight: all 1.0, すべて1.0")
if mid_lr_weight != None:
mid_lr_weight = [w if w > zero_threshold else 0 for w in mid_lr_weight]
logger.info(f"mid_lr_weight: {mid_lr_weight}")
else:
mid_lr_weight = [1.0] * max_len_for_mid
logger.info("mid_lr_weight: all 1.0, すべて1.0")
if up_lr_weight != None:
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
logger.info(
f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}"
)
else:
up_lr_weight = [1.0] * max_len_for_down_or_up
logger.info("up_lr_weight: all 1.0, すべて1.0")
lr_weight = down_lr_weight + mid_lr_weight + up_lr_weight
if is_sdxl:
lr_weight = [1.0] + lr_weight + [1.0] # add 1.0 for emb_layers and out
assert (
not is_sdxl
and len(lr_weight)
== LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS
) or (
is_sdxl
and len(lr_weight)
== 1
+ LoRANetwork.SDXL_NUM_OF_BLOCKS * 2
+ LoRANetwork.SDXL_NUM_OF_MID_BLOCKS
+ 1
), f"lr_weight length is invalid: {len(lr_weight)}"
return lr_weight
# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
def remove_block_dims_and_alphas(
is_sdxl,
block_dims,
block_alphas,
conv_block_dims,
conv_block_alphas,
block_lr_weight: Optional[List[float]],
):
if block_lr_weight is not None:
for i, lr in enumerate(block_lr_weight):
if lr == 0:
block_dims[i] = 0
if conv_block_dims is not None:
conv_block_dims[i] = 0
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
# 外部から呼び出す可能性を考慮しておく
def get_block_index(lora_name: str, is_sdxl: bool = False) -> int:
block_idx = -1 # invalid lora name
if not is_sdxl:
m = RE_UPDOWN.search(lora_name)
if m:
g = m.groups()
i = int(g[1])
j = int(g[3])
if g[2] == "resnets":
idx = 3 * i + j
elif g[2] == "attentions":
idx = 3 * i + j
elif g[2] == "upsamplers" or g[2] == "downsamplers":
idx = 3 * i + 2
if g[0] == "down":
block_idx = 1 + idx # 0に該当するLoRAは存在しない
elif g[0] == "up":
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
elif "mid_block_" in lora_name:
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
else:
# copy from sdxl_train
if lora_name.startswith("lora_unet_"):
name = lora_name[len("lora_unet_") :]
if name.startswith("time_embed_") or name.startswith(
"label_emb_"
): # No LoRA
block_idx = 0 # 0
elif name.startswith("input_blocks_"): # 1-9
block_idx = 1 + int(name.split("_")[2])
elif name.startswith("middle_block_"): # 10-12
block_idx = 10 + int(name.split("_")[2])
elif name.startswith("output_blocks_"): # 13-21
block_idx = 13 + int(name.split("_")[2])
elif name.startswith("out_"): # 22, out, no LoRA
block_idx = 22
return block_idx
def convert_diffusers_to_sai_if_needed(weights_sd):
# only supports U-Net LoRA modules
found_up_down_blocks = False
for k in list(weights_sd.keys()):
if "down_blocks" in k:
found_up_down_blocks = True
break
if "up_blocks" in k:
found_up_down_blocks = True
break
if not found_up_down_blocks:
return
from library.sdxl_model_util import make_unet_conversion_map
unet_conversion_map = make_unet_conversion_map()
unet_conversion_map = {
hf.replace(".", "_")[:-1]: sd.replace(".", "_")[:-1]
for sd, hf in unet_conversion_map
}
# # add extra conversion
# unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv"
logger.info(f"Converting LoRA keys from Diffusers to SAI")
lora_unet_prefix = "lora_unet_"
for k in list(weights_sd.keys()):
if not k.startswith(lora_unet_prefix):
continue
unet_module_name = k[len(lora_unet_prefix) :].split(".")[0]
# search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small
for hf_module_name, sd_module_name in unet_conversion_map.items():
if hf_module_name in unet_module_name:
new_key = (
lora_unet_prefix
+ unet_module_name.replace(hf_module_name, sd_module_name)
+ k[len(lora_unet_prefix) + len(unet_module_name) :]
)
weights_sd[new_key] = weights_sd.pop(k)
found = True
break
if not found:
logger.warning(f"Key {k} is not found in unet_conversion_map")
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(
multiplier,
file,
vae,
text_encoder,
unet,
weights_sd=None,
for_inference=False,
**kwargs,
):
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel)
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# if keys are Diffusers based, convert to SAI based
convert_diffusers_to_sai_if_needed(weights_sd)
# get dim/alpha mapping
modules_dim = {}
modules_alpha = {}
for key, value in weights_sd.items():
if "." not in key:
continue
lora_name = key.split(".")[0]
if "alpha" in key:
modules_alpha[lora_name] = value
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# logger.info(lora_name, value.size(), dim)
# support old LoRA without alpha
for key in modules_dim.keys():
if key not in modules_alpha:
modules_alpha[key] = modules_dim[key]
module_class = LoRAInfModule if for_inference else LoRAModule
network = LoRANetwork(
text_encoder,
unet,
multiplier=multiplier,
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
)
# block lr
block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs)
if block_lr_weight is not None:
network.set_block_lr_weight(block_lr_weight)
return network, weights_sd
class LoRANetwork(torch.nn.Module):
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
NUM_OF_MID_BLOCKS = 1
SDXL_NUM_OF_BLOCKS = 9 # SDXLのモデルでのinput/outputの層の数 total=1(base) 9(input) + 3(mid) + 9(output) + 1(out) = 23
SDXL_NUM_OF_MID_BLOCKS = 3
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "HunYuanDiTBlock"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = [
"ResnetBlock2D",
"Downsample2D",
"Upsample2D",
]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP", "BertLayer"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
def __init__(
self,
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
unet,
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: float = 1,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
block_dims: Optional[List[int]] = None,
block_alphas: Optional[List[float]] = None,
conv_block_dims: Optional[List[int]] = None,
conv_block_alphas: Optional[List[float]] = None,
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
module_class: Type[object] = LoRAModule,
varbose: Optional[bool] = False,
is_sdxl: Optional[bool] = False,
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
1. lora_dimとalphaを指定
2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
5. modules_dimとmodules_alphaを指定 (推論用)
"""
super().__init__()
self.multiplier = multiplier
self.lora_dim = lora_dim
self.alpha = alpha
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info(f"create LoRA network from weights")
elif block_dims is not None:
logger.info(f"create LoRA network from block_dims")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
logger.info(f"block_dims: {block_dims}")
logger.info(f"block_alphas: {block_alphas}")
if conv_block_dims is not None:
logger.info(f"conv_block_dims: {conv_block_dims}")
logger.info(f"conv_block_alphas: {conv_block_alphas}")
else:
logger.info(
f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}"
)
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
if self.conv_lora_dim is not None:
logger.info(
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
)
# create module instances
def create_modules(
is_unet: bool,
text_encoder_idx: Optional[int], # None, 1, 2
root_module: torch.nn.Module,
target_replace_modules: List[torch.nn.Module],
) -> List[LoRAModule]:
prefix = (
self.LORA_PREFIX_UNET
if is_unet
else (
self.LORA_PREFIX_TEXT_ENCODER
if text_encoder_idx is None
else (
self.LORA_PREFIX_TEXT_ENCODER1
if text_encoder_idx == 1
else self.LORA_PREFIX_TEXT_ENCODER2
)
)
)
loras = []
skipped = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
dim = None
alpha = None
if modules_dim is not None:
# モジュール指定あり
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
elif is_unet and block_dims is not None:
# U-Netでblock_dims指定あり
block_idx = get_block_index(lora_name, is_sdxl)
if is_linear or is_conv2d_1x1:
dim = block_dims[block_idx]
alpha = block_alphas[block_idx]
elif conv_block_dims is not None:
dim = conv_block_dims[block_idx]
alpha = conv_block_alphas[block_idx]
else:
# 通常、すべて対象とする
if is_linear or is_conv2d_1x1:
dim = self.lora_dim
alpha = self.alpha
elif self.conv_lora_dim is not None:
dim = self.conv_lora_dim
alpha = self.conv_alpha
if dim is None or dim == 0:
# skipした情報を出力
if (
is_linear
or is_conv2d_1x1
or (
self.conv_lora_dim is not None
or conv_block_dims is not None
)
):
skipped.append(lora_name)
continue
lora = module_class(
lora_name,
child_module,
self.multiplier,
dim,
alpha,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
)
loras.append(lora)
return loras, skipped
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
# create LoRA for text encoder
# 毎回すべてのモジュールを作るのは無駄なので要検討
self.text_encoder_loras = []
skipped_te = []
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
index = i + 1
logger.info(f"create LoRA for Text Encoder {index}:")
else:
index = None
logger.info(f"create LoRA for Text Encoder:")
text_encoder_loras, skipped = create_modules(
False,
index,
text_encoder,
LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE,
)
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
logger.info(
f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules."
)
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
if (
modules_dim is not None
or self.conv_lora_dim is not None
or conv_block_dims is not None
):
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
skipped = skipped_te + skipped_un
if varbose and len(skipped) > 0:
logger.warning(
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
for name in skipped:
logger.info(f"\t{name}")
self.block_lr_weight = None
self.block_lr = False
# assertion
names = set()
for lora in self.text_encoder_loras + self.unet_loras:
assert (
lora.lora_name not in names
), f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def set_enabled(self, is_enabled):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
logger.info(
f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules"
)
else:
self.text_encoder_loras = []
if apply_unet:
logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
lora.apply_to()
self.add_module(lora.lora_name, lora)
# マージできるかどうかを返す
def is_mergeable(self):
return True
# TODO refactor to common function with apply_to
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
apply_text_encoder = apply_unet = False
for key in weights_sd.keys():
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
apply_text_encoder = True
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
apply_unet = True
if apply_text_encoder:
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
logger.info(f"weights are merged")
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
def set_block_lr_weight(self, block_lr_weight: Optional[List[float]]):
self.block_lr = True
self.block_lr_weight = block_lr_weight
def get_lr_weight(self, block_idx: int) -> float:
if not self.block_lr or self.block_lr_weight is None:
return 1.0
return self.block_lr_weight[block_idx]
def set_loraplus_lr_ratio(
self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio
):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
logger.info(
f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}"
)
logger.info(
f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}"
)
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
# TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?)
# if (
# self.loraplus_lr_ratio is not None
# or self.loraplus_text_encoder_lr_ratio is not None
# or self.loraplus_unet_lr_ratio is not None
# ):
# assert (
# optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower()
# ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません"
self.requires_grad_(True)
all_params = []
lr_descriptions = []
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
for name, param in lora.named_parameters():
if ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
params = []
descriptions = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr
if (
param_data.get("lr", None) == 0
or param_data.get("lr", None) is None
):
logger.info("NO LR skipping!")
continue
params.append(param_data)
descriptions.append("plus" if key == "plus" else "")
return params, descriptions
if self.text_encoder_loras:
params, descriptions = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(
["textencoder" + (" " + d if d else "") for d in descriptions]
)
if self.unet_loras:
if self.block_lr:
is_sdxl = False
for lora in self.unet_loras:
if (
"input_blocks" in lora.lora_name
or "output_blocks" in lora.lora_name
):
is_sdxl = True
break
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
block_idx_to_lora = {}
for lora in self.unet_loras:
idx = get_block_index(lora.lora_name, is_sdxl)
if idx not in block_idx_to_lora:
block_idx_to_lora[idx] = []
block_idx_to_lora[idx].append(lora)
# blockごとにパラメータを設定する
for idx, block_loras in block_idx_to_lora.items():
params, descriptions = assemble_params(
block_loras,
(unet_lr if unet_lr is not None else default_lr)
* self.get_lr_weight(idx),
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(
[
f"unet_block{idx}" + (" " + d if d else "")
for d in descriptions
]
)
else:
params, descriptions = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(
["unet" + (" " + d if d else "") for d in descriptions]
)
return all_params, lr_descriptions
def enable_gradient_checkpointing(self):
# not supported
pass
def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)
def on_epoch_start(self, text_encoder, unet):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util
# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(
state_dict, metadata
)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
# mask is a tensor with values from 0 to 1
def set_region(self, sub_prompt_index, is_last_network, mask):
if mask.max() == 0:
mask = torch.ones_like(mask)
self.mask = mask
self.sub_prompt_index = sub_prompt_index
self.is_last_network = is_last_network
for lora in self.text_encoder_loras + self.unet_loras:
lora.set_network(self)
def set_current_generation(
self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None
):
self.batch_size = batch_size
self.num_sub_prompts = num_sub_prompts
self.current_size = (height, width)
self.shared = shared
# create masks
mask = self.mask
mask_dic = {}
mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
ref_weight = (
self.text_encoder_loras[0].lora_down.weight
if self.text_encoder_loras
else self.unet_loras[0].lora_down.weight
)
dtype = ref_weight.dtype
device = ref_weight.device
def resize_add(mh, mw):
# logger.info(mh, mw, mh * mw)
m = torch.nn.functional.interpolate(
mask, (mh, mw), mode="bilinear"
) # doesn't work in bf16
m = m.to(device, dtype=dtype)
mask_dic[mh * mw] = m
h = height // 8
w = width // 8
for _ in range(4):
resize_add(h, w)
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
resize_add(h + h % 2, w + w % 2)
# deep shrink
if ds_ratio is not None:
hd = int(h * ds_ratio)
wd = int(w * ds_ratio)
resize_add(hd, wd)
h = (h + 1) // 2
w = (w + 1) // 2
self.mask_dic = mask_dic
def backup_weights(self):
# 重みのバックアップを行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True
def restore_weights(self):
# 重みのリストアを行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True
def pre_calculation(self):
# 事前計算を行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
sd = org_module.state_dict()
org_weight = sd["weight"]
lora_weight = lora.get_weight().to(
org_weight.device, dtype=org_weight.dtype
)
sd["weight"] = org_weight + lora_weight
assert sd["weight"].shape == org_weight.shape
org_module.load_state_dict(sd)
org_module._lora_restored = False
lora.enabled = False
def apply_max_norm_regularization(self, max_norm_value, device):
downkeys = []
upkeys = []
alphakeys = []
norms = []
keys_scaled = 0
state_dict = self.state_dict()
for key in state_dict.keys():
if "lora_down" in key and "weight" in key:
downkeys.append(key)
upkeys.append(key.replace("lora_down", "lora_up"))
alphakeys.append(key.replace("lora_down.weight", "alpha"))
for i in range(len(downkeys)):
down = state_dict[downkeys[i]].to(device)
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (
(up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2))
.unsqueeze(2)
.unsqueeze(3)
)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(
down.permute(1, 0, 2, 3), up
).permute(1, 0, 2, 3)
else:
updown = up @ down
updown *= scale
norm = updown.norm().clamp(min=max_norm_value / 2)
desired = torch.clamp(norm, max=max_norm_value)
ratio = desired.cpu() / norm.cpu()
sqrt_ratio = ratio**0.5
if ratio != 1:
keys_scaled += 1
state_dict[upkeys[i]] *= sqrt_ratio
state_dict[downkeys[i]] *= sqrt_ratio
scalednorm = updown.norm() * ratio
norms.append(scalednorm.item())
return keys_scaled, sum(norms) / len(norms), max(norms)
import importlib
import argparse
import math
import os
import sys
import random
import time
import json
from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import deepspeed_utils, model_util
import library.train_util as train_util
from library.train_util import DreamBoothDataset
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
apply_masked_loss,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
class NetworkTrainer:
def __init__(self):
self.vae_scale_factor = 0.18215
self.is_sdxl = False
# TODO 他のスクリプトと共通化する
def generate_step_logs(
self,
args: argparse.Namespace,
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
keys_scaled=None,
mean_norm=None,
maximum_norm=None,
):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
if keys_scaled is not None:
logs["max_norm/keys_scaled"] = keys_scaled
logs["max_norm/average_key_norm"] = mean_norm
logs["max_norm/max_key_norm"] = maximum_norm
lrs = lr_scheduler.get_last_lr()
for i, lr in enumerate(lrs):
if lr_descriptions is not None:
lr_desc = lr_descriptions[i]
else:
idx = i - (0 if args.network_train_unet_only else -1)
if idx == -1:
lr_desc = "textencoder"
else:
if len(lrs) > 2:
lr_desc = f"group{idx}"
else:
lr_desc = "unet"
logs[f"lr/{lr_desc}"] = lr
if (
args.optimizer_type.lower().startswith("DAdapt".lower())
or args.optimizer_type.lower() == "Prodigy".lower()
):
# tracking d*lr value
logs[f"lr/d*lr/{lr_desc}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"]
* lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
return logs
def assert_extra_args(self, args, train_dataset_group):
pass
def load_target_model(self, args, weight_dtype, accelerator):
text_encoder, vae, unet, _ = train_util.load_target_model(
args, weight_dtype, accelerator
)
return (
model_util.get_model_version_str_for_sd1_sd2(
args.v2, args.v_parameterization
),
text_encoder,
vae,
unet,
)
def load_tokenizer(self, args):
tokenizer = train_util.load_tokenizer(args)
return tokenizer
def load_noise_scheduler(self, args):
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
)
return noise_scheduler
def is_text_encoder_outputs_cached(self, args):
return False
def is_train_text_encoder(self, args):
return (
not args.network_train_unet_only
and not self.is_text_encoder_outputs_cached(args)
)
def cache_text_encoder_outputs_if_needed(
self,
args,
accelerator,
unet,
vae,
tokenizers,
text_encoders,
data_loader,
weight_dtype,
):
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)
def get_text_cond(
self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype
):
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizers[0], text_encoders[0], weight_dtype
)
return encoder_hidden_states
def call_unet(
self,
args,
accelerator,
unet,
noisy_latents,
timesteps,
text_conds,
batch,
weight_dtype,
):
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
return noise_pred
def all_reduce_network(self, accelerator, network):
for param in network.parameters():
if param.grad is not None:
param.grad = accelerator.reduce(param.grad, reduction="mean")
def sample_images(
self,
accelerator,
args,
epoch,
global_step,
device,
vae,
tokenizer,
text_encoder,
unet,
):
train_util.sample_images(
accelerator,
args,
epoch,
global_step,
device,
vae,
tokenizer,
text_encoder,
unet,
)
def train(self, args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
use_user_config = args.dataset_config is not None
if args.seed is None:
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
# tokenizerは単体またはリスト、tokenizersは必ずリスト:既存のコードとの互換性のため
tokenizer = self.load_tokenizer(args)
tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer]
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(
ConfigSanitizer(True, True, args.masked_loss, True)
)
if use_user_config:
logger.info(f"Loading dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(
user_config, args, tokenizer=tokenizer
)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(
blueprint.dataset_group
)
else:
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = (
train_dataset_group if args.max_data_loader_n_workers == 0 else None
)
collator = train_util.collator_class(
current_epoch, current_step, ds_for_collator
)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
self.assert_extra_args(args, train_dataset_group)
# acceleratorを準備する
logger.info("preparing accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
model_version, text_encoder, vae, unet = self.load_target_model(
args, weight_dtype, accelerator
)
# text_encoder is List[CLIPTextModel] or CLIPTextModel
text_encoders = (
text_encoder if isinstance(text_encoder, list) else [text_encoder]
)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(
unet, args.mem_eff_attn, args.xformers, args.sdpa
)
if (
torch.__version__ >= "2.0.0"
): # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(args.xformers)
# 差分追加学習のためにモデルを読み込む
sys.path.append(os.path.dirname(__file__))
accelerator.print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
if args.base_weights is not None:
# base_weights が指定されている場合は、指定された重みを読み込みマージする
for i, weight_path in enumerate(args.base_weights):
if (
args.base_weights_multiplier is None
or len(args.base_weights_multiplier) <= i
):
multiplier = 1.0
else:
multiplier = args.base_weights_multiplier[i]
accelerator.print(
f"merging module: {weight_path} with multiplier {multiplier}"
)
module, weights_sd = network_module.create_network_from_weights(
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
)
module.merge_to(
text_encoder,
unet,
weights_sd,
weight_dtype,
accelerator.device if args.lowram else "cpu",
)
accelerator.print(f"all weights merged: {', '.join(args.base_weights)}")
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(
vae,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
self.cache_text_encoder_outputs_if_needed(
args,
accelerator,
unet,
vae,
tokenizers,
text_encoders,
train_dataset_group,
weight_dtype,
)
# prepare network
net_kwargs = {}
if args.network_args is not None:
for net_arg in args.network_args:
key, value = net_arg.split("=")
net_kwargs[key] = value
# if a new network is added in future, add if ~ then blocks for each network (;'∀')
if args.dim_from_weights:
network, _ = network_module.create_network_from_weights(
1, args.network_weights, vae, text_encoder, unet, **net_kwargs
)
else:
if "dropout" not in net_kwargs:
# workaround for LyCORIS (;^ω^)
net_kwargs["dropout"] = args.network_dropout
network = network_module.create_network(
1.0,
args.network_dim,
args.network_alpha,
vae,
text_encoder,
unet,
neuron_dropout=args.network_dropout,
**net_kwargs,
)
if network is None:
return
network_has_multiplier = hasattr(network, "set_multiplier")
if hasattr(network, "prepare_network"):
network.prepare_network(args)
if args.scale_weight_norms and not hasattr(
network, "apply_max_norm_regularization"
):
logger.warning(
"warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
)
args.scale_weight_norms = False
train_unet = not args.network_train_text_encoder_only
train_text_encoder = self.is_train_text_encoder(args)
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
if args.network_weights is not None:
# FIXME consider alpha of weights
info = network.load_weights(args.network_weights)
accelerator.print(
f"load network weights from {args.network_weights}: {info}"
)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
for t_enc in text_encoders:
t_enc.gradient_checkpointing_enable()
del t_enc
network.enable_gradient_checkpointing() # may have no effect
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
# 後方互換性を確保するよ
try:
results = network.prepare_optimizer_params(
args.text_encoder_lr, args.unet_lr, args.learning_rate
)
if type(results) is tuple:
trainable_params = results[0]
lr_descriptions = results[1]
else:
trainable_params = results
lr_descriptions = None
except TypeError as e:
# logger.warning(f"{e}")
# accelerator.print(
# "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
# )
trainable_params = network.prepare_optimizer_params(
args.text_encoder_lr, args.unet_lr
)
lr_descriptions = None
# if len(trainable_params) == 0:
# accelerator.print("no trainable parameters found / 学習可能なパラメータが見つかりませんでした")
# for params in trainable_params:
# for k, v in params.items():
# if type(v) == float:
# pass
# else:
# v = len(v)
# accelerator.print(f"trainable_params: {k} = {v}")
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(
args, trainable_params
)
# dataloaderを準備する
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(
args.max_data_loader_n_workers, os.cpu_count()
) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader)
/ accelerator.num_processes
/ args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(
args, optimizer, accelerator.num_processes
)
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
network.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
network.to(weight_dtype)
unet_weight_dtype = te_weight_dtype = weight_dtype
# Experimental Feature: Put base model into fp8 to save vram
if args.fp8_base:
assert (
torch.__version__ >= "2.1.0"
), "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
assert (
args.mixed_precision != "no"
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
accelerator.print("enable fp8 training.")
unet_weight_dtype = torch.float8_e4m3fn
te_weight_dtype = torch.float8_e4m3fn
unet.requires_grad_(False)
unet.to(dtype=unet_weight_dtype)
for t_enc in text_encoders:
t_enc.requires_grad_(False)
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
if t_enc.device.type != "cpu":
t_enc.to(dtype=te_weight_dtype)
# nn.Embedding not support FP8
if hasattr(t_enc, "text_model"):
t_enc.text_model.embeddings.to(
dtype=(
weight_dtype
if te_weight_dtype != weight_dtype
else te_weight_dtype
)
)
elif hasattr(t_enc, "embeddings"):
# HunYuan Bert(CLIP)
t_enc.embeddings.to(
dtype=(
weight_dtype
if te_weight_dtype != weight_dtype
else te_weight_dtype
)
)
elif hasattr(t_enc, "get_token_embedding"):
# Others (mT5 or other encoder, will have custom method to get the correct embedding)
t_enc.get_token_embedding().to(
dtype=(
weight_dtype
if te_weight_dtype != weight_dtype
else te_weight_dtype
)
)
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(
args,
unet=unet if train_unet else None,
text_encoder1=text_encoders[0] if train_text_encoder else None,
text_encoder2=(
text_encoders[1]
if train_text_encoder and len(text_encoders) > 1
else None
),
network=network,
)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_model = ds_model
else:
if train_unet:
unet = accelerator.prepare(unet)
else:
unet.to(
accelerator.device, dtype=unet_weight_dtype
) # move to device because unet is not prepared by accelerator
if train_text_encoder:
if len(text_encoders) > 1:
text_encoder = text_encoders = [
accelerator.prepare(t_enc) for t_enc in text_encoders
]
else:
text_encoder = accelerator.prepare(text_encoder)
text_encoders = [text_encoder]
else:
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, lr_scheduler
)
training_model = network
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
unet.train()
for t_enc in text_encoders:
t_enc.train()
# set top parameter requires_grad = True for gradient checkpointing works
if train_text_encoder:
if hasattr(t_enc, "text_model"):
t_enc.text_model.embeddings.requires_grad_(True)
elif hasattr(t_enc, "embeddings"):
# HunYuan Bert(CLIP)
t_enc.embeddings.requires_grad_(True)
elif hasattr(t_enc, "get_token_embedding"):
# Others (mT5 or other encoder, will have custom method to get the correct embedding)
t_enc.get_token_embedding().requires_grad_(True)
else:
unet.eval()
for t_enc in text_encoders:
t_enc.eval()
del t_enc
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=vae_dtype)
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# before resuming make hook for saving/loading to save/load the network weights only
def save_model_hook(models, weights, output_dir):
# pop weights of other models than network to save only network weights
# only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
if accelerator.is_main_process or args.deepspeed:
remove_indices = []
for i, model in enumerate(models):
if not isinstance(model, type(accelerator.unwrap_model(network))):
remove_indices.append(i)
for i in reversed(remove_indices):
if len(weights) > i:
weights.pop(i)
# print(f"save model hook: {len(weights)} weights will be saved")
# save current ecpoch and step
train_state_file = os.path.join(output_dir, "train_state.json")
# +1 is needed because the state is saved before current_step is set from global_step
logger.info(
f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}"
)
with open(train_state_file, "w", encoding="utf-8") as f:
json.dump(
{
"current_epoch": current_epoch.value,
"current_step": current_step.value + 1,
},
f,
)
steps_from_state = None
def load_model_hook(models, input_dir):
# remove models except network
remove_indices = []
for i, model in enumerate(models):
if not isinstance(model, type(accelerator.unwrap_model(network))):
remove_indices.append(i)
for i in reversed(remove_indices):
models.pop(i)
# print(f"load model hook: {len(models)} models will be loaded")
# load current epoch and step to
nonlocal steps_from_state
train_state_file = os.path.join(input_dir, "train_state.json")
if os.path.exists(train_state_file):
with open(train_state_file, "r", encoding="utf-8") as f:
data = json.load(f)
steps_from_state = data["current_step"]
logger.info(f"load train state from {train_state_file}: {data}")
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps
)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = (
math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
)
# 学習する
# TODO: find a way to handle total batch size when there are multiple datasets
total_batch_size = (
args.train_batch_size
* accelerator.num_processes
* args.gradient_accumulation_steps
)
accelerator.print("running training / 学習開始")
accelerator.print(
f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}"
)
accelerator.print(
f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}"
)
accelerator.print(
f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}"
)
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(
f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}"
)
accelerator.print(
f" total optimization steps / 学習ステップ数: {args.max_train_steps}"
)
# TODO refactor metadata creation and move to util
metadata = {
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
"ss_training_started_at": training_started_at, # unix timestamp
"ss_output_name": args.output_name,
"ss_learning_rate": args.learning_rate,
"ss_text_encoder_lr": args.text_encoder_lr,
"ss_unet_lr": args.unet_lr,
"ss_num_train_images": train_dataset_group.num_train_images,
"ss_num_reg_images": train_dataset_group.num_reg_images,
"ss_num_batches_per_epoch": len(train_dataloader),
"ss_num_epochs": num_train_epochs,
"ss_gradient_checkpointing": args.gradient_checkpointing,
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
"ss_max_train_steps": args.max_train_steps,
"ss_lr_warmup_steps": args.lr_warmup_steps,
"ss_lr_scheduler": args.lr_scheduler,
"ss_network_module": args.network_module,
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
"ss_network_alpha": args.network_alpha, # some networks may not have alpha
"ss_network_dropout": args.network_dropout, # some networks may not have dropout
"ss_mixed_precision": args.mixed_precision,
"ss_full_fp16": bool(args.full_fp16),
"ss_v2": bool(args.v2),
"ss_base_model_version": model_version,
"ss_clip_skip": args.clip_skip,
"ss_max_token_length": args.max_token_length,
"ss_cache_latents": bool(args.cache_latents),
"ss_seed": args.seed,
"ss_lowram": args.lowram,
"ss_noise_offset": args.noise_offset,
"ss_multires_noise_iterations": args.multires_noise_iterations,
"ss_multires_noise_discount": args.multires_noise_discount,
"ss_adaptive_noise_scale": args.adaptive_noise_scale,
"ss_zero_terminal_snr": args.zero_terminal_snr,
"ss_training_comment": args.training_comment, # will not be updated after training
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
"ss_optimizer": optimizer_name
+ (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
"ss_max_grad_norm": args.max_grad_norm,
"ss_caption_dropout_rate": args.caption_dropout_rate,
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
"ss_face_crop_aug_range": args.face_crop_aug_range,
"ss_prior_loss_weight": args.prior_loss_weight,
"ss_min_snr_gamma": args.min_snr_gamma,
"ss_scale_weight_norms": args.scale_weight_norms,
"ss_ip_noise_gamma": args.ip_noise_gamma,
"ss_debiased_estimation": bool(args.debiased_estimation_loss),
"ss_noise_offset_random_strength": args.noise_offset_random_strength,
"ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
"ss_loss_type": args.loss_type,
"ss_huber_schedule": args.huber_schedule,
"ss_huber_c": args.huber_c,
}
if use_user_config:
# save metadata of multiple datasets
# NOTE: pack "ss_datasets" value as json one time
# or should also pack nested collections as json?
datasets_metadata = []
tag_frequency = {} # merge tag frequency for metadata editor
dataset_dirs_info = {} # merge subset dirs for metadata editor
for dataset in train_dataset_group.datasets:
is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
dataset_metadata = {
"is_dreambooth": is_dreambooth_dataset,
"batch_size_per_device": dataset.batch_size,
"num_train_images": dataset.num_train_images, # includes repeating
"num_reg_images": dataset.num_reg_images,
"resolution": (dataset.width, dataset.height),
"enable_bucket": bool(dataset.enable_bucket),
"min_bucket_reso": dataset.min_bucket_reso,
"max_bucket_reso": dataset.max_bucket_reso,
"tag_frequency": dataset.tag_frequency,
"bucket_info": dataset.bucket_info,
}
subsets_metadata = []
for subset in dataset.subsets:
subset_metadata = {
"img_count": subset.img_count,
"num_repeats": subset.num_repeats,
"color_aug": bool(subset.color_aug),
"flip_aug": bool(subset.flip_aug),
"random_crop": bool(subset.random_crop),
"shuffle_caption": bool(subset.shuffle_caption),
"keep_tokens": subset.keep_tokens,
"keep_tokens_separator": subset.keep_tokens_separator,
"secondary_separator": subset.secondary_separator,
"enable_wildcard": bool(subset.enable_wildcard),
"caption_prefix": subset.caption_prefix,
"caption_suffix": subset.caption_suffix,
}
image_dir_or_metadata_file = None
if subset.image_dir:
image_dir = os.path.basename(subset.image_dir)
subset_metadata["image_dir"] = image_dir
image_dir_or_metadata_file = image_dir
if is_dreambooth_dataset:
subset_metadata["class_tokens"] = subset.class_tokens
subset_metadata["is_reg"] = subset.is_reg
if subset.is_reg:
image_dir_or_metadata_file = None # not merging reg dataset
else:
metadata_file = os.path.basename(subset.metadata_file)
subset_metadata["metadata_file"] = metadata_file
image_dir_or_metadata_file = metadata_file # may overwrite
subsets_metadata.append(subset_metadata)
# merge dataset dir: not reg subset only
# TODO update additional-network extension to show detailed dataset config from metadata
if image_dir_or_metadata_file is not None:
# datasets may have a certain dir multiple times
v = image_dir_or_metadata_file
i = 2
while v in dataset_dirs_info:
v = image_dir_or_metadata_file + f" ({i})"
i += 1
image_dir_or_metadata_file = v
dataset_dirs_info[image_dir_or_metadata_file] = {
"n_repeats": subset.num_repeats,
"img_count": subset.img_count,
}
dataset_metadata["subsets"] = subsets_metadata
datasets_metadata.append(dataset_metadata)
# merge tag frequency:
for ds_dir_name, ds_freq_for_dir in dataset.tag_frequency.items():
# あるディレクトリが複数のdatasetで使用されている場合、一度だけ数える
# もともと繰り返し回数を指定しているので、キャプション内でのタグの出現回数と、それが学習で何度使われるかは一致しない
# なので、ここで複数datasetの回数を合算してもあまり意味はない
if ds_dir_name in tag_frequency:
continue
tag_frequency[ds_dir_name] = ds_freq_for_dir
metadata["ss_datasets"] = json.dumps(datasets_metadata)
metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
else:
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
assert (
len(train_dataset_group.datasets) == 1
), f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
dataset = train_dataset_group.datasets[0]
dataset_dirs_info = {}
reg_dataset_dirs_info = {}
if use_dreambooth_method:
for subset in dataset.subsets:
info = reg_dataset_dirs_info if subset.is_reg else dataset_dirs_info
info[os.path.basename(subset.image_dir)] = {
"n_repeats": subset.num_repeats,
"img_count": subset.img_count,
}
else:
for subset in dataset.subsets:
dataset_dirs_info[os.path.basename(subset.metadata_file)] = {
"n_repeats": subset.num_repeats,
"img_count": subset.img_count,
}
metadata.update(
{
"ss_batch_size_per_device": args.train_batch_size,
"ss_total_batch_size": total_batch_size,
"ss_resolution": args.resolution,
"ss_color_aug": bool(args.color_aug),
"ss_flip_aug": bool(args.flip_aug),
"ss_random_crop": bool(args.random_crop),
"ss_shuffle_caption": bool(args.shuffle_caption),
"ss_enable_bucket": bool(dataset.enable_bucket),
"ss_bucket_no_upscale": bool(dataset.bucket_no_upscale),
"ss_min_bucket_reso": dataset.min_bucket_reso,
"ss_max_bucket_reso": dataset.max_bucket_reso,
"ss_keep_tokens": args.keep_tokens,
"ss_dataset_dirs": json.dumps(dataset_dirs_info),
"ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info),
"ss_tag_frequency": json.dumps(dataset.tag_frequency),
"ss_bucket_info": json.dumps(dataset.bucket_info),
}
)
# add extra args
if args.network_args:
metadata["ss_network_args"] = json.dumps(net_kwargs)
# model name and hash
if args.pretrained_model_name_or_path is not None:
sd_model_name = args.pretrained_model_name_or_path
if os.path.exists(sd_model_name):
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(
sd_model_name
)
sd_model_name = os.path.basename(sd_model_name)
metadata["ss_sd_model_name"] = sd_model_name
if args.vae is not None:
vae_name = args.vae
if os.path.exists(vae_name):
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
vae_name = os.path.basename(vae_name)
metadata["ss_vae_name"] = vae_name
metadata = {k: str(v) for k, v in metadata.items()}
# make minimum metadata for filtering
minimum_metadata = {}
for key in train_util.SS_METADATA_MINIMUM_KEYS:
if key in metadata:
minimum_metadata[key] = metadata[key]
# calculate steps to skip when resuming or starting from a specific step
initial_step = 0
if args.initial_epoch is not None or args.initial_step is not None:
# if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming
if steps_from_state is not None:
logger.warning(
"steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます"
)
if args.initial_step is not None:
initial_step = args.initial_step
else:
# num steps per epoch is calculated by num_processes and gradient_accumulation_steps
initial_step = (args.initial_epoch - 1) * math.ceil(
len(train_dataloader)
/ accelerator.num_processes
/ args.gradient_accumulation_steps
)
else:
# if initial_epoch and initial_step are not specified, steps_from_state is used when resuming
if steps_from_state is not None:
initial_step = steps_from_state
steps_from_state = None
if initial_step > 0:
assert (
args.max_train_steps > initial_step
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
progress_bar = tqdm(
range(args.max_train_steps - initial_step),
smoothing=0,
disable=not accelerator.is_local_main_process,
desc="steps",
)
epoch_to_start = 0
if initial_step > 0:
if args.skip_until_initial_step:
# if skip_until_initial_step is specified, load data and discard it to ensure the same data is used
if not args.resume:
logger.info(
f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
)
logger.info(
f"skipping {initial_step} steps / {initial_step}ステップをスキップします"
)
initial_step *= args.gradient_accumulation_steps
# set epoch to start to make initial_step less than len(train_dataloader)
epoch_to_start = initial_step // math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps
)
else:
# if not, only epoch no is skipped for informative purpose
epoch_to_start = initial_step // math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps
)
initial_step = 0 # do not skip
global_step = 0
noise_scheduler = self.load_noise_scheduler(args)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(
noise_scheduler
)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
(
"network_train"
if args.log_tracker_name is None
else args.log_tracker_name
),
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
loss_recorder = train_util.LossRecorder()
del train_dataset_group
# callback for step start
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
on_step_start = accelerator.unwrap_model(network).on_step_start
else:
on_step_start = lambda *args, **kwargs: None
# function for saving/removing
def save_model(
ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False
):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
metadata["ss_training_finished_at"] = str(time.time())
metadata["ss_steps"] = str(steps)
metadata["ss_epoch"] = str(epoch_no)
metadata_to_save = minimum_metadata if args.no_metadata else metadata
sai_metadata = train_util.get_sai_model_spec(
None, args, self.is_sdxl, True, False
)
metadata_to_save.update(sai_metadata)
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
if args.huggingface_repo_id is not None:
huggingface_util.upload(
args,
ckpt_file,
"/" + ckpt_name,
force_sync_upload=force_sync_upload,
)
def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
self.sample_images(
accelerator,
args,
0,
global_step,
accelerator.device,
vae,
tokenizer,
text_encoder,
unet,
)
# training loop
if initial_step > 0: # only if skip_until_initial_step is specified
for skip_epoch in range(epoch_to_start): # skip epochs
logger.info(
f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}"
)
initial_step -= len(train_dataloader)
global_step = initial_step
for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
metadata["ss_epoch"] = str(epoch + 1)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
skipped_dataloader = None
if initial_step > 0:
skipped_dataloader = accelerator.skip_first_batches(
train_dataloader, initial_step - 1
)
initial_step = 1
for step, batch in enumerate(skipped_dataloader or train_dataloader):
current_step.value = global_step
if initial_step > 0:
initial_step -= 1
continue
with accelerator.accumulate(training_model):
on_step_start(text_encoder, unet)
if "latents" in batch and batch["latents"] is not None:
latents = (
batch["latents"]
.to(accelerator.device)
.to(dtype=weight_dtype)
)
else:
with torch.no_grad():
# latentに変換
latents = (
vae.encode(batch["images"].to(dtype=vae_dtype))
.latent_dist.sample()
.to(dtype=weight_dtype)
)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print(
"NaN found in latents, replacing with zeros"
)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * self.vae_scale_factor
# get multiplier for each sample
if network_has_multiplier:
multipliers = batch["network_multipliers"]
# if all multipliers are same, use single multiplier
if torch.all(multipliers == multipliers[0]):
multipliers = multipliers[0].item()
else:
raise NotImplementedError(
"multipliers for each sample is not supported yet"
)
# print(f"set multiplier: {multipliers}")
accelerator.unwrap_model(network).set_multiplier(multipliers)
with torch.set_grad_enabled(
train_text_encoder
), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
text_encoder_conds = get_weighted_text_embeddings(
tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
(
args.max_token_length // 75
if args.max_token_length
else 1
),
clip_skip=args.clip_skip,
)
else:
text_encoder_conds = self.get_text_cond(
args,
accelerator,
batch,
tokenizers,
text_encoders,
weight_dtype,
)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = (
train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
)
# ensure the hidden state will require grad
if args.gradient_checkpointing:
for x in noisy_latents:
x.requires_grad_(True)
for t in text_encoder_conds:
if t.dtype in {
torch.float16,
torch.bfloat16,
torch.float32,
}:
t.requires_grad_(True)
# Predict the noise residual
with accelerator.autocast():
noise_pred = self.call_unet(
args,
accelerator,
unet,
noisy_latents.requires_grad_(train_unet),
timesteps,
text_encoder_conds,
batch,
weight_dtype,
)
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = train_util.conditional_loss(
noise_pred.float(),
target.float(),
reduction="none",
loss_type=args.loss_type,
huber_c=huber_c,
)
if args.masked_loss or (
"alpha_masks" in batch and batch["alpha_masks"] is not None
):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(
loss,
timesteps,
noise_scheduler,
args.min_snr_gamma,
args.v_parameterization,
)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(
loss, timesteps, noise_scheduler
)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(
loss, timesteps, noise_scheduler, args.v_pred_like_loss
)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(
loss, timesteps, noise_scheduler
)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients:
self.all_reduce_network(
accelerator, network
) # sync DDP grad manually
if args.max_grad_norm != 0.0:
params_to_clip = accelerator.unwrap_model(
network
).get_trainable_params()
accelerator.clip_grad_norm_(
params_to_clip, args.max_grad_norm
)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if args.scale_weight_norms:
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(
network
).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
)
max_mean_logs = {
"Keys Scaled": keys_scaled,
"Average key norm": mean_norm,
}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
self.sample_images(
accelerator,
args,
None,
global_step,
accelerator.device,
vae,
tokenizer,
text_encoder,
unet,
)
# 指定ステップごとにモデルを保存
if (
args.save_every_n_steps is not None
and global_step % args.save_every_n_steps == 0
):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(
args, "." + args.save_model_as, global_step
)
save_model(
ckpt_name,
accelerator.unwrap_model(network),
global_step,
epoch,
)
if args.save_state:
train_util.save_and_remove_state_stepwise(
args, accelerator, global_step
)
remove_step_no = train_util.get_remove_step_no(
args, global_step
)
if remove_step_no is not None:
remove_ckpt_name = train_util.get_step_ckpt_name(
args, "." + args.save_model_as, remove_step_no
)
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs, **logs})
if args.logging_dir is not None:
logs = self.generate_step_logs(
args,
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
keys_scaled,
mean_norm,
maximum_norm,
)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存
if args.save_every_n_epochs is not None:
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (
epoch + 1
) < num_train_epochs
if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(
args, "." + args.save_model_as, epoch + 1
)
save_model(
ckpt_name,
accelerator.unwrap_model(network),
global_step,
epoch + 1,
)
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
remove_ckpt_name = train_util.get_epoch_ckpt_name(
args, "." + args.save_model_as, remove_epoch_no
)
remove_model(remove_ckpt_name)
if args.save_state:
train_util.save_and_remove_state_on_epoch_end(
args, accelerator, epoch + 1
)
self.sample_images(
accelerator,
args,
epoch + 1,
global_step,
accelerator.device,
vae,
tokenizer,
text_encoder,
unet,
)
# end of epoch
# metadata["ss_epoch"] = str(num_train_epochs)
metadata["ss_training_finished_at"] = str(time.time())
if is_main_process:
network = accelerator.unwrap_model(network)
accelerator.end_training()
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(
ckpt_name,
network,
global_step,
num_train_epochs,
force_sync_upload=True,
)
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument(
"--no_metadata",
action="store_true",
help="do not save metadata in output model / メタデータを出力先モデルに保存しない",
)
parser.add_argument(
"--save_model_as",
type=str,
default="safetensors",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
)
parser.add_argument(
"--unet_lr",
type=float,
default=None,
help="learning rate for U-Net / U-Netの学習率",
)
parser.add_argument(
"--text_encoder_lr",
type=float,
default=None,
help="learning rate for Text Encoder / Text Encoderの学習率",
)
parser.add_argument(
"--network_weights",
type=str,
default=None,
help="pretrained weights for network / 学習するネットワークの初期重み",
)
parser.add_argument(
"--network_module",
type=str,
default=None,
help="network module to train / 学習対象のネットワークのモジュール",
)
parser.add_argument(
"--network_dim",
type=int,
default=None,
help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)",
)
parser.add_argument(
"--network_alpha",
type=float,
default=1,
help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)",
)
parser.add_argument(
"--network_dropout",
type=float,
default=None,
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
)
parser.add_argument(
"--network_args",
type=str,
default=None,
nargs="*",
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
)
parser.add_argument(
"--network_train_unet_only",
action="store_true",
help="only training U-Net part / U-Net関連部分のみ学習する",
)
parser.add_argument(
"--network_train_text_encoder_only",
action="store_true",
help="only training Text Encoder part / Text Encoder関連部分のみ学習する",
)
parser.add_argument(
"--training_comment",
type=str,
default=None,
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列",
)
parser.add_argument(
"--dim_from_weights",
action="store_true",
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
)
parser.add_argument(
"--scale_weight_norms",
type=float,
default=None,
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
)
parser.add_argument(
"--base_weights",
type=str,
default=None,
nargs="*",
help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル",
)
parser.add_argument(
"--base_weights_multiplier",
type=float,
default=None,
nargs="*",
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--skip_until_initial_step",
action="store_true",
help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする",
)
parser.add_argument(
"--initial_epoch",
type=int,
default=None,
help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`."
+ " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる",
)
parser.add_argument(
"--initial_step",
type=int,
default=None,
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする",
)
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = NetworkTrainer()
trainer.train(args)
## Using HunyuanDiT Inference with under 6GB GPU VRAM
### Instructions
Running HunyuanDiT in under 6GB GPU VRAM is available now based on [**diffusers**](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuandit). Here we provide instructions and demo for your quick start.
The 6Glite version supports Nvidia Ampere architecture series graphics cards such as RTX 3070/3080/4080/4090, A100, and so on.
The only thing you need do is to install the following library:
```bash
pip install -U bitsandbytes
pip install git+https://github.com/huggingface/diffusers
pip install torch==2.0.0
```
Then you can enjoy your HunyuanDiT text-to-image journey under 6GB GPU VRAM directly!
Here is a demo for you.
```bash
cd HunyuanDiT
# Quick start
model_id=Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers-Distilled
prompt=一个宇航员在骑马
infer_steps=50
guidance_scale=6
python3 lite/inference.py ${model_id} ${prompt} ${infer_steps} ${guidance_scale}
```
Note: To use other features in hydit requires torch 1.13.1. In this case, you may need to downgrade your torch version.
```bash
pip install torch==1.13.1
```
\ No newline at end of file
import random
import torch
from diffusers import HunyuanDiTPipeline
from transformers import T5EncoderModel
import time
from loguru import logger
import gc
import sys
NEGATIVE_PROMPT = ""
TEXT_ENCODER_CONF = {
"negative_prompt": NEGATIVE_PROMPT,
"prompt_embeds": None,
"negative_prompt_embeds": None,
"prompt_attention_mask": None,
"negative_prompt_attention_mask": None,
"max_sequence_length": 256,
"text_encoder_index": 1,
}
def flush():
gc.collect()
torch.cuda.empty_cache()
class End2End(object):
def __init__(self, model_id="Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled"):
self.model_id = model_id
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# ========================================================================
self.default_negative_prompt = NEGATIVE_PROMPT
logger.info("==================================================")
logger.info(f" Model is ready. ")
logger.info("==================================================")
def load_pipeline(self):
self.pipeline = HunyuanDiTPipeline.from_pretrained(
self.model_id,
text_encoder=None,
text_encoder_2=None,
torch_dtype=torch.float16,
).to(self.device)
def get_text_emb(self, prompts):
with torch.no_grad():
text_encoder_2 = T5EncoderModel.from_pretrained(
self.model_id,
subfolder="text_encoder_2",
load_in_8bit=True,
device_map="auto",
)
encoder_pipeline = HunyuanDiTPipeline.from_pretrained(
self.model_id,
text_encoder_2=text_encoder_2,
transformer=None,
vae=None,
torch_dtype=torch.float16,
device_map="balanced",
)
TEXT_ENCODER_CONF["negative_prompt"] = self.default_negative_prompt
prompt_emb1 = encoder_pipeline.encode_prompt(
prompts, negative_prompt=self.default_negative_prompt
)
prompt_emb2 = encoder_pipeline.encode_prompt(prompts, **TEXT_ENCODER_CONF)
del text_encoder_2
del encoder_pipeline
flush()
return prompt_emb1, prompt_emb2
def predict(
self,
user_prompt,
seed=None,
enhanced_prompt=None,
negative_prompt=None,
infer_steps=50,
guidance_scale=6,
batch_size=1,
):
# ========================================================================
# Arguments: seed
# ========================================================================
if seed is None:
seed = random.randint(0, 1_000_000)
if not isinstance(seed, int):
raise TypeError(f"`seed` must be an integer, but got {type(seed)}")
generator = torch.Generator(device=self.device).manual_seed(seed)
# ========================================================================
# Arguments: prompt, new_prompt, negative_prompt
# ========================================================================
if not isinstance(user_prompt, str):
raise TypeError(
f"`user_prompt` must be a string, but got {type(user_prompt)}"
)
user_prompt = user_prompt.strip()
prompt = user_prompt
if enhanced_prompt is not None:
if not isinstance(enhanced_prompt, str):
raise TypeError(
f"`enhanced_prompt` must be a string, but got {type(enhanced_prompt)}"
)
enhanced_prompt = enhanced_prompt.strip()
prompt = enhanced_prompt
# negative prompt
if negative_prompt is not None and negative_prompt != "":
self.default_negative_prompt = negative_prompt
if not isinstance(self.default_negative_prompt, str):
raise TypeError(
f"`negative_prompt` must be a string, but got {type(negative_prompt)}"
)
# ========================================================================
logger.debug(
f"""
prompt: {user_prompt}
enhanced prompt: {enhanced_prompt}
seed: {seed}
negative_prompt: {negative_prompt}
batch_size: {batch_size}
guidance_scale: {guidance_scale}
infer_steps: {infer_steps}
"""
)
# get text embeding
flush()
prompt_emb1, prompt_emb2 = self.get_text_emb(prompt)
(
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
) = prompt_emb1
(
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
) = prompt_emb2
del prompt_emb1
del prompt_emb2
# get pipeline
self.load_pipeline()
samples = self.pipeline(
prompt_embeds=prompt_embeds,
prompt_embeds_2=prompt_embeds_2,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_2=negative_prompt_embeds_2,
prompt_attention_mask=prompt_attention_mask,
prompt_attention_mask_2=prompt_attention_mask_2,
negative_prompt_attention_mask=negative_prompt_attention_mask,
negative_prompt_attention_mask_2=negative_prompt_attention_mask_2,
num_images_per_prompt=batch_size,
guidance_scale=guidance_scale,
num_inference_steps=infer_steps,
generator=generator,
).images[0]
return {
"images": samples,
"seed": seed,
}
if __name__ == "__main__":
if len(sys.argv) != 5:
print(
"Usage: python lite/inference.py ${model_id} ${prompt} ${infer_steps} ${guidance_scale}"
)
print(
"model_id: Choose a diffusers repository from the official Hugging Face repository https://huggingface.co/Tencent-Hunyuan, "
"such as Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers, "
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled, "
"Tencent-Hunyuan/HunyuanDiT-Diffusers, or Tencent-Hunyuan/HunyuanDiT-Diffusers-Distilled."
)
print("prompt: the input prompt")
print("infer_steps: infer_steps")
print("guidance_scale: guidance_scale")
sys.exit(1)
model_id = sys.argv[1]
prompt = sys.argv[2]
infer_steps = int(sys.argv[3])
guidance_scale = int(sys.argv[4])
gen = End2End(model_id)
seed = 42
results = gen.predict(
prompt,
seed=seed,
infer_steps=infer_steps,
guidance_scale=guidance_scale,
)
results["images"].save("./lite_image.png")
This source diff could not be displayed because it is too large. You can view the blob instead.
## Using LoRA to fine-tune HunyuanDiT
### Instructions
The dependencies and installation are basically the same as the [**base model**](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2).
We provide two types of trained LoRA weights for you to test.
Then download the model using the following commands:
```bash
cd HunyuanDiT
# Use the huggingface-cli tool to download the model.
huggingface-cli download Tencent-Hunyuan/HYDiT-LoRA --local-dir ./ckpts/t2i/lora
# Quick start
python sample_t2i.py --prompt "青花瓷风格,一只猫在追蝴蝶" --no-enhance --load-key ema --lora-ckpt ./ckpts/t2i/lora/porcelain --infer-mode fa
```
Examples of training data and inference results are as follows:
<table>
<tr>
<td colspan="4" align="center">Examples of training data</td>
</tr>
<tr>
<td align="center"><img src="asset/porcelain/train/0.png" alt="Image 0" width="200"/></td>
<td align="center"><img src="asset/porcelain/train/1.png" alt="Image 1" width="200"/></td>
<td align="center"><img src="asset/porcelain/train/2.png" alt="Image 2" width="200"/></td>
<td align="center"><img src="asset/porcelain/train/3.png" alt="Image 3" width="200"/></td>
</tr>
<tr>
<td align="center">青花瓷风格,一只蓝色的鸟儿站在蓝色的花瓶上,周围点缀着白色花朵,背景是白色 (Porcelain style, a blue bird stands on a blue vase, surrounded by white flowers, with a white background.
</td>
<td align="center">青花瓷风格,这是一幅蓝白相间的陶瓷盘子,上面描绘着一只狐狸和它的幼崽在森林中漫步,背景是白色 (Porcelain style, this is a blue and white ceramic plate depicting a fox and its cubs strolling in the forest, with a white background.)</td>
<td align="center">青花瓷风格,在黑色背景上,一只蓝色的狼站在蓝白相间的盘子上,周围是树木和月亮 (Porcelain style, on a black background, a blue wolf stands on a blue and white plate, surrounded by trees and the moon.)</td>
<td align="center">青花瓷风格,在蓝色背景上,一只蓝色蝴蝶和白色花朵被放置在中央 (Porcelain style, on a blue background, a blue butterfly and white flowers are placed in the center.)</td>
</tr>
<tr>
<td colspan="4" align="center">Examples of inference results</td>
</tr>
<tr>
<td align="center"><img src="asset/porcelain/inference/0.png" alt="Image 4" width="200"/></td>
<td align="center"><img src="asset/porcelain/inference/1.png" alt="Image 5" width="200"/></td>
<td align="center"><img src="asset/porcelain/inference/2.png" alt="Image 6" width="200"/></td>
<td align="center"><img src="asset/porcelain/inference/3.png" alt="Image 7" width="200"/></td>
</tr>
<tr>
<td align="center">青花瓷风格,苏州园林 (Porcelain style, Suzhou Gardens.)</td>
<td align="center">青花瓷风格,一朵荷花 (Porcelain style, a lotus flower.)</td>
<td align="center">青花瓷风格,一只羊(Porcelain style, a sheep.)</td>
<td align="center">青花瓷风格,一个女孩在雨中跳舞(Porcelain style, a girl dancing in the rain.)</td>
</tr>
</table>
### Training
We provide three types of weights for fine-tuning LoRA, `ema`, `module` and `distill`, and you can choose according to the actual effect. By default, we use `ema` weights.
Here is an example for LoRA with HunYuanDiT v1.2, we load the `distill` weights into the main model and perform LoRA fine-tuning through the `resume_module_root=./ckpts/t2i/model/pytorch_model_distill.pt` setting.
If multiple resolution are used, you need to add the `--multireso` and `--reso-step 64 ` parameter.
If you want to train LoRA with HunYuanDiT v1.1, you could add `--use-style-cond`, `--size-cond 1024 1024` and `--beta-end 0.03`.
```bash
model='DiT-g/2' # model type
task_flag="lora_porcelain_ema_rank64" # task flag
resume_module_root=./ckpts/t2i/model/pytorch_model_distill.pt # resume checkpoint
index_file=dataset/porcelain/jsons/porcelain.json # the selected data indices
results_dir=./log_EXP # save root for results
batch_size=1 # training batch size
image_size=1024 # training image resolution
grad_accu_steps=2 # gradient accumulation steps
warmup_num_steps=0 # warm-up steps
lr=0.0001 # learning rate
ckpt_every=100 # create a ckpt every a few steps.
ckpt_latest_every=2000 # create a ckpt named `latest.pt` every a few steps.
rank=64 # rank of lora
max_training_steps=2000 # Maximum training iteration steps
PYTHONPATH=./ deepspeed hydit/train_deepspeed.py \
--task-flag ${task_flag} \
--model ${model} \
--training-parts lora \
--rank ${rank} \
--resume \
--resume-module-root ${resume_module_root} \
--lr ${lr} \
--noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.018 \
--predict-type v_prediction \
--uncond-p 0 \
--uncond-p-t5 0 \
--index-file ${index_file} \
--random-flip \
--batch-size ${batch_size} \
--image-size ${image_size} \
--global-seed 999 \
--grad-accu-steps ${grad_accu_steps} \
--warmup-num-steps ${warmup_num_steps} \
--use-flash-attn \
--use-fp16 \
--ema-dtype fp32 \
--results-dir ${results_dir} \
--ckpt-every ${ckpt_every} \
--max-training-steps ${max_training_steps}\
--ckpt-latest-every ${ckpt_latest_every} \
--log-every 10 \
--deepspeed \
--deepspeed-optimizer \
--use-zero-stage 2 \
--qk-norm \
--rope-img base512 \
--rope-real \
"$@"
```
Recommended parameter settings
| Parameter | Description | Recommended Parameter Value | Note|
|:---------------:|:---------:|:---------------------------------------------------:|:--:|
| `--batch-size` | Training batch size | 1 | Depends on GPU memory|
| `--grad-accu-steps` | Size of gradient accumulation | 2 | - |
| `--rank` | Rank of lora | 64 | Choosing from 8-128 |
| `--max-training-steps` | Training steps | 2000 | Depend on training data size, for reference apply 2000 steps on 100 images|
| `--lr` | Learning rate | 0.0001 | - |
### Inference
After the training is complete, you can use the following command line for inference.
We provide the `--lora-ckpt` parameter for selecting the folder which contains lora weights and configurations.
a. Using LoRA during inference
```bash
python sample_t2i.py --infer-mode fa --prompt "青花瓷风格,一只小狗" --no-enhance --lora-ckpt log_EXP/001-lora_porcelain_ema_rank64/checkpoints/0001000.pt/
```
b. Using LoRA in gradio
```bash
python app/hydit_app.py --infer-mode fa --no-enhance --lora-ckpt log_EXP/001-lora_porcelain_ema_rank64/checkpoints/0001000.pt/
```
c. Merge LoRA weights into the main model
We provide the `--output-merge-path` parameter to set the path for saving the merged weights.
```bash
PYTHONPATH=./ python lora/merge.py --lora-ckpt log_EXP/001-lora_porcelain_ema_rank64/checkpoints/0000100.pt/ --output-merge-path ./ckpts/t2i/model/pytorch_model_merge.pt
```
d. Regarding how to use the LoRA weights we trained in diffusion, we provide the following script. To ensure compatibility with the diffuser, some modifications are made, which means that LoRA cannot be directly loaded.
```python
import torch
from diffusers import HunyuanDiTPipeline
num_layers = 40
def load_hunyuan_dit_lora(transformer_state_dict, lora_state_dict, lora_scale):
for i in range(num_layers):
Wqkv = torch.matmul(lora_state_dict[f"blocks.{i}.attn1.Wqkv.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn1.Wqkv.lora_A.weight"])
q, k, v = torch.chunk(Wqkv, 3, dim=0)
transformer_state_dict[f"blocks.{i}.attn1.to_q.weight"] += lora_scale * q
transformer_state_dict[f"blocks.{i}.attn1.to_k.weight"] += lora_scale * k
transformer_state_dict[f"blocks.{i}.attn1.to_v.weight"] += lora_scale * v
out_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn1.out_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn1.out_proj.lora_A.weight"])
transformer_state_dict[f"blocks.{i}.attn1.to_out.0.weight"] += lora_scale * out_proj
q_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.q_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.q_proj.lora_A.weight"])
transformer_state_dict[f"blocks.{i}.attn2.to_q.weight"] += lora_scale * q_proj
kv_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.kv_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.kv_proj.lora_A.weight"])
k, v = torch.chunk(kv_proj, 2, dim=0)
transformer_state_dict[f"blocks.{i}.attn2.to_k.weight"] += lora_scale * k
transformer_state_dict[f"blocks.{i}.attn2.to_v.weight"] += lora_scale * v
out_proj = torch.matmul(lora_state_dict[f"blocks.{i}.attn2.out_proj.lora_B.weight"], lora_state_dict[f"blocks.{i}.attn2.out_proj.lora_A.weight"])
transformer_state_dict[f"blocks.{i}.attn2.to_out.0.weight"] += lora_scale * out_proj
q_proj = torch.matmul(lora_state_dict["pooler.q_proj.lora_B.weight"], lora_state_dict["pooler.q_proj.lora_A.weight"])
transformer_state_dict["time_extra_emb.pooler.q_proj.weight"] += lora_scale * q_proj
return transformer_state_dict
pipe = HunyuanDiTPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", torch_dtype=torch.float16)
pipe.to("cuda")
from safetensors import safe_open
lora_state_dict = {}
with safe_open("./ckpts/t2i/lora/jade/adapter_model.safetensors", framework="pt", device=0) as f:
for k in f.keys():
lora_state_dict[k[17:]] = f.get_tensor(k) # remove 'basemodel.model'
transformer_state_dict = pipe.transformer.state_dict()
transformer_state_dict = load_hunyuan_dit_lora(transformer_state_dict, lora_state_dict, lora_scale=1.0)
pipe.transformer.load_state_dict(transformer_state_dict)
prompt = "玉石绘画风格,一只猫在追蝴蝶"
image = pipe(
prompt,
num_inference_steps=100,
guidance_scale=6.0,
).images[0]
image.save('img.png')
```
e. For more information, please refer to [HYDiT-LoRA](https://huggingface.co/Tencent-Hunyuan/HYDiT-LoRA).
import torch
import os
from hydit.config import get_args
from hydit.modules.models import HUNYUAN_DIT_MODELS
from hydit.inference import _to_tuple
args = get_args()
image_size = _to_tuple(args.image_size)
latent_size = (image_size[0] // 8, image_size[1] // 8)
model = HUNYUAN_DIT_MODELS[args.model](
args,
input_size=latent_size,
log_fn=print,
)
model_path = os.path.join(
args.model_root, "t2i", "model", f"pytorch_model_{args.load_key}.pt"
)
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
print(f"Loading model from {model_path}")
model.load_state_dict(state_dict)
print(f"Loading lora from {args.lora_ckpt}")
model.load_adapter(args.lora_ckpt)
model.merge_and_unload()
torch.save(model.state_dict(), args.output_merge_path)
print(f"Model saved to {args.output_merge_path}")
export CUDA_VISIBLE_DEVICES=0
model='DiT-g/2' # model type
resume_module_root=./ckpts/t2i/model/pytorch_model_distill.pt # resume checkpoint
index_file=dataset/porcelain/jsons/porcelain.json # the selected data indices
results_dir=./log_EXP # save root for results
batch_size=2 # training batch size
image_size=1024 # training image resolution
grad_accu_steps=1 # gradient accumulation steps
warmup_num_steps=0 # warm-up steps
lr=0.0001 # learning rate
ckpt_every=100 # create a ckpt every a few steps.
ckpt_latest_every=2000 # create a ckpt named `latest.pt` every a few steps.
rank=128 # rank of lora
max_training_steps=2000 # Maximum training iteration steps
task_flag="lora_porcelain_ema_rank${rank}" # task flag
echo $task_flag
PYTHONPATH=./ deepspeed hydit/train_deepspeed.py \
--no-flash-attn \
--task-flag ${task_flag} \
--model ${model} \
--training-parts lora \
--rank ${rank} \
--resume \
--resume-module-root ${resume_module_root} \
--lr ${lr} \
--noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.018 \
--predict-type v_prediction \
--uncond-p 0 \
--uncond-p-t5 0 \
--index-file ${index_file} \
--random-flip \
--batch-size ${batch_size} \
--image-size ${image_size} \
--global-seed 999 \
--grad-accu-steps ${grad_accu_steps} \
--warmup-num-steps ${warmup_num_steps} \
--use-fp16 \
--ema-dtype fp32 \
--results-dir ${results_dir} \
--ckpt-every ${ckpt_every} \
--max-training-steps ${max_training_steps}\
--ckpt-latest-every ${ckpt_latest_every} \
--log-every 10 \
--use-zero-stage 2 \
--qk-norm \
--rope-img base512 \
--rope-real \
--gradient-checkpointing \
--deepspeed-optimizer \
--deepspeed \
"$@"
export CUDA_VISIBLE_DEVICES=0
model='DiT-g/2' # model type
task_flag="lora_porcelain_ema_rank64" # task flag
resume_module_root=./ckpts/t2i/model/pytorch_model_distill.pt # resume checkpoint
index_file=dataset/porcelain/jsons/porcelain.json # the selected data indices
results_dir=./log_EXP # save root for results
batch_size=2 # training batch size
image_size=1024 # training image resolution
grad_accu_steps=1 # gradient accumulation steps
warmup_num_steps=0 # warm-up steps
lr=0.0001 # learning rate
ckpt_every=100 # create a ckpt every a few steps.
ckpt_latest_every=2000 # create a ckpt named `latest.pt` every a few steps.
rank=64 # rank of lora
max_training_steps=2000 # Maximum training iteration steps
PYTHONPATH=./ deepspeed hydit/train_deepspeed.py \
--task-flag ${task_flag} \
--model ${model} \
--training-parts lora \
--rank ${rank} \
--resume \
--resume-module-root ${resume_module_root} \
--lr ${lr} \
--noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.018 \
--predict-type v_prediction \
--uncond-p 0 \
--uncond-p-t5 0 \
--index-file ${index_file} \
--random-flip \
--batch-size ${batch_size} \
--image-size ${image_size} \
--global-seed 999 \
--grad-accu-steps ${grad_accu_steps} \
--warmup-num-steps ${warmup_num_steps} \
--use-flash-attn \
--use-fp16 \
--ema-dtype fp32 \
--results-dir ${results_dir} \
--ckpt-every ${ckpt_every} \
--max-training-steps ${max_training_steps}\
--ckpt-latest-every ${ckpt_latest_every} \
--log-every 10 \
--use-zero-stage 2 \
--deepspeed \
--deepspeed-optimizer \
--qk-norm \
--rope-img base512 \
--rope-real \
--gradient-checkpointing \
"$@"
export CUDA_VISIBLE_DEVICES=0
model='DiT-g/2' # model type
task_flag="lora_porcelain_ema_rank64" # task flag
resume_module_root=./ckpts/t2i/model/pytorch_model_distill.pt # resume checkpoint
index_file=dataset/porcelain/jsons/porcelain.json # the selected data indices
results_dir=./log_EXP # save root for results
batch_size=2 # training batch size
image_size=1024 # training image resolution
grad_accu_steps=1 # gradient accumulation steps
warmup_num_steps=0 # warm-up steps
lr=0.0001 # learning rate
ckpt_every=100 # create a ckpt every a few steps.
ckpt_latest_every=2000 # create a ckpt named `latest.pt` every a few steps.
rank=64 # rank of lora
max_training_steps=2000 # Maximum training iteration steps
PYTHONPATH=./ deepspeed hydit/train_deepspeed.py \
--task-flag ${task_flag} \
--model ${model} \
--training-parts lora \
--rank ${rank} \
--resume \
--resume-module-root ${resume_module_root} \
--lr ${lr} \
--noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.018 \
--predict-type v_prediction \
--uncond-p 0 \
--uncond-p-t5 0 \
--index-file ${index_file} \
--random-flip \
--batch-size ${batch_size} \
--image-size ${image_size} \
--global-seed 999 \
--grad-accu-steps ${grad_accu_steps} \
--warmup-num-steps ${warmup_num_steps} \
--use-flash-attn \
--use-fp16 \
--ema-dtype fp32 \
--results-dir ${results_dir} \
--ckpt-every ${ckpt_every} \
--max-training-steps ${max_training_steps}\
--ckpt-latest-every ${ckpt_latest_every} \
--log-every 10 \
--use-zero-stage 2 \
--deepspeed \
--deepspeed-optimizer \
--qk-norm \
--rope-img base512 \
--rope-real \
--gradient-checkpointing \
--gc-rate 0.5 \
"$@"
model='DiT-g/2' # model type
task_flag="lora_porcelain_ema_rank64" # task flag
resume=./ckpts/t2i/model/ # resume checkpoint
index_file=dataset/porcelain/jsons/porcelain.json # the selected data indices
results_dir=./log_EXP # save root for results
batch_size=1 # training batch size
image_size=1024 # training image resolution
grad_accu_steps=2 # gradient accumulation steps
warmup_num_steps=0 # warm-up steps
lr=0.0001 # learning rate
ckpt_every=100 # create a ckpt every a few steps.
ckpt_latest_every=2000 # create a ckpt named `latest.pt` every a few steps.
rank=64 # rank of lora
max_training_steps=2000 # Maximum training iteration steps
PYTHONPATH=./ deepspeed hydit/train_deepspeed.py \
--task-flag ${task_flag} \
--model ${model} \
--training-parts lora \
--rank ${rank} \
--resume-split \
--resume ${resume} \
--ema-to-module \
--lr ${lr} \
--noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.03 \
--predict-type v_prediction \
--uncond-p 0.44 \
--uncond-p-t5 0.44 \
--index-file ${index_file} \
--random-flip \
--batch-size ${batch_size} \
--image-size ${image_size} \
--global-seed 999 \
--grad-accu-steps ${grad_accu_steps} \
--warmup-num-steps ${warmup_num_steps} \
--use-flash-attn \
--use-fp16 \
--ema-dtype fp32 \
--results-dir ${results_dir} \
--ckpt-every ${ckpt_every} \
--max-training-steps ${max_training_steps}\
--ckpt-latest-every ${ckpt_latest_every} \
--log-every 10 \
--deepspeed \
--deepspeed-optimizer \
--use-zero-stage 2 \
--qk-norm \
--rope-img base512 \
--rope-real \
--use-style-cond \
--size-cond 1024 1024 \
"$@"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment