Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
3e215bad
Commit
3e215bad
authored
Aug 06, 2025
by
gushiqiao
Browse files
Support bf16/fp16 inference and mixed-precision inference with fp32 for some layers
parent
e684202c
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
85 additions
and
60 deletions
+85
-60
lightx2v/models/networks/wan/lora_adapter.py
lightx2v/models/networks/wan/lora_adapter.py
+1
-1
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+17
-21
lightx2v/models/runners/hunyuan/hunyuan_runner.py
lightx2v/models/runners/hunyuan/hunyuan_runner.py
+2
-1
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+6
-5
lightx2v/models/runners/wan/wan_causvid_runner.py
lightx2v/models/runners/wan/wan_causvid_runner.py
+4
-3
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+3
-2
lightx2v/models/runners/wan/wan_skyreels_v2_df_runner.py
lightx2v/models/runners/wan/wan_skyreels_v2_df_runner.py
+2
-1
lightx2v/models/schedulers/cogvideox/scheduler.py
lightx2v/models/schedulers/cogvideox/scheduler.py
+0
-2
lightx2v/models/schedulers/hunyuan/scheduler.py
lightx2v/models/schedulers/hunyuan/scheduler.py
+7
-6
lightx2v/models/schedulers/scheduler.py
lightx2v/models/schedulers/scheduler.py
+2
-4
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+2
-4
lightx2v/models/schedulers/wan/df/skyreels_v2_df_scheduler.py
...tx2v/models/schedulers/wan/df/skyreels_v2_df_scheduler.py
+3
-1
lightx2v/models/video_encoders/hf/autoencoder_kl_causal_3d/unet_causal_3d_blocks.py
...ders/hf/autoencoder_kl_causal_3d/unet_causal_3d_blocks.py
+2
-2
lightx2v/models/video_encoders/hf/cogvideox/model.py
lightx2v/models/video_encoders/hf/cogvideox/model.py
+3
-2
lightx2v/utils/envs.py
lightx2v/utils/envs.py
+25
-2
scripts/bench/run_lightx2v_1.sh
scripts/bench/run_lightx2v_1.sh
+2
-1
scripts/bench/run_lightx2v_5.sh
scripts/bench/run_lightx2v_5.sh
+2
-1
scripts/bench/run_lightx2v_5_distill.sh
scripts/bench/run_lightx2v_5_distill.sh
+2
-1
No files found.
lightx2v/models/networks/wan/lora_adapter.py
View file @
3e215bad
...
@@ -33,7 +33,7 @@ class WanLoraWrapper:
...
@@ -33,7 +33,7 @@ class WanLoraWrapper:
use_bfloat16
=
self
.
model
.
config
.
get
(
"use_bfloat16"
,
True
)
use_bfloat16
=
self
.
model
.
config
.
get
(
"use_bfloat16"
,
True
)
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
if
use_bfloat16
:
if
use_bfloat16
:
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
to
(
torch
.
bfloat16
)
for
key
in
f
.
keys
()}
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
()
)
for
key
in
f
.
keys
()}
else
:
else
:
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
)
for
key
in
f
.
keys
()}
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
)
for
key
in
f
.
keys
()}
return
tensor_dict
return
tensor_dict
...
...
lightx2v/models/networks/wan/model.py
View file @
3e215bad
import
glob
import
json
import
os
import
os
import
torch
import
torch
...
@@ -103,20 +101,20 @@ class WanModel:
...
@@ -103,20 +101,20 @@ class WanModel:
else
:
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
def
_load_safetensor_to_dict
(
self
,
file_path
,
u
se_bf16
,
skip_bf16
):
def
_load_safetensor_to_dict
(
self
,
file_path
,
u
nified_dtype
,
sensitive_layer
):
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
torch
.
bfloat16
)
if
use_bf16
or
all
(
s
not
in
key
for
s
in
s
kip_bf16
)
else
f
.
get_tensor
(
key
)).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()}
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
s
ensitive_layer
)
else
f
.
get_tensor
(
key
)).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()}
def
_load_ckpt
(
self
,
u
se_bf16
,
skip_bf16
):
def
_load_ckpt
(
self
,
u
nified_dtype
,
sensitive_layer
):
safetensors_path
=
find_hf_model_path
(
self
.
config
,
"dit_original_ckpt"
,
subdir
=
"original"
)
safetensors_path
=
find_hf_model_path
(
self
.
config
,
"dit_original_ckpt"
,
subdir
=
"original"
)
safetensors_files
=
glob
.
glob
(
os
.
path
.
join
(
safetensors_path
,
"*.safetensors"
))
safetensors_files
=
glob
.
glob
(
os
.
path
.
join
(
safetensors_path
,
"*.safetensors"
))
weight_dict
=
{}
weight_dict
=
{}
for
file_path
in
safetensors_files
:
for
file_path
in
safetensors_files
:
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
u
se_bf16
,
skip_bf16
)
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
u
nified_dtype
,
sensitive_layer
)
weight_dict
.
update
(
file_weights
)
weight_dict
.
update
(
file_weights
)
return
weight_dict
return
weight_dict
def
_load_quant_ckpt
(
self
,
u
se_bf16
,
skip_bf16
):
def
_load_quant_ckpt
(
self
,
u
nified_dtype
,
sensitive_layer
):
ckpt_path
=
self
.
dit_quantized_ckpt
ckpt_path
=
self
.
dit_quantized_ckpt
logger
.
info
(
f
"Loading quant dit model from
{
ckpt_path
}
"
)
logger
.
info
(
f
"Loading quant dit model from
{
ckpt_path
}
"
)
...
@@ -137,8 +135,8 @@ class WanModel:
...
@@ -137,8 +135,8 @@ class WanModel:
logger
.
info
(
f
"Loading weights from
{
safetensor_path
}
"
)
logger
.
info
(
f
"Loading weights from
{
safetensor_path
}
"
)
for
k
in
f
.
keys
():
for
k
in
f
.
keys
():
if
f
.
get_tensor
(
k
).
dtype
==
torch
.
float
:
if
f
.
get_tensor
(
k
).
dtype
==
torch
.
float
:
if
u
se_bf16
or
all
(
s
not
in
k
for
s
in
s
kip_bf16
):
if
u
nified_dtype
or
all
(
s
not
in
k
for
s
in
s
ensitive_layer
):
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
torch
.
bfloat16
).
to
(
self
.
device
)
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
GET_DTYPE
()
).
to
(
self
.
device
)
else
:
else
:
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
else
:
else
:
...
@@ -146,7 +144,7 @@ class WanModel:
...
@@ -146,7 +144,7 @@ class WanModel:
return
weight_dict
return
weight_dict
def
_load_quant_split_ckpt
(
self
,
u
se_bf16
,
skip_bf16
):
def
_load_quant_split_ckpt
(
self
,
u
nified_dtype
,
sensitive_layer
):
lazy_load_model_path
=
self
.
dit_quantized_ckpt
lazy_load_model_path
=
self
.
dit_quantized_ckpt
logger
.
info
(
f
"Loading splited quant model from
{
lazy_load_model_path
}
"
)
logger
.
info
(
f
"Loading splited quant model from
{
lazy_load_model_path
}
"
)
pre_post_weight_dict
=
{}
pre_post_weight_dict
=
{}
...
@@ -155,8 +153,8 @@ class WanModel:
...
@@ -155,8 +153,8 @@ class WanModel:
with
safe_open
(
safetensor_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
with
safe_open
(
safetensor_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
for
k
in
f
.
keys
():
for
k
in
f
.
keys
():
if
f
.
get_tensor
(
k
).
dtype
==
torch
.
float
:
if
f
.
get_tensor
(
k
).
dtype
==
torch
.
float
:
if
u
se_bf16
or
all
(
s
not
in
k
for
s
in
s
kip_bf16
):
if
u
nified_dtype
or
all
(
s
not
in
k
for
s
in
s
ensitive_layer
):
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
torch
.
bfloat16
).
to
(
self
.
device
)
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
GET_DTYPE
()
).
to
(
self
.
device
)
else
:
else
:
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
else
:
else
:
...
@@ -173,9 +171,9 @@ class WanModel:
...
@@ -173,9 +171,9 @@ class WanModel:
pass
pass
def
_init_weights
(
self
,
weight_dict
=
None
):
def
_init_weights
(
self
,
weight_dict
=
None
):
u
se_bf16
=
GET_DTYPE
()
==
"BF16"
u
nified_dtype
=
GET_DTYPE
()
==
GET_SENSITIVE_DTYPE
()
# Some layers run with float32 to achieve high accuracy
# Some layers run with float32 to achieve high accuracy
s
kip_bf16
=
{
s
ensitive_layer
=
{
"norm"
,
"norm"
,
"embedding"
,
"embedding"
,
"modulation"
,
"modulation"
,
...
@@ -185,14 +183,12 @@ class WanModel:
...
@@ -185,14 +183,12 @@ class WanModel:
}
}
if
weight_dict
is
None
:
if
weight_dict
is
None
:
if
not
self
.
dit_quantized
or
self
.
weight_auto_quant
:
if
not
self
.
dit_quantized
or
self
.
weight_auto_quant
:
self
.
original_weight_dict
=
self
.
_load_ckpt
(
use_bf16
,
skip_bf16
)
self
.
original_weight_dict
=
self
.
_load_ckpt
(
unified_dtype
,
sensitive_layer
)
elif
self
.
config
.
get
(
"use_gguf"
,
False
):
self
.
original_weight_dict
=
self
.
_load_gguf_ckpt
()
else
:
else
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
original_weight_dict
=
self
.
_load_quant_ckpt
(
u
se_bf16
,
skip_bf16
)
self
.
original_weight_dict
=
self
.
_load_quant_ckpt
(
u
nified_dtype
,
sensitive_layer
)
else
:
else
:
self
.
original_weight_dict
=
self
.
_load_quant_split_ckpt
(
u
se_bf16
,
skip_bf16
)
self
.
original_weight_dict
=
self
.
_load_quant_split_ckpt
(
u
nified_dtype
,
sensitive_layer
)
else
:
else
:
self
.
original_weight_dict
=
weight_dict
self
.
original_weight_dict
=
weight_dict
# init weights
# init weights
...
@@ -300,11 +296,11 @@ class WanModel:
...
@@ -300,11 +296,11 @@ class WanModel:
class
Wan22MoeModel
(
WanModel
):
class
Wan22MoeModel
(
WanModel
):
def
_load_ckpt
(
self
,
u
se_bf16
,
skip_bf16
):
def
_load_ckpt
(
self
,
u
nified_dtype
,
sensitive_layer
):
safetensors_files
=
glob
.
glob
(
os
.
path
.
join
(
self
.
model_path
,
"*.safetensors"
))
safetensors_files
=
glob
.
glob
(
os
.
path
.
join
(
self
.
model_path
,
"*.safetensors"
))
weight_dict
=
{}
weight_dict
=
{}
for
file_path
in
safetensors_files
:
for
file_path
in
safetensors_files
:
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
u
se_bf16
,
skip_bf16
)
file_weights
=
self
.
_load_safetensor_to_dict
(
file_path
,
u
nified_dtype
,
sensitive_layer
)
weight_dict
.
update
(
file_weights
)
weight_dict
.
update
(
file_weights
)
return
weight_dict
return
weight_dict
...
...
lightx2v/models/runners/hunyuan/hunyuan_runner.py
View file @
3e215bad
...
@@ -12,6 +12,7 @@ from lightx2v.models.runners.default_runner import DefaultRunner
...
@@ -12,6 +12,7 @@ from lightx2v.models.runners.default_runner import DefaultRunner
from
lightx2v.models.schedulers.hunyuan.feature_caching.scheduler
import
HunyuanSchedulerAdaCaching
,
HunyuanSchedulerCustomCaching
,
HunyuanSchedulerTaylorCaching
,
HunyuanSchedulerTeaCaching
from
lightx2v.models.schedulers.hunyuan.feature_caching.scheduler
import
HunyuanSchedulerAdaCaching
,
HunyuanSchedulerCustomCaching
,
HunyuanSchedulerTaylorCaching
,
HunyuanSchedulerTeaCaching
from
lightx2v.models.schedulers.hunyuan.scheduler
import
HunyuanScheduler
from
lightx2v.models.schedulers.hunyuan.scheduler
import
HunyuanScheduler
from
lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model
import
VideoEncoderKLCausal3DModel
from
lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model
import
VideoEncoderKLCausal3DModel
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
save_videos_grid
from
lightx2v.utils.utils
import
save_videos_grid
...
@@ -62,7 +63,7 @@ class HunyuanRunner(DefaultRunner):
...
@@ -62,7 +63,7 @@ class HunyuanRunner(DefaultRunner):
text_state
,
attention_mask
=
encoder
.
infer
(
text
,
img
,
self
.
config
)
text_state
,
attention_mask
=
encoder
.
infer
(
text
,
img
,
self
.
config
)
else
:
else
:
text_state
,
attention_mask
=
encoder
.
infer
(
text
,
self
.
config
)
text_state
,
attention_mask
=
encoder
.
infer
(
text
,
self
.
config
)
text_encoder_output
[
f
"text_encoder_
{
i
+
1
}
_text_states"
]
=
text_state
.
to
(
dtype
=
torch
.
bfloat16
)
text_encoder_output
[
f
"text_encoder_
{
i
+
1
}
_text_states"
]
=
text_state
.
to
(
dtype
=
GET_DTYPE
()
)
text_encoder_output
[
f
"text_encoder_
{
i
+
1
}
_attention_mask"
]
=
attention_mask
text_encoder_output
[
f
"text_encoder_
{
i
+
1
}
_attention_mask"
]
=
attention_mask
return
text_encoder_output
return
text_encoder_output
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
3e215bad
...
@@ -20,6 +20,7 @@ from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudi
...
@@ -20,6 +20,7 @@ from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudi
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.runners.wan.wan_runner
import
MultiModelStruct
,
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
MultiModelStruct
,
WanRunner
from
lightx2v.models.schedulers.wan.audio.scheduler
import
ConsistencyModelScheduler
from
lightx2v.models.schedulers.wan.audio.scheduler
import
ConsistencyModelScheduler
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
...
@@ -259,7 +260,7 @@ class VideoGenerator:
...
@@ -259,7 +260,7 @@ class VideoGenerator:
return
None
return
None
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
dtype
=
GET_DTYPE
()
vae_dtype
=
torch
.
float
vae_dtype
=
torch
.
float
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
...
@@ -312,7 +313,7 @@ class VideoGenerator:
...
@@ -312,7 +313,7 @@ class VideoGenerator:
# Prepare previous latents - ALWAYS needed, even for first segment
# Prepare previous latents - ALWAYS needed, even for first segment
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
dtype
=
GET_DTYPE
()
vae_dtype
=
torch
.
float
vae_dtype
=
torch
.
float
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
tgt_h
,
tgt_w
=
self
.
config
.
tgt_h
,
self
.
config
.
tgt_w
max_num_frames
=
self
.
config
.
target_video_length
max_num_frames
=
self
.
config
.
target_video_length
...
@@ -425,7 +426,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -425,7 +426,7 @@ class WanAudioRunner(WanRunner): # type:ignore
else
:
else
:
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
audio_encoder_repo
=
self
.
config
[
"model_path"
]
+
"/audio_encoder"
audio_encoder_repo
=
self
.
config
[
"model_path"
]
+
"/audio_encoder"
self
.
_audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
torch
.
bfloat16
,
device
=
device
,
weight
=
1.0
,
cpu_offload
=
cpu_offload
)
self
.
_audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
GET_DTYPE
()
,
device
=
device
,
weight
=
1.0
,
cpu_offload
=
cpu_offload
)
return
self
.
_audio_adapter_pipe
return
self
.
_audio_adapter_pipe
...
@@ -655,13 +656,13 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -655,13 +656,13 @@ class WanAudioRunner(WanRunner): # type:ignore
cond_frms
=
torch
.
nn
.
functional
.
interpolate
(
ref_img
,
size
=
(
config
.
tgt_h
,
config
.
tgt_w
),
mode
=
"bicubic"
)
cond_frms
=
torch
.
nn
.
functional
.
interpolate
(
ref_img
,
size
=
(
config
.
tgt_h
,
config
.
tgt_w
),
mode
=
"bicubic"
)
# clip encoder
# clip encoder
clip_encoder_out
=
self
.
image_encoder
.
visual
([
cond_frms
],
self
.
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
clip_encoder_out
=
self
.
image_encoder
.
visual
([
cond_frms
],
self
.
config
).
squeeze
(
0
).
to
(
GET_DTYPE
()
)
if
self
.
config
.
get
(
"use_image_encoder"
,
True
)
else
None
# vae encode
# vae encode
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
vae_encoder_out
=
vae_model
.
encode
(
cond_frms
.
to
(
torch
.
float
),
config
)
vae_encoder_out
=
vae_model
.
encode
(
cond_frms
.
to
(
torch
.
float
),
config
)
if
isinstance
(
vae_encoder_out
,
list
):
if
isinstance
(
vae_encoder_out
,
list
):
vae_encoder_out
=
torch
.
stack
(
vae_encoder_out
,
dim
=
0
).
to
(
torch
.
bfloat16
)
vae_encoder_out
=
torch
.
stack
(
vae_encoder_out
,
dim
=
0
).
to
(
GET_DTYPE
()
)
return
vae_encoder_out
,
clip_encoder_out
return
vae_encoder_out
,
clip_encoder_out
...
...
lightx2v/models/runners/wan/wan_causvid_runner.py
View file @
3e215bad
...
@@ -8,6 +8,7 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
...
@@ -8,6 +8,7 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.schedulers.wan.step_distill.scheduler
import
WanStepDistillScheduler
from
lightx2v.models.schedulers.wan.step_distill.scheduler
import
WanStepDistillScheduler
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
...
@@ -65,13 +66,13 @@ class WanCausVidRunner(WanRunner):
...
@@ -65,13 +66,13 @@ class WanCausVidRunner(WanRunner):
)
)
def
run
(
self
):
def
run
(
self
):
self
.
model
.
transformer_infer
.
_init_kv_cache
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
self
.
model
.
transformer_infer
.
_init_kv_cache
(
dtype
=
GET_DTYPE
()
,
device
=
"cuda"
)
self
.
model
.
transformer_infer
.
_init_crossattn_cache
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
self
.
model
.
transformer_infer
.
_init_crossattn_cache
(
dtype
=
GET_DTYPE
()
,
device
=
"cuda"
)
output_latents
=
torch
.
zeros
(
output_latents
=
torch
.
zeros
(
(
self
.
model
.
config
.
target_shape
[
0
],
self
.
num_frames
+
(
self
.
num_fragments
-
1
)
*
(
self
.
num_frames
-
self
.
num_frame_per_block
),
*
self
.
model
.
config
.
target_shape
[
2
:]),
(
self
.
model
.
config
.
target_shape
[
0
],
self
.
num_frames
+
(
self
.
num_fragments
-
1
)
*
(
self
.
num_frames
-
self
.
num_frame_per_block
),
*
self
.
model
.
config
.
target_shape
[
2
:]),
device
=
"cuda"
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
,
dtype
=
GET_DTYPE
()
,
)
)
start_block_idx
=
0
start_block_idx
=
0
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
3e215bad
...
@@ -24,6 +24,7 @@ from lightx2v.models.schedulers.wan.scheduler import WanScheduler
...
@@ -24,6 +24,7 @@ from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.models.video_encoders.hf.wan.vae_2_2
import
Wan2_2_VAE
from
lightx2v.models.video_encoders.hf.wan.vae_2_2
import
Wan2_2_VAE
from
lightx2v.models.video_encoders.hf.wan.vae_tiny
import
WanVAE_tiny
from
lightx2v.models.video_encoders.hf.wan.vae_tiny
import
WanVAE_tiny
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
*
from
lightx2v.utils.utils
import
*
from
lightx2v.utils.utils
import
best_output_size
,
cache_video
from
lightx2v.utils.utils
import
best_output_size
,
cache_video
...
@@ -207,7 +208,7 @@ class WanRunner(DefaultRunner):
...
@@ -207,7 +208,7 @@ class WanRunner(DefaultRunner):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
image_encoder
=
self
.
load_image_encoder
()
self
.
image_encoder
=
self
.
load_image_encoder
()
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
[
None
,
:,
:,
:]],
self
.
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
[
None
,
:,
:,
:]],
self
.
config
).
squeeze
(
0
).
to
(
GET_DTYPE
()
)
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
image_encoder
del
self
.
image_encoder
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -271,7 +272,7 @@ class WanRunner(DefaultRunner):
...
@@ -271,7 +272,7 @@ class WanRunner(DefaultRunner):
del
self
.
vae_encoder
del
self
.
vae_encoder
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
vae_encoder_out
=
torch
.
concat
([
msk
,
vae_encoder_out
]).
to
(
torch
.
bfloat16
)
vae_encoder_out
=
torch
.
concat
([
msk
,
vae_encoder_out
]).
to
(
GET_DTYPE
()
)
return
vae_encoder_out
return
vae_encoder_out
def
get_encoder_output_i2v
(
self
,
clip_encoder_out
,
vae_encoder_out
,
text_encoder_output
,
img
):
def
get_encoder_output_i2v
(
self
,
clip_encoder_out
,
vae_encoder_out
,
text_encoder_output
,
img
):
...
...
lightx2v/models/runners/wan/wan_skyreels_v2_df_runner.py
View file @
3e215bad
...
@@ -9,6 +9,7 @@ from loguru import logger
...
@@ -9,6 +9,7 @@ from loguru import logger
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.schedulers.wan.df.skyreels_v2_df_scheduler
import
WanSkyreelsV2DFScheduler
from
lightx2v.models.schedulers.wan.df.skyreels_v2_df_scheduler
import
WanSkyreelsV2DFScheduler
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
...
@@ -37,7 +38,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
...
@@ -37,7 +38,7 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
config
.
lat_w
=
lat_w
config
.
lat_w
=
lat_w
vae_encoder_out
=
vae_model
.
encode
([
torch
.
nn
.
functional
.
interpolate
(
img
[
None
].
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
).
cuda
()],
config
)[
0
]
vae_encoder_out
=
vae_model
.
encode
([
torch
.
nn
.
functional
.
interpolate
(
img
[
None
].
cpu
(),
size
=
(
h
,
w
),
mode
=
"bicubic"
).
transpose
(
0
,
1
).
cuda
()],
config
)[
0
]
vae_encoder_out
=
vae_encoder_out
.
to
(
torch
.
bfloat16
)
vae_encoder_out
=
vae_encoder_out
.
to
(
GET_DTYPE
()
)
return
vae_encoder_out
return
vae_encoder_out
def
set_target_shape
(
self
):
def
set_target_shape
(
self
):
...
...
lightx2v/models/schedulers/cogvideox/scheduler.py
View file @
3e215bad
...
@@ -269,5 +269,3 @@ class CogvideoxXDPMScheduler(BaseScheduler):
...
@@ -269,5 +269,3 @@ class CogvideoxXDPMScheduler(BaseScheduler):
x_advanced
=
mult
[
0
]
*
self
.
latents
-
mult
[
1
]
*
denoised_d
+
mult_noise
*
noise
x_advanced
=
mult
[
0
]
*
self
.
latents
-
mult
[
1
]
*
denoised_d
+
mult_noise
*
noise
self
.
latents
=
x_advanced
self
.
latents
=
x_advanced
self
.
old_pred_original_sample
=
pred_original_sample
self
.
old_pred_original_sample
=
pred_original_sample
self
.
latents
=
self
.
latents
.
to
(
torch
.
bfloat16
)
lightx2v/models/schedulers/hunyuan/scheduler.py
View file @
3e215bad
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
from
diffusers.utils.torch_utils
import
randn_tensor
from
diffusers.utils.torch_utils
import
randn_tensor
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
lightx2v.utils.envs
import
*
def
_to_tuple
(
x
,
dim
=
2
):
def
_to_tuple
(
x
,
dim
=
2
):
...
@@ -247,12 +248,12 @@ class HunyuanScheduler(BaseScheduler):
...
@@ -247,12 +248,12 @@ class HunyuanScheduler(BaseScheduler):
def
prepare
(
self
,
image_encoder_output
):
def
prepare
(
self
,
image_encoder_output
):
self
.
image_encoder_output
=
image_encoder_output
self
.
image_encoder_output
=
image_encoder_output
self
.
prepare_latents
(
shape
=
self
.
config
.
target_shape
,
dtype
=
torch
.
float
16
,
image_encoder_output
=
image_encoder_output
)
self
.
prepare_latents
(
shape
=
self
.
config
.
target_shape
,
dtype
=
torch
.
float
32
,
image_encoder_output
=
image_encoder_output
)
self
.
prepare_guidance
()
self
.
prepare_guidance
()
self
.
prepare_rotary_pos_embedding
(
video_length
=
self
.
config
.
target_video_length
,
height
=
self
.
config
.
target_height
,
width
=
self
.
config
.
target_width
)
self
.
prepare_rotary_pos_embedding
(
video_length
=
self
.
config
.
target_video_length
,
height
=
self
.
config
.
target_height
,
width
=
self
.
config
.
target_width
)
def
prepare_guidance
(
self
):
def
prepare_guidance
(
self
):
self
.
guidance
=
torch
.
tensor
([
self
.
embedded_guidance_scale
],
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
*
1000.0
self
.
guidance
=
torch
.
tensor
([
self
.
embedded_guidance_scale
],
dtype
=
GET_DTYPE
()
,
device
=
torch
.
device
(
"cuda"
))
*
1000.0
def
step_post
(
self
):
def
step_post
(
self
):
if
self
.
config
.
task
==
"t2v"
:
if
self
.
config
.
task
==
"t2v"
:
...
@@ -316,8 +317,8 @@ class HunyuanScheduler(BaseScheduler):
...
@@ -316,8 +317,8 @@ class HunyuanScheduler(BaseScheduler):
use_real
=
True
,
use_real
=
True
,
theta_rescale_factor
=
1
,
theta_rescale_factor
=
1
,
)
)
self
.
freqs_cos
=
self
.
freqs_cos
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
self
.
freqs_cos
=
self
.
freqs_cos
.
to
(
dtype
=
GET_DTYPE
()
,
device
=
torch
.
device
(
"cuda"
))
self
.
freqs_sin
=
self
.
freqs_sin
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
self
.
freqs_sin
=
self
.
freqs_sin
.
to
(
dtype
=
GET_DTYPE
()
,
device
=
torch
.
device
(
"cuda"
))
else
:
else
:
L_test
=
rope_sizes
[
0
]
# Latent frames
L_test
=
rope_sizes
[
0
]
# Latent frames
...
@@ -359,5 +360,5 @@ class HunyuanScheduler(BaseScheduler):
...
@@ -359,5 +360,5 @@ class HunyuanScheduler(BaseScheduler):
theta_rescale_factor
=
1
,
theta_rescale_factor
=
1
,
)
)
self
.
freqs_cos
=
freqs_cos
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
self
.
freqs_cos
=
freqs_cos
.
to
(
dtype
=
GET_DTYPE
()
,
device
=
torch
.
device
(
"cuda"
))
self
.
freqs_sin
=
freqs_sin
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
torch
.
device
(
"cuda"
))
self
.
freqs_sin
=
freqs_sin
.
to
(
dtype
=
GET_DTYPE
()
,
device
=
torch
.
device
(
"cuda"
))
lightx2v/models/schedulers/scheduler.py
View file @
3e215bad
import
torch
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
...
@@ -15,8 +13,8 @@ class BaseScheduler:
...
@@ -15,8 +13,8 @@ class BaseScheduler:
def
step_pre
(
self
,
step_index
):
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
self
.
step_index
=
step_index
if
GET_DTYPE
()
==
"BF16"
:
if
GET_DTYPE
()
==
GET_SENSITIVE_DTYPE
()
:
self
.
latents
=
self
.
latents
.
to
(
dtype
=
torch
.
bfloat16
)
self
.
latents
=
self
.
latents
.
to
(
GET_DTYPE
()
)
def
clear
(
self
):
def
clear
(
self
):
pass
pass
lightx2v/models/schedulers/wan/audio/scheduler.py
View file @
3e215bad
import
gc
import
gc
import
math
import
math
import
warnings
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch
import
Tensor
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
lightx2v.models.schedulers.scheduler
import
BaseScheduler
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
...
@@ -34,8 +32,8 @@ class EulerSchedulerTimestepFix(BaseScheduler):
...
@@ -34,8 +32,8 @@ class EulerSchedulerTimestepFix(BaseScheduler):
def
step_pre
(
self
,
step_index
):
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
self
.
step_index
=
step_index
if
GET_DTYPE
()
==
"BF16"
:
if
GET_DTYPE
()
==
GET_SENSITIVE_DTYPE
()
:
self
.
latents
=
self
.
latents
.
to
(
dtype
=
torch
.
bfloat16
)
self
.
latents
=
self
.
latents
.
to
(
GET_DTYPE
()
)
def
prepare
(
self
,
image_encoder_output
=
None
):
def
prepare
(
self
,
image_encoder_output
=
None
):
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
self
.
prepare_latents
(
self
.
config
.
target_shape
,
dtype
=
torch
.
float32
)
...
...
lightx2v/models/schedulers/wan/df/skyreels_v2_df_scheduler.py
View file @
3e215bad
...
@@ -5,6 +5,7 @@ import numpy as np
...
@@ -5,6 +5,7 @@ import numpy as np
import
torch
import
torch
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.utils.envs
import
*
class
WanSkyreelsV2DFScheduler
(
WanScheduler
):
class
WanSkyreelsV2DFScheduler
(
WanScheduler
):
...
@@ -132,7 +133,8 @@ class WanSkyreelsV2DFScheduler(WanScheduler):
...
@@ -132,7 +133,8 @@ class WanSkyreelsV2DFScheduler(WanScheduler):
def
step_pre
(
self
,
step_index
):
def
step_pre
(
self
,
step_index
):
self
.
step_index
=
step_index
self
.
step_index
=
step_index
self
.
latents
=
self
.
latents
.
to
(
dtype
=
torch
.
bfloat16
)
if
GET_DTYPE
()
==
GET_SENSITIVE_DTYPE
():
self
.
latents
=
self
.
latents
.
to
(
GET_DTYPE
())
valid_interval_start
,
valid_interval_end
=
self
.
valid_interval
[
step_index
]
valid_interval_start
,
valid_interval_end
=
self
.
valid_interval
[
step_index
]
timestep
=
self
.
step_matrix
[
step_index
][
None
,
valid_interval_start
:
valid_interval_end
].
clone
()
timestep
=
self
.
step_matrix
[
step_index
][
None
,
valid_interval_start
:
valid_interval_end
].
clone
()
...
...
lightx2v/models/video_encoders/hf/autoencoder_kl_causal_3d/unet_causal_3d_blocks.py
View file @
3e215bad
...
@@ -156,7 +156,7 @@ class UpsampleCausal3D(nn.Module):
...
@@ -156,7 +156,7 @@ class UpsampleCausal3D(nn.Module):
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype
=
hidden_states
.
dtype
dtype
=
hidden_states
.
dtype
if
dtype
==
torch
.
bfloat16
:
if
dtype
in
[
torch
.
bfloat16
,
torch
.
float16
]
:
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
...
@@ -185,7 +185,7 @@ class UpsampleCausal3D(nn.Module):
...
@@ -185,7 +185,7 @@ class UpsampleCausal3D(nn.Module):
hidden_states
=
first_h
hidden_states
=
first_h
# If the input is bfloat16, we cast back to bfloat16
# If the input is bfloat16, we cast back to bfloat16
if
dtype
==
torch
.
bfloat16
:
if
dtype
in
[
torch
.
bfloat16
,
torch
.
float16
]
:
hidden_states
=
hidden_states
.
to
(
dtype
)
hidden_states
=
hidden_states
.
to
(
dtype
)
if
self
.
use_conv
:
if
self
.
use_conv
:
...
...
lightx2v/models/video_encoders/hf/cogvideox/model.py
View file @
3e215bad
...
@@ -6,6 +6,7 @@ from diffusers.video_processor import VideoProcessor # type: ignore
...
@@ -6,6 +6,7 @@ from diffusers.video_processor import VideoProcessor # type: ignore
from
safetensors
import
safe_open
# type: ignore
from
safetensors
import
safe_open
# type: ignore
from
lightx2v.models.video_encoders.hf.cogvideox.autoencoder_ks_cogvidex
import
AutoencoderKLCogVideoX
from
lightx2v.models.video_encoders.hf.cogvideox.autoencoder_ks_cogvidex
import
AutoencoderKLCogVideoX
from
lightx2v.utils.envs
import
*
class
CogvideoxVAE
:
class
CogvideoxVAE
:
...
@@ -15,7 +16,7 @@ class CogvideoxVAE:
...
@@ -15,7 +16,7 @@ class CogvideoxVAE:
def
_load_safetensor_to_dict
(
self
,
file_path
):
def
_load_safetensor_to_dict
(
self
,
file_path
):
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
to
(
torch
.
bfloat16
).
cuda
()
for
key
in
f
.
keys
()}
tensor_dict
=
{
key
:
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
()
).
cuda
()
for
key
in
f
.
keys
()}
return
tensor_dict
return
tensor_dict
def
_load_ckpt
(
self
,
model_path
):
def
_load_ckpt
(
self
,
model_path
):
...
@@ -39,7 +40,7 @@ class CogvideoxVAE:
...
@@ -39,7 +40,7 @@ class CogvideoxVAE:
self
.
vae_scale_factor_temporal
=
self
.
vae_config
[
"temporal_compression_ratio"
]
# 4
self
.
vae_scale_factor_temporal
=
self
.
vae_config
[
"temporal_compression_ratio"
]
# 4
self
.
vae_scaling_factor_image
=
self
.
vae_config
[
"scaling_factor"
]
# 0.7
self
.
vae_scaling_factor_image
=
self
.
vae_config
[
"scaling_factor"
]
# 0.7
self
.
model
.
load_state_dict
(
vae_ckpt
)
self
.
model
.
load_state_dict
(
vae_ckpt
)
self
.
model
.
to
(
torch
.
bfloat16
).
to
(
torch
.
device
(
"cuda"
))
self
.
model
.
to
(
GET_DTYPE
()
).
to
(
torch
.
device
(
"cuda"
))
self
.
video_processor
=
VideoProcessor
(
vae_scale_factor
=
self
.
vae_scale_factor_spatial
)
self
.
video_processor
=
VideoProcessor
(
vae_scale_factor
=
self
.
vae_scale_factor_spatial
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
...
lightx2v/utils/envs.py
View file @
3e215bad
import
os
import
os
from
functools
import
lru_cache
from
functools
import
lru_cache
import
torch
DTYPE_MAP
=
{
"BF16"
:
torch
.
bfloat16
,
"FP16"
:
torch
.
float16
,
"FP32"
:
torch
.
float32
,
"bf16"
:
torch
.
bfloat16
,
"fp16"
:
torch
.
float16
,
"fp32"
:
torch
.
float32
,
"torch.bfloat16"
:
torch
.
bfloat16
,
"torch.float16"
:
torch
.
float16
,
"torch.float32"
:
torch
.
float32
,
}
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
CHECK_ENABLE_PROFILING_DEBUG
():
def
CHECK_ENABLE_PROFILING_DEBUG
():
...
@@ -22,5 +36,14 @@ def GET_RUNNING_FLAG():
...
@@ -22,5 +36,14 @@ def GET_RUNNING_FLAG():
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
GET_DTYPE
():
def
GET_DTYPE
():
RUNNING_FLAG
=
os
.
getenv
(
"DTYPE"
)
RUNNING_FLAG
=
os
.
getenv
(
"DTYPE"
,
"BF16"
)
return
RUNNING_FLAG
assert
RUNNING_FLAG
in
[
"BF16"
,
"FP16"
]
return
DTYPE_MAP
[
RUNNING_FLAG
]
@
lru_cache
(
maxsize
=
None
)
def
GET_SENSITIVE_DTYPE
():
RUNNING_FLAG
=
os
.
getenv
(
"SENSITIVE_LAYER_DTYPE"
,
None
)
if
RUNNING_FLAG
is
None
:
return
GET_DTYPE
()
return
DTYPE_MAP
[
RUNNING_FLAG
]
scripts/bench/run_lightx2v_1.sh
View file @
3e215bad
...
@@ -25,7 +25,8 @@ fi
...
@@ -25,7 +25,8 @@ fi
export
TOKENIZERS_PARALLELISM
=
false
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
DTYPE
=
BF16
export
SENSITIVE_LAYER_DTYPE
=
FP32
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
export
ENABLE_GRAPH_MODE
=
false
...
...
scripts/bench/run_lightx2v_5.sh
View file @
3e215bad
...
@@ -25,7 +25,8 @@ fi
...
@@ -25,7 +25,8 @@ fi
export
TOKENIZERS_PARALLELISM
=
false
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
DTYPE
=
BF16
export
SENSITIVE_LAYER_DTYPE
=
FP32
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
export
ENABLE_GRAPH_MODE
=
false
...
...
scripts/bench/run_lightx2v_5_distill.sh
View file @
3e215bad
...
@@ -25,7 +25,8 @@ fi
...
@@ -25,7 +25,8 @@ fi
export
TOKENIZERS_PARALLELISM
=
false
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
DTYPE
=
BF16
export
SENSITIVE_LAYER_DTYPE
=
FP32
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
export
ENABLE_GRAPH_MODE
=
false
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment