Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
27c5575f
"vscode:/vscode.git/clone" did not exist on "0e238e6778b4e045ed91387a97304a7cb3ad8544"
Commit
27c5575f
authored
Sep 04, 2025
by
Yang Yong(雍洋)
Committed by
GitHub
Sep 04, 2025
Browse files
Support Multi Levels Profile Log (#290)
parent
dd870f3f
Changes
32
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
97 additions
and
81 deletions
+97
-81
app/run_gradio.sh
app/run_gradio.sh
+1
-1
app/run_gradio_win.bat
app/run_gradio_win.bat
+1
-1
lightx2v/deploy/server/__main__.py
lightx2v/deploy/server/__main__.py
+2
-2
lightx2v/deploy/worker/hub.py
lightx2v/deploy/worker/hub.py
+2
-2
lightx2v/infer.py
lightx2v/infer.py
+2
-2
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+13
-13
lightx2v/models/runners/graph_runner.py
lightx2v/models/runners/graph_runner.py
+2
-2
lightx2v/models/runners/qwen_image/qwen_image_runner.py
lightx2v/models/runners/qwen_image/qwen_image_runner.py
+10
-12
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+12
-12
lightx2v/models/runners/wan/wan_causvid_runner.py
lightx2v/models/runners/wan/wan_causvid_runner.py
+6
-6
lightx2v/models/runners/wan/wan_skyreels_v2_df_runner.py
lightx2v/models/runners/wan/wan_skyreels_v2_df_runner.py
+6
-6
lightx2v/models/runners/wan/wan_vace_runner.py
lightx2v/models/runners/wan/wan_vace_runner.py
+2
-2
lightx2v/models/vfi/rife/rife_comfyui_wrapper.py
lightx2v/models/vfi/rife/rife_comfyui_wrapper.py
+3
-3
lightx2v/utils/envs.py
lightx2v/utils/envs.py
+3
-3
lightx2v/utils/profiler.py
lightx2v/utils/profiler.py
+21
-3
lightx2v/utils/prompt_enhancer.py
lightx2v/utils/prompt_enhancer.py
+2
-2
scripts/base/base.sh
scripts/base/base.sh
+6
-6
scripts/bench/run_lightx2v_1.sh
scripts/bench/run_lightx2v_1.sh
+1
-1
scripts/bench/run_lightx2v_2.sh
scripts/bench/run_lightx2v_2.sh
+1
-1
scripts/bench/run_lightx2v_3.sh
scripts/bench/run_lightx2v_3.sh
+1
-1
No files found.
app/run_gradio.sh
View file @
27c5575f
...
@@ -46,7 +46,7 @@ gpu_id=0
...
@@ -46,7 +46,7 @@ gpu_id=0
export
CUDA_VISIBLE_DEVICES
=
$gpu_id
export
CUDA_VISIBLE_DEVICES
=
$gpu_id
export
CUDA_LAUNCH_BLOCKING
=
1
export
CUDA_LAUNCH_BLOCKING
=
1
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_
PROFILING_DEBUG
=
true
export
PROFILING_DEBUG
_LEVEL
=
2
export
PYTORCH_CUDA_ALLOC_CONF
=
expandable_segments:True
export
PYTORCH_CUDA_ALLOC_CONF
=
expandable_segments:True
# ==================== Parameter Parsing ====================
# ==================== Parameter Parsing ====================
...
...
app/run_gradio_win.bat
View file @
27c5575f
...
@@ -45,7 +45,7 @@ set gpu_id=0
...
@@ -45,7 +45,7 @@ set gpu_id=0
REM ==================== Environment Variables Setup ====================
REM ==================== Environment Variables Setup ====================
set
CUDA_VISIBLE_DEVICES
=
%gpu_id%
set
CUDA_VISIBLE_DEVICES
=
%gpu_id%
set
PYTHONPATH
=
%lightx2
v_path
%
;
%PYTHONPATH%
set
PYTHONPATH
=
%lightx2
v_path
%
;
%PYTHONPATH%
set
ENABLE_
PROFILING_DEBUG
=
true
set
PROFILING_DEBUG
_LEVEL
=
2
set
PYTORCH_CUDA_ALLOC_CONF
=
expandable_segments
:True
set
PYTORCH_CUDA_ALLOC_CONF
=
expandable_segments
:True
REM ==================== Parameter Parsing ====================
REM ==================== Parameter Parsing ====================
...
...
lightx2v/deploy/server/__main__.py
View file @
27c5575f
...
@@ -21,7 +21,7 @@ from lightx2v.deploy.server.auth import AuthManager
...
@@ -21,7 +21,7 @@ from lightx2v.deploy.server.auth import AuthManager
from
lightx2v.deploy.server.metrics
import
MetricMonitor
from
lightx2v.deploy.server.metrics
import
MetricMonitor
from
lightx2v.deploy.server.monitor
import
ServerMonitor
,
WorkerStatus
from
lightx2v.deploy.server.monitor
import
ServerMonitor
,
WorkerStatus
from
lightx2v.deploy.task_manager
import
LocalTaskManager
,
PostgresSQLTaskManager
,
TaskStatus
from
lightx2v.deploy.task_manager
import
LocalTaskManager
,
PostgresSQLTaskManager
,
TaskStatus
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.service_utils
import
ProcessManager
from
lightx2v.utils.service_utils
import
ProcessManager
# =========================
# =========================
...
@@ -679,7 +679,7 @@ if __name__ == "__main__":
...
@@ -679,7 +679,7 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
logger
.
info
(
f
"args:
{
args
}
"
)
logger
.
info
(
f
"args:
{
args
}
"
)
with
ProfilingContext
(
"Init Server Cost"
):
with
ProfilingContext
4DebugL1
(
"Init Server Cost"
):
model_pipelines
=
Pipeline
(
args
.
pipeline_json
)
model_pipelines
=
Pipeline
(
args
.
pipeline_json
)
auth_manager
=
AuthManager
()
auth_manager
=
AuthManager
()
if
args
.
task_url
.
startswith
(
"/"
):
if
args
.
task_url
.
startswith
(
"/"
):
...
...
lightx2v/deploy/worker/hub.py
View file @
27c5575f
...
@@ -16,14 +16,14 @@ from lightx2v.deploy.common.utils import class_try_catch_async
...
@@ -16,14 +16,14 @@ from lightx2v.deploy.common.utils import class_try_catch_async
from
lightx2v.infer
import
init_runner
# noqa
from
lightx2v.infer
import
init_runner
# noqa
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.utils.envs
import
CHECK_ENABLE_GRAPH_MODE
from
lightx2v.utils.envs
import
CHECK_ENABLE_GRAPH_MODE
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.set_config
import
set_config
,
set_parallel_config
from
lightx2v.utils.set_config
import
set_config
,
set_parallel_config
from
lightx2v.utils.utils
import
seed_all
from
lightx2v.utils.utils
import
seed_all
class
BaseWorker
:
class
BaseWorker
:
@
ProfilingContext
(
"Init Worker Worker Cost:"
)
@
ProfilingContext
4DebugL1
(
"Init Worker Worker Cost:"
)
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
config
=
set_config
(
args
)
config
=
set_config
(
args
)
config
[
"mode"
]
=
""
config
[
"mode"
]
=
""
...
...
lightx2v/infer.py
View file @
27c5575f
...
@@ -15,7 +15,7 @@ from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner #
...
@@ -15,7 +15,7 @@ from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner #
from
lightx2v.models.runners.wan.wan_skyreels_v2_df_runner
import
WanSkyreelsV2DFRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_skyreels_v2_df_runner
import
WanSkyreelsV2DFRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_vace_runner
import
WanVaceRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_vace_runner
import
WanVaceRunner
# noqa: F401
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.set_config
import
print_config
,
set_config
,
set_parallel_config
from
lightx2v.utils.set_config
import
print_config
,
set_config
,
set_parallel_config
from
lightx2v.utils.utils
import
seed_all
from
lightx2v.utils.utils
import
seed_all
...
@@ -103,7 +103,7 @@ def main():
...
@@ -103,7 +103,7 @@ def main():
print_config
(
config
)
print_config
(
config
)
with
ProfilingContext
(
"Total Cost"
):
with
ProfilingContext
4DebugL1
(
"Total Cost"
):
runner
=
init_runner
(
config
)
runner
=
init_runner
(
config
)
runner
.
run_pipeline
()
runner
.
run_pipeline
()
...
...
lightx2v/models/runners/default_runner.py
View file @
27c5575f
...
@@ -10,7 +10,7 @@ from requests.exceptions import RequestException
...
@@ -10,7 +10,7 @@ from requests.exceptions import RequestException
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.generate_task_id
import
generate_task_id
from
lightx2v.utils.generate_task_id
import
generate_task_id
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
from
.base_runner
import
BaseRunner
from
.base_runner
import
BaseRunner
...
@@ -60,7 +60,7 @@ class DefaultRunner(BaseRunner):
...
@@ -60,7 +60,7 @@ class DefaultRunner(BaseRunner):
else
:
else
:
raise
ValueError
(
f
"Unsupported VFI model:
{
self
.
config
[
'video_frame_interpolation'
][
'algo'
]
}
"
)
raise
ValueError
(
f
"Unsupported VFI model:
{
self
.
config
[
'video_frame_interpolation'
][
'algo'
]
}
"
)
@
ProfilingContext
(
"Load models"
)
@
ProfilingContext
4DebugL2
(
"Load models"
)
def
load_model
(
self
):
def
load_model
(
self
):
self
.
model
=
self
.
load_transformer
()
self
.
model
=
self
.
load_transformer
()
self
.
text_encoders
=
self
.
load_text_encoder
()
self
.
text_encoders
=
self
.
load_text_encoder
()
...
@@ -116,13 +116,13 @@ class DefaultRunner(BaseRunner):
...
@@ -116,13 +116,13 @@ class DefaultRunner(BaseRunner):
self
.
check_stop
()
self
.
check_stop
()
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
total_steps
}
"
)
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
total_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
with
ProfilingContext4Debug
L1
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
with
ProfilingContext4Debug
(
"🚀 infer_main"
):
with
ProfilingContext4Debug
L1
(
"🚀 infer_main"
):
self
.
model
.
infer
(
self
.
inputs
)
self
.
model
.
infer
(
self
.
inputs
)
with
ProfilingContext4Debug
(
"step_post"
):
with
ProfilingContext4Debug
L1
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
self
.
model
.
scheduler
.
step_post
()
if
self
.
progress_callback
:
if
self
.
progress_callback
:
...
@@ -155,7 +155,7 @@ class DefaultRunner(BaseRunner):
...
@@ -155,7 +155,7 @@ class DefaultRunner(BaseRunner):
img
=
TF
.
to_tensor
(
img_ori
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
cuda
()
img
=
TF
.
to_tensor
(
img_ori
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
cuda
()
return
img
,
img_ori
return
img
,
img_ori
@
ProfilingContext
(
"Run Encoders"
)
@
ProfilingContext
4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_i2v
(
self
):
def
_run_input_encoder_local_i2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
img
,
img_ori
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
img
,
img_ori
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
...
@@ -166,7 +166,7 @@ class DefaultRunner(BaseRunner):
...
@@ -166,7 +166,7 @@ class DefaultRunner(BaseRunner):
gc
.
collect
()
gc
.
collect
()
return
self
.
get_encoder_output_i2v
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
,
img
)
return
self
.
get_encoder_output_i2v
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
,
img
)
@
ProfilingContext
(
"Run Encoders"
)
@
ProfilingContext
4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_t2v
(
self
):
def
_run_input_encoder_local_t2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
None
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
,
None
)
...
@@ -177,7 +177,7 @@ class DefaultRunner(BaseRunner):
...
@@ -177,7 +177,7 @@ class DefaultRunner(BaseRunner):
"image_encoder_output"
:
None
,
"image_encoder_output"
:
None
,
}
}
@
ProfilingContext
(
"Run Encoders"
)
@
ProfilingContext
4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_flf2v
(
self
):
def
_run_input_encoder_local_flf2v
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
first_frame
,
_
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
first_frame
,
_
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
...
@@ -189,7 +189,7 @@ class DefaultRunner(BaseRunner):
...
@@ -189,7 +189,7 @@ class DefaultRunner(BaseRunner):
gc
.
collect
()
gc
.
collect
()
return
self
.
get_encoder_output_i2v
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
)
return
self
.
get_encoder_output_i2v
(
clip_encoder_out
,
vae_encode_out
,
text_encoder_output
)
@
ProfilingContext
(
"Run Encoders"
)
@
ProfilingContext
4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_vace
(
self
):
def
_run_input_encoder_local_vace
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
src_video
=
self
.
config
.
get
(
"src_video"
,
None
)
src_video
=
self
.
config
.
get
(
"src_video"
,
None
)
...
@@ -219,12 +219,12 @@ class DefaultRunner(BaseRunner):
...
@@ -219,12 +219,12 @@ class DefaultRunner(BaseRunner):
if
self
.
config
.
get
(
"model_cls"
)
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
if
self
.
config
.
get
(
"model_cls"
)
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
self
.
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
=
None
self
.
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
=
None
@
ProfilingContext
(
"Run DiT"
)
@
ProfilingContext
4DebugL2
(
"Run DiT"
)
def
run_main
(
self
,
total_steps
=
None
):
def
run_main
(
self
,
total_steps
=
None
):
self
.
init_run
()
self
.
init_run
()
for
segment_idx
in
range
(
self
.
video_segment_num
):
for
segment_idx
in
range
(
self
.
video_segment_num
):
logger
.
info
(
f
"🔄 segment
_idx:
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
"
)
logger
.
info
(
f
"🔄
start
segment
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
"
)
with
ProfilingContext
(
f
"segment end2end
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
"
):
with
ProfilingContext
4DebugL1
(
f
"segment end2end
{
segment_idx
+
1
}
/
{
self
.
video_segment_num
}
"
):
self
.
check_stop
()
self
.
check_stop
()
# 1. default do nothing
# 1. default do nothing
self
.
init_run_segment
(
segment_idx
)
self
.
init_run_segment
(
segment_idx
)
...
@@ -236,7 +236,7 @@ class DefaultRunner(BaseRunner):
...
@@ -236,7 +236,7 @@ class DefaultRunner(BaseRunner):
self
.
end_run_segment
()
self
.
end_run_segment
()
self
.
end_run
()
self
.
end_run
()
@
ProfilingContext
(
"Run VAE Decoder"
)
@
ProfilingContext
4DebugL1
(
"Run VAE Decoder"
)
def
run_vae_decoder
(
self
,
latents
):
def
run_vae_decoder
(
self
,
latents
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
vae_decoder
=
self
.
load_vae_decoder
()
self
.
vae_decoder
=
self
.
load_vae_decoder
()
...
...
lightx2v/models/runners/graph_runner.py
View file @
27c5575f
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
*
class
GraphRunner
:
class
GraphRunner
:
...
@@ -13,7 +13,7 @@ class GraphRunner:
...
@@ -13,7 +13,7 @@ class GraphRunner:
logger
.
info
(
"🚀 Starting Model Compilation - Please wait, this may take a while... 🚀"
)
logger
.
info
(
"🚀 Starting Model Compilation - Please wait, this may take a while... 🚀"
)
logger
.
info
(
"="
*
60
)
logger
.
info
(
"="
*
60
)
with
ProfilingContext4Debug
(
"compile"
):
with
ProfilingContext4Debug
L2
(
"compile"
):
self
.
runner
.
run_step
()
self
.
runner
.
run_step
()
logger
.
info
(
"="
*
60
)
logger
.
info
(
"="
*
60
)
...
...
lightx2v/models/runners/qwen_image/qwen_image_runner.py
View file @
27c5575f
...
@@ -10,7 +10,7 @@ from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel
...
@@ -10,7 +10,7 @@ from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.schedulers.qwen_image.scheduler
import
QwenImageScheduler
from
lightx2v.models.schedulers.qwen_image.scheduler
import
QwenImageScheduler
from
lightx2v.models.video_encoders.hf.qwen_image.vae
import
AutoencoderKLQwenImageVAE
from
lightx2v.models.video_encoders.hf.qwen_image.vae
import
AutoencoderKLQwenImageVAE
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
...
@@ -32,7 +32,7 @@ class QwenImageRunner(DefaultRunner):
...
@@ -32,7 +32,7 @@ class QwenImageRunner(DefaultRunner):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
@
ProfilingContext
(
"Load models"
)
@
ProfilingContext
4DebugL2
(
"Load models"
)
def
load_model
(
self
):
def
load_model
(
self
):
self
.
model
=
self
.
load_transformer
()
self
.
model
=
self
.
load_transformer
()
self
.
text_encoders
=
self
.
load_text_encoder
()
self
.
text_encoders
=
self
.
load_text_encoder
()
...
@@ -69,7 +69,7 @@ class QwenImageRunner(DefaultRunner):
...
@@ -69,7 +69,7 @@ class QwenImageRunner(DefaultRunner):
else
:
else
:
assert
NotImplementedError
assert
NotImplementedError
@
ProfilingContext
(
"Run DiT"
)
@
ProfilingContext
4DebugL2
(
"Run DiT"
)
def
_run_dit_local
(
self
,
total_steps
=
None
):
def
_run_dit_local
(
self
,
total_steps
=
None
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
=
self
.
load_transformer
()
self
.
model
=
self
.
load_transformer
()
...
@@ -81,7 +81,7 @@ class QwenImageRunner(DefaultRunner):
...
@@ -81,7 +81,7 @@ class QwenImageRunner(DefaultRunner):
self
.
end_run
()
self
.
end_run
()
return
latents
,
generator
return
latents
,
generator
@
ProfilingContext
(
"Run Encoders"
)
@
ProfilingContext
4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_t2i
(
self
):
def
_run_input_encoder_local_t2i
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
text_encoder_output
=
self
.
run_text_encoder
(
prompt
)
text_encoder_output
=
self
.
run_text_encoder
(
prompt
)
...
@@ -92,7 +92,7 @@ class QwenImageRunner(DefaultRunner):
...
@@ -92,7 +92,7 @@ class QwenImageRunner(DefaultRunner):
"image_encoder_output"
:
None
,
"image_encoder_output"
:
None
,
}
}
@
ProfilingContext
(
"Run Encoders"
)
@
ProfilingContext
4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_i2i
(
self
):
def
_run_input_encoder_local_i2i
(
self
):
image
=
Image
.
open
(
self
.
config
[
"image_path"
])
image
=
Image
.
open
(
self
.
config
[
"image_path"
])
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
...
@@ -125,20 +125,18 @@ class QwenImageRunner(DefaultRunner):
...
@@ -125,20 +125,18 @@ class QwenImageRunner(DefaultRunner):
return
{
"image_latents"
:
image_latents
}
return
{
"image_latents"
:
image_latents
}
def
run
(
self
,
total_steps
=
None
):
def
run
(
self
,
total_steps
=
None
):
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
if
total_steps
is
None
:
if
total_steps
is
None
:
total_steps
=
self
.
model
.
scheduler
.
infer_steps
total_steps
=
self
.
model
.
scheduler
.
infer_steps
for
step_index
in
range
(
total_steps
):
for
step_index
in
range
(
total_steps
):
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
total_steps
}
"
)
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
total_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
with
ProfilingContext4Debug
L1
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
with
ProfilingContext4Debug
(
"🚀 infer_main"
):
with
ProfilingContext4Debug
L1
(
"🚀 infer_main"
):
self
.
model
.
infer
(
self
.
inputs
)
self
.
model
.
infer
(
self
.
inputs
)
with
ProfilingContext4Debug
(
"step_post"
):
with
ProfilingContext4Debug
L1
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
self
.
model
.
scheduler
.
step_post
()
if
self
.
progress_callback
:
if
self
.
progress_callback
:
...
@@ -181,7 +179,7 @@ class QwenImageRunner(DefaultRunner):
...
@@ -181,7 +179,7 @@ class QwenImageRunner(DefaultRunner):
def
run_image_encoder
(
self
):
def
run_image_encoder
(
self
):
pass
pass
@
ProfilingContext
(
"Load models"
)
@
ProfilingContext
4DebugL2
(
"Load models"
)
def
load_model
(
self
):
def
load_model
(
self
):
self
.
model
=
self
.
load_transformer
()
self
.
model
=
self
.
load_transformer
()
self
.
text_encoders
=
self
.
load_text_encoder
()
self
.
text_encoders
=
self
.
load_text_encoder
()
...
@@ -189,7 +187,7 @@ class QwenImageRunner(DefaultRunner):
...
@@ -189,7 +187,7 @@ class QwenImageRunner(DefaultRunner):
self
.
vae
=
self
.
load_vae
()
self
.
vae
=
self
.
load_vae
()
self
.
vfi_model
=
self
.
load_vfi_model
()
if
"video_frame_interpolation"
in
self
.
config
else
None
self
.
vfi_model
=
self
.
load_vfi_model
()
if
"video_frame_interpolation"
in
self
.
config
else
None
@
ProfilingContext
(
"Run VAE Decoder"
)
@
ProfilingContext
4DebugL1
(
"Run VAE Decoder"
)
def
_run_vae_decoder_local
(
self
,
latents
,
generator
):
def
_run_vae_decoder_local
(
self
,
latents
,
generator
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
vae_decoder
=
self
.
load_vae
()
self
.
vae_decoder
=
self
.
load_vae
()
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
27c5575f
...
@@ -25,7 +25,7 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner
...
@@ -25,7 +25,7 @@ from lightx2v.models.runners.wan.wan_runner import WanRunner
from
lightx2v.models.schedulers.wan.audio.scheduler
import
EulerScheduler
from
lightx2v.models.schedulers.wan.audio.scheduler
import
EulerScheduler
from
lightx2v.models.video_encoders.hf.wan.vae_2_2
import
Wan2_2_VAE
from
lightx2v.models.video_encoders.hf.wan.vae_2_2
import
Wan2_2_VAE
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
find_torch_model_path
,
load_weights
,
save_to_video
,
vae_to_comfyui_image
from
lightx2v.utils.utils
import
find_torch_model_path
,
load_weights
,
save_to_video
,
vae_to_comfyui_image
...
@@ -368,7 +368,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -368,7 +368,7 @@ class WanAudioRunner(WanRunner): # type:ignore
gc
.
collect
()
gc
.
collect
()
return
vae_encoder_out
return
vae_encoder_out
@
ProfilingContext
(
"Run Encoders"
)
@
ProfilingContext
4DebugL2
(
"Run Encoders"
)
def
_run_input_encoder_local_r2v_audio
(
self
):
def
_run_input_encoder_local_r2v_audio
(
self
):
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
prompt
=
self
.
config
[
"prompt_enhanced"
]
if
self
.
config
[
"use_prompt_enhancer"
]
else
self
.
config
[
"prompt"
]
img
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
img
=
self
.
read_image_input
(
self
.
config
[
"image_path"
])
...
@@ -410,7 +410,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -410,7 +410,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
vae_encoder
=
self
.
load_vae_encoder
()
self
.
vae_encoder
=
self
.
load_vae_encoder
()
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
_
,
nframe
,
height
,
width
=
self
.
model
.
scheduler
.
latents
.
shape
with
ProfilingContext4Debug
(
"vae_encoder in init run segment"
):
with
ProfilingContext4Debug
L1
(
"vae_encoder in init run segment"
):
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
if
prev_video
is
not
None
:
if
prev_video
is
not
None
:
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
dtype
))
prev_latents
=
self
.
vae_encoder
.
encode
(
prev_frames
.
to
(
dtype
))
...
@@ -460,7 +460,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -460,7 +460,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
cut_audio_list
=
[]
self
.
cut_audio_list
=
[]
self
.
prev_video
=
None
self
.
prev_video
=
None
@
ProfilingContext4Debug
(
"Init run segment"
)
@
ProfilingContext4Debug
L1
(
"Init run segment"
)
def
init_run_segment
(
self
,
segment_idx
,
audio_array
=
None
):
def
init_run_segment
(
self
,
segment_idx
,
audio_array
=
None
):
self
.
segment_idx
=
segment_idx
self
.
segment_idx
=
segment_idx
if
audio_array
is
not
None
:
if
audio_array
is
not
None
:
...
@@ -485,7 +485,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -485,7 +485,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if
segment_idx
>
0
:
if
segment_idx
>
0
:
self
.
model
.
scheduler
.
reset
(
self
.
inputs
[
"previmg_encoder_output"
])
self
.
model
.
scheduler
.
reset
(
self
.
inputs
[
"previmg_encoder_output"
])
@
ProfilingContext4Debug
(
"End run segment"
)
@
ProfilingContext4Debug
L1
(
"End run segment"
)
def
end_run_segment
(
self
):
def
end_run_segment
(
self
):
self
.
gen_video
=
torch
.
clamp
(
self
.
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
self
.
gen_video
=
torch
.
clamp
(
self
.
gen_video
,
-
1
,
1
).
to
(
torch
.
float
)
useful_length
=
self
.
segment
.
end_frame
-
self
.
segment
.
start_frame
useful_length
=
self
.
segment
.
end_frame
-
self
.
segment
.
start_frame
...
@@ -575,7 +575,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -575,7 +575,7 @@ class WanAudioRunner(WanRunner): # type:ignore
max_fail_count
=
10
max_fail_count
=
10
while
True
:
while
True
:
with
ProfilingContext4Debug
(
f
"stream segment get audio segment
{
segment_idx
}
"
):
with
ProfilingContext4Debug
L1
(
f
"stream segment get audio segment
{
segment_idx
}
"
):
self
.
check_stop
()
self
.
check_stop
()
audio_array
=
self
.
va_reader
.
get_audio_segment
(
timeout
=
fetch_timeout
)
audio_array
=
self
.
va_reader
.
get_audio_segment
(
timeout
=
fetch_timeout
)
if
audio_array
is
None
:
if
audio_array
is
None
:
...
@@ -585,7 +585,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -585,7 +585,7 @@ class WanAudioRunner(WanRunner): # type:ignore
raise
Exception
(
f
"Failed to get audio chunk
{
fail_count
}
times, stop reader"
)
raise
Exception
(
f
"Failed to get audio chunk
{
fail_count
}
times, stop reader"
)
continue
continue
with
ProfilingContext4Debug
(
f
"stream segment end2end
{
segment_idx
}
"
):
with
ProfilingContext4Debug
L1
(
f
"stream segment end2end
{
segment_idx
}
"
):
fail_count
=
0
fail_count
=
0
self
.
init_run_segment
(
segment_idx
,
audio_array
)
self
.
init_run_segment
(
segment_idx
,
audio_array
)
latents
=
self
.
run_segment
(
total_steps
=
None
)
latents
=
self
.
run_segment
(
total_steps
=
None
)
...
@@ -603,7 +603,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -603,7 +603,7 @@ class WanAudioRunner(WanRunner): # type:ignore
self
.
va_recorder
.
stop
(
wait
=
False
)
self
.
va_recorder
.
stop
(
wait
=
False
)
self
.
va_recorder
=
None
self
.
va_recorder
=
None
@
ProfilingContext4Debug
(
"Process after vae decoder"
)
@
ProfilingContext4Debug
L1
(
"Process after vae decoder"
)
def
process_images_after_vae_decoder
(
self
,
save_video
=
True
):
def
process_images_after_vae_decoder
(
self
,
save_video
=
True
):
# Merge results
# Merge results
gen_lvideo
=
torch
.
cat
(
self
.
gen_video_list
,
dim
=
2
).
float
()
gen_lvideo
=
torch
.
cat
(
self
.
gen_video_list
,
dim
=
2
).
float
()
...
@@ -728,9 +728,9 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -728,9 +728,9 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_adapter
.
load_state_dict
(
weights_dict
,
strict
=
False
)
audio_adapter
.
load_state_dict
(
weights_dict
,
strict
=
False
)
return
audio_adapter
.
to
(
dtype
=
GET_DTYPE
())
return
audio_adapter
.
to
(
dtype
=
GET_DTYPE
())
@
ProfilingContext
(
"Load models"
)
def
load_model
(
self
):
def
load_model
(
self
):
super
().
load_model
()
super
().
load_model
()
with
ProfilingContext4DebugL2
(
"Load audio encoder and adapter"
):
self
.
audio_encoder
=
self
.
load_audio_encoder
()
self
.
audio_encoder
=
self
.
load_audio_encoder
()
self
.
audio_adapter
=
self
.
load_audio_adapter
()
self
.
audio_adapter
=
self
.
load_audio_adapter
()
self
.
model
.
set_audio_adapter
(
self
.
audio_adapter
)
self
.
model
.
set_audio_adapter
(
self
.
audio_adapter
)
...
...
lightx2v/models/runners/wan/wan_causvid_runner.py
View file @
27c5575f
...
@@ -9,7 +9,7 @@ from lightx2v.models.networks.wan.model import WanModel
...
@@ -9,7 +9,7 @@ from lightx2v.models.networks.wan.model import WanModel
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.schedulers.wan.step_distill.scheduler
import
WanStepDistillScheduler
from
lightx2v.models.schedulers.wan.step_distill.scheduler
import
WanStepDistillScheduler
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
...
@@ -85,11 +85,11 @@ class WanCausVidRunner(WanRunner):
...
@@ -85,11 +85,11 @@ class WanCausVidRunner(WanRunner):
if
fragment_idx
>
0
:
if
fragment_idx
>
0
:
logger
.
info
(
"recompute the kv_cache ..."
)
logger
.
info
(
"recompute the kv_cache ..."
)
with
ProfilingContext4Debug
(
"step_pre"
):
with
ProfilingContext4Debug
L1
(
"step_pre"
):
self
.
model
.
scheduler
.
latents
=
self
.
model
.
scheduler
.
last_sample
self
.
model
.
scheduler
.
latents
=
self
.
model
.
scheduler
.
last_sample
self
.
model
.
scheduler
.
step_pre
(
step_index
=
self
.
model
.
scheduler
.
infer_steps
-
1
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
self
.
model
.
scheduler
.
infer_steps
-
1
)
with
ProfilingContext4Debug
(
"🚀 infer_main"
):
with
ProfilingContext4Debug
L1
(
"🚀 infer_main"
):
self
.
model
.
infer
(
self
.
inputs
,
kv_start
,
kv_end
)
self
.
model
.
infer
(
self
.
inputs
,
kv_start
,
kv_end
)
kv_start
+=
self
.
num_frame_per_block
*
self
.
frame_seq_length
kv_start
+=
self
.
num_frame_per_block
*
self
.
frame_seq_length
...
@@ -105,13 +105,13 @@ class WanCausVidRunner(WanRunner):
...
@@ -105,13 +105,13 @@ class WanCausVidRunner(WanRunner):
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
with
ProfilingContext4Debug
L1
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
with
ProfilingContext4Debug
(
"🚀 infer_main"
):
with
ProfilingContext4Debug
L1
(
"🚀 infer_main"
):
self
.
model
.
infer
(
self
.
inputs
,
kv_start
,
kv_end
)
self
.
model
.
infer
(
self
.
inputs
,
kv_start
,
kv_end
)
with
ProfilingContext4Debug
(
"step_post"
):
with
ProfilingContext4Debug
L1
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
self
.
model
.
scheduler
.
step_post
()
kv_start
+=
self
.
num_frame_per_block
*
self
.
frame_seq_length
kv_start
+=
self
.
num_frame_per_block
*
self
.
frame_seq_length
...
...
lightx2v/models/runners/wan/wan_skyreels_v2_df_runner.py
View file @
27c5575f
...
@@ -10,7 +10,7 @@ from loguru import logger
...
@@ -10,7 +10,7 @@ from loguru import logger
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.schedulers.wan.df.skyreels_v2_df_scheduler
import
WanSkyreelsV2DFScheduler
from
lightx2v.models.schedulers.wan.df.skyreels_v2_df_scheduler
import
WanSkyreelsV2DFScheduler
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
...
@@ -55,9 +55,9 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
...
@@ -55,9 +55,9 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
def
run_input_encoder
(
self
):
def
run_input_encoder
(
self
):
image_encoder_output
=
None
image_encoder_output
=
None
if
os
.
path
.
isfile
(
self
.
config
.
image_path
):
if
os
.
path
.
isfile
(
self
.
config
.
image_path
):
with
ProfilingContext
(
"Run Img Encoder"
):
with
ProfilingContext
4DebugL2
(
"Run Img Encoder"
):
image_encoder_output
=
self
.
run_image_encoder
(
self
.
config
,
self
.
image_encoder
,
self
.
vae_model
)
image_encoder_output
=
self
.
run_image_encoder
(
self
.
config
,
self
.
image_encoder
,
self
.
vae_model
)
with
ProfilingContext
(
"Run Text Encoder"
):
with
ProfilingContext
4DebugL2
(
"Run Text Encoder"
):
text_encoder_output
=
self
.
run_text_encoder
(
self
.
config
[
"prompt"
],
self
.
text_encoders
,
self
.
config
,
image_encoder_output
)
text_encoder_output
=
self
.
run_text_encoder
(
self
.
config
[
"prompt"
],
self
.
text_encoders
,
self
.
config
,
image_encoder_output
)
self
.
set_target_shape
()
self
.
set_target_shape
()
self
.
inputs
=
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
}
self
.
inputs
=
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
}
...
@@ -107,13 +107,13 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
...
@@ -107,13 +107,13 @@ class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
with
ProfilingContext4Debug
L1
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
with
ProfilingContext4Debug
(
"🚀 infer_main"
):
with
ProfilingContext4Debug
L1
(
"🚀 infer_main"
):
self
.
model
.
infer
(
self
.
inputs
)
self
.
model
.
infer
(
self
.
inputs
)
with
ProfilingContext4Debug
(
"step_post"
):
with
ProfilingContext4Debug
L1
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
self
.
model
.
scheduler
.
step_post
()
videos
=
self
.
run_vae
(
self
.
model
.
scheduler
.
latents
,
self
.
model
.
scheduler
.
generator
)
videos
=
self
.
run_vae
(
self
.
model
.
scheduler
.
latents
,
self
.
model
.
scheduler
.
generator
)
...
...
lightx2v/models/runners/wan/wan_vace_runner.py
View file @
27c5575f
...
@@ -9,7 +9,7 @@ from lightx2v.models.input_encoders.hf.vace.vace_processor import VaceVideoProce
...
@@ -9,7 +9,7 @@ from lightx2v.models.input_encoders.hf.vace.vace_processor import VaceVideoProce
from
lightx2v.models.networks.wan.vace_model
import
WanVaceModel
from
lightx2v.models.networks.wan.vace_model
import
WanVaceModel
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
...
@@ -159,7 +159,7 @@ class WanVaceRunner(WanRunner):
...
@@ -159,7 +159,7 @@ class WanVaceRunner(WanRunner):
target_shape
[
0
]
=
int
(
target_shape
[
0
]
/
2
)
target_shape
[
0
]
=
int
(
target_shape
[
0
]
/
2
)
self
.
config
.
target_shape
=
target_shape
self
.
config
.
target_shape
=
target_shape
@
ProfilingContext
(
"Run VAE Decoder"
)
@
ProfilingContext
4DebugL1
(
"Run VAE Decoder"
)
def
run_vae_decoder
(
self
,
latents
):
def
run_vae_decoder
(
self
,
latents
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
vae_decoder
=
self
.
load_vae_decoder
()
self
.
vae_decoder
=
self
.
load_vae_decoder
()
...
...
lightx2v/models/vfi/rife/rife_comfyui_wrapper.py
View file @
27c5575f
...
@@ -4,7 +4,7 @@ from typing import List, Optional, Tuple
...
@@ -4,7 +4,7 @@ from typing import List, Optional, Tuple
import
torch
import
torch
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
*
class
RIFEWrapper
:
class
RIFEWrapper
:
...
@@ -25,12 +25,12 @@ class RIFEWrapper:
...
@@ -25,12 +25,12 @@ class RIFEWrapper:
from
.train_log.RIFE_HDv3
import
Model
from
.train_log.RIFE_HDv3
import
Model
self
.
model
=
Model
()
self
.
model
=
Model
()
with
ProfilingContext
(
"Load RIFE model"
):
with
ProfilingContext
4DebugL2
(
"Load RIFE model"
):
self
.
model
.
load_model
(
model_path
,
-
1
)
self
.
model
.
load_model
(
model_path
,
-
1
)
self
.
model
.
eval
()
self
.
model
.
eval
()
self
.
model
.
device
()
self
.
model
.
device
()
@
ProfilingContext
(
"Interpolate frames"
)
@
ProfilingContext
4DebugL2
(
"Interpolate frames"
)
def
interpolate_frames
(
def
interpolate_frames
(
self
,
self
,
images
:
torch
.
Tensor
,
images
:
torch
.
Tensor
,
...
...
lightx2v/utils/envs.py
View file @
27c5575f
...
@@ -17,9 +17,9 @@ DTYPE_MAP = {
...
@@ -17,9 +17,9 @@ DTYPE_MAP = {
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
CHECK_
ENABLE_
PROFILING_DEBUG
(
):
def
CHECK_PROFILING_DEBUG
_LEVEL
(
target_level
):
ENABLE_PROFILING_DEBUG
=
os
.
getenv
(
"
ENABLE_
PROFILING_DEBUG
"
,
"false"
).
lower
()
==
"true"
current_level
=
int
(
os
.
getenv
(
"PROFILING_DEBUG
_LEVEL"
,
"0"
))
return
ENABLE_PROFILING_DEBUG
return
current_level
>=
target_level
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
...
...
lightx2v/utils/profiler.py
View file @
27c5575f
...
@@ -12,7 +12,6 @@ from lightx2v.utils.envs import *
...
@@ -12,7 +12,6 @@ from lightx2v.utils.envs import *
class
_ProfilingContext
:
class
_ProfilingContext
:
def
__init__
(
self
,
name
):
def
__init__
(
self
,
name
):
self
.
name
=
name
self
.
name
=
name
self
.
rank_info
=
""
if
dist
.
is_initialized
():
if
dist
.
is_initialized
():
self
.
rank_info
=
f
"Rank
{
dist
.
get_rank
()
}
"
self
.
rank_info
=
f
"Rank
{
dist
.
get_rank
()
}
"
else
:
else
:
...
@@ -80,5 +79,24 @@ class _NullContext:
...
@@ -80,5 +79,24 @@ class _NullContext:
return
func
return
func
ProfilingContext
=
_ProfilingContext
class
_ProfilingContextL1
(
_ProfilingContext
):
ProfilingContext4Debug
=
_ProfilingContext
if
CHECK_ENABLE_PROFILING_DEBUG
()
else
_NullContext
"""Level 1 profiling context with Level1_Log prefix."""
def
__init__
(
self
,
name
):
super
().
__init__
(
f
"Level1_Log
{
name
}
"
)
class
_ProfilingContextL2
(
_ProfilingContext
):
"""Level 2 profiling context with Level2_Log prefix."""
def
__init__
(
self
,
name
):
super
().
__init__
(
f
"Level2_Log
{
name
}
"
)
"""
PROFILING_DEBUG_LEVEL=0: [Default] disable all profiling
PROFILING_DEBUG_LEVEL=1: enable ProfilingContext4DebugL1
PROFILING_DEBUG_LEVEL=2: enable ProfilingContext4DebugL1 and ProfilingContext4DebugL2
"""
ProfilingContext4DebugL1
=
_ProfilingContextL1
if
CHECK_PROFILING_DEBUG_LEVEL
(
1
)
else
_NullContext
# if user >= 1, enable profiling
ProfilingContext4DebugL2
=
_ProfilingContextL2
if
CHECK_PROFILING_DEBUG_LEVEL
(
2
)
else
_NullContext
# if user >= 2, enable profiling
lightx2v/utils/prompt_enhancer.py
View file @
27c5575f
...
@@ -4,7 +4,7 @@ import torch
...
@@ -4,7 +4,7 @@ import torch
from
loguru
import
logger
from
loguru
import
logger
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
*
sys_prompt
=
"""
sys_prompt
=
"""
Transform the short prompt into a detailed video-generation caption using this structure:
Transform the short prompt into a detailed video-generation caption using this structure:
...
@@ -40,7 +40,7 @@ class PromptEnhancer:
...
@@ -40,7 +40,7 @@ class PromptEnhancer:
def
to_device
(
self
,
device
):
def
to_device
(
self
,
device
):
self
.
model
=
self
.
model
.
to
(
device
)
self
.
model
=
self
.
model
.
to
(
device
)
@
ProfilingContext
(
"Run prompt enhancer"
)
@
ProfilingContext
4DebugL1
(
"Run prompt enhancer"
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
__call__
(
self
,
prompt
):
def
__call__
(
self
,
prompt
):
prompt
=
prompt
.
strip
()
prompt
=
prompt
.
strip
()
...
...
scripts/base/base.sh
View file @
27c5575f
...
@@ -32,12 +32,12 @@ export DTYPE=BF16
...
@@ -32,12 +32,12 @@ export DTYPE=BF16
# Note: If set to FP32, it will be slower, so we recommend set ENABLE_GRAPH_MODE to true.
# Note: If set to FP32, it will be slower, so we recommend set ENABLE_GRAPH_MODE to true.
export
SENSITIVE_LAYER_DTYPE
=
FP32
export
SENSITIVE_LAYER_DTYPE
=
FP32
# Performance Profiling Debug
Mode
(Debug Only)
# Performance Profiling Debug
Level
(Debug Only)
# Enables detailed performance analysis output, such as time cost and memory usage
# Enables detailed performance analysis output, such as time cost and memory usage
# Available options: [
true, false
]
# Available options: [
0, 1, 2
]
# If not set, default value:
false
# If not set, default value:
0
# Note: This option can be set to
false
for production.
# Note: This option can be set to
0
for production.
export
ENABLE_
PROFILING_DEBUG
=
true
export
PROFILING_DEBUG
_LEVEL
=
2
# Graph Mode Optimization (Performance Enhancement)
# Graph Mode Optimization (Performance Enhancement)
# Enables torch.compile for graph optimization, can improve inference performance
# Enables torch.compile for graph optimization, can improve inference performance
...
@@ -56,6 +56,6 @@ echo "model_path: ${model_path}"
...
@@ -56,6 +56,6 @@ echo "model_path: ${model_path}"
echo
"-------------------------------------------------------------------------------"
echo
"-------------------------------------------------------------------------------"
echo
"Model Inference Data Type:
${
DTYPE
}
"
echo
"Model Inference Data Type:
${
DTYPE
}
"
echo
"Sensitive Layer Data Type:
${
SENSITIVE_LAYER_DTYPE
}
"
echo
"Sensitive Layer Data Type:
${
SENSITIVE_LAYER_DTYPE
}
"
echo
"Performance Profiling Debug
Mode:
${
ENABLE_
PROFILING_DEBUG
}
"
echo
"Performance Profiling Debug
Level:
${
PROFILING_DEBUG
_LEVEL
}
"
echo
"Graph Mode Optimization:
${
ENABLE_GRAPH_MODE
}
"
echo
"Graph Mode Optimization:
${
ENABLE_GRAPH_MODE
}
"
echo
"==============================================================================="
echo
"==============================================================================="
scripts/bench/run_lightx2v_1.sh
View file @
27c5575f
...
@@ -27,7 +27,7 @@ export TOKENIZERS_PARALLELISM=false
...
@@ -27,7 +27,7 @@ export TOKENIZERS_PARALLELISM=false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
DTYPE
=
BF16
export
DTYPE
=
BF16
export
SENSITIVE_LAYER_DTYPE
=
FP32
export
SENSITIVE_LAYER_DTYPE
=
FP32
export
ENABLE_
PROFILING_DEBUG
=
true
export
PROFILING_DEBUG
_LEVEL
=
2
export
ENABLE_GRAPH_MODE
=
false
export
ENABLE_GRAPH_MODE
=
false
python
-m
lightx2v.infer
\
python
-m
lightx2v.infer
\
...
...
scripts/bench/run_lightx2v_2.sh
View file @
27c5575f
...
@@ -26,7 +26,7 @@ export TOKENIZERS_PARALLELISM=false
...
@@ -26,7 +26,7 @@ export TOKENIZERS_PARALLELISM=false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_
PROFILING_DEBUG
=
true
export
PROFILING_DEBUG
_LEVEL
=
2
export
ENABLE_GRAPH_MODE
=
false
export
ENABLE_GRAPH_MODE
=
false
export
DTYPE
=
BF16
export
DTYPE
=
BF16
...
...
scripts/bench/run_lightx2v_3.sh
View file @
27c5575f
...
@@ -26,7 +26,7 @@ export TOKENIZERS_PARALLELISM=false
...
@@ -26,7 +26,7 @@ export TOKENIZERS_PARALLELISM=false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_
PROFILING_DEBUG
=
true
export
PROFILING_DEBUG
_LEVEL
=
2
export
ENABLE_GRAPH_MODE
=
false
export
ENABLE_GRAPH_MODE
=
false
export
DTYPE
=
BF16
export
DTYPE
=
BF16
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment