Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
69c2f650
Unverified
Commit
69c2f650
authored
Oct 01, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Oct 01, 2025
Browse files
Remove outdated models (#348)
parent
08d2f46a
Changes
59
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
1781 deletions
+0
-1781
lightx2v/models/networks/cogvideox/weights/post_weights.py
lightx2v/models/networks/cogvideox/weights/post_weights.py
+0
-31
lightx2v/models/networks/cogvideox/weights/pre_weights.py
lightx2v/models/networks/cogvideox/weights/pre_weights.py
+0
-30
lightx2v/models/networks/cogvideox/weights/transformers_weights.py
...models/networks/cogvideox/weights/transformers_weights.py
+0
-77
lightx2v/models/networks/hunyuan/__init__.py
lightx2v/models/networks/hunyuan/__init__.py
+0
-0
lightx2v/models/networks/hunyuan/infer/__init__.py
lightx2v/models/networks/hunyuan/infer/__init__.py
+0
-0
lightx2v/models/networks/hunyuan/infer/feature_caching/__init__.py
...models/networks/hunyuan/infer/feature_caching/__init__.py
+0
-0
lightx2v/models/networks/hunyuan/infer/feature_caching/transformer_infer.py
...tworks/hunyuan/infer/feature_caching/transformer_infer.py
+0
-604
lightx2v/models/networks/hunyuan/infer/feature_caching/utils.py
...2v/models/networks/hunyuan/infer/feature_caching/utils.py
+0
-51
lightx2v/models/networks/hunyuan/infer/post_infer.py
lightx2v/models/networks/hunyuan/infer/post_infer.py
+0
-33
lightx2v/models/networks/hunyuan/infer/pre_infer.py
lightx2v/models/networks/hunyuan/infer/pre_infer.py
+0
-157
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
+0
-385
lightx2v/models/networks/hunyuan/infer/utils.py
lightx2v/models/networks/hunyuan/infer/utils.py
+0
-9
lightx2v/models/networks/hunyuan/infer/utils_bf16.py
lightx2v/models/networks/hunyuan/infer/utils_bf16.py
+0
-32
lightx2v/models/networks/hunyuan/infer/utils_fp32.py
lightx2v/models/networks/hunyuan/infer/utils_fp32.py
+0
-36
lightx2v/models/networks/hunyuan/model.py
lightx2v/models/networks/hunyuan/model.py
+0
-170
lightx2v/models/networks/hunyuan/weights/__init__.py
lightx2v/models/networks/hunyuan/weights/__init__.py
+0
-0
lightx2v/models/networks/hunyuan/weights/post_weights.py
lightx2v/models/networks/hunyuan/weights/post_weights.py
+0
-11
lightx2v/models/networks/hunyuan/weights/pre_weights.py
lightx2v/models/networks/hunyuan/weights/pre_weights.py
+0
-84
lightx2v/models/networks/hunyuan/weights/transformer_weights.py
...2v/models/networks/hunyuan/weights/transformer_weights.py
+0
-71
lightx2v/models/runners/cogvideox/__init__.py
lightx2v/models/runners/cogvideox/__init__.py
+0
-0
No files found.
lightx2v/models/networks/cogvideox/weights/post_weights.py
deleted
100644 → 0
View file @
08d2f46a
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
from
lightx2v.common.ops.norm.layer_norm_weight
import
LNWeightTemplate
from
lightx2v.utils.registry_factory
import
LN_WEIGHT_REGISTER
,
MM_WEIGHT_REGISTER
class
CogvideoxPostWeights
:
def
__init__
(
self
,
config
,
mm_type
=
"Default"
):
self
.
config
=
config
self
.
mm_type
=
mm_type
def
load_weights
(
self
,
weight_dict
):
self
.
norm_out_linear
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
"norm_out.linear.weight"
,
"norm_out.linear.bias"
)
self
.
proj_out
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
"proj_out.weight"
,
"proj_out.bias"
)
self
.
norm_final
=
LN_WEIGHT_REGISTER
[
self
.
mm_type
](
"norm_final.weight"
,
"norm_final.bias"
)
self
.
norm_out_norm
=
LN_WEIGHT_REGISTER
[
self
.
mm_type
](
"norm_out.norm.weight"
,
"norm_out.norm.bias"
,
eps
=
1e-5
)
self
.
weight_list
=
[
self
.
norm_out_linear
,
self
.
proj_out
,
self
.
norm_final
,
self
.
norm_out_norm
]
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
)):
mm_weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
)):
mm_weight
.
to_cpu
()
def
to_cuda
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
)):
mm_weight
.
to_cuda
()
lightx2v/models/networks/cogvideox/weights/pre_weights.py
deleted
100644 → 0
View file @
08d2f46a
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
from
lightx2v.common.ops.norm.layer_norm_weight
import
LNWeightTemplate
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
class
CogvideoxPreWeights
:
def
__init__
(
self
,
config
):
self
.
config
=
config
def
load_weights
(
self
,
weight_dict
):
self
.
time_embedding_linear_1
=
MM_WEIGHT_REGISTER
[
"Default"
](
"time_embedding.linear_1.weight"
,
"time_embedding.linear_1.bias"
)
self
.
time_embedding_linear_2
=
MM_WEIGHT_REGISTER
[
"Default"
](
"time_embedding.linear_2.weight"
,
"time_embedding.linear_2.bias"
)
self
.
patch_embed_proj
=
MM_WEIGHT_REGISTER
[
"Default"
](
"patch_embed.proj.weight"
,
"patch_embed.proj.bias"
)
self
.
patch_embed_text_proj
=
MM_WEIGHT_REGISTER
[
"Default"
](
"patch_embed.text_proj.weight"
,
"patch_embed.text_proj.bias"
)
self
.
weight_list
=
[
self
.
time_embedding_linear_1
,
self
.
time_embedding_linear_2
,
self
.
patch_embed_proj
,
self
.
patch_embed_text_proj
]
for
mm_weight
in
self
.
weight_list
:
mm_weight
.
set_config
(
self
.
config
)
mm_weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
)):
mm_weight
.
to_cpu
()
def
to_cuda
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
)):
mm_weight
.
to_cuda
()
lightx2v/models/networks/cogvideox/weights/transformers_weights.py
deleted
100644 → 0
View file @
08d2f46a
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
from
lightx2v.common.ops.norm.layer_norm_weight
import
LNWeightTemplate
from
lightx2v.utils.registry_factory
import
LN_WEIGHT_REGISTER
,
MM_WEIGHT_REGISTER
class
CogvideoxTransformerWeights
:
def
__init__
(
self
,
config
,
task
=
"t2v"
,
mm_type
=
"Default"
):
self
.
config
=
config
self
.
task
=
task
self
.
mm_type
=
mm_type
self
.
init
()
def
init
(
self
):
self
.
num_layers
=
self
.
config
[
"num_layers"
]
def
load_weights
(
self
,
weight_dict
):
self
.
blocks_weights
=
[
CogVideoXBlock
(
i
,
self
.
task
,
self
.
mm_type
)
for
i
in
range
(
self
.
num_layers
)]
for
block
in
self
.
blocks_weights
:
block
.
load_weights
(
weight_dict
)
def
to_cpu
(
self
):
for
block
in
self
.
blocks_weights
:
block
.
to_cpu
()
def
to_cuda
(
self
):
for
block
in
self
.
blocks_weights
:
block
.
to_cuda
()
class
CogVideoXBlock
:
def
__init__
(
self
,
block_index
,
task
=
"t2v"
,
mm_type
=
"Default"
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
task
=
task
def
load_weights
(
self
,
weight_dict
):
self
.
attn1_to_k
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.to_k.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.to_k.bias"
)
self
.
attn1_to_q
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.to_q.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.to_q.bias"
)
self
.
attn1_to_v
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.to_v.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.to_v.bias"
)
self
.
attn1_to_out
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.to_out.0.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.to_out.0.bias"
)
self
.
ff_net_0_proj
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.ff.net.0.proj.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.ff.net.0.proj.bias"
)
self
.
ff_net_2_proj
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.ff.net.2.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.ff.net.2.bias"
)
self
.
norm1_linear
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.norm1.linear.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.norm1.linear.bias"
)
self
.
norm2_linear
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.norm2.linear.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.norm2.linear.bias"
)
self
.
attn1_norm_k
=
LN_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.norm_k.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.norm_k.bias"
)
self
.
attn1_norm_q
=
LN_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.norm_q.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.attn1.norm_q.bias"
)
self
.
norm1_norm
=
LN_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.norm1.norm.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.norm1.norm.bias"
,
eps
=
1e-05
)
self
.
norm2_norm
=
LN_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"transformer_blocks.
{
self
.
block_index
}
.norm2.norm.weight"
,
f
"transformer_blocks.
{
self
.
block_index
}
.norm2.norm.bias"
,
eps
=
1e-05
)
self
.
weight_list
=
[
self
.
attn1_to_k
,
self
.
attn1_to_q
,
self
.
attn1_to_v
,
self
.
attn1_to_out
,
self
.
ff_net_0_proj
,
self
.
ff_net_2_proj
,
self
.
norm1_linear
,
self
.
norm2_linear
,
self
.
attn1_norm_k
,
self
.
attn1_norm_q
,
self
.
norm1_norm
,
self
.
norm2_norm
,
]
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
)):
mm_weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
)):
mm_weight
.
to_cpu
()
def
to_cuda
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
)):
mm_weight
.
to_cuda
()
lightx2v/models/networks/hunyuan/__init__.py
deleted
100755 → 0
View file @
08d2f46a
lightx2v/models/networks/hunyuan/infer/__init__.py
deleted
100755 → 0
View file @
08d2f46a
lightx2v/models/networks/hunyuan/infer/feature_caching/__init__.py
deleted
100755 → 0
View file @
08d2f46a
lightx2v/models/networks/hunyuan/infer/feature_caching/transformer_infer.py
deleted
100755 → 0
View file @
08d2f46a
This diff is collapsed.
Click to expand it.
lightx2v/models/networks/hunyuan/infer/feature_caching/utils.py
deleted
100755 → 0
View file @
08d2f46a
import
math
from
typing
import
Dict
import
torch
def
taylor_cache_init
(
cache_dic
:
Dict
,
current
:
Dict
):
"""
Initialize Taylor cache, expanding storage areas for Taylor series derivatives
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
if
current
[
"step"
]
==
0
:
cache_dic
[
"cache"
][
-
1
][
current
[
"stream"
]][
current
[
"layer"
]][
current
[
"module"
]]
=
{}
def
derivative_approximation
(
cache_dic
:
Dict
,
current
:
Dict
,
feature
:
torch
.
Tensor
):
"""
Compute derivative approximation
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
difference_distance
=
current
[
"activated_steps"
][
-
1
]
-
current
[
"activated_steps"
][
-
2
]
# difference_distance = current['activated_times'][-1] - current['activated_times'][-2]
updated_taylor_factors
=
{}
updated_taylor_factors
[
0
]
=
feature
for
i
in
range
(
cache_dic
[
"max_order"
]):
if
(
cache_dic
[
"cache"
][
-
1
][
current
[
"stream"
]][
current
[
"layer"
]][
current
[
"module"
]].
get
(
i
,
None
)
is
not
None
)
and
(
current
[
"step"
]
>
cache_dic
[
"first_enhance"
]
-
2
):
updated_taylor_factors
[
i
+
1
]
=
(
updated_taylor_factors
[
i
]
-
cache_dic
[
"cache"
][
-
1
][
current
[
"stream"
]][
current
[
"layer"
]][
current
[
"module"
]][
i
])
/
difference_distance
else
:
break
cache_dic
[
"cache"
][
-
1
][
current
[
"stream"
]][
current
[
"layer"
]][
current
[
"module"
]]
=
updated_taylor_factors
def
taylor_formula
(
cache_dic
:
Dict
,
current
:
Dict
)
->
torch
.
Tensor
:
"""
Compute Taylor expansion error
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
x
=
current
[
"step"
]
-
current
[
"activated_steps"
][
-
1
]
# x = current['t'] - current['activated_times'][-1]
output
=
0
for
i
in
range
(
len
(
cache_dic
[
"cache"
][
-
1
][
current
[
"stream"
]][
current
[
"layer"
]][
current
[
"module"
]])):
output
+=
(
1
/
math
.
factorial
(
i
))
*
cache_dic
[
"cache"
][
-
1
][
current
[
"stream"
]][
current
[
"layer"
]][
current
[
"module"
]][
i
]
*
(
x
**
i
)
return
output
lightx2v/models/networks/hunyuan/infer/post_infer.py
deleted
100755 → 0
View file @
08d2f46a
import
torch
class
HunyuanPostInfer
:
def
__init__
(
self
,
config
):
self
.
config
=
config
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
def
infer
(
self
,
weights
,
img
,
vec
):
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
weights
.
final_layer_adaLN_modulation_1
.
apply
(
out
)
shift
,
scale
=
out
.
chunk
(
2
,
dim
=
1
)
out
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
out
=
out
*
(
1
+
scale
)
+
shift
out
=
weights
.
final_layer_linear
.
apply
(
out
.
to
(
torch
.
float32
))
_
,
_
,
ot
,
oh
,
ow
=
self
.
scheduler
.
latents
.
shape
patch_size
=
[
1
,
2
,
2
]
tt
,
th
,
tw
=
(
ot
//
patch_size
[
0
],
oh
//
patch_size
[
1
],
ow
//
patch_size
[
2
],
)
c
=
16
pt
,
ph
,
pw
=
patch_size
out
=
out
.
reshape
(
shape
=
(
1
,
tt
,
th
,
tw
,
c
,
pt
,
ph
,
pw
))
out
=
torch
.
einsum
(
"nthwcopq->nctohpwq"
,
out
)
out
=
out
.
reshape
(
shape
=
(
1
,
c
,
tt
*
pt
,
th
*
ph
,
tw
*
pw
))
return
out
lightx2v/models/networks/hunyuan/infer/pre_infer.py
deleted
100755 → 0
View file @
08d2f46a
import
math
import
torch
from
einops
import
rearrange
from
lightx2v.utils.envs
import
*
class
HunyuanPreInfer
:
def
__init__
(
self
,
config
):
self
.
heads_num
=
24
self
.
config
=
config
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
def
infer
(
self
,
weights
,
inputs
):
x
=
self
.
scheduler
.
latents
t
=
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]
freqs_cos
=
self
.
scheduler
.
freqs_cos
freqs_sin
=
self
.
scheduler
.
freqs_sin
guidance
=
self
.
scheduler
.
guidance
text_states
=
inputs
[
"text_encoder_output"
][
"text_encoder_1_text_states"
]
text_mask
=
inputs
[
"text_encoder_output"
][
"text_encoder_1_attention_mask"
]
text_states_2
=
inputs
[
"text_encoder_output"
][
"text_encoder_2_text_states"
]
if
self
.
config
[
"task"
]
==
"i2v"
:
token_replace_t
=
torch
.
zeros_like
(
t
)
token_replace_vec
=
self
.
infer_time_in
(
weights
,
token_replace_t
)
th
=
x
.
shape
[
-
2
]
//
2
tw
=
x
.
shape
[
-
1
]
//
2
frist_frame_token_num
=
th
*
tw
time_out
=
self
.
infer_time_in
(
weights
,
t
)
img_out
=
self
.
infer_img_in
(
weights
,
x
)
infer_text_out
=
self
.
infer_text_in
(
weights
,
text_states
,
text_mask
,
t
)
infer_vector_out
=
self
.
infer_vector_in
(
weights
,
text_states_2
)
vec
=
time_out
+
infer_vector_out
if
self
.
config
[
"task"
]
==
"i2v"
:
token_replace_vec
=
token_replace_vec
+
infer_vector_out
guidance_out
=
self
.
infer_guidance_in
(
weights
,
guidance
)
vec
=
vec
+
guidance_out
txt_seq_len
=
infer_text_out
.
shape
[
0
]
img_seq_len
=
img_out
.
shape
[
1
]
batch_size
=
text_mask
.
shape
[
0
]
text_len
=
text_mask
.
sum
(
dim
=
1
)
max_len
=
text_mask
.
shape
[
1
]
+
img_seq_len
cu_seqlens_qkv
=
torch
.
zeros
([
2
*
batch_size
+
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
for
i
in
range
(
batch_size
):
s
=
text_len
[
i
]
+
img_seq_len
s1
=
i
*
max_len
+
s
s2
=
(
i
+
1
)
*
max_len
cu_seqlens_qkv
[
2
*
i
+
1
]
=
s1
cu_seqlens_qkv
[
2
*
i
+
2
]
=
s2
max_seqlen_qkv
=
img_seq_len
+
txt_seq_len
if
self
.
config
[
"task"
]
==
"i2v"
:
return
img_out
[
0
],
infer_text_out
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
(
freqs_cos
,
freqs_sin
),
token_replace_vec
,
frist_frame_token_num
return
img_out
[
0
],
infer_text_out
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
(
freqs_cos
,
freqs_sin
)
def
infer_time_in
(
self
,
weights
,
t
):
freqs
=
torch
.
exp
(
-
math
.
log
(
10000
)
*
torch
.
arange
(
start
=
0
,
end
=
128
,
dtype
=
torch
.
float32
)
/
128
).
to
(
device
=
t
.
device
)
args
=
t
.
unsqueeze
(
0
).
unsqueeze
(
0
).
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
).
to
(
dtype
=
GET_DTYPE
())
out
=
weights
.
time_in_mlp_0
.
apply
(
embedding
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
out
=
weights
.
time_in_mlp_2
.
apply
(
out
)
return
out
def
infer_img_in
(
self
,
weights
,
x
):
out
=
weights
.
img_in_proj
.
apply
(
x
)
out
=
out
.
flatten
(
2
).
transpose
(
1
,
2
)
return
out
def
infer_text_in
(
self
,
weights
,
text_states
,
text_mask
,
t
):
freqs
=
torch
.
exp
(
-
math
.
log
(
10000
)
*
torch
.
arange
(
start
=
0
,
end
=
128
,
dtype
=
torch
.
float32
)
/
128
).
to
(
device
=
t
.
device
)
args
=
t
.
unsqueeze
(
0
).
unsqueeze
(
0
).
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
).
to
(
dtype
=
GET_DTYPE
())
out
=
weights
.
txt_in_t_embedder_mlp_0
.
apply
(
embedding
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
timestep_aware_representations
=
weights
.
txt_in_t_embedder_mlp_2
.
apply
(
out
)
mask_float
=
text_mask
.
float
().
unsqueeze
(
-
1
).
to
(
GET_DTYPE
())
# [b, s1, 1]
context_aware_representations
=
(
text_states
*
mask_float
).
sum
(
dim
=
1
)
/
mask_float
.
sum
(
dim
=
1
)
context_aware_representations
=
context_aware_representations
out
=
weights
.
txt_in_c_embedder_linear_1
.
apply
(
context_aware_representations
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
context_aware_representations
=
weights
.
txt_in_c_embedder_linear_2
.
apply
(
out
)
c
=
timestep_aware_representations
+
context_aware_representations
txt_in_input_embed
=
weights
.
txt_in_input_embedder
.
apply
(
text_states
[
0
])
batch_size
=
text_mask
.
shape
[
0
]
seq_len
=
text_mask
.
shape
[
1
]
self_attn_mask_1
=
text_mask
.
view
(
batch_size
,
1
,
1
,
seq_len
).
repeat
(
1
,
1
,
seq_len
,
1
)
self_attn_mask_2
=
self_attn_mask_1
.
transpose
(
2
,
3
)
self_attn_mask
=
(
self_attn_mask_1
&
self_attn_mask_2
).
bool
()
self_attn_mask
[:,
:,
:,
0
]
=
True
cx
=
torch
.
nn
.
functional
.
silu
(
c
)
cx
=
weights
.
txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1
.
apply
(
cx
)
gate_msa
,
gate_mlp
=
cx
.
chunk
(
2
,
dim
=
1
)
normx
=
weights
.
txt_in_individual_token_refiner_blocks_0_norm1
.
apply
(
txt_in_input_embed
)
qkv
=
weights
.
txt_in_individual_token_refiner_blocks_0_self_attn_qkv
.
apply
(
normx
)
q
,
k
,
v
=
rearrange
(
qkv
.
unsqueeze
(
0
),
"B L (K H D) -> K B L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
attn
=
weights
.
txt_in_attn_1
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
attn_mask
=
self_attn_mask
)[
0
]
out
=
weights
.
txt_in_individual_token_refiner_blocks_0_self_attn_proj
.
apply
(
attn
)
out_1
=
txt_in_input_embed
+
out
*
gate_msa
out
=
weights
.
txt_in_individual_token_refiner_blocks_0_norm2
.
apply
(
out_1
)
# mlp
out
=
weights
.
txt_in_individual_token_refiner_blocks_0_mlp_fc1
.
apply
(
out
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
out
=
weights
.
txt_in_individual_token_refiner_blocks_0_mlp_fc2
.
apply
(
out
)
txt_in_input_embed
=
out_1
+
out
*
gate_mlp
cx
=
torch
.
nn
.
functional
.
silu
(
c
)
cx
=
weights
.
txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1
.
apply
(
cx
)
gate_msa
,
gate_mlp
=
cx
.
chunk
(
2
,
dim
=
1
)
normx
=
weights
.
txt_in_individual_token_refiner_blocks_1_norm1
.
apply
(
txt_in_input_embed
)
qkv
=
weights
.
txt_in_individual_token_refiner_blocks_1_self_attn_qkv
.
apply
(
normx
)
q
,
k
,
v
=
rearrange
(
qkv
.
unsqueeze
(
0
),
"B L (K H D) -> K B L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
attn
=
weights
.
txt_in_attn_1
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
attn_mask
=
self_attn_mask
)[
0
]
out
=
weights
.
txt_in_individual_token_refiner_blocks_1_self_attn_proj
.
apply
(
attn
)
out_1
=
txt_in_input_embed
+
out
*
gate_msa
out
=
weights
.
txt_in_individual_token_refiner_blocks_1_norm2
.
apply
(
out_1
)
# mlp
out
=
weights
.
txt_in_individual_token_refiner_blocks_1_mlp_fc1
.
apply
(
out
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
out
=
weights
.
txt_in_individual_token_refiner_blocks_1_mlp_fc2
.
apply
(
out
)
out
=
out_1
+
out
*
gate_mlp
return
out
def
infer_vector_in
(
self
,
weights
,
text_states_2
):
out
=
weights
.
vector_in_in_layer
.
apply
(
text_states_2
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
out
=
weights
.
vector_in_out_layer
.
apply
(
out
)
return
out
def
infer_guidance_in
(
self
,
weights
,
guidance
):
freqs
=
torch
.
exp
(
-
math
.
log
(
10000
)
*
torch
.
arange
(
start
=
0
,
end
=
128
,
dtype
=
torch
.
float32
)
/
128
).
to
(
device
=
guidance
.
device
)
args
=
guidance
.
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
).
to
(
dtype
=
GET_DTYPE
())
out
=
weights
.
guidance_in_mlp_0
.
apply
(
embedding
)
out
=
torch
.
nn
.
functional
.
silu
(
out
)
out
=
weights
.
guidance_in_mlp_2
.
apply
(
out
)
return
out
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
deleted
100755 → 0
View file @
08d2f46a
import
torch
from
einops
import
rearrange
from
lightx2v.common.offload.manager
import
WeightAsyncStreamManager
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.utils.envs
import
*
from
.utils_bf16
import
apply_rotary_emb
class
HunyuanTransformerInfer
(
BaseTransformerInfer
):
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
attention_type
=
config
.
get
(
"attention_type"
,
"flash_attn2"
)
self
.
double_blocks_num
=
20
self
.
single_blocks_num
=
40
self
.
heads_num
=
24
self
.
hidden_size
=
3072
self
.
mlp_hidden_dim
=
12288
self
.
parallel_attention
=
None
if
self
.
config
[
"cpu_offload"
]:
if
"offload_ratio"
in
self
.
config
:
offload_ratio
=
self
.
config
[
"offload_ratio"
]
else
:
offload_ratio
=
1
self
.
double_weights_stream_mgr
=
WeightAsyncStreamManager
(
blocks_num
=
self
.
double_blocks_num
,
offload_ratio
=
offload_ratio
)
self
.
single_weights_stream_mgr
=
WeightAsyncStreamManager
(
blocks_num
=
self
.
single_blocks_num
,
offload_ratio
=
offload_ratio
)
self
.
infer_func
=
self
.
_infer_with_offload
else
:
self
.
infer_func
=
self
.
_infer_without_offload
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
return
self
.
infer_func
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
def
_infer_with_offload
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
):
txt_seq_len
=
txt
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
for
double_block_idx
in
range
(
self
.
double_blocks_num
):
if
double_block_idx
==
0
:
self
.
double_weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
double_blocks
[
0
]
self
.
double_weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
with
torch
.
cuda
.
stream
(
self
.
double_weights_stream_mgr
.
compute_stream
):
img
,
txt
=
self
.
infer_double_block
(
self
.
double_weights_stream_mgr
.
active_weights
[
0
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
if
double_block_idx
<
self
.
double_blocks_num
-
1
:
self
.
double_weights_stream_mgr
.
prefetch_weights
(
double_block_idx
+
1
,
weights
.
double_blocks
)
self
.
double_weights_stream_mgr
.
swap_weights
()
x
=
torch
.
cat
((
img
,
txt
),
0
)
img
=
img
.
cpu
()
txt
=
txt
.
cpu
()
del
img
,
txt
torch
.
cuda
.
empty_cache
()
for
single_block_idx
in
range
(
self
.
single_blocks_num
):
if
single_block_idx
==
0
:
self
.
single_weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
single_blocks
[
0
]
self
.
single_weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
with
torch
.
cuda
.
stream
(
self
.
single_weights_stream_mgr
.
compute_stream
):
x
=
self
.
infer_single_block
(
self
.
single_weights_stream_mgr
.
active_weights
[
0
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
if
single_block_idx
<
self
.
single_blocks_num
-
1
:
self
.
single_weights_stream_mgr
.
prefetch_weights
(
single_block_idx
+
1
,
weights
.
single_blocks
)
self
.
single_weights_stream_mgr
.
swap_weights
()
torch
.
cuda
.
empty_cache
()
img
=
x
[:
img_seq_len
,
...]
return
img
,
vec
def
_infer_without_offload
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
):
txt_seq_len
=
txt
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
for
i
in
range
(
self
.
double_blocks_num
):
img
,
txt
=
self
.
infer_double_block
(
weights
.
double_blocks
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
x
=
torch
.
cat
((
img
,
txt
),
0
)
for
i
in
range
(
self
.
single_blocks_num
):
x
=
self
.
infer_single_block
(
weights
.
single_blocks
[
i
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
img
=
x
[:
img_seq_len
,
...]
return
img
,
vec
def
infer_double_block_phase_1
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
):
vec_silu
=
torch
.
nn
.
functional
.
silu
(
vec
)
img_mod_out
=
weights
.
img_mod
.
apply
(
vec_silu
)
(
img_mod1_shift
,
img_mod1_scale
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
)
=
img_mod_out
.
chunk
(
6
,
dim
=-
1
)
if
token_replace_vec
is
not
None
:
token_replace_vec_silu
=
torch
.
nn
.
functional
.
silu
(
token_replace_vec
)
token_replace_vec_img_mod_out
=
weights
.
img_mod
.
apply
(
token_replace_vec_silu
)
(
tr_img_mod1_shift
,
tr_img_mod1_scale
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
)
=
token_replace_vec_img_mod_out
.
chunk
(
6
,
dim
=-
1
)
else
:
(
tr_img_mod1_shift
,
tr_img_mod1_scale
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
)
=
None
,
None
,
None
,
None
,
None
,
None
txt_mod_out
=
weights
.
txt_mod
.
apply
(
vec_silu
)
(
txt_mod1_shift
,
txt_mod1_scale
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
=
txt_mod_out
.
chunk
(
6
,
dim
=-
1
)
img_q
,
img_k
,
img_v
=
self
.
infer_double_block_img_pre_atten
(
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
tr_img_mod1_scale
,
tr_img_mod1_shift
,
frist_frame_token_num
,
freqs_cis
)
txt_q
,
txt_k
,
txt_v
=
self
.
infer_double_block_txt_pre_atten
(
weights
,
txt
,
txt_mod1_scale
,
txt_mod1_shift
)
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
v
=
torch
.
cat
((
img_v
,
txt_v
),
dim
=
0
)
if
not
self
.
parallel_attention
:
attn
=
weights
.
double_attn
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
,
)
else
:
# world_size = dist.get_world_size()
attn
=
self
.
parallel_attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
img_qkv_len
=
img_q
.
shape
[
0
],
cu_seqlens_qkv
=
cu_seqlens_qkv
,
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
img_attn
,
txt_attn
=
attn
[:
img
.
shape
[
0
]],
attn
[
img
.
shape
[
0
]
:]
img_out
=
weights
.
img_attn_proj
.
apply
(
img_attn
)
txt_out
=
weights
.
txt_attn_proj
.
apply
(
txt_attn
)
return
(
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
def
infer_double_block_phase_2
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
,
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
):
if
tr_img_mod1_gate
is
not
None
:
x_zero
=
img_out
[:
frist_frame_token_num
]
*
tr_img_mod1_gate
x_orig
=
img_out
[
frist_frame_token_num
:]
*
img_mod1_gate
img_out
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
img_out
=
img_out
*
img_mod1_gate
img
=
img
+
img_out
img_out
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
if
tr_img_mod1_gate
is
not
None
:
x_zero
=
img_out
[:
frist_frame_token_num
]
*
(
1
+
tr_img_mod2_scale
)
+
tr_img_mod2_shift
x_orig
=
img_out
[
frist_frame_token_num
:]
*
(
1
+
img_mod2_scale
)
+
img_mod2_shift
img_out
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
img_out
=
img_out
*
(
1
+
img_mod2_scale
)
+
img_mod2_shift
img_out
=
weights
.
img_mlp_fc1
.
apply
(
img_out
)
img_out
=
torch
.
nn
.
functional
.
gelu
(
img_out
,
approximate
=
"tanh"
)
img_out
=
weights
.
img_mlp_fc2
.
apply
(
img_out
)
txt_out
=
txt_out
*
txt_mod1_gate
txt
=
txt
+
txt_out
txt_out
=
torch
.
nn
.
functional
.
layer_norm
(
txt
,
(
txt
.
shape
[
1
],),
None
,
None
,
1e-6
)
txt_out
=
txt_out
*
(
1
+
txt_mod2_scale
)
+
txt_mod2_shift
txt_out
=
weights
.
txt_mlp_fc1
.
apply
(
txt_out
)
txt_out
=
torch
.
nn
.
functional
.
gelu
(
txt_out
,
approximate
=
"tanh"
)
txt_out
=
weights
.
txt_mlp_fc2
.
apply
(
txt_out
)
return
img
,
txt
,
img_out
,
txt_out
,
img_mod2_gate
,
txt_mod2_gate
def
infer_double_block_phase_3
(
self
,
img_out
,
img_mod2_gate
,
img
,
txt_out
,
txt_mod2_gate
,
txt
):
# img
img_out
=
img_out
*
img_mod2_gate
img
=
img
+
img_out
# txt
txt_out
=
txt_out
*
txt_mod2_gate
txt
=
txt
+
txt_out
return
img
,
txt
def
infer_double_block
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
):
(
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
=
self
.
infer_double_block_phase_1
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
img
,
txt
,
img_out
,
txt_out
,
img_mod2_gate
,
txt_mod2_gate
=
self
.
infer_double_block_phase_2
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
,
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
img
,
txt
=
self
.
infer_double_block_phase_3
(
img_out
,
img_mod2_gate
,
img
,
txt_out
,
txt_mod2_gate
,
txt
)
return
img
,
txt
def
infer_double_block_img_pre_atten
(
self
,
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
tr_img_mod1_scale
,
tr_img_mod1_shift
,
frist_frame_token_num
,
freqs_cis
):
img_modulated
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
if
tr_img_mod1_scale
is
not
None
:
x_zero
=
img_modulated
[:
frist_frame_token_num
]
*
(
1
+
tr_img_mod1_scale
)
+
tr_img_mod1_shift
x_orig
=
img_modulated
[
frist_frame_token_num
:]
*
(
1
+
img_mod1_scale
)
+
img_mod1_shift
img_modulated
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
img_modulated
=
img_modulated
*
(
1
+
img_mod1_scale
)
+
img_mod1_shift
img_qkv
=
weights
.
img_attn_qkv
.
apply
(
img_modulated
)
img_q
,
img_k
,
img_v
=
rearrange
(
img_qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
img_q
=
weights
.
img_attn_q_norm
.
apply
(
img_q
)
img_k
=
weights
.
img_attn_k_norm
.
apply
(
img_k
)
img_q
,
img_k
=
apply_rotary_emb
(
img_q
,
img_k
,
freqs_cis
)
return
img_q
,
img_k
,
img_v
def
infer_double_block_txt_pre_atten
(
self
,
weights
,
txt
,
txt_mod1_scale
,
txt_mod1_shift
):
txt_modulated
=
torch
.
nn
.
functional
.
layer_norm
(
txt
,
(
txt
.
shape
[
1
],),
None
,
None
,
1e-6
)
txt_modulated
=
txt_modulated
*
(
1
+
txt_mod1_scale
)
+
txt_mod1_shift
txt_qkv
=
weights
.
txt_attn_qkv
.
apply
(
txt_modulated
)
txt_q
,
txt_k
,
txt_v
=
rearrange
(
txt_qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
txt_q
=
weights
.
txt_attn_q_norm
.
apply
(
txt_q
)
txt_k
=
weights
.
txt_attn_k_norm
.
apply
(
txt_k
)
return
txt_q
,
txt_k
,
txt_v
def
infer_single_block_phase_1
(
self
,
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
weights
.
modulation
.
apply
(
out
)
mod_shift
,
mod_scale
,
mod_gate
=
out
.
chunk
(
3
,
dim
=-
1
)
if
token_replace_vec
is
not
None
:
token_replace_vec_out
=
torch
.
nn
.
functional
.
silu
(
token_replace_vec
)
token_replace_vec_out
=
weights
.
modulation
.
apply
(
token_replace_vec_out
)
tr_mod_shift
,
tr_mod_scale
,
tr_mod_gate
=
token_replace_vec_out
.
chunk
(
3
,
dim
=-
1
)
else
:
tr_mod_shift
,
tr_mod_scale
,
tr_mod_gate
=
None
,
None
,
None
out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
if
token_replace_vec
is
not
None
:
x_zero
=
out
[:
frist_frame_token_num
]
*
(
1
+
tr_mod_scale
)
+
tr_mod_shift
x_orig
=
out
[
frist_frame_token_num
:]
*
(
1
+
mod_scale
)
+
mod_shift
x_mod
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
x_mod
=
out
*
(
1
+
mod_scale
)
+
mod_shift
x_mod
=
weights
.
linear1
.
apply
(
x_mod
)
qkv
,
mlp
=
torch
.
split
(
x_mod
,
[
3
*
self
.
hidden_size
,
self
.
mlp_hidden_dim
],
dim
=-
1
)
q
,
k
,
v
=
rearrange
(
qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
q
=
weights
.
q_norm
.
apply
(
q
)
k
=
weights
.
k_norm
.
apply
(
k
)
img_q
,
txt_q
=
q
[:
-
txt_seq_len
,
:,
:],
q
[
-
txt_seq_len
:,
:,
:]
img_k
,
txt_k
=
k
[:
-
txt_seq_len
,
:,
:],
k
[
-
txt_seq_len
:,
:,
:]
img_q
,
img_k
=
apply_rotary_emb
(
img_q
,
img_k
,
freqs_cis
)
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
if
not
self
.
parallel_attention
:
attn
=
weights
.
single_attn
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
,
)
else
:
attn
=
self
.
parallel_attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
img_qkv_len
=
img_q
.
shape
[
0
],
cu_seqlens_qkv
=
cu_seqlens_qkv
,
# cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv,
)
out
=
torch
.
nn
.
functional
.
gelu
(
mlp
,
approximate
=
"tanh"
)
out
=
torch
.
cat
((
attn
,
out
),
1
)
out
=
weights
.
linear2
.
apply
(
out
)
return
out
,
mod_gate
,
tr_mod_gate
def
infer_single_block_phase_2
(
self
,
x
,
out
,
tr_mod_gate
,
mod_gate
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
if
token_replace_vec
is
not
None
:
x_zero
=
out
[:
frist_frame_token_num
]
*
tr_mod_gate
x_orig
=
out
[
frist_frame_token_num
:]
*
mod_gate
out
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
out
=
out
*
mod_gate
x
=
x
+
out
return
x
def
infer_single_block
(
self
,
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
out
,
mod_gate
,
tr_mod_gate
=
self
.
infer_single_block_phase_1
(
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
x
=
self
.
infer_single_block_phase_2
(
x
,
out
,
tr_mod_gate
,
mod_gate
,
token_replace_vec
,
frist_frame_token_num
)
return
x
lightx2v/models/networks/hunyuan/infer/utils.py
deleted
100644 → 0
View file @
08d2f46a
import
sgl_kernel
def
rms_norm
(
x
,
weight
,
eps
):
x
=
x
.
contiguous
()
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
orig_shape
[
-
1
])
x
=
sgl_kernel
.
rmsnorm
(
x
,
weight
,
eps
).
view
(
orig_shape
)
return
x
lightx2v/models/networks/hunyuan/infer/utils_bf16.py
deleted
100644 → 0
View file @
08d2f46a
from
typing
import
Tuple
,
Union
import
torch
def
rms_norm
(
x
,
weight
,
eps
):
x
=
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
eps
)
x
=
x
*
weight
return
x
def
rotate_half
(
x
,
shape_0
,
shape_1
):
x_real
,
x_imag
=
x
.
reshape
(
shape_0
,
shape_1
,
-
1
,
2
).
unbind
(
-
1
)
return
torch
.
stack
([
-
x_imag
,
x_real
],
dim
=-
1
).
flatten
(
2
)
def
rotary_emb
(
x
,
shape_0
,
shape_1
,
cos
,
sin
):
x_out
=
x
*
cos
+
rotate_half
(
x
,
shape_0
,
shape_1
)
*
sin
return
x_out
def
apply_rotary_emb
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
freqs_cis
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
shape_0
,
shape_1
,
shape_2
=
xq
.
shape
cos
=
freqs_cis
[
0
].
view
(
shape_0
,
1
,
shape_2
)
sin
=
freqs_cis
[
1
].
view
(
shape_0
,
1
,
shape_2
)
xq_out
=
rotary_emb
(
xq
,
shape_0
,
shape_1
,
cos
,
sin
)
xk_out
=
rotary_emb
(
xk
,
shape_0
,
shape_1
,
cos
,
sin
)
return
xq_out
,
xk_out
lightx2v/models/networks/hunyuan/infer/utils_fp32.py
deleted
100644 → 0
View file @
08d2f46a
from
typing
import
Tuple
,
Union
import
torch
from
lightx2v.utils.envs
import
*
def
rms_norm
(
x
,
weight
,
eps
):
x
=
x
.
float
()
x
=
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
eps
)
x
=
x
.
to
(
GET_DTYPE
())
x
=
x
*
weight
return
x
def
rotate_half
(
x
,
shape_0
,
shape_1
):
x_real
,
x_imag
=
x
.
float
().
reshape
(
shape_0
,
shape_1
,
-
1
,
2
).
unbind
(
-
1
)
return
torch
.
stack
([
-
x_imag
,
x_real
],
dim
=-
1
).
flatten
(
2
)
def
rotary_emb
(
x
,
shape_0
,
shape_1
,
cos
,
sin
):
x_out
=
x
*
cos
+
rotate_half
(
x
,
shape_0
,
shape_1
)
*
sin
return
x_out
.
to
(
GET_DTYPE
())
def
apply_rotary_emb
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
freqs_cis
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
shape_0
,
shape_1
,
shape_2
=
xq
.
shape
cos
=
freqs_cis
[
0
].
view
(
shape_0
,
1
,
shape_2
)
sin
=
freqs_cis
[
1
].
view
(
shape_0
,
1
,
shape_2
)
xq_out
=
rotary_emb
(
xq
.
float
(),
shape_0
,
shape_1
,
cos
,
sin
)
xk_out
=
rotary_emb
(
xk
.
float
(),
shape_0
,
shape_1
,
cos
,
sin
)
return
xq_out
,
xk_out
lightx2v/models/networks/hunyuan/model.py
deleted
100755 → 0
View file @
08d2f46a
import
json
import
os
import
torch
from
loguru
import
logger
from
safetensors
import
safe_open
from
lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer
import
(
HunyuanTransformerInferAdaCaching
,
HunyuanTransformerInferCustomCaching
,
HunyuanTransformerInferTaylorCaching
,
HunyuanTransformerInferTeaCaching
,
)
from
lightx2v.models.networks.hunyuan.infer.post_infer
import
HunyuanPostInfer
from
lightx2v.models.networks.hunyuan.infer.pre_infer
import
HunyuanPreInfer
from
lightx2v.models.networks.hunyuan.infer.transformer_infer
import
HunyuanTransformerInfer
from
lightx2v.models.networks.hunyuan.weights.post_weights
import
HunyuanPostWeights
from
lightx2v.models.networks.hunyuan.weights.pre_weights
import
HunyuanPreWeights
from
lightx2v.models.networks.hunyuan.weights.transformer_weights
import
HunyuanTransformerWeights
from
lightx2v.utils.envs
import
*
class
HunyuanModel
:
pre_weight_class
=
HunyuanPreWeights
post_weight_class
=
HunyuanPostWeights
transformer_weight_class
=
HunyuanTransformerWeights
def
__init__
(
self
,
model_path
,
config
,
device
,
args
):
self
.
model_path
=
model_path
self
.
config
=
config
self
.
device
=
device
self
.
args
=
args
self
.
dit_quantized
=
self
.
config
.
mm_config
.
get
(
"mm_type"
,
"Default"
)
!=
"Default"
self
.
dit_quantized_ckpt
=
self
.
config
.
get
(
"dit_quantized_ckpt"
,
None
)
self
.
weight_auto_quant
=
self
.
config
.
mm_config
.
get
(
"weight_auto_quant"
,
False
)
if
self
.
dit_quantized
:
assert
self
.
weight_auto_quant
or
self
.
dit_quantized_ckpt
is
not
None
self
.
_init_infer_class
()
self
.
_init_weights
()
self
.
_init_infer
()
if
self
.
config
[
"cpu_offload"
]:
self
.
to_cpu
()
def
_load_ckpt
(
self
):
if
self
.
args
.
task
==
"t2v"
:
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
"hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
)
else
:
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
"hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt"
)
weight_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
self
.
device
,
weights_only
=
True
)[
"module"
]
return
weight_dict
def
_load_quant_ckpt
(
self
):
ckpt_path
=
self
.
dit_quantized_ckpt
logger
.
info
(
f
"Loading quant dit model from
{
ckpt_path
}
"
)
if
ckpt_path
.
endswith
(
".pth"
):
logger
.
info
(
f
"Loading
{
ckpt_path
}
as PyTorch model."
)
weight_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
self
.
device
,
weights_only
=
True
)
else
:
index_files
=
[
f
for
f
in
os
.
listdir
(
ckpt_path
)
if
f
.
endswith
(
".index.json"
)]
if
not
index_files
:
raise
FileNotFoundError
(
f
"No .pth file or *.index.json found in
{
ckpt_path
}
"
)
index_path
=
os
.
path
.
join
(
ckpt_path
,
index_files
[
0
])
logger
.
info
(
f
" Using safetensors index:
{
index_path
}
"
)
with
open
(
index_path
,
"r"
)
as
f
:
index_data
=
json
.
load
(
f
)
weight_dict
=
{}
for
filename
in
set
(
index_data
[
"weight_map"
].
values
()):
safetensor_path
=
os
.
path
.
join
(
ckpt_path
,
filename
)
with
safe_open
(
safetensor_path
,
framework
=
"pt"
,
device
=
str
(
self
.
device
))
as
f
:
logger
.
info
(
f
"Loading weights from
{
safetensor_path
}
"
)
for
k
in
f
.
keys
():
weight_dict
[
k
]
=
f
.
get_tensor
(
k
)
if
weight_dict
[
k
].
dtype
==
torch
.
float
:
weight_dict
[
k
]
=
weight_dict
[
k
].
to
(
GET_DTYPE
())
return
weight_dict
def
_init_weights
(
self
):
if
not
self
.
dit_quantized
or
self
.
weight_auto_quant
:
weight_dict
=
self
.
_load_ckpt
()
else
:
weight_dict
=
self
.
_load_quant_ckpt
()
# init weights
self
.
pre_weight
=
self
.
pre_weight_class
(
self
.
config
)
self
.
post_weight
=
self
.
post_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
# load weights
self
.
pre_weight
.
load
(
weight_dict
)
self
.
post_weight
.
load
(
weight_dict
)
self
.
transformer_weights
.
load
(
weight_dict
)
def
_init_infer
(
self
):
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
def
save_weights
(
self
,
save_path
):
if
not
os
.
path
.
exists
(
save_path
):
os
.
makedirs
(
save_path
)
pre_state_dict
=
self
.
pre_weight
.
state_dict
()
logger
.
info
(
pre_state_dict
.
keys
())
post_state_dict
=
self
.
post_weight
.
state_dict
()
logger
.
info
(
post_state_dict
.
keys
())
transformer_state_dict
=
self
.
transformer_weights
.
state_dict
()
logger
.
info
(
transformer_state_dict
.
keys
())
save_dict
=
{}
save_dict
.
update
(
pre_state_dict
)
save_dict
.
update
(
post_state_dict
)
save_dict
.
update
(
transformer_state_dict
)
save_path
=
os
.
path
.
join
(
save_path
,
"quant_weights.pth"
)
torch
.
save
(
save_dict
,
save_path
)
logger
.
info
(
f
"Save weights to
{
save_path
}
"
)
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
pre_infer
.
set_scheduler
(
scheduler
)
self
.
post_infer
.
set_scheduler
(
scheduler
)
self
.
transformer_infer
.
set_scheduler
(
scheduler
)
def
to_cpu
(
self
):
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
self
.
transformer_weights
.
to_cpu
()
def
to_cuda
(
self
):
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
self
.
transformer_weights
.
to_cuda
()
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
inputs
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
)
inputs
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
*
inputs
)
self
.
scheduler
.
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
*
inputs
)
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
HunyuanPreInfer
self
.
post_infer_class
=
HunyuanPostInfer
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
self
.
transformer_infer_class
=
HunyuanTransformerInfer
elif
self
.
config
[
"feature_caching"
]
==
"TaylorSeer"
:
self
.
transformer_infer_class
=
HunyuanTransformerInferTaylorCaching
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
HunyuanTransformerInferTeaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Ada"
:
self
.
transformer_infer_class
=
HunyuanTransformerInferAdaCaching
elif
self
.
config
[
"feature_caching"
]
==
"Custom"
:
self
.
transformer_infer_class
=
HunyuanTransformerInferCustomCaching
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
lightx2v/models/networks/hunyuan/weights/__init__.py
deleted
100755 → 0
View file @
08d2f46a
lightx2v/models/networks/hunyuan/weights/post_weights.py
deleted
100755 → 0
View file @
08d2f46a
from
lightx2v.common.modules.weight_module
import
WeightModule
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
class
HunyuanPostWeights
(
WeightModule
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
add_module
(
"final_layer_linear"
,
MM_WEIGHT_REGISTER
[
"Default-Force-FP32"
](
"final_layer.linear.weight"
,
"final_layer.linear.bias"
))
self
.
add_module
(
"final_layer_adaLN_modulation_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"final_layer.adaLN_modulation.1.weight"
,
"final_layer.adaLN_modulation.1.bias"
))
lightx2v/models/networks/hunyuan/weights/pre_weights.py
deleted
100755 → 0
View file @
08d2f46a
from
lightx2v.common.modules.weight_module
import
WeightModule
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
,
CONV3D_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
MM_WEIGHT_REGISTER
class
HunyuanPreWeights
(
WeightModule
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
add_module
(
"img_in_proj"
,
CONV3D_WEIGHT_REGISTER
[
"Default"
](
"img_in.proj.weight"
,
"img_in.proj.bias"
,
stride
=
(
1
,
2
,
2
)))
self
.
add_module
(
"txt_in_input_embedder"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.input_embedder.weight"
,
"txt_in.input_embedder.bias"
))
self
.
add_module
(
"txt_in_t_embedder_mlp_0"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.t_embedder.mlp.0.weight"
,
"txt_in.t_embedder.mlp.0.bias"
))
self
.
add_module
(
"txt_in_t_embedder_mlp_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.t_embedder.mlp.2.weight"
,
"txt_in.t_embedder.mlp.2.bias"
))
self
.
add_module
(
"txt_in_c_embedder_linear_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.c_embedder.linear_1.weight"
,
"txt_in.c_embedder.linear_1.bias"
))
self
.
add_module
(
"txt_in_c_embedder_linear_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.c_embedder.linear_2.weight"
,
"txt_in.c_embedder.linear_2.bias"
))
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_norm1"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.norm1.weight"
,
"txt_in.individual_token_refiner.blocks.0.norm1.bias"
,
eps
=
1e-6
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_self_attn_qkv"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.self_attn_qkv.weight"
,
"txt_in.individual_token_refiner.blocks.0.self_attn_qkv.bias"
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_self_attn_proj"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.self_attn_proj.weight"
,
"txt_in.individual_token_refiner.blocks.0.self_attn_proj.bias"
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_norm2"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.norm2.weight"
,
"txt_in.individual_token_refiner.blocks.0.norm2.bias"
,
eps
=
1e-6
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_mlp_fc1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.mlp.fc1.weight"
,
"txt_in.individual_token_refiner.blocks.0.mlp.fc1.bias"
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_mlp_fc2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.mlp.fc2.weight"
,
"txt_in.individual_token_refiner.blocks.0.mlp.fc2.bias"
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.weight"
,
"txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias"
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_norm1"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.norm1.weight"
,
"txt_in.individual_token_refiner.blocks.1.norm1.bias"
,
eps
=
1e-6
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_self_attn_qkv"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight"
,
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.bias"
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_self_attn_proj"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.self_attn_proj.weight"
,
"txt_in.individual_token_refiner.blocks.1.self_attn_proj.bias"
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_norm2"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.norm2.weight"
,
"txt_in.individual_token_refiner.blocks.1.norm2.bias"
,
eps
=
1e-6
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_mlp_fc1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.mlp.fc1.weight"
,
"txt_in.individual_token_refiner.blocks.1.mlp.fc1.bias"
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_mlp_fc2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.mlp.fc2.weight"
,
"txt_in.individual_token_refiner.blocks.1.mlp.fc2.bias"
),
)
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.weight"
,
"txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.bias"
),
)
self
.
add_module
(
"time_in_mlp_0"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_in.mlp.0.weight"
,
"time_in.mlp.0.bias"
))
self
.
add_module
(
"time_in_mlp_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_in.mlp.2.weight"
,
"time_in.mlp.2.bias"
))
self
.
add_module
(
"vector_in_in_layer"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"vector_in.in_layer.weight"
,
"vector_in.in_layer.bias"
))
self
.
add_module
(
"vector_in_out_layer"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"vector_in.out_layer.weight"
,
"vector_in.out_layer.bias"
))
self
.
add_module
(
"guidance_in_mlp_0"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_in.mlp.0.weight"
,
"guidance_in.mlp.0.bias"
))
self
.
add_module
(
"guidance_in_mlp_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_in.mlp.2.weight"
,
"guidance_in.mlp.2.bias"
))
# attention weights section
self
.
add_module
(
"txt_in_attn_1"
,
ATTN_WEIGHT_REGISTER
[
"torch_sdpa"
]())
lightx2v/models/networks/hunyuan/weights/transformer_weights.py
deleted
100755 → 0
View file @
08d2f46a
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
,
MM_WEIGHT_REGISTER
,
RMS_WEIGHT_REGISTER
class
HunyuanTransformerWeights
(
WeightModule
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
double_blocks_num
=
20
self
.
single_blocks_num
=
40
self
.
add_module
(
"double_blocks"
,
WeightModuleList
([
HunyuanTransformerDoubleBlock
(
i
,
self
.
config
)
for
i
in
range
(
self
.
double_blocks_num
)]))
self
.
add_module
(
"single_blocks"
,
WeightModuleList
([
HunyuanTransformerSingleBlock
(
i
,
self
.
config
)
for
i
in
range
(
self
.
single_blocks_num
)]))
class
HunyuanTransformerDoubleBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
config
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
config
=
config
if
self
.
config
[
"do_mm_calib"
]:
mm_type
=
"Calib"
else
:
mm_type
=
self
.
config
[
"mm_config"
].
get
(
"mm_type"
,
"Default"
)
if
self
.
config
[
"mm_config"
]
else
"Default"
self
.
add_module
(
"img_mod"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_mod.linear.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_mod.linear.bias"
))
self
.
add_module
(
"img_attn_qkv"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_qkv.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_attn_qkv.bias"
))
self
.
add_module
(
"img_attn_q_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_q_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"img_attn_k_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_k_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"img_attn_proj"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_proj.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_attn_proj.bias"
))
self
.
add_module
(
"img_mlp_fc1"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc1.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc1.bias"
))
self
.
add_module
(
"img_mlp_fc2"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc2.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc2.bias"
))
self
.
add_module
(
"txt_mod"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mod.linear.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mod.linear.bias"
))
self
.
add_module
(
"txt_attn_qkv"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_qkv.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_qkv.bias"
))
self
.
add_module
(
"txt_attn_q_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_q_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"txt_attn_k_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_k_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"txt_attn_proj"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_proj.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_proj.bias"
))
self
.
add_module
(
"txt_mlp_fc1"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc1.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc1.bias"
))
self
.
add_module
(
"txt_mlp_fc2"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc2.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc2.bias"
))
# attention weights section
self
.
add_module
(
"double_attn"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attention_type"
]]())
class
HunyuanTransformerSingleBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
config
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
config
=
config
self
.
sparge
=
config
.
get
(
"sparge"
,
False
)
if
self
.
config
[
"do_mm_calib"
]:
mm_type
=
"Calib"
else
:
mm_type
=
self
.
config
[
"mm_config"
].
get
(
"mm_type"
,
"Default"
)
if
self
.
config
[
"mm_config"
]
else
"Default"
self
.
add_module
(
"linear1"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.linear1.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.linear1.bias"
))
self
.
add_module
(
"linear2"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.linear2.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.linear2.bias"
))
self
.
add_module
(
"q_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"single_blocks.
{
self
.
block_index
}
.q_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"k_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"single_blocks.
{
self
.
block_index
}
.k_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"modulation"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.modulation.linear.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.modulation.linear.bias"
))
# attention weights section
if
self
.
sparge
:
# load sparge attention weights
#! todo
pass
else
:
self
.
add_module
(
"single_attn"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attention_type"
]]())
lightx2v/models/runners/cogvideox/__init__.py
deleted
100644 → 0
View file @
08d2f46a
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment