Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
682037cd
Unverified
Commit
682037cd
authored
Sep 27, 2025
by
gushiqiao
Committed by
GitHub
Sep 27, 2025
Browse files
[Feat] Add wan2.2 animate model (#339)
parent
e251e4dc
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
759 additions
and
27 deletions
+759
-27
configs/wan22/wan_animate.json
configs/wan22/wan_animate.json
+22
-0
configs/wan22/wan_animate_replace.json
configs/wan22/wan_animate_replace.json
+24
-0
lightx2v/common/modules/weight_module.py
lightx2v/common/modules/weight_module.py
+3
-0
lightx2v/common/ops/attn/flash_attn.py
lightx2v/common/ops/attn/flash_attn.py
+20
-9
lightx2v/common/ops/attn/sage_attn.py
lightx2v/common/ops/attn/sage_attn.py
+17
-8
lightx2v/infer.py
lightx2v/infer.py
+3
-1
lightx2v/models/input_encoders/hf/animate/__init__.py
lightx2v/models/input_encoders/hf/animate/__init__.py
+0
-0
lightx2v/models/input_encoders/hf/animate/face_encoder.py
lightx2v/models/input_encoders/hf/animate/face_encoder.py
+171
-0
lightx2v/models/input_encoders/hf/animate/motion_encoder.py
lightx2v/models/input_encoders/hf/animate/motion_encoder.py
+300
-0
lightx2v/models/networks/wan/animate_model.py
lightx2v/models/networks/wan/animate_model.py
+22
-0
lightx2v/models/networks/wan/infer/animate/pre_infer.py
lightx2v/models/networks/wan/infer/animate/pre_infer.py
+31
-0
lightx2v/models/networks/wan/infer/animate/transformer_infer.py
...2v/models/networks/wan/infer/animate/transformer_infer.py
+38
-0
lightx2v/models/networks/wan/infer/module_io.py
lightx2v/models/networks/wan/infer/module_io.py
+1
-0
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+9
-2
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+3
-3
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+12
-1
lightx2v/models/networks/wan/weights/animate/transformer_weights.py
...odels/networks/wan/weights/animate/transformer_weights.py
+75
-0
lightx2v/models/networks/wan/weights/pre_weights.py
lightx2v/models/networks/wan/weights/pre_weights.py
+6
-1
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+1
-1
lightx2v/models/runners/base_runner.py
lightx2v/models/runners/base_runner.py
+1
-1
No files found.
configs/wan22/wan_animate.json
0 → 100755
View file @
682037cd
{
"infer_steps"
:
20
,
"target_video_length"
:
77
,
"text_len"
:
512
,
"target_height"
:
720
,
"target_width"
:
1280
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"adapter_attn_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_shift"
:
5.0
,
"sample_guide_scale"
:
5.0
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"src_pose_path"
:
"/path/to/animate/process_results/src_pose.mp4"
,
"src_face_path"
:
"/path/to/animate/process_results/src_face.mp4"
,
"src_ref_images"
:
"/path/to/animate/process_results/src_ref.png"
,
"refert_num"
:
1
,
"replace_flag"
:
false
,
"fps"
:
30
}
configs/wan22/wan_animate_replace.json
0 → 100755
View file @
682037cd
{
"infer_steps"
:
20
,
"target_video_length"
:
77
,
"text_len"
:
512
,
"target_height"
:
720
,
"target_width"
:
1280
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"adapter_attn_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_shift"
:
5.0
,
"sample_guide_scale"
:
5.0
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"src_pose_path"
:
"/path/to/replace/process_results/src_pose.mp4"
,
"src_face_path"
:
"/path/to/replace/process_results/src_face.mp4"
,
"src_ref_images"
:
"/path/to/replace/process_results/src_ref.png"
,
"src_bg_path"
:
"/path/to/replace/process_results/src_bg.mp4"
,
"src_mask_path"
:
"/path/to/replace/process_results/src_mask.mp4"
,
"refert_num"
:
1
,
"fps"
:
30
,
"replace_flag"
:
true
}
lightx2v/common/modules/weight_module.py
View file @
682037cd
...
@@ -3,6 +3,9 @@ class WeightModule:
...
@@ -3,6 +3,9 @@ class WeightModule:
self
.
_modules
=
{}
self
.
_modules
=
{}
self
.
_parameters
=
{}
self
.
_parameters
=
{}
def
is_empty
(
self
):
return
len
(
self
.
_modules
)
==
0
and
len
(
self
.
_parameters
)
==
0
def
add_module
(
self
,
name
,
module
):
def
add_module
(
self
,
name
,
module
):
self
.
_modules
[
name
]
=
module
self
.
_modules
[
name
]
=
module
setattr
(
self
,
name
,
module
)
setattr
(
self
,
name
,
module
)
...
...
lightx2v/common/ops/attn/flash_attn.py
100644 → 100755
View file @
682037cd
...
@@ -62,13 +62,24 @@ class FlashAttn3Weight(AttnWeightTemplate):
...
@@ -62,13 +62,24 @@ class FlashAttn3Weight(AttnWeightTemplate):
max_seqlen_kv
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
model_cls
=
None
,
):
):
x
=
flash_attn_varlen_func_v3
(
if
len
(
q
.
shape
)
==
3
:
q
,
x
=
flash_attn_varlen_func_v3
(
k
,
q
,
v
,
k
,
cu_seqlens_q
,
v
,
cu_seqlens_kv
,
cu_seqlens_q
,
max_seqlen_q
,
cu_seqlens_kv
,
max_seqlen_kv
,
max_seqlen_q
,
).
reshape
(
max_seqlen_q
,
-
1
)
max_seqlen_kv
,
).
reshape
(
max_seqlen_q
,
-
1
)
elif
len
(
q
.
shape
)
==
4
:
x
=
flash_attn_varlen_func_v3
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_kv
,
max_seqlen_q
,
max_seqlen_kv
,
).
reshape
(
q
.
shape
[
0
]
*
max_seqlen_q
,
-
1
)
return
x
return
x
lightx2v/common/ops/attn/sage_attn.py
View file @
682037cd
...
@@ -51,14 +51,23 @@ class SageAttn2Weight(AttnWeightTemplate):
...
@@ -51,14 +51,23 @@ class SageAttn2Weight(AttnWeightTemplate):
)
)
x
=
torch
.
cat
((
x1
,
x2
),
dim
=
1
)
x
=
torch
.
cat
((
x1
,
x2
),
dim
=
1
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
elif
model_cls
in
[
"wan2.1"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_df"
,
"seko_talk"
,
"wan2.2"
,
"wan2.1_vace"
,
"wan2.2_moe"
,
"wan2.2_moe_distill"
,
"qwen_image"
]:
elif
model_cls
in
[
"wan2.1"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_df"
,
"seko_talk"
,
"wan2.2"
,
"wan2.1_vace"
,
"wan2.2_moe"
,
"wan2.2_animate"
,
"wan2.2_moe_distill"
,
"qwen_image"
]:
x
=
sageattn
(
if
len
(
q
.
shape
)
==
3
:
q
.
unsqueeze
(
0
),
x
=
sageattn
(
k
.
unsqueeze
(
0
),
q
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
tensor_layout
=
"NHD"
,
v
.
unsqueeze
(
0
),
)
tensor_layout
=
"NHD"
,
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
elif
len
(
q
.
shape
)
==
4
:
x
=
sageattn
(
q
,
k
,
v
,
tensor_layout
=
"NHD"
,
)
x
=
x
.
view
(
q
.
shape
[
0
]
*
max_seqlen_q
,
-
1
)
else
:
else
:
raise
NotImplementedError
(
f
"Model class '
{
model_cls
}
' is not implemented in this attention implementation"
)
raise
NotImplementedError
(
f
"Model class '
{
model_cls
}
' is not implemented in this attention implementation"
)
return
x
return
x
lightx2v/infer.py
View file @
682037cd
...
@@ -8,6 +8,7 @@ from lightx2v.common.ops import *
...
@@ -8,6 +8,7 @@ from lightx2v.common.ops import *
from
lightx2v.models.runners.cogvideox.cogvidex_runner
import
CogvideoxRunner
# noqa: F401
from
lightx2v.models.runners.cogvideox.cogvidex_runner
import
CogvideoxRunner
# noqa: F401
from
lightx2v.models.runners.hunyuan.hunyuan_runner
import
HunyuanRunner
# noqa: F401
from
lightx2v.models.runners.hunyuan.hunyuan_runner
import
HunyuanRunner
# noqa: F401
from
lightx2v.models.runners.qwen_image.qwen_image_runner
import
QwenImageRunner
# noqa: F401
from
lightx2v.models.runners.qwen_image.qwen_image_runner
import
QwenImageRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_animate_runner
import
WanAnimateRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_audio_runner
import
Wan22AudioRunner
,
WanAudioRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_audio_runner
import
Wan22AudioRunner
,
WanAudioRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_causvid_runner
import
WanCausVidRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_causvid_runner
import
WanCausVidRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
# noqa: F401
...
@@ -50,11 +51,12 @@ def main():
...
@@ -50,11 +51,12 @@ def main():
"wan2.2_audio"
,
"wan2.2_audio"
,
"wan2.2_moe_distill"
,
"wan2.2_moe_distill"
,
"qwen_image"
,
"qwen_image"
,
"wan2.2_animate"
,
],
],
default
=
"wan2.1"
,
default
=
"wan2.1"
,
)
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
,
"t2i"
,
"i2i"
,
"flf2v"
,
"vace"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--task"
,
type
=
str
,
choices
=
[
"t2v"
,
"i2v"
,
"t2i"
,
"i2i"
,
"flf2v"
,
"vace"
,
"animate"
],
default
=
"t2v"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--config_json"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--use_prompt_enhancer"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--use_prompt_enhancer"
,
action
=
"store_true"
)
...
...
lightx2v/models/input_encoders/hf/animate/__init__.py
0 → 100644
View file @
682037cd
lightx2v/models/input_encoders/hf/animate/face_encoder.py
0 → 100644
View file @
682037cd
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import
math
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
torch
import
nn
try
:
from
flash_attn
import
flash_attn_func
,
flash_attn_qkvpacked_func
# noqa: F401
except
ImportError
:
flash_attn_func
=
None
MEMORY_LAYOUT
=
{
"flash"
:
(
lambda
x
:
x
.
view
(
x
.
shape
[
0
]
*
x
.
shape
[
1
],
*
x
.
shape
[
2
:]),
lambda
x
:
x
,
),
"torch"
:
(
lambda
x
:
x
.
transpose
(
1
,
2
),
lambda
x
:
x
.
transpose
(
1
,
2
),
),
"vanilla"
:
(
lambda
x
:
x
.
transpose
(
1
,
2
),
lambda
x
:
x
.
transpose
(
1
,
2
),
),
}
def
attention
(
q
,
k
,
v
,
mode
=
"flash"
,
drop_rate
=
0
,
attn_mask
=
None
,
causal
=
False
,
max_seqlen_q
=
None
,
batch_size
=
1
,
):
"""
Perform QKV self attention.
Args:
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
drop_rate (float): Dropout rate in attention map. (default: 0)
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
(default: None)
causal (bool): Whether to use causal attention. (default: False)
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into q.
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
used to index into kv.
max_seqlen_q (int): The maximum sequence length in the batch of q.
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
Returns:
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
"""
pre_attn_layout
,
post_attn_layout
=
MEMORY_LAYOUT
[
mode
]
if
mode
==
"torch"
:
if
attn_mask
is
not
None
and
attn_mask
.
dtype
!=
torch
.
bool
:
attn_mask
=
attn_mask
.
to
(
q
.
dtype
)
x
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
attn_mask
,
dropout_p
=
drop_rate
,
is_causal
=
causal
)
elif
mode
==
"flash"
:
x
=
flash_attn_func
(
q
,
k
,
v
,
)
x
=
x
.
view
(
batch_size
,
max_seqlen_q
,
x
.
shape
[
-
2
],
x
.
shape
[
-
1
])
# reshape x to [b, s, a, d]
elif
mode
==
"vanilla"
:
scale_factor
=
1
/
math
.
sqrt
(
q
.
size
(
-
1
))
b
,
a
,
s
,
_
=
q
.
shape
s1
=
k
.
size
(
2
)
attn_bias
=
torch
.
zeros
(
b
,
a
,
s
,
s1
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
if
causal
:
# Only applied to self attention
assert
attn_mask
is
None
,
"Causal mask and attn_mask cannot be used together"
temp_mask
=
torch
.
ones
(
b
,
a
,
s
,
s
,
dtype
=
torch
.
bool
,
device
=
q
.
device
).
tril
(
diagonal
=
0
)
attn_bias
.
masked_fill_
(
temp_mask
.
logical_not
(),
float
(
"-inf"
))
attn_bias
.
to
(
q
.
dtype
)
if
attn_mask
is
not
None
:
if
attn_mask
.
dtype
==
torch
.
bool
:
attn_bias
.
masked_fill_
(
attn_mask
.
logical_not
(),
float
(
"-inf"
))
else
:
attn_bias
+=
attn_mask
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
scale_factor
attn
+=
attn_bias
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
torch
.
dropout
(
attn
,
p
=
drop_rate
,
train
=
True
)
x
=
attn
@
v
else
:
raise
NotImplementedError
(
f
"Unsupported attention mode:
{
mode
}
"
)
x
=
post_attn_layout
(
x
)
b
,
s
,
a
,
d
=
x
.
shape
out
=
x
.
reshape
(
b
,
s
,
-
1
)
return
out
class
CausalConv1d
(
nn
.
Module
):
def
__init__
(
self
,
chan_in
,
chan_out
,
kernel_size
=
3
,
stride
=
1
,
dilation
=
1
,
pad_mode
=
"replicate"
,
**
kwargs
):
super
().
__init__
()
self
.
pad_mode
=
pad_mode
padding
=
(
kernel_size
-
1
,
0
)
# T
self
.
time_causal_padding
=
padding
self
.
conv
=
nn
.
Conv1d
(
chan_in
,
chan_out
,
kernel_size
,
stride
=
stride
,
dilation
=
dilation
,
**
kwargs
)
def
forward
(
self
,
x
):
x
=
F
.
pad
(
x
,
self
.
time_causal_padding
,
mode
=
self
.
pad_mode
)
return
self
.
conv
(
x
)
class
FaceEncoder
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
:
int
,
hidden_dim
:
int
,
num_heads
=
int
,
dtype
=
None
,
device
=
None
):
factory_kwargs
=
{
"dtype"
:
dtype
,
"device"
:
device
}
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
conv1_local
=
CausalConv1d
(
in_dim
,
1024
*
num_heads
,
3
,
stride
=
1
)
self
.
norm1
=
nn
.
LayerNorm
(
hidden_dim
//
8
,
elementwise_affine
=
False
,
eps
=
1e-6
,
**
factory_kwargs
)
self
.
act
=
nn
.
SiLU
()
self
.
conv2
=
CausalConv1d
(
1024
,
1024
,
3
,
stride
=
2
)
self
.
conv3
=
CausalConv1d
(
1024
,
1024
,
3
,
stride
=
2
)
self
.
out_proj
=
nn
.
Linear
(
1024
,
hidden_dim
)
self
.
norm1
=
nn
.
LayerNorm
(
1024
,
elementwise_affine
=
False
,
eps
=
1e-6
,
**
factory_kwargs
)
self
.
norm2
=
nn
.
LayerNorm
(
1024
,
elementwise_affine
=
False
,
eps
=
1e-6
,
**
factory_kwargs
)
self
.
norm3
=
nn
.
LayerNorm
(
1024
,
elementwise_affine
=
False
,
eps
=
1e-6
,
**
factory_kwargs
)
self
.
padding_tokens
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
1
,
hidden_dim
))
def
forward
(
self
,
x
):
x
=
rearrange
(
x
,
"b t c -> b c t"
)
b
,
c
,
t
=
x
.
shape
x
=
self
.
conv1_local
(
x
)
x
=
rearrange
(
x
,
"b (n c) t -> (b n) t c"
,
n
=
self
.
num_heads
)
x
=
self
.
norm1
(
x
)
x
=
self
.
act
(
x
)
x
=
rearrange
(
x
,
"b t c -> b c t"
)
x
=
self
.
conv2
(
x
)
x
=
rearrange
(
x
,
"b c t -> b t c"
)
x
=
self
.
norm2
(
x
)
x
=
self
.
act
(
x
)
x
=
rearrange
(
x
,
"b t c -> b c t"
)
x
=
self
.
conv3
(
x
)
x
=
rearrange
(
x
,
"b c t -> b t c"
)
x
=
self
.
norm3
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
out_proj
(
x
)
x
=
rearrange
(
x
,
"(b n) t c -> b t n c"
,
b
=
b
)
padding
=
self
.
padding_tokens
.
repeat
(
b
,
x
.
shape
[
1
],
1
,
1
)
x
=
torch
.
cat
([
x
,
padding
],
dim
=-
2
)
x_local
=
x
.
clone
()
return
x_local
lightx2v/models/input_encoders/hf/animate/motion_encoder.py
0 → 100644
View file @
682037cd
# Modified from ``https://github.com/wyhsirius/LIA``
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import
math
import
torch
import
torch.nn
as
nn
from
torch.nn
import
functional
as
F
def
custom_qr
(
input_tensor
):
original_dtype
=
input_tensor
.
dtype
if
original_dtype
==
torch
.
bfloat16
:
q
,
r
=
torch
.
linalg
.
qr
(
input_tensor
.
to
(
torch
.
float32
))
return
q
.
to
(
original_dtype
),
r
.
to
(
original_dtype
)
return
torch
.
linalg
.
qr
(
input_tensor
)
def
fused_leaky_relu
(
input
,
bias
,
negative_slope
=
0.2
,
scale
=
2
**
0.5
):
return
F
.
leaky_relu
(
input
+
bias
,
negative_slope
)
*
scale
def
upfirdn2d_native
(
input
,
kernel
,
up_x
,
up_y
,
down_x
,
down_y
,
pad_x0
,
pad_x1
,
pad_y0
,
pad_y1
):
_
,
minor
,
in_h
,
in_w
=
input
.
shape
kernel_h
,
kernel_w
=
kernel
.
shape
out
=
input
.
view
(
-
1
,
minor
,
in_h
,
1
,
in_w
,
1
)
out
=
F
.
pad
(
out
,
[
0
,
up_x
-
1
,
0
,
0
,
0
,
up_y
-
1
,
0
,
0
])
out
=
out
.
view
(
-
1
,
minor
,
in_h
*
up_y
,
in_w
*
up_x
)
out
=
F
.
pad
(
out
,
[
max
(
pad_x0
,
0
),
max
(
pad_x1
,
0
),
max
(
pad_y0
,
0
),
max
(
pad_y1
,
0
)])
out
=
out
[
:,
:,
max
(
-
pad_y0
,
0
)
:
out
.
shape
[
2
]
-
max
(
-
pad_y1
,
0
),
max
(
-
pad_x0
,
0
)
:
out
.
shape
[
3
]
-
max
(
-
pad_x1
,
0
),
]
out
=
out
.
reshape
([
-
1
,
1
,
in_h
*
up_y
+
pad_y0
+
pad_y1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
])
w
=
torch
.
flip
(
kernel
,
[
0
,
1
]).
view
(
1
,
1
,
kernel_h
,
kernel_w
)
out
=
F
.
conv2d
(
out
,
w
)
out
=
out
.
reshape
(
-
1
,
minor
,
in_h
*
up_y
+
pad_y0
+
pad_y1
-
kernel_h
+
1
,
in_w
*
up_x
+
pad_x0
+
pad_x1
-
kernel_w
+
1
,
)
return
out
[:,
:,
::
down_y
,
::
down_x
]
def
upfirdn2d
(
input
,
kernel
,
up
=
1
,
down
=
1
,
pad
=
(
0
,
0
)):
return
upfirdn2d_native
(
input
,
kernel
,
up
,
up
,
down
,
down
,
pad
[
0
],
pad
[
1
],
pad
[
0
],
pad
[
1
])
def
make_kernel
(
k
):
k
=
torch
.
tensor
(
k
,
dtype
=
torch
.
float32
)
if
k
.
ndim
==
1
:
k
=
k
[
None
,
:]
*
k
[:,
None
]
k
/=
k
.
sum
()
return
k
class
FusedLeakyReLU
(
nn
.
Module
):
def
__init__
(
self
,
channel
,
negative_slope
=
0.2
,
scale
=
2
**
0.5
):
super
().
__init__
()
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
channel
,
1
,
1
))
self
.
negative_slope
=
negative_slope
self
.
scale
=
scale
def
forward
(
self
,
input
):
out
=
fused_leaky_relu
(
input
,
self
.
bias
,
self
.
negative_slope
,
self
.
scale
)
return
out
class
Blur
(
nn
.
Module
):
def
__init__
(
self
,
kernel
,
pad
,
upsample_factor
=
1
):
super
().
__init__
()
kernel
=
make_kernel
(
kernel
)
if
upsample_factor
>
1
:
kernel
=
kernel
*
(
upsample_factor
**
2
)
self
.
register_buffer
(
"kernel"
,
kernel
)
self
.
pad
=
pad
def
forward
(
self
,
input
):
return
upfirdn2d
(
input
,
self
.
kernel
,
pad
=
self
.
pad
)
class
ScaledLeakyReLU
(
nn
.
Module
):
def
__init__
(
self
,
negative_slope
=
0.2
):
super
().
__init__
()
self
.
negative_slope
=
negative_slope
def
forward
(
self
,
input
):
return
F
.
leaky_relu
(
input
,
negative_slope
=
self
.
negative_slope
)
class
EqualConv2d
(
nn
.
Module
):
def
__init__
(
self
,
in_channel
,
out_channel
,
kernel_size
,
stride
=
1
,
padding
=
0
,
bias
=
True
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
out_channel
,
in_channel
,
kernel_size
,
kernel_size
))
self
.
scale
=
1
/
math
.
sqrt
(
in_channel
*
kernel_size
**
2
)
self
.
stride
=
stride
self
.
padding
=
padding
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_channel
))
else
:
self
.
bias
=
None
def
forward
(
self
,
input
):
return
F
.
conv2d
(
input
,
self
.
weight
*
self
.
scale
,
bias
=
self
.
bias
,
stride
=
self
.
stride
,
padding
=
self
.
padding
)
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
(
{
self
.
weight
.
shape
[
1
]
}
,
{
self
.
weight
.
shape
[
0
]
}
,
{
self
.
weight
.
shape
[
2
]
}
, stride=
{
self
.
stride
}
, padding=
{
self
.
padding
}
)"
class
EqualLinear
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
,
bias
=
True
,
bias_init
=
0
,
lr_mul
=
1
,
activation
=
None
):
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
out_dim
,
in_dim
).
div_
(
lr_mul
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
out_dim
).
fill_
(
bias_init
))
else
:
self
.
bias
=
None
self
.
activation
=
activation
self
.
scale
=
(
1
/
math
.
sqrt
(
in_dim
))
*
lr_mul
self
.
lr_mul
=
lr_mul
def
forward
(
self
,
input
):
if
self
.
activation
:
out
=
F
.
linear
(
input
,
self
.
weight
*
self
.
scale
)
out
=
fused_leaky_relu
(
out
,
self
.
bias
*
self
.
lr_mul
)
else
:
out
=
F
.
linear
(
input
,
self
.
weight
*
self
.
scale
,
bias
=
self
.
bias
*
self
.
lr_mul
)
return
out
def
__repr__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
(
{
self
.
weight
.
shape
[
1
]
}
,
{
self
.
weight
.
shape
[
0
]
}
)"
class
ConvLayer
(
nn
.
Sequential
):
def
__init__
(
self
,
in_channel
,
out_channel
,
kernel_size
,
downsample
=
False
,
blur_kernel
=
[
1
,
3
,
3
,
1
],
bias
=
True
,
activate
=
True
,
):
layers
=
[]
if
downsample
:
factor
=
2
p
=
(
len
(
blur_kernel
)
-
factor
)
+
(
kernel_size
-
1
)
pad0
=
(
p
+
1
)
//
2
pad1
=
p
//
2
layers
.
append
(
Blur
(
blur_kernel
,
pad
=
(
pad0
,
pad1
)))
stride
=
2
self
.
padding
=
0
else
:
stride
=
1
self
.
padding
=
kernel_size
//
2
layers
.
append
(
EqualConv2d
(
in_channel
,
out_channel
,
kernel_size
,
padding
=
self
.
padding
,
stride
=
stride
,
bias
=
bias
and
not
activate
))
if
activate
:
if
bias
:
layers
.
append
(
FusedLeakyReLU
(
out_channel
))
else
:
layers
.
append
(
ScaledLeakyReLU
(
0.2
))
super
().
__init__
(
*
layers
)
class
ResBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channel
,
out_channel
,
blur_kernel
=
[
1
,
3
,
3
,
1
]):
super
().
__init__
()
self
.
conv1
=
ConvLayer
(
in_channel
,
in_channel
,
3
)
self
.
conv2
=
ConvLayer
(
in_channel
,
out_channel
,
3
,
downsample
=
True
)
self
.
skip
=
ConvLayer
(
in_channel
,
out_channel
,
1
,
downsample
=
True
,
activate
=
False
,
bias
=
False
)
def
forward
(
self
,
input
):
out
=
self
.
conv1
(
input
)
out
=
self
.
conv2
(
out
)
skip
=
self
.
skip
(
input
)
out
=
(
out
+
skip
)
/
math
.
sqrt
(
2
)
return
out
class
EncoderApp
(
nn
.
Module
):
def
__init__
(
self
,
size
,
w_dim
=
512
):
super
(
EncoderApp
,
self
).
__init__
()
channels
=
{
4
:
512
,
8
:
512
,
16
:
512
,
32
:
512
,
64
:
256
,
128
:
128
,
256
:
64
,
512
:
32
,
1024
:
16
}
self
.
w_dim
=
w_dim
log_size
=
int
(
math
.
log
(
size
,
2
))
self
.
convs
=
nn
.
ModuleList
()
self
.
convs
.
append
(
ConvLayer
(
3
,
channels
[
size
],
1
))
in_channel
=
channels
[
size
]
for
i
in
range
(
log_size
,
2
,
-
1
):
out_channel
=
channels
[
2
**
(
i
-
1
)]
self
.
convs
.
append
(
ResBlock
(
in_channel
,
out_channel
))
in_channel
=
out_channel
self
.
convs
.
append
(
EqualConv2d
(
in_channel
,
self
.
w_dim
,
4
,
padding
=
0
,
bias
=
False
))
def
forward
(
self
,
x
):
res
=
[]
h
=
x
for
conv
in
self
.
convs
:
h
=
conv
(
h
)
res
.
append
(
h
)
return
res
[
-
1
].
squeeze
(
-
1
).
squeeze
(
-
1
),
res
[::
-
1
][
2
:]
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
size
,
dim
=
512
,
dim_motion
=
20
):
super
(
Encoder
,
self
).
__init__
()
# appearance netmork
self
.
net_app
=
EncoderApp
(
size
,
dim
)
# motion network
fc
=
[
EqualLinear
(
dim
,
dim
)]
for
i
in
range
(
3
):
fc
.
append
(
EqualLinear
(
dim
,
dim
))
fc
.
append
(
EqualLinear
(
dim
,
dim_motion
))
self
.
fc
=
nn
.
Sequential
(
*
fc
)
def
enc_app
(
self
,
x
):
h_source
=
self
.
net_app
(
x
)
return
h_source
def
enc_motion
(
self
,
x
):
h
,
_
=
self
.
net_app
(
x
)
h_motion
=
self
.
fc
(
h
)
return
h_motion
class
Direction
(
nn
.
Module
):
def
__init__
(
self
,
motion_dim
):
super
(
Direction
,
self
).
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
randn
(
512
,
motion_dim
))
def
forward
(
self
,
input
):
weight
=
self
.
weight
+
1e-8
Q
,
R
=
custom_qr
(
weight
)
if
input
is
None
:
return
Q
else
:
input_diag
=
torch
.
diag_embed
(
input
)
# alpha, diagonal matrix
out
=
torch
.
matmul
(
input_diag
,
Q
.
T
)
out
=
torch
.
sum
(
out
,
dim
=
1
)
return
out
class
Synthesis
(
nn
.
Module
):
def
__init__
(
self
,
motion_dim
):
super
(
Synthesis
,
self
).
__init__
()
self
.
direction
=
Direction
(
motion_dim
)
class
Generator
(
nn
.
Module
):
def
__init__
(
self
,
size
,
style_dim
=
512
,
motion_dim
=
20
):
super
().
__init__
()
self
.
enc
=
Encoder
(
size
,
style_dim
,
motion_dim
)
self
.
dec
=
Synthesis
(
motion_dim
)
def
get_motion
(
self
,
img
):
# motion_feat = self.enc.enc_motion(img)
motion_feat
=
torch
.
utils
.
checkpoint
.
checkpoint
((
self
.
enc
.
enc_motion
),
img
,
use_reentrant
=
True
)
with
torch
.
amp
.
autocast
(
"cuda"
,
dtype
=
torch
.
float32
):
motion
=
self
.
dec
.
direction
(
motion_feat
)
return
motion
lightx2v/models/networks/wan/animate_model.py
0 → 100755
View file @
682037cd
from
lightx2v.models.networks.wan.infer.animate.pre_infer
import
WanAnimatePreInfer
from
lightx2v.models.networks.wan.infer.animate.transformer_infer
import
WanAnimateTransformerInfer
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.networks.wan.weights.animate.transformer_weights
import
WanAnimateTransformerWeights
from
lightx2v.models.networks.wan.weights.pre_weights
import
WanPreWeights
class
WanAnimateModel
(
WanModel
):
pre_weight_class
=
WanPreWeights
transformer_weight_class
=
WanAnimateTransformerWeights
def
__init__
(
self
,
model_path
,
config
,
device
):
self
.
remove_keys
=
[
"face_encoder"
,
"motion_encoder"
]
super
().
__init__
(
model_path
,
config
,
device
)
def
_init_infer_class
(
self
):
super
().
_init_infer_class
()
self
.
pre_infer_class
=
WanAnimatePreInfer
self
.
transformer_infer_class
=
WanAnimateTransformerInfer
def
set_animate_encoders
(
self
,
motion_encoder
,
face_encoder
):
self
.
pre_infer
.
set_animate_encoders
(
motion_encoder
,
face_encoder
)
lightx2v/models/networks/wan/infer/animate/pre_infer.py
0 → 100755
View file @
682037cd
import
math
import
torch
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.utils.envs
import
*
class
WanAnimatePreInfer
(
WanPreInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
encode_bs
=
8
def
set_animate_encoders
(
self
,
motion_encoder
,
face_encoder
):
self
.
motion_encoder
=
motion_encoder
self
.
face_encoder
=
face_encoder
@
torch
.
no_grad
()
def
after_patch_embedding
(
self
,
weights
,
x
,
pose_latents
,
face_pixel_values
):
pose_latents
=
weights
.
pose_patch_embedding
.
apply
(
pose_latents
)
x
[:,
:,
1
:].
add_
(
pose_latents
)
face_pixel_values_tmp
=
[]
for
i
in
range
(
math
.
ceil
(
face_pixel_values
.
shape
[
0
]
/
self
.
encode_bs
)):
face_pixel_values_tmp
.
append
(
self
.
motion_encoder
.
get_motion
(
face_pixel_values
[
i
*
self
.
encode_bs
:
(
i
+
1
)
*
self
.
encode_bs
]))
motion_vec
=
torch
.
cat
(
face_pixel_values_tmp
)
motion_vec
=
self
.
face_encoder
(
motion_vec
.
unsqueeze
(
0
).
to
(
GET_DTYPE
())).
squeeze
(
0
)
pad_face
=
torch
.
zeros
(
1
,
motion_vec
.
shape
[
1
],
motion_vec
.
shape
[
2
],
dtype
=
motion_vec
.
dtype
,
device
=
"cuda"
)
motion_vec
=
torch
.
cat
([
pad_face
,
motion_vec
],
dim
=
0
)
return
x
,
motion_vec
lightx2v/models/networks/wan/infer/animate/transformer_infer.py
0 → 100755
View file @
682037cd
import
torch
from
einops
import
rearrange
from
lightx2v.models.networks.wan.infer.offload.transformer_infer
import
WanOffloadTransformerInfer
class
WanAnimateTransformerInfer
(
WanOffloadTransformerInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
has_post_adapter
=
True
self
.
phases_num
=
4
@
torch
.
no_grad
()
def
infer_post_adapter
(
self
,
phase
,
x
,
pre_infer_out
):
if
phase
.
is_empty
():
return
x
T
=
pre_infer_out
.
motion_vec
.
shape
[
0
]
x_motion
=
phase
.
pre_norm_motion
.
apply
(
pre_infer_out
.
motion_vec
)
x_feat
=
phase
.
pre_norm_feat
.
apply
(
x
)
kv
=
phase
.
linear1_kv
.
apply
(
x_motion
.
view
(
-
1
,
x_motion
.
shape
[
-
1
]))
kv
=
kv
.
view
(
T
,
-
1
,
kv
.
shape
[
-
1
])
q
=
phase
.
linear1_q
.
apply
(
x_feat
)
k
,
v
=
rearrange
(
kv
,
"L N (K H D) -> K L N H D"
,
K
=
2
,
H
=
self
.
config
.
num_heads
)
q
=
rearrange
(
q
,
"S (H D) -> S H D"
,
H
=
self
.
config
.
num_heads
)
q
=
phase
.
q_norm
.
apply
(
q
).
view
(
T
,
q
.
shape
[
0
]
//
T
,
q
.
shape
[
1
],
q
.
shape
[
2
])
k
=
phase
.
k_norm
.
apply
(
k
)
attn
=
phase
.
adapter_attn
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
max_seqlen_q
=
q
.
shape
[
1
],
model_cls
=
self
.
config
[
"model_cls"
],
)
output
=
phase
.
linear2
.
apply
(
attn
)
x
=
x
.
add_
(
output
)
return
x
lightx2v/models/networks/wan/infer/module_io.py
View file @
682037cd
...
@@ -19,4 +19,5 @@ class WanPreInferModuleOutput:
...
@@ -19,4 +19,5 @@ class WanPreInferModuleOutput:
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
freqs
:
torch
.
Tensor
freqs
:
torch
.
Tensor
context
:
torch
.
Tensor
context
:
torch
.
Tensor
motion_vec
:
torch
.
Tensor
adapter_output
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
adapter_output
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
682037cd
...
@@ -41,7 +41,7 @@ class WanPreInfer:
...
@@ -41,7 +41,7 @@ class WanPreInfer:
else
:
else
:
context
=
inputs
[
"text_encoder_output"
][
"context_null"
]
context
=
inputs
[
"text_encoder_output"
][
"context_null"
]
if
self
.
task
in
[
"i2v"
,
"flf2v"
]:
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]:
if
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
config
.
get
(
"use_image_encoder"
,
True
):
clip_fea
=
inputs
[
"image_encoder_output"
][
"clip_encoder_out"
]
clip_fea
=
inputs
[
"image_encoder_output"
][
"clip_encoder_out"
]
...
@@ -61,6 +61,12 @@ class WanPreInfer:
...
@@ -61,6 +61,12 @@ class WanPreInfer:
# embeddings
# embeddings
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
if
hasattr
(
self
,
"after_patch_embedding"
):
x
,
motion_vec
=
self
.
after_patch_embedding
(
weights
,
x
,
inputs
[
"image_encoder_output"
][
"pose_latents"
],
inputs
[
"image_encoder_output"
][
"face_pixel_values"
])
else
:
motion_vec
=
None
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
=
x
.
shape
[
2
:]
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
=
x
.
shape
[
2
:]
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
int32
,
device
=
x
.
device
).
unsqueeze
(
0
)
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
int32
,
device
=
x
.
device
).
unsqueeze
(
0
)
...
@@ -94,7 +100,7 @@ class WanPreInfer:
...
@@ -94,7 +100,7 @@ class WanPreInfer:
del
out
del
out
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
self
.
task
in
[
"i2v"
,
"flf2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
task
==
"flf2v"
:
if
self
.
task
==
"flf2v"
:
_
,
n
,
d
=
clip_fea
.
shape
_
,
n
,
d
=
clip_fea
.
shape
clip_fea
=
clip_fea
.
view
(
2
*
n
,
d
)
clip_fea
=
clip_fea
.
view
(
2
*
n
,
d
)
...
@@ -125,4 +131,5 @@ class WanPreInfer:
...
@@ -125,4 +131,5 @@ class WanPreInfer:
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
freqs
=
self
.
freqs
,
context
=
context
,
context
=
context
,
motion_vec
=
motion_vec
,
)
)
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
682037cd
...
@@ -203,7 +203,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -203,7 +203,7 @@ class WanTransformerInfer(BaseTransformerInfer):
x
.
add_
(
y_out
*
gate_msa
.
squeeze
())
x
.
add_
(
y_out
*
gate_msa
.
squeeze
())
norm3_out
=
phase
.
norm3
.
apply
(
x
)
norm3_out
=
phase
.
norm3
.
apply
(
x
)
if
self
.
task
in
[
"i2v"
,
"flf2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
context_img
=
context
[:
257
]
context_img
=
context
[:
257
]
context
=
context
[
257
:]
context
=
context
[
257
:]
else
:
else
:
...
@@ -211,7 +211,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -211,7 +211,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
context
=
context
.
to
(
self
.
infer_dtype
)
context
=
context
.
to
(
self
.
infer_dtype
)
if
self
.
task
in
[
"i2v"
,
"flf2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
context_img
=
context_img
.
to
(
self
.
infer_dtype
)
context_img
=
context_img
.
to
(
self
.
infer_dtype
)
n
,
d
=
self
.
num_heads
,
self
.
head_dim
n
,
d
=
self
.
num_heads
,
self
.
head_dim
...
@@ -234,7 +234,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -234,7 +234,7 @@ class WanTransformerInfer(BaseTransformerInfer):
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
)
)
if
self
.
task
in
[
"i2v"
,
"flf2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
)
and
context_img
is
not
None
:
if
self
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
)
and
context_img
is
not
None
:
k_img
=
phase
.
cross_attn_norm_k_img
.
apply
(
phase
.
cross_attn_k_img
.
apply
(
context_img
)).
view
(
-
1
,
n
,
d
)
k_img
=
phase
.
cross_attn_norm_k_img
.
apply
(
phase
.
cross_attn_k_img
.
apply
(
context_img
)).
view
(
-
1
,
n
,
d
)
v_img
=
phase
.
cross_attn_v_img
.
apply
(
context_img
).
view
(
-
1
,
n
,
d
)
v_img
=
phase
.
cross_attn_v_img
.
apply
(
context_img
).
view
(
-
1
,
n
,
d
)
...
...
lightx2v/models/networks/wan/model.py
View file @
682037cd
...
@@ -137,12 +137,19 @@ class WanModel(CompiledMethodsMixin):
...
@@ -137,12 +137,19 @@ class WanModel(CompiledMethodsMixin):
return
False
return
False
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
if
self
.
device
.
type
==
"cuda"
and
dist
.
is_initialized
():
if
self
.
device
.
type
==
"cuda"
and
dist
.
is_initialized
():
device
=
torch
.
device
(
"cuda:{}"
.
format
(
dist
.
get_rank
()))
device
=
torch
.
device
(
"cuda:{}"
.
format
(
dist
.
get_rank
()))
else
:
else
:
device
=
self
.
device
device
=
self
.
device
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
str
(
device
))
as
f
:
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
str
(
device
))
as
f
:
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
()))
for
key
in
f
.
keys
()}
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
()))
for
key
in
f
.
keys
()
if
not
any
(
remove_key
in
key
for
remove_key
in
remove_keys
)
}
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
safetensors_path
=
find_hf_model_path
(
self
.
config
,
self
.
model_path
,
"dit_original_ckpt"
,
subdir
=
"original"
)
safetensors_path
=
find_hf_model_path
(
self
.
config
,
self
.
model_path
,
"dit_original_ckpt"
,
subdir
=
"original"
)
...
@@ -158,6 +165,7 @@ class WanModel(CompiledMethodsMixin):
...
@@ -158,6 +165,7 @@ class WanModel(CompiledMethodsMixin):
return
weight_dict
return
weight_dict
def
_load_quant_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
def
_load_quant_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
ckpt_path
=
self
.
dit_quantized_ckpt
ckpt_path
=
self
.
dit_quantized_ckpt
index_files
=
[
f
for
f
in
os
.
listdir
(
ckpt_path
)
if
f
.
endswith
(
".index.json"
)]
index_files
=
[
f
for
f
in
os
.
listdir
(
ckpt_path
)
if
f
.
endswith
(
".index.json"
)]
if
not
index_files
:
if
not
index_files
:
...
@@ -175,6 +183,9 @@ class WanModel(CompiledMethodsMixin):
...
@@ -175,6 +183,9 @@ class WanModel(CompiledMethodsMixin):
with
safe_open
(
safetensor_path
,
framework
=
"pt"
)
as
f
:
with
safe_open
(
safetensor_path
,
framework
=
"pt"
)
as
f
:
logger
.
info
(
f
"Loading weights from
{
safetensor_path
}
"
)
logger
.
info
(
f
"Loading weights from
{
safetensor_path
}
"
)
for
k
in
f
.
keys
():
for
k
in
f
.
keys
():
if
any
(
remove_key
in
k
for
remove_key
in
remove_keys
):
continue
if
f
.
get_tensor
(
k
).
dtype
in
[
if
f
.
get_tensor
(
k
).
dtype
in
[
torch
.
float16
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
bfloat16
,
...
...
lightx2v/models/networks/wan/weights/animate/transformer_weights.py
0 → 100755
View file @
682037cd
import
os
from
safetensors
import
safe_open
from
lightx2v.common.modules.weight_module
import
WeightModule
from
lightx2v.models.networks.wan.weights.transformer_weights
import
(
WanTransformerWeights
,
)
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
MM_WEIGHT_REGISTER
,
RMS_WEIGHT_REGISTER
class
WanAnimateTransformerWeights
(
WanTransformerWeights
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
adapter_blocks_num
=
self
.
blocks_num
//
5
for
i
in
range
(
self
.
blocks_num
):
if
i
%
5
==
0
:
self
.
blocks
[
i
].
compute_phases
.
append
(
WanAnimateFuserBlock
(
self
.
config
,
i
//
5
,
"face_adapter.fuser_blocks"
,
self
.
mm_type
))
else
:
self
.
blocks
[
i
].
compute_phases
.
append
(
WeightModule
())
class
WanAnimateFuserBlock
(
WeightModule
):
def
__init__
(
self
,
config
,
block_index
,
block_prefix
,
mm_type
):
super
().
__init__
()
self
.
config
=
config
lazy_load
=
config
.
get
(
"lazy_load"
,
False
)
if
lazy_load
:
lazy_load_path
=
os
.
path
.
join
(
config
.
dit_quantized_ckpt
,
f
"
{
block_prefix
[:
-
1
]
}
_
{
block_index
}
.safetensors"
)
lazy_load_file
=
safe_open
(
lazy_load_path
,
framework
=
"pt"
,
device
=
"cpu"
)
else
:
lazy_load_file
=
None
self
.
add_module
(
"linear1_kv"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_kv.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_kv.bias"
,
lazy_load
,
lazy_load_file
),
)
self
.
add_module
(
"linear1_q"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_q.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_q.bias"
,
lazy_load
,
lazy_load_file
),
)
self
.
add_module
(
"linear2"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.linear2.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear2.bias"
,
lazy_load
,
lazy_load_file
),
)
self
.
add_module
(
"q_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"
{
block_prefix
}
.
{
block_index
}
.q_norm.weight"
,
lazy_load
,
lazy_load_file
,
),
)
self
.
add_module
(
"k_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"
{
block_prefix
}
.
{
block_index
}
.k_norm.weight"
,
lazy_load
,
lazy_load_file
,
),
)
self
.
add_module
(
"pre_norm_feat"
,
LN_WEIGHT_REGISTER
[
"Default"
](),
)
self
.
add_module
(
"pre_norm_motion"
,
LN_WEIGHT_REGISTER
[
"Default"
](),
)
self
.
add_module
(
"adapter_attn"
,
ATTN_WEIGHT_REGISTER
[
config
[
"adapter_attn_type"
]]())
lightx2v/models/networks/wan/weights/pre_weights.py
View file @
682037cd
...
@@ -40,7 +40,7 @@ class WanPreWeights(WeightModule):
...
@@ -40,7 +40,7 @@ class WanPreWeights(WeightModule):
MM_WEIGHT_REGISTER
[
"Default"
](
"time_projection.1.weight"
,
"time_projection.1.bias"
),
MM_WEIGHT_REGISTER
[
"Default"
](
"time_projection.1.weight"
,
"time_projection.1.bias"
),
)
)
if
config
.
task
in
[
"i2v"
,
"flf2v"
]
and
config
.
get
(
"use_image_encoder"
,
True
):
if
config
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
config
.
get
(
"use_image_encoder"
,
True
):
self
.
add_module
(
self
.
add_module
(
"proj_0"
,
"proj_0"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.0.weight"
,
"img_emb.proj.0.bias"
),
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.0.weight"
,
"img_emb.proj.0.bias"
),
...
@@ -73,3 +73,8 @@ class WanPreWeights(WeightModule):
...
@@ -73,3 +73,8 @@ class WanPreWeights(WeightModule):
"emb_pos"
,
"emb_pos"
,
TENSOR_REGISTER
[
"Default"
](
f
"img_emb.emb_pos"
),
TENSOR_REGISTER
[
"Default"
](
f
"img_emb.emb_pos"
),
)
)
if
config
.
task
==
"animate"
:
self
.
add_module
(
"pose_patch_embedding"
,
CONV3D_WEIGHT_REGISTER
[
"Default"
](
"pose_patch_embedding.weight"
,
"pose_patch_embedding.bias"
,
stride
=
self
.
patch_size
),
)
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
682037cd
...
@@ -285,7 +285,7 @@ class WanCrossAttention(WeightModule):
...
@@ -285,7 +285,7 @@ class WanCrossAttention(WeightModule):
)
)
self
.
add_module
(
"cross_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"cross_attn_1_type"
]]())
self
.
add_module
(
"cross_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"cross_attn_1_type"
]]())
if
self
.
config
.
task
in
[
"i2v"
,
"flf2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
config
.
task
in
[
"i2v"
,
"flf2v"
,
"animate"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
self
.
add_module
(
self
.
add_module
(
"cross_attn_k_img"
,
"cross_attn_k_img"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
...
...
lightx2v/models/runners/base_runner.py
View file @
682037cd
...
@@ -145,7 +145,7 @@ class BaseRunner(ABC):
...
@@ -145,7 +145,7 @@ class BaseRunner(ABC):
def
run_segment
(
self
,
total_steps
=
None
):
def
run_segment
(
self
,
total_steps
=
None
):
pass
pass
def
end_run_segment
(
self
):
def
end_run_segment
(
self
,
segment_idx
=
None
):
pass
pass
def
end_run
(
self
):
def
end_run
(
self
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment