Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
d061ae81
Commit
d061ae81
authored
Aug 20, 2025
by
gushiqiao
Committed by
GitHub
Aug 20, 2025
Browse files
[Fea] add approximate patch vae (#230)
parent
fba9754a
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
52 additions
and
25 deletions
+52
-25
configs/audio_driven/wan_i2v_audio_dist.json
configs/audio_driven/wan_i2v_audio_dist.json
+2
-2
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+5
-4
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+4
-4
lightx2v/models/runners/wan/wan_skyreels_v2_df_runner.py
lightx2v/models/runners/wan/wan_skyreels_v2_df_runner.py
+2
-2
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+39
-13
No files found.
configs/audio_driven/wan_i2v_audio_dist.json
100644 → 100755
View file @
d061ae81
...
@@ -17,8 +17,8 @@
...
@@ -17,8 +17,8 @@
"use_31_block"
:
false
,
"use_31_block"
:
false
,
"adaptive_resize"
:
true
,
"adaptive_resize"
:
true
,
"parallel"
:
{
"parallel"
:
{
"vae_p_size"
:
4
,
"seq_p_size"
:
4
,
"seq_p_size"
:
4
,
"seq_p_attn_type"
:
"ulysses"
"seq_p_attn_type"
:
"ulysses"
,
"use_patch_vae"
:
false
}
}
}
}
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
d061ae81
...
@@ -273,10 +273,10 @@ class VideoGenerator:
...
@@ -273,10 +273,10 @@ class VideoGenerator:
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
)
,
self
.
config
).
to
(
dtype
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
)).
to
(
dtype
)
_
,
prev_mask
=
self
.
_wan22_masks_like
([
self
.
model
.
scheduler
.
latents
],
zero
=
True
,
prev_length
=
prev_latents
.
shape
[
1
])
_
,
prev_mask
=
self
.
_wan22_masks_like
([
self
.
model
.
scheduler
.
latents
],
zero
=
True
,
prev_length
=
prev_latents
.
shape
[
1
])
else
:
else
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
)
,
self
.
config
)[
0
].
to
(
dtype
)
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
vae_dtype
))[
0
].
to
(
dtype
)
if
prev_video
is
not
None
:
if
prev_video
is
not
None
:
prev_token_length
=
(
prev_frame_length
-
1
)
//
4
+
1
prev_token_length
=
(
prev_frame_length
-
1
)
//
4
+
1
...
@@ -370,6 +370,7 @@ class VideoGenerator:
...
@@ -370,6 +370,7 @@ class VideoGenerator:
# Decode latents
# Decode latents
latents
=
self
.
model
.
scheduler
.
latents
latents
=
self
.
model
.
scheduler
.
latents
generator
=
self
.
model
.
scheduler
.
generator
generator
=
self
.
model
.
scheduler
.
generator
with
ProfilingContext
(
"Run VAE Decoder"
):
gen_video
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
gen_video
=
self
.
vae_decoder
.
decode
(
latents
,
generator
=
generator
,
config
=
self
.
config
)
gen_video
=
torch
.
clamp
(
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
gen_video
=
torch
.
clamp
(
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
...
@@ -667,7 +668,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -667,7 +668,7 @@ class WanAudioRunner(WanRunner): # type:ignore
# vae encode
# vae encode
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
vae_encoder_out
=
vae_model
.
encode
(
cond_frms
.
to
(
torch
.
float
)
,
config
)
vae_encoder_out
=
vae_model
.
encode
(
cond_frms
.
to
(
torch
.
float
))
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
vae_encoder_out
=
vae_encoder_out
.
unsqueeze
(
0
).
to
(
GET_DTYPE
())
vae_encoder_out
=
vae_encoder_out
.
unsqueeze
(
0
).
to
(
GET_DTYPE
())
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
d061ae81
...
@@ -135,7 +135,7 @@ class WanRunner(DefaultRunner):
...
@@ -135,7 +135,7 @@ class WanRunner(DefaultRunner):
vae_config
=
{
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.1_VAE.pth"
),
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.1_VAE.pth"
),
"device"
:
vae_device
,
"device"
:
vae_device
,
"parallel"
:
self
.
config
.
parallel
and
self
.
config
.
parallel
.
get
(
"vae_p_size"
,
False
)
and
self
.
config
.
parallel
.
vae_p_size
>
1
,
"parallel"
:
self
.
config
.
parallel
,
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"cpu_offload"
:
vae_offload
,
"cpu_offload"
:
vae_offload
,
}
}
...
@@ -155,7 +155,7 @@ class WanRunner(DefaultRunner):
...
@@ -155,7 +155,7 @@ class WanRunner(DefaultRunner):
vae_config
=
{
vae_config
=
{
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.1_VAE.pth"
),
"vae_pth"
:
find_torch_model_path
(
self
.
config
,
"vae_pth"
,
"Wan2.1_VAE.pth"
),
"device"
:
vae_device
,
"device"
:
vae_device
,
"parallel"
:
self
.
config
.
parallel
and
self
.
config
.
parallel
.
get
(
"vae_p_size"
,
False
)
and
self
.
config
.
parallel
.
vae_p_size
>
1
,
"parallel"
:
self
.
config
.
parallel
,
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"use_tiling"
:
self
.
config
.
get
(
"use_tiling_vae"
,
False
),
"cpu_offload"
:
vae_offload
,
"cpu_offload"
:
vae_offload
,
}
}
...
@@ -313,7 +313,7 @@ class WanRunner(DefaultRunner):
...
@@ -313,7 +313,7 @@ class WanRunner(DefaultRunner):
dim
=
1
,
dim
=
1
,
).
cuda
()
).
cuda
()
vae_encoder_out
=
self
.
vae_encoder
.
encode
([
vae_input
]
,
self
.
config
)[
0
]
vae_encoder_out
=
self
.
vae_encoder
.
encode
([
vae_input
])[
0
]
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
vae_encoder
del
self
.
vae_encoder
...
@@ -497,5 +497,5 @@ class Wan22DenseRunner(WanRunner):
...
@@ -497,5 +497,5 @@ class Wan22DenseRunner(WanRunner):
return
vae_encoder_out
return
vae_encoder_out
def
get_vae_encoder_output
(
self
,
img
):
def
get_vae_encoder_output
(
self
,
img
):
z
=
self
.
vae_encoder
.
encode
(
img
,
self
.
config
)
z
=
self
.
vae_encoder
.
encode
(
img
)
return
z
return
z
lightx2v/models/runners/wan/wan_skyreels_v2_df_runner.py
View file @
d061ae81
...
@@ -37,7 +37,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
...
@@ -37,7 +37,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
config
.
lat_h
=
lat_h
config
.
lat_h
=
lat_h
config
.
lat_w
=
lat_w
config
.
lat_w
=
lat_w
vae_encoder_out
=
vae_model
.
encode
([
torch
.
nn
.
functional
.
interpolate
(
img
[
None
].
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
).
cuda
()]
,
config
)[
0
]
vae_encoder_out
=
vae_model
.
encode
([
torch
.
nn
.
functional
.
interpolate
(
img
[
None
].
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
).
cuda
()])[
0
]
vae_encoder_out
=
vae_encoder_out
.
to
(
GET_DTYPE
())
vae_encoder_out
=
vae_encoder_out
.
to
(
GET_DTYPE
())
return
vae_encoder_out
return
vae_encoder_out
...
@@ -87,7 +87,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
...
@@ -87,7 +87,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
for
i
in
range
(
n_iter
):
for
i
in
range
(
n_iter
):
if
output_video
is
not
None
:
# i !=0
if
output_video
is
not
None
:
# i !=0
prefix_video
=
output_video
[:,
:,
-
overlap_history
:].
to
(
self
.
model
.
scheduler
.
device
)
prefix_video
=
output_video
[:,
:,
-
overlap_history
:].
to
(
self
.
model
.
scheduler
.
device
)
prefix_video
=
self
.
vae_model
.
encode
(
prefix_video
,
self
.
config
)[
0
]
# [(b, c, f, h, w)]
prefix_video
=
self
.
vae_model
.
encode
(
prefix_video
)[
0
]
# [(b, c, f, h, w)]
if
prefix_video
.
shape
[
1
]
%
causal_block_size
!=
0
:
if
prefix_video
.
shape
[
1
]
%
causal_block_size
!=
0
:
truncate_len
=
prefix_video
.
shape
[
1
]
%
causal_block_size
truncate_len
=
prefix_video
.
shape
[
1
]
%
causal_block_size
# the length of prefix video is truncated for the casual block size alignment.
# the length of prefix video is truncated for the casual block size alignment.
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
d061ae81
...
@@ -7,6 +7,8 @@ import torch.nn.functional as F
...
@@ -7,6 +7,8 @@ import torch.nn.functional as F
from
einops
import
rearrange
from
einops
import
rearrange
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v.models.video_encoders.hf.wan.dist.distributed_env
import
DistributedEnv
from
lightx2v.models.video_encoders.hf.wan.dist.split_gather
import
gather_forward_split_backward
,
split_forward_gather_backward
from
lightx2v.utils.utils
import
load_weights
from
lightx2v.utils.utils
import
load_weights
__all__
=
[
__all__
=
[
...
@@ -517,6 +519,7 @@ class WanVAE_(nn.Module):
...
@@ -517,6 +519,7 @@ class WanVAE_(nn.Module):
self
.
temperal_downsample
=
temperal_downsample
self
.
temperal_downsample
=
temperal_downsample
self
.
temperal_upsample
=
temperal_downsample
[::
-
1
]
self
.
temperal_upsample
=
temperal_downsample
[::
-
1
]
self
.
spatial_compression_ratio
=
2
**
len
(
self
.
temperal_downsample
)
self
.
spatial_compression_ratio
=
2
**
len
(
self
.
temperal_downsample
)
self
.
use_approximate_patch
=
False
# The minimal tile height and width for spatial tiling to be used
# The minimal tile height and width for spatial tiling to be used
self
.
tile_sample_min_height
=
256
self
.
tile_sample_min_height
=
256
...
@@ -547,6 +550,12 @@ class WanVAE_(nn.Module):
...
@@ -547,6 +550,12 @@ class WanVAE_(nn.Module):
dropout
,
dropout
,
)
)
def
enable_approximate_patch
(
self
):
self
.
use_approximate_patch
=
True
def
disable_approximate_patch
(
self
):
self
.
use_approximate_patch
=
False
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
mu
,
log_var
=
self
.
encode
(
x
)
mu
,
log_var
=
self
.
encode
(
x
)
z
=
self
.
reparameterize
(
mu
,
log_var
)
z
=
self
.
reparameterize
(
mu
,
log_var
)
...
@@ -629,6 +638,9 @@ class WanVAE_(nn.Module):
...
@@ -629,6 +638,9 @@ class WanVAE_(nn.Module):
return
enc
return
enc
def
tiled_decode
(
self
,
z
,
scale
):
def
tiled_decode
(
self
,
z
,
scale
):
if
self
.
use_approximate_patch
:
z
=
split_forward_gather_backward
(
None
,
z
,
3
)
if
isinstance
(
scale
[
0
],
torch
.
Tensor
):
if
isinstance
(
scale
[
0
],
torch
.
Tensor
):
z
=
z
/
scale
[
1
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
+
scale
[
0
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
z
=
z
/
scale
[
1
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
+
scale
[
0
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
else
:
else
:
...
@@ -678,6 +690,8 @@ class WanVAE_(nn.Module):
...
@@ -678,6 +690,8 @@ class WanVAE_(nn.Module):
result_rows
.
append
(
torch
.
cat
(
result_row
,
dim
=-
1
))
result_rows
.
append
(
torch
.
cat
(
result_row
,
dim
=-
1
))
dec
=
torch
.
cat
(
result_rows
,
dim
=
3
)[:,
:,
:,
:
sample_height
,
:
sample_width
]
dec
=
torch
.
cat
(
result_rows
,
dim
=
3
)[:,
:,
:,
:
sample_height
,
:
sample_width
]
if
self
.
use_approximate_patch
:
dec
=
gather_forward_split_backward
(
None
,
dec
,
3
)
return
dec
return
dec
...
@@ -686,7 +700,6 @@ class WanVAE_(nn.Module):
...
@@ -686,7 +700,6 @@ class WanVAE_(nn.Module):
## cache
## cache
t
=
x
.
shape
[
2
]
t
=
x
.
shape
[
2
]
iter_
=
1
+
(
t
-
1
)
//
4
iter_
=
1
+
(
t
-
1
)
//
4
## 对encode输入的x,按时间拆分为1、4、4、4....
for
i
in
range
(
iter_
):
for
i
in
range
(
iter_
):
self
.
_enc_conv_idx
=
[
0
]
self
.
_enc_conv_idx
=
[
0
]
if
i
==
0
:
if
i
==
0
:
...
@@ -707,11 +720,15 @@ class WanVAE_(nn.Module):
...
@@ -707,11 +720,15 @@ class WanVAE_(nn.Module):
mu
=
(
mu
-
scale
[
0
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
))
*
scale
[
1
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
mu
=
(
mu
-
scale
[
0
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
))
*
scale
[
1
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
else
:
else
:
mu
=
(
mu
-
scale
[
0
])
*
scale
[
1
]
mu
=
(
mu
-
scale
[
0
])
*
scale
[
1
]
self
.
clear_cache
()
self
.
clear_cache
()
return
mu
return
mu
def
decode
(
self
,
z
,
scale
):
def
decode
(
self
,
z
,
scale
):
self
.
clear_cache
()
self
.
clear_cache
()
if
self
.
use_approximate_patch
:
z
=
split_forward_gather_backward
(
None
,
z
,
3
)
# z: [b,c,t,h,w]
# z: [b,c,t,h,w]
if
isinstance
(
scale
[
0
],
torch
.
Tensor
):
if
isinstance
(
scale
[
0
],
torch
.
Tensor
):
z
=
z
/
scale
[
1
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
+
scale
[
0
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
z
=
z
/
scale
[
1
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
+
scale
[
0
].
view
(
1
,
self
.
z_dim
,
1
,
1
,
1
)
...
@@ -734,6 +751,10 @@ class WanVAE_(nn.Module):
...
@@ -734,6 +751,10 @@ class WanVAE_(nn.Module):
feat_idx
=
self
.
_conv_idx
,
feat_idx
=
self
.
_conv_idx
,
)
)
out
=
torch
.
cat
([
out
,
out_
],
2
)
out
=
torch
.
cat
([
out
,
out_
],
2
)
if
self
.
use_approximate_patch
:
out
=
gather_forward_split_backward
(
None
,
out
,
3
)
self
.
clear_cache
()
self
.
clear_cache
()
return
out
return
out
...
@@ -845,6 +866,12 @@ class WanVAE:
...
@@ -845,6 +866,12 @@ class WanVAE:
# init model
# init model
self
.
model
=
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
cpu_offload
=
cpu_offload
).
eval
().
requires_grad_
(
False
).
to
(
device
)
self
.
model
=
_video_vae
(
pretrained_path
=
vae_pth
,
z_dim
=
z_dim
,
cpu_offload
=
cpu_offload
).
eval
().
requires_grad_
(
False
).
to
(
device
)
self
.
use_approximate_patch
=
False
if
self
.
parallel
and
self
.
parallel
.
get
(
"use_patch_vae"
,
False
):
# assert not self.use_tiling
DistributedEnv
.
initialize
(
None
)
self
.
use_approximate_patch
=
True
self
.
model
.
enable_approximate_patch
()
def
current_device
(
self
):
def
current_device
(
self
):
return
next
(
self
.
model
.
parameters
()).
device
return
next
(
self
.
model
.
parameters
()).
device
...
@@ -865,11 +892,11 @@ class WanVAE:
...
@@ -865,11 +892,11 @@ class WanVAE:
self
.
inv_std
=
self
.
inv_std
.
cuda
()
self
.
inv_std
=
self
.
inv_std
.
cuda
()
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
self
.
scale
=
[
self
.
mean
,
self
.
inv_std
]
def
encode
(
self
,
videos
,
args
):
def
encode
(
self
,
videos
):
"""
"""
videos: A list of videos each with shape [C, T, H, W].
videos: A list of videos each with shape [C, T, H, W].
"""
"""
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
to_cuda
()
self
.
to_cuda
()
if
self
.
use_tiling
:
if
self
.
use_tiling
:
...
@@ -877,7 +904,7 @@ class WanVAE:
...
@@ -877,7 +904,7 @@ class WanVAE:
else
:
else
:
out
=
[
self
.
model
.
encode
(
u
.
unsqueeze
(
0
).
to
(
self
.
current_device
()),
self
.
scale
).
float
().
squeeze
(
0
)
for
u
in
videos
]
out
=
[
self
.
model
.
encode
(
u
.
unsqueeze
(
0
).
to
(
self
.
current_device
()),
self
.
scale
).
float
().
squeeze
(
0
)
for
u
in
videos
]
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
to_cpu
()
self
.
to_cpu
()
return
out
return
out
...
@@ -902,7 +929,8 @@ class WanVAE:
...
@@ -902,7 +929,8 @@ class WanVAE:
elif
split_dim
==
3
:
elif
split_dim
==
3
:
zs
=
zs
[:,
:,
:,
cur_rank
*
splited_chunk_len
-
padding_size
:
(
cur_rank
+
1
)
*
splited_chunk_len
+
padding_size
].
contiguous
()
zs
=
zs
[:,
:,
:,
cur_rank
*
splited_chunk_len
-
padding_size
:
(
cur_rank
+
1
)
*
splited_chunk_len
+
padding_size
].
contiguous
()
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
decode_func
=
self
.
model
.
tiled_decode
if
self
.
use_tiling
else
self
.
model
.
decode
images
=
decode_func
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
if
cur_rank
==
0
:
if
cur_rank
==
0
:
if
split_dim
==
2
:
if
split_dim
==
2
:
...
@@ -933,23 +961,21 @@ class WanVAE:
...
@@ -933,23 +961,21 @@ class WanVAE:
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
to_cuda
()
self
.
to_cuda
()
if
self
.
parallel
:
if
self
.
parallel
and
not
self
.
use_approximate_patch
:
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
()
cur_rank
=
dist
.
get_rank
()
height
,
width
=
zs
.
shape
[
2
],
zs
.
shape
[
3
]
height
,
width
=
zs
.
shape
[
2
],
zs
.
shape
[
3
]
if
width
%
world_size
==
0
:
if
width
%
world_size
==
0
:
split_dim
=
3
images
=
self
.
decode_dist
(
zs
,
world_size
,
cur_rank
,
split_dim
=
3
)
images
=
self
.
decode_dist
(
zs
,
world_size
,
cur_rank
,
split_dim
)
elif
height
%
world_size
==
0
:
elif
height
%
world_size
==
0
:
split_dim
=
2
images
=
self
.
decode_dist
(
zs
,
world_size
,
cur_rank
,
split_dim
=
2
)
images
=
self
.
decode_dist
(
zs
,
world_size
,
cur_rank
,
split_dim
)
else
:
else
:
logger
.
info
(
"Fall back to naive decode mode"
)
logger
.
info
(
"Fall back to naive decode mode"
)
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
elif
self
.
use_tiling
:
images
=
self
.
model
.
tiled_decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
else
:
else
:
images
=
self
.
model
.
decode
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
decode_func
=
self
.
model
.
tiled_decode
if
self
.
use_tiling
else
self
.
model
.
decode
images
=
decode_func
(
zs
.
unsqueeze
(
0
),
self
.
scale
).
float
().
clamp_
(
-
1
,
1
)
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
images
=
images
.
cpu
().
float
()
images
=
images
.
cpu
().
float
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment