Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
69c2f650
"vscode:/vscode.git/clone" did not exist on "9dadc325ebbf33d9fe2ecbde7556d76ea9825410"
Unverified
Commit
69c2f650
authored
Oct 01, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Oct 01, 2025
Browse files
Remove outdated models (#348)
parent
08d2f46a
Changes
59
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
1781 deletions
+0
-1781
lightx2v/models/networks/cogvideox/weights/post_weights.py
lightx2v/models/networks/cogvideox/weights/post_weights.py
+0
-31
lightx2v/models/networks/cogvideox/weights/pre_weights.py
lightx2v/models/networks/cogvideox/weights/pre_weights.py
+0
-30
lightx2v/models/networks/cogvideox/weights/transformers_weights.py
...models/networks/cogvideox/weights/transformers_weights.py
+0
-77
lightx2v/models/networks/hunyuan/__init__.py
lightx2v/models/networks/hunyuan/__init__.py
+0
-0
lightx2v/models/networks/hunyuan/infer/__init__.py
lightx2v/models/networks/hunyuan/infer/__init__.py
+0
-0
lightx2v/models/networks/hunyuan/infer/feature_caching/__init__.py
...models/networks/hunyuan/infer/feature_caching/__init__.py
+0
-0
lightx2v/models/networks/hunyuan/infer/feature_caching/transformer_infer.py
...tworks/hunyuan/infer/feature_caching/transformer_infer.py
+0
-604
lightx2v/models/networks/hunyuan/infer/feature_caching/utils.py
...2v/models/networks/hunyuan/infer/feature_caching/utils.py
+0
-51
lightx2v/models/networks/hunyuan/infer/post_infer.py
lightx2v/models/networks/hunyuan/infer/post_infer.py
+0
-33
lightx2v/models/networks/hunyuan/infer/pre_infer.py
lightx2v/models/networks/hunyuan/infer/pre_infer.py
+0
-157
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
+0
-385
lightx2v/models/networks/hunyuan/infer/utils.py
lightx2v/models/networks/hunyuan/infer/utils.py
+0
-9
lightx2v/models/networks/hunyuan/infer/utils_bf16.py
lightx2v/models/networks/hunyuan/infer/utils_bf16.py
+0
-32
lightx2v/models/networks/hunyuan/infer/utils_fp32.py
lightx2v/models/networks/hunyuan/infer/utils_fp32.py
+0
-36
lightx2v/models/networks/hunyuan/model.py
lightx2v/models/networks/hunyuan/model.py
+0
-170
lightx2v/models/networks/hunyuan/weights/__init__.py
lightx2v/models/networks/hunyuan/weights/__init__.py
+0
-0
lightx2v/models/networks/hunyuan/weights/post_weights.py
lightx2v/models/networks/hunyuan/weights/post_weights.py
+0
-11
lightx2v/models/networks/hunyuan/weights/pre_weights.py
lightx2v/models/networks/hunyuan/weights/pre_weights.py
+0
-84
lightx2v/models/networks/hunyuan/weights/transformer_weights.py
...2v/models/networks/hunyuan/weights/transformer_weights.py
+0
-71
lightx2v/models/runners/cogvideox/__init__.py
lightx2v/models/runners/cogvideox/__init__.py
+0
-0
No files found.
lightx2v/models/networks/cogvideox/weights/post_weights.py
deleted
100644 → 0
View file @
08d2f46a
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
from
lightx2v.common.ops.norm.layer_norm_weight
import
LNWeightTemplate
from
lightx2v.utils.registry_factory
import
LN_WEIGHT_REGISTER
,
MM_WEIGHT_REGISTER
class
CogvideoxPostWeights
:
def
__init__
(
self
,
config
,
mm_type
=
"Default"
):
self
.
config
=
config
self
.
mm_type
=
mm_type
def
load_weights
(
self
,
weight_dict
):
self
.
norm_out_linear
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
"norm_out.linear.weight"
,
"norm_out.linear.bias"
)
self
.
proj_out
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
"proj_out.weight"
,
"proj_out.bias"
)
self
.
norm_final
=
LN_WEIGHT_REGISTER
[
self
.
mm_type
](
"norm_final.weight"
,
"norm_final.bias"
)
self
.
norm_out_norm
=
LN_WEIGHT_REGISTER
[
self
.
mm_type
](
"norm_out.norm.weight"
,
"norm_out.norm.bias"
,
eps
=
1e-5
)
self
.
weight_list
=
[
self
.
norm_out_linear
,
self
.
proj_out
,
self
.
norm_final
,
self
.
norm_out_norm
]
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
)):
mm_weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
)):
mm_weight
.
to_cpu
()
def
to_cuda
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
)):
mm_weight
.
to_cuda
()
lightx2v/models/networks/cogvideox/weights/pre_weights.py
deleted
100644 → 0
View file @
08d2f46a
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
from
lightx2v.common.ops.norm.layer_norm_weight
import
LNWeightTemplate
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
class
CogvideoxPreWeights
:
def
__init__
(
self
,
config
):
self
.
config
=
config
def
load_weights
(
self
,
weight_dict
):
self
.
time_embedding_linear_1
=
MM_WEIGHT_REGISTER
[
"Default"
](
"time_embedding.linear_1.weight"
,
"time_embedding.linear_1.bias"
)
self
.
time_embedding_linear_2
=
MM_WEIGHT_REGISTER
[
"Default"
](
"time_embedding.linear_2.weight"
,
"time_embedding.linear_2.bias"
)
self
.
patch_embed_proj
=
MM_WEIGHT_REGISTER
[
"Default"
](
"patch_embed.proj.weight"
,
"patch_embed.proj.bias"
)
self
.
patch_embed_text_proj
=
MM_WEIGHT_REGISTER
[
"Default"
](
"patch_embed.text_proj.weight"
,
"patch_embed.text_proj.bias"
)
self
.
weight_list
=
[
self
.
time_embedding_linear_1
,
self
.
time_embedding_linear_2
,
self
.
patch_embed_proj
,
self
.
patch_embed_text_proj
]
for
mm_weight
in
self
.
weight_list
:
mm_weight
.
set_config
(
self
.
config
)
mm_weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
)):
mm_weight
.
to_cpu
()
def
to_cuda
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
)):
mm_weight
.
to_cuda
()
lightx2v/models/networks/cogvideox/weights/transformers_weights.py
deleted
100644 → 0
View file @
08d2f46a
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
from
lightx2v.common.ops.norm.layer_norm_weight
import
LNWeightTemplate
from
lightx2v.utils.registry_factory
import
LN_WEIGHT_REGISTER
,
MM_WEIGHT_REGISTER
class
CogvideoxTransformerWeights
:
def
__init__
(
self
,
config
,
task
=
"t2v"
,
mm_type
=
"Default"
):
self
.
config
=
config
self
.
task
=
task
self
.
mm_type
=
mm_type
self
.
init
()
def
init
(
self
):
self
.
num_layers
=
self
.
config
[
"num_layers"
]
def
load_weights
(
self
,
weight_dict
):
self
.
blocks_weights
=
[
CogVideoXBlock
(
i
,
self
.
task
,
self
.
mm_type
)
for
i
in
range
(
self
.
num_layers
)]
for
block
in
self
.
blocks_weights
:
block
.
load_weights
(
weight_dict
)
def
to_cpu
(
self
):
for
block
in
self
.
blocks_weights
:
block
.
to_cpu
()
def
to_cuda
(
self
):
for
block
in
self
.
blocks_weights
:
block
.
to_cuda
()
class
CogVideoXBlock
:
def
__init__
(
self
,
block_index
,
task
=
"t2v"
,
mm_type
=
"Default"
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
task
=
task
def
load_weights
(
self
,
weight_dict
):
self
.
attn1_to_k
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.to_k.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.to_k.bias"
)
self
.
attn1_to_q
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.to_q.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.to_q.bias"
)
self
.
attn1_to_v
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.to_v.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.to_v.bias"
)
self
.
attn1_to_out
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.to_out.0.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.to_out.0.bias"
)
self
.
ff_net_0_proj
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.ff.net.0.proj.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.ff.net.0.proj.bias"
)
self
.
ff_net_2_proj
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.ff.net.2.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.ff.net.2.bias"
)
self
.
norm1_linear
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.norm1.linear.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.norm1.linear.bias"
)
self
.
norm2_linear
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.norm2.linear.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.norm2.linear.bias"
)
self
.
attn1_norm_k
=
LN_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.norm_k.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.norm_k.bias"
)
self
.
attn1_norm_q
=
LN_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.norm_q.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.norm_q.bias"
)
self
.
norm1_norm
=
LN_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.norm1.norm.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.norm1.norm.bias"
,
eps
=
1e-05
)
self
.
norm2_norm
=
LN_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.norm2.norm.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.norm2.norm.bias"
,
eps
=
1e-05
)
self
.
weight_list
=
[
self
.
attn1_to_k
,
self
.
attn1_to_q
,
self
.
attn1_to_v
,
self
.
attn1_to_out
,
self
.
ff_net_0_proj
,
self
.
ff_net_2_proj
,
self
.
norm1_linear
,
self
.
norm2_linear
,
self
.
attn1_norm_k
,
self
.
attn1_norm_q
,
self
.
norm1_norm
,
self
.
norm2_norm
,
]
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
)):
mm_weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
)):
mm_weight
.
to_cpu
()
def
to_cuda
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
)):
mm_weight
.
to_cuda
()
lightx2v/models/networks/hunyuan/__init__.py
deleted
100755 → 0
View file @
08d2f46a
lightx2v/models/networks/hunyuan/infer/__init__.py
deleted
100755 → 0
View file @
08d2f46a
lightx2v/models/networks/hunyuan/infer/feature_caching/__init__.py
deleted
100755 → 0
View file @
08d2f46a
lightx2v/models/networks/hunyuan/infer/feature_caching/transformer_infer.py
deleted
100755 → 0
View file @
08d2f46a
This diff is collapsed.
Click to expand it.
lightx2v/models/networks/hunyuan/infer/feature_caching/utils.py
deleted
100755 → 0
View file @
08d2f46a
import
math
from
typing
import
Dict
import
torch
def
taylor_cache_init
(
cache_dic
:
Dict
,
current
:
Dict
):
"""
Initialize Taylor cache, expanding storage areas for Taylor series derivatives
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
if
current
[
"step"
]
==
0
:
cache_dic
[
"cache"
][
-
1
][
current
[
"stream"
]][
current
[
"layer"
]][
current
[
"module"
]]
=
{}
def
derivative_approximation
(
cache_dic
:
Dict
,
current
:
Dict
,
feature
:
torch
.
Tensor
):
"""
Compute derivative approximation
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
difference_distance
=
current
[
"activated_steps"
][
-
1
]
-
current
[
"activated_steps"
][
-
2
]
# difference_distance = current['activated_times'][-1] - current['activated_times'][-2]
updated_taylor_factors
=
{}
updated_taylor_factors
[
0
]
=
feature
for
i
in
range
(
cache_dic
[
"max_order"
]):
if
(
cache_dic
[
"cache"
][
-
1
][
current
[
"stream"
]][
current
[
"layer"
]][
current
[
"module"
]].
get
(
i
,
None
)
is
not
None
)
and
(
current
[
"step"
]
>
cache_dic
[
"first_enhance"
]
-
2
):
updated_taylor_factors
[
i
+
1
]
=
(
updated_taylor_factors
[
i
]
-
cache_dic
[
"cache"
][
-
1
][
current
[
"stream"
]][
current
[
"layer"
]][
current
[
"module"
]][
i
])
/
difference_distance
else
:
break
cache_dic
[
"cache"
][
-
1
][
current
[
"stream"
]][
current
[
"layer"
]][
current
[
"module"
]]
=
updated_taylor_factors
def
taylor_formula
(
cache_dic
:
Dict
,
current
:
Dict
)
->
torch
.
Tensor
:
"""
Compute Taylor expansion error
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
x
=
current
[
"step"
]
-
current
[
"activated_steps"
][
-
1
]
# x = current['t'] - current['activated_times'][-1]
output
=
0
for
i
in
range
(
len
(
cache_dic
[
"cache"
][
-
1
][
current
[
"stream"
]][
current
[
"layer"
]][
current
[
"module"
]])):
output
+=
(
1
/
math
.
factorial
(
i
))
*
cache_dic
[
"cache"
][
-
1
][
current
[
"stream"
]][
current
[
"layer"
]][
current
[
"module"
]][
i
]
*
(
x
**
i
)
return
output
lightx2v/models/networks/hunyuan/infer/post_infer.py
deleted
100755 → 0
View file @
08d2f46a
import
torch
class
HunyuanPostInfer
:
def
__init__
(
self
,
config
):
self
.
config
=
config
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
def
infer
(
self
,
weights
,
img
,
vec
):
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
weights
.
final_layer_adaLN_modulation_1
.
apply
(
out
)
shift
,
scale
=
out
.
chunk
(
2
,
dim
=
1
)
out
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
out
=
out
*
(
1
+
scale
)
+
shift
out
=
weights
.
final_layer_linear
.
apply
(
out
.
to
(
torch
.
float32
))
_
,
_
,
ot
,
oh
,
ow
=
self
.
scheduler
.
latents
.
shape
patch_size
=
[
1
,
2
,
2
]
tt
,
th
,
tw
=
(
ot
//
patch_size
[
0
],
oh
//
patch_size
[
1
],
ow
//
patch_size
[
2
],
)
c
=
16
pt
,
ph
,
pw
=
patch_size
out
=
out
.
reshape
(
shape
=
(
1
,
tt
,
th
,
tw
,
c
,
pt
,
ph
,
pw
))
out
=
torch
.
einsum
(
"nthwcopq->nctohpwq"
,
out
)
out
=
out
.
reshape
(
shape
=
(
1
,
c
,
tt
*
pt
,
th
*
ph
,
tw
*
pw
))
return
out
lightx2v/models/networks/hunyuan/infer/pre_infer.py
deleted
100755 → 0
View file @
08d2f46a
import
math
import
torch
from
einops
import
rearrange
from
lightx2v.utils.envs
import
*
class
HunyuanPreInfer
:
def
__init__
(
self
,
config
):
self
.
heads_num
=
24
self
.
config
=
config
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
def
infer
(
self
,
weights
,
inputs
):
x
=
self
.
scheduler
.
latents
t
=
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]
freqs_cos
=
self
.
scheduler
.
freqs_cos
freqs_sin
=
self
.
scheduler
.
freqs_sin
guidance
=
self
.
scheduler
.
guidance
text_states
=
inputs
[
"text_encoder_output"
][
"text_encoder_1_text_states"
]
text_mask
=
inputs
[
"text_encoder_output"
][
"text_encoder_1_attention_mask"
]
text_states_2
=
inputs
[
"text_encoder_output"
][
"text_encoder_2_text_states"
]
if
self
.
config
[
"task"
]
==
"i2v"
:
token_replace_t
=
torch
.
zeros_like
(
t
)
token_replace_vec
=
self
.
infer_time_in
(
weights
,
token_replace_t
)
th
=
x
.
shape
[
-
2
]
//
2
tw
=
x
.
shape
[
-
1
]
//
2
frist_frame_token_num
=
th
*
tw
time_out
=
self
.
infer_time_in
(
weights
,
t
)
img_out
=
self
.
infer_img_in
(
weights
,
x
)
infer_text_out
=
self
.
infer_text_in
(
weights
,
text_states
,
text_mask
,
t
)
infer_vector_out
=
self
.
infer_vector_in
(
weights
,
text_states_2
)
vec
=
time_out
+
infer_vector_out
if
self
.
config
[
"task"
]
==
"i2v"
:
token_replace_vec
=
token_replace_vec
+
infer_vector_out
guidance_out
=
self
.
infer_guidance_in
(
weights
,
guidance
)
vec
=
vec
+
guidance_out
txt_seq_len
=
infer_text_out
.
shape
[
0
]
img_seq_len
=
img_out
.
shape
[
1
]
batch_size
=
text_mask
.
shape
[
0
]
text_len
=
text_mask
.
sum
(
dim
=
1
)
max_len
=
text_mask
.
shape
[
1
]
+
img_seq_len
cu_seqlens_qkv
=
torch
.
zeros
([
2
*
batch_size
+
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
for
i
in
range
(
batch_size
):
s
=
text_len
[
i
]
+
img_seq_len
s1
=
i
*
max_len
+
s
s2
=
(
i
+
1
)
*
max_len
cu_seqlens_qkv
[
2
*
i
+
1
]
=
s1
cu_seqlens_qkv
[
2
*
i
+
2
]
=
s2
max_seqlen_qkv
=
img_seq_len
+
txt_seq_len
if
self
.
config
[
"task"
]
==
"i2v"
:
return
img_out
[
0
],
infer_text_out
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
(
freqs_cos
,
freqs_sin
),
token_replace_vec
,
frist_frame_token_num
return
img_out
[
0
],
infer_text_out
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
(
freqs_cos
,
freqs_sin
)
def
infer_time_in
(
self
,
weights
,
t
):
freqs
=
torch
.
exp
(
-
math
.
log
(
10000
)
*
torch
.
arange
(
start
=
0
,
end
=
128
,
dtype
=
torch
.
float32
)
/
128
).
to
(
device
=
t
.
device
)
args
=
t
.
unsqueeze
(
0
).
unsqueeze
(
0
).
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
).
to
(
dtype
=
GET_DTYPE
())
out
=
weights
.
time_in_mlp_0
.
apply
(
embedding
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
out
=
weights
.
time_in_mlp_2
.
apply
(
out
)
return
out
def
infer_img_in
(
self
,
weights
,
x
):
out
=
weights
.
img_in_proj
.
apply
(
x
)
out
=
out
.
flatten
(
2
).
transpose
(
1
,
2
)
return
out
def
infer_text_in
(
self
,
weights
,
text_states
,
text_mask
,
t
):
freqs
=
torch
.
exp
(
-
math
.
log
(
10000
)
*
torch
.
arange
(
start
=
0
,
end
=
128
,
dtype
=
torch
.
float32
)
/
128
).
to
(
device
=
t
.
device
)
args
=
t
.
unsqueeze
(
0
).
unsqueeze
(
0
).
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
).
to
(
dtype
=
GET_DTYPE
())
out
=
weights
.
txt_in_t_embedder_mlp_0
.
apply
(
embedding
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
timestep_aware_representations
=
weights
.
txt_in_t_embedder_mlp_2
.
apply
(
out
)
mask_float
=
text_mask
.
float
().
unsqueeze
(
-
1
).
to
(
GET_DTYPE
())
# [b, s1, 1]
context_aware_representations
=
(
text_states
*
mask_float
).
sum
(
dim
=
1
)
/
mask_float
.
sum
(
dim
=
1
)
context_aware_representations
=
context_aware_representations
out
=
weights
.
txt_in_c_embedder_linear_1
.
apply
(
context_aware_representations
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
context_aware_representations
=
weights
.
txt_in_c_embedder_linear_2
.
apply
(
out
)
c
=
timestep_aware_representations
+
context_aware_representations
txt_in_input_embed
=
weights
.
txt_in_input_embedder
.
apply
(
text_states
[
0
])
batch_size
=
text_mask
.
shape
[
0
]
seq_len
=
text_mask
.
shape
[
1
]
self_attn_mask_1
=
text_mask
.
view
(
batch_size
,
1
,
1
,
seq_len
).
repeat
(
1
,
1
,
seq_len
,
1
)
self_attn_mask_2
=
self_attn_mask_1
.
transpose
(
2
,
3
)
self_attn_mask
=
(
self_attn_mask_1
&
self_attn_mask_2
).
bool
()
self_attn_mask
[:,
:,
:,
0
]
=
True
cx
=
torch
.
nn
.
functional
.
silu
(
c
)
cx
=
weights
.
txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1
.
apply
(
cx
)
gate_msa
,
gate_mlp
=
cx
.
chunk
(
2
,
dim
=
1
)
normx
=
weights
.
txt_in_individual_token_refiner_blocks_0_norm1
.
apply
(
txt_in_input_embed
)
qkv
=
weights
.
txt_in_individual_token_refiner_blocks_0_self_attn_qkv
.
apply
(
normx
)
q
,
k
,
v
=
rearrange
(
qkv
.
unsqueeze
(
0
),
"B L (K H D) -> K B L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
attn
=
weights
.
txt_in_attn_1
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
attn_mask
=
self_attn_mask
)[
0
]
out
=
weights
.
txt_in_individual_token_refiner_blocks_0_self_attn_proj
.
apply
(
attn
)
out_1
=
txt_in_input_embed
+
out
*
gate_msa
out
=
weights
.
txt_in_individual_token_refiner_blocks_0_norm2
.
apply
(
out_1
)
# mlp
out
=
weights
.
txt_in_individual_token_refiner_blocks_0_mlp_fc1
.
apply
(
out
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
out
=
weights
.
txt_in_individual_token_refiner_blocks_0_mlp_fc2
.
apply
(
out
)
txt_in_input_embed
=
out_1
+
out
*
gate_mlp
cx
=
torch
.
nn
.
functional
.
silu
(
c
)
cx
=
weights
.
txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1
.
apply
(
cx
)
gate_msa
,
gate_mlp
=
cx
.
chunk
(
2
,
dim
=
1
)
normx
=
weights
.
txt_in_individual_token_refiner_blocks_1_norm1
.
apply
(
txt_in_input_embed
)
qkv
=
weights
.
txt_in_individual_token_refiner_blocks_1_self_attn_qkv
.
apply
(
normx
)
q
,
k
,
v
=
rearrange
(
qkv
.
unsqueeze
(
0
),
"B L (K H D) -> K B L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
attn
=
weights
.
txt_in_attn_1
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
attn_mask
=
self_attn_mask
)[
0
]
out
=
weights
.
txt_in_individual_token_refiner_blocks_1_self_attn_proj
.
apply
(
attn
)
out_1
=
txt_in_input_embed
+
out
*
gate_msa
out
=
weights
.
txt_in_individual_token_refiner_blocks_1_norm2
.
apply
(
out_1
)
# mlp
out
=
weights
.
txt_in_individual_token_refiner_blocks_1_mlp_fc1
.
apply
(
out
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
out
=
weights
.
txt_in_individual_token_refiner_blocks_1_mlp_fc2
.
apply
(
out
)
out
=
out_1
+
out
*
gate_mlp
return
out
def
infer_vector_in
(
self
,
weights
,
text_states_2
):
out
=
weights
.
vector_in_in_layer
.
apply
(
text_states_2
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
out
=
weights
.
vector_in_out_layer
.
apply
(
out
)
return
out
def
infer_guidance_in
(
self
,
weights
,
guidance
):
freqs
=
torch
.
exp
(
-
math
.
log
(
10000
)
*
torch
.
arange
(
start
=
0
,
end
=
128
,
dtype
=
torch
.
float32
)
/
128
).
to
(
device
=
guidance
.
device
)
args
=
guidance
.
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
).
to
(
dtype
=
GET_DTYPE
())
out
=
weights
.
guidance_in_mlp_0
.
apply
(
embedding
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
out
=
weights
.
guidance_in_mlp_2
.
apply
(
out
)
return
out
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
deleted
100755 → 0
View file @
08d2f46a
import
torch
from
einops
import
rearrange
from
lightx2v.common.offload.manager
import
WeightAsyncStreamManager
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.utils.envs
import
*
from
.utils_bf16
import
apply_rotary_emb
class
HunyuanTransformerInfer
(
BaseTransformerInfer
):
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
attention_type
=
config
.
get
(
"attention_type"
,
"flash_attn2"
)
self
.
double_blocks_num
=
20
self
.
single_blocks_num
=
40
self
.
heads_num
=
24
self
.
hidden_size
=
3072
self
.
mlp_hidden_dim
=
12288
self
.
parallel_attention
=
None
if
self
.
config
[
"cpu_offload"
]:
if
"offload_ratio"
in
self
.
config
:
offload_ratio
=
self
.
config
[
"offload_ratio"
]
else
:
offload_ratio
=
1
self
.
double_weights_stream_mgr
=
WeightAsyncStreamManager
(
blocks_num
=
self
.
double_blocks_num
,
offload_ratio
=
offload_ratio
)
self
.
single_weights_stream_mgr
=
WeightAsyncStreamManager
(
blocks_num
=
self
.
single_blocks_num
,
offload_ratio
=
offload_ratio
)
self
.
infer_func
=
self
.
_infer_with_offload
else
:
self
.
infer_func
=
self
.
_infer_without_offload
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
return
self
.
infer_func
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
def
_infer_with_offload
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
):
txt_seq_len
=
txt
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
for
double_block_idx
in
range
(
self
.
double_blocks_num
):
if
double_block_idx
==
0
:
self
.
double_weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
double_blocks
[
0
]
self
.
double_weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
with
torch
.
cuda
.
stream
(
self
.
double_weights_stream_mgr
.
compute_stream
):
img
,
txt
=
self
.
infer_double_block
(
self
.
double_weights_stream_mgr
.
active_weights
[
0
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
if
double_block_idx
<
self
.
double_blocks_num
-
1
:
self
.
double_weights_stream_mgr
.
prefetch_weights
(
double_block_idx
+
1
,
weights
.
double_blocks
)
self
.
double_weights_stream_mgr
.
swap_weights
()
x
=
torch
.
cat
((
img
,
txt
),
0
)
img
=
img
.
cpu
()
txt
=
txt
.
cpu
()
del
img
,
txt
torch
.
cuda
.
empty_cache
()
for
single_block_idx
in
range
(
self
.
single_blocks_num
):
if
single_block_idx
==
0
:
self
.
single_weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
single_blocks
[
0
]
self
.
single_weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
with
torch
.
cuda
.
stream
(
self
.
single_weights_stream_mgr
.
compute_stream
):
x
=
self
.
infer_single_block
(
self
.
single_weights_stream_mgr
.
active_weights
[
0
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
if
single_block_idx
<
self
.
single_blocks_num
-
1
:
self
.
single_weights_stream_mgr
.
prefetch_weights
(
single_block_idx
+
1
,
weights
.
single_blocks
)
self
.
single_weights_stream_mgr
.
swap_weights
()
torch
.
cuda
.
empty_cache
()
img
=
x
[:
img_seq_len
,
...]
return
img
,
vec
def
_infer_without_offload
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
):
txt_seq_len
=
txt
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
for
i
in
range
(
self
.
double_blocks_num
):
img
,
txt
=
self
.
infer_double_block
(
weights
.
double_blocks
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
x
=
torch
.
cat
((
img
,
txt
),
0
)
for
i
in
range
(
self
.
single_blocks_num
):
x
=
self
.
infer_single_block
(
weights
.
single_blocks
[
i
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
img
=
x
[:
img_seq_len
,
...]
return
img
,
vec
def
infer_double_block_phase_1
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
):
vec_silu
=
torch
.
nn
.
functional
.
silu
(
vec
)
img_mod_out
=
weights
.
img_mod
.
apply
(
vec_silu
)
(
img_mod1_shift
,
img_mod1_scale
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
)
=
img_mod_out
.
chunk
(
6
,
dim
=-
1
)
if
token_replace_vec
is
not
None
:
token_replace_vec_silu
=
torch
.
nn
.
functional
.
silu
(
token_replace_vec
)
token_replace_vec_img_mod_out
=
weights
.
img_mod
.
apply
(
token_replace_vec_silu
)
(
tr_img_mod1_shift
,
tr_img_mod1_scale
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
)
=
token_replace_vec_img_mod_out
.
chunk
(
6
,
dim
=-
1
)
else
:
(
tr_img_mod1_shift
,
tr_img_mod1_scale
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
)
=
None
,
None
,
None
,
None
,
None
,
None
txt_mod_out
=
weights
.
txt_mod
.
apply
(
vec_silu
)
(
txt_mod1_shift
,
txt_mod1_scale
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
=
txt_mod_out
.
chunk
(
6
,
dim
=-
1
)
img_q
,
img_k
,
img_v
=
self
.
infer_double_block_img_pre_atten
(
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
tr_img_mod1_scale
,
tr_img_mod1_shift
,
frist_frame_token_num
,
freqs_cis
)
txt_q
,
txt_k
,
txt_v
=
self
.
infer_double_block_txt_pre_atten
(
weights
,
txt
,
txt_mod1_scale
,
txt_mod1_shift
)
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
v
=
torch
.
cat
((
img_v
,
txt_v
),
dim
=
0
)
if
not
self
.
parallel_attention
:
attn
=
weights
.
double_attn
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
,
)
else
:
# world_size = dist.get_world_size()
attn
=
self
.
parallel_attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
img_qkv_len
=
img_q
.
shape
[
0
],
cu_seqlens_qkv
=
cu_seqlens_qkv
,
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
img_attn
,
txt_attn
=
attn
[:
img
.
shape
[
0
]],
attn
[
img
.
shape
[
0
]
:]
img_out
=
weights
.
img_attn_proj
.
apply
(
img_attn
)
txt_out
=
weights
.
txt_attn_proj
.
apply
(
txt_attn
)
return
(
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
def
infer_double_block_phase_2
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
,
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
):
if
tr_img_mod1_gate
is
not
None
:
x_zero
=
img_out
[:
frist_frame_token_num
]
*
tr_img_mod1_gate
x_orig
=
img_out
[
frist_frame_token_num
:]
*
img_mod1_gate
img_out
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
img_out
=
img_out
*
img_mod1_gate
img
=
img
+
img_out
img_out
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
if
tr_img_mod1_gate
is
not
None
:
x_zero
=
img_out
[:
frist_frame_token_num
]
*
(
1
+
tr_img_mod2_scale
)
+
tr_img_mod2_shift
x_orig
=
img_out
[
frist_frame_token_num
:]
*
(
1
+
img_mod2_scale
)
+
img_mod2_shift
img_out
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
img_out
=
img_out
*
(
1
+
img_mod2_scale
)
+
img_mod2_shift
img_out
=
weights
.
img_mlp_fc1
.
apply
(
img_out
)
img_out
=
torch
.
nn
.
functional
.
gelu
(
img_out
,
approximate
=
"tanh"
)
img_out
=
weights
.
img_mlp_fc2
.
apply
(
img_out
)
txt_out
=
txt_out
*
txt_mod1_gate
txt
=
txt
+
txt_out
txt_out
=
torch
.
nn
.
functional
.
layer_norm
(
txt
,
(
txt
.
shape
[
1
],),
None
,
None
,
1e-6
)
txt_out
=
txt_out
*
(
1
+
txt_mod2_scale
)
+
txt_mod2_shift
txt_out
=
weights
.
txt_mlp_fc1
.
apply
(
txt_out
)
txt_out
=
torch
.
nn
.
functional
.
gelu
(
txt_out
,
approximate
=
"tanh"
)
txt_out
=
weights
.
txt_mlp_fc2
.
apply
(
txt_out
)
return
img
,
txt
,
img_out
,
txt_out
,
img_mod2_gate
,
txt_mod2_gate
def
infer_double_block_phase_3
(
self
,
img_out
,
img_mod2_gate
,
img
,
txt_out
,
txt_mod2_gate
,
txt
):
# img
img_out
=
img_out
*
img_mod2_gate
img
=
img
+
img_out
# txt
txt_out
=
txt_out
*
txt_mod2_gate
txt
=
txt
+
txt_out
return
img
,
txt
def
infer_double_block
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
):
(
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
=
self
.
infer_double_block_phase_1
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
img
,
txt
,
img_out
,
txt_out
,
img_mod2_gate
,
txt_mod2_gate
=
self
.
infer_double_block_phase_2
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
,
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
img
,
txt
=
self
.
infer_double_block_phase_3
(
img_out
,
img_mod2_gate
,
img
,
txt_out
,
txt_mod2_gate
,
txt
)
return
img
,
txt
def
infer_double_block_img_pre_atten
(
self
,
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
tr_img_mod1_scale
,
tr_img_mod1_shift
,
frist_frame_token_num
,
freqs_cis
):
img_modulated
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
if
tr_img_mod1_scale
is
not
None
:
x_zero
=
img_modulated
[:
frist_frame_token_num
]
*
(
1
+
tr_img_mod1_scale
)
+
tr_img_mod1_shift
x_orig
=
img_modulated
[
frist_frame_token_num
:]
*
(
1
+
img_mod1_scale
)
+
img_mod1_shift
img_modulated
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
img_modulated
=
img_modulated
*
(
1
+
img_mod1_scale
)
+
img_mod1_shift
img_qkv
=
weights
.
img_attn_qkv
.
apply
(
img_modulated
)
img_q
,
img_k
,
img_v
=
rearrange
(
img_qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
img_q
=
weights
.
img_attn_q_norm
.
apply
(
img_q
)
img_k
=
weights
.
img_attn_k_norm
.
apply
(
img_k
)
img_q
,
img_k
=
apply_rotary_emb
(
img_q
,
img_k
,
freqs_cis
)
return
img_q
,
img_k
,
img_v
def
infer_double_block_txt_pre_atten
(
self
,
weights
,
txt
,
txt_mod1_scale
,
txt_mod1_shift
):
txt_modulated
=
torch
.
nn
.
functional
.
layer_norm
(
txt
,
(
txt
.
shape
[
1
],),
None
,
None
,
1e-6
)
txt_modulated
=
txt_modulated
*
(
1
+
txt_mod1_scale
)
+
txt_mod1_shift
txt_qkv
=
weights
.
txt_attn_qkv
.
apply
(
txt_modulated
)
txt_q
,
txt_k
,
txt_v
=
rearrange
(
txt_qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
txt_q
=
weights
.
txt_attn_q_norm
.
apply
(
txt_q
)
txt_k
=
weights
.
txt_attn_k_norm
.
apply
(
txt_k
)
return
txt_q
,
txt_k
,
txt_v
def
infer_single_block_phase_1
(
self
,
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
weights
.
modulation
.
apply
(
out
)
mod_shift
,
mod_scale
,
mod_gate
=
out
.
chunk
(
3
,
dim
=-
1
)
if
token_replace_vec
is
not
None
:
token_replace_vec_out
=
torch
.
nn
.
functional
.
silu
(
token_replace_vec
)
token_replace_vec_out
=
weights
.
modulation
.
apply
(
token_replace_vec_out
)
tr_mod_shift
,
tr_mod_scale
,
tr_mod_gate
=
token_replace_vec_out
.
chunk
(
3
,
dim
=-
1
)
else
:
tr_mod_shift
,
tr_mod_scale
,
tr_mod_gate
=
None
,
None
,
None
out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
if
token_replace_vec
is
not
None
:
x_zero
=
out
[:
frist_frame_token_num
]
*
(
1
+
tr_mod_scale
)
+
tr_mod_shift
x_orig
=
out
[
frist_frame_token_num
:]
*
(
1
+
mod_scale
)
+
mod_shift
x_mod
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
x_mod
=
out
*
(
1
+
mod_scale
)
+
mod_shift
x_mod
=
weights
.
linear1
.
apply
(
x_mod
)
qkv
,
mlp
=
torch
.
split
(
x_mod
,
[
3
*
self
.
hidden_size
,
self
.
mlp_hidden_dim
],
dim
=-
1
)
q
,
k
,
v
=
rearrange
(
qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
q
=
weights
.
q_norm
.
apply
(
q
)
k
=
weights
.
k_norm
.
apply
(
k
)
img_q
,
txt_q
=
q
[:
-
txt_seq_len
,
:,
:],
q
[
-
txt_seq_len
:,
:,
:]
img_k
,
txt_k
=
k
[:
-
txt_seq_len
,
:,
:],
k
[
-
txt_seq_len
:,
:,
:]
img_q
,
img_k
=
apply_rotary_emb
(
img_q
,
img_k
,
freqs_cis
)
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
if
not
self
.
parallel_attention
:
attn
=
weights
.
single_attn
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
,
)
else
:
attn
=
self
.
parallel_attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
img_qkv_len
=
img_q
.
shape
[
0
],
cu_seqlens_qkv
=
cu_seqlens_qkv
,
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
out
=
torch
.
nn
.
functional
.
gelu
(
mlp
,
approximate
=
"tanh"
)
out
=
torch
.
cat
((
attn
,
out
),
1
)
out
=
weights
.
linear2
.
apply
(
out
)
return
out
,
mod_gate
,
tr_mod_gate
def
infer_single_block_phase_2
(
self
,
x
,
out
,
tr_mod_gate
,
mod_gate
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
if
token_replace_vec
is
not
None
:
x_zero
=
out
[:
frist_frame_token_num
]
*
tr_mod_gate
x_orig
=
out
[
frist_frame_token_num
:]
*
mod_gate
out
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
out
=
out
*
mod_gate
x
=
x
+
out
return
x
def
infer_single_block
(
self
,
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
out
,
mod_gate
,
tr_mod_gate
=
self
.
infer_single_block_phase_1
(
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
x
=
self
.
infer_single_block_phase_2
(
x
,
out
,
tr_mod_gate
,
mod_gate
,
token_replace_vec
,
frist_frame_token_num
)
return
x
lightx2v/models/networks/hunyuan/infer/utils.py
deleted
100644 → 0
View file @
08d2f46a
import
sgl_kernel
def
rms_norm
(
x
,
weight
,
eps
):
x
=
x
.
contiguous
()
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
orig_shape
[
-
1
])
x
=
sgl_kernel
.
rmsnorm
(
x
,
weight
,
eps
).
view
(
orig_shape
)
return
x
lightx2v/models/networks/hunyuan/infer/utils_bf16.py
deleted
100644 → 0
View file @
08d2f46a
from
typing
import
Tuple
,
Union
import
torch
def
rms_norm
(
x
,
weight
,
eps
):
x
=
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
eps
)
x
=
x
*
weight
return
x
def
rotate_half
(
x
,
shape_0
,
shape_1
):
x_real
,
x_imag
=
x
.
reshape
(
shape_0
,
shape_1
,
-
1
,
2
).
unbind
(
-
1
)
return
torch
.
stack
([
-
x_imag
,
x_real
],
dim
=-
1
).
flatten
(
2
)
def
rotary_emb
(
x
,
shape_0
,
shape_1
,
cos
,
sin
):
x_out
=
x
*
cos
+
rotate_half
(
x
,
shape_0
,
shape_1
)
*
sin
return
x_out
def
apply_rotary_emb
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
freqs_cis
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
shape_0
,
shape_1
,
shape_2
=
xq
.
shape
cos
=
freqs_cis
[
0
].
view
(
shape_0
,
1
,
shape_2
)
sin
=
freqs_cis
[
1
].
view
(
shape_0
,
1
,
shape_2
)
xq_out
=
rotary_emb
(
xq
,
shape_0
,
shape_1
,
cos
,
sin
)
xk_out
=
rotary_emb
(
xk
,
shape_0
,
shape_1
,
cos
,
sin
)
return
xq_out
,
xk_out
lightx2v/models/networks/hunyuan/infer/utils_fp32.py
deleted
100644 → 0
View file @
08d2f46a
from
typing
import
Tuple
,
Union
import
torch
from
lightx2v.utils.envs
import
*
def
rms_norm
(
x
,
weight
,
eps
):
x
=
x
.
float
()
x
=
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
eps
)
x
=
x
.
to
(
GET_DTYPE
())
x
=
x
*
weight
return
x
def
rotate_half
(
x
,
shape_0
,
shape_1
):
x_real
,
x_imag
=
x
.
float
().
reshape
(
shape_0
,
shape_1
,
-
1
,
2
).
unbind
(
-
1
)
return
torch
.
stack
([
-
x_imag
,
x_real
],
dim
=-
1
).
flatten
(
2
)
def
rotary_emb
(
x
,
shape_0
,
shape_1
,
cos
,
sin
):
x_out
=
x
*
cos
+
rotate_half
(
x
,
shape_0
,
shape_1
)
*
sin
return
x_out
.
to
(
GET_DTYPE
())
def
apply_rotary_emb
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
freqs_cis
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
shape_0
,
shape_1
,
shape_2
=
xq
.
shape
cos
=
freqs_cis
[
0
].
view
(
shape_0
,
1
,
shape_2
)
sin
=
freqs_cis
[
1
].
view
(
shape_0
,
1
,
shape_2
)
xq_out
=
rotary_emb
(
xq
.
float
(),
shape_0
,
shape_1
,
cos
,
sin
)
xk_out
=
rotary_emb
(
xk
.
float
(),
shape_0
,
shape_1
,
cos
,
sin
)
return
xq_out
,
xk_out
lightx2v/models/networks/hunyuan/model.py
deleted
100755 → 0
View file @
08d2f46a
import
json
import
os
import
torch
from
loguru
import
logger
from
safetensors
import
safe_open
from
lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer
import
(
HunyuanTransformerInferAdaCaching
,
HunyuanTransformerInferCustomCaching
,
HunyuanTransformerInferTaylorCaching
,
HunyuanTransformerInferTeaCaching
,
)
from
lightx2v.models.networks.hunyuan.infer.post_infer
import
HunyuanPostInfer
from
lightx2v.models.networks.hunyuan.infer.pre_infer
import
HunyuanPreInfer
from
lightx2v.models.networks.hunyuan.infer.transformer_infer
import
HunyuanTransformerInfer
from
lightx2v.models.networks.hunyuan.weights.post_weights
import
HunyuanPostWeights
from
lightx2v.models.networks.hunyuan.weights.pre_weights
import
HunyuanPreWeights
from
lightx2v.models.networks.hunyuan.weights.transformer_weights
import
HunyuanTransformerWeights
from
lightx2v.utils.envs
import
*
class
HunyuanModel
:
pre_weight_class
=
HunyuanPreWeights
post_weight_class
=
HunyuanPostWeights
transformer_weight_class
=
HunyuanTransformerWeights
def
__init__
(
self
,
model_path
,
config
,
device
,
args
):
self
.
model_path
=
model_path
self
.
config
=
config
self
.
device
=
device
self
.
args
=
args
self
.
dit_quantized
=
self
.
config
.
mm_config
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
self
.
dit_quantized_ckpt
=
self
.
config
.
get
(
"dit_quantized_ckpt"
,
None
)
self
.
weight_auto_quant
=
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
if
self
.
dit_quantized
:
assert
self
.
weight_auto_quant
or
self
.
dit_quantized_ckpt
is
not
None
self
.
_init_infer_class
()
self
.
_init_weights
()
self
.
_init_infer
()
if
self
.
config
[
"cpu_offload"
]:
self
.
to_cpu
()
def
_load_ckpt
(
self
):
if
self
.
args
.
task
==
"t2v"
:
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
"hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
)
else
:
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
"hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt"
)
weight_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
self
.
device
,
weights_only
=
True
)[
"module"
]
return
weight_dict
def
_load_quant_ckpt
(
self
):
ckpt_path
=
self
.
dit_quantized_ckpt
logger
.
info
(
f
"Loading quant dit model from
{
ckpt_path
}
"
)
if
ckpt_path
.
endswith
(
".pth"
):
logger
.
info
(
f
"Loading
{
ckpt_path
}
as PyTorch model."
)
weight_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
self
.
device
,
weights_only
=
True
)
else
:
index_files
=
[
f
for
f
in
os
.
listdir
(
ckpt_path
)
if
f
.
endswith
(
".index.json"
)]
if
not
index_files
:
raise
FileNotFoundError
(
f
"No .pth file or *.index.json found in
{
ckpt_path
}
"
)
index_path
=
os
.
path
.
join
(
ckpt_path
,
index_files
[
0
])
logger
.
info
(
f
" Using safetensors index:
{
index_path
}
"
)
with
open
(
index_path
,
"r"
)
as
f
:
index_data
=
json
.
load
(
f
)
weight_dict
=
{}
for
filename
in
set
(
index_data
[
"weight_map"
].
values
()):
safetensor_path
=
os
.
path
.
join
(
ckpt_path
,
filename
)
with
safe_open
(
safetensor_path
,
framework
=
"pt"
,
device
=
str
(
self
.
device
))
as
f
:
logger
.
info
(
f
"Loading weights from
{
safetensor_path
}
"
)
for
k
in
f
.
keys
():
weight_dict
[
k
]
=
f
.
get_tensor
(
k
)
if
weight_dict
[
k
].
dtype
==
torch
.
float
:
weight_dict
[
k
]
=
weight_dict
[
k
].
to
(
GET_DTYPE
())
return
weight_dict
def
_init_weights
(
self
):
if
not
self
.
dit_quantized
or
self
.
weight_auto_quant
:
weight_dict
=
self
.
_load_ckpt
()
else
:
weight_dict
=
self
.
_load_quant_ckpt
()
# init weights
self
.
pre_weight
=
self
.
pre_weight_class
(
self
.
config
)
self
.
post_weight
=
self
.
post_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
# load weights
self
.
pre_weight
.
load
(
weight_dict
)
self
.
post_weight
.
load
(
weight_dict
)
self
.
transformer_weights
.
load
(
weight_dict
)
def
_init_infer
(
self
):
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
def
save_weights
(
self
,
save_path
):
if
not
os
.
path
.
exists
(
save_path
):
os
.
makedirs
(
save_path
)
pre_state_dict
=
self
.
pre_weight
.
state_dict
()
logger
.
info
(
pre_state_dict
.
keys
())
post_state_dict
=
self
.
post_weight
.
state_dict
()
logger
.
info
(
post_state_dict
.
keys
())
transformer_state_dict
=
self
.
transformer_weights
.
state_dict
()
logger
.
info
(
transformer_state_dict
.
keys
())
save_dict
=
{}
save_dict
.
update
(
pre_state_dict
)
save_dict
.
update
(
post_state_dict
)
save_dict
.
update
(
transformer_state_dict
)
save_path
=
os
.
path
.
join
(
save_path
,
"quant_weights.pth"
)
torch
.
save
(
save_dict
,
save_path
)
logger
.
info
(
f
"Save weights to
{
save_path
}
"
)
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
pre_infer
.
set_scheduler
(
scheduler
)
self
.
post_infer
.
set_scheduler
(
scheduler
)
self
.
transformer_infer
.
set_scheduler
(
scheduler
)
def
to_cpu
(
self
):
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
self
.
transformer_weights
.
to_cpu
()
def
to_cuda
(
self
):
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
self
.
transformer_weights
.
to_cuda
()
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
inputs
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
)
inputs
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
*
inputs
)
self
.
scheduler
.
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
*
inputs
)
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
HunyuanPreInfer
self
.
post_infer_class
=
HunyuanPostInfer
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
self
.
transformer_infer_class
=
HunyuanTransformerInfer
elif
self
.
config
[
"feature_caching"
]
==
"TaylorSeer"
:
self
.
transformer_infer_class
=
HunyuanTransformerInferTaylorCaching
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
HunyuanTransformerInferTeaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Ada"
:
self
.
transformer_infer_class
=
HunyuanTransformerInferAdaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Custom"
:
self
.
transformer_infer_class
=
HunyuanTransformerInferCustomCaching
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
lightx2v/models/networks/hunyuan/weights/__init__.py
deleted
100755 → 0
View file @
08d2f46a
lightx2v/models/networks/hunyuan/weights/post_weights.py
deleted
100755 → 0
View file @
08d2f46a
from
lightx2v.common.modules.weight_module
import
WeightModule
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
class
HunyuanPostWeights
(
WeightModule
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
add_module
(
"final_layer_linear"
,
MM_WEIGHT_REGISTER
[
"Default-Force-FP32"
](
"final_layer.linear.weight"
,
"final_layer.linear.bias"
))
self
.
add_module
(
"final_layer_adaLN_modulation_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"final_layer.adaLN_modulation.1.weight"
,
"final_layer.adaLN_modulation.1.bias"
))
lightx2v/models/networks/hunyuan/weights/pre_weights.py
deleted
100755 → 0
View file @
08d2f46a
from
lightx2v.common.modules.weight_module
import
WeightModule
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
,
CONV3D_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
MM_WEIGHT_REGISTER
class
HunyuanPreWeights
(
WeightModule
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
add_module
(
"img_in_proj"
,
CONV3D_WEIGHT_REGISTER
[
"Default"
](
"img_in.proj.weight"
,
"img_in.proj.bias"
,
stride
=
(
1
,
2
,
2
)))
self
.
add_module
(
"txt_in_input_embedder"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.input_embedder.weight"
,
"txt_in.input_embedder.bias"
))
self
.
add_module
(
"txt_in_t_embedder_mlp_0"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.t_embedder.mlp.0.weight"
,
"txt_in.t_embedder.mlp.0.bias"
))
self
.
add_module
(
"txt_in_t_embedder_mlp_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.t_embedder.mlp.2.weight"
,
"txt_in.t_embedder.mlp.2.bias"
))
self
.
add_module
(
"txt_in_c_embedder_linear_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.c_embedder.linear_1.weight"
,
"txt_in.c_embedder.linear_1.bias"
))
self
.
add_module
(
"txt_in_c_embedder_linear_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.c_embedder.linear_2.weight"
,
"txt_in.c_embedder.linear_2.bias"
))
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_norm1"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.norm1.weight"
,
"txt_in.individual_token_refiner.blocks.0.norm1.bias"
,
eps
=
1e-6
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_self_attn_qkv"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.self_attn_qkv.weight"
,
"txt_in.individual_token_refiner.blocks.0.self_attn_qkv.bias"
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_self_attn_proj"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.self_attn_proj.weight"
,
"txt_in.individual_token_refiner.blocks.0.self_attn_proj.bias"
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_norm2"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.norm2.weight"
,
"txt_in.individual_token_refiner.blocks.0.norm2.bias"
,
eps
=
1e-6
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_mlp_fc1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.mlp.fc1.weight"
,
"txt_in.individual_token_refiner.blocks.0.mlp.fc1.bias"
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_mlp_fc2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.mlp.fc2.weight"
,
"txt_in.individual_token_refiner.blocks.0.mlp.fc2.bias"
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.weight"
,
"txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias"
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_norm1"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.norm1.weight"
,
"txt_in.individual_token_refiner.blocks.1.norm1.bias"
,
eps
=
1e-6
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_self_attn_qkv"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight"
,
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.bias"
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_self_attn_proj"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.self_attn_proj.weight"
,
"txt_in.individual_token_refiner.blocks.1.self_attn_proj.bias"
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_norm2"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.norm2.weight"
,
"txt_in.individual_token_refiner.blocks.1.norm2.bias"
,
eps
=
1e-6
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_mlp_fc1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.mlp.fc1.weight"
,
"txt_in.individual_token_refiner.blocks.1.mlp.fc1.bias"
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_mlp_fc2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.mlp.fc2.weight"
,
"txt_in.individual_token_refiner.blocks.1.mlp.fc2.bias"
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.weight"
,
"txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.bias"
),
)
self
.
add_module
(
"time_in_mlp_0"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_in.mlp.0.weight"
,
"time_in.mlp.0.bias"
))
self
.
add_module
(
"time_in_mlp_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_in.mlp.2.weight"
,
"time_in.mlp.2.bias"
))
self
.
add_module
(
"vector_in_in_layer"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"vector_in.in_layer.weight"
,
"vector_in.in_layer.bias"
))
self
.
add_module
(
"vector_in_out_layer"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"vector_in.out_layer.weight"
,
"vector_in.out_layer.bias"
))
self
.
add_module
(
"guidance_in_mlp_0"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_in.mlp.0.weight"
,
"guidance_in.mlp.0.bias"
))
self
.
add_module
(
"guidance_in_mlp_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_in.mlp.2.weight"
,
"guidance_in.mlp.2.bias"
))
# attention weights section
self
.
add_module
(
"txt_in_attn_1"
,
ATTN_WEIGHT_REGISTER
[
"torch_sdpa"
]())
lightx2v/models/networks/hunyuan/weights/transformer_weights.py
deleted
100755 → 0
View file @
08d2f46a
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
,
MM_WEIGHT_REGISTER
,
RMS_WEIGHT_REGISTER
class
HunyuanTransformerWeights
(
WeightModule
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
double_blocks_num
=
20
self
.
single_blocks_num
=
40
self
.
add_module
(
"double_blocks"
,
WeightModuleList
([
HunyuanTransformerDoubleBlock
(
i
,
self
.
config
)
for
i
in
range
(
self
.
double_blocks_num
)]))
self
.
add_module
(
"single_blocks"
,
WeightModuleList
([
HunyuanTransformerSingleBlock
(
i
,
self
.
config
)
for
i
in
range
(
self
.
single_blocks_num
)]))
class
HunyuanTransformerDoubleBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
config
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
config
=
config
if
self
.
config
[
"do_mm_calib"
]:
mm_type
=
"Calib"
else
:
mm_type
=
self
.
config
[
"mm_config"
].
get
(
"mm_type"
,
"Default"
)
if
self
.
config
[
"mm_config"
]
else
"Default"
self
.
add_module
(
"img_mod"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_mod.linear.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_mod.linear.bias"
))
self
.
add_module
(
"img_attn_qkv"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_qkv.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_attn_qkv.bias"
))
self
.
add_module
(
"img_attn_q_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_q_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"img_attn_k_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_k_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"img_attn_proj"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_proj.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_attn_proj.bias"
))
self
.
add_module
(
"img_mlp_fc1"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc1.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc1.bias"
))
self
.
add_module
(
"img_mlp_fc2"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc2.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc2.bias"
))
self
.
add_module
(
"txt_mod"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mod.linear.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mod.linear.bias"
))
self
.
add_module
(
"txt_attn_qkv"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_qkv.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_qkv.bias"
))
self
.
add_module
(
"txt_attn_q_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_q_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"txt_attn_k_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_k_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"txt_attn_proj"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_proj.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_proj.bias"
))
self
.
add_module
(
"txt_mlp_fc1"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc1.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc1.bias"
))
self
.
add_module
(
"txt_mlp_fc2"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc2.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc2.bias"
))
# attention weights section
self
.
add_module
(
"double_attn"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attention_type"
]]())
class
HunyuanTransformerSingleBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
config
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
config
=
config
self
.
sparge
=
config
.
get
(
"sparge"
,
False
)
if
self
.
config
[
"do_mm_calib"
]:
mm_type
=
"Calib"
else
:
mm_type
=
self
.
config
[
"mm_config"
].
get
(
"mm_type"
,
"Default"
)
if
self
.
config
[
"mm_config"
]
else
"Default"
self
.
add_module
(
"linear1"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.linear1.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.linear1.bias"
))
self
.
add_module
(
"linear2"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.linear2.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.linear2.bias"
))
self
.
add_module
(
"q_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"single_blocks.
{
self
.
block_index
}
.q_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"k_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"single_blocks.
{
self
.
block_index
}
.k_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"modulation"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.modulation.linear.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.modulation.linear.bias"
))
# attention weights section
if
self
.
sparge
:
# load sparge attention weights
#! todo
pass
else
:
self
.
add_module
(
"single_attn"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attention_type"
]]())
lightx2v/models/runners/cogvideox/__init__.py
deleted
100644 → 0
View file @
08d2f46a
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment