Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
9b13cab2
"docs/vscode:/vscode.git/clone" did not exist on "c57b8b184bcd21723df284767f8262839d9c60d6"
Unverified
Commit
9b13cab2
authored
Nov 27, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Nov 27, 2025
Browse files
Update wan infer (#524)
parent
d242358f
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
919 additions
and
15 deletions
+919
-15
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+17
-15
lightx2v/models/networks/wan/infer/triton_ops.py
lightx2v/models/networks/wan/infer/triton_ops.py
+902
-0
No files found.
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
9b13cab2
...
...
@@ -5,6 +5,7 @@ import torch
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.utils.envs
import
*
from
.triton_ops
import
fuse_scale_shift_kernel
from
.utils
import
apply_wan_rope_with_chunk
,
apply_wan_rope_with_flashinfer
,
apply_wan_rope_with_torch
...
...
@@ -135,16 +136,15 @@ class WanTransformerInfer(BaseTransformerInfer):
if
hasattr
(
phase
,
"smooth_norm1_weight"
):
norm1_weight
=
(
1
+
scale_msa
.
squeeze
())
*
phase
.
smooth_norm1_weight
.
tensor
norm1_bias
=
shift_msa
.
squeeze
()
*
phase
.
smooth_norm1_bias
.
tensor
norm1_out
=
phase
.
norm1
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm1_out
=
norm1_out
.
to
(
self
.
sensitive_layer_dtype
)
norm1_out
.
mul_
(
norm1_weight
).
add_
(
norm1_bias
)
else
:
norm1_weight
=
1
+
scale_msa
.
squeeze
()
norm1_bias
=
shift_msa
.
squeeze
()
norm1_out
=
phase
.
norm1
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm1_out
=
norm1_out
.
to
(
self
.
sensitive_layer_dtype
)
norm1_out
.
mul_
(
norm1_weight
).
add_
(
norm1_bias
)
norm1_out
=
phase
.
norm1
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm1_out
=
norm1_out
.
to
(
self
.
sensitive_layer_dtype
)
norm1_out
=
fuse_scale_shift_kernel
(
norm1_out
,
scale
=
scale_msa
,
shift
=
shift_msa
).
squeeze
(
0
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm1_out
=
norm1_out
.
to
(
self
.
infer_dtype
)
...
...
@@ -274,14 +274,16 @@ class WanTransformerInfer(BaseTransformerInfer):
if
hasattr
(
phase
,
"smooth_norm2_weight"
):
norm2_weight
=
(
1
+
c_scale_msa
.
squeeze
())
*
phase
.
smooth_norm2_weight
.
tensor
norm2_bias
=
c_shift_msa
.
squeeze
()
*
phase
.
smooth_norm2_bias
.
tensor
norm2_out
=
phase
.
norm2
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm2_out
=
norm2_out
.
to
(
self
.
sensitive_layer_dtype
)
norm2_out
.
mul_
(
norm2_weight
).
add_
(
norm2_bias
)
else
:
norm2_weight
=
1
+
c_scale_msa
.
squeeze
()
norm2_bias
=
c_shift_msa
.
squeeze
()
norm2_out
=
phase
.
norm2
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm2_out
=
norm2_out
.
to
(
self
.
sensitive_layer_dtype
)
norm2_out
=
fuse_scale_shift_kernel
(
norm2_out
,
scale
=
c_scale_msa
,
shift
=
c_shift_msa
).
squeeze
(
0
)
norm2_out
=
phase
.
norm2
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm2_out
=
norm2_out
.
to
(
self
.
sensitive_layer_dtype
)
norm2_out
.
mul_
(
norm2_weight
).
add_
(
norm2_bias
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm2_out
=
norm2_out
.
to
(
self
.
infer_dtype
)
...
...
lightx2v/models/networks/wan/infer/triton_ops.py
0 → 100644
View file @
9b13cab2
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment