Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
834b09c4
"vscode:/vscode.git/clone" did not exist on "9000d93789e5ee2a2ba86c21e4f0ddc16f2a9343"
Commit
834b09c4
authored
Jul 21, 2025
by
helloyongyang
Browse files
[Feature] Support progressive resolution
parent
3e67df1c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
57 additions
and
48 deletions
+57
-48
configs/changing_resolution/wan_i2v.json
configs/changing_resolution/wan_i2v.json
+2
-2
configs/changing_resolution/wan_t2v.json
configs/changing_resolution/wan_t2v.json
+2
-2
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+2
-2
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+10
-18
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
...2v/models/schedulers/wan/changing_resolution/scheduler.py
+41
-24
No files found.
configs/changing_resolution/wan_i2v.json
View file @
834b09c4
...
...
@@ -12,6 +12,6 @@
"enable_cfg"
:
true
,
"cpu_offload"
:
false
,
"changing_resolution"
:
true
,
"resolution_rate"
:
0.75
,
"changing_resolution_steps"
:
20
"resolution_rate"
:
[
1.0
,
0.75
]
,
"changing_resolution_steps"
:
[
5
,
25
]
}
configs/changing_resolution/wan_t2v.json
View file @
834b09c4
...
...
@@ -13,6 +13,6 @@
"enable_cfg"
:
true
,
"cpu_offload"
:
false
,
"changing_resolution"
:
true
,
"resolution_rate"
:
0.75
,
"changing_resolution_steps"
:
25
"resolution_rate"
:
[
1.0
,
0.75
]
,
"changing_resolution_steps"
:
[
10
,
35
]
}
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
834b09c4
...
...
@@ -44,8 +44,8 @@ class WanPreInfer:
if
self
.
task
==
"i2v"
:
clip_fea
=
inputs
[
"image_encoder_output"
][
"clip_encoder_out"
]
if
self
.
config
.
get
(
"changing_resolution"
,
False
)
and
self
.
scheduler
.
step_index
>
self
.
config
.
changing_resolution_steps
-
1
:
image_encoder
=
inputs
[
"image_encoder_output"
][
"vae_encode_out
_ori
gin
al
_resolution
"
]
if
self
.
config
.
get
(
"changing_resolution"
,
False
):
image_encoder
=
inputs
[
"image_encoder_output"
][
"vae_encode_out
"
][
self
.
scheduler
.
chan
gin
g
_resolution
_index
]
else
:
image_encoder
=
inputs
[
"image_encoder_output"
][
"vae_encode_out"
]
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
834b09c4
...
...
@@ -211,12 +211,12 @@ class WanRunner(DefaultRunner):
if
self
.
config
.
get
(
"changing_resolution"
,
False
):
self
.
config
.
lat_h
,
self
.
config
.
lat_w
=
lat_h
,
lat_w
vae_encode_out_
original_resolution
=
self
.
get_vae_encoder_output
(
img
,
lat_h
,
lat_w
)
# get vae encode out at low resolution
lat_h
,
lat_w
=
int
(
self
.
config
.
lat_h
*
self
.
config
.
resolution_rate
)
//
2
*
2
,
int
(
self
.
config
.
lat_w
*
self
.
config
.
resolution_rate
)
//
2
*
2
vae_encode_out
=
self
.
get_vae_encoder_output
(
img
,
lat_h
,
lat_w
)
return
vae_encode_out
,
vae_encode_out_original_resolution
# low resolution, original resolution
vae_encode_out_
list
=
[]
for
i
in
range
(
len
(
self
.
config
[
"resolution_rate"
])):
lat_h
,
lat_w
=
int
(
self
.
config
.
lat_h
*
self
.
config
.
resolution_rate
[
i
])
//
2
*
2
,
int
(
self
.
config
.
lat_w
*
self
.
config
.
resolution_rate
[
i
])
//
2
*
2
vae_encode_out_list
.
append
(
self
.
get_vae_encoder_output
(
img
,
lat_h
,
lat_w
))
vae_encode_out
_list
.
append
(
self
.
get_vae_encoder_output
(
img
,
self
.
config
.
lat_h
,
self
.
config
.
lat_w
)
)
return
vae_encode_out
_list
else
:
self
.
config
.
lat_h
,
self
.
config
.
lat_w
=
lat_h
,
lat_w
vae_encode_out
=
self
.
get_vae_encoder_output
(
img
,
lat_h
,
lat_w
)
...
...
@@ -259,18 +259,10 @@ class WanRunner(DefaultRunner):
return
vae_encode_out
def
get_encoder_output_i2v
(
self
,
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
,
img
):
if
self
.
config
.
get
(
"changing_resolution"
,
False
):
image_encoder_output
=
{
"clip_encoder_out"
:
clip_encoder_out
,
"vae_encode_out"
:
vae_encode_out
[
0
],
"vae_encode_out_original_resolution"
:
vae_encode_out
[
1
],
}
else
:
image_encoder_output
=
{
"clip_encoder_out"
:
clip_encoder_out
,
"vae_encode_out"
:
vae_encode_out
,
}
image_encoder_output
=
{
"clip_encoder_out"
:
clip_encoder_out
,
"vae_encode_out"
:
vae_encode_out
,
}
return
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
,
...
...
lightx2v/models/schedulers/wan/changing_resolution/scheduler.py
View file @
834b09c4
...
...
@@ -5,33 +5,48 @@ from lightx2v.models.schedulers.wan.scheduler import WanScheduler
class
WanScheduler4ChangingResolution
(
WanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
resolution_rate
=
config
.
get
(
"resolution_rate"
,
0.75
)
self
.
changing_resolution_steps
=
config
.
get
(
"changing_resolution_steps"
,
config
.
infer_steps
//
2
)
if
"resolution_rate"
not
in
config
:
config
[
"resolution_rate"
]
=
[
0.75
]
if
"changing_resolution_steps"
not
in
config
:
config
[
"changing_resolution_steps"
]
=
[
config
.
infer_steps
//
2
]
assert
len
(
config
[
"resolution_rate"
])
==
len
(
config
[
"changing_resolution_steps"
])
def
prepare_latents
(
self
,
target_shape
,
dtype
=
torch
.
float32
):
self
.
latents
=
torch
.
randn
(
target_shape
[
0
],
target_shape
[
1
],
int
(
target_shape
[
2
]
*
self
.
resolution_rate
)
//
2
*
2
,
int
(
target_shape
[
3
]
*
self
.
resolution_rate
)
//
2
*
2
,
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
)
self
.
latents_list
=
[]
for
i
in
range
(
len
(
self
.
config
[
"resolution_rate"
])):
self
.
latents_list
.
append
(
torch
.
randn
(
target_shape
[
0
],
target_shape
[
1
],
int
(
target_shape
[
2
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
int
(
target_shape
[
3
]
*
self
.
config
[
"resolution_rate"
][
i
])
//
2
*
2
,
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
)
)
self
.
noise_original_resolution
=
torch
.
randn
(
target_shape
[
0
],
target_shape
[
1
],
target_shape
[
2
],
target_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
# add original resolution latents
self
.
latents_list
.
append
(
torch
.
randn
(
target_shape
[
0
],
target_shape
[
1
],
target_shape
[
2
],
target_shape
[
3
],
dtype
=
dtype
,
device
=
self
.
device
,
generator
=
self
.
generator
,
)
)
# set initial latents
self
.
latents
=
self
.
latents_list
[
0
]
self
.
changing_resolution_index
=
0
def
step_post
(
self
):
if
self
.
step_index
==
self
.
changing_resolution_steps
-
1
:
if
self
.
step_index
+
1
in
self
.
config
[
"
changing_resolution_steps
"
]
:
self
.
step_post_upsample
()
self
.
changing_resolution_index
+=
1
else
:
super
().
step_post
()
...
...
@@ -45,19 +60,21 @@ class WanScheduler4ChangingResolution(WanScheduler):
# 2. upsample clean noise to target shape
denoised_sample_5d
=
denoised_sample
.
unsqueeze
(
0
)
# (C,T,H,W) -> (1,C,T,H,W)
clean_noise
=
torch
.
nn
.
functional
.
interpolate
(
denoised_sample_5d
,
size
=
(
self
.
config
.
target_shape
[
1
],
self
.
config
.
target_shape
[
2
],
self
.
config
.
target_shape
[
3
]),
mode
=
"trilinear"
)
shape_to_upsampled
=
self
.
latents_list
[
self
.
changing_resolution_index
+
1
].
shape
[
1
:]
clean_noise
=
torch
.
nn
.
functional
.
interpolate
(
denoised_sample_5d
,
size
=
shape_to_upsampled
,
mode
=
"trilinear"
)
clean_noise
=
clean_noise
.
squeeze
(
0
)
# (1,C,T,H,W) -> (C,T,H,W)
# 3. add noise to clean noise
noisy_sample
=
self
.
add_noise
(
clean_noise
,
self
.
noise_ori
gin
al
_resolution
,
self
.
timesteps
[
self
.
step_index
+
1
])
noisy_sample
=
self
.
add_noise
(
clean_noise
,
self
.
latents_list
[
self
.
chan
gin
g
_resolution
_index
+
1
]
,
self
.
timesteps
[
self
.
step_index
+
1
])
# 4. update latents
self
.
latents
=
noisy_sample
# self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed
# 5. update timesteps using shift +
2
更激进的去噪
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
+
2
)
# 5. update timesteps using shift +
self.changing_resolution_index + 1
更激进的去噪
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
+
self
.
changing_resolution_index
+
1
)
def
add_noise
(
self
,
original_samples
,
noise
,
timesteps
):
sigma
=
self
.
sigmas
[
self
.
step_index
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment