Commit 2b0139fe authored by gushiqiao's avatar gushiqiao Committed by Yang Yong(雍洋)
Browse files

Support cpu offload for hunyuan

parent 83c5f3b8
......@@ -41,7 +41,7 @@ def load_models(args, model_config):
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(args.model_path, "text_encoder"), init_device)
text_encoder_2 = TextEncoderHFClipModel(os.path.join(args.model_path, "text_encoder_2"), init_device)
text_encoders = [text_encoder_1, text_encoder_2]
model = HunyuanModel(args.model_path, model_config)
model = HunyuanModel(args.model_path, model_config, device=init_device)
vae_model = VideoEncoderKLCausal3DModel(args.model_path, dtype=torch.float16, device=init_device)
elif args.model_cls == "wan2.1":
......
......@@ -2,6 +2,7 @@ import torch
from einops import rearrange
from lightx2v.attentions import attention
from .utils_bf16 import apply_rotary_emb
from lightx2v.common.offload.manager import WeightStreamManager
class HunyuanTransformerInfer:
......@@ -14,26 +15,110 @@ class HunyuanTransformerInfer:
self.hidden_size = 3072
self.mlp_hidden_dim = 12288
self.parallel_attention = None
if self.config["cpu_offload"]:
self.double_weights_stream_mgr = WeightStreamManager()
self.single_weights_stream_mgr = WeightStreamManager()
self.infer_func = self._infer_with_offload
else:
self.infer_func = self._infer_without_offload
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
return self.infer_func(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
def _infer_with_offload(
self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis
):
txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0]
for double_block_idx in range(self.double_blocks_num):
if double_block_idx == 0:
self.double_weights_stream_mgr.active_weights[
0
] = weights.double_blocks_weights[0]
self.double_weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.double_weights_stream_mgr.compute_stream):
img, txt = self.infer_double_block(
self.double_weights_stream_mgr.active_weights[0],
img,
txt,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
)
if double_block_idx < self.double_blocks_num - 1:
self.double_weights_stream_mgr.prefetch_weights(
double_block_idx + 1, weights.double_blocks_weights
)
self.double_weights_stream_mgr.swap_weights()
x = torch.cat((img, txt), 0)
for single_block_idx in range(self.single_blocks_num):
if single_block_idx == 0:
self.single_weights_stream_mgr.active_weights[
0
] = weights.single_blocks_weights[0]
self.single_weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.single_weights_stream_mgr.compute_stream):
x = self.infer_single_block(
weights.single_blocks_weights[single_block_idx],
x,
vec,
txt_seq_len,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
)
if single_block_idx < self.single_blocks_num - 1:
self.single_weights_stream_mgr.prefetch_weights(
single_block_idx + 1, weights.single_blocks_weights
)
self.single_weights_stream_mgr.swap_weights()
img = x[:img_seq_len, ...]
return img, vec
def _infer_without_offload(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0]
for i in range(self.double_blocks_num):
img, txt = self.infer_double_block(weights.double_blocks_weights[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
img, txt = self.infer_double_block(
weights.double_blocks_weights[i],
img,
txt,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
)
x = torch.cat((img, txt), 0)
for i in range(self.single_blocks_num):
x = self.infer_single_block(weights.single_blocks_weights[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
x = self.infer_single_block(
weights.single_blocks_weights[i],
x,
vec,
txt_seq_len,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
)
img = x[:img_seq_len, ...]
return img, vec
def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
def infer_double_block(
self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis
):
vec_silu = torch.nn.functional.silu(vec)
img_mod_out = weights.img_mod.apply(vec_silu)
......@@ -56,8 +141,12 @@ class HunyuanTransformerInfer:
txt_mod2_gate,
) = txt_mod_out.chunk(6, dim=-1)
img_q, img_k, img_v = self.infer_double_block_img_pre_atten(weights, img, img_mod1_scale, img_mod1_shift, freqs_cis)
txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(weights, txt, txt_mod1_scale, txt_mod1_shift)
img_q, img_k, img_v = self.infer_double_block_img_pre_atten(
weights, img, img_mod1_scale, img_mod1_shift, freqs_cis
)
txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(
weights, txt, txt_mod1_scale, txt_mod1_shift
)
q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0)
......@@ -88,16 +177,38 @@ class HunyuanTransformerInfer:
)
img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :]
img = self.infer_double_block_img_post_atten(weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate)
txt = self.infer_double_block_txt_post_atten(weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate)
img = self.infer_double_block_img_post_atten(
weights,
img,
img_attn,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
)
txt = self.infer_double_block_txt_post_atten(
weights,
txt,
txt_attn,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
)
return img, txt
def infer_double_block_img_pre_atten(self, weights, img, img_mod1_scale, img_mod1_shift, freqs_cis):
img_modulated = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
def infer_double_block_img_pre_atten(
self, weights, img, img_mod1_scale, img_mod1_shift, freqs_cis
):
img_modulated = torch.nn.functional.layer_norm(
img, (img.shape[1],), None, None, 1e-6
)
img_modulated = img_modulated * (1 + img_mod1_scale) + img_mod1_shift
img_qkv = weights.img_attn_qkv.apply(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
img_q, img_k, img_v = rearrange(
img_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num
)
img_q = weights.img_attn_q_norm.apply(img_q)
img_k = weights.img_attn_k_norm.apply(img_k)
......@@ -105,18 +216,33 @@ class HunyuanTransformerInfer:
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis)
return img_q, img_k, img_v
def infer_double_block_txt_pre_atten(self, weights, txt, txt_mod1_scale, txt_mod1_shift):
txt_modulated = torch.nn.functional.layer_norm(txt, (txt.shape[1],), None, None, 1e-6)
def infer_double_block_txt_pre_atten(
self, weights, txt, txt_mod1_scale, txt_mod1_shift
):
txt_modulated = torch.nn.functional.layer_norm(
txt, (txt.shape[1],), None, None, 1e-6
)
txt_modulated = txt_modulated * (1 + txt_mod1_scale) + txt_mod1_shift
txt_qkv = weights.txt_attn_qkv.apply(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
txt_q, txt_k, txt_v = rearrange(
txt_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num
)
txt_q = weights.txt_attn_q_norm.apply(txt_q)
txt_k = weights.txt_attn_k_norm.apply(txt_k)
return txt_q, txt_k, txt_v
def infer_double_block_img_post_atten(self, weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate):
def infer_double_block_img_post_atten(
self,
weights,
img,
img_attn,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
):
out = weights.img_attn_proj.apply(img_attn)
out = out * img_mod1_gate
img = img + out
......@@ -130,7 +256,16 @@ class HunyuanTransformerInfer:
img = img + out
return img
def infer_double_block_txt_post_atten(self, weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate):
def infer_double_block_txt_post_atten(
self,
weights,
txt,
txt_attn,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
):
out = weights.txt_attn_proj.apply(txt_attn)
out = out * txt_mod1_gate
txt = txt + out
......@@ -144,7 +279,9 @@ class HunyuanTransformerInfer:
txt = txt + out
return txt
def infer_single_block(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
def infer_single_block(
self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis
):
out = torch.nn.functional.silu(vec)
out = weights.modulation.apply(out)
mod_shift, mod_scale, mod_gate = out.chunk(3, dim=-1)
......@@ -154,7 +291,9 @@ class HunyuanTransformerInfer:
x_mod = weights.linear1.apply(x_mod)
qkv, mlp = torch.split(x_mod, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
qkv, mlp = torch.split(
x_mod, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
)
q, k, v = rearrange(qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
......
......@@ -17,9 +17,10 @@ class HunyuanModel:
post_weight_class = HunyuanPostWeights
transformer_weight_class = HunyuanTransformerWeights
def __init__(self, model_path, config):
def __init__(self, model_path, config, device):
self.model_path = model_path
self.config = config
self.device = device
self._init_infer_class()
self._init_weights()
self._init_infer()
......@@ -47,7 +48,7 @@ class HunyuanModel:
def _load_ckpt(self):
ckpt_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt")
weight_dict = torch.load(ckpt_path, map_location="cuda", weights_only=True)["module"]
weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)["module"]
return weight_dict
def _init_weights(self):
......@@ -82,6 +83,9 @@ class HunyuanModel:
@torch.no_grad()
def infer(self, text_encoder_output, image_encoder_output, args):
if self.config["cpu_offload"]:
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
pre_infer_out = self.pre_infer.infer(
self.pre_weight,
self.scheduler.latents,
......@@ -95,3 +99,6 @@ class HunyuanModel:
)
img, vec = self.transformer_infer.infer(self.transformer_weights, *pre_infer_out)
self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, img, vec, self.scheduler.latents.shape)
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
\ No newline at end of file
......@@ -80,20 +80,30 @@ class HunyuanTransformerDoubleBlock:
]
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate):
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate):
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate):
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda()
def to_cpu_sync(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu(non_blocking=True)
def to_cuda_sync(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda(non_blocking=True)
class HunyuanTransformerSingleBlock:
def __init__(self, block_index, config):
......@@ -122,16 +132,27 @@ class HunyuanTransformerSingleBlock:
]
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate):
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate):
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu()
def to_cuda(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate):
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda()
def to_cpu_sync(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu(non_blocking=True)
def to_cuda_sync(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda(non_blocking=True)
......@@ -87,9 +87,6 @@ class WanTransformerAttentionBlock:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)
if self.config["cpu_offload"]:
mm_weight.to_cpu()
self.modulation = self.modulation.cpu()
def to_cpu(self):
for mm_weight in self.weight_list:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment