"test/vscode:/vscode.git/clone" did not exist on "17de02f98d8f28e5affec7c5ff8e28f110d0af42"
Commit 91c5dd15 authored by gushiqiao's avatar gushiqiao Committed by Yang Yong(雍洋)
Browse files

Support cpu offload for hunyuan

parent 2b0139fe
...@@ -28,17 +28,13 @@ class HunyuanTransformerInfer: ...@@ -28,17 +28,13 @@ class HunyuanTransformerInfer:
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis): def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
return self.infer_func(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis) return self.infer_func(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
def _infer_with_offload( def _infer_with_offload(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis
):
txt_seq_len = txt.shape[0] txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0] img_seq_len = img.shape[0]
for double_block_idx in range(self.double_blocks_num): for double_block_idx in range(self.double_blocks_num):
if double_block_idx == 0: if double_block_idx == 0:
self.double_weights_stream_mgr.active_weights[ self.double_weights_stream_mgr.active_weights[0] = weights.double_blocks_weights[0]
0
] = weights.double_blocks_weights[0]
self.double_weights_stream_mgr.active_weights[0].to_cuda() self.double_weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.double_weights_stream_mgr.compute_stream): with torch.cuda.stream(self.double_weights_stream_mgr.compute_stream):
...@@ -53,18 +49,14 @@ class HunyuanTransformerInfer: ...@@ -53,18 +49,14 @@ class HunyuanTransformerInfer:
) )
if double_block_idx < self.double_blocks_num - 1: if double_block_idx < self.double_blocks_num - 1:
self.double_weights_stream_mgr.prefetch_weights( self.double_weights_stream_mgr.prefetch_weights(double_block_idx + 1, weights.double_blocks_weights)
double_block_idx + 1, weights.double_blocks_weights
)
self.double_weights_stream_mgr.swap_weights() self.double_weights_stream_mgr.swap_weights()
x = torch.cat((img, txt), 0) x = torch.cat((img, txt), 0)
for single_block_idx in range(self.single_blocks_num): for single_block_idx in range(self.single_blocks_num):
if single_block_idx == 0: if single_block_idx == 0:
self.single_weights_stream_mgr.active_weights[ self.single_weights_stream_mgr.active_weights[0] = weights.single_blocks_weights[0]
0
] = weights.single_blocks_weights[0]
self.single_weights_stream_mgr.active_weights[0].to_cuda() self.single_weights_stream_mgr.active_weights[0].to_cuda()
with torch.cuda.stream(self.single_weights_stream_mgr.compute_stream): with torch.cuda.stream(self.single_weights_stream_mgr.compute_stream):
x = self.infer_single_block( x = self.infer_single_block(
...@@ -77,9 +69,7 @@ class HunyuanTransformerInfer: ...@@ -77,9 +69,7 @@ class HunyuanTransformerInfer:
freqs_cis, freqs_cis,
) )
if single_block_idx < self.single_blocks_num - 1: if single_block_idx < self.single_blocks_num - 1:
self.single_weights_stream_mgr.prefetch_weights( self.single_weights_stream_mgr.prefetch_weights(single_block_idx + 1, weights.single_blocks_weights)
single_block_idx + 1, weights.single_blocks_weights
)
self.single_weights_stream_mgr.swap_weights() self.single_weights_stream_mgr.swap_weights()
img = x[:img_seq_len, ...] img = x[:img_seq_len, ...]
...@@ -116,9 +106,7 @@ class HunyuanTransformerInfer: ...@@ -116,9 +106,7 @@ class HunyuanTransformerInfer:
img = x[:img_seq_len, ...] img = x[:img_seq_len, ...]
return img, vec return img, vec
def infer_double_block( def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis
):
vec_silu = torch.nn.functional.silu(vec) vec_silu = torch.nn.functional.silu(vec)
img_mod_out = weights.img_mod.apply(vec_silu) img_mod_out = weights.img_mod.apply(vec_silu)
...@@ -141,12 +129,8 @@ class HunyuanTransformerInfer: ...@@ -141,12 +129,8 @@ class HunyuanTransformerInfer:
txt_mod2_gate, txt_mod2_gate,
) = txt_mod_out.chunk(6, dim=-1) ) = txt_mod_out.chunk(6, dim=-1)
img_q, img_k, img_v = self.infer_double_block_img_pre_atten( img_q, img_k, img_v = self.infer_double_block_img_pre_atten(weights, img, img_mod1_scale, img_mod1_shift, freqs_cis)
weights, img, img_mod1_scale, img_mod1_shift, freqs_cis txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(weights, txt, txt_mod1_scale, txt_mod1_shift)
)
txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(
weights, txt, txt_mod1_scale, txt_mod1_shift
)
q = torch.cat((img_q, txt_q), dim=0) q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0) k = torch.cat((img_k, txt_k), dim=0)
...@@ -197,18 +181,12 @@ class HunyuanTransformerInfer: ...@@ -197,18 +181,12 @@ class HunyuanTransformerInfer:
) )
return img, txt return img, txt
def infer_double_block_img_pre_atten( def infer_double_block_img_pre_atten(self, weights, img, img_mod1_scale, img_mod1_shift, freqs_cis):
self, weights, img, img_mod1_scale, img_mod1_shift, freqs_cis img_modulated = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
):
img_modulated = torch.nn.functional.layer_norm(
img, (img.shape[1],), None, None, 1e-6
)
img_modulated = img_modulated * (1 + img_mod1_scale) + img_mod1_shift img_modulated = img_modulated * (1 + img_mod1_scale) + img_mod1_shift
img_qkv = weights.img_attn_qkv.apply(img_modulated) img_qkv = weights.img_attn_qkv.apply(img_modulated)
img_q, img_k, img_v = rearrange( img_q, img_k, img_v = rearrange(img_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
img_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num
)
img_q = weights.img_attn_q_norm.apply(img_q) img_q = weights.img_attn_q_norm.apply(img_q)
img_k = weights.img_attn_k_norm.apply(img_k) img_k = weights.img_attn_k_norm.apply(img_k)
...@@ -216,18 +194,12 @@ class HunyuanTransformerInfer: ...@@ -216,18 +194,12 @@ class HunyuanTransformerInfer:
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis) img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis)
return img_q, img_k, img_v return img_q, img_k, img_v
def infer_double_block_txt_pre_atten( def infer_double_block_txt_pre_atten(self, weights, txt, txt_mod1_scale, txt_mod1_shift):
self, weights, txt, txt_mod1_scale, txt_mod1_shift txt_modulated = torch.nn.functional.layer_norm(txt, (txt.shape[1],), None, None, 1e-6)
):
txt_modulated = torch.nn.functional.layer_norm(
txt, (txt.shape[1],), None, None, 1e-6
)
txt_modulated = txt_modulated * (1 + txt_mod1_scale) + txt_mod1_shift txt_modulated = txt_modulated * (1 + txt_mod1_scale) + txt_mod1_shift
txt_qkv = weights.txt_attn_qkv.apply(txt_modulated) txt_qkv = weights.txt_attn_qkv.apply(txt_modulated)
txt_q, txt_k, txt_v = rearrange( txt_q, txt_k, txt_v = rearrange(txt_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
txt_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num
)
txt_q = weights.txt_attn_q_norm.apply(txt_q) txt_q = weights.txt_attn_q_norm.apply(txt_q)
txt_k = weights.txt_attn_k_norm.apply(txt_k) txt_k = weights.txt_attn_k_norm.apply(txt_k)
...@@ -279,9 +251,7 @@ class HunyuanTransformerInfer: ...@@ -279,9 +251,7 @@ class HunyuanTransformerInfer:
txt = txt + out txt = txt + out
return txt return txt
def infer_single_block( def infer_single_block(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis
):
out = torch.nn.functional.silu(vec) out = torch.nn.functional.silu(vec)
out = weights.modulation.apply(out) out = weights.modulation.apply(out)
mod_shift, mod_scale, mod_gate = out.chunk(3, dim=-1) mod_shift, mod_scale, mod_gate = out.chunk(3, dim=-1)
...@@ -291,9 +261,7 @@ class HunyuanTransformerInfer: ...@@ -291,9 +261,7 @@ class HunyuanTransformerInfer:
x_mod = weights.linear1.apply(x_mod) x_mod = weights.linear1.apply(x_mod)
qkv, mlp = torch.split( qkv, mlp = torch.split(x_mod, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
x_mod, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
)
q, k, v = rearrange(qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num) q, k, v = rearrange(qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
......
...@@ -136,7 +136,6 @@ class HunyuanTransformerSingleBlock: ...@@ -136,7 +136,6 @@ class HunyuanTransformerSingleBlock:
mm_weight.set_config(self.config["mm_config"]) mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
def to_cpu(self): def to_cpu(self):
for mm_weight in self.weight_list: for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)): if isinstance(mm_weight, (MMWeightTemplate, RMSWeightTemplate)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment