Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
63f233ad
Unverified
Commit
63f233ad
authored
Oct 09, 2025
by
gushiqiao
Committed by
GitHub
Oct 09, 2025
Browse files
[Fix] Fix vace and animate models config bug (#351)
parent
69c2f650
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
106 additions
and
126 deletions
+106
-126
lightx2v/common/ops/attn/flash_attn.py
lightx2v/common/ops/attn/flash_attn.py
+16
-19
lightx2v/common/ops/attn/sage_attn.py
lightx2v/common/ops/attn/sage_attn.py
+11
-34
lightx2v/models/networks/wan/infer/animate/transformer_infer.py
...2v/models/networks/wan/infer/animate/transformer_infer.py
+2
-2
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
+2
-2
lightx2v/models/networks/wan/weights/vace/transformer_weights.py
...v/models/networks/wan/weights/vace/transformer_weights.py
+1
-1
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+1
-1
lightx2v/models/runners/wan/wan_animate_runner.py
lightx2v/models/runners/wan/wan_animate_runner.py
+43
-46
lightx2v/models/runners/wan/wan_vace_runner.py
lightx2v/models/runners/wan/wan_vace_runner.py
+16
-16
tools/convert/converter.py
tools/convert/converter.py
+14
-5
No files found.
lightx2v/common/ops/attn/flash_attn.py
View file @
63f233ad
...
@@ -34,6 +34,10 @@ class FlashAttn2Weight(AttnWeightTemplate):
...
@@ -34,6 +34,10 @@ class FlashAttn2Weight(AttnWeightTemplate):
max_seqlen_kv
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
model_cls
=
None
,
):
):
if
len
(
q
.
shape
)
==
3
:
bs
=
1
elif
len
(
q
.
shape
)
==
4
:
bs
=
q
.
shape
[
0
]
x
=
flash_attn_varlen_func
(
x
=
flash_attn_varlen_func
(
q
,
q
,
k
,
k
,
...
@@ -42,7 +46,7 @@ class FlashAttn2Weight(AttnWeightTemplate):
...
@@ -42,7 +46,7 @@ class FlashAttn2Weight(AttnWeightTemplate):
cu_seqlens_kv
,
cu_seqlens_kv
,
max_seqlen_q
,
max_seqlen_q
,
max_seqlen_kv
,
max_seqlen_kv
,
).
reshape
(
max_seqlen_q
,
-
1
)
).
reshape
(
bs
*
max_seqlen_q
,
-
1
)
return
x
return
x
...
@@ -63,23 +67,16 @@ class FlashAttn3Weight(AttnWeightTemplate):
...
@@ -63,23 +67,16 @@ class FlashAttn3Weight(AttnWeightTemplate):
model_cls
=
None
,
model_cls
=
None
,
):
):
if
len
(
q
.
shape
)
==
3
:
if
len
(
q
.
shape
)
==
3
:
x
=
flash_attn_varlen_func_v3
(
bs
=
1
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_kv
,
max_seqlen_q
,
max_seqlen_kv
,
).
reshape
(
max_seqlen_q
,
-
1
)
elif
len
(
q
.
shape
)
==
4
:
elif
len
(
q
.
shape
)
==
4
:
x
=
flash_attn_varlen_func_v3
(
bs
=
q
.
shape
[
0
]
q
,
x
=
flash_attn_varlen_func_v3
(
k
,
q
,
v
,
k
,
cu_seqlens_q
,
v
,
cu_seqlens_kv
,
cu_seqlens_q
,
max_seqlen_q
,
cu_seqlens_kv
,
max_seqlen_kv
,
max_seqlen_q
,
).
reshape
(
q
.
shape
[
0
]
*
max_seqlen_q
,
-
1
)
max_seqlen_kv
,
).
reshape
(
bs
*
max_seqlen_q
,
-
1
)
return
x
return
x
lightx2v/common/ops/attn/sage_attn.py
View file @
63f233ad
...
@@ -36,38 +36,15 @@ class SageAttn2Weight(AttnWeightTemplate):
...
@@ -36,38 +36,15 @@ class SageAttn2Weight(AttnWeightTemplate):
model_cls
=
None
,
model_cls
=
None
,
):
):
q
,
k
,
v
=
q
.
contiguous
(),
k
.
contiguous
(),
v
.
contiguous
()
q
,
k
,
v
=
q
.
contiguous
(),
k
.
contiguous
(),
v
.
contiguous
()
if
model_cls
==
"hunyuan"
:
if
len
(
q
.
shape
)
==
3
:
x1
=
sageattn
(
bs
=
1
q
[:
cu_seqlens_q
[
1
]].
unsqueeze
(
0
),
q
,
k
,
v
=
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)
k
[:
cu_seqlens_kv
[
1
]].
unsqueeze
(
0
),
elif
len
(
q
.
shape
)
==
4
:
v
[:
cu_seqlens_kv
[
1
]].
unsqueeze
(
0
),
bs
=
q
.
shape
[
0
]
tensor_layout
=
"NHD"
,
x
=
sageattn
(
)
q
,
x2
=
sageattn
(
k
,
q
[
cu_seqlens_q
[
1
]
:].
unsqueeze
(
0
),
v
,
k
[
cu_seqlens_kv
[
1
]
:].
unsqueeze
(
0
),
tensor_layout
=
"NHD"
,
v
[
cu_seqlens_kv
[
1
]
:].
unsqueeze
(
0
),
).
view
(
bs
*
max_seqlen_q
,
-
1
)
tensor_layout
=
"NHD"
,
)
x
=
torch
.
cat
((
x1
,
x2
),
dim
=
1
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
elif
model_cls
in
[
"wan2.1"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_df"
,
"seko_talk"
,
"wan2.2"
,
"wan2.1_vace"
,
"wan2.2_moe"
,
"wan2.2_animate"
,
"wan2.2_moe_distill"
,
"qwen_image"
]:
if
len
(
q
.
shape
)
==
3
:
x
=
sageattn
(
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
),
tensor_layout
=
"NHD"
,
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
elif
len
(
q
.
shape
)
==
4
:
x
=
sageattn
(
q
,
k
,
v
,
tensor_layout
=
"NHD"
,
)
x
=
x
.
view
(
q
.
shape
[
0
]
*
max_seqlen_q
,
-
1
)
else
:
raise
NotImplementedError
(
f
"Model class '
{
model_cls
}
' is not implemented in this attention implementation"
)
return
x
return
x
lightx2v/models/networks/wan/infer/animate/transformer_infer.py
View file @
63f233ad
...
@@ -20,8 +20,8 @@ class WanAnimateTransformerInfer(WanOffloadTransformerInfer):
...
@@ -20,8 +20,8 @@ class WanAnimateTransformerInfer(WanOffloadTransformerInfer):
kv
=
phase
.
linear1_kv
.
apply
(
x_motion
.
view
(
-
1
,
x_motion
.
shape
[
-
1
]))
kv
=
phase
.
linear1_kv
.
apply
(
x_motion
.
view
(
-
1
,
x_motion
.
shape
[
-
1
]))
kv
=
kv
.
view
(
T
,
-
1
,
kv
.
shape
[
-
1
])
kv
=
kv
.
view
(
T
,
-
1
,
kv
.
shape
[
-
1
])
q
=
phase
.
linear1_q
.
apply
(
x_feat
)
q
=
phase
.
linear1_q
.
apply
(
x_feat
)
k
,
v
=
rearrange
(
kv
,
"L N (K H D) -> K L N H D"
,
K
=
2
,
H
=
self
.
config
.
num_heads
)
k
,
v
=
rearrange
(
kv
,
"L N (K H D) -> K L N H D"
,
K
=
2
,
H
=
self
.
config
[
"
num_heads
"
]
)
q
=
rearrange
(
q
,
"S (H D) -> S H D"
,
H
=
self
.
config
.
num_heads
)
q
=
rearrange
(
q
,
"S (H D) -> S H D"
,
H
=
self
.
config
[
"
num_heads
"
]
)
q
=
phase
.
q_norm
.
apply
(
q
).
view
(
T
,
q
.
shape
[
0
]
//
T
,
q
.
shape
[
1
],
q
.
shape
[
2
])
q
=
phase
.
q_norm
.
apply
(
q
).
view
(
T
,
q
.
shape
[
0
]
//
T
,
q
.
shape
[
1
],
q
.
shape
[
2
])
k
=
phase
.
k_norm
.
apply
(
k
)
k
=
phase
.
k_norm
.
apply
(
k
)
...
...
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
View file @
63f233ad
...
@@ -5,8 +5,8 @@ from lightx2v.utils.envs import *
...
@@ -5,8 +5,8 @@ from lightx2v.utils.envs import *
class
WanVaceTransformerInfer
(
WanOffloadTransformerInfer
):
class
WanVaceTransformerInfer
(
WanOffloadTransformerInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
vace_blocks_num
=
len
(
self
.
config
.
vace_layers
)
self
.
vace_blocks_num
=
len
(
self
.
config
[
"
vace_layers
"
]
)
self
.
vace_blocks_mapping
=
{
orig_idx
:
seq_idx
for
seq_idx
,
orig_idx
in
enumerate
(
self
.
config
.
vace_layers
)}
self
.
vace_blocks_mapping
=
{
orig_idx
:
seq_idx
for
seq_idx
,
orig_idx
in
enumerate
(
self
.
config
[
"
vace_layers
"
]
)}
def
infer
(
self
,
weights
,
pre_infer_out
):
def
infer
(
self
,
weights
,
pre_infer_out
):
pre_infer_out
.
c
=
self
.
vace_pre_process
(
weights
.
vace_patch_embedding
,
pre_infer_out
.
vace_context
)
pre_infer_out
.
c
=
self
.
vace_pre_process
(
weights
.
vace_patch_embedding
,
pre_infer_out
.
vace_context
)
...
...
lightx2v/models/networks/wan/weights/vace/transformer_weights.py
View file @
63f233ad
...
@@ -14,7 +14,7 @@ class WanVaceTransformerWeights(WanTransformerWeights):
...
@@ -14,7 +14,7 @@ class WanVaceTransformerWeights(WanTransformerWeights):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
vace_blocks
=
WeightModuleList
(
self
.
vace_blocks
=
WeightModuleList
(
[
WanVaceTransformerAttentionBlock
(
self
.
config
.
vace_layers
[
i
],
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
"vace_blocks"
)
for
i
in
range
(
len
(
self
.
config
.
vace_layers
))]
[
WanVaceTransformerAttentionBlock
(
self
.
config
[
"
vace_layers
"
]
[
i
],
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
"vace_blocks"
)
for
i
in
range
(
len
(
self
.
config
[
"
vace_layers
"
]
))]
)
)
self
.
add_module
(
"vace_blocks"
,
self
.
vace_blocks
)
self
.
add_module
(
"vace_blocks"
,
self
.
vace_blocks
)
...
...
lightx2v/models/runners/default_runner.py
View file @
63f233ad
...
@@ -214,7 +214,7 @@ class DefaultRunner(BaseRunner):
...
@@ -214,7 +214,7 @@ class DefaultRunner(BaseRunner):
[
src_video
],
[
src_video
],
[
src_mask
],
[
src_mask
],
[
None
if
src_ref_images
is
None
else
src_ref_images
.
split
(
","
)],
[
None
if
src_ref_images
is
None
else
src_ref_images
.
split
(
","
)],
(
self
.
config
.
target_width
,
self
.
config
.
target_height
),
(
self
.
config
[
"
target_width
"
]
,
self
.
config
[
"
target_height
"
]
),
)
)
self
.
src_ref_images
=
src_ref_images
self
.
src_ref_images
=
src_ref_images
...
...
lightx2v/models/runners/wan/wan_animate_runner.py
View file @
63f233ad
...
@@ -11,7 +11,7 @@ try:
...
@@ -11,7 +11,7 @@ try:
from
decord
import
VideoReader
from
decord
import
VideoReader
except
ImportError
:
except
ImportError
:
VideoReader
=
None
VideoReader
=
None
logger
.
info
(
"If you
need
run animate model, please install decord."
)
logger
.
info
(
"If you
want to
run animate model, please install decord."
)
from
lightx2v.models.input_encoders.hf.animate.face_encoder
import
FaceEncoder
from
lightx2v.models.input_encoders.hf.animate.face_encoder
import
FaceEncoder
...
@@ -28,7 +28,7 @@ from lightx2v.utils.utils import load_weights, remove_substrings_from_keys
...
@@ -28,7 +28,7 @@ from lightx2v.utils.utils import load_weights, remove_substrings_from_keys
class
WanAnimateRunner
(
WanRunner
):
class
WanAnimateRunner
(
WanRunner
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
assert
self
.
config
.
task
==
"animate"
assert
self
.
config
[
"
task
"
]
==
"animate"
def
inputs_padding
(
self
,
array
,
target_len
):
def
inputs_padding
(
self
,
array
,
target_len
):
idx
=
0
idx
=
0
...
@@ -161,11 +161,11 @@ class WanAnimateRunner(WanRunner):
...
@@ -161,11 +161,11 @@ class WanAnimateRunner(WanRunner):
pose_latents
=
self
.
vae_encoder
.
encode
(
conditioning_pixel_values
.
unsqueeze
(
0
))
# c t h w
pose_latents
=
self
.
vae_encoder
.
encode
(
conditioning_pixel_values
.
unsqueeze
(
0
))
# c t h w
ref_latents
=
self
.
vae_encoder
.
encode
(
self
.
refer_pixel_values
.
unsqueeze
(
1
).
unsqueeze
(
0
))
# c t h w
ref_latents
=
self
.
vae_encoder
.
encode
(
self
.
refer_pixel_values
.
unsqueeze
(
1
).
unsqueeze
(
0
))
# c t h w
mask_ref
=
self
.
get_i2v_mask
(
1
,
self
.
config
.
la
t_h
,
self
.
config
.
la
t_w
,
1
)
mask_ref
=
self
.
get_i2v_mask
(
1
,
self
.
laten
t_h
,
self
.
laten
t_w
,
1
)
y_ref
=
torch
.
concat
([
mask_ref
,
ref_latents
])
y_ref
=
torch
.
concat
([
mask_ref
,
ref_latents
])
if
self
.
mask_reft_len
>
0
:
if
self
.
mask_reft_len
>
0
:
if
self
.
config
.
replace_flag
:
if
self
.
config
[
"
replace_flag
"
]
:
y_reft
=
self
.
vae_encoder
.
encode
(
y_reft
=
self
.
vae_encoder
.
encode
(
torch
.
concat
(
torch
.
concat
(
[
[
...
@@ -183,9 +183,9 @@ class WanAnimateRunner(WanRunner):
...
@@ -183,9 +183,9 @@ class WanAnimateRunner(WanRunner):
mask_pixel_values
=
mask_pixel_values
[:,
0
,
:,
:]
mask_pixel_values
=
mask_pixel_values
[:,
0
,
:,
:]
msk_reft
=
self
.
get_i2v_mask
(
msk_reft
=
self
.
get_i2v_mask
(
self
.
config
.
la
t_t
,
self
.
laten
t_t
,
self
.
config
.
la
t_h
,
self
.
laten
t_h
,
self
.
config
.
la
t_w
,
self
.
laten
t_w
,
self
.
mask_reft_len
,
self
.
mask_reft_len
,
mask_pixel_values
=
mask_pixel_values
.
unsqueeze
(
0
),
mask_pixel_values
=
mask_pixel_values
.
unsqueeze
(
0
),
)
)
...
@@ -198,31 +198,31 @@ class WanAnimateRunner(WanRunner):
...
@@ -198,31 +198,31 @@ class WanAnimateRunner(WanRunner):
size
=
(
H
,
W
),
size
=
(
H
,
W
),
mode
=
"bicubic"
,
mode
=
"bicubic"
,
),
),
torch
.
zeros
(
3
,
self
.
config
.
target_video_length
-
self
.
mask_reft_len
,
H
,
W
,
dtype
=
GET_DTYPE
()),
torch
.
zeros
(
3
,
self
.
config
[
"
target_video_length
"
]
-
self
.
mask_reft_len
,
H
,
W
,
dtype
=
GET_DTYPE
()),
],
],
dim
=
1
,
dim
=
1
,
)
)
.
cuda
()
.
cuda
()
.
unsqueeze
(
0
)
.
unsqueeze
(
0
)
)
)
msk_reft
=
self
.
get_i2v_mask
(
self
.
config
.
lat_t
,
self
.
config
.
la
t_h
,
self
.
config
.
la
t_w
,
self
.
mask_reft_len
)
msk_reft
=
self
.
get_i2v_mask
(
self
.
latent_t
,
self
.
laten
t_h
,
self
.
laten
t_w
,
self
.
mask_reft_len
)
else
:
else
:
if
self
.
config
.
replace_flag
:
if
self
.
config
[
"
replace_flag
"
]
:
mask_pixel_values
=
1
-
mask_pixel_values
mask_pixel_values
=
1
-
mask_pixel_values
mask_pixel_values
=
mask_pixel_values
.
permute
(
1
,
0
,
2
,
3
)
mask_pixel_values
=
mask_pixel_values
.
permute
(
1
,
0
,
2
,
3
)
mask_pixel_values
=
F
.
interpolate
(
mask_pixel_values
,
size
=
(
H
//
8
,
W
//
8
),
mode
=
"nearest"
)
mask_pixel_values
=
F
.
interpolate
(
mask_pixel_values
,
size
=
(
H
//
8
,
W
//
8
),
mode
=
"nearest"
)
mask_pixel_values
=
mask_pixel_values
[:,
0
,
:,
:]
mask_pixel_values
=
mask_pixel_values
[:,
0
,
:,
:]
y_reft
=
self
.
vae_encoder
.
encode
(
bg_pixel_values
.
unsqueeze
(
0
))
y_reft
=
self
.
vae_encoder
.
encode
(
bg_pixel_values
.
unsqueeze
(
0
))
msk_reft
=
self
.
get_i2v_mask
(
msk_reft
=
self
.
get_i2v_mask
(
self
.
config
.
la
t_t
,
self
.
laten
t_t
,
self
.
config
.
la
t_h
,
self
.
laten
t_h
,
self
.
config
.
la
t_w
,
self
.
laten
t_w
,
self
.
mask_reft_len
,
self
.
mask_reft_len
,
mask_pixel_values
=
mask_pixel_values
.
unsqueeze
(
0
),
mask_pixel_values
=
mask_pixel_values
.
unsqueeze
(
0
),
)
)
else
:
else
:
y_reft
=
self
.
vae_encoder
.
encode
(
torch
.
zeros
(
1
,
3
,
self
.
config
.
target_video_length
-
self
.
mask_reft_len
,
H
,
W
,
dtype
=
GET_DTYPE
(),
device
=
"cuda"
))
y_reft
=
self
.
vae_encoder
.
encode
(
torch
.
zeros
(
1
,
3
,
self
.
config
[
"
target_video_length
"
]
-
self
.
mask_reft_len
,
H
,
W
,
dtype
=
GET_DTYPE
(),
device
=
"cuda"
))
msk_reft
=
self
.
get_i2v_mask
(
self
.
config
.
lat_t
,
self
.
config
.
la
t_h
,
self
.
config
.
la
t_w
,
self
.
mask_reft_len
)
msk_reft
=
self
.
get_i2v_mask
(
self
.
latent_t
,
self
.
laten
t_h
,
self
.
laten
t_w
,
self
.
mask_reft_len
)
y_reft
=
torch
.
concat
([
msk_reft
,
y_reft
])
y_reft
=
torch
.
concat
([
msk_reft
,
y_reft
])
y
=
torch
.
concat
([
y_ref
,
y_reft
],
dim
=
1
)
y
=
torch
.
concat
([
y_ref
,
y_reft
],
dim
=
1
)
...
@@ -230,35 +230,39 @@ class WanAnimateRunner(WanRunner):
...
@@ -230,35 +230,39 @@ class WanAnimateRunner(WanRunner):
return
y
,
pose_latents
return
y
,
pose_latents
def
prepare_input
(
self
):
def
prepare_input
(
self
):
src_pose_path
=
self
.
config
.
get
(
"src_pose_path"
,
None
)
src_pose_path
=
self
.
config
[
"src_pose_path"
]
if
"src_pose_path"
in
self
.
config
else
None
src_face_path
=
self
.
config
.
get
(
"src_face_path"
,
None
)
src_face_path
=
self
.
config
[
"src_face_path"
]
if
"src_face_path"
in
self
.
config
else
None
src_ref_path
=
self
.
config
.
get
(
"src_ref_images"
,
None
)
src_ref_path
=
self
.
config
[
"src_ref_images"
]
if
"src_ref_images"
in
self
.
config
else
None
self
.
cond_images
,
self
.
face_images
,
self
.
refer_images
=
self
.
prepare_source
(
src_pose_path
,
src_face_path
,
src_ref_path
)
self
.
cond_images
,
self
.
face_images
,
self
.
refer_images
=
self
.
prepare_source
(
src_pose_path
,
src_face_path
,
src_ref_path
)
self
.
refer_pixel_values
=
torch
.
tensor
(
self
.
refer_images
/
127.5
-
1
,
dtype
=
GET_DTYPE
(),
device
=
"cuda"
).
permute
(
2
,
0
,
1
)
# chw
self
.
refer_pixel_values
=
torch
.
tensor
(
self
.
refer_images
/
127.5
-
1
,
dtype
=
GET_DTYPE
(),
device
=
"cuda"
).
permute
(
2
,
0
,
1
)
# chw
self
.
latent_t
=
self
.
config
[
"target_video_length"
]
//
self
.
config
[
"vae_stride"
][
0
]
+
1
self
.
latent_h
=
self
.
refer_pixel_values
.
shape
[
-
2
]
//
self
.
config
[
"vae_stride"
][
1
]
self
.
latent_w
=
self
.
refer_pixel_values
.
shape
[
-
1
]
//
self
.
config
[
"vae_stride"
][
2
]
self
.
input_info
.
latent_shape
=
[
self
.
config
.
get
(
"num_channels_latents"
,
16
),
self
.
latent_t
+
1
,
self
.
latent_h
,
self
.
latent_w
]
self
.
real_frame_len
=
len
(
self
.
cond_images
)
self
.
real_frame_len
=
len
(
self
.
cond_images
)
target_len
=
self
.
get_valid_len
(
target_len
=
self
.
get_valid_len
(
self
.
real_frame_len
,
self
.
real_frame_len
,
self
.
config
.
target_video_length
,
self
.
config
[
"
target_video_length
"
]
,
overlap
=
self
.
config
.
get
(
"refert_num"
,
1
)
,
overlap
=
self
.
config
[
"refert_num"
]
if
"refert_num"
in
self
.
config
else
1
,
)
)
logger
.
info
(
"real frames: {} target frames: {}"
.
format
(
self
.
real_frame_len
,
target_len
))
logger
.
info
(
"real frames: {} target frames: {}"
.
format
(
self
.
real_frame_len
,
target_len
))
self
.
cond_images
=
self
.
inputs_padding
(
self
.
cond_images
,
target_len
)
self
.
cond_images
=
self
.
inputs_padding
(
self
.
cond_images
,
target_len
)
self
.
face_images
=
self
.
inputs_padding
(
self
.
face_images
,
target_len
)
self
.
face_images
=
self
.
inputs_padding
(
self
.
face_images
,
target_len
)
if
self
.
config
.
get
(
"replace_flag"
,
False
)
:
if
self
.
config
[
"replace_flag"
]
if
"replace_flag"
in
self
.
config
else
False
:
src_bg_path
=
self
.
config
.
get
(
"src_bg_path"
)
src_bg_path
=
self
.
config
[
"src_bg_path"
]
src_mask_path
=
self
.
config
.
get
(
"src_mask_path"
)
src_mask_path
=
self
.
config
[
"src_mask_path"
]
self
.
bg_images
,
self
.
mask_images
=
self
.
prepare_source_for_replace
(
src_bg_path
,
src_mask_path
)
self
.
bg_images
,
self
.
mask_images
=
self
.
prepare_source_for_replace
(
src_bg_path
,
src_mask_path
)
self
.
bg_images
=
self
.
inputs_padding
(
self
.
bg_images
,
target_len
)
self
.
bg_images
=
self
.
inputs_padding
(
self
.
bg_images
,
target_len
)
self
.
mask_images
=
self
.
inputs_padding
(
self
.
mask_images
,
target_len
)
self
.
mask_images
=
self
.
inputs_padding
(
self
.
mask_images
,
target_len
)
def
get_video_segment_num
(
self
):
def
get_video_segment_num
(
self
):
total_frames
=
len
(
self
.
cond_images
)
total_frames
=
len
(
self
.
cond_images
)
self
.
move_frames
=
self
.
config
.
target_video_length
-
self
.
config
.
refert_num
self
.
move_frames
=
self
.
config
[
"
target_video_length
"
]
-
self
.
config
[
"
refert_num
"
]
if
total_frames
<=
self
.
config
.
target_video_length
:
if
total_frames
<=
self
.
config
[
"
target_video_length
"
]
:
self
.
video_segment_num
=
1
self
.
video_segment_num
=
1
else
:
else
:
self
.
video_segment_num
=
1
+
(
total_frames
-
self
.
config
.
target_video_length
+
self
.
move_frames
-
1
)
//
self
.
move_frames
self
.
video_segment_num
=
1
+
(
total_frames
-
self
.
config
[
"
target_video_length
"
]
+
self
.
move_frames
-
1
)
//
self
.
move_frames
def
init_run
(
self
):
def
init_run
(
self
):
self
.
all_out_frames
=
[]
self
.
all_out_frames
=
[]
...
@@ -267,10 +271,10 @@ class WanAnimateRunner(WanRunner):
...
@@ -267,10 +271,10 @@ class WanAnimateRunner(WanRunner):
@
ProfilingContext4DebugL1
(
"Run VAE Decoder"
)
@
ProfilingContext4DebugL1
(
"Run VAE Decoder"
)
def
run_vae_decoder
(
self
,
latents
):
def
run_vae_decoder
(
self
,
latents
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
(
self
.
config
[
"lazy_load"
]
if
"lazy_load"
in
self
.
config
else
False
)
or
(
self
.
config
[
"unload_modules"
]
if
"unload_modules"
in
self
.
config
else
False
):
self
.
vae_decoder
=
self
.
load_vae_decoder
()
self
.
vae_decoder
=
self
.
load_vae_decoder
()
images
=
self
.
vae_decoder
.
decode
(
latents
[:,
1
:].
to
(
GET_DTYPE
()))
images
=
self
.
vae_decoder
.
decode
(
latents
[:,
1
:].
to
(
GET_DTYPE
()))
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
(
self
.
config
[
"lazy_load"
]
if
"lazy_load"
in
self
.
config
else
False
)
or
(
self
.
config
[
"unload_modules"
]
if
"unload_modules"
in
self
.
config
else
False
):
del
self
.
vae_decoder
del
self
.
vae_decoder
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
...
@@ -278,11 +282,11 @@ class WanAnimateRunner(WanRunner):
...
@@ -278,11 +282,11 @@ class WanAnimateRunner(WanRunner):
def
init_run_segment
(
self
,
segment_idx
):
def
init_run_segment
(
self
,
segment_idx
):
start
=
segment_idx
*
self
.
move_frames
start
=
segment_idx
*
self
.
move_frames
end
=
start
+
self
.
config
.
target_video_length
end
=
start
+
self
.
config
[
"
target_video_length
"
]
if
start
==
0
:
if
start
==
0
:
self
.
mask_reft_len
=
0
self
.
mask_reft_len
=
0
else
:
else
:
self
.
mask_reft_len
=
self
.
config
.
refert_num
self
.
mask_reft_len
=
self
.
config
[
"
refert_num
"
]
conditioning_pixel_values
=
torch
.
tensor
(
conditioning_pixel_values
=
torch
.
tensor
(
np
.
stack
(
self
.
cond_images
[
start
:
end
])
/
127.5
-
1
,
np
.
stack
(
self
.
cond_images
[
start
:
end
])
/
127.5
-
1
,
...
@@ -300,17 +304,17 @@ class WanAnimateRunner(WanRunner):
...
@@ -300,17 +304,17 @@ class WanAnimateRunner(WanRunner):
height
,
width
=
self
.
refer_images
.
shape
[:
2
]
height
,
width
=
self
.
refer_images
.
shape
[:
2
]
refer_t_pixel_values
=
torch
.
zeros
(
refer_t_pixel_values
=
torch
.
zeros
(
3
,
3
,
self
.
config
.
refert_num
,
self
.
config
[
"
refert_num
"
]
,
height
,
height
,
width
,
width
,
device
=
"cuda"
,
device
=
"cuda"
,
dtype
=
GET_DTYPE
(),
dtype
=
GET_DTYPE
(),
)
# c t h w
)
# c t h w
else
:
else
:
refer_t_pixel_values
=
self
.
gen_video
[
0
,
:,
-
self
.
config
.
refert_num
:].
transpose
(
0
,
1
).
clone
().
detach
()
# c t h w
refer_t_pixel_values
=
self
.
gen_video
[
0
,
:,
-
self
.
config
[
"
refert_num
"
]
:].
transpose
(
0
,
1
).
clone
().
detach
()
# c t h w
bg_pixel_values
,
mask_pixel_values
=
None
,
None
bg_pixel_values
,
mask_pixel_values
=
None
,
None
if
self
.
config
.
replace_flag
:
if
self
.
config
[
"replace_flag"
]
if
"replace_flag"
in
self
.
config
else
False
:
bg_pixel_values
=
torch
.
tensor
(
bg_pixel_values
=
torch
.
tensor
(
np
.
stack
(
self
.
bg_images
[
start
:
end
])
/
127.5
-
1
,
np
.
stack
(
self
.
bg_images
[
start
:
end
])
/
127.5
-
1
,
device
=
"cuda"
,
device
=
"cuda"
,
...
@@ -341,24 +345,17 @@ class WanAnimateRunner(WanRunner):
...
@@ -341,24 +345,17 @@ class WanAnimateRunner(WanRunner):
self
.
gen_video
=
self
.
gen_video
[:,
:,
self
.
config
[
"refert_num"
]
:]
self
.
gen_video
=
self
.
gen_video
[:,
:,
self
.
config
[
"refert_num"
]
:]
self
.
all_out_frames
.
append
(
self
.
gen_video
.
cpu
())
self
.
all_out_frames
.
append
(
self
.
gen_video
.
cpu
())
def
process_images_after_vae_decoder
(
self
,
save_video
=
True
):
def
process_images_after_vae_decoder
(
self
):
self
.
gen_video
=
torch
.
cat
(
self
.
all_out_frames
,
dim
=
2
)[:,
:,
:
self
.
real_frame_len
]
self
.
gen_video
_final
=
torch
.
cat
(
self
.
all_out_frames
,
dim
=
2
)[:,
:,
:
self
.
real_frame_len
]
del
self
.
all_out_frames
del
self
.
all_out_frames
gc
.
collect
()
gc
.
collect
()
super
().
process_images_after_vae_decoder
(
save_video
)
super
().
process_images_after_vae_decoder
()
def
set_target_shape
(
self
):
self
.
config
.
target_video_length
=
self
.
config
.
target_video_length
self
.
config
.
lat_h
=
self
.
refer_pixel_values
.
shape
[
-
2
]
//
8
self
.
config
.
lat_w
=
self
.
refer_pixel_values
.
shape
[
-
1
]
//
8
self
.
config
.
lat_t
=
self
.
config
.
target_video_length
//
4
+
1
self
.
config
.
target_shape
=
[
16
,
self
.
config
.
lat_t
+
1
,
self
.
config
.
lat_h
,
self
.
config
.
lat_w
]
def
run_image_encoder
(
self
,
img
):
# CHW
def
run_image_encoder
(
self
,
img
):
# CHW
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
(
self
.
config
[
"lazy_load"
]
if
"lazy_load"
in
self
.
config
else
False
)
or
(
self
.
config
[
"unload_modules"
]
if
"unload_modules"
in
self
.
config
else
False
):
self
.
image_encoder
=
self
.
load_image_encoder
()
self
.
image_encoder
=
self
.
load_image_encoder
()
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
.
unsqueeze
(
0
)]).
squeeze
(
0
).
to
(
GET_DTYPE
())
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
.
unsqueeze
(
0
)]).
squeeze
(
0
).
to
(
GET_DTYPE
())
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
(
self
.
config
[
"lazy_load"
]
if
"lazy_load"
in
self
.
config
else
False
)
or
(
self
.
config
[
"unload_modules"
]
if
"unload_modules"
in
self
.
config
else
False
):
del
self
.
image_encoder
del
self
.
image_encoder
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
...
@@ -366,7 +363,7 @@ class WanAnimateRunner(WanRunner):
...
@@ -366,7 +363,7 @@ class WanAnimateRunner(WanRunner):
def
load_transformer
(
self
):
def
load_transformer
(
self
):
model
=
WanAnimateModel
(
model
=
WanAnimateModel
(
self
.
config
.
model_path
,
self
.
config
[
"
model_path
"
]
,
self
.
config
,
self
.
config
,
self
.
init_device
,
self
.
init_device
,
)
)
...
...
lightx2v/models/runners/wan/wan_vace_runner.py
View file @
63f233ad
...
@@ -17,13 +17,13 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER
...
@@ -17,13 +17,13 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER
class
WanVaceRunner
(
WanRunner
):
class
WanVaceRunner
(
WanRunner
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
assert
self
.
config
.
task
==
"vace"
assert
self
.
config
[
"
task
"
]
==
"vace"
self
.
vid_proc
=
VaceVideoProcessor
(
self
.
vid_proc
=
VaceVideoProcessor
(
downsample
=
tuple
([
x
*
y
for
x
,
y
in
zip
(
self
.
config
.
vae_stride
,
self
.
config
.
patch_size
)]),
downsample
=
tuple
([
x
*
y
for
x
,
y
in
zip
(
self
.
config
[
"
vae_stride
"
]
,
self
.
config
[
"
patch_size
"
]
)]),
min_area
=
720
*
1280
,
min_area
=
720
*
1280
,
max_area
=
720
*
1280
,
max_area
=
720
*
1280
,
min_fps
=
self
.
config
.
get
(
"fps"
,
16
)
,
min_fps
=
self
.
config
[
"fps"
]
if
"fps"
in
self
.
config
else
16
,
max_fps
=
self
.
config
.
get
(
"fps"
,
16
)
,
max_fps
=
self
.
config
[
"fps"
]
if
"fps"
in
self
.
config
else
16
,
zero_start
=
True
,
zero_start
=
True
,
seq_len
=
75600
,
seq_len
=
75600
,
keep_last
=
True
,
keep_last
=
True
,
...
@@ -31,7 +31,7 @@ class WanVaceRunner(WanRunner):
...
@@ -31,7 +31,7 @@ class WanVaceRunner(WanRunner):
def
load_transformer
(
self
):
def
load_transformer
(
self
):
model
=
WanVaceModel
(
model
=
WanVaceModel
(
self
.
config
.
model_path
,
self
.
config
[
"
model_path
"
]
,
self
.
config
,
self
.
config
,
self
.
init_device
,
self
.
init_device
,
)
)
...
@@ -57,7 +57,7 @@ class WanVaceRunner(WanRunner):
...
@@ -57,7 +57,7 @@ class WanVaceRunner(WanRunner):
src_mask
[
i
]
=
torch
.
clamp
((
src_mask
[
i
][:
1
,
:,
:,
:]
+
1
)
/
2
,
min
=
0
,
max
=
1
)
src_mask
[
i
]
=
torch
.
clamp
((
src_mask
[
i
][:
1
,
:,
:,
:]
+
1
)
/
2
,
min
=
0
,
max
=
1
)
image_sizes
.
append
(
src_video
[
i
].
shape
[
2
:])
image_sizes
.
append
(
src_video
[
i
].
shape
[
2
:])
elif
sub_src_video
is
None
:
elif
sub_src_video
is
None
:
src_video
[
i
]
=
torch
.
zeros
((
3
,
self
.
config
.
target_video_length
,
image_size
[
0
],
image_size
[
1
]),
device
=
device
)
src_video
[
i
]
=
torch
.
zeros
((
3
,
self
.
config
[
"
target_video_length
"
]
,
image_size
[
0
],
image_size
[
1
]),
device
=
device
)
src_mask
[
i
]
=
torch
.
ones_like
(
src_video
[
i
],
device
=
device
)
src_mask
[
i
]
=
torch
.
ones_like
(
src_video
[
i
],
device
=
device
)
image_sizes
.
append
(
image_size
)
image_sizes
.
append
(
image_size
)
else
:
else
:
...
@@ -89,7 +89,7 @@ class WanVaceRunner(WanRunner):
...
@@ -89,7 +89,7 @@ class WanVaceRunner(WanRunner):
return
src_video
,
src_mask
,
src_ref_images
return
src_video
,
src_mask
,
src_ref_images
def
run_vae_encoder
(
self
,
frames
,
ref_images
,
masks
):
def
run_vae_encoder
(
self
,
frames
,
ref_images
,
masks
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
(
self
.
config
[
"lazy_load"
]
if
"lazy_load"
in
self
.
config
else
False
)
or
(
self
.
config
[
"unload_modules"
]
if
"unload_modules"
in
self
.
config
else
False
):
self
.
vae_encoder
=
self
.
load_vae_encoder
()
self
.
vae_encoder
=
self
.
load_vae_encoder
()
if
ref_images
is
None
:
if
ref_images
is
None
:
ref_images
=
[
None
]
*
len
(
frames
)
ref_images
=
[
None
]
*
len
(
frames
)
...
@@ -118,11 +118,11 @@ class WanVaceRunner(WanRunner):
...
@@ -118,11 +118,11 @@ class WanVaceRunner(WanRunner):
latent
=
torch
.
cat
([
*
ref_latent
,
latent
],
dim
=
1
)
latent
=
torch
.
cat
([
*
ref_latent
,
latent
],
dim
=
1
)
cat_latents
.
append
(
latent
)
cat_latents
.
append
(
latent
)
self
.
latent_shape
=
list
(
cat_latents
[
0
].
shape
)
self
.
latent_shape
=
list
(
cat_latents
[
0
].
shape
)
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
(
self
.
config
[
"lazy_load"
]
if
"lazy_load"
in
self
.
config
else
False
)
or
(
self
.
config
[
"unload_modules"
]
if
"unload_modules"
in
self
.
config
else
False
):
del
self
.
vae_encoder
del
self
.
vae_encoder
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
return
self
.
get_vae_encoder_output
(
cat_latents
,
masks
,
ref_images
)
return
self
.
get_vae_encoder_output
(
cat_latents
,
masks
,
ref_images
)
,
self
.
set_input_info_latent_shape
()
def
get_vae_encoder_output
(
self
,
cat_latents
,
masks
,
ref_images
):
def
get_vae_encoder_output
(
self
,
cat_latents
,
masks
,
ref_images
):
if
ref_images
is
None
:
if
ref_images
is
None
:
...
@@ -133,15 +133,15 @@ class WanVaceRunner(WanRunner):
...
@@ -133,15 +133,15 @@ class WanVaceRunner(WanRunner):
result_masks
=
[]
result_masks
=
[]
for
mask
,
refs
in
zip
(
masks
,
ref_images
):
for
mask
,
refs
in
zip
(
masks
,
ref_images
):
c
,
depth
,
height
,
width
=
mask
.
shape
c
,
depth
,
height
,
width
=
mask
.
shape
new_depth
=
int
((
depth
+
3
)
//
self
.
config
.
vae_stride
[
0
])
new_depth
=
int
((
depth
+
3
)
//
self
.
config
[
"
vae_stride
"
]
[
0
])
height
=
2
*
(
int
(
height
)
//
(
self
.
config
.
vae_stride
[
1
]
*
2
))
height
=
2
*
(
int
(
height
)
//
(
self
.
config
[
"
vae_stride
"
]
[
1
]
*
2
))
width
=
2
*
(
int
(
width
)
//
(
self
.
config
.
vae_stride
[
2
]
*
2
))
width
=
2
*
(
int
(
width
)
//
(
self
.
config
[
"
vae_stride
"
]
[
2
]
*
2
))
# reshape
# reshape
mask
=
mask
[
0
,
:,
:,
:]
mask
=
mask
[
0
,
:,
:,
:]
mask
=
mask
.
view
(
depth
,
height
,
self
.
config
.
vae_stride
[
1
],
width
,
self
.
config
.
vae_stride
[
1
])
# depth, height, 8, width, 8
mask
=
mask
.
view
(
depth
,
height
,
self
.
config
[
"
vae_stride
"
]
[
1
],
width
,
self
.
config
[
"
vae_stride
"
]
[
1
])
# depth, height, 8, width, 8
mask
=
mask
.
permute
(
2
,
4
,
0
,
1
,
3
)
# 8, 8, depth, height, width
mask
=
mask
.
permute
(
2
,
4
,
0
,
1
,
3
)
# 8, 8, depth, height, width
mask
=
mask
.
reshape
(
self
.
config
.
vae_stride
[
1
]
*
self
.
config
.
vae_stride
[
2
],
depth
,
height
,
width
)
# 8*8, depth, height, width
mask
=
mask
.
reshape
(
self
.
config
[
"
vae_stride
"
]
[
1
]
*
self
.
config
[
"
vae_stride
"
]
[
2
],
depth
,
height
,
width
)
# 8*8, depth, height, width
# interpolation
# interpolation
mask
=
F
.
interpolate
(
mask
.
unsqueeze
(
0
),
size
=
(
new_depth
,
height
,
width
),
mode
=
"nearest-exact"
).
squeeze
(
0
)
mask
=
F
.
interpolate
(
mask
.
unsqueeze
(
0
),
size
=
(
new_depth
,
height
,
width
),
mode
=
"nearest-exact"
).
squeeze
(
0
)
...
@@ -161,7 +161,7 @@ class WanVaceRunner(WanRunner):
...
@@ -161,7 +161,7 @@ class WanVaceRunner(WanRunner):
@
ProfilingContext4DebugL1
(
"Run VAE Decoder"
)
@
ProfilingContext4DebugL1
(
"Run VAE Decoder"
)
def
run_vae_decoder
(
self
,
latents
):
def
run_vae_decoder
(
self
,
latents
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
(
self
.
config
[
"lazy_load"
]
if
"lazy_load"
in
self
.
config
else
False
)
or
(
self
.
config
[
"unload_modules"
]
if
"unload_modules"
in
self
.
config
else
False
):
self
.
vae_decoder
=
self
.
load_vae_decoder
()
self
.
vae_decoder
=
self
.
load_vae_decoder
()
if
self
.
src_ref_images
is
not
None
:
if
self
.
src_ref_images
is
not
None
:
...
@@ -172,7 +172,7 @@ class WanVaceRunner(WanRunner):
...
@@ -172,7 +172,7 @@ class WanVaceRunner(WanRunner):
images
=
self
.
vae_decoder
.
decode
(
latents
.
to
(
GET_DTYPE
()))
images
=
self
.
vae_decoder
.
decode
(
latents
.
to
(
GET_DTYPE
()))
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
(
self
.
config
[
"lazy_load"
]
if
"lazy_load"
in
self
.
config
else
False
)
or
(
self
.
config
[
"unload_modules"
]
if
"unload_modules"
in
self
.
config
else
False
):
del
self
.
vae_decoder
del
self
.
vae_decoder
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
...
...
tools/convert/converter.py
View file @
63f233ad
...
@@ -345,6 +345,7 @@ def quantize_model(
...
@@ -345,6 +345,7 @@ def quantize_model(
weights
,
weights
,
w_bit
=
8
,
w_bit
=
8
,
target_keys
=
[
"attn"
,
"ffn"
],
target_keys
=
[
"attn"
,
"ffn"
],
adapter_keys
=
None
,
key_idx
=
2
,
key_idx
=
2
,
ignore_key
=
None
,
ignore_key
=
None
,
linear_dtype
=
torch
.
int8
,
linear_dtype
=
torch
.
int8
,
...
@@ -375,18 +376,21 @@ def quantize_model(
...
@@ -375,18 +376,21 @@ def quantize_model(
tensor
=
weights
[
key
]
tensor
=
weights
[
key
]
# Skip non-tensors
, small tensors,
and non-2D tensors
# Skip non-tensors and non-2D tensors
if
not
isinstance
(
tensor
,
torch
.
Tensor
)
or
tensor
.
dim
()
!=
2
:
if
not
isinstance
(
tensor
,
torch
.
Tensor
)
or
tensor
.
dim
()
!=
2
:
if
tensor
.
dtype
!=
non_linear_dtype
:
if
tensor
.
dtype
!=
non_linear_dtype
:
weights
[
key
]
=
tensor
.
to
(
non_linear_dtype
)
weights
[
key
]
=
tensor
.
to
(
non_linear_dtype
)
continue
continue
# Check if key matches target modules
# Check if key matches target modules
parts
=
key
.
split
(
"."
)
parts
=
key
.
split
(
"."
)
if
len
(
parts
)
<
key_idx
+
1
or
parts
[
key_idx
]
not
in
target_keys
:
if
len
(
parts
)
<
key_idx
+
1
or
parts
[
key_idx
]
not
in
target_keys
:
if
tensor
.
dtype
!=
non_linear_dtype
:
if
adapter_keys
is
not
None
and
not
any
(
adapter_key
in
parts
for
adapter_key
in
adapter_keys
):
weights
[
key
]
=
tensor
.
to
(
non_linear_dtype
)
if
tensor
.
dtype
!=
non_linear_dtype
:
continue
weights
[
key
]
=
tensor
.
to
(
non_linear_dtype
)
continue
try
:
try
:
# Quantize tensor and store results
# Quantize tensor and store results
...
@@ -511,6 +515,7 @@ def convert_weights(args):
...
@@ -511,6 +515,7 @@ def convert_weights(args):
converted_weights
,
converted_weights
,
w_bit
=
args
.
bits
,
w_bit
=
args
.
bits
,
target_keys
=
args
.
target_keys
,
target_keys
=
args
.
target_keys
,
adapter_keys
=
args
.
adapter_keys
,
key_idx
=
args
.
key_idx
,
key_idx
=
args
.
key_idx
,
ignore_key
=
args
.
ignore_key
,
ignore_key
=
args
.
ignore_key
,
linear_dtype
=
args
.
linear_dtype
,
linear_dtype
=
args
.
linear_dtype
,
...
@@ -535,6 +540,8 @@ def convert_weights(args):
...
@@ -535,6 +540,8 @@ def convert_weights(args):
match
=
block_pattern
.
search
(
key
)
match
=
block_pattern
.
search
(
key
)
if
match
:
if
match
:
block_idx
=
match
.
group
(
1
)
block_idx
=
match
.
group
(
1
)
if
args
.
model_type
==
"wan_animate_dit"
and
"face_adapter"
in
key
:
block_idx
=
str
(
int
(
block_idx
)
*
5
)
block_groups
[
block_idx
][
key
]
=
tensor
block_groups
[
block_idx
][
key
]
=
tensor
else
:
else
:
non_block_weights
[
key
]
=
tensor
non_block_weights
[
key
]
=
tensor
...
@@ -635,7 +642,7 @@ def main():
...
@@ -635,7 +642,7 @@ def main():
parser
.
add_argument
(
parser
.
add_argument
(
"-t"
,
"-t"
,
"--model_type"
,
"--model_type"
,
choices
=
[
"wan_dit"
,
"hunyuan_dit"
,
"wan_t5"
,
"wan_clip"
],
choices
=
[
"wan_dit"
,
"hunyuan_dit"
,
"wan_t5"
,
"wan_clip"
,
"wan_animate_dit"
],
default
=
"wan_dit"
,
default
=
"wan_dit"
,
help
=
"Model type"
,
help
=
"Model type"
,
)
)
...
@@ -684,6 +691,7 @@ def main():
...
@@ -684,6 +691,7 @@ def main():
"target_keys"
:
[
"self_attn"
,
"cross_attn"
,
"ffn"
],
"target_keys"
:
[
"self_attn"
,
"cross_attn"
,
"ffn"
],
"ignore_key"
:
[
"ca"
,
"audio"
],
"ignore_key"
:
[
"ca"
,
"audio"
],
},
},
"wan_animate_dit"
:
{
"key_idx"
:
2
,
"target_keys"
:
[
"self_attn"
,
"cross_attn"
,
"ffn"
],
"adapter_keys"
:
[
"linear1_kv"
,
"linear1_q"
,
"linear2"
],
"ignore_key"
:
None
},
"hunyuan_dit"
:
{
"hunyuan_dit"
:
{
"key_idx"
:
2
,
"key_idx"
:
2
,
"target_keys"
:
[
"target_keys"
:
[
...
@@ -710,6 +718,7 @@ def main():
...
@@ -710,6 +718,7 @@ def main():
}
}
args
.
target_keys
=
model_type_keys_map
[
args
.
model_type
][
"target_keys"
]
args
.
target_keys
=
model_type_keys_map
[
args
.
model_type
][
"target_keys"
]
args
.
adapter_keys
=
model_type_keys_map
[
args
.
model_type
][
"adapter_keys"
]
if
"adapter_keys"
in
model_type_keys_map
[
args
.
model_type
]
else
None
args
.
key_idx
=
model_type_keys_map
[
args
.
model_type
][
"key_idx"
]
args
.
key_idx
=
model_type_keys_map
[
args
.
model_type
][
"key_idx"
]
args
.
ignore_key
=
model_type_keys_map
[
args
.
model_type
][
"ignore_key"
]
args
.
ignore_key
=
model_type_keys_map
[
args
.
model_type
][
"ignore_key"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment