Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
7d109a7c
Commit
7d109a7c
authored
Jul 21, 2025
by
helloyongyang
Browse files
Simplify wan pre infer & Remove seq_len in WanScheduler
parent
68a807f1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
15 deletions
+7
-15
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+7
-10
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+0
-5
No files found.
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
7d109a7c
...
...
@@ -27,7 +27,7 @@ class WanPreInfer:
self
.
scheduler
=
scheduler
def
infer
(
self
,
weights
,
inputs
,
positive
,
kv_start
=
0
,
kv_end
=
0
):
x
=
[
self
.
scheduler
.
latents
]
x
=
self
.
scheduler
.
latents
if
self
.
scheduler
.
flag_df
:
t
=
self
.
scheduler
.
df_timesteps
[
self
.
scheduler
.
step_index
].
unsqueeze
(
0
)
...
...
@@ -39,7 +39,6 @@ class WanPreInfer:
context
=
inputs
[
"text_encoder_output"
][
"context"
]
else
:
context
=
inputs
[
"text_encoder_output"
][
"context_null"
]
seq_len
=
self
.
scheduler
.
seq_len
if
self
.
task
==
"i2v"
:
clip_fea
=
inputs
[
"image_encoder_output"
][
"clip_encoder_out"
]
...
...
@@ -50,16 +49,14 @@ class WanPreInfer:
idx_s
=
kv_start
//
frame_seq_length
idx_e
=
kv_end
//
frame_seq_length
image_encoder
=
image_encoder
[:,
idx_s
:
idx_e
,
:,
:]
y
=
[
image_encoder
]
x
=
[
torch
.
cat
([
u
,
v
],
dim
=
0
)
for
u
,
v
in
zip
(
x
,
y
)]
y
=
image_encoder
x
=
torch
.
cat
([
x
,
y
],
dim
=
0
)
# embeddings
x
=
[
weights
.
patch_embedding
.
apply
(
u
.
unsqueeze
(
0
))
for
u
in
x
]
grid_sizes
=
torch
.
stack
([
torch
.
tensor
(
u
.
shape
[
2
:],
dtype
=
torch
.
long
)
for
u
in
x
])
x
=
[
u
.
flatten
(
2
).
transpose
(
1
,
2
)
for
u
in
x
]
seq_lens
=
torch
.
tensor
([
u
.
size
(
1
)
for
u
in
x
],
dtype
=
torch
.
long
).
cuda
()
assert
seq_lens
.
max
()
<=
seq_len
x
=
torch
.
cat
([
torch
.
cat
([
u
,
u
.
new_zeros
(
1
,
seq_len
-
u
.
size
(
1
),
u
.
size
(
2
))],
dim
=
1
)
for
u
in
x
])
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
grid_sizes
=
torch
.
tensor
(
x
.
shape
[
2
:],
dtype
=
torch
.
long
).
unsqueeze
(
0
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
long
).
cuda
().
unsqueeze
(
0
)
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
if
self
.
enable_dynamic_cfg
:
...
...
lightx2v/models/schedulers/wan/scheduler.py
View file @
7d109a7c
...
...
@@ -27,11 +27,6 @@ class WanScheduler(BaseScheduler):
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
if
self
.
config
.
task
in
[
"t2v"
]:
self
.
seq_len
=
math
.
ceil
((
self
.
config
.
target_shape
[
2
]
*
self
.
config
.
target_shape
[
3
])
/
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
*
self
.
config
.
target_shape
[
1
])
elif
self
.
config
.
task
in
[
"i2v"
]:
self
.
seq_len
=
((
self
.
config
.
target_video_length
-
1
)
//
self
.
config
.
vae_stride
[
0
]
+
1
)
*
self
.
config
.
lat_h
*
self
.
config
.
lat_w
//
(
self
.
config
.
patch_size
[
1
]
*
self
.
config
.
patch_size
[
2
])
alphas
=
np
.
linspace
(
1
,
1
/
self
.
num_train_timesteps
,
self
.
num_train_timesteps
)[::
-
1
].
copy
()
sigmas
=
1.0
-
alphas
sigmas
=
torch
.
from_numpy
(
sigmas
).
to
(
dtype
=
torch
.
float32
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment