Commit f91d2ea3 authored by mashun1's avatar mashun1
Browse files

hunyuandit

parents
import math
import torch
import torch.nn as nn
from einops import repeat
from timm.models.layers import to_2tuple
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
Image to Patch Embedding using Conv2d
A convolution based approach to patchifying a 2D image w/ embedding projection.
Based on the impl in https://github.com/google-research/vision_transformer
Hacked together by / Copyright 2020 Ross Wightman
Remove the _assert function in forward function to be compatible with multi-resolution images.
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
):
super().__init__()
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
elif isinstance(img_size, (tuple, list)) and len(img_size) == 2:
img_size = tuple(img_size)
else:
raise ValueError(f"img_size must be int or tuple/list of length 2. Got {img_size}")
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def update_image_size(self, img_size):
self.img_size = img_size
self.grid_size = (img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
def forward(self, x):
# B, C, H, W = x.shape
# _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
# _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
else:
embedding = repeat(t, "b -> b d", d=dim)
return embedding
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256, out_size=None):
super().__init__()
if out_size is None:
out_size = hidden_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, out_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
def forward(self, t):
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
t_emb = self.mlp(t_freq)
return t_emb
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from timm.models.vision_transformer import Mlp
from .attn_layers import Attention, FlashCrossMHAModified, FlashSelfMHAModified, CrossAttention
from .embedders import TimestepEmbedder, PatchEmbed, timestep_embedding
from .norm_layers import RMSNorm
from .poolers import AttentionPool
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class FP32_Layernorm(nn.LayerNorm):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(),
self.eps).to(origin_dtype)
class FP32_SiLU(nn.SiLU):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.silu(inputs.float(), inplace=False).to(inputs.dtype)
class HunYuanDiTBlock(nn.Module):
"""
A HunYuanDiT block with `add` conditioning.
"""
def __init__(self,
hidden_size,
c_emb_size,
num_heads,
mlp_ratio=4.0,
text_states_dim=1024,
use_flash_attn=False,
qk_norm=False,
norm_type="layer",
skip=False,
):
super().__init__()
self.use_flash_attn = use_flash_attn
use_ele_affine = True
if norm_type == "layer":
norm_layer = FP32_Layernorm
elif norm_type == "rms":
norm_layer = RMSNorm
else:
raise ValueError(f"Unknown norm_type: {norm_type}")
# ========================= Self-Attention =========================
self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
if use_flash_attn:
self.attn1 = FlashSelfMHAModified(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm)
else:
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm)
# ========================= FFN =========================
self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
# ========================= Add =========================
# Simply use add like SDXL.
self.default_modulation = nn.Sequential(
FP32_SiLU(),
nn.Linear(c_emb_size, hidden_size, bias=True)
)
# ========================= Cross-Attention =========================
if use_flash_attn:
self.attn2 = FlashCrossMHAModified(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
qk_norm=qk_norm)
else:
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
qk_norm=qk_norm)
self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6)
# ========================= Skip Connection =========================
if skip:
self.skip_norm = norm_layer(2 * hidden_size, elementwise_affine=True, eps=1e-6)
self.skip_linear = nn.Linear(2 * hidden_size, hidden_size)
else:
self.skip_linear = None
def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
# Long Skip Connection
if self.skip_linear is not None:
cat = torch.cat([x, skip], dim=-1)
cat = self.skip_norm(cat)
x = self.skip_linear(cat)
# Self-Attention
shift_msa = self.default_modulation(c).unsqueeze(dim=1)
attn_inputs = (
self.norm1(x) + shift_msa, freq_cis_img,
)
x = x + self.attn1(*attn_inputs)[0]
# Cross-Attention
cross_inputs = (
self.norm3(x), text_states, freq_cis_img
)
x = x + self.attn2(*cross_inputs)[0]
# FFN Layer
mlp_inputs = self.norm2(x)
x = x + self.mlp(mlp_inputs)
return x
class FinalLayer(nn.Module):
"""
The final layer of HunYuanDiT.
"""
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
FP32_SiLU(),
nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class HunYuanDiT(ModelMixin, ConfigMixin):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
input_size: tuple
The size of the input image.
patch_size: int
The size of the patch.
in_channels: int
The number of input channels.
hidden_size: int
The hidden size of the transformer backbone.
depth: int
The number of transformer blocks.
num_heads: int
The number of attention heads.
mlp_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
log_fn: callable
The logging function.
"""
@register_to_config
def __init__(
self, args,
input_size=(32, 32),
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
log_fn=print,
):
super().__init__()
self.args = args
self.log_fn = log_fn
self.depth = depth
self.learn_sigma = args.learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if args.learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.hidden_size = hidden_size
self.text_states_dim = args.text_states_dim
self.text_states_dim_t5 = args.text_states_dim_t5
self.text_len = args.text_len
self.text_len_t5 = args.text_len_t5
self.norm = args.norm
use_flash_attn = args.infer_mode == 'fa'
if use_flash_attn:
log_fn(f" Enable Flash Attention.")
qk_norm = True # See http://arxiv.org/abs/2302.05442 for details.
self.mlp_t5 = nn.Sequential(
nn.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True),
FP32_SiLU(),
nn.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True),
)
# learnable replace
self.text_embedding_padding = nn.Parameter(
torch.randn(self.text_len + self.text_len_t5, self.text_states_dim, dtype=torch.float32))
# Attention pooling
self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=1024)
# Here we use a default learned embedder layer for future extension.
self.style_embedder = nn.Embedding(1, hidden_size)
# Image size and crop size conditions
self.extra_in_dim = 256 * 6 + hidden_size
# Text embedding for `add`
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
self.extra_in_dim += 1024
self.extra_embedder = nn.Sequential(
nn.Linear(self.extra_in_dim, hidden_size * 4),
FP32_SiLU(),
nn.Linear(hidden_size * 4, hidden_size, bias=True),
)
# Image embedding
num_patches = self.x_embedder.num_patches
log_fn(f" Number of tokens: {num_patches}")
# HUnYuanDiT Blocks
self.blocks = nn.ModuleList([
HunYuanDiTBlock(hidden_size=hidden_size,
c_emb_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
text_states_dim=self.text_states_dim,
use_flash_attn=use_flash_attn,
qk_norm=qk_norm,
norm_type=self.norm,
skip=layer > depth // 2,
)
for layer in range(depth)
])
self.final_layer = FinalLayer(hidden_size, hidden_size, patch_size, self.out_channels)
self.unpatchify_channels = self.out_channels
self.initialize_weights()
def forward(self,
x,
t,
encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
cos_cis_img=None,
sin_cis_img=None,
return_dict=True,
):
"""
Forward pass of the encoder.
Parameters
----------
x: torch.Tensor
(B, D, H, W)
t: torch.Tensor
(B)
encoder_hidden_states: torch.Tensor
CLIP text embedding, (B, L_clip, D)
text_embedding_mask: torch.Tensor
CLIP text embedding mask, (B, L_clip)
encoder_hidden_states_t5: torch.Tensor
T5 text embedding, (B, L_t5, D)
text_embedding_mask_t5: torch.Tensor
T5 text embedding mask, (B, L_t5)
image_meta_size: torch.Tensor
(B, 6)
style: torch.Tensor
(B)
cos_cis_img: torch.Tensor
sin_cis_img: torch.Tensor
return_dict: bool
Whether to return a dictionary.
"""
text_states = encoder_hidden_states # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
text_states_mask = text_embedding_mask.bool() # 2,77
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
b_t5, l_t5, c_t5 = text_states_t5.shape
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5))
text_states = torch.cat([text_states, text_states_t5.view(b_t5, l_t5, -1)], dim=1) # 2,205,1024
clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
clip_t5_mask = clip_t5_mask
text_states = torch.where(clip_t5_mask.unsqueeze(2), text_states, self.text_embedding_padding.to(text_states))
_, _, oh, ow = x.shape
th, tw = oh // self.patch_size, ow // self.patch_size
# ========================= Build time and image embedding =========================
t = self.t_embedder(t)
x = self.x_embedder(x)
# Get image RoPE embedding according to `reso`lution.
freqs_cis_img = (cos_cis_img, sin_cis_img)
# ========================= Concatenate all extra vectors =========================
# Build text tokens with pooling
extra_vec = self.pooler(encoder_hidden_states_t5)
# Build image meta size tokens
image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256]
if self.args.use_fp16:
image_meta_size = image_meta_size.half()
image_meta_size = image_meta_size.view(-1, 6 * 256)
extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
# Build style tokens
style_embedding = self.style_embedder(style)
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
# Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D]
# ========================= Forward pass through HunYuanDiT blocks =========================
skips = []
for layer, block in enumerate(self.blocks):
if layer > self.depth // 2:
skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
else:
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
if layer < (self.depth // 2 - 1):
skips.append(x)
# ========================= Final layer =========================
x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels)
x = self.unpatchify(x, th, tw) # (N, out_channels, H, W)
if return_dict:
return {'x': x}
return x
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.extra_embedder[0].weight, std=0.02)
nn.init.normal_(self.extra_embedder[2].weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in HunYuanDiT blocks:
for block in self.blocks:
nn.init.constant_(block.default_modulation[-1].weight, 0)
nn.init.constant_(block.default_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.unpatchify_channels
p = self.x_embedder.patch_size[0]
# h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
#################################################################################
# HunYuanDiT Configs #
#################################################################################
HUNYUAN_DIT_CONFIG = {
'DiT-g/2': {'depth': 40, 'hidden_size': 1408, 'patch_size': 2, 'num_heads': 16, 'mlp_ratio': 4.3637},
'DiT-XL/2': {'depth': 28, 'hidden_size': 1152, 'patch_size': 2, 'num_heads': 16},
'DiT-L/2': {'depth': 24, 'hidden_size': 1024, 'patch_size': 2, 'num_heads': 16},
'DiT-B/2': {'depth': 12, 'hidden_size': 768, 'patch_size': 2, 'num_heads': 12},
}
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
class GroupNorm32(nn.GroupNorm):
def __init__(self, num_groups, num_channels, eps=1e-5, dtype=None):
super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, dtype=dtype)
def forward(self, x):
y = super().forward(x).to(x.dtype)
return y
def normalization(channels, dtype=None):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(num_channels=channels, num_groups=32, dtype=dtype)
import torch
import torch.nn as nn
import torch.nn.functional as F
class AttentionPool(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.permute(1, 0, 2) # NLC -> LNC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1], key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x.squeeze(0)
import torch
import numpy as np
from typing import Union
def _to_tuple(x):
if isinstance(x, int):
return x, x
else:
return x
def get_fill_resize_and_crop(src, tgt): # src 来源的分辨率 tgt base 分辨率
th, tw = _to_tuple(tgt)
h, w = _to_tuple(src)
tr = th / tw # base 分辨率
r = h / w # 目标分辨率
# resize
if r > tr:
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def get_meshgrid(start, *args):
if len(args) == 0:
# start is grid_size
num = _to_tuple(start)
start = (0, 0)
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start)
stop = _to_tuple(args[0])
num = (stop[0] - start[0], stop[1] - start[1])
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start) # 左上角 eg: 12,0
stop = _to_tuple(args[0]) # 右下角 eg: 20,32
num = _to_tuple(args[1]) # 目标大小 eg: 32,124
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0) # [2, W, H]
return grid
#################################################################################
# Sine/Cosine Positional Embedding Functions #
#################################################################################
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid = get_meshgrid(start, *args) # [2, H, w]
# grid_h = np.arange(grid_size, dtype=np.float32)
# grid_w = np.arange(grid_size, dtype=np.float32)
# grid = np.meshgrid(grid_w, grid_h) # here w goes first
# grid = np.stack(grid, axis=0) # [2, W, H]
grid = grid.reshape([2, 1, *grid.shape[1:]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (W,H)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
#################################################################################
# Rotary Positional Embedding Functions #
#################################################################################
# https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443
def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
"""
This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure.
Parameters
----------
embed_dim: int
embedding dimension size
start: int or tuple of int
If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1;
If len(args) == 2, start is start, args[0] is stop, args[1] is num.
use_real: bool
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns
-------
pos_embed: torch.Tensor
[HW, D/2]
"""
grid = get_meshgrid(start, *args) # [2, H, w]
grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
return pos_embed
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
assert embed_dim % 4 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
if use_real:
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
return cos, sin
else:
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
return emb
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
if isinstance(pos, int):
pos = np.arange(pos)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
def calc_sizes(rope_img, patch_size, th, tw):
""" 计算 RoPE 的尺寸. """
if rope_img == 'extend':
# 拓展模式
sub_args = [(th, tw)]
elif rope_img.startswith('base'):
# 基于一个尺寸, 其他尺寸插值获得.
base_size = int(rope_img[4:]) // 8 // patch_size # 基于512作为base,其他根据512差值得到
start, stop = get_fill_resize_and_crop((th, tw), base_size) # 需要在32x32里面 crop的左上角和右下角
sub_args = [start, stop, (th, tw)]
else:
raise ValueError(f"Unknown rope_img: {rope_img}")
return sub_args
def init_image_posemb(rope_img,
resolutions,
patch_size,
hidden_size,
num_heads,
log_fn,
rope_real=True,
):
freqs_cis_img = {}
for reso in resolutions:
th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size
sub_args = calc_sizes(rope_img, patch_size, th, tw) # [左上角, 右下角, 目标高宽] 需要在32x32里面 crop的左上角和右下角
freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real)
log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) "
f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}")
return freqs_cis_img
import torch
import torch.nn as nn
from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration
class MT5Embedder(nn.Module):
available_models = ["t5-v1_1-xxl"]
def __init__(
self,
model_dir="t5-v1_1-xxl",
model_kwargs=None,
torch_dtype=None,
use_tokenizer_only=False,
conditional_generation=False,
max_length=128,
):
super().__init__()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch_dtype or torch.bfloat16
self.max_length = max_length
if model_kwargs is None:
model_kwargs = {
# "low_cpu_mem_usage": True,
"torch_dtype": self.torch_dtype,
}
model_kwargs["device_map"] = {"shared": self.device, "encoder": self.device}
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
if use_tokenizer_only:
return
if conditional_generation:
self.model = None
self.generation_model = T5ForConditionalGeneration.from_pretrained(
model_dir
)
return
self.model = T5EncoderModel.from_pretrained(model_dir, **model_kwargs).eval().to(self.torch_dtype)
def get_tokens_and_mask(self, texts):
text_tokens_and_mask = self.tokenizer(
texts,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
tokens = text_tokens_and_mask["input_ids"][0]
mask = text_tokens_and_mask["attention_mask"][0]
# tokens = torch.tensor(tokens).clone().detach()
# mask = torch.tensor(mask, dtype=torch.bool).clone().detach()
return tokens, mask
def get_text_embeddings(self, texts, attention_mask=True, layer_index=-1):
text_tokens_and_mask = self.tokenizer(
texts,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
with torch.no_grad():
outputs = self.model(
input_ids=text_tokens_and_mask["input_ids"].to(self.device),
attention_mask=text_tokens_and_mask["attention_mask"].to(self.device)
if attention_mask
else None,
output_hidden_states=True,
)
text_encoder_embs = outputs["hidden_states"][layer_index].detach()
return text_encoder_embs, text_tokens_and_mask["attention_mask"].to(self.device)
@torch.no_grad()
def __call__(self, tokens, attention_mask, layer_index=-1):
with torch.cuda.amp.autocast():
outputs = self.model(
input_ids=tokens,
attention_mask=attention_mask,
output_hidden_states=True,
)
z = outputs.hidden_states[layer_index].detach()
return z
def general(self, text: str):
# input_ids = input_ids = torch.tensor([list(text.encode("utf-8"))]) + num_special_tokens
input_ids = self.tokenizer(text, max_length=128).input_ids
print(input_ids)
outputs = self.generation_model(input_ids)
return outputs
\ No newline at end of file
#
# Copyright 2022 The HuggingFace Inc. team.
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from collections import OrderedDict
from copy import copy
import numpy as np
import tensorrt as trt
import torch
from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.trt import CreateConfig, Profile
from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine
from polygraphy.backend.trt import util as trt_util
import ctypes
from glob import glob
from cuda import cudart
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
trt_util.TRT_LOGGER = TRT_LOGGER
class Engine():
def __init__(
self,
model_name,
engine_dir,
onnx_file=None,
):
self.engine_path = os.path.join(engine_dir, model_name + '.plan')
self.engine = None
self.context = None
self.buffers = OrderedDict()
self.tensors = OrderedDict()
self.weightNameList = None
self.refitter = None
self.onnx_initializers = None
self.onnx_file = onnx_file
self.trt_lora_weight = None
self.trt_lora_weight_mem = None
self.torch_weight = None
def __del__(self):
del self.engine
del self.context
del self.buffers
del self.tensors
def build(self, onnx_path, fp16, input_profile=None, enable_preview=False, sparse_weights=False):
print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
p = Profile()
if input_profile:
for name, dims in input_profile.items():
assert len(dims) == 3
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
preview_features = []
if enable_preview:
trt_version = [int(i) for i in trt.__version__.split(".")]
# FASTER_DYNAMIC_SHAPES_0805 should only be used for TRT 8.5.1 or above.
if trt_version[0] > 8 or \
(trt_version[0] == 8 and (trt_version[1] > 5 or (trt_version[1] == 5 and trt_version[2] >= 1))):
preview_features = [trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805]
engine = engine_from_network(network_from_onnx_path(onnx_path), config=CreateConfig(fp16=fp16, profiles=[p],
preview_features=preview_features,
sparse_weights=sparse_weights))
save_engine(engine, path=self.engine_path)
def activate(self, plugin_path=""):
ctypes.cdll.LoadLibrary(plugin_path)
self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
self.context = self.engine.create_execution_context()
def get_shared_memory(self):
_, device_memory = cudart.cudaMalloc(self.engine.device_memory_size)
self.device_memory = device_memory
return self.device_memory
def set_shared_memory(self, device_memory_size):
self.context.device_memory = device_memory_size
def binding_input(self, name, shape):
idx = self.engine.get_binding_index(name)
result = self.context.set_binding_shape(idx, shape)
return result
def allocate_buffers(self, shape_dict=None, device='cuda'):
print("Allocate buffers and bindings inputs:")
for idx in range(trt_util.get_bindings_per_profile(self.engine)):
binding = self.engine[idx]
print("binding: ", binding)
if shape_dict and binding in shape_dict:
shape = shape_dict[binding]
else:
shape = self.engine.get_binding_shape(binding)
nv_dtype = self.engine.get_binding_dtype(binding)
dtype_map = {trt.DataType.FLOAT: np.float32,
trt.DataType.HALF: np.float16,
trt.DataType.INT8: np.int8,
trt.DataType.INT64: np.int64,
trt.DataType.BOOL: bool}
if hasattr(trt.DataType, 'INT32'):
dtype_map[trt.DataType.INT32] = np.int32
dtype = dtype_map[nv_dtype]
if self.engine.binding_is_input(binding):
self.context.set_binding_shape(idx, shape)
# Workaround to convert np dtype to torch
np_type_tensor = np.empty(shape=[], dtype=dtype)
torch_type_tensor = torch.from_numpy(np_type_tensor)
tensor = torch.empty(tuple(shape), dtype=torch_type_tensor.dtype).to(device=device)
print(f" binding={binding}, shape={shape}, dtype={tensor.dtype}")
self.tensors[binding] = tensor
self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype)
def infer(self, feed_dict, stream):
start_binding, end_binding = trt_util.get_active_profile_bindings(self.context)
# shallow copy of ordered dict
device_buffers = copy(self.buffers)
for name, buf in feed_dict.items():
assert isinstance(buf, cuda.DeviceView)
device_buffers[name] = buf
self.binding_input(name, buf.shape)
bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()]
noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr)
if not noerror:
raise ValueError(f"ERROR: inference failed.")
for idx in range(trt_util.get_bindings_per_profile(self.engine)):
binding = self.engine[idx]
if not self.engine.binding_is_input(binding):
shape = self.context.get_binding_shape(idx)
self.tensors[binding].resize_(tuple(shape))
return self.tensors
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from polygraphy import cuda
from .engine import Engine
class TRTModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
in_channels=4,
model_name="unet-dyn",
engine_dir="./unet",
device_id=0,
fp16=True,
image_width=1024,
image_height=1024,
text_maxlen=77,
embedding_dim=768,
max_batch_size=1,
plugin_path="./ckpts/trt_model/fmha_plugins/9.2_plugin_cuda11/fMHAPlugin.so",
):
super().__init__()
# create engine
self.in_channels = in_channels # For pipeline compatibility
self.fp16 = fp16
self.max_batch_size = max_batch_size
self.model_name = model_name
self.engine_dir = engine_dir
self.engine = Engine(self.model_name, self.engine_dir)
self.engine.activate(plugin_path)
# create cuda stream
self.stream = cuda.Stream()
# create inputs buffer
self.latent_width = image_width // 8
self.latent_height = image_height // 8
self.text_maxlen = text_maxlen
self.embedding_dim = embedding_dim
shape_dict = {
'x': (2 * self.max_batch_size, 4, self.latent_height, self.latent_width),
't': (2 * self.max_batch_size,),
'encoder_hidden_states': (2 * self.max_batch_size, self.text_maxlen, self.embedding_dim),
'text_embedding_mask': (2 * self.max_batch_size, self.text_maxlen),
'encoder_hidden_states_t5': (2 * self.max_batch_size, 256, 2048),
'text_embedding_mask_t5': (2 * self.max_batch_size, 256),
'image_meta_size': (2 * self.max_batch_size, 6),
'style': (2 * self.max_batch_size,),
'cos_cis_img': (6400, 88),
'sin_cis_img': (6400, 88),
'output': (2 * self.max_batch_size, 8, self.latent_height, self.latent_width),
}
device = "cuda:{}".format(device_id)
self.engine_device = torch.device(device)
self.engine.allocate_buffers(shape_dict=shape_dict, device=device)
print("[INFO] Create hcf nv controlled unet success")
@property
def device(self):
return self.engine_device
def __call__(self, x, t_emb, context, image_meta_size, style, freqs_cis_img0,
freqs_cis_img1, text_embedding_mask, encoder_hidden_states_t5, text_embedding_mask_t5):
return self.forward(x=x, t_emb=t_emb, context=context, image_meta_size=image_meta_size, style=style,
freqs_cis_img0=freqs_cis_img0, freqs_cis_img1=freqs_cis_img1,
text_embedding_mask=text_embedding_mask, encoder_hidden_states_t5=encoder_hidden_states_t5,
text_embedding_mask_t5=text_embedding_mask_t5)
def get_shared_memory(self):
return self.engine.get_shared_memory()
def set_shared_memory(self, shared_memory):
self.engine.set_shared_memory(shared_memory)
def forward(self, x, t_emb, context, image_meta_size, style, freqs_cis_img0,
freqs_cis_img1, text_embedding_mask, encoder_hidden_states_t5, text_embedding_mask_t5):
x_c = x.half()
t_emb_c = t_emb.half()
context_c = context.half()
image_meta_size_c = image_meta_size.half()
style_c = style.long()
freqs_cis_img0_c = freqs_cis_img0.float()
freqs_cis_img1_c = freqs_cis_img1.float()
text_embedding_mask_c = text_embedding_mask.long()
encoder_hidden_states_t5_c = encoder_hidden_states_t5.half()
text_embedding_mask_t5_c = text_embedding_mask_t5.long()
dtype = np.float16
batch_size = x.shape[0] // 2
if batch_size <= self.max_batch_size:
sample_inp = cuda.DeviceView(ptr=x_c.reshape(-1).data_ptr(), shape=x_c.shape, dtype=np.float16)
t_emb_inp = cuda.DeviceView(ptr=t_emb_c.reshape(-1).data_ptr(), shape=t_emb_c.shape, dtype=np.float16)
embeddings_inp = cuda.DeviceView(ptr=context_c.reshape(-1).data_ptr(), shape=context_c.shape,
dtype=np.float16)
image_meta_size_inp = cuda.DeviceView(ptr=image_meta_size_c.reshape(-1).data_ptr(),
shape=image_meta_size_c.shape, dtype=np.float16)
style_inp = cuda.DeviceView(ptr=style_c.reshape(-1).data_ptr(), shape=style_c.shape, dtype=np.int64)
freqs_cis_img0_inp = cuda.DeviceView(ptr=freqs_cis_img0_c.reshape(-1).data_ptr(),
shape=freqs_cis_img0_c.shape, dtype=np.float32)
freqs_cis_img1_inp = cuda.DeviceView(ptr=freqs_cis_img1_c.reshape(-1).data_ptr(),
shape=freqs_cis_img1_c.shape, dtype=np.float32)
text_embedding_mask_inp = cuda.DeviceView(ptr=text_embedding_mask_c.reshape(-1).data_ptr(),
shape=text_embedding_mask_c.shape, dtype=np.int64)
encoder_hidden_states_t5_inp = cuda.DeviceView(ptr=encoder_hidden_states_t5_c.reshape(-1).data_ptr(),
shape=encoder_hidden_states_t5_c.shape, dtype=np.float16)
text_embedding_mask_t5_inp = cuda.DeviceView(ptr=text_embedding_mask_t5_c.reshape(-1).data_ptr(),
shape=text_embedding_mask_t5_c.shape, dtype=np.int64)
feed_dict = {"x": sample_inp,
"t": t_emb_inp,
"encoder_hidden_states": embeddings_inp,
"image_meta_size": image_meta_size_inp,
"text_embedding_mask": text_embedding_mask_inp,
"encoder_hidden_states_t5": encoder_hidden_states_t5_inp,
"text_embedding_mask_t5": text_embedding_mask_t5_inp,
"style": style_inp, "cos_cis_img": freqs_cis_img0_inp,
"sin_cis_img": freqs_cis_img1_inp}
latent = self.engine.infer(feed_dict, self.stream)
return latent['output']
else:
raise ValueError(
"[ERROR] Input batch_size={} execeed max_batch_size={}".format(batch_size, self.max_batch_size))
import random
import numpy as np
import torch
def set_seeds(seed_list, device=None):
if isinstance(seed_list, (tuple, list)):
seed = sum(seed_list)
else:
seed = seed_list
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
return torch.Generator(device).manual_seed(seed)
# 模型唯一标识
modelCode=658
# 模型名称
modelName=HunyuanDiT
# 模型描述
modelDescription=支持中文输入的文生图模型。
# 应用场景
appScenario=推理,AIGC,媒体,科研,教育
# 框架类型
frameType=pytorch
# --extra-index-url https://pypi.ngc.nvidia.com
# timm==0.9.5
diffusers==0.21.2
peft==0.10.0
protobuf==3.19.0
# torchvision==0.14.1
transformers==4.37.2
accelerate==0.29.3
loguru==0.7.2
einops==0.7.0
sentencepiece==0.1.99
# cuda-python==11.7.1
# onnxruntime==1.12.1
# onnx==1.12.0
# nvidia-pyindex==1.0.9
# onnx-graphsurgeon==0.3.27
# polygraphy==0.47.1
pandas==2.0.3
gradio==3.50.2
loguru
\ No newline at end of file
from pathlib import Path
from loguru import logger
from dialoggen.dialoggen_demo import DialogGen
from hydit.config import get_args
from hydit.inference import End2End
def inferencer():
args = get_args()
models_root_path = Path(args.model_root)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")
# Load models
gen = End2End(args, models_root_path)
# Try to enhance prompt
if args.enhance:
logger.info("Loading DialogGen model (for prompt enhancement)...")
enhancer = DialogGen(str(models_root_path / "dialoggen"), args.load_4bit)
logger.info("DialogGen model loaded.")
else:
enhancer = None
return args, gen, enhancer
if __name__ == "__main__":
args, gen, enhancer = inferencer()
if enhancer:
logger.info("Prompt Enhancement...")
success, enhanced_prompt = enhancer(args.prompt)
if not success:
logger.info("Sorry, the prompt is not compliant, refuse to draw.")
exit()
logger.info(f"Enhanced prompt: {enhanced_prompt}")
else:
enhanced_prompt = None
# Run inference
logger.info("Generating images...")
height, width = args.image_size
results = gen.predict(args.prompt,
height=height,
width=width,
seed=args.seed,
enhanced_prompt=enhanced_prompt,
negative_prompt=args.negative,
infer_steps=args.infer_steps,
guidance_scale=args.cfg_scale,
batch_size=args.batch_size,
src_size_cond=args.size_cond,
)
images = results['images']
# Save images
save_dir = Path('results')
save_dir.mkdir(exist_ok=True)
# Find the first available index
all_files = list(save_dir.glob('*.png'))
if all_files:
start = max([int(f.stem) for f in all_files]) + 1
else:
start = 0
for idx, pil_img in enumerate(images):
save_path = save_dir / f"{idx + start}.png"
pil_img.save(save_path)
logger.info(f"Save to {save_path}")
# ==============================================================================
# Description: Export ONNX model and build TensorRT engine.
# ==============================================================================
# Check if the model root path is exists or provided.
if [ -z "$1" ]; then
if [ -d "ckpts" ]; then
echo "The model root directory is not provided. Use the default path 'ckpts'."
export MODEL_ROOT=ckpts
else
echo "Default model path 'ckpts' does not exist. Please provide the path of the model root directory."
exit 1
fi
elif [ ! -d "$1" ]; then
echo "The model root directory ($1) does not exist."
exit 1
else
export MODEL_ROOT=$(cd "$1"; pwd)
fi
export ONNX_WORKDIR=${MODEL_ROOT}/onnx_model
echo "MODEL_ROOT=${MODEL_ROOT}"
echo "ONNX_WORKDIR=${ONNX_WORKDIR}"
# Remove old directories.
if [ -d "${ONNX_WORKDIR}" ]; then
echo "Remove old ONNX directories..."
rm -r ${ONNX_WORKDIR}
fi
# Inspect the project directory.
SCRIPT_PATH="$( cd "$( dirname "$0" )" && pwd )"
PROJECT_DIR=$(dirname "$SCRIPT_PATH")
export PYTHONPATH=${PROJECT_DIR}:${PYTHONPATH}
echo "PYTHONPATH=${PYTHONPATH}"
cd ${PROJECT_DIR}
echo "Change directory to ${PROJECT_DIR}"
# ----------------------------------------
# 1. Export ONNX model.
# ----------------------------------------
# Sleep for reading the message.
sleep 2s
echo "Exporting ONNX model..."
python trt/export_onnx.py --model-root ${MODEL_ROOT} --onnx-workdir ${ONNX_WORKDIR}
echo "Exporting ONNX model finished"
# ----------------------------------------
# 2. Build TensorRT engine.
# ----------------------------------------
echo "Building TensorRT engine..."
ENGINE_DIR="${MODEL_ROOT}/t2i/model_trt/engine"
mkdir -p ${ENGINE_DIR}
ENGINE_PATH=${ENGINE_DIR}/model_onnx.plan
PLUGIN_PATH=${MODEL_ROOT}/t2i/model_trt/fmha_plugins/9.2_plugin_cuda11/fMHAPlugin.so
trtexec \
--onnx=${ONNX_WORKDIR}/export_modified_fmha/model.onnx \
--fp16 \
--saveEngine=${ENGINE_PATH} \
--minShapes=x:2x4x90x90,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,image_meta_size:2x6,style:2,cos_cis_img:2025x88,sin_cis_img:2025x88 \
--optShapes=x:2x4x128x128,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,image_meta_size:2x6,style:2,cos_cis_img:4096x88,sin_cis_img:4096x88 \
--maxShapes=x:2x4x160x160,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,image_meta_size:2x6,style:2,cos_cis_img:6400x88,sin_cis_img:6400x88 \
--shapes=x:2x4x128x128,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,image_meta_size:2x6,style:2,cos_cis_img:4096x88,sin_cis_img:4096x88 \
--verbose \
--builderOptimizationLevel=4 \
--staticPlugins=${PLUGIN_PATH}
from pathlib import Path
import torch
from loguru import logger
from hydit.config import get_args
from hydit.modules.models import HunYuanDiT, HUNYUAN_DIT_CONFIG
import numpy as np
import onnx
import onnx_graphsurgeon as gs
import polygraphy.backend.onnx.loader
def _to_tuple(val):
if isinstance(val, (list, tuple)):
if len(val) == 1:
val = [val[0], val[0]]
elif len(val) == 2:
val = tuple(val)
else:
raise ValueError(f"Invalid value: {val}")
elif isinstance(val, (int, float)):
val = (val, val)
else:
raise ValueError(f"Invalid value: {val}")
return val
class ExportONNX(object):
def __init__(self, args, models_root_path):
self.args = args
self.model = None
# Set device and disable gradient
self.device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_grad_enabled(False)
# Check arguments
t2i_root_path = Path(models_root_path) / "t2i"
self.root = t2i_root_path
logger.info(f"Got text-to-image model root path: {t2i_root_path}")
# Create folder to save onnx model
onnx_workdir = Path(self.args.onnx_workdir)
self.onnx_workdir = onnx_workdir
self.onnx_export = self.onnx_workdir / "export/model.onnx"
self.onnx_export.parent.mkdir(parents=True, exist_ok=True)
self.onnx_modify = self.onnx_workdir / "export_modified/model.onnx"
self.onnx_modify.parent.mkdir(parents=True, exist_ok=True)
self.onnx_fmha = self.onnx_workdir / "export_modified_fmha/model.onnx"
self.onnx_fmha.parent.mkdir(parents=True, exist_ok=True)
def load_model(self):
# ========================================================================
# Create model structure and load the checkpoint
logger.info(f"Building HunYuan-DiT model...")
model_config = HUNYUAN_DIT_CONFIG[self.args.model]
image_size = _to_tuple(self.args.image_size)
latent_size = (image_size[0] // 8, image_size[1] // 8)
model_dir = self.root / "model"
model_path = model_dir / f"pytorch_model_{self.args.load_key}.pt"
if not model_path.exists():
raise ValueError(f"model_path not exists: {model_path}")
# Build model structure
self.model = HunYuanDiT(self.args,
input_size=latent_size,
**model_config,
log_fn=logger.info,
).half().to(self.device) # Force to use fp16
# Load model checkpoint
logger.info(f"Loading torch model {model_path}...")
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
self.model.load_state_dict(state_dict)
self.model.eval()
logger.info(f"Loading torch model finished")
logger.info("==================================================")
logger.info(f" Model is ready. ")
logger.info("==================================================")
def export(self):
if self.model is None:
self.load_model()
# Construct model inputs
latent_model_input = torch.randn(2, 4, 128, 128, device=self.device).half()
t_expand = torch.randint(0, 1000, [2], device=self.device).half()
prompt_embeds = torch.randn(2, 77, 1024, device=self.device).half()
attention_mask = torch.randint(0, 2, [2, 77], device=self.device).long()
prompt_embeds_t5 = torch.randn(2, 256, 2048, device=self.device).half()
attention_mask_t5 = torch.randint(0, 2, [2, 256], device=self.device).long()
ims = torch.tensor([[1024, 1024, 1024, 1024, 0, 0], [1024, 1024, 1024, 1024, 0, 0]], device=self.device).half()
style = torch.tensor([0, 0], device=self.device).long()
freqs_cis_img = (
torch.randn(4096, 88),
torch.randn(4096, 88),
)
save_to = self.onnx_export
logger.info(f"Exporting ONNX model {save_to}...")
logger.info(f"Exporting ONNX external data {save_to.parent}...")
model_args = (
latent_model_input,
t_expand,
prompt_embeds,
attention_mask,
prompt_embeds_t5,
attention_mask_t5,
ims, style,
freqs_cis_img[0],
freqs_cis_img[1]
)
torch.onnx.export(self.model,
model_args,
str(save_to),
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names=["x", "t", "encoder_hidden_states", "text_embedding_mask",
"encoder_hidden_states_t5", "text_embedding_mask_t5", "image_meta_size", "style",
"cos_cis_img", "sin_cis_img"],
output_names=["output"],
dynamic_axes={"x": {0: "2B", 2: "H", 3: "W"}, "t": {0: "2B"},
"encoder_hidden_states": {0: "2B"},
"text_embedding_mask": {0: "2B"}, "encoder_hidden_states_t5": {0: "2B"},
"text_embedding_mask_t5": {0: "2B"},
"image_meta_size": {0: "2B"}, "style": {0: "2B"}, "cos_cis_img": {0: "seqlen"},
"sin_cis_img": {0: "seqlen"}},
)
logger.info("Exporting onnx finished")
def postprocessing(self):
load_from = self.onnx_export
save_to = self.onnx_modify
logger.info(f"Postprocessing ONNX model {load_from}...")
onnxModel = onnx.load(str(load_from), load_external_data=False)
onnx.load_external_data_for_model(onnxModel, str(load_from.parent))
graph = gs.import_onnx(onnxModel)
# ADD GAMMA BETA FOR LN
for node in graph.nodes:
if node.name == "/final_layer/norm_final/LayerNormalization":
constantKernel = gs.Constant("final_layer.norm_final.weight",
np.ascontiguousarray(np.ones((1408,), dtype=np.float16)))
constantBias = gs.Constant("final_layer.norm_final.bias",
np.ascontiguousarray(np.zeros((1408,), dtype=np.float16)))
node.inputs = [node.inputs[0], constantKernel, constantBias]
graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph.cleanup()),
str(save_to),
save_as_external_data=True,
all_tensors_to_one_file=False,
location=str(save_to.parent),
)
logger.info(f"Postprocessing ONNX model finished: {save_to}")
def fuse_attn(self):
load_from = self.onnx_modify
save_to = self.onnx_fmha
logger.info(f"FuseAttn ONNX model {load_from}...")
onnx_graph = polygraphy.backend.onnx.loader.fold_constants(
onnx.load(str(load_from)),
allow_onnxruntime_shape_inference=True,
)
graph = gs.import_onnx(onnx_graph)
cnt = 0
for node in graph.nodes:
if node.op == "Softmax" and node.i().op == "MatMul" and node.o().op == "MatMul" and \
node.o().o().op == "Transpose":
if "pooler" in node.name:
continue
if "attn1" in node.name:
matmul_0 = node.i()
transpose = matmul_0.i(1, 0)
transpose.attrs["perm"] = [0, 2, 1, 3]
k = transpose.outputs[0]
q = gs.Variable("transpose_0_v_{}".format(cnt), np.dtype(np.float16))
transpose_0 = gs.Node("Transpose", "Transpose_0_{}".format(cnt),
attrs={"perm": [0, 2, 1, 3]},
inputs=[matmul_0.inputs[0]],
outputs=[q])
graph.nodes.append(transpose_0)
matmul_1 = node.o()
v = gs.Variable("transpose_1_v_{}".format(cnt), np.dtype(np.float16))
transpose_1 = gs.Node("Transpose", "Transpose_1_{}".format(cnt),
attrs={"perm": [0, 2, 1, 3]},
inputs=[matmul_1.inputs[1]],
outputs=[v])
graph.nodes.append(transpose_1)
output_variable = node.o().o().outputs[0]
# fMHA_v = gs.Variable("fMHA_v", np.dtype(np.float16))
fMHA = gs.Node("fMHAPlugin", "fMHAPlugin_1_{}".format(cnt),
# attrs={"scale": 1.0},
inputs=[q, k, v],
outputs=[output_variable])
graph.nodes.append(fMHA)
node.o().o().outputs = []
cnt = cnt + 1
elif "attn2" in node.name:
matmul_0 = node.i()
transpose_q = matmul_0.i()
transpose_k = matmul_0.i(1, 0)
matmul_1 = node.o()
transpose_v = matmul_1.i(1, 0)
q = transpose_q.inputs[0]
k = transpose_k.inputs[0]
v = transpose_v.inputs[0]
output_variable = node.o().o().outputs[0]
fMHA = gs.Node("fMHAPlugin", "fMHAPlugin_1_{}".format(cnt),
# attrs={"scale": 1.0},
inputs=[q, k, v],
outputs=[output_variable])
graph.nodes.append(fMHA)
node.o().o().outputs = []
cnt = cnt + 1
logger.info("mha count: ", cnt)
onnx.save(gs.export_onnx(graph.cleanup()),
str(save_to),
save_as_external_data=True,
)
logger.info(f"FuseAttn ONNX model finished: {save_to}")
if __name__ == "__main__":
args = get_args()
models_root_path = Path(args.model_root)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")
exporter = ExportONNX(args, models_root_path)
exporter.export()
exporter.postprocessing()
exporter.fuse_attn()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment