Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
daf4c74e
Commit
daf4c74e
authored
Mar 24, 2025
by
helloyongyang
Committed by
Yang Yong(雍洋)
Apr 08, 2025
Browse files
first commit
parent
6c79160f
Changes
107
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1633 additions
and
0 deletions
+1633
-0
lightx2v/text2v/models/networks/hunyuan/model.py
lightx2v/text2v/models/networks/hunyuan/model.py
+95
-0
lightx2v/text2v/models/networks/hunyuan/weights/__init__.py
lightx2v/text2v/models/networks/hunyuan/weights/__init__.py
+0
-0
lightx2v/text2v/models/networks/hunyuan/weights/post_weights.py
...2v/text2v/models/networks/hunyuan/weights/post_weights.py
+31
-0
lightx2v/text2v/models/networks/hunyuan/weights/pre_weights.py
...x2v/text2v/models/networks/hunyuan/weights/pre_weights.py
+89
-0
lightx2v/text2v/models/networks/hunyuan/weights/transformer_weights.py
...2v/models/networks/hunyuan/weights/transformer_weights.py
+137
-0
lightx2v/text2v/models/networks/wan/infer/feature_caching/__init__.py
...t2v/models/networks/wan/infer/feature_caching/__init__.py
+0
-0
lightx2v/text2v/models/networks/wan/infer/feature_caching/transformer_infer.py
...s/networks/wan/infer/feature_caching/transformer_infer.py
+90
-0
lightx2v/text2v/models/networks/wan/infer/post_infer.py
lightx2v/text2v/models/networks/wan/infer/post_infer.py
+30
-0
lightx2v/text2v/models/networks/wan/infer/pre_infer.py
lightx2v/text2v/models/networks/wan/infer/pre_infer.py
+119
-0
lightx2v/text2v/models/networks/wan/infer/transformer_infer.py
...x2v/text2v/models/networks/wan/infer/transformer_infer.py
+162
-0
lightx2v/text2v/models/networks/wan/infer/utils.py
lightx2v/text2v/models/networks/wan/infer/utils.py
+64
-0
lightx2v/text2v/models/networks/wan/model.py
lightx2v/text2v/models/networks/wan/model.py
+133
-0
lightx2v/text2v/models/networks/wan/weights/post_weights.py
lightx2v/text2v/models/networks/wan/weights/post_weights.py
+11
-0
lightx2v/text2v/models/networks/wan/weights/pre_weights.py
lightx2v/text2v/models/networks/wan/weights/pre_weights.py
+53
-0
lightx2v/text2v/models/networks/wan/weights/transformer_weights.py
...text2v/models/networks/wan/weights/transformer_weights.py
+81
-0
lightx2v/text2v/models/schedulers/__init__.py
lightx2v/text2v/models/schedulers/__init__.py
+0
-0
lightx2v/text2v/models/schedulers/hunyuan/feature_caching/scheduler.py
...2v/models/schedulers/hunyuan/feature_caching/scheduler.py
+201
-0
lightx2v/text2v/models/schedulers/hunyuan/scheduler.py
lightx2v/text2v/models/schedulers/hunyuan/scheduler.py
+269
-0
lightx2v/text2v/models/schedulers/scheduler.py
lightx2v/text2v/models/schedulers/scheduler.py
+12
-0
lightx2v/text2v/models/schedulers/wan/feature_caching/scheduler.py
...text2v/models/schedulers/wan/feature_caching/scheduler.py
+56
-0
No files found.
lightx2v/text2v/models/networks/hunyuan/model.py
0 → 100755
View file @
daf4c74e
import
os
import
torch
from
lightx2v.text2v.models.networks.hunyuan.weights.pre_weights
import
HunyuanPreWeights
from
lightx2v.text2v.models.networks.hunyuan.weights.post_weights
import
HunyuanPostWeights
from
lightx2v.text2v.models.networks.hunyuan.weights.transformer_weights
import
HunyuanTransformerWeights
from
lightx2v.text2v.models.networks.hunyuan.infer.pre_infer
import
HunyuanPreInfer
from
lightx2v.text2v.models.networks.hunyuan.infer.post_infer
import
HunyuanPostInfer
from
lightx2v.text2v.models.networks.hunyuan.infer.transformer_infer
import
HunyuanTransformerInfer
from
lightx2v.text2v.models.networks.hunyuan.infer.feature_caching.transformer_infer
import
HunyuanTransformerInferFeatureCaching
# from lightx2v.core.distributed.partial_heads_attn.wrap import parallelize_hunyuan
from
lightx2v.attentions.distributed.ulysses.wrap
import
parallelize_hunyuan
class
HunyuanModel
:
pre_weight_class
=
HunyuanPreWeights
post_weight_class
=
HunyuanPostWeights
transformer_weight_class
=
HunyuanTransformerWeights
def
__init__
(
self
,
model_path
,
config
):
self
.
model_path
=
model_path
self
.
config
=
config
self
.
_init_infer_class
()
self
.
_init_weights
()
self
.
_init_infer
()
if
self
.
config
[
'parallel_attn'
]:
parallelize_hunyuan
(
self
)
if
self
.
config
[
'cpu_offload'
]:
self
.
to_cpu
()
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
HunyuanPreInfer
self
.
post_infer_class
=
HunyuanPostInfer
if
self
.
config
[
'feature_caching'
]
==
"NoCaching"
:
self
.
transformer_infer_class
=
HunyuanTransformerInfer
elif
self
.
config
[
'feature_caching'
]
==
"TaylorSeer"
:
self
.
transformer_infer_class
=
HunyuanTransformerInferFeatureCaching
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
def
_load_ckpt
(
self
):
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
"hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
)
weight_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cuda"
,
weights_only
=
True
)[
"module"
]
return
weight_dict
def
_init_weights
(
self
):
weight_dict
=
self
.
_load_ckpt
()
# init weights
self
.
pre_weight
=
self
.
pre_weight_class
(
self
.
config
)
self
.
post_weight
=
self
.
post_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
# load weights
self
.
pre_weight
.
load_weights
(
weight_dict
)
self
.
post_weight
.
load_weights
(
weight_dict
)
self
.
transformer_weights
.
load_weights
(
weight_dict
)
def
_init_infer
(
self
):
self
.
pre_infer
=
self
.
pre_infer_class
()
self
.
post_infer
=
self
.
post_infer_class
()
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
transformer_infer
.
set_scheduler
(
scheduler
)
def
to_cpu
(
self
):
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
self
.
transformer_weights
.
to_cpu
()
def
to_cuda
(
self
):
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
self
.
transformer_weights
.
to_cuda
()
@
torch
.
no_grad
()
def
infer
(
self
,
text_encoder_output
,
image_encoder_output
,
args
):
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
self
.
scheduler
.
latents
,
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
],
text_encoder_output
[
"text_encoder_1_text_states"
],
text_encoder_output
[
"text_encoder_1_attention_mask"
],
text_encoder_output
[
"text_encoder_2_text_states"
],
self
.
scheduler
.
freqs_cos
,
self
.
scheduler
.
freqs_sin
,
self
.
scheduler
.
guidance
,
)
img
,
vec
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
*
pre_infer_out
)
self
.
scheduler
.
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
img
,
vec
,
self
.
scheduler
.
latents
.
shape
)
lightx2v/text2v/models/networks/hunyuan/weights/__init__.py
0 → 100755
View file @
daf4c74e
lightx2v/text2v/models/networks/hunyuan/weights/post_weights.py
0 → 100755
View file @
daf4c74e
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
class
HunyuanPostWeights
:
def
__init__
(
self
,
config
):
self
.
config
=
config
def
load_weights
(
self
,
weight_dict
):
self
.
final_layer_linear
=
MM_WEIGHT_REGISTER
[
'Default-Force-FP32'
](
'final_layer.linear.weight'
,
'final_layer.linear.bias'
)
self
.
final_layer_adaLN_modulation_1
=
MM_WEIGHT_REGISTER
[
'Default'
](
'final_layer.adaLN_modulation.1.weight'
,
'final_layer.adaLN_modulation.1.bias'
)
self
.
weight_list
=
[
self
.
final_layer_linear
,
self
.
final_layer_adaLN_modulation_1
,
]
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
):
mm_weight
.
set_config
(
self
.
config
[
'mm_config'
])
mm_weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
):
mm_weight
.
to_cpu
()
def
to_cuda
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
):
mm_weight
.
to_cuda
()
lightx2v/text2v/models/networks/hunyuan/weights/pre_weights.py
0 → 100755
View file @
daf4c74e
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
CONV3D_WEIGHT_REGISTER
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
from
lightx2v.common.ops.norm.layer_norm_weight
import
LNWeightTemplate
from
lightx2v.common.ops.conv.conv3d
import
Conv3dWeightTemplate
class
HunyuanPreWeights
:
def
__init__
(
self
,
config
):
self
.
config
=
config
def
load_weights
(
self
,
weight_dict
):
self
.
img_in_proj
=
CONV3D_WEIGHT_REGISTER
[
"Default"
](
'img_in.proj.weight'
,
'img_in.proj.bias'
,
stride
=
(
1
,
2
,
2
))
self
.
txt_in_input_embedder
=
MM_WEIGHT_REGISTER
[
"Default"
](
'txt_in.input_embedder.weight'
,
'txt_in.input_embedder.bias'
)
self
.
txt_in_t_embedder_mlp_0
=
MM_WEIGHT_REGISTER
[
"Default"
](
'txt_in.t_embedder.mlp.0.weight'
,
'txt_in.t_embedder.mlp.0.bias'
)
self
.
txt_in_t_embedder_mlp_2
=
MM_WEIGHT_REGISTER
[
"Default"
](
'txt_in.t_embedder.mlp.2.weight'
,
'txt_in.t_embedder.mlp.2.bias'
)
self
.
txt_in_c_embedder_linear_1
=
MM_WEIGHT_REGISTER
[
"Default"
](
'txt_in.c_embedder.linear_1.weight'
,
'txt_in.c_embedder.linear_1.bias'
)
self
.
txt_in_c_embedder_linear_2
=
MM_WEIGHT_REGISTER
[
"Default"
](
'txt_in.c_embedder.linear_2.weight'
,
'txt_in.c_embedder.linear_2.bias'
)
self
.
txt_in_individual_token_refiner_blocks_0_norm1
=
LN_WEIGHT_REGISTER
[
"Default"
](
'txt_in.individual_token_refiner.blocks.0.norm1.weight'
,
'txt_in.individual_token_refiner.blocks.0.norm1.bias'
,
eps
=
1e-6
)
self
.
txt_in_individual_token_refiner_blocks_0_self_attn_qkv
=
MM_WEIGHT_REGISTER
[
"Default"
](
'txt_in.individual_token_refiner.blocks.0.self_attn_qkv.weight'
,
'txt_in.individual_token_refiner.blocks.0.self_attn_qkv.bias'
)
self
.
txt_in_individual_token_refiner_blocks_0_self_attn_proj
=
MM_WEIGHT_REGISTER
[
"Default"
](
'txt_in.individual_token_refiner.blocks.0.self_attn_proj.weight'
,
'txt_in.individual_token_refiner.blocks.0.self_attn_proj.bias'
)
self
.
txt_in_individual_token_refiner_blocks_0_norm2
=
LN_WEIGHT_REGISTER
[
"Default"
](
'txt_in.individual_token_refiner.blocks.0.norm2.weight'
,
'txt_in.individual_token_refiner.blocks.0.norm2.bias'
,
eps
=
1e-6
)
self
.
txt_in_individual_token_refiner_blocks_0_mlp_fc1
=
MM_WEIGHT_REGISTER
[
"Default"
](
'txt_in.individual_token_refiner.blocks.0.mlp.fc1.weight'
,
'txt_in.individual_token_refiner.blocks.0.mlp.fc1.bias'
)
self
.
txt_in_individual_token_refiner_blocks_0_mlp_fc2
=
MM_WEIGHT_REGISTER
[
"Default"
](
'txt_in.individual_token_refiner.blocks.0.mlp.fc2.weight'
,
'txt_in.individual_token_refiner.blocks.0.mlp.fc2.bias'
)
self
.
txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1
=
MM_WEIGHT_REGISTER
[
"Default"
](
'txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.weight'
,
'txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias'
)
self
.
txt_in_individual_token_refiner_blocks_1_norm1
=
LN_WEIGHT_REGISTER
[
"Default"
](
'txt_in.individual_token_refiner.blocks.1.norm1.weight'
,
'txt_in.individual_token_refiner.blocks.1.norm1.bias'
,
eps
=
1e-6
)
self
.
txt_in_individual_token_refiner_blocks_1_self_attn_qkv
=
MM_WEIGHT_REGISTER
[
"Default"
](
'txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight'
,
'txt_in.individual_token_refiner.blocks.1.self_attn_qkv.bias'
)
self
.
txt_in_individual_token_refiner_blocks_1_self_attn_proj
=
MM_WEIGHT_REGISTER
[
"Default"
](
'txt_in.individual_token_refiner.blocks.1.self_attn_proj.weight'
,
'txt_in.individual_token_refiner.blocks.1.self_attn_proj.bias'
)
self
.
txt_in_individual_token_refiner_blocks_1_norm2
=
LN_WEIGHT_REGISTER
[
"Default"
](
'txt_in.individual_token_refiner.blocks.1.norm2.weight'
,
'txt_in.individual_token_refiner.blocks.1.norm2.bias'
,
eps
=
1e-6
)
self
.
txt_in_individual_token_refiner_blocks_1_mlp_fc1
=
MM_WEIGHT_REGISTER
[
"Default"
](
'txt_in.individual_token_refiner.blocks.1.mlp.fc1.weight'
,
'txt_in.individual_token_refiner.blocks.1.mlp.fc1.bias'
)
self
.
txt_in_individual_token_refiner_blocks_1_mlp_fc2
=
MM_WEIGHT_REGISTER
[
"Default"
](
'txt_in.individual_token_refiner.blocks.1.mlp.fc2.weight'
,
'txt_in.individual_token_refiner.blocks.1.mlp.fc2.bias'
)
self
.
txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1
=
MM_WEIGHT_REGISTER
[
"Default"
](
'txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.weight'
,
'txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.bias'
)
self
.
time_in_mlp_0
=
MM_WEIGHT_REGISTER
[
"Default"
](
'time_in.mlp.0.weight'
,
'time_in.mlp.0.bias'
)
self
.
time_in_mlp_2
=
MM_WEIGHT_REGISTER
[
"Default"
](
'time_in.mlp.2.weight'
,
'time_in.mlp.2.bias'
)
self
.
vector_in_in_layer
=
MM_WEIGHT_REGISTER
[
"Default"
](
'vector_in.in_layer.weight'
,
'vector_in.in_layer.bias'
)
self
.
vector_in_out_layer
=
MM_WEIGHT_REGISTER
[
"Default"
](
'vector_in.out_layer.weight'
,
'vector_in.out_layer.bias'
)
self
.
guidance_in_mlp_0
=
MM_WEIGHT_REGISTER
[
"Default"
](
'guidance_in.mlp.0.weight'
,
'guidance_in.mlp.0.bias'
)
self
.
guidance_in_mlp_2
=
MM_WEIGHT_REGISTER
[
"Default"
](
'guidance_in.mlp.2.weight'
,
'guidance_in.mlp.2.bias'
)
self
.
weight_list
=
[
self
.
img_in_proj
,
self
.
txt_in_input_embedder
,
self
.
txt_in_t_embedder_mlp_0
,
self
.
txt_in_t_embedder_mlp_2
,
self
.
txt_in_c_embedder_linear_1
,
self
.
txt_in_c_embedder_linear_2
,
self
.
txt_in_individual_token_refiner_blocks_0_norm1
,
self
.
txt_in_individual_token_refiner_blocks_0_self_attn_qkv
,
self
.
txt_in_individual_token_refiner_blocks_0_self_attn_proj
,
self
.
txt_in_individual_token_refiner_blocks_0_norm2
,
self
.
txt_in_individual_token_refiner_blocks_0_mlp_fc1
,
self
.
txt_in_individual_token_refiner_blocks_0_mlp_fc2
,
self
.
txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1
,
self
.
txt_in_individual_token_refiner_blocks_1_norm1
,
self
.
txt_in_individual_token_refiner_blocks_1_self_attn_qkv
,
self
.
txt_in_individual_token_refiner_blocks_1_self_attn_proj
,
self
.
txt_in_individual_token_refiner_blocks_1_norm2
,
self
.
txt_in_individual_token_refiner_blocks_1_mlp_fc1
,
self
.
txt_in_individual_token_refiner_blocks_1_mlp_fc2
,
self
.
txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1
,
self
.
time_in_mlp_0
,
self
.
time_in_mlp_2
,
self
.
vector_in_in_layer
,
self
.
vector_in_out_layer
,
self
.
guidance_in_mlp_0
,
self
.
guidance_in_mlp_2
,
]
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
LNWeightTemplate
)
or
isinstance
(
mm_weight
,
Conv3dWeightTemplate
):
mm_weight
.
set_config
(
self
.
config
[
'mm_config'
])
mm_weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
LNWeightTemplate
)
or
isinstance
(
mm_weight
,
Conv3dWeightTemplate
):
mm_weight
.
to_cpu
()
def
to_cuda
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
LNWeightTemplate
)
or
isinstance
(
mm_weight
,
Conv3dWeightTemplate
):
mm_weight
.
to_cuda
()
lightx2v/text2v/models/networks/hunyuan/weights/transformer_weights.py
0 → 100755
View file @
daf4c74e
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.common.ops.norm.rms_norm_weight
import
RMS_WEIGHT_REGISTER
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
from
lightx2v.common.ops.norm.rms_norm_weight
import
RMSWeightTemplate
class
HunyuanTransformerWeights
:
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
init
()
def
init
(
self
):
self
.
double_blocks_num
=
20
self
.
single_blocks_num
=
40
def
load_weights
(
self
,
weight_dict
):
self
.
double_blocks_weights
=
[
HunyuanTransformerDoubleBlock
(
i
,
self
.
config
)
for
i
in
range
(
self
.
double_blocks_num
)]
self
.
single_blocks_weights
=
[
HunyuanTransformerSingleBlock
(
i
,
self
.
config
)
for
i
in
range
(
self
.
single_blocks_num
)]
for
double_block
in
self
.
double_blocks_weights
:
double_block
.
load_weights
(
weight_dict
)
for
single_block
in
self
.
single_blocks_weights
:
single_block
.
load_weights
(
weight_dict
)
def
to_cpu
(
self
):
for
double_block
in
self
.
double_blocks_weights
:
double_block
.
to_cpu
()
for
single_block
in
self
.
single_blocks_weights
:
single_block
.
to_cpu
()
def
to_cuda
(
self
):
for
double_block
in
self
.
double_blocks_weights
:
double_block
.
to_cuda
()
for
single_block
in
self
.
single_blocks_weights
:
single_block
.
to_cuda
()
class
HunyuanTransformerDoubleBlock
:
def
__init__
(
self
,
block_index
,
config
):
self
.
block_index
=
block_index
self
.
config
=
config
self
.
weight_list
=
[]
def
load_weights
(
self
,
weight_dict
):
if
self
.
config
[
'do_mm_calib'
]:
mm_type
=
'Calib'
else
:
mm_type
=
self
.
config
[
'mm_config'
].
get
(
'mm_type'
,
'Default'
)
if
self
.
config
[
'mm_config'
]
else
'Default'
self
.
img_mod
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
'double_blocks.
{
self
.
block_index
}
.img_mod.linear.weight'
,
f
'double_blocks.
{
self
.
block_index
}
.img_mod.linear.bias'
)
self
.
img_attn_qkv
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
'double_blocks.
{
self
.
block_index
}
.img_attn_qkv.weight'
,
f
'double_blocks.
{
self
.
block_index
}
.img_attn_qkv.bias'
)
self
.
img_attn_q_norm
=
RMS_WEIGHT_REGISTER
[
'sgl-kernel'
](
f
'double_blocks.
{
self
.
block_index
}
.img_attn_q_norm.weight'
,
eps
=
1e-6
)
self
.
img_attn_k_norm
=
RMS_WEIGHT_REGISTER
[
'sgl-kernel'
](
f
'double_blocks.
{
self
.
block_index
}
.img_attn_k_norm.weight'
,
eps
=
1e-6
)
self
.
img_attn_proj
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
'double_blocks.
{
self
.
block_index
}
.img_attn_proj.weight'
,
f
'double_blocks.
{
self
.
block_index
}
.img_attn_proj.bias'
)
self
.
img_mlp_fc1
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
'double_blocks.
{
self
.
block_index
}
.img_mlp.fc1.weight'
,
f
'double_blocks.
{
self
.
block_index
}
.img_mlp.fc1.bias'
)
self
.
img_mlp_fc2
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
'double_blocks.
{
self
.
block_index
}
.img_mlp.fc2.weight'
,
f
'double_blocks.
{
self
.
block_index
}
.img_mlp.fc2.bias'
)
self
.
txt_mod
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
'double_blocks.
{
self
.
block_index
}
.txt_mod.linear.weight'
,
f
'double_blocks.
{
self
.
block_index
}
.txt_mod.linear.bias'
)
self
.
txt_attn_qkv
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
'double_blocks.
{
self
.
block_index
}
.txt_attn_qkv.weight'
,
f
'double_blocks.
{
self
.
block_index
}
.txt_attn_qkv.bias'
)
self
.
txt_attn_q_norm
=
RMS_WEIGHT_REGISTER
[
'sgl-kernel'
](
f
'double_blocks.
{
self
.
block_index
}
.txt_attn_q_norm.weight'
,
eps
=
1e-6
)
self
.
txt_attn_k_norm
=
RMS_WEIGHT_REGISTER
[
'sgl-kernel'
](
f
'double_blocks.
{
self
.
block_index
}
.txt_attn_k_norm.weight'
,
eps
=
1e-6
)
self
.
txt_attn_proj
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
'double_blocks.
{
self
.
block_index
}
.txt_attn_proj.weight'
,
f
'double_blocks.
{
self
.
block_index
}
.txt_attn_proj.bias'
)
self
.
txt_mlp_fc1
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
'double_blocks.
{
self
.
block_index
}
.txt_mlp.fc1.weight'
,
f
'double_blocks.
{
self
.
block_index
}
.txt_mlp.fc1.bias'
)
self
.
txt_mlp_fc2
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
'double_blocks.
{
self
.
block_index
}
.txt_mlp.fc2.weight'
,
f
'double_blocks.
{
self
.
block_index
}
.txt_mlp.fc2.bias'
)
self
.
weight_list
=
[
self
.
img_mod
,
self
.
img_attn_qkv
,
self
.
img_attn_q_norm
,
self
.
img_attn_k_norm
,
self
.
img_attn_proj
,
self
.
img_mlp_fc1
,
self
.
img_mlp_fc2
,
self
.
txt_mod
,
self
.
txt_attn_qkv
,
self
.
txt_attn_q_norm
,
self
.
txt_attn_k_norm
,
self
.
txt_attn_proj
,
self
.
txt_mlp_fc1
,
self
.
txt_mlp_fc2
,
]
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
RMSWeightTemplate
):
mm_weight
.
set_config
(
self
.
config
[
'mm_config'
])
mm_weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
RMSWeightTemplate
):
mm_weight
.
to_cpu
()
def
to_cuda
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
RMSWeightTemplate
):
mm_weight
.
to_cuda
()
class
HunyuanTransformerSingleBlock
:
def
__init__
(
self
,
block_index
,
config
):
self
.
block_index
=
block_index
self
.
config
=
config
self
.
weight_list
=
[]
def
load_weights
(
self
,
weight_dict
):
if
self
.
config
[
'do_mm_calib'
]:
mm_type
=
'Calib'
else
:
mm_type
=
self
.
config
[
'mm_config'
].
get
(
'mm_type'
,
'Default'
)
if
self
.
config
[
'mm_config'
]
else
'Default'
self
.
linear1
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
'single_blocks.
{
self
.
block_index
}
.linear1.weight'
,
f
'single_blocks.
{
self
.
block_index
}
.linear1.bias'
)
self
.
linear2
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
'single_blocks.
{
self
.
block_index
}
.linear2.weight'
,
f
'single_blocks.
{
self
.
block_index
}
.linear2.bias'
)
self
.
q_norm
=
RMS_WEIGHT_REGISTER
[
'sgl-kernel'
](
f
'single_blocks.
{
self
.
block_index
}
.q_norm.weight'
,
eps
=
1e-6
)
self
.
k_norm
=
RMS_WEIGHT_REGISTER
[
'sgl-kernel'
](
f
'single_blocks.
{
self
.
block_index
}
.k_norm.weight'
,
eps
=
1e-6
)
self
.
modulation
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
'single_blocks.
{
self
.
block_index
}
.modulation.linear.weight'
,
f
'single_blocks.
{
self
.
block_index
}
.modulation.linear.bias'
)
self
.
weight_list
=
[
self
.
linear1
,
self
.
linear2
,
self
.
q_norm
,
self
.
k_norm
,
self
.
modulation
,
]
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
RMSWeightTemplate
):
mm_weight
.
set_config
(
self
.
config
[
'mm_config'
])
mm_weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
RMSWeightTemplate
):
mm_weight
.
to_cpu
()
def
to_cuda
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
RMSWeightTemplate
):
mm_weight
.
to_cuda
()
lightx2v/text2v/models/networks/wan/infer/feature_caching/__init__.py
0 → 100644
View file @
daf4c74e
lightx2v/text2v/models/networks/wan/infer/feature_caching/transformer_infer.py
0 → 100644
View file @
daf4c74e
import
numpy
as
np
from
..transformer_infer
import
WanTransformerInfer
from
lightx2v.attentions
import
attention
class
WanTransformerInferFeatureCaching
(
WanTransformerInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
modulated_inp
=
embed0
if
self
.
scheduler
.
use_ret_steps
else
embed
# teacache
if
self
.
scheduler
.
cnt
%
2
==
0
:
# even -> conditon
self
.
scheduler
.
is_even
=
True
if
(
self
.
scheduler
.
cnt
<
self
.
scheduler
.
ret_steps
or
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
cutoff_steps
):
should_calc_even
=
True
self
.
scheduler
.
accumulated_rel_l1_distance_even
=
0
else
:
rescale_func
=
np
.
poly1d
(
self
.
scheduler
.
coefficients
)
self
.
scheduler
.
accumulated_rel_l1_distance_even
+=
rescale_func
(
(
(
modulated_inp
-
self
.
scheduler
.
previous_e0_even
).
abs
().
mean
()
/
self
.
scheduler
.
previous_e0_even
.
abs
().
mean
()
)
.
cpu
()
.
item
()
)
if
(
self
.
scheduler
.
accumulated_rel_l1_distance_even
<
self
.
scheduler
.
teacache_thresh
):
should_calc_even
=
False
else
:
should_calc_even
=
True
self
.
scheduler
.
accumulated_rel_l1_distance_even
=
0
self
.
scheduler
.
previous_e0_even
=
modulated_inp
.
clone
()
else
:
# odd -> unconditon
self
.
scheduler
.
is_even
=
False
if
self
.
scheduler
.
cnt
<
self
.
scheduler
.
ret_steps
or
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
cutoff_steps
:
should_calc_odd
=
True
self
.
scheduler
.
accumulated_rel_l1_distance_odd
=
0
else
:
rescale_func
=
np
.
poly1d
(
self
.
scheduler
.
coefficients
)
self
.
scheduler
.
accumulated_rel_l1_distance_odd
+=
rescale_func
(((
modulated_inp
-
self
.
scheduler
.
previous_e0_odd
).
abs
().
mean
()
/
self
.
scheduler
.
previous_e0_odd
.
abs
().
mean
()).
cpu
().
item
())
if
self
.
scheduler
.
accumulated_rel_l1_distance_odd
<
self
.
scheduler
.
teacache_thresh
:
should_calc_odd
=
False
else
:
should_calc_odd
=
True
self
.
scheduler
.
accumulated_rel_l1_distance_odd
=
0
self
.
scheduler
.
previous_e0_odd
=
modulated_inp
.
clone
()
if
self
.
scheduler
.
is_even
:
if
not
should_calc_even
:
x
+=
self
.
scheduler
.
previous_residual_even
else
:
ori_x
=
x
.
clone
()
x
=
super
().
infer
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
)
self
.
scheduler
.
previous_residual_even
=
x
-
ori_x
else
:
if
not
should_calc_odd
:
x
+=
self
.
scheduler
.
previous_residual_odd
else
:
ori_x
=
x
.
clone
()
x
=
super
().
infer
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
)
self
.
scheduler
.
previous_residual_odd
=
x
-
ori_x
return
x
lightx2v/text2v/models/networks/wan/infer/post_infer.py
0 → 100755
View file @
daf4c74e
import
math
import
torch
import
torch.cuda.amp
as
amp
class
WanPostInfer
:
def
__init__
(
self
,
config
):
self
.
out_dim
=
config
[
"out_dim"
]
self
.
patch_size
=
(
1
,
2
,
2
)
def
infer
(
self
,
weights
,
x
,
e
,
grid_sizes
):
e
=
(
weights
.
head_modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
norm_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
).
type_as
(
x
)
out
=
norm_out
*
(
1
+
e
[
1
].
squeeze
(
0
))
+
e
[
0
].
squeeze
(
0
)
x
=
torch
.
addmm
(
weights
.
head_bias
,
out
,
weights
.
head_weight
.
t
())
x
=
self
.
unpatchify
(
x
,
grid_sizes
)
return
[
u
.
float
()
for
u
in
x
]
def
unpatchify
(
self
,
x
,
grid_sizes
):
x
=
x
.
unsqueeze
(
0
)
c
=
self
.
out_dim
out
=
[]
for
u
,
v
in
zip
(
x
,
grid_sizes
.
tolist
()):
u
=
u
[:
math
.
prod
(
v
)].
view
(
*
v
,
*
self
.
patch_size
,
c
)
u
=
torch
.
einsum
(
"fhwpqrc->cfphqwr"
,
u
)
u
=
u
.
reshape
(
c
,
*
[
i
*
j
for
i
,
j
in
zip
(
v
,
self
.
patch_size
)])
out
.
append
(
u
)
return
out
lightx2v/text2v/models/networks/wan/infer/pre_infer.py
0 → 100755
View file @
daf4c74e
import
torch
import
math
from
.utils
import
rope_params
,
sinusoidal_embedding_1d
import
torch.cuda.amp
as
amp
class
WanPreInfer
:
def
__init__
(
self
,
config
):
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
]
)
%
2
==
0
d
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
task
=
config
[
'task'
]
self
.
freqs
=
torch
.
cat
(
[
rope_params
(
1024
,
d
-
4
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
],
dim
=
1
,
).
cuda
()
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
text_len
=
config
[
"text_len"
]
def
infer
(
self
,
weights
,
x
,
t
,
context
,
seq_len
,
clip_fea
=
None
,
y
=
None
):
if
self
.
task
==
'i2v'
:
x
=
[
torch
.
cat
([
u
,
v
],
dim
=
0
)
for
u
,
v
in
zip
(
x
,
y
)]
# embeddings
x
=
[
weights
.
patch_embedding
(
u
.
unsqueeze
(
0
))
for
u
in
x
]
grid_sizes
=
torch
.
stack
(
[
torch
.
tensor
(
u
.
shape
[
2
:],
dtype
=
torch
.
long
)
for
u
in
x
]
)
x
=
[
u
.
flatten
(
2
).
transpose
(
1
,
2
)
for
u
in
x
]
seq_lens
=
torch
.
tensor
([
u
.
size
(
1
)
for
u
in
x
],
dtype
=
torch
.
long
).
cuda
()
assert
seq_lens
.
max
()
<=
seq_len
x
=
torch
.
cat
(
[
torch
.
cat
([
u
,
u
.
new_zeros
(
1
,
seq_len
-
u
.
size
(
1
),
u
.
size
(
2
))],
dim
=
1
)
for
u
in
x
]
)
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
)
embed
=
torch
.
addmm
(
weights
.
time_embedding_0_bias
,
embed
,
weights
.
time_embedding_0_weight
.
t
(),
)
embed
=
torch
.
nn
.
functional
.
silu
(
embed
)
embed
=
torch
.
addmm
(
weights
.
time_embedding_2_bias
,
embed
,
weights
.
time_embedding_2_weight
.
t
(),
)
embed0
=
torch
.
nn
.
functional
.
silu
(
embed
)
embed0
=
torch
.
addmm
(
weights
.
time_projection_1_bias
,
embed0
,
weights
.
time_projection_1_weight
.
t
(),
).
unflatten
(
1
,
(
6
,
self
.
dim
))
# text embeddings
stacked
=
torch
.
stack
(
[
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
text_len
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context
]
)
out
=
torch
.
addmm
(
weights
.
text_embedding_0_bias
,
stacked
.
squeeze
(
0
),
weights
.
text_embedding_0_weight
.
t
(),
)
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
context
=
torch
.
addmm
(
weights
.
text_embedding_2_bias
,
out
,
weights
.
text_embedding_2_weight
.
t
(),
)
if
self
.
task
==
'i2v'
:
context_clip
=
torch
.
nn
.
functional
.
layer_norm
(
clip_fea
,
normalized_shape
=
(
clip_fea
.
shape
[
1
],),
weight
=
weights
.
proj_0_weight
,
bias
=
weights
.
proj_0_bias
,
eps
=
1e-5
,
)
context_clip
=
torch
.
addmm
(
weights
.
proj_1_bias
,
context_clip
,
weights
.
proj_1_weight
.
t
(),
)
context_clip
=
torch
.
nn
.
functional
.
gelu
(
context_clip
,
approximate
=
"none"
)
context_clip
=
torch
.
addmm
(
weights
.
proj_3_bias
,
context_clip
,
weights
.
proj_3_weight
.
t
(),
)
context_clip
=
torch
.
nn
.
functional
.
layer_norm
(
context_clip
,
normalized_shape
=
(
context_clip
.
shape
[
1
],),
weight
=
weights
.
proj_4_weight
,
bias
=
weights
.
proj_4_bias
,
eps
=
1e-5
,
)
context
=
torch
.
concat
([
context_clip
,
context
],
dim
=
0
)
return
(
embed
,
grid_sizes
,
(
x
.
squeeze
(
0
),
embed0
.
squeeze
(
0
),
seq_lens
,
self
.
freqs
,
context
),
)
lightx2v/text2v/models/networks/wan/infer/transformer_infer.py
0 → 100755
View file @
daf4c74e
import
torch
from
.utils
import
compute_freqs
,
apply_rotary_emb
,
rms_norm
from
lightx2v.attentions
import
attention
class
WanTransformerInfer
:
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
task
=
config
[
'task'
]
self
.
attention_type
=
config
.
get
(
"attention_type"
,
"flash_attn2"
)
self
.
blocks_num
=
config
[
"num_layers"
]
self
.
num_heads
=
config
[
"num_heads"
]
self
.
head_dim
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
window_size
=
config
.
get
(
"window_size"
,
(
-
1
,
-
1
))
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
def
_calculate_q_k_len
(
self
,
q
,
k
,
k_lens
):
lq
,
nq
,
c1
=
q
.
size
()
lk
,
nk
,
c1_k
=
k
.
size
()
# Handle query and key lengths (use `q_lens` and `k_lens` or set them to Lq and Lk if None)
q_lens
=
torch
.
tensor
([
lq
],
dtype
=
torch
.
int32
,
device
=
q
.
device
)
# We don't have a batch dimension anymore, so directly use the `q_lens` and `k_lens` values
cu_seqlens_q
=
(
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
])
.
cumsum
(
0
,
dtype
=
torch
.
int32
)
)
cu_seqlens_k
=
(
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
])
.
cumsum
(
0
,
dtype
=
torch
.
int32
)
)
return
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
for
i
in
range
(
self
.
blocks_num
):
x
=
self
.
infer_block
(
weights
.
blocks_weights
[
i
],
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
)
return
x
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
embed0
=
(
weights
.
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
norm1_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
norm1_out
=
(
norm1_out
*
(
1
+
embed0
[
1
])
+
embed0
[
0
]).
squeeze
(
0
)
s
,
n
,
d
=
*
norm1_out
.
shape
[:
1
],
self
.
num_heads
,
self
.
head_dim
q
=
rms_norm
(
weights
.
self_attn_q
.
apply
(
norm1_out
),
weights
.
self_attn_norm_q_weight
,
1e-6
).
view
(
s
,
n
,
d
)
k
=
rms_norm
(
weights
.
self_attn_k
.
apply
(
norm1_out
),
weights
.
self_attn_norm_k_weight
,
1e-6
).
view
(
s
,
n
,
d
)
v
=
weights
.
self_attn_v
.
apply
(
norm1_out
).
view
(
s
,
n
,
d
)
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
q
=
apply_rotary_emb
(
q
,
freqs_i
)
k
=
apply_rotary_emb
(
k
,
freqs_i
)
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
=
self
.
_calculate_q_k_len
(
q
,
k
,
k_lens
=
seq_lens
)
attn_out
=
attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
lq
,
max_seqlen_kv
=
lk
,
)
y
=
weights
.
self_attn_o
.
apply
(
attn_out
)
x
=
x
+
y
*
embed0
[
2
].
squeeze
(
0
)
norm3_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
normalized_shape
=
(
x
.
shape
[
1
],),
weight
=
weights
.
norm3_weight
,
bias
=
weights
.
norm3_bias
,
eps
=
1e-6
,
)
if
self
.
task
==
'i2v'
:
context_img
=
context
[:
257
]
context
=
context
[
257
:]
n
,
d
=
self
.
num_heads
,
self
.
head_dim
q
=
rms_norm
(
weights
.
cross_attn_q
.
apply
(
norm3_out
),
weights
.
cross_attn_norm_q_weight
,
1e-6
).
view
(
-
1
,
n
,
d
)
k
=
rms_norm
(
weights
.
cross_attn_k
.
apply
(
context
),
weights
.
cross_attn_norm_k_weight
,
1e-6
).
view
(
-
1
,
n
,
d
)
v
=
weights
.
cross_attn_v
.
apply
(
context
).
view
(
-
1
,
n
,
d
)
if
self
.
task
==
'i2v'
:
k_img
=
rms_norm
(
weights
.
cross_attn_k_img
.
apply
(
context_img
),
weights
.
cross_attn_norm_k_img_weight
,
1e-6
).
view
(
-
1
,
n
,
d
)
v_img
=
weights
.
cross_attn_v_img
.
apply
(
context_img
).
view
(
-
1
,
n
,
d
)
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
=
self
.
_calculate_q_k_len
(
q
,
k_img
,
k_lens
=
torch
.
tensor
([
k_img
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
)
)
img_attn_out
=
attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k_img
,
v
=
v_img
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
lq
,
max_seqlen_kv
=
lk
,
)
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
=
self
.
_calculate_q_k_len
(
q
,
k
,
k_lens
=
torch
.
tensor
([
k
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
)
)
attn_out
=
attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
lq
,
max_seqlen_kv
=
lk
,
)
attn_out
=
weights
.
cross_attn_o
.
apply
(
attn_out
)
x
=
x
+
attn_out
norm2_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
y
=
weights
.
ffn_0
.
apply
(
norm2_out
*
(
1
+
embed0
[
4
].
squeeze
(
0
))
+
embed0
[
3
].
squeeze
(
0
))
y
=
torch
.
nn
.
functional
.
gelu
(
y
,
approximate
=
"tanh"
)
y
=
weights
.
ffn_2
.
apply
(
y
)
x
=
x
+
y
*
embed0
[
5
].
squeeze
(
0
)
return
x
lightx2v/text2v/models/networks/wan/infer/utils.py
0 → 100755
View file @
daf4c74e
import
torch
import
sgl_kernel
import
torch.cuda.amp
as
amp
def
rms_norm
(
x
,
weight
,
eps
):
x
=
x
.
contiguous
()
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
orig_shape
[
-
1
])
x
=
sgl_kernel
.
rmsnorm
(
x
,
weight
,
eps
).
view
(
orig_shape
)
return
x
def
compute_freqs
(
c
,
grid_sizes
,
freqs
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
return
freqs_i
def
apply_rotary_emb
(
x
,
freqs_i
):
n
=
x
.
size
(
1
)
seq_len
=
freqs_i
.
size
(
0
)
x_i
=
torch
.
view_as_complex
(
x
[:
seq_len
].
to
(
torch
.
float64
).
reshape
(
seq_len
,
n
,
-
1
,
2
)
)
# Apply rotary embedding
x_i
=
torch
.
view_as_real
(
x_i
*
freqs_i
).
flatten
(
2
)
x_i
=
torch
.
cat
([
x_i
,
x
[
seq_len
:]]).
to
(
torch
.
bfloat16
)
return
x_i
def
rope_params
(
max_seq_len
,
dim
,
theta
=
10000
):
assert
dim
%
2
==
0
freqs
=
torch
.
outer
(
torch
.
arange
(
max_seq_len
),
1.0
/
torch
.
pow
(
theta
,
torch
.
arange
(
0
,
dim
,
2
).
to
(
torch
.
float64
).
div
(
dim
)),
)
freqs
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
return
freqs
def
sinusoidal_embedding_1d
(
dim
,
position
):
# preprocess
assert
dim
%
2
==
0
half
=
dim
//
2
position
=
position
.
type
(
torch
.
float64
)
# calculation
sinusoid
=
torch
.
outer
(
position
,
torch
.
pow
(
10000
,
-
torch
.
arange
(
half
).
to
(
position
).
div
(
half
))
)
x
=
torch
.
cat
([
torch
.
cos
(
sinusoid
),
torch
.
sin
(
sinusoid
)],
dim
=
1
).
to
(
torch
.
bfloat16
)
return
x
lightx2v/text2v/models/networks/wan/model.py
0 → 100755
View file @
daf4c74e
import
os
import
torch
import
time
import
glob
from
lightx2v.text2v.models.networks.wan.weights.pre_weights
import
WanPreWeights
from
lightx2v.text2v.models.networks.wan.weights.post_weights
import
WanPostWeights
from
lightx2v.text2v.models.networks.wan.weights.transformer_weights
import
(
WanTransformerWeights
,
)
from
lightx2v.text2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.text2v.models.networks.wan.infer.post_infer
import
WanPostInfer
from
lightx2v.text2v.models.networks.wan.infer.transformer_infer
import
(
WanTransformerInfer
,
)
from
lightx2v.text2v.models.networks.wan.infer.feature_caching.transformer_infer
import
WanTransformerInferFeatureCaching
from
safetensors
import
safe_open
class
WanModel
:
pre_weight_class
=
WanPreWeights
post_weight_class
=
WanPostWeights
transformer_weight_class
=
WanTransformerWeights
def
__init__
(
self
,
model_path
,
config
):
self
.
model_path
=
model_path
self
.
config
=
config
self
.
_init_infer_class
()
self
.
_init_weights
()
self
.
_init_infer
()
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
WanPreInfer
self
.
post_infer_class
=
WanPostInfer
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
self
.
transformer_infer_class
=
WanTransformerInfer
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
WanTransformerInferFeatureCaching
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
def
_load_safetensor_to_dict
(
self
,
file_path
):
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
to
(
torch
.
bfloat16
).
cuda
()
for
key
in
f
.
keys
()
}
return
tensor_dict
def
_load_ckpt
(
self
):
safetensors_pattern
=
os
.
path
.
join
(
self
.
model_path
,
"*.safetensors"
)
safetensors_files
=
glob
.
glob
(
safetensors_pattern
)
if
not
safetensors_files
:
raise
FileNotFoundError
(
f
"No .safetensors files found in directory:
{
self
.
model_path
}
"
)
weight_dict
=
{}
for
file_path
in
safetensors_files
:
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
)
weight_dict
.
update
(
file_weights
)
return
weight_dict
def
_init_weights
(
self
):
weight_dict
=
self
.
_load_ckpt
()
# init weights
self
.
pre_weight
=
self
.
pre_weight_class
(
self
.
config
)
self
.
post_weight
=
self
.
post_weight_class
()
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
# load weights
self
.
pre_weight
.
load_weights
(
weight_dict
)
self
.
post_weight
.
load_weights
(
weight_dict
)
self
.
transformer_weights
.
load_weights
(
weight_dict
)
def
_init_infer
(
self
):
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
transformer_infer
.
set_scheduler
(
scheduler
)
@
torch
.
no_grad
()
def
infer
(
self
,
text_encoders_output
,
image_encoder_output
,
args
):
timestep
=
torch
.
stack
([
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]])
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
[
self
.
scheduler
.
latents
],
timestep
,
text_encoders_output
[
"context"
],
self
.
scheduler
.
seq_len
,
image_encoder_output
[
"clip_encoder_out"
],
[
image_encoder_output
[
"vae_encode_out"
]],
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
[
self
.
scheduler
.
latents
],
timestep
,
text_encoders_output
[
"context_null"
],
self
.
scheduler
.
seq_len
,
image_encoder_output
[
"clip_encoder_out"
],
[
image_encoder_output
[
"vae_encode_out"
]],
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
scheduler
.
cnt
+=
1
if
self
.
scheduler
.
cnt
>=
self
.
scheduler
.
num_steps
:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_uncond
+
args
.
sample_guide_scale
*
(
noise_pred_cond
-
noise_pred_uncond
)
lightx2v/text2v/models/networks/wan/weights/post_weights.py
0 → 100755
View file @
daf4c74e
class
WanPostWeights
:
def
__init__
(
self
):
pass
def
load_weights
(
self
,
weight_dict
):
head_layers
=
{
"head"
:
[
"head.weight"
,
"head.bias"
,
"modulation"
]}
for
param_name
,
param_keys
in
head_layers
.
items
():
for
key
in
param_keys
:
weight_path
=
f
"
{
param_name
}
.
{
key
}
"
key
=
key
.
split
(
'.'
)
setattr
(
self
,
f
"
{
param_name
}
_
{
key
[
-
1
]
}
"
,
weight_dict
[
weight_path
])
\ No newline at end of file
lightx2v/text2v/models/networks/wan/weights/pre_weights.py
0 → 100755
View file @
daf4c74e
import
torch
class
WanPreWeights
:
def
__init__
(
self
,
config
):
self
.
in_dim
=
config
[
"in_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
patch_size
=
(
1
,
2
,
2
)
def
load_weights
(
self
,
weight_dict
):
layers
=
{
"text_embedding"
:
{
"0"
:
[
"weight"
,
"bias"
],
"2"
:
[
"weight"
,
"bias"
]},
"time_embedding"
:
{
"0"
:
[
"weight"
,
"bias"
],
"2"
:
[
"weight"
,
"bias"
]},
"time_projection"
:
{
"1"
:
[
"weight"
,
"bias"
]},
}
self
.
patch_embedding
=
(
torch
.
nn
.
Conv3d
(
self
.
in_dim
,
self
.
dim
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
)
.
to
(
torch
.
bfloat16
)
.
cuda
()
)
self
.
patch_embedding
.
weight
.
data
.
copy_
(
weight_dict
[
"patch_embedding.weight"
])
self
.
patch_embedding
.
bias
.
data
.
copy_
(
weight_dict
[
"patch_embedding.bias"
])
for
module_name
,
sub_layers
in
layers
.
items
():
for
param_name
,
param_keys
in
sub_layers
.
items
():
for
key
in
param_keys
:
weight_path
=
f
"
{
module_name
}
.
{
param_name
}
.
{
key
}
"
setattr
(
self
,
f
"
{
module_name
}
_
{
param_name
}
_
{
key
}
"
,
weight_dict
[
weight_path
],
)
if
'img_emb.proj.0.weight'
in
weight_dict
.
keys
():
MLP_layers
=
{
"proj_0_weight"
:
"proj.0.weight"
,
"proj_0_bias"
:
"proj.0.bias"
,
"proj_1_weight"
:
"proj.1.weight"
,
"proj_1_bias"
:
"proj.1.bias"
,
"proj_3_weight"
:
"proj.3.weight"
,
"proj_3_bias"
:
"proj.3.bias"
,
"proj_4_weight"
:
"proj.4.weight"
,
"proj_4_bias"
:
"proj.4.bias"
,
}
for
layer_name
,
weight_keys
in
MLP_layers
.
items
():
weight_path
=
f
"img_emb.
{
weight_keys
}
"
setattr
(
self
,
layer_name
,
weight_dict
[
weight_path
])
\ No newline at end of file
lightx2v/text2v/models/networks/wan/weights/transformer_weights.py
0 → 100755
View file @
daf4c74e
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
class
WanTransformerWeights
:
def
__init__
(
self
,
config
):
self
.
blocks_num
=
config
[
"num_layers"
]
self
.
task
=
config
[
'task'
]
if
config
[
'do_mm_calib'
]:
self
.
mm_type
=
'Calib'
else
:
self
.
mm_type
=
config
[
'mm_config'
].
get
(
'mm_type'
,
'Default'
)
if
config
[
'mm_config'
]
else
'Default'
def
load_weights
(
self
,
weight_dict
):
self
.
blocks_weights
=
[
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
)
for
i
in
range
(
self
.
blocks_num
)
]
for
block
in
self
.
blocks_weights
:
block
.
load_weights
(
weight_dict
)
class
WanTransformerAttentionBlock
:
def
__init__
(
self
,
block_index
,
task
,
mm_type
):
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
task
=
task
def
load_weights
(
self
,
weight_dict
):
if
self
.
task
==
't2v'
:
layers
=
{
"self_attn_q"
:
[
"self_attn.q.weight"
,
"self_attn.q.bias"
],
"self_attn_k"
:
[
"self_attn.k.weight"
,
"self_attn.k.bias"
],
"self_attn_v"
:
[
"self_attn.v.weight"
,
"self_attn.v.bias"
],
"self_attn_o"
:
[
"self_attn.o.weight"
,
"self_attn.o.bias"
],
"self_attn_norm_q_weight"
:
"self_attn.norm_q.weight"
,
"self_attn_norm_k_weight"
:
"self_attn.norm_k.weight"
,
"norm3_weight"
:
"norm3.weight"
,
"norm3_bias"
:
"norm3.bias"
,
"cross_attn_q"
:
[
"cross_attn.q.weight"
,
"cross_attn.q.bias"
],
"cross_attn_k"
:
[
"cross_attn.k.weight"
,
"cross_attn.k.bias"
],
"cross_attn_v"
:
[
"cross_attn.v.weight"
,
"cross_attn.v.bias"
],
"cross_attn_o"
:
[
"cross_attn.o.weight"
,
"cross_attn.o.bias"
],
"cross_attn_norm_q_weight"
:
"cross_attn.norm_q.weight"
,
"cross_attn_norm_k_weight"
:
"cross_attn.norm_k.weight"
,
"ffn_0"
:
[
"ffn.0.weight"
,
"ffn.0.bias"
],
"ffn_2"
:
[
"ffn.2.weight"
,
"ffn.2.bias"
],
"modulation"
:
"modulation"
,
}
elif
self
.
task
==
'i2v'
:
layers
=
{
"self_attn_q"
:
[
"self_attn.q.weight"
,
"self_attn.q.bias"
],
"self_attn_k"
:
[
"self_attn.k.weight"
,
"self_attn.k.bias"
],
"self_attn_v"
:
[
"self_attn.v.weight"
,
"self_attn.v.bias"
],
"self_attn_o"
:
[
"self_attn.o.weight"
,
"self_attn.o.bias"
],
"self_attn_norm_q_weight"
:
"self_attn.norm_q.weight"
,
"self_attn_norm_k_weight"
:
"self_attn.norm_k.weight"
,
"norm3_weight"
:
"norm3.weight"
,
"norm3_bias"
:
"norm3.bias"
,
"cross_attn_q"
:
[
"cross_attn.q.weight"
,
"cross_attn.q.bias"
],
"cross_attn_k"
:
[
"cross_attn.k.weight"
,
"cross_attn.k.bias"
],
"cross_attn_v"
:
[
"cross_attn.v.weight"
,
"cross_attn.v.bias"
],
"cross_attn_o"
:
[
"cross_attn.o.weight"
,
"cross_attn.o.bias"
],
"cross_attn_norm_q_weight"
:
"cross_attn.norm_q.weight"
,
"cross_attn_norm_k_weight"
:
"cross_attn.norm_k.weight"
,
"cross_attn_k_img"
:
[
"cross_attn.k_img.weight"
,
"cross_attn.k_img.bias"
],
"cross_attn_v_img"
:
[
"cross_attn.v_img.weight"
,
"cross_attn.v_img.bias"
],
"cross_attn_norm_k_img_weight"
:
"cross_attn.norm_k_img.weight"
,
"ffn_0"
:
[
"ffn.0.weight"
,
"ffn.0.bias"
],
"ffn_2"
:
[
"ffn.2.weight"
,
"ffn.2.bias"
],
"modulation"
:
"modulation"
,
}
for
layer_name
,
weight_keys
in
layers
.
items
():
if
isinstance
(
weight_keys
,
list
):
weight_key
,
bias_key
=
weight_keys
weight_path
=
f
"blocks.
{
self
.
block_index
}
.
{
weight_key
}
"
bias_path
=
f
"blocks.
{
self
.
block_index
}
.
{
bias_key
}
"
setattr
(
self
,
layer_name
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
weight_path
,
bias_path
))
getattr
(
self
,
layer_name
).
load
(
weight_dict
)
else
:
weight_path
=
f
"blocks.
{
self
.
block_index
}
.
{
weight_keys
}
"
setattr
(
self
,
layer_name
,
weight_dict
[
weight_path
])
\ No newline at end of file
lightx2v/text2v/models/schedulers/__init__.py
0 → 100755
View file @
daf4c74e
lightx2v/text2v/models/schedulers/hunyuan/feature_caching/scheduler.py
0 → 100755
View file @
daf4c74e
import
torch
from
..scheduler
import
HunyuanScheduler
def
cache_init
(
num_steps
,
model_kwargs
=
None
):
'''
Initialization for cache.
'''
cache_dic
=
{}
cache
=
{}
cache_index
=
{}
cache
[
-
1
]
=
{}
cache_index
[
-
1
]
=
{}
cache_index
[
'layer_index'
]
=
{}
cache_dic
[
'attn_map'
]
=
{}
cache_dic
[
'attn_map'
][
-
1
]
=
{}
cache_dic
[
'attn_map'
][
-
1
][
'double_stream'
]
=
{}
cache_dic
[
'attn_map'
][
-
1
][
'single_stream'
]
=
{}
cache_dic
[
'k-norm'
]
=
{}
cache_dic
[
'k-norm'
][
-
1
]
=
{}
cache_dic
[
'k-norm'
][
-
1
][
'double_stream'
]
=
{}
cache_dic
[
'k-norm'
][
-
1
][
'single_stream'
]
=
{}
cache_dic
[
'v-norm'
]
=
{}
cache_dic
[
'v-norm'
][
-
1
]
=
{}
cache_dic
[
'v-norm'
][
-
1
][
'double_stream'
]
=
{}
cache_dic
[
'v-norm'
][
-
1
][
'single_stream'
]
=
{}
cache_dic
[
'cross_attn_map'
]
=
{}
cache_dic
[
'cross_attn_map'
][
-
1
]
=
{}
cache
[
-
1
][
'double_stream'
]
=
{}
cache
[
-
1
][
'single_stream'
]
=
{}
cache_dic
[
'cache_counter'
]
=
0
for
j
in
range
(
20
):
cache
[
-
1
][
'double_stream'
][
j
]
=
{}
cache_index
[
-
1
][
j
]
=
{}
cache_dic
[
'attn_map'
][
-
1
][
'double_stream'
][
j
]
=
{}
cache_dic
[
'attn_map'
][
-
1
][
'double_stream'
][
j
][
'total'
]
=
{}
cache_dic
[
'attn_map'
][
-
1
][
'double_stream'
][
j
][
'txt_mlp'
]
=
{}
cache_dic
[
'attn_map'
][
-
1
][
'double_stream'
][
j
][
'img_mlp'
]
=
{}
cache_dic
[
'k-norm'
][
-
1
][
'double_stream'
][
j
]
=
{}
cache_dic
[
'k-norm'
][
-
1
][
'double_stream'
][
j
][
'txt_mlp'
]
=
{}
cache_dic
[
'k-norm'
][
-
1
][
'double_stream'
][
j
][
'img_mlp'
]
=
{}
cache_dic
[
'v-norm'
][
-
1
][
'double_stream'
][
j
]
=
{}
cache_dic
[
'v-norm'
][
-
1
][
'double_stream'
][
j
][
'txt_mlp'
]
=
{}
cache_dic
[
'v-norm'
][
-
1
][
'double_stream'
][
j
][
'img_mlp'
]
=
{}
for
j
in
range
(
40
):
cache
[
-
1
][
'single_stream'
][
j
]
=
{}
cache_index
[
-
1
][
j
]
=
{}
cache_dic
[
'attn_map'
][
-
1
][
'single_stream'
][
j
]
=
{}
cache_dic
[
'attn_map'
][
-
1
][
'single_stream'
][
j
][
'total'
]
=
{}
cache_dic
[
'k-norm'
][
-
1
][
'single_stream'
][
j
]
=
{}
cache_dic
[
'k-norm'
][
-
1
][
'single_stream'
][
j
][
'total'
]
=
{}
cache_dic
[
'v-norm'
][
-
1
][
'single_stream'
][
j
]
=
{}
cache_dic
[
'v-norm'
][
-
1
][
'single_stream'
][
j
][
'total'
]
=
{}
cache_dic
[
'taylor_cache'
]
=
False
cache_dic
[
'duca'
]
=
False
cache_dic
[
'test_FLOPs'
]
=
False
mode
=
'Taylor'
if
mode
==
'original'
:
cache_dic
[
'cache_type'
]
=
'random'
cache_dic
[
'cache_index'
]
=
cache_index
cache_dic
[
'cache'
]
=
cache
cache_dic
[
'fresh_ratio_schedule'
]
=
'ToCa'
cache_dic
[
'fresh_ratio'
]
=
0.0
cache_dic
[
'fresh_threshold'
]
=
1
cache_dic
[
'force_fresh'
]
=
'global'
cache_dic
[
'soft_fresh_weight'
]
=
0.0
cache_dic
[
'max_order'
]
=
0
cache_dic
[
'first_enhance'
]
=
1
elif
mode
==
'ToCa'
:
cache_dic
[
'cache_type'
]
=
'random'
cache_dic
[
'cache_index'
]
=
cache_index
cache_dic
[
'cache'
]
=
cache
cache_dic
[
'fresh_ratio_schedule'
]
=
'ToCa'
cache_dic
[
'fresh_ratio'
]
=
0.10
cache_dic
[
'fresh_threshold'
]
=
5
cache_dic
[
'force_fresh'
]
=
'global'
cache_dic
[
'soft_fresh_weight'
]
=
0.0
cache_dic
[
'max_order'
]
=
0
cache_dic
[
'first_enhance'
]
=
1
cache_dic
[
'duca'
]
=
False
elif
mode
==
'DuCa'
:
cache_dic
[
'cache_type'
]
=
'random'
cache_dic
[
'cache_index'
]
=
cache_index
cache_dic
[
'cache'
]
=
cache
cache_dic
[
'fresh_ratio_schedule'
]
=
'ToCa'
cache_dic
[
'fresh_ratio'
]
=
0.10
cache_dic
[
'fresh_threshold'
]
=
5
cache_dic
[
'force_fresh'
]
=
'global'
cache_dic
[
'soft_fresh_weight'
]
=
0.0
cache_dic
[
'max_order'
]
=
0
cache_dic
[
'first_enhance'
]
=
1
cache_dic
[
'duca'
]
=
True
elif
mode
==
'Taylor'
:
cache_dic
[
'cache_type'
]
=
'random'
cache_dic
[
'cache_index'
]
=
cache_index
cache_dic
[
'cache'
]
=
cache
cache_dic
[
'fresh_ratio_schedule'
]
=
'ToCa'
cache_dic
[
'fresh_ratio'
]
=
0.0
cache_dic
[
'fresh_threshold'
]
=
5
cache_dic
[
'max_order'
]
=
1
cache_dic
[
'force_fresh'
]
=
'global'
cache_dic
[
'soft_fresh_weight'
]
=
0.0
cache_dic
[
'taylor_cache'
]
=
True
cache_dic
[
'first_enhance'
]
=
1
current
=
{}
current
[
'num_steps'
]
=
num_steps
current
[
'activated_steps'
]
=
[
0
]
return
cache_dic
,
current
def
force_scheduler
(
cache_dic
,
current
):
if
cache_dic
[
'fresh_ratio'
]
==
0
:
# FORA
linear_step_weight
=
0.0
else
:
# TokenCache
linear_step_weight
=
0.0
step_factor
=
torch
.
tensor
(
1
-
linear_step_weight
+
2
*
linear_step_weight
*
current
[
'step'
]
/
current
[
'num_steps'
])
threshold
=
torch
.
round
(
cache_dic
[
'fresh_threshold'
]
/
step_factor
)
# no force constrain for sensitive steps, cause the performance is good enough.
# you may have a try.
cache_dic
[
'cal_threshold'
]
=
threshold
#return threshold
def
cal_type
(
cache_dic
,
current
):
'''
Determine calculation type for this step
'''
if
(
cache_dic
[
'fresh_ratio'
]
==
0.0
)
and
(
not
cache_dic
[
'taylor_cache'
]):
# FORA:Uniform
first_step
=
(
current
[
'step'
]
==
0
)
else
:
# ToCa: First enhanced
first_step
=
(
current
[
'step'
]
<
cache_dic
[
'first_enhance'
])
#first_step = (current['step'] <= 3)
force_fresh
=
cache_dic
[
'force_fresh'
]
if
not
first_step
:
fresh_interval
=
cache_dic
[
'cal_threshold'
]
else
:
fresh_interval
=
cache_dic
[
'fresh_threshold'
]
if
(
first_step
)
or
(
cache_dic
[
'cache_counter'
]
==
fresh_interval
-
1
):
current
[
'type'
]
=
'full'
cache_dic
[
'cache_counter'
]
=
0
current
[
'activated_steps'
].
append
(
current
[
'step'
])
#current['activated_times'].append(current['t'])
force_scheduler
(
cache_dic
,
current
)
elif
(
cache_dic
[
'taylor_cache'
]):
cache_dic
[
'cache_counter'
]
+=
1
current
[
'type'
]
=
'taylor_cache'
else
:
cache_dic
[
'cache_counter'
]
+=
1
if
(
cache_dic
[
'duca'
]):
if
(
cache_dic
[
'cache_counter'
]
%
2
==
1
):
# 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive
current
[
'type'
]
=
'ToCa'
# 'cache_noise' 'ToCa' 'FORA'
else
:
current
[
'type'
]
=
'aggressive'
else
:
current
[
'type'
]
=
'ToCa'
#if current['step'] < 25:
# current['type'] = 'FORA'
#else:
# current['type'] = 'aggressive'
######################################################################
#if (current['step'] in [3,2,1,0]):
# current['type'] = 'full'
class
HunyuanSchedulerFeatureCaching
(
HunyuanScheduler
):
def
__init__
(
self
,
args
):
super
().
__init__
(
args
)
self
.
cache_dic
,
self
.
current
=
cache_init
(
self
.
infer_steps
)
def
step_pre
(
self
,
step_index
):
super
().
step_pre
(
step_index
)
self
.
current
[
'step'
]
=
step_index
cal_type
(
self
.
cache_dic
,
self
.
current
)
lightx2v/text2v/models/schedulers/hunyuan/scheduler.py
0 → 100755
View file @
daf4c74e
import
torch
from
diffusers.utils.torch_utils
import
randn_tensor
from
typing
import
Union
,
Tuple
,
List
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
,
Tuple
from
lightx2v.text2v.models.schedulers.scheduler
import
BaseScheduler
def
_to_tuple
(
x
,
dim
=
2
):
if
isinstance
(
x
,
int
):
return
(
x
,)
*
dim
elif
len
(
x
)
==
dim
:
return
x
else
:
raise
ValueError
(
f
"Expected length
{
dim
}
or int, but got
{
x
}
"
)
def
get_1d_rotary_pos_embed
(
dim
:
int
,
pos
:
Union
[
torch
.
FloatTensor
,
int
],
theta
:
float
=
10000.0
,
use_real
:
bool
=
False
,
theta_rescale_factor
:
float
=
1.0
,
interpolation_factor
:
float
=
1.0
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
Returns:
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
"""
if
isinstance
(
pos
,
int
):
pos
=
torch
.
arange
(
pos
).
float
()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
if
theta_rescale_factor
!=
1.0
:
theta
*=
theta_rescale_factor
**
(
dim
/
(
dim
-
2
))
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
)[:
(
dim
//
2
)].
float
()
/
dim
)
)
# [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
freqs
=
torch
.
outer
(
pos
*
interpolation_factor
,
freqs
)
# [S, D/2]
if
use_real
:
freqs_cos
=
freqs
.
cos
().
repeat_interleave
(
2
,
dim
=
1
)
# [S, D]
freqs_sin
=
freqs
.
sin
().
repeat_interleave
(
2
,
dim
=
1
)
# [S, D]
return
freqs_cos
,
freqs_sin
else
:
freqs_cis
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
# complex64 # [S, D/2]
return
freqs_cis
def
get_meshgrid_nd
(
start
,
*
args
,
dim
=
2
):
"""
Get n-D meshgrid with start, stop and num.
Args:
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
n-tuples.
*args: See above.
dim (int): Dimension of the meshgrid. Defaults to 2.
Returns:
grid (np.ndarray): [dim, ...]
"""
if
len
(
args
)
==
0
:
# start is grid_size
num
=
_to_tuple
(
start
,
dim
=
dim
)
start
=
(
0
,)
*
dim
stop
=
num
elif
len
(
args
)
==
1
:
# start is start, args[0] is stop, step is 1
start
=
_to_tuple
(
start
,
dim
=
dim
)
stop
=
_to_tuple
(
args
[
0
],
dim
=
dim
)
num
=
[
stop
[
i
]
-
start
[
i
]
for
i
in
range
(
dim
)]
elif
len
(
args
)
==
2
:
# start is start, args[0] is stop, args[1] is num
start
=
_to_tuple
(
start
,
dim
=
dim
)
# Left-Top eg: 12,0
stop
=
_to_tuple
(
args
[
0
],
dim
=
dim
)
# Right-Bottom eg: 20,32
num
=
_to_tuple
(
args
[
1
],
dim
=
dim
)
# Target Size eg: 32,124
else
:
raise
ValueError
(
f
"len(args) should be 0, 1 or 2, but got
{
len
(
args
)
}
"
)
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
axis_grid
=
[]
for
i
in
range
(
dim
):
a
,
b
,
n
=
start
[
i
],
stop
[
i
],
num
[
i
]
g
=
torch
.
linspace
(
a
,
b
,
n
+
1
,
dtype
=
torch
.
float32
)[:
n
]
axis_grid
.
append
(
g
)
grid
=
torch
.
meshgrid
(
*
axis_grid
,
indexing
=
"ij"
)
# dim x [W, H, D]
grid
=
torch
.
stack
(
grid
,
dim
=
0
)
# [dim, W, H, D]
return
grid
def
get_nd_rotary_pos_embed
(
rope_dim_list
,
start
,
*
args
,
theta
=
10000.0
,
use_real
=
False
,
theta_rescale_factor
:
Union
[
float
,
List
[
float
]]
=
1.0
,
interpolation_factor
:
Union
[
float
,
List
[
float
]]
=
1.0
,
):
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
Args:
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
sum(rope_dim_list) should equal to head_dim of attention layer.
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
*args: See above.
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
part and an imaginary part separately.
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
Returns:
pos_embed (torch.Tensor): [HW, D/2]
"""
grid
=
get_meshgrid_nd
(
start
,
*
args
,
dim
=
len
(
rope_dim_list
)
)
# [3, W, H, D] / [2, W, H]
if
isinstance
(
theta_rescale_factor
,
int
)
or
isinstance
(
theta_rescale_factor
,
float
):
theta_rescale_factor
=
[
theta_rescale_factor
]
*
len
(
rope_dim_list
)
elif
isinstance
(
theta_rescale_factor
,
list
)
and
len
(
theta_rescale_factor
)
==
1
:
theta_rescale_factor
=
[
theta_rescale_factor
[
0
]]
*
len
(
rope_dim_list
)
assert
len
(
theta_rescale_factor
)
==
len
(
rope_dim_list
),
"len(theta_rescale_factor) should equal to len(rope_dim_list)"
if
isinstance
(
interpolation_factor
,
int
)
or
isinstance
(
interpolation_factor
,
float
):
interpolation_factor
=
[
interpolation_factor
]
*
len
(
rope_dim_list
)
elif
isinstance
(
interpolation_factor
,
list
)
and
len
(
interpolation_factor
)
==
1
:
interpolation_factor
=
[
interpolation_factor
[
0
]]
*
len
(
rope_dim_list
)
assert
len
(
interpolation_factor
)
==
len
(
rope_dim_list
),
"len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis
embs
=
[]
for
i
in
range
(
len
(
rope_dim_list
)):
emb
=
get_1d_rotary_pos_embed
(
rope_dim_list
[
i
],
grid
[
i
].
reshape
(
-
1
),
theta
,
use_real
=
use_real
,
theta_rescale_factor
=
theta_rescale_factor
[
i
],
interpolation_factor
=
interpolation_factor
[
i
],
)
# 2 x [WHD, rope_dim_list[i]]
embs
.
append
(
emb
)
if
use_real
:
cos
=
torch
.
cat
([
emb
[
0
]
for
emb
in
embs
],
dim
=
1
)
# (WHD, D/2)
sin
=
torch
.
cat
([
emb
[
1
]
for
emb
in
embs
],
dim
=
1
)
# (WHD, D/2)
return
cos
,
sin
else
:
emb
=
torch
.
cat
(
embs
,
dim
=
1
)
# (WHD, D/2)
return
emb
def
set_timesteps_sigmas
(
num_inference_steps
,
shift
,
device
,
num_train_timesteps
=
1000
):
sigmas
=
torch
.
linspace
(
1
,
0
,
num_inference_steps
+
1
)
sigmas
=
(
shift
*
sigmas
)
/
(
1
+
(
shift
-
1
)
*
sigmas
)
timesteps
=
(
sigmas
[:
-
1
]
*
num_train_timesteps
).
to
(
dtype
=
torch
.
bfloat16
,
device
=
device
)
return
timesteps
,
sigmas
class
HunyuanScheduler
(
BaseScheduler
):
def
__init__
(
self
,
args
):
super
().
__init__
(
args
)
self
.
infer_steps
=
self
.
args
.
infer_steps
self
.
shift
=
7.0
self
.
timesteps
,
self
.
sigmas
=
set_timesteps_sigmas
(
self
.
infer_steps
,
self
.
shift
,
device
=
torch
.
device
(
'cuda'
))
assert
len
(
self
.
timesteps
)
==
self
.
infer_steps
self
.
embedded_guidance_scale
=
6.0
self
.
generator
=
[
torch
.
Generator
(
'cuda'
).
manual_seed
(
seed
)
for
seed
in
[
42
]]
self
.
noise_pred
=
None
self
.
prepare_latents
(
shape
=
self
.
args
.
target_shape
,
dtype
=
torch
.
bfloat16
)
self
.
prepare_guidance
()
self
.
prepare_rotary_pos_embedding
(
video_length
=
self
.
args
.
target_video_length
,
height
=
self
.
args
.
target_height
,
width
=
self
.
args
.
target_width
)
def
prepare_guidance
(
self
):
self
.
guidance
=
torch
.
tensor
([
self
.
embedded_guidance_scale
],
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
'cuda'
))
*
1000.0
def
step_post
(
self
):
sample
=
self
.
latents
.
to
(
torch
.
float32
)
dt
=
self
.
sigmas
[
self
.
step_index
+
1
]
-
self
.
sigmas
[
self
.
step_index
]
prev_sample
=
sample
+
self
.
noise_pred
.
to
(
torch
.
float32
)
*
dt
self
.
latents
=
prev_sample
def
prepare_latents
(
self
,
shape
,
dtype
):
self
.
latents
=
randn_tensor
(
shape
,
generator
=
self
.
generator
,
device
=
torch
.
device
(
'cuda'
),
dtype
=
dtype
)
def
prepare_rotary_pos_embedding
(
self
,
video_length
,
height
,
width
):
target_ndim
=
3
ndim
=
5
-
2
# 884
vae
=
"884-16c-hy"
patch_size
=
[
1
,
2
,
2
]
hidden_size
=
3072
heads_num
=
24
rope_theta
=
256
rope_dim_list
=
[
16
,
56
,
56
]
if
"884"
in
vae
:
latents_size
=
[(
video_length
-
1
)
//
4
+
1
,
height
//
8
,
width
//
8
]
elif
"888"
in
vae
:
latents_size
=
[(
video_length
-
1
)
//
8
+
1
,
height
//
8
,
width
//
8
]
else
:
latents_size
=
[
video_length
,
height
//
8
,
width
//
8
]
if
isinstance
(
patch_size
,
int
):
assert
all
(
s
%
patch_size
==
0
for
s
in
latents_size
),
(
f
"Latent size(last
{
ndim
}
dimensions) should be divisible by patch size(
{
patch_size
}
), "
f
"but got
{
latents_size
}
."
)
rope_sizes
=
[
s
//
patch_size
for
s
in
latents_size
]
elif
isinstance
(
patch_size
,
list
):
assert
all
(
s
%
patch_size
[
idx
]
==
0
for
idx
,
s
in
enumerate
(
latents_size
)
),
(
f
"Latent size(last
{
ndim
}
dimensions) should be divisible by patch size(
{
patch_size
}
), "
f
"but got
{
latents_size
}
."
)
rope_sizes
=
[
s
//
patch_size
[
idx
]
for
idx
,
s
in
enumerate
(
latents_size
)
]
if
len
(
rope_sizes
)
!=
target_ndim
:
rope_sizes
=
[
1
]
*
(
target_ndim
-
len
(
rope_sizes
))
+
rope_sizes
# time axis
head_dim
=
hidden_size
//
heads_num
rope_dim_list
=
rope_dim_list
if
rope_dim_list
is
None
:
rope_dim_list
=
[
head_dim
//
target_ndim
for
_
in
range
(
target_ndim
)]
assert
(
sum
(
rope_dim_list
)
==
head_dim
),
"sum(rope_dim_list) should equal to head_dim of attention layer"
self
.
freqs_cos
,
self
.
freqs_sin
=
get_nd_rotary_pos_embed
(
rope_dim_list
,
rope_sizes
,
theta
=
rope_theta
,
use_real
=
True
,
theta_rescale_factor
=
1
,
)
self
.
freqs_cos
=
self
.
freqs_cos
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
self
.
freqs_sin
=
self
.
freqs_sin
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
lightx2v/text2v/models/schedulers/scheduler.py
0 → 100755
View file @
daf4c74e
import
torch
class
BaseScheduler
():
def
__init__
(
self
,
args
):
self
.
args
=
args
self
.
step_index
=
0
self
.
latents
=
None
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
self
.
latents
=
self
.
latents
.
to
(
dtype
=
torch
.
bfloat16
)
lightx2v/text2v/models/schedulers/wan/feature_caching/scheduler.py
0 → 100755
View file @
daf4c74e
import
torch
from
..scheduler
import
WanScheduler
class
WanSchedulerFeatureCaching
(
WanScheduler
):
def
__init__
(
self
,
args
):
super
().
__init__
(
args
)
self
.
cnt
=
0
self
.
num_steps
=
self
.
args
.
infer_steps
*
2
self
.
teacache_thresh
=
self
.
args
.
teacache_thresh
self
.
accumulated_rel_l1_distance_even
=
0
self
.
accumulated_rel_l1_distance_odd
=
0
self
.
previous_e0_even
=
None
self
.
previous_e0_odd
=
None
self
.
previous_residual_even
=
None
self
.
previous_residual_odd
=
None
self
.
use_ret_steps
=
self
.
args
.
use_ret_steps
if
self
.
use_ret_steps
:
if
self
.
args
.
target_width
==
480
or
self
.
args
.
target_height
==
480
:
self
.
coefficients
=
[
2.57151496e05
,
-
3.54229917e04
,
1.40286849e03
,
-
1.35890334e01
,
1.32517977e-01
,
]
if
self
.
args
.
target_width
==
720
or
self
.
args
.
target_height
==
720
:
self
.
coefficients
=
[
8.10705460e03
,
2.13393892e03
,
-
3.72934672e02
,
1.66203073e01
,
-
4.17769401e-02
,
]
self
.
ret_steps
=
5
*
2
self
.
cutoff_steps
=
self
.
args
.
infer_steps
*
2
else
:
if
self
.
args
.
target_width
==
480
or
self
.
args
.
target_height
==
480
:
self
.
coefficients
=
[
-
3.02331670e02
,
2.23948934e02
,
-
5.25463970e01
,
5.87348440e00
,
-
2.01973289e-01
,
]
if
self
.
args
.
target_width
==
720
or
self
.
args
.
target_height
==
720
:
self
.
coefficients
=
[
-
114.36346466
,
65.26524496
,
-
18.82220707
,
4.91518089
,
-
0.23412683
,
]
self
.
ret_steps
=
1
*
2
self
.
cutoff_steps
=
self
.
args
.
infer_steps
*
2
-
2
\ No newline at end of file
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment