Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
3e215bad
Commit
3e215bad
authored
Aug 06, 2025
by
gushiqiao
Browse files
Support bf16/fp16 inference and mixed-precision inference with fp32 for some layers
parent
e684202c
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
85 additions
and
60 deletions
+85
-60
lightx2v/models/networks/wan/lora_adapter.py
lightx2v/models/networks/wan/lora_adapter.py
+1
-1
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+17
-21
lightx2v/models/runners/hunyuan/hunyuan_runner.py
lightx2v/models/runners/hunyuan/hunyuan_runner.py
+2
-1
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+6
-5
lightx2v/models/runners/wan/wan_causvid_runner.py
lightx2v/models/runners/wan/wan_causvid_runner.py
+4
-3
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+3
-2
lightx2v/models/runners/wan/wan_skyreels_v2_df_runner.py
lightx2v/models/runners/wan/wan_skyreels_v2_df_runner.py
+2
-1
lightx2v/models/schedulers/cogvideox/scheduler.py
lightx2v/models/schedulers/cogvideox/scheduler.py
+0
-2
lightx2v/models/schedulers/hunyuan/scheduler.py
lightx2v/models/schedulers/hunyuan/scheduler.py
+7
-6
lightx2v/models/schedulers/scheduler.py
lightx2v/models/schedulers/scheduler.py
+2
-4
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+2
-4
lightx2v/models/schedulers/wan/df/skyreels_v2_df_scheduler.py
...tx2v/models/schedulers/wan/df/skyreels_v2_df_scheduler.py
+3
-1
lightx2v/models/video_encoders/hf/autoencoder_kl_causal_3d/unet_causal_3d_blocks.py
...ders/hf/autoencoder_kl_causal_3d/unet_causal_3d_blocks.py
+2
-2
lightx2v/models/video_encoders/hf/cogvideox/model.py
lightx2v/models/video_encoders/hf/cogvideox/model.py
+3
-2
lightx2v/utils/envs.py
lightx2v/utils/envs.py
+25
-2
scripts/bench/run_lightx2v_1.sh
scripts/bench/run_lightx2v_1.sh
+2
-1
scripts/bench/run_lightx2v_5.sh
scripts/bench/run_lightx2v_5.sh
+2
-1
scripts/bench/run_lightx2v_5_distill.sh
scripts/bench/run_lightx2v_5_distill.sh
+2
-1
No files found.
lightx2v/models/networks/wan/lora_adapter.py
View file @
3e215bad
...
...
@@ -33,7 +33,7 @@ class WanLoraWrapper:
use_bfloat16
=
self
.
model
.
config
.
get
(
"use_bfloat16"
,
True
)
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
if
use_bfloat16
:
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
to
(
torch
.
bfloat16
)
for
key
in
f
.
keys
()}
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
()
)
for
key
in
f
.
keys
()}
else
:
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
)
for
key
in
f
.
keys
()}
return
tensor_dict
...
...
lightx2v/models/networks/wan/model.py
View file @
3e215bad
import
glob
import
json
import
os
import
torch
...
...
@@ -103,20 +101,20 @@ class WanModel:
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
def
_load_safetensor_to_dict
(
self
,
file_path
,
u
se_bf16
,
skip_bf16
):
def
_load_safetensor_to_dict
(
self
,
file_path
,
u
nified_dtype
,
sensitive_layer
):
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
torch
.
bfloat16
)
if
use_bf16
or
all
(
s
not
in
key
for
s
in
s
kip_bf16
)
else
f
.
get_tensor
(
key
)).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()}
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
s
ensitive_layer
)
else
f
.
get_tensor
(
key
)).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()}
def
_load_ckpt
(
self
,
u
se_bf16
,
skip_bf16
):
def
_load_ckpt
(
self
,
u
nified_dtype
,
sensitive_layer
):
safetensors_path
=
find_hf_model_path
(
self
.
config
,
"dit_original_ckpt"
,
subdir
=
"original"
)
safetensors_files
=
glob
.
glob
(
os
.
path
.
join
(
safetensors_path
,
"*.safetensors"
))
weight_dict
=
{}
for
file_path
in
safetensors_files
:
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
u
se_bf16
,
skip_bf16
)
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
u
nified_dtype
,
sensitive_layer
)
weight_dict
.
update
(
file_weights
)
return
weight_dict
def
_load_quant_ckpt
(
self
,
u
se_bf16
,
skip_bf16
):
def
_load_quant_ckpt
(
self
,
u
nified_dtype
,
sensitive_layer
):
ckpt_path
=
self
.
dit_quantized_ckpt
logger
.
info
(
f
"Loading quant dit model from
{
ckpt_path
}
"
)
...
...
@@ -137,8 +135,8 @@ class WanModel:
logger
.
info
(
f
"Loading weights from
{
safetensor_path
}
"
)
for
k
in
f
.
keys
():
if
f
.
get_tensor
(
k
).
dtype
==
torch
.
float
:
if
u
se_bf16
or
all
(
s
not
in
k
for
s
in
s
kip_bf16
):
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
torch
.
bfloat16
).
to
(
self
.
device
)
if
u
nified_dtype
or
all
(
s
not
in
k
for
s
in
s
ensitive_layer
):
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
GET_DTYPE
()
).
to
(
self
.
device
)
else
:
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
else
:
...
...
@@ -146,7 +144,7 @@ class WanModel:
return
weight_dict
def
_load_quant_split_ckpt
(
self
,
u
se_bf16
,
skip_bf16
):
def
_load_quant_split_ckpt
(
self
,
u
nified_dtype
,
sensitive_layer
):
lazy_load_model_path
=
self
.
dit_quantized_ckpt
logger
.
info
(
f
"Loading splited quant model from
{
lazy_load_model_path
}
"
)
pre_post_weight_dict
=
{}
...
...
@@ -155,8 +153,8 @@ class WanModel:
with
safe_open
(
safetensor_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
for
k
in
f
.
keys
():
if
f
.
get_tensor
(
k
).
dtype
==
torch
.
float
:
if
u
se_bf16
or
all
(
s
not
in
k
for
s
in
s
kip_bf16
):
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
torch
.
bfloat16
).
to
(
self
.
device
)
if
u
nified_dtype
or
all
(
s
not
in
k
for
s
in
s
ensitive_layer
):
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
GET_DTYPE
()
).
to
(
self
.
device
)
else
:
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
else
:
...
...
@@ -173,9 +171,9 @@ class WanModel:
pass
def
_init_weights
(
self
,
weight_dict
=
None
):
u
se_bf16
=
GET_DTYPE
()
==
"BF16"
u
nified_dtype
=
GET_DTYPE
()
==
GET_SENSITIVE_DTYPE
()
# Some layers run with float32 to achieve high accuracy
s
kip_bf16
=
{
s
ensitive_layer
=
{
"norm"
,
"embedding"
,
"modulation"
,
...
...
@@ -185,14 +183,12 @@ class WanModel:
}
if
weight_dict
is
None
:
if
not
self
.
dit_quantized
or
self
.
weight_auto_quant
:
self
.
original_weight_dict
=
self
.
_load_ckpt
(
use_bf16
,
skip_bf16
)
elif
self
.
config
.
get
(
"use_gguf"
,
False
):
self
.
original_weight_dict
=
self
.
_load_gguf_ckpt
()
self
.
original_weight_dict
=
self
.
_load_ckpt
(
unified_dtype
,
sensitive_layer
)
else
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
original_weight_dict
=
self
.
_load_quant_ckpt
(
u
se_bf16
,
skip_bf16
)
self
.
original_weight_dict
=
self
.
_load_quant_ckpt
(
u
nified_dtype
,
sensitive_layer
)
else
:
self
.
original_weight_dict
=
self
.
_load_quant_split_ckpt
(
u
se_bf16
,
skip_bf16
)
self
.
original_weight_dict
=
self
.
_load_quant_split_ckpt
(
u
nified_dtype
,
sensitive_layer
)
else
:
self
.
original_weight_dict
=
weight_dict
# init weights
...
...
@@ -300,11 +296,11 @@ class WanModel:
class
Wan22MoeModel
(
WanModel
):
def
_load_ckpt
(
self
,
u
se_bf16
,
skip_bf16
):
def
_load_ckpt
(
self
,
u
nified_dtype
,
sensitive_layer
):
safetensors_files
=
glob
.
glob
(
os
.
path
.
join
(
self
.
model_path
,
"*.safetensors"
))
weight_dict
=
{}
for
file_path
in
safetensors_files
:
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
u
se_bf16
,
skip_bf16
)
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
u
nified_dtype
,
sensitive_layer
)
weight_dict
.
update
(
file_weights
)
return
weight_dict
...
...
lightx2v/models/runners/hunyuan/hunyuan_runner.py
View file @
3e215bad
...
...
@@ -12,6 +12,7 @@ from lightx2v.models.runners.default_runner import DefaultRunner
from
lightx2v.models.schedulers.hunyuan.feature_caching.scheduler
import
HunyuanSchedulerAdaCaching
,
HunyuanSchedulerCustomCaching
,
HunyuanSchedulerTaylorCaching
,
HunyuanSchedulerTeaCaching
from
lightx2v.models.schedulers.hunyuan.scheduler
import
HunyuanScheduler
from
lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model
import
VideoEncoderKLCausal3DModel
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
save_videos_grid
...
...
@@ -62,7 +63,7 @@ class HunyuanRunner(DefaultRunner):
text_state
,
attention_mask
=
encoder
.
infer
(
text
,
img
,
self
.
config
)
else
:
text_state
,
attention_mask
=
encoder
.
infer
(
text
,
self
.
config
)
text_encoder_output
[
f
"text_encoder_
{
i
+
1
}
_text_states"
]
=
text_state
.
to
(
dtype
=
torch
.
bfloat16
)
text_encoder_output
[
f
"text_encoder_
{
i
+
1
}
_text_states"
]
=
text_state
.
to
(
dtype
=
GET_DTYPE
()
)
text_encoder_output
[
f
"text_encoder_
{
i
+
1
}
_attention_mask"
]
=
attention_mask
return
text_encoder_output
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
3e215bad
...
...
@@ -20,6 +20,7 @@ from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudi
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.runners.wan.wan_runner
import
MultiModelStruct
,
WanRunner
from
lightx2v.models.schedulers.wan.audio.scheduler
import
ConsistencyModelScheduler
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
...
...
@@ -259,7 +260,7 @@ class VideoGenerator:
return
None
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
dtype
=
GET_DTYPE
()
vae_dtype
=
torch
.
float
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
...
...
@@ -312,7 +313,7 @@ class VideoGenerator:
# Prepare previous latents - ALWAYS needed, even for first segment
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
dtype
=
GET_DTYPE
()
vae_dtype
=
torch
.
float
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
max_num_frames
=
self
.
config
.
target_video_length
...
...
@@ -425,7 +426,7 @@ class WanAudioRunner(WanRunner): # type:ignore
else
:
device
=
torch
.
device
(
"cuda"
)
audio_encoder_repo
=
self
.
config
[
"model_path"
]
+
"/audio_encoder"
self
.
_audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
torch
.
bfloat16
,
device
=
device
,
weight
=
1.0
,
cpu_offload
=
cpu_offload
)
self
.
_audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
GET_DTYPE
()
,
device
=
device
,
weight
=
1.0
,
cpu_offload
=
cpu_offload
)
return
self
.
_audio_adapter_pipe
...
...
@@ -655,13 +656,13 @@ class WanAudioRunner(WanRunner): # type:ignore
cond_frms
=
torch
.
nn
.
functional
.
interpolate
(
ref_img
,
size
=
(
config
.
tgt_h
,
config
.
tgt_w
),
mode
=
"bicubic"
)
# clip encoder
clip_encoder_out
=
self
.
image_encoder
.
visual
([
cond_frms
],
self
.
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
clip_encoder_out
=
self
.
image_encoder
.
visual
([
cond_frms
],
self
.
config
).
squeeze
(
0
).
to
(
GET_DTYPE
()
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
# vae encode
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
vae_encoder_out
=
vae_model
.
encode
(
cond_frms
.
to
(
torch
.
float
),
config
)
if
isinstance
(
vae_encoder_out
,
list
):
vae_encoder_out
=
torch
.
stack
(
vae_encoder_out
,
dim
=
0
).
to
(
torch
.
bfloat16
)
vae_encoder_out
=
torch
.
stack
(
vae_encoder_out
,
dim
=
0
).
to
(
GET_DTYPE
()
)
return
vae_encoder_out
,
clip_encoder_out
...
...
lightx2v/models/runners/wan/wan_causvid_runner.py
View file @
3e215bad
...
...
@@ -8,6 +8,7 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.schedulers.wan.step_distill.scheduler
import
WanStepDistillScheduler
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
...
...
@@ -65,13 +66,13 @@ class WanCausVidRunner(WanRunner):
)
def
run
(
self
):
self
.
model
.
transformer_infer
.
_init_kv_cache
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
self
.
model
.
transformer_infer
.
_init_crossattn_cache
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
self
.
model
.
transformer_infer
.
_init_kv_cache
(
dtype
=
GET_DTYPE
()
,
device
=
"cuda"
)
self
.
model
.
transformer_infer
.
_init_crossattn_cache
(
dtype
=
GET_DTYPE
()
,
device
=
"cuda"
)
output_latents
=
torch
.
zeros
(
(
self
.
model
.
config
.
target_shape
[
0
],
self
.
num_frames
+
(
self
.
num_fragments
-
1
)
*
(
self
.
num_frames
-
self
.
num_frame_per_block
),
*
self
.
model
.
config
.
target_shape
[
2
:]),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
,
dtype
=
GET_DTYPE
()
,
)
start_block_idx
=
0
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
3e215bad
...
...
@@ -24,6 +24,7 @@ from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.models.video_encoders.hf.wan.vae_2_2
import
Wan2_2_VAE
from
lightx2v.models.video_encoders.hf.wan.vae_tiny
import
WanVAE_tiny
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
*
from
lightx2v.utils.utils
import
best_output_size
,
cache_video
...
...
@@ -207,7 +208,7 @@ class WanRunner(DefaultRunner):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
image_encoder
=
self
.
load_image_encoder
()
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
[
None
,
:,
:,
:]],
self
.
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
[
None
,
:,
:,
:]],
self
.
config
).
squeeze
(
0
).
to
(
GET_DTYPE
()
)
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
image_encoder
torch
.
cuda
.
empty_cache
()
...
...
@@ -271,7 +272,7 @@ class WanRunner(DefaultRunner):
del
self
.
vae_encoder
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
vae_encoder_out
=
torch
.
concat
([
msk
,
vae_encoder_out
]).
to
(
torch
.
bfloat16
)
vae_encoder_out
=
torch
.
concat
([
msk
,
vae_encoder_out
]).
to
(
GET_DTYPE
()
)
return
vae_encoder_out
def
get_encoder_output_i2v
(
self
,
clip_encoder_out
,
vae_encoder_out
,
text_encoder_output
,
img
):
...
...
lightx2v/models/runners/wan/wan_skyreels_v2_df_runner.py
View file @
3e215bad
...
...
@@ -9,6 +9,7 @@ from loguru import logger
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.schedulers.wan.df.skyreels_v2_df_scheduler
import
WanSkyreelsV2DFScheduler
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
...
...
@@ -37,7 +38,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
config
.
lat_w
=
lat_w
vae_encoder_out
=
vae_model
.
encode
([
torch
.
nn
.
functional
.
interpolate
(
img
[
None
].
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
).
cuda
()],
config
)[
0
]
vae_encoder_out
=
vae_encoder_out
.
to
(
torch
.
bfloat16
)
vae_encoder_out
=
vae_encoder_out
.
to
(
GET_DTYPE
()
)
return
vae_encoder_out
def
set_target_shape
(
self
):
...
...
lightx2v/models/schedulers/cogvideox/scheduler.py
View file @
3e215bad
...
...
@@ -269,5 +269,3 @@ class CogvideoxXDPMScheduler(BaseScheduler):
x_advanced
=
mult
[
0
]
*
self
.
latents
-
mult
[
1
]
*
denoised_d
+
mult_noise
*
noise
self
.
latents
=
x_advanced
self
.
old_pred_original_sample
=
pred_original_sample
self
.
latents
=
self
.
latents
.
to
(
torch
.
bfloat16
)
lightx2v/models/schedulers/hunyuan/scheduler.py
View file @
3e215bad
...
...
@@ -5,6 +5,7 @@ import torch
from
diffusers.utils.torch_utils
import
randn_tensor
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
lightx2v.utils.envs
import
*
def
_to_tuple
(
x
,
dim
=
2
):
...
...
@@ -247,12 +248,12 @@ class HunyuanScheduler(BaseScheduler):
def
prepare
(
self
,
image_encoder_output
):
self
.
image_encoder_output
=
image_encoder_output
self
.
prepare_latents
(
shape
=
self
.
config
.
target_shape
,
dtype
=
torch
.
float
16
,
image_encoder_output
=
image_encoder_output
)
self
.
prepare_latents
(
shape
=
self
.
config
.
target_shape
,
dtype
=
torch
.
float
32
,
image_encoder_output
=
image_encoder_output
)
self
.
prepare_guidance
()
self
.
prepare_rotary_pos_embedding
(
video_length
=
self
.
config
.
target_video_length
,
height
=
self
.
config
.
target_height
,
width
=
self
.
config
.
target_width
)
def
prepare_guidance
(
self
):
self
.
guidance
=
torch
.
tensor
([
self
.
embedded_guidance_scale
],
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
*
1000.0
self
.
guidance
=
torch
.
tensor
([
self
.
embedded_guidance_scale
],
dtype
=
GET_DTYPE
()
,
device
=
torch
.
device
(
"cuda"
))
*
1000.0
def
step_post
(
self
):
if
self
.
config
.
task
==
"t2v"
:
...
...
@@ -316,8 +317,8 @@ class HunyuanScheduler(BaseScheduler):
use_real
=
True
,
theta_rescale_factor
=
1
,
)
self
.
freqs_cos
=
self
.
freqs_cos
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
self
.
freqs_sin
=
self
.
freqs_sin
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
self
.
freqs_cos
=
self
.
freqs_cos
.
to
(
dtype
=
GET_DTYPE
()
,
device
=
torch
.
device
(
"cuda"
))
self
.
freqs_sin
=
self
.
freqs_sin
.
to
(
dtype
=
GET_DTYPE
()
,
device
=
torch
.
device
(
"cuda"
))
else
:
L_test
=
rope_sizes
[
0
]
# Latent frames
...
...
@@ -359,5 +360,5 @@ class HunyuanScheduler(BaseScheduler):
theta_rescale_factor
=
1
,
)
self
.
freqs_cos
=
freqs_cos
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
self
.
freqs_sin
=
freqs_sin
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
self
.
freqs_cos
=
freqs_cos
.
to
(
dtype
=
GET_DTYPE
()
,
device
=
torch
.
device
(
"cuda"
))
self
.
freqs_sin
=
freqs_sin
.
to
(
dtype
=
GET_DTYPE
()
,
device
=
torch
.
device
(
"cuda"
))
lightx2v/models/schedulers/scheduler.py
View file @
3e215bad
import
torch
from
lightx2v.utils.envs
import
*
...
...
@@ -15,8 +13,8 @@ class BaseScheduler:
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
if
GET_DTYPE
()
==
"BF16"
:
self
.
latents
=
self
.
latents
.
to
(
dtype
=
torch
.
bfloat16
)
if
GET_DTYPE
()
==
GET_SENSITIVE_DTYPE
()
:
self
.
latents
=
self
.
latents
.
to
(
GET_DTYPE
()
)
def
clear
(
self
):
pass
lightx2v/models/schedulers/wan/audio/scheduler.py
View file @
3e215bad
import
gc
import
math
import
warnings
import
numpy
as
np
import
torch
from
torch
import
Tensor
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
lightx2v.utils.envs
import
*
...
...
@@ -34,8 +32,8 @@ class EulerSchedulerTimestepFix(BaseScheduler):
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
if
GET_DTYPE
()
==
"BF16"
:
self
.
latents
=
self
.
latents
.
to
(
dtype
=
torch
.
bfloat16
)
if
GET_DTYPE
()
==
GET_SENSITIVE_DTYPE
()
:
self
.
latents
=
self
.
latents
.
to
(
GET_DTYPE
()
)
def
prepare
(
self
,
image_encoder_output
=
None
):
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
...
...
lightx2v/models/schedulers/wan/df/skyreels_v2_df_scheduler.py
View file @
3e215bad
...
...
@@ -5,6 +5,7 @@ import numpy as np
import
torch
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.utils.envs
import
*
class
WanSkyreelsV2DFScheduler
(
WanScheduler
):
...
...
@@ -132,7 +133,8 @@ class WanSkyreelsV2DFScheduler(WanScheduler):
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
self
.
latents
=
self
.
latents
.
to
(
dtype
=
torch
.
bfloat16
)
if
GET_DTYPE
()
==
GET_SENSITIVE_DTYPE
():
self
.
latents
=
self
.
latents
.
to
(
GET_DTYPE
())
valid_interval_start
,
valid_interval_end
=
self
.
valid_interval
[
step_index
]
timestep
=
self
.
step_matrix
[
step_index
][
None
,
valid_interval_start
:
valid_interval_end
].
clone
()
...
...
lightx2v/models/video_encoders/hf/autoencoder_kl_causal_3d/unet_causal_3d_blocks.py
View file @
3e215bad
...
...
@@ -156,7 +156,7 @@ class UpsampleCausal3D(nn.Module):
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype
=
hidden_states
.
dtype
if
dtype
==
torch
.
bfloat16
:
if
dtype
in
[
torch
.
bfloat16
,
torch
.
float16
]
:
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
...
...
@@ -185,7 +185,7 @@ class UpsampleCausal3D(nn.Module):
hidden_states
=
first_h
# If the input is bfloat16, we cast back to bfloat16
if
dtype
==
torch
.
bfloat16
:
if
dtype
in
[
torch
.
bfloat16
,
torch
.
float16
]
:
hidden_states
=
hidden_states
.
to
(
dtype
)
if
self
.
use_conv
:
...
...
lightx2v/models/video_encoders/hf/cogvideox/model.py
View file @
3e215bad
...
...
@@ -6,6 +6,7 @@ from diffusers.video_processor import VideoProcessor # type: ignore
from
safetensors
import
safe_open
# type: ignore
from
lightx2v.models.video_encoders.hf.cogvideox.autoencoder_ks_cogvidex
import
AutoencoderKLCogVideoX
from
lightx2v.utils.envs
import
*
class
CogvideoxVAE
:
...
...
@@ -15,7 +16,7 @@ class CogvideoxVAE:
def
_load_safetensor_to_dict
(
self
,
file_path
):
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
to
(
torch
.
bfloat16
).
cuda
()
for
key
in
f
.
keys
()}
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
()
).
cuda
()
for
key
in
f
.
keys
()}
return
tensor_dict
def
_load_ckpt
(
self
,
model_path
):
...
...
@@ -39,7 +40,7 @@ class CogvideoxVAE:
self
.
vae_scale_factor_temporal
=
self
.
vae_config
[
"temporal_compression_ratio"
]
# 4
self
.
vae_scaling_factor_image
=
self
.
vae_config
[
"scaling_factor"
]
# 0.7
self
.
model
.
load_state_dict
(
vae_ckpt
)
self
.
model
.
to
(
torch
.
bfloat16
).
to
(
torch
.
device
(
"cuda"
))
self
.
model
.
to
(
GET_DTYPE
()
).
to
(
torch
.
device
(
"cuda"
))
self
.
video_processor
=
VideoProcessor
(
vae_scale_factor
=
self
.
vae_scale_factor_spatial
)
@
torch
.
no_grad
()
...
...
lightx2v/utils/envs.py
View file @
3e215bad
import
os
from
functools
import
lru_cache
import
torch
DTYPE_MAP
=
{
"BF16"
:
torch
.
bfloat16
,
"FP16"
:
torch
.
float16
,
"FP32"
:
torch
.
float32
,
"bf16"
:
torch
.
bfloat16
,
"fp16"
:
torch
.
float16
,
"fp32"
:
torch
.
float32
,
"torch.bfloat16"
:
torch
.
bfloat16
,
"torch.float16"
:
torch
.
float16
,
"torch.float32"
:
torch
.
float32
,
}
@
lru_cache
(
maxsize
=
None
)
def
CHECK_ENABLE_PROFILING_DEBUG
():
...
...
@@ -22,5 +36,14 @@ def GET_RUNNING_FLAG():
@
lru_cache
(
maxsize
=
None
)
def
GET_DTYPE
():
RUNNING_FLAG
=
os
.
getenv
(
"DTYPE"
)
return
RUNNING_FLAG
RUNNING_FLAG
=
os
.
getenv
(
"DTYPE"
,
"BF16"
)
assert
RUNNING_FLAG
in
[
"BF16"
,
"FP16"
]
return
DTYPE_MAP
[
RUNNING_FLAG
]
@
lru_cache
(
maxsize
=
None
)
def
GET_SENSITIVE_DTYPE
():
RUNNING_FLAG
=
os
.
getenv
(
"SENSITIVE_LAYER_DTYPE"
,
None
)
if
RUNNING_FLAG
is
None
:
return
GET_DTYPE
()
return
DTYPE_MAP
[
RUNNING_FLAG
]
scripts/bench/run_lightx2v_1.sh
View file @
3e215bad
...
...
@@ -25,7 +25,8 @@ fi
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
DTYPE
=
BF16
export
SENSITIVE_LAYER_DTYPE
=
FP32
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
...
...
scripts/bench/run_lightx2v_5.sh
View file @
3e215bad
...
...
@@ -25,7 +25,8 @@ fi
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
DTYPE
=
BF16
export
SENSITIVE_LAYER_DTYPE
=
FP32
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
...
...
scripts/bench/run_lightx2v_5_distill.sh
View file @
3e215bad
...
...
@@ -25,7 +25,8 @@ fi
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
DTYPE
=
BF16
export
SENSITIVE_LAYER_DTYPE
=
FP32
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment