Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
131c8a46
Unverified
Commit
131c8a46
authored
Sep 28, 2025
by
gushiqiao
Committed by
GitHub
Sep 28, 2025
Browse files
[Fix] Fix model_io datacls (#340)
parent
682037cd
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
19 additions
and
23 deletions
+19
-23
lightx2v/models/input_encoders/hf/animate/motion_encoder.py
lightx2v/models/input_encoders/hf/animate/motion_encoder.py
+2
-2
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+2
-2
lightx2v/models/networks/wan/infer/animate/transformer_infer.py
...2v/models/networks/wan/infer/animate/transformer_infer.py
+2
-2
lightx2v/models/networks/wan/infer/audio/pre_infer.py
lightx2v/models/networks/wan/infer/audio/pre_infer.py
+1
-1
lightx2v/models/networks/wan/infer/audio/transformer_infer.py
...tx2v/models/networks/wan/infer/audio/transformer_infer.py
+2
-2
lightx2v/models/networks/wan/infer/module_io.py
lightx2v/models/networks/wan/infer/module_io.py
+1
-2
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
...2v/models/networks/wan/infer/offload/transformer_infer.py
+1
-1
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+1
-1
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+1
-1
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
+2
-2
lightx2v/models/runners/wan/wan_animate_runner.py
lightx2v/models/runners/wan/wan_animate_runner.py
+4
-7
No files found.
lightx2v/models/input_encoders/hf/animate/motion_encoder.py
100644 → 100755
View file @
131c8a46
...
...
@@ -293,8 +293,8 @@ class Generator(nn.Module):
self
.
dec
=
Synthesis
(
motion_dim
)
def
get_motion
(
self
,
img
):
#
motion_feat = self.enc.enc_motion(img)
motion_feat
=
torch
.
utils
.
checkpoint
.
checkpoint
((
self
.
enc
.
enc_motion
),
img
,
use_reentrant
=
True
)
motion_feat
=
self
.
enc
.
enc_motion
(
img
)
#
motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
with
torch
.
amp
.
autocast
(
"cuda"
,
dtype
=
torch
.
float32
):
motion
=
self
.
dec
.
direction
(
motion_feat
)
return
motion
lightx2v/models/networks/wan/audio_model.py
View file @
131c8a46
...
...
@@ -123,7 +123,7 @@ class WanAudioModel(WanModel):
@
torch
.
no_grad
()
def
_seq_parallel_pre_process
(
self
,
pre_infer_out
):
x
=
pre_infer_out
.
x
person_mask_latens
=
pre_infer_out
.
adapter_
output
[
"person_mask_latens"
]
person_mask_latens
=
pre_infer_out
.
adapter_
args
[
"person_mask_latens"
]
world_size
=
dist
.
get_world_size
(
self
.
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
self
.
seq_p_group
)
...
...
@@ -136,7 +136,7 @@ class WanAudioModel(WanModel):
pre_infer_out
.
x
=
torch
.
chunk
(
x
,
world_size
,
dim
=
0
)[
cur_rank
]
if
person_mask_latens
is
not
None
:
pre_infer_out
.
adapter_
output
[
"person_mask_latens"
]
=
torch
.
chunk
(
person_mask_latens
,
world_size
,
dim
=
1
)[
cur_rank
]
pre_infer_out
.
adapter_
args
[
"person_mask_latens"
]
=
torch
.
chunk
(
person_mask_latens
,
world_size
,
dim
=
1
)[
cur_rank
]
if
self
.
config
[
"model_cls"
]
in
[
"wan2.2"
,
"wan2.2_audio"
]
and
self
.
config
[
"task"
]
==
"i2v"
:
embed
,
embed0
=
pre_infer_out
.
embed
,
pre_infer_out
.
embed0
...
...
lightx2v/models/networks/wan/infer/animate/transformer_infer.py
View file @
131c8a46
...
...
@@ -14,8 +14,8 @@ class WanAnimateTransformerInfer(WanOffloadTransformerInfer):
def
infer_post_adapter
(
self
,
phase
,
x
,
pre_infer_out
):
if
phase
.
is_empty
():
return
x
T
=
pre_infer_out
.
motion_vec
.
shape
[
0
]
x_motion
=
phase
.
pre_norm_motion
.
apply
(
pre_infer_out
.
motion_vec
)
T
=
pre_infer_out
.
adapter_args
[
"
motion_vec
"
]
.
shape
[
0
]
x_motion
=
phase
.
pre_norm_motion
.
apply
(
pre_infer_out
.
adapter_args
[
"
motion_vec
"
]
)
x_feat
=
phase
.
pre_norm_feat
.
apply
(
x
)
kv
=
phase
.
linear1_kv
.
apply
(
x_motion
.
view
(
-
1
,
x_motion
.
shape
[
-
1
]))
kv
=
kv
.
view
(
T
,
-
1
,
kv
.
shape
[
-
1
])
...
...
lightx2v/models/networks/wan/infer/audio/pre_infer.py
View file @
131c8a46
...
...
@@ -128,5 +128,5 @@ class WanAudioPreInfer(WanPreInfer):
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
context
=
context
,
adapter_
output
=
{
"audio_encoder_output"
:
inputs
[
"audio_encoder_output"
],
"person_mask_latens"
:
person_mask_latens
},
adapter_
args
=
{
"audio_encoder_output"
:
inputs
[
"audio_encoder_output"
],
"person_mask_latens"
:
person_mask_latens
},
)
lightx2v/models/networks/wan/infer/audio/transformer_infer.py
View file @
131c8a46
...
...
@@ -22,8 +22,8 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
@
torch
.
no_grad
()
def
infer_post_adapter
(
self
,
phase
,
x
,
pre_infer_out
):
grid_sizes
=
pre_infer_out
.
grid_sizes
.
tensor
audio_encoder_output
=
pre_infer_out
.
adapter_
output
[
"audio_encoder_output"
]
person_mask_latens
=
pre_infer_out
.
adapter_
output
[
"person_mask_latens"
]
audio_encoder_output
=
pre_infer_out
.
adapter_
args
[
"audio_encoder_output"
]
person_mask_latens
=
pre_infer_out
.
adapter_
args
[
"person_mask_latens"
]
total_tokens
=
grid_sizes
[
0
].
prod
()
pre_frame_tokens
=
grid_sizes
[
0
][
1
:].
prod
()
n_tokens
=
total_tokens
-
pre_frame_tokens
# 去掉ref image的token数
...
...
lightx2v/models/networks/wan/infer/module_io.py
View file @
131c8a46
...
...
@@ -19,5 +19,4 @@ class WanPreInferModuleOutput:
seq_lens
:
torch
.
Tensor
freqs
:
torch
.
Tensor
context
:
torch
.
Tensor
motion_vec
:
torch
.
Tensor
adapter_output
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
adapter_args
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
View file @
131c8a46
...
...
@@ -217,7 +217,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
self
.
phase_params
[
"c_gate_msa"
],
)
if
hasattr
(
cur_phase
,
"after_proj"
):
pre_infer_out
.
adapter_
output
[
"hints"
].
append
(
cur_phase
.
after_proj
.
apply
(
x
))
pre_infer_out
.
adapter_
args
[
"hints"
].
append
(
cur_phase
.
after_proj
.
apply
(
x
))
elif
cur_phase_idx
==
3
:
x
=
self
.
infer_post_adapter
(
cur_phase
,
x
,
pre_infer_out
)
return
x
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
131c8a46
...
...
@@ -131,5 +131,5 @@ class WanPreInfer:
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
context
=
context
,
motion_vec
=
motion_vec
,
adapter_args
=
{
"
motion_vec
"
:
motion_vec
}
,
)
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
131c8a46
...
...
@@ -108,7 +108,7 @@ class WanTransformerInfer(BaseTransformerInfer):
y
=
self
.
infer_ffn
(
block
.
compute_phases
[
2
],
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
if
hasattr
(
block
.
compute_phases
[
2
],
"after_proj"
):
pre_infer_out
.
adapter_
output
[
"hints"
].
append
(
block
.
compute_phases
[
2
].
after_proj
.
apply
(
x
))
pre_infer_out
.
adapter_
args
[
"hints"
].
append
(
block
.
compute_phases
[
2
].
after_proj
.
apply
(
x
))
if
self
.
has_post_adapter
:
x
=
self
.
infer_post_adapter
(
block
.
compute_phases
[
3
],
x
,
pre_infer_out
)
...
...
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
View file @
131c8a46
...
...
@@ -20,7 +20,7 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
return
c
def
infer_vace_blocks
(
self
,
vace_blocks
,
pre_infer_out
):
pre_infer_out
.
adapter_
output
[
"hints"
]
=
[]
pre_infer_out
.
adapter_
args
[
"hints"
]
=
[]
self
.
infer_state
=
"vace"
if
hasattr
(
self
,
"weights_stream_mgr"
):
self
.
weights_stream_mgr
.
init
(
self
.
vace_blocks_num
,
self
.
phases_num
,
self
.
offload_ratio
)
...
...
@@ -33,5 +33,5 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
x
=
super
().
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
if
self
.
infer_state
==
"base"
and
self
.
block_idx
in
self
.
vace_blocks_mapping
:
hint_idx
=
self
.
vace_blocks_mapping
[
self
.
block_idx
]
x
=
x
+
pre_infer_out
.
adapter_
output
[
"hints"
][
hint_idx
]
*
pre_infer_out
.
adapter_
output
.
get
(
"context_scale"
,
1.0
)
x
=
x
+
pre_infer_out
.
adapter_
args
[
"hints"
][
hint_idx
]
*
pre_infer_out
.
adapter_
args
.
get
(
"context_scale"
,
1.0
)
return
x
lightx2v/models/runners/wan/wan_animate_runner.py
View file @
131c8a46
...
...
@@ -363,18 +363,15 @@ class WanAnimateRunner(WanRunner):
self
.
config
,
self
.
init_device
,
)
motion_encoder
,
face_encoder
=
self
.
load_encoder
()
motion_encoder
,
face_encoder
=
self
.
load_encoder
s
()
model
.
set_animate_encoders
(
motion_encoder
,
face_encoder
)
return
model
def
load_encoder
(
self
):
motion_encoder
=
Generator
(
size
=
512
,
style_dim
=
512
,
motion_dim
=
20
).
eval
().
requires_grad_
(
False
).
to
(
GET_DTYPE
())
face_encoder
=
FaceEncoder
(
in_dim
=
512
,
hidden_dim
=
5120
,
num_heads
=
4
).
eval
().
requires_grad_
(
False
).
to
(
GET_DTYPE
())
def
load_encoder
s
(
self
):
motion_encoder
=
Generator
(
size
=
512
,
style_dim
=
512
,
motion_dim
=
20
).
eval
().
requires_grad_
(
False
).
to
(
GET_DTYPE
())
.
cuda
()
face_encoder
=
FaceEncoder
(
in_dim
=
512
,
hidden_dim
=
5120
,
num_heads
=
4
).
eval
().
requires_grad_
(
False
).
to
(
GET_DTYPE
())
.
cuda
()
motion_weight_dict
=
remove_substrings_from_keys
(
load_weights
(
self
.
config
[
"model_path"
],
include_keys
=
[
"motion_encoder"
]),
"motion_encoder."
)
face_weight_dict
=
remove_substrings_from_keys
(
load_weights
(
self
.
config
[
"model_path"
],
include_keys
=
[
"face_encoder"
]),
"face_encoder."
)
motion_encoder
.
load_state_dict
(
motion_weight_dict
)
face_encoder
.
load_state_dict
(
face_weight_dict
)
if
not
self
.
config
[
"cpu_offload"
]:
motion_encoder
=
motion_encoder
.
cuda
()
face_encoder
=
face_encoder
.
cuda
()
return
motion_encoder
,
face_encoder
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment