Commit 727428ec authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit CI/CD

parents
import random
import time
from pathlib import Path
import numpy as np
import torch
# For reproducibility
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True
from diffusers import schedulers
from diffusers.models import AutoencoderKL
from loguru import logger
from transformers import BertModel, BertTokenizer
from transformers.modeling_utils import logger as tf_logger
from .constants import (
SAMPLER_FACTORY,
NEGATIVE_PROMPT,
TRT_MAX_WIDTH,
TRT_MAX_HEIGHT,
TRT_MAX_BATCH_SIZE,
)
from .diffusion.pipeline_controlnet import StableDiffusionControlNetPipeline
from .modules.models import HunYuanDiT, HUNYUAN_DIT_CONFIG
from .modules.controlnet import HunYuanControlNet
from .modules.posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
from .modules.text_encoder import MT5Embedder
from .utils.tools import set_seeds
from peft import LoraConfig
class Resolution:
def __init__(self, width, height):
self.width = width
self.height = height
def __str__(self):
return f"{self.height}x{self.width}"
class ResolutionGroup:
def __init__(self):
self.data = [
Resolution(1024, 1024), # 1:1
Resolution(1280, 1280), # 1:1
Resolution(1024, 768), # 4:3
Resolution(1152, 864), # 4:3
Resolution(1280, 960), # 4:3
Resolution(768, 1024), # 3:4
Resolution(864, 1152), # 3:4
Resolution(960, 1280), # 3:4
Resolution(1280, 768), # 16:9
Resolution(768, 1280), # 9:16
]
self.supported_sizes = set([(r.width, r.height) for r in self.data])
def is_valid(self, width, height):
return (width, height) in self.supported_sizes
STANDARD_RATIO = np.array(
[
1.0, # 1:1
4.0 / 3.0, # 4:3
3.0 / 4.0, # 3:4
16.0 / 9.0, # 16:9
9.0 / 16.0, # 9:16
]
)
STANDARD_SHAPE = [
[(1024, 1024), (1280, 1280)], # 1:1
[(1280, 960)], # 4:3
[(960, 1280)], # 3:4
[(1280, 768)], # 16:9
[(768, 1280)], # 9:16
]
STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE]
def get_standard_shape(target_width, target_height):
"""
Map image size to standard size.
"""
target_ratio = target_width / target_height
closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
closest_area_idx = np.argmin(
np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)
)
width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
return width, height
def _to_tuple(val):
if isinstance(val, (list, tuple)):
if len(val) == 1:
val = [val[0], val[0]]
elif len(val) == 2:
val = tuple(val)
else:
raise ValueError(f"Invalid value: {val}")
elif isinstance(val, (int, float)):
val = (val, val)
else:
raise ValueError(f"Invalid value: {val}")
return val
def get_pipeline(
args,
vae,
text_encoder,
tokenizer,
model,
device,
rank,
embedder_t5,
infer_mode,
controlnet,
sampler=None,
):
"""
Get scheduler and pipeline for sampling. The sampler and pipeline are both
based on diffusers and make some modifications.
Returns
-------
pipeline: StableDiffusionControlNetPipeline
sampler_name: str
"""
sampler = sampler or args.sampler
# Load sampler from factory
kwargs = SAMPLER_FACTORY[sampler]["kwargs"]
scheduler = SAMPLER_FACTORY[sampler]["scheduler"]
# Update sampler according to the arguments
kwargs["beta_schedule"] = args.noise_schedule
kwargs["beta_start"] = args.beta_start
kwargs["beta_end"] = args.beta_end
kwargs["prediction_type"] = args.predict_type
# Build scheduler according to the sampler.
scheduler_class = getattr(schedulers, scheduler)
scheduler = scheduler_class(**kwargs)
logger.debug(f"Using sampler: {sampler} with scheduler: {scheduler}")
# Set timesteps for inference steps.
scheduler.set_timesteps(args.infer_steps, device)
# Only enable progress bar for rank 0
progress_bar_config = {} if rank == 0 else {"disable": True}
pipeline = StableDiffusionControlNetPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=model,
scheduler=scheduler,
feature_extractor=None,
safety_checker=None,
requires_safety_checker=False,
progress_bar_config=progress_bar_config,
embedder_t5=embedder_t5,
infer_mode=infer_mode,
controlnet=controlnet,
)
pipeline = pipeline.to(device)
return pipeline, sampler
class End2End(object):
def __init__(self, args, models_root_path):
self.args = args
# Check arguments
t2i_root_path = Path(models_root_path) / "t2i"
self.root = t2i_root_path
logger.info(f"Got text-to-image model root path: {t2i_root_path}")
# Set device and disable gradient
self.device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_grad_enabled(False)
# Disable BertModel logging checkpoint info
tf_logger.setLevel("ERROR")
# ========================================================================
logger.info(f"Loading CLIP Text Encoder...")
text_encoder_path = self.root / "clip_text_encoder"
self.clip_text_encoder = BertModel.from_pretrained(
str(text_encoder_path), False, revision=None
).to(self.device)
logger.info(f"Loading CLIP Text Encoder finished")
# ========================================================================
logger.info(f"Loading CLIP Tokenizer...")
tokenizer_path = self.root / "tokenizer"
self.tokenizer = BertTokenizer.from_pretrained(str(tokenizer_path))
logger.info(f"Loading CLIP Tokenizer finished")
# ========================================================================
logger.info(f"Loading T5 Text Encoder and T5 Tokenizer...")
t5_text_encoder_path = self.root / "mt5"
embedder_t5 = MT5Embedder(
t5_text_encoder_path, torch_dtype=torch.float16, max_length=256
)
self.embedder_t5 = embedder_t5
self.embedder_t5.model.to(self.device) # Only move encoder to device
logger.info(f"Loading t5_text_encoder and t5_tokenizer finished")
# ========================================================================
logger.info(f"Loading VAE...")
vae_path = self.root / "sdxl-vae-fp16-fix"
self.vae = AutoencoderKL.from_pretrained(str(vae_path)).to(self.device)
logger.info(f"Loading VAE finished")
# ========================================================================
# Create model structure and load the checkpoint
logger.info(f"Building HunYuan-DiT model...")
model_config = HUNYUAN_DIT_CONFIG[self.args.model]
self.patch_size = model_config["patch_size"]
self.head_size = model_config["hidden_size"] // model_config["num_heads"]
self.resolutions, self.freqs_cis_img = (
self.standard_shapes()
) # Used for TensorRT models
self.image_size = _to_tuple(self.args.image_size)
latent_size = (self.image_size[0] // 8, self.image_size[1] // 8)
self.infer_mode = self.args.infer_mode
if self.infer_mode in ["fa", "torch"]:
# Build model structure
self.model = (
HunYuanDiT(
self.args,
input_size=latent_size,
**model_config,
log_fn=logger.info,
)
.half()
.to(self.device)
) # Force to use fp16
self.controlnet = (
HunYuanControlNet(
self.args,
input_size=latent_size,
**model_config,
log_fn=logger.info,
)
.half()
.to(self.device)
)
# Load model checkpoint
self.load_torch_weights()
lora_ckpt = args.lora_ckpt
if lora_ckpt is not None and lora_ckpt != "":
logger.info(f"Loading Lora checkpoint {lora_ckpt}...")
self.model.load_adapter(lora_ckpt)
self.model.merge_and_unload()
self.model.eval()
self.controlnet.eval()
logger.info(f"Loading torch model finished")
logger.info(f"Loading controlnet finished")
elif self.infer_mode == "trt":
from .modules.trt.hcf_model import TRTModel
trt_dir = self.root / "model_trt"
engine_dir = trt_dir / "engine"
plugin_path = trt_dir / "fmha_plugins/9.2_plugin_cuda11/fMHAPlugin.so"
model_name = "model_onnx"
logger.info(f"Loading TensorRT model {engine_dir}/{model_name}...")
self.model = TRTModel(
model_name=model_name,
engine_dir=str(engine_dir),
image_height=TRT_MAX_HEIGHT,
image_width=TRT_MAX_WIDTH,
text_maxlen=args.text_len,
embedding_dim=args.text_states_dim,
plugin_path=str(plugin_path),
max_batch_size=TRT_MAX_BATCH_SIZE,
)
logger.info(f"Loading TensorRT model finished")
else:
raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
# ========================================================================
# Build inference pipeline. We use a customized StableDiffusionControlNetPipeline.
logger.info(f"Loading inference pipeline...")
self.pipeline, self.sampler = self.load_sampler()
logger.info(f"Loading pipeline finished")
# ========================================================================
self.default_negative_prompt = NEGATIVE_PROMPT
logger.info("==================================================")
logger.info(f" Model is ready. ")
logger.info("==================================================")
def load_torch_weights(self):
load_key = self.args.load_key
# get base model path
if self.args.dit_weight is not None:
dit_weight = Path(self.args.dit_weight)
if dit_weight.is_dir():
files = list(dit_weight.glob("*.pt"))
if len(files) == 0:
raise ValueError(f"No model weights found in {dit_weight}")
if str(files[0]).startswith("pytorch_model_"):
model_path = dit_weight / f"pytorch_model_{load_key}.pt"
bare_model = True
elif any(str(f).endswith("_model_states.pt") for f in files):
files = [f for f in files if str(f).endswith("_model_states.pt")]
model_path = files[0]
if len(files) > 1:
logger.warning(
f"Multiple model weights found in {dit_weight}, using {model_path}"
)
bare_model = False
else:
raise ValueError(
f"Invalid model path: {dit_weight} with unrecognized weight format: "
f"{list(map(str, files))}. When given a directory as --dit-weight, only "
f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
f"specific weight file, please provide the full path to the file."
)
elif dit_weight.is_file():
model_path = dit_weight
bare_model = "unknown"
else:
raise ValueError(f"Invalid model path: {dit_weight}")
else:
model_dir = self.root / "model"
model_path = model_dir / f"pytorch_model_{load_key}.pt"
bare_model = True
# get controlnet model path
if self.args.controlnet_weight is not None:
controlnet_weight = Path(self.args.controlnet_weight)
if controlnet_weight.is_dir():
controlnet_dir = controlnet_weight
controlnet_path = (
controlnet_dir
/ f"pytorch_model_{self.args.control_type}_{load_key}.pt"
)
elif controlnet_weight.is_file():
controlnet_path = controlnet_weight
else:
raise ValueError(f"Invalid controlnet path: {controlnet_weight}")
else:
controlnet_dir = self.root / "controlnet"
controlnet_path = (
controlnet_dir / f"pytorch_model_{self.args.control_type}_{load_key}.pt"
)
if not model_path.exists():
raise ValueError(f"model_path not exists: {model_path}")
if not controlnet_path.exists():
raise ValueError(f"controlnet_path not exists: {controlnet_path}")
logger.info(f"Loading torch model {model_path}...")
if model_path.suffix == ".safetensors":
raise NotImplementedError(f"Loading safetensors is not supported yet.")
else:
# Assume it's a single weight file in the *.pt format.
state_dict = torch.load(
model_path, map_location=lambda storage, loc: storage
)
logger.info(f"Loading controlnet model {controlnet_path}...")
if controlnet_path.suffix == ".safetensors":
raise NotImplementedError(f"Loading safetensors is not supported yet.")
else:
# Assume it's a single weight file in the *.pt format.
controlnet_state_dict = torch.load(
controlnet_path, map_location=lambda storage, loc: storage
)
if "module" in controlnet_state_dict:
controlnet_state_dict = controlnet_state_dict["module"]
if bare_model == "unknown" and ("ema" in state_dict or "module" in state_dict):
bare_model = False
if bare_model is False:
if load_key in state_dict:
state_dict = state_dict[load_key]
else:
raise KeyError(
f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
f"are: {list(state_dict.keys())}."
)
if "style_embedder.weight" in state_dict and not hasattr(
self.model, "style_embedder"
):
raise ValueError(
f"You might be attempting to load the weights of HunYuanDiT version <= 1.1. You need "
f"to set `--use-style-cond --size-cond 1024 1024 --beta-end 0.03` to adapt to these weights."
f"Alternatively, you can use weights of version >= 1.2, which no longer depend on "
f"these two parameters."
)
if "style_embedder.weight" not in state_dict and hasattr(
self.model, "style_embedder"
):
raise ValueError(
f"You might be attempting to load the weights of HunYuanDiT version >= 1.2. You need "
f"to remove `--use-style-cond` and `--size-cond 1024 1024` to adapt to these weights."
)
if "style_embedder.weight" in controlnet_state_dict and not hasattr(
self.controlnet, "style_embedder"
):
raise ValueError(
f"You might be attempting to load the weights of HunYuanDiT version <= 1.1. You need "
f"to set `--use-style-cond --size-cond 1024 1024 --beta-end 0.03` to adapt to these weights."
f"Alternatively, you can use weights of version >= 1.2, which no longer depend on "
f"these two parameters."
)
if "style_embedder.weight" not in controlnet_state_dict and hasattr(
self.controlnet, "style_embedder"
):
raise ValueError(
f"You might be attempting to load the weights of HunYuanDiT version >= 1.2. You need "
f"to remove `--use-style-cond` and `--size-cond 1024 1024` to adapt to these weights."
)
# Don't set strict=False. Always explicitly check the state_dict.
self.model.load_state_dict(state_dict, strict=True)
self.controlnet.load_state_dict(controlnet_state_dict, strict=True)
def load_sampler(self, sampler=None):
pipeline, sampler = get_pipeline(
self.args,
self.vae,
self.clip_text_encoder,
self.tokenizer,
self.model,
device=self.device,
rank=0,
embedder_t5=self.embedder_t5,
infer_mode=self.infer_mode,
sampler=sampler,
controlnet=self.controlnet,
)
return pipeline, sampler
def calc_rope(self, height, width):
th = height // 8 // self.patch_size
tw = width // 8 // self.patch_size
base_size = 512 // 8 // self.patch_size
start, stop = get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
rope = get_2d_rotary_pos_embed(self.head_size, *sub_args)
return rope
def standard_shapes(self):
resolutions = ResolutionGroup()
freqs_cis_img = {}
for reso in resolutions.data:
freqs_cis_img[str(reso)] = self.calc_rope(reso.height, reso.width)
return resolutions, freqs_cis_img
def predict(
self,
user_prompt,
image,
height=1024,
width=1024,
seed=None,
enhanced_prompt=None,
negative_prompt=None,
infer_steps=100,
guidance_scale=6,
batch_size=1,
src_size_cond=(1024, 1024),
sampler=None,
use_style_cond=False,
):
# ========================================================================
# Arguments: seed
# ========================================================================
if seed is None:
seed = random.randint(0, 1_000_000)
if not isinstance(seed, int):
raise TypeError(f"`seed` must be an integer, but got {type(seed)}")
generator = set_seeds(seed, device=self.device)
# ========================================================================
# Arguments: target_width, target_height
# ========================================================================
if width <= 0 or height <= 0:
raise ValueError(
f"`height` and `width` must be positive integers, got height={height}, width={width}"
)
logger.info(f"Input (height, width) = ({height}, {width})")
if self.infer_mode in ["fa", "torch"]:
# We must force height and width to align to 16 and to be an integer.
target_height = int((height // 16) * 16)
target_width = int((width // 16) * 16)
logger.info(
f"Align to 16: (height, width) = ({target_height}, {target_width})"
)
elif self.infer_mode == "trt":
target_width, target_height = get_standard_shape(width, height)
logger.info(
f"Align to standard shape: (height, width) = ({target_height}, {target_width})"
)
else:
raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
# ========================================================================
# Arguments: prompt, new_prompt, negative_prompt
# ========================================================================
if not isinstance(user_prompt, str):
raise TypeError(
f"`user_prompt` must be a string, but got {type(user_prompt)}"
)
user_prompt = user_prompt.strip()
prompt = user_prompt
if enhanced_prompt is not None:
if not isinstance(enhanced_prompt, str):
raise TypeError(
f"`enhanced_prompt` must be a string, but got {type(enhanced_prompt)}"
)
enhanced_prompt = enhanced_prompt.strip()
prompt = enhanced_prompt
# negative prompt
if negative_prompt is None or negative_prompt == "":
negative_prompt = self.default_negative_prompt
if not isinstance(negative_prompt, str):
raise TypeError(
f"`negative_prompt` must be a string, but got {type(negative_prompt)}"
)
# ========================================================================
# Arguments: style. (A fixed argument. Don't Change it.)
# ========================================================================
if use_style_cond:
# Only for hydit <= 1.1
style = torch.as_tensor([0, 0] * batch_size, device=self.device)
else:
style = None
# ========================================================================
# Inner arguments: image_meta_size (Please refer to SDXL.)
# ========================================================================
if src_size_cond is None:
size_cond = None
image_meta_size = None
else:
# Only for hydit <= 1.1
if isinstance(src_size_cond, int):
src_size_cond = [src_size_cond, src_size_cond]
if not isinstance(src_size_cond, (list, tuple)):
raise TypeError(
f"`src_size_cond` must be a list or tuple, but got {type(src_size_cond)}"
)
if len(src_size_cond) != 2:
raise ValueError(
f"`src_size_cond` must be a tuple of 2 integers, but got {len(src_size_cond)}"
)
size_cond = list(src_size_cond) + [target_width, target_height, 0, 0]
image_meta_size = torch.as_tensor(
[size_cond] * 2 * batch_size, device=self.device
)
# ========================================================================
start_time = time.time()
logger.debug(
f"""
prompt: {user_prompt}
enhanced prompt: {enhanced_prompt}
seed: {seed}
(height, width): {(target_height, target_width)}
negative_prompt: {negative_prompt}
batch_size: {batch_size}
guidance_scale: {guidance_scale}
infer_steps: {infer_steps}
image_meta_size: {size_cond}
"""
)
reso = f"{target_height}x{target_width}"
if reso in self.freqs_cis_img:
freqs_cis_img = self.freqs_cis_img[reso]
else:
freqs_cis_img = self.calc_rope(target_height, target_width)
if sampler is not None and sampler != self.sampler:
self.pipeline, self.sampler = self.load_sampler(sampler)
samples = self.pipeline(
height=target_height,
width=target_width,
prompt=prompt,
negative_prompt=negative_prompt,
num_images_per_prompt=batch_size,
guidance_scale=guidance_scale,
num_inference_steps=infer_steps,
image_meta_size=image_meta_size,
style=style,
return_dict=False,
generator=generator,
freqs_cis_img=freqs_cis_img,
use_fp16=self.args.use_fp16,
learn_sigma=self.args.learn_sigma,
image=image,
control_weight=eval(self.args.control_weight),
)[0]
gen_time = time.time() - start_time
logger.debug(f"Success, time: {gen_time}")
return {
"images": samples,
"seed": seed,
}
import random
import time
from pathlib import Path
import numpy as np
import torch
# For reproducibility
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True
from diffusers import schedulers
from diffusers.models import AutoencoderKL
from loguru import logger
from transformers import BertModel, BertTokenizer
from transformers.modeling_utils import logger as tf_logger
from .constants import (
SAMPLER_FACTORY,
NEGATIVE_PROMPT,
TRT_MAX_WIDTH,
TRT_MAX_HEIGHT,
TRT_MAX_BATCH_SIZE,
)
from .diffusion.pipeline_ipadapter import StableDiffusionIPAPipeline
from .modules.models import HunYuanDiT, HUNYUAN_DIT_CONFIG
from .modules.posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
from .modules.text_encoder import MT5Embedder
from .utils.tools import set_seeds
from peft import LoraConfig
import sys
from .utils.img_clip_emb import ImgClipEmbDetector
class Resolution:
def __init__(self, width, height):
self.width = width
self.height = height
def __str__(self):
return f"{self.height}x{self.width}"
class ResolutionGroup:
def __init__(self):
self.data = [
Resolution(1024, 1024), # 1:1
Resolution(1280, 1280), # 1:1
Resolution(1024, 768), # 4:3
Resolution(1152, 864), # 4:3
Resolution(1280, 960), # 4:3
Resolution(768, 1024), # 3:4
Resolution(864, 1152), # 3:4
Resolution(960, 1280), # 3:4
Resolution(1280, 768), # 16:9
Resolution(768, 1280), # 9:16
]
self.supported_sizes = set([(r.width, r.height) for r in self.data])
def is_valid(self, width, height):
return (width, height) in self.supported_sizes
STANDARD_RATIO = np.array(
[
1.0, # 1:1
4.0 / 3.0, # 4:3
3.0 / 4.0, # 3:4
16.0 / 9.0, # 16:9
9.0 / 16.0, # 9:16
]
)
STANDARD_SHAPE = [
[(1024, 1024), (1280, 1280)], # 1:1
[(1280, 960)], # 4:3
[(960, 1280)], # 3:4
[(1280, 768)], # 16:9
[(768, 1280)], # 9:16
]
STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE]
def get_standard_shape(target_width, target_height):
"""
Map image size to standard size.
"""
target_ratio = target_width / target_height
closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
closest_area_idx = np.argmin(
np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)
)
width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
return width, height
def _to_tuple(val):
if isinstance(val, (list, tuple)):
if len(val) == 1:
val = [val[0], val[0]]
elif len(val) == 2:
val = tuple(val)
else:
raise ValueError(f"Invalid value: {val}")
elif isinstance(val, (int, float)):
val = (val, val)
else:
raise ValueError(f"Invalid value: {val}")
return val
def get_pipeline(
args,
vae,
text_encoder,
tokenizer,
model,
device,
rank,
embedder_t5,
infer_mode,
sampler=None,
):
"""
Get scheduler and pipeline for sampling. The sampler and pipeline are both
based on diffusers and make some modifications.
Returns
-------
pipeline: StableDiffusionPipeline
sampler_name: str
"""
sampler = sampler or args.sampler
# Load sampler from factory
kwargs = SAMPLER_FACTORY[sampler]["kwargs"]
scheduler = SAMPLER_FACTORY[sampler]["scheduler"]
# Update sampler according to the arguments
kwargs["beta_schedule"] = args.noise_schedule
kwargs["beta_start"] = args.beta_start
kwargs["beta_end"] = args.beta_end
kwargs["prediction_type"] = args.predict_type
# Build scheduler according to the sampler.
scheduler_class = getattr(schedulers, scheduler)
scheduler = scheduler_class(**kwargs)
logger.debug(f"Using sampler: {sampler} with scheduler: {scheduler}")
# Set timesteps for inference steps.
scheduler.set_timesteps(args.infer_steps, device)
# Only enable progress bar for rank 0
progress_bar_config = {} if rank == 0 else {"disable": True}
img_encoder = ImgClipEmbDetector()
pipeline = StableDiffusionIPAPipeline(
vae=vae,
text_encoder=text_encoder,
img_encoder=img_encoder,
tokenizer=tokenizer,
unet=model,
scheduler=scheduler,
feature_extractor=None,
safety_checker=None,
requires_safety_checker=False,
progress_bar_config=progress_bar_config,
embedder_t5=embedder_t5,
infer_mode=infer_mode,
)
pipeline = pipeline.to(device)
return pipeline, sampler
class End2End(object):
def __init__(self, args, models_root_path):
self.args = args
# Check arguments
t2i_root_path = Path(models_root_path) / "t2i"
self.root = t2i_root_path
logger.info(f"Got text-to-image model root path: {t2i_root_path}")
# Set device and disable gradient
self.device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_grad_enabled(False)
# Disable BertModel logging checkpoint info
tf_logger.setLevel("ERROR")
# ========================================================================
logger.info(f"Loading CLIP Text Encoder...")
text_encoder_path = self.root / "clip_text_encoder"
self.clip_text_encoder = BertModel.from_pretrained(
str(text_encoder_path), False, revision=None
).to(self.device)
logger.info(f"Loading CLIP Text Encoder finished")
# ========================================================================
logger.info(f"Loading CLIP Tokenizer...")
tokenizer_path = self.root / "tokenizer"
self.tokenizer = BertTokenizer.from_pretrained(str(tokenizer_path))
logger.info(f"Loading CLIP Tokenizer finished")
# ========================================================================
logger.info(f"Loading T5 Text Encoder and T5 Tokenizer...")
t5_text_encoder_path = self.root / "mt5"
embedder_t5 = MT5Embedder(
t5_text_encoder_path, torch_dtype=torch.float16, max_length=256
)
self.embedder_t5 = embedder_t5
self.embedder_t5.model.to(self.device) # Only move encoder to device
logger.info(f"Loading t5_text_encoder and t5_tokenizer finished")
# ========================================================================
logger.info(f"Loading VAE...")
vae_path = self.root / "sdxl-vae-fp16-fix"
self.vae = AutoencoderKL.from_pretrained(str(vae_path)).to(self.device)
logger.info(f"Loading VAE finished")
# ========================================================================
# Create model structure and load the checkpoint
logger.info(f"Building HunYuan-DiT model...")
model_config = HUNYUAN_DIT_CONFIG[self.args.model]
self.patch_size = model_config["patch_size"]
self.head_size = model_config["hidden_size"] // model_config["num_heads"]
self.resolutions, self.freqs_cis_img = (
self.standard_shapes()
) # Used for TensorRT models
self.image_size = _to_tuple(self.args.image_size)
latent_size = (self.image_size[0] // 8, self.image_size[1] // 8)
self.infer_mode = self.args.infer_mode
if self.infer_mode in ["fa", "torch"]:
# Build model structure
self.model = (
HunYuanDiT(
self.args,
input_size=latent_size,
**model_config,
log_fn=logger.info,
)
.half()
.to(self.device)
) # Force to use fp16
# Load model checkpoint
self.load_torch_weights()
self.load_torch_ipa_weights()
lora_ckpt = args.lora_ckpt
if lora_ckpt is not None and lora_ckpt != "":
logger.info(f"Loading Lora checkpoint {lora_ckpt}...")
self.model.load_adapter(lora_ckpt)
self.model.merge_and_unload()
self.model.eval()
logger.info(f"Loading torch model finished")
elif self.infer_mode == "trt":
from .modules.trt.hcf_model import TRTModel
trt_dir = self.root / "model_trt"
engine_dir = trt_dir / "engine"
plugin_path = trt_dir / "fmha_plugins/9.2_plugin_cuda11/fMHAPlugin.so"
model_name = "model_onnx"
logger.info(f"Loading TensorRT model {engine_dir}/{model_name}...")
self.model = TRTModel(
model_name=model_name,
engine_dir=str(engine_dir),
image_height=TRT_MAX_HEIGHT,
image_width=TRT_MAX_WIDTH,
text_maxlen=args.text_len,
embedding_dim=args.text_states_dim,
plugin_path=str(plugin_path),
max_batch_size=TRT_MAX_BATCH_SIZE,
)
logger.info(f"Loading TensorRT model finished")
else:
raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
# ========================================================================
# Build inference pipeline. We use a customized StableDiffusionPipeline.
logger.info(f"Loading inference pipeline...")
self.pipeline, self.sampler = self.load_sampler()
logger.info(f"Loading pipeline finished")
# ========================================================================
self.default_negative_prompt = NEGATIVE_PROMPT
logger.info("==================================================")
logger.info(f" Model is ready. ")
logger.info("==================================================")
def load_torch_weights(self):
load_key = self.args.load_key
if self.args.dit_weight is not None:
dit_weight = Path(self.args.dit_weight)
if dit_weight.is_dir():
files = list(dit_weight.glob("*.pt"))
if len(files) == 0:
raise ValueError(f"No model weights found in {dit_weight}")
if str(files[0]).startswith("pytorch_model_"):
model_path = dit_weight / f"pytorch_model_{load_key}.pt"
bare_model = True
elif any(str(f).endswith("_model_states.pt") for f in files):
files = [f for f in files if str(f).endswith("_model_states.pt")]
model_path = files[0]
if len(files) > 1:
logger.warning(
f"Multiple model weights found in {dit_weight}, using {model_path}"
)
bare_model = False
else:
raise ValueError(
f"Invalid model path: {dit_weight} with unrecognized weight format: "
f"{list(map(str, files))}. When given a directory as --dit-weight, only "
f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
f"specific weight file, please provide the full path to the file."
)
elif dit_weight.is_file():
model_path = dit_weight
bare_model = "unknown"
else:
raise ValueError(f"Invalid model path: {dit_weight}")
else:
model_dir = self.root / "model"
model_path = model_dir / f"pytorch_model_{load_key}.pt"
bare_model = True
if not model_path.exists():
raise ValueError(f"model_path not exists: {model_path}")
logger.info(f"Loading torch model {model_path}...")
if model_path.suffix == ".safetensors":
raise NotImplementedError(f"Loading safetensors is not supported yet.")
else:
# Assume it's a single weight file in the *.pt format.
state_dict = torch.load(
model_path, map_location=lambda storage, loc: storage
)
if bare_model == "unknown" and ("ema" in state_dict or "module" in state_dict):
bare_model = False
if bare_model is False:
if load_key in state_dict:
state_dict = state_dict[load_key]
else:
raise KeyError(
f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
f"are: {list(state_dict.keys())}."
)
if "style_embedder.weight" in state_dict and not hasattr(
self.model, "style_embedder"
):
raise ValueError(
f"You might be attempting to load the weights of HunYuanDiT version <= 1.1. You need "
f"to set `--use-style-cond --size-cond 1024 1024 --beta-end 0.03` to adapt to these weights."
f"Alternatively, you can use weights of version >= 1.2, which no longer depend on "
f"these two parameters."
)
if "style_embedder.weight" not in state_dict and hasattr(
self.model, "style_embedder"
):
raise ValueError(
f"You might be attempting to load the weights of HunYuanDiT version >= 1.2. You need "
f"to remove `--use-style-cond` and `--size-cond 1024 1024` to adapt to these weights."
)
# Don't set strict=False. Always explicitly check the state_dict.
# self.model.load_state_dict(state_dict["module"], strict=False)
self.model.load_state_dict(state_dict, strict=False)
def load_torch_ipa_weights(self):
ipa_model_dir = self.root / "model"
ipa_model_path = ipa_model_dir / "ipa.pt"
ipa_state_dict = torch.load(
ipa_model_path, map_location=lambda storage, loc: storage
)
self.model.load_state_dict(ipa_state_dict, strict=False)
def load_sampler(self, sampler=None):
pipeline, sampler = get_pipeline(
self.args,
self.vae,
self.clip_text_encoder,
self.tokenizer,
self.model,
device=self.device,
rank=0,
embedder_t5=self.embedder_t5,
infer_mode=self.infer_mode,
sampler=sampler,
)
return pipeline, sampler
def calc_rope(self, height, width):
th = height // 8 // self.patch_size
tw = width // 8 // self.patch_size
base_size = 512 // 8 // self.patch_size
start, stop = get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
rope = get_2d_rotary_pos_embed(self.head_size, *sub_args)
return rope
def standard_shapes(self):
resolutions = ResolutionGroup()
freqs_cis_img = {}
for reso in resolutions.data:
freqs_cis_img[str(reso)] = self.calc_rope(reso.height, reso.width)
return resolutions, freqs_cis_img
def predict(
self,
user_prompt,
image,
t_scale,
i_scale,
height=1024,
width=1024,
seed=None,
enhanced_prompt=None,
negative_prompt=None,
infer_steps=100,
guidance_scale=6,
batch_size=1,
src_size_cond=(1024, 1024),
sampler=None,
use_style_cond=False,
):
# ========================================================================
# Arguments: seed
# ========================================================================
if seed is None:
seed = random.randint(0, 1_000_000)
if not isinstance(seed, int):
raise TypeError(f"`seed` must be an integer, but got {type(seed)}")
generator = set_seeds(seed, device=self.device)
# ========================================================================
# Arguments: target_width, target_height
# ========================================================================
if width <= 0 or height <= 0:
raise ValueError(
f"`height` and `width` must be positive integers, got height={height}, width={width}"
)
logger.info(f"Input (height, width) = ({height}, {width})")
if self.infer_mode in ["fa", "torch"]:
# We must force height and width to align to 16 and to be an integer.
target_height = int((height // 16) * 16)
target_width = int((width // 16) * 16)
logger.info(
f"Align to 16: (height, width) = ({target_height}, {target_width})"
)
elif self.infer_mode == "trt":
target_width, target_height = get_standard_shape(width, height)
logger.info(
f"Align to standard shape: (height, width) = ({target_height}, {target_width})"
)
else:
raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
# ========================================================================
# Arguments: prompt, new_prompt, negative_prompt
# ========================================================================
if not isinstance(user_prompt, str):
raise TypeError(
f"`user_prompt` must be a string, but got {type(user_prompt)}"
)
user_prompt = user_prompt.strip()
prompt = user_prompt
if enhanced_prompt is not None:
if not isinstance(enhanced_prompt, str):
raise TypeError(
f"`enhanced_prompt` must be a string, but got {type(enhanced_prompt)}"
)
enhanced_prompt = enhanced_prompt.strip()
prompt = enhanced_prompt
# negative prompt
if negative_prompt is None or negative_prompt == "":
negative_prompt = self.default_negative_prompt
if not isinstance(negative_prompt, str):
raise TypeError(
f"`negative_prompt` must be a string, but got {type(negative_prompt)}"
)
# ========================================================================
# Arguments: style. (A fixed argument. Don't Change it.)
# ========================================================================
if use_style_cond:
# Only for hydit <= 1.1
style = torch.as_tensor([0, 0] * batch_size, device=self.device)
else:
style = None
# ========================================================================
# Inner arguments: image_meta_size (Please refer to SDXL.)
# ========================================================================
if src_size_cond is None:
size_cond = None
image_meta_size = None
else:
# Only for hydit <= 1.1
if isinstance(src_size_cond, int):
src_size_cond = [src_size_cond, src_size_cond]
if not isinstance(src_size_cond, (list, tuple)):
raise TypeError(
f"`src_size_cond` must be a list or tuple, but got {type(src_size_cond)}"
)
if len(src_size_cond) != 2:
raise ValueError(
f"`src_size_cond` must be a tuple of 2 integers, but got {len(src_size_cond)}"
)
size_cond = list(src_size_cond) + [target_width, target_height, 0, 0]
image_meta_size = torch.as_tensor(
[size_cond] * 2 * batch_size, device=self.device
)
# ========================================================================
start_time = time.time()
logger.debug(
f"""
prompt: {user_prompt}
enhanced prompt: {enhanced_prompt}
seed: {seed}
(height, width): {(target_height, target_width)}
negative_prompt: {negative_prompt}
batch_size: {batch_size}
guidance_scale: {guidance_scale}
infer_steps: {infer_steps}
image_meta_size: {size_cond}
"""
)
reso = f"{target_height}x{target_width}"
if reso in self.freqs_cis_img:
freqs_cis_img = self.freqs_cis_img[reso]
else:
freqs_cis_img = self.calc_rope(target_height, target_width)
if sampler is not None and sampler != self.sampler:
self.pipeline, self.sampler = self.load_sampler(sampler)
samples = self.pipeline(
height=target_height,
width=target_width,
prompt=prompt,
image=image,
t_scale=t_scale,
i_scale=i_scale,
negative_prompt=negative_prompt,
num_images_per_prompt=batch_size,
guidance_scale=guidance_scale,
num_inference_steps=infer_steps,
image_meta_size=image_meta_size,
style=style,
return_dict=False,
generator=generator,
freqs_cis_img=freqs_cis_img,
use_fp16=self.args.use_fp16,
learn_sigma=self.args.learn_sigma,
)[0]
gen_time = time.time() - start_time
logger.debug(f"Success, time: {gen_time}")
return {
"images": samples,
"seed": seed,
}
# import comfy.utils
import logging
import torch
import numpy as np
def load_lora(lora, to_load, weight):
model_dict = to_load
patch_dict = {}
loaded_keys = set()
for x in to_load:
alpha_name = "{}.alpha".format(x)
alpha = None
if alpha_name in lora.keys():
alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name)
dora_scale_name = "{}.dora_scale".format(x)
dora_scale = None
if dora_scale_name in lora.keys():
dora_scale = lora[dora_scale_name]
loaded_keys.add(dora_scale_name)
hunyuan_lora = "unet.{}.lora.up.weight".format(
x.replace(".weight", "").replace("_", ".")
)
A_name = None
if hunyuan_lora in lora.keys():
A_name = hunyuan_lora
B_name = "unet.{}.lora.down.weight".format(
x.replace(".weight", "").replace("_", ".")
)
mid_name = None
bias_name = "{}.bias".format(x.replace(".weight", ""))
if A_name is not None:
mid = None
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = (
"lora",
(lora[A_name], lora[B_name], alpha, mid, dora_scale),
)
lora_update = torch.matmul(lora[A_name].to("cuda"), lora[B_name].to("cuda"))
if alpha:
lora_update *= alpha / lora[A_name].shape[1]
else:
lora_update /= np.sqrt(lora[A_name].shape[1])
lora_update *= weight
model_dict[x] += lora_update
loaded_keys.add(A_name)
loaded_keys.add(B_name)
for x in lora.keys():
if x not in loaded_keys:
logging.warning("lora key not loaded: {}".format(x))
return model_dict
"""
Implementation of learning rate schedules.
Taken and modified from PyTorch v1.0.1 source
https://github.com/pytorch/pytorch/blob/v1.1.0/torch/optim/lr_scheduler.py
"""
import argparse
from torch.optim import Optimizer
import math
LR_SCHEDULE = "lr_schedule"
LR_RANGE_TEST = "LRRangeTest"
ONE_CYCLE = "OneCycle"
WARMUP_LR = "WarmupLR"
WARMUP_DECAY_LR = "WarmupDecayLR"
VALID_LR_SCHEDULES = [LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR, WARMUP_DECAY_LR]
LR_RANGE_TEST_MIN_LR = "lr_range_test_min_lr"
LR_RANGE_TEST_STEP_RATE = "lr_range_test_step_rate"
LR_RANGE_TEST_STEP_SIZE = "lr_range_test_step_size"
LR_RANGE_TEST_STAIRCASE = "lr_range_test_staircase"
EDGE_VALUE = "edge_value"
MID_VALUE = "mid_value"
CYCLE_FIRST_STEP_SIZE = "cycle_first_step_size"
CYCLE_FIRST_STAIR_COUNT = "cycle_first_stair_count"
CYCLE_SECOND_STEP_SIZE = "cycle_second_step_size"
CYCLE_SECOND_STAIR_COUNT = "cycle_second_stair_count"
DECAY_STEP_SIZE = "decay_step_size"
CYCLE_MIN_LR = "cycle_min_lr"
CYCLE_MAX_LR = "cycle_max_lr"
DECAY_LR_RATE = "decay_lr_rate"
CYCLE_MIN_MOM = "cycle_min_mom"
CYCLE_MAX_MOM = "cycle_max_mom"
DECAY_MOM_RATE = "decay_mom_rate"
WARMUP_MIN_LR = "warmup_min_lr"
WARMUP_MAX_LR = "warmup_max_lr"
WARMUP_NUM_STEPS = "warmup_num_steps"
WARMUP_TYPE = "warmup_type"
WARMUP_LOG_RATE = "log"
WARMUP_LINEAR_RATE = "linear"
TOTAL_NUM_STEPS = "total_num_steps"
def add_tuning_arguments(parser):
group = parser.add_argument_group(
"Convergence Tuning", "Convergence tuning configurations"
)
# LR scheduler
group.add_argument(
"--lr_schedule", type=str, default=None, help="LR schedule for training."
)
# Learning rate range test
group.add_argument(
"--lr_range_test_min_lr", type=float, default=0.001, help="Starting lr value."
)
group.add_argument(
"--lr_range_test_step_rate",
type=float,
default=1.0,
help="scaling rate for LR range test.",
)
group.add_argument(
"--lr_range_test_step_size",
type=int,
default=1000,
help="training steps per LR change.",
)
group.add_argument(
"--lr_range_test_staircase",
type=bool,
default=False,
help="use staircase scaling for LR range test.",
)
# OneCycle schedule
group.add_argument(
"--cycle_first_step_size",
type=int,
default=1000,
help="size of first step of 1Cycle schedule (training steps).",
)
group.add_argument(
"--cycle_first_stair_count",
type=int,
default=-1,
help="first stair count for 1Cycle schedule.",
)
group.add_argument(
"--cycle_second_step_size",
type=int,
default=-1,
help="size of second step of 1Cycle schedule (default first_step_size).",
)
group.add_argument(
"--cycle_second_stair_count",
type=int,
default=-1,
help="second stair count for 1Cycle schedule.",
)
group.add_argument(
"--decay_step_size",
type=int,
default=1000,
help="size of intervals for applying post cycle decay (training steps).",
)
# 1Cycle LR
group.add_argument(
"--cycle_min_lr", type=float, default=0.01, help="1Cycle LR lower bound."
)
group.add_argument(
"--cycle_max_lr", type=float, default=0.1, help="1Cycle LR upper bound."
)
group.add_argument(
"--decay_lr_rate", type=float, default=0.0, help="post cycle LR decay rate."
)
# 1Cycle Momentum
group.add_argument(
"--cycle_momentum",
default=False,
action="store_true",
help="Enable 1Cycle momentum schedule.",
)
group.add_argument(
"--cycle_min_mom", type=float, default=0.8, help="1Cycle momentum lower bound."
)
group.add_argument(
"--cycle_max_mom", type=float, default=0.9, help="1Cycle momentum upper bound."
)
group.add_argument(
"--decay_mom_rate",
type=float,
default=0.0,
help="post cycle momentum decay rate.",
)
# Warmup LR
group.add_argument(
"--warmup_min_lr",
type=float,
default=0,
help="WarmupLR minimum/initial LR value",
)
group.add_argument(
"--warmup_max_lr", type=float, default=0.001, help="WarmupLR maximum LR value."
)
group.add_argument(
"--warmup_num_steps",
type=int,
default=1000,
help="WarmupLR step count for LR warmup.",
)
group.add_argument(
"--warmup_type",
type=str,
default=WARMUP_LOG_RATE,
help="WarmupLR increasing function during warmup",
)
return parser
def parse_arguments():
parser = argparse.ArgumentParser()
parser = add_tuning_arguments(parser)
lr_sched_args, unknown_args = parser.parse_known_args()
return lr_sched_args, unknown_args
def override_lr_range_test_params(args, params):
if hasattr(args, LR_RANGE_TEST_MIN_LR) and args.lr_range_test_min_lr is not None:
params[LR_RANGE_TEST_MIN_LR] = args.lr_range_test_min_lr
if (
hasattr(args, LR_RANGE_TEST_STEP_RATE)
and args.lr_range_test_step_rate is not None
):
params[LR_RANGE_TEST_STEP_RATE] = args.lr_range_test_step_rate
if (
hasattr(args, LR_RANGE_TEST_STEP_SIZE)
and args.lr_range_test_step_size is not None
):
params[LR_RANGE_TEST_STEP_SIZE] = args.lr_range_test_step_size
if (
hasattr(args, LR_RANGE_TEST_STAIRCASE)
and args.lr_range_test_staircase is not None
):
params[LR_RANGE_TEST_STAIRCASE] = args.lr_range_test_staircase
def override_1cycle_params(args, params):
if hasattr(args, CYCLE_FIRST_STEP_SIZE) and args.cycle_first_step_size is not None:
params[CYCLE_FIRST_STEP_SIZE] = args.cycle_first_step_size
if (
hasattr(args, CYCLE_FIRST_STAIR_COUNT)
and args.cycle_first_stair_count is not None
):
params[CYCLE_FIRST_STAIR_COUNT] = args.cycle_first_stair_count
if (
hasattr(args, CYCLE_SECOND_STEP_SIZE)
and args.cycle_second_step_size is not None
):
params[CYCLE_SECOND_STEP_SIZE] = args.cycle_second_step_size
if (
hasattr(args, CYCLE_SECOND_STAIR_COUNT)
and args.cycle_second_stair_count is not None
):
params[CYCLE_SECOND_STAIR_COUNT] = args.cycle_second_stair_count
if hasattr(args, DECAY_STEP_SIZE) and args.decay_step_size is not None:
params[DECAY_STEP_SIZE] = args.decay_step_size
# 1Cycle LR params
if hasattr(args, CYCLE_MIN_LR) and args.cycle_min_lr is not None:
params[CYCLE_MIN_LR] = args.cycle_min_lr
if hasattr(args, CYCLE_MAX_LR) and args.cycle_max_lr is not None:
params[CYCLE_MAX_LR] = args.cycle_max_lr
if hasattr(args, DECAY_LR_RATE) and args.decay_lr_rate is not None:
params[DECAY_LR_RATE] = args.decay_lr_rate
# 1Cycle MOM params
if hasattr(args, CYCLE_MIN_MOM) and args.cycle_min_mom is not None:
params[CYCLE_MIN_MOM] = args.cycle_min_mom
if hasattr(args, CYCLE_MAX_MOM) and args.cycle_max_mom is not None:
params[CYCLE_MAX_MOM] = args.cycle_max_mom
if hasattr(args, DECAY_MOM_RATE) and args.decay_mom_rate is not None:
params[DECAY_MOM_RATE] = args.decay_mom_rate
def override_warmupLR_params(args, params):
if hasattr(args, WARMUP_MIN_LR) and args.warmup_min_lr is not None:
params[WARMUP_MIN_LR] = args.warmup_min_lr
if hasattr(args, WARMUP_MAX_LR) and args.warmup_max_lr is not None:
params[WARMUP_MAX_LR] = args.warmup_max_lr
if hasattr(args, WARMUP_NUM_STEPS) and args.warmup_num_steps is not None:
params[WARMUP_NUM_STEPS] = args.warmup_num_steps
if hasattr(args, WARMUP_TYPE) and args.warmup_type is not None:
params[WARMUP_TYPE] = args.warmup_type
def override_params(args, params):
# LR range test params
override_lr_range_test_params(args, params)
# 1Cycle params
override_1cycle_params(args, params)
# WarmupLR params
override_warmupLR_params(args, params)
def get_config_from_args(args):
if not hasattr(args, LR_SCHEDULE) or args.lr_schedule is None:
return None, "--{} not specified on command line".format(LR_SCHEDULE)
if not args.lr_schedule in VALID_LR_SCHEDULES:
return None, "{} is not supported LR schedule".format(args.lr_schedule)
config = {}
config["type"] = args.lr_schedule
config["params"] = {}
if args.lr_schedule == LR_RANGE_TEST:
override_lr_range_test_params(args, config["params"])
elif args.lr_schedule == ONE_CYCLE:
override_1cycle_params(args, config["params"])
else:
override_warmupLR_params(args, config["params"])
return config, None
def get_lr_from_config(config):
if not "type" in config:
return None, "LR schedule type not defined in config"
if not "params" in config:
return None, "LR schedule params not defined in config"
lr_schedule = config["type"]
lr_params = config["params"]
if not lr_schedule in VALID_LR_SCHEDULES:
return None, "{} is not a valid LR schedule".format(lr_schedule)
if lr_schedule == LR_RANGE_TEST:
return lr_params[LR_RANGE_TEST_MIN_LR], ""
if lr_schedule == ONE_CYCLE:
return lr_params[CYCLE_MAX_LR], ""
# Warmup LR
return lr_params[WARMUP_MAX_LR], ""
"""
Only optimizers that are subclass of torch.optim.Optimizer are supported. So check the passed optimizer and wrapped
optimizer to see if requirement is satisfied.
TODO: Looking under the hood to examine the wrapped optimizer is a hack that requires a better long-term fix.
"""
def get_torch_optimizer(optimizer):
if isinstance(optimizer, Optimizer):
return optimizer
if hasattr(optimizer, "optimizer") and isinstance(optimizer.optimizer, Optimizer):
return optimizer.optimizer
raise TypeError(
"{} is not a subclass of torch.optim.Optimizer".format(type(optimizer).__name__)
)
class LRRangeTest(object):
"""Sets the learning rate of each parameter group according to
learning rate range test (LRRT) policy. The policy increases learning
rate starting from a base value with a constant frequency, as detailed in
the paper `A disciplined approach to neural network hyper-parameters: Part1`_.
LRRT policy is used for finding maximum LR that trains a model without divergence, and can be used to
configure the LR boundaries for Cyclic LR schedules.
LRRT changes the learning rate after every batch.
`step` should be called after a batch has been used for training.
Args:
optimizer (Optimizer): Wrapped optimizer.
lr_range_test_min_lr (float or list): Initial learning rate which is the
lower boundary in the range test for each parameter group.
lr_range_test_step_size (int): Interval of training steps to increase learning rate. Default: 2000
lr_range_test_step_rate (float): Scaling rate for range test. Default: 1.0
lr_range_test_staircase (bool): Scale in staircase fashion, rather than continuous. Default: False.
last_batch_iteration (int): The index of the last batch. This parameter is used when
resuming a training job. Since `step()` should be invoked after each
batch instead of after each epoch, this number represents the total
number of *batches* computed, not the total number of epochs computed.
When last_batch_iteration=-1, the schedule is started from the beginning.
Default: -1
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = LRRangeTest(optimizer)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> scheduler.step()
_A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay:
https://arxiv.org/abs/1803.09820
"""
def __init__(
self,
optimizer: Optimizer,
lr_range_test_min_lr: float = 1e-3,
lr_range_test_step_size: int = 2000,
lr_range_test_step_rate: float = 1.0,
lr_range_test_staircase: bool = False,
last_batch_iteration: int = -1,
):
self.optimizer = get_torch_optimizer(optimizer)
if isinstance(lr_range_test_min_lr, list) or isinstance(
lr_range_test_min_lr, tuple
):
if len(lr_range_test_min_lr) != len(self.optimizer.param_groups):
raise ValueError(
"expected {} lr_range_test_min_lr, got {}".format(
len(self.optimizer.param_groups), len(lr_range_test_min_lr)
)
)
self.min_lr = list(lr_range_test_min_lr)
else:
self.min_lr = [lr_range_test_min_lr] * len(self.optimizer.param_groups)
self.step_size = lr_range_test_step_size
self.step_rate = lr_range_test_step_rate
self.last_batch_iteration = last_batch_iteration
self.staircase = lr_range_test_staircase
self.interval_fn = (
self._staircase_interval
if lr_range_test_staircase
else self._continuous_interval
)
if last_batch_iteration == -1:
self._update_optimizer(self.min_lr)
def _staircase_interval(self):
return math.floor(float(self.last_batch_iteration + 1) / self.step_size)
def _continuous_interval(self):
return float(self.last_batch_iteration + 1) / self.step_size
def _get_increase(self):
return 1 + self.step_rate * self.interval_fn()
def get_lr(self):
lr_increase = self._get_increase()
return [
lr_range_test_min_lr * lr_increase for lr_range_test_min_lr in self.min_lr
]
def get_last_lr(self):
"""Return last computed learning rate by current scheduler."""
assert getattr(self, "_last_lr", None) is not None, "need to call step() first"
return self._last_lr
def _update_optimizer(self, group_lrs):
for param_group, lr in zip(self.optimizer.param_groups, group_lrs):
param_group["lr"] = lr
def step(self, batch_iteration=None):
if batch_iteration is None:
batch_iteration = self.last_batch_iteration + 1
self.last_batch_iteration = batch_iteration
self._update_optimizer(self.get_lr())
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
def state_dict(self):
return {"last_batch_iteration": self.last_batch_iteration}
def load_state_dict(self, sd):
self.last_batch_iteration = sd["last_batch_iteration"]
class OneCycle(object):
"""Sets the learning rate of each parameter group according to
1Cycle learning rate policy (1CLR). 1CLR is a variation of the
Cyclical Learning Rate (CLR) policy that involves one cycle followed by
decay. The policy simultaneously cycles the learning rate (and momentum)
between two boundaries with a constant frequency, as detailed in
the paper `A disciplined approach to neural network hyper-parameters`_.
1CLR policy changes the learning rate after every batch.
`step` should be called after a batch has been used for training.
This implementation was adapted from the github repo: `pytorch/pytorch`_
Args:
optimizer (Optimizer): Wrapped optimizer.
cycle_min_lr (float or list): Initial learning rate which is the
lower boundary in the cycle for each parameter group.
cycle_max_lr (float or list): Upper learning rate boundaries in the cycle
for each parameter group. Functionally,
it defines the cycle amplitude (cycle_max_lr - cycle_min_lr).
The lr at any cycle is the sum of cycle_min_lr
and some scaling of the amplitude; therefore
cycle_max_lr may not actually be reached depending on
scaling function.
decay_lr_rate(float): Decay rate for learning rate. Default: 0.
cycle_first_step_size (int): Number of training iterations in the
increasing half of a cycle. Default: 2000
cycle_second_step_size (int): Number of training iterations in the
decreasing half of a cycle. If cycle_second_step_size is None,
it is set to cycle_first_step_size. Default: None
cycle_first_stair_count(int): Number of stairs in first half of cycle phase. This means
lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
cycle_second_stair_count(int): Number of stairs in second half of cycle phase. This means
lr/mom are changed in staircase fashion. Default 0, means staircase disabled.
decay_step_size (int): Intervals for applying decay in decay phase. Default: 0, means no decay.
cycle_momentum (bool): If ``True``, momentum is cycled inversely
to learning rate between 'cycle_min_mom' and 'cycle_max_mom'.
Default: True
cycle_min_mom (float or list): Initial momentum which is the
lower boundary in the cycle for each parameter group.
Default: 0.8
cycle_max_mom (float or list): Upper momentum boundaries in the cycle
for each parameter group. Functionally,
it defines the cycle amplitude (cycle_max_mom - cycle_min_mom).
The momentum at any cycle is the difference of cycle_max_mom
and some scaling of the amplitude; therefore
cycle_min_mom may not actually be reached depending on
scaling function. Default: 0.9
decay_mom_rate (float): Decay rate for momentum. Default: 0.
last_batch_iteration (int): The index of the last batch. This parameter is used when
resuming a training job. Since `step()` should be invoked after each
batch instead of after each epoch, this number represents the total
number of *batches* computed, not the total number of epochs computed.
When last_batch_iteration=-1, the schedule is started from the beginning.
Default: -1
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = OneCycle(optimizer, 0.0001, 0.0010)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> scheduler.step()
.. _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay: https://arxiv.org/abs/1803.09820
"""
def __init__(
self,
optimizer,
cycle_min_lr,
cycle_max_lr,
decay_lr_rate=0.0,
cycle_first_step_size=2000,
cycle_second_step_size=None,
cycle_first_stair_count=0,
cycle_second_stair_count=None,
decay_step_size=0,
cycle_momentum=True,
cycle_min_mom=0.8,
cycle_max_mom=0.9,
decay_mom_rate=0.0,
last_batch_iteration=-1,
):
self.optimizer = get_torch_optimizer(optimizer)
# Initialize cycle shape
self._initialize_cycle(
cycle_first_step_size,
cycle_second_step_size,
cycle_first_stair_count,
cycle_second_stair_count,
decay_step_size,
)
# Initialize cycle lr
self._initialize_lr(
self.optimizer,
cycle_min_lr,
cycle_max_lr,
decay_lr_rate,
last_batch_iteration,
)
# Initialize cyclic momentum
self.cycle_momentum = cycle_momentum
if cycle_momentum:
self._initialize_momentum(
self.optimizer,
cycle_min_mom,
cycle_max_mom,
decay_mom_rate,
last_batch_iteration,
)
# Initialize batch iteration tracker
self.last_batch_iteration = last_batch_iteration
# Configure cycle shape
def _initialize_cycle(
self,
cycle_first_step_size,
cycle_second_step_size,
cycle_first_stair_count,
cycle_second_stair_count,
decay_step_size,
):
cycle_first_step_size = float(cycle_first_step_size)
cycle_second_step_size = (
float(cycle_second_step_size)
if cycle_second_step_size is not None
else cycle_first_step_size
)
self.total_size = cycle_first_step_size + cycle_second_step_size
self.step_ratio = cycle_first_step_size / self.total_size
self.first_stair_count = cycle_first_stair_count
self.second_stair_count = (
cycle_first_stair_count
if cycle_second_stair_count is None
else cycle_second_stair_count
)
self.decay_step_size = decay_step_size
if math.isclose(self.decay_step_size, 0):
self.skip_lr_decay = True
self.skip_mom_decay = True
else:
self.skip_lr_decay = False
self.skip_mom_decay = False
# Configure lr schedule
def _initialize_lr(
self, optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, last_batch_iteration
):
self.min_lrs = [cycle_min_lr] * len(optimizer.param_groups)
if last_batch_iteration == -1:
for lr, group in zip(self.min_lrs, optimizer.param_groups):
group["lr"] = lr
self.max_lrs = [cycle_max_lr] * len(optimizer.param_groups)
self.decay_lr_rate = decay_lr_rate
if math.isclose(self.decay_lr_rate, 0):
self.skip_lr_decay = True
# Configure momentum schedule
def _initialize_momentum(
self,
optimizer,
cycle_min_mom,
cycle_max_mom,
decay_mom_rate,
last_batch_iteration,
):
if "betas" not in optimizer.defaults:
optimizer_name = type(optimizer).__name__
print(
f"cycle_momentum is disabled because optimizer {optimizer_name} does not support momentum, no betas attribute in defaults"
)
self.cycle_momentum = False
return
self.decay_mom_rate = decay_mom_rate
self.min_moms = [(cycle_min_mom, 0.99)] * len(optimizer.param_groups)
self.max_moms = [(cycle_max_mom, 0.99)] * len(optimizer.param_groups)
if last_batch_iteration == -1:
for momentum, group in zip(self.min_moms, optimizer.param_groups):
group["betas"] = momentum
if math.isclose(self.decay_mom_rate, 0):
self.skip_mom_decay = True
def _get_scale_factor(self):
batch_iteration = self.last_batch_iteration + 1
cycle = math.floor(1 + batch_iteration / self.total_size)
x = 1.0 + batch_iteration / self.total_size - cycle
if x <= self.step_ratio:
scale_factor = x / self.step_ratio
else:
scale_factor = (x - 1) / (self.step_ratio - 1)
return scale_factor
def _get_cycle_mom(self):
scale_factor = self._get_scale_factor()
momentums = []
for base_betas, max_betas in zip(self.min_moms, self.max_moms):
cycle_min_mom = base_betas[0]
cycle_max_mom = max_betas[0]
base_height = (cycle_max_mom - cycle_min_mom) * scale_factor
momentum = cycle_max_mom - base_height
momentums.append((momentum, base_betas[1]))
return momentums
def _get_cycle_lr(self):
scale_factor = self._get_scale_factor()
lrs = []
for cycle_min_lr, cycle_max_lr in zip(self.min_lrs, self.max_lrs):
base_height = (cycle_max_lr - cycle_min_lr) * scale_factor
lr = cycle_min_lr + base_height
lrs.append(lr)
return lrs
def _get_decay_mom(self, decay_batch_iteration):
if self.skip_mom_decay:
return self.max_moms
decay_interval = decay_batch_iteration / self.decay_step_size
mom_decay_factor = 1 + self.decay_mom_rate * decay_interval
momentums = [
(beta0 * mom_decay_factor, beta1) for beta0, beta1 in self.max_moms
]
return momentums
def _get_decay_lr(self, decay_batch_iteration):
"""Calculates the learning rate at batch index. This function is used
after the cycle completes and post cycle decaying of lr/mom is enabled.
This function treats `self.last_batch_iteration` as the last batch index.
"""
if self.skip_lr_decay:
return self.min_lrs
decay_interval = decay_batch_iteration / self.decay_step_size
lr_decay_factor = 1 + self.decay_lr_rate * decay_interval
lrs = [cycle_min_lr / lr_decay_factor for cycle_min_lr in self.min_lrs]
return lrs
def get_lr(self):
"""Calculates the learning rate at batch index. This function treats
`self.last_batch_iteration` as the last batch index.
"""
if self.last_batch_iteration < self.total_size:
return self._get_cycle_lr()
return self._get_decay_lr(self.last_batch_iteration - self.total_size + 1)
def get_mom(self):
"""Calculates the momentum at batch index. This function treats
`self.last_batch_iteration` as the last batch index.
"""
if not self.cycle_momentum:
return None
if self.last_batch_iteration < self.total_size:
return self._get_cycle_mom()
return self._get_decay_mom(self.last_batch_iteration - self.total_size + 1)
def get_last_lr(self):
"""Return last computed learning rate by current scheduler."""
assert getattr(self, "_last_lr", None) is not None, "need to call step() first"
return self._last_lr
def step(self, batch_iteration=None):
"""Updates the optimizer with the learning rate for the last batch index.
`self.last_batch_iteration` is treated as the last batch index.
If self.cycle_momentum is true, also updates optimizer momentum.
"""
if batch_iteration is None:
batch_iteration = self.last_batch_iteration + 1
self.last_batch_iteration = batch_iteration
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group["lr"] = lr
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
if self.cycle_momentum:
momentums = self.get_mom()
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
param_group["betas"] = momentum
def state_dict(self):
return {"last_batch_iteration": self.last_batch_iteration}
def load_state_dict(self, sd):
self.last_batch_iteration = sd["last_batch_iteration"]
class WarmupLR(object):
"""Increase the learning rate of each parameter group from min lr to max lr
over warmup_num_steps steps, and then fix at max lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
warmup_min_lr (float or list): minimum learning rate. Default: 0
warmup_max_lr (float or list): maximum learning rate. Default: 0.001
warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000
warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
last_batch_iteration (int): The index of the last batch. Default: -1.
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = WarmupLR(optimizer)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> scheduler.step()
"""
def __init__(
self,
optimizer: Optimizer,
warmup_min_lr: float = 0.0,
warmup_max_lr: float = 0.001,
warmup_num_steps: int = 1000,
warmup_type: str = WARMUP_LOG_RATE,
last_batch_iteration: int = -1,
):
self.optimizer = get_torch_optimizer(optimizer)
self.min_lrs = self._format_param(self.optimizer, warmup_min_lr, "min_lr")
self.max_lrs = self._format_param(self.optimizer, warmup_max_lr, "max_lr")
self.delta_lrs = [big - small for big, small in zip(self.max_lrs, self.min_lrs)]
self.warmup_num_steps = max(2, warmup_num_steps)
# Currently only support linear and log function
if warmup_type not in {WARMUP_LOG_RATE, WARMUP_LINEAR_RATE}:
print(
f"Using unknown warmup_type: {warmup_type}. The increasing function "
f"is set to default (log)"
)
warmup_type = WARMUP_LOG_RATE
self.warmup_type = warmup_type
self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps)
self.last_batch_iteration = last_batch_iteration
def get_lr(self):
if self.last_batch_iteration < 0:
print(
"Attempting to get learning rate from scheduler before it has started"
)
return [0.0]
gamma = self._get_gamma()
return [
min_lr + (delta_lr * gamma)
for min_lr, delta_lr in zip(self.min_lrs, self.delta_lrs)
]
def get_last_lr(self):
"""Return last computed learning rate by current scheduler."""
assert getattr(self, "_last_lr", None) is not None, "need to call step() first"
return self._last_lr
def step(self, last_batch_iteration=None):
if last_batch_iteration is None:
last_batch_iteration = self.last_batch_iteration + 1
self.last_batch_iteration = last_batch_iteration
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group["lr"] = lr
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
def state_dict(self):
return {"last_batch_iteration": self.last_batch_iteration}
def load_state_dict(self, sd):
self.last_batch_iteration = sd["last_batch_iteration"]
def _get_gamma(self):
if self.last_batch_iteration < self.warmup_num_steps:
if self.warmup_type == WARMUP_LOG_RATE:
return self.inverse_log_warm_up * math.log(
self.last_batch_iteration + 1
)
elif self.warmup_type == WARMUP_LINEAR_RATE:
return self.last_batch_iteration / self.warmup_num_steps
return 1.0
def _format_param(self, optimizer, param_value, param_name):
if isinstance(param_value, list) or isinstance(param_value, tuple):
if len(param_value) != len(optimizer.param_groups):
raise ValueError(
"expected {} value for {}, got {}".format(
len(optimizer.param_groups),
param_name,
FileNotFoundError(param_value),
)
)
return list(param_value)
return [param_value] * len(optimizer.param_groups)
class WarmupDecayLR(WarmupLR):
"""Increase the learning rate of each parameter group from min lr to max lr
over warmup_num_steps steps, and then decay at linear rate over the remaining training steps.
Args:
optimizer (Optimizer): Wrapped optimizer.
total_num_steps (int): total number of training steps
warmup_min_lr (float or list): minimum learning rate. Default: 0
warmup_max_lr (float or list): maximum learning rate. Default: 0.001
warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000
warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
last_batch_iteration (int): The index of the last batch. Default: -1.
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = WarmupDecayLR(optimizer, 1000000)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> scheduler.step()
"""
def __init__(
self,
optimizer: Optimizer,
total_num_steps: int,
warmup_min_lr: float = 0.0,
warmup_max_lr: float = 0.001,
warmup_num_steps: int = 1000,
warmup_type: str = WARMUP_LOG_RATE,
last_batch_iteration: int = -1,
):
self.total_num_steps = total_num_steps
super(WarmupDecayLR, self).__init__(
optimizer,
warmup_min_lr,
warmup_max_lr,
warmup_num_steps,
warmup_type,
last_batch_iteration,
)
if self.total_num_steps < self.warmup_num_steps:
print(
"total_num_steps {} is less than warmup_num_steps {}".format(
total_num_steps, warmup_num_steps
)
)
def _get_gamma(self):
if self.last_batch_iteration < self.warmup_num_steps:
if self.warmup_type == WARMUP_LOG_RATE:
return self.inverse_log_warm_up * math.log(
self.last_batch_iteration + 1
)
elif self.warmup_type == WARMUP_LINEAR_RATE:
return self.last_batch_iteration / self.warmup_num_steps
return max(
0.0,
float(self.total_num_steps - self.last_batch_iteration)
/ float(max(1.0, self.total_num_steps - self.warmup_num_steps)),
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Union, Optional
from lightop import op as ops
try:
import flash_attn
if hasattr(flash_attn, "__version__") and int(flash_attn.__version__[0]) == 2:
from flash_attn.flash_attn_interface import flash_attn_kvpacked_func
from flash_attn.modules.mha import FlashSelfAttention, FlashCrossAttention
else:
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
from flash_attn.modules.mha import FlashSelfAttention, FlashCrossAttention
except Exception as e:
print(f"flash_attn import failed: {e}")
def reshape_for_broadcast(
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
x: torch.Tensor,
head_first=False,
):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Args:
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
torch.Tensor: Reshaped frequency tensor.
Raises:
AssertionError: If the frequency tensor doesn't match the expected shape.
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
if isinstance(freqs_cis, tuple):
# freqs_cis: (cos, sin) in real space
if head_first:
assert freqs_cis[0].shape == (
x.shape[-2],
x.shape[-1],
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
shape = [
d if i == ndim - 2 or i == ndim - 1 else 1
for i, d in enumerate(x.shape)
]
else:
assert freqs_cis[0].shape == (
x.shape[1],
x.shape[-1],
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
else:
# freqs_cis: values in complex space
if head_first:
assert freqs_cis.shape == (
x.shape[-2],
x.shape[-1],
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
shape = [
d if i == ndim - 2 or i == ndim - 1 else 1
for i, d in enumerate(x.shape)
]
else:
assert freqs_cis.shape == (
x.shape[1],
x.shape[-1],
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def rotate_half(x):
x_real, x_imag = (
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(
xq: torch.Tensor,
xk: Optional[torch.Tensor],
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
head_first: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
xk_out = None
if isinstance(freqs_cis, tuple):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
cos, sin = cos.to(xq.device), sin.to(xq.device)
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
if xk is not None:
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
else:
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], -1, 2)
) # [B, S, H, D//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
xq.device
) # [S, D//2] --> [1, S, 1, D//2]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
if xk is not None:
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], -1, 2)
) # [B, S, H, D//2]
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
return xq_out, xk_out
class FlashSelfMHAModified(nn.Module):
"""
Use QK Normalization.
"""
def __init__(
self,
dim,
num_heads,
qkv_bias=True,
qk_norm=False,
attn_drop=0.0,
proj_drop=0.0,
device=None,
dtype=None,
norm_layer=nn.LayerNorm,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.dim = dim
self.num_heads = num_heads
assert self.dim % num_heads == 0, "self.kdim must be divisible by num_heads"
self.head_dim = self.dim // num_heads
assert (
self.head_dim % 8 == 0 and self.head_dim <= 128
), "Only support head_dim <= 128 and divisible by 8"
self.Wqkv = nn.Linear(dim, 3 * dim, bias=qkv_bias, **factory_kwargs)
# TODO: eps should be 1 / 65530 if using fp16
self.q_norm = (
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
if qk_norm
else nn.Identity()
)
self.k_norm = (
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
if qk_norm
else nn.Identity()
)
self.inner_attn = FlashSelfAttention(attention_dropout=attn_drop)
self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs)
self.proj_drop = nn.Dropout(proj_drop)
self.gamma_q = torch.ones(self.head_dim, device='cuda', dtype=torch.float16)
self.beta_q = torch.zeros(self.head_dim, device='cuda', dtype=torch.float16)
self.gamma_k = torch.ones(self.head_dim, device='cuda', dtype=torch.float16)
self.beta_k = torch.zeros(self.head_dim, device='cuda', dtype=torch.float16)
self.eps = 1e-6
def forward(self, x, freqs_cis_img=None):
"""
Parameters
----------
x: torch.Tensor
(batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
freqs_cis_img: torch.Tensor
(batch, hidden_dim // 2), RoPE for image
"""
b, s, d = x.shape
qkv = self.Wqkv(x)
qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim).contiguous() # [b, s, 3, h, d]
ops.mha_norm_rope_qkv_inplace_cuda(
qkv, # arg0: Tensor [B,S,3,H,D], contiguous
self.q_norm.weight, # arg1: Optional[Tensor] or None, [D]
self.q_norm.bias, # arg2: Optional[Tensor] or None, [D]
self.k_norm.weight, # arg3: Optional[Tensor] or None, [D]
self.k_norm.bias, # arg4: Optional[Tensor] or None, [D]
freqs_cis_img[0], # arg5: Tensor float32 [S,D], CUDA, contiguous
freqs_cis_img[1], # arg6: Tensor float32 [S,D], CUDA, contiguous
self.q_norm.eps # arg7: float
)
context = self.inner_attn(qkv)
out = self.out_proj(context.view(b, s, d))
out = self.proj_drop(out)
out_tuple = (out,)
return out_tuple
class FlashCrossMHAModified(nn.Module):
"""
Use QK Normalization.
"""
def __init__(
self,
qdim,
kdim,
num_heads,
qkv_bias=True,
qk_norm=False,
is_ipa=False,
attn_drop=0.0,
proj_drop=0.0,
device=None,
dtype=None,
norm_layer=nn.LayerNorm,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.qdim = qdim
self.kdim = kdim
self.num_heads = num_heads
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
self.head_dim = self.qdim // num_heads
assert (
self.head_dim % 8 == 0 and self.head_dim <= 128
), "Only support head_dim <= 128 and divisible by 8"
self.scale = self.head_dim**-0.5
self.is_ipa = is_ipa
self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
# TODO: eps should be 1 / 65530 if using fp16
self.q_norm = (
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
if qk_norm
else nn.Identity()
)
self.k_norm = (
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
if qk_norm
else nn.Identity()
)
self.inner_attn = FlashCrossAttention(attention_dropout=attn_drop)
self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
self.proj_drop = nn.Dropout(proj_drop)
if self.is_ipa:
self.kv_proj_ip_adapter = nn.Linear(
kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs
)
self.k_norm_ip_adapter = norm_layer(
self.head_dim, elementwise_affine=True, eps=1e-6
)
def forward(self, x, y, z, t_scale, i_scale, freqs_cis_img=None, is_ipa=False):
"""
Parameters
----------
x: torch.Tensor
(batch, seqlen1, hidden_dim) (where hidden_dim = num_heads * head_dim)
y: torch.Tensor
(batch, seqlen2, hidden_dim2)
freqs_cis_img: torch.Tensor
(batch, hidden_dim // num_heads), RoPE for image
"""
b, s1, _ = x.shape # [b, s1, D]
_, s2, _ = y.shape # [b, s2, 1024]
q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim).contiguous() # [b, s1, h, d]
kv = self.kv_proj(y).view(
b, s2, 2, self.num_heads, self.head_dim).contiguous() # [b, s2, 2, h, d]
# print("q:",q.dtype)
# print("self.gamma_q:",self.gamma_q.dtype)
ops.cross_norm_rope_q_kv_inplace_cuda(
q, # [B,S1,H,D]
kv, # [B,S2,2,H,D]
self.q_norm.weight, self.q_norm.bias,
self.k_norm.weight, self.k_norm.bias,
freqs_cis_img[0], freqs_cis_img[1],
self.q_norm.eps)
context = self.inner_attn(q, kv) # [b, s1, h, d]
context = context.view(b, s1, -1) # [b, s1, D]
if is_ipa:
assert z is not None
_, s3, _ = z.shape
kv_2 = self.kv_proj_ip_adapter(z).view(
b, s3, 2, self.num_heads, self.head_dim
)
k_2, v_2 = kv_2.unbind(dim=2) # [b, s, h, d]
k_2 = self.k_norm_ip_adapter(k_2).half()
kv_2 = torch.stack([k_2, v_2], dim=2)
context_2 = self.inner_attn(q, kv_2)
context_2 = context_2.view(b, s1, -1)
context = context * t_scale + context_2 * i_scale
out = self.out_proj(context)
out = self.proj_drop(out)
out_tuple = (out,)
return out_tuple
class CrossAttention(nn.Module):
"""
Use QK Normalization.
"""
def __init__(
self,
qdim,
kdim,
num_heads,
qkv_bias=True,
qk_norm=False,
is_ipa=False,
attn_drop=0.0,
proj_drop=0.0,
device=None,
dtype=None,
norm_layer=nn.LayerNorm,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.qdim = qdim
self.kdim = kdim
self.num_heads = num_heads
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
self.head_dim = self.qdim // num_heads
assert (
self.head_dim % 8 == 0 and self.head_dim <= 128
), "Only support head_dim <= 128 and divisible by 8"
self.scale = self.head_dim**-0.5
self.is_ipa = is_ipa
self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
# TODO: eps should be 1 / 65530 if using fp16
self.q_norm = (
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
if qk_norm
else nn.Identity()
)
self.k_norm = (
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
if qk_norm
else nn.Identity()
)
self.attn_drop = nn.Dropout(attn_drop)
self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
self.proj_drop = nn.Dropout(proj_drop)
if self.is_ipa:
self.kv_proj_ip_adapter = nn.Linear(
kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs
)
self.k_norm_ip_adapter = norm_layer(
self.head_dim, elementwise_affine=True, eps=1e-6
)
def forward(self, x, y, z, t_scale, i_scale, freqs_cis_img=None, is_ipa=False):
"""
Parameters
----------
x: torch.Tensor
(batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
y: torch.Tensor
(batch, seqlen2, hidden_dim2)
freqs_cis_img: torch.Tensor
(batch, hidden_dim // 2), RoPE for image
"""
b, s1, c = x.shape # [b, s1, D]
_, s2, c = y.shape # [b, s2, 1024]
q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
kv = self.kv_proj(y).view(
b, s2, 2, self.num_heads, self.head_dim
) # [b, s2, 2, h, d]
k, v = kv.unbind(dim=2) # [b, s, h, d]
q = self.q_norm(q)
k = self.k_norm(k)
# Apply RoPE if needed
if freqs_cis_img is not None:
qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
assert qq.shape == q.shape, f"qq: {qq.shape}, q: {q.shape}"
q = qq
if torch.__version__[0] == '1':
q = q * self.scale
q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
k = k.permute(0, 2, 3, 1).contiguous() # k -> B, L2, H, C - B, H, C, L2
attn = q.float() @ k.float() # attn -> B, H, L1, L2
attn = attn.softmax(dim=-1).to(q.dtype) # attn -> B, H, L1, L2
attn = self.attn_drop(attn)
x = attn @ v.transpose(-2, -3) # v -> B, L2, H, C - B, H, L2, C x-> B, H, L1, C
elif torch.__version__[0] == '2':
# import pdb;pdb.set_trace()
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p)
else:
raise NotImplementedError
context = x.transpose(1, 2) # context -> B, H, L1, C - B, L1, H, C
if is_ipa:
assert z is not None
_, s3, c = z.shape
kv_2 = self.kv_proj_ip_adapter(z).view(
b, s3, 2, self.num_heads, self.head_dim
)
k_2, v_2 = kv_2.unbind(dim=2) # [b, s, h, d]
k_2 = self.k_norm_ip_adapter(k_2)
k_2 = k_2.permute(0, 2, 3, 1).contiguous()
attn_2 = q @ k_2
attn_2 = attn_2.softmax(dim=-1).half()
x_2 = attn_2 @ v_2.transpose(-2, -3)
context_2 = x_2.transpose(1, 2)
context = context * t_scale + context_2 * i_scale
context = context.contiguous().view(b, s1, -1)
out = self.out_proj(context) # context.reshape - B, L1, -1
out = self.proj_drop(out)
out_tuple = (out,)
return out_tuple
class Attention(nn.Module):
"""
We rename some layer names to align with flash attention
"""
def __init__(
self,
dim,
num_heads,
qkv_bias=True,
qk_norm=False,
attn_drop=0.0,
proj_drop=0.0,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
assert self.dim % num_heads == 0, "dim should be divisible by num_heads"
self.head_dim = self.dim // num_heads
# This assertion is aligned with flash attention
assert (
self.head_dim % 8 == 0 and self.head_dim <= 128
), "Only support head_dim <= 128 and divisible by 8"
self.scale = self.head_dim**-0.5
# qkv --> Wqkv
self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# TODO: eps should be 1 / 65530 if using fp16
self.q_norm = (
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
if qk_norm
else nn.Identity()
)
self.k_norm = (
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6)
if qk_norm
else nn.Identity()
)
self.attn_drop = nn.Dropout(attn_drop)
self.out_proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, freqs_cis_img=None, mask=None):
B, N, C = x.shape
qkv = (
self.Wqkv(x)
.reshape(B, N, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4)
) # [3, b, h, s, d]
q, k, v = qkv.unbind(0) # [b, h, s, d]
q = self.q_norm(q) # [b, h, s, d]
k = self.k_norm(k) # [b, h, s, d]
# Apply RoPE if needed
if freqs_cis_img is not None:
qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True)
assert (
qq.shape == q.shape and kk.shape == k.shape
), f"qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}"
q, k = qq, kk
if torch.__version__[0] == '1':
q = q * self.scale
# Here we force q and k to be float32 to avoid numerical overflow
attn = q.float() @ k.float().transpose(-2, -1) # [b, h, s, d] @ [b, h, d, s]
if mask is not None:
attn = attn + mask
attn = attn.softmax(dim=-1).to(q.dtype) # [b, h, s, s]
attn = self.attn_drop(attn)
x = attn @ v # [b, h, s, d]
elif torch.__version__[0] == '2':
x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.attn_drop.p)
else:
raise NotImplementedError
x = x.transpose(1, 2).reshape(B, N, C) # [b, s, h, d]
x = self.out_proj(x)
x = self.proj_drop(x)
out_tuple = (x,)
return out_tuple
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from peft.utils import (
ModulesToSaveWrapper,
_get_submodules,
)
from timm.models.vision_transformer import Mlp
from torch.utils import checkpoint
from tqdm import tqdm
from transformers.integrations import PeftAdapterMixin
from .attn_layers import (
Attention,
FlashCrossMHAModified,
FlashSelfMHAModified,
CrossAttention,
)
from .embedders import TimestepEmbedder, PatchEmbed, timestep_embedding
from .norm_layers import RMSNorm
from .poolers import AttentionPool
from .models import FP32_Layernorm, FP32_SiLU, HunYuanDiTBlock
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
class HunYuanControlNet(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
input_size: tuple
The size of the input image.
patch_size: int
The size of the patch.
in_channels: int
The number of input channels.
hidden_size: int
The hidden size of the transformer backbone.
depth: int
The number of transformer blocks.
num_heads: int
The number of attention heads.
mlp_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
log_fn: callable
The logging function.
"""
@register_to_config
def __init__(
self,
args: Any,
input_size: tuple = (32, 32),
patch_size: int = 2,
in_channels: int = 4,
hidden_size: int = 1152,
depth: int = 28,
num_heads: int = 16,
mlp_ratio: float = 4.0,
log_fn: callable = print,
):
super().__init__()
self.args = args
self.log_fn = log_fn
self.depth = depth
self.learn_sigma = args.learn_sigma
self.in_channels = in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.hidden_size = hidden_size
self.text_states_dim = args.text_states_dim
self.text_states_dim_t5 = args.text_states_dim_t5
self.text_len = args.text_len
self.text_len_t5 = args.text_len_t5
self.norm = args.norm
use_flash_attn = args.infer_mode == "fa" or args.use_flash_attn
if use_flash_attn:
log_fn(f" Enable Flash Attention.")
qk_norm = args.qk_norm # See http://arxiv.org/abs/2302.05442 for details.
self.mlp_t5 = nn.Sequential(
nn.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True),
FP32_SiLU(),
nn.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True),
)
# learnable replace
self.text_embedding_padding = nn.Parameter(
torch.randn(
self.text_len + self.text_len_t5,
self.text_states_dim,
dtype=torch.float32,
)
)
# Attention pooling
pooler_out_dim = 1024
self.pooler = AttentionPool(
self.text_len_t5,
self.text_states_dim_t5,
num_heads=8,
output_dim=pooler_out_dim,
)
# Dimension of the extra input vectors
self.extra_in_dim = pooler_out_dim
if args.size_cond:
# Image size and crop size conditions
self.extra_in_dim += 6 * 256
if args.use_style_cond:
# Here we use a default learned embedder layer for future extension.
self.style_embedder = nn.Embedding(1, hidden_size)
self.extra_in_dim += hidden_size
# Text embedding for `add`
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
self.extra_embedder = nn.Sequential(
nn.Linear(self.extra_in_dim, hidden_size * 4),
FP32_SiLU(),
nn.Linear(hidden_size * 4, hidden_size, bias=True),
)
# Image embedding
num_patches = self.x_embedder.num_patches
log_fn(f" Number of tokens: {num_patches}")
# HUnYuanDiT Blocks
self.blocks = nn.ModuleList(
[
HunYuanDiTBlock(
hidden_size=hidden_size,
c_emb_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
text_states_dim=self.text_states_dim,
use_flash_attn=use_flash_attn,
qk_norm=qk_norm,
norm_type=self.norm,
skip=False,
)
for _ in range(19)
]
)
# Input zero linear for the first block
self.before_proj = zero_module(nn.Linear(self.hidden_size, self.hidden_size))
# Output zero linear for the every block
self.after_proj_list = nn.ModuleList(
[
zero_module(nn.Linear(self.hidden_size, self.hidden_size))
for _ in range(len(self.blocks))
]
)
self.fix_weight_modules = [
"mlp_t5",
"text_embedding_padding",
"pooler",
"style_embedder",
"x_embedder",
"t_embedder",
"extra_embedder",
]
def check_condition_validation(self, image_meta_size, style):
if self.args.size_cond is None and image_meta_size is not None:
raise ValueError(
f"When `size_cond` is None, `image_meta_size` should be None, but got "
f"{type(image_meta_size)}. "
)
if self.args.size_cond is not None and image_meta_size is None:
raise ValueError(
f"When `size_cond` is not None, `image_meta_size` should not be None. "
)
if not self.args.use_style_cond and style is not None:
raise ValueError(
f"When `use_style_cond` is False, `style` should be None, but got {type(style)}. "
)
if self.args.use_style_cond and style is None:
raise ValueError(
f"When `use_style_cond` is True, `style` should be not None."
)
def enable_gradient_checkpointing(self):
for block in self.blocks:
block.gradient_checkpointing = True
def disable_gradient_checkpointing(self):
for block in self.blocks:
block.gradient_checkpointing = False
def from_dit(self, dit):
"""
Load the parameters from a pre-trained HunYuanDiT model.
Parameters
----------
dit: HunYuanDiT
The pre-trained HunYuanDiT model.
"""
self.mlp_t5.load_state_dict(dit.mlp_t5.state_dict())
self.text_embedding_padding.data = dit.text_embedding_padding.data
self.pooler.load_state_dict(dit.pooler.state_dict())
if self.args.use_style_cond:
self.style_embedder.load_state_dict(dit.style_embedder.state_dict())
self.x_embedder.load_state_dict(dit.x_embedder.state_dict())
self.t_embedder.load_state_dict(dit.t_embedder.state_dict())
self.extra_embedder.load_state_dict(dit.extra_embedder.state_dict())
for i, block in enumerate(self.blocks):
block.load_state_dict(dit.blocks[i].state_dict())
def set_trainable(self):
self.mlp_t5.requires_grad_(False)
self.text_embedding_padding.requires_grad_(False)
self.pooler.requires_grad_(False)
if self.args.use_style_cond:
self.style_embedder.requires_grad_(False)
self.x_embedder.requires_grad_(False)
self.t_embedder.requires_grad_(False)
self.extra_embedder.requires_grad_(False)
self.blocks.requires_grad_(True)
self.before_proj.requires_grad_(True)
self.after_proj_list.requires_grad_(True)
self.blocks.train()
self.before_proj.train()
self.after_proj_list.train()
def forward(
self,
x,
t,
condition,
encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
cos_cis_img=None,
sin_cis_img=None,
return_dict=True,
):
"""
Forward pass of the encoder.
Parameters
----------
x: torch.Tensor
(B, D, H, W)
t: torch.Tensor
(B)
encoder_hidden_states: torch.Tensor
CLIP text embedding, (B, L_clip, D)
text_embedding_mask: torch.Tensor
CLIP text embedding mask, (B, L_clip)
encoder_hidden_states_t5: torch.Tensor
T5 text embedding, (B, L_t5, D)
text_embedding_mask_t5: torch.Tensor
T5 text embedding mask, (B, L_t5)
image_meta_size: torch.Tensor
(B, 6)
style: torch.Tensor
(B)
cos_cis_img: torch.Tensor
sin_cis_img: torch.Tensor
return_dict: bool
Whether to return a dictionary.
"""
text_states = encoder_hidden_states # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
text_states_mask = text_embedding_mask.bool() # 2,77
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
b_t5, l_t5, c_t5 = text_states_t5.shape
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5))
text_states = torch.cat(
[text_states, text_states_t5.view(b_t5, l_t5, -1)], dim=1
) # 2,205,1024
clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
clip_t5_mask = clip_t5_mask
text_states = torch.where(
clip_t5_mask.unsqueeze(2),
text_states,
self.text_embedding_padding.to(text_states),
)
_, _, oh, ow = x.shape
th, tw = oh // self.patch_size, ow // self.patch_size
# ========================= Build time and image embedding =========================
t = self.t_embedder(t)
x = self.x_embedder(x)
# Get image RoPE embedding according to `reso`lution.
freqs_cis_img = (cos_cis_img, sin_cis_img)
# ========================= Concatenate all extra vectors =========================
# Build text tokens with pooling
extra_vec = self.pooler(encoder_hidden_states_t5)
self.check_condition_validation(image_meta_size, style)
# Build image meta size tokens if applicable
if image_meta_size is not None:
image_meta_size = timestep_embedding(
image_meta_size.view(-1), 256
) # [B * 6, 256]
if self.args.use_fp16:
image_meta_size = image_meta_size.half()
image_meta_size = image_meta_size.view(-1, 6 * 256)
extra_vec = torch.cat(
[extra_vec, image_meta_size], dim=1
) # [B, D + 6 * 256]
# Build style tokens
if style is not None:
style_embedding = self.style_embedder(style)
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
# Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D]
# ========================= Deal with Condition =========================
condition = self.x_embedder(condition)
# ========================= Forward pass through HunYuanDiT blocks =========================
controls = []
x = x + self.before_proj(condition) # add condition
for layer, block in enumerate(self.blocks):
x = block(x=x, c=c, text_states=text_states, freq_cis_img=freqs_cis_img)
controls.append(self.after_proj_list[layer](x)) # zero linear for output
if return_dict:
return {"controls": controls}
return controls
from collections import OrderedDict
from copy import deepcopy
import torch
from deepspeed.utils import instrument_w_nvtx
from pathlib import Path
def requires_grad(model, flag=True):
"""
Set requires_grad flag for all parameters in a model.
"""
for p in model.parameters():
p.requires_grad = flag
class EMA(object):
def __init__(self, args, model, device, logger):
if args.ema_dtype == "fp32":
self.warmup = args.ema_warmup
self.update_after_step = 0
self.max_value = args.ema_decay if args.ema_decay is not None else 0.9999
self.inv_gamma = 1.0
self.power = (
args.ema_warmup_power if args.ema_warmup_power is not None else 2 / 3
)
self.min_value = 0.0
else:
self.warmup = args.ema_warmup
self.update_after_step = 0
self.max_value = args.ema_decay if args.ema_decay is not None else 0.992
self.inv_gamma = 1.0
self.power = (
args.ema_warmup_power if args.ema_warmup_power is not None else 0.446249
)
# 0.446249 == math.log(1 - 0.992) / math.log(50000)
self.min_value = 0.0
self.ema_reset_decay = args.ema_reset_decay
self.decay_steps = 0
if args.ema_dtype == "none":
ema_dtype = "fp16" if args.use_fp16 else "fp32"
else:
ema_dtype = args.ema_dtype
# 由于module.half()和module.float()会发生inplace类型修改,因此需要先copy后修改类型
self.ema_model = deepcopy(model)
if ema_dtype == "fp16":
self.ema_model = self.ema_model.half().to(device)
elif ema_dtype == "fp32":
self.ema_model = self.ema_model.float().to(device)
else:
raise ValueError(f"Unknown EMA dtype {ema_dtype}.")
requires_grad(self.ema_model, False)
logger.info(
f" Using EMA with date type {args.ema_dtype} "
f"(decay={args.ema_decay}, warmup={args.ema_warmup}, warmup_power={args.ema_warmup_power}, "
f"reset_decay={args.ema_reset_decay})."
)
def get_decay(self):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
at 215.4k steps).
@jarvizhang's notes on EMA max_value when enabling FP16:
If using FP16 for EMA, max_value=0.995 is better (Don't larger than 0.999, unless you know
what you are doing). This is because FP16 has less precision than FP32, so the EMA value can
be pushed out of the range of FP16.
gamma=1, power=0.446249 are good values for models (reaches decay factor 0.99 at 30K steps,
0.992 at 50K steps).
"""
if self.warmup:
step = max(0, self.decay_steps - self.update_after_step - 1)
value = 1 - (1 + step / self.inv_gamma) ** -self.power
if step <= 0:
return 0.0
return max(self.min_value, min(value, self.max_value))
else:
return self.max_value
@torch.no_grad()
@instrument_w_nvtx
def update(self, model, step, decay=None):
"""
Step the EMA model towards the current model.
Parameters
----------
model: nn.Module
The current model
step: int
The current training step. This is used to determine the decay factor. If you want to control
the decay, you can pass in a custom step instead.
For example, if you want to restart the EMA decay, you can pass in step=0 at start and increase
step by step.
decay: float
The decay factor. If None, will be determined by the current step.
"""
if decay is None:
if self.ema_reset_decay:
self.decay_steps += 1
else:
self.decay_steps = step
decay = self.get_decay()
ema_params = OrderedDict(self.ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
return None
def state_dict(self, *args, **kwargs):
return self.ema_model.state_dict(*args, **kwargs)
def load_state_dict(self, *args, **kwargs):
return self.ema_model.load_state_dict(*args, **kwargs)
def train(self):
self.ema_model.train()
def eval(self):
self.ema_model.eval()
import math
import torch
import torch.nn as nn
from einops import repeat
from timm.models.layers import to_2tuple
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding
Image to Patch Embedding using Conv2d
A convolution based approach to patchifying a 2D image w/ embedding projection.
Based on the impl in https://github.com/google-research/vision_transformer
Hacked together by / Copyright 2020 Ross Wightman
Remove the _assert function in forward function to be compatible with multi-resolution images.
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
):
super().__init__()
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
elif isinstance(img_size, (tuple, list)) and len(img_size) == 2:
img_size = tuple(img_size)
else:
raise ValueError(
f"img_size must be int or tuple/list of length 2. Got {img_size}"
)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def update_image_size(self, img_size):
self.img_size = img_size
self.grid_size = (
img_size[0] // self.patch_size[0],
img_size[1] // self.patch_size[1],
)
self.num_patches = self.grid_size[0] * self.grid_size[1]
def forward(self, x):
# B, C, H, W = x.shape
# _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
# _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2).contiguous() # BCHW -> BNC
x = self.norm(x)
return x
def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(
device=t.device
) # size: [dim/2], 一个指数衰减的曲线
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
else:
embedding = repeat(t, "b -> b d", d=dim)
return embedding
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256, out_size=None):
super().__init__()
if out_size is None:
out_size = hidden_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, out_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
def forward(self, t):
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(
self.mlp[0].weight.dtype
)
t_emb = self.mlp(t_freq)
return t_emb
import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure."""
if isinstance(val, dict):
res_dict = {}
for k, v in val.items():
if k != "cos_cis_img" and k != "sin_cis_img":
res_dict[k] = conversion_helper(v, conversion)
else:
res_dict[k] = v
return res_dict
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn
def fp32_to_float16(val, float16_convertor):
"""Convert fp32 `val` to fp16/bf16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _FLOAT_TYPES):
val = float16_convertor(val)
return val
return conversion_helper(val, half_conversion)
def float16_to_fp32(val):
"""Convert fp16/bf16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, (_HALF_TYPES,)):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class Float16Module(torch.nn.Module):
def __init__(self, module, args):
super(Float16Module, self).__init__()
self.add_module("module", module.half())
def float16_convertor(val):
return val.half()
self.float16_convertor = float16_convertor
self.config = self.module.config
self.dtype = torch.float16
def forward(self, *inputs, **kwargs):
inputs = fp32_to_float16(inputs, self.float16_convertor)
kwargs = fp32_to_float16(kwargs, self.float16_convertor)
outputs = self.module(*inputs, **kwargs)
outputs = float16_to_fp32(outputs)
return outputs
def state_dict(self, destination=None, prefix="", keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars)
def state_dict_for_save_checkpoint(
self, destination=None, prefix="", keep_vars=False
):
return self.module.state_dict_for_save_checkpoint(
destination, prefix, keep_vars
)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from peft.utils import (
ModulesToSaveWrapper,
_get_submodules,
)
from timm.models.vision_transformer import Mlp
from torch.utils import checkpoint
from tqdm import tqdm
from transformers.integrations import PeftAdapterMixin
from .attn_layers import (
Attention,
FlashCrossMHAModified,
FlashSelfMHAModified,
CrossAttention,
)
from .embedders import TimestepEmbedder, PatchEmbed, timestep_embedding
from .norm_layers import RMSNorm
from .poolers import AttentionPool
from apex.fused_dense import fused_dense_gelu_dense_function,fused_dense_function
from lightop import op
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def ckpt_wrapper(module):
def ckpt_forward(*inputs):
outputs = module(*inputs)
return outputs
return ckpt_forward
# class FP32_Layernorm(nn.LayerNorm):
# def forward(self, inputs: torch.Tensor) -> torch.Tensor:
# origin_dtype = inputs.dtype
# return F.layer_norm(
# inputs.float(),
# self.normalized_shape,
# self.weight.float(),
# self.bias.float(),
# self.eps,
# ).to(origin_dtype)
class FP32_Layernorm(torch.nn.Module):
def __init__(self, normalized_shape:int,eps: float = 1e-5,device=None, dtype=None ):
super().__init__()
self.eps = eps
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, **factory_kwargs))
self.bias = torch.nn.Parameter(torch.empty(normalized_shape, **factory_kwargs))
torch.nn.init.ones_(self.weight)
torch.nn.init.zeros_(self.bias)
def forward(self, x ):
return op.layernorm_forward_autograd(x, self.weight,self.bias,self.eps,self.training)
def extra_repr(self):
return f'eps={round(self.eps,5):0.5f}'
class FP32_SiLU(nn.SiLU):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.silu(inputs.float(), inplace=False).to(inputs.dtype)
class HunYuanDiTBlock(nn.Module):
"""
A HunYuanDiT block with `add` conditioning.
"""
def __init__(
self,
hidden_size,
c_emb_size,
num_heads,
mlp_ratio=4.0,
text_states_dim=1024,
use_flash_attn=False,
qk_norm=False,
norm_type="layer",
skip=False,
is_ipa=False,
):
super().__init__()
self.use_flash_attn = use_flash_attn
use_ele_affine = True
if norm_type == "layer":
norm_layer = FP32_Layernorm
elif norm_type == "rms":
norm_layer = RMSNorm
else:
raise ValueError(f"Unknown norm_type: {norm_type}")
# ========================= Self-Attention =========================
self.norm1 = norm_layer(
hidden_size, eps=1e-6
)
if use_flash_attn:
self.attn1 = FlashSelfMHAModified(
hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm
)
else:
self.attn1 = Attention(
hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm
)
# ========================= FFN =========================
self.norm2 = norm_layer(
hidden_size, eps=1e-6
)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=approx_gelu,
drop=0,
)
# ========================= Add =========================
# Simply use add like SDXL.
self.default_modulation = nn.Sequential(
FP32_SiLU(), nn.Linear(c_emb_size, hidden_size, bias=True)
)
# ========================= Cross-Attention =========================
if use_flash_attn:
self.attn2 = FlashCrossMHAModified(
hidden_size,
text_states_dim,
num_heads=num_heads,
qkv_bias=True,
qk_norm=qk_norm,
is_ipa=is_ipa,
)
else:
self.attn2 = CrossAttention(
hidden_size,
text_states_dim,
num_heads=num_heads,
qkv_bias=True,
qk_norm=qk_norm,
is_ipa=is_ipa,
)
self.norm3 = norm_layer(hidden_size, eps=1e-6)
# ========================= Skip Connection =========================
if skip:
self.skip_norm = norm_layer(
2 * hidden_size, eps=1e-6
)
self.skip_linear = nn.Linear(2 * hidden_size, hidden_size)
else:
self.skip_linear = None
self.gradient_checkpointing = False
def _forward(
self,
x,
c=None,
text_states=None,
img_clip_embedding=None,
t_scale=1,
i_scale=1,
freq_cis_img=None,
skip=None,
is_ipa=False,
):
# Long Skip Connection
if self.skip_linear is not None:
cat = torch.cat([x, skip], dim=-1)
cat = self.skip_norm(cat)
x = self.skip_linear(cat)
# Self-Attention
shift_msa = self.default_modulation(c).unsqueeze(dim=1)
attn_inputs = (
self.norm1(x) + shift_msa,
freq_cis_img,
)
x = x + self.attn1(*attn_inputs)[0]
# Cross-Attention
cross_inputs = (
self.norm3(x),
text_states,
img_clip_embedding,
t_scale,
i_scale,
freq_cis_img,
is_ipa,
)
x = x + self.attn2(*cross_inputs)[0]
# FFN Layer
mlp_inputs = self.norm2(x)
# x = x + self.mlp(mlp_inputs)
mlp_shape=mlp_inputs.shape
mlp_out= fused_dense_gelu_dense_function(mlp_inputs.view(-1,mlp_shape[2]), self.mlp.fc1.weight, self.mlp.fc1.bias, self.mlp.fc2.weight, self.mlp.fc2.bias)
mlp_out = mlp_out.view(mlp_shape[0],mlp_shape[1],mlp_shape[2])
x = x + mlp_out
return x
def forward(
self,
x,
c=None,
text_states=None,
img_clip_embedding=None,
t_scale=1,
i_scale=1,
freq_cis_img=None,
skip=None,
is_ipa=False,
):
if self.gradient_checkpointing and self.training:
return checkpoint.checkpoint(
self._forward,
x,
c,
text_states,
img_clip_embedding,
t_scale,
i_scale,
freq_cis_img,
skip,
is_ipa,
)
return self._forward(
x,
c,
text_states,
img_clip_embedding,
t_scale,
i_scale,
freq_cis_img,
skip,
is_ipa,
)
class FinalLayer(nn.Module):
"""
The final layer of HunYuanDiT.
"""
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(
final_hidden_size, elementwise_affine=False, eps=1e-6
)
self.linear = nn.Linear(
final_hidden_size, patch_size * patch_size * out_channels, bias=True
)
self.adaLN_modulation = nn.Sequential(
FP32_SiLU(), nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class HunYuanDiT(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
input_size: tuple
The size of the input image.
patch_size: int
The size of the patch.
in_channels: int
The number of input channels.
hidden_size: int
The hidden size of the transformer backbone.
depth: int
The number of transformer blocks.
num_heads: int
The number of attention heads.
mlp_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
log_fn: callable
The logging function.
"""
@register_to_config
def __init__(
self,
args: Any,
input_size: tuple = (32, 32),
patch_size: int = 2,
in_channels: int = 4,
hidden_size: int = 1152,
depth: int = 28,
num_heads: int = 16,
mlp_ratio: float = 4.0,
log_fn: callable = print,
):
super().__init__()
self.args = args
self.log_fn = log_fn
self.depth = depth
self.learn_sigma = args.learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if args.learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.hidden_size = hidden_size
self.text_states_dim = args.text_states_dim
self.text_states_dim_t5 = args.text_states_dim_t5
self.text_len = args.text_len
self.text_len_t5 = args.text_len_t5
self.norm = args.norm
self.is_ipa = args.is_ipa
use_flash_attn = args.infer_mode == "fa" or args.use_flash_attn
if use_flash_attn:
log_fn(f" Enable Flash Attention.")
qk_norm = args.qk_norm # See http://arxiv.org/abs/2302.05442 for details.
self.mlp_t5 = nn.Sequential(
nn.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True),
FP32_SiLU(),
nn.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True),
)
# learnable replace
self.text_embedding_padding = nn.Parameter(
torch.randn(
self.text_len + self.text_len_t5,
self.text_states_dim,
dtype=torch.float32,
)
)
if self.is_ipa:
self.img_clip_hidden_dim = 1024
self.img_clip_seq_len = 8
self.ip_adapter_clip_len_trans_block = nn.Linear(
self.img_clip_hidden_dim, self.img_clip_seq_len * 1024
)
nn.init.normal_(
self.ip_adapter_clip_len_trans_block.weight,
std=self.img_clip_hidden_dim**-0.5,
)
self.ip_adapter_img_norm_cond = nn.LayerNorm(1024)
# Attention pooling
pooler_out_dim = 1024
self.pooler = AttentionPool(
self.text_len_t5,
self.text_states_dim_t5,
num_heads=8,
output_dim=pooler_out_dim,
)
# Dimension of the extra input vectors
self.extra_in_dim = pooler_out_dim
if args.size_cond:
# Image size and crop size conditions
self.extra_in_dim += 6 * 256
if args.use_style_cond:
# Here we use a default learned embedder layer for future extension.
self.style_embedder = nn.Embedding(1, hidden_size)
self.extra_in_dim += hidden_size
# Text embedding for `add`
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
self.extra_embedder = nn.Sequential(
nn.Linear(self.extra_in_dim, hidden_size * 4),
FP32_SiLU(),
nn.Linear(hidden_size * 4, hidden_size, bias=True),
)
# Image embedding
num_patches = self.x_embedder.num_patches
log_fn(f" Number of tokens: {num_patches}")
# HUnYuanDiT Blocks
self.blocks = nn.ModuleList(
[
HunYuanDiTBlock(
hidden_size=hidden_size,
c_emb_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
text_states_dim=self.text_states_dim,
use_flash_attn=use_flash_attn,
qk_norm=qk_norm,
norm_type=self.norm,
skip=layer > depth // 2,
is_ipa=self.is_ipa,
)
for layer in range(depth)
]
)
self.final_layer = FinalLayer(
hidden_size, hidden_size, patch_size, self.out_channels
)
self.unpatchify_channels = self.out_channels
self.initialize_weights()
def check_condition_validation(self, image_meta_size, style):
if self.args.size_cond is None and image_meta_size is not None:
raise ValueError(
f"When `size_cond` is None, `image_meta_size` should be None, but got "
f"{type(image_meta_size)}. "
)
if self.args.size_cond is not None and image_meta_size is None:
raise ValueError(
f"When `size_cond` is not None, `image_meta_size` should not be None. "
)
if not self.args.use_style_cond and style is not None:
raise ValueError(
f"When `use_style_cond` is False, `style` should be None, but got {type(style)}. "
)
if self.args.use_style_cond and style is None:
raise ValueError(
f"When `use_style_cond` is True, `style` should be not None."
)
def enable_gradient_checkpointing(self):
for i, block in enumerate(self.blocks):
if i >= (1-self.args.gc_rate)*len(self.blocks):
block.gradient_checkpointing = True
def disable_gradient_checkpointing(self):
for block in self.blocks:
block.gradient_checkpointing = False
def enable_input_requires_grad(self):
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
self.x_embedder.register_forward_hook(make_inputs_require_grad)
self.t_embedder.register_forward_hook(make_inputs_require_grad)
self.extra_embedder.register_forward_hook(make_inputs_require_grad)
if hasattr(self, "style_embedder"):
self.style_embedder.register_forward_hook(make_inputs_require_grad)
def forward(
self,
x,
t,
t_scale=1,
i_scale=1,
encoder_hidden_states=None,
img_clip_embedding=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
cos_cis_img=None,
sin_cis_img=None,
image_meta_size=None,
style=None,
return_dict=True,
controls=None,
):
"""
Forward pass of the encoder.
Parameters
----------
x: torch.Tensor
(B, D, H, W)
t: torch.Tensor
(B)
encoder_hidden_states: torch.Tensor
CLIP text embedding, (B, L_clip, D)
text_embedding_mask: torch.Tensor
CLIP text embedding mask, (B, L_clip)
encoder_hidden_states_t5: torch.Tensor
T5 text embedding, (B, L_t5, D)
text_embedding_mask_t5: torch.Tensor
T5 text embedding mask, (B, L_t5)
image_meta_size: torch.Tensor
(B, 6)
style: torch.Tensor
(B)
cos_cis_img: torch.Tensor
sin_cis_img: torch.Tensor
return_dict: bool
Whether to return a dictionary.
"""
if self.is_ipa:
assert img_clip_embedding is not None
img_clip_embedding = img_clip_embedding.view(x.shape[0], 1024)
img_clip_embedding = self.ip_adapter_clip_len_trans_block(
img_clip_embedding
)
img_clip_embedding = img_clip_embedding.reshape(
x.shape[0], -1, self.img_clip_seq_len
)
img_clip_embedding = img_clip_embedding.transpose(2, 1).contiguous()
img_clip_embedding = self.ip_adapter_img_norm_cond(img_clip_embedding)
text_states = encoder_hidden_states # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
text_states_mask = text_embedding_mask.bool() # 2,77
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
b_t5, l_t5, c_t5 = text_states_t5.shape
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5))
text_states = torch.cat(
[text_states, text_states_t5.view(b_t5, l_t5, -1)], dim=1
) # 2,205,1024
clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
clip_t5_mask = clip_t5_mask
text_states = torch.where(
clip_t5_mask.unsqueeze(2),
text_states,
self.text_embedding_padding.to(text_states),
)
_, _, oh, ow = x.shape
th, tw = oh // self.patch_size, ow // self.patch_size
# ========================= Build time and image embedding =========================
t = self.t_embedder(t)
x = self.x_embedder(x)
# Get image RoPE embedding according to `reso`lution.
freqs_cis_img = (cos_cis_img.to(x.device).contiguous(), sin_cis_img.to(x.device).contiguous())
# freqs_cis_img = (cos_cis_img, sin_cis_img)
# ========================= Concatenate all extra vectors =========================
# Build text tokens with pooling
extra_vec = self.pooler(encoder_hidden_states_t5)
if self.args.size_cond == None:
image_meta_size = None
self.check_condition_validation(image_meta_size, style)
# Build image meta size tokens if applicable
if image_meta_size is not None:
image_meta_size = timestep_embedding(
image_meta_size.view(-1), 256
) # [B * 6, 256]
if self.args.use_fp16:
image_meta_size = image_meta_size.half()
image_meta_size = image_meta_size.view(-1, 6 * 256)
extra_vec = torch.cat(
[extra_vec, image_meta_size], dim=1
) # [B, D + 6 * 256]
# Build style tokens
if style is not None:
style_embedding = self.style_embedder(style)
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
# Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D]
# ========================= Forward pass through HunYuanDiT blocks =========================
skips = []
for layer, block in enumerate(self.blocks):
if layer > self.depth // 2:
if controls is not None:
skip = skips.pop() + controls.pop()
else:
skip = skips.pop()
x = block(
x,
c,
text_states,
img_clip_embedding,
t_scale,
i_scale,
freqs_cis_img,
skip,
self.is_ipa,
) # (N, L, D)
else:
x = block(
x,
c,
text_states,
img_clip_embedding,
t_scale,
i_scale,
freqs_cis_img,
None,
self.is_ipa,
) # (N, L, D)
if layer < (self.depth // 2 - 1):
skips.append(x)
if controls is not None and len(controls) != 0:
raise ValueError(
"The number of controls is not equal to the number of skip connections."
)
# ========================= Final layer =========================
x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels)
x = self.unpatchify(x, th, tw) # (N, out_channels, H, W)
if return_dict:
return {"x": x}
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.extra_embedder[0].weight, std=0.02)
nn.init.normal_(self.extra_embedder[2].weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in HunYuanDiT blocks:
for block in self.blocks:
nn.init.constant_(block.default_modulation[-1].weight, 0)
nn.init.constant_(block.default_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.unpatchify_channels
p = self.x_embedder.patch_size[0]
# h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
def _replace_module(self, parent, child_name, new_module, child) -> None:
setattr(parent, child_name, new_module)
# It's not necessary to set requires_grad here, as that is handled by
# _mark_only_adapters_as_trainable
# child layer wraps the original module, unpack it
if hasattr(child, "base_layer"):
child = child.get_base_layer()
elif hasattr(child, "quant_linear_module"):
# TODO maybe not necessary to have special treatment?
child = child.quant_linear_module
if not hasattr(new_module, "base_layer"):
new_module.weight = child.weight
if hasattr(child, "bias"):
new_module.bias = child.bias
if getattr(child, "state", None) is not None:
if hasattr(new_module, "base_layer"):
new_module.base_layer.state = child.state
else:
new_module.state = child.state
new_module.to(child.weight.device)
# dispatch to correct device
for name, module in new_module.named_modules():
# if any(prefix in name for prefix in PREFIXES):
# module.to(child.weight.device)
if "ranknum" in name:
module.to(child.weight.device)
def merge_and_unload(
self,
merge=True,
progressbar: bool = False,
safe_merge: bool = False,
adapter_names=None,
):
if merge:
if getattr(self, "quantization_method", None) == "gptq":
raise ValueError("Cannot merge layers when the model is gptq quantized")
def merge_recursively(module):
# helper function to recursively merge the base_layer of the target
path = []
layer = module
while hasattr(layer, "base_layer"):
path.append(layer)
layer = layer.base_layer
for layer_before, layer_after in zip(path[:-1], path[1:]):
layer_after.merge(safe_merge=safe_merge, adapter_names=adapter_names)
layer_before.base_layer = layer_after.base_layer
module.merge(safe_merge=safe_merge, adapter_names=adapter_names)
key_list = [key for key, _ in self.named_modules()]
desc = "Unloading " + ("and merging " if merge else "") + "model"
for key in tqdm(key_list, disable=not progressbar, desc=desc):
try:
parent, target, target_name = _get_submodules(self, key)
except AttributeError:
continue
if hasattr(target, "base_layer"):
if merge:
merge_recursively(target)
self._replace_module(
parent, target_name, target.get_base_layer(), target
)
elif isinstance(target, ModulesToSaveWrapper):
# save any additional trainable modules part of `modules_to_save`
new_module = target.modules_to_save[target.active_adapter]
if hasattr(new_module, "base_layer"):
# check if the module is itself a tuner layer
if merge:
new_module.merge(
safe_merge=safe_merge, adapter_names=adapter_names
)
new_module = new_module.get_base_layer()
setattr(parent, target_name, new_module)
#################################################################################
# HunYuanDiT Configs #
#################################################################################
HUNYUAN_DIT_CONFIG = {
"DiT-g/2": {
"depth": 40,
"hidden_size": 1408,
"patch_size": 2,
"num_heads": 16,
"mlp_ratio": 4.3637,
},
"DiT-XL/2": {"depth": 28, "hidden_size": 1152, "patch_size": 2, "num_heads": 16},
}
def DiT_g_2(args, **kwargs):
return HunYuanDiT(
args,
depth=40,
hidden_size=1408,
patch_size=2,
num_heads=16,
mlp_ratio=4.3637,
**kwargs,
)
def DiT_XL_2(args, **kwargs):
return HunYuanDiT(
args, depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs
)
HUNYUAN_DIT_MODELS = {
"DiT-g/2": DiT_g_2,
"DiT-XL/2": DiT_XL_2,
}
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, eps=1e-5, dtype=None):
super().__init__(
num_groups=num_groups, num_channels=num_channels, eps=eps, dtype=dtype
)
def forward(self, x):
y = super().forward(x).to(x.dtype)
return y
def normalization(channels, dtype=None):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(num_channels=channels, num_groups=32, dtype=dtype)
import torch
import torch.nn as nn
import torch.nn.functional as F
class AttentionPool(nn.Module):
def __init__(
self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
):
super().__init__()
self.positional_embedding = nn.Parameter(
torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5
)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.permute(1, 0, 2) # NLC -> LNC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1],
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat(
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False,
)
return x.squeeze(0)
import torch
import numpy as np
from typing import Union
def _to_tuple(x):
if isinstance(x, int):
return x, x
else:
return x
def get_fill_resize_and_crop(src, tgt):
th, tw = _to_tuple(tgt)
h, w = _to_tuple(src)
tr = th / tw # base resolution
r = h / w # target resolution
# resize
if r > tr:
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(
round(tw / w * h)
) # resize the target resolution down based on the base resolution
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def get_meshgrid(start, *args):
if len(args) == 0:
# start is grid_size
num = _to_tuple(start)
start = (0, 0)
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start)
stop = _to_tuple(args[0])
num = (stop[0] - start[0], stop[1] - start[1])
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start)
stop = _to_tuple(args[0])
num = _to_tuple(args[1])
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0) # [2, W, H]
return grid
#################################################################################
# Sine/Cosine Positional Embedding Functions #
#################################################################################
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid = get_meshgrid(start, *args) # [2, H, w]
# grid_h = np.arange(grid_size, dtype=np.float32)
# grid_w = np.arange(grid_size, dtype=np.float32)
# grid = np.meshgrid(grid_w, grid_h) # here w goes first
# grid = np.stack(grid, axis=0) # [2, W, H]
grid = grid.reshape([2, 1, *grid.shape[1:]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate(
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (W,H)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
#################################################################################
# Rotary Positional Embedding Functions #
#################################################################################
# https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443
def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
"""
This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure.
Parameters
----------
embed_dim: int
embedding dimension size
start: int or tuple of int
If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1;
If len(args) == 2, start is start, args[0] is stop, args[1] is num.
use_real: bool
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns
-------
pos_embed: torch.Tensor
[HW, D/2]
"""
grid = get_meshgrid(start, *args) # [2, H, w]
grid = grid.reshape(
[2, 1, *grid.shape[1:]]
) # Returns a sampling matrix with the same resolution as the target resolution
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
return pos_embed
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
assert embed_dim % 4 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_rotary_pos_embed(
embed_dim // 2, grid[0].reshape(-1), use_real=use_real
) # (H*W, D/4)
emb_w = get_1d_rotary_pos_embed(
embed_dim // 2, grid[1].reshape(-1), use_real=use_real
) # (H*W, D/4)
if use_real:
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
return cos, sin
else:
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
return emb
def get_1d_rotary_pos_embed(
dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
if isinstance(pos, int):
pos = np.arange(pos)
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
) # [D/2]
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(
torch.ones_like(freqs), freqs
) # complex64 # [S, D/2]
return freqs_cis
def calc_sizes(rope_img, patch_size, th, tw):
if rope_img == "extend":
# Expansion mode
sub_args = [(th, tw)]
elif rope_img.startswith("base"):
# Based on the specified dimensions, other dimensions are obtained through interpolation.
base_size = int(rope_img[4:]) // 8 // patch_size
start, stop = get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
else:
raise ValueError(f"Unknown rope_img: {rope_img}")
return sub_args
def init_image_posemb(
rope_img,
resolutions,
patch_size,
hidden_size,
num_heads,
log_fn,
rope_real=True,
):
freqs_cis_img = {}
for reso in resolutions:
th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size
sub_args = calc_sizes(rope_img, patch_size, th, tw)
freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(
hidden_size // num_heads, *sub_args, use_real=rope_real
)
log_fn(
f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) "
f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}"
)
return freqs_cis_img
import torch
import torch.nn as nn
from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration
class MT5Embedder(nn.Module):
available_models = ["t5-v1_1-xxl"]
def __init__(
self,
model_dir="t5-v1_1-xxl",
model_kwargs=None,
torch_dtype=None,
use_tokenizer_only=False,
conditional_generation=False,
max_length=128,
):
super().__init__()
self.device = "cpu"
self.torch_dtype = torch_dtype or torch.bfloat16
self.max_length = max_length
if model_kwargs is None:
model_kwargs = {
# "low_cpu_mem_usage": True,
"torch_dtype": self.torch_dtype,
}
model_kwargs["device_map"] = {"shared": self.device, "encoder": self.device}
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
if use_tokenizer_only:
return
if conditional_generation:
self.model = None
self.generation_model = T5ForConditionalGeneration.from_pretrained(
model_dir
)
return
self.model = (
T5EncoderModel.from_pretrained(model_dir, **model_kwargs)
.eval()
.to(self.torch_dtype)
)
def get_tokens_and_mask(self, texts):
text_tokens_and_mask = self.tokenizer(
texts,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
tokens = text_tokens_and_mask["input_ids"][0]
mask = text_tokens_and_mask["attention_mask"][0]
# tokens = torch.tensor(tokens).clone().detach()
# mask = torch.tensor(mask, dtype=torch.bool).clone().detach()
return tokens, mask
def get_text_embeddings(self, texts, attention_mask=True, layer_index=-1):
text_tokens_and_mask = self.tokenizer(
texts,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
with torch.no_grad():
outputs = self.model(
input_ids=text_tokens_and_mask["input_ids"].to(self.device),
attention_mask=(
text_tokens_and_mask["attention_mask"].to(self.device)
if attention_mask
else None
),
output_hidden_states=True,
)
text_encoder_embs = outputs["hidden_states"][layer_index].detach()
return text_encoder_embs, text_tokens_and_mask["attention_mask"].to(self.device)
@torch.no_grad()
def __call__(self, tokens, attention_mask, layer_index=-1):
with torch.cuda.amp.autocast():
outputs = self.model(
input_ids=tokens,
attention_mask=attention_mask,
output_hidden_states=True,
)
z = outputs.hidden_states[layer_index].detach()
return z
def general(self, text: str):
# input_ids = input_ids = torch.tensor([list(text.encode("utf-8"))]) + num_special_tokens
input_ids = self.tokenizer(text, max_length=128).input_ids
print(input_ids)
outputs = self.generation_model(input_ids)
return outputs
#
# Copyright 2022 The HuggingFace Inc. team.
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from collections import OrderedDict
from copy import copy
import numpy as np
import tensorrt as trt
import torch
from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.trt import CreateConfig, Profile
from polygraphy.backend.trt import (
engine_from_bytes,
engine_from_network,
network_from_onnx_path,
save_engine,
)
from polygraphy.backend.trt import util as trt_util
import ctypes
from glob import glob
from cuda import cudart
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
trt_util.TRT_LOGGER = TRT_LOGGER
class Engine:
def __init__(
self,
model_name,
engine_dir,
onnx_file=None,
):
self.engine_path = os.path.join(engine_dir, model_name + ".plan")
self.engine = None
self.context = None
self.buffers = OrderedDict()
self.tensors = OrderedDict()
self.weightNameList = None
self.refitter = None
self.onnx_initializers = None
self.onnx_file = onnx_file
self.trt_lora_weight = None
self.trt_lora_weight_mem = None
self.torch_weight = None
def __del__(self):
del self.engine
del self.context
del self.buffers
del self.tensors
def build(
self,
onnx_path,
fp16,
input_profile=None,
enable_preview=False,
sparse_weights=False,
):
print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
p = Profile()
if input_profile:
for name, dims in input_profile.items():
assert len(dims) == 3
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
preview_features = []
if enable_preview:
trt_version = [int(i) for i in trt.__version__.split(".")]
# FASTER_DYNAMIC_SHAPES_0805 should only be used for TRT 8.5.1 or above.
if trt_version[0] > 8 or (
trt_version[0] == 8
and (
trt_version[1] > 5 or (trt_version[1] == 5 and trt_version[2] >= 1)
)
):
preview_features = [trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805]
engine = engine_from_network(
network_from_onnx_path(onnx_path),
config=CreateConfig(
fp16=fp16,
profiles=[p],
preview_features=preview_features,
sparse_weights=sparse_weights,
),
)
save_engine(engine, path=self.engine_path)
def activate(self, plugin_path=""):
ctypes.cdll.LoadLibrary(plugin_path)
self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
self.context = self.engine.create_execution_context()
def get_shared_memory(self):
_, device_memory = cudart.cudaMalloc(self.engine.device_memory_size)
self.device_memory = device_memory
return self.device_memory
def set_shared_memory(self, device_memory_size):
self.context.device_memory = device_memory_size
def binding_input(self, name, shape):
idx = self.engine.get_binding_index(name)
result = self.context.set_binding_shape(idx, shape)
return result
def allocate_buffers(self, shape_dict=None, device="cuda"):
print("Allocate buffers and bindings inputs:")
for idx in range(trt_util.get_bindings_per_profile(self.engine)):
binding = self.engine[idx]
print("binding: ", binding)
if shape_dict and binding in shape_dict:
shape = shape_dict[binding]
else:
shape = self.engine.get_binding_shape(binding)
nv_dtype = self.engine.get_binding_dtype(binding)
dtype_map = {
trt.DataType.FLOAT: np.float32,
trt.DataType.HALF: np.float16,
trt.DataType.INT8: np.int8,
trt.DataType.INT64: np.int64,
trt.DataType.BOOL: bool,
}
if hasattr(trt.DataType, "INT32"):
dtype_map[trt.DataType.INT32] = np.int32
dtype = dtype_map[nv_dtype]
if self.engine.binding_is_input(binding):
self.context.set_binding_shape(idx, shape)
# Workaround to convert np dtype to torch
np_type_tensor = np.empty(shape=[], dtype=dtype)
torch_type_tensor = torch.from_numpy(np_type_tensor)
tensor = torch.empty(tuple(shape), dtype=torch_type_tensor.dtype).to(
device=device
)
print(f" binding={binding}, shape={shape}, dtype={tensor.dtype}")
self.tensors[binding] = tensor
self.buffers[binding] = cuda.DeviceView(
ptr=tensor.data_ptr(), shape=shape, dtype=dtype
)
def infer(self, feed_dict, stream):
start_binding, end_binding = trt_util.get_active_profile_bindings(self.context)
# shallow copy of ordered dict
device_buffers = copy(self.buffers)
for name, buf in feed_dict.items():
assert isinstance(buf, cuda.DeviceView)
device_buffers[name] = buf
self.binding_input(name, buf.shape)
bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()]
noerror = self.context.execute_async_v2(
bindings=bindings, stream_handle=stream.ptr
)
if not noerror:
raise ValueError(f"ERROR: inference failed.")
for idx in range(trt_util.get_bindings_per_profile(self.engine)):
binding = self.engine[idx]
if not self.engine.binding_is_input(binding):
shape = self.context.get_binding_shape(idx)
self.tensors[binding].resize_(tuple(shape))
return self.tensors
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from polygraphy import cuda
from .engine import Engine
class TRTModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
in_channels=4,
model_name="unet-dyn",
engine_dir="./unet",
device_id=0,
fp16=True,
image_width=1024,
image_height=1024,
text_maxlen=77,
embedding_dim=768,
max_batch_size=1,
plugin_path="./ckpts/trt_model/fmha_plugins/10.1_plugin_cuda11/fMHAPlugin.so",
):
super().__init__()
# create engine
self.in_channels = in_channels # For pipeline compatibility
self.fp16 = fp16
self.max_batch_size = max_batch_size
self.model_name = model_name
self.engine_dir = engine_dir
self.engine = Engine(self.model_name, self.engine_dir)
self.engine.activate(plugin_path)
# create cuda stream
self.stream = torch.cuda.Stream().cuda_stream
self.latent_width = image_width // 8
self.latent_height = image_height // 8
self.text_maxlen = text_maxlen
self.embedding_dim = embedding_dim
device = "cuda:{}".format(device_id)
self.engine_device = torch.device(device)
print("[INFO] Create hcf nv controlled unet success")
@property
def device(self):
return self.engine_device
def __call__(
self,
x,
t_emb,
context,
image_meta_size,
style,
freqs_cis_img0,
freqs_cis_img1,
text_embedding_mask,
encoder_hidden_states_t5,
text_embedding_mask_t5,
):
return self.forward(
x=x,
t_emb=t_emb,
context=context,
image_meta_size=image_meta_size,
style=style,
freqs_cis_img0=freqs_cis_img0,
freqs_cis_img1=freqs_cis_img1,
text_embedding_mask=text_embedding_mask,
encoder_hidden_states_t5=encoder_hidden_states_t5,
text_embedding_mask_t5=text_embedding_mask_t5,
)
def get_shared_memory(self):
return self.engine.get_shared_memory()
def set_shared_memory(self, shared_memory):
self.engine.set_shared_memory(shared_memory)
def forward(
self,
x,
t_emb,
context,
image_meta_size,
style,
freqs_cis_img0,
freqs_cis_img1,
text_embedding_mask,
encoder_hidden_states_t5,
text_embedding_mask_t5,
):
x_c = x.half()
t_emb_c = t_emb.half()
context_c = context.half()
if image_meta_size is not None:
image_meta_size_c = image_meta_size.half().contiguous()
self.engine.context.set_input_shape(
"image_meta_size", image_meta_size_c.shape
)
self.engine.context.set_tensor_address(
"image_meta_size", image_meta_size_c.contiguous().data_ptr()
)
if style is not None:
style_c = style.long().contiguous()
self.engine.context.set_input_shape("style", style_c.shape)
self.engine.context.set_tensor_address(
"style", style_c.contiguous().data_ptr()
)
freqs_cis_img0_c = freqs_cis_img0.float()
freqs_cis_img1_c = freqs_cis_img1.float()
text_embedding_mask_c = text_embedding_mask.long()
encoder_hidden_states_t5_c = encoder_hidden_states_t5.half()
text_embedding_mask_t5_c = text_embedding_mask_t5.long()
self.engine.context.set_input_shape("x", x_c.shape)
self.engine.context.set_input_shape("t", t_emb_c.shape)
self.engine.context.set_input_shape("encoder_hidden_states", context_c.shape)
self.engine.context.set_input_shape(
"text_embedding_mask", text_embedding_mask_c.shape
)
self.engine.context.set_input_shape(
"encoder_hidden_states_t5", encoder_hidden_states_t5_c.shape
)
self.engine.context.set_input_shape(
"text_embedding_mask_t5", text_embedding_mask_t5_c.shape
)
self.engine.context.set_input_shape("cos_cis_img", freqs_cis_img0_c.shape)
self.engine.context.set_input_shape("sin_cis_img", freqs_cis_img1_c.shape)
self.engine.context.set_tensor_address("x", x_c.contiguous().data_ptr())
self.engine.context.set_tensor_address("t", t_emb_c.contiguous().data_ptr())
self.engine.context.set_tensor_address(
"encoder_hidden_states", context_c.contiguous().data_ptr()
)
self.engine.context.set_tensor_address(
"text_embedding_mask", text_embedding_mask_c.contiguous().data_ptr()
)
self.engine.context.set_tensor_address(
"encoder_hidden_states_t5",
encoder_hidden_states_t5_c.contiguous().data_ptr(),
)
self.engine.context.set_tensor_address(
"text_embedding_mask_t5", text_embedding_mask_t5_c.contiguous().data_ptr()
)
self.engine.context.set_tensor_address(
"cos_cis_img", freqs_cis_img0_c.contiguous().data_ptr()
)
self.engine.context.set_tensor_address(
"sin_cis_img", freqs_cis_img1_c.contiguous().data_ptr()
)
output = torch.zeros(
(2 * self.max_batch_size, 8, self.latent_height, self.latent_width),
dtype=torch.float16,
device="cuda",
)
self.engine.context.set_tensor_address("output", output.contiguous().data_ptr())
self.engine.context.execute_async_v3(self.stream)
torch.cuda.synchronize()
output.resize_(tuple(self.engine.context.get_tensor_shape("output")))
return output
model='DiT-g/2'
params=" \
--qk-norm \
--model ${model} \
--rope-img base512 \
--rope-real \
"
deepspeed hydit/train_deepspeed.py ${params} "$@"
\ No newline at end of file
model='DiT-g/2'
params=" \
--qk-norm \
--model ${model} \
--rope-img base512 \
--rope-real \
"
deepspeed hydit/train_deepspeed_controlnet.py ${params} "$@"
\ No newline at end of file
model='DiT-g/2'
params=" \
--qk-norm \
--model ${model} \
--rope-img base512 \
--rope-real \
"
deepspeed hydit/train_deepspeed_ipadapter.py ${params} "$@"
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment