Commit 91c5dd15 authored by gushiqiao's avatar gushiqiao Committed by Yang Yong(雍洋)
Browse files

Support cpu offload for hunyuan

parent 2b0139fe
......@@ -28,17 +28,13 @@ class HunyuanTransformerInfer:
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
return self.infer_func(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
def _infer_with_offload(
self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis
):
def _infer_with_offload(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0]
for double_block_idx in range(self.double_blocks_num):
if double_block_idx == 0:
self.double_weights_stream_mgr.active_weights[
0
] = weights.double_blocks_weights[0]
self.double_weights_stream_mgr.active_weights[0] = weights.double_blocks_weights[0]
self.double_weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.double_weights_stream_mgr.compute_stream):
......@@ -53,18 +49,14 @@ class HunyuanTransformerInfer:
)
if double_block_idx < self.double_blocks_num - 1:
self.double_weights_stream_mgr.prefetch_weights(
double_block_idx + 1, weights.double_blocks_weights
)
self.double_weights_stream_mgr.prefetch_weights(double_block_idx + 1, weights.double_blocks_weights)
self.double_weights_stream_mgr.swap_weights()
x = torch.cat((img, txt), 0)
for single_block_idx in range(self.single_blocks_num):
if single_block_idx == 0:
self.single_weights_stream_mgr.active_weights[
0
] = weights.single_blocks_weights[0]
self.single_weights_stream_mgr.active_weights[0] = weights.single_blocks_weights[0]
self.single_weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.single_weights_stream_mgr.compute_stream):
x = self.infer_single_block(
......@@ -77,9 +69,7 @@ class HunyuanTransformerInfer:
freqs_cis,
)
if single_block_idx < self.single_blocks_num - 1:
self.single_weights_stream_mgr.prefetch_weights(
single_block_idx + 1, weights.single_blocks_weights
)
self.single_weights_stream_mgr.prefetch_weights(single_block_idx + 1, weights.single_blocks_weights)
self.single_weights_stream_mgr.swap_weights()
img = x[:img_seq_len, ...]
......@@ -116,9 +106,7 @@ class HunyuanTransformerInfer:
img = x[:img_seq_len, ...]
return img, vec
def infer_double_block(
self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis
):
def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
vec_silu = torch.nn.functional.silu(vec)
img_mod_out = weights.img_mod.apply(vec_silu)
......@@ -141,12 +129,8 @@ class HunyuanTransformerInfer:
txt_mod2_gate,
) = txt_mod_out.chunk(6, dim=-1)
img_q, img_k, img_v = self.infer_double_block_img_pre_atten(
weights, img, img_mod1_scale, img_mod1_shift, freqs_cis
)
txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(
weights, txt, txt_mod1_scale, txt_mod1_shift
)
img_q, img_k, img_v = self.infer_double_block_img_pre_atten(weights, img, img_mod1_scale, img_mod1_shift, freqs_cis)
txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(weights, txt, txt_mod1_scale, txt_mod1_shift)
q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0)
......@@ -197,18 +181,12 @@ class HunyuanTransformerInfer:
)
return img, txt
def infer_double_block_img_pre_atten(
self, weights, img, img_mod1_scale, img_mod1_shift, freqs_cis
):
img_modulated = torch.nn.functional.layer_norm(
img, (img.shape[1],), None, None, 1e-6
)
def infer_double_block_img_pre_atten(self, weights, img, img_mod1_scale, img_mod1_shift, freqs_cis):
img_modulated = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
img_modulated = img_modulated * (1 + img_mod1_scale) + img_mod1_shift
img_qkv = weights.img_attn_qkv.apply(img_modulated)
img_q, img_k, img_v = rearrange(
img_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num
)
img_q, img_k, img_v = rearrange(img_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
img_q = weights.img_attn_q_norm.apply(img_q)
img_k = weights.img_attn_k_norm.apply(img_k)
......@@ -216,18 +194,12 @@ class HunyuanTransformerInfer:
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis)
return img_q, img_k, img_v
def infer_double_block_txt_pre_atten(
self, weights, txt, txt_mod1_scale, txt_mod1_shift
):
txt_modulated = torch.nn.functional.layer_norm(
txt, (txt.shape[1],), None, None, 1e-6
)
def infer_double_block_txt_pre_atten(self, weights, txt, txt_mod1_scale, txt_mod1_shift):
txt_modulated = torch.nn.functional.layer_norm(txt, (txt.shape[1],), None, None, 1e-6)
txt_modulated = txt_modulated * (1 + txt_mod1_scale) + txt_mod1_shift
txt_qkv = weights.txt_attn_qkv.apply(txt_modulated)
txt_q, txt_k, txt_v = rearrange(
txt_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num
)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
txt_q = weights.txt_attn_q_norm.apply(txt_q)
txt_k = weights.txt_attn_k_norm.apply(txt_k)
......@@ -279,9 +251,7 @@ class HunyuanTransformerInfer:
txt = txt + out
return txt
def infer_single_block(
self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis
):
def infer_single_block(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
out = torch.nn.functional.silu(vec)
out = weights.modulation.apply(out)
mod_shift, mod_scale, mod_gate = out.chunk(3, dim=-1)
......@@ -291,9 +261,7 @@ class HunyuanTransformerInfer:
x_mod = weights.linear1.apply(x_mod)
qkv, mlp = torch.split(
x_mod, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
)
qkv, mlp = torch.split(x_mod, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
......
......@@ -101,4 +101,4 @@ class HunyuanModel:
self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, img, vec, self.scheduler.latents.shape)
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
\ No newline at end of file
self.post_weight.to_cpu()
......@@ -136,7 +136,6 @@ class HunyuanTransformerSingleBlock:
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)
def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment