Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
bff9bd05
Commit
bff9bd05
authored
Jun 29, 2025
by
helloyongyang
Browse files
update wan infer code
parent
1c065c06
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
70 deletions
+40
-70
lightx2v/common/ops/norm/layer_norm_weight.py
lightx2v/common/ops/norm/layer_norm_weight.py
+3
-1
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+37
-69
No files found.
lightx2v/common/ops/norm/layer_norm_weight.py
View file @
bff9bd05
...
...
@@ -27,6 +27,8 @@ class LNWeightTemplate(metaclass=ABCMeta):
self
.
bias
=
None
def
_calculate_size
(
self
):
if
self
.
weight
is
None
:
return
0
if
self
.
bias
is
not
None
:
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
+
self
.
bias
.
numel
()
*
self
.
bias
.
element_size
()
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
bff9bd05
...
...
@@ -91,13 +91,7 @@ class WanTransformerInfer:
for
block_idx
in
range
(
weights
.
blocks_num
):
weights
.
blocks
[
block_idx
].
modulation
.
to_cuda
()
if
embed0
.
dim
()
==
3
:
modulation
=
weights
.
blocks
[
block_idx
].
modulation
.
tensor
.
unsqueeze
(
2
)
current_embed0
=
(
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
[
ei
.
squeeze
(
1
)
for
ei
in
current_embed0
]
elif
embed0
.
dim
()
==
2
:
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
(
weights
.
blocks
[
block_idx
].
modulation
.
tensor
+
embed0
).
chunk
(
6
,
dim
=
1
)
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_phase_1
(
weights
.
blocks
[
block_idx
],
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
for
phase_idx
in
range
(
self
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
...
...
@@ -108,22 +102,12 @@ class WanTransformerInfer:
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
cur_phase_idx
,
cur_phase
=
self
.
weights_stream_mgr
.
active_weights
[
0
]
if
cur_phase_idx
==
0
:
x
=
self
.
_infer_self_attn
(
cur_phase
,
x
,
shift_msa
,
scale_msa
,
gate_msa
,
grid_sizes
,
freqs
,
seq_lens
,
)
y_out
=
self
.
infer_phase_2
(
cur_phase
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
)
elif
cur_phase_idx
==
1
:
x
=
self
.
_infer_cross_attn
(
cur_phase
,
x
,
context
)
attn_out
=
self
.
infer_phase_3
(
cur_phase
,
x
,
context
,
y_out
,
gate_msa
)
elif
cur_phase_idx
==
2
:
x
=
self
.
_infer_ffn
(
cur_phase
,
x
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
)
y
=
self
.
infer_phase_4
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
infer_phase_5
(
x
,
y
,
c_gate_msa
)
is_last_phase
=
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
if
not
is_last_phase
:
...
...
@@ -146,12 +130,7 @@ class WanTransformerInfer:
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
weights
.
blocks
[
block_idx
].
modulation
.
to_cuda
()
if
embed0
.
dim
()
==
3
:
modulation
=
weights
.
blocks
[
block_idx
].
modulation
.
tensor
.
unsqueeze
(
2
)
current_embed0
=
(
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
[
ei
.
squeeze
(
1
)
for
ei
in
current_embed0
]
elif
embed0
.
dim
()
==
2
:
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
(
weights
.
blocks
[
block_idx
].
modulation
.
tensor
+
embed0
).
chunk
(
6
,
dim
=
1
)
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_phase_1
(
weights
.
blocks
[
block_idx
],
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
for
phase_idx
in
range
(
self
.
weights_stream_mgr
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
...
...
@@ -170,20 +149,12 @@ class WanTransformerInfer:
)
=
self
.
weights_stream_mgr
.
active_weights
[
0
]
if
cur_phase_idx
==
0
:
x
=
self
.
_infer_self_attn
(
cur_phase
,
x
,
shift_msa
,
scale_msa
,
gate_msa
,
grid_sizes
,
freqs
,
seq_lens
,
)
y_out
=
self
.
infer_phase_2
(
cur_phase
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
)
elif
cur_phase_idx
==
1
:
x
=
self
.
_
infer_
cross_attn
(
cur_phase
,
x
,
context
)
attn_out
=
self
.
infer_
phase_3
(
cur_phase
,
x
,
context
,
y_out
,
gate_msa
)
elif
cur_phase_idx
==
2
:
x
=
self
.
_infer_ffn
(
cur_phase
,
x
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
)
y
=
self
.
infer_phase_4
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
infer_phase_5
(
x
,
y
,
c_gate_msa
)
if
not
(
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
):
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
...
...
@@ -213,7 +184,24 @@ class WanTransformerInfer:
)
return
x
def
_infer_self_attn
(
self
,
weights
,
x
,
shift_msa
,
scale_msa
,
gate_msa
,
grid_sizes
,
freqs
,
seq_lens
):
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_phase_1
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
y_out
=
self
.
infer_phase_2
(
weights
.
compute_phases
[
0
],
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
)
attn_out
=
self
.
infer_phase_3
(
weights
.
compute_phases
[
1
],
x
,
context
,
y_out
,
gate_msa
)
y
=
self
.
infer_phase_4
(
weights
.
compute_phases
[
2
],
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
infer_phase_5
(
x
,
y
,
c_gate_msa
)
return
x
def
infer_phase_1
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
if
embed0
.
dim
()
==
3
:
modulation
=
weights
.
modulation
.
tensor
.
unsqueeze
(
2
)
embed0
=
(
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
[
ei
.
squeeze
(
1
)
for
ei
in
embed0
]
elif
embed0
.
dim
()
==
2
:
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
(
weights
.
modulation
.
tensor
+
embed0
).
chunk
(
6
,
dim
=
1
)
return
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
def
infer_phase_2
(
self
,
weights
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
):
if
hasattr
(
weights
,
"smooth_norm1_weight"
):
norm1_weight
=
(
1
+
scale_msa
)
*
weights
.
smooth_norm1_weight
.
tensor
norm1_bias
=
shift_msa
*
weights
.
smooth_norm1_bias
.
tensor
...
...
@@ -269,14 +257,14 @@ class WanTransformerInfer:
)
y
=
weights
.
self_attn_o
.
apply
(
attn_out
)
return
y
def
infer_phase_3
(
self
,
weights
,
x
,
context
,
y_out
,
gate_msa
):
if
GET_DTYPE
()
!=
"BF16"
:
x
=
x
.
float
()
+
y
.
float
()
*
gate_msa
.
squeeze
(
0
)
x
=
x
.
float
()
+
y
_out
.
float
()
*
gate_msa
.
squeeze
(
0
)
else
:
x
.
add_
(
y
*
gate_msa
.
squeeze
(
0
))
return
x
x
.
add_
(
y_out
*
gate_msa
.
squeeze
(
0
))
def
_infer_cross_attn
(
self
,
weights
,
x
,
context
):
norm3_out
=
weights
.
norm3
.
apply
(
x
)
if
self
.
task
==
"i2v"
:
context_img
=
context
[:
257
]
...
...
@@ -331,10 +319,10 @@ class WanTransformerInfer:
attn_out
=
attn_out
+
img_attn_out
attn_out
=
weights
.
cross_attn_o
.
apply
(
attn_out
)
x
.
add_
(
attn_out
)
return
x
return
attn_out
def
_infer_ffn
(
self
,
weights
,
x
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
):
def
infer_phase_4
(
self
,
weights
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
):
x
.
add_
(
attn_out
)
if
hasattr
(
weights
,
"smooth_norm2_weight"
):
norm2_weight
=
(
1
+
c_scale_msa
.
squeeze
(
0
))
*
weights
.
smooth_norm2_weight
.
tensor
norm2_bias
=
c_shift_msa
.
squeeze
(
0
)
*
weights
.
smooth_norm2_bias
.
tensor
...
...
@@ -352,31 +340,11 @@ class WanTransformerInfer:
y
=
weights
.
ffn_0
.
apply
(
norm2_out
)
y
=
torch
.
nn
.
functional
.
gelu
(
y
,
approximate
=
"tanh"
)
y
=
weights
.
ffn_2
.
apply
(
y
)
return
y
def
infer_phase_5
(
self
,
x
,
y
,
c_gate_msa
):
if
GET_DTYPE
()
!=
"BF16"
:
x
=
x
.
float
()
+
y
.
float
()
*
c_gate_msa
.
squeeze
(
0
)
else
:
x
.
add_
(
y
*
c_gate_msa
.
squeeze
(
0
))
return
x
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
if
embed0
.
dim
()
==
3
:
modulation
=
weights
.
modulation
.
tensor
.
unsqueeze
(
2
)
embed0
=
(
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
[
ei
.
squeeze
(
1
)
for
ei
in
embed0
]
elif
embed0
.
dim
()
==
2
:
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
(
weights
.
modulation
.
tensor
+
embed0
).
chunk
(
6
,
dim
=
1
)
x
=
self
.
_infer_self_attn
(
weights
.
compute_phases
[
0
],
x
,
shift_msa
,
scale_msa
,
gate_msa
,
grid_sizes
,
freqs
,
seq_lens
,
)
x
=
self
.
_infer_cross_attn
(
weights
.
compute_phases
[
1
],
x
,
context
)
x
=
self
.
_infer_ffn
(
weights
.
compute_phases
[
2
],
x
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
)
return
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment