Commit 2b0139fe authored by gushiqiao's avatar gushiqiao Committed by Yang Yong(雍洋)
Browse files

Support cpu offload for hunyuan

parent 83c5f3b8
...@@ -41,7 +41,7 @@ def load_models(args, model_config): ...@@ -41,7 +41,7 @@ def load_models(args, model_config):
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(args.model_path, "text_encoder"), init_device) text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(args.model_path, "text_encoder"), init_device)
text_encoder_2 = TextEncoderHFClipModel(os.path.join(args.model_path, "text_encoder_2"), init_device) text_encoder_2 = TextEncoderHFClipModel(os.path.join(args.model_path, "text_encoder_2"), init_device)
text_encoders = [text_encoder_1, text_encoder_2] text_encoders = [text_encoder_1, text_encoder_2]
model = HunyuanModel(args.model_path, model_config) model = HunyuanModel(args.model_path, model_config, device=init_device)
vae_model = VideoEncoderKLCausal3DModel(args.model_path, dtype=torch.float16, device=init_device) vae_model = VideoEncoderKLCausal3DModel(args.model_path, dtype=torch.float16, device=init_device)
elif args.model_cls == "wan2.1": elif args.model_cls == "wan2.1":
......
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
from einops import rearrange from einops import rearrange
from lightx2v.attentions import attention from lightx2v.attentions import attention
from .utils_bf16 import apply_rotary_emb from .utils_bf16 import apply_rotary_emb
from lightx2v.common.offload.manager import WeightStreamManager
class HunyuanTransformerInfer: class HunyuanTransformerInfer:
...@@ -14,26 +15,110 @@ class HunyuanTransformerInfer: ...@@ -14,26 +15,110 @@ class HunyuanTransformerInfer:
self.hidden_size = 3072 self.hidden_size = 3072
self.mlp_hidden_dim = 12288 self.mlp_hidden_dim = 12288
self.parallel_attention = None self.parallel_attention = None
if self.config["cpu_offload"]:
self.double_weights_stream_mgr = WeightStreamManager()
self.single_weights_stream_mgr = WeightStreamManager()
self.infer_func = self._infer_with_offload
else:
self.infer_func = self._infer_without_offload
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis): def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
return self.infer_func(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
def _infer_with_offload(
self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis
):
txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0]
for double_block_idx in range(self.double_blocks_num):
if double_block_idx == 0:
self.double_weights_stream_mgr.active_weights[
0
] = weights.double_blocks_weights[0]
self.double_weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.double_weights_stream_mgr.compute_stream):
img, txt = self.infer_double_block(
self.double_weights_stream_mgr.active_weights[0],
img,
txt,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
)
if double_block_idx < self.double_blocks_num - 1:
self.double_weights_stream_mgr.prefetch_weights(
double_block_idx + 1, weights.double_blocks_weights
)
self.double_weights_stream_mgr.swap_weights()
x = torch.cat((img, txt), 0)
for single_block_idx in range(self.single_blocks_num):
if single_block_idx == 0:
self.single_weights_stream_mgr.active_weights[
0
] = weights.single_blocks_weights[0]
self.single_weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.single_weights_stream_mgr.compute_stream):
x = self.infer_single_block(
weights.single_blocks_weights[single_block_idx],
x,
vec,
txt_seq_len,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
)
if single_block_idx < self.single_blocks_num - 1:
self.single_weights_stream_mgr.prefetch_weights(
single_block_idx + 1, weights.single_blocks_weights
)
self.single_weights_stream_mgr.swap_weights()
img = x[:img_seq_len, ...]
return img, vec
def _infer_without_offload(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
txt_seq_len = txt.shape[0] txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0] img_seq_len = img.shape[0]
for i in range(self.double_blocks_num): for i in range(self.double_blocks_num):
img, txt = self.infer_double_block(weights.double_blocks_weights[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis) img, txt = self.infer_double_block(
weights.double_blocks_weights[i],
img,
txt,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
)
x = torch.cat((img, txt), 0) x = torch.cat((img, txt), 0)
for i in range(self.single_blocks_num): for i in range(self.single_blocks_num):
x = self.infer_single_block(weights.single_blocks_weights[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis) x = self.infer_single_block(
weights.single_blocks_weights[i],
x,
vec,
txt_seq_len,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
)
img = x[:img_seq_len, ...] img = x[:img_seq_len, ...]
return img, vec return img, vec
def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis): def infer_double_block(
self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis
):
vec_silu = torch.nn.functional.silu(vec) vec_silu = torch.nn.functional.silu(vec)
img_mod_out = weights.img_mod.apply(vec_silu) img_mod_out = weights.img_mod.apply(vec_silu)
...@@ -56,8 +141,12 @@ class HunyuanTransformerInfer: ...@@ -56,8 +141,12 @@ class HunyuanTransformerInfer:
txt_mod2_gate, txt_mod2_gate,
) = txt_mod_out.chunk(6, dim=-1) ) = txt_mod_out.chunk(6, dim=-1)
img_q, img_k, img_v = self.infer_double_block_img_pre_atten(weights, img, img_mod1_scale, img_mod1_shift, freqs_cis) img_q, img_k, img_v = self.infer_double_block_img_pre_atten(
txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(weights, txt, txt_mod1_scale, txt_mod1_shift) weights, img, img_mod1_scale, img_mod1_shift, freqs_cis
)
txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(
weights, txt, txt_mod1_scale, txt_mod1_shift
)
q = torch.cat((img_q, txt_q), dim=0) q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0) k = torch.cat((img_k, txt_k), dim=0)
...@@ -88,16 +177,38 @@ class HunyuanTransformerInfer: ...@@ -88,16 +177,38 @@ class HunyuanTransformerInfer:
) )
img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :] img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :]
img = self.infer_double_block_img_post_atten(weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) img = self.infer_double_block_img_post_atten(
txt = self.infer_double_block_txt_post_atten(weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) weights,
img,
img_attn,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
)
txt = self.infer_double_block_txt_post_atten(
weights,
txt,
txt_attn,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
)
return img, txt return img, txt
def infer_double_block_img_pre_atten(self, weights, img, img_mod1_scale, img_mod1_shift, freqs_cis): def infer_double_block_img_pre_atten(
img_modulated = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6) self, weights, img, img_mod1_scale, img_mod1_shift, freqs_cis
):
img_modulated = torch.nn.functional.layer_norm(
img, (img.shape[1],), None, None, 1e-6
)
img_modulated = img_modulated * (1 + img_mod1_scale) + img_mod1_shift img_modulated = img_modulated * (1 + img_mod1_scale) + img_mod1_shift
img_qkv = weights.img_attn_qkv.apply(img_modulated) img_qkv = weights.img_attn_qkv.apply(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num) img_q, img_k, img_v = rearrange(
img_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num
)
img_q = weights.img_attn_q_norm.apply(img_q) img_q = weights.img_attn_q_norm.apply(img_q)
img_k = weights.img_attn_k_norm.apply(img_k) img_k = weights.img_attn_k_norm.apply(img_k)
...@@ -105,18 +216,33 @@ class HunyuanTransformerInfer: ...@@ -105,18 +216,33 @@ class HunyuanTransformerInfer:
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis) img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis)
return img_q, img_k, img_v return img_q, img_k, img_v
def infer_double_block_txt_pre_atten(self, weights, txt, txt_mod1_scale, txt_mod1_shift): def infer_double_block_txt_pre_atten(
txt_modulated = torch.nn.functional.layer_norm(txt, (txt.shape[1],), None, None, 1e-6) self, weights, txt, txt_mod1_scale, txt_mod1_shift
):
txt_modulated = torch.nn.functional.layer_norm(
txt, (txt.shape[1],), None, None, 1e-6
)
txt_modulated = txt_modulated * (1 + txt_mod1_scale) + txt_mod1_shift txt_modulated = txt_modulated * (1 + txt_mod1_scale) + txt_mod1_shift
txt_qkv = weights.txt_attn_qkv.apply(txt_modulated) txt_qkv = weights.txt_attn_qkv.apply(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num) txt_q, txt_k, txt_v = rearrange(
txt_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num
)
txt_q = weights.txt_attn_q_norm.apply(txt_q) txt_q = weights.txt_attn_q_norm.apply(txt_q)
txt_k = weights.txt_attn_k_norm.apply(txt_k) txt_k = weights.txt_attn_k_norm.apply(txt_k)
return txt_q, txt_k, txt_v return txt_q, txt_k, txt_v
def infer_double_block_img_post_atten(self, weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate): def infer_double_block_img_post_atten(
self,
weights,
img,
img_attn,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
):
out = weights.img_attn_proj.apply(img_attn) out = weights.img_attn_proj.apply(img_attn)
out = out * img_mod1_gate out = out * img_mod1_gate
img = img + out img = img + out
...@@ -130,7 +256,16 @@ class HunyuanTransformerInfer: ...@@ -130,7 +256,16 @@ class HunyuanTransformerInfer:
img = img + out img = img + out
return img return img
def infer_double_block_txt_post_atten(self, weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate): def infer_double_block_txt_post_atten(
self,
weights,
txt,
txt_attn,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
):
out = weights.txt_attn_proj.apply(txt_attn) out = weights.txt_attn_proj.apply(txt_attn)
out = out * txt_mod1_gate out = out * txt_mod1_gate
txt = txt + out txt = txt + out
...@@ -144,7 +279,9 @@ class HunyuanTransformerInfer: ...@@ -144,7 +279,9 @@ class HunyuanTransformerInfer:
txt = txt + out txt = txt + out
return txt return txt
def infer_single_block(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis): def infer_single_block(
self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis
):
out = torch.nn.functional.silu(vec) out = torch.nn.functional.silu(vec)
out = weights.modulation.apply(out) out = weights.modulation.apply(out)
mod_shift, mod_scale, mod_gate = out.chunk(3, dim=-1) mod_shift, mod_scale, mod_gate = out.chunk(3, dim=-1)
...@@ -154,7 +291,9 @@ class HunyuanTransformerInfer: ...@@ -154,7 +291,9 @@ class HunyuanTransformerInfer:
x_mod = weights.linear1.apply(x_mod) x_mod = weights.linear1.apply(x_mod)
qkv, mlp = torch.split(x_mod, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(
x_mod, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
)
q, k, v = rearrange(qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num) q, k, v = rearrange(qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
......
...@@ -17,9 +17,10 @@ class HunyuanModel: ...@@ -17,9 +17,10 @@ class HunyuanModel:
post_weight_class = HunyuanPostWeights post_weight_class = HunyuanPostWeights
transformer_weight_class = HunyuanTransformerWeights transformer_weight_class = HunyuanTransformerWeights
def __init__(self, model_path, config): def __init__(self, model_path, config, device):
self.model_path = model_path self.model_path = model_path
self.config = config self.config = config
self.device = device
self._init_infer_class() self._init_infer_class()
self._init_weights() self._init_weights()
self._init_infer() self._init_infer()
...@@ -47,7 +48,7 @@ class HunyuanModel: ...@@ -47,7 +48,7 @@ class HunyuanModel:
def _load_ckpt(self): def _load_ckpt(self):
ckpt_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt") ckpt_path = os.path.join(self.model_path, "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt")
weight_dict = torch.load(ckpt_path, map_location="cuda", weights_only=True)["module"] weight_dict = torch.load(ckpt_path, map_location=self.device, weights_only=True)["module"]
return weight_dict return weight_dict
def _init_weights(self): def _init_weights(self):
...@@ -82,6 +83,9 @@ class HunyuanModel: ...@@ -82,6 +83,9 @@ class HunyuanModel:
@torch.no_grad() @torch.no_grad()
def infer(self, text_encoder_output, image_encoder_output, args): def infer(self, text_encoder_output, image_encoder_output, args):
if self.config["cpu_offload"]:
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
pre_infer_out = self.pre_infer.infer( pre_infer_out = self.pre_infer.infer(
self.pre_weight, self.pre_weight,
self.scheduler.latents, self.scheduler.latents,
...@@ -95,3 +99,6 @@ class HunyuanModel: ...@@ -95,3 +99,6 @@ class HunyuanModel:
) )
img, vec = self.transformer_infer.infer(self.transformer_weights, *pre_infer_out) img, vec = self.transformer_infer.infer(self.transformer_weights, *pre_infer_out)
self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, img, vec, self.scheduler.latents.shape) self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, img, vec, self.scheduler.latents.shape)
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
\ No newline at end of file
...@@ -80,20 +80,30 @@ class HunyuanTransformerDoubleBlock: ...@@ -80,20 +80,30 @@ class HunyuanTransformerDoubleBlock:
] ]
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate): if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.set_config(self.config["mm_config"]) mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
def to_cpu(self): def to_cpu(self):
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate): if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu() mm_weight.to_cpu()
def to_cuda(self): def to_cuda(self):
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate): if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda() mm_weight.to_cuda()
def to_cpu_sync(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu(non_blocking=True)
def to_cuda_sync(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda(non_blocking=True)
class HunyuanTransformerSingleBlock: class HunyuanTransformerSingleBlock:
def __init__(self, block_index, config): def __init__(self, block_index, config):
...@@ -122,16 +132,27 @@ class HunyuanTransformerSingleBlock: ...@@ -122,16 +132,27 @@ class HunyuanTransformerSingleBlock:
] ]
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate): if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.set_config(self.config["mm_config"]) mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
def to_cpu(self): def to_cpu(self):
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate): if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu() mm_weight.to_cpu()
def to_cuda(self): def to_cuda(self):
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, MMWeightTemplate) or isinstance(mm_weight, RMSWeightTemplate): if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda() mm_weight.to_cuda()
def to_cpu_sync(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cpu(non_blocking=True)
def to_cuda_sync(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
mm_weight.to_cuda(non_blocking=True)
...@@ -87,9 +87,6 @@ class WanTransformerAttentionBlock: ...@@ -87,9 +87,6 @@ class WanTransformerAttentionBlock:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)): if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
mm_weight.set_config(self.config["mm_config"]) mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
if self.config["cpu_offload"]:
mm_weight.to_cpu()
self.modulation = self.modulation.cpu()
def to_cpu(self): def to_cpu(self):
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment