Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
2fb721ab
Commit
2fb721ab
authored
May 22, 2025
by
gushiqiao
Committed by
GitHub
May 22, 2025
Browse files
Merge pull request #48 from ModelTC/dev_fixbugs
Dev fixbugs
parents
f4213c00
4fd83968
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
11 deletions
+11
-11
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+9
-9
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+2
-2
No files found.
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
2fb721ab
...
...
@@ -37,10 +37,10 @@ class WanTransformerInfer:
return
cu_seqlens_q
,
cu_seqlens_k
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
grid_sizes
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
return
self
.
infer_func
(
weights
,
grid_sizes
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
def
_infer_with_offload
(
self
,
weights
,
grid_sizes
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_with_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
for
block_idx
in
range
(
self
.
blocks_num
):
if
block_idx
==
0
:
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
blocks
[
0
]
...
...
@@ -63,7 +63,7 @@ class WanTransformerInfer:
return
x
def
_infer_with_phases_offload
(
self
,
weights
,
grid_sizes
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_with_phases_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
for
block_idx
in
range
(
weights
.
blocks_num
):
weights
.
blocks
[
block_idx
].
modulation
.
to_cuda
()
...
...
@@ -114,7 +114,7 @@ class WanTransformerInfer:
return
x
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
for
block_idx
in
range
(
self
.
blocks_num
):
x
=
self
.
infer_block
(
weights
.
blocks
[
block_idx
],
...
...
@@ -249,7 +249,7 @@ class WanTransformerInfer:
x
.
add_
(
y
*
c_gate_msa
.
squeeze
(
0
))
return
x
def
infer_block
(
self
,
weights
,
grid_sizes
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
if
embed0
.
dim
()
==
3
:
modulation
=
weights
.
modulation
.
tensor
.
unsqueeze
(
2
)
embed0
=
(
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
...
...
@@ -258,7 +258,7 @@ class WanTransformerInfer:
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
(
weights
.
modulation
.
tensor
+
embed0
).
chunk
(
6
,
dim
=
1
)
x
=
self
.
_infer_self_attn
(
weights
.
compute_phases
[
1
],
weights
.
compute_phases
[
0
],
x
,
shift_msa
,
scale_msa
,
...
...
@@ -267,6 +267,6 @@ class WanTransformerInfer:
freqs
,
seq_lens
,
)
x
=
self
.
_infer_cross_attn
(
weights
.
compute_phases
[
2
],
x
,
context
)
x
=
self
.
_infer_ffn
(
weights
.
compute_phases
[
3
],
x
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
)
x
=
self
.
_infer_cross_attn
(
weights
.
compute_phases
[
1
],
x
,
context
)
x
=
self
.
_infer_ffn
(
weights
.
compute_phases
[
2
],
x
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
)
return
x
lightx2v/models/networks/wan/model.py
View file @
2fb721ab
...
...
@@ -183,7 +183,7 @@ class WanModel:
self
.
post_weight
.
to_cuda
()
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
True
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
*
pre_infer_out
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_cond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
...
...
@@ -194,7 +194,7 @@ class WanModel:
if
self
.
config
[
"enable_cfg"
]:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
*
pre_infer_out
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
if
self
.
config
[
"feature_caching"
]
==
"Tea"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment