Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
27c5575f
Commit
27c5575f
authored
Sep 04, 2025
by
Yang Yong(雍洋)
Committed by
GitHub
Sep 04, 2025
Browse files
Support Multi Levels Profile Log (#290)
parent
dd870f3f
Changes
32
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
97 additions
and
81 deletions
+97
-81
app/run_gradio.sh
app/run_gradio.sh
+1
-1
app/run_gradio_win.bat
app/run_gradio_win.bat
+1
-1
lightx2v/deploy/server/__main__.py
lightx2v/deploy/server/__main__.py
+2
-2
lightx2v/deploy/worker/hub.py
lightx2v/deploy/worker/hub.py
+2
-2
lightx2v/infer.py
lightx2v/infer.py
+2
-2
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+13
-13
lightx2v/models/runners/graph_runner.py
lightx2v/models/runners/graph_runner.py
+2
-2
lightx2v/models/runners/qwen_image/qwen_image_runner.py
lightx2v/models/runners/qwen_image/qwen_image_runner.py
+10
-12
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+12
-12
lightx2v/models/runners/wan/wan_causvid_runner.py
lightx2v/models/runners/wan/wan_causvid_runner.py
+6
-6
lightx2v/models/runners/wan/wan_skyreels_v2_df_runner.py
lightx2v/models/runners/wan/wan_skyreels_v2_df_runner.py
+6
-6
lightx2v/models/runners/wan/wan_vace_runner.py
lightx2v/models/runners/wan/wan_vace_runner.py
+2
-2
lightx2v/models/vfi/rife/rife_comfyui_wrapper.py
lightx2v/models/vfi/rife/rife_comfyui_wrapper.py
+3
-3
lightx2v/utils/envs.py
lightx2v/utils/envs.py
+3
-3
lightx2v/utils/profiler.py
lightx2v/utils/profiler.py
+21
-3
lightx2v/utils/prompt_enhancer.py
lightx2v/utils/prompt_enhancer.py
+2
-2
scripts/base/base.sh
scripts/base/base.sh
+6
-6
scripts/bench/run_lightx2v_1.sh
scripts/bench/run_lightx2v_1.sh
+1
-1
scripts/bench/run_lightx2v_2.sh
scripts/bench/run_lightx2v_2.sh
+1
-1
scripts/bench/run_lightx2v_3.sh
scripts/bench/run_lightx2v_3.sh
+1
-1
No files found.
app/run_gradio.sh
View file @
27c5575f
...
@@ -46,7 +46,7 @@ gpu_id=0
...
@@ -46,7 +46,7 @@ gpu_id=0
export
CUDA_VISIBLE_DEVICES
=
$gpu_id
export
CUDA_VISIBLE_DEVICES
=
$gpu_id
export
CUDA_LAUNCH_BLOCKING
=
1
export
CUDA_LAUNCH_BLOCKING
=
1
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_
PROFILING_DEBUG
=
true
export
PROFILING_DEBUG
_LEVEL
=
2
export
PYTORCH_CUDA_ALLOC_CONF
=
expandable_segments:True
export
PYTORCH_CUDA_ALLOC_CONF
=
expandable_segments:True
# ==================== Parameter Parsing ====================
# ==================== Parameter Parsing ====================
...
...
app/run_gradio_win.bat
View file @
27c5575f
...
@@ -45,7 +45,7 @@ set gpu_id=0
...
@@ -45,7 +45,7 @@ set gpu_id=0
REM ==================== Environment Variables Setup ====================
REM ==================== Environment Variables Setup ====================
set
CUDA_VISIBLE_DEVICES
=
%gpu_id%
set
CUDA_VISIBLE_DEVICES
=
%gpu_id%
set
PYTHONPATH
=
%lightx2
v_path
%
;
%PYTHONPATH%
set
PYTHONPATH
=
%lightx2
v_path
%
;
%PYTHONPATH%
set
ENABLE_
PROFILING_DEBUG
=
true
set
PROFILING_DEBUG
_LEVEL
=
2
set
PYTORCH_CUDA_ALLOC_CONF
=
expandable_segments
:True
set
PYTORCH_CUDA_ALLOC_CONF
=
expandable_segments
:True
REM ==================== Parameter Parsing ====================
REM ==================== Parameter Parsing ====================
...
...
lightx2v/deploy/server/__main__.py
View file @
27c5575f
...
@@ -21,7 +21,7 @@ from lightx2v.deploy.server.auth import AuthManager
...
@@ -21,7 +21,7 @@ from lightx2v.deploy.server.auth import AuthManager
from
lightx2v.deploy.server.metrics
import
MetricMonitor
from
lightx2v.deploy.server.metrics
import
MetricMonitor
from
lightx2v.deploy.server.monitor
import
ServerMonitor
,
WorkerStatus
from
lightx2v.deploy.server.monitor
import
ServerMonitor
,
WorkerStatus
from
lightx2v.deploy.task_manager
import
LocalTaskManager
,
PostgresSQLTaskManager
,
TaskStatus
from
lightx2v.deploy.task_manager
import
LocalTaskManager
,
PostgresSQLTaskManager
,
TaskStatus
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.service_utils
import
ProcessManager
from
lightx2v.utils.service_utils
import
ProcessManager
# =========================
# =========================
...
@@ -679,7 +679,7 @@ if __name__ == "__main__":
...
@@ -679,7 +679,7 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
logger
.
info
(
f
"args:
{
args
}
"
)
logger
.
info
(
f
"args:
{
args
}
"
)
with
ProfilingContext
(
"Init Server Cost"
):
with
ProfilingContext
4DebugL1
(
"Init Server Cost"
):
model_pipelines
=
Pipeline
(
args
.
pipeline_json
)
model_pipelines
=
Pipeline
(
args
.
pipeline_json
)
auth_manager
=
AuthManager
()
auth_manager
=
AuthManager
()
if
args
.
task_url
.
startswith
(
"/"
):
if
args
.
task_url
.
startswith
(
"/"
):
...
...
lightx2v/deploy/worker/hub.py
View file @
27c5575f
...
@@ -16,14 +16,14 @@ from lightx2v.deploy.common.utils import class_try_catch_async
...
@@ -16,14 +16,14 @@ from lightx2v.deploy.common.utils import class_try_catch_async
from
lightx2v.infer
import
init_runner
# noqa
from
lightx2v.infer
import
init_runner
# noqa
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.utils.envs
import
CHECK_ENABLE_GRAPH_MODE
from
lightx2v.utils.envs
import
CHECK_ENABLE_GRAPH_MODE
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.set_config
import
set_config
,
set_parallel_config
from
lightx2v.utils.set_config
import
set_config
,
set_parallel_config
from
lightx2v.utils.utils
import
seed_all
from
lightx2v.utils.utils
import
seed_all
class
BaseWorker
:
class
BaseWorker
:
@
ProfilingContext
(
"Init Worker Worker Cost:"
)
@
ProfilingContext
4DebugL1
(
"Init Worker Worker Cost:"
)
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
config
=
set_config
(
args
)
config
=
set_config
(
args
)
config
[
"mode"
]
=
""
config
[
"mode"
]
=
""
...
...
lightx2v/infer.py
View file @
27c5575f
...
@@ -15,7 +15,7 @@ from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner #
...
@@ -15,7 +15,7 @@ from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner #
from
lightx2v.models.runners.wan.wan_skyreels_v2_df_runner
import
WanSkyreelsV2DFRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_skyreels_v2_df_runner
import
WanSkyreelsV2DFRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_vace_runner
import
WanVaceRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_vace_runner
import
WanVaceRunner
# noqa: F401
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.set_config
import
print_config
,
set_config
,
set_parallel_config
from
lightx2v.utils.set_config
import
print_config
,
set_config
,
set_parallel_config
from
lightx2v.utils.utils
import
seed_all
from
lightx2v.utils.utils
import
seed_all
...
@@ -103,7 +103,7 @@ def main():
...
@@ -103,7 +103,7 @@ def main():
print_config
(
config
)
print_config
(
config
)
with
ProfilingContext
(
"Total Cost"
):
with
ProfilingContext
4DebugL1
(
"Total Cost"
):
runner
=
init_runner
(
config
)
runner
=
init_runner
(
config
)
runner
.
run_pipeline
()
runner
.
run_pipeline
()
...
...
lightx2v/models/runners/default_runner.py
View file @
27c5575f
...
@@ -10,7 +10,7 @@ from requests.exceptions import RequestException
...
@@ -10,7 +10,7 @@ from requests.exceptions import RequestException
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.generate_task_id
import
generate_task_id
from
lightx2v.utils.generate_task_id
import
generate_task_id
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
from
.base_runner
import
BaseRunner
from
.base_runner
import
BaseRunner
...
@@ -60,7 +60,7 @@ class DefaultRunner(BaseRunner):
...
@@ -60,7 +60,7 @@ class DefaultRunner(BaseRunner):
else
:
else
:
raise
ValueError
(
f
"Unsupported VFI model:
{
self
.
config
[
'video_frame_interpolation'
][
'algo'
]
}
"
)
raise
ValueError
(
f
"Unsupported VFI model:
{
self
.
config
[
'video_frame_interpolation'
][
'algo'
]
}
"
)
@
ProfilingContext
(
"Load models"
)
@
ProfilingContext
4DebugL2
(
"Load models"
)
def
load_model
(
self
):
def
load_model
(
self
):
self
.
model
=
self
.
load_transformer
()
self
.
model
=
self
.
load_transformer
()
self
.
text_encoders
=
self
.
load_text_encoder
()
self
.
text_encoders
=
self
.
load_text_encoder
()
...
@@ -116,13 +116,13 @@ class DefaultRunner(BaseRunner):
...
@@ -116,13 +116,13 @@ class DefaultRunner(BaseRunner):
self
.
check_stop
()
self
.
check_stop
()
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
total_steps
}
"
)
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
total_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
with
ProfilingContext4Debug
L1
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
with
ProfilingContext4Debug
(
"🚀 infer_main"
):
with
ProfilingContext4Debug
L1
(
"🚀 infer_main"
):
self
.
model
.
infer
(
self
.
inputs
)
self
.
model
.
infer
(
self
.
inputs
)
with
ProfilingContext4Debug
(
"step_post"
):
with
ProfilingContext4Debug
L1
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
self
.
model
.
scheduler
.
step_post
()
if
self
.
progress_callback
:
if
self
.
progress_callback
:
...
@@ -155,7 +155,7 @@ class DefaultRunner(BaseRunner):
...
@@ -155,7 +155,7 @@ class DefaultRunner(BaseRunner):
img
=
TF
.
to_tensor
(
img_ori
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
cuda
()
img
=
TF
.
to_tensor
(
img_ori
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
cuda
()
return
img
,
img_ori
return
img
,
img_ori
@
ProfilingContext
(
"Run Encoders"
)
@
ProfilingContext
4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_i2v
(
self
):
def
_run_input_encoder_local_i2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
img
,
img_ori
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
img
,
img_ori
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
...
@@ -166,7 +166,7 @@ class DefaultRunner(BaseRunner):
...
@@ -166,7 +166,7 @@ class DefaultRunner(BaseRunner):
gc
.
collect
()
gc
.
collect
()
return
self
.
get_encoder_output_i2v
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
,
img
)
return
self
.
get_encoder_output_i2v
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
,
img
)
@
ProfilingContext
(
"Run Encoders"
)
@
ProfilingContext
4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_t2v
(
self
):
def
_run_input_encoder_local_t2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
None
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
None
)
...
@@ -177,7 +177,7 @@ class DefaultRunner(BaseRunner):
...
@@ -177,7 +177,7 @@ class DefaultRunner(BaseRunner):
"image_encoder_output"
:
None
,
"image_encoder_output"
:
None
,
}
}
@
ProfilingContext
(
"Run Encoders"
)
@
ProfilingContext
4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_flf2v
(
self
):
def
_run_input_encoder_local_flf2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
first_frame
,
_
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
first_frame
,
_
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
...
@@ -189,7 +189,7 @@ class DefaultRunner(BaseRunner):
...
@@ -189,7 +189,7 @@ class DefaultRunner(BaseRunner):
gc
.
collect
()
gc
.
collect
()
return
self
.
get_encoder_output_i2v
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
)
return
self
.
get_encoder_output_i2v
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
)
@
ProfilingContext
(
"Run Encoders"
)
@
ProfilingContext
4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_vace
(
self
):
def
_run_input_encoder_local_vace
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
src_video
=
self
.
config
.
get
(
"src_video"
,
None
)
src_video
=
self
.
config
.
get
(
"src_video"
,
None
)
...
@@ -219,12 +219,12 @@ class DefaultRunner(BaseRunner):
...
@@ -219,12 +219,12 @@ class DefaultRunner(BaseRunner):
if
self
.
config
.
get
(
"model_cls"
)
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
if
self
.
config
.
get
(
"model_cls"
)
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
self
.
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
=
None
self
.
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
=
None
@
ProfilingContext
(
"Run DiT"
)
@
ProfilingContext
4DebugL2
(
"Run DiT"
)
def
run_main
(
self
,
total_steps
=
None
):
def
run_main
(
self
,
total_steps
=
None
):
self
.
init_run
()
self
.
init_run
()
for
segment_idx
in
range
(
self
.
video_segment_num
):
for
segment_idx
in
range
(
self
.
video_segment_num
):
logger
.
info
(
f
"🔄 segment
_idx:
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
"
)
logger
.
info
(
f
"🔄
start
segment
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
"
)
with
ProfilingContext
(
f
"segment end2end
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
"
):
with
ProfilingContext
4DebugL1
(
f
"segment end2end
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
"
):
self
.
check_stop
()
self
.
check_stop
()
# 1. default do nothing
# 1. default do nothing
self
.
init_run_segment
(
segment_idx
)
self
.
init_run_segment
(
segment_idx
)
...
@@ -236,7 +236,7 @@ class DefaultRunner(BaseRunner):
...
@@ -236,7 +236,7 @@ class DefaultRunner(BaseRunner):
self
.
end_run_segment
()
self
.
end_run_segment
()
self
.
end_run
()
self
.
end_run
()
@
ProfilingContext
(
"Run VAE Decoder"
)
@
ProfilingContext
4DebugL1
(
"Run VAE Decoder"
)
def
run_vae_decoder
(
self
,
latents
):
def
run_vae_decoder
(
self
,
latents
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
vae_decoder
=
self
.
load_vae_decoder
()
self
.
vae_decoder
=
self
.
load_vae_decoder
()
...
...
lightx2v/models/runners/graph_runner.py
View file @
27c5575f
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
*
class
GraphRunner
:
class
GraphRunner
:
...
@@ -13,7 +13,7 @@ class GraphRunner:
...
@@ -13,7 +13,7 @@ class GraphRunner:
logger
.
info
(
"🚀 Starting Model Compilation - Please wait, this may take a while... 🚀"
)
logger
.
info
(
"🚀 Starting Model Compilation - Please wait, this may take a while... 🚀"
)
logger
.
info
(
"="
*
60
)
logger
.
info
(
"="
*
60
)
with
ProfilingContext4Debug
(
"compile"
):
with
ProfilingContext4Debug
L2
(
"compile"
):
self
.
runner
.
run_step
()
self
.
runner
.
run_step
()
logger
.
info
(
"="
*
60
)
logger
.
info
(
"="
*
60
)
...
...
lightx2v/models/runners/qwen_image/qwen_image_runner.py
View file @
27c5575f
...
@@ -10,7 +10,7 @@ from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel
...
@@ -10,7 +10,7 @@ from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.schedulers.qwen_image.scheduler
import
QwenImageScheduler
from
lightx2v.models.schedulers.qwen_image.scheduler
import
QwenImageScheduler
from
lightx2v.models.video_encoders.hf.qwen_image.vae
import
AutoencoderKLQwenImageVAE
from
lightx2v.models.video_encoders.hf.qwen_image.vae
import
AutoencoderKLQwenImageVAE
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
...
@@ -32,7 +32,7 @@ class QwenImageRunner(DefaultRunner):
...
@@ -32,7 +32,7 @@ class QwenImageRunner(DefaultRunner):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
@
ProfilingContext
(
"Load models"
)
@
ProfilingContext
4DebugL2
(
"Load models"
)
def
load_model
(
self
):
def
load_model
(
self
):
self
.
model
=
self
.
load_transformer
()
self
.
model
=
self
.
load_transformer
()
self
.
text_encoders
=
self
.
load_text_encoder
()
self
.
text_encoders
=
self
.
load_text_encoder
()
...
@@ -69,7 +69,7 @@ class QwenImageRunner(DefaultRunner):
...
@@ -69,7 +69,7 @@ class QwenImageRunner(DefaultRunner):
else
:
else
:
assert
NotImplementedError
assert
NotImplementedError
@
ProfilingContext
(
"Run DiT"
)
@
ProfilingContext
4DebugL2
(
"Run DiT"
)
def
_run_dit_local
(
self
,
total_steps
=
None
):
def
_run_dit_local
(
self
,
total_steps
=
None
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
=
self
.
load_transformer
()
self
.
model
=
self
.
load_transformer
()
...
@@ -81,7 +81,7 @@ class QwenImageRunner(DefaultRunner):
...
@@ -81,7 +81,7 @@ class QwenImageRunner(DefaultRunner):
self
.
end_run
()
self
.
end_run
()
return
latents
,
generator
return
latents
,
generator
@
ProfilingContext
(
"Run Encoders"
)
@
ProfilingContext
4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_t2i
(
self
):
def
_run_input_encoder_local_t2i
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
text_encoder_output
=
self
.
run_text_encoder
(
prompt
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
)
...
@@ -92,7 +92,7 @@ class QwenImageRunner(DefaultRunner):
...
@@ -92,7 +92,7 @@ class QwenImageRunner(DefaultRunner):
"image_encoder_output"
:
None
,
"image_encoder_output"
:
None
,
}
}
@
ProfilingContext
(
"Run Encoders"
)
@
ProfilingContext
4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_i2i
(
self
):
def
_run_input_encoder_local_i2i
(
self
):
image
=
Image
.
open
(
self
.
config
[
"image_path"
])
image
=
Image
.
open
(
self
.
config
[
"image_path"
])
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
...
@@ -125,20 +125,18 @@ class QwenImageRunner(DefaultRunner):
...
@@ -125,20 +125,18 @@ class QwenImageRunner(DefaultRunner):
return
{
"image_latents"
:
image_latents
}
return
{
"image_latents"
:
image_latents
}
def
run
(
self
,
total_steps
=
None
):
def
run
(
self
,
total_steps
=
None
):
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
if
total_steps
is
None
:
if
total_steps
is
None
:
total_steps
=
self
.
model
.
scheduler
.
infer_steps
total_steps
=
self
.
model
.
scheduler
.
infer_steps
for
step_index
in
range
(
total_steps
):
for
step_index
in
range
(
total_steps
):
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
total_steps
}
"
)
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
total_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
with
ProfilingContext4Debug
L1
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
with
ProfilingContext4Debug
(
"🚀 infer_main"
):
with
ProfilingContext4Debug
L1
(
"🚀 infer_main"
):
self
.
model
.
infer
(
self
.
inputs
)
self
.
model
.
infer
(
self
.
inputs
)
with
ProfilingContext4Debug
(
"step_post"
):
with
ProfilingContext4Debug
L1
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
self
.
model
.
scheduler
.
step_post
()
if
self
.
progress_callback
:
if
self
.
progress_callback
:
...
@@ -181,7 +179,7 @@ class QwenImageRunner(DefaultRunner):
...
@@ -181,7 +179,7 @@ class QwenImageRunner(DefaultRunner):
def
run_image_encoder
(
self
):
def
run_image_encoder
(
self
):
pass
pass
@
ProfilingContext
(
"Load models"
)
@
ProfilingContext
4DebugL2
(
"Load models"
)
def
load_model
(
self
):
def
load_model
(
self
):
self
.
model
=
self
.
load_transformer
()
self
.
model
=
self
.
load_transformer
()
self
.
text_encoders
=
self
.
load_text_encoder
()
self
.
text_encoders
=
self
.
load_text_encoder
()
...
@@ -189,7 +187,7 @@ class QwenImageRunner(DefaultRunner):
...
@@ -189,7 +187,7 @@ class QwenImageRunner(DefaultRunner):
self
.
vae
=
self
.
load_vae
()
self
.
vae
=
self
.
load_vae
()
self
.
vfi_model
=
self
.
load_vfi_model
()
if
"video_frame_interpolation"
in
self
.
config
else
None
self
.
vfi_model
=
self
.
load_vfi_model
()
if
"video_frame_interpolation"
in
self
.
config
else
None
@
ProfilingContext
(
"Run VAE Decoder"
)
@
ProfilingContext
4DebugL1
(
"Run VAE Decoder"
)
def
_run_vae_decoder_local
(
self
,
latents
,
generator
):
def
_run_vae_decoder_local
(
self
,
latents
,
generator
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
vae_decoder
=
self
.
load_vae
()
self
.
vae_decoder
=
self
.
load_vae
()
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
27c5575f
...
@@ -25,7 +25,7 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner
...
@@ -25,7 +25,7 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner
from
lightx2v.models.schedulers.wan.audio.scheduler
import
EulerScheduler
from
lightx2v.models.schedulers.wan.audio.scheduler
import
EulerScheduler
from
lightx2v.models.video_encoders.hf.wan.vae_2_2
import
Wan2_2_VAE
from
lightx2v.models.video_encoders.hf.wan.vae_2_2
import
Wan2_2_VAE
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
find_torch_model_path
,
load_weights
,
save_to_video
,
vae_to_comfyui_image
from
lightx2v.utils.utils
import
find_torch_model_path
,
load_weights
,
save_to_video
,
vae_to_comfyui_image
...
@@ -368,7 +368,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -368,7 +368,7 @@ class WanAudioRunner(WanRunner): # type:ignore
gc
.
collect
()
gc
.
collect
()
return
vae_encoder_out
return
vae_encoder_out
@
ProfilingContext
(
"Run Encoders"
)
@
ProfilingContext
4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_r2v_audio
(
self
):
def
_run_input_encoder_local_r2v_audio
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
img
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
img
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
...
@@ -410,7 +410,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -410,7 +410,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
vae_encoder
=
self
.
load_vae_encoder
()
self
.
vae_encoder
=
self
.
load_vae_encoder
()
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
with
ProfilingContext4Debug
(
"vae_encoder in init run segment"
):
with
ProfilingContext4Debug
L1
(
"vae_encoder in init run segment"
):
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
if
prev_video
is
not
None
:
if
prev_video
is
not
None
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
dtype
))
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
dtype
))
...
@@ -460,7 +460,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -460,7 +460,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
cut_audio_list
=
[]
self
.
cut_audio_list
=
[]
self
.
prev_video
=
None
self
.
prev_video
=
None
@
ProfilingContext4Debug
(
"Init run segment"
)
@
ProfilingContext4Debug
L1
(
"Init run segment"
)
def
init_run_segment
(
self
,
segment_idx
,
audio_array
=
None
):
def
init_run_segment
(
self
,
segment_idx
,
audio_array
=
None
):
self
.
segment_idx
=
segment_idx
self
.
segment_idx
=
segment_idx
if
audio_array
is
not
None
:
if
audio_array
is
not
None
:
...
@@ -485,7 +485,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -485,7 +485,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if
segment_idx
>
0
:
if
segment_idx
>
0
:
self
.
model
.
scheduler
.
reset
(
self
.
inputs
[
"previmg_encoder_output"
])
self
.
model
.
scheduler
.
reset
(
self
.
inputs
[
"previmg_encoder_output"
])
@
ProfilingContext4Debug
(
"End run segment"
)
@
ProfilingContext4Debug
L1
(
"End run segment"
)
def
end_run_segment
(
self
):
def
end_run_segment
(
self
):
self
.
gen_video
=
torch
.
clamp
(
self
.
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
self
.
gen_video
=
torch
.
clamp
(
self
.
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
useful_length
=
self
.
segment
.
end_frame
-
self
.
segment
.
start_frame
useful_length
=
self
.
segment
.
end_frame
-
self
.
segment
.
start_frame
...
@@ -575,7 +575,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -575,7 +575,7 @@ class WanAudioRunner(WanRunner): # type:ignore
max_fail_count
=
10
max_fail_count
=
10
while
True
:
while
True
:
with
ProfilingContext4Debug
(
f
"stream segment get audio segment
{
segment_idx
}
"
):
with
ProfilingContext4Debug
L1
(
f
"stream segment get audio segment
{
segment_idx
}
"
):
self
.
check_stop
()
self
.
check_stop
()
audio_array
=
self
.
va_reader
.
get_audio_segment
(
timeout
=
fetch_timeout
)
audio_array
=
self
.
va_reader
.
get_audio_segment
(
timeout
=
fetch_timeout
)
if
audio_array
is
None
:
if
audio_array
is
None
:
...
@@ -585,7 +585,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -585,7 +585,7 @@ class WanAudioRunner(WanRunner): # type:ignore
raise
Exception
(
f
"Failed to get audio chunk
{
fail_count
}
times, stop reader"
)
raise
Exception
(
f
"Failed to get audio chunk
{
fail_count
}
times, stop reader"
)
continue
continue
with
ProfilingContext4Debug
(
f
"stream segment end2end
{
segment_idx
}
"
):
with
ProfilingContext4Debug
L1
(
f
"stream segment end2end
{
segment_idx
}
"
):
fail_count
=
0
fail_count
=
0
self
.
init_run_segment
(
segment_idx
,
audio_array
)
self
.
init_run_segment
(
segment_idx
,
audio_array
)
latents
=
self
.
run_segment
(
total_steps
=
None
)
latents
=
self
.
run_segment
(
total_steps
=
None
)
...
@@ -603,7 +603,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -603,7 +603,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
va_recorder
.
stop
(
wait
=
False
)
self
.
va_recorder
.
stop
(
wait
=
False
)
self
.
va_recorder
=
None
self
.
va_recorder
=
None
@
ProfilingContext4Debug
(
"Process after vae decoder"
)
@
ProfilingContext4Debug
L1
(
"Process after vae decoder"
)
def
process_images_after_vae_decoder
(
self
,
save_video
=
True
):
def
process_images_after_vae_decoder
(
self
,
save_video
=
True
):
# Merge results
# Merge results
gen_lvideo
=
torch
.
cat
(
self
.
gen_video_list
,
dim
=
2
).
float
()
gen_lvideo
=
torch
.
cat
(
self
.
gen_video_list
,
dim
=
2
).
float
()
...
@@ -728,12 +728,12 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -728,12 +728,12 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_adapter
.
load_state_dict
(
weights_dict
,
strict
=
False
)
audio_adapter
.
load_state_dict
(
weights_dict
,
strict
=
False
)
return
audio_adapter
.
to
(
dtype
=
GET_DTYPE
())
return
audio_adapter
.
to
(
dtype
=
GET_DTYPE
())
@
ProfilingContext
(
"Load models"
)
def
load_model
(
self
):
def
load_model
(
self
):
super
().
load_model
()
super
().
load_model
()
self
.
audio_encoder
=
self
.
load_audio_encoder
()
with
ProfilingContext4DebugL2
(
"Load audio encoder and adapter"
):
self
.
audio_adapter
=
self
.
load_audio_adapter
()
self
.
audio_encoder
=
self
.
load_audio_encoder
()
self
.
model
.
set_audio_adapter
(
self
.
audio_adapter
)
self
.
audio_adapter
=
self
.
load_audio_adapter
()
self
.
model
.
set_audio_adapter
(
self
.
audio_adapter
)
def
set_target_shape
(
self
):
def
set_target_shape
(
self
):
"""Set target shape for generation"""
"""Set target shape for generation"""
...
...
lightx2v/models/runners/wan/wan_causvid_runner.py
View file @
27c5575f
...
@@ -9,7 +9,7 @@ from lightx2v.models.networks.wan.model import WanModel
...
@@ -9,7 +9,7 @@ from lightx2v.models.networks.wan.model import WanModel
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.schedulers.wan.step_distill.scheduler
import
WanStepDistillScheduler
from
lightx2v.models.schedulers.wan.step_distill.scheduler
import
WanStepDistillScheduler
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
...
@@ -85,11 +85,11 @@ class WanCausVidRunner(WanRunner):
...
@@ -85,11 +85,11 @@ class WanCausVidRunner(WanRunner):
if
fragment_idx
>
0
:
if
fragment_idx
>
0
:
logger
.
info
(
"recompute the kv_cache ..."
)
logger
.
info
(
"recompute the kv_cache ..."
)
with
ProfilingContext4Debug
(
"step_pre"
):
with
ProfilingContext4Debug
L1
(
"step_pre"
):
self
.
model
.
scheduler
.
latents
=
self
.
model
.
scheduler
.
last_sample
self
.
model
.
scheduler
.
latents
=
self
.
model
.
scheduler
.
last_sample
self
.
model
.
scheduler
.
step_pre
(
step_index
=
self
.
model
.
scheduler
.
infer_steps
-
1
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
self
.
model
.
scheduler
.
infer_steps
-
1
)
with
ProfilingContext4Debug
(
"🚀 infer_main"
):
with
ProfilingContext4Debug
L1
(
"🚀 infer_main"
):
self
.
model
.
infer
(
self
.
inputs
,
kv_start
,
kv_end
)
self
.
model
.
infer
(
self
.
inputs
,
kv_start
,
kv_end
)
kv_start
+=
self
.
num_frame_per_block
*
self
.
frame_seq_length
kv_start
+=
self
.
num_frame_per_block
*
self
.
frame_seq_length
...
@@ -105,13 +105,13 @@ class WanCausVidRunner(WanRunner):
...
@@ -105,13 +105,13 @@ class WanCausVidRunner(WanRunner):
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
with
ProfilingContext4Debug
L1
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
with
ProfilingContext4Debug
(
"🚀 infer_main"
):
with
ProfilingContext4Debug
L1
(
"🚀 infer_main"
):
self
.
model
.
infer
(
self
.
inputs
,
kv_start
,
kv_end
)
self
.
model
.
infer
(
self
.
inputs
,
kv_start
,
kv_end
)
with
ProfilingContext4Debug
(
"step_post"
):
with
ProfilingContext4Debug
L1
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
self
.
model
.
scheduler
.
step_post
()
kv_start
+=
self
.
num_frame_per_block
*
self
.
frame_seq_length
kv_start
+=
self
.
num_frame_per_block
*
self
.
frame_seq_length
...
...
lightx2v/models/runners/wan/wan_skyreels_v2_df_runner.py
View file @
27c5575f
...
@@ -10,7 +10,7 @@ from loguru import logger
...
@@ -10,7 +10,7 @@ from loguru import logger
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.schedulers.wan.df.skyreels_v2_df_scheduler
import
WanSkyreelsV2DFScheduler
from
lightx2v.models.schedulers.wan.df.skyreels_v2_df_scheduler
import
WanSkyreelsV2DFScheduler
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
...
@@ -55,9 +55,9 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
...
@@ -55,9 +55,9 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
def
run_input_encoder
(
self
):
def
run_input_encoder
(
self
):
image_encoder_output
=
None
image_encoder_output
=
None
if
os
.
path
.
isfile
(
self
.
config
.
image_path
):
if
os
.
path
.
isfile
(
self
.
config
.
image_path
):
with
ProfilingContext
(
"Run Img Encoder"
):
with
ProfilingContext
4DebugL2
(
"Run Img Encoder"
):
image_encoder_output
=
self
.
run_image_encoder
(
self
.
config
,
self
.
image_encoder
,
self
.
vae_model
)
image_encoder_output
=
self
.
run_image_encoder
(
self
.
config
,
self
.
image_encoder
,
self
.
vae_model
)
with
ProfilingContext
(
"Run Text Encoder"
):
with
ProfilingContext
4DebugL2
(
"Run Text Encoder"
):
text_encoder_output
=
self
.
run_text_encoder
(
self
.
config
[
"prompt"
],
self
.
text_encoders
,
self
.
config
,
image_encoder_output
)
text_encoder_output
=
self
.
run_text_encoder
(
self
.
config
[
"prompt"
],
self
.
text_encoders
,
self
.
config
,
image_encoder_output
)
self
.
set_target_shape
()
self
.
set_target_shape
()
self
.
inputs
=
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
}
self
.
inputs
=
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
}
...
@@ -107,13 +107,13 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
...
@@ -107,13 +107,13 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
with
ProfilingContext4Debug
L1
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
with
ProfilingContext4Debug
(
"🚀 infer_main"
):
with
ProfilingContext4Debug
L1
(
"🚀 infer_main"
):
self
.
model
.
infer
(
self
.
inputs
)
self
.
model
.
infer
(
self
.
inputs
)
with
ProfilingContext4Debug
(
"step_post"
):
with
ProfilingContext4Debug
L1
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
self
.
model
.
scheduler
.
step_post
()
videos
=
self
.
run_vae
(
self
.
model
.
scheduler
.
latents
,
self
.
model
.
scheduler
.
generator
)
videos
=
self
.
run_vae
(
self
.
model
.
scheduler
.
latents
,
self
.
model
.
scheduler
.
generator
)
...
...
lightx2v/models/runners/wan/wan_vace_runner.py
View file @
27c5575f
...
@@ -9,7 +9,7 @@ from lightx2v.models.input_encoders.hf.vace.vace_processor import VaceVideoProce
...
@@ -9,7 +9,7 @@ from lightx2v.models.input_encoders.hf.vace.vace_processor import VaceVideoProce
from
lightx2v.models.networks.wan.vace_model
import
WanVaceModel
from
lightx2v.models.networks.wan.vace_model
import
WanVaceModel
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
...
@@ -159,7 +159,7 @@ class WanVaceRunner(WanRunner):
...
@@ -159,7 +159,7 @@ class WanVaceRunner(WanRunner):
target_shape
[
0
]
=
int
(
target_shape
[
0
]
/
2
)
target_shape
[
0
]
=
int
(
target_shape
[
0
]
/
2
)
self
.
config
.
target_shape
=
target_shape
self
.
config
.
target_shape
=
target_shape
@
ProfilingContext
(
"Run VAE Decoder"
)
@
ProfilingContext
4DebugL1
(
"Run VAE Decoder"
)
def
run_vae_decoder
(
self
,
latents
):
def
run_vae_decoder
(
self
,
latents
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
vae_decoder
=
self
.
load_vae_decoder
()
self
.
vae_decoder
=
self
.
load_vae_decoder
()
...
...
lightx2v/models/vfi/rife/rife_comfyui_wrapper.py
View file @
27c5575f
...
@@ -4,7 +4,7 @@ from typing import List, Optional, Tuple
...
@@ -4,7 +4,7 @@ from typing import List, Optional, Tuple
import
torch
import
torch
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
*
class
RIFEWrapper
:
class
RIFEWrapper
:
...
@@ -25,12 +25,12 @@ class RIFEWrapper:
...
@@ -25,12 +25,12 @@ class RIFEWrapper:
from
.train_log.RIFE_HDv3
import
Model
from
.train_log.RIFE_HDv3
import
Model
self
.
model
=
Model
()
self
.
model
=
Model
()
with
ProfilingContext
(
"Load RIFE model"
):
with
ProfilingContext
4DebugL2
(
"Load RIFE model"
):
self
.
model
.
load_model
(
model_path
,
-
1
)
self
.
model
.
load_model
(
model_path
,
-
1
)
self
.
model
.
eval
()
self
.
model
.
eval
()
self
.
model
.
device
()
self
.
model
.
device
()
@
ProfilingContext
(
"Interpolate frames"
)
@
ProfilingContext
4DebugL2
(
"Interpolate frames"
)
def
interpolate_frames
(
def
interpolate_frames
(
self
,
self
,
images
:
torch
.
Tensor
,
images
:
torch
.
Tensor
,
...
...
lightx2v/utils/envs.py
View file @
27c5575f
...
@@ -17,9 +17,9 @@ DTYPE_MAP = {
...
@@ -17,9 +17,9 @@ DTYPE_MAP = {
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
CHECK_
ENABLE_
PROFILING_DEBUG
(
):
def
CHECK_PROFILING_DEBUG
_LEVEL
(
target_level
):
ENABLE_PROFILING_DEBUG
=
os
.
getenv
(
"
ENABLE_
PROFILING_DEBUG
"
,
"false"
).
lower
()
==
"true"
current_level
=
int
(
os
.
getenv
(
"PROFILING_DEBUG
_LEVEL"
,
"0"
))
return
ENABLE_PROFILING_DEBUG
return
current_level
>=
target_level
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
...
...
lightx2v/utils/profiler.py
View file @
27c5575f
...
@@ -12,7 +12,6 @@ from lightx2v.utils.envs import *
...
@@ -12,7 +12,6 @@ from lightx2v.utils.envs import *
class
_ProfilingContext
:
class
_ProfilingContext
:
def
__init__
(
self
,
name
):
def
__init__
(
self
,
name
):
self
.
name
=
name
self
.
name
=
name
self
.
rank_info
=
""
if
dist
.
is_initialized
():
if
dist
.
is_initialized
():
self
.
rank_info
=
f
"Rank
{
dist
.
get_rank
()
}
"
self
.
rank_info
=
f
"Rank
{
dist
.
get_rank
()
}
"
else
:
else
:
...
@@ -80,5 +79,24 @@ class _NullContext:
...
@@ -80,5 +79,24 @@ class _NullContext:
return
func
return
func
ProfilingContext
=
_ProfilingContext
class
_ProfilingContextL1
(
_ProfilingContext
):
ProfilingContext4Debug
=
_ProfilingContext
if
CHECK_ENABLE_PROFILING_DEBUG
()
else
_NullContext
"""Level 1 profiling context with Level1_Log prefix."""
def
__init__
(
self
,
name
):
super
().
__init__
(
f
"Level1_Log
{
name
}
"
)
class
_ProfilingContextL2
(
_ProfilingContext
):
"""Level 2 profiling context with Level2_Log prefix."""
def
__init__
(
self
,
name
):
super
().
__init__
(
f
"Level2_Log
{
name
}
"
)
"""
PROFILING_DEBUG_LEVEL=0: [Default] disable all profiling
PROFILING_DEBUG_LEVEL=1: enable ProfilingContext4DebugL1
PROFILING_DEBUG_LEVEL=2: enable ProfilingContext4DebugL1 and ProfilingContext4DebugL2
"""
ProfilingContext4DebugL1
=
_ProfilingContextL1
if
CHECK_PROFILING_DEBUG_LEVEL
(
1
)
else
_NullContext
# if user >= 1, enable profiling
ProfilingContext4DebugL2
=
_ProfilingContextL2
if
CHECK_PROFILING_DEBUG_LEVEL
(
2
)
else
_NullContext
# if user >= 2, enable profiling
lightx2v/utils/prompt_enhancer.py
View file @
27c5575f
...
@@ -4,7 +4,7 @@ import torch
...
@@ -4,7 +4,7 @@ import torch
from
loguru
import
logger
from
loguru
import
logger
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
*
sys_prompt
=
"""
sys_prompt
=
"""
Transform the short prompt into a detailed video-generation caption using this structure:
Transform the short prompt into a detailed video-generation caption using this structure:
...
@@ -40,7 +40,7 @@ class PromptEnhancer:
...
@@ -40,7 +40,7 @@ class PromptEnhancer:
def
to_device
(
self
,
device
):
def
to_device
(
self
,
device
):
self
.
model
=
self
.
model
.
to
(
device
)
self
.
model
=
self
.
model
.
to
(
device
)
@
ProfilingContext
(
"Run prompt enhancer"
)
@
ProfilingContext
4DebugL1
(
"Run prompt enhancer"
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
):
def
__call__
(
self
,
prompt
):
prompt
=
prompt
.
strip
()
prompt
=
prompt
.
strip
()
...
...
scripts/base/base.sh
View file @
27c5575f
...
@@ -32,12 +32,12 @@ export DTYPE=BF16
...
@@ -32,12 +32,12 @@ export DTYPE=BF16
# Note: If set to FP32, it will be slower, so we recommend set ENABLE_GRAPH_MODE to true.
# Note: If set to FP32, it will be slower, so we recommend set ENABLE_GRAPH_MODE to true.
export
SENSITIVE_LAYER_DTYPE
=
FP32
export
SENSITIVE_LAYER_DTYPE
=
FP32
# Performance Profiling Debug
Mode
(Debug Only)
# Performance Profiling Debug
Level
(Debug Only)
# Enables detailed performance analysis output, such as time cost and memory usage
# Enables detailed performance analysis output, such as time cost and memory usage
# Available options: [
true, false
]
# Available options: [
0, 1, 2
]
# If not set, default value:
false
# If not set, default value:
0
# Note: This option can be set to
false
for production.
# Note: This option can be set to
0
for production.
export
ENABLE_
PROFILING_DEBUG
=
true
export
PROFILING_DEBUG
_LEVEL
=
2
# Graph Mode Optimization (Performance Enhancement)
# Graph Mode Optimization (Performance Enhancement)
# Enables torch.compile for graph optimization, can improve inference performance
# Enables torch.compile for graph optimization, can improve inference performance
...
@@ -56,6 +56,6 @@ echo "model_path: ${model_path}"
...
@@ -56,6 +56,6 @@ echo "model_path: ${model_path}"
echo
"-------------------------------------------------------------------------------"
echo
"-------------------------------------------------------------------------------"
echo
"Model Inference Data Type:
${
DTYPE
}
"
echo
"Model Inference Data Type:
${
DTYPE
}
"
echo
"Sensitive Layer Data Type:
${
SENSITIVE_LAYER_DTYPE
}
"
echo
"Sensitive Layer Data Type:
${
SENSITIVE_LAYER_DTYPE
}
"
echo
"Performance Profiling Debug
Mode:
${
ENABLE_
PROFILING_DEBUG
}
"
echo
"Performance Profiling Debug
Level:
${
PROFILING_DEBUG
_LEVEL
}
"
echo
"Graph Mode Optimization:
${
ENABLE_GRAPH_MODE
}
"
echo
"Graph Mode Optimization:
${
ENABLE_GRAPH_MODE
}
"
echo
"==============================================================================="
echo
"==============================================================================="
scripts/bench/run_lightx2v_1.sh
View file @
27c5575f
...
@@ -27,7 +27,7 @@ export TOKENIZERS_PARALLELISM=false
...
@@ -27,7 +27,7 @@ export TOKENIZERS_PARALLELISM=false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
DTYPE
=
BF16
export
DTYPE
=
BF16
export
SENSITIVE_LAYER_DTYPE
=
FP32
export
SENSITIVE_LAYER_DTYPE
=
FP32
export
ENABLE_
PROFILING_DEBUG
=
true
export
PROFILING_DEBUG
_LEVEL
=
2
export
ENABLE_GRAPH_MODE
=
false
export
ENABLE_GRAPH_MODE
=
false
python
-m
lightx2v.infer
\
python
-m
lightx2v.infer
\
...
...
scripts/bench/run_lightx2v_2.sh
View file @
27c5575f
...
@@ -26,7 +26,7 @@ export TOKENIZERS_PARALLELISM=false
...
@@ -26,7 +26,7 @@ export TOKENIZERS_PARALLELISM=false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_
PROFILING_DEBUG
=
true
export
PROFILING_DEBUG
_LEVEL
=
2
export
ENABLE_GRAPH_MODE
=
false
export
ENABLE_GRAPH_MODE
=
false
export
DTYPE
=
BF16
export
DTYPE
=
BF16
...
...
scripts/bench/run_lightx2v_3.sh
View file @
27c5575f
...
@@ -26,7 +26,7 @@ export TOKENIZERS_PARALLELISM=false
...
@@ -26,7 +26,7 @@ export TOKENIZERS_PARALLELISM=false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_
PROFILING_DEBUG
=
true
export
PROFILING_DEBUG
_LEVEL
=
2
export
ENABLE_GRAPH_MODE
=
false
export
ENABLE_GRAPH_MODE
=
false
export
DTYPE
=
BF16
export
DTYPE
=
BF16
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment