Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
5b56dc56
Commit
5b56dc56
authored
May 09, 2025
by
Dongz
Committed by
GitHub
May 09, 2025
Browse files
[major]: deprecated attention functions (#35)
parent
ad0237f9
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
37 additions
and
25 deletions
+37
-25
lightx2v/models/networks/hunyuan/infer/feature_caching/transformer_infer.py
...tworks/hunyuan/infer/feature_caching/transformer_infer.py
+2
-5
lightx2v/models/networks/hunyuan/infer/pre_infer.py
lightx2v/models/networks/hunyuan/infer/pre_infer.py
+2
-3
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
+2
-5
lightx2v/models/networks/hunyuan/weights/pre_weights.py
lightx2v/models/networks/hunyuan/weights/pre_weights.py
+4
-1
lightx2v/models/networks/hunyuan/weights/transformer_weights.py
...2v/models/networks/hunyuan/weights/transformer_weights.py
+13
-1
lightx2v/models/networks/wan/infer/causvid/transformer_infer.py
...2v/models/networks/wan/infer/causvid/transformer_infer.py
+10
-5
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+0
-1
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+4
-3
lightx2v/utils/set_config.py
lightx2v/utils/set_config.py
+0
-1
No files found.
lightx2v/models/networks/hunyuan/infer/feature_caching/transformer_infer.py
View file @
5b56dc56
import
torch
import
numpy
as
np
from
einops
import
rearrange
from
lightx2v.attentions
import
attention
from
.utils
import
taylor_cache_init
,
derivative_approximation
,
taylor_formula
from
..utils_bf16
import
apply_rotary_emb
from
..transformer_infer
import
HunyuanTransformerInfer
...
...
@@ -118,8 +117,7 @@ class HunyuanTransformerInferTaylorCaching(HunyuanTransformerInfer):
v
=
torch
.
cat
((
img_v
,
txt_v
),
dim
=
0
)
if
not
self
.
parallel_attention
:
attn
=
attention
(
attention_type
=
self
.
attention_type
,
attn
=
weights
.
double_attn
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
...
...
@@ -284,8 +282,7 @@ class HunyuanTransformerInferTaylorCaching(HunyuanTransformerInfer):
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
if
not
self
.
parallel_attention
:
attn
=
attention
(
attention_type
=
self
.
attention_type
,
attn
=
weights
.
single_attn
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
...
...
lightx2v/models/networks/hunyuan/infer/pre_infer.py
View file @
5b56dc56
import
torch
import
math
from
einops
import
rearrange
from
lightx2v.attentions
import
attention
class
HunyuanPreInfer
:
...
...
@@ -107,7 +106,7 @@ class HunyuanPreInfer:
normx
=
weights
.
txt_in_individual_token_refiner_blocks_0_norm1
.
apply
(
txt_in_input_embed
)
qkv
=
weights
.
txt_in_individual_token_refiner_blocks_0_self_attn_qkv
.
apply
(
normx
)
q
,
k
,
v
=
rearrange
(
qkv
.
unsqueeze
(
0
),
"B L (K H D) -> K B L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
attn
=
attention
(
attention_type
=
"torch_sdpa"
,
q
=
q
,
k
=
k
,
v
=
v
,
attn_mask
=
self_attn_mask
)[
0
]
attn
=
weights
.
txt_in_attn_1
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
attn_mask
=
self_attn_mask
)[
0
]
out
=
weights
.
txt_in_individual_token_refiner_blocks_0_self_attn_proj
.
apply
(
attn
)
out_1
=
txt_in_input_embed
+
out
*
gate_msa
out
=
weights
.
txt_in_individual_token_refiner_blocks_0_norm2
.
apply
(
out_1
)
...
...
@@ -126,7 +125,7 @@ class HunyuanPreInfer:
q
,
k
,
v
=
rearrange
(
qkv
.
unsqueeze
(
0
),
"B L (K H D) -> K B L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
attn
=
attention
(
attention_type
=
"torch_sdpa"
,
q
=
q
,
k
=
k
,
v
=
v
,
attn_mask
=
self_attn_mask
)[
0
]
attn
=
weights
.
txt_in_attn_1
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
attn_mask
=
self_attn_mask
)[
0
]
out
=
weights
.
txt_in_individual_token_refiner_blocks_1_self_attn_proj
.
apply
(
attn
)
out_1
=
txt_in_input_embed
+
out
*
gate_msa
...
...
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
View file @
5b56dc56
import
torch
from
einops
import
rearrange
from
lightx2v.attentions
import
attention
from
.utils_bf16
import
apply_rotary_emb
from
lightx2v.common.offload.manager
import
WeightStreamManager
from
lightx2v.utils.envs
import
*
...
...
@@ -120,8 +119,7 @@ class HunyuanTransformerInfer:
v
=
torch
.
cat
((
img_v
,
txt_v
),
dim
=
0
)
if
not
self
.
parallel_attention
:
attn
=
attention
(
attention_type
=
self
.
attention_type
,
attn
=
weights
.
double_attn
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
...
...
@@ -263,8 +261,7 @@ class HunyuanTransformerInfer:
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
if
not
self
.
parallel_attention
:
attn
=
attention
(
attention_type
=
self
.
attention_type
,
attn
=
weights
.
single_attn
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
...
...
lightx2v/models/networks/hunyuan/weights/pre_weights.py
View file @
5b56dc56
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
CONV3D_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
CONV3D_WEIGHT_REGISTER
,
ATTN_WEIGHT_REGISTER
from
lightx2v.common.modules.weight_module
import
WeightModule
...
...
@@ -79,3 +79,6 @@ class HunyuanPreWeights(WeightModule):
self
.
add_module
(
"vector_in_out_layer"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"vector_in.out_layer.weight"
,
"vector_in.out_layer.bias"
))
self
.
add_module
(
"guidance_in_mlp_0"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_in.mlp.0.weight"
,
"guidance_in.mlp.0.bias"
))
self
.
add_module
(
"guidance_in_mlp_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_in.mlp.2.weight"
,
"guidance_in.mlp.2.bias"
))
# attention weights section
self
.
add_module
(
"txt_in_attn_1"
,
ATTN_WEIGHT_REGISTER
[
"torch_sdpa"
]())
lightx2v/models/networks/hunyuan/weights/transformer_weights.py
View file @
5b56dc56
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
RMS_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
RMS_WEIGHT_REGISTER
,
ATTN_WEIGHT_REGISTER
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
...
...
@@ -40,12 +40,16 @@ class HunyuanTransformerDoubleBlock(WeightModule):
self
.
add_module
(
"txt_mlp_fc1"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc1.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc1.bias"
))
self
.
add_module
(
"txt_mlp_fc2"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc2.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc2.bias"
))
# attention weights section
self
.
add_module
(
"double_attn"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attention_type"
]]())
class
HunyuanTransformerSingleBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
config
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
config
=
config
self
.
sparge
=
config
.
get
(
"sparge"
,
False
)
if
self
.
config
[
"do_mm_calib"
]:
mm_type
=
"Calib"
...
...
@@ -57,3 +61,11 @@ class HunyuanTransformerSingleBlock(WeightModule):
self
.
add_module
(
"q_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"single_blocks.
{
self
.
block_index
}
.q_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"k_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"single_blocks.
{
self
.
block_index
}
.k_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"modulation"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.modulation.linear.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.modulation.linear.bias"
))
# attention weights section
if
self
.
sparge
:
# load sparge attention weights
#! todo
pass
else
:
self
.
add_module
(
"single_attn"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attention_type"
]]())
lightx2v/models/networks/wan/infer/causvid/transformer_infer.py
View file @
5b56dc56
import
torch
import
math
from
..utils
import
compute_freqs
,
compute_freqs_causvid
,
compute_freqs_dist
,
apply_rotary_emb
from
lightx2v.attentions
import
attention
from
lightx2v.common.offload.manager
import
WeightStreamManager
from
lightx2v.utils.envs
import
*
from
..transformer_infer
import
WanTransformerInfer
...
...
@@ -125,8 +124,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
=
self
.
_calculate_q_k_len
(
q
=
q
,
k
=
self
.
kv_cache
[
block_idx
][
"k"
][:
kv_end
],
k_lens
=
torch
.
tensor
([
kv_end
],
dtype
=
torch
.
int32
,
device
=
k
.
device
))
if
not
self
.
parallel_attention
:
attn_out
=
attention
(
attention_type
=
self
.
attention_type
,
attn_out
=
weights
.
self_attn_1
.
apply
(
q
=
q
,
k
=
self
.
kv_cache
[
block_idx
][
"k"
][:
kv_end
],
v
=
self
.
kv_cache
[
block_idx
][
"v"
][:
kv_end
],
...
...
@@ -164,8 +162,15 @@ class WanTransformerInferCausVid(WanTransformerInfer):
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
=
self
.
_calculate_q_k_len
(
q
,
k
,
k_lens
=
torch
.
tensor
([
k
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
))
attn_out
=
attention
(
attention_type
=
self
.
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
lq
,
max_seqlen_kv
=
lk
,
model_cls
=
self
.
config
[
"model_cls"
]
attn_out
=
weights
.
cross_attn_1
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
lq
,
max_seqlen_kv
=
lk
,
model_cls
=
self
.
config
[
"model_cls"
],
)
# TODO: Implement I2V inference for causvid model
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
5b56dc56
import
torch
from
.utils
import
compute_freqs
,
compute_freqs_dist
,
apply_rotary_emb
from
lightx2v.attentions
import
attention
from
lightx2v.common.offload.manager
import
WeightStreamManager
from
lightx2v.utils.envs
import
*
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
5b56dc56
...
...
@@ -25,6 +25,7 @@ class WanTransformerAttentionBlock(WeightModule):
self
.
task
=
task
self
.
config
=
config
self
.
quant_method
=
config
[
"mm_config"
].
get
(
"quant_method"
,
None
)
self
.
sparge
=
config
.
get
(
"sparge"
,
False
)
self
.
add_module
(
"self_attn_q"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.q.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.q.bias"
))
self
.
add_module
(
"self_attn_k"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.k.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.k.bias"
))
...
...
@@ -44,8 +45,8 @@ class WanTransformerAttentionBlock(WeightModule):
self
.
add_module
(
"ffn_0"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.ffn.0.weight"
,
f
"blocks.
{
self
.
block_index
}
.ffn.0.bias"
))
self
.
add_module
(
"ffn_2"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.ffn.2.weight"
,
f
"blocks.
{
self
.
block_index
}
.ffn.2.bias"
))
# attention weights
if
self
.
config
[
"
sparge
"
]
:
# attention weights
section
if
self
.
sparge
:
assert
self
.
config
[
"sparge_ckpt"
],
"sparge_ckpt must be set when sparge is True"
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
"Sparge"
](
f
"blocks.
{
self
.
block_index
}
"
))
self
.
add_module
(
"cross_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attention_type"
]]())
...
...
@@ -61,7 +62,7 @@ class WanTransformerAttentionBlock(WeightModule):
self
.
add_module
(
"cross_attn_2"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attention_type"
]]())
# load attn weights
if
self
.
config
[
"
sparge
"
]
:
if
self
.
sparge
:
assert
self
.
config
[
"sparge_ckpt"
],
"sparge_ckpt must be set when sparge is True"
sparge_ckpt
=
torch
.
load
(
self
.
config
[
"sparge_ckpt"
])
self
.
self_attn_1
.
load
(
sparge_ckpt
)
...
...
lightx2v/utils/set_config.py
View file @
5b56dc56
...
...
@@ -20,7 +20,6 @@ def get_default_config():
"strength_model"
:
1.0
,
"mm_config"
:
{},
"use_prompt_enhancer"
:
False
,
"sparge"
:
False
,
}
return
default_config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment